Repository: Hanbo-Cheng/DAWN-pytorch Branch: main Commit: 76c44b8396c3 Files: 291 Total size: 72.5 MB Directory structure: gitextract_jmo9ls54/ ├── .gitignore ├── DAWN_256.yaml ├── DM_3/ │ ├── datasets_hdtf_wpose_lmk_block_lmk.py │ ├── datasets_hdtf_wpose_lmk_block_lmk_rand.py │ ├── modules/ │ │ ├── local_attention.py │ │ ├── text.py │ │ ├── video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_mouth_mask_6D.py │ │ ├── video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_mouth_mask_rand_6D.py │ │ ├── video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_test.py │ │ ├── video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi.py │ │ ├── video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test.py │ │ └── video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test_local_opt.py │ ├── test_lr.py │ ├── train_vdm_hdtf_wpose_plus_faceemb_init_cond_liploss_6D.py │ ├── train_vdm_hdtf_wpose_plus_faceemb_init_cond_liploss_6D_s2.py │ └── utils.py ├── LFG/ │ ├── __init__.py │ ├── augmentation.py │ ├── frames_dataset.py │ ├── hdtf_dataset.py │ ├── modules/ │ │ ├── avd_network.py │ │ ├── bg_motion_predictor.py │ │ ├── flow_autoenc.py │ │ ├── generator.py │ │ ├── model.py │ │ ├── pixelwise_flow_predictor.py │ │ ├── region_predictor.py │ │ └── util.py │ ├── run_hdtf.py │ ├── run_hdtf_crema.py │ ├── sync_batchnorm/ │ │ ├── __init__.py │ │ ├── batchnorm.py │ │ ├── comm.py │ │ ├── replicate.py │ │ └── unittest.py │ ├── test_flowautoenc_crema_video.py │ ├── test_flowautoenc_hdtf_video.py │ ├── test_flowautoenc_hdtf_video_256.py │ ├── train.py │ └── vis_flow.py ├── PBnet/ │ ├── run_cvae_h_ann_reemb_rope_eye_3.sh │ └── src/ │ ├── __init__.py │ ├── config.py │ ├── datasets/ │ │ ├── __init__.py │ │ ├── datasets_hdtf_pos_chunk_norm_2_fast.py │ │ ├── datasets_hdtf_pos_chunk_norm_eye_fast.py │ │ ├── datasets_hdtf_pos_df.py │ │ ├── datasets_hdtf_pos_dict_norm_2.py │ │ ├── datasets_hdtf_wpose_lmk_block.py │ │ ├── get_dataset.py │ │ └── tools.py │ ├── evaluate/ │ │ ├── __init__.py │ │ ├── action2motion/ │ │ │ ├── accuracy.py │ │ │ ├── diversity.py │ │ │ ├── evaluate.py │ │ │ ├── fid.py │ │ │ └── models.py │ │ ├── evaluate_cvae.py │ │ ├── evaluate_cvae_debug.py │ │ ├── evaluate_cvae_f3.py │ │ ├── evaluate_cvae_f3_debug.py │ │ ├── evaluate_cvae_f3_mel.py │ │ ├── evaluate_cvae_norm.py │ │ ├── evaluate_cvae_norm_all.py │ │ ├── evaluate_cvae_norm_all_seg.py │ │ ├── evaluate_cvae_norm_all_seg_weye.py │ │ ├── evaluate_cvae_norm_all_seg_weye2.py │ │ ├── evaluate_cvae_norm_eye_pose.py │ │ ├── evaluate_cvae_norm_eye_pose_test.py │ │ ├── evaluate_cvae_onlyeye_all_seg.py │ │ ├── othermetrics/ │ │ │ ├── acceleration.py │ │ │ └── evaluation.py │ │ ├── stgcn/ │ │ │ ├── accuracy.py │ │ │ ├── diversity.py │ │ │ ├── evaluate.py │ │ │ └── fid.py │ │ ├── tables/ │ │ │ ├── archtable.py │ │ │ ├── bstable.py │ │ │ ├── easy_table.py │ │ │ ├── easy_table_A2M.py │ │ │ ├── kltable.py │ │ │ ├── latexmodela2m.py │ │ │ ├── latexmodelsa2m.py │ │ │ ├── latexmodelsstgcn.py │ │ │ ├── losstable.py │ │ │ ├── maketable.py │ │ │ ├── numlayertable.py │ │ │ └── posereptable.py │ │ ├── tools.py │ │ ├── tvae_eval.py │ │ ├── tvae_eval_norm.py │ │ ├── tvae_eval_norm_all.py │ │ ├── tvae_eval_norm_eye_pose.py │ │ ├── tvae_eval_norm_eye_pose_seg.py │ │ ├── tvae_eval_norm_seg.py │ │ ├── tvae_eval_onlyeye_all_seg.py │ │ ├── tvae_eval_single.py │ │ ├── tvae_eval_single_both_eye_pose.py │ │ ├── tvae_eval_std.py │ │ ├── tvae_eval_train.py │ │ ├── tvae_eval_train_norm.py │ │ └── tvae_eval_train_std.py │ ├── generate/ │ │ └── generate_sequences.py │ ├── models/ │ │ ├── __init__.py │ │ ├── architectures/ │ │ │ ├── __init__.py │ │ │ ├── autotrans.py │ │ │ ├── fc.py │ │ │ ├── gru.py │ │ │ ├── grutrans.py │ │ │ ├── mlp.py │ │ │ ├── resnet34.py │ │ │ ├── tools/ │ │ │ │ ├── embeddings.py │ │ │ │ ├── resnet.py │ │ │ │ ├── transformer_layers.py │ │ │ │ └── util.py │ │ │ ├── transformer.py │ │ │ ├── transformerdecoder.py │ │ │ ├── transformerdecoder4.py │ │ │ ├── transformerdecoder5.py │ │ │ ├── transformerreemb.py │ │ │ ├── transformerreemb5.py │ │ │ ├── transformerreemb6.py │ │ │ └── transgru.py │ │ ├── get_model.py │ │ ├── modeltype/ │ │ │ ├── __init__.py │ │ │ ├── cae.py │ │ │ ├── cae_0.py │ │ │ ├── cvae.py │ │ │ └── lstm.py │ │ ├── rotation2xyz.py │ │ ├── smpl.py │ │ └── tools/ │ │ ├── __init__.py │ │ ├── graphconv.py │ │ ├── hessian_penalty.py │ │ ├── losses.py │ │ ├── mmd.py │ │ ├── msssim_loss.py │ │ ├── normalize_data.py │ │ ├── ssim_loss.py │ │ └── tools.py │ ├── parser/ │ │ ├── base.py │ │ ├── checkpoint.py │ │ ├── dataset.py │ │ ├── evaluation.py │ │ ├── finetunning.py │ │ ├── generate.py │ │ ├── model.py │ │ ├── recognition.py │ │ ├── tools.py │ │ ├── training.py │ │ └── visualize.py │ ├── preprocess/ │ │ ├── humanact12_process.py │ │ ├── phspdtools.py │ │ └── uestc_vibe_postprocessing.py │ ├── recognition/ │ │ ├── compute_accuracy.py │ │ ├── get_model.py │ │ └── models/ │ │ ├── stgcn.py │ │ └── stgcnutils/ │ │ ├── graph.py │ │ └── tgcn.py │ ├── render/ │ │ ├── renderer.py │ │ └── rendermotion.py │ ├── train/ │ │ ├── __init__.py │ │ ├── train_cvae_ganloss_ann_eye.py │ │ ├── train_cvae_ganloss_ann_fast.py │ │ ├── trainer.py │ │ ├── trainer_gan.py │ │ └── trainer_gan_ann.py │ ├── utils/ │ │ ├── PYTORCH3D_LICENSE │ │ ├── __init__.py │ │ ├── fixseed.py │ │ ├── get_model_and_data.py │ │ ├── misc.py │ │ ├── rotation_conversions.py │ │ ├── tensors.py │ │ ├── tensors_eye.py │ │ ├── tensors_eye_eval.py │ │ ├── tensors_hdtf.py │ │ ├── tensors_onlyeye.py │ │ ├── utils.py │ │ └── video.py │ └── visualize/ │ ├── __init__.py │ ├── anim.py │ ├── visualize.py │ ├── visualize_checkpoint.py │ ├── visualize_dataset.py │ ├── visualize_latent_space.py │ ├── visualize_nturefined.py │ └── visualize_sequence.py ├── README.md ├── README_CN.md ├── config/ │ ├── DAWN_128.yaml │ ├── DAWN_256.yaml │ ├── hdtf128.yaml │ ├── hdtf128_1000ep.yaml │ ├── hdtf128_1000ep_crema.yaml │ ├── hdtf256.yaml │ └── hdtf256_400ep.yaml ├── extract_init_states/ │ ├── FaceBoxes/ │ │ ├── FaceBoxes.py │ │ ├── FaceBoxes_ONNX.py │ │ ├── __init__.py │ │ ├── build_cpu_nms.sh │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ └── faceboxes.py │ │ ├── onnx.py │ │ ├── readme.md │ │ ├── utils/ │ │ │ ├── .gitignore │ │ │ ├── __init__.py │ │ │ ├── box_utils.py │ │ │ ├── build.py │ │ │ ├── config.py │ │ │ ├── functions.py │ │ │ ├── nms/ │ │ │ │ ├── .gitignore │ │ │ │ ├── __init__.py │ │ │ │ ├── cpu_nms.cp38-win_amd64.pyd │ │ │ │ ├── cpu_nms.pyx │ │ │ │ └── py_cpu_nms.py │ │ │ ├── nms_wrapper.py │ │ │ ├── prior_box.py │ │ │ └── timer.py │ │ └── weights/ │ │ ├── .gitignore │ │ ├── FaceBoxesProd.pth │ │ └── readme.md │ ├── TDDFA_ONNX.py │ ├── bfm/ │ │ ├── .gitignore │ │ ├── __init__.py │ │ ├── bfm.py │ │ ├── bfm_onnx.py │ │ └── readme.md │ ├── build.sh │ ├── configs/ │ │ ├── .gitignore │ │ ├── BFM_UV.mat │ │ ├── bfm_noneck_v3.onnx │ │ ├── bfm_noneck_v3.pkl │ │ ├── indices.npy │ │ ├── mb05_120x120.yml │ │ ├── mb1_120x120.yml │ │ ├── ncc_code.npy │ │ ├── param_mean_std_62d_120x120.pkl │ │ ├── readme.md │ │ ├── resnet_120x120.yml │ │ └── tri.pkl │ ├── demo_pose_extract_2d_lmk_img.py │ ├── functions.py │ ├── models/ │ │ ├── __init__.py │ │ ├── mobilenet_v1.py │ │ ├── mobilenet_v3.py │ │ └── resnet.py │ ├── pose.py │ ├── readme.md │ ├── utils/ │ │ ├── __init__.py │ │ ├── asset/ │ │ │ ├── .gitignore │ │ │ ├── build_render_ctypes.sh │ │ │ └── render.c │ │ ├── depth.py │ │ ├── functions.py │ │ ├── io.py │ │ ├── onnx.py │ │ ├── pncc.py │ │ ├── pose.py │ │ ├── render.py │ │ ├── render_ctypes.py │ │ ├── serialization.py │ │ ├── tddfa_util.py │ │ └── uv.py │ └── weights/ │ ├── .gitignore │ ├── mb05_120x120.pth │ ├── mb1_120x120.onnx │ ├── mb1_120x120.pth │ └── readme.md ├── filter_fourier.py ├── hubert_extract/ │ └── data_gen/ │ └── process_lrs3/ │ ├── binarizer.py │ ├── process_audio_hubert.py │ ├── process_audio_hubert_interpolate.py │ ├── process_audio_hubert_interpolate_batch.py │ ├── process_audio_hubert_interpolate_demo.py │ ├── process_audio_hubert_interpolate_single.py │ └── process_audio_mel_f0.py ├── misc.py ├── requirements.txt ├── run_ood_test/ │ ├── run_DM_v0_df_test_128_both_pose_blink.sh │ ├── run_DM_v0_df_test_128_separate_pose_blink.sh │ ├── run_DM_v0_df_test_256.sh │ ├── run_DM_v0_df_test_256_1.sh │ └── run_DM_v0_df_test_256_1_separate_pose_blink.sh ├── sync_batchnorm/ │ ├── __init__.py │ ├── batchnorm.py │ ├── comm.py │ ├── replicate.py │ ├── replicate_ddp.py │ └── unittest.py └── unified_video_generator.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ .idea/ __pycache__ **/__pycache__ cache submit* pretrain_models/* *.mp3 *.mp4 tmp* output/* ================================================ FILE: DAWN_256.yaml ================================================ input_size: 256 max_n_frames: 200 random_seed: 1234 mean: [0.0, 0.0, 0.0] win_width: 40 sampling_step: 20 ddim_sampling_eta: 1.0 cond_scale: 1.0 model_config: is_train: true pose_dim: 6 config_pth: './config/hdtf256.yaml' ae_pretrained_pth: './pretrain_models/LFG_256_400ep.pth' diffusion_pretrained_pth: './pretrain_models/DAWN_256.pth' ================================================ FILE: DM_3/datasets_hdtf_wpose_lmk_block_lmk.py ================================================ # dataset for HDTF, stage 1 from os import name import sys sys.path.append('your_path') import os import random import torch import numpy as np import torch.utils.data as data import torch.nn.functional as Ft import imageio.v2 as imageio import cv2 import torchvision.transforms.functional as F import matplotlib.pyplot as plt from PIL import Image from scipy.interpolate import interp1d import decord from torchvision.transforms.functional import to_pil_image from torchvision import transforms import time import pickle as pkl decord.bridge.set_bridge('torch') def resize(im, desired_size, interpolation): old_size = im.shape[:2] ratio = float(desired_size)/max(old_size) new_size = tuple(int(x*ratio) for x in old_size) im = cv2.resize(im, (new_size[1], new_size[0]), interpolation=interpolation) delta_w = desired_size - new_size[1] delta_h = desired_size - new_size[0] top, bottom = delta_h//2, delta_h-(delta_h//2) left, right = delta_w//2, delta_w-(delta_w//2) color = [0, 0, 0] new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) return new_im class HDTF(data.Dataset): def __init__(self, data_dir, pose_dir, eye_blink_dir, max_num_frames=80, image_size=128,mode='train', mean=(128, 128, 128), color_jitter=True): super(HDTF, self).__init__() self.mean = torch.tensor(mean)[None,:,None,None] self.data_dir = data_dir self.pose_dir = pose_dir self.eye_blink_dir = eye_blink_dir self.is_jitter = color_jitter self.max_num_frames = max_num_frames self.image_size = image_size self.mode = mode vid_list = [] # # crema # self.hubert_dir = '/train20/intern/permanent/lmlin2/data/crema_wav_hubert' # if mode == 'train': # for id_name in os.listdir(data_dir): # if id_name in ['s64','s76','s88','s90','s91']: # continue # vid_list.extend([os.path.join(id_name, sent) for sent in os.listdir(f'{data_dir}/{id_name}') ]) # if mode == 'test': # for id_name in ['s64','s76','s88','s90','s91']: # vid_list.extend([os.path.join(id_name, sent) for sent in os.listdir(f'{data_dir}/{id_name}') ]) # self.videos = vid_list # hdtf vid_id_name_list = ['RD_Radio14_000','RD_Radio30_000','RD_Radio47_000','RD_Radio56_000','WDA_AmyKlobuchar1_001',\ 'WDA_BarbaraLee0_000','WDA_BobCasey0_000','WDA_CatherineCortezMasto_000','WDA_DebbieDingell1_000','WDA_DonaldMcEachin_000',\ 'WDA_EricSwalwell_000','WDA_HenryWaxman_000','WDA_JanSchakowsky1_000','WDA_JoeDonnelly_000','WDA_JohnSarbanes1_000',\ 'WDA_JoeNeguse_001','WDA_KatieHill_000','WDA_LucyMcBath_000','WDA_MazieHirono0_000','WDA_NancyPelosi1_000',\ 'WDA_PattyMurray0_000','WDA_RaulRuiz_000','WDA_SeanPatrickMaloney_000','WDA_TammyBaldwin0_000','WDA_TerriSewell0_000',\ 'WDA_TomCarper_000','WDA_WhipJimClyburn_000','WRA_AdamKinzinger0_000','WRA_AnnWagner_000','WRA_BobCorker_000',\ 'WRA_CandiceMiller0_000','WRA_CathyMcMorrisRodgers2_000','WRA_CoryGardner1_000','WRA_DebFischer1_000','WRA_DianeBlack1_000',\ 'WRA_ErikPaulsen_000','WRA_GeorgeLeMieux_000','WRA_JebHensarling0_001','WRA_JoeHeck1_000','WRA_JohnKasich1_001',\ 'WRA_MarcoRubio_000'] bad_id_name = ['WDA_DanKildee_000', 'WDA_PatrickLeahy1_000', 'WRA_KristiNoem2_000','RD_Radio39_000'] # vid_id_name_list = [item + '.mp4' for item in vid_id_name_list] # bad_id_name = [item + '.mp4' for item in bad_id_name] # hdtf self.hubert_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert_interpolate_chunk' self.mouth_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/mouth_ratio_bar' self.lmk_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/lmk_25hz_chunk' with open('/train20/intern/permanent/hbcheng2/data/HDTF/length_dict.pkl', 'rb') as f: self.len_dict = pkl.load(f) # vid_id_name_list = ['RD_Radio47_000','WDA_CatherineCortezMasto_000','WDA_JoeNeguse_001','WDA_MichelleLujanGrisham_000','WRA_ErikPaulsen_002', \ # 'WDA_ZoeLofgren_000','WRA_JebHensarling2_003','WRA_MichaelSteele_000', 'WRA_ToddYoung_000', 'WRA_VickyHartzler_000'] if mode == 'train': for id_name in os.listdir(data_dir): # id_name = id_name[:-4] if id_name in vid_id_name_list or id_name in bad_id_name: continue vid_list.append(id_name) self.videos = vid_list if mode == 'test': self.videos = vid_id_name_list def check_head(self, frame_list, video_name, start, end): ''' Check if the desired pose address exists. ''' start_path = self.get_pose_path(frame_list, video_name, start) end_path = self.get_pose_path(frame_list, video_name, end) if os.path.exists(start_path) and os.path.exists(end_path): return True else: return False def get_block_data_for_two(self, path, start, end): # TODO: id function ''' input: start: start id end: end id output: the data from block ''' block_st = start//25 block_ed = end//25 st_pos = block_st % 25 ed_pos = block_ed % 25 block_st_name = 'chunk_%04d.npy' % (block_st) block_ed_name = 'chunk_%04d.npy' % (block_ed) if block_st != block_ed: block_st_path = os.path.join(path, block_st_name) block_ed_path = os.path.join(path, block_ed_name) block_st = np.load(block_st_path) block_ed = np.load(block_ed_path) return np.concatenate((block_st[st_pos:], block_ed[:ed_pos])) else: block_st_path = os.path.join(path, block_st_name) block_st = np.load(block_st_path) return block_st[st_pos, ed_pos] def get_block_data(self, path, start, end): # TODO: id function ''' input: start: start id end: end id output: the data from block ''' block_st = start//25 block_ed = end//25 st_pos = start % 25 ed_pos = end % 25 block_list = [os.path.join(path,'chunk_%04d.npy' % (i)) for i in range(block_st, block_ed+1)] if block_st != block_ed: arr_list = [] block_st = np.load(block_list[0]) arr_list.append(block_st[st_pos:]) for path in block_list[1:-1]: arr_list.append(np.load(path)) block_ed = np.load(block_list[-1]) arr_list.append(block_ed[:ed_pos]) return np.concatenate(arr_list) else: block_st_path = os.path.join(path, block_list[0]) block_st = np.load(block_st_path) return block_st[st_pos: ed_pos] def check_len(self, name): return self.len_dict[name] def __len__(self): return len(self.videos) def __getitem__(self, idx): video_name = self.videos[idx] path = os.path.join(self.data_dir, video_name) hubert_path = os.path.join(self.hubert_dir, video_name) lmk_path = os.path.join(self.lmk_dir, video_name) pose_path = os.path.join(self.pose_dir, video_name) eye_blink_path = os.path.join(self.eye_blink_dir, video_name) total_num_frames = self.check_len(video_name) if total_num_frames <= self.max_num_frames: sample_frames = total_num_frames start = 0 else: sample_frames = self.max_num_frames start = np.random.randint(total_num_frames-self.max_num_frames) start=start stop=sample_frames+start sample_frame_npy = self.get_block_data(path = path, start = start, end = stop) sample_hubert_feature_npy = self.get_block_data(path = hubert_path, start = start, end = stop).astype(np.float32) sample_pose_list_npy = self.get_block_data(path = pose_path, start = start, end = stop).astype(np.float32) sample_eye_blink_list_npy = self.get_block_data(path = eye_blink_path, start = start, end = stop).astype(np.float32) sample_frame_list = torch.tensor(sample_frame_npy).permute(0,3,1,2) sample_hubert_feature_tensor = torch.tensor(sample_hubert_feature_npy) sample_frame_list = sample_frame_list - self.mean # 20, 3, 128, 128 # sample_frame_list = [np.transpose(x, (2, 0, 1)) for x in sample_frame_list] # sample_frame_list_npy = np.stack(sample_frame_list, axis=1) # sample_pose_list_npy = np.stack(sample_pose_list, axis = 1) # sample_eye_blink_list_npy = np.stack(sample_eye_blink_list, axis = 1) # change to float32 sample_frame_list = sample_frame_list.permute(1, 0, 2, 3) # sample_frame_list = np.array(sample_frame_list/255.0, dtype=np.float32) #3, 40, 128, 128 # sample_frame_list = sample_frame_list/255. # put to mode l forward # added to change the video_name of crema video_name = video_name.replace('/','_') sample_pose_list_npy = sample_pose_list_npy.transpose(1,0) # for compatibility sample_eye_blink_list_npy = sample_eye_blink_list_npy.transpose(1,0) # if __debug__: # end_time = time.time() # end # print(f'process time {end_time- start_time}') # spend lot of time # start_time = end_time if self.mode == 'test': return sample_frame_list, sample_hubert_feature_tensor, sample_pose_list_npy, sample_eye_blink_list_npy, video_name, start return sample_frame_list, sample_hubert_feature_tensor, sample_pose_list_npy, sample_eye_blink_list_npy, video_name, total_num_frames if __name__ == "__main__": # hdtf data_dir = "/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz" pose_dir = "/train20/intern/permanent/hbcheng2/data/HDTF/pose" # crema # data_dir='/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images' dataset = HDTF(data_dir=data_dir, pose_dir=pose_dir ,mode='train') for i in range(10): dataset.__getitem__(i) print('------') test_dataset = data.DataLoader(dataset=dataset, batch_size=10, num_workers=8, shuffle=False) for i, batch in enumerate(test_dataset): print(i) ================================================ FILE: DM_3/datasets_hdtf_wpose_lmk_block_lmk_rand.py ================================================ # dataset for HDTF, stage 2 from os import name import sys sys.path.append('your_path') import os import random import torch import numpy as np import torch.utils.data as data import torch.nn.functional as Ft import imageio.v2 as imageio import cv2 import torchvision.transforms.functional as F import matplotlib.pyplot as plt from PIL import Image from scipy.interpolate import interp1d import decord from torchvision.transforms.functional import to_pil_image from torchvision import transforms import time import pickle as pkl decord.bridge.set_bridge('torch') def resize(im, desired_size, interpolation): old_size = im.shape[:2] ratio = float(desired_size)/max(old_size) new_size = tuple(int(x*ratio) for x in old_size) im = cv2.resize(im, (new_size[1], new_size[0]), interpolation=interpolation) delta_w = desired_size - new_size[1] delta_h = desired_size - new_size[0] top, bottom = delta_h//2, delta_h-(delta_h//2) left, right = delta_w//2, delta_w-(delta_w//2) color = [0, 0, 0] new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) return new_im class HDTF(data.Dataset): def __init__(self, data_dir, pose_dir, eye_blink_dir, max_num_frames=80, image_size=128, audio_dir=None, ref_id = None, mode='train', mean=(128, 128, 128), color_jitter=True): super(HDTF, self).__init__() self.mean = torch.tensor(mean)[None,:,None,None] self.data_dir = data_dir self.pose_dir = pose_dir self.eye_blink_dir = eye_blink_dir self.is_jitter = color_jitter self.max_num_frames = max_num_frames self.image_size = image_size self.mode = mode vid_list = [] # # crema # self.hubert_dir = '/train20/intern/permanent/lmlin2/data/crema_wav_hubert' # if mode == 'train': # for id_name in os.listdir(data_dir): # if id_name in ['s64','s76','s88','s90','s91']: # continue # vid_list.extend([os.path.join(id_name, sent) for sent in os.listdir(f'{data_dir}/{id_name}') ]) # if mode == 'test': # for id_name in ['s64','s76','s88','s90','s91']: # vid_list.extend([os.path.join(id_name, sent) for sent in os.listdir(f'{data_dir}/{id_name}') ]) # self.videos = vid_list # hdtf vid_id_name_list = ['RD_Radio14_000','RD_Radio30_000','RD_Radio47_000','RD_Radio56_000','WDA_AmyKlobuchar1_001',\ 'WDA_BarbaraLee0_000','WDA_BobCasey0_000','WDA_CatherineCortezMasto_000','WDA_DebbieDingell1_000','WDA_DonaldMcEachin_000',\ 'WDA_EricSwalwell_000','WDA_HenryWaxman_000','WDA_JanSchakowsky1_000','WDA_JoeDonnelly_000','WDA_JohnSarbanes1_000',\ 'WDA_JoeNeguse_001','WDA_KatieHill_000','WDA_LucyMcBath_000','WDA_MazieHirono0_000','WDA_NancyPelosi1_000',\ 'WDA_PattyMurray0_000','WDA_RaulRuiz_000','WDA_SeanPatrickMaloney_000','WDA_TammyBaldwin0_000','WDA_TerriSewell0_000',\ 'WDA_TomCarper_000','WDA_WhipJimClyburn_000','WRA_AdamKinzinger0_000','WRA_AnnWagner_000','WRA_BobCorker_000',\ 'WRA_CandiceMiller0_000','WRA_CathyMcMorrisRodgers2_000','WRA_CoryGardner1_000','WRA_DebFischer1_000','WRA_DianeBlack1_000',\ 'WRA_ErikPaulsen_000','WRA_GeorgeLeMieux_000','WRA_JebHensarling0_001','WRA_JoeHeck1_000','WRA_JohnKasich1_001',\ 'WRA_MarcoRubio_000'] bad_id_name = ['WDA_DanKildee_000', 'WDA_PatrickLeahy1_000', 'WRA_KristiNoem2_000'] # vid_id_name_list = [item + '.mp4' for item in vid_id_name_list] # bad_id_name = [item + '.mp4' for item in bad_id_name] # hdtf if audio_dir == None: self.hubert_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert_interpolate_chunk' else: self.hubert_dir = audio_dir self.ref_id = ref_id self.mouth_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/mouth_ratio_bar' self.lmk_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/lmk_25hz_chunk' with open('/train20/intern/permanent/hbcheng2/data/HDTF/length_dict.pkl', 'rb') as f: self.len_dict = pkl.load(f) # vid_id_name_list = ['RD_Radio47_000','WDA_CatherineCortezMasto_000','WDA_JoeNeguse_001','WDA_MichelleLujanGrisham_000','WRA_ErikPaulsen_002', \ # 'WDA_ZoeLofgren_000','WRA_JebHensarling2_003','WRA_MichaelSteele_000', 'WRA_ToddYoung_000', 'WRA_VickyHartzler_000'] if mode == 'train': for id_name in os.listdir(data_dir): # id_name = id_name[:-4] if id_name in vid_id_name_list or id_name in bad_id_name: continue vid_list.append(id_name) self.videos = vid_list if mode == 'test': self.videos = vid_id_name_list def check_head(self, frame_list, video_name, start, end): ''' Check if the desired pose address exists. ''' start_path = self.get_pose_path(frame_list, video_name, start) end_path = self.get_pose_path(frame_list, video_name, end) if os.path.exists(start_path) and os.path.exists(end_path): return True else: return False def get_block_data_for_two(self, path, start, end): # TODO: id function ''' input: start: start id end: end id output: the data from block ''' block_st = start//25 block_ed = end//25 st_pos = block_st % 25 ed_pos = block_ed % 25 block_st_name = 'chunk_%04d.npy' % (block_st) block_ed_name = 'chunk_%04d.npy' % (block_ed) if block_st != block_ed: block_st_path = os.path.join(path, block_st_name) block_ed_path = os.path.join(path, block_ed_name) block_st = np.load(block_st_path) block_ed = np.load(block_ed_path) return np.concatenate((block_st[st_pos:], block_ed[:ed_pos])) else: block_st_path = os.path.join(path, block_st_name) block_st = np.load(block_st_path) return block_st[st_pos, ed_pos] def get_block_data(self, path, start, end): # TODO: id function ''' input: start: start id end: end id output: the data from block ''' block_st = start//25 block_ed = end//25 st_pos = start % 25 ed_pos = end % 25 block_list = [os.path.join(path,'chunk_%04d.npy' % (i)) for i in range(block_st, block_ed+1)] if block_st != block_ed: arr_list = [] block_st = np.load(block_list[0]) arr_list.append(block_st[st_pos:]) for path in block_list[1:-1]: arr_list.append(np.load(path)) block_ed = np.load(block_list[-1]) arr_list.append(block_ed[:ed_pos]) return np.concatenate(arr_list) else: block_st_path = os.path.join(path, block_list[0]) block_st = np.load(block_st_path) return block_st[st_pos: ed_pos] def check_len(self, name): return self.len_dict[name] def __len__(self): return len(self.videos) def __getitem__(self, idx): video_name = self.videos[idx] path = os.path.join(self.data_dir, video_name) hubert_path = os.path.join(self.hubert_dir, video_name) lmk_path = os.path.join(self.lmk_dir, video_name) pose_path = os.path.join(self.pose_dir, video_name) eye_blink_path = os.path.join(self.eye_blink_dir, video_name) total_num_frames = self.check_len(video_name) if total_num_frames <= self.max_num_frames: sample_frames = total_num_frames start = 0 else: sample_frames = self.max_num_frames start = np.random.randint(total_num_frames-self.max_num_frames) start=start stop=sample_frames+start if self.ref_id == None: ref_id = np.random.randint(total_num_frames) elif self.ref_id == "clip": ref_id = np.random.randint(sample_frames) + start else: ref_id = 0 sample_frame_npy = self.get_block_data(path = path, start = start, end = stop) sample_hubert_feature_npy = self.get_block_data(path = hubert_path, start = start, end = stop).astype(np.float32) sample_lmk_npy = self.get_block_data(path = lmk_path, start = start, end = stop).astype(np.float32) sample_pose_list_npy = self.get_block_data(path = pose_path, start = start, end = stop).astype(np.float32) sample_eye_blink_list_npy = self.get_block_data(path = eye_blink_path, start = start, end = stop).astype(np.float32) ref_frame_npy = self.get_block_data(path = path, start = ref_id, end = ref_id + 1) ref_hubert_feature_npy = self.get_block_data(path = hubert_path, start = ref_id, end = ref_id + 1).astype(np.float32) ref_pose_list_npy = self.get_block_data(path = pose_path, start = ref_id, end = ref_id + 1).astype(np.float32) ref_eye_blink_list_npy = self.get_block_data(path = eye_blink_path, start = ref_id, end = ref_id + 1).astype(np.float32) # mouth_path = os.path.join(self.mouth_dir, video_name+'.npy') # mouth_seq = np.load(mouth_path).astype(np.float32) # ref_mouth = mouth_seq[ref_id] # mouth_seq = mouth_seq[start:stop] mouth_lmk_tensor = torch.tensor(sample_lmk_npy[:,48:67]) sample_frame_list = torch.tensor(sample_frame_npy).permute(0,3,1,2) sample_hubert_feature_tensor = torch.tensor(sample_hubert_feature_npy) sample_frame_list = sample_frame_list - self.mean # 20, 3, 128, 128 ref_frame_npy = torch.tensor(ref_frame_npy).permute(0,3,1,2) ref_hubert_feature_npy = torch.tensor(ref_hubert_feature_npy) ref_frame_npy = ref_frame_npy - self.mean # 20, 3, 128, 128 sample_hubert_feature_tensor = torch.concat([ref_hubert_feature_npy, sample_hubert_feature_tensor], dim = 0) sample_frame_list = torch.concat([ref_frame_npy, sample_frame_list], dim = 0) # sample_frame_list = [np.transpose(x, (2, 0, 1)) for x in sample_frame_list] # sample_frame_list_npy = np.stack(sample_frame_list, axis=1) # sample_pose_list_npy = np.stack(sample_pose_list, axis = 1) # sample_eye_blink_list_npy = np.stack(sample_eye_blink_list, axis = 1) # change to float32 sample_frame_list = sample_frame_list.permute(1, 0, 2, 3) # sample_frame_list = np.array(sample_frame_list/255.0, dtype=np.float32) #3, 40, 128, 128 # sample_frame_list = sample_frame_list/255. # put to mode l forward # added to change the video_name of crema video_name = video_name.replace('/','_') sample_pose_list_npy = np.concatenate([ref_pose_list_npy, sample_pose_list_npy], axis = 0) sample_eye_blink_list_npy = np.concatenate([ref_eye_blink_list_npy, sample_eye_blink_list_npy], axis = 0) sample_pose_list_npy = sample_pose_list_npy.transpose(1,0) # for compatibility sample_eye_blink_list_npy = sample_eye_blink_list_npy.transpose(1,0) # mouth_seq = np.concatenate([ref_mouth[None], mouth_seq], axis = 0) # mouth_seq_npy = mouth_seq.transpose(1,0) # if __debug__: # end_time = time.time() # end # print(f'load data time {end_time- start_time}') # spend lot of time # start_time = end_time if self.mode == 'test': return sample_frame_list, sample_hubert_feature_tensor, sample_pose_list_npy, sample_eye_blink_list_npy, mouth_lmk_tensor, video_name, start return sample_frame_list, sample_hubert_feature_tensor, sample_pose_list_npy, sample_eye_blink_list_npy, mouth_lmk_tensor, video_name, total_num_frames if __name__ == "__main__": # hdtf data_dir = "/train20/intern/permanent/hbcheng2/data/HDTF/images_25hz_128_chunk" pose_dir = "/train20/intern/permanent/hbcheng2/data/HDTF/pose_bar_chunk" eye_blink_dir = "/train20/intern/permanent/hbcheng2/data/HDTF/eye_blink_bbox_from_xpc_bar_2_chunk" # crema # data_dir='/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images' dataset = HDTF(data_dir=data_dir, pose_dir=pose_dir, eye_blink_dir = eye_blink_dir, image_size=128, max_num_frames=30, color_jitter=True) for i in range(10): dataset.__getitem__(i) print('------') test_dataset = data.DataLoader(dataset=dataset, batch_size=10, num_workers=8, shuffle=False) for i, batch in enumerate(test_dataset): print(i) ================================================ FILE: DM_3/modules/local_attention.py ================================================ import sys # sys.path.append('your/path/DAWN-pytorch') import torch import torch.nn as nn import torch.nn.functional as F import time from einops import rearrange from rotary_embedding_torch import RotaryEmbedding from einops_exts import rearrange_many import math import concurrent.futures # from local_attn_cuda_pkg.test_cuda import attn_forward # from local_attn_cuda_pkg.test_cuda import attn_forward, compute_res_forward # def attn_forward(x, y, batch_size, hw, seq_len, k_size, head, d, device): # attn = torch.zeros(batch_size, hw, seq_len, k_size, head, device=device) # local_attn_res.attn_cuda(x, y, attn, batch_size, hw, seq_len, k_size, head, d) # return attn # def compute_res_forward(attn, z, batch_size, hw, seq_len, head, head_dim, k_size, device): # res = torch.zeros(batch_size, hw, seq_len, head, head_dim, device=device) # local_attn_res.compute_res_cuda(attn, z, res, batch_size, hw, seq_len, head, head_dim, k_size) # return res def exists(x): return x is not None def to_mask(x, mask, mode='mul'): if mask is None: return x else: while x.dim() > mask.dim(): mask = mask.unsqueeze(-1) if mode == 'mul': return x * mask else: return x + mask # def extract_seq_patches(x, kernel_size, rate): # """x.shape = [batch_size, seq_len, seq_dim]""" # seq_len = x.size(1) # seq_dim = x.size(2) # k_size = kernel_size + (rate - 1) * (kernel_size - 1) # p_right = (k_size - 1) // 2 # p_left = k_size - 1 - p_right # x = F.pad(x, (0, 0, p_left, p_right), mode='constant', value=0) # xs = [x[:, i: i + seq_len] for i in range(0, k_size, rate)] # x = torch.cat(xs, dim=2) # return x.reshape(-1, seq_len, kernel_size, seq_dim) def extract_seq_patches(x, kernel_size, rate): """x.shape = [batch_size, hw, seq_len, seq_dim]""" # batch_size, hw, seq_len, seq_dim = x.size() # Calculate the size of the expanded kernel and the number of padding to be added on both sides. k_size = kernel_size + (rate - 1) * (kernel_size - 1) p_right = (k_size - 1) // 2 p_left = k_size - 1 - p_right # padding x = F.pad(x, (0, 0, p_left, p_right), mode='constant', value=0) # pad only the second dimension # Use the unfold method to extract sliding windows. x_unfold = x.unfold(dimension=2, size=k_size, step=rate) # x, window, k_size, step, rate x_unfold = x_unfold.transpose(-1, -2) # reshape (batch_size, hw, seq_len, kernel_size, seq_dim) x_patches = x_unfold[:, :, :, ::rate] return x_patches def window_attn(x, y, z, kernel_size, mask, rate): """y.shape x.shape = [batch_size, hw, seq_len, self.heads, dim_head]""" batch_size, hw, seq_len, head, head_dim = x.size() device = x.device # Calculate the size of the expanded kernel and the number of padding to be added on both sides. k_size = kernel_size + (rate - 1) * (kernel_size - 1) p_right = (k_size - 1) // 2 p_left = k_size - 1 - p_right # padding y = F.pad(y, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0) # pad only the second dimension z = F.pad(z, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0) attn = torch.zeros(batch_size, hw, seq_len, k_size, head).to(device) for i in range(seq_len): # torch.matmul(x[:,:,i].unsqueeze(2), y[:,:,i:i + k_size].transpose()) # b, hw, 1, d ; b, hw, w, d attn[:,:, i] = torch.einsum('b n h d, b n w h d -> b n w h', x[:,:,i], y[:,:,i:i + k_size]) # reshape (batch_size, hw, seq_len, kernel_size, seq_dim) # res = rearrange(res, 'b n l w h -> b n h l w') attn = to_mask(attn, mask.unsqueeze(0), 'add') attn = attn - attn.amax(dim=-2, keepdim=True).detach() attn = F.softmax(attn, dim=-2) res = torch.zeros(batch_size, hw, seq_len, head, head_dim).to(device) for i in range(seq_len): res[:,:,i] = torch.einsum('b n w h, b n w h d -> b n h d', attn[:,:,i], z[:,:,i : i +k_size]) # attn[:,:,i] * z[:,:,i : i +k_size] res = res.view(batch_size, hw, seq_len, -1) return res def window_attn_2(x, y, z, kernel_size, mask, rate): # bad optimization """ The optimized window_attn function eliminates two explicit for loops and utilizes tensor operations for parallel computation. param: x (Tensor): [batch_size, hw, seq_len, heads, dim_head] y (Tensor): [batch_size, hw, seq_len, heads, dim_head] z (Tensor): [batch_size, hw, seq_len, heads, dim_head] kernel_size (int): window size mask (Tensor) rate (int) return: Tensor: [batch_size, hw, seq_len, heads * dim_head] """ batch_size, hw, seq_len, head, head_dim = x.size() k_size = kernel_size + (rate - 1) * (kernel_size - 1) p_right = (k_size - 1) // 2 p_left = k_size - 1 - p_right y_padded = F.pad(y, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0) # [batch_size, hw, seq_len + p_left + p_right, heads, dim_head] z_padded = F.pad(z, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0) # y_windows z_windows [batch_size, hw, seq_len, k_size, heads, dim_head] y_windows = y_padded.as_strided( size=(batch_size, hw, seq_len, k_size, head, head_dim), stride=( y_padded.stride(0), y_padded.stride(1), y_padded.stride(2), y_padded.stride(2), y_padded.stride(3), y_padded.stride(4) ) ) z_windows = z_padded.as_strided( size=(batch_size, hw, seq_len, k_size, head, head_dim), stride=( z_padded.stride(0), z_padded.stride(1), z_padded.stride(2), z_padded.stride(2), z_padded.stride(3), z_padded.stride(4) ) ) # x: [batch_size, hw, seq_len, heads, dim_head] -> [batch_size, hw, seq_len, 1, heads, dim_head] x_expanded = x #.unsqueeze(3) # [batch_size, hw, seq_len, 1, heads, dim_head] attn_scores = torch.einsum('b n l h d, b n l s h d -> b n l s h', x_expanded, y_windows) attn = to_mask(attn_scores, mask.unsqueeze(0), 'add') attn = attn - attn.amax(dim=-2, keepdim=True).detach() attn = F.softmax(attn, dim=-2) # Softmax on k_size res = (attn.unsqueeze(-1) * z_windows).sum(dim=-3) res = res.view(batch_size, hw, seq_len, head * head_dim) return res def window_attn_stream(x, y, z, kernel_size, mask, rate): # bad optimization """y.shape x.shape = [batch_size, hw, seq_len, self.heads, dim_head]""" batch_size, hw, seq_len, head, head_dim = x.size() # Calculate the size of the expanded kernel and the number of padding to be added on both sides. k_size = kernel_size + (rate - 1) * (kernel_size - 1) p_right = (k_size - 1) // 2 p_left = k_size - 1 - p_right # padding y = F.pad(y, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0) # pad only the second dimension z = F.pad(z, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0) attn = torch.zeros(batch_size, hw, seq_len, k_size, head, device=x.device) res = torch.zeros(batch_size, hw, seq_len, head, head_dim, device=x.device) streams = [torch.cuda.Stream() for _ in range(seq_len)] def compute_attn(i): with torch.cuda.stream(streams[i]): attn[:, :, i] = torch.einsum('b n h d, b n w h d -> b n w h', x[:, :, i], y[:, :, i:i + k_size]) def compute_res(i): with torch.cuda.stream(streams[i]): res[:, :, i] = torch.einsum('b n w h, b n w h d -> b n h d', attn[:, :, i], z[:, :, i:i + k_size]) for i in range(seq_len): compute_attn(i) for stream in streams: stream.synchronize() attn = to_mask(attn, mask.unsqueeze(0), 'add') attn = attn - attn.amax(dim=-2, keepdim=True).detach() attn = F.softmax(attn, dim=-2) for i in range(seq_len): compute_res(i) for stream in streams: stream.synchronize() res = res.view(batch_size, hw, seq_len, -1) return res def create_sliding_window_mask(x, win_size, rate): # mask (len, len, head) # assert mask.dim() == 3, "The input mask must be of shape (len, len, head)" k_size = win_size + (rate - 1) * (win_size - 1) p_right = (k_size - 1) // 2 p_left = k_size - 1 - p_right # padding x = F.pad(x, (p_left, p_right), mode='constant', value=-1e10) # pad only the second dimension res = [] for i in range(x.shape[1]): res.append(x[:, i , i :i +k_size]) return torch.stack(res, dim = 1) # len k_size, head class OurLayer(nn.Module): def reuse(self, layer, *args, **kwargs): outputs = layer(*args, **kwargs) return outputs def heavy_computation(x, y, attn, k_size, i): attn[:,:, i] = torch.einsum('b n h d, b n w h d -> b n w h', x[:,:,i], y[:,:,i:i + k_size]) def heavy_computation2(res, z, attn, k_size, i): res[:,:,i] = torch.einsum('b n w h, b n w h d -> b n h d', attn[:,:,i], z[:,:,i : i +k_size]) # attn[:,:,i] * z[:,:,i : i +k_size] from functools import partial def window_attn_mp(x, y, z, kernel_size, mask, rate): """y.shape x.shape = [batch_size, hw, seq_len, self.heads, dim_head]""" batch_size, hw, seq_len, head, head_dim = x.size() device = x.device # Calculate the size of the expanded kernel and the number of padding to be added on both sides. k_size = kernel_size + (rate - 1) * (kernel_size - 1) p_right = (k_size - 1) // 2 p_left = k_size - 1 - p_right # padding y = F.pad(y, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0) # pad only the second dimension z = F.pad(z, (0, 0, 0, 0, p_left, p_right), mode='constant', value=0) attn = torch.zeros(batch_size, hw, seq_len, k_size, head).to(device) unary = partial(heavy_computation, x,y,attn, k_size) with concurrent.futures.ProcessPoolExecutor() as executor: executor.map(unary, list(range(seq_len))) # reshape (batch_size, hw, seq_len, kernel_size, seq_dim) # res = rearrange(res, 'b n l w h -> b n h l w') attn = to_mask(attn, mask.unsqueeze(0), 'add') attn = attn - attn.amax(dim=-2, keepdim=True).detach() attn = F.softmax(attn, dim=-2) res = torch.zeros(batch_size, hw, seq_len, head, head_dim).to(device) unary2 = partial(heavy_computation2, res,z,attn, k_size) with concurrent.futures.ProcessPoolExecutor() as executor: executor.map(unary2, list(range(seq_len))) res = res.view(batch_size, hw, seq_len, -1) return res class LocalSelfAttention_opt(OurLayer): def __init__(self, d_model, heads, size_per_head, neighbors=3, rate=1, rotary_emb=None, key_size=None, mask_right=False): super(LocalSelfAttention_opt, self).__init__() self.heads = heads self.size_per_head = size_per_head self.out_dim = heads * size_per_head self.key_size = key_size if key_size else size_per_head self.neighbors = neighbors self.rate = rate self.mask_right = mask_right self.rotary_emb = rotary_emb # self.q_dense = nn.Linear(self.key_size * self.heads, self.key_size * self.heads, bias=False) # self.k_dense = nn.Linear(self.key_size * self.heads, self.key_size * self.heads, bias=False) # self.v_dense = nn.Linear(self.key_size * self.heads, self.key_size * self.heads, bias=False) # self.q_dense.weight.data.fill_(1) # self.k_dense.weight.data.fill_(1) # self.v_dense.weight.data.fill_(1) self.to_qkv = nn.Linear(d_model, self.key_size * self.heads * 3, bias=False) self.to_out = nn.Linear(self.key_size * self.heads, d_model, bias=False) # self.to_qkv.weight.data.fill_(1) # self.to_out.weight.data.fill_(1) def forward(self, inputs, pos_bias, focus_present_mask=None,): # if isinstance(inputs, list): # x, x_mask = inputs # else: # x, x_mask = inputs, None x = inputs x_mask = pos_bias kernel_size = 1 + 2 * self.neighbors # if x_mask is not None: # xp_mask = create_sliding_window_mask(x_mask, kernel_size, self.rate) # b, hw, seq, d_model -> b, hw, seq, win, d_model batch_size, hw, seq_len, seq_dim = x.size() if x_mask is not None: xp_mask = x_mask.unsqueeze(0) # b, hw, seq, win, 1 v_mask = xp_mask else: v_mask = None # k = self.k_dense(x) # v = self.v_dense(x) qw, k, v = self.to_qkv(x).chunk(3, dim=-1) # qw: b, hw, seq_len, d_model qw = qw/ (self.key_size ** 0.5) qw = qw.view(batch_size, hw, seq_len, self.heads, self.key_size) k = k.view(batch_size, hw, seq_len, self.heads, self.key_size) # b, hw, seq_len,h, d_head v = v.view(batch_size, hw, seq_len, self.heads, self.key_size) st = time.time() if exists(self.rotary_emb): qw = self.rotary_emb.rotate_queries_or_keys(qw) k = self.rotary_emb.rotate_queries_or_keys(k) ed = time.time() # print("rope local: ", ed - st) st = time.time() # qw = qw.view(batch_size * hw, seq_len, seq_dim) # b * hw, seq, d_model # k = k.view(batch_size, hw, seq_len, self.key_size * self.heads) res = window_attn(qw, k, v, kernel_size, v_mask.permute(0, 2, 3, 1), rate = 1) ed = time.time() # print("rope local: ", ed - st) return self.to_out(res) class MultiHeadLocalAttention(nn.Module): def __init__(self, d_model, num_heads, window_size): super(MultiHeadLocalAttention, self).__init__() self.d_model = d_model self.num_heads = num_heads self.window_size = window_size assert d_model % num_heads == 0 self.depth = d_model // num_heads self.query = nn.Linear(d_model, d_model) self.key = nn.Linear(d_model, d_model) self.value = nn.Linear(d_model, d_model) # self.out = nn.Linear(d_model, d_model) self.query.weight.data.fill_(1) self.key.weight.data.fill_(1) self.value.weight.data.fill_(1) self.query.bias.data.fill_(0) self.key.bias.data.fill_(0) self.value.bias.data.fill_(0) def split_heads(self, x, batch_size): """Split the last dimension into (num_heads, depth).""" x = x.reshape(batch_size, -1, self.num_heads, self.depth) return x.permute(0, 2, 1, 3) # (batch_size, num_heads, seq_len, depth) def forward(self, x): batch_size, seq_len, d_model = x.size() assert d_model == self.d_model Q = self.split_heads(self.query(x), batch_size) K = self.split_heads(self.key(x), batch_size) V = self.split_heads(self.value(x), batch_size) # Create the attention scores attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.depth ** 0.5) # (batch_size, num_heads, seq_len, seq_len) # Create the mask mask = (torch.arange(seq_len).unsqueeze(0) - torch.arange(seq_len).unsqueeze(1)).abs() mask = (mask > self.window_size).unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len) mask = mask.to(x.device) # Apply the mask to the attention scores attn_scores = attn_scores.masked_fill(mask, float('-inf')) # Compute the attention weights attn_weights = F.softmax(attn_scores, dim=-1) # (batch_size, num_heads, seq_len, seq_len) # Compute the output output = torch.matmul(attn_weights, V) # (batch_size, num_heads, seq_len, depth) output = output.permute(0, 2, 1, 3) # (batch_size, seq_len, num_heads, depth) output = output.reshape(batch_size, seq_len, d_model) return output class Attention(nn.Module): def __init__( self, dim, heads=4, dim_head=32, rotary_emb=None ): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads hidden_dim = dim_head * heads self.rotary_emb = rotary_emb self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False) self.to_out = nn.Linear(hidden_dim, dim, bias=False) self.to_qkv.weight.data.fill_(1) self.to_out.weight.data.fill_(1) def forward( self, x, pos_bias=None, focus_present_mask=None ): n, device = x.shape[-2], x.device qkv = self.to_qkv(x).chunk(3, dim=-1) if exists(focus_present_mask) and focus_present_mask.all(): # if all batch samples are focusing on present # it would be equivalent to passing that token's values through to the output values = qkv[-1] return self.to_out(values) # split out heads q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads) # scale q = q * self.scale # rotate positions into queries and keys for time attention st = time.time() if exists(self.rotary_emb): q = self.rotary_emb.rotate_queries_or_keys(q) k = self.rotary_emb.rotate_queries_or_keys(k) ed = time.time() print("rope normal: ", ed - st) # similarity sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k) # relative positional bias if exists(pos_bias): sim = sim + pos_bias if exists(focus_present_mask) and not (~focus_present_mask).all(): attend_all_mask = torch.ones((n, n), device=device, dtype=torch.bool) attend_self_mask = torch.eye(n, device=device, dtype=torch.bool) mask = torch.where( rearrange(focus_present_mask, 'b -> b 1 1 1 1'), rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), ) sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) # numerical stability sim = sim - sim.amax(dim=-1, keepdim=True).detach() attn = sim.softmax(dim=-1) # aggregate values out = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v) out = rearrange(out, '... h n d -> ... n (h d)') # return self.to_out(out) return out class RelativePositionBias(nn.Module): def __init__( self, heads=8, num_buckets=32, max_distance=128 ): super().__init__() self.num_buckets = num_buckets self.max_distance = max_distance self.relative_attention_bias = nn.Embedding(num_buckets, heads) @staticmethod def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): ret = 0 n = -relative_position num_buckets //= 2 ret += (n < 0).long() * num_buckets n = torch.abs(n) max_exact = num_buckets // 2 is_small = n < max_exact val_if_large = max_exact + ( torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).long() val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) ret += torch.where(is_small, n, val_if_large) return ret def forward(self, n, device): q_pos = torch.arange(n, dtype=torch.long, device=device) k_pos = torch.arange(n, dtype=torch.long, device=device) rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') rp_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance) mask = - (((rel_pos > 20) + (rel_pos < - 20)) * (1e10)) values = self.relative_attention_bias(rp_bucket) return mask + rearrange(values, 'i j h -> h i j') if __name__ == "__main__": # Example usage: d_model = 256 window_size = 20 seq_len = 200 batch_size = 1 head = 4 res_pos = RelativePositionBias(heads=4, max_distance=32) rope = RotaryEmbedding(min(64, d_model//head), seq_before_head_dim = True) rope2 = RotaryEmbedding(min(64, d_model//head)) model = LocalSelfAttention_opt(d_model, head, d_model//head, window_size, rotary_emb=rope) model_2 = Attention(d_model, head, dim_head= d_model//head, rotary_emb = rope2) rp = res_pos(200, 'cpu') xp_mask = create_sliding_window_mask(rp, 2 * window_size + 1, 1) for i in range(5): x = torch.randn(batch_size, 9, seq_len, d_model) st = time.time() output = model([x, xp_mask]) ed = time.time() print("optimized: ", ed - st) st = time.time() output_2 = model_2(x , pos_bias = rp) ed = time.time() print("origin: ", ed - st) print(((output - output_2)**2).mean()) ================================================ FILE: DM_3/modules/text.py ================================================ # the code from https://github.com/lucidrains/video-diffusion-pytorch import torch from einops import rearrange def exists(val): return val is not None # singleton globals MODEL = None TOKENIZER = None HUBERT_MODEL_DIM = 20*1024 # BERT_MODEL_DIM = 768 def get_tokenizer(): global TOKENIZER if not exists(TOKENIZER): TOKENIZER = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased') return TOKENIZER def get_bert(): global MODEL if not exists(MODEL): MODEL = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-cased') if torch.cuda.is_available(): MODEL = MODEL.cuda() return MODEL # tokenize def tokenize(texts, add_special_tokens=True): if not isinstance(texts, (list, tuple)): texts = [texts] tokenizer = get_tokenizer() encoding = tokenizer.batch_encode_plus( texts, add_special_tokens=add_special_tokens, padding=True, return_tensors='pt' ) token_ids = encoding.input_ids return token_ids # embedding function @torch.no_grad() def bert_embed( token_ids, return_cls_repr=False, eps=1e-8, pad_id=0. ): model = get_bert() mask = token_ids != pad_id if torch.cuda.is_available(): token_ids = token_ids.cuda() mask = mask.cuda() outputs = model( input_ids=token_ids, attention_mask=mask, output_hidden_states=True ) hidden_state = outputs.hidden_states[-1] if return_cls_repr: return hidden_state[:, 0] # return [cls] as representation if not exists(mask): return hidden_state.mean(dim=1) mask = mask[:, 1:] # mean all tokens excluding [cls], accounting for length mask = rearrange(mask, 'b n -> b n 1') numer = (hidden_state[:, 1:] * mask).sum(dim=1) denom = mask.sum(dim=1) masked_mean = numer / (denom + eps) return masked_mean ================================================ FILE: DM_3/modules/video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_mouth_mask_6D.py ================================================ ''' stage 1: using 0th as the reference, short and fixed clip with lip loss, 6D pose, conditioned by cross attention ''' import os import torch import torch.nn as nn import sys sys.path.append('your/path') from LFG.modules.generator import Generator from LFG.modules.bg_motion_predictor import BGMotionPredictor from LFG.modules.region_predictor import RegionPredictor from DM_3.modules.video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi import DynamicNfUnet3D, DynamicNfGaussianDiffusion import yaml from sync_batchnorm import DataParallelWithCallback from filter_fourier import * from DM_v0_loss_exp.modules.util import AntiAliasInterpolation2d from torchvision import models import numpy as np import time from einops import rearrange class Attention(nn.Module): def __init__(self, params): super(Attention, self).__init__() self.fc_query = nn.Linear(params['n'], params['dim_attention'], bias=False) self.fc_attention = nn.Linear(params['dim_attention'], 1) def forward(self, ctx_val, ctx_key, ctx_mask, ht_query): ht_query = self.fc_query(ht_query) attention_score = torch.tanh(ctx_key + ht_query[:, None, None, :]) attention_score = self.fc_attention(attention_score).squeeze(3) attention_score = attention_score - attention_score.max() attention_score = torch.exp(attention_score) * ctx_mask attention_score = attention_score / (attention_score.sum(2).sum(1)[:, None, None] + 1e-10) ct = (ctx_val * attention_score[:, None, :, :]).sum(3).sum(2) return ct, attention_score class Face_loc_Encoder(nn.Module): def __init__(self, dim = 1): super(Face_loc_Encoder, self).__init__() self.conv1 = nn.Conv2d(dim, 8, kernel_size=3, stride=2, padding=1) self.conv2 = nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1) def forward(self, x): x = self.conv1(x) x = nn.functional.relu(x) x = self.conv2(x) x = nn.functional.relu(x) return x class Vgg19(torch.nn.Module): """ Vgg19 network for perceptual loss. """ def __init__(self, requires_grad=False): super(Vgg19, self).__init__() vgg_pretrained_features = models.vgg19(pretrained=True).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() for x in range(2): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(2, 7): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(7, 12): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(12, 21): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(21, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), requires_grad=False) self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), requires_grad=False) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, x): x = (x - self.mean) / self.std h_relu1 = self.slice1(x) h_relu2 = self.slice2(h_relu1) h_relu3 = self.slice3(h_relu2) h_relu4 = self.slice4(h_relu3) h_relu5 = self.slice5(h_relu4) out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] return out class ImagePyramide(torch.nn.Module): """ Create image pyramide for computing pyramide perceptual loss. """ def __init__(self, scales, num_channels): super(ImagePyramide, self).__init__() downs = {} for scale in scales: downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale) self.downs = nn.ModuleDict(downs) def forward(self, x): out_dict = {} for scale, down_module in self.downs.items(): out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x) return out_dict class FlowDiffusion(nn.Module): def __init__(self, img_size=32, num_frames=40, sampling_timesteps=250, null_cond_prob=0.1, ddim_sampling_eta=1., dim_mults=(1, 2, 4, 8), is_train=True, use_residual_flow=False, learn_null_cond=False, use_deconv=True, padding_mode="zeros", pretrained_pth="", config_pth=""): super(FlowDiffusion, self).__init__() self.use_residual_flow = use_residual_flow checkpoint = torch.load(pretrained_pth) with open(config_pth) as f: config = yaml.safe_load(f) self.generator = Generator(num_regions=config['model_params']['num_regions'], num_channels=config['model_params']['num_channels'], revert_axis_swap=config['model_params']['revert_axis_swap'], **config['model_params']['generator_params']).cuda() self.generator.load_state_dict(checkpoint['generator']) self.generator.eval() self.set_requires_grad(self.generator, False) self.region_predictor = RegionPredictor(num_regions=config['model_params']['num_regions'], num_channels=config['model_params']['num_channels'], estimate_affine=config['model_params']['estimate_affine'], **config['model_params']['region_predictor_params']).cuda() self.region_predictor.load_state_dict(checkpoint['region_predictor']) self.region_predictor.eval() self.set_requires_grad(self.region_predictor, False) self.bg_predictor = BGMotionPredictor(num_channels=config['model_params']['num_channels'], **config['model_params']['bg_predictor_params']) self.bg_predictor.load_state_dict(checkpoint['bg_predictor']) self.bg_predictor.eval() self.set_requires_grad(self.bg_predictor, False) self.scales = config['train_params']['scales'] self.pyramid = ImagePyramide(self.scales, self.generator.num_channels) if torch.cuda.is_available(): self.pyramid = self.pyramid.cuda() # self.vgg_loss_weights = config['train_params']['loss_weights']['rec_vgg'] # if sum(self.vgg_loss_weights) != 0: # self.vgg = Vgg19() # if torch.cuda.is_available(): # self.vgg = self.vgg.cuda() self.unet = DynamicNfUnet3D(dim=64, cond_dim=1024 + 6 + 2, cond_aud=1024, cond_pose=6, cond_eye=2, num_frames=num_frames, channels=3 + 256 + 16, out_grid_dim=2, out_conf_dim=1, dim_mults=dim_mults, use_hubert_audio_cond=True, learn_null_cond=learn_null_cond, use_final_activation=False, use_deconv=use_deconv, padding_mode=padding_mode) self.diffusion = DynamicNfGaussianDiffusion( denoise_fn = self.unet, num_frames=num_frames, image_size=img_size, sampling_timesteps=sampling_timesteps, timesteps=1000, # number of steps loss_type='l2', # L1 or L2 use_dynamic_thres=True, null_cond_prob=null_cond_prob, ddim_sampling_eta=ddim_sampling_eta ) self.face_loc_emb = Face_loc_Encoder() # training self.is_train = is_train if self.is_train: self.unet.train() self.diffusion.train() def update_num_frames(self, new_num_frames): # to update num_frames of Unet3D and GaussianDiffusion self.unet.update_num_frames(new_num_frames) self.diffusion.update_num_frames(new_num_frames) def generate_bbox_mask(self, bbox, size = 32): # b = bbox.shape[0] b, c, fn = bbox.size() bbox = bbox[:,:,0] # b, c, fn bbox[:, :2] = (bbox[:, :2]/bbox[:, 4].unsqueeze(1)) * size # rescale to 32* 32 for 128, and 64 * 64 for 256 bbox[:,2:4] = (bbox[:, 2:4]/bbox[:, 5].unsqueeze(1) )* size bbox_left_top = bbox[:, :4:2].to(torch.int32) # left up bbox_right_bottom = (bbox[:, 1:4:2] +1).to(torch.int32) # right down # generating 2D index row_indices = torch.arange(size).view(1, size, 1).expand(b, size, size).to(torch.uint8).cuda() col_indices = torch.arange(size).view(1, 1, size).expand(b, size, size).to(torch.uint8).cuda() # set the face bbox as 1, the first channel is y, the second is x mask = (row_indices >= bbox_left_top[:, 1].view(b, 1, 1)) & (row_indices <= bbox_right_bottom[:, 1].view(b, 1, 1)) & \ (col_indices >= bbox_left_top[:, 0].view(b, 1, 1)) & (col_indices <= bbox_right_bottom[:, 0].view(b, 1, 1)) # mask : b,32,32 bbox_mask = mask.unsqueeze(1).float() # b, 1, 32, 32 return bbox_mask def generate_mouth_mask(self, mouth_lmk, origin_size, size = 32): b, fn, pn, c = mouth_lmk.size() # b, fn 12, 2 origin_size = origin_size.unsqueeze(1) mouth_lmk = (mouth_lmk/origin_size) * size ld_coner = mouth_lmk.max(dim=-2)[0] ru_coner = mouth_lmk.min(dim=-2)[0] row_indices = torch.arange(size).view(1, size, 1).expand(b, fn, size, size).to(torch.uint8).cuda() col_indices = torch.arange(size).view(1, 1, size).expand(b, fn, size, size).to(torch.uint8).cuda() mask = (row_indices >= ru_coner[:,:, 1].view(b, fn, 1, 1)) & (row_indices <= ld_coner[:,:, 1].view(b, fn, 1, 1)) & \ (col_indices >= ru_coner[:,:, 0].view(b, fn, 1, 1)) & (col_indices <= ld_coner[:,:, 0].view(b, fn, 1, 1)) bbox_mask = (mask).float() # b, 1, 32, 32 return bbox_mask def forward(self, real_vid, ref_img, ref_text, ref_pose, ref_eye_blink, bbox, mouth_lmk, is_eval=False, ref_id = 0): if True: b,c,f,h,w = real_vid.size() real_vid = rearrange(real_vid, 'b c f h w -> (b f) c h w') bright = 64. / 255 contrast = 0.25 sat = 0.25 hue = 0.04 color_jitters = transforms.ColorJitter(hue = (-hue, hue), \ contrast = (max(0, 1 - contrast), 1 + contrast), saturation = (max(0, 1 - sat), 1 + sat), brightness = (max(0, 1 - bright), 1 + bright)) # mast have shape : [..., 1 or 3, H, W] real_vid = real_vid/255. # because the img are floats, so need to scale to 0-1 real_vid = color_jitters(real_vid) # shape need be checked real_vid = rearrange(real_vid, '(b f) c h w -> b c f h w', b = b, f = f) ref_img = real_vid[:,:,ref_id,:,:].clone().detach() b, _, nf, H, W = real_vid.size() ref_pose = ref_pose.squeeze(1).permute(0, 2, 1)[:, :, :-1] ref_eye_blink = ref_eye_blink.squeeze(1).permute(0, 2, 1) init_pose = ref_pose[:, ref_id].unsqueeze(1).repeat(1, nf, 1) # b, fn, 7 init state init_eye = ref_eye_blink[:, ref_id].unsqueeze(1).repeat(1, nf, 1) # b, fn, 2 ref_text = torch.concat([ref_text, (ref_pose-init_pose), (ref_eye_blink-init_eye)], dim=-1) bbox_mask = self.generate_bbox_mask(bbox, size = real_vid.shape[-1]) # b, 1, 32, 32 bbox_mask = self.face_loc_emb(bbox_mask) # conv encoder for face mask mouth_mask = self.generate_mouth_mask(mouth_lmk, bbox[:,None,-2:,0], size = real_vid.shape[-1]//4) real_grid_list = [] real_conf_list = [] real_out_img_list = [] real_warped_img_list = [] output_dict = {} with torch.no_grad(): b,c,f,h,w = real_vid.size() real_vid_tmp = rearrange(real_vid, 'b c f h w -> (b f) c h w') # real_vid.reshape(b * f, c, h, w) ref_img_tmp = ref_img.unsqueeze(1).repeat(1,f,1,1,1).reshape(-1, 3, h, w) source_region_params = self.region_predictor(ref_img_tmp) driving_region_params = self.region_predictor(real_vid_tmp) bg_params = self.bg_predictor(ref_img_tmp, real_vid_tmp) generated = self.generator(ref_img_tmp, source_region_params=source_region_params, driving_region_params=driving_region_params, bg_params=bg_params) output_dict["real_vid_grid"] = rearrange(generated["optical_flow"], '(b f) h w c -> b c f h w', b = b, f = f) output_dict["real_vid_conf"] = rearrange(generated["occlusion_map"], '(b f) c h w -> b c f h w', b = b, f = f) output_dict["real_out_vid"] = rearrange(generated["prediction"], '(b f) c h w -> b c f h w', b = b, f = f) output_dict["real_warped_vid"] = rearrange(generated["deformed"], '(b f) c h w -> b c f h w', b = b, f = f) ref_img_fea = generated["bottle_neck_feat"][::f].clone().detach() #bs, 256, 32, 32 del real_vid_tmp, ref_img_tmp del generated if self.is_train: if self.use_residual_flow: h, w, = H // 4, W // 4 identity_grid = self.get_grid(b, nf, h, w, normalize=True).cuda() output_dict["loss"], output_dict["null_cond_mask"] = self.diffusion( torch.cat((output_dict["real_vid_grid"] - identity_grid, output_dict["real_vid_conf"] * 2 - 1), dim=1), ref_img_fea, bbox_mask, ref_text) else: output_dict["loss"], output_dict["null_cond_mask"] = self.diffusion( torch.cat((output_dict["real_vid_grid"], output_dict["real_vid_conf"] * 2 - 1), dim=1), ref_img_fea, bbox_mask, ref_text) pred = self.diffusion.pred_x0 pred_flow = pred[:, :2, :, :, :] pred_conf = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5 # loss_high_freq = hf_loss(fea = pred_flow, mask = self.gaussian_mask.cuda(), dim = 2) loss_high_freq = nn.MSELoss(reduce = False)(pred_flow, output_dict["real_vid_grid"]) + nn.MSELoss(reduce = False)(pred_conf, output_dict["real_vid_conf"]) output_dict['mouth_loss'] = ((output_dict["loss"] * mouth_mask.unsqueeze(1)).sum())/(mouth_mask.sum()) output_dict["loss"] = output_dict["loss"].mean(1) output_dict["floss"] = loss_high_freq.mean(1) if(is_eval): with torch.no_grad(): fake_out_img_list = [] fake_warped_img_list = [] pred = self.diffusion.pred_x0 # bs, 3, nf, 32, 32 if self.use_residual_flow: output_dict["fake_vid_grid"] = pred[:, :2, :, :, :] + identity_grid else: output_dict["fake_vid_grid"] = pred[:, :2, :, :, :] # optical flow predicted by DM_2 bs, 2, nf, 32, 32 output_dict["fake_vid_conf"] = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5 # occlusion map predicted by DM_2 bs, 1, nf, 32, 32 for idx in range(nf): fake_grid = output_dict["fake_vid_grid"][:, :, idx, :, :].permute(0, 2, 3, 1) #bs, 32, 32, 2 fake_conf = output_dict["fake_vid_conf"][:, :, idx, :, :] #bs, 1, 32, 32 # predict fake out image and fake warped image generated = self.generator.forward_with_flow(source_image=ref_img, optical_flow=fake_grid, occlusion_map=fake_conf) fake_out_img_list.append(generated["prediction"]) fake_warped_img_list.append(generated["deformed"].detach()) del generated output_dict["fake_out_vid"] = torch.stack(fake_out_img_list, dim=2) output_dict["fake_warped_vid"] = torch.stack(fake_warped_img_list, dim=2).detach() # output_dict["rec_loss"] = nn.L1Loss(reduce=False)(real_vid, output_dict["fake_out_vid"]) # output_dict["rec_warp_loss"] = nn.L1Loss(reduce=False)(real_vid, output_dict["fake_warped_vid"]) # b,c,f,h,w = real_vid.size() # real_vid_tensor = real_vid.permute(0,2,1,3,4).reshape(b*f,c,h,w).detach() # fake_out_vid_tensor = output_dict["fake_out_vid"].permute(0,2,1,3,4).reshape(b*f,c,h,w) # if sum(self.vgg_loss_weights) != 0: # pyramide_real = self.pyramid(real_vid_tensor) # pyramide_generated = self.pyramid(fake_out_vid_tensor) # value_total = 0 # for scale in self.scales: # x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) # y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) # for i, weight in enumerate(self.vgg_loss_weights): # value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() # value_total += self.vgg_loss_weights[i] * value # output_dict['rec_vgg_loss'] = value_total # if __debug__: # end_time = time.time() # end # # print(f'forward for eval part time {end_time- start_time}') # start_time = end_time return output_dict def sample_one_video(self, real_vid, sample_img, sample_audio_hubert, sample_pose, sample_eye, sample_bbox, cond_scale): output_dict = {} sample_img_fea = self.generator.compute_fea(sample_img) # sample_img: bs,3,128,128 sample_img_fea: 1,256,32,32 bbox_mask = self.generate_bbox_mask(sample_bbox, size = sample_img.shape[-1]) bbox_mask = self.face_loc_emb(bbox_mask) # conv encoder for face mask ref_pose = sample_pose.permute(0, 2, 1)[:,:,:-1] ref_eye_blink = sample_eye.permute(0, 2, 1) init_pose = ref_pose[:, 0].unsqueeze(1).repeat(1,ref_pose.shape[1], 1) init_eye = ref_eye_blink[:, 0].unsqueeze(1).repeat(1,ref_eye_blink.shape[1], 1) # b, fn, 2 ref_text = torch.concat([sample_audio_hubert, (ref_pose - init_pose), (ref_eye_blink - init_eye)], dim=-1) bs = sample_img_fea.size(0) # if cond_scale = 1.0, not using unconditional model # pred bs, 3, nf, 32, 32 pred = self.diffusion.sample(sample_img_fea, bbox_mask, cond=ref_text, batch_size=bs, cond_scale=cond_scale) if self.use_residual_flow: b, _, nf, h, w = pred[:, :2, :, :, :].size() identity_grid = self.get_grid(b, nf, h, w, normalize=True).cuda() output_dict["sample_vid_grid"] = pred[:, :2, :, :, :] + identity_grid else: output_dict["sample_vid_grid"] = pred[:, :2, :, :, :] # bs, 2, nf, 32, 32 output_dict["sample_vid_conf"] = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5 # bs, 1, nf, 32, 32 nf = output_dict["sample_vid_grid"].size(2) with torch.no_grad(): sample_out_img_list = [] sample_warped_img_list = [] for idx in range(nf): sample_grid = output_dict["sample_vid_grid"][:, :, idx, :, :].permute(0, 2, 3, 1) sample_conf = output_dict["sample_vid_conf"][:, :, idx, :, :] # predict fake out image and fake warped image generated = self.generator.forward_with_flow(source_image=sample_img, optical_flow=sample_grid, occlusion_map=sample_conf) sample_out_img_list.append(generated["prediction"]) sample_warped_img_list.append(generated["deformed"]) output_dict["sample_out_vid"] = torch.stack(sample_out_img_list, dim=2) output_dict["sample_warped_vid"] = torch.stack(sample_warped_img_list, dim=2) output_dict["rec_loss"] = nn.L1Loss(reduce=False)(real_vid, output_dict["sample_out_vid"]) output_dict["rec_warp_loss"] = nn.L1Loss(reduce=False)(real_vid, output_dict["sample_warped_vid"]) # b,c,f,h,w = real_vid[0].unsqueeze(dim=0).size() # real_vid_tensor = real_vid[0].unsqueeze(dim=0).permute(0,2,1,3,4).reshape(b*f,c,h,w) # fake_out_vid_tensor = output_dict["sample_out_vid"].permute(0,2,1,3,4).reshape(b*f,c,h,w) # pyramide_real = self.pyramid(real_vid_tensor) # pyramide_generated = self.pyramid(fake_out_vid_tensor) # if sum(self.vgg_loss_weights) != 0: # value_total = 0 # for scale in self.scales: # x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) # y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) # for i, weight in enumerate(self.vgg_loss_weights): # value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() # value_total += self.vgg_loss_weights[i] * value # output_dict['rec_vgg_loss'] = value_total return output_dict def get_grid(self, b, nf, H, W, normalize=True): if normalize: h_range = torch.linspace(-1, 1, H) w_range = torch.linspace(-1, 1, W) else: h_range = torch.arange(0, H) w_range = torch.arange(0, W) grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).repeat(b, 1, 1, 1).flip(3).float() # flip h,w to x,y return grid.permute(0, 3, 1, 2).unsqueeze(dim=2).repeat(1, 1, nf, 1, 1) def set_requires_grad(self, nets, requires_grad=False): """Set requies_grad=Fasle for all the networks to avoid unnecessary computations Parameters: nets (network list) -- a list of networks requires_grad (bool) -- whether the networks require gradients or not """ if not isinstance(nets, list): nets = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad if __name__ == "__main__": os.environ["CUDA_VISIBLE_DEVICES"] = "3" bs = 5 img_size = 128 num_frames = 10 ref_text = ["play basketball"] * bs ref_img = torch.rand((bs, 3, img_size, img_size), dtype=torch.float32).cuda() real_vid = torch.rand((bs, 3, num_frames, img_size, img_size), dtype=torch.float32).cuda() model = FlowDiffusion(num_frames=num_frames, use_residual_flow=False, sampling_timesteps=10, dim_mults=(1, 2, 4, 8, 16)) model.cuda() # embedding ref_text # cond = bert_embed(tokenize(ref_text), return_cls_repr=model.diffusion.text_use_bert_cls).cuda() # to simulate the situation of hubert embedding cond = torch.rand((bs,10,1024), dtype=torch.float32).cuda() model = DataParallelWithCallback(model) output_dict = model.forward(real_vid=real_vid, ref_img=ref_img, ref_text=cond) model.module.sample_one_video(sample_img=ref_img[0].unsqueeze(dim=0), sample_audio_hubert=cond[0].unsqueeze(dim=0), cond_scale=1.0) ================================================ FILE: DM_3/modules/video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_mouth_mask_rand_6D.py ================================================ ''' stage 2: using random reference, long and dynamic clip with lip loss, 6D pose, conditioned by cross attention ''' import os import torch import torch.nn as nn import sys sys.path.append('your/path') from LFG.modules.generator import Generator from LFG.modules.bg_motion_predictor import BGMotionPredictor from LFG.modules.region_predictor import RegionPredictor from DM_3.modules.video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi import DynamicNfUnet3D, DynamicNfGaussianDiffusion import yaml from sync_batchnorm import DataParallelWithCallback from filter_fourier import * from torchvision import models import numpy as np import time from einops import rearrange class Attention(nn.Module): def __init__(self, params): super(Attention, self).__init__() self.fc_query = nn.Linear(params['n'], params['dim_attention'], bias=False) self.fc_attention = nn.Linear(params['dim_attention'], 1) def forward(self, ctx_val, ctx_key, ctx_mask, ht_query): ht_query = self.fc_query(ht_query) attention_score = torch.tanh(ctx_key + ht_query[:, None, None, :]) attention_score = self.fc_attention(attention_score).squeeze(3) attention_score = attention_score - attention_score.max() attention_score = torch.exp(attention_score) * ctx_mask attention_score = attention_score / (attention_score.sum(2).sum(1)[:, None, None] + 1e-10) ct = (ctx_val * attention_score[:, None, :, :]).sum(3).sum(2) return ct, attention_score class Face_loc_Encoder(nn.Module): def __init__(self, dim = 1): super(Face_loc_Encoder, self).__init__() self.conv1 = nn.Conv2d(dim, 8, kernel_size=3, stride=2, padding=1) self.conv2 = nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1) def forward(self, x): x = self.conv1(x) x = nn.functional.relu(x) x = self.conv2(x) x = nn.functional.relu(x) return x class Vgg19(torch.nn.Module): """ Vgg19 network for perceptual loss. """ def __init__(self, requires_grad=False): super(Vgg19, self).__init__() vgg_pretrained_features = models.vgg19(pretrained=True).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() for x in range(2): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(2, 7): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(7, 12): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(12, 21): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(21, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), requires_grad=False) self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), requires_grad=False) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, x): x = (x - self.mean) / self.std h_relu1 = self.slice1(x) h_relu2 = self.slice2(h_relu1) h_relu3 = self.slice3(h_relu2) h_relu4 = self.slice4(h_relu3) h_relu5 = self.slice5(h_relu4) out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] return out class FlowDiffusion(nn.Module): def __init__(self, img_size=32, num_frames=40, sampling_timesteps=250, null_cond_prob=0.1, ddim_sampling_eta=1., dim_mults=(1, 2, 4, 8), is_train=True, use_residual_flow=False, learn_null_cond=False, use_deconv=True, padding_mode="zeros", pretrained_pth="your_path/data/log-hdtf/hdtf128_2023-11-17_20:13/snapshots/RegionMM.pth", config_pth="your/path/DAWN-pytorch/config/hdtf128.yaml"): super(FlowDiffusion, self).__init__() self.use_residual_flow = use_residual_flow checkpoint = torch.load(pretrained_pth) with open(config_pth) as f: config = yaml.safe_load(f) self.generator = Generator(num_regions=config['model_params']['num_regions'], num_channels=config['model_params']['num_channels'], revert_axis_swap=config['model_params']['revert_axis_swap'], **config['model_params']['generator_params']).cuda() self.generator.load_state_dict(checkpoint['generator']) self.generator.eval() self.set_requires_grad(self.generator, False) self.region_predictor = RegionPredictor(num_regions=config['model_params']['num_regions'], num_channels=config['model_params']['num_channels'], estimate_affine=config['model_params']['estimate_affine'], **config['model_params']['region_predictor_params']).cuda() self.region_predictor.load_state_dict(checkpoint['region_predictor']) self.region_predictor.eval() self.set_requires_grad(self.region_predictor, False) self.bg_predictor = BGMotionPredictor(num_channels=config['model_params']['num_channels'], **config['model_params']['bg_predictor_params']) self.bg_predictor.load_state_dict(checkpoint['bg_predictor']) self.bg_predictor.eval() self.set_requires_grad(self.bg_predictor, False) self.scales = config['train_params']['scales'] self.unet = DynamicNfUnet3D(dim=64, cond_dim=1024 + 6 + 2, cond_aud=1024, cond_pose=6, cond_eye=2, num_frames=num_frames, channels=3 + 256 + 16, out_grid_dim=2, out_conf_dim=1, dim_mults=dim_mults, use_hubert_audio_cond=True, learn_null_cond=learn_null_cond, use_final_activation=False, use_deconv=use_deconv, padding_mode=padding_mode) self.diffusion = DynamicNfGaussianDiffusion( denoise_fn = self.unet, num_frames=num_frames, image_size=img_size, sampling_timesteps=sampling_timesteps, timesteps=1000, # number of steps loss_type='l2', # L1 or L2 use_dynamic_thres=True, null_cond_prob=null_cond_prob, ddim_sampling_eta=ddim_sampling_eta ) self.face_loc_emb = Face_loc_Encoder() # training self.is_train = is_train if self.is_train: self.unet.train() self.diffusion.train() def update_num_frames(self, new_num_frames): # to update num_frames of Unet3D and GaussianDiffusion self.unet.update_num_frames(new_num_frames) self.diffusion.update_num_frames(new_num_frames) def generate_bbox_mask(self, bbox, size = 32): # b = bbox.shape[0] b, c, fn = bbox.size() bbox = bbox[:,:,0] # b, c, fn bbox[:, :2] = (bbox[:, :2]/bbox[:, 4].unsqueeze(1)) * size # rescale 32* 32 bbox[:,2:4] = (bbox[:, 2:4]/bbox[:, 5].unsqueeze(1) )* size bbox_left_top = bbox[:, :4:2].to(torch.int32) bbox_right_bottom = (bbox[:, 1:4:2] +1).to(torch.int32) row_indices = torch.arange(size).view(1, size, 1).expand(b, size, size).to(torch.uint8).cuda() col_indices = torch.arange(size).view(1, 1, size).expand(b, size, size).to(torch.uint8).cuda() mask = (row_indices >= bbox_left_top[:, 1].view(b, 1, 1)) & (row_indices <= bbox_right_bottom[:, 1].view(b, 1, 1)) & \ (col_indices >= bbox_left_top[:, 0].view(b, 1, 1)) & (col_indices <= bbox_right_bottom[:, 0].view(b, 1, 1)) bbox_mask = mask.unsqueeze(1).float() # b, 1, 32, 32 return bbox_mask def generate_mouth_mask(self, mouth_lmk, origin_size, size = 32): b, fn, pn, c = mouth_lmk.size() # b, fn 12, 2 origin_size = origin_size.unsqueeze(1) mouth_lmk = (mouth_lmk/origin_size) * size ld_coner = mouth_lmk.max(dim=-2)[0] ru_coner = mouth_lmk.min(dim=-2)[0] row_indices = torch.arange(size).view(1, size, 1).expand(b, fn, size, size).to(torch.uint8).cuda() col_indices = torch.arange(size).view(1, 1, size).expand(b, fn, size, size).to(torch.uint8).cuda() mask = (row_indices >= ru_coner[:,:, 1].view(b, fn, 1, 1)) & (row_indices <= ld_coner[:,:, 1].view(b, fn, 1, 1)) & \ (col_indices >= ru_coner[:,:, 0].view(b, fn, 1, 1)) & (col_indices <= ld_coner[:,:, 0].view(b, fn, 1, 1)) bbox_mask = (mask).float() # b, 1, 32, 32 return bbox_mask def forward(self, real_vid, ref_img, ref_text, ref_pose, ref_eye_blink, bbox, mouth_lmk, is_eval=False): if True: b,c,f,h,w = real_vid.size() real_vid = rearrange(real_vid, 'b c f h w -> (b f) c h w') bright = 64. / 255 contrast = 0.25 sat = 0.25 hue = 0.04 # bright_f = random.uniform(max(0, 1 - bright), 1 + bright) # contrast_f = random.uniform(max(0, 1 - contrast), 1 + contrast) # sat_f = random.uniform(max(0, 1 - sat), 1 + sat) # hue_f = random.uniform(-hue, hue) color_jitters = transforms.ColorJitter(hue = (-hue, hue), \ contrast = (max(0, 1 - contrast), 1 + contrast), saturation = (max(0, 1 - sat), 1 + sat), brightness = (max(0, 1 - bright), 1 + bright)) # mast have shape : [..., 1 or 3, H, W] real_vid = real_vid/255. # because the img are floats, so need to scale to 0-1 real_vid = color_jitters(real_vid) # shape need be checked real_vid = rearrange(real_vid, '(b f) c h w -> b c f h w', b = b, f = f) ref_img = real_vid[:,:,0,:,:].clone().detach() real_vid = real_vid[:,:,1:,:,:] # if __debug__: # end_time = time.time() # end # print(f'data augment time 1 {end_time- start_time}') # start_time = end_time # sample_frame_list = color_jitters(sample_frame_list) # if __debug__: # end_time = time.time() # end # print(f'data augment time 2 {end_time- start_time}') # start_time = end_time # else: # real_vid = rearrange(real_vid, 'b f c h w -> b c f h w') # else: # real_vid = real_vid/255. # ref_img = ref_img/255. b, _, _, H, W = real_vid.size() _, nf, _ = ref_text.size() ref_pose = ref_pose.squeeze(1).permute(0, 2, 1)[:, :, :-1] ref_eye_blink = ref_eye_blink.squeeze(1).permute(0, 2, 1) init_pose = ref_pose[:, 0].unsqueeze(1).repeat(1, nf, 1) # b, fn, 7 init state init_eye = ref_eye_blink[:, 0].unsqueeze(1).repeat(1, nf, 1) # b, fn, 2 ref_text = torch.concat([ref_text, (ref_pose-init_pose), (ref_eye_blink-init_eye)], dim=-1) ref_text = ref_text[:, 1:] bbox_mask = self.generate_bbox_mask(bbox, size = real_vid.shape[-1]) # b, 1, 32, 32 bbox_mask = self.face_loc_emb(bbox_mask) # conv encoder for face mask mouth_mask = self.generate_mouth_mask(mouth_lmk, bbox[:,None,-2:,0], size = real_vid.shape[-1]//4) real_grid_list = [] real_conf_list = [] real_out_img_list = [] real_warped_img_list = [] output_dict = {} # if __debug__: # end_time = time.time() # end # # print(f'forward process time {end_time- start_time}') # start_time = end_time with torch.no_grad(): # for idx in range(nf): # driving_region_params = self.region_predictor(real_vid[:, :, idx, :, :]) # bg_params = self.bg_predictor(ref_img, real_vid[:, :, idx, :, :]) # generated = self.generator(ref_img, source_region_params=source_region_params, # driving_region_params=driving_region_params, bg_params=bg_params) # generated.update({'source_region_params': source_region_params, # 'driving_region_params': driving_region_params}) # real_grid_list.append(generated["optical_flow"].permute(0, 3, 1, 2)) # # normalized occlusion map # real_conf_list.append(generated["occlusion_map"]) # real_out_img_list.append(generated["prediction"]) # real_warped_img_list.append(generated["deformed"]) b,c,f,h,w = real_vid.size() real_vid_tmp = rearrange(real_vid, 'b c f h w -> (b f) c h w')# real_vid.reshape(b * f, c, h, w) ref_img_tmp = ref_img.unsqueeze(1).repeat(1,f,1,1,1).reshape(-1, 3, h, w) source_region_params = self.region_predictor(ref_img_tmp) driving_region_params = self.region_predictor(real_vid_tmp) bg_params = self.bg_predictor(ref_img_tmp, real_vid_tmp) generated = self.generator(ref_img_tmp, source_region_params=source_region_params, driving_region_params=driving_region_params, bg_params=bg_params) output_dict["real_vid_grid"] = rearrange(generated["optical_flow"], '(b f) h w c -> b c f h w', b = b, f = f) # .permute(0,3,1,2).reshape(b, 2, f, 32, 32) output_dict["real_vid_conf"] = rearrange(generated["occlusion_map"], '(b f) c h w -> b c f h w', b = b, f = f) # generated["occlusion_map"].reshape(b, 1, f, 32, 32) output_dict["real_out_vid"] = rearrange(generated["prediction"], '(b f) c h w -> b c f h w', b = b, f = f) # generated["prediction"].reshape(b, 3, f, h, w) output_dict["real_warped_vid"] = rearrange(generated["deformed"], '(b f) c h w -> b c f h w', b = b, f = f) # generated["deformed"].reshape(b, 3, f, h, w) # output_dict["real_vid_grid"] = torch.stack(real_grid_list, dim=2) # bs,2,num_frames,32,32 # output_dict["real_vid_conf"] = torch.stack(real_conf_list, dim=2) # bs,1,num_frames,32,32 # output_dict["real_out_vid"] = torch.stack(real_out_img_list, dim=2) # bs,3,num_frames,128,128 # output_dict["real_warped_vid"] = torch.stack(real_warped_img_list, dim=2) # bs,3,num_frames,128,128 # reference images are the same for different time steps, just pick the final one # ref_img_fea = generated["bottle_neck_feat"].clone().detach() #bs, 256, 32, 32 ref_img_fea = generated["bottle_neck_feat"][::f].clone().detach() #bs, 256, 32, 32 del real_vid_tmp, ref_img_tmp del generated # if __debug__: # end_time = time.time() # end # # print(f'generate gt flow time {end_time- start_time}') # start_time = end_time if self.is_train: if self.use_residual_flow: h, w, = H // 4, W // 4 identity_grid = self.get_grid(b, nf, h, w, normalize=True).cuda() output_dict["loss"], output_dict["null_cond_mask"] = self.diffusion( torch.cat((output_dict["real_vid_grid"] - identity_grid, output_dict["real_vid_conf"] * 2 - 1), dim=1), ref_img_fea, bbox_mask, ref_text) else: output_dict["loss"], output_dict["null_cond_mask"] = self.diffusion( torch.cat((output_dict["real_vid_grid"], output_dict["real_vid_conf"] * 2 - 1), dim=1), ref_img_fea, bbox_mask, ref_text) pred = self.diffusion.pred_x0 pred_flow = pred[:, :2, :, :, :] pred_conf = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5 # loss_high_freq = hf_loss(fea = pred_flow, mask = self.gaussian_mask.cuda(), dim = 2) # loss_high_freq = hf_loss_2(pred_flow, output_dict["real_vid_grid"], dim=2) loss_high_freq = nn.MSELoss(reduce = False)(pred_flow, output_dict["real_vid_grid"]) + nn.MSELoss(reduce = False)(pred_conf, output_dict["real_vid_conf"]) output_dict['mouth_loss'] = ((output_dict["loss"] * mouth_mask.unsqueeze(1)).sum())/(mouth_mask.sum()) output_dict["loss"] = output_dict["loss"].mean(1) output_dict["floss"] = loss_high_freq.mean(1) # if __debug__: # end_time = time.time() # end # # print(f'forward diffusion time {end_time- start_time}') # start_time = end_time if(is_eval): with torch.no_grad(): fake_out_img_list = [] fake_warped_img_list = [] pred = self.diffusion.pred_x0 # bs, 3, nf, 32, 32 if self.use_residual_flow: output_dict["fake_vid_grid"] = pred[:, :2, :, :, :] + identity_grid else: output_dict["fake_vid_grid"] = pred[:, :2, :, :, :] # optical flow predicted by DM_2 bs, 2, nf, 32, 32 output_dict["fake_vid_conf"] = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5 # occlusion map predicted by DM_2 bs, 1, nf, 32, 32 for idx in range(nf - 1): fake_grid = output_dict["fake_vid_grid"][:, :, idx, :, :].permute(0, 2, 3, 1) #bs, 32, 32, 2 fake_conf = output_dict["fake_vid_conf"][:, :, idx, :, :] #bs, 1, 32, 32 # predict fake out image and fake warped image generated = self.generator.forward_with_flow(source_image=ref_img, optical_flow=fake_grid, occlusion_map=fake_conf) fake_out_img_list.append(generated["prediction"]) fake_warped_img_list.append(generated["deformed"].detach()) del generated output_dict["fake_out_vid"] = torch.stack(fake_out_img_list, dim=2) output_dict["fake_warped_vid"] = torch.stack(fake_warped_img_list, dim=2).detach() # output_dict["rec_loss"] = nn.L1Loss(reduce=False)(real_vid, output_dict["fake_out_vid"]) # output_dict["rec_warp_loss"] = nn.L1Loss(reduce=False)(real_vid, output_dict["fake_warped_vid"]) # b,c,f,h,w = real_vid.size() # if __debug__: # end_time = time.time() # end # # print(f'forward for eval part time {end_time- start_time}') # start_time = end_time return output_dict def sample_one_video(self, real_vid, sample_img, sample_audio_hubert, sample_pose, sample_eye, sample_bbox, cond_scale): output_dict = {} sample_img_fea = self.generator.compute_fea(sample_img) # sample_img: bs,3,128,128 sample_img_fea: 1,256,32,32 bbox_mask = self.generate_bbox_mask(sample_bbox, size = sample_img.shape[-1]) bbox_mask = self.face_loc_emb(bbox_mask) # conv encoder for face mask ref_pose = sample_pose.permute(0, 2, 1)[:,:,:-1] ref_eye_blink = sample_eye.permute(0, 2, 1) init_pose = ref_pose[:, 0].unsqueeze(1).repeat(1,ref_pose.shape[1], 1) init_eye = ref_eye_blink[:, 0].unsqueeze(1).repeat(1,ref_eye_blink.shape[1], 1) # b, fn, 2 ref_text = torch.concat([sample_audio_hubert, (ref_pose - init_pose), (ref_eye_blink - init_eye)], dim=-1) ref_text = ref_text[:, 1:] bs = sample_img_fea.size(0) # if cond_scale = 1.0, not using unconditional model # pred bs, 3, nf, 32, 32 pred = self.diffusion.sample(sample_img_fea, bbox_mask, cond=ref_text, batch_size=bs, cond_scale=cond_scale) if self.use_residual_flow: b, _, nf, h, w = pred[:, :2, :, :, :].size() identity_grid = self.get_grid(b, nf, h, w, normalize=True).cuda() output_dict["sample_vid_grid"] = pred[:, :2, :, :, :] + identity_grid else: output_dict["sample_vid_grid"] = pred[:, :2, :, :, :] # bs, 2, nf, 32, 32 output_dict["sample_vid_conf"] = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5 # bs, 1, nf, 32, 32 nf = output_dict["sample_vid_grid"].size(2) with torch.no_grad(): sample_out_img_list = [] sample_warped_img_list = [] for idx in range(nf): sample_grid = output_dict["sample_vid_grid"][:, :, idx, :, :].permute(0, 2, 3, 1) sample_conf = output_dict["sample_vid_conf"][:, :, idx, :, :] # predict fake out image and fake warped image generated = self.generator.forward_with_flow(source_image=sample_img, optical_flow=sample_grid, occlusion_map=sample_conf) sample_out_img_list.append(generated["prediction"]) sample_warped_img_list.append(generated["deformed"]) output_dict["sample_out_vid"] = torch.stack(sample_out_img_list, dim=2) output_dict["sample_warped_vid"] = torch.stack(sample_warped_img_list, dim=2) output_dict["rec_loss"] = nn.L1Loss(reduce=False)(real_vid, output_dict["sample_out_vid"]) output_dict["rec_warp_loss"] = nn.L1Loss(reduce=False)(real_vid, output_dict["sample_warped_vid"]) return output_dict def get_grid(self, b, nf, H, W, normalize=True): if normalize: h_range = torch.linspace(-1, 1, H) w_range = torch.linspace(-1, 1, W) else: h_range = torch.arange(0, H) w_range = torch.arange(0, W) grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).repeat(b, 1, 1, 1).flip(3).float() # flip h,w to x,y return grid.permute(0, 3, 1, 2).unsqueeze(dim=2).repeat(1, 1, nf, 1, 1) def set_requires_grad(self, nets, requires_grad=False): """Set requies_grad=Fasle for all the networks to avoid unnecessary computations Parameters: nets (network list) -- a list of networks requires_grad (bool) -- whether the networks require gradients or not """ if not isinstance(nets, list): nets = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad if __name__ == "__main__": os.environ["CUDA_VISIBLE_DEVICES"] = "3" bs = 5 img_size = 128 num_frames = 10 ref_text = ["play basketball"] * bs ref_img = torch.rand((bs, 3, img_size, img_size), dtype=torch.float32).cuda() real_vid = torch.rand((bs, 3, num_frames, img_size, img_size), dtype=torch.float32).cuda() model = FlowDiffusion(num_frames=num_frames, use_residual_flow=False, sampling_timesteps=10, dim_mults=(1, 2, 4, 8, 16)) model.cuda() # embedding ref_text # cond = bert_embed(tokenize(ref_text), return_cls_repr=model.diffusion.text_use_bert_cls).cuda() # to simulate the situation of hubert embedding cond = torch.rand((bs,10,1024), dtype=torch.float32).cuda() model = DataParallelWithCallback(model) output_dict = model.forward(real_vid=real_vid, ref_img=ref_img, ref_text=cond) model.module.sample_one_video(sample_img=ref_img[0].unsqueeze(dim=0), sample_audio_hubert=cond[0].unsqueeze(dim=0), cond_scale=1.0) ================================================ FILE: DM_3/modules/video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_test.py ================================================ import os import torch import torch.nn as nn import sys from LFG.modules.generator import Generator from LFG.modules.bg_motion_predictor import BGMotionPredictor from LFG.modules.region_predictor import RegionPredictor from DM_3.modules.video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test import DynamicNfUnet3D, DynamicNfGaussianDiffusion import yaml from sync_batchnorm import DataParallelWithCallback from filter_fourier import * from torchvision import models import numpy as np import time from einops import rearrange class Attention(nn.Module): def __init__(self, params): super(Attention, self).__init__() self.fc_query = nn.Linear(params['n'], params['dim_attention'], bias=False) self.fc_attention = nn.Linear(params['dim_attention'], 1) def forward(self, ctx_val, ctx_key, ctx_mask, ht_query): ht_query = self.fc_query(ht_query) attention_score = torch.tanh(ctx_key + ht_query[:, None, None, :]) attention_score = self.fc_attention(attention_score).squeeze(3) attention_score = attention_score - attention_score.max() attention_score = torch.exp(attention_score) * ctx_mask attention_score = attention_score / (attention_score.sum(2).sum(1)[:, None, None] + 1e-10) ct = (ctx_val * attention_score[:, None, :, :]).sum(3).sum(2) return ct, attention_score class Face_loc_Encoder(nn.Module): def __init__(self, dim = 1): super(Face_loc_Encoder, self).__init__() self.conv1 = nn.Conv2d(dim, 8, kernel_size=3, stride=2, padding=1) self.conv2 = nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1) def forward(self, x): x = self.conv1(x) x = nn.functional.relu(x) x = self.conv2(x) x = nn.functional.relu(x) return x class Vgg19(torch.nn.Module): """ Vgg19 network for perceptual loss. """ def __init__(self, requires_grad=False): super(Vgg19, self).__init__() vgg_pretrained_features = models.vgg19(pretrained=True).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() for x in range(2): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(2, 7): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(7, 12): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(12, 21): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(21, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), requires_grad=False) self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), requires_grad=False) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, x): x = (x - self.mean) / self.std h_relu1 = self.slice1(x) h_relu2 = self.slice2(h_relu1) h_relu3 = self.slice3(h_relu2) h_relu4 = self.slice4(h_relu3) h_relu5 = self.slice5(h_relu4) out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] return out class FlowDiffusion(nn.Module): def __init__(self, img_size=32, num_frames=40, sampling_timesteps=250, win_width = 40, null_cond_prob=0.1, ddim_sampling_eta=1., pose_dim = 7, dim_mults=(1, 2, 4, 8), is_train=True, use_residual_flow=False, learn_null_cond=False, use_deconv=True, padding_mode="zeros", pretrained_pth="your_path/data/log-hdtf/hdtf128_2023-11-17_20:13/snapshots/RegionMM.pth", config_pth="your/path/DAWN-pytorch/config/hdtf128.yaml"): super(FlowDiffusion, self).__init__() self.use_residual_flow = use_residual_flow checkpoint = torch.load(pretrained_pth) with open(config_pth) as f: config = yaml.safe_load(f) self.generator = Generator(num_regions=config['model_params']['num_regions'], num_channels=config['model_params']['num_channels'], revert_axis_swap=config['model_params']['revert_axis_swap'], **config['model_params']['generator_params']).cuda() self.generator.load_state_dict(checkpoint['generator']) self.generator.eval() self.set_requires_grad(self.generator, False) self.region_predictor = RegionPredictor(num_regions=config['model_params']['num_regions'], num_channels=config['model_params']['num_channels'], estimate_affine=config['model_params']['estimate_affine'], **config['model_params']['region_predictor_params']).cuda() self.region_predictor.load_state_dict(checkpoint['region_predictor']) self.region_predictor.eval() self.set_requires_grad(self.region_predictor, False) self.bg_predictor = BGMotionPredictor(num_channels=config['model_params']['num_channels'], **config['model_params']['bg_predictor_params']) self.bg_predictor.load_state_dict(checkpoint['bg_predictor']) self.bg_predictor.eval() self.set_requires_grad(self.bg_predictor, False) self.scales = config['train_params']['scales'] self.pose_dim = pose_dim self.unet = DynamicNfUnet3D(dim=64, cond_dim=1024 + self.pose_dim + 2, cond_aud=1024, cond_pose=self.pose_dim, cond_eye=2, num_frames=num_frames, channels=3 + 256 + 16, out_grid_dim=2, out_conf_dim=1, dim_mults=dim_mults, use_hubert_audio_cond=True, learn_null_cond=learn_null_cond, use_final_activation=False, use_deconv=use_deconv, padding_mode=padding_mode, win_width = win_width) self.diffusion = DynamicNfGaussianDiffusion( denoise_fn = self.unet, num_frames=num_frames, image_size=img_size, sampling_timesteps=sampling_timesteps, timesteps=1000, # number of steps loss_type='l2', # L1 or L2 use_dynamic_thres=True, null_cond_prob=null_cond_prob, ddim_sampling_eta=ddim_sampling_eta ) self.face_loc_emb = Face_loc_Encoder() # training self.is_train = is_train if self.is_train: self.unet.train() self.diffusion.train() def update_num_frames(self, new_num_frames): # to update num_frames of Unet3D and GaussianDiffusion self.unet.update_num_frames(new_num_frames) self.diffusion.update_num_frames(new_num_frames) def generate_bbox_mask(self, bbox, size = 32): # b = bbox.shape[0] b, c, fn = bbox.size() bbox = bbox[:,:,0] # b, c, fn bbox[:, :2] = (bbox[:, :2]/bbox[:, 4].unsqueeze(1)) * size bbox[:,2:4] = (bbox[:, 2:4]/bbox[:, 5].unsqueeze(1) )* size bbox_left_top = bbox[:, :4:2].to(torch.int32) bbox_right_bottom = (bbox[:, 1:4:2] +1).to(torch.int32) row_indices = torch.arange(size).view(1, size, 1).expand(b, size, size).to(torch.uint8).cuda() col_indices = torch.arange(size).view(1, 1, size).expand(b, size, size).to(torch.uint8).cuda() mask = (row_indices >= bbox_left_top[:, 1].view(b, 1, 1)) & (row_indices <= bbox_right_bottom[:, 1].view(b, 1, 1)) & \ (col_indices >= bbox_left_top[:, 0].view(b, 1, 1)) & (col_indices <= bbox_right_bottom[:, 0].view(b, 1, 1)) bbox_mask = mask.unsqueeze(1).float() # b, 1, 32, 32 return bbox_mask def forward(self, real_vid, ref_img, ref_text, ref_pose, ref_eye_blink, bbox, is_eval=False, ref_id = 0): if True: b,c,f,h,w = real_vid.size() real_vid = rearrange(real_vid, 'b c f h w -> (b f) c h w') bright = 64. / 255 contrast = 0.25 sat = 0.25 hue = 0.04 color_jitters = transforms.ColorJitter(hue = (-hue, hue), \ contrast = (max(0, 1 - contrast), 1 + contrast), saturation = (max(0, 1 - sat), 1 + sat), brightness = (max(0, 1 - bright), 1 + bright)) # mast have shape : [..., 1 or 3, H, W] real_vid = real_vid/255. # because the img are floats, so need to scale to 0-1 real_vid = color_jitters(real_vid) # shape need be checked real_vid = rearrange(real_vid, '(b f) c h w -> b c f h w', b = b, f = f) ref_img = real_vid[:,:,ref_id,:,:].clone().detach() b, _, nf, H, W = real_vid.size() ref_pose = ref_pose.squeeze(1).permute(0, 2, 1) ref_eye_blink = ref_eye_blink.squeeze(1).permute(0, 2, 1) init_pose = ref_pose[:, ref_id].unsqueeze(1).repeat(1, nf, 1) # b, fn, 7 init state init_eye = ref_eye_blink[:, ref_id].unsqueeze(1).repeat(1, nf, 1) # b, fn, 2 ref_text = torch.concat([ref_text, (ref_pose-init_pose), (ref_eye_blink-init_eye)], dim=-1) bbox_mask = self.generate_bbox_mask(bbox, size = H) # b, 1, 32, 32 bbox_mask = self.face_loc_emb(bbox_mask) # conv encoder for face mask real_grid_list = [] real_conf_list = [] real_out_img_list = [] real_warped_img_list = [] output_dict = {} with torch.no_grad(): b,c,f,h,w = real_vid.size() real_vid_tmp = rearrange(real_vid, 'b c f h w -> (b f) c h w')# real_vid.reshape(b * f, c, h, w) ref_img_tmp = ref_img.unsqueeze(1).repeat(1,f,1,1,1).reshape(-1, 3, 128, 128) source_region_params = self.region_predictor(ref_img_tmp) driving_region_params = self.region_predictor(real_vid_tmp) bg_params = self.bg_predictor(ref_img_tmp, real_vid_tmp) generated = self.generator(ref_img_tmp, source_region_params=source_region_params, driving_region_params=driving_region_params, bg_params=bg_params) output_dict["real_vid_grid"] = rearrange(generated["optical_flow"], '(b f) h w c -> b c f h w', b = b, f = f) # .permute(0,3,1,2).reshape(b, 2, f, 32, 32) output_dict["real_vid_conf"] = rearrange(generated["occlusion_map"], '(b f) c h w -> b c f h w', b = b, f = f) # generated["occlusion_map"].reshape(b, 1, f, 32, 32) output_dict["real_out_vid"] = rearrange(generated["prediction"], '(b f) c h w -> b c f h w', b = b, f = f) # generated["prediction"].reshape(b, 3, f, h, w) output_dict["real_warped_vid"] = rearrange(generated["deformed"], '(b f) c h w -> b c f h w', b = b, f = f) # generated["deformed"].reshape(b, 3, f, h, w) ref_img_fea = generated["bottle_neck_feat"][::f].clone().detach() #bs, 256, 32, 32 del real_vid_tmp, ref_img_tmp del generated if self.is_train: if self.use_residual_flow: h, w, = H // 4, W // 4 identity_grid = self.get_grid(b, nf, h, w, normalize=True).cuda() output_dict["loss"], output_dict["null_cond_mask"] = self.diffusion( torch.cat((output_dict["real_vid_grid"] - identity_grid, output_dict["real_vid_conf"] * 2 - 1), dim=1), ref_img_fea, bbox_mask, ref_text) else: output_dict["loss"], output_dict["null_cond_mask"] = self.diffusion( torch.cat((output_dict["real_vid_grid"], output_dict["real_vid_conf"] * 2 - 1), dim=1), ref_img_fea, bbox_mask, ref_text) pred = self.diffusion.pred_x0 pred_flow = pred[:, :2, :, :, :] # loss_high_freq = hf_loss(fea = pred_flow, mask = self.gaussian_mask.cuda(), dim = 2) loss_high_freq = hf_loss_2(pred_flow, output_dict["real_vid_grid"], dim=2) output_dict["loss"] = output_dict["loss"].mean(1) output_dict["floss"] = loss_high_freq.mean(1) # if __debug__: # end_time = time.time() # end # # print(f'forward diffusion time {end_time- start_time}') # start_time = end_time if(is_eval): with torch.no_grad(): fake_out_img_list = [] fake_warped_img_list = [] pred = self.diffusion.pred_x0 # bs, 3, nf, 32, 32 if self.use_residual_flow: output_dict["fake_vid_grid"] = pred[:, :2, :, :, :] + identity_grid else: output_dict["fake_vid_grid"] = pred[:, :2, :, :, :] # optical flow predicted by DM_2 bs, 2, nf, 32, 32 output_dict["fake_vid_conf"] = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5 # occlusion map predicted by DM_2 bs, 1, nf, 32, 32 for idx in range(nf): fake_grid = output_dict["fake_vid_grid"][:, :, idx, :, :].permute(0, 2, 3, 1) #bs, 32, 32, 2 fake_conf = output_dict["fake_vid_conf"][:, :, idx, :, :] #bs, 1, 32, 32 # predict fake out image and fake warped image generated = self.generator.forward_with_flow(source_image=ref_img, optical_flow=fake_grid, occlusion_map=fake_conf) fake_out_img_list.append(generated["prediction"]) fake_warped_img_list.append(generated["deformed"].detach()) del generated output_dict["fake_out_vid"] = torch.stack(fake_out_img_list, dim=2) output_dict["fake_warped_vid"] = torch.stack(fake_warped_img_list, dim=2).detach() return output_dict def sample_one_video(self, sample_img, sample_audio_hubert, sample_pose, sample_eye, sample_bbox, cond_scale, init_pose = None, init_eye = None, real_vid = None): output_dict = {} sample_img_fea = self.generator.compute_fea(sample_img) # sample_img: bs,3,128,128 sample_img_fea: 1,256,32,32 bbox_mask = self.generate_bbox_mask(sample_bbox, size = sample_img.shape[-1]) bbox_mask = self.face_loc_emb(bbox_mask) # conv encoder for face mask sample_pose = sample_pose[:,:self.pose_dim] ref_pose = sample_pose.permute(0, 2, 1) ref_eye_blink = sample_eye.permute(0, 2, 1) if init_pose == None: init_pose = ref_pose[:, 0].unsqueeze(1).repeat(1,ref_pose.shape[1], 1) else: init_pose = init_pose.unsqueeze(1).repeat(1,ref_pose.shape[1], 1) init_pose = init_pose[:,:,:self.pose_dim] if init_eye == None: init_eye = ref_eye_blink[:, 0].unsqueeze(1).repeat(1,ref_eye_blink.shape[1], 1) else: init_eye = init_eye.unsqueeze(1).repeat(1,ref_eye_blink.shape[1], 1) if ref_pose.shape[-1] != init_pose.shape[-1]: ref_pose = torch.concat([ref_pose, init_pose[:,:,-1].unsqueeze(-1)], dim = -1) ref_text = torch.concat([sample_audio_hubert, (ref_pose - init_pose), (ref_eye_blink - init_eye)], dim=-1) bs = sample_img_fea.size(0) # if cond_scale = 1.0, not using unconditional model # pred bs, 3, nf, 32, 32 start_time = time.time() # end start_time_total = time.time() # end pred = self.diffusion.sample(sample_img_fea, bbox_mask, cond=ref_text, batch_size=bs, cond_scale=cond_scale) if self.use_residual_flow: b, _, nf, h, w = pred[:, :2, :, :, :].size() identity_grid = self.get_grid(b, nf, h, w, normalize=True).cuda() output_dict["sample_vid_grid"] = pred[:, :2, :, :, :] + identity_grid else: output_dict["sample_vid_grid"] = pred[:, :2, :, :, :] # bs, 2, nf, 32, 32 output_dict["sample_vid_conf"] = (pred[:, 2, :, :, :].unsqueeze(dim=1) + 1) * 0.5 # bs, 1, nf, 32, 32 nf = output_dict["sample_vid_grid"].size(2) end_time = time.time() # end print(f'DDIM time {end_time- start_time}') start_time = end_time with torch.no_grad(): sample_out_img_list = [] sample_warped_img_list = [] for idx in range(nf): sample_grid = output_dict["sample_vid_grid"][:, :, idx, :, :].permute(0, 2, 3, 1) sample_conf = output_dict["sample_vid_conf"][:, :, idx, :, :] # predict fake out image and fake warped image generated = self.generator.forward_with_flow(source_image=sample_img, optical_flow=sample_grid, occlusion_map=sample_conf) sample_out_img_list.append(generated["prediction"]) sample_warped_img_list.append(generated["deformed"]) output_dict["sample_out_vid"] = torch.stack(sample_out_img_list, dim=2) output_dict["sample_warped_vid"] = torch.stack(sample_warped_img_list, dim=2) # real_vid_tmp = rearrange(real_vids, 'b c f h w -> (b f) c h w')# real_vid.reshape(b * f, c, h, w) # with torch.no_grad(): # sample_grid = output_dict["sample_vid_grid"] # sample_grid = rearrange(sample_grid, 'b c f h w -> (b f) h w c') # sample_conf = output_dict["sample_vid_conf"] # sample_conf = rearrange(sample_conf, 'b c f h w -> (b f) c h w') # sample_img = sample_img.repeat(nf, 1, 1, 1) # generated = self.generator.forward_with_flow(source_image=sample_img, # optical_flow=sample_grid, # occlusion_map=sample_conf) # output_dict["sample_out_vid"] = rearrange(generated["prediction"], '(b f) c h w -> b c f h w', b = 1, f =nf) end_time = time.time() # end # with open('your/path/DAWN-pytorch/speed_test.txt', 'a') as f: # f.write(f'AE time {end_time- start_time}\n') # f.write(f'Total time {end_time- start_time_total}') # print(f'AE time {end_time- start_time}') # print(f'Total time {end_time- start_time_total}') start_time = end_time return output_dict def get_grid(self, b, nf, H, W, normalize=True): if normalize: h_range = torch.linspace(-1, 1, H) w_range = torch.linspace(-1, 1, W) else: h_range = torch.arange(0, H) w_range = torch.arange(0, W) grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).repeat(b, 1, 1, 1).flip(3).float() # flip h,w to x,y return grid.permute(0, 3, 1, 2).unsqueeze(dim=2).repeat(1, 1, nf, 1, 1) def set_requires_grad(self, nets, requires_grad=False): """Set requies_grad=Fasle for all the networks to avoid unnecessary computations Parameters: nets (network list) -- a list of networks requires_grad (bool) -- whether the networks require gradients or not """ if not isinstance(nets, list): nets = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad if __name__ == "__main__": os.environ["CUDA_VISIBLE_DEVICES"] = "3" bs = 5 img_size = 128 num_frames = 10 ref_text = ["play basketball"] * bs ref_img = torch.rand((bs, 3, img_size, img_size), dtype=torch.float32).cuda() real_vid = torch.rand((bs, 3, num_frames, img_size, img_size), dtype=torch.float32).cuda() model = FlowDiffusion(num_frames=num_frames, use_residual_flow=False, sampling_timesteps=10, dim_mults=(1, 2, 4, 8, 16)) model.cuda() # embedding ref_text # cond = bert_embed(tokenize(ref_text), return_cls_repr=model.diffusion.text_use_bert_cls).cuda() # to simulate the situation of hubert embedding cond = torch.rand((bs,10,1024), dtype=torch.float32).cuda() model = DataParallelWithCallback(model) output_dict = model.forward(real_vid=real_vid, ref_img=ref_img, ref_text=cond) model.module.sample_one_video(sample_img=ref_img[0].unsqueeze(dim=0), sample_audio_hubert=cond[0].unsqueeze(dim=0), cond_scale=1.0) ================================================ FILE: DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi.py ================================================ ''' adding pose condtioning on baseline using cross attention to add different condition for training ''' import math import torch from torch import nn, einsum import torch.nn.functional as F from functools import partial from torchvision import transforms as T from PIL import Image from tqdm import tqdm from einops import rearrange, repeat, reduce, pack, unpack from einops_exts import rearrange_many from rotary_embedding_torch import RotaryEmbedding # from DM.modules.text import tokenize, bert_embed, HUBERT_MODEL_DIM # helpers functions def exists(x): return x is not None def noop(*args, **kwargs): pass def is_odd(n): return (n % 2) == 1 def default(val, d): if exists(val): return val return d() if callable(d) else d def cycle(dl): while True: for data in dl: yield data def num_to_groups(num, divisor): groups = num // divisor remainder = num % divisor arr = [divisor] * groups if remainder > 0: arr.append(remainder) return arr def prob_mask_like(shape, prob, device): if prob == 1: return torch.ones(shape, device=device, dtype=torch.bool) elif prob == 0: return torch.zeros(shape, device=device, dtype=torch.bool) else: return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob def is_list_str(x): if not isinstance(x, (list, tuple)): return False return all([type(el) == str for el in x]) # relative positional bias class RelativePositionBias(nn.Module): def __init__( self, heads=8, num_buckets=32, max_distance=128 ): super().__init__() self.num_buckets = num_buckets self.max_distance = max_distance self.relative_attention_bias = nn.Embedding(num_buckets, heads) @staticmethod def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): ret = 0 n = -relative_position num_buckets //= 2 ret += (n < 0).long() * num_buckets n = torch.abs(n) max_exact = num_buckets // 2 is_small = n < max_exact val_if_large = max_exact + ( torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).long() val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) ret += torch.where(is_small, n, val_if_large) return ret def forward(self, n, device): q_pos = torch.arange(n, dtype=torch.long, device=device) k_pos = torch.arange(n, dtype=torch.long, device=device) rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') rp_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance) # mask = -(((rel_pos > 35) + (rel_pos < -35)) * (1e8)) # -(((rp_bucket ==15) + (rp_bucket >= 30)) * (1e8)) values = self.relative_attention_bias(rp_bucket) return rearrange(values, 'i j h -> h i j') # + mask # small helper modules class EMA(): def __init__(self, beta): super().__init__() self.beta = beta def update_model_average(self, ma_model, current_model): for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): old_weight, up_weight = ma_params.data, current_params.data ma_params.data = self.update_average(old_weight, up_weight) def update_average(self, old, new): if old is None: return new return old * self.beta + (1 - self.beta) * new class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, *args, **kwargs): return self.fn(x, *args, **kwargs) + x class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb def Upsample(dim, use_deconv=True, padding_mode="reflect"): if use_deconv: return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) else: return nn.Sequential( nn.Upsample(scale_factor=(1, 2, 2), mode='nearest'), nn.Conv3d(dim, dim, (1, 3, 3), (1, 1, 1), (0, 1, 1), padding_mode=padding_mode) ) def Downsample(dim): return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) class LayerNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1)) def forward(self, x): var = torch.var(x, dim=1, unbiased=False, keepdim=True) mean = torch.mean(x, dim=1, keepdim=True) return (x - mean) / (var + self.eps).sqrt() * self.gamma class LayerNorm_img(nn.Module): def __init__(self, dim, stable = False): super().__init__() self.stable = stable self.g = nn.Parameter(torch.ones(dim)) def forward(self, x): if self.stable: x = x / x.amax(dim = -1, keepdim = True).detach() eps = 1e-5 if x.dtype == torch.float32 else 1e-3 var = torch.var(x, dim = -1, unbiased = False, keepdim = True) mean = torch.mean(x, dim = -1, keepdim = True) return (x - mean) * (var + eps).rsqrt() * self.g class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = LayerNorm(dim) def forward(self, x, **kwargs): x = self.norm(x) return self.fn(x, **kwargs) class Identity(nn.Module): def __init__(self, *args, **kwargs): super().__init__() def forward(self, x, *args, **kwargs): return x def l2norm(t): return F.normalize(t, dim = -1) # building block modules class Block(nn.Module): def __init__(self, dim, dim_out, groups=8): super().__init__() self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding=(0, 1, 1)) self.norm = nn.GroupNorm(groups, dim_out) self.act = nn.SiLU() def forward(self, x, time_scale_shift=None, audio_scale_shift=None): x = self.proj(x) x = self.norm(x) if exists(time_scale_shift): time_scale, time_shift = time_scale_shift x = x * (time_scale + 1) + time_shift # added by lml to change the control method of audio embedding, inspired by diffusedhead # if exists(audio_scale_shift): # # audio_scale and audio_shift:(bs, 64, nf, 1, 1) # # x:(bs, 64, nf, 32, 32) # audio_scale, audio_shift = audio_scale_shift # x = x * (audio_scale + 1) + audio_shift return self.act(x) class ResnetBlock(nn.Module): def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, groups=8): super().__init__() self.time_mlp = nn.Sequential( nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2) ) if exists(time_emb_dim) else None self.audio_mlp = nn.Sequential( nn.SiLU(), nn.Linear(audio_emb_dim, dim_out * 2) ) if exists(audio_emb_dim) else None self.block1 = Block(dim, dim_out, groups=groups) self.block2 = Block(dim_out, dim_out, groups=groups) self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb=None, audio_emb=None): time_scale_shift = None audio_scale_shift = None if exists(self.time_mlp): assert exists(time_emb), 'time emb must be passed in' time_emb = self.time_mlp(time_emb) time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1 time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1 # added by lml to get audio embedding if exists(self.audio_mlp): assert exists(audio_emb), 'audio emb must be passed in' audio_emb = self.audio_mlp(audio_emb) audio_emb = rearrange(audio_emb, 'b n c -> b c n 1 1') # bs, 128, nf, 1, 1 audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1 h = self.block1(x, time_scale_shift=time_scale_shift, audio_scale_shift=audio_scale_shift) h = self.block2(h) return h + self.res_conv(x) class ResnetBlock_ca(nn.Module): def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, groups=8): super().__init__() self.time_mlp = nn.Sequential( nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2) ) if exists(time_emb_dim) else None self.audio_mlp = nn.Sequential( nn.SiLU(), nn.Linear(audio_emb_dim, dim_out * 2) ) if exists(audio_emb_dim) else None # self.audio_mlp_2 = nn.Sequential( # nn.SiLU(), # nn.Linear(dim_out, dim_out * 2) # ) if exists(audio_emb_dim) else None attn_klass = CrossAttention self.cross_attn = attn_klass( dim = dim, context_dim = dim_out * 2 ) self.block1 = Block(dim, dim_out, groups=groups) self.block2 = Block(dim_out, dim_out, groups=groups) self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb=None, audio_emb=None): time_scale_shift = None audio_scale_shift = None b, c, f, H, W = x.size() if exists(self.time_mlp): assert exists(time_emb), 'time emb must be passed in' time_emb = self.time_mlp(time_emb) time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1 time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1 # added by lml to get audio embedding if exists(self.audio_mlp): assert exists(audio_emb), 'audio emb must be passed in' audio_emb = self.audio_mlp(audio_emb) if exists(self.cross_attn): # h = rearrange(x, 'b c f ... -> (b f) ... c') # h, ps = pack([h], 'b * c') # audio_emb = rearrange(audio_emb, 'b f ... -> (b f) ...') # audio_emb = self.cross_attn(h, context = audio_emb) # # h, = unpack(h, ps, 'b * c') # # h = rearrange(h, '(b f) ... c -> b c f ...', b = b, f = f, c = c) # # audio_emb = self.audio_mlp_2(audio_emb) # audio_emb = rearrange(audio_emb, '(b f) ... -> b f ...', b = b, f = f) assert exists(audio_emb) h = rearrange(x, 'b c f ... -> (b f) ... c') # h = rearrange(x, 'b c ... -> b ... c') h, ps = pack([h], 'b * c') h = self.cross_attn(h, context = audio_emb) + h h, = unpack(h, ps, 'b * c') # h = rearrange(h, 'b ... c -> b c ...') h = rearrange(h, '(b f) ... c -> b f c ...', b = b, f = f) # audio_emb = rearrange(audio_emb, 'b f (h w) c -> b c f h w', w = W, h = H) # bs, 128, nf, 1, 1 # audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1 h = self.block1(x, time_scale_shift=time_scale_shift) h = self.block2(h) return h + self.res_conv(x) class ResnetBlock_ca_mul(nn.Module): def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, pose_emb_dim=None, eye_emb_dim=None, groups=8): super().__init__() self.time_mlp = nn.Sequential( nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2) ) if exists(time_emb_dim) else None self.audio_mlp = nn.Sequential( nn.SiLU(), nn.Linear(audio_emb_dim, dim_out * 2) ) if exists(audio_emb_dim) else None self.pose_mlp = nn.Sequential( nn.SiLU(), nn.Linear(pose_emb_dim, dim_out * 2) ) if exists(pose_emb_dim) else None self.eye_mlp = nn.Sequential( nn.SiLU(), nn.Linear(eye_emb_dim, dim_out * 2) ) if exists(eye_emb_dim) else None self.audio_emb_dim = audio_emb_dim self.pose_emb_dim = pose_emb_dim self.eye_emb_dim = eye_emb_dim # self.audio_mlp_2 = nn.Sequential( # nn.SiLU(), # nn.Linear(dim_out, dim_out * 2) # ) if exists(audio_emb_dim) else None attn_klass = CrossAttention self.cross_attn_aud = attn_klass( dim = dim, context_dim = dim_out * 2, out_dim = dim_out ) self.cross_attn_pose = attn_klass( dim = dim, context_dim = dim_out * 2, out_dim = dim_out ) self.cross_attn_eye = attn_klass( dim = dim, context_dim = dim_out * 2, out_dim = dim_out ) self.block1 = Block(dim, dim_out, groups=groups) self.block2 = Block(dim_out, dim_out, groups=groups) self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb=None, audio_emb=None): time_scale_shift = None audio_scale_shift = None ''' need seperate 3 diffiserent condition ''' if exists(audio_emb): pose_emb = audio_emb[:,:,self.audio_emb_dim:self.audio_emb_dim + self.pose_emb_dim] eye_emb = audio_emb[:,:,self.audio_emb_dim + self.pose_emb_dim: ] audio_emb = audio_emb[:,:,:self.audio_emb_dim] b, c, f, H, W = x.size() if exists(self.time_mlp): assert exists(time_emb), 'time emb must be passed in' time_emb = self.time_mlp(time_emb) time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1 time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1 # added by lml to get audio embedding if exists(self.audio_mlp): # mouth lmk + audio emb assert exists(audio_emb), 'audio emb must be passed in' audio_emb = self.audio_mlp(audio_emb) pose_emb = self.pose_mlp(pose_emb) # TODO: embedding eye_emb = self.eye_mlp(eye_emb) if exists(self.cross_attn_aud): # h = rearrange(x, 'b c f ... -> (b f) ... c') # h, ps = pack([h], 'b * c') # audio_emb = rearrange(audio_emb, 'b f ... -> (b f) ...') # audio_emb = self.cross_attn(h, context = audio_emb) # # h, = unpack(h, ps, 'b * c') # # h = rearrange(h, '(b f) ... c -> b c f ...', b = b, f = f, c = c) # # audio_emb = self.audio_mlp_2(audio_emb) # audio_emb = rearrange(audio_emb, '(b f) ... -> b f ...', b = b, f = f) assert exists(audio_emb) h_cond = rearrange(x, 'b c f ... -> (b f) ... c') # h = rearrange(x, 'b c ... -> b ... c') h_cond, ps = pack([h_cond], 'b * c') h_pose = self.cross_attn_pose(h_cond, context = pose_emb) h_aud = self.cross_attn_aud(h_cond, context = audio_emb) h_eye = self.cross_attn_eye(h_cond, context = eye_emb) h_cond = h_pose + h_aud + h_eye h_cond, = unpack(h_cond, ps, 'b * c') # h = rearrange(h, 'b ... c -> b c ...') h_cond = rearrange(h_cond, '(b f) ... c -> b c f ...', b = b, f = f) # audio_emb = rearrange(audio_emb, 'b f (h w) c -> b c f h w', w = W, h = H) # bs, 128, nf, 1, 1 # audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1 h = self.block1(x, time_scale_shift=time_scale_shift) if exists(self.audio_mlp): h = h_cond + h h = self.block2(h) return h + self.res_conv(x) class CrossAttention(nn.Module): def __init__( self, dim, out_dim, *, context_dim = None, dim_head = 8, heads = 8, norm_context = False, scale = 8 ): super().__init__() self.scale = scale self.heads = heads inner_dim = dim_head * heads context_dim = default(context_dim, dim) self.norm = LayerNorm_img(dim) self.norm_context = LayerNorm_img(context_dim) if norm_context else Identity() self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) self.q_scale = nn.Parameter(torch.ones(dim_head)) self.k_scale = nn.Parameter(torch.ones(dim_head)) self.to_out = nn.Sequential( nn.Linear(inner_dim, out_dim, bias = False), LayerNorm_img(out_dim) ) def forward(self, x, context, mask = None): b, n, device = *x.shape[:2], x.device x = self.norm(x) # bn * fn ? # context: b, fn, c context = rearrange(context, 'b f c -> (b f) c') context = self.norm_context(context) q, k, v = (self.to_q(x), *self.to_kv(context[:, None, :]).chunk(2, dim = -1)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v)) # add null key / value for classifier free guidance in prior net nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2)) k = torch.cat((nk, k), dim = -2) v = torch.cat((nv, v), dim = -2) # cosine sim attention q, k = map(l2norm, (q, k)) q = q * self.q_scale k = k * self.k_scale # similarities sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale # masking max_neg_value = -torch.finfo(sim.dtype).max if exists(mask): mask = F.pad(mask, (1, 0), value = True) mask = rearrange(mask, 'b j -> b 1 1 j') sim = sim.masked_fill(~mask, max_neg_value) attn = sim.softmax(dim = -1, dtype = torch.float32) attn = attn.to(sim.dtype) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) class LinearCrossAttention(CrossAttention): def forward(self, x, context, mask = None): b, n, device = *x.shape[:2], x.device x = self.norm(x) context = rearrange(context, 'b f c -> (b f) c') context = self.norm_context(context) q, k, v = (self.to_q(x), *self.to_kv(context[:, None, :]).chunk(2, dim = -1)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v)) # add null key / value for classifier free guidance in prior net nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2)) k = torch.cat((nk, k), dim = -2) # b * nf * h, 2, c//h v = torch.cat((nv, v), dim = -2) # masking max_neg_value = -torch.finfo(x.dtype).max if exists(mask): mask = F.pad(mask, (1, 0), value = True) mask = rearrange(mask, 'b n -> b 1 n') k = k.masked_fill(~mask, max_neg_value) v = v.masked_fill(~mask, 0.) # linear attention q = q.softmax(dim = -1) # # b * nf * h, 32*32, c//h, k = k.softmax(dim = -2) q = q * self.scale context = einsum('b n d, b n e -> b d e', k, v) # b * nf * h, 2, c//h, b * nf * h, 2, c//h out = einsum('b n d, b d e -> b n e', q, context) out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads) return self.to_out(out) class SpatialLinearAttention(nn.Module): def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, f, h, w = x.shape x = rearrange(x, 'b c f h w -> (b f) c h w') qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = rearrange_many(qkv, 'b (h c) x y -> b h c (x y)', h=self.heads) q = q.softmax(dim=-2) k = k.softmax(dim=-1) q = q * self.scale context = torch.einsum('b h d n, b h e n -> b h d e', k, v) out = torch.einsum('b h d e, b h d n -> b h e n', context, q) out = rearrange(out, 'b h c (x y) -> b (h c) x y', h=self.heads, x=h, y=w) out = self.to_out(out) return rearrange(out, '(b f) c h w -> b c f h w', b=b) # attention along space and time class EinopsToAndFrom(nn.Module): def __init__(self, from_einops, to_einops, fn): super().__init__() self.from_einops = from_einops self.to_einops = to_einops self.fn = fn def forward(self, x, **kwargs): shape = x.shape reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape))) x = rearrange(x, f'{self.from_einops} -> {self.to_einops}') x = self.fn(x, **kwargs) x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs) return x class Attention(nn.Module): def __init__( self, dim, heads=4, dim_head=32, rotary_emb=None ): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads hidden_dim = dim_head * heads self.rotary_emb = rotary_emb self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False) self.to_out = nn.Linear(hidden_dim, dim, bias=False) def forward( self, x, pos_bias=None, focus_present_mask=None ): # temperal: 'b (h w) f c' ; spatial : 'b f (h w) c' n, device = x.shape[-2], x.device qkv = self.to_qkv(x).chunk(3, dim=-1) if exists(focus_present_mask) and focus_present_mask.all(): # if all batch samples are focusing on present # it would be equivalent to passing that token's values through to the output values = qkv[-1] return self.to_out(values) # split out heads q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads) # scale q = q * self.scale # rotate positions into queries and keys for time attention if exists(self.rotary_emb): q = self.rotary_emb.rotate_queries_or_keys(q) k = self.rotary_emb.rotate_queries_or_keys(k) # similarity sim = einsum('... h i d, ... h j d -> ... h i j', q, k) # relative positional bias if exists(pos_bias): sim = sim + pos_bias if exists(focus_present_mask) and not (~focus_present_mask).all(): attend_all_mask = torch.ones((n, n), device=device, dtype=torch.bool) attend_self_mask = torch.eye(n, device=device, dtype=torch.bool) mask = torch.where( rearrange(focus_present_mask, 'b -> b 1 1 1 1'), rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), ) sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) # numerical stability sim = sim - sim.amax(dim=-1, keepdim=True).detach() attn = sim.softmax(dim=-1) # aggregate values out = einsum('... h i j, ... h j d -> ... h i d', attn, v) out = rearrange(out, '... h n d -> ... n (h d)') return self.to_out(out) # model class Unet3D(nn.Module): def __init__( self, dim, cond_aud=1024, cond_pose=7, cond_eye=2, cond_dim=None, out_grid_dim=2, out_conf_dim=1, num_frames=40, dim_mults=(1, 2, 4, 8), channels=3, attn_heads=8, attn_dim_head=32, use_hubert_audio_cond=False, init_dim=None, init_kernel_size=7, use_sparse_linear_attn=True, resnet_groups=8, use_final_activation=False, learn_null_cond=False, use_deconv=True, padding_mode="zeros", ): super().__init__() self.null_cond_mask = None self.channels = channels self.num_frames = num_frames self.HUBERT_MODEL_DIM = 1024 # temporal attention and its relative positional encoding rotary_emb = RotaryEmbedding(min(32, attn_dim_head)) temporal_attn = lambda dim: EinopsToAndFrom('b c f h w', 'b (h w) f c', Attention(dim, heads=attn_heads, dim_head=attn_dim_head, rotary_emb=rotary_emb)) self.time_rel_pos_bias = RelativePositionBias(heads=attn_heads, max_distance=32) # realistically will not be able to generate that many frames of video... yet # initial conv init_dim = default(init_dim, dim) assert is_odd(init_kernel_size) init_padding = init_kernel_size // 2 self.init_conv = nn.Conv3d(channels, init_dim, (1, init_kernel_size, init_kernel_size), padding=(0, init_padding, init_padding)) self.init_temporal_attn = Residual(PreNorm(init_dim, temporal_attn(init_dim))) # dimensions dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) # time conditioning time_dim = dim * 4 self.time_mlp = nn.Sequential( SinusoidalPosEmb(dim), nn.Linear(dim, time_dim), nn.GELU(), nn.Linear(time_dim, time_dim) ) # audio conditioning self.has_cond = exists(cond_dim) or use_hubert_audio_cond self.cond_dim = cond_dim self.cond_aud_dim = cond_aud self.cond_pose_dim = cond_pose self.cond_eye_dim = cond_eye # modified by lml self.learn_null_cond = learn_null_cond # cat(t,cond) is not suitable # cond_dim = time_dim + int(cond_dim or 0) # layers self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) num_resolutions = len(in_out) # block type block_klass = partial(ResnetBlock_ca_mul, groups=resnet_groups) block_klass_cond = partial(block_klass, time_emb_dim=time_dim, audio_emb_dim=self.cond_aud_dim, pose_emb_dim=self.cond_pose_dim, eye_emb_dim=self.cond_eye_dim) # block_klass_cond = partial(block_klass, time_emb_dim=cond_dim) # cat embedding # modules for all layers for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) self.downs.append(nn.ModuleList([ block_klass_cond(dim_in, dim_out), block_klass_cond(dim_out, dim_out), Residual(PreNorm(dim_out, SpatialLinearAttention(dim_out, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(), Residual(PreNorm(dim_out, temporal_attn(dim_out))), Downsample(dim_out) if not is_last else nn.Identity() ])) mid_dim = dims[-1] self.mid_block1 = block_klass_cond(mid_dim, mid_dim) spatial_attn = EinopsToAndFrom('b c f h w', 'b f (h w) c', Attention(mid_dim, heads=attn_heads)) self.mid_spatial_attn = Residual(PreNorm(mid_dim, spatial_attn)) self.mid_temporal_attn = Residual(PreNorm(mid_dim, temporal_attn(mid_dim))) self.mid_block2 = block_klass_cond(mid_dim, mid_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): is_last = ind >= (num_resolutions - 1) self.ups.append(nn.ModuleList([ block_klass_cond(dim_out * 2, dim_in), block_klass_cond(dim_in, dim_in), Residual(PreNorm(dim_in, SpatialLinearAttention(dim_in, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(), Residual(PreNorm(dim_in, temporal_attn(dim_in))), Upsample(dim_in, use_deconv, padding_mode) if not is_last else nn.Identity() ])) # out_dim = default(out_grid_dim, channels) self.final_conv = nn.Sequential( block_klass(dim * 2, dim), nn.Conv3d(dim, out_grid_dim, 1) ) # added by nhm self.use_final_activation = use_final_activation if self.use_final_activation: self.final_activation = nn.Tanh() else: self.final_activation = nn.Identity() # added by nhm for predicting occlusion mask self.occlusion_map = nn.Sequential( block_klass(dim * 2, dim), nn.Conv3d(dim, out_conf_dim, 1) ) def forward_with_cond_scale( self, *args, cond_scale=2., **kwargs ): logits = self.forward(*args, null_cond_prob=0., **kwargs) if cond_scale == 1 or not self.has_cond: return logits null_logits = self.forward(*args, null_cond_prob=1., **kwargs) return null_logits + (logits - null_logits) * cond_scale def forward( self, x, time, cond=None, null_cond_prob=0., focus_present_mask=None, prob_focus_present=0. # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time) ): assert not (self.has_cond and not exists(cond)), 'cond must be passed in if cond_dim specified' batch, device = x.shape[0], x.device focus_present_mask = default(focus_present_mask, lambda: prob_mask_like((batch,), prob_focus_present, device=device)) time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device=x.device) x = self.init_conv(x) r = x.clone() x = self.init_temporal_attn(x, pos_bias=time_rel_pos_bias) t = self.time_mlp(time) if exists(self.time_mlp) else None if self.learn_null_cond: self.null_cond_emb = nn.Parameter(torch.randn(1, self.num_frames, self.cond_dim)) if self.has_cond else None else: self.null_cond_emb = torch.zeros(1, self.num_frames, self.cond_dim) if self.has_cond else None # classifier free guidance if self.has_cond: batch, device = x.shape[0], x.device self.null_cond_mask = prob_mask_like((batch, self.num_frames,), null_cond_prob, device=device) cond = torch.where(rearrange(self.null_cond_mask, 'b n -> b n 1'), self.null_cond_emb.to(cond.device), cond) # t (bs, 256) cond (bs, nf*1024)->(bs, nf, 1024) in this version # it's the original cond embedding method used in LFDM # t = torch.cat((t, cond), dim=-1) h = [] for block1, block2, spatial_attn, temporal_attn, downsample in self.downs: x = block1(x, t, cond) x = block2(x, t, cond) x = spatial_attn(x) x = temporal_attn(x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask) h.append(x) x = downsample(x) x = self.mid_block1(x, t, cond) x = self.mid_spatial_attn(x) x = self.mid_temporal_attn(x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask) x = self.mid_block2(x, t, cond) for block1, block2, spatial_attn, temporal_attn, upsample in self.ups: x = torch.cat((x, h.pop()), dim=1) x = block1(x, t, cond) x = block2(x, t, cond) x = spatial_attn(x) x = temporal_attn(x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask) x = upsample(x) x = torch.cat((x, r), dim=1) return torch.cat((self.final_conv(x), self.occlusion_map(x)), dim=1) # to dynamically change num_frames of Unet3D class DynamicNfUnet3D(Unet3D): def __init__(self, default_num_frames=20, *args, **kwargs): super(DynamicNfUnet3D, self).__init__(*args, **kwargs) self.default_num_frames = default_num_frames self.num_frames = default_num_frames def update_num_frames(self, new_num_frames): self.num_frames = new_num_frames # gaussian diffusion trainer class def extract(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) def cosine_beta_schedule(timesteps, s=0.008): """ cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ """ steps = timesteps + 1 x = torch.linspace(0, timesteps, steps, dtype=torch.float64) alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.9999) class GaussianDiffusion(nn.Module): def __init__( self, denoise_fn, *, image_size, num_frames, text_use_bert_cls=False, channels=3, timesteps=1000, sampling_timesteps=250, ddim_sampling_eta=1., loss_type='l1', use_dynamic_thres=False, # from the Imagen paper dynamic_thres_percentile=0.9, null_cond_prob=0.1 ): super().__init__() self.null_cond_prob = null_cond_prob self.channels = channels self.image_size = image_size self.num_frames = num_frames self.denoise_fn = denoise_fn betas = cosine_beta_schedule(timesteps) alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, axis=0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.) timesteps, = betas.shape self.num_timesteps = int(timesteps) self.loss_type = loss_type self.sampling_timesteps = default(sampling_timesteps, timesteps) self.is_ddim_sampling = self.sampling_timesteps < timesteps self.ddim_sampling_eta = ddim_sampling_eta # register buffer helper function that casts float64 to float32 register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) register_buffer('betas', betas) register_buffer('alphas_cumprod', alphas_cumprod) register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) # calculations for diffusion q(x_t | x_{t-1}) and others register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) register_buffer('posterior_variance', posterior_variance) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20))) register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) # text conditioning parameters self.text_use_bert_cls = text_use_bert_cls # dynamic thresholding when sampling self.use_dynamic_thres = use_dynamic_thres self.dynamic_thres_percentile = dynamic_thres_percentile def q_mean_variance(self, x_start, t): mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start variance = extract(1. - self.alphas_cumprod, t, x_start.shape) log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) return mean, variance, log_variance def predict_start_from_noise(self, x_t, t, noise): return ( extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) def q_posterior(self, x_start, x_t, t): posterior_mean = ( extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = extract(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped def p_mean_variance(self, x, t, fea, clip_denoised: bool, cond=None, cond_scale=1.): fea = fea.unsqueeze(dim=2).repeat(1, 1, x.size(2), 1, 1) x_recon = self.predict_start_from_noise(x, t=t, noise=self.denoise_fn.forward_with_cond_scale(torch.cat([x, fea], dim=1), t, cond=cond, cond_scale=cond_scale)) if clip_denoised: s = 1. if self.use_dynamic_thres: s = torch.quantile( rearrange(x_recon, 'b ... -> b (...)').abs(), self.dynamic_thres_percentile, dim=-1 ) s.clamp_(min=1.) s = s.view(-1, *((1,) * (x_recon.ndim - 1))) # clip by threshold, depending on whether static or dynamic x_recon = x_recon.clamp(-s, s) / s model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) return model_mean, posterior_variance, posterior_log_variance @torch.inference_mode() def p_sample(self, x, t, fea, cond=None, cond_scale=1., clip_denoised=True): b, *_, device = *x.shape, x.device model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, fea=fea, clip_denoised=clip_denoised, cond=cond, cond_scale=cond_scale) noise = torch.randn_like(x) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.inference_mode() def p_sample_loop(self, fea, shape, cond=None, cond_scale=1.): device = self.betas.device b = shape[0] img = torch.randn(shape, device=device) for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), fea, cond=cond, cond_scale=cond_scale) return img # return unnormalize_img(img) @torch.inference_mode() def sample(self, fea, bbox_mask, cond=None, cond_scale=1., batch_size=16): # text bert: cond 1,768 # device = next(self.denoise_fn.parameters()).device # if is_list_str(cond): # cond = torch.rand((1 ,768), dtype=torch.float32).cuda() #used to debug # cond = bert_embed(tokenize(cond), return_cls_repr=self.text_use_bert_cls).to(device) batch_size = cond.shape[0] if exists(cond) else batch_size # batch_size = 1 if exists(cond) else batch_size image_size = self.image_size channels = self.channels num_frames = self.num_frames sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample fea = torch.cat([fea, bbox_mask], dim=1) return sample_fn(fea, (batch_size, channels, num_frames, image_size, image_size), cond=cond, cond_scale=cond_scale) # add by nhm @torch.no_grad() def ddim_sample(self, fea, shape, cond=None, cond_scale=1., clip_denoised=True): batch, device, total_timesteps, sampling_timesteps, eta = \ shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta times = torch.linspace(0., total_timesteps, steps=sampling_timesteps + 2)[:-1] times = list(reversed(times.int().tolist())) time_pairs = list(zip(times[:-1], times[1:])) img = torch.randn(shape, device=device) # bs, 3, nf, 32, 32 fea = fea.unsqueeze(dim=2).repeat(1, 1, img.size(2), 1, 1) #bs, 256, nf, 32, 32 for time, time_next in tqdm(time_pairs, desc='sampling loop time step'): alpha = self.alphas_cumprod_prev[time] alpha_next = self.alphas_cumprod_prev[time_next] time_cond = torch.full((batch,), time, device=device, dtype=torch.long) # pred_noise, x_start, *_ = self.model_predictions(img, time_cond, fea) pred_noise = self.denoise_fn.forward_with_cond_scale( torch.cat([img, fea], dim=1), time_cond, cond=cond, cond_scale=cond_scale) x_start = self.predict_start_from_noise(img, t=time_cond, noise=pred_noise) if clip_denoised: s = 1. if self.use_dynamic_thres: s = torch.quantile( rearrange(x_start, 'b ... -> b (...)').abs(), self.dynamic_thres_percentile, dim=-1 ) s.clamp_(min=1.) s = s.view(-1, *((1,) * (x_start.ndim - 1))) # clip by threshold, depending on whether static or dynamic x_start = x_start.clamp(-s, s) / s sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() c = ((1 - alpha_next) - sigma ** 2).sqrt() noise = torch.randn_like(img) if time_next > 0 else 0. img = x_start * alpha_next.sqrt() + \ c * pred_noise + \ sigma * noise # img = unnormalize_to_zero_to_one(img) return img @torch.inference_mode() def interpolate(self, x1, x2, t=None, lam=0.5): b, *_, device = *x1.shape, x1.device t = default(t, self.num_timesteps - 1) assert x1.shape == x2.shape t_batched = torch.stack([torch.tensor(t, device=device)] * b) xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) img = (1 - lam) * xt1 + lam * xt2 for i in tqdm(reversed(range(0, t)), desc='interpolation sample time step', total=t): img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long)) return img def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) return ( extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) def p_losses(self, x_start, t, fea, bbox_mask, cond=None, noise=None, clip_denoised=True, **kwargs): # x_start: bs, 3, num_frame, 32, 32 # t: bs # fea: bs, 256, num_frame, 32, 32 # cond: bs, 768 b, c, f, h, w, device = *x_start.shape, x_start.device noise = default(noise, lambda: torch.randn_like(x_start)) # bs, 3, nf, 32, 32 x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)# bs, 3, nf, 32, 32 pred_noise = self.denoise_fn.forward(torch.cat([x_noisy, fea, bbox_mask], dim=1), t, cond=cond, null_cond_prob=self.null_cond_prob, **kwargs) if self.loss_type == 'l1': loss = F.l1_loss(noise, pred_noise, reduce=False) elif self.loss_type == 'l2': loss = F.mse_loss(noise, pred_noise, reduce=False) else: raise NotImplementedError() pred_x0 = self.predict_start_from_noise(x_noisy, t, pred_noise) if clip_denoised: s = 1. if self.use_dynamic_thres: s = torch.quantile( rearrange(pred_x0, 'b ... -> b (...)').abs(), self.dynamic_thres_percentile, dim=-1 ) s.clamp_(min=1.) s = s.view(-1, *((1,) * (pred_x0.ndim - 1))) # clip by threshold, depending on whether static or dynamic self.pred_x0 = pred_x0.clamp(-s, s) / s return loss, self.denoise_fn.null_cond_mask def forward(self, x, fea, bbox_mask, cond, *args, **kwargs): b, device, img_size, = x.shape[0], x.device, self.image_size # check_shape(x, 'b c f h w', c=self.channels, f=self.num_frames, h=img_size, w=img_size) t = torch.randint(0, self.num_timesteps, (b,), device=device).long() fea = fea.unsqueeze(dim=2).repeat(1, 1, x.size(2), 1, 1) bbox_mask = bbox_mask.unsqueeze(dim=2).repeat(1, 1, x.size(2), 1, 1) return self.p_losses(x, t, fea, bbox_mask, cond, *args, **kwargs) # trainer class CHANNELS_TO_MODE = { 1: 'L', 3: 'RGB', 4: 'RGBA' } def seek_all_images(img, channels=3): assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid' mode = CHANNELS_TO_MODE[channels] i = 0 while True: try: img.seek(i) yield img.convert(mode) except EOFError: break i += 1 # to dynamically change num_frames of GaussianDiffusion class DynamicNfGaussianDiffusion(GaussianDiffusion): def __init__(self, default_num_frames=20, *args, **kwargs): super(DynamicNfGaussianDiffusion, self).__init__(*args, **kwargs) self.default_num_frames = default_num_frames self.num_frames = default_num_frames def update_num_frames(self, new_num_frames): self.num_frames = new_num_frames # tensor of shape (channels, frames, height, width) -> gif def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True): images = map(T.ToPILImage(), tensor.unbind(dim=1)) first_img, *rest_imgs = images first_img.save(path, save_all=True, append_images=rest_imgs, duration=duration, loop=loop, optimize=optimize) return images # gif -> (channels, frame, height, width) tensor def gif_to_tensor(path, channels=3, transform=T.ToTensor()): img = Image.open(path) tensors = tuple(map(transform, seek_all_images(img, channels=channels))) return torch.stack(tensors, dim=1) def identity(t, *args, **kwargs): return t def normalize_img(t): return t * 2 - 1 # def unnormalize_img(t): # return (t + 1) * 0.5 def cast_num_frames(t, *, frames): f = t.shape[1] if f == frames: return t if f > frames: return t[:, :frames] return F.pad(t, (0, 0, 0, 0, 0, frames - f)) ================================================ FILE: DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test.py ================================================ ''' adding pose condtioning on baseline using cross attention to add different condition using local attention, for inference, faster cost more ram ''' import math import torch from torch import nn, einsum import torch.nn.functional as F from functools import partial from torchvision import transforms as T from PIL import Image from tqdm import tqdm from einops import rearrange, repeat, reduce, pack, unpack from einops_exts import rearrange_many from rotary_embedding_torch import RotaryEmbedding # from DM.modules.text import tokenize, bert_embed, HUBERT_MODEL_DIM # helpers functions def exists(x): return x is not None def noop(*args, **kwargs): pass def is_odd(n): return (n % 2) == 1 def default(val, d): if exists(val): return val return d() if callable(d) else d def cycle(dl): while True: for data in dl: yield data def num_to_groups(num, divisor): groups = num // divisor remainder = num % divisor arr = [divisor] * groups if remainder > 0: arr.append(remainder) return arr def prob_mask_like(shape, prob, device): if prob == 1: return torch.ones(shape, device=device, dtype=torch.bool) elif prob == 0: return torch.zeros(shape, device=device, dtype=torch.bool) else: return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob def is_list_str(x): if not isinstance(x, (list, tuple)): return False return all([type(el) == str for el in x]) # relative positional bias class RelativePositionBias(nn.Module): def __init__( self, heads=8, num_buckets=32, max_distance=128, window_width = 20 ): super().__init__() self.num_buckets = num_buckets self.max_distance = max_distance self.relative_attention_bias = nn.Embedding(num_buckets, heads) self.window_width = window_width @staticmethod def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): ret = 0 n = -relative_position num_buckets //= 2 ret += (n < 0).long() * num_buckets n = torch.abs(n) max_exact = num_buckets // 2 is_small = n < max_exact val_if_large = max_exact + ( torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).long() val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) ret += torch.where(is_small, n, val_if_large) return ret def forward(self, n, device): q_pos = torch.arange(n, dtype=torch.long, device=device) k_pos = torch.arange(n, dtype=torch.long, device=device) rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') rp_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance) mask = -(((rel_pos > self.window_width) + (rel_pos < -self.window_width)) * (1e8)) # -(((rp_bucket ==15) + (rp_bucket >= 30)) * (1e8)) values = self.relative_attention_bias(rp_bucket) return rearrange(values, 'i j h -> h i j') + mask # small helper modules class EMA(): def __init__(self, beta): super().__init__() self.beta = beta def update_model_average(self, ma_model, current_model): for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): old_weight, up_weight = ma_params.data, current_params.data ma_params.data = self.update_average(old_weight, up_weight) def update_average(self, old, new): if old is None: return new return old * self.beta + (1 - self.beta) * new class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, *args, **kwargs): return self.fn(x, *args, **kwargs) + x class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb def Upsample(dim, use_deconv=True, padding_mode="reflect"): if use_deconv: return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) else: return nn.Sequential( nn.Upsample(scale_factor=(1, 2, 2), mode='nearest'), nn.Conv3d(dim, dim, (1, 3, 3), (1, 1, 1), (0, 1, 1), padding_mode=padding_mode) ) def Downsample(dim): return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) class LayerNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1)) def forward(self, x): var = torch.var(x, dim=1, unbiased=False, keepdim=True) mean = torch.mean(x, dim=1, keepdim=True) return (x - mean) / (var + self.eps).sqrt() * self.gamma class LayerNorm_img(nn.Module): def __init__(self, dim, stable = False): super().__init__() self.stable = stable self.g = nn.Parameter(torch.ones(dim)) def forward(self, x): if self.stable: x = x / x.amax(dim = -1, keepdim = True).detach() eps = 1e-5 if x.dtype == torch.float32 else 1e-3 var = torch.var(x, dim = -1, unbiased = False, keepdim = True) mean = torch.mean(x, dim = -1, keepdim = True) return (x - mean) * (var + eps).rsqrt() * self.g class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = LayerNorm(dim) def forward(self, x, **kwargs): x = self.norm(x) return self.fn(x, **kwargs) class Identity(nn.Module): def __init__(self, *args, **kwargs): super().__init__() def forward(self, x, *args, **kwargs): return x def l2norm(t): return F.normalize(t, dim = -1) # building block modules class Block(nn.Module): def __init__(self, dim, dim_out, groups=8): super().__init__() self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding=(0, 1, 1)) self.norm = nn.GroupNorm(groups, dim_out) self.act = nn.SiLU() def forward(self, x, time_scale_shift=None, audio_scale_shift=None): x = self.proj(x) x = self.norm(x) if exists(time_scale_shift): time_scale, time_shift = time_scale_shift x = x * (time_scale + 1) + time_shift # added by lml to change the control method of audio embedding, inspired by diffusedhead # if exists(audio_scale_shift): # # audio_scale and audio_shift:(bs, 64, nf, 1, 1) # # x:(bs, 64, nf, 32, 32) # audio_scale, audio_shift = audio_scale_shift # x = x * (audio_scale + 1) + audio_shift return self.act(x) class ResnetBlock(nn.Module): def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, groups=8): super().__init__() self.time_mlp = nn.Sequential( nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2) ) if exists(time_emb_dim) else None self.audio_mlp = nn.Sequential( nn.SiLU(), nn.Linear(audio_emb_dim, dim_out * 2) ) if exists(audio_emb_dim) else None self.block1 = Block(dim, dim_out, groups=groups) self.block2 = Block(dim_out, dim_out, groups=groups) self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb=None, audio_emb=None): time_scale_shift = None audio_scale_shift = None if exists(self.time_mlp): assert exists(time_emb), 'time emb must be passed in' time_emb = self.time_mlp(time_emb) time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1 time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1 # added by lml to get audio embedding if exists(self.audio_mlp): assert exists(audio_emb), 'audio emb must be passed in' audio_emb = self.audio_mlp(audio_emb) audio_emb = rearrange(audio_emb, 'b n c -> b c n 1 1') # bs, 128, nf, 1, 1 audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1 h = self.block1(x, time_scale_shift=time_scale_shift, audio_scale_shift=audio_scale_shift) h = self.block2(h) return h + self.res_conv(x) class ResnetBlock_ca(nn.Module): def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, groups=8): super().__init__() self.time_mlp = nn.Sequential( nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2) ) if exists(time_emb_dim) else None self.audio_mlp = nn.Sequential( nn.SiLU(), nn.Linear(audio_emb_dim, dim_out * 2) ) if exists(audio_emb_dim) else None # self.audio_mlp_2 = nn.Sequential( # nn.SiLU(), # nn.Linear(dim_out, dim_out * 2) # ) if exists(audio_emb_dim) else None attn_klass = CrossAttention self.cross_attn = attn_klass( dim = dim, context_dim = dim_out * 2 ) self.block1 = Block(dim, dim_out, groups=groups) self.block2 = Block(dim_out, dim_out, groups=groups) self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb=None, audio_emb=None): time_scale_shift = None audio_scale_shift = None b, c, f, H, W = x.size() if exists(self.time_mlp): assert exists(time_emb), 'time emb must be passed in' time_emb = self.time_mlp(time_emb) time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1 time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1 # added by lml to get audio embedding if exists(self.audio_mlp): assert exists(audio_emb), 'audio emb must be passed in' audio_emb = self.audio_mlp(audio_emb) if exists(self.cross_attn): # h = rearrange(x, 'b c f ... -> (b f) ... c') # h, ps = pack([h], 'b * c') # audio_emb = rearrange(audio_emb, 'b f ... -> (b f) ...') # audio_emb = self.cross_attn(h, context = audio_emb) # # h, = unpack(h, ps, 'b * c') # # h = rearrange(h, '(b f) ... c -> b c f ...', b = b, f = f, c = c) # # audio_emb = self.audio_mlp_2(audio_emb) # audio_emb = rearrange(audio_emb, '(b f) ... -> b f ...', b = b, f = f) assert exists(audio_emb) h = rearrange(x, 'b c f ... -> (b f) ... c') # h = rearrange(x, 'b c ... -> b ... c') h, ps = pack([h], 'b * c') h = self.cross_attn(h, context = audio_emb) + h h, = unpack(h, ps, 'b * c') # h = rearrange(h, 'b ... c -> b c ...') h = rearrange(h, '(b f) ... c -> b f c ...', b = b, f = f) # audio_emb = rearrange(audio_emb, 'b f (h w) c -> b c f h w', w = W, h = H) # bs, 128, nf, 1, 1 # audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1 h = self.block1(x, time_scale_shift=time_scale_shift) h = self.block2(h) return h + self.res_conv(x) class ResnetBlock_ca_mul(nn.Module): def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, pose_emb_dim=None, eye_emb_dim=None, groups=8): super().__init__() self.time_mlp = nn.Sequential( nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2) ) if exists(time_emb_dim) else None self.audio_mlp = nn.Sequential( nn.SiLU(), nn.Linear(audio_emb_dim, dim_out * 2) ) if exists(audio_emb_dim) else None self.pose_mlp = nn.Sequential( nn.SiLU(), nn.Linear(pose_emb_dim, dim_out * 2) ) if exists(pose_emb_dim) else None self.eye_mlp = nn.Sequential( nn.SiLU(), nn.Linear(eye_emb_dim, dim_out * 2) ) if exists(eye_emb_dim) else None self.audio_emb_dim = audio_emb_dim self.pose_emb_dim = pose_emb_dim self.eye_emb_dim = eye_emb_dim # self.audio_mlp_2 = nn.Sequential( # nn.SiLU(), # nn.Linear(dim_out, dim_out * 2) # ) if exists(audio_emb_dim) else None attn_klass = CrossAttention self.cross_attn_aud = attn_klass( dim = dim, context_dim = dim_out * 2, out_dim = dim_out ) self.cross_attn_pose = attn_klass( dim = dim, context_dim = dim_out * 2, out_dim = dim_out ) self.cross_attn_eye = attn_klass( dim = dim, context_dim = dim_out * 2, out_dim = dim_out ) self.block1 = Block(dim, dim_out, groups=groups) self.block2 = Block(dim_out, dim_out, groups=groups) self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb=None, audio_emb=None): time_scale_shift = None audio_scale_shift = None ''' need seperate 3 diffiserent condition ''' if exists(audio_emb): pose_emb = audio_emb[:,:,self.audio_emb_dim:self.audio_emb_dim + self.pose_emb_dim] eye_emb = audio_emb[:,:,self.audio_emb_dim + self.pose_emb_dim: ] audio_emb = audio_emb[:,:,:self.audio_emb_dim] b, c, f, H, W = x.size() if exists(self.time_mlp): assert exists(time_emb), 'time emb must be passed in' time_emb = self.time_mlp(time_emb) time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1 time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1 # added by lml to get audio embedding if exists(self.audio_mlp): # mouth lmk + audio emb assert exists(audio_emb), 'audio emb must be passed in' audio_emb = self.audio_mlp(audio_emb) pose_emb = self.pose_mlp(pose_emb) # TODO: embedding eye_emb = self.eye_mlp(eye_emb) if exists(self.cross_attn_aud): # h = rearrange(x, 'b c f ... -> (b f) ... c') # h, ps = pack([h], 'b * c') # audio_emb = rearrange(audio_emb, 'b f ... -> (b f) ...') # audio_emb = self.cross_attn(h, context = audio_emb) # # h, = unpack(h, ps, 'b * c') # # h = rearrange(h, '(b f) ... c -> b c f ...', b = b, f = f, c = c) # # audio_emb = self.audio_mlp_2(audio_emb) # audio_emb = rearrange(audio_emb, '(b f) ... -> b f ...', b = b, f = f) assert exists(audio_emb) h_cond = rearrange(x, 'b c f ... -> (b f) ... c') # h = rearrange(x, 'b c ... -> b ... c') h_cond, ps = pack([h_cond], 'b * c') h_pose = self.cross_attn_pose(h_cond, context = pose_emb) h_aud = self.cross_attn_aud(h_cond, context = audio_emb) h_eye = self.cross_attn_eye(h_cond, context = eye_emb) h_cond = h_pose + h_aud + h_eye h_cond, = unpack(h_cond, ps, 'b * c') # h = rearrange(h, 'b ... c -> b c ...') h_cond = rearrange(h_cond, '(b f) ... c -> b c f ...', b = b, f = f) # audio_emb = rearrange(audio_emb, 'b f (h w) c -> b c f h w', w = W, h = H) # bs, 128, nf, 1, 1 # audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1 h = self.block1(x, time_scale_shift=time_scale_shift) if exists(self.audio_mlp): h = h_cond + h h = self.block2(h) return h + self.res_conv(x) class CrossAttention(nn.Module): def __init__( self, dim, out_dim, *, context_dim = None, dim_head = 8, heads = 8, norm_context = False, scale = 8 ): super().__init__() self.scale = scale self.heads = heads inner_dim = dim_head * heads context_dim = default(context_dim, dim) self.norm = LayerNorm_img(dim) self.norm_context = LayerNorm_img(context_dim) if norm_context else Identity() self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) self.q_scale = nn.Parameter(torch.ones(dim_head)) self.k_scale = nn.Parameter(torch.ones(dim_head)) self.to_out = nn.Sequential( nn.Linear(inner_dim, out_dim, bias = False), LayerNorm_img(out_dim) ) def forward(self, x, context, mask = None): b, n, device = *x.shape[:2], x.device x = self.norm(x) # bn * fn ? # context: b, fn, c context = rearrange(context, 'b f c -> (b f) c') context = self.norm_context(context) q, k, v = (self.to_q(x), *self.to_kv(context[:, None, :]).chunk(2, dim = -1)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v)) # add null key / value for classifier free guidance in prior net nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2)) k = torch.cat((nk, k), dim = -2) v = torch.cat((nv, v), dim = -2) # cosine sim attention q, k = map(l2norm, (q, k)) q = q * self.q_scale k = k * self.k_scale # similarities sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale # masking max_neg_value = -torch.finfo(sim.dtype).max if exists(mask): mask = F.pad(mask, (1, 0), value = True) mask = rearrange(mask, 'b j -> b 1 1 j') sim = sim.masked_fill(~mask, max_neg_value) attn = sim.softmax(dim = -1, dtype = torch.float32) attn = attn.to(sim.dtype) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) class LinearCrossAttention(CrossAttention): def forward(self, x, context, mask = None): b, n, c = x.size() b, n, device = *x.shape[:2], x.device # x : b * fn, 32*32, c x = self.norm(x) context = self.norm_context(context) q, k, v = (self.to_q(x), *self.to_kv(context[:, None, :]).chunk(2, dim = -1)) # b*fn, 32*32, c, b*fn, 1, c * 2, q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads, d = c//self.heads), (q, k, v)) # head * b*fn, n, c//head # add null key / value for classifier free guidance in prior net nk, nv = map(lambda t: repeat(t, 'd -> (b h) 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2)) k = torch.cat((nk, k), dim = -2) # b * nf * h, 2, c//h v = torch.cat((nv, v), dim = -2) # masking max_neg_value = -torch.finfo(x.dtype).max if exists(mask): mask = F.pad(mask, (1, 0), value = True) mask = rearrange(mask, 'b n -> b n 1') k = k.masked_fill(~mask, max_neg_value) v = v.masked_fill(~mask, 0.) # linear attention q = q.softmax(dim = -1) # # b * nf * h, 32*32, c//h, k = k.softmax(dim = -2) q = q * self.scale context = einsum('b n d, b n e -> b d e', k, v) # b * nf * h, 2, c//h, b * nf * h, 2, c//h out = einsum('b n d, b d e -> b n e', q, context) out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads) return self.to_out(out) class SpatialLinearAttention(nn.Module): def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, f, h, w = x.shape x = rearrange(x, 'b c f h w -> (b f) c h w') qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = rearrange_many(qkv, 'b (h c) x y -> b h c (x y)', h=self.heads) q = q.softmax(dim=-2) k = k.softmax(dim=-1) q = q * self.scale context = torch.einsum('b h d n, b h e n -> b h d e', k, v) out = torch.einsum('b h d e, b h d n -> b h e n', context, q) out = rearrange(out, 'b h c (x y) -> b (h c) x y', h=self.heads, x=h, y=w) out = self.to_out(out) return rearrange(out, '(b f) c h w -> b c f h w', b=b) # attention along space and time class EinopsToAndFrom(nn.Module): def __init__(self, from_einops, to_einops, fn): super().__init__() self.from_einops = from_einops self.to_einops = to_einops self.fn = fn def forward(self, x, **kwargs): shape = x.shape reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape))) x = rearrange(x, f'{self.from_einops} -> {self.to_einops}') x = self.fn(x, **kwargs) x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs) return x class Attention(nn.Module): def __init__( self, dim, heads=4, dim_head=32, rotary_emb=None ): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads hidden_dim = dim_head * heads self.rotary_emb = rotary_emb self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False) self.to_out = nn.Linear(hidden_dim, dim, bias=False) def forward( self, x, pos_bias=None, focus_present_mask=None ): # temperal: 'b (h w) f c' ; spatial : 'b f (h w) c' n, device = x.shape[-2], x.device qkv = self.to_qkv(x).chunk(3, dim=-1) if exists(focus_present_mask) and focus_present_mask.all(): # if all batch samples are focusing on present # it would be equivalent to passing that token's values through to the output values = qkv[-1] return self.to_out(values) # split out heads q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads) # scale q = q * self.scale # rotate positions into queries and keys for time attention if exists(self.rotary_emb): q = self.rotary_emb.rotate_queries_or_keys(q) k = self.rotary_emb.rotate_queries_or_keys(k) # similarity sim = einsum('... h i d, ... h j d -> ... h i j', q, k) # relative positional bias if exists(pos_bias): sim = sim + pos_bias if exists(focus_present_mask) and not (~focus_present_mask).all(): attend_all_mask = torch.ones((n, n), device=device, dtype=torch.bool) attend_self_mask = torch.eye(n, device=device, dtype=torch.bool) mask = torch.where( rearrange(focus_present_mask, 'b -> b 1 1 1 1'), rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), ) sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) # numerical stability sim = sim - sim.amax(dim=-1, keepdim=True).detach() attn = sim.softmax(dim=-1) # aggregate values out = einsum('... h i j, ... h j d -> ... h i d', attn, v) out = rearrange(out, '... h n d -> ... n (h d)') return self.to_out(out) # model class Unet3D(nn.Module): def __init__( self, dim, cond_aud=1024, cond_pose=7, cond_eye=2, cond_dim=None, out_grid_dim=2, out_conf_dim=1, num_frames=40, dim_mults=(1, 2, 4, 8), channels=3, attn_heads=8, attn_dim_head=32, use_hubert_audio_cond=False, init_dim=None, init_kernel_size=7, use_sparse_linear_attn=True, resnet_groups=8, use_final_activation=False, learn_null_cond=False, use_deconv=True, padding_mode="zeros", win_width = 20 ): super().__init__() self.null_cond_mask = None self.channels = channels self.num_frames = num_frames self.HUBERT_MODEL_DIM = 1024 # temporal attention and its relative positional encoding rotary_emb = RotaryEmbedding(min(32, attn_dim_head)) temporal_attn = lambda dim: EinopsToAndFrom('b c f h w', 'b (h w) f c', Attention(dim, heads=attn_heads, dim_head=attn_dim_head, rotary_emb=rotary_emb)) self.time_rel_pos_bias = RelativePositionBias(heads=attn_heads, max_distance=32, window_width = win_width) # realistically will not be able to generate that many frames of video... yet # initial conv init_dim = default(init_dim, dim) assert is_odd(init_kernel_size) init_padding = init_kernel_size // 2 self.init_conv = nn.Conv3d(channels, init_dim, (1, init_kernel_size, init_kernel_size), padding=(0, init_padding, init_padding)) self.init_temporal_attn = Residual(PreNorm(init_dim, temporal_attn(init_dim))) # dimensions dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) # time conditioning time_dim = dim * 4 self.time_mlp = nn.Sequential( SinusoidalPosEmb(dim), nn.Linear(dim, time_dim), nn.GELU(), nn.Linear(time_dim, time_dim) ) # audio conditioning self.has_cond = exists(cond_dim) or use_hubert_audio_cond self.cond_dim = cond_dim self.cond_aud_dim = cond_aud self.cond_pose_dim = cond_pose self.cond_eye_dim = cond_eye # modified by lml self.learn_null_cond = learn_null_cond # cat(t,cond) is not suitable # cond_dim = time_dim + int(cond_dim or 0) # layers self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) num_resolutions = len(in_out) # block type block_klass = partial(ResnetBlock_ca_mul, groups=resnet_groups) block_klass_cond = partial(block_klass, time_emb_dim=time_dim, audio_emb_dim=self.cond_aud_dim, pose_emb_dim=self.cond_pose_dim, eye_emb_dim=self.cond_eye_dim) # block_klass_cond = partial(block_klass, time_emb_dim=cond_dim) # cat embedding # modules for all layers for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) self.downs.append(nn.ModuleList([ block_klass_cond(dim_in, dim_out), block_klass_cond(dim_out, dim_out), Residual(PreNorm(dim_out, SpatialLinearAttention(dim_out, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(), Residual(PreNorm(dim_out, temporal_attn(dim_out))), Downsample(dim_out) if not is_last else nn.Identity() ])) mid_dim = dims[-1] self.mid_block1 = block_klass_cond(mid_dim, mid_dim) spatial_attn = EinopsToAndFrom('b c f h w', 'b f (h w) c', Attention(mid_dim, heads=attn_heads)) self.mid_spatial_attn = Residual(PreNorm(mid_dim, spatial_attn)) self.mid_temporal_attn = Residual(PreNorm(mid_dim, temporal_attn(mid_dim))) self.mid_block2 = block_klass_cond(mid_dim, mid_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): is_last = ind >= (num_resolutions - 1) self.ups.append(nn.ModuleList([ block_klass_cond(dim_out * 2, dim_in), block_klass_cond(dim_in, dim_in), Residual(PreNorm(dim_in, SpatialLinearAttention(dim_in, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(), Residual(PreNorm(dim_in, temporal_attn(dim_in))), Upsample(dim_in, use_deconv, padding_mode) if not is_last else nn.Identity() ])) # out_dim = default(out_grid_dim, channels) self.final_conv = nn.Sequential( block_klass(dim * 2, dim), nn.Conv3d(dim, out_grid_dim, 1) ) # added by nhm self.use_final_activation = use_final_activation if self.use_final_activation: self.final_activation = nn.Tanh() else: self.final_activation = nn.Identity() # added by nhm for predicting occlusion mask self.occlusion_map = nn.Sequential( block_klass(dim * 2, dim), nn.Conv3d(dim, out_conf_dim, 1) ) def forward_with_cond_scale( self, *args, cond_scale=2., **kwargs ): logits = self.forward(*args, null_cond_prob=0., **kwargs) if cond_scale == 1 or not self.has_cond: return logits null_logits = self.forward(*args, null_cond_prob=1., **kwargs) return null_logits + (logits - null_logits) * cond_scale def forward( self, x, time, cond=None, null_cond_prob=0., focus_present_mask=None, prob_focus_present=0. # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time) ): assert not (self.has_cond and not exists(cond)), 'cond must be passed in if cond_dim specified' batch, device = x.shape[0], x.device focus_present_mask = default(focus_present_mask, lambda: prob_mask_like((batch,), prob_focus_present, device=device)) time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device=x.device) x = self.init_conv(x) r = x.clone() x = self.init_temporal_attn(x, pos_bias=time_rel_pos_bias) t = self.time_mlp(time) if exists(self.time_mlp) else None if self.learn_null_cond: self.null_cond_emb = nn.Parameter(torch.randn(1, self.num_frames, self.cond_dim)) if self.has_cond else None else: self.null_cond_emb = torch.zeros(1, self.num_frames, self.cond_dim) if self.has_cond else None # classifier free guidance if self.has_cond: batch, device = x.shape[0], x.device self.null_cond_mask = prob_mask_like((batch, self.num_frames,), null_cond_prob, device=device) cond = torch.where(rearrange(self.null_cond_mask, 'b n -> b n 1'), self.null_cond_emb.to(cond.device), cond) # t (bs, 256) cond (bs, nf*1024)->(bs, nf, 1024) in this version # it's the original cond embedding method used in LFDM # t = torch.cat((t, cond), dim=-1) h = [] for block1, block2, spatial_attn, temporal_attn, downsample in self.downs: x = block1(x, t, cond) x = block2(x, t, cond) x = spatial_attn(x) x = temporal_attn(x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask) h.append(x) x = downsample(x) x = self.mid_block1(x, t, cond) x = self.mid_spatial_attn(x) x = self.mid_temporal_attn(x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask) x = self.mid_block2(x, t, cond) for block1, block2, spatial_attn, temporal_attn, upsample in self.ups: x = torch.cat((x, h.pop()), dim=1) x = block1(x, t, cond) x = block2(x, t, cond) x = spatial_attn(x) x = temporal_attn(x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask) x = upsample(x) x = torch.cat((x, r), dim=1) return torch.cat((self.final_conv(x), self.occlusion_map(x)), dim=1) # to dynamically change num_frames of Unet3D class DynamicNfUnet3D(Unet3D): def __init__(self, default_num_frames=20, *args, **kwargs): super(DynamicNfUnet3D, self).__init__(*args, **kwargs) self.default_num_frames = default_num_frames self.num_frames = default_num_frames def update_num_frames(self, new_num_frames): self.num_frames = new_num_frames # gaussian diffusion trainer class def extract(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) def cosine_beta_schedule(timesteps, s=0.008): """ cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ """ steps = timesteps + 1 x = torch.linspace(0, timesteps, steps, dtype=torch.float64) alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.9999) class GaussianDiffusion(nn.Module): def __init__( self, denoise_fn, *, image_size, num_frames, text_use_bert_cls=False, channels=3, timesteps=1000, sampling_timesteps=250, ddim_sampling_eta=1., loss_type='l1', use_dynamic_thres=False, # from the Imagen paper dynamic_thres_percentile=0.9, null_cond_prob=0.1 ): super().__init__() self.null_cond_prob = null_cond_prob self.channels = channels self.image_size = image_size self.num_frames = num_frames self.denoise_fn = denoise_fn betas = cosine_beta_schedule(timesteps) alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, axis=0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.) timesteps, = betas.shape self.num_timesteps = int(timesteps) self.loss_type = loss_type self.sampling_timesteps = default(sampling_timesteps, timesteps) self.is_ddim_sampling = self.sampling_timesteps < timesteps self.ddim_sampling_eta = ddim_sampling_eta # register buffer helper function that casts float64 to float32 register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) register_buffer('betas', betas) register_buffer('alphas_cumprod', alphas_cumprod) register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) # calculations for diffusion q(x_t | x_{t-1}) and others register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) register_buffer('posterior_variance', posterior_variance) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20))) register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) # text conditioning parameters self.text_use_bert_cls = text_use_bert_cls # dynamic thresholding when sampling self.use_dynamic_thres = use_dynamic_thres self.dynamic_thres_percentile = dynamic_thres_percentile def q_mean_variance(self, x_start, t): mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start variance = extract(1. - self.alphas_cumprod, t, x_start.shape) log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) return mean, variance, log_variance def predict_start_from_noise(self, x_t, t, noise): return ( extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) def q_posterior(self, x_start, x_t, t): posterior_mean = ( extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = extract(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped def p_mean_variance(self, x, t, fea, clip_denoised: bool, cond=None, cond_scale=1.): fea = fea.unsqueeze(dim=2).repeat(1, 1, x.size(2), 1, 1) x_recon = self.predict_start_from_noise(x, t=t, noise=self.denoise_fn.forward_with_cond_scale(torch.cat([x, fea], dim=1), t, cond=cond, cond_scale=cond_scale)) if clip_denoised: s = 1. if self.use_dynamic_thres: s = torch.quantile( rearrange(x_recon, 'b ... -> b (...)').abs(), self.dynamic_thres_percentile, dim=-1 ) s.clamp_(min=1.) s = s.view(-1, *((1,) * (x_recon.ndim - 1))) # clip by threshold, depending on whether static or dynamic x_recon = x_recon.clamp(-s, s) / s model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) return model_mean, posterior_variance, posterior_log_variance @torch.inference_mode() def p_sample(self, x, t, fea, cond=None, cond_scale=1., clip_denoised=True): b, *_, device = *x.shape, x.device model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, fea=fea, clip_denoised=clip_denoised, cond=cond, cond_scale=cond_scale) noise = torch.randn_like(x) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.inference_mode() def p_sample_loop(self, fea, shape, cond=None, cond_scale=1.): device = self.betas.device b = shape[0] img = torch.randn(shape, device=device) for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), fea, cond=cond, cond_scale=cond_scale) return img # return unnormalize_img(img) @torch.inference_mode() def sample(self, fea, bbox_mask, cond=None, cond_scale=1., batch_size=16): # text bert: cond 1,768 # device = next(self.denoise_fn.parameters()).device # if is_list_str(cond): # cond = torch.rand((1 ,768), dtype=torch.float32).cuda() #used to debug # cond = bert_embed(tokenize(cond), return_cls_repr=self.text_use_bert_cls).to(device) batch_size = cond.shape[0] if exists(cond) else batch_size # batch_size = 1 if exists(cond) else batch_size image_size = self.image_size channels = self.channels num_frames = self.num_frames sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample fea = torch.cat([fea, bbox_mask], dim=1) return sample_fn(fea, (batch_size, channels, num_frames, fea.shape[-1], fea.shape[-1]), cond=cond, cond_scale=cond_scale) # add by nhm @torch.no_grad() def ddim_sample(self, fea, shape, cond=None, cond_scale=1., clip_denoised=True): batch, device, total_timesteps, sampling_timesteps, eta = \ shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta times = torch.linspace(0., total_timesteps, steps=sampling_timesteps + 2)[:-1] times = list(reversed(times.int().tolist())) time_pairs = list(zip(times[:-1], times[1:])) img = torch.randn(shape, device=device) # bs, 3, nf, 32, 32 fea = fea.unsqueeze(dim=2).repeat(1, 1, img.size(2), 1, 1) #bs, 256, nf, 32, 32 for time, time_next in tqdm(time_pairs, desc='sampling loop time step'): alpha = self.alphas_cumprod_prev[time] alpha_next = self.alphas_cumprod_prev[time_next] time_cond = torch.full((batch,), time, device=device, dtype=torch.long) # pred_noise, x_start, *_ = self.model_predictions(img, time_cond, fea) pred_noise = self.denoise_fn.forward_with_cond_scale( torch.cat([img, fea], dim=1), time_cond, cond=cond, cond_scale=cond_scale) x_start = self.predict_start_from_noise(img, t=time_cond, noise=pred_noise) if clip_denoised: s = 1. if self.use_dynamic_thres: s = torch.quantile( rearrange(x_start, 'b ... -> b (...)').abs(), self.dynamic_thres_percentile, dim=-1 ) s.clamp_(min=1.) s = s.view(-1, *((1,) * (x_start.ndim - 1))) # clip by threshold, depending on whether static or dynamic x_start = x_start.clamp(-s, s) / s sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() c = ((1 - alpha_next) - sigma ** 2).sqrt() noise = torch.randn_like(img) if time_next > 0 else 0. img = x_start * alpha_next.sqrt() + \ c * pred_noise + \ sigma * noise # img = unnormalize_to_zero_to_one(img) return img @torch.inference_mode() def interpolate(self, x1, x2, t=None, lam=0.5): b, *_, device = *x1.shape, x1.device t = default(t, self.num_timesteps - 1) assert x1.shape == x2.shape t_batched = torch.stack([torch.tensor(t, device=device)] * b) xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) img = (1 - lam) * xt1 + lam * xt2 for i in tqdm(reversed(range(0, t)), desc='interpolation sample time step', total=t): img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long)) return img def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) return ( extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) def p_losses(self, x_start, t, fea, bbox_mask, cond=None, noise=None, clip_denoised=True, **kwargs): # x_start: bs, 3, num_frame, 32, 32 # t: bs # fea: bs, 256, num_frame, 32, 32 # cond: bs, 768 b, c, f, h, w, device = *x_start.shape, x_start.device noise = default(noise, lambda: torch.randn_like(x_start)) # bs, 3, nf, 32, 32 x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)# bs, 3, nf, 32, 32 pred_noise = self.denoise_fn.forward(torch.cat([x_noisy, fea, bbox_mask], dim=1), t, cond=cond, null_cond_prob=self.null_cond_prob, **kwargs) if self.loss_type == 'l1': loss = F.l1_loss(noise, pred_noise, reduce=False) elif self.loss_type == 'l2': loss = F.mse_loss(noise, pred_noise, reduce=False) else: raise NotImplementedError() pred_x0 = self.predict_start_from_noise(x_noisy, t, pred_noise) if clip_denoised: s = 1. if self.use_dynamic_thres: s = torch.quantile( rearrange(pred_x0, 'b ... -> b (...)').abs(), self.dynamic_thres_percentile, dim=-1 ) s.clamp_(min=1.) s = s.view(-1, *((1,) * (pred_x0.ndim - 1))) # clip by threshold, depending on whether static or dynamic self.pred_x0 = pred_x0.clamp(-s, s) / s return loss, self.denoise_fn.null_cond_mask def forward(self, x, fea, bbox_mask, cond, *args, **kwargs): b, device, img_size, = x.shape[0], x.device, self.image_size # check_shape(x, 'b c f h w', c=self.channels, f=self.num_frames, h=img_size, w=img_size) t = torch.randint(0, self.num_timesteps, (b,), device=device).long() fea = fea.unsqueeze(dim=2).repeat(1, 1, x.size(2), 1, 1) bbox_mask = bbox_mask.unsqueeze(dim=2).repeat(1, 1, x.size(2), 1, 1) return self.p_losses(x, t, fea, bbox_mask, cond, *args, **kwargs) # trainer class CHANNELS_TO_MODE = { 1: 'L', 3: 'RGB', 4: 'RGBA' } def seek_all_images(img, channels=3): assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid' mode = CHANNELS_TO_MODE[channels] i = 0 while True: try: img.seek(i) yield img.convert(mode) except EOFError: break i += 1 # to dynamically change num_frames of GaussianDiffusion class DynamicNfGaussianDiffusion(GaussianDiffusion): def __init__(self, default_num_frames=20, *args, **kwargs): super(DynamicNfGaussianDiffusion, self).__init__(*args, **kwargs) self.default_num_frames = default_num_frames self.num_frames = default_num_frames def update_num_frames(self, new_num_frames): self.num_frames = new_num_frames # tensor of shape (channels, frames, height, width) -> gif def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True): images = map(T.ToPILImage(), tensor.unbind(dim=1)) first_img, *rest_imgs = images first_img.save(path, save_all=True, append_images=rest_imgs, duration=duration, loop=loop, optimize=optimize) return images # gif -> (channels, frame, height, width) tensor def gif_to_tensor(path, channels=3, transform=T.ToTensor()): img = Image.open(path) tensors = tuple(map(transform, seek_all_images(img, channels=channels))) return torch.stack(tensors, dim=1) def identity(t, *args, **kwargs): return t def normalize_img(t): return t * 2 - 1 # def unnormalize_img(t): # return (t + 1) * 0.5 def cast_num_frames(t, *, frames): f = t.shape[1] if f == frames: return t if f > frames: return t[:, :frames] return F.pad(t, (0, 0, 0, 0, 0, frames - f)) ================================================ FILE: DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test_local_opt.py ================================================ ''' adding pose condtioning on baseline using cross attention to add different condition using ram optimized local attention, for inference (slower, costing less ram) ''' import math import torch from torch import nn, einsum import torch.nn.functional as F from functools import partial from torchvision import transforms as T from PIL import Image from tqdm import tqdm from einops import rearrange, repeat, reduce, pack, unpack from einops_exts import rearrange_many from rotary_embedding_torch import RotaryEmbedding from DM_3.modules.local_attention import LocalSelfAttention_opt, create_sliding_window_mask # from DM.modules.text import tokenize, bert_embed, HUBERT_MODEL_DIM # helpers functions def exists(x): return x is not None def noop(*args, **kwargs): pass def is_odd(n): return (n % 2) == 1 def default(val, d): if exists(val): return val return d() if callable(d) else d def cycle(dl): while True: for data in dl: yield data def num_to_groups(num, divisor): groups = num // divisor remainder = num % divisor arr = [divisor] * groups if remainder > 0: arr.append(remainder) return arr def prob_mask_like(shape, prob, device): if prob == 1: return torch.ones(shape, device=device, dtype=torch.bool) elif prob == 0: return torch.zeros(shape, device=device, dtype=torch.bool) else: return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob def is_list_str(x): if not isinstance(x, (list, tuple)): return False return all([type(el) == str for el in x]) # relative positional bias class RelativePositionBias(nn.Module): def __init__( self, heads=8, num_buckets=32, max_distance=128, window_width = 20 ): super().__init__() self.num_buckets = num_buckets self.max_distance = max_distance self.relative_attention_bias = nn.Embedding(num_buckets, heads) self.window_width = window_width @staticmethod def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): ret = 0 n = -relative_position num_buckets //= 2 ret += (n < 0).long() * num_buckets n = torch.abs(n) max_exact = num_buckets // 2 is_small = n < max_exact val_if_large = max_exact + ( torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).long() val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) ret += torch.where(is_small, n, val_if_large) return ret def forward(self, n, device): q_pos = torch.arange(n, dtype=torch.long, device=device) k_pos = torch.arange(n, dtype=torch.long, device=device) rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') rp_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance) mask = -(((rel_pos > self.window_width) + (rel_pos < -self.window_width)) * (1e8)) # -(((rp_bucket ==15) + (rp_bucket >= 30)) * (1e8)) values = self.relative_attention_bias(rp_bucket) return rearrange(values, 'i j h -> h i j') + mask # small helper modules class EMA(): def __init__(self, beta): super().__init__() self.beta = beta def update_model_average(self, ma_model, current_model): for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): old_weight, up_weight = ma_params.data, current_params.data ma_params.data = self.update_average(old_weight, up_weight) def update_average(self, old, new): if old is None: return new return old * self.beta + (1 - self.beta) * new class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, *args, **kwargs): return self.fn(x, *args, **kwargs) + x class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb def Upsample(dim, use_deconv=True, padding_mode="reflect"): if use_deconv: return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) else: return nn.Sequential( nn.Upsample(scale_factor=(1, 2, 2), mode='nearest'), nn.Conv3d(dim, dim, (1, 3, 3), (1, 1, 1), (0, 1, 1), padding_mode=padding_mode) ) def Downsample(dim): return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) class LayerNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1)) def forward(self, x): var = torch.var(x, dim=1, unbiased=False, keepdim=True) mean = torch.mean(x, dim=1, keepdim=True) return (x - mean) / (var + self.eps).sqrt() * self.gamma class LayerNorm_img(nn.Module): def __init__(self, dim, stable = False): super().__init__() self.stable = stable self.g = nn.Parameter(torch.ones(dim)) def forward(self, x): if self.stable: x = x / x.amax(dim = -1, keepdim = True).detach() eps = 1e-5 if x.dtype == torch.float32 else 1e-3 var = torch.var(x, dim = -1, unbiased = False, keepdim = True) mean = torch.mean(x, dim = -1, keepdim = True) return (x - mean) * (var + eps).rsqrt() * self.g class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = LayerNorm(dim) def forward(self, x, **kwargs): x = self.norm(x) return self.fn(x, **kwargs) class Identity(nn.Module): def __init__(self, *args, **kwargs): super().__init__() def forward(self, x, *args, **kwargs): return x def l2norm(t): return F.normalize(t, dim = -1) # building block modules class Block(nn.Module): def __init__(self, dim, dim_out, groups=8): super().__init__() self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding=(0, 1, 1)) self.norm = nn.GroupNorm(groups, dim_out) self.act = nn.SiLU() def forward(self, x, time_scale_shift=None, audio_scale_shift=None): x = self.proj(x) x = self.norm(x) if exists(time_scale_shift): time_scale, time_shift = time_scale_shift x = x * (time_scale + 1) + time_shift # added by lml to change the control method of audio embedding, inspired by diffusedhead # if exists(audio_scale_shift): # # audio_scale and audio_shift:(bs, 64, nf, 1, 1) # # x:(bs, 64, nf, 32, 32) # audio_scale, audio_shift = audio_scale_shift # x = x * (audio_scale + 1) + audio_shift return self.act(x) class ResnetBlock(nn.Module): def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, groups=8): super().__init__() self.time_mlp = nn.Sequential( nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2) ) if exists(time_emb_dim) else None self.audio_mlp = nn.Sequential( nn.SiLU(), nn.Linear(audio_emb_dim, dim_out * 2) ) if exists(audio_emb_dim) else None self.block1 = Block(dim, dim_out, groups=groups) self.block2 = Block(dim_out, dim_out, groups=groups) self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb=None, audio_emb=None): time_scale_shift = None audio_scale_shift = None if exists(self.time_mlp): assert exists(time_emb), 'time emb must be passed in' time_emb = self.time_mlp(time_emb) time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1 time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1 # added by lml to get audio embedding if exists(self.audio_mlp): assert exists(audio_emb), 'audio emb must be passed in' audio_emb = self.audio_mlp(audio_emb) audio_emb = rearrange(audio_emb, 'b n c -> b c n 1 1') # bs, 128, nf, 1, 1 audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1 h = self.block1(x, time_scale_shift=time_scale_shift, audio_scale_shift=audio_scale_shift) h = self.block2(h) return h + self.res_conv(x) class ResnetBlock_ca(nn.Module): def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, groups=8): super().__init__() self.time_mlp = nn.Sequential( nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2) ) if exists(time_emb_dim) else None self.audio_mlp = nn.Sequential( nn.SiLU(), nn.Linear(audio_emb_dim, dim_out * 2) ) if exists(audio_emb_dim) else None # self.audio_mlp_2 = nn.Sequential( # nn.SiLU(), # nn.Linear(dim_out, dim_out * 2) # ) if exists(audio_emb_dim) else None attn_klass = CrossAttention self.cross_attn = attn_klass( dim = dim, context_dim = dim_out * 2 ) self.block1 = Block(dim, dim_out, groups=groups) self.block2 = Block(dim_out, dim_out, groups=groups) self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb=None, audio_emb=None): time_scale_shift = None audio_scale_shift = None b, c, f, H, W = x.size() if exists(self.time_mlp): assert exists(time_emb), 'time emb must be passed in' time_emb = self.time_mlp(time_emb) time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1 time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1 # added by lml to get audio embedding if exists(self.audio_mlp): assert exists(audio_emb), 'audio emb must be passed in' audio_emb = self.audio_mlp(audio_emb) if exists(self.cross_attn): # h = rearrange(x, 'b c f ... -> (b f) ... c') # h, ps = pack([h], 'b * c') # audio_emb = rearrange(audio_emb, 'b f ... -> (b f) ...') # audio_emb = self.cross_attn(h, context = audio_emb) # # h, = unpack(h, ps, 'b * c') # # h = rearrange(h, '(b f) ... c -> b c f ...', b = b, f = f, c = c) # # audio_emb = self.audio_mlp_2(audio_emb) # audio_emb = rearrange(audio_emb, '(b f) ... -> b f ...', b = b, f = f) assert exists(audio_emb) h = rearrange(x, 'b c f ... -> (b f) ... c') # h = rearrange(x, 'b c ... -> b ... c') h, ps = pack([h], 'b * c') h = self.cross_attn(h, context = audio_emb) + h h, = unpack(h, ps, 'b * c') # h = rearrange(h, 'b ... c -> b c ...') h = rearrange(h, '(b f) ... c -> b f c ...', b = b, f = f) # audio_emb = rearrange(audio_emb, 'b f (h w) c -> b c f h w', w = W, h = H) # bs, 128, nf, 1, 1 # audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1 h = self.block1(x, time_scale_shift=time_scale_shift) h = self.block2(h) return h + self.res_conv(x) class ResnetBlock_ca_mul(nn.Module): def __init__(self, dim, dim_out, *, time_emb_dim=None, audio_emb_dim=None, pose_emb_dim=None, eye_emb_dim=None, groups=8): super().__init__() self.time_mlp = nn.Sequential( nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2) ) if exists(time_emb_dim) else None self.audio_mlp = nn.Sequential( nn.SiLU(), nn.Linear(audio_emb_dim, dim_out * 2) ) if exists(audio_emb_dim) else None self.pose_mlp = nn.Sequential( nn.SiLU(), nn.Linear(pose_emb_dim, dim_out * 2) ) if exists(pose_emb_dim) else None self.eye_mlp = nn.Sequential( nn.SiLU(), nn.Linear(eye_emb_dim, dim_out * 2) ) if exists(eye_emb_dim) else None self.audio_emb_dim = audio_emb_dim self.pose_emb_dim = pose_emb_dim self.eye_emb_dim = eye_emb_dim # self.audio_mlp_2 = nn.Sequential( # nn.SiLU(), # nn.Linear(dim_out, dim_out * 2) # ) if exists(audio_emb_dim) else None attn_klass = CrossAttention self.cross_attn_aud = attn_klass( dim = dim, context_dim = dim_out * 2, out_dim = dim_out ) self.cross_attn_pose = attn_klass( dim = dim, context_dim = dim_out * 2, out_dim = dim_out ) self.cross_attn_eye = attn_klass( dim = dim, context_dim = dim_out * 2, out_dim = dim_out ) self.block1 = Block(dim, dim_out, groups=groups) self.block2 = Block(dim_out, dim_out, groups=groups) self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb=None, audio_emb=None): time_scale_shift = None audio_scale_shift = None ''' need seperate 3 diffiserent condition ''' if exists(audio_emb): pose_emb = audio_emb[:,:,self.audio_emb_dim:self.audio_emb_dim + self.pose_emb_dim] eye_emb = audio_emb[:,:,self.audio_emb_dim + self.pose_emb_dim: ] audio_emb = audio_emb[:,:,:self.audio_emb_dim] b, c, f, H, W = x.size() if exists(self.time_mlp): assert exists(time_emb), 'time emb must be passed in' time_emb = self.time_mlp(time_emb) time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') # bs, 128, 1, 1 time_scale_shift = time_emb.chunk(2, dim=1) # bs, 64, 1, 1 # added by lml to get audio embedding if exists(self.audio_mlp): # mouth lmk + audio emb assert exists(audio_emb), 'audio emb must be passed in' audio_emb = self.audio_mlp(audio_emb) pose_emb = self.pose_mlp(pose_emb) # TODO: embedding eye_emb = self.eye_mlp(eye_emb) if exists(self.cross_attn_aud): # h = rearrange(x, 'b c f ... -> (b f) ... c') # h, ps = pack([h], 'b * c') # audio_emb = rearrange(audio_emb, 'b f ... -> (b f) ...') # audio_emb = self.cross_attn(h, context = audio_emb) # # h, = unpack(h, ps, 'b * c') # # h = rearrange(h, '(b f) ... c -> b c f ...', b = b, f = f, c = c) # # audio_emb = self.audio_mlp_2(audio_emb) # audio_emb = rearrange(audio_emb, '(b f) ... -> b f ...', b = b, f = f) assert exists(audio_emb) h_cond = rearrange(x, 'b c f ... -> (b f) ... c') # h = rearrange(x, 'b c ... -> b ... c') h_cond, ps = pack([h_cond], 'b * c') h_pose = self.cross_attn_pose(h_cond, context = pose_emb) h_aud = self.cross_attn_aud(h_cond, context = audio_emb) h_eye = self.cross_attn_eye(h_cond, context = eye_emb) h_cond = h_pose + h_aud + h_eye h_cond, = unpack(h_cond, ps, 'b * c') # h = rearrange(h, 'b ... c -> b c ...') h_cond = rearrange(h_cond, '(b f) ... c -> b c f ...', b = b, f = f) # audio_emb = rearrange(audio_emb, 'b f (h w) c -> b c f h w', w = W, h = H) # bs, 128, nf, 1, 1 # audio_scale_shift = audio_emb.chunk(2, dim=1) # bs, 64, nf, 1, 1 h = self.block1(x, time_scale_shift=time_scale_shift) if exists(self.audio_mlp): h = h_cond + h h = self.block2(h) return h + self.res_conv(x) class CrossAttention(nn.Module): def __init__( self, dim, out_dim, *, context_dim = None, dim_head = 8, heads = 8, norm_context = False, scale = 8 ): super().__init__() self.scale = scale self.heads = heads inner_dim = dim_head * heads context_dim = default(context_dim, dim) self.norm = LayerNorm_img(dim) self.norm_context = LayerNorm_img(context_dim) if norm_context else Identity() self.null_kv = nn.Parameter(torch.randn(2, dim_head)) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) self.q_scale = nn.Parameter(torch.ones(dim_head)) self.k_scale = nn.Parameter(torch.ones(dim_head)) self.to_out = nn.Sequential( nn.Linear(inner_dim, out_dim, bias = False), LayerNorm_img(out_dim) ) def forward(self, x, context, mask = None): b, n, device = *x.shape[:2], x.device x = self.norm(x) # bn * fn ? # context: b, fn, c context = rearrange(context, 'b f c -> (b f) c') context = self.norm_context(context) q, k, v = (self.to_q(x), *self.to_kv(context[:, None, :]).chunk(2, dim = -1)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v)) # add null key / value for classifier free guidance in prior net nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2)) k = torch.cat((nk, k), dim = -2) v = torch.cat((nv, v), dim = -2) # cosine sim attention q, k = map(l2norm, (q, k)) q = q * self.q_scale k = k * self.k_scale # similarities sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale # masking max_neg_value = -torch.finfo(sim.dtype).max if exists(mask): mask = F.pad(mask, (1, 0), value = True) mask = rearrange(mask, 'b j -> b 1 1 j') sim = sim.masked_fill(~mask, max_neg_value) attn = sim.softmax(dim = -1, dtype = torch.float32) attn = attn.to(sim.dtype) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) class LinearCrossAttention(CrossAttention): def forward(self, x, context, mask = None): b, n, c = x.size() b, n, device = *x.shape[:2], x.device # x : b * fn, 32*32, c x = self.norm(x) context = self.norm_context(context) q, k, v = (self.to_q(x), *self.to_kv(context[:, None, :]).chunk(2, dim = -1)) # b*fn, 32*32, c, b*fn, 1, c * 2, q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads, d = c//self.heads), (q, k, v)) # head * b*fn, n, c//head # add null key / value for classifier free guidance in prior net nk, nv = map(lambda t: repeat(t, 'd -> (b h) 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2)) k = torch.cat((nk, k), dim = -2) # b * nf * h, 2, c//h v = torch.cat((nv, v), dim = -2) # masking max_neg_value = -torch.finfo(x.dtype).max if exists(mask): mask = F.pad(mask, (1, 0), value = True) mask = rearrange(mask, 'b n -> b n 1') k = k.masked_fill(~mask, max_neg_value) v = v.masked_fill(~mask, 0.) # linear attention q = q.softmax(dim = -1) # # b * nf * h, 32*32, c//h, k = k.softmax(dim = -2) q = q * self.scale context = einsum('b n d, b n e -> b d e', k, v) # b * nf * h, 2, c//h, b * nf * h, 2, c//h out = einsum('b n d, b d e -> b n e', q, context) out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads) return self.to_out(out) class SpatialLinearAttention(nn.Module): def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, f, h, w = x.shape x = rearrange(x, 'b c f h w -> (b f) c h w') qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = rearrange_many(qkv, 'b (h c) x y -> b h c (x y)', h=self.heads) q = q.softmax(dim=-2) k = k.softmax(dim=-1) q = q * self.scale context = torch.einsum('b h d n, b h e n -> b h d e', k, v) out = torch.einsum('b h d e, b h d n -> b h e n', context, q) out = rearrange(out, 'b h c (x y) -> b (h c) x y', h=self.heads, x=h, y=w) out = self.to_out(out) return rearrange(out, '(b f) c h w -> b c f h w', b=b) # attention along space and time class EinopsToAndFrom(nn.Module): def __init__(self, from_einops, to_einops, fn): super().__init__() self.from_einops = from_einops self.to_einops = to_einops self.fn = fn def forward(self, x, **kwargs): shape = x.shape reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape))) x = rearrange(x, f'{self.from_einops} -> {self.to_einops}') x = self.fn(x, **kwargs) x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs) return x class Attention(nn.Module): def __init__( self, dim, heads=4, dim_head=32, rotary_emb=None ): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads hidden_dim = dim_head * heads self.rotary_emb = rotary_emb self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False) self.to_out = nn.Linear(hidden_dim, dim, bias=False) def forward( self, x, pos_bias=None, focus_present_mask=None ): # temperal: 'b (h w) f c' ; spatial : 'b f (h w) c' n, device = x.shape[-2], x.device qkv = self.to_qkv(x).chunk(3, dim=-1) if exists(focus_present_mask) and focus_present_mask.all(): # if all batch samples are focusing on present # it would be equivalent to passing that token's values through to the output values = qkv[-1] return self.to_out(values) # split out heads q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads) # scale q = q * self.scale # rotate positions into queries and keys for time attention if exists(self.rotary_emb): q = self.rotary_emb.rotate_queries_or_keys(q) k = self.rotary_emb.rotate_queries_or_keys(k) # similarity sim = einsum('... h i d, ... h j d -> ... h i j', q, k) # relative positional bias if exists(pos_bias): sim = sim + pos_bias if exists(focus_present_mask) and not (~focus_present_mask).all(): attend_all_mask = torch.ones((n, n), device=device, dtype=torch.bool) attend_self_mask = torch.eye(n, device=device, dtype=torch.bool) mask = torch.where( rearrange(focus_present_mask, 'b -> b 1 1 1 1'), rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), ) sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) # numerical stability sim = sim - sim.amax(dim=-1, keepdim=True).detach() attn = sim.softmax(dim=-1) # aggregate values out = einsum('... h i j, ... h j d -> ... h i d', attn, v) out = rearrange(out, '... h n d -> ... n (h d)') return self.to_out(out) # model class Unet3D(nn.Module): def __init__( self, dim, cond_aud=1024, cond_pose=7, cond_eye=2, cond_dim=None, out_grid_dim=2, out_conf_dim=1, num_frames=40, dim_mults=(1, 2, 4, 8), channels=3, attn_heads=8, attn_dim_head=32, use_hubert_audio_cond=False, init_dim=None, init_kernel_size=7, use_sparse_linear_attn=True, resnet_groups=8, use_final_activation=False, learn_null_cond=False, use_deconv=True, padding_mode="zeros", win_width = 20 ): super().__init__() self.null_cond_mask = None self.channels = channels self.num_frames = num_frames self.HUBERT_MODEL_DIM = 1024 self.win_width = win_width # temporal attention and its relative positional encoding rotary_emb = RotaryEmbedding(min(32, attn_dim_head), seq_before_head_dim = True) temporal_attn = lambda dim: EinopsToAndFrom('b c f h w', 'b (h w) f c', LocalSelfAttention_opt(dim, heads=attn_heads, size_per_head=attn_dim_head, neighbors=self.win_width, rotary_emb=rotary_emb)) self.time_rel_pos_bias = RelativePositionBias(heads=attn_heads, max_distance=32, window_width = self.win_width ) # realistically will not be able to generate that many frames of video... yet # initial conv init_dim = default(init_dim, dim) assert is_odd(init_kernel_size) init_padding = init_kernel_size // 2 self.init_conv = nn.Conv3d(channels, init_dim, (1, init_kernel_size, init_kernel_size), padding=(0, init_padding, init_padding)) self.init_temporal_attn = Residual(PreNorm(init_dim, temporal_attn(init_dim))) # dimensions dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) # time conditioning time_dim = dim * 4 self.time_mlp = nn.Sequential( SinusoidalPosEmb(dim), nn.Linear(dim, time_dim), nn.GELU(), nn.Linear(time_dim, time_dim) ) # audio conditioning self.has_cond = exists(cond_dim) or use_hubert_audio_cond self.cond_dim = cond_dim self.cond_aud_dim = cond_aud self.cond_pose_dim = cond_pose self.cond_eye_dim = cond_eye # modified by lml self.learn_null_cond = learn_null_cond # cat(t,cond) is not suitable # cond_dim = time_dim + int(cond_dim or 0) # layers self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) num_resolutions = len(in_out) # block type block_klass = partial(ResnetBlock_ca_mul, groups=resnet_groups) block_klass_cond = partial(block_klass, time_emb_dim=time_dim, audio_emb_dim=self.cond_aud_dim, pose_emb_dim=self.cond_pose_dim, eye_emb_dim=self.cond_eye_dim) # block_klass_cond = partial(block_klass, time_emb_dim=cond_dim) # cat embedding # modules for all layers for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) self.downs.append(nn.ModuleList([ block_klass_cond(dim_in, dim_out), block_klass_cond(dim_out, dim_out), Residual(PreNorm(dim_out, SpatialLinearAttention(dim_out, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(), Residual(PreNorm(dim_out, temporal_attn(dim_out))), Downsample(dim_out) if not is_last else nn.Identity() ])) mid_dim = dims[-1] self.mid_block1 = block_klass_cond(mid_dim, mid_dim) spatial_attn = EinopsToAndFrom('b c f h w', 'b f (h w) c', Attention(mid_dim, heads=attn_heads)) self.mid_spatial_attn = Residual(PreNorm(mid_dim, spatial_attn)) self.mid_temporal_attn = Residual(PreNorm(mid_dim, temporal_attn(mid_dim))) self.mid_block2 = block_klass_cond(mid_dim, mid_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): is_last = ind >= (num_resolutions - 1) self.ups.append(nn.ModuleList([ block_klass_cond(dim_out * 2, dim_in), block_klass_cond(dim_in, dim_in), Residual(PreNorm(dim_in, SpatialLinearAttention(dim_in, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(), Residual(PreNorm(dim_in, temporal_attn(dim_in))), Upsample(dim_in, use_deconv, padding_mode) if not is_last else nn.Identity() ])) # out_dim = default(out_grid_dim, channels) self.final_conv = nn.Sequential( block_klass(dim * 2, dim), nn.Conv3d(dim, out_grid_dim, 1) ) # added by nhm self.use_final_activation = use_final_activation if self.use_final_activation: self.final_activation = nn.Tanh() else: self.final_activation = nn.Identity() # added by nhm for predicting occlusion mask self.occlusion_map = nn.Sequential( block_klass(dim * 2, dim), nn.Conv3d(dim, out_conf_dim, 1) ) def forward_with_cond_scale( self, *args, cond_scale=2., **kwargs ): logits = self.forward(*args, null_cond_prob=0., **kwargs) if cond_scale == 1 or not self.has_cond: return logits null_logits = self.forward(*args, null_cond_prob=1., **kwargs) return null_logits + (logits - null_logits) * cond_scale def forward( self, x, time, cond=None, null_cond_prob=0., focus_present_mask=None, prob_focus_present=0. # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time) ): assert not (self.has_cond and not exists(cond)), 'cond must be passed in if cond_dim specified' batch, device = x.shape[0], x.device focus_present_mask = default(focus_present_mask, lambda: prob_mask_like((batch,), prob_focus_present, device=device)) time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device=x.device) time_rel_pos_bias = create_sliding_window_mask(time_rel_pos_bias, 2 * self.win_width + 1, 1) x = self.init_conv(x) r = x.clone() x = self.init_temporal_attn(x, pos_bias=time_rel_pos_bias) t = self.time_mlp(time) if exists(self.time_mlp) else None if self.learn_null_cond: self.null_cond_emb = nn.Parameter(torch.randn(1, self.num_frames, self.cond_dim)) if self.has_cond else None else: self.null_cond_emb = torch.zeros(1, self.num_frames, self.cond_dim) if self.has_cond else None # classifier free guidance if self.has_cond: batch, device = x.shape[0], x.device self.null_cond_mask = prob_mask_like((batch, self.num_frames,), null_cond_prob, device=device) cond = torch.where(rearrange(self.null_cond_mask, 'b n -> b n 1'), self.null_cond_emb.to(cond.device), cond) # t (bs, 256) cond (bs, nf*1024)->(bs, nf, 1024) in this version # it's the original cond embedding method used in LFDM # t = torch.cat((t, cond), dim=-1) h = [] for block1, block2, spatial_attn, temporal_attn, downsample in self.downs: x = block1(x, t, cond) x = block2(x, t, cond) x = spatial_attn(x) x = temporal_attn(x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask) h.append(x) x = downsample(x) x = self.mid_block1(x, t, cond) x = self.mid_spatial_attn(x) x = self.mid_temporal_attn(x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask) x = self.mid_block2(x, t, cond) for block1, block2, spatial_attn, temporal_attn, upsample in self.ups: x = torch.cat((x, h.pop()), dim=1) x = block1(x, t, cond) x = block2(x, t, cond) x = spatial_attn(x) x = temporal_attn(x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask) x = upsample(x) x = torch.cat((x, r), dim=1) return torch.cat((self.final_conv(x), self.occlusion_map(x)), dim=1) # to dynamically change num_frames of Unet3D class DynamicNfUnet3D(Unet3D): def __init__(self, default_num_frames=20, *args, **kwargs): super(DynamicNfUnet3D, self).__init__(*args, **kwargs) self.default_num_frames = default_num_frames self.num_frames = default_num_frames def update_num_frames(self, new_num_frames): self.num_frames = new_num_frames # gaussian diffusion trainer class def extract(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) def cosine_beta_schedule(timesteps, s=0.008): """ cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ """ steps = timesteps + 1 x = torch.linspace(0, timesteps, steps, dtype=torch.float64) alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.9999) class GaussianDiffusion(nn.Module): def __init__( self, denoise_fn, *, image_size, num_frames, text_use_bert_cls=False, channels=3, timesteps=1000, sampling_timesteps=250, ddim_sampling_eta=1., loss_type='l1', use_dynamic_thres=False, # from the Imagen paper dynamic_thres_percentile=0.9, null_cond_prob=0.1 ): super().__init__() self.null_cond_prob = null_cond_prob self.channels = channels self.image_size = image_size self.num_frames = num_frames self.denoise_fn = denoise_fn betas = cosine_beta_schedule(timesteps) alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, axis=0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.) timesteps, = betas.shape self.num_timesteps = int(timesteps) self.loss_type = loss_type self.sampling_timesteps = default(sampling_timesteps, timesteps) self.is_ddim_sampling = self.sampling_timesteps < timesteps self.ddim_sampling_eta = ddim_sampling_eta # register buffer helper function that casts float64 to float32 register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) register_buffer('betas', betas) register_buffer('alphas_cumprod', alphas_cumprod) register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) # calculations for diffusion q(x_t | x_{t-1}) and others register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) register_buffer('posterior_variance', posterior_variance) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20))) register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) # text conditioning parameters self.text_use_bert_cls = text_use_bert_cls # dynamic thresholding when sampling self.use_dynamic_thres = use_dynamic_thres self.dynamic_thres_percentile = dynamic_thres_percentile def q_mean_variance(self, x_start, t): mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start variance = extract(1. - self.alphas_cumprod, t, x_start.shape) log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) return mean, variance, log_variance def predict_start_from_noise(self, x_t, t, noise): return ( extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) def q_posterior(self, x_start, x_t, t): posterior_mean = ( extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = extract(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped def p_mean_variance(self, x, t, fea, clip_denoised: bool, cond=None, cond_scale=1.): fea = fea.unsqueeze(dim=2).repeat(1, 1, x.size(2), 1, 1) x_recon = self.predict_start_from_noise(x, t=t, noise=self.denoise_fn.forward_with_cond_scale(torch.cat([x, fea], dim=1), t, cond=cond, cond_scale=cond_scale)) if clip_denoised: s = 1. if self.use_dynamic_thres: s = torch.quantile( rearrange(x_recon, 'b ... -> b (...)').abs(), self.dynamic_thres_percentile, dim=-1 ) s.clamp_(min=1.) s = s.view(-1, *((1,) * (x_recon.ndim - 1))) # clip by threshold, depending on whether static or dynamic x_recon = x_recon.clamp(-s, s) / s model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) return model_mean, posterior_variance, posterior_log_variance @torch.inference_mode() def p_sample(self, x, t, fea, cond=None, cond_scale=1., clip_denoised=True): b, *_, device = *x.shape, x.device model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, fea=fea, clip_denoised=clip_denoised, cond=cond, cond_scale=cond_scale) noise = torch.randn_like(x) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.inference_mode() def p_sample_loop(self, fea, shape, cond=None, cond_scale=1.): device = self.betas.device b = shape[0] img = torch.randn(shape, device=device) for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), fea, cond=cond, cond_scale=cond_scale) return img # return unnormalize_img(img) @torch.inference_mode() def sample(self, fea, bbox_mask, cond=None, cond_scale=1., batch_size=16): # text bert: cond 1,768 # device = next(self.denoise_fn.parameters()).device # if is_list_str(cond): # cond = torch.rand((1 ,768), dtype=torch.float32).cuda() #used to debug # cond = bert_embed(tokenize(cond), return_cls_repr=self.text_use_bert_cls).to(device) batch_size = cond.shape[0] if exists(cond) else batch_size # batch_size = 1 if exists(cond) else batch_size image_size = self.image_size channels = self.channels num_frames = self.num_frames sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample fea = torch.cat([fea, bbox_mask], dim=1) return sample_fn(fea, (batch_size, channels, num_frames, fea.shape[-1], fea.shape[-1]), cond=cond, cond_scale=cond_scale) # add by nhm @torch.no_grad() def ddim_sample(self, fea, shape, cond=None, cond_scale=1., clip_denoised=True): batch, device, total_timesteps, sampling_timesteps, eta = \ shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta times = torch.linspace(0., total_timesteps, steps=sampling_timesteps + 2)[:-1] times = list(reversed(times.int().tolist())) time_pairs = list(zip(times[:-1], times[1:])) img = torch.randn(shape, device=device) # bs, 3, nf, 32, 32 fea = fea.unsqueeze(dim=2).repeat(1, 1, img.size(2), 1, 1) #bs, 256, nf, 32, 32 for time, time_next in tqdm(time_pairs, desc='sampling loop time step'): alpha = self.alphas_cumprod_prev[time] alpha_next = self.alphas_cumprod_prev[time_next] time_cond = torch.full((batch,), time, device=device, dtype=torch.long) # pred_noise, x_start, *_ = self.model_predictions(img, time_cond, fea) pred_noise = self.denoise_fn.forward_with_cond_scale( torch.cat([img, fea], dim=1), time_cond, cond=cond, cond_scale=cond_scale) x_start = self.predict_start_from_noise(img, t=time_cond, noise=pred_noise) if clip_denoised: s = 1. if self.use_dynamic_thres: s = torch.quantile( rearrange(x_start, 'b ... -> b (...)').abs(), self.dynamic_thres_percentile, dim=-1 ) s.clamp_(min=1.) s = s.view(-1, *((1,) * (x_start.ndim - 1))) # clip by threshold, depending on whether static or dynamic x_start = x_start.clamp(-s, s) / s sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() c = ((1 - alpha_next) - sigma ** 2).sqrt() noise = torch.randn_like(img) if time_next > 0 else 0. img = x_start * alpha_next.sqrt() + \ c * pred_noise + \ sigma * noise # img = unnormalize_to_zero_to_one(img) return img @torch.inference_mode() def interpolate(self, x1, x2, t=None, lam=0.5): b, *_, device = *x1.shape, x1.device t = default(t, self.num_timesteps - 1) assert x1.shape == x2.shape t_batched = torch.stack([torch.tensor(t, device=device)] * b) xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) img = (1 - lam) * xt1 + lam * xt2 for i in tqdm(reversed(range(0, t)), desc='interpolation sample time step', total=t): img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long)) return img def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) return ( extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) def p_losses(self, x_start, t, fea, bbox_mask, cond=None, noise=None, clip_denoised=True, **kwargs): # x_start: bs, 3, num_frame, 32, 32 # t: bs # fea: bs, 256, num_frame, 32, 32 # cond: bs, 768 b, c, f, h, w, device = *x_start.shape, x_start.device noise = default(noise, lambda: torch.randn_like(x_start)) # bs, 3, nf, 32, 32 x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)# bs, 3, nf, 32, 32 pred_noise = self.denoise_fn.forward(torch.cat([x_noisy, fea, bbox_mask], dim=1), t, cond=cond, null_cond_prob=self.null_cond_prob, **kwargs) if self.loss_type == 'l1': loss = F.l1_loss(noise, pred_noise, reduce=False) elif self.loss_type == 'l2': loss = F.mse_loss(noise, pred_noise, reduce=False) else: raise NotImplementedError() pred_x0 = self.predict_start_from_noise(x_noisy, t, pred_noise) if clip_denoised: s = 1. if self.use_dynamic_thres: s = torch.quantile( rearrange(pred_x0, 'b ... -> b (...)').abs(), self.dynamic_thres_percentile, dim=-1 ) s.clamp_(min=1.) s = s.view(-1, *((1,) * (pred_x0.ndim - 1))) # clip by threshold, depending on whether static or dynamic self.pred_x0 = pred_x0.clamp(-s, s) / s return loss, self.denoise_fn.null_cond_mask def forward(self, x, fea, bbox_mask, cond, *args, **kwargs): b, device, img_size, = x.shape[0], x.device, self.image_size # check_shape(x, 'b c f h w', c=self.channels, f=self.num_frames, h=img_size, w=img_size) t = torch.randint(0, self.num_timesteps, (b,), device=device).long() fea = fea.unsqueeze(dim=2).repeat(1, 1, x.size(2), 1, 1) bbox_mask = bbox_mask.unsqueeze(dim=2).repeat(1, 1, x.size(2), 1, 1) return self.p_losses(x, t, fea, bbox_mask, cond, *args, **kwargs) # trainer class CHANNELS_TO_MODE = { 1: 'L', 3: 'RGB', 4: 'RGBA' } def seek_all_images(img, channels=3): assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid' mode = CHANNELS_TO_MODE[channels] i = 0 while True: try: img.seek(i) yield img.convert(mode) except EOFError: break i += 1 # to dynamically change num_frames of GaussianDiffusion class DynamicNfGaussianDiffusion(GaussianDiffusion): def __init__(self, default_num_frames=20, *args, **kwargs): super(DynamicNfGaussianDiffusion, self).__init__(*args, **kwargs) self.default_num_frames = default_num_frames self.num_frames = default_num_frames def update_num_frames(self, new_num_frames): self.num_frames = new_num_frames # tensor of shape (channels, frames, height, width) -> gif def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True): images = map(T.ToPILImage(), tensor.unbind(dim=1)) first_img, *rest_imgs = images first_img.save(path, save_all=True, append_images=rest_imgs, duration=duration, loop=loop, optimize=optimize) return images # gif -> (channels, frame, height, width) tensor def gif_to_tensor(path, channels=3, transform=T.ToTensor()): img = Image.open(path) tensors = tuple(map(transform, seek_all_images(img, channels=channels))) return torch.stack(tensors, dim=1) def identity(t, *args, **kwargs): return t def normalize_img(t): return t * 2 - 1 # def unnormalize_img(t): # return (t + 1) * 0.5 def cast_num_frames(t, *, frames): f = t.shape[1] if f == frames: return t if f > frames: return t[:, :frames] return F.pad(t, (0, 0, 0, 0, 0, frames - f)) ================================================ FILE: DM_3/test_lr.py ================================================ import torch import torch.optim as optim from torch.optim.lr_scheduler import CosineAnnealingLR model = torch.nn.Linear(2,4) optimizer = optim.Adam(model.parameters(), lr=2e-5) scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-5) for epoch in range(100): scheduler.step() print(optimizer.param_groups[0]['lr']) ================================================ FILE: DM_3/train_vdm_hdtf_wpose_plus_faceemb_init_cond_liploss_6D.py ================================================ import sys sys.path.append('your/path/') import argparse from datetime import datetime, time import imageio import torch from torch.utils import data import numpy as np import torch.backends.cudnn as cudnn import os import os.path as osp import timeit import math from PIL import Image from misc import Logger, grid2fig, conf2fig from DM_3.datasets_hdtf_wpose_lmk_block_lmk import HDTF import sys import random from torch.utils.tensorboard import SummaryWriter from DM_3.utils import MultiEpochsDataLoader as DataLoader import time from DM_3.modules.video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_mouth_mask_6D import FlowDiffusion from torch.optim.lr_scheduler import MultiStepLR from sync_batchnorm import DataParallelWithCallback import torch.multiprocessing as mp start = timeit.default_timer() BATCH_SIZE = 20 # crema settings # MAX_EPOCH = 300 # epoch_milestones = [800, 1000] # hdtf MAX_EPOCH = 500 * 25 epoch_milestones = [800000, 10000000] root_dir = 'your/path' data_dir = "your/image/path" pose_dir = "your/pose/path" eye_blink_dir = "your/blink/path" GPU = "2" postfix = "-j-of" joint = "joint" in postfix or "-j" in postfix # allow joint training with unconditional model only_use_flow = "onlyflow" in postfix or "-of" in postfix # whether only use flow loss vgg_weight = 0 floss_weight = 0.15 if joint: null_cond_prob = 0.1 else: null_cond_prob = 0.0 if "upconv" in postfix: use_deconv = False padding_mode = "reflect" else: use_deconv = True padding_mode = "zeros" use_residual_flow = "-rf" in postfix learn_null_cond = "-lnc" in postfix INPUT_SIZE = 128 MAX_N_FRAMES = 20 LEARNING_RATE = 2e-4 RANDOM_SEED = 1234 clip_c = 2. print('use grad clip, clip = ', clip_c) MEAN = (0.0, 0.0, 0.0) config_pth = "./config/hdtf128.yaml" # PATH of LFG checkpoint AE_RESTORE_FROM = 'LFG/path' RESTORE_FROM = '' # use existing checkpoint DM_LOG_PATH = os.path.join(root_dir,'data','HDTF_wpose_faceemb_newae_6Dpose', 'ca_init_cond_liploss','stage1_0ref_1000epae_v0_lr_N='+str(MAX_N_FRAMES)) print(DM_LOG_PATH) SNAPSHOT_DIR = os.path.join(DM_LOG_PATH, 'snapshots' + postfix) IMGSHOT_DIR = os.path.join(DM_LOG_PATH, 'imgshots' + postfix) VIDSHOT_DIR = os.path.join(DM_LOG_PATH, "vidshots" + postfix) SAMPLE_DIR = os.path.join(DM_LOG_PATH, 'sample' + postfix) NUM_EXAMPLES_PER_EPOCH = 400 NUM_STEPS_PER_EPOCH = math.ceil(NUM_EXAMPLES_PER_EPOCH / float(BATCH_SIZE)) MAX_ITER = max(NUM_EXAMPLES_PER_EPOCH * MAX_EPOCH + 1, NUM_STEPS_PER_EPOCH * BATCH_SIZE * MAX_EPOCH + 1) SAVE_MODEL_EVERY = int(250000) SAVE_VID_EVERY = 4000 SAMPLE_VID_EVERY = 2000 UPDATE_MODEL_EVERY = 500 os.makedirs(SNAPSHOT_DIR, exist_ok=True) os.makedirs(IMGSHOT_DIR, exist_ok=True) os.makedirs(VIDSHOT_DIR, exist_ok=True) os.makedirs(SAMPLE_DIR, exist_ok=True) LOG_PATH = SNAPSHOT_DIR + "/B" + format(BATCH_SIZE, "04d") + "E" + format(MAX_EPOCH, "04d") + ".log" sys.stdout = Logger(LOG_PATH, sys.stdout) print(root_dir) print("update saved model every:", UPDATE_MODEL_EVERY) print("save model every:", SAVE_MODEL_EVERY) print("save video every:", SAVE_VID_EVERY) print("sample video every:", SAMPLE_VID_EVERY) print(postfix) print("RESTORE_FROM", RESTORE_FROM) print("num examples per epoch:", NUM_EXAMPLES_PER_EPOCH) print("max epoch:", MAX_EPOCH) print("image size", INPUT_SIZE) print("epoch milestones:", epoch_milestones) print("only use flow loss:", only_use_flow) print("null_cond_prob:", null_cond_prob) print("use residual flow:", use_residual_flow) print("learn null cond:", learn_null_cond) print("use deconv:", use_deconv) def get_arguments(): """Parse all the arguments provided from the CLI. Returns: A list of parsed arguments. """ parser = argparse.ArgumentParser(description="Flow Diffusion") parser.add_argument("--fine-tune", default=False) parser.add_argument("--set-start", default=True) parser.add_argument("--start-step", default=0, type=int) parser.add_argument("--img-dir", type=str, default=IMGSHOT_DIR, help="Where to save images of the model.") parser.add_argument("--num-workers", default=8) parser.add_argument("--final-step", type=int, default=int(NUM_STEPS_PER_EPOCH * MAX_EPOCH), help="Number of training steps.") parser.add_argument("--gpu", default=GPU, help="choose gpu device.") parser.add_argument('--print-freq', '-p', default=2, type=int, metavar='N', help='print frequency') parser.add_argument('--save-img-freq', default=2000, type=int, metavar='N', help='save image frequency') parser.add_argument('--save-vid-freq', default=SAVE_VID_EVERY, type=int) parser.add_argument('--sample-vid-freq', default=SAMPLE_VID_EVERY, type=int) parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, help="Number of images sent to the network in one step.") parser.add_argument("--input-size", type=str, default=INPUT_SIZE, help="Comma-separated string with height and width of images.") parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE, help="Base learning rate for training with polynomial decay.") parser.add_argument("--random-seed", type=int, default=RANDOM_SEED, help="Random seed to have reproducible results.") parser.add_argument("--restore-from", default=RESTORE_FROM) parser.add_argument("--save-pred-every", type=int, default=SAVE_MODEL_EVERY, help="Save checkpoint every often.") parser.add_argument("--update-pred-every", type=int, default=UPDATE_MODEL_EVERY) parser.add_argument("--snapshot-dir", type=str, default=SNAPSHOT_DIR, help="Where to save snapshots of the model.") parser.add_argument("--fp16", default=True) parser.add_argument("--cosin", default=True) return parser.parse_args() args = get_arguments() def sample_img(rec_img_batch, idx=0): rec_img = rec_img_batch[idx].permute(1, 2, 0).data.cpu().numpy().copy() rec_img += np.array(MEAN) / 255.0 rec_img[rec_img < 0] = 0 rec_img[rec_img > 1] = 1 rec_img *= 255 return np.array(rec_img, np.uint8) def main(): """Create the model and start the training.""" os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu cudnn.enabled = True cudnn.benchmark = True setup_seed(args.random_seed) writer = SummaryWriter(os.path.join(DM_LOG_PATH, 'tensorboard')) model = FlowDiffusion(is_train=True, img_size=INPUT_SIZE // 4, num_frames=MAX_N_FRAMES, null_cond_prob=null_cond_prob, sampling_timesteps=20, use_residual_flow=use_residual_flow, learn_null_cond=learn_null_cond, use_deconv=use_deconv, padding_mode=padding_mode, config_pth=config_pth, pretrained_pth=AE_RESTORE_FROM) model.cuda() scaler = torch.cuda.amp.GradScaler(enabled=args.fp16) # Not set model to be train mode! Because pretrained flow autoenc need to be eval (BatchNorm) # create optimizer optimizer_diff = torch.optim.Adam(model.diffusion.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.99)) if args.fine_tune: pass elif args.restore_from: if os.path.isfile(args.restore_from): print("=> loading checkpoint '{}'".format(args.restore_from)) checkpoint = torch.load(args.restore_from) if args.set_start: args.start_step = int(math.ceil(checkpoint['example'] / args.batch_size)) model_ckpt = model.diffusion.state_dict() for name, _ in model_ckpt.items(): model_ckpt[name].copy_(checkpoint['diffusion'][name]) model.diffusion.load_state_dict(model_ckpt) print("=> loaded checkpoint '{}'".format(args.restore_from)) if args.set_start: if "optimizer_diff" in list(checkpoint.keys()): optimizer_diff.load_state_dict(checkpoint['optimizer_diff']) else: print("=> no checkpoint found at '{}'".format(args.restore_from)) else: print("NO checkpoint found!") # enable the usage of multi-GPU model = DataParallelWithCallback(model) setup_seed(args.random_seed) trainloader = DataLoader(HDTF(data_dir=data_dir, pose_dir=pose_dir, eye_blink_dir = eye_blink_dir, image_size=INPUT_SIZE, max_num_frames=MAX_N_FRAMES, color_jitter=True, mean=MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,# args.num_workers, pin_memory=True) batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() losses_rec = AverageMeter() losses_warp = AverageMeter() losses_vgg = AverageMeter() cnt = 0 actual_step = args.start_step start_epoch = int(math.ceil((args.start_step * args.batch_size) / NUM_EXAMPLES_PER_EPOCH)) epoch_cnt = start_epoch if(not args.cosin): scheduler = MultiStepLR(optimizer_diff, epoch_milestones, gamma=0.1, last_epoch=start_epoch - 1) else: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_diff, T_max=MAX_EPOCH, eta_min=1e-6) print("epoch %d, lr= %.7f" % (epoch_cnt, optimizer_diff.param_groups[0]["lr"])) # start = time.time() torch.inverse(torch.ones((1,1), device = "cuda:0")) while actual_step < args.final_step: iter_end = timeit.default_timer() start_time = time.time() # start load_sum = 0 calculate_sum = 0 for i_iter, batch in enumerate(trainloader): if __debug__: end_time = time.time() # end # print(f'load time {end_time- start_time}') load_sum += end_time - start_time if end_time - start_time > 1: print('unnormal load: \t',i_iter) start_time = end_time actual_step = int(args.start_step + cnt) data_time.update(timeit.default_timer() - iter_end) real_vids, ref_hubert, real_poses, real_blink_bbox, mouth_lmk_tensor, real_names, _ = batch # ref_hubert, real_poses, real_blink_bbox : b, c, fn # use first frame of each video as reference frame ref_id = 0 # random.randint(0, real_vids.shape[2] - 1) ref_imgs = real_vids[:, :, ref_id, :, :].clone().detach() bs = real_vids.size(0) new_num_frames = real_vids.size(2) model.module.update_num_frames(new_num_frames) # end_time = time.time() # end # print(f'preprocess time {end_time- start_time}') # start_time = end_time # encode text # cond = bert_embed(tokenize(ref_texts), return_cls_repr=model.module.diffusion.text_use_bert_cls).cuda() is_eval = actual_step % args.save_vid_freq == 0 or actual_step % args.sample_vid_freq == 0 with torch.cuda.amp.autocast(enabled=args.fp16): train_output_dict = model.forward(real_vid=real_vids, ref_img=ref_imgs, ref_text=ref_hubert, ref_pose=real_poses, ref_eye_blink = real_blink_bbox[:, :2], bbox=real_blink_bbox[:, 2:], mouth_lmk = mouth_lmk_tensor, is_eval = is_eval, ref_id = ref_id) # optimize model optimizer_diff.zero_grad() # if only_use_flow: # scaler.scale(train_output_dict["loss"].mean()).backward() # else: # scaler.scale((train_output_dict["loss"].mean() + train_output_dict["rec_loss"].mean() + # train_output_dict["rec_warp_loss"].mean())).backward() scaler.scale(train_output_dict["loss"].mean() + floss_weight * train_output_dict['floss'].mean() + 0.15 * train_output_dict['mouth_loss'].mean()).backward() # optimizer_diff.step() scaler.unscale_(optimizer_diff) # loss.backward() if clip_c > 0.: torch.nn.utils.clip_grad_norm_(model.parameters(), clip_c) scaler.step(optimizer_diff) scaler.update() batch_time.update(timeit.default_timer() - iter_end) iter_end = timeit.default_timer() losses.update(train_output_dict["loss"].mean().item(), bs) losses_rec.update(train_output_dict["floss"].mean().item(), bs) losses_warp.update(train_output_dict["mouth_loss"].mean().item(), bs) # losses_vgg.update(train_output_dict["rec_vgg_loss"].mean().item(), bs) writer.add_scalar('train/loss', train_output_dict["loss"].mean().item(),actual_step) writer.add_scalar('train/floss', train_output_dict["floss"].mean().item(),actual_step) writer.add_scalar('train/mouth_loss', train_output_dict["mouth_loss"].mean().item(),actual_step) # writer.add_scalar('train/rec_loss', train_output_dict["rec_loss"].mean().item(),actual_step) # writer.add_scalar('train/rec_warp_loss', train_output_dict["rec_warp_loss"].mean().item(),actual_step) # writer.add_scalar('train/rec_vgg_loss', train_output_dict["rec_vgg_loss"].mean().item(),actual_step) if __debug__: end_time = time.time() # end # print(f'forward time {end_time- start_time}') calculate_sum += end_time - start_time start_time = end_time # if actual_step % 100 == 0: # end = time.time() # print("100 iter time:{0}".format(end-start)) if actual_step % args.print_freq == 0: current_time = datetime.now() current_time_str = current_time.strftime("%Y-%m-%d %H:%M:%S") print("Current time is:", current_time_str) print('iter: [{0}]{1}/{2}\t' 'loss {loss.val:.7f} ({loss.avg:.7f})\t' 'loss_rec {loss_rec.val:.4f} ({loss_rec.avg:.4f})\t' 'loss_warp {loss_warp.val:.4f} ({loss_warp.avg:.4f})' .format( cnt, actual_step, args.final_step, batch_time=batch_time, data_time=data_time, loss=losses, loss_rec=losses_rec, loss_warp=losses_warp, )) null_cond_mask = np.array(train_output_dict["null_cond_mask"].data.cpu().numpy(), dtype=np.uint8) if actual_step % args.save_vid_freq == 0: # and cnt != 0: print("saving video...") num_frames = real_vids.size(2) msk_size = ref_imgs.shape[-1] new_im_arr_list = [] save_src_img = sample_img(ref_imgs/255.) for nf in range(num_frames): save_tar_img = sample_img(real_vids[:, :, nf, :, :]/255.) # adapt fast version save_real_out_img = sample_img(train_output_dict["real_out_vid"][:, :, nf, :, :]) save_real_warp_img = sample_img(train_output_dict["real_warped_vid"][:, :, nf, :, :]) save_fake_out_img = sample_img(train_output_dict["fake_out_vid"][:, :, nf, :, :]) save_fake_warp_img = sample_img(train_output_dict["fake_warped_vid"][:, :, nf, :, :]) save_real_grid = grid2fig( train_output_dict["real_vid_grid"][0, :, nf].permute((1, 2, 0)).data.cpu().numpy(), grid_size=32, img_size=msk_size) save_fake_grid = grid2fig( train_output_dict["fake_vid_grid"][0, :, nf].permute((1, 2, 0)).data.cpu().numpy(), grid_size=32, img_size=msk_size) save_real_conf = conf2fig(train_output_dict["real_vid_conf"][0, :, nf]) save_fake_conf = conf2fig(train_output_dict["fake_vid_conf"][0, :, nf]) new_im = Image.new('RGB', (msk_size * 5, msk_size * 2)) new_im.paste(Image.fromarray(save_src_img, 'RGB'), (0, 0)) new_im.paste(Image.fromarray(save_tar_img, 'RGB'), (0, msk_size)) new_im.paste(Image.fromarray(save_real_out_img, 'RGB'), (msk_size, 0)) new_im.paste(Image.fromarray(save_real_warp_img, 'RGB'), (msk_size, msk_size)) new_im.paste(Image.fromarray(save_fake_out_img, 'RGB'), (msk_size * 2, 0)) new_im.paste(Image.fromarray(save_fake_warp_img, 'RGB'), (msk_size * 2, msk_size)) new_im.paste(Image.fromarray(save_real_grid, 'RGB'), (msk_size * 3, 0)) new_im.paste(Image.fromarray(save_fake_grid, 'RGB'), (msk_size * 3, msk_size)) new_im.paste(Image.fromarray(save_real_conf, 'L'), (msk_size * 4, 0)) new_im.paste(Image.fromarray(save_fake_conf, 'L'), (msk_size * 4, msk_size)) new_im_arr = np.array(new_im) new_im_arr_list.append(new_im_arr) new_vid_name = 'B' + format(args.batch_size, "04d") + '_S' + format(actual_step, "06d") \ + '_' + real_names[0] + "_%d.gif" % (null_cond_mask[0][0]) new_vid_file = os.path.join(VIDSHOT_DIR, new_vid_name) imageio.mimsave(new_vid_file, new_im_arr_list) new_im_arr_list = None new_im_arr = None new_im = None del new_im_arr_list, new_im_arr, new_im # sampling if actual_step % args.sample_vid_freq == 0: print("sampling video...") with torch.no_grad(): # cond = torch.concat([ref_hubert[0].unsqueeze(dim=0), real_poses[0].permute(1,0).unsqueeze(0), real_blink_bbox[0][:2].permute(1,0).unsqueeze(0)], dim=-1).cuda() sample_output_dict = model.module.sample_one_video(real_vid=real_vids.cuda()/255., sample_img=ref_imgs[0].unsqueeze(dim=0).cuda()/255., sample_audio_hubert = ref_hubert[0].unsqueeze(dim=0).cuda(), sample_pose = real_poses[0].unsqueeze(0).cuda(), sample_eye = real_blink_bbox[0][:2].unsqueeze(0).cuda(), sample_bbox = real_blink_bbox[0,2:].unsqueeze(0).cuda(), cond_scale=1.0) num_frames = real_vids.size(2) msk_size = ref_imgs.shape[-1] new_im_arr_list = [] save_src_img = sample_img(ref_imgs/255.) for nf in range(num_frames): save_tar_img = sample_img(real_vids[:, :, nf, :, :]/255.) save_real_out_img = sample_img(train_output_dict["real_out_vid"][:, :, nf, :, :]) save_real_warp_img = sample_img(train_output_dict["real_warped_vid"][:, :, nf, :, :]) save_sample_out_img = sample_img(sample_output_dict["sample_out_vid"][:, :, nf, :, :]) save_sample_warp_img = sample_img(sample_output_dict["sample_warped_vid"][:, :, nf, :, :]) save_real_grid = grid2fig( train_output_dict["real_vid_grid"][0, :, nf].permute((1, 2, 0)).data.cpu().numpy(), grid_size=32, img_size=msk_size) save_fake_grid = grid2fig( sample_output_dict["sample_vid_grid"][0, :, nf].permute((1, 2, 0)).data.cpu().numpy(), grid_size=32, img_size=msk_size) save_real_conf = conf2fig(train_output_dict["real_vid_conf"][0, :, nf]) save_fake_conf = conf2fig(sample_output_dict["sample_vid_conf"][0, :, nf]) new_im = Image.new('RGB', (msk_size * 5, msk_size * 2)) new_im.paste(Image.fromarray(save_src_img, 'RGB'), (0, 0)) new_im.paste(Image.fromarray(save_tar_img, 'RGB'), (0, msk_size)) new_im.paste(Image.fromarray(save_real_out_img, 'RGB'), (msk_size, 0)) new_im.paste(Image.fromarray(save_real_warp_img, 'RGB'), (msk_size, msk_size)) new_im.paste(Image.fromarray(save_sample_out_img, 'RGB'), (msk_size * 2, 0)) new_im.paste(Image.fromarray(save_sample_warp_img, 'RGB'), (msk_size * 2, msk_size)) new_im.paste(Image.fromarray(save_real_grid, 'RGB'), (msk_size * 3, 0)) new_im.paste(Image.fromarray(save_fake_grid, 'RGB'), (msk_size * 3, msk_size)) new_im.paste(Image.fromarray(save_real_conf, 'L'), (msk_size * 4, 0)) new_im.paste(Image.fromarray(save_fake_conf, 'L'), (msk_size * 4, msk_size)) new_im_arr = np.array(new_im) new_im_arr_list.append(new_im_arr) new_vid_name = 'B' + format(args.batch_size, "04d") + '_S' + format(actual_step, "06d") \ + '_' + real_names[0] + ".gif" new_vid_file = os.path.join(SAMPLE_DIR, new_vid_name) imageio.mimsave(new_vid_file, new_im_arr_list) new_im_arr_list = None new_im_arr = None new_im = None del new_im_arr_list, new_im_arr, new_im # save model at i-th step if actual_step % args.save_pred_every == 0 and cnt != 0: print('taking snapshot ...') torch.save({'example': actual_step * args.batch_size, 'diffusion': model.module.diffusion.state_dict(), 'optimizer_diff': optimizer_diff.state_dict()}, osp.join(args.snapshot_dir, 'flowdiff_' + format(args.batch_size, "04d") + '_S' + format(actual_step, "06d") + '.pth')) # update saved model if actual_step % args.update_pred_every == 0 and cnt != 0: print('updating saved snapshot ...') torch.save({'example': actual_step * args.batch_size, 'diffusion': model.module.diffusion.state_dict(), 'optimizer_diff': optimizer_diff.state_dict()}, osp.join(args.snapshot_dir, 'flowdiff.pth')) if actual_step >= args.final_step: break cnt += 1 # if __debug__: # end_time = time.time() # end # print(f'orther time 1: {end_time- start_time}') # start_time = end_time del real_vids, ref_imgs, ref_hubert, real_names, null_cond_mask del train_output_dict del batch # torch.cuda.empty_cache() # if __debug__: # end_time = time.time() # end # print(f'orther time time {end_time- start_time}') # start_time = end_time scheduler.step() epoch_cnt += 1 print("epoch %d, lr= %.7f" % (epoch_cnt, optimizer_diff.param_groups[0]["lr"])) if __debug__: print('load_sum: ', load_sum) print('calculate_sum: ', calculate_sum) print('save the final model ...') torch.save({'example': actual_step * args.batch_size, 'diffusion': model.module.diffusion.state_dict(), 'optimizer_diff': optimizer_diff.state_dict()}, osp.join(args.snapshot_dir, 'flowdiff_' + format(args.batch_size, "04d") + '_S' + format(actual_step, "06d") + '.pth')) end = timeit.default_timer() print(end - start, 'seconds') class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True if __name__ == '__main__': # mp.set_start_method('spawn') # torch.multiprocessing.set_start_method("spawn") main() ================================================ FILE: DM_3/train_vdm_hdtf_wpose_plus_faceemb_init_cond_liploss_6D_s2.py ================================================ import sys sys.path.append('your/path/') import argparse from datetime import datetime, time import imageio import torch from torch.utils import data import numpy as np import torch.backends.cudnn as cudnn import os import os.path as osp import timeit import math from PIL import Image from misc import Logger, grid2fig, conf2fig from DM_3.datasets_hdtf_wpose_lmk_block_lmk_rand import HDTF import sys import random from torch.utils.tensorboard import SummaryWriter from DM_3.utils import MultiEpochsDataLoader as DataLoader import time from DM_3.modules.video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_mouth_mask_rand_6D import FlowDiffusion from torch.optim.lr_scheduler import MultiStepLR from sync_batchnorm import DataParallelWithCallback import torch.multiprocessing as mp start = timeit.default_timer() BATCH_SIZE = 40 # crema settings # MAX_EPOCH = 300 # epoch_milestones = [800, 1000] # hdtf MAX_EPOCH = 500 * 25 epoch_milestones = [800000, 10000000] root_dir = 'your/path' data_dir = "your/image/path" pose_dir = "your/pose/path" eye_blink_dir = "your/blink/path" GPU = "2" postfix = "-j-of-s2" joint = "joint" in postfix or "-j" in postfix # allow joint training with unconditional model only_use_flow = "onlyflow" in postfix or "-of" in postfix # whether only use flow loss DYNAMIC_FRAMES = '-s2' in postfix vgg_weight = 0 floss_weight = 0.15 if joint: null_cond_prob = 0.1 else: null_cond_prob = 0.0 if "upconv" in postfix: use_deconv = False padding_mode = "reflect" else: use_deconv = True padding_mode = "zeros" use_residual_flow = "-rf" in postfix learn_null_cond = "-lnc" in postfix INPUT_SIZE = 128 if DYNAMIC_FRAMES: MAX_N_FRAMES = 40 MIN_N_FRAMES = 30 else: MAX_N_FRAMES = 20 LEARNING_RATE = 2e-4 RANDOM_SEED = 1234 clip_c = 2. print('use grad clip, clip = ', clip_c) MEAN = (0.0, 0.0, 0.0) config_pth = "./config/hdtf128.yaml" # PATH of LFG checkpoint AE_RESTORE_FROM = 'LFG/path' RESTORE_FROM = '' # use existing checkpoint DM_LOG_PATH = os.path.join(root_dir,'data','HDTF_wpose_faceemb_newae_6Dpose', 'ca_init_cond_liploss','fromstart_rand_df_liploss_lr_N='+str(MAX_N_FRAMES)) print(DM_LOG_PATH) SNAPSHOT_DIR = os.path.join(DM_LOG_PATH, 'snapshots' + postfix) IMGSHOT_DIR = os.path.join(DM_LOG_PATH, 'imgshots' + postfix) VIDSHOT_DIR = os.path.join(DM_LOG_PATH, "vidshots" + postfix) SAMPLE_DIR = os.path.join(DM_LOG_PATH, 'sample' + postfix) NUM_EXAMPLES_PER_EPOCH = 400 NUM_STEPS_PER_EPOCH = math.ceil(NUM_EXAMPLES_PER_EPOCH / float(BATCH_SIZE)) MAX_ITER = max(NUM_EXAMPLES_PER_EPOCH * MAX_EPOCH + 1, NUM_STEPS_PER_EPOCH * BATCH_SIZE * MAX_EPOCH + 1) SAVE_MODEL_EVERY = int(100000/10) SAVE_VID_EVERY = 4000 SAMPLE_VID_EVERY = 2000 UPDATE_MODEL_EVERY = 500 os.makedirs(SNAPSHOT_DIR, exist_ok=True) os.makedirs(IMGSHOT_DIR, exist_ok=True) os.makedirs(VIDSHOT_DIR, exist_ok=True) os.makedirs(SAMPLE_DIR, exist_ok=True) LOG_PATH = SNAPSHOT_DIR + "/B" + format(BATCH_SIZE, "04d") + "E" + format(MAX_EPOCH, "04d") + ".log" sys.stdout = Logger(LOG_PATH, sys.stdout) print(root_dir) print("update saved model every:", UPDATE_MODEL_EVERY) print("save model every:", SAVE_MODEL_EVERY) print("save video every:", SAVE_VID_EVERY) print("sample video every:", SAMPLE_VID_EVERY) print(postfix) print("RESTORE_FROM", RESTORE_FROM) print("num examples per epoch:", NUM_EXAMPLES_PER_EPOCH) print("max epoch:", MAX_EPOCH) print("image size", INPUT_SIZE) print("epoch milestones:", epoch_milestones) print("only use flow loss:", only_use_flow) print("null_cond_prob:", null_cond_prob) print("use residual flow:", use_residual_flow) print("learn null cond:", learn_null_cond) print("use deconv:", use_deconv) def get_arguments(): """Parse all the arguments provided from the CLI. Returns: A list of parsed arguments. """ parser = argparse.ArgumentParser(description="Flow Diffusion") parser.add_argument("--fine-tune", default=False) parser.add_argument("--set-start", default=True) parser.add_argument("--start-step", default=0, type=int) parser.add_argument("--img-dir", type=str, default=IMGSHOT_DIR, help="Where to save images of the model.") parser.add_argument("--num-workers", default=8) parser.add_argument("--final-step", type=int, default=int(NUM_STEPS_PER_EPOCH * MAX_EPOCH), help="Number of training steps.") parser.add_argument("--gpu", default=GPU, help="choose gpu device.") parser.add_argument('--print-freq', '-p', default=2, type=int, metavar='N', help='print frequency') parser.add_argument('--save-img-freq', default=2000, type=int, metavar='N', help='save image frequency') parser.add_argument('--save-vid-freq', default=SAVE_VID_EVERY, type=int) parser.add_argument('--sample-vid-freq', default=SAMPLE_VID_EVERY, type=int) parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, help="Number of images sent to the network in one step.") parser.add_argument("--input-size", type=str, default=INPUT_SIZE, help="Comma-separated string with height and width of images.") parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE, help="Base learning rate for training with polynomial decay.") parser.add_argument("--random-seed", type=int, default=RANDOM_SEED, help="Random seed to have reproducible results.") parser.add_argument("--restore-from", default=RESTORE_FROM) parser.add_argument("--save-pred-every", type=int, default=SAVE_MODEL_EVERY, help="Save checkpoint every often.") parser.add_argument("--update-pred-every", type=int, default=UPDATE_MODEL_EVERY) parser.add_argument("--snapshot-dir", type=str, default=SNAPSHOT_DIR, help="Where to save snapshots of the model.") parser.add_argument("--fp16", default=True) parser.add_argument("--cosin", default=True) return parser.parse_args() args = get_arguments() def sample_img(rec_img_batch, idx=0): rec_img = rec_img_batch[idx].permute(1, 2, 0).data.cpu().numpy().copy() rec_img += np.array(MEAN) / 255.0 rec_img[rec_img < 0] = 0 rec_img[rec_img > 1] = 1 rec_img *= 255 return np.array(rec_img, np.uint8) def main(): """Create the model and start the training.""" os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu cudnn.enabled = True cudnn.benchmark = True setup_seed(args.random_seed) writer = SummaryWriter(os.path.join(DM_LOG_PATH, 'tensorboard')) model = FlowDiffusion(is_train=True, img_size=INPUT_SIZE // 4, num_frames=MAX_N_FRAMES, null_cond_prob=null_cond_prob, sampling_timesteps=20, use_residual_flow=use_residual_flow, learn_null_cond=learn_null_cond, use_deconv=use_deconv, padding_mode=padding_mode, config_pth=config_pth, pretrained_pth=AE_RESTORE_FROM) model.cuda() scaler = torch.cuda.amp.GradScaler(enabled=args.fp16) # Not set model to be train mode! Because pretrained flow autoenc need to be eval (BatchNorm) # create optimizer optimizer_diff = torch.optim.Adam(model.diffusion.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.99)) if args.fine_tune: pass elif args.restore_from: if os.path.isfile(args.restore_from): print("=> loading checkpoint '{}'".format(args.restore_from)) checkpoint = torch.load(args.restore_from) if args.set_start: args.start_step = int(math.ceil(checkpoint['example'] / args.batch_size)) model_ckpt = model.diffusion.state_dict() for name, _ in model_ckpt.items(): model_ckpt[name].copy_(checkpoint['diffusion'][name]) model.diffusion.load_state_dict(model_ckpt) print("=> loaded checkpoint '{}'".format(args.restore_from)) if args.set_start: if "optimizer_diff" in list(checkpoint.keys()): optimizer_diff.load_state_dict(checkpoint['optimizer_diff']) else: print("=> no checkpoint found at '{}'".format(args.restore_from)) else: print("NO checkpoint found!") # enable the usage of multi-GPU model = DataParallelWithCallback(model) setup_seed(args.random_seed) trainloader = DataLoader(HDTF(data_dir=data_dir, pose_dir=pose_dir, eye_blink_dir = eye_blink_dir, image_size=INPUT_SIZE, max_num_frames=MAX_N_FRAMES, color_jitter=True, mean=MEAN), batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,# args.num_workers, pin_memory=True) batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() losses_rec = AverageMeter() losses_warp = AverageMeter() losses_vgg = AverageMeter() cnt = 0 actual_step = args.start_step start_epoch = int(math.ceil((args.start_step * args.batch_size) / NUM_EXAMPLES_PER_EPOCH)) epoch_cnt = start_epoch if(not args.cosin): scheduler = MultiStepLR(optimizer_diff, epoch_milestones, gamma=0.1, last_epoch=start_epoch - 1) else: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_diff, T_max=MAX_EPOCH, eta_min=1e-6) print("epoch %d, lr= %.7f" % (epoch_cnt, optimizer_diff.param_groups[0]["lr"])) # start = time.time() torch.inverse(torch.ones((1,1), device = "cuda:0")) while actual_step < args.final_step: iter_end = timeit.default_timer() start_time = time.time() load_sum = 0 calculate_sum = 0 for i_iter, batch in enumerate(trainloader): if __debug__: end_time = time.time() # end # print(f'load time {end_time- start_time}') load_sum += end_time - start_time if end_time - start_time > 1: print('unnormal load: \t',i_iter) start_time = end_time actual_step = int(args.start_step + cnt) data_time.update(timeit.default_timer() - iter_end) real_vids, ref_hubert, real_poses, real_blink_bbox, mouth_lmk_tensor, real_names, _ = batch if(DYNAMIC_FRAMES == True): selct_frames = random.randint(MIN_N_FRAMES, MAX_N_FRAMES) + 1 selct_start = 0 real_vids = real_vids[:,:,selct_start:selct_start+selct_frames,:,:] ref_hubert = ref_hubert[:,selct_start:selct_start+selct_frames,:] real_poses = real_poses[:,:,selct_start:selct_start+selct_frames] mouth_lmk_tensor = mouth_lmk_tensor[:,selct_start:selct_start+selct_frames -1] real_blink_bbox = real_blink_bbox[:,:,selct_start:selct_start+selct_frames] # ref_hubert, real_poses, real_blink_bbox : b, c, fn # use first frame of each video as reference frame ref_id = 0# random.randint(0, real_vids.shape[2] - 1) ref_imgs = real_vids[:, :, ref_id, :, :].clone().detach() bs = real_vids.size(0) new_num_frames = real_vids.size(2) -1 model.module.update_num_frames(new_num_frames) # end_time = time.time() # end # print(f'preprocess time {end_time- start_time}') # start_time = end_time # encode text # cond = bert_embed(tokenize(ref_texts), return_cls_repr=model.module.diffusion.text_use_bert_cls).cuda() is_eval = actual_step % args.save_vid_freq == 0 or actual_step % args.sample_vid_freq == 0 with torch.cuda.amp.autocast(enabled=args.fp16): train_output_dict = model.forward(real_vid=real_vids, ref_img=ref_imgs, ref_text=ref_hubert, ref_pose=real_poses, ref_eye_blink = real_blink_bbox[:, :2], bbox=real_blink_bbox[:, 2:], mouth_lmk = mouth_lmk_tensor, is_eval = is_eval) # optimize model optimizer_diff.zero_grad() # if only_use_flow: # scaler.scale(train_output_dict["loss"].mean()).backward() # else: # scaler.scale((train_output_dict["loss"].mean() + train_output_dict["rec_loss"].mean() + # train_output_dict["rec_warp_loss"].mean())).backward() scaler.scale(train_output_dict["loss"].mean() + floss_weight * train_output_dict['floss'].mean() + 0.15 * train_output_dict['mouth_loss'].mean()).backward() # optimizer_diff.step() scaler.unscale_(optimizer_diff) # loss.backward() if clip_c > 0.: torch.nn.utils.clip_grad_norm_(model.parameters(), clip_c) has_nan_grad = False for name, param in model.named_parameters(): if param.grad != None and torch.isnan(param.grad).any(): has_nan_grad = True print(name) break if has_nan_grad: print("grad NaN, dont update param") scaler.update() continue scaler.step(optimizer_diff) scaler.update() batch_time.update(timeit.default_timer() - iter_end) iter_end = timeit.default_timer() losses.update(train_output_dict["loss"].mean().item(), bs) losses_rec.update(train_output_dict["floss"].mean().item(), bs) losses_warp.update(train_output_dict["mouth_loss"].mean().item(), bs) # losses_vgg.update(train_output_dict["rec_vgg_loss"].mean().item(), bs) writer.add_scalar('train/loss', train_output_dict["loss"].mean().item(),actual_step) writer.add_scalar('train/floss', train_output_dict["floss"].mean().item(),actual_step) writer.add_scalar('train/mouth_loss', train_output_dict["mouth_loss"].mean().item(),actual_step) # writer.add_scalar('train/rec_loss', train_output_dict["rec_loss"].mean().item(),actual_step) # writer.add_scalar('train/rec_warp_loss', train_output_dict["rec_warp_loss"].mean().item(),actual_step) # writer.add_scalar('train/rec_vgg_loss', train_output_dict["rec_vgg_loss"].mean().item(),actual_step) if __debug__: end_time = time.time() # end # print(f'forward time {end_time- start_time}') calculate_sum += end_time - start_time start_time = end_time # if actual_step % 100 == 0: # end = time.time() # print("100 iter time:{0}".format(end-start)) if actual_step % args.print_freq == 0: current_time = datetime.now() current_time_str = current_time.strftime("%Y-%m-%d %H:%M:%S") print("Current time is:", current_time_str) print('iter: [{0}]{1}/{2}\t' 'loss {loss.val:.7f} ({loss.avg:.7f})\t' 'loss_mse {loss_rec.val:.4f} ({loss_rec.avg:.4f})\t' 'loss_lip {loss_warp.val:.4f} ({loss_warp.avg:.4f})' .format( cnt, actual_step, args.final_step, batch_time=batch_time, data_time=data_time, loss=losses, loss_rec=losses_rec, loss_warp=losses_warp, )) null_cond_mask = np.array(train_output_dict["null_cond_mask"].data.cpu().numpy(), dtype=np.uint8) if actual_step % args.save_vid_freq == 0: # and cnt != 0: torch.cuda.empty_cache() print("saving video...") num_frames = real_vids.size(2) - 1 msk_size = ref_imgs.shape[-1] new_im_arr_list = [] save_src_img = sample_img(ref_imgs/255.) for nf in range(num_frames): save_tar_img = sample_img(real_vids[:, :, nf, :, :]/255.) # adapt fast version save_real_out_img = sample_img(train_output_dict["real_out_vid"][:, :, nf, :, :]) save_real_warp_img = sample_img(train_output_dict["real_warped_vid"][:, :, nf, :, :]) save_fake_out_img = sample_img(train_output_dict["fake_out_vid"][:, :, nf, :, :]) save_fake_warp_img = sample_img(train_output_dict["fake_warped_vid"][:, :, nf, :, :]) save_real_grid = grid2fig( train_output_dict["real_vid_grid"][0, :, nf].permute((1, 2, 0)).data.cpu().numpy(), grid_size=32, img_size=msk_size) save_fake_grid = grid2fig( train_output_dict["fake_vid_grid"][0, :, nf].permute((1, 2, 0)).data.cpu().numpy(), grid_size=32, img_size=msk_size) save_real_conf = conf2fig(train_output_dict["real_vid_conf"][0, :, nf]) save_fake_conf = conf2fig(train_output_dict["fake_vid_conf"][0, :, nf]) new_im = Image.new('RGB', (msk_size * 5, msk_size * 2)) new_im.paste(Image.fromarray(save_src_img, 'RGB'), (0, 0)) new_im.paste(Image.fromarray(save_tar_img, 'RGB'), (0, msk_size)) new_im.paste(Image.fromarray(save_real_out_img, 'RGB'), (msk_size, 0)) new_im.paste(Image.fromarray(save_real_warp_img, 'RGB'), (msk_size, msk_size)) new_im.paste(Image.fromarray(save_fake_out_img, 'RGB'), (msk_size * 2, 0)) new_im.paste(Image.fromarray(save_fake_warp_img, 'RGB'), (msk_size * 2, msk_size)) new_im.paste(Image.fromarray(save_real_grid, 'RGB'), (msk_size * 3, 0)) new_im.paste(Image.fromarray(save_fake_grid, 'RGB'), (msk_size * 3, msk_size)) new_im.paste(Image.fromarray(save_real_conf, 'L'), (msk_size * 4, 0)) new_im.paste(Image.fromarray(save_fake_conf, 'L'), (msk_size * 4, msk_size)) new_im_arr = np.array(new_im) new_im_arr_list.append(new_im_arr) new_vid_name = 'B' + format(args.batch_size, "04d") + '_S' + format(actual_step, "06d") \ + '_' + real_names[0] + "_%d.gif" % (null_cond_mask[0][0]) new_vid_file = os.path.join(VIDSHOT_DIR, new_vid_name) imageio.mimsave(new_vid_file, new_im_arr_list) new_im_arr_list = None new_im_arr = None new_im = None del new_im_arr_list, new_im_arr, new_im # sampling if actual_step % args.sample_vid_freq == 0: torch.cuda.empty_cache() real_vids = real_vids[:,:,1:,:,:] print("sampling video...") with torch.no_grad(): # cond = torch.concat([ref_hubert[0].unsqueeze(dim=0), real_poses[0].permute(1,0).unsqueeze(0), real_blink_bbox[0][:2].permute(1,0).unsqueeze(0)], dim=-1).cuda() sample_output_dict = model.module.sample_one_video(real_vid=real_vids.cuda()/255., sample_img=ref_imgs[0].unsqueeze(dim=0).cuda()/255., sample_audio_hubert = ref_hubert[0].unsqueeze(dim=0).cuda(), sample_pose = real_poses[0].unsqueeze(0).cuda(), sample_eye = real_blink_bbox[0][:2].unsqueeze(0).cuda(), sample_bbox = real_blink_bbox[0,2:].unsqueeze(0).cuda(), cond_scale=1.0) num_frames = real_vids.size(2) msk_size = ref_imgs.shape[-1] new_im_arr_list = [] save_src_img = sample_img(ref_imgs/255.) for nf in range(num_frames): save_tar_img = sample_img(real_vids[:, :, nf, :, :]/255.) save_real_out_img = sample_img(train_output_dict["real_out_vid"][:, :, nf, :, :]) save_real_warp_img = sample_img(train_output_dict["real_warped_vid"][:, :, nf, :, :]) save_sample_out_img = sample_img(sample_output_dict["sample_out_vid"][:, :, nf, :, :]) save_sample_warp_img = sample_img(sample_output_dict["sample_warped_vid"][:, :, nf, :, :]) save_real_grid = grid2fig( train_output_dict["real_vid_grid"][0, :, nf].permute((1, 2, 0)).data.cpu().numpy(), grid_size=32, img_size=msk_size) save_fake_grid = grid2fig( sample_output_dict["sample_vid_grid"][0, :, nf].permute((1, 2, 0)).data.cpu().numpy(), grid_size=32, img_size=msk_size) save_real_conf = conf2fig(train_output_dict["real_vid_conf"][0, :, nf]) save_fake_conf = conf2fig(sample_output_dict["sample_vid_conf"][0, :, nf]) new_im = Image.new('RGB', (msk_size * 5, msk_size * 2)) new_im.paste(Image.fromarray(save_src_img, 'RGB'), (0, 0)) new_im.paste(Image.fromarray(save_tar_img, 'RGB'), (0, msk_size)) new_im.paste(Image.fromarray(save_real_out_img, 'RGB'), (msk_size, 0)) new_im.paste(Image.fromarray(save_real_warp_img, 'RGB'), (msk_size, msk_size)) new_im.paste(Image.fromarray(save_sample_out_img, 'RGB'), (msk_size * 2, 0)) new_im.paste(Image.fromarray(save_sample_warp_img, 'RGB'), (msk_size * 2, msk_size)) new_im.paste(Image.fromarray(save_real_grid, 'RGB'), (msk_size * 3, 0)) new_im.paste(Image.fromarray(save_fake_grid, 'RGB'), (msk_size * 3, msk_size)) new_im.paste(Image.fromarray(save_real_conf, 'L'), (msk_size * 4, 0)) new_im.paste(Image.fromarray(save_fake_conf, 'L'), (msk_size * 4, msk_size)) new_im_arr = np.array(new_im) new_im_arr_list.append(new_im_arr) new_vid_name = 'B' + format(args.batch_size, "04d") + '_S' + format(actual_step, "06d") \ + '_' + real_names[0] + ".gif" new_vid_file = os.path.join(SAMPLE_DIR, new_vid_name) imageio.mimsave(new_vid_file, new_im_arr_list) new_im_arr_list = None new_im_arr = None new_im = None del new_im_arr_list, new_im_arr, new_im # save model at i-th step if actual_step % args.save_pred_every == 0 and cnt != 0 and actual_step > args.final_step // 2: print('taking snapshot ...') torch.save({'example': actual_step * args.batch_size, 'diffusion': model.module.diffusion.state_dict(), 'optimizer_diff': optimizer_diff.state_dict()}, osp.join(args.snapshot_dir, 'flowdiff_' + format(args.batch_size, "04d") + '_S' + format(actual_step, "06d") + '.pth')) # update saved model if actual_step % args.update_pred_every == 0 and cnt != 0: print('updating saved snapshot ...') torch.save({'example': actual_step * args.batch_size, 'diffusion': model.module.diffusion.state_dict(), 'optimizer_diff': optimizer_diff.state_dict()}, osp.join(args.snapshot_dir, 'flowdiff.pth')) if actual_step >= args.final_step: break cnt += 1 del real_vids, ref_imgs, ref_hubert, real_names, null_cond_mask del train_output_dict del batch # torch.cuda.empty_cache() scheduler.step() epoch_cnt += 1 print("epoch %d, lr= %.7f" % (epoch_cnt, optimizer_diff.param_groups[0]["lr"])) if __debug__: print('load_sum: ', load_sum) print('calculate_sum: ', calculate_sum) print('save the final model ...') torch.save({'example': actual_step * args.batch_size, 'diffusion': model.module.diffusion.state_dict(), 'optimizer_diff': optimizer_diff.state_dict()}, osp.join(args.snapshot_dir, 'flowdiff_' + format(args.batch_size, "04d") + '_S' + format(actual_step, "06d") + '.pth')) end = timeit.default_timer() print(end - start, 'seconds') class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True if __name__ == '__main__': # mp.set_start_method('spawn') # torch.multiprocessing.set_start_method("spawn") main() ================================================ FILE: DM_3/utils.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F class MultiEpochsDataLoader(torch.utils.data.DataLoader): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._DataLoader__initialized = False self.batch_sampler = _RepeatSampler(self.batch_sampler) self._DataLoader__initialized = True self.iterator = super().__iter__() def __len__(self): return len(self.batch_sampler.sampler) def __iter__(self): for i in range(len(self)): yield next(self.iterator) class _RepeatSampler(object): """ Sampler that repeats forever. Args: sampler (Sampler) """ def __init__(self, sampler): self.sampler = sampler def __iter__(self): while True: yield from iter(self.sampler) ================================================ FILE: LFG/__init__.py ================================================ ================================================ FILE: LFG/augmentation.py ================================================ """ Code from https://github.com/hassony2/torch_videovision """ import numbers import random import numpy as np import PIL from skimage.transform import resize, rotate from numpy import pad import torchvision import warnings from skimage import img_as_ubyte, img_as_float def crop_clip(clip, min_h, min_w, h, w): if isinstance(clip[0], np.ndarray): cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] elif isinstance(clip[0], PIL.Image.Image): cropped = [ img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip ] else: raise TypeError('Expected numpy.ndarray or PIL.Image' + 'but got list of {0}'.format(type(clip[0]))) return cropped def pad_clip(clip, h, w): im_h, im_w = clip[0].shape[:2] pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2) pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2) return pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge') def resize_clip(clip, size, interpolation='bilinear'): if isinstance(clip[0], np.ndarray): if isinstance(size, numbers.Number): im_h, im_w, im_c = clip[0].shape # Min spatial dim already matches minimal size if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size): return clip new_h, new_w = get_resize_sizes(im_h, im_w, size) size = (new_w, new_h) else: size = size[1], size[0] scaled = [ resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True, mode='constant', anti_aliasing=True) for img in clip ] elif isinstance(clip[0], PIL.Image.Image): if isinstance(size, numbers.Number): im_w, im_h = clip[0].size # Min spatial dim already matches minimal size if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size): return clip new_h, new_w = get_resize_sizes(im_h, im_w, size) size = (new_w, new_h) else: size = size[1], size[0] if interpolation == 'bilinear': pil_inter = PIL.Image.NEAREST else: pil_inter = PIL.Image.BILINEAR scaled = [img.resize(size, pil_inter) for img in clip] else: raise TypeError('Expected numpy.ndarray or PIL.Image' + 'but got list of {0}'.format(type(clip[0]))) return scaled def get_resize_sizes(im_h, im_w, size): if im_w < im_h: ow = size oh = int(size * im_h / im_w) else: oh = size ow = int(size * im_w / im_h) return oh, ow class RandomFlip(object): def __init__(self, time_flip=False, horizontal_flip=False): self.time_flip = time_flip self.horizontal_flip = horizontal_flip def __call__(self, clip): if random.random() < 0.5 and self.time_flip: return clip[::-1] if random.random() < 0.5 and self.horizontal_flip: return [np.fliplr(img) for img in clip] return clip class RandomResize(object): """Resizes a list of (H x W x C) numpy.ndarray to the final size The larger the original image is, the more times it takes to interpolate Args: interpolation (str): Can be one of 'nearest', 'bilinear' defaults to nearest size (tuple): (widht, height) """ def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'): self.ratio = ratio self.interpolation = interpolation def __call__(self, clip): scaling_factor = random.uniform(self.ratio[0], self.ratio[1]) if isinstance(clip[0], np.ndarray): im_h, im_w, im_c = clip[0].shape elif isinstance(clip[0], PIL.Image.Image): im_w, im_h = clip[0].size new_w = int(im_w * scaling_factor) new_h = int(im_h * scaling_factor) new_size = (new_w, new_h) resized = resize_clip( clip, new_size, interpolation=self.interpolation) return resized class RandomCrop(object): """Extract random crop at the same location for a list of videos Args: size (sequence or int): Desired output size for the crop in format (h, w) """ def __init__(self, size): if isinstance(size, numbers.Number): size = (size, size) self.size = size def __call__(self, clip): """ Args: img (PIL.Image or numpy.ndarray): List of videos to be cropped in format (h, w, c) in numpy.ndarray Returns: PIL.Image or numpy.ndarray: Cropped list of videos """ h, w = self.size if isinstance(clip[0], np.ndarray): im_h, im_w, im_c = clip[0].shape elif isinstance(clip[0], PIL.Image.Image): im_w, im_h = clip[0].size else: raise TypeError('Expected numpy.ndarray or PIL.Image' + 'but got list of {0}'.format(type(clip[0]))) clip = pad_clip(clip, h, w) im_h, im_w = clip.shape[1:3] x1 = 0 if h == im_h else random.randint(0, im_w - w) y1 = 0 if w == im_w else random.randint(0, im_h - h) cropped = crop_clip(clip, y1, x1, h, w) return cropped class RandomRotation(object): """Rotate entire clip randomly by a random angle within given bounds Args: degrees (sequence or int): Range of degrees to select from If degrees is a number instead of sequence like (min, max), the range of degrees, will be (-degrees, +degrees). """ def __init__(self, degrees): if isinstance(degrees, numbers.Number): if degrees < 0: raise ValueError('If degrees is a single number,' 'must be positive') degrees = (-degrees, degrees) else: if len(degrees) != 2: raise ValueError('If degrees is a sequence,' 'it must be of len 2.') self.degrees = degrees def __call__(self, clip): """ Args: img (PIL.Image or numpy.ndarray): List of videos to be cropped in format (h, w, c) in numpy.ndarray Returns: PIL.Image or numpy.ndarray: Cropped list of videos """ angle = random.uniform(self.degrees[0], self.degrees[1]) if isinstance(clip[0], np.ndarray): rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip] elif isinstance(clip[0], PIL.Image.Image): rotated = [img.rotate(angle) for img in clip] else: raise TypeError('Expected numpy.ndarray or PIL.Image' + 'but got list of {0}'.format(type(clip[0]))) return rotated class ColorJitter(object): """Randomly change the brightness, contrast and saturation and hue of the clip Args: brightness (float): How much to jitter brightness. brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. contrast (float): How much to jitter contrast. contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. saturation (float): How much to jitter saturation. saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. hue(float): How much to jitter hue. hue_factor is chosen uniformly from [-hue, hue]. Should be >=0 and <= 0.5. """ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): self.brightness = brightness self.contrast = contrast self.saturation = saturation self.hue = hue def get_params(self, brightness, contrast, saturation, hue): if brightness > 0: brightness_factor = random.uniform( max(0, 1 - brightness), 1 + brightness) else: brightness_factor = None if contrast > 0: contrast_factor = random.uniform( max(0, 1 - contrast), 1 + contrast) else: contrast_factor = None if saturation > 0: saturation_factor = random.uniform( max(0, 1 - saturation), 1 + saturation) else: saturation_factor = None if hue > 0: hue_factor = random.uniform(-hue, hue) else: hue_factor = None return brightness_factor, contrast_factor, saturation_factor, hue_factor def __call__(self, clip): """ Args: clip (list): list of PIL.Image Returns: list PIL.Image : list of transformed PIL.Image """ if isinstance(clip[0], np.ndarray): brightness, contrast, saturation, hue = self.get_params( self.brightness, self.contrast, self.saturation, self.hue) # Create img transform function sequence img_transforms = [] if brightness is not None: img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) if saturation is not None: img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) if hue is not None: img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) if contrast is not None: img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) random.shuffle(img_transforms) img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array, img_as_float] with warnings.catch_warnings(): warnings.simplefilter("ignore") jittered_clip = [] for img in clip: jittered_img = img for func in img_transforms: jittered_img = func(jittered_img) jittered_clip.append(jittered_img.astype('float32')) elif isinstance(clip[0], PIL.Image.Image): brightness, contrast, saturation, hue = self.get_params( self.brightness, self.contrast, self.saturation, self.hue) # Create img transform function sequence img_transforms = [] if brightness is not None: img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) if saturation is not None: img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) if hue is not None: img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) if contrast is not None: img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) random.shuffle(img_transforms) # Apply to all videos jittered_clip = [] for img in clip: for func in img_transforms: jittered_img = func(img) jittered_clip.append(jittered_img) else: raise TypeError('Expected numpy.ndarray or PIL.Image' + 'but got list of {0}'.format(type(clip[0]))) return jittered_clip class AllAugmentationTransform: def __init__(self, resize_param=None, rotation_param=None, flip_param=None, crop_param=None, jitter_param=None): self.transforms = [] if flip_param is not None: self.transforms.append(RandomFlip(**flip_param)) if rotation_param is not None: self.transforms.append(RandomRotation(**rotation_param)) if resize_param is not None: self.transforms.append(RandomResize(**resize_param)) if crop_param is not None: self.transforms.append(RandomCrop(**crop_param)) if jitter_param is not None: self.transforms.append(ColorJitter(**jitter_param)) def __call__(self, clip): for t in self.transforms: clip = t(clip) return clip ================================================ FILE: LFG/frames_dataset.py ================================================ """ Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. """ import os from skimage import io, img_as_float32 from skimage.color import gray2rgb from skimage.transform import resize from sklearn.model_selection import train_test_split from imageio import mimread import numpy as np from torch.utils.data import Dataset import pandas as pd from augmentation import AllAugmentationTransform import glob from functools import partial def read_video(name, frame_shape): """ Read video which can be: - an image of concatenated frames - '.mp4' and'.gif' - folder with videos """ if os.path.isdir(name): frames = sorted(os.listdir(name)) num_frames = len(frames) video_array = [img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)] if frame_shape is not None: video_array = np.array([resize(frame, frame_shape) for frame in video_array]) elif name.lower().endswith('.png') or name.lower().endswith('.jpg'): image = io.imread(name) if frame_shape is None: raise ValueError('Frame shape can not be None for stacked png format.') frame_shape = tuple(frame_shape) if len(image.shape) == 2 or image.shape[2] == 1: image = gray2rgb(image) if image.shape[2] == 4: image = image[..., :3] image = img_as_float32(image) video_array = np.moveaxis(image, 1, 0) video_array = video_array.reshape((-1,) + frame_shape + (3, )) video_array = np.moveaxis(video_array, 1, 2) elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'): video = mimread(name) if len(video[0].shape) == 2: video = [gray2rgb(frame) for frame in video] if frame_shape is not None: video = np.array([resize(frame, frame_shape) for frame in video]) video = np.array(video) if video.shape[-1] == 4: video = video[..., :3] video_array = img_as_float32(video) else: raise Exception("Unknown file extensions %s" % name) return video_array class FramesDataset(Dataset): """ Dataset of videos, each video can be represented as: - an image of concatenated frames - '.mp4' or '.gif' - folder with all frames """ def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True, random_seed=0, pairs_list=None, augmentation_params=None): self.root_dir = root_dir self.videos = os.listdir(root_dir) self.frame_shape = frame_shape self.pairs_list = pairs_list self.id_sampling = id_sampling if os.path.exists(os.path.join(root_dir, 'train')): assert os.path.exists(os.path.join(root_dir, 'test')) print("Use predefined train-test split.") if id_sampling: train_videos = {os.path.basename(video).split('#')[0] for video in os.listdir(os.path.join(root_dir, 'train'))} train_videos = list(train_videos) else: train_videos = os.listdir(os.path.join(root_dir, 'train')) test_videos = os.listdir(os.path.join(root_dir, 'test')) self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test') else: print("Use random train-test split.") train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2) if is_train: self.videos = train_videos else: self.videos = test_videos self.is_train = is_train if self.is_train: self.transform = AllAugmentationTransform(**augmentation_params) else: self.transform = None def __len__(self): return len(self.videos) def __getitem__(self, idx): if self.is_train and self.id_sampling: name = self.videos[idx] try: path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4'))) except ValueError: raise ValueError("File formatting is not correct for id_sampling=True. " "Change file formatting, or set id_sampling=False.") else: name = self.videos[idx] path = os.path.join(self.root_dir, name) video_name = os.path.basename(path) if self.is_train and os.path.isdir(path): frames = os.listdir(path) num_frames = len(frames) frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) if self.frame_shape is not None: resize_fn = partial(resize, output_shape=self.frame_shape) else: resize_fn = img_as_float32 if type(frames[0]) is bytes: video_array = [resize_fn(io.imread(os.path.join(path, frames[idx].decode('utf-8')))) for idx in frame_idx] else: video_array = [resize_fn(io.imread(os.path.join(path, frames[idx]))) for idx in frame_idx] else: video_array = read_video(path, frame_shape=self.frame_shape) num_frames = len(video_array) frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) if self.is_train else range( num_frames) video_array = video_array[frame_idx][..., :3] if self.transform is not None: video_array = self.transform(video_array) out = {} if self.is_train: source = np.array(video_array[0], dtype='float32') driving = np.array(video_array[1], dtype='float32') out['driving'] = driving.transpose((2, 0, 1)) out['source'] = source.transpose((2, 0, 1)) else: video = np.array(video_array, dtype='float32') out['video'] = video.transpose((3, 0, 1, 2)) out['name'] = video_name out['id'] = idx return out class DatasetRepeater(Dataset): """ Pass several times over the same dataset for better i/o performance """ def __init__(self, dataset, num_repeats=100): self.dataset = dataset self.num_repeats = num_repeats def __len__(self): return self.num_repeats * self.dataset.__len__() def __getitem__(self, idx): return self.dataset[idx % self.dataset.__len__()] class PairedDataset(Dataset): """ Dataset of pairs for animation. """ def __init__(self, initial_dataset, number_of_pairs, seed=0): self.initial_dataset = initial_dataset pairs_list = self.initial_dataset.pairs_list np.random.seed(seed) if pairs_list is None: max_idx = min(number_of_pairs, len(initial_dataset)) nx, ny = max_idx, max_idx xy = np.mgrid[:nx, :ny].reshape(2, -1).T number_of_pairs = min(xy.shape[0], number_of_pairs) self.pairs = xy.take(np.random.choice(xy.shape[0], number_of_pairs, replace=False), axis=0) else: videos = self.initial_dataset.videos name_to_index = {name: index for index, name in enumerate(videos)} pairs = pd.read_csv(pairs_list) pairs = pairs[np.logical_and(pairs['source'].isin(videos), pairs['driving'].isin(videos))] number_of_pairs = min(pairs.shape[0], number_of_pairs) self.pairs = [] self.start_frames = [] for ind in range(number_of_pairs): self.pairs.append( (name_to_index[pairs['driving'].iloc[ind]], name_to_index[pairs['source'].iloc[ind]])) def __len__(self): return len(self.pairs) def __getitem__(self, idx): pair = self.pairs[idx] first = self.initial_dataset[pair[0]] second = self.initial_dataset[pair[1]] first = {'driving_' + key: value for key, value in first.items()} second = {'source_' + key: value for key, value in second.items()} return {**first, **second} ================================================ FILE: LFG/hdtf_dataset.py ================================================ # build MUG dataset for RegionMM import os import imageio import numpy as np from torch.utils.data import Dataset import yaml from argparse import ArgumentParser from augmentation import AllAugmentationTransform from functools import partial import cv2 import matplotlib.pyplot as plt import imageio.v2 as imageio def resize(im, desired_size, interpolation): old_size = im.shape[:2] ratio = float(desired_size)/max(old_size) new_size = tuple(int(x*ratio) for x in old_size) im = cv2.resize(im, (new_size[1], new_size[0]), interpolation=interpolation) delta_w = desired_size - new_size[1] delta_h = desired_size - new_size[0] top, bottom = delta_h//2, delta_h-(delta_h//2) left, right = delta_w//2, delta_w-(delta_w//2) color = [0, 0, 0] new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) return new_im # this is just for training class FramesDataset(Dataset): """ Dataset of videos, each video can be represented as: - an image of concatenated frames - '.mp4' or '.gif' - folder with all frames """ def __init__(self, root_dir, frame_shape=256, id_sampling=False, pairs_list=None, augmentation_params=None): self.root_dir = root_dir self.frame_shape = frame_shape self.pairs_list = pairs_list self.id_sampling = id_sampling vid_list = [] # crema for id_name in os.listdir(root_dir): vid_list.extend([os.path.join(id_name, sent) for sent in os.listdir(f'{root_dir}/{id_name}') ]) #hdtf # for id_name in os.listdir(root_dir): # vid_list.append(id_name) self.videos = vid_list self.transform = AllAugmentationTransform(**augmentation_params) def __len__(self): return len(self.videos) def __getitem__(self, idx): if self.id_sampling: raise NotImplementedError else: name = self.videos[idx] path = os.path.join(self.root_dir, name) frames = os.listdir(path) frames.sort() num_frames = len(frames) frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) resize_fn = partial(resize, desired_size=self.frame_shape, interpolation=cv2.INTER_AREA) if type(frames[0]) is bytes: frame_names = [frames[idx].decode('utf-8') for idx in frame_idx] else: frame_names = [frames[idx] for idx in frame_idx] video_array = [resize_fn(imageio.imread(os.path.join(path, x))) for x in frame_names] # video_array = [img_as_float32(x) for x in video_array] video_array = self.transform(video_array) out = {} source = np.array(video_array[0], dtype='float32') driving = np.array(video_array[1], dtype='float32') out['driving'] = driving.transpose((2, 0, 1)) out['source'] = source.transpose((2, 0, 1)) out['name'] = name out['frame'] = frame_names out['id'] = idx return out class DatasetRepeater(Dataset): """ Pass several times over the same dataset for better i/o performance """ def __init__(self, dataset, num_repeats=100): self.dataset = dataset self.num_repeats = num_repeats def __len__(self): return self.num_repeats * self.dataset.__len__() def __getitem__(self, idx): return self.dataset[idx % self.dataset.__len__()] if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--config", default="/train20/intern/permanent/lmlin2/Flow/CVPR23_LFDM-main/config/hdtf128.yaml", help="path to config") opt = parser.parse_args() with open(opt.config) as f: config = yaml.safe_load(f) data = FramesDataset(**config['dataset_params']) # data.__getitem__(0) # print('_------') # data.__getitem__(1) # print('------') data.__getitem__(2) print('------') data.__getitem__(3) print('------') ================================================ FILE: LFG/modules/avd_network.py ================================================ """ Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. """ import torch from torch import nn class AVDNetwork(nn.Module): """ Animation via Disentanglement network """ def __init__(self, num_regions, id_bottle_size=64, pose_bottle_size=64, revert_axis_swap=True): super(AVDNetwork, self).__init__() input_size = (2 + 4) * num_regions self.num_regions = num_regions self.revert_axis_swap = revert_axis_swap self.id_encoder = nn.Sequential( nn.Linear(input_size, 256), nn.BatchNorm1d(256), nn.ReLU(inplace=True), nn.Linear(256, 512), nn.BatchNorm1d(512), nn.ReLU(inplace=True), nn.Linear(512, 1024), nn.BatchNorm1d(1024), nn.ReLU(inplace=True), nn.Linear(1024, id_bottle_size) ) self.pose_encoder = nn.Sequential( nn.Linear(input_size, 256), nn.BatchNorm1d(256), nn.ReLU(inplace=True), nn.Linear(256, 512), nn.BatchNorm1d(512), nn.ReLU(inplace=True), nn.Linear(512, 1024), nn.BatchNorm1d(1024), nn.ReLU(inplace=True), nn.Linear(1024, pose_bottle_size) ) self.decoder = nn.Sequential( nn.Linear(pose_bottle_size + id_bottle_size, 1024), nn.BatchNorm1d(1024), nn.ReLU(), nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Linear(256, input_size) ) @staticmethod def region_params_to_emb(x): mean = x['shift'] jac = x['affine'] emb = torch.cat([mean, jac.view(jac.shape[0], jac.shape[1], -1)], dim=-1) emb = emb.view(emb.shape[0], -1) return emb def emb_to_region_params(self, emb): emb = emb.view(emb.shape[0], self.num_regions, 6) mean = emb[:, :, :2] jac = emb[:, :, 2:].view(emb.shape[0], emb.shape[1], 2, 2) return {'shift': mean, 'affine': jac} def forward(self, x_id, x_pose, alpha=0.2): if self.revert_axis_swap: affine = torch.matmul(x_id['affine'], torch.inverse(x_pose['affine'])) sign = torch.sign(affine[:, :, 0:1, 0:1]) x_id = {'affine': x_id['affine'] * sign, 'shift': x_id['shift']} pose_emb = self.pose_encoder(self.region_params_to_emb(x_pose)) id_emb = self.id_encoder(self.region_params_to_emb(x_id)) rec = self.decoder(torch.cat([pose_emb, id_emb], dim=1)) rec = self.emb_to_region_params(rec) rec['covar'] = torch.matmul(rec['affine'], rec['affine'].permute(0, 1, 3, 2)) return rec ================================================ FILE: LFG/modules/bg_motion_predictor.py ================================================ """ Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. """ from torch import nn import torch from LFG.modules.util import Encoder class BGMotionPredictor(nn.Module): """ Module for background estimation, return single transformation, parametrized as 3x3 matrix. """ def __init__(self, block_expansion, num_channels, max_features, num_blocks, bg_type='zero'): super(BGMotionPredictor, self).__init__() assert bg_type in ['zero', 'shift', 'affine', 'perspective'] self.bg_type = bg_type if self.bg_type != 'zero': self.encoder = Encoder(block_expansion, in_features=num_channels * 2, max_features=max_features, num_blocks=num_blocks) in_features = min(max_features, block_expansion * (2 ** num_blocks)) if self.bg_type == 'perspective': self.fc = nn.Linear(in_features, 8) self.fc.weight.data.zero_() self.fc.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0, 0, 0], dtype=torch.float)) elif self.bg_type == 'affine': self.fc = nn.Linear(in_features, 6) self.fc.weight.data.zero_() self.fc.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) elif self.bg_type == 'shift': self.fc = nn.Linear(in_features, 2) self.fc.weight.data.zero_() self.fc.bias.data.copy_(torch.tensor([0, 0], dtype=torch.float)) def forward(self, source_image, driving_image): bs = source_image.shape[0] out = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).type(source_image.type()) if self.bg_type != 'zero': prediction = self.encoder(torch.cat([source_image, driving_image], dim=1)) prediction = prediction[-1].mean(dim=(2, 3)) prediction = self.fc(prediction) if self.bg_type == 'shift': out[:, :2, 2] = prediction elif self.bg_type == 'affine': out[:, :2, :] = prediction.view(bs, 2, 3) elif self.bg_type == 'perspective': out[:, :2, :] = prediction[:, :6].view(bs, 2, 3) out[:, 2, :2] = prediction[:, 6:].view(bs, 2) return out ================================================ FILE: LFG/modules/flow_autoenc.py ================================================ # utilize RegionMM to design a flow auto-encoder import torch import torch.nn as nn import sys sys.path.append('/train20/intern/permanent/lmlin2/Flow/CVPR23_LFDM-main') from LFG.modules.generator import Generator from LFG.modules.bg_motion_predictor import BGMotionPredictor from LFG.modules.region_predictor import RegionPredictor import yaml # based on RegionMM class FlowAE(nn.Module): def __init__(self, is_train=False, config_pth="/train20/intern/permanent/lmlin2/Flow/CVPR23_LFDM-main/config/mug128.yaml"): super(FlowAE, self).__init__() with open(config_pth) as f: config = yaml.safe_load(f) self.generator = Generator(num_regions=config['model_params']['num_regions'], num_channels=config['model_params']['num_channels'], revert_axis_swap=config['model_params']['revert_axis_swap'], **config['model_params']['generator_params']).cuda() self.region_predictor = RegionPredictor(num_regions=config['model_params']['num_regions'], num_channels=config['model_params']['num_channels'], estimate_affine=config['model_params']['estimate_affine'], **config['model_params']['region_predictor_params']).cuda() self.bg_predictor = BGMotionPredictor(num_channels=config['model_params']['num_channels'], **config['model_params']['bg_predictor_params']) self.is_train = is_train self.ref_img = None self.dri_img = None self.generated = None def forward(self): source_region_params = self.region_predictor(self.ref_img) self.driving_region_params = self.region_predictor(self.dri_img) bg_params = self.bg_predictor(self.ref_img, self.dri_img) self.generated = self.generator(self.ref_img, source_region_params=source_region_params, driving_region_params=self.driving_region_params, bg_params=bg_params) self.generated.update({'source_region_params': source_region_params, 'driving_region_params': self.driving_region_params}) def set_train_input(self, ref_img, dri_img): self.ref_img = ref_img.cuda() self.dri_img = dri_img.cuda() if __name__ == "__main__": # default image size is 128 # import os # os.environ["CUDA_VISIBLE_DEVICES"] = "0" ref_img = torch.rand((5, 3, 128, 128), dtype=torch.float32) dri_img = torch.rand((5, 3, 128, 128), dtype=torch.float32) model = FlowAE(is_train=True).cuda() model.train() model.set_train_input(ref_img=ref_img, dri_img=dri_img) model.forward() print("___finihed___") ================================================ FILE: LFG/modules/generator.py ================================================ """ Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. """ import time import torch from torch import nn import torch.nn.functional as F import sys sys.path.append("/train20/intern/permanent/lmlin2/Flow/CVPR23_LFDM-main") from LFG.modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d from LFG.modules.pixelwise_flow_predictor import PixelwiseFlowPredictor class Generator(nn.Module): """ Generator that given source image and region parameters try to transform image according to movement trajectories induced by region parameters. Generator follows Johnson architecture. """ def __init__(self, num_channels, num_regions, block_expansion, max_features, num_down_blocks, num_bottleneck_blocks, pixelwise_flow_predictor_params=None, skips=False, revert_axis_swap=True): super(Generator, self).__init__() if pixelwise_flow_predictor_params is not None: self.pixelwise_flow_predictor = PixelwiseFlowPredictor(num_regions=num_regions, num_channels=num_channels, revert_axis_swap=revert_axis_swap, **pixelwise_flow_predictor_params) else: self.pixelwise_flow_predictor = None self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3)) down_blocks = [] for i in range(num_down_blocks): in_features = min(max_features, block_expansion * (2 ** i)) out_features = min(max_features, block_expansion * (2 ** (i + 1))) down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) self.down_blocks = nn.ModuleList(down_blocks) up_blocks = [] for i in range(num_down_blocks): in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i))) out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1))) up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) self.up_blocks = nn.ModuleList(up_blocks) self.bottleneck = torch.nn.Sequential() in_features = min(max_features, block_expansion * (2 ** num_down_blocks)) for i in range(num_bottleneck_blocks): self.bottleneck.add_module('r' + str(i), ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1))) self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3)) self.num_channels = num_channels self.skips = skips @staticmethod def deform_input(inp, optical_flow): _, h_old, w_old, _ = optical_flow.shape _, _, h, w = inp.shape if h_old != h or w_old != w: optical_flow = optical_flow.permute(0, 3, 1, 2) optical_flow = F.interpolate(optical_flow, size=(h, w), mode='bilinear') optical_flow = optical_flow.permute(0, 2, 3, 1) return F.grid_sample(inp, optical_flow) def apply_optical(self, input_previous=None, input_skip=None, motion_params=None): if motion_params is not None: if 'occlusion_map' in motion_params: occlusion_map = motion_params['occlusion_map'] else: occlusion_map = None deformation = motion_params['optical_flow'] input_skip = self.deform_input(input_skip, deformation) if occlusion_map is not None: if input_skip.shape[2] != occlusion_map.shape[2] or input_skip.shape[3] != occlusion_map.shape[3]: occlusion_map = F.interpolate(occlusion_map, size=input_skip.shape[2:], mode='bilinear') if input_previous is not None: input_skip = input_skip * occlusion_map + input_previous * (1 - occlusion_map) else: input_skip = input_skip * occlusion_map out = input_skip else: out = input_previous if input_previous is not None else input_skip return out def forward(self, source_image, driving_region_params, source_region_params, bg_params=None): out = self.first(source_image) skips = [out] for i in range(len(self.down_blocks)): out = self.down_blocks[i](out) skips.append(out) output_dict = {} output_dict["bottle_neck_feat"] = out if self.pixelwise_flow_predictor is not None: motion_params = self.pixelwise_flow_predictor(source_image=source_image, driving_region_params=driving_region_params, source_region_params=source_region_params, bg_params=bg_params) output_dict["deformed"] = self.deform_input(source_image, motion_params['optical_flow']) output_dict["optical_flow"] = motion_params['optical_flow'] if 'occlusion_map' in motion_params: output_dict['occlusion_map'] = motion_params['occlusion_map'] else: motion_params = None out = self.apply_optical(input_previous=None, input_skip=out, motion_params=motion_params) out = self.bottleneck(out) for i in range(len(self.up_blocks)): if self.skips: out = self.apply_optical(input_skip=skips[-(i + 1)], input_previous=out, motion_params=motion_params) out = self.up_blocks[i](out) if self.skips: out = self.apply_optical(input_skip=skips[0], input_previous=out, motion_params=motion_params) out = self.final(out) out = torch.sigmoid(out) if self.skips: out = self.apply_optical(input_skip=source_image, input_previous=out, motion_params=motion_params) output_dict["prediction"] = out return output_dict def compute_fea(self, source_image): out = self.first(source_image) for i in range(len(self.down_blocks)): out = self.down_blocks[i](out) return out def forward_with_flow(self, source_image, optical_flow, occlusion_map): start_time = time.time() # end out = self.first(source_image) end_time = time.time() # end # print(f'img fea extract time surplus {end_time- start_time}') skips = [out] for i in range(len(self.down_blocks)): out = self.down_blocks[i](out) skips.append(out) output_dict = {} motion_params = {} motion_params["optical_flow"] = optical_flow motion_params["occlusion_map"] = occlusion_map output_dict["deformed"] = self.deform_input(source_image, motion_params['optical_flow']) out = self.apply_optical(input_previous=None, input_skip=out, motion_params=motion_params) out = self.bottleneck(out) for i in range(len(self.up_blocks)): if self.skips: out = self.apply_optical(input_skip=skips[-(i + 1)], input_previous=out, motion_params=motion_params) out = self.up_blocks[i](out) if self.skips: out = self.apply_optical(input_skip=skips[0], input_previous=out, motion_params=motion_params) out = self.final(out) out = torch.sigmoid(out) if self.skips: out = self.apply_optical(input_skip=source_image, input_previous=out, motion_params=motion_params) output_dict["prediction"] = out return output_dict ================================================ FILE: LFG/modules/model.py ================================================ """ Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. """ from torch import nn import torch import torch.nn.functional as F from LFG.modules.util import AntiAliasInterpolation2d, make_coordinate_grid from torchvision import models import numpy as np from torch.autograd import grad class Vgg19(torch.nn.Module): """ Vgg19 network for perceptual loss. """ def __init__(self, requires_grad=False): super(Vgg19, self).__init__() vgg_pretrained_features = models.vgg19(pretrained=True).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() for x in range(2): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(2, 7): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(7, 12): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(12, 21): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(21, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), requires_grad=False) self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), requires_grad=False) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, x): x = (x - self.mean) / self.std h_relu1 = self.slice1(x) h_relu2 = self.slice2(h_relu1) h_relu3 = self.slice3(h_relu2) h_relu4 = self.slice4(h_relu3) h_relu5 = self.slice5(h_relu4) out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] return out class ImagePyramide(torch.nn.Module): """ Create image pyramide for computing pyramide perceptual loss. """ def __init__(self, scales, num_channels): super(ImagePyramide, self).__init__() downs = {} for scale in scales: downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale) self.downs = nn.ModuleDict(downs) def forward(self, x): out_dict = {} for scale, down_module in self.downs.items(): out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x) return out_dict class Transform: """ Random tps transformation for equivariance constraints. """ def __init__(self, bs, **kwargs): noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3])) self.theta = noise + torch.eye(2, 3).view(1, 2, 3) self.bs = bs if ('sigma_tps' in kwargs) and ('points_tps' in kwargs): self.tps = True self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=self.theta.type()) self.control_points = self.control_points.unsqueeze(0) self.control_params = torch.normal(mean=0, std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2])) else: self.tps = False def transform_frame(self, frame): grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) grid = grid.view(1, frame.shape[2] * frame.shape[3], 2) grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2) return F.grid_sample(frame, grid, padding_mode="reflection") def warp_coordinates(self, coordinates): theta = self.theta.type(coordinates.type()) theta = theta.unsqueeze(1) transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:] transformed = transformed.squeeze(-1) if self.tps: control_points = self.control_points.type(coordinates.type()) control_params = self.control_params.type(coordinates.type()) distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2) distances = torch.abs(distances).sum(-1) # TODO this part may have bugs result = distances ** 2 result = result * torch.log(distances + 1e-6) result = result * control_params result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1) transformed = transformed + result return transformed def jacobian(self, coordinates): new_coordinates = self.warp_coordinates(coordinates) grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True) grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True) jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2) return jacobian def detach_kp(kp): return {key: value.detach() for key, value in kp.items()} class ReconstructionModel(torch.nn.Module): """ Merge all updates into single model for better multi-gpu usage """ def __init__(self, region_predictor, bg_predictor, generator, train_params): super(ReconstructionModel, self).__init__() self.region_predictor = region_predictor self.bg_predictor = bg_predictor self.generator = generator self.train_params = train_params self.scales = train_params['scales'] self.pyramid = ImagePyramide(self.scales, generator.num_channels) if torch.cuda.is_available(): self.pyramid = self.pyramid.cuda() self.loss_weights = train_params['loss_weights'] if sum(self.loss_weights['perceptual']) != 0: self.vgg = Vgg19() if torch.cuda.is_available(): self.vgg = self.vgg.cuda() def forward(self, x): source_region_params = self.region_predictor(x['source']) driving_region_params = self.region_predictor(x['driving']) bg_params = self.bg_predictor(x['source'], x['driving']) #background generated = self.generator(x['source'], source_region_params=source_region_params, driving_region_params=driving_region_params, bg_params=bg_params) generated.update({'source_region_params': source_region_params, 'driving_region_params': driving_region_params}) loss_values = {} pyramide_real = self.pyramid(x['driving']) pyramide_generated = self.pyramid(generated['prediction']) if sum(self.loss_weights['perceptual']) != 0: value_total = 0 for scale in self.scales: x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) for i, weight in enumerate(self.loss_weights['perceptual']): value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() value_total += self.loss_weights['perceptual'][i] * value loss_values['perceptual'] = value_total if (self.loss_weights['equivariance_shift'] + self.loss_weights['equivariance_affine']) != 0: transform = Transform(x['driving'].shape[0], **self.train_params['transform_params']) transformed_frame = transform.transform_frame(x['driving']) transformed_region_params = self.region_predictor(transformed_frame) generated['transformed_frame'] = transformed_frame generated['transformed_region_params'] = transformed_region_params if self.loss_weights['equivariance_shift'] != 0: value = torch.abs(driving_region_params['shift'] - transform.warp_coordinates(transformed_region_params['shift'])).mean() loss_values['equivariance_shift'] = self.loss_weights['equivariance_shift'] * value if self.loss_weights['equivariance_affine'] != 0: affine_transformed = torch.matmul(transform.jacobian(transformed_region_params['shift']), transformed_region_params['affine']) normed_driving = torch.inverse(driving_region_params['affine']) normed_transformed = affine_transformed value = torch.matmul(normed_driving, normed_transformed) eye = torch.eye(2).view(1, 1, 2, 2).type(value.type()) if self.generator.pixelwise_flow_predictor.revert_axis_swap: value = value * torch.sign(value[:, :, 0:1, 0:1]) value = torch.abs(eye - value).mean() loss_values['equivariance_affine'] = self.loss_weights['equivariance_affine'] * value return loss_values, generated ================================================ FILE: LFG/modules/pixelwise_flow_predictor.py ================================================ """ Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. """ from torch import nn import torch.nn.functional as F import torch from LFG.modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, region2gaussian from LFG.modules.util import to_homogeneous, from_homogeneous class PixelwiseFlowPredictor(nn.Module): """ Module that predicts a pixelwise flow from sparse motion representation given by source_region_params and driving_region_params """ def __init__(self, block_expansion, num_blocks, max_features, num_regions, num_channels, estimate_occlusion_map=False, scale_factor=1, region_var=0.01, use_covar_heatmap=False, use_deformed_source=True, revert_axis_swap=False): super(PixelwiseFlowPredictor, self).__init__() self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_regions + 1) * (num_channels * use_deformed_source + 1), max_features=max_features, num_blocks=num_blocks) self.mask = nn.Conv2d(self.hourglass.out_filters, num_regions + 1, kernel_size=(7, 7), padding=(3, 3)) if estimate_occlusion_map: self.occlusion = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3)) else: self.occlusion = None self.num_regions = num_regions self.scale_factor = scale_factor self.region_var = region_var self.use_covar_heatmap = use_covar_heatmap self.use_deformed_source = use_deformed_source self.revert_axis_swap = revert_axis_swap if self.scale_factor != 1: self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) def create_heatmap_representations(self, source_image, driving_region_params, source_region_params): """ Eq 6. in the paper H_k(z) """ spatial_size = source_image.shape[2:] covar = self.region_var if not self.use_covar_heatmap else driving_region_params['covar'] gaussian_driving = region2gaussian(driving_region_params['shift'], covar=covar, spatial_size=spatial_size) covar = self.region_var if not self.use_covar_heatmap else source_region_params['covar'] gaussian_source = region2gaussian(source_region_params['shift'], covar=covar, spatial_size=spatial_size) heatmap = gaussian_driving - gaussian_source # adding background feature zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]) heatmap = torch.cat([zeros.type(heatmap.type()), heatmap], dim=1) heatmap = heatmap.unsqueeze(2) return heatmap def create_sparse_motions(self, source_image, driving_region_params, source_region_params, bg_params=None): bs, _, h, w = source_image.shape identity_grid = make_coordinate_grid((h, w), type=source_region_params['shift'].type()) identity_grid = identity_grid.view(1, 1, h, w, 2) coordinate_grid = identity_grid - driving_region_params['shift'].view(bs, self.num_regions, 1, 1, 2) if 'affine' in driving_region_params: affine = torch.matmul(source_region_params['affine'], torch.inverse(driving_region_params['affine'].float())) if self.revert_axis_swap: affine = affine * torch.sign(affine[:, :, 0:1, 0:1]) affine = affine.unsqueeze(-3).unsqueeze(-3) affine = affine.repeat(1, 1, h, w, 1, 1) coordinate_grid = torch.matmul(affine, coordinate_grid.unsqueeze(-1)) coordinate_grid = coordinate_grid.squeeze(-1) driving_to_source = coordinate_grid + source_region_params['shift'].view(bs, self.num_regions, 1, 1, 2) # adding background feature if bg_params is None: bg_grid = identity_grid.repeat(bs, 1, 1, 1, 1) else: bg_grid = identity_grid.repeat(bs, 1, 1, 1, 1) bg_grid = to_homogeneous(bg_grid) bg_grid = torch.matmul(bg_params.view(bs, 1, 1, 1, 3, 3), bg_grid.unsqueeze(-1)).squeeze(-1) bg_grid = from_homogeneous(bg_grid) sparse_motions = torch.cat([bg_grid, driving_to_source], dim=1) return sparse_motions def create_deformed_source_image(self, source_image, sparse_motions): bs, _, h, w = source_image.shape source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_regions + 1, 1, 1, 1, 1) source_repeat = source_repeat.view(bs * (self.num_regions + 1), -1, h, w) sparse_motions = sparse_motions.view((bs * (self.num_regions + 1), h, w, -1)) sparse_deformed = F.grid_sample(source_repeat, sparse_motions) sparse_deformed = sparse_deformed.view((bs, self.num_regions + 1, -1, h, w)) return sparse_deformed def forward(self, source_image, driving_region_params, source_region_params, bg_params=None): if self.scale_factor != 1: source_image = self.down(source_image) bs, _, h, w = source_image.shape out_dict = dict() heatmap_representation = self.create_heatmap_representations(source_image, driving_region_params, source_region_params) sparse_motion = self.create_sparse_motions(source_image, driving_region_params, source_region_params, bg_params=bg_params) deformed_source = self.create_deformed_source_image(source_image, sparse_motion) if self.use_deformed_source: predictor_input = torch.cat([heatmap_representation, deformed_source], dim=2) else: predictor_input = heatmap_representation predictor_input = predictor_input.view(bs, -1, h, w) prediction = self.hourglass(predictor_input) mask = self.mask(prediction) mask = F.softmax(mask, dim=1) mask = mask.unsqueeze(2) sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3) deformation = (sparse_motion * mask).sum(dim=1) deformation = deformation.permute(0, 2, 3, 1) out_dict['optical_flow'] = deformation if self.occlusion: occlusion_map = torch.sigmoid(self.occlusion(prediction)) out_dict['occlusion_map'] = occlusion_map return out_dict ================================================ FILE: LFG/modules/region_predictor.py ================================================ """ Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. """ from torch import nn import torch import torch.nn.functional as F from LFG.modules.util import Hourglass, make_coordinate_grid, AntiAliasInterpolation2d, Encoder def svd(covar, fast=False): if fast: from torch_batch_svd import svd as fast_svd return fast_svd(covar) else: u, s, v = torch.svd(covar.cpu()) s = s.to(covar.device) u = u.to(covar.device) v = v.to(covar.device) return u, s, v class RegionPredictor(nn.Module): """ Region estimating. Estimate affine parameters of the region. """ def __init__(self, block_expansion, num_regions, num_channels, max_features, num_blocks, temperature, estimate_affine=False, scale_factor=1, pca_based=False, fast_svd=False, pad=3): super(RegionPredictor, self).__init__() self.predictor = Hourglass(block_expansion, in_features=num_channels, max_features=max_features, num_blocks=num_blocks) self.regions = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_regions, kernel_size=(7, 7), padding=pad) # FOMM-like regression based representation if estimate_affine and not pca_based: self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=4, kernel_size=(7, 7), padding=pad) self.jacobian.weight.data.zero_() self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1], dtype=torch.float)) else: self.jacobian = None self.temperature = temperature self.scale_factor = scale_factor self.pca_based = pca_based self.fast_svd = fast_svd if self.scale_factor != 1: self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) def region2affine(self, region): shape = region.shape region = region.unsqueeze(-1) grid = make_coordinate_grid(shape[2:], region.type()).unsqueeze_(0).unsqueeze_(0) mean = (region * grid).sum(dim=(2, 3)) region_params = {'shift': mean} if self.pca_based: mean_sub = grid - mean.unsqueeze(-2).unsqueeze(-2) covar = torch.matmul(mean_sub.unsqueeze(-1), mean_sub.unsqueeze(-2)) covar = covar * region.unsqueeze(-1) covar = covar.sum(dim=(2, 3)) region_params['covar'] = covar return region_params def forward(self, x): if self.scale_factor != 1: x = self.down(x) feature_map = self.predictor(x) prediction = self.regions(feature_map) final_shape = prediction.shape region = prediction.view(final_shape[0], final_shape[1], -1) region = F.softmax(region / self.temperature, dim=2) region = region.view(*final_shape) region_params = self.region2affine(region) region_params['heatmap'] = region # Regression-based estimation if self.jacobian is not None: jacobian_map = self.jacobian(feature_map) jacobian_map = jacobian_map.reshape(final_shape[0], 1, 4, final_shape[2], final_shape[3]) region = region.unsqueeze(2) jacobian = region * jacobian_map jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) jacobian = jacobian.sum(dim=-1) jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) region_params['affine'] = jacobian region_params['covar'] = torch.matmul(jacobian, jacobian.permute(0, 1, 3, 2)) elif self.pca_based: covar = region_params['covar'] shape = covar.shape covar = covar.view(-1, 2, 2) u, s, v = svd(covar, self.fast_svd) d = torch.diag_embed(s ** 0.5) sqrt = torch.matmul(u, d) sqrt = sqrt.view(*shape) region_params['affine'] = sqrt region_params['u'] = u region_params['d'] = d return region_params ================================================ FILE: LFG/modules/util.py ================================================ """ Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. """ from torch import nn import torch.nn.functional as F import torch from LFG.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d import numpy as np import matplotlib.pyplot as plt from skimage.draw import disk as circle import math def region2gaussian(center, covar, spatial_size): """ Transform a region parameters into gaussian like heatmap """ mean = center coordinate_grid = make_coordinate_grid(spatial_size, mean.type()) number_of_leading_dimensions = len(mean.shape) - 1 shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape coordinate_grid = coordinate_grid.view(*shape) repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1) coordinate_grid = coordinate_grid.repeat(*repeats) # Preprocess kp shape shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 2) mean = mean.view(*shape) mean_sub = (coordinate_grid - mean) if type(covar) == float: out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / covar) else: shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 2, 2) covar_inverse = torch.inverse(covar).view(*shape) under_exp = torch.matmul(torch.matmul(mean_sub.unsqueeze(-2), covar_inverse), mean_sub.unsqueeze(-1)) out = torch.exp(-0.5 * under_exp.sum(dim=(-1, -2))) return out def make_coordinate_grid(spatial_size, type): """ Create a meshgrid [-1,1] x [-1,1] of given spatial_size. """ h, w = spatial_size x = torch.arange(w).type(type) y = torch.arange(h).type(type) x = (2 * (x / (w - 1)) - 1) y = (2 * (y / (h - 1)) - 1) yy = y.view(-1, 1).repeat(1, w) xx = x.view(1, -1).repeat(h, 1) meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) return meshed class ResBlock2d(nn.Module): """ Res block, preserve spatial resolution. """ def __init__(self, in_features, kernel_size, padding): super(ResBlock2d, self).__init__() self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding) self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding) self.norm1 = BatchNorm2d(in_features, affine=True) self.norm2 = BatchNorm2d(in_features, affine=True) def forward(self, x): out = self.norm1(x) out = F.relu(out) out = self.conv1(out) out = self.norm2(out) out = F.relu(out) out = self.conv2(out) out += x return out class UpBlock2d(nn.Module): """ Upsampling block for use in decoder. """ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): super(UpBlock2d, self).__init__() self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) self.norm = BatchNorm2d(out_features, affine=True) def forward(self, x): out = F.interpolate(x, scale_factor=2) out = self.conv(out) out = self.norm(out) out = F.relu(out) return out class DownBlock2d(nn.Module): """ Downsampling block for use in encoder. """ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): super(DownBlock2d, self).__init__() self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) self.norm = BatchNorm2d(out_features, affine=True) self.pool = nn.AvgPool2d(kernel_size=(2, 2)) def forward(self, x): out = self.conv(x) out = self.norm(out) out = F.relu(out) out = self.pool(out) return out class SameBlock2d(nn.Module): """ Simple block, preserve spatial resolution. """ def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1): super(SameBlock2d, self).__init__() self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) self.norm = BatchNorm2d(out_features, affine=True) def forward(self, x): out = self.conv(x) out = self.norm(out) out = F.relu(out) return out class Encoder(nn.Module): """ Hourglass Encoder """ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): super(Encoder, self).__init__() down_blocks = [] for i in range(num_blocks): down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), min(max_features, block_expansion * (2 ** (i + 1))), kernel_size=3, padding=1)) self.down_blocks = nn.ModuleList(down_blocks) def forward(self, x): outs = [x] for down_block in self.down_blocks: outs.append(down_block(outs[-1])) return outs class Decoder(nn.Module): """ Hourglass Decoder """ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): super(Decoder, self).__init__() up_blocks = [] for i in range(num_blocks)[::-1]: in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1))) out_filters = min(max_features, block_expansion * (2 ** i)) up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1)) self.up_blocks = nn.ModuleList(up_blocks) self.out_filters = block_expansion + in_features def forward(self, x): out = x.pop() for up_block in self.up_blocks: out = up_block(out) skip = x.pop() out = torch.cat([out, skip], dim=1) return out class Hourglass(nn.Module): """ Hourglass architecture. """ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): super(Hourglass, self).__init__() self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) self.out_filters = self.decoder.out_filters def forward(self, x): return self.decoder(self.encoder(x)) class AntiAliasInterpolation2d(nn.Module): """ Band-limited downsampling, for better preservation of the input signal. """ def __init__(self, channels, scale): super(AntiAliasInterpolation2d, self).__init__() sigma = (1 / scale - 1) / 2 kernel_size = 2 * round(sigma * 4) + 1 self.ka = kernel_size // 2 self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka kernel_size = [kernel_size, kernel_size] sigma = [sigma, sigma] # The gaussian kernel is the product of the # gaussian function of each dimension. kernel = 1 meshgrids = torch.meshgrid( [ torch.arange(size, dtype=torch.float32) for size in kernel_size ] ) for size, std, mgrid in zip(kernel_size, sigma, meshgrids): mean = (size - 1) / 2 kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2)) # Make sure sum of values in gaussian kernel equals 1. kernel = kernel / torch.sum(kernel) # Reshape to depthwise convolutional weight kernel = kernel.view(1, 1, *kernel.size()) kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) self.register_buffer('weight', kernel) self.groups = channels self.scale = scale inv_scale = 1 / scale self.int_inv_scale = int(inv_scale) def forward(self, input): if self.scale == 1.0: return input out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) out = F.conv2d(out, weight=self.weight, groups=self.groups) out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale] return out def to_homogeneous(coordinates): ones_shape = list(coordinates.shape) ones_shape[-1] = 1 ones = torch.ones(ones_shape).type(coordinates.type()) return torch.cat([coordinates, ones], dim=-1) def from_homogeneous(coordinates): return coordinates[..., :2] / coordinates[..., 2:3] def draw_colored_heatmap(heatmap, colormap, bg_color): parts = [] weights = [] bg_color = np.array(bg_color).reshape((1, 1, 1, 3)) num_regions = heatmap.shape[-1] for i in range(num_regions): color = np.array(colormap(i / num_regions))[:3] color = color.reshape((1, 1, 1, 3)) part = heatmap[:, :, :, i:(i + 1)] part = part / np.max(part, axis=(1, 2), keepdims=True) weights.append(part) color_part = part * color parts.append(color_part) weight = sum(weights) bg_weight = 1 - np.minimum(1, weight) weight = np.maximum(1, weight) result = sum(parts) / weight + bg_weight * bg_color return result class Visualizer: def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbow', region_bg_color=(0, 0, 0)): self.kp_size = kp_size self.draw_border = draw_border self.colormap = plt.get_cmap(colormap) self.region_bg_color = np.array(region_bg_color) def draw_image_with_kp(self, image, kp_array): image = np.copy(image) spatial_size = np.array(image.shape[:2][::-1])[np.newaxis] kp_array = spatial_size * (kp_array + 1) / 2 num_regions = kp_array.shape[0] for kp_ind, kp in enumerate(kp_array): rr, cc = circle((kp[1], kp[0]), self.kp_size, shape=image.shape[:2]) image[rr, cc] = np.array(self.colormap(kp_ind / num_regions))[:3] return image def create_image_column_with_kp(self, images, kp): image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)]) return self.create_image_column(image_array) def create_image_column(self, images): if self.draw_border: images = np.copy(images) images[:, :, [0, -1]] = (1, 1, 1) images[:, :, [0, -1]] = (1, 1, 1) return np.concatenate(list(images), axis=0) def create_image_grid(self, *args): out = [] for arg in args: if type(arg) == tuple: out.append(self.create_image_column_with_kp(arg[0], arg[1])) else: out.append(self.create_image_column(arg)) return np.concatenate(out, axis=1) @staticmethod def sample(x, index): return x[index].unsqueeze(dim=0).clone().detach() def visualize(self, driving, source, out, index=0): images = [] # Source image with region centers source = self.sample(source, index) source = source.data.cpu() source_region_params = self.sample(out['source_region_params']['shift'], index) source_region_params = source_region_params.data.cpu().numpy() source = np.transpose(source, [0, 2, 3, 1]) images.append((source, source_region_params)) if 'heatmap' in out['source_region_params']: source_heatmap = self.sample(out['source_region_params']['heatmap'], index) source_heatmap = F.interpolate(source_heatmap, size=source.shape[1:3]) source_heatmap = np.transpose(source_heatmap.data.cpu().numpy(), [0, 2, 3, 1]) images.append(draw_colored_heatmap(source_heatmap, self.colormap, self.region_bg_color)) # Deformed image if 'deformed' in out: deformed = self.sample(out['deformed'], index) deformed = deformed.data.cpu().numpy() deformed = np.transpose(deformed, [0, 2, 3, 1]) images.append(deformed) # Equivariance visualization if 'transformed_frame' in out: transformed = self.sample(out['transformed_frame'], index) transformed = transformed.data.cpu().numpy() transformed = np.transpose(transformed, [0, 2, 3, 1]) transformed_kp = self.sample(out['transformed_region_params']['shift'], index) transformed_kp = transformed_kp.data.cpu().numpy() images.append((transformed, transformed_kp)) # Driving image with region centers driving_region_params = self.sample(out['driving_region_params']['shift'], index) driving_region_params = driving_region_params.data.cpu().numpy() driving = self.sample(driving, index) driving = driving.data.cpu().numpy() driving = np.transpose(driving, [0, 2, 3, 1]) images.append((driving, driving_region_params)) # Heatmaps visualizations if 'heatmap' in out['driving_region_params']: driving_heatmap = self.sample(out['driving_region_params']['heatmap'], index) driving_heatmap = F.interpolate(driving_heatmap, size=source.shape[1:3]) driving_heatmap = np.transpose(driving_heatmap.data.cpu().numpy(), [0, 2, 3, 1]) images.append(draw_colored_heatmap(driving_heatmap, self.colormap, self.region_bg_color)) # Result prediction = self.sample(out['prediction'], index) prediction = prediction.data.cpu().numpy() prediction = np.transpose(prediction, [0, 2, 3, 1]) images.append(prediction) # Occlusion map if 'occlusion_map' in out: occlusion_map = self.sample(out['occlusion_map'], index) occlusion_map = occlusion_map.data.cpu().repeat(1, 3, 1, 1) occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy() occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1]) images.append(occlusion_map) image = self.create_image_grid(*images) # reshape (1, 8) to (2, 4) H, W, _ = image.shape base = H num_image = W // H row, col = 2, math.ceil(num_image/2) new_image = np.zeros((row*base, col*base, 3), dtype=np.float32) cnt = 0 for ii in range(row): for jj in range(col): try: new_image[ii*base:(ii+1)*base, jj*base:(jj+1)*base, :] = image[:, cnt*base:(cnt+1)*base, :] except: pass cnt += 1 new_image = (255 * new_image).astype(np.uint8) return new_image ================================================ FILE: LFG/run_hdtf.py ================================================ # Estimate flow and occlusion mask via RegionMM (or called MRAA) for MUG dataset # this code is based on RegionMM from Snap Inc. # https://github.com/snap-research/articulated-animation import os import sys sys.path.append("your/path/DAWN-pytorch") # change this path to your current work directory import math import yaml from argparse import ArgumentParser from shutil import copy from datetime import datetime from LFG.hdtf_dataset import FramesDataset from LFG.modules.generator import Generator from LFG.modules.bg_motion_predictor import BGMotionPredictor from LFG.modules.region_predictor import RegionPredictor from LFG.modules.avd_network import AVDNetwork import torch import torch.backends.cudnn as cudnn import numpy as np import random from LFG.train import train class Logger(object): def __init__(self, filename='default.log', stream=sys.stdout): self.terminal = stream self.log = open(filename, 'w') def write(self, message): self.terminal.write(message) self.log.write(message) def flush(self): pass def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True if __name__ == "__main__": # os.environ["CUDA_VISIBLE_DEVICES"] = "3" cudnn.enabled = True cudnn.benchmark = True if sys.version_info[0] < 3: raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7") parser = ArgumentParser() parser.add_argument("--postfix", default="") # indicate different settings parser.add_argument("--random-seed", default=1234) parser.add_argument("--set-start", default=False) parser.add_argument("--config", default="your/path/DAWN-pytorch/config/hdtf128_llm.yaml", help="path to config") parser.add_argument("--mode", default="train", choices=["train"]) parser.add_argument("--log_dir", default='your/path/DAWN-pytorch/AE/data/log-hdtf', help="path to log into") parser.add_argument("--checkpoint", # use the pretrained VOX model given by Snap default="/train20/intern/permanent/lmlin2/Flow/CVPR23_LFDM-main/data/ckp/vox256.pth", help="path to checkpoint to restore") parser.add_argument("--device_ids", default="0", type=lambda x: list(map(int, x.split(','))), help="Names of the devices comma separated.") parser.add_argument("--verbose", dest="verbose", default=False, help="Print model architecture") parser.set_defaults(verbose=False) opt = parser.parse_args() setup_seed(opt.random_seed) with open(opt.config) as f: config = yaml.safe_load(f) current_time = datetime.now() current_time = current_time.strftime("%Y-%m-%d_%H:%M") log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0]+opt.postfix+'_'+current_time) if not os.path.exists(log_dir): os.makedirs(log_dir) if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))): copy(opt.config, log_dir) # the directory to save checkpoints config["snapshots"] = os.path.join(log_dir, 'snapshots'+opt.postfix) os.makedirs(config["snapshots"], exist_ok=True) # the directory to save images of training results config["imgshots"] = os.path.join(log_dir, 'imgshots'+opt.postfix) os.makedirs(config["imgshots"], exist_ok=True) config["set_start"] = opt.set_start log_txt = os.path.join(log_dir, "B"+format(config['train_params']['batch_size'], "04d")+ "E"+format(config['train_params']['max_epochs'], "04d")+".log") sys.stdout = Logger(log_txt, sys.stdout) print("postfix:", opt.postfix) print("checkpoint:", opt.checkpoint) print("batch size:", config['train_params']['batch_size']) generator = Generator(num_regions=config['model_params']['num_regions'], num_channels=config['model_params']['num_channels'], revert_axis_swap=config['model_params']['revert_axis_swap'], **config['model_params']['generator_params']) if torch.cuda.is_available(): generator.to(opt.device_ids[0]) if opt.verbose: print(generator) region_predictor = RegionPredictor(num_regions=config['model_params']['num_regions'], num_channels=config['model_params']['num_channels'], estimate_affine=config['model_params']['estimate_affine'], **config['model_params']['region_predictor_params']) if torch.cuda.is_available(): region_predictor.to(opt.device_ids[0]) if opt.verbose: print(region_predictor) bg_predictor = BGMotionPredictor(num_channels=config['model_params']['num_channels'], **config['model_params']['bg_predictor_params']) if torch.cuda.is_available(): bg_predictor.to(opt.device_ids[0]) if opt.verbose: print(bg_predictor) avd_network = AVDNetwork(num_regions=config['model_params']['num_regions'], **config['model_params']['avd_network_params']) if torch.cuda.is_available(): avd_network.to(opt.device_ids[0]) if opt.verbose: print(avd_network) dataset = FramesDataset(**config['dataset_params']) config["num_example_per_epoch"] = config['train_params']['num_repeats'] * len(dataset) config["num_step_per_epoch"] = math.ceil(config["num_example_per_epoch"]/float(config['train_params']['batch_size'])) # save 10 checkpoints in total config["save_ckpt_freq"] = config["num_step_per_epoch"] * (config['train_params']['max_epochs'] // 10) print("save ckpt freq:", config["save_ckpt_freq"]) print("Training...") train(config, generator, region_predictor, bg_predictor, opt.checkpoint, log_dir, dataset, opt.device_ids) ================================================ FILE: LFG/run_hdtf_crema.py ================================================ # Estimate flow and occlusion mask via RegionMM (or called MRAA) for MUG dataset # this code is based on RegionMM from Snap Inc. # https://github.com/snap-research/articulated-animation import os import sys sys.path.append("your/path/DAWN-pytorch") # change this path to your current work directory import math import yaml from argparse import ArgumentParser from shutil import copy from datetime import datetime from LFG.frames_dataset import FramesDataset from LFG.modules.generator import Generator from LFG.modules.bg_motion_predictor import BGMotionPredictor from LFG.modules.region_predictor import RegionPredictor from LFG.modules.avd_network import AVDNetwork import torch import torch.backends.cudnn as cudnn import numpy as np import random from LFG.train import train class Logger(object): def __init__(self, filename='default.log', stream=sys.stdout): self.terminal = stream self.log = open(filename, 'w') def write(self, message): self.terminal.write(message) self.log.write(message) def flush(self): pass def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True if __name__ == "__main__": # os.environ["CUDA_VISIBLE_DEVICES"] = "3" cudnn.enabled = True cudnn.benchmark = True if sys.version_info[0] < 3: raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7") parser = ArgumentParser() parser.add_argument("--postfix", default="") # indicate different settings parser.add_argument("--random-seed", default=1234) parser.add_argument("--set-start", default=False) parser.add_argument("--config", default="your/path/DAWN-pytorch/config/hdtf128_llm.yaml", help="path to config") parser.add_argument("--mode", default="train", choices=["train"]) parser.add_argument("--log_dir", default='your/path/DAWN-pytorch/AE/data/log-hdtf', help="path to log into") parser.add_argument("--checkpoint", # use the pretrained VOX model given by Snap default="/train20/intern/permanent/lmlin2/Flow/CVPR23_LFDM-main/data/ckp/vox256.pth", help="path to checkpoint to restore") parser.add_argument("--device_ids", default="0", type=lambda x: list(map(int, x.split(','))), help="Names of the devices comma separated.") parser.add_argument("--verbose", dest="verbose", default=False, help="Print model architecture") parser.set_defaults(verbose=False) opt = parser.parse_args() setup_seed(opt.random_seed) with open(opt.config) as f: config = yaml.safe_load(f) current_time = datetime.now() current_time = current_time.strftime("%Y-%m-%d_%H:%M") log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0]+opt.postfix+'_'+current_time) if not os.path.exists(log_dir): os.makedirs(log_dir) if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))): copy(opt.config, log_dir) # the directory to save checkpoints config["snapshots"] = os.path.join(log_dir, 'snapshots'+opt.postfix) os.makedirs(config["snapshots"], exist_ok=True) # the directory to save images of training results config["imgshots"] = os.path.join(log_dir, 'imgshots'+opt.postfix) os.makedirs(config["imgshots"], exist_ok=True) config["set_start"] = opt.set_start log_txt = os.path.join(log_dir, "B"+format(config['train_params']['batch_size'], "04d")+ "E"+format(config['train_params']['max_epochs'], "04d")+".log") sys.stdout = Logger(log_txt, sys.stdout) print("postfix:", opt.postfix) print("checkpoint:", opt.checkpoint) print("batch size:", config['train_params']['batch_size']) generator = Generator(num_regions=config['model_params']['num_regions'], num_channels=config['model_params']['num_channels'], revert_axis_swap=config['model_params']['revert_axis_swap'], **config['model_params']['generator_params']) if torch.cuda.is_available(): generator.to(opt.device_ids[0]) if opt.verbose: print(generator) region_predictor = RegionPredictor(num_regions=config['model_params']['num_regions'], num_channels=config['model_params']['num_channels'], estimate_affine=config['model_params']['estimate_affine'], **config['model_params']['region_predictor_params']) if torch.cuda.is_available(): region_predictor.to(opt.device_ids[0]) if opt.verbose: print(region_predictor) bg_predictor = BGMotionPredictor(num_channels=config['model_params']['num_channels'], **config['model_params']['bg_predictor_params']) if torch.cuda.is_available(): bg_predictor.to(opt.device_ids[0]) if opt.verbose: print(bg_predictor) avd_network = AVDNetwork(num_regions=config['model_params']['num_regions'], **config['model_params']['avd_network_params']) if torch.cuda.is_available(): avd_network.to(opt.device_ids[0]) if opt.verbose: print(avd_network) dataset = FramesDataset(**config['dataset_params']) config["num_example_per_epoch"] = config['train_params']['num_repeats'] * len(dataset) config["num_step_per_epoch"] = math.ceil(config["num_example_per_epoch"]/float(config['train_params']['batch_size'])) # save 10 checkpoints in total config["save_ckpt_freq"] = config["num_step_per_epoch"] * (config['train_params']['max_epochs'] // 10) print("save ckpt freq:", config["save_ckpt_freq"]) print("Training...") train(config, generator, region_predictor, bg_predictor, opt.checkpoint, log_dir, dataset, opt.device_ids) ================================================ FILE: LFG/sync_batchnorm/__init__.py ================================================ # -*- coding: utf-8 -*- # File : __init__.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d from .replicate import DataParallelWithCallback, patch_replication_callback ================================================ FILE: LFG/sync_batchnorm/batchnorm.py ================================================ # -*- coding: utf-8 -*- # File : batchnorm.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import collections import torch import torch.nn.functional as F from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast from .comm import SyncMaster __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] def _sum_ft(tensor): """sum over the first and last dimention""" return tensor.sum(dim=0).sum(dim=-1) def _unsqueeze_ft(tensor): """add new dementions at the front and the tail""" return tensor.unsqueeze(0).unsqueeze(-1) _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) class _SynchronizedBatchNorm(_BatchNorm): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) self._sync_master = SyncMaster(self._data_parallel_master) self._is_parallel = False self._parallel_id = None self._slave_pipe = None def forward(self, input): # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. if not (self._is_parallel and self.training): return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps) # Resize the input to (B, C, -1). input_shape = input.size() input = input.view(input.size(0), self.num_features, -1) # Compute the sum and square-sum. sum_size = input.size(0) * input.size(2) input_sum = _sum_ft(input) input_ssum = _sum_ft(input ** 2) # Reduce-and-broadcast the statistics. if self._parallel_id == 0: mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) else: mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) # Compute the output. if self.affine: # MJY:: Fuse the multiplication for speed. output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) else: output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) # Reshape it. return output.view(input_shape) def __data_parallel_replicate__(self, ctx, copy_id): self._is_parallel = True self._parallel_id = copy_id # parallel_id == 0 means master device. if self._parallel_id == 0: ctx.sync_master = self._sync_master else: self._slave_pipe = ctx.sync_master.register_slave(copy_id) def _data_parallel_master(self, intermediates): """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" # Always using same "device order" makes the ReduceAdd operation faster. # Thanks to:: Tete Xiao (http://tetexiao.com/) intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) to_reduce = [i[1][:2] for i in intermediates] to_reduce = [j for i in to_reduce for j in i] # flatten target_gpus = [i[1].sum.get_device() for i in intermediates] sum_size = sum([i[1].sum_size for i in intermediates]) sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) broadcasted = Broadcast.apply(target_gpus, mean, inv_std) outputs = [] for i, rec in enumerate(intermediates): outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) return outputs def _compute_mean_std(self, sum_, ssum, size): """Compute the mean and standard-deviation with sum and square-sum. This method also maintains the moving average on the master device.""" assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' mean = sum_ / size sumvar = ssum - sum_ * mean unbias_var = sumvar / (size - 1) bias_var = sumvar / size self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data return mean, bias_var.clamp(self.eps) ** -0.5 class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a mini-batch. .. math:: y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta This module differs from the built-in PyTorch BatchNorm1d as the mean and standard-deviation are reduced across all devices during training. For example, when one uses `nn.DataParallel` to wrap the network during training, PyTorch's implementation normalize the tensor on each device using the statistics only on that device, which accelerated the computation and is also easy to implement, but the statistics might be inaccurate. Instead, in this synchronized version, the statistics will be computed over all training samples distributed on multiple devices. Note that, for one-GPU or CPU-only case, this module behaves exactly same as the built-in PyTorch implementation. The mean and standard-deviation are calculated per-dimension over the mini-batches and gamma and beta are learnable parameter vectors of size C (where C is the input size). During training, this layer keeps a running estimate of its computed mean and variance. The running sum is kept with a default momentum of 0.1. During evaluation, this running mean/variance is used for normalization. Because the BatchNorm is done over the `C` dimension, computing statistics on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm Args: num_features: num_features from an expected input of size `batch_size x num_features [x width]` eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, gives the layer learnable affine parameters. Default: ``True`` Shape: - Input: :math:`(N, C)` or :math:`(N, C, L)` - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) Examples: >>> # With Learnable Parameters >>> m = SynchronizedBatchNorm1d(100) >>> # Without Learnable Parameters >>> m = SynchronizedBatchNorm1d(100, affine=False) >>> input = torch.autograd.Variable(torch.randn(20, 100)) >>> output = m(input) """ def _check_input_dim(self, input): if input.dim() != 2 and input.dim() != 3: raise ValueError('expected 2D or 3D input (got {}D input)' .format(input.dim())) super(SynchronizedBatchNorm1d, self)._check_input_dim(input) class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch of 3d inputs .. math:: y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta This module differs from the built-in PyTorch BatchNorm2d as the mean and standard-deviation are reduced across all devices during training. For example, when one uses `nn.DataParallel` to wrap the network during training, PyTorch's implementation normalize the tensor on each device using the statistics only on that device, which accelerated the computation and is also easy to implement, but the statistics might be inaccurate. Instead, in this synchronized version, the statistics will be computed over all training samples distributed on multiple devices. Note that, for one-GPU or CPU-only case, this module behaves exactly same as the built-in PyTorch implementation. The mean and standard-deviation are calculated per-dimension over the mini-batches and gamma and beta are learnable parameter vectors of size C (where C is the input size). During training, this layer keeps a running estimate of its computed mean and variance. The running sum is kept with a default momentum of 0.1. During evaluation, this running mean/variance is used for normalization. Because the BatchNorm is done over the `C` dimension, computing statistics on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm Args: num_features: num_features from an expected input of size batch_size x num_features x height x width eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, gives the layer learnable affine parameters. Default: ``True`` Shape: - Input: :math:`(N, C, H, W)` - Output: :math:`(N, C, H, W)` (same shape as input) Examples: >>> # With Learnable Parameters >>> m = SynchronizedBatchNorm2d(100) >>> # Without Learnable Parameters >>> m = SynchronizedBatchNorm2d(100, affine=False) >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) >>> output = m(input) """ def _check_input_dim(self, input): if input.dim() != 4: raise ValueError('expected 4D input (got {}D input)' .format(input.dim())) super(SynchronizedBatchNorm2d, self)._check_input_dim(input) class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch of 4d inputs .. math:: y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta This module differs from the built-in PyTorch BatchNorm3d as the mean and standard-deviation are reduced across all devices during training. For example, when one uses `nn.DataParallel` to wrap the network during training, PyTorch's implementation normalize the tensor on each device using the statistics only on that device, which accelerated the computation and is also easy to implement, but the statistics might be inaccurate. Instead, in this synchronized version, the statistics will be computed over all training samples distributed on multiple devices. Note that, for one-GPU or CPU-only case, this module behaves exactly same as the built-in PyTorch implementation. The mean and standard-deviation are calculated per-dimension over the mini-batches and gamma and beta are learnable parameter vectors of size C (where C is the input size). During training, this layer keeps a running estimate of its computed mean and variance. The running sum is kept with a default momentum of 0.1. During evaluation, this running mean/variance is used for normalization. Because the BatchNorm is done over the `C` dimension, computing statistics on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm or Spatio-temporal BatchNorm Args: num_features: num_features from an expected input of size batch_size x num_features x depth x height x width eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, gives the layer learnable affine parameters. Default: ``True`` Shape: - Input: :math:`(N, C, D, H, W)` - Output: :math:`(N, C, D, H, W)` (same shape as input) Examples: >>> # With Learnable Parameters >>> m = SynchronizedBatchNorm3d(100) >>> # Without Learnable Parameters >>> m = SynchronizedBatchNorm3d(100, affine=False) >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) >>> output = m(input) """ def _check_input_dim(self, input): if input.dim() != 5: raise ValueError('expected 5D input (got {}D input)' .format(input.dim())) super(SynchronizedBatchNorm3d, self)._check_input_dim(input) ================================================ FILE: LFG/sync_batchnorm/comm.py ================================================ # -*- coding: utf-8 -*- # File : comm.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import queue import collections import threading __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] class FutureResult(object): """A thread-safe future implementation. Used only as one-to-one pipe.""" def __init__(self): self._result = None self._lock = threading.Lock() self._cond = threading.Condition(self._lock) def put(self, result): with self._lock: assert self._result is None, 'Previous result has\'t been fetched.' self._result = result self._cond.notify() def get(self): with self._lock: if self._result is None: self._cond.wait() res = self._result self._result = None return res _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) class SlavePipe(_SlavePipeBase): """Pipe for master-slave communication.""" def run_slave(self, msg): self.queue.put((self.identifier, msg)) ret = self.result.get() self.queue.put(True) return ret class SyncMaster(object): """An abstract `SyncMaster` object. - During the replication, as the data parallel will trigger an callback of each module, all slave devices should call `register(id)` and obtain an `SlavePipe` to communicate with the master. - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, and passed to a registered callback. - After receiving the messages, the master device should gather the information and determine to message passed back to each slave devices. """ def __init__(self, master_callback): """ Args: master_callback: a callback to be invoked after having collected messages from slave devices. """ self._master_callback = master_callback self._queue = queue.Queue() self._registry = collections.OrderedDict() self._activated = False def __getstate__(self): return {'master_callback': self._master_callback} def __setstate__(self, state): self.__init__(state['master_callback']) def register_slave(self, identifier): """ Register an slave device. Args: identifier: an identifier, usually is the device id. Returns: a `SlavePipe` object which can be used to communicate with the master device. """ if self._activated: assert self._queue.empty(), 'Queue is not clean before next initialization.' self._activated = False self._registry.clear() future = FutureResult() self._registry[identifier] = _MasterRegistry(future) return SlavePipe(identifier, self._queue, future) def run_master(self, master_msg): """ Main entry for the master device in each forward pass. The messages were first collected from each devices (including the master device), and then an callback will be invoked to compute the message to be sent back to each devices (including the master device). Args: master_msg: the message that the master want to send to itself. This will be placed as the first message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. Returns: the message to be sent back to the master device. """ self._activated = True intermediates = [(0, master_msg)] for i in range(self.nr_slaves): intermediates.append(self._queue.get()) results = self._master_callback(intermediates) assert results[0][0] == 0, 'The first result should belongs to the master.' for i, res in results: if i == 0: continue self._registry[i].result.put(res) for i in range(self.nr_slaves): assert self._queue.get() is True return results[0][1] @property def nr_slaves(self): return len(self._registry) ================================================ FILE: LFG/sync_batchnorm/replicate.py ================================================ # -*- coding: utf-8 -*- # File : replicate.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import functools from torch.nn.parallel.data_parallel import DataParallel __all__ = [ 'CallbackContext', 'execute_replication_callbacks', 'DataParallelWithCallback', 'patch_replication_callback' ] class CallbackContext(object): pass def execute_replication_callbacks(modules): """ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` Note that, as all modules are isomorphism, we assign each sub-module with a context (shared among multiple copies of this module on different devices). Through this context, different copies can share some information. We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback of any slave copies. """ master_copy = modules[0] nr_modules = len(list(master_copy.modules())) ctxs = [CallbackContext() for _ in range(nr_modules)] for i, module in enumerate(modules): for j, m in enumerate(module.modules()): if hasattr(m, '__data_parallel_replicate__'): m.__data_parallel_replicate__(ctxs[j], i) class DataParallelWithCallback(DataParallel): """ Data Parallel with a replication callback. An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by original `replicate` function. The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` Examples: > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) # sync_bn.__data_parallel_replicate__ will be invoked. """ def replicate(self, module, device_ids): modules = super(DataParallelWithCallback, self).replicate(module, device_ids) execute_replication_callbacks(modules) return modules def patch_replication_callback(data_parallel): """ Monkey-patch an existing `DataParallel` object. Add the replication callback. Useful when you have customized `DataParallel` implementation. Examples: > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) > patch_replication_callback(sync_bn) # this is equivalent to > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) """ assert isinstance(data_parallel, DataParallel) old_replicate = data_parallel.replicate @functools.wraps(old_replicate) def new_replicate(module, device_ids): modules = old_replicate(module, device_ids) execute_replication_callbacks(modules) return modules data_parallel.replicate = new_replicate ================================================ FILE: LFG/sync_batchnorm/unittest.py ================================================ # -*- coding: utf-8 -*- # File : unittest.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import unittest import numpy as np from torch.autograd import Variable def as_numpy(v): if isinstance(v, Variable): v = v.data return v.cpu().numpy() class TorchTestCase(unittest.TestCase): def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): npa, npb = as_numpy(a), as_numpy(b) self.assertTrue( np.allclose(npa, npb, atol=atol), 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) ) ================================================ FILE: LFG/test_flowautoenc_crema_video.py ================================================ # use LFG to reconstruct testing videos and measure the loss in video domain # using RegionMM import argparse import imageio import torch from torch.utils import data import numpy as np import torch.backends.cudnn as cudnn import os import timeit from PIL import Image import sys sys.path.append("your/path/DAWN-pytorch") from misc import grid2fig from DM_2.datasets_crema_wpose_lmk_block import HDTF import random from LFG.modules.flow_autoenc import FlowAE import torch.nn.functional as F from LFG.modules.util import Visualizer import json_tricks as json import cv2 import tempfile from subprocess import call from pydub import AudioSegment from einops import rearrange from tqdm import tqdm start = timeit.default_timer() BATCH_SIZE = 1 INPUT_SIZE = 128 root_dir = 'your/path/DAWN-pytorch/AE' # your work directory data_dir = "/train20/intern/permanent/hbcheng2/data/crema/images_25hz_128_chunk" pose_dir = '/train20/intern/permanent/hbcheng2/data/crema/pose_bar_chunk' eye_blink_dir = '/train20/intern/permanent/hbcheng2/data/crema/eye_blink_bbox_bar_2_chunk' DATASAVE_DIR = '/train20/intern/permanent/hbcheng2/data' CKPT_DIR = os.path.join(DATASAVE_DIR, 'mraa_result_crema', str(INPUT_SIZE),'video') os.makedirs(CKPT_DIR, exist_ok=True) IMG_DIR = os.path.join(DATASAVE_DIR, 'mraa_result_crema', str(INPUT_SIZE),'img') os.makedirs(IMG_DIR, exist_ok=True) # GPU = "6" postfix = "" N_FRAMES = 40 NUM_VIDEOS = 10 SAVE_VIDEO = True NUM_ITER = NUM_VIDEOS // BATCH_SIZE RANDOM_SEED = 1234 MEAN = (0.0, 0.0, 0.0) # the path to trained LFG model RESTORE_FROM ='your_path/data/log-hdtf/hdtf128_2024-02-11_15:45/snapshots/RegionMM_0100_S074360.pth' # RESTORE_FROM = "/train20/intern/permanent/lmlin2/Flow/CVPR23_LFDM-main/data/log-hdtf/hdtf256_2023-11-21_16:49/snapshots/RegionMM_0020_S080000.pth" config_pth = "your/path/DAWN-pytorch/AE/data/log-hdtf/hdtf128_llm_2024-07-26_12:54/hdtf128_llm.yaml" json_path = os.path.join(CKPT_DIR, "loss%d%s.json" % (NUM_VIDEOS, postfix)) visualizer = Visualizer() print(root_dir) print(postfix) print("RESTORE_FROM:", RESTORE_FROM) print("config_path:", config_pth) print(json_path) print("save video:", SAVE_VIDEO) def get_arguments(): """Parse all the arguments provided from the CLI. Returns: A list of parsed arguments. """ parser = argparse.ArgumentParser(description="Flow Autoencoder") parser.add_argument("--num-workers", default=8) parser.add_argument("--gpu", default=0, help="choose gpu device.") parser.add_argument('--print-freq', '-p', default=1, type=int, metavar='N', help='print frequency') parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, help="Number of images sent to the network in one step.") parser.add_argument("--input-size", type=str, default=INPUT_SIZE, help="Comma-separated string with height and width of images.") parser.add_argument("--random-seed", type=int, default=RANDOM_SEED, help="Random seed to have reproducible results.") parser.add_argument("--restore-from", default=RESTORE_FROM) parser.add_argument("--fp16", default=False) return parser.parse_args() args = get_arguments() def extract_audio_by_frames(input_wav_path, start_frame_index, num_frames, frame_rate, output_wav_path): # audio = AudioSegment.from_wav(input_wav_path) # frame_duration = 1000 / frame_rate # # start_time_ms = start_frame_index * frame_duration end_time_ms = (start_frame_index + num_frames) * frame_duration # selected_audio = audio[start_time_ms:end_time_ms] # selected_audio.export(output_wav_path, format="wav") def sample_img(rec_img_batch): rec_img = rec_img_batch.permute(1, 2, 0).data.cpu().numpy().copy() rec_img += np.array(MEAN)/255.0 rec_img[rec_img < 0] = 0 rec_img[rec_img > 1] = 1 rec_img *= 255 return np.array(rec_img, np.uint8) def main(): """Create the model and start the training.""" os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) cudnn.enabled = True cudnn.benchmark = True setup_seed(args.random_seed) model = FlowAE(is_train=False, config_pth=config_pth) model.cuda() if os.path.isfile(args.restore_from): print("=> loading checkpoint '{}'".format(args.restore_from)) checkpoint = torch.load(args.restore_from) model.generator.load_state_dict(checkpoint['generator']) model.region_predictor.load_state_dict(checkpoint['region_predictor']) model.bg_predictor.load_state_dict(checkpoint['bg_predictor']) print("=> loaded checkpoint '{}'".format(args.restore_from)) else: print("=> no checkpoint found at '{}'".format(args.restore_from)) exit(-1) model.eval() setup_seed(args.random_seed) testloader = data.DataLoader(HDTF(data_dir=data_dir, pose_dir=pose_dir, eye_blink_dir = eye_blink_dir, image_size=INPUT_SIZE, mode='test', max_num_frames=1e8, color_jitter=True, mean=MEAN), batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True) batch_time = AverageMeter() data_time = AverageMeter() iter_end = timeit.default_timer() cnt = 0 out_loss = 0.0 warp_loss = 0.0 num_sample = 0.0 l1_loss = torch.nn.L1Loss(reduction='sum') global_iter = 0 while global_iter < NUM_ITER: for i_iter, batch in enumerate(testloader): # if i_iter < NUM_ITER: # break # if global_iter < NUM_ITER: # break data_time.update(timeit.default_timer() - iter_end) real_vids, ref_hubert, real_poses, real_blink_bbox, mouth_lmk_tensor, real_names, _ = batch # use first frame of each video as reference frame real_vids = real_vids/255. ref_imgs = real_vids[:, :, 0, :, :].clone().detach() bs = real_vids.size(0) batch_time.update(timeit.default_timer() - iter_end) nf = real_vids.size(2) out_img_list = [] warped_img_list = [] warped_grid_list = [] conf_map_list = [] segment_length = 80 b,c,f,h,w = real_vids.size() real_vid_tmp = rearrange(real_vids, 'b c f h w -> (b f) c h w')# real_vid.reshape(b * f, c, h, w) ref_img_tmp = ref_imgs.repeat(segment_length,1,1,1).reshape(-1, 3, INPUT_SIZE, INPUT_SIZE) for frame_idx in tqdm(range(0, nf, segment_length)): end_fn = min(nf, frame_idx + segment_length) dri_imgs = real_vid_tmp[frame_idx : end_fn, :, :, :] if end_fn == nf: ref_img_tmp = ref_imgs.repeat(dri_imgs.shape[0],1,1,1).reshape(-1, 3, INPUT_SIZE, INPUT_SIZE) with torch.no_grad(): model.set_train_input(ref_img=ref_img_tmp, dri_img=dri_imgs) model.forward() out_img_list.append(model.generated['prediction'].clone().detach().cpu()) # warped_img_list.append(model.generated['deformed'].clone().detach()) out_img_list_tensor = torch.concat(out_img_list, dim = 0) # out_loss += l1_loss(real_vids.permute(2, 0, 1, 3, 4).cpu(), out_img_list_tensor.cpu()).item() # warp_loss += l1_loss(real_vids.permute(2, 0, 1, 3, 4).cpu(), warped_img_list_tensor.cpu()).item() num_sample += bs if SAVE_VIDEO: for batch_idx in range(bs): msk_size = ref_imgs.shape[-1] new_im_list = [] img_dir_name = "%04d_%s" % (i_iter, real_names[batch_idx]) cur_img_dir_gt = os.path.join(IMG_DIR, img_dir_name,'gt') os.makedirs(cur_img_dir_gt, exist_ok=True) cur_img_dir_samp = os.path.join(IMG_DIR, img_dir_name,'mraa') os.makedirs(cur_img_dir_samp, exist_ok=True) fps = 25 tmp_video_file_pred = tempfile.NamedTemporaryFile('w', suffix='.mp4', dir='your/path/DAWN-pytorch/demo') output_wav_path = tempfile.NamedTemporaryFile('w', suffix='.wav', dir='your/path/DAWN-pytorch/demo').name fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_writer = cv2.VideoWriter(tmp_video_file_pred.name, fourcc, fps, (INPUT_SIZE, INPUT_SIZE)) SAV_DIR = os.path.join(CKPT_DIR, str(i_iter)+'_'+real_names[0] + '.mp4') wav_path = os.path.join('/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/audio', real_names[0].replace('_','/',1)+'.wav') extract_audio_by_frames(wav_path, 0, nf, fps, output_wav_path) for frame_idx in range(nf): new_im_gt = Image.new('RGB', (msk_size, msk_size)) new_im_sample = Image.new('RGB', (msk_size, msk_size)) save_tar_img = sample_img(real_vids[0, :, frame_idx]) save_out_img = sample_img(out_img_list_tensor[frame_idx]) # save_warped_img = sample_img(warped_img_list_tensor[frame_idx], batch_idx) # save_warped_grid = grid2fig(warped_grid_list_tensor[frame_idx, batch_idx].data.cpu().numpy(), # grid_size=32, img_size=msk_size) # save_conf_map = conf_map_list_tensor[frame_idx, batch_idx].unsqueeze(dim=0) # save_conf_map = save_conf_map.data.cpu() # save_conf_map = F.interpolate(save_conf_map, size=real_vids.shape[3:5]).numpy() # save_conf_map = np.transpose(save_conf_map, [0, 2, 3, 1]) # save_conf_map = np.array(save_conf_map[0, :, :, 0]*255, dtype=np.uint8) frame_rgb = np.uint8(save_out_img) frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR) video_writer.write(frame_bgr) # save sample and gt imgs new_im_gt.paste(Image.fromarray(save_tar_img, 'RGB'), (0, 0)) new_im_sample.paste(Image.fromarray(save_out_img, 'RGB'), (0, 0)) new_im_arr_gt = np.array(new_im_gt) new_im_arr_sample = np.array(new_im_sample) new_im_name = "%03d_%s.png" % (frame_idx, real_names[batch_idx]) imageio.imsave(os.path.join(cur_img_dir_gt,new_im_name), new_im_arr_gt) imageio.imsave(os.path.join(cur_img_dir_samp,new_im_name), new_im_arr_sample) # new_im = Image.new('RGB', (msk_size * 5, msk_size)) # new_im.paste(Image.fromarray(save_tar_img, 'RGB'), (0, 0)) # new_im.paste(Image.fromarray(save_out_img, 'RGB'), (msk_size, 0)) # new_im.paste(Image.fromarray(save_warped_img, 'RGB'), (msk_size * 2, 0)) # new_im.paste(Image.fromarray(save_warped_grid), (msk_size * 3, 0)) # new_im.paste(Image.fromarray(save_conf_map, "L"), (msk_size * 4, 0)) # new_im_list.append(new_im) # video_name = "%04d_%s.gif" % (cnt, real_names[batch_idx]) # imageio.mimsave(os.path.join(CKPT_DIR, video_name), new_im_list) cnt += 1 video_writer.release() cmd = ('ffmpeg -y ' + ' -i {0} -i {1} -vcodec copy -ac 2 -channel_layout stereo -pix_fmt yuv420p {2} -shortest'.format( output_wav_path, tmp_video_file_pred.name, SAV_DIR)).split() call(cmd) try: os.remove(tmp_video_file_pred.name) os.remove(output_wav_path) except OSError as e: print(f'Error: {e.strerror}') iter_end = timeit.default_timer() if global_iter % args.print_freq == 0: print('Test:[{0}/{1}]\t' 'Time {batch_time.val:.3f}({batch_time.avg:.3f})' .format(global_iter, NUM_ITER, batch_time=batch_time)) global_iter += 1 print("loss for prediction: %.5f" % (out_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3))) print("loss for warping: %.5f" % (warp_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3))) res_dict = {} res_dict["out_loss"] = out_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3) res_dict["warp_loss"] = warp_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3) with open(json_path, "w") as f: json.dump(res_dict, f) end = timeit.default_timer() print(end - start, 'seconds') class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True if __name__ == '__main__': main() ================================================ FILE: LFG/test_flowautoenc_hdtf_video.py ================================================ # use LFG to reconstruct testing videos and measure the loss in video domain # using RegionMM import argparse import imageio import torch from torch.utils import data import numpy as np import torch.backends.cudnn as cudnn import os import timeit from PIL import Image import sys sys.path.append("your/path/DAWN-pytorch") from misc import grid2fig from DM.datasets_hdtf_wpose_lmk_mo_block import HDTF import random from LFG.modules.flow_autoenc import FlowAE import torch.nn.functional as F from LFG.modules.util import Visualizer import json_tricks as json import cv2 import tempfile from subprocess import call from pydub import AudioSegment from einops import rearrange from tqdm import tqdm start = timeit.default_timer() BATCH_SIZE = 1 INPUT_SIZE = 128 root_dir = 'your/path/DAWN-pytorch/AE' # your work directory data_dir = "/train20/intern/permanent/hbcheng2/data/HDTF/images_25hz_128_chunk" pose_dir = "/train20/intern/permanent/hbcheng2/data/HDTF/pose_bar_chunk" eye_blink_dir = "/train20/intern/permanent/hbcheng2/data/HDTF/eye_blink_bbox_from_xpc_bar_2_chunk" DATASAVE_DIR = '/train20/intern/permanent/hbcheng2/data' CKPT_DIR = os.path.join(DATASAVE_DIR, 'mraa_result', str(INPUT_SIZE) + '_1000ep','video') os.makedirs(CKPT_DIR, exist_ok=True) IMG_DIR = os.path.join(DATASAVE_DIR, 'mraa_result', str(INPUT_SIZE) + '_1000ep','img') os.makedirs(IMG_DIR, exist_ok=True) # GPU = "6" postfix = "" N_FRAMES = 40 NUM_VIDEOS = 10 SAVE_VIDEO = True NUM_ITER = NUM_VIDEOS // BATCH_SIZE RANDOM_SEED = 1234 MEAN = (0.0, 0.0, 0.0) # the path to trained LFG model RESTORE_FROM ='your/path/DAWN-pytorch/AE/data/log-hdtf-cosin/hdtf128_1000ep_2024-08-08_15:04/snapshots/RegionMM.pth' # RESTORE_FROM = "/train20/intern/permanent/lmlin2/Flow/CVPR23_LFDM-main/data/log-hdtf/hdtf256_2023-11-21_16:49/snapshots/RegionMM_0020_S080000.pth" config_pth = "your/path/DAWN-pytorch/AE/data/log-hdtf/hdtf128_llm_2024-07-26_12:54/hdtf128_llm.yaml" json_path = os.path.join(CKPT_DIR, "loss%d%s.json" % (NUM_VIDEOS, postfix)) visualizer = Visualizer() print(root_dir) print(postfix) print("RESTORE_FROM:", RESTORE_FROM) print("config_path:", config_pth) print(json_path) print("save video:", SAVE_VIDEO) def get_arguments(): """Parse all the arguments provided from the CLI. Returns: A list of parsed arguments. """ parser = argparse.ArgumentParser(description="Flow Autoencoder") parser.add_argument("--num-workers", default=8) parser.add_argument("--gpu", default=0, help="choose gpu device.") parser.add_argument('--print-freq', '-p', default=1, type=int, metavar='N', help='print frequency') parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, help="Number of images sent to the network in one step.") parser.add_argument("--input-size", type=str, default=INPUT_SIZE, help="Comma-separated string with height and width of images.") parser.add_argument("--random-seed", type=int, default=RANDOM_SEED, help="Random seed to have reproducible results.") parser.add_argument("--restore-from", default=RESTORE_FROM) parser.add_argument("--fp16", default=False) return parser.parse_args() args = get_arguments() def extract_audio_by_frames(input_wav_path, start_frame_index, num_frames, frame_rate, output_wav_path): # audio = AudioSegment.from_wav(input_wav_path) # frame_duration = 1000 / frame_rate # # start_time_ms = start_frame_index * frame_duration end_time_ms = (start_frame_index + num_frames) * frame_duration # selected_audio = audio[start_time_ms:end_time_ms] # selected_audio.export(output_wav_path, format="wav") def sample_img(rec_img_batch): rec_img = rec_img_batch.permute(1, 2, 0).data.cpu().numpy().copy() rec_img += np.array(MEAN)/255.0 rec_img[rec_img < 0] = 0 rec_img[rec_img > 1] = 1 rec_img *= 255 return np.array(rec_img, np.uint8) def main(): """Create the model and start the training.""" os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) cudnn.enabled = True cudnn.benchmark = True setup_seed(args.random_seed) model = FlowAE(is_train=False, config_pth=config_pth) model.cuda() if os.path.isfile(args.restore_from): print("=> loading checkpoint '{}'".format(args.restore_from)) checkpoint = torch.load(args.restore_from) model.generator.load_state_dict(checkpoint['generator']) model.region_predictor.load_state_dict(checkpoint['region_predictor']) model.bg_predictor.load_state_dict(checkpoint['bg_predictor']) print("=> loaded checkpoint '{}'".format(args.restore_from)) else: print("=> no checkpoint found at '{}'".format(args.restore_from)) exit(-1) model.eval() setup_seed(args.random_seed) testloader = data.DataLoader(HDTF(data_dir=data_dir, pose_dir=pose_dir, eye_blink_dir = eye_blink_dir, image_size=INPUT_SIZE, mode='test', max_num_frames=1e8, color_jitter=True, mean=MEAN), batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True) batch_time = AverageMeter() data_time = AverageMeter() iter_end = timeit.default_timer() cnt = 0 out_loss = 0.0 warp_loss = 0.0 num_sample = 0.0 l1_loss = torch.nn.L1Loss(reduction='sum') global_iter = 0 while global_iter < NUM_ITER: for i_iter, batch in enumerate(testloader): # if i_iter < NUM_ITER: # break # if global_iter < NUM_ITER: # break data_time.update(timeit.default_timer() - iter_end) real_vids, ref_hubert, real_poses, real_blink_bbox, real_mouth_ratio, real_names, start_frame_index = batch # use first frame of each video as reference frame real_vids = real_vids/255. ref_imgs = real_vids[:, :, 0, :, :].clone().detach() bs = real_vids.size(0) batch_time.update(timeit.default_timer() - iter_end) nf = real_vids.size(2) out_img_list = [] warped_img_list = [] warped_grid_list = [] conf_map_list = [] segment_length = 120 b,c,f,h,w = real_vids.size() real_vid_tmp = rearrange(real_vids, 'b c f h w -> (b f) c h w')# real_vid.reshape(b * f, c, h, w) ref_img_tmp = ref_imgs.repeat(segment_length,1,1,1).reshape(-1, 3, INPUT_SIZE, INPUT_SIZE) for frame_idx in tqdm(range(0, nf, segment_length)): end_fn = min(nf, frame_idx + segment_length) dri_imgs = real_vid_tmp[frame_idx : end_fn, :, :, :] if end_fn == nf: ref_img_tmp = ref_imgs.repeat(dri_imgs.shape[0],1,1,1).reshape(-1, 3, INPUT_SIZE, INPUT_SIZE) with torch.no_grad(): model.set_train_input(ref_img=ref_img_tmp, dri_img=dri_imgs) model.forward() out_img_list.append(model.generated['prediction'].clone().detach().cpu()) # warped_img_list.append(model.generated['deformed'].clone().detach()) out_img_list_tensor = torch.concat(out_img_list, dim = 0) # out_loss += l1_loss(real_vids.permute(2, 0, 1, 3, 4).cpu(), out_img_list_tensor.cpu()).item() # warp_loss += l1_loss(real_vids.permute(2, 0, 1, 3, 4).cpu(), warped_img_list_tensor.cpu()).item() num_sample += bs if SAVE_VIDEO: for batch_idx in range(bs): msk_size = ref_imgs.shape[-1] new_im_list = [] img_dir_name = "%04d_%s" % (i_iter, real_names[batch_idx]) cur_img_dir_gt = os.path.join(IMG_DIR, img_dir_name,'gt') os.makedirs(cur_img_dir_gt, exist_ok=True) cur_img_dir_samp = os.path.join(IMG_DIR, img_dir_name,'mraa') os.makedirs(cur_img_dir_samp, exist_ok=True) fps = 25 # tmp_video_file_pred = tempfile.NamedTemporaryFile('w', suffix='.mp4', dir='your/path/DAWN-pytorch/demo') output_wav_path = tempfile.NamedTemporaryFile('w', suffix='.wav', dir='your/path/DAWN-pytorch/demo').name fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_writer = cv2.VideoWriter(tmp_video_file_pred.name, fourcc, fps, (INPUT_SIZE, INPUT_SIZE)) SAV_DIR = os.path.join(CKPT_DIR, str(i_iter)+'_'+real_names[0] + '.mp4') wav_path = os.path.join("/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz".replace('/images_25hz','/image_audio'), real_names[0]+'.wav') extract_audio_by_frames(wav_path, 0, nf, fps, output_wav_path) for frame_idx in range(nf): new_im_gt = Image.new('RGB', (msk_size, msk_size)) new_im_sample = Image.new('RGB', (msk_size, msk_size)) save_tar_img = sample_img(real_vids[0, :, frame_idx]) save_out_img = sample_img(out_img_list_tensor[frame_idx]) # save_warped_img = sample_img(warped_img_list_tensor[frame_idx], batch_idx) # save_warped_grid = grid2fig(warped_grid_list_tensor[frame_idx, batch_idx].data.cpu().numpy(), # grid_size=32, img_size=msk_size) # save_conf_map = conf_map_list_tensor[frame_idx, batch_idx].unsqueeze(dim=0) # save_conf_map = save_conf_map.data.cpu() # save_conf_map = F.interpolate(save_conf_map, size=real_vids.shape[3:5]).numpy() # save_conf_map = np.transpose(save_conf_map, [0, 2, 3, 1]) # save_conf_map = np.array(save_conf_map[0, :, :, 0]*255, dtype=np.uint8) frame_rgb = np.uint8(save_out_img) frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR) video_writer.write(frame_bgr) # save sample and gt imgs new_im_gt.paste(Image.fromarray(save_tar_img, 'RGB'), (0, 0)) new_im_sample.paste(Image.fromarray(save_out_img, 'RGB'), (0, 0)) new_im_arr_gt = np.array(new_im_gt) new_im_arr_sample = np.array(new_im_sample) new_im_name = "%03d_%s.png" % (frame_idx, real_names[batch_idx]) imageio.imsave(os.path.join(cur_img_dir_gt,new_im_name), new_im_arr_gt) imageio.imsave(os.path.join(cur_img_dir_samp,new_im_name), new_im_arr_sample) # new_im = Image.new('RGB', (msk_size * 5, msk_size)) # new_im.paste(Image.fromarray(save_tar_img, 'RGB'), (0, 0)) # new_im.paste(Image.fromarray(save_out_img, 'RGB'), (msk_size, 0)) # new_im.paste(Image.fromarray(save_warped_img, 'RGB'), (msk_size * 2, 0)) # new_im.paste(Image.fromarray(save_warped_grid), (msk_size * 3, 0)) # new_im.paste(Image.fromarray(save_conf_map, "L"), (msk_size * 4, 0)) # new_im_list.append(new_im) # video_name = "%04d_%s.gif" % (cnt, real_names[batch_idx]) # imageio.mimsave(os.path.join(CKPT_DIR, video_name), new_im_list) cnt += 1 video_writer.release() cmd = ('ffmpeg -y ' + ' -i {0} -i {1} -vcodec copy -ac 2 -channel_layout stereo -pix_fmt yuv420p {2} -shortest'.format( output_wav_path, tmp_video_file_pred.name, SAV_DIR)).split() call(cmd) try: os.remove(tmp_video_file_pred.name) os.remove(output_wav_path) except OSError as e: print(f'Error: {e.strerror}') iter_end = timeit.default_timer() if global_iter % args.print_freq == 0: print('Test:[{0}/{1}]\t' 'Time {batch_time.val:.3f}({batch_time.avg:.3f})' .format(global_iter, NUM_ITER, batch_time=batch_time)) global_iter += 1 print("loss for prediction: %.5f" % (out_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3))) print("loss for warping: %.5f" % (warp_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3))) res_dict = {} res_dict["out_loss"] = out_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3) res_dict["warp_loss"] = warp_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3) with open(json_path, "w") as f: json.dump(res_dict, f) end = timeit.default_timer() print(end - start, 'seconds') class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True if __name__ == '__main__': main() ================================================ FILE: LFG/test_flowautoenc_hdtf_video_256.py ================================================ # use LFG to reconstruct testing videos and measure the loss in video domain # using RegionMM import argparse import imageio import torch from torch.utils import data import numpy as np import torch.backends.cudnn as cudnn import os import timeit from PIL import Image import sys sys.path.append("your/path/DAWN-pytorch") from misc import grid2fig from DM.datasets_hdtf_wpose_lmk_mo_block_mraa import HDTF_test as HDTF import random from LFG.modules.flow_autoenc import FlowAE import torch.nn.functional as F from LFG.modules.util import Visualizer import json_tricks as json import cv2 import tempfile from subprocess import call from pydub import AudioSegment from einops import rearrange from tqdm import tqdm start = timeit.default_timer() BATCH_SIZE = 1 INPUT_SIZE = 256 root_dir = 'your/path/DAWN-pytorch/AE' # your work directory data_dir = "/train20/intern/permanent/hbcheng2/data/HDTF/images_25hz_256_chunk" pose_dir = "/train20/intern/permanent/hbcheng2/data/HDTF/pose_bar_chunk" eye_blink_dir = "/train20/intern/permanent/hbcheng2/data/HDTF/eye_blink_bbox_from_xpc_bar_2_chunk" DATASAVE_DIR = '/train20/intern/permanent/hbcheng2/data' CKPT_DIR = os.path.join(DATASAVE_DIR, 'mraa_result', str(INPUT_SIZE) + '_400ep','video') os.makedirs(CKPT_DIR, exist_ok=True) IMG_DIR = os.path.join(DATASAVE_DIR, 'mraa_result', str(INPUT_SIZE) + '_400ep','img') os.makedirs(IMG_DIR, exist_ok=True) # GPU = "6" postfix = "" N_FRAMES = 40 NUM_VIDEOS = 10 SAVE_VIDEO = True NUM_ITER = NUM_VIDEOS // BATCH_SIZE RANDOM_SEED = 1234 MEAN = (0.0, 0.0, 0.0) # the path to trained LFG model RESTORE_FROM ='your/path/DAWN-pytorch/AE/data/log-hdtf-256-cosin/hdtf256_400ep_2024-08-08_00:15/snapshots/RegionMM.pth' # RESTORE_FROM = "/train20/intern/permanent/lmlin2/Flow/CVPR23_LFDM-main/data/log-hdtf/hdtf256_2023-11-21_16:49/snapshots/RegionMM_0020_S080000.pth" config_pth = "your/path/DAWN-pytorch/config/hdtf256.yaml" json_path = os.path.join(CKPT_DIR, "loss%d%s.json" % (NUM_VIDEOS, postfix)) visualizer = Visualizer() print(root_dir) print(postfix) print("RESTORE_FROM:", RESTORE_FROM) print("config_path:", config_pth) print(json_path) print("save video:", SAVE_VIDEO) def get_arguments(): """Parse all the arguments provided from the CLI. Returns: A list of parsed arguments. """ parser = argparse.ArgumentParser(description="Flow Autoencoder") parser.add_argument("--num-workers", default=8) parser.add_argument("--gpu", default=0, help="choose gpu device.") parser.add_argument('--print-freq', '-p', default=1, type=int, metavar='N', help='print frequency') parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, help="Number of images sent to the network in one step.") parser.add_argument("--input-size", type=str, default=INPUT_SIZE, help="Comma-separated string with height and width of images.") parser.add_argument("--random-seed", type=int, default=RANDOM_SEED, help="Random seed to have reproducible results.") parser.add_argument("--restore-from", default=RESTORE_FROM) parser.add_argument("--fp16", default=False) return parser.parse_args() args = get_arguments() def extract_audio_by_frames(input_wav_path, start_frame_index, num_frames, frame_rate, output_wav_path): # audio = AudioSegment.from_wav(input_wav_path) # frame_duration = 1000 / frame_rate # # start_time_ms = start_frame_index * frame_duration end_time_ms = (start_frame_index + num_frames) * frame_duration # selected_audio = audio[start_time_ms:end_time_ms] # selected_audio.export(output_wav_path, format="wav") def sample_img(rec_img_batch): rec_img = rec_img_batch.permute(1, 2, 0).data.cpu().numpy().copy() rec_img += np.array(MEAN)/255.0 rec_img[rec_img < 0] = 0 rec_img[rec_img > 1] = 1 rec_img *= 255 return np.array(rec_img, np.uint8) def main(): """Create the model and start the training.""" os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) cudnn.enabled = True cudnn.benchmark = True setup_seed(args.random_seed) model = FlowAE(is_train=False, config_pth=config_pth) model.cuda() if os.path.isfile(args.restore_from): print("=> loading checkpoint '{}'".format(args.restore_from)) checkpoint = torch.load(args.restore_from) model.generator.load_state_dict(checkpoint['generator']) model.region_predictor.load_state_dict(checkpoint['region_predictor']) model.bg_predictor.load_state_dict(checkpoint['bg_predictor']) print("=> loaded checkpoint '{}'".format(args.restore_from)) else: print("=> no checkpoint found at '{}'".format(args.restore_from)) exit(-1) model.eval() setup_seed(args.random_seed) testloader = data.DataLoader(HDTF(data_dir=data_dir, pose_dir=pose_dir, eye_blink_dir = eye_blink_dir, image_size=INPUT_SIZE, mode='test', max_num_frames=1e8, color_jitter=True, mean=MEAN), batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True) batch_time = AverageMeter() data_time = AverageMeter() iter_end = timeit.default_timer() cnt = 0 out_loss = 0.0 warp_loss = 0.0 num_sample = 0.0 l1_loss = torch.nn.L1Loss(reduction='sum') global_iter = 0 while global_iter < NUM_ITER: for i_iter, batch in enumerate(testloader): # if i_iter < NUM_ITER: # break # if global_iter < NUM_ITER: # break data_time.update(timeit.default_timer() - iter_end) block_path_list, real_names, total_num_frames = batch out_img_list = [] # use first frame of each video as reference frame ref_path = block_path_list[0][0] ref_imgs = np.load(ref_path) # 25, 256, 256, 3 ref_imgs = torch.tensor(ref_imgs).permute(0, 3, 1, 2) ref_imgs = ref_imgs[0].clone().detach().to(torch.float32)/255. ref_img_tmp = ref_imgs.repeat(25,1,1,1).reshape(-1, 3, INPUT_SIZE, INPUT_SIZE) msk_size = ref_imgs.shape[-1] new_im_list = [] img_dir_name = "%04d_%s" % (i_iter, real_names[0]) cur_img_dir_gt = os.path.join(IMG_DIR, img_dir_name,'gt') os.makedirs(cur_img_dir_gt, exist_ok=True) cur_img_dir_samp = os.path.join(IMG_DIR, img_dir_name,'mraa') os.makedirs(cur_img_dir_samp, exist_ok=True) fps = 25 # tmp_video_file_pred = tempfile.NamedTemporaryFile('w', suffix='.mp4', dir='your/path/DAWN-pytorch/demo') output_wav_path = tempfile.NamedTemporaryFile('w', suffix='.wav', dir='your/path/DAWN-pytorch/demo').name fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_writer = cv2.VideoWriter(tmp_video_file_pred.name, fourcc, fps, (INPUT_SIZE, INPUT_SIZE)) SAV_DIR = os.path.join(CKPT_DIR, str(i_iter)+'_'+real_names[0] + '.mp4') wav_path = os.path.join("/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz".replace('/images_25hz','/image_audio'), real_names[0]+'.wav') extract_audio_by_frames(wav_path, 0, total_num_frames, fps, output_wav_path) batch_time.update(timeit.default_timer() - iter_end) new_im_gt = Image.new('RGB', (msk_size, msk_size)) new_im_sample = Image.new('RGB', (msk_size, msk_size)) frame_cnt = 0 for id in range(len(block_path_list)): block_path = block_path_list[id][0] real_vids = np.load(block_path) if real_vids.shape[0] !=ref_img_tmp.shape[0]: ref_img_tmp = ref_imgs.repeat(real_vids.shape[0],1,1,1).reshape(-1, 3, INPUT_SIZE, INPUT_SIZE) real_vids = torch.tensor(real_vids).permute(0, 3, 1, 2) # 25, 256, 256, 3 - > 25, 3, 256, 256 dri_imgs = real_vids.to(torch.float32)/255. with torch.no_grad(): model.set_train_input(ref_img=ref_img_tmp, dri_img=dri_imgs) model.forward() out_img_tensor = (model.generated['prediction'] * 255.).to(torch.uint8).clone().detach().cpu() # save real_vids for i in range(real_vids.shape[0]): save_tar_img = np.array(real_vids[i].permute(1, 2, 0), np.uint8) save_out_img = np.array(out_img_tensor[i].permute(1, 2, 0), np.uint8) frame_rgb = np.uint8(save_out_img) frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR) video_writer.write(frame_bgr) new_im_gt.paste(Image.fromarray(save_tar_img, 'RGB'), (0, 0)) new_im_sample.paste(Image.fromarray(save_out_img, 'RGB'), (0, 0)) new_im_arr_gt = np.array(new_im_gt) new_im_arr_sample = np.array(new_im_sample) new_im_name = "%03d_%s.png" % (frame_cnt, real_names[0]) imageio.imsave(os.path.join(cur_img_dir_gt,new_im_name), new_im_arr_gt) imageio.imsave(os.path.join(cur_img_dir_samp,new_im_name), new_im_arr_sample) frame_cnt += 1 # out_loss += l1_loss(real_vids.permute(2, 0, 1, 3, 4).cpu(), out_img_list_tensor.cpu()).item() # warp_loss += l1_loss(real_vids.permute(2, 0, 1, 3, 4).cpu(), warped_img_list_tensor.cpu()).item() video_writer.release() cmd = ('ffmpeg -y ' + ' -i {0} -i {1} -vcodec copy -ac 2 -channel_layout stereo -pix_fmt yuv420p {2} -shortest'.format( output_wav_path, tmp_video_file_pred.name, SAV_DIR)).split() call(cmd) try: os.remove(tmp_video_file_pred.name) os.remove(output_wav_path) except OSError as e: print(f'Error: {e.strerror}') iter_end = timeit.default_timer() if global_iter % args.print_freq == 0: print('Test:[{0}/{1}]\t' 'Time {batch_time.val:.3f}({batch_time.avg:.3f})' .format(global_iter, NUM_ITER, batch_time=batch_time)) global_iter += 1 print("loss for prediction: %.5f" % (out_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3))) print("loss for warping: %.5f" % (warp_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3))) res_dict = {} res_dict["out_loss"] = out_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3) res_dict["warp_loss"] = warp_loss/(num_sample*INPUT_SIZE*INPUT_SIZE*3) with open(json_path, "w") as f: json.dump(res_dict, f) end = timeit.default_timer() print(end - start, 'seconds') class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True if __name__ == '__main__': main() ================================================ FILE: LFG/train.py ================================================ # train a LFAE # this code is based on RegionMM (MRAA): https://github.com/snap-research/articulated-animation import os.path import torch from torch.utils.data import DataLoader from modules.model import ReconstructionModel from torch.optim.lr_scheduler import MultiStepLR from sync_batchnorm import DataParallelWithCallback from frames_dataset import DatasetRepeater import timeit from modules.util import Visualizer import imageio import math class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def train(config, generator, region_predictor, bg_predictor, checkpoint, log_dir, dataset, device_ids): train_params = config['train_params'] optimizer = torch.optim.Adam(list(generator.parameters()) + list(region_predictor.parameters()) + list(bg_predictor.parameters()), lr=train_params['lr'], betas=(0.5, 0.999)) start_epoch = 0 start_step = 0 if checkpoint is not None: ckpt = torch.load(checkpoint) if config["set_start"]: start_step = int(math.ceil(ckpt['example'] / config['train_params']['batch_size'])) start_epoch = ckpt['epoch'] generator.load_state_dict(ckpt['generator']) region_predictor.load_state_dict(ckpt['region_predictor']) bg_predictor.load_state_dict(ckpt['bg_predictor']) if 'optimizer' in list(ckpt.keys()): try: optimizer.load_state_dict(ckpt['optimizer']) except: optimizer.load_state_dict(ckpt['optimizer'].state_dict()) # scheduler = MultiStepLR(optimizer, train_params['epoch_milestones'], gamma=0.1, last_epoch=start_epoch - 1) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=train_params["max_epochs"], eta_min=2e-6) if 'num_repeats' in train_params or train_params['num_repeats'] != 1: dataset = DatasetRepeater(dataset, train_params['num_repeats']) dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=8, drop_last=True) model = ReconstructionModel(region_predictor, bg_predictor, generator, train_params) visualizer = Visualizer(**config['visualizer_params']) if torch.cuda.is_available(): if ('use_sync_bn' in train_params) and train_params['use_sync_bn']: model = DataParallelWithCallback(model, device_ids=device_ids) else: model = torch.nn.DataParallel(model, device_ids=device_ids) # rewritten by nhm batch_time = AverageMeter() data_time = AverageMeter() total_losses = AverageMeter() losses_perc = AverageMeter() losses_equiv_shift = AverageMeter() losses_equiv_affine = AverageMeter() cnt = 0 epoch_cnt = start_epoch actual_step = start_step final_step = config["num_step_per_epoch"] * train_params["max_epochs"] while actual_step < final_step: iter_end = timeit.default_timer() for i_iter, x in enumerate(dataloader): actual_step = int(start_step + cnt) data_time.update(timeit.default_timer() - iter_end) optimizer.zero_grad() losses, generated = model(x) loss_values = [val.mean() for val in losses.values()] loss = sum(loss_values) loss.backward() optimizer.step() batch_time.update(timeit.default_timer() - iter_end) iter_end = timeit.default_timer() bs = x['source'].size(0) total_losses.update(loss.item(), bs) losses_perc.update(loss_values[0].item(), bs) losses_equiv_shift.update(loss_values[1].item(), bs) losses_equiv_affine.update(loss_values[2].item(), bs) if actual_step % train_params["print_freq"] == 0: print('iter: [{0}]{1}/{2}\t' 'loss {loss.val:.4f} ({loss.avg:.4f})\t' 'loss_perc {loss_perc.val:.4f} ({loss_perc.avg:.4f})\n' 'loss_shift {loss_shift.val:.4f} ({loss_shift.avg:.4f})\t' 'loss_affine {loss_affine.val:.4f} ({loss_affine.avg:.4f})' .format( cnt, actual_step, final_step, loss=total_losses, loss_perc=losses_perc, loss_shift=losses_equiv_shift, loss_affine=losses_equiv_affine )) if actual_step % train_params['save_img_freq'] == 0: save_image = visualizer.visualize(x['driving'], x['source'], generated, index=0) save_name = 'B' + format(train_params["batch_size"], "04d") + '_S' + format(actual_step, "06d") \ + '_' + x["frame"][0][0][:-4] + '_to_' + x["frame"][1][0][-7:] save_file = os.path.join(config["imgshots"], save_name) imageio.imsave(save_file, save_image) if actual_step % config["save_ckpt_freq"] == 0 and cnt != 0: print('taking snapshot...') torch.save({'example': actual_step * train_params["batch_size"], 'epoch': epoch_cnt, 'generator': generator.state_dict(), 'bg_predictor': bg_predictor.state_dict(), 'region_predictor': region_predictor.state_dict(), 'optimizer': optimizer.state_dict()}, os.path.join(config["snapshots"], 'RegionMM_' + format(train_params["batch_size"], "04d") + '_S' + format(actual_step, "06d") + '.pth')) if actual_step % train_params["update_ckpt_freq"] == 0 and cnt != 0: print('updating snapshot...') torch.save({'example': actual_step * train_params["batch_size"], 'epoch': epoch_cnt, 'generator': generator.state_dict(), 'bg_predictor': bg_predictor.state_dict(), 'region_predictor': region_predictor.state_dict(), 'optimizer': optimizer.state_dict()}, os.path.join(config["snapshots"],'RegionMM.pth')) if actual_step >= final_step: break cnt += 1 scheduler.step() epoch_cnt += 1 # print lr print("epoch %d, lr= %.7f" % (epoch_cnt, optimizer.param_groups[0]["lr"])) print('save the final model...') torch.save({'example': actual_step * train_params["batch_size"], 'epoch': epoch_cnt, 'generator': generator.state_dict(), 'bg_predictor': bg_predictor.state_dict(), 'region_predictor': region_predictor.state_dict(), 'optimizer': optimizer.state_dict()}, os.path.join(config["snapshots"], 'RegionMM_' + format(train_params["batch_size"], "04d") + '_S' + format(actual_step, "06d") + '.pth')) ================================================ FILE: LFG/vis_flow.py ================================================ import torch import numpy as np import matplotlib.pyplot as plt def visualize_dense_optical_flow(flow_tensor, save_path): flow_np = flow_tensor.cpu().numpy() flow_tensor = flow_tensor + 1e-7 magnitude = np.sqrt(flow_np[0]**2 + flow_np[1]**2) # mask = magnitude > 1/64 magnitude = magnitude # * mask angle = np.arctan2(flow_np[1], flow_np[0]) angle = angle # * mask plt.figure() plt.imshow(magnitude, cmap='BuPu', alpha=0.8) plt.imshow(angle, cmap='hsv', alpha=0.2) plt.title('Dense Optical Flow') plt.axis('off') plt.savefig(save_path) plt.close() def grid2flow(warped_grid, grid_size=64, img_size=256): dpi = 1000 # plt.ioff() h_range = torch.linspace(-1, 1, grid_size) w_range = torch.linspace(-1, 1, grid_size) grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).flip(2) out = warped_grid - grid return out if __name__ == '__main__': dense_flow_tensor = torch.zeros(2, 100, 100) visualize_dense_optical_flow(dense_flow_tensor, 'test.jpg') ================================================ FILE: PBnet/run_cvae_h_ann_reemb_rope_eye_3.sh ================================================ source /home4/intern/lmlin2/.bashrc conda activate actor # crema rc delta pose export CUDA_VISIBLE_DEVICES="0" # python /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/src/train/train_cvae.py\ # --num_frames 40\ # --lambda_kl 1\ # --lambda_ssim 1\ # --lambda_freq 1\ # --modelname cvae_transformer_ssim_kl_freq\ # --dataset hdtf\ # --num_epochs 10000\ # --folder exps_delta_pose/HDTF_nf40_kl1_ssim1_freq_128_w5_1w_6 python /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/src/train/train_cvae_ganloss_ann_eye.py\ --num_frames 200\ --eye True\ --lr 0.0004 \ --batch_size 40\ --lambda_kl 0.004\ --lambda_reg 0.0005\ --lambda_rc 1\ --ff_size 128\ --max_distance 128\ --num_buckets 128\ --num_layers 2\ --audio_latent_dim 256\ --snapshot 10000\ --modelname cvae_transformerreemb8_rc_kl_reg\ --dataset hdtf\ --num_epochs 100000\ --folder exps_delta_pose_rope_eye/HDTF_b40_200_eye_kl4e3_lr4e-4_reg5e-4_rope16_3 # > output.log & # nohup python /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/src/train/train_cvae_ganloss_first3.py\ # --num_frames 40\ # --batch_size 20\ # --lambda_kl 1\ # --lambda_rc 1\ # --num_layers 4\ # --modelname cvae_transformerold_kl_ssim\ # --dataset hdtf\ # --num_epochs 30000\ # --folder exps_delta_pose_f3/HDTF_l2_nf40_kl1_ssim_norm_w5_1w_b20_first_3 > output.log & ================================================ FILE: PBnet/src/__init__.py ================================================ ================================================ FILE: PBnet/src/config.py ================================================ import os SMPL_DATA_PATH = "models/smpl/" SMPL_KINTREE_PATH = os.path.join(SMPL_DATA_PATH, "kintree_table.pkl") SMPL_MODEL_PATH = os.path.join(SMPL_DATA_PATH, "SMPL_NEUTRAL.pkl") JOINT_REGRESSOR_TRAIN_EXTRA = os.path.join(SMPL_DATA_PATH, 'J_regressor_extra.npy') ================================================ FILE: PBnet/src/datasets/__init__.py ================================================ ================================================ FILE: PBnet/src/datasets/datasets_hdtf_pos_chunk_norm_2_fast.py ================================================ from os import name import sys sys.path.append('your_path') import os import random import torch import numpy as np import torch.utils.data as data import torch.nn.functional as Ft import imageio.v2 as imageio import cv2 import torchvision.transforms.functional as F import matplotlib.pyplot as plt from PIL import Image from scipy.interpolate import interp1d # import decord from torchvision.transforms.functional import to_pil_image from torchvision import transforms import time import pickle as pkl # decord.bridge.set_bridge('torch') def resize(im, desired_size, interpolation): old_size = im.shape[:2] ratio = float(desired_size)/max(old_size) new_size = tuple(int(x*ratio) for x in old_size) im = cv2.resize(im, (new_size[1], new_size[0]), interpolation=interpolation) delta_w = desired_size - new_size[1] delta_h = desired_size - new_size[0] top, bottom = delta_h//2, delta_h-(delta_h//2) left, right = delta_w//2, delta_w-(delta_w//2) color = [0, 0, 0] new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) return new_im class HDTF(data.Dataset): def __init__(self, data_dir, max_num_frames=80, mode='train'): super(HDTF, self).__init__() self.data_dir = data_dir self.max_num_frames = max_num_frames self.mode = mode self.hubert_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert_interpolate_chunk' self.pose_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/pose_bar_chunk' # self.max_vals = torch.tensor([20, 10, 10, 7e-4, # 7e+1, 9e+1]).to(torch.float32) # self.min_vals = torch.tensor([-20, -10, -10, 4e-4, # 5e+1, 6e+1]).to(torch.float32) self.max_vals = torch.tensor([90, 90, 90, 1, 720, 1080]).to(torch.float32) self.min_vals = torch.tensor([-90, -90, -90, 0, 0, 0]).to(torch.float32) vid_list = [] # hdtf vid_id_name_list = ['RD_Radio14_000','RD_Radio30_000','RD_Radio47_000','RD_Radio56_000','WDA_AmyKlobuchar1_001',\ 'WDA_BarbaraLee0_000','WDA_BobCasey0_000','WDA_CatherineCortezMasto_000','WDA_DebbieDingell1_000','WDA_DonaldMcEachin_000',\ 'WDA_EricSwalwell_000','WDA_HenryWaxman_000','WDA_JanSchakowsky1_000','WDA_JoeDonnelly_000','WDA_JohnSarbanes1_000',\ 'WDA_JoeNeguse_001','WDA_KatieHill_000','WDA_LucyMcBath_000','WDA_MazieHirono0_000','WDA_NancyPelosi1_000',\ 'WDA_PattyMurray0_000','WDA_RaulRuiz_000','WDA_SeanPatrickMaloney_000','WDA_TammyBaldwin0_000','WDA_TerriSewell0_000',\ 'WDA_TomCarper_000','WDA_WhipJimClyburn_000','WRA_AdamKinzinger0_000','WRA_AnnWagner_000','WRA_BobCorker_000',\ 'WRA_CandiceMiller0_000','WRA_CathyMcMorrisRodgers2_000','WRA_CoryGardner1_000','WRA_DebFischer1_000','WRA_DianeBlack1_000',\ 'WRA_ErikPaulsen_000','WRA_GeorgeLeMieux_000','WRA_JebHensarling0_001','WRA_JoeHeck1_000','WRA_JohnKasich1_001',\ 'WRA_MarcoRubio_000'] bad_id_name = ['WDA_DanKildee_000', 'WDA_PatrickLeahy1_000', 'WRA_KristiNoem2_000'] # vid_id_name_list = [item + '.mp4' for item in vid_id_name_list] # bad_id_name = [item + '.mp4' for item in bad_id_name] with open('/train20/intern/permanent/hbcheng2/data/HDTF/length_dict.pkl', 'rb') as f: self.len_dict = pkl.load(f) if mode == 'train': for id_name in os.listdir(data_dir): # id_name = id_name[:-4] if id_name in vid_id_name_list or id_name in bad_id_name: continue vid_list.append(id_name) self.videos = vid_list if mode == 'test': self.videos = vid_id_name_list def check_head(self, frame_list, video_name, start, end): start_path = self.get_pose_path(frame_list, video_name, start) end_path = self.get_pose_path(frame_list, video_name, end) if os.path.exists(start_path) and os.path.exists(end_path): return True else: return False def get_block_data_for_two(self, path, start, end): # TODO: id function ''' input: start: start id end: end id output: the data from block ''' block_st = start//25 block_ed = end//25 st_pos = block_st % 25 ed_pos = block_ed % 25 block_st_name = 'chunk_%04d.npy' % (block_st) block_ed_name = 'chunk_%04d.npy' % (block_ed) if block_st != block_ed: block_st_path = os.path.join(path, block_st_name) block_ed_path = os.path.join(path, block_ed_name) block_st = np.load(block_st_path) block_ed = np.load(block_ed_path) return np.concatenate((block_st[st_pos:], block_ed[:ed_pos])) else: block_st_path = os.path.join(path, block_st_name) block_st = np.load(block_st_path) return block_st[st_pos, ed_pos] def get_block_data(self, path, start, end): # TODO: id function ''' input: start: start id end: end id output: the data from block ''' block_st = start//25 block_ed = end//25 st_pos = start % 25 ed_pos = end % 25 block_list = [os.path.join(path,'chunk_%04d.npy' % (i)) for i in range(block_st, block_ed+1)] if block_st != block_ed: arr_list = [] block_st = np.load(block_list[0]) arr_list.append(block_st[st_pos:]) for path in block_list[1:-1]: arr_list.append(np.load(path)) block_ed = np.load(block_list[-1]) arr_list.append(block_ed[:ed_pos]) return np.concatenate(arr_list) else: block_st_path = os.path.join(path, block_list[0]) block_st = np.load(block_st_path) return block_st[st_pos: ed_pos] def check_len(self, name): return self.len_dict[name] def __len__(self): return len(self.videos) def __getitem__(self, idx): video_name = self.videos[idx] path = os.path.join(self.data_dir, video_name) hubert_path = os.path.join(self.hubert_dir, video_name) pose_path = os.path.join(self.pose_dir, video_name) # eye_blink_path = os.path.join(self.eye_blink_dir, video_name) total_num_frames = self.check_len(video_name) if total_num_frames <= self.max_num_frames: sample_frames = total_num_frames start = 0 else: sample_frames = self.max_num_frames start = np.random.randint(total_num_frames-self.max_num_frames) start=start stop=sample_frames+start sample_hubert_feature_npy = self.get_block_data(path = hubert_path, start = start, end = stop).astype(np.float32) sample_pose_list_npy = self.get_block_data(path = pose_path, start = start, end = stop).astype(np.float32) sample_hubert_feature_tensor = torch.tensor(sample_hubert_feature_npy) sample_pos_feature_tensor = torch.tensor(sample_pose_list_npy)[:,:-1] sample_pos_feature_tensor = (sample_pos_feature_tensor - self.min_vals)/ (self.max_vals - self.min_vals) video_name = video_name.replace('/','_') # sample_pose_list_npy = sample_pose_list_npy.transpose(1,0) # for compatibility return sample_hubert_feature_tensor, sample_pos_feature_tensor, video_name, start def update_parameters(self, parameters): _, self.pos_dim = self[0][1].shape _, self.audio_dim = self[0][0].shape parameters["audio_dim"] = self.audio_dim parameters["pos_dim"] = self.pos_dim # parameters["njoints"] = self.njoints if __name__ == "__main__": # hdtf data_dir = "/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz" # pose_dir = "/train20/intern/permanent/hbcheng2/data/HDTF/pose" # crema # data_dir='/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images' dataset = HDTF(data_dir=data_dir, mode='train') for i in range(10): dataset.__getitem__(i) print('------') test_dataset = data.DataLoader(dataset=dataset, batch_size=10, num_workers=8, shuffle=False) for i, batch in enumerate(test_dataset): print(i) ================================================ FILE: PBnet/src/datasets/datasets_hdtf_pos_chunk_norm_eye_fast.py ================================================ from os import name import sys # sys.path.append('your_path') import os import random import torch import numpy as np import torch.utils.data as data import torch.nn.functional as Ft import imageio.v2 as imageio import cv2 import torchvision.transforms.functional as F import matplotlib.pyplot as plt from PIL import Image from scipy.interpolate import interp1d # import decord from torchvision.transforms.functional import to_pil_image from torchvision import transforms import time import pickle as pkl # decord.bridge.set_bridge('torch') def resize(im, desired_size, interpolation): old_size = im.shape[:2] ratio = float(desired_size)/max(old_size) new_size = tuple(int(x*ratio) for x in old_size) im = cv2.resize(im, (new_size[1], new_size[0]), interpolation=interpolation) delta_w = desired_size - new_size[1] delta_h = desired_size - new_size[0] top, bottom = delta_h//2, delta_h-(delta_h//2) left, right = delta_w//2, delta_w-(delta_w//2) color = [0, 0, 0] new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) return new_im class HDTF(data.Dataset): def __init__(self, data_dir, max_num_frames=80, mode='train'): super(HDTF, self).__init__() self.data_dir = data_dir self.max_num_frames = max_num_frames self.mode = mode self.hubert_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert_interpolate' #hdtf hubert # self.hubert_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wavlm_interpolate_chunk' self.pose_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/pose_bar' self.eye_blink_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/eye_blink_bbox_from_xpc_bar' # self.max_vals = torch.tensor([20, 10, 10, 7e-4, # 7e+1, 9e+1]).to(torch.float32) # self.min_vals = torch.tensor([-20, -10, -10, 4e-4, # 5e+1, 6e+1]).to(torch.float32) self.max_vals = torch.tensor([90, 90, 90, 1, 720, 1080]).to(torch.float32) self.min_vals = torch.tensor([-90, -90, -90, 0, 0, 0]).to(torch.float32) vid_list = [] # hdtf vid_id_name_list = ['RD_Radio14_000','RD_Radio30_000','RD_Radio47_000','RD_Radio56_000','WDA_AmyKlobuchar1_001',\ 'WDA_BarbaraLee0_000','WDA_BobCasey0_000','WDA_CatherineCortezMasto_000','WDA_DebbieDingell1_000','WDA_DonaldMcEachin_000',\ 'WDA_EricSwalwell_000','WDA_HenryWaxman_000','WDA_JanSchakowsky1_000','WDA_JoeDonnelly_000','WDA_JohnSarbanes1_000',\ 'WDA_JoeNeguse_001','WDA_KatieHill_000','WDA_LucyMcBath_000','WDA_MazieHirono0_000','WDA_NancyPelosi1_000',\ 'WDA_PattyMurray0_000','WDA_RaulRuiz_000','WDA_SeanPatrickMaloney_000','WDA_TammyBaldwin0_000','WDA_TerriSewell0_000',\ 'WDA_TomCarper_000','WDA_WhipJimClyburn_000','WRA_AdamKinzinger0_000','WRA_AnnWagner_000','WRA_BobCorker_000',\ 'WRA_CandiceMiller0_000','WRA_CathyMcMorrisRodgers2_000','WRA_CoryGardner1_000','WRA_DebFischer1_000','WRA_DianeBlack1_000',\ 'WRA_ErikPaulsen_000','WRA_GeorgeLeMieux_000','WRA_JebHensarling0_001','WRA_JoeHeck1_000','WRA_JohnKasich1_001',\ 'WRA_MarcoRubio_000'] bad_id_name = ['WDA_DanKildee_000', 'WDA_PatrickLeahy1_000', 'WRA_KristiNoem2_000'] # vid_id_name_list = [item + '.mp4' for item in vid_id_name_list] # bad_id_name = [item + '.mp4' for item in bad_id_name] with open('/train20/intern/permanent/hbcheng2/data/HDTF/length_dict.pkl', 'rb') as f: self.len_dict = pkl.load(f) if mode == 'train': for id_name in os.listdir(data_dir): # id_name = id_name[:-4] if id_name in vid_id_name_list or id_name in bad_id_name: continue vid_list.append(id_name) self.videos = vid_list if mode == 'test': self.videos = vid_id_name_list self.cache_audio = {} self.cache_eye = {} self.cache_pose = {} for video in self.videos: hubert_path = os.path.join(self.hubert_dir, video) + '.npy' pose_path = os.path.join(self.pose_dir, video) + '.npy' eye_blink_path = os.path.join(self.eye_blink_dir, video) + '.npy' hubert_fea = np.load(hubert_path) pose_fea = np.load(pose_path) blink_fea = np.load(eye_blink_path) self.cache_audio[video] = hubert_fea self.cache_pose[video] = pose_fea self.cache_eye[video] = blink_fea def check_head(self, frame_list, video_name, start, end): start_path = self.get_pose_path(frame_list, video_name, start) end_path = self.get_pose_path(frame_list, video_name, end) if os.path.exists(start_path) and os.path.exists(end_path): return True else: return False def get_block_data_for_two(self, path, start, end): # TODO: id function ''' input: start: start id end: end id output: the data from block ''' block_st = start//25 block_ed = end//25 st_pos = block_st % 25 ed_pos = block_ed % 25 block_st_name = 'chunk_%04d.npy' % (block_st) block_ed_name = 'chunk_%04d.npy' % (block_ed) if block_st != block_ed: block_st_path = os.path.join(path, block_st_name) block_ed_path = os.path.join(path, block_ed_name) block_st = np.load(block_st_path) block_ed = np.load(block_ed_path) return np.concatenate((block_st[st_pos:], block_ed[:ed_pos])) else: block_st_path = os.path.join(path, block_st_name) block_st = np.load(block_st_path) return block_st[st_pos, ed_pos] def get_block_data(self, path, start, end): # TODO: id function ''' input: start: start id end: end id output: the data from block ''' block_st = start//25 block_ed = end//25 st_pos = start % 25 ed_pos = end % 25 block_list = [os.path.join(path,'chunk_%04d.npy' % (i)) for i in range(block_st, block_ed+1)] if block_st != block_ed: arr_list = [] block_st = np.load(block_list[0]) arr_list.append(block_st[st_pos:]) for path in block_list[1:-1]: arr_list.append(np.load(path)) block_ed = np.load(block_list[-1]) arr_list.append(block_ed[:ed_pos]) return np.concatenate(arr_list) else: block_st_path = os.path.join(path, block_list[0]) block_st = np.load(block_st_path) return block_st[st_pos: ed_pos] def check_len(self, name): return self.len_dict[name] def __len__(self): return len(self.videos) def __getitem__(self, idx): video_name = self.videos[idx] # path = os.path.join(self.data_dir, video_name) # hubert_path = os.path.join(self.hubert_dir, video_name) # pose_path = os.path.join(self.pose_dir, video_name) # eye_blink_path = os.path.join(self.eye_blink_dir, video_name) total_num_frames = self.check_len(video_name) if total_num_frames <= self.max_num_frames: sample_frames = total_num_frames start = 0 else: sample_frames = self.max_num_frames start = np.random.randint(total_num_frames-self.max_num_frames) start=start stop=sample_frames+start start_time = time.time() # sample_hubert_feature_npy = self.get_block_data(path = hubert_path, start = start, end = stop).astype(np.float32) # sample_pose_list_npy = self.get_block_data(path = pose_path, start = start, end = stop).astype(np.float32) sample_hubert_feature_npy = self.cache_audio[video_name][start:stop].astype(np.float32) sample_pose_list_npy = self.cache_pose[video_name][start:stop].astype(np.float32) sample_eye_blink_list_npy = self.cache_eye[video_name][start:stop].astype(np.float32) # end_time = time.time() # print("dataset_audiopose_cost: ", - start_time + end_time) # start_time = time.time() # sample_eye_blink_list_npy = self.get_block_data(path = eye_blink_path, start = start, end = stop).astype(np.float32) # end_time = time.time() # print("dataset_eye_cost: ", - start_time + end_time) # start_time = time.time() sample_hubert_feature_tensor = torch.tensor(sample_hubert_feature_npy) sample_pos_feature_tensor = torch.tensor(sample_pose_list_npy)[:,:-1] sample_pos_feature_tensor = (sample_pos_feature_tensor - self.min_vals)/ (self.max_vals - self.min_vals) # end_time = time.time() # print("dataset_audiopose_cost2: ", - start_time + end_time) # start_time = time.time() sample_eye_feature_tensor = torch.tensor(sample_eye_blink_list_npy)[:,:2] # end_time = time.time() # print("dataset_eye_cost2: ", - start_time + end_time) # start_time = time.time() sample_pos_eye_cat_tensor = torch.cat((sample_pos_feature_tensor,sample_eye_feature_tensor),dim=1) # end_time = time.time() # print("dataset_eye_cost3: ", - start_time + end_time) # start_time = time.time() video_name = video_name.replace('/','_') # sample_pose_list_npy = sample_pose_list_npy.transpose(1,0) # for compatibility return sample_hubert_feature_tensor, sample_pos_feature_tensor, sample_eye_feature_tensor, video_name, start, sample_pos_eye_cat_tensor def update_parameters(self, parameters): _, self.pos_dim = self[0][1].shape _, self.eye_dim = self[0][2].shape _, self.audio_dim = self[0][0].shape parameters["audio_dim"] = self.audio_dim parameters["pos_dim"] = self.pos_dim parameters["eye_dim"] = self.eye_dim # parameters["njoints"] = self.njoints if __name__ == "__main__": # hdtf data_dir = "/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz" # pose_dir = "/train20/intern/permanent/hbcheng2/data/HDTF/pose" # crema # data_dir='/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images' dataset = HDTF(data_dir=data_dir, mode='train') for i in range(10): dataset.__getitem__(i) print('------') test_dataset = data.DataLoader(dataset=dataset, batch_size=10, num_workers=8, shuffle=False) for i, batch in enumerate(test_dataset): print(i) ================================================ FILE: PBnet/src/datasets/datasets_hdtf_pos_df.py ================================================ from os import name from src.datasets.datasets_hdtf_pos import HDTF import sys sys.path.append('your_path') import os import random import torch import numpy as np import torch.utils.data as data import imageio.v2 as imageio import cv2 import torchvision.transforms.functional as F import matplotlib.pyplot as plt from PIL import Image from scipy.interpolate import interp1d # from ..utils.tensors import collate def resize(im, desired_size, interpolation): old_size = im.shape[:2] ratio = float(desired_size)/max(old_size) new_size = tuple(int(x*ratio) for x in old_size) im = cv2.resize(im, (new_size[1], new_size[0]), interpolation=interpolation) delta_w = desired_size - new_size[1] delta_h = desired_size - new_size[0] top, bottom = delta_h//2, delta_h-(delta_h//2) left, right = delta_w//2, delta_w-(delta_w//2) color = [0, 0, 0] new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) return new_im class HDTF(data.Dataset): def __init__(self, data_dir, max_num_frames=80, min_num_frames=40, mode='train'): super(HDTF, self).__init__() self.data_dir = data_dir self.max_num_frames = max_num_frames self.min_num_frames = min_num_frames self.mode = mode # self.hubert_dir = '/train20/intern/permanent/lmlin2/data/crema_wav_hubert' # self.pose_dir = '/train20/intern/permanent/hbcheng2/data/crema/pose' self.pose_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/pose_bar' self.hubert_dir = '/train20/intern/permanent/lmlin2/data/hdtf_wav_hubert' vid_list = [] # # crema # if mode == 'train': # for id_name in os.listdir(data_dir): # if id_name in ['s15','s20','s21','s30','s33','s52','s62','s81','s82','s89']: #['s64','s76','s88','s90','s91'] # continue # vid_list.extend([os.path.join(id_name, sent) for sent in os.listdir(f'{data_dir}/{id_name}') ]) # # crema # if mode == 'test': # for id_name in ['s15','s20','s21','s30','s33','s52','s62','s81','s82','s89']: # vid_list.extend([os.path.join(id_name, sent) for sent in os.listdir(f'{data_dir}/{id_name}') ]) # hdtf vid_id_name_list = ['RD_Radio14_000','RD_Radio30_000','RD_Radio47_000','RD_Radio56_000','WDA_AmyKlobuchar1_001',\ 'WDA_BarbaraLee0_000','WDA_BobCasey0_000','WDA_CatherineCortezMasto_000','WDA_DebbieDingell1_000','WDA_DonaldMcEachin_000',\ 'WDA_EricSwalwell_000','WDA_HenryWaxman_000','WDA_JanSchakowsky1_000','WDA_JoeDonnelly_000','WDA_JohnSarbanes1_000',\ 'WDA_JoeNeguse_001','WDA_KatieHill_000','WDA_LucyMcBath_000','WDA_MazieHirono0_000','WDA_NancyPelosi1_000',\ 'WDA_PattyMurray0_000','WDA_RaulRuiz_000','WDA_SeanPatrickMaloney_000','WDA_TammyBaldwin0_000','WDA_TerriSewell0_000',\ 'WDA_TomCarper_000','WDA_WhipJimClyburn_000','WRA_AdamKinzinger0_000','WRA_AnnWagner_000','WRA_BobCorker_000',\ 'WRA_CandiceMiller0_000','WRA_CathyMcMorrisRodgers2_000','WRA_CoryGardner1_000','WRA_DebFischer1_000','WRA_DianeBlack1_000',\ 'WRA_ErikPaulsen_000','WRA_GeorgeLeMieux_000','WRA_JebHensarling0_001','WRA_JoeHeck1_000','WRA_JohnKasich1_001',\ 'WRA_MarcoRubio_000'] bad_id_name_list = ['WDA_DanKildee_000', 'WDA_PatrickLeahy1_000', 'WRA_KristiNoem2_000', 'RD_Radio39_000'] if mode == 'train': for id_name in os.listdir(data_dir): if id_name in vid_id_name_list or id_name in bad_id_name_list: continue vid_list.append(id_name) self.videos = vid_list if mode == 'test': self.videos = vid_id_name_list # def __len__(self): # return len(self.videos) def __len__(self): num_seq_max = getattr(self, "num_seq_max", -1) if num_seq_max == -1: from math import inf num_seq_max = inf return min(len(self.videos), num_seq_max) def __getitem__(self, idx): video_name = self.videos[idx] path = os.path.join(self.data_dir, video_name) # path_pose = os.path.join(self.pose_dir, video_name) frame_path_list = os.listdir(path) frame_path_list.sort() total_num_frames = len(frame_path_list) # pose_path_list = os.listdir(path_pose) # pose_path_list.sort() hubert_path = os.path.join(self.hubert_dir, video_name+'.npy') hubert_feature = np.load(hubert_path) Nframes_hubert = hubert_feature.shape[0] interp_func = interp1d(np.arange(Nframes_hubert), hubert_feature, kind='linear', axis=0) hubert_feature = interp_func(np.linspace(0, Nframes_hubert - 1, total_num_frames)).astype(np.float32) pose_path = os.path.join(self.pose_dir, video_name+'.npy') pose_seq = np.load(pose_path).astype(np.float32) cur_num_frames = np.random.randint(self.min_num_frames, self.max_num_frames+1) if total_num_frames <= cur_num_frames: sample_frames = total_num_frames start = 0 else: sample_frames = cur_num_frames start = np.random.randint(total_num_frames-cur_num_frames) sample_idx_list = np.linspace(start=start, stop=sample_frames+start-1, num=sample_frames, dtype=int) # sample_frame_path_list = [frame_path_list[x] for x in sample_idx_list] # sample_pose_path_list = [pose_path_list[x] for x in sample_idx_list] sample_hubert_feature_list = [hubert_feature[x,:] for x in sample_idx_list] # nf,1024 sample_hubert_feature_tensor = [torch.from_numpy(arr) for arr in sample_hubert_feature_list] sample_hubert_feature_tensor = torch.stack(sample_hubert_feature_tensor) # sample_hubert_feature_list = np.stack(sample_hubert_feature_list).reshape(-1) # (nf*1024) # load pose try: # sample_pose_list = [np.load(os.path.join(path_pose, x))[0][:-1].astype(np.float32) for x in sample_pose_path_list] sample_pose_list = [pose_seq[x,:] for x in sample_idx_list] sample_pos_feature_tensor = [torch.from_numpy(arr) for arr in sample_pose_list] sample_pos_feature_tensor = torch.stack(sample_pos_feature_tensor) # nf, 6 except Exception: # print(os.path.join(path_pose, x)) print("load fail !! ") print(pose_path) print(sample_idx_list) sample_pose_list = [pose_seq[x,:] for x in sample_idx_list] sample_pos_feature_tensor = [torch.from_numpy(arr) for arr in sample_pose_list] sample_pos_feature_tensor = torch.stack(sample_pos_feature_tensor) # nf, 6 # added to change the video_name of crema video_name = video_name.replace('/','_') # sample_class_tensor = torch.tensor(0) return sample_hubert_feature_tensor, sample_pos_feature_tensor, video_name def update_parameters(self, parameters): _, self.pos_dim = self[0][1].shape _, self.audio_dim = self[0][0].shape parameters["audio_dim"] = self.audio_dim parameters["pos_dim"] = self.pos_dim # parameters["njoints"] = self.njoints if __name__ == "__main__": data_dir = "/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz" # data_dir='/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images' dataset = HDTF(data_dir=data_dir,mode='test') for i in range(100): dataset.__getitem__(i) print('------') test_dataset = data.DataLoader(dataset=dataset, batch_size=10, num_workers=8, shuffle=False) for i, batch in enumerate(test_dataset): print(i) ================================================ FILE: PBnet/src/datasets/datasets_hdtf_pos_dict_norm_2.py ================================================ from os import name import sys sys.path.append('your_path') import os import random import torch import numpy as np import torch.utils.data as data import torch.nn.functional as Ft import imageio.v2 as imageio import cv2 import torchvision.transforms.functional as F import matplotlib.pyplot as plt from PIL import Image from scipy.interpolate import interp1d # import decord from torchvision.transforms.functional import to_pil_image from torchvision import transforms import time import pickle as pkl # decord.bridge.set_bridge('torch') def resize(im, desired_size, interpolation): old_size = im.shape[:2] ratio = float(desired_size)/max(old_size) new_size = tuple(int(x*ratio) for x in old_size) im = cv2.resize(im, (new_size[1], new_size[0]), interpolation=interpolation) delta_w = desired_size - new_size[1] delta_h = desired_size - new_size[0] top, bottom = delta_h//2, delta_h-(delta_h//2) left, right = delta_w//2, delta_w-(delta_w//2) color = [0, 0, 0] new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) return new_im class HDTF(data.Dataset): def __init__(self, data_dir, max_num_frames=80, mode='train'): super(HDTF, self).__init__() self.data_dir = data_dir self.max_num_frames = max_num_frames self.mode = mode self.hubert_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert_interpolate' self.pose_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/pose_bar' # self.max_vals = torch.tensor([20, 10, 10, 7e-4, # 7e+1, 9e+1]).to(torch.float32) # self.min_vals = torch.tensor([-20, -10, -10, 4e-4, # 5e+1, 6e+1]).to(torch.float32) self.max_vals = torch.tensor([90, 90, 90, 1, 720, 1080]).to(torch.float32).cuda() self.min_vals = torch.tensor([-90, -90, -90, 0, 0, 0]).to(torch.float32).cuda() vid_list = [] # hdtf vid_id_name_list = ['RD_Radio14_000','RD_Radio30_000','RD_Radio47_000','RD_Radio56_000','WDA_AmyKlobuchar1_001',\ 'WDA_BarbaraLee0_000','WDA_BobCasey0_000','WDA_CatherineCortezMasto_000','WDA_DebbieDingell1_000','WDA_DonaldMcEachin_000',\ 'WDA_EricSwalwell_000','WDA_HenryWaxman_000','WDA_JanSchakowsky1_000','WDA_JoeDonnelly_000','WDA_JohnSarbanes1_000',\ 'WDA_JoeNeguse_001','WDA_KatieHill_000','WDA_LucyMcBath_000','WDA_MazieHirono0_000','WDA_NancyPelosi1_000',\ 'WDA_PattyMurray0_000','WDA_RaulRuiz_000','WDA_SeanPatrickMaloney_000','WDA_TammyBaldwin0_000','WDA_TerriSewell0_000',\ 'WDA_TomCarper_000','WDA_WhipJimClyburn_000','WRA_AdamKinzinger0_000','WRA_AnnWagner_000','WRA_BobCorker_000',\ 'WRA_CandiceMiller0_000','WRA_CathyMcMorrisRodgers2_000','WRA_CoryGardner1_000','WRA_DebFischer1_000','WRA_DianeBlack1_000',\ 'WRA_ErikPaulsen_000','WRA_GeorgeLeMieux_000','WRA_JebHensarling0_001','WRA_JoeHeck1_000','WRA_JohnKasich1_001',\ 'WRA_MarcoRubio_000'] bad_id_name = ['WDA_DanKildee_000', 'WDA_PatrickLeahy1_000', 'WRA_KristiNoem2_000'] # vid_id_name_list = [item + '.mp4' for item in vid_id_name_list] # bad_id_name = [item + '.mp4' for item in bad_id_name] with open('/train20/intern/permanent/hbcheng2/data/HDTF/length_dict.pkl', 'rb') as f: self.len_dict = pkl.load(f) if mode == 'train': for id_name in os.listdir(data_dir): # id_name = id_name[:-4] if id_name in vid_id_name_list or id_name in bad_id_name: continue vid_list.append(id_name) self.videos = vid_list if mode == 'test': self.videos = vid_id_name_list self.audio_dict = {} self.pose_dict = {} for video_name in self.videos: hubert_path = os.path.join(self.hubert_dir, video_name + '.npy') pose_path = os.path.join(self.pose_dir, video_name + '.npy') hubert_npy = torch.tensor(np.load(hubert_path).astype(np.float32)) pose_npy = torch.tensor(np.load(pose_path).astype(np.float32)) self.audio_dict[video_name] = hubert_npy self.pose_dict[video_name] = pose_npy def check_head(self, frame_list, video_name, start, end): start_path = self.get_pose_path(frame_list, video_name, start) end_path = self.get_pose_path(frame_list, video_name, end) if os.path.exists(start_path) and os.path.exists(end_path): return True else: return False def get_block_data_for_two(self, path, start, end): # TODO: id function ''' input: start: start id end: end id output: the data from block ''' block_st = start//25 block_ed = end//25 st_pos = block_st % 25 ed_pos = block_ed % 25 block_st_name = 'chunk_%04d.npy' % (block_st) block_ed_name = 'chunk_%04d.npy' % (block_ed) if block_st != block_ed: block_st_path = os.path.join(path, block_st_name) block_ed_path = os.path.join(path, block_ed_name) block_st = np.load(block_st_path) block_ed = np.load(block_ed_path) return np.concatenate((block_st[st_pos:], block_ed[:ed_pos])) else: block_st_path = os.path.join(path, block_st_name) block_st = np.load(block_st_path) return block_st[st_pos, ed_pos] def get_block_data(self, path, start, end): # TODO: id function ''' input: start: start id end: end id output: the data from block ''' block_st = start//25 block_ed = end//25 st_pos = start % 25 ed_pos = end % 25 block_list = [os.path.join(path,'chunk_%04d.npy' % (i)) for i in range(block_st, block_ed+1)] if block_st != block_ed: arr_list = [] block_st = np.load(block_list[0]) arr_list.append(block_st[st_pos:]) for path in block_list[1:-1]: arr_list.append(np.load(path)) block_ed = np.load(block_list[-1]) arr_list.append(block_ed[:ed_pos]) return np.concatenate(arr_list) else: block_st_path = os.path.join(path, block_list[0]) block_st = np.load(block_st_path) return block_st[st_pos: ed_pos] def check_len(self, name): return self.len_dict[name] def __len__(self): return len(self.videos) def __getitem__(self, idx): # if __debug__: video_name = self.videos[idx] # path = os.path.join(self.data_dir, video_name) # hubert_path = os.path.join(self.hubert_dir, video_name) # pose_path = os.path.join(self.pose_dir, video_name) # eye_blink_path = os.path.join(self.eye_blink_dir, video_name) total_num_frames = self.check_len(video_name) if total_num_frames <= self.max_num_frames: sample_frames = total_num_frames start = 0 else: sample_frames = self.max_num_frames start = np.random.randint(total_num_frames-self.max_num_frames) start=start stop=sample_frames+start # end_time = time.time() # print("indexing: ", - start_time + end_time) # start_time = time.time() sample_hubert_feature_tensor = self.audio_dict[video_name][start:stop].cuda() # self.get_block_data(path = hubert_path, start = start, end = stop).astype(np.float32) sample_pos_feature_tensor = self.pose_dict[video_name][start:stop][:,:-1].cuda() # self.get_block_data(path = pose_path, start = start, end = stop).astype(np.float32) # end_time = time.time() # print("loading: ", - start_time + end_time) # start_time = time.time() # sample_hubert_feature_tensor = torch.tensor(sample_hubert_feature_npy) # sample_pos_feature_tensor = torch.tensor(sample_pose_list_npy)[:,:-1] # end_time = time.time() # print("converting: ", - start_time + end_time) # start_time = time.time() sample_pos_feature_tensor = (sample_pos_feature_tensor - self.min_vals)/ (self.max_vals - self.min_vals) video_name = video_name.replace('/','_') # end_time = time.time() # print("processing: ", - start_time + end_time) # start_time = time.time() # sample_pose_list_npy = sample_pose_list_npy.transpose(1,0) # for compatibility return sample_hubert_feature_tensor, sample_pos_feature_tensor, video_name, start def update_parameters(self, parameters): _, self.pos_dim = self[0][1].shape _, self.audio_dim = self[0][0].shape parameters["audio_dim"] = self.audio_dim parameters["pos_dim"] = self.pos_dim # parameters["njoints"] = self.njoints if __name__ == "__main__": # hdtf data_dir = "/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz" # pose_dir = "/train20/intern/permanent/hbcheng2/data/HDTF/pose" # crema # data_dir='/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images' dataset = HDTF(data_dir=data_dir, mode='train') for i in range(10): dataset.__getitem__(i) print('------') test_dataset = data.DataLoader(dataset=dataset, batch_size=10, num_workers=8, shuffle=False) for i, batch in enumerate(test_dataset): print(i) ================================================ FILE: PBnet/src/datasets/datasets_hdtf_wpose_lmk_block.py ================================================ from os import name import sys sys.path.append('your_path') import os import random import torch import numpy as np import torch.utils.data as data import torch.nn.functional as Ft import imageio.v2 as imageio import cv2 import torchvision.transforms.functional as F import matplotlib.pyplot as plt from PIL import Image from scipy.interpolate import interp1d import decord from torchvision.transforms.functional import to_pil_image from torchvision import transforms import time import pickle as pkl decord.bridge.set_bridge('torch') def resize(im, desired_size, interpolation): old_size = im.shape[:2] ratio = float(desired_size)/max(old_size) new_size = tuple(int(x*ratio) for x in old_size) im = cv2.resize(im, (new_size[1], new_size[0]), interpolation=interpolation) delta_w = desired_size - new_size[1] delta_h = desired_size - new_size[0] top, bottom = delta_h//2, delta_h-(delta_h//2) left, right = delta_w//2, delta_w-(delta_w//2) color = [0, 0, 0] new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) return new_im class HDTF(data.Dataset): def __init__(self, data_dir, pose_dir, eye_blink_dir, max_num_frames=80, image_size=128,mode='train', mean=(128, 128, 128), color_jitter=True): super(HDTF, self).__init__() self.mean = torch.tensor(mean)[None,:,None,None] self.data_dir = data_dir self.pose_dir = pose_dir self.eye_blink_dir = eye_blink_dir self.is_jitter = color_jitter self.max_num_frames = max_num_frames self.image_size = image_size self.mode = mode vid_list = [] # # crema # self.hubert_dir = '/train20/intern/permanent/lmlin2/data/crema_wav_hubert' # if mode == 'train': # for id_name in os.listdir(data_dir): # if id_name in ['s64','s76','s88','s90','s91']: # continue # vid_list.extend([os.path.join(id_name, sent) for sent in os.listdir(f'{data_dir}/{id_name}') ]) # if mode == 'test': # for id_name in ['s64','s76','s88','s90','s91']: # vid_list.extend([os.path.join(id_name, sent) for sent in os.listdir(f'{data_dir}/{id_name}') ]) # self.videos = vid_list # hdtf vid_id_name_list = ['RD_Radio14_000','RD_Radio30_000','RD_Radio47_000','RD_Radio56_000','WDA_AmyKlobuchar1_001',\ 'WDA_BarbaraLee0_000','WDA_BobCasey0_000','WDA_CatherineCortezMasto_000','WDA_DebbieDingell1_000','WDA_DonaldMcEachin_000',\ 'WDA_EricSwalwell_000','WDA_HenryWaxman_000','WDA_JanSchakowsky1_000','WDA_JoeDonnelly_000','WDA_JohnSarbanes1_000',\ 'WDA_JoeNeguse_001','WDA_KatieHill_000','WDA_LucyMcBath_000','WDA_MazieHirono0_000','WDA_NancyPelosi1_000',\ 'WDA_PattyMurray0_000','WDA_RaulRuiz_000','WDA_SeanPatrickMaloney_000','WDA_TammyBaldwin0_000','WDA_TerriSewell0_000',\ 'WDA_TomCarper_000','WDA_WhipJimClyburn_000','WRA_AdamKinzinger0_000','WRA_AnnWagner_000','WRA_BobCorker_000',\ 'WRA_CandiceMiller0_000','WRA_CathyMcMorrisRodgers2_000','WRA_CoryGardner1_000','WRA_DebFischer1_000','WRA_DianeBlack1_000',\ 'WRA_ErikPaulsen_000','WRA_GeorgeLeMieux_000','WRA_JebHensarling0_001','WRA_JoeHeck1_000','WRA_JohnKasich1_001',\ 'WRA_MarcoRubio_000'] bad_id_name = ['WDA_DanKildee_000', 'WDA_PatrickLeahy1_000', 'WRA_KristiNoem2_000'] # vid_id_name_list = [item + '.mp4' for item in vid_id_name_list] # bad_id_name = [item + '.mp4' for item in bad_id_name] # hdtf self.hubert_dir = '/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert_interpolate_chunk' with open('/train20/intern/permanent/hbcheng2/data/HDTF/length_dict.pkl', 'rb') as f: self.len_dict = pkl.load(f) # vid_id_name_list = ['RD_Radio47_000','WDA_CatherineCortezMasto_000','WDA_JoeNeguse_001','WDA_MichelleLujanGrisham_000','WRA_ErikPaulsen_002', \ # 'WDA_ZoeLofgren_000','WRA_JebHensarling2_003','WRA_MichaelSteele_000', 'WRA_ToddYoung_000', 'WRA_VickyHartzler_000'] if mode == 'train': for id_name in os.listdir(data_dir): # id_name = id_name[:-4] if id_name in vid_id_name_list or id_name in bad_id_name: continue vid_list.append(id_name) self.videos = vid_list if mode == 'test': self.videos = vid_id_name_list def check_head(self, frame_list, video_name, start, end): start_path = self.get_pose_path(frame_list, video_name, start) end_path = self.get_pose_path(frame_list, video_name, end) if os.path.exists(start_path) and os.path.exists(end_path): return True else: return False def get_block_data_for_two(self, path, start, end): # TODO: id function ''' input: start: start id end: end id output: the data from block ''' block_st = start//25 block_ed = end//25 st_pos = block_st % 25 ed_pos = block_ed % 25 block_st_name = 'chunk_%04d.npy' % (block_st) block_ed_name = 'chunk_%04d.npy' % (block_ed) if block_st != block_ed: block_st_path = os.path.join(path, block_st_name) block_ed_path = os.path.join(path, block_ed_name) block_st = np.load(block_st_path) block_ed = np.load(block_ed_path) return np.concatenate((block_st[st_pos:], block_ed[:ed_pos])) else: block_st_path = os.path.join(path, block_st_name) block_st = np.load(block_st_path) return block_st[st_pos, ed_pos] def get_block_data(self, path, start, end): # TODO: id function ''' input: start: start id end: end id output: the data from block ''' block_st = start//25 block_ed = end//25 st_pos = start % 25 ed_pos = end % 25 block_list = [os.path.join(path,'chunk_%04d.npy' % (i)) for i in range(block_st, block_ed+1)] if block_st != block_ed: arr_list = [] block_st = np.load(block_list[0]) arr_list.append(block_st[st_pos:]) for path in block_list[1:-1]: arr_list.append(np.load(path)) block_ed = np.load(block_list[-1]) arr_list.append(block_ed[:ed_pos]) return np.concatenate(arr_list) else: block_st_path = os.path.join(path, block_list[0]) block_st = np.load(block_st_path) return block_st[st_pos: ed_pos] def check_len(self, name): return self.len_dict[name] def __len__(self): return len(self.videos) def __getitem__(self, idx): video_name = self.videos[idx] path = os.path.join(self.data_dir, video_name) hubert_path = os.path.join(self.hubert_dir, video_name) pose_path = os.path.join(self.pose_dir, video_name) eye_blink_path = os.path.join(self.eye_blink_dir, video_name) total_num_frames = self.check_len(video_name) if total_num_frames <= self.max_num_frames: sample_frames = total_num_frames start = 0 else: sample_frames = self.max_num_frames start = np.random.randint(total_num_frames-self.max_num_frames) start=start stop=sample_frames+start sample_frame_npy = self.get_block_data(path = path, start = start, end = stop) sample_hubert_feature_npy = self.get_block_data(path = hubert_path, start = start, end = stop).astype(np.float32) sample_pose_list_npy = self.get_block_data(path = pose_path, start = start, end = stop).astype(np.float32) sample_eye_blink_list_npy = self.get_block_data(path = eye_blink_path, start = start, end = stop).astype(np.float32) sample_frame_list = torch.tensor(sample_frame_npy).permute(0,3,1,2) sample_hubert_feature_tensor = torch.tensor(sample_hubert_feature_npy) sample_frame_list = sample_frame_list - self.mean # 20, 3, 128, 128 # sample_frame_list = [np.transpose(x, (2, 0, 1)) for x in sample_frame_list] # sample_frame_list_npy = np.stack(sample_frame_list, axis=1) # sample_pose_list_npy = np.stack(sample_pose_list, axis = 1) # sample_eye_blink_list_npy = np.stack(sample_eye_blink_list, axis = 1) # change to float32 sample_frame_list = sample_frame_list.permute(1, 0, 2, 3) # sample_frame_list = np.array(sample_frame_list/255.0, dtype=np.float32) #3, 40, 128, 128 # sample_frame_list = sample_frame_list/255. # put to mode l forward # added to change the video_name of crema video_name = video_name.replace('/','_') sample_pose_list_npy = sample_pose_list_npy.transpose(1,0) # for compatibility sample_eye_blink_list_npy = sample_eye_blink_list_npy.transpose(1,0) if self.mode == 'test': return sample_frame_list, sample_hubert_feature_tensor, sample_pose_list_npy, sample_eye_blink_list_npy, video_name, start return sample_frame_list, sample_hubert_feature_tensor, sample_pose_list_npy, sample_eye_blink_list_npy, video_name, total_num_frames if __name__ == "__main__": # hdtf data_dir = "/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz" pose_dir = "/train20/intern/permanent/hbcheng2/data/HDTF/pose" # crema # data_dir='/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images' dataset = HDTF(data_dir=data_dir, pose_dir=pose_dir ,mode='train') for i in range(10): dataset.__getitem__(i) print('------') test_dataset = data.DataLoader(dataset=dataset, batch_size=10, num_workers=8, shuffle=False) for i, batch in enumerate(test_dataset): print(i) ================================================ FILE: PBnet/src/datasets/get_dataset.py ================================================ def get_dataset(name="ntu13"): if name == "ntu13": from .ntu13 import NTU13 return NTU13 elif name == "uestc": from .uestc import UESTC return UESTC elif name == "humanact12": from .humanact12poses import HumanAct12Poses return HumanAct12Poses def get_datasets(parameters): name = parameters["dataset"] DATA = get_dataset(name) dataset = DATA(split="train", **parameters) train = dataset # test: shallow copy (share the memory) but set the other indices from copy import copy test = copy(train) test.split = test datasets = {"train": train, "test": test} # add specific parameters from the dataset loading dataset.update_parameters(parameters) return datasets ================================================ FILE: PBnet/src/datasets/tools.py ================================================ import os import string def parse_info_name(path): name = os.path.splitext(os.path.split(path)[-1])[0] info = {} current_letter = None for letter in name: if letter in string.ascii_letters: info[letter] = [] current_letter = letter else: info[current_letter].append(letter) for key in info.keys(): info[key] = "".join(info[key]) return info ================================================ FILE: PBnet/src/evaluate/__init__.py ================================================ ================================================ FILE: PBnet/src/evaluate/action2motion/accuracy.py ================================================ import torch def calculate_accuracy(model, motion_loader, num_labels, classifier, device): confusion = torch.zeros(num_labels, num_labels, dtype=torch.long) with torch.no_grad(): for batch in motion_loader: batch_prob = classifier(batch["output_xyz"], lengths=batch["lengths"]) batch_pred = batch_prob.max(dim=1).indices for label, pred in zip(batch["y"], batch_pred): confusion[label][pred] += 1 accuracy = torch.trace(confusion)/torch.sum(confusion) return accuracy.item(), confusion ================================================ FILE: PBnet/src/evaluate/action2motion/diversity.py ================================================ import torch import numpy as np # from action2motion def calculate_diversity_multimodality(activations, labels, num_labels): diversity_times = 200 multimodality_times = 20 labels = labels.long() num_motions = len(labels) diversity = 0 first_indices = np.random.randint(0, num_motions, diversity_times) second_indices = np.random.randint(0, num_motions, diversity_times) for first_idx, second_idx in zip(first_indices, second_indices): diversity += torch.dist(activations[first_idx, :], activations[second_idx, :]) diversity /= diversity_times multimodality = 0 label_quotas = np.repeat(multimodality_times, num_labels) while np.any(label_quotas > 0): # print(label_quotas) first_idx = np.random.randint(0, num_motions) first_label = labels[first_idx] if not label_quotas[first_label]: continue second_idx = np.random.randint(0, num_motions) second_label = labels[second_idx] while first_label != second_label: second_idx = np.random.randint(0, num_motions) second_label = labels[second_idx] label_quotas[first_label] -= 1 first_activation = activations[first_idx, :] second_activation = activations[second_idx, :] multimodality += torch.dist(first_activation, second_activation) multimodality /= (multimodality_times * num_labels) return diversity.item(), multimodality.item() ================================================ FILE: PBnet/src/evaluate/action2motion/evaluate.py ================================================ import torch import numpy as np from .models import load_classifier, load_classifier_for_fid from .accuracy import calculate_accuracy from .fid import calculate_fid from .diversity import calculate_diversity_multimodality class A2MEvaluation: def __init__(self, dataname, device): dataset_opt = {"ntu13": {"joints_num": 18, "input_size_raw": 54, "num_classes": 13}, 'humanact12': {"input_size_raw": 72, "joints_num": 24, "num_classes": 12}} if dataname != dataset_opt.keys(): assert NotImplementedError(f"{dataname} is not supported.") self.dataname = dataname self.input_size_raw = dataset_opt[dataname]["input_size_raw"] self.num_classes = dataset_opt[dataname]["num_classes"] self.device = device self.gru_classifier_for_fid = load_classifier_for_fid(dataname, self.input_size_raw, self.num_classes, device).eval() self.gru_classifier = load_classifier(dataname, self.input_size_raw, self.num_classes, device).eval() def compute_features(self, model, motionloader): # calculate_activations_labels function from action2motion activations = [] labels = [] with torch.no_grad(): for idx, batch in enumerate(motionloader): activations.append(self.gru_classifier_for_fid(batch["output_xyz"], lengths=batch["lengths"])) labels.append(batch["y"]) activations = torch.cat(activations, dim=0) labels = torch.cat(labels, dim=0) return activations, labels @staticmethod def calculate_activation_statistics(activations): activations = activations.cpu().numpy() mu = np.mean(activations, axis=0) sigma = np.cov(activations, rowvar=False) return mu, sigma def evaluate(self, model, loaders): def print_logs(metric, key): print(f"Computing action2motion {metric} on the {key} loader ...") metrics = {} computedfeats = {} for key, loader in loaders.items(): metric = "accuracy" print_logs(metric, key) mkey = f"{metric}_{key}" metrics[mkey], _ = calculate_accuracy(model, loader, self.num_classes, self.gru_classifier, self.device) # features for diversity print_logs("features", key) feats, labels = self.compute_features(model, loader) print_logs("stats", key) stats = self.calculate_activation_statistics(feats) computedfeats[key] = {"feats": feats, "labels": labels, "stats": stats} print_logs("diversity", key) ret = calculate_diversity_multimodality(feats, labels, self.num_classes) metrics[f"diversity_{key}"], metrics[f"multimodality_{key}"] = ret # taking the stats of the ground truth and remove it from the computed feats gtstats = computedfeats["gt"]["stats"] # computing fid for key, loader in computedfeats.items(): metric = "fid" mkey = f"{metric}_{key}" stats = computedfeats[key]["stats"] metrics[mkey] = float(calculate_fid(gtstats, stats)) return metrics ================================================ FILE: PBnet/src/evaluate/action2motion/fid.py ================================================ import numpy as np from scipy import linalg # from action2motion def calculate_fid(statistics_1, statistics_2): return calculate_frechet_distance(statistics_1[0], statistics_1[1], statistics_2[0], statistics_2[1]) def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): """Numpy implementation of the Frechet Distance. The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) and X_2 ~ N(mu_2, C_2) is d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). Stable version by Dougal J. Sutherland. Params: -- mu1 : Numpy array containing the activations of a layer of the inception net (like returned by the function 'get_predictions') for generated samples. -- mu2 : The sample mean over activations, precalculated on an representative data set. -- sigma1: The covariance matrix over activations for generated samples. -- sigma2: The covariance matrix over activations, precalculated on an representative data set. Returns: -- : The Frechet Distance. """ mu1 = np.atleast_1d(mu1) mu2 = np.atleast_1d(mu2) sigma1 = np.atleast_2d(sigma1) sigma2 = np.atleast_2d(sigma2) assert mu1.shape == mu2.shape, \ 'Training and test mean vectors have different lengths' assert sigma1.shape == sigma2.shape, \ 'Training and test covariances have different dimensions' diff = mu1 - mu2 # Product might be almost singular covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) if not np.isfinite(covmean).all(): msg = ('fid calculation produces singular product; ' 'adding %s to diagonal of cov estimates') % eps print(msg) offset = np.eye(sigma1.shape[0]) * eps covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) # Numerical error might give slight imaginary component if np.iscomplexobj(covmean): if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): m = np.max(np.abs(covmean.imag)) raise ValueError('Imaginary component {}'.format(m)) covmean = covmean.real tr_covmean = np.trace(covmean) return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean) ================================================ FILE: PBnet/src/evaluate/action2motion/models.py ================================================ import torch import torch.nn as nn # adapted from action2motion to take inputs of different lengths class MotionDiscriminator(nn.Module): def __init__(self, input_size, hidden_size, hidden_layer, device, output_size=12, use_noise=None): super(MotionDiscriminator, self).__init__() self.device = device self.input_size = input_size self.hidden_size = hidden_size self.hidden_layer = hidden_layer self.use_noise = use_noise self.recurrent = nn.GRU(input_size, hidden_size, hidden_layer) self.linear1 = nn.Linear(hidden_size, 30) self.linear2 = nn.Linear(30, output_size) def forward(self, motion_sequence, lengths=None, hidden_unit=None): # dim (motion_length, num_samples, hidden_size) bs, njoints, nfeats, num_frames = motion_sequence.shape motion_sequence = motion_sequence.reshape(bs, njoints*nfeats, num_frames) motion_sequence = motion_sequence.permute(2, 0, 1) if hidden_unit is None: # motion_sequence = motion_sequence.permute(1, 0, 2) hidden_unit = self.initHidden(motion_sequence.size(1), self.hidden_layer) gru_o, _ = self.recurrent(motion_sequence.float(), hidden_unit) # select the last valid, instead of: gru_o[-1, :, :] out = gru_o[tuple(torch.stack((lengths-1, torch.arange(bs, device=self.device))))] # dim (num_samples, 30) lin1 = self.linear1(out) lin1 = torch.tanh(lin1) # dim (num_samples, output_size) lin2 = self.linear2(lin1) return lin2 def initHidden(self, num_samples, layer): return torch.randn(layer, num_samples, self.hidden_size, device=self.device, requires_grad=False) class MotionDiscriminatorForFID(MotionDiscriminator): def forward(self, motion_sequence, lengths=None, hidden_unit=None): # dim (motion_length, num_samples, hidden_size) bs, njoints, nfeats, num_frames = motion_sequence.shape motion_sequence = motion_sequence.reshape(bs, njoints*nfeats, num_frames) motion_sequence = motion_sequence.permute(2, 0, 1) if hidden_unit is None: # motion_sequence = motion_sequence.permute(1, 0, 2) hidden_unit = self.initHidden(motion_sequence.size(1), self.hidden_layer) gru_o, _ = self.recurrent(motion_sequence.float(), hidden_unit) # select the last valid, instead of: gru_o[-1, :, :] out = gru_o[tuple(torch.stack((lengths-1, torch.arange(bs, device=self.device))))] # dim (num_samples, 30) lin1 = self.linear1(out) lin1 = torch.tanh(lin1) return lin1 classifier_model_files = { "ntu13": "models/actionrecognition/ntu13_gru.tar", "humanact12": "models/actionrecognition/humanact12_gru.tar", } def load_classifier(dataset_type, input_size_raw, num_classes, device): model = torch.load(classifier_model_files[dataset_type], map_location=device) classifier = MotionDiscriminator(input_size_raw, 128, 2, device=device, output_size=num_classes).to(device) classifier.load_state_dict(model["model"]) classifier.eval() return classifier def load_classifier_for_fid(dataset_type, input_size_raw, num_classes, device): model = torch.load(classifier_model_files[dataset_type], map_location=device) classifier = MotionDiscriminatorForFID(input_size_raw, 128, 2, device=device, output_size=num_classes).to(device) classifier.load_state_dict(model["model"]) classifier.eval() return classifier def test(): from src.datasets.ntu13 import NTU13 import src.utils.fixseed # noqa classifier = load_classifier("ntu13", input_size_raw=54, num_classes=13, device="cuda").eval() params = {"pose_rep": "rot6d", "translation": True, "glob": True, "jointstype": "a2m", "vertstrans": True, "num_frames": 60, "sampling": "conseq", "sampling_step": 1} dataset = NTU13(**params) from src.models.rotation2xyz import Rotation2xyz rot2xyz = Rotation2xyz(device="cuda") confusion_xyz = torch.zeros(13, 13, dtype=torch.long) confusion = torch.zeros(13, 13, dtype=torch.long) for i in range(1000): dataset.pose_rep = "xyz" data = dataset[i][0].to("cuda") data = data[None] dataset.pose_rep = params["pose_rep"] x = dataset[i][0].to("cuda")[None] mask = torch.ones(1, x.shape[-1], dtype=bool, device="cuda") lengths = mask.sum(1) xyz_t = rot2xyz(x, mask, **params) predicted_cls_xyz = classifier(data, lengths=lengths).argmax().item() predicted_cls = classifier(xyz_t, lengths=lengths).argmax().item() gt_cls = dataset[i][1] confusion_xyz[gt_cls][predicted_cls_xyz] += 1 confusion[gt_cls][predicted_cls] += 1 accuracy_xyz = torch.trace(confusion_xyz)/torch.sum(confusion_xyz).item() accuracy = torch.trace(confusion)/torch.sum(confusion).item() print(f"accuracy: {accuracy:.1%}, accuracy_xyz: {accuracy_xyz:.1%}") if __name__ == "__main__": test() ================================================ FILE: PBnet/src/evaluate/evaluate_cvae.py ================================================ import sys sys.path.append('/train20/intern/permanent/lmlin2/ReferenceCode/ACTOR-master') from src.parser.evaluation import parser from src.datasets.datasets_crema_pos import CREMA from src.datasets.datasets_hdtf_pos_chunk import HDTF # from src.datasets.datasets_hdtf_pos_chunk_mel_3 import HDTF from src.evaluate.tvae_eval import evaluate def main(): parameters, folder, checkpointname, epoch, niter = parser() # data path dataset_name = parameters["dataset"] if dataset_name == 'crema': # data path data_dir = "/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images" # model and dataset dataset = CREMA(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'test') dataset.update_parameters(parameters) elif dataset_name == 'hdtf': data_dir = "/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz" dataset = HDTF(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'test') dataset.update_parameters(parameters) else: dataset = None print('Dataset can not be found!!') evaluate(parameters, dataset, folder, checkpointname, epoch, niter) if __name__ == '__main__': main() ================================================ FILE: PBnet/src/evaluate/evaluate_cvae_debug.py ================================================ import sys sys.path.append('your_path/PBnet') from src.parser.evaluation import parser from src.datasets.datasets_crema_pos import CREMA from src.datasets.datasets_hdtf_pos_chunk_norm_2 import HDTF from src.evaluate.tvae_eval_train_norm import evaluate def main(): parameters, folder, checkpointname, epoch, niter = parser() # data path dataset_name = parameters["dataset"] if dataset_name == 'crema': # data path data_dir = "/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images" # model and dataset dataset = CREMA(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'test') dataset.update_parameters(parameters) elif dataset_name == 'hdtf': data_dir = "/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz" dataset = HDTF(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'test') dataset.update_parameters(parameters) else: dataset = None print('Dataset can not be found!!') evaluate(parameters, dataset, folder, checkpointname, epoch, niter) if __name__ == '__main__': main() ================================================ FILE: PBnet/src/evaluate/evaluate_cvae_f3.py ================================================ import sys sys.path.append('your_path/PBnet') from src.parser.evaluation import parser from src.datasets.datasets_crema_pos import CREMA from src.datasets.datasets_hdtf_pos_chunk_3 import HDTF from src.evaluate.tvae_eval_std import evaluate def main(): parameters, folder, checkpointname, epoch, niter = parser() # data path dataset_name = parameters["dataset"] if dataset_name == 'crema': # data path data_dir = "/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images" # model and dataset dataset = CREMA(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'test') dataset.update_parameters(parameters) elif dataset_name == 'hdtf': data_dir = "/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz" dataset = HDTF(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'test') dataset.update_parameters(parameters) else: dataset = None print('Dataset can not be found!!') evaluate(parameters, dataset, folder, checkpointname, epoch, niter) if __name__ == '__main__': main() ================================================ FILE: PBnet/src/evaluate/evaluate_cvae_f3_debug.py ================================================ import sys sys.path.append('your_path/PBnet') from src.parser.evaluation import parser from src.datasets.datasets_crema_pos import CREMA from src.datasets.datasets_hdtf_pos_chunk_3 import HDTF from src.evaluate.tvae_eval_train_std import evaluate def main(): parameters, folder, checkpointname, epoch, niter = parser() # data path dataset_name = parameters["dataset"] if dataset_name == 'crema': # data path data_dir = "/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images" # model and dataset dataset = CREMA(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'test') dataset.update_parameters(parameters) elif dataset_name == 'hdtf': data_dir = "/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz" dataset = HDTF(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'test') dataset.update_parameters(parameters) else: dataset = None print('Dataset can not be found!!') evaluate(parameters, dataset, folder, checkpointname, epoch, niter) if __name__ == '__main__': main() ================================================ FILE: PBnet/src/evaluate/evaluate_cvae_f3_mel.py ================================================ import sys sys.path.append('your_path/PBnet') from src.parser.evaluation import parser from src.datasets.datasets_crema_pos import CREMA from src.datasets.datasets_hdtf_pos_chunk_mel_f3 import HDTF from src.evaluate.tvae_eval_std import evaluate def main(): parameters, folder, checkpointname, epoch, niter = parser() # data path dataset_name = parameters["dataset"] if dataset_name == 'crema': # data path data_dir = "/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images" # model and dataset dataset = CREMA(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'test') dataset.update_parameters(parameters) elif dataset_name == 'hdtf': data_dir = "/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz" dataset = HDTF(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'test') dataset.update_parameters(parameters) else: dataset = None print('Dataset can not be found!!') evaluate(parameters, dataset, folder, checkpointname, epoch, niter) if __name__ == '__main__': main() ================================================ FILE: PBnet/src/evaluate/evaluate_cvae_norm.py ================================================ import sys sys.path.append('your_path/PBnet') from src.parser.evaluation import parser from src.datasets.datasets_crema_pos import CREMA from src.datasets.datasets_hdtf_pos_chunk_norm_2 import HDTF from src.evaluate.tvae_eval_norm import evaluate def main(): parameters, folder, checkpointname, epoch, niter = parser() # data path dataset_name = parameters["dataset"] parameters["eye"] = False if dataset_name == 'crema': # data path data_dir = "/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images" # model and dataset dataset = CREMA(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'test') dataset.update_parameters(parameters) elif dataset_name == 'hdtf': data_dir = "/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz" dataset = HDTF(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'test') dataset.update_parameters(parameters) else: dataset = None print('Dataset can not be found!!') evaluate(parameters, dataset, folder, checkpointname, epoch, niter) if __name__ == '__main__': main() ================================================ FILE: PBnet/src/evaluate/evaluate_cvae_norm_all.py ================================================ import sys sys.path.append('your_path/PBnet') from src.parser.evaluation import parser from src.datasets.datasets_crema_pos import CREMA from src.datasets.datasets_hdtf_pos_chunk_norm_2_all import HDTF from src.evaluate.tvae_eval_norm_all import evaluate def main(): parameters, folder, checkpointname, epoch, niter = parser() # data path dataset_name = parameters["dataset"] if dataset_name == 'crema': # data path data_dir = "/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images" # model and dataset dataset = CREMA(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'test') dataset.update_parameters(parameters) elif dataset_name == 'hdtf': data_dir = "/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz" dataset = HDTF(data_dir=data_dir, max_num_frames=1e8, mode = 'test') # dataset.update_parameters(parameters) parameters["audio_dim"] = 1024 parameters["pos_dim"] = 6 else: dataset = None print('Dataset can not be found!!') evaluate(parameters, dataset, folder, checkpointname, epoch, niter) if __name__ == '__main__': main() ================================================ FILE: PBnet/src/evaluate/evaluate_cvae_norm_all_seg.py ================================================ import sys sys.path.append('your_path/PBnet') from src.parser.evaluation import parser from src.datasets.datasets_crema_pos import CREMA from src.datasets.datasets_hdtf_pos_chunk_norm_2_all import HDTF from src.evaluate.tvae_eval_norm_seg import evaluate def main(): parameters, folder, checkpointname, epoch, niter = parser() # data path dataset_name = parameters["dataset"] if dataset_name == 'crema': # data path data_dir = "/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images" # model and dataset dataset = CREMA(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'test') dataset.update_parameters(parameters) elif dataset_name == 'hdtf': data_dir = "/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz" dataset = HDTF(data_dir=data_dir, max_num_frames=1e8, mode = 'test') # dataset.update_parameters(parameters) parameters["audio_dim"] = 1024 parameters["pos_dim"] = 6 else: dataset = None print('Dataset can not be found!!') evaluate(parameters, dataset, folder, checkpointname, epoch, niter) if __name__ == '__main__': main() ================================================ FILE: PBnet/src/evaluate/evaluate_cvae_norm_all_seg_weye.py ================================================ import sys sys.path.append('your_path/PBnet') from src.parser.evaluation import parser from src.datasets.datasets_crema_pos_eye_fast_all import CREMA from src.datasets.datasets_hdtf_pos_chunk_norm_2_all import HDTF from src.evaluate.tvae_eval_norm_seg import evaluate def main(): parameters, folder, checkpointname, epoch, niter = parser() # data path dataset_name = parameters["dataset"] if dataset_name == 'crema': # data path data_dir = "/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images" # model and dataset dataset = CREMA(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'test') dataset.update_parameters(parameters) elif dataset_name == 'hdtf': data_dir = "/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz" dataset = HDTF(data_dir=data_dir, max_num_frames=1e8, mode = 'test') # dataset.update_parameters(parameters) parameters["audio_dim"] = 1024 parameters["pos_dim"] = 6 else: dataset = None print('Dataset can not be found!!') evaluate(parameters, dataset, folder, checkpointname, epoch, niter) if __name__ == '__main__': main() ================================================ FILE: PBnet/src/evaluate/evaluate_cvae_norm_all_seg_weye2.py ================================================ import sys sys.path.append('your_path/PBnet') from src.parser.evaluation import parser from src.datasets.datasets_crema_pos_eye_fast_all import CREMA from src.datasets.datasets_hdtf_pos_chunk_norm_eye_fast import HDTF from src.evaluate.tvae_eval_norm_eye_pose_seg import evaluate def main(): parameters, folder, checkpointname, epoch, niter = parser() # data path dataset_name = parameters["dataset"] if dataset_name == 'crema': # data path data_dir = "/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images" # model and dataset dataset = CREMA(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'test') dataset.update_parameters(parameters) elif dataset_name == 'hdtf': data_dir = "/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz" dataset = HDTF(data_dir=data_dir, max_num_frames=1e8, mode = 'test') # dataset.update_parameters(parameters) parameters["audio_dim"] = 1024 parameters["pos_dim"] = 3 parameters['latent_dim'] = 128 else: dataset = None print('Dataset can not be found!!') evaluate(parameters, dataset, folder, checkpointname, epoch, niter) if __name__ == '__main__': main() ================================================ FILE: PBnet/src/evaluate/evaluate_cvae_norm_eye_pose.py ================================================ import sys sys.path.append('your_path/PBnet') from src.parser.evaluation import parser from src.datasets.datasets_crema_pos_eye_fast import CREMA from src.datasets.datasets_hdtf_pos_chunk_norm_eye_fast import HDTF from src.evaluate.tvae_eval_norm_eye_pose import evaluate def main(): parameters, folder, checkpointname, epoch, niter = parser() # data path dataset_name = parameters["dataset"] if dataset_name == 'crema': # data path data_dir = "/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images" # model and dataset dataset = CREMA(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'test') dataset.update_parameters(parameters) elif dataset_name == 'hdtf': data_dir = "/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz" dataset = HDTF(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'test') dataset.update_parameters(parameters) else: dataset = None print('Dataset can not be found!!') evaluate(parameters, dataset, folder, checkpointname, epoch, niter) if __name__ == '__main__': main() ================================================ FILE: PBnet/src/evaluate/evaluate_cvae_norm_eye_pose_test.py ================================================ import sys sys.path.append('your_path/PBnet') from src.parser.evaluation import parser from src.datasets.datasets_crema_pos import CREMA from src.datasets.datasets_hdtf_pos_chunk_norm_eye_fast import HDTF from src.evaluate.tvae_eval_norm_eye_pose import evaluate def main(): parameters, folder, checkpointname, epoch, niter = parser() # data path dataset_name = parameters["dataset"] if dataset_name == 'crema': # data path data_dir = "/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images" # model and dataset dataset = CREMA(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'test') dataset.update_parameters(parameters) elif dataset_name == 'hdtf': data_dir = "/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz" dataset = HDTF(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'test') dataset.update_parameters(parameters) else: dataset = None print('Dataset can not be found!!') evaluate(parameters, dataset, folder, checkpointname, epoch, niter) if __name__ == '__main__': main() ================================================ FILE: PBnet/src/evaluate/evaluate_cvae_onlyeye_all_seg.py ================================================ import sys sys.path.append('/train20/intern/permanent/lmlin2/ReferenceCode/ACTOR-master') from src.parser.evaluation import parser from src.datasets.datasets_crema_pos import CREMA from src.datasets.datasets_hdtf_onlyeye_fast import HDTF from src.evaluate.tvae_eval_onlyeye_all_seg import evaluate def main(): parameters, folder, checkpointname, epoch, niter = parser() # data path dataset_name = parameters["dataset"] if dataset_name == 'crema': # data path data_dir = "/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images" # model and dataset dataset = CREMA(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'test') dataset.update_parameters(parameters) elif dataset_name == 'hdtf': data_dir = "/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz" dataset = HDTF(data_dir=data_dir, max_num_frames=1e8, mode = 'test') dataset.update_parameters(parameters) # parameters["audio_dim"] = 1024 # parameters["pos_dim"] = 6 else: dataset = None print('Dataset can not be found!!') evaluate(parameters, dataset, folder, checkpointname, epoch, niter) if __name__ == '__main__': main() ================================================ FILE: PBnet/src/evaluate/othermetrics/acceleration.py ================================================ import torch import numpy as np from src.utils.tensors import lengths_to_mask def calculate_acceletation(motionloader, device, xyz): # for now even if it is not xyz, the acceleration is one the euclidian/pose outfeat = "output_xyz" if xyz else "output" sum_acc = 0 num_acc = 0 for batch in motionloader: motion = batch[outfeat].permute(0, 3, 1, 2) bs, num_frames, njoints, nfeats = motion.shape velocity = motion[:, 1:] - motion[:, :-1] acceleration = velocity[:, 1:] - velocity[:, :-1] acceleration_normed = torch.linalg.norm(acceleration, axis=3) lengths = batch["lengths"] mask = lengths_to_mask(lengths - 2) # because acceleration usefull_accs_n = acceleration_normed[mask] sum_acc += usefull_accs_n.sum().item() num_acc += np.prod(usefull_accs_n.shape) acceleration = sum_acc/num_acc return acceleration ================================================ FILE: PBnet/src/evaluate/othermetrics/evaluation.py ================================================ import torch import numpy as np from ..action2motion.diversity import calculate_diversity_multimodality from .acceleration import calculate_acceletation class OtherMetricsEvaluation: """ Evaluation of some metrics in output space (not feature space): - Acceleration metrics - Reconstruction loss - Diversity - Multimodality (Not used in the paper) """ def __init__(self, device): self.device = device def compute_features(self, model, motionloader, xyz=True): feat = "output_xyz" if xyz else "output" activations = [] labels = [] for idx, batch in enumerate(motionloader): batch_motion = batch[feat] batch_label = batch["y"] activations.append(batch_motion) labels.append(batch_label) activations = torch.cat(activations, dim=0) activations = activations.reshape(activations.shape[0], -1) labels = torch.cat(labels, dim=0) return activations, labels def reconstructionloss(self, motionloader, xyz=True): infeat = "x_xyz" if xyz else "x" outfeat = "output_xyz" if xyz else "output" sum_loss = 0 num_loss = 0 for batch in motionloader: motion_in = batch[infeat].permute(0, 3, 1, 2) motion_out = batch[outfeat].permute(0, 3, 1, 2) mask = batch["mask"] square_diff = (motion_in[mask] - motion_out[mask])**2 sum_loss += square_diff.sum().item() num_loss += np.prod(square_diff.shape) rcloss = sum_loss / num_loss return rcloss def evaluate(self, model, num_classes, loaders, xyz=True): # get the xyz as well model.outputxyz = True metrics = {} repname = "xyz" if xyz else "pose" def print_logs(metric, key): print(f"Computing {metric} on the {key} loader ({repname})...") for key, loader in loaders.items(): # acceleration metric = "acceleration" print_logs(metric, key) mkey = f"{metric}_{key}" metrics[mkey] = calculate_acceletation(loader, device=self.device, xyz=xyz) # features for diversity print_logs("features", key) feats, labels = self.compute_features(model, loader, xyz=xyz) # diversity and multimodality metric = "diversity" print_logs(metric, key) ret = calculate_diversity_multimodality(feats, labels, num_classes) metrics[f"diversity_{key}"], metrics[f"multimodality_{key}"] = ret metric = "rc_recons" print(f"Computing reconstruction loss ({repname})..") rcloss = self.reconstructionloss(loaders["recons"], xyz=xyz) metrics[metric] = rcloss return metrics ================================================ FILE: PBnet/src/evaluate/stgcn/accuracy.py ================================================ import torch def calculate_accuracy(model, motion_loader, num_labels, classifier, device): confusion = torch.zeros(num_labels, num_labels, dtype=torch.long) with torch.no_grad(): for batch in motion_loader: batch_prob = classifier(batch)["yhat"] batch_pred = batch_prob.max(dim=1).indices for label, pred in zip(batch["y"], batch_pred): confusion[label][pred] += 1 accuracy = torch.trace(confusion)/torch.sum(confusion) return accuracy.item(), confusion ================================================ FILE: PBnet/src/evaluate/stgcn/diversity.py ================================================ import torch import numpy as np # from action2motion def calculate_diversity_multimodality(activations, labels, num_labels, seed=None): diversity_times = 200 multimodality_times = 20 labels = labels.long() num_motions = len(labels) diversity = 0 if seed is not None: np.random.seed(seed) first_indices = np.random.randint(0, num_motions, diversity_times) second_indices = np.random.randint(0, num_motions, diversity_times) for first_idx, second_idx in zip(first_indices, second_indices): diversity += torch.dist(activations[first_idx, :], activations[second_idx, :]) diversity /= diversity_times multimodality = 0 label_quotas = np.repeat(multimodality_times, num_labels) while np.any(label_quotas > 0): # print(label_quotas) first_idx = np.random.randint(0, num_motions) first_label = labels[first_idx] if not label_quotas[first_label]: continue second_idx = np.random.randint(0, num_motions) second_label = labels[second_idx] while first_label != second_label: second_idx = np.random.randint(0, num_motions) second_label = labels[second_idx] label_quotas[first_label] -= 1 first_activation = activations[first_idx, :] second_activation = activations[second_idx, :] multimodality += torch.dist(first_activation, second_activation) multimodality /= (multimodality_times * num_labels) return diversity.item(), multimodality.item() ================================================ FILE: PBnet/src/evaluate/stgcn/evaluate.py ================================================ import torch import numpy as np from .accuracy import calculate_accuracy from .fid import calculate_fid from .diversity import calculate_diversity_multimodality from src.recognition.models.stgcn import STGCN class Evaluation: def __init__(self, dataname, parameters, device, seed=None): layout = "smpl" if parameters["glob"] else "smpl_noglobal" model = STGCN(in_channels=parameters["nfeats"], num_class=parameters["num_classes"], graph_args={"layout": layout, "strategy": "spatial"}, edge_importance_weighting=True, device=parameters["device"]) model = model.to(parameters["device"]) modelpath = "models/actionrecognition/uestc_rot6d_stgcn.tar" state_dict = torch.load(modelpath, map_location=parameters["device"]) model.load_state_dict(state_dict) model.eval() self.num_classes = parameters["num_classes"] self.model = model self.dataname = dataname self.device = device self.seed = seed def compute_features(self, model, motionloader): # calculate_activations_labels function from action2motion activations = [] labels = [] with torch.no_grad(): for idx, batch in enumerate(motionloader): activations.append(self.model(batch)["features"]) labels.append(batch["y"]) activations = torch.cat(activations, dim=0) labels = torch.cat(labels, dim=0) return activations, labels @staticmethod def calculate_activation_statistics(activations): activations = activations.cpu().numpy() mu = np.mean(activations, axis=0) sigma = np.cov(activations, rowvar=False) return mu, sigma def evaluate(self, model, loaders): def print_logs(metric, key): print(f"Computing stgcn {metric} on the {key} loader ...") metrics_all = {} for sets in ["train", "test"]: computedfeats = {} metrics = {} for key, loaderSets in loaders.items(): loader = loaderSets[sets] metric = "accuracy" print_logs(metric, key) mkey = f"{metric}_{key}" metrics[mkey], _ = calculate_accuracy(model, loader, self.num_classes, self.model, self.device) # features for diversity print_logs("features", key) feats, labels = self.compute_features(model, loader) print_logs("stats", key) stats = self.calculate_activation_statistics(feats) computedfeats[key] = {"feats": feats, "labels": labels, "stats": stats} print_logs("diversity", key) ret = calculate_diversity_multimodality(feats, labels, self.num_classes, seed=self.seed) metrics[f"diversity_{key}"], metrics[f"multimodality_{key}"] = ret # taking the stats of the ground truth and remove it from the computed feats gtstats = computedfeats["gt"]["stats"] # computing fid for key, loader in computedfeats.items(): metric = "fid" mkey = f"{metric}_{key}" stats = computedfeats[key]["stats"] metrics[mkey] = float(calculate_fid(gtstats, stats)) metrics_all[sets] = metrics metrics = {} for sets in ["train", "test"]: for key in metrics_all[sets]: metrics[f"{key}_{sets}"] = metrics_all[sets][key] return metrics ================================================ FILE: PBnet/src/evaluate/stgcn/fid.py ================================================ import numpy as np from scipy import linalg # from action2motion def calculate_fid(statistics_1, statistics_2): return calculate_frechet_distance(statistics_1[0], statistics_1[1], statistics_2[0], statistics_2[1]) def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): """Numpy implementation of the Frechet Distance. The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) and X_2 ~ N(mu_2, C_2) is d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). Stable version by Dougal J. Sutherland. Params: -- mu1 : Numpy array containing the activations of a layer of the inception net (like returned by the function 'get_predictions') for generated samples. -- mu2 : The sample mean over activations, precalculated on an representative data set. -- sigma1: The covariance matrix over activations for generated samples. -- sigma2: The covariance matrix over activations, precalculated on an representative data set. Returns: -- : The Frechet Distance. """ mu1 = np.atleast_1d(mu1) mu2 = np.atleast_1d(mu2) sigma1 = np.atleast_2d(sigma1) sigma2 = np.atleast_2d(sigma2) assert mu1.shape == mu2.shape, \ 'Training and test mean vectors have different lengths' assert sigma1.shape == sigma2.shape, \ 'Training and test covariances have different dimensions' diff = mu1 - mu2 # Product might be almost singular covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) if not np.isfinite(covmean).all(): msg = ('fid calculation produces singular product; ' 'adding %s to diagonal of cov estimates') % eps print(msg) offset = np.eye(sigma1.shape[0]) * eps covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) # Numerical error might give slight imaginary component if np.iscomplexobj(covmean): if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): m = np.max(np.abs(covmean.imag)) raise ValueError('Imaginary component {}'.format(m)) covmean = covmean.real tr_covmean = np.trace(covmean) return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean) ================================================ FILE: PBnet/src/evaluate/tables/archtable.py ================================================ import os import glob import math import re import numpy as np from .tools import load_metrics def valformat(val, power=3): p = float(pow(10, power)) # "{:<04}".format(np.round(p*val).astype(int)/p) return str(np.round(p*val).astype(int)/p).ljust(4, "0") def format_values(values, key): mean = np.mean(values) if key == "accuracy": mean = 100*mean values = 100*values smean = valformat(mean, 1) else: smean = valformat(mean, 2) interval = valformat(1.96 * np.var(values), 2) # [1:] # string = rf"${mean:.4}^{{\pm{interval:.3}}}$" # string = rf"${smean}$" # ^{{\pm{interval}}}$" string = rf"${smean}^{{\pm{interval}}}$" return string def construct_table(folder): exppath = folder paths = glob.glob(f"{exppath}/**/evaluation*_all.yaml") keys = ["fid", "accuracy", "diversity", "multimodality"] model_metrics_dataset = {"ntu13": {}, "uestc": {}} epoch_dataset = {"ntu13": 1000, "uestc": 500} model_naming = {"fc": "Fully connected", "gru": "GRU", # "transformer": "Old transformer", "gtransformer": "Transformer"} ablation_naming = {"average_encoder": r"No $\mu_{a}^{token},\Sigma_{a}^{token}$", "time_encoding": r"No Decoder-PE", "zandtime": r"No $b_{a}^{token}$"} for i, path in enumerate(paths): epoch = int(path.split("evaluation_metrics_")[1].split(".")[0].split("_")[0]) modelinfo = os.path.split(os.path.split(path)[0])[1] modelname = modelinfo.split("_")[1] dataset = modelinfo.split("_kl_")[1].split("_")[0] # Take the right epoch if epoch_dataset[dataset] != epoch: continue # Ablation study if "abl" in modelinfo: ablation = modelinfo.split("_abl_")[1].split("_sampling")[0] if ablation not in ablation_naming: continue name = ablation_naming[ablation] else: if modelname not in model_naming: continue name = model_naming[modelname] metrics = load_metrics(path) model_metrics = model_metrics_dataset[dataset] if dataset == "ntu13": a2m = metrics["action2motion"] if "GT" not in model_metrics: a2m["fid_gt"] = a2m["fid_gt2"] row = [] for key in keys: ckey = f"{key}_gt" values = np.array([float(x) for x in a2m[ckey]]) string = format_values(values, key) row.append(string) model_metrics["GT"] = row row = [] for key in keys: ckey = f"{key}_gen" values = np.array([float(x) for x in a2m[ckey]]) string = format_values(values, key) row.append(string) model_metrics[name] = row elif dataset == "uestc": stgcn = metrics["stgcn"] if "GT" not in model_metrics: for sets in ["train", "test"]: stgcn[f"fid_gt_{sets}"] = stgcn[f"fid_gt2_{sets}"] stgcnkeys = ["fid_gt_train", "fid_gt_test", "accuracy_gt_train", "diversity_gt_train", "multimodality_gt_train"] row = [] for ckey in stgcnkeys: values = np.array([float(x) for x in stgcn[ckey]]) string = format_values(values, ckey.split("_")[0]) row.append(string) model_metrics["GT"] = row stgcnkeys = ["fid_gen_train", "fid_gen_test", "accuracy_gen_train", "diversity_gen_train", "multimodality_gen_train"] row = [] for ckey in stgcnkeys: values = np.array([float(x) for x in stgcn[ckey]]) string = format_values(values, ckey.split("_")[0]) row.append(string) model_metrics[name] = row archmodels = list(model_naming.values()) ablationmodels = list(ablation_naming.values()) gtvalues = ["GT"] for dataset in ["uestc", "ntu13"]: model_metrics = model_metrics_dataset[dataset] gtvalues.extend(model_metrics["GT"]) gtrow = " & ".join(gtvalues) + r"\\" groupedrows = [] for lst in [archmodels, ablationmodels]: rows = [] for model in lst: if model == "GT": continue values = [model] for dataset in ["uestc", "ntu13"]: model_metrics = model_metrics_dataset[dataset] if model in model_metrics: values.extend(model_metrics[model]) else: dummy = ["" for _ in range(len(model_metrics["GT"]))] values.extend(dummy) row = " & ".join(values) + r"\\" rows.append(row) groupedrows.append("\n".join(rows) + "\n") template = r"""\documentclass{{standalone}} \usepackage{{booktabs}} \usepackage[dvipsnames]{{xcolor}} \begin{{document}} \begin{{tabular}}{{lccccc|cccc}} \toprule Architecture & FID$_{{tr}}$$\downarrow$ & Acc.$\uparrow$ & Div.$\uparrow$ & Multimod.$\uparrow$ & FID$_{{tr}}$$\downarrow$ & FID$_{{test}}$$\downarrow$ & Acc.$\uparrow$ & Div.$\uparrow$ & Multimod.$\uparrow$\\ \midrule {gtrow} \midrule {archrow} \midrule {ablationrow} \bottomrule \end{{tabular}} \end{{document}} """.format(gtrow=gtrow, archrow=groupedrows[0], ablationrow=groupedrows[1]) return template if __name__ == "__main__": import argparse def parse_opts(): parser = argparse.ArgumentParser() parser.add_argument("exppath", help="name of the exp") return parser.parse_args() opt = parse_opts() exppath = opt.exppath folder = exppath tex = construct_table(folder) texpath = os.path.join(folder, "table_arch.tex") with open(texpath, "w") as ftex: ftex.write(tex) print(f"Table saved at {texpath}") ================================================ FILE: PBnet/src/evaluate/tables/bstable.py ================================================ import os import glob import math import re import numpy as np from .tools import load_metrics def valformat(val, power=3): p = float(pow(10, power)) # "{:<04}".format(np.round(p*val).astype(int)/p) return str(np.round(p*val).astype(int)/p).ljust(4, "0") def format_values(values, key): mean = np.mean(values) if key == "accuracy": mean = 100*mean values = 100*values smean = valformat(mean, 1) else: smean = valformat(mean, 2) interval = valformat(1.96 * np.var(values), 2) # [1:] # string = rf"${mean:.4}^{{\pm{interval:.3}}}$" # string = rf"${smean}$" # ^{{\pm{interval}}}$" string = rf"${smean}^{{\pm{interval}}}$" return string def construct_table(folder): exppath = folder paths = glob.glob(f"{exppath}/**/evaluation*_all*.yaml") keys = ["fid", "accuracy", "diversity", "multimodality"] model_metrics_dataset = {"ntu13": {}, "uestc": {}} epoch_dataset = {"ntu13": 1000, "uestc": 500} for i, path in enumerate(paths): epoch = int(path.split("evaluation_metrics_")[1].split(".")[0].split("_")[0]) modelinfo = os.path.split(os.path.split(path)[0])[1] dataset = modelinfo.split("_kl_")[1].split("_")[0] # Take the right epoch if epoch_dataset[dataset] != epoch: continue name = "Batch size " + modelinfo.split("bs_")[1] metrics = load_metrics(path) model_metrics = model_metrics_dataset[dataset] if dataset == "ntu13": a2m = metrics["action2motion"] if "GT" not in model_metrics: a2m["fid_gt"] = a2m["fid_gt2"] row = [] for key in keys: ckey = f"{key}_gt" values = np.array([float(x) for x in a2m[ckey]]) string = format_values(values, key) row.append(string) model_metrics["GT"] = row row = [] for key in keys: ckey = f"{key}_gen" values = np.array([float(x) for x in a2m[ckey]]) string = format_values(values, key) row.append(string) model_metrics[name] = row elif dataset == "uestc": stgcn = metrics["stgcn"] if "GT" not in model_metrics: for sets in ["train", "test"]: stgcn[f"fid_gt_{sets}"] = stgcn[f"fid_gt2_{sets}"] stgcnkeys = ["fid_gt_train", "fid_gt_test", "accuracy_gt_train", "diversity_gt_train", "multimodality_gt_train"] row = [] for ckey in stgcnkeys: values = np.array([float(x) for x in stgcn[ckey]]) string = format_values(values, ckey.split("_")[0]) row.append(string) model_metrics["GT"] = row stgcnkeys = ["fid_gen_train", "fid_gen_test", "accuracy_gen_train", "diversity_gen_train", "multimodality_gen_train"] row = [] for ckey in stgcnkeys: values = np.array([float(x) for x in stgcn[ckey]]) string = format_values(values, ckey.split("_")[0]) row.append(string) model_metrics[name] = row gtvalues = ["GT"] for dataset in ["uestc", "ntu13"]: model_metrics = model_metrics_dataset[dataset] gtvalues.extend(model_metrics["GT"]) gtrow = " & ".join(gtvalues) + r"\\" rows = [] modelnames = sorted(list(model_metrics.keys())) for model in modelnames: if model == "GT": continue values = [model] for dataset in ["uestc", "ntu13"]: model_metrics = model_metrics_dataset[dataset] if model in model_metrics: values.extend(model_metrics[model]) else: dummy = ["" for _ in range(len(model_metrics["GT"]))] values.extend(dummy) row = " & ".join(values) + r"\\" rows.append(row) rows = "\n".join(rows) template = r"""\documentclass{{standalone}} \usepackage{{booktabs}} \usepackage[dvipsnames]{{xcolor}} \begin{{document}} \begin{{tabular}}{{lccccc|cccc}} \toprule & \multicolumn{{5}}{{c}}{{UESTC}} & \multicolumn{{4}}{{|c}}{{NTU-13}} \\ Loss & FID$_{{tr}}$$\downarrow$ & FID$_{{test}}$$\downarrow$ & Acc.$\uparrow$ & Div.$\rightarrow$ & Multimod.$\rightarrow$ & FID$_{{tr}}$$\downarrow$ & Acc.$\uparrow$ & Div.$\rightarrow$ & Multimod.$\rightarrow$ \\ \midrule {gtrow} \midrule {rows} \bottomrule \end{{tabular}} \end{{document}} """.format(rows=rows, gtrow=gtrow) return template if __name__ == "__main__": import argparse def parse_opts(): parser = argparse.ArgumentParser() parser.add_argument("exppath", help="name of the exp") return parser.parse_args() opt = parse_opts() exppath = opt.exppath folder = exppath tex = construct_table(folder) texpath = os.path.join(folder, "table_loss.tex") with open(texpath, "w") as ftex: ftex.write(tex) print(f"Table saved at {texpath}") ================================================ FILE: PBnet/src/evaluate/tables/easy_table.py ================================================ import os import glob import math import numpy as np from ..tools import load_metrics def get_gtname(mname): return mname + "_gt" def get_genname(mname): return mname + "_gen" def get_reconsname(mname): return mname + "_recons" def valformat(val, power=3): p = float(pow(10, power)) # "{:<04}".format(np.round(p*val).astype(int)/p) return str(np.round(p*val).astype(int)/p).ljust(4, "0") def format_values(values, key, latex=True): mean = np.mean(values) if "accuracy" in key: mean = 100*mean values = 100*values smean = valformat(mean, 1) else: smean = valformat(mean, 2) interval = valformat(1.96 * np.var(values), 2) # [1:] if latex: string = rf"${smean}^{{\pm{interval}}}$" else: string = rf"{smean} +/- {interval}" return string def print_results(folder, evaluation): evalpath = os.path.join(folder, evaluation) metrics = load_metrics(evalpath) a2m = metrics["feats"] if "fid_gen_test" in a2m: keys = ["fid_{}_train", "fid_{}_test", "accuracy_{}_train", "diversity_{}_train", "multimodality_{}_train"] else: keys = ["fid_{}", "accuracy_{}", "diversity_{}", "multimodality_{}"] lines = ["gen", "recons"] # print the GT, only if it is computed with respect to "another" GT if "fid_gt2" in a2m: a2m["fid_gt"] = a2m["fid_gt2"] lines = ["gt"] + lines rows = [] rows_latex = [] for model in lines: row = ["{:6}".format(model)] row_latex = ["{:6}".format(model)] try: for key in keys: ckey = key.format(model) values = np.array([float(x) for x in a2m[ckey]]) string_latex = format_values(values, key, latex=True) string = format_values(values, key, latex=False) row.append(string) row_latex.append(string_latex) rows.append(" | ".join(row)) rows_latex.append(" & ".join(row_latex) + r"\\") except KeyError: continue table = "\n".join(rows) table_latex = "\n".join(rows_latex) print("Results") print(table) print() print("Latex table") print(table_latex) if __name__ == "__main__": import argparse def parse_opts(): parser = argparse.ArgumentParser() parser.add_argument("evalpath", help="name of the evaluation") return parser.parse_args() opt = parse_opts() evalpath = opt.evalpath folder, evaluation = os.path.split(evalpath) print_results(folder, evaluation) ================================================ FILE: PBnet/src/evaluate/tables/easy_table_A2M.py ================================================ import os import glob import math import numpy as np from ..tools import load_metrics def valformat(val, power=3): p = float(pow(10, power)) # "{:<04}".format(np.round(p*val).astype(int)/p) return str(np.round(p*val).astype(int)/p).ljust(4, "0") def construct_table(folder, evaluation): evalpath = os.path.join(folder, evaluation) metrics = load_metrics(evalpath) a2m = metrics["feats"] keys = ["fid", "accuracy", "diversity", "multimodality"] a2m["fid_gt"] = a2m["fid_gt2"] values = [] rows = [] for model in ["gt", "gen", "genden"]: row = ["{:6}".format(model)] for key in keys: ckey = f"{key}_{model}" values = np.array([float(x) for x in a2m[ckey]]) mean = np.mean(values) if key == "accuracy": mean = 100*mean values = 100*values smean = valformat(mean, 1) else: smean = valformat(mean, 2) mean = np.mean(values) interval = valformat(1.96 * np.var(values), 2) # [1:] string = rf"${smean}^{{\pm{interval}}}$" # string = rf"{mean:.4}" #^{{\pm{interval:.1}}}" row.append(string) rows.append(" & ".join(row) + r"\\") test = "\n".join(rows) print(test) import ipdb; ipdb.set_trace() bodylist.append(r"\bottomrule") body = "\n".join(bodylist) ncols = 5 title = f"Evaluation TODO name" template = r"""\documentclass{{standalone}} \usepackage{{booktabs}} \usepackage[dvipsnames]{{xcolor}} \begin{{document}} \begin{{tabular}}{{{ncolsl}}} \multicolumn{{{ncols}}}{{c}}{{{title}}} \\ \multicolumn{{{ncols}}}{{c}}{{}} \\ & \multicolumn{{{nbcolsxyz}}}{{c}}{{xyz}} & & \multicolumn{{{nbcolspose}}}{{c}}{{{pose_rep}}} & & \multicolumn{{{nbcolsa2m}}}{{c}}{{action2motion}} \\ {firstrow} \midrule {body} \end{{tabular}} \end{{document}} """.format(ncolsl="l"+"c"*(ncols-1), ncols=ncols, pose_rep=pose_rep, title=title, firstrow=firstrow, nbcolsxyz=len(METRICS["joints"]), nbcolspose=len(METRICS[pose_rep]), nbcolsa2m=len(METRICS["action2motion"]), body=body) return template if __name__ == "__main__": import argparse def parse_opts(): parser = argparse.ArgumentParser() parser.add_argument("evalpath", help="name of the evaluation") return parser.parse_args() opt = parse_opts() evalpath = opt.evalpath folder, evaluation = os.path.split(evalpath) tex = construct_table(folder, evaluation) texpath = os.path.join(folder, os.path.splitext(evaluation)[0] + ".tex") with open(texpath, "w") as ftex: ftex.write(tex) print(f"Table saved at {texpath}") ================================================ FILE: PBnet/src/evaluate/tables/kltable.py ================================================ import os import glob import math import re import numpy as np from .tools import load_metrics def valformat(val, power=3): p = float(pow(10, power)) # "{:<04}".format(np.round(p*val).astype(int)/p) return str(np.round(p*val).astype(int)/p).ljust(4, "0") def format_values(values, key): mean = np.mean(values) if key == "accuracy": mean = 100*mean values = 100*values smean = valformat(mean, 1) else: smean = valformat(mean, 2) interval = valformat(1.96 * np.var(values), 2) # [1:] # string = rf"${mean:.4}^{{\pm{interval:.3}}}$" # string = rf"${smean}$" # ^{{\pm{interval}}}$" string = rf"${smean}^{{\pm{interval}}}$" return string def construct_table(folder): exppath = folder paths = glob.glob(f"{exppath}/**/evaluation*_all*.yaml") keys = ["fid", "accuracy", "diversity", "multimodality"] model_metrics_dataset = {"ntu13": {}, "uestc": {}} epoch_dataset = {"ntu13": 1000, "uestc": 500} for i, path in enumerate(paths): epoch = int(path.split("evaluation_metrics_")[1].split(".")[0].split("_")[0]) modelinfo = os.path.split(os.path.split(path)[0])[1] dataset = modelinfo.split("_kl_")[1].split("_")[0] # Take the right epoch if epoch_dataset[dataset] != epoch: continue name = modelinfo.split("samplingstep_1_")[1].split("_gelu")[0].replace("_", " ") metrics = load_metrics(path) model_metrics = model_metrics_dataset[dataset] if dataset == "ntu13": a2m = metrics["action2motion"] if "GT" not in model_metrics: a2m["fid_gt"] = a2m["fid_gt2"] row = [] for key in keys: ckey = f"{key}_gt" values = np.array([float(x) for x in a2m[ckey]]) string = format_values(values, key) row.append(string) model_metrics["GT"] = row row = [] for key in keys: ckey = f"{key}_gen" values = np.array([float(x) for x in a2m[ckey]]) string = format_values(values, key) row.append(string) model_metrics[name] = row elif dataset == "uestc": stgcn = metrics["stgcn"] if "GT" not in model_metrics: for sets in ["train", "test"]: stgcn[f"fid_gt_{sets}"] = stgcn[f"fid_gt2_{sets}"] stgcnkeys = ["fid_gt_train", "fid_gt_test", "accuracy_gt_train", "diversity_gt_train", "multimodality_gt_train"] row = [] for ckey in stgcnkeys: values = np.array([float(x) for x in stgcn[ckey]]) string = format_values(values, ckey.split("_")[0]) row.append(string) model_metrics["GT"] = row stgcnkeys = ["fid_gen_train", "fid_gen_test", "accuracy_gen_train", "diversity_gen_train", "multimodality_gen_train"] row = [] for ckey in stgcnkeys: values = np.array([float(x) for x in stgcn[ckey]]) string = format_values(values, ckey.split("_")[0]) row.append(string) model_metrics[name] = row gtvalues = ["GT"] for dataset in ["uestc", "ntu13"]: model_metrics = model_metrics_dataset[dataset] gtvalues.extend(model_metrics["GT"]) gtrow = " & ".join(gtvalues) + r"\\" rows = [] for model in model_metrics: if model == "GT": continue values = [model] for dataset in ["uestc", "ntu13"]: model_metrics = model_metrics_dataset[dataset] if model in model_metrics: values.extend(model_metrics[model]) else: dummy = ["" for _ in range(len(model_metrics["GT"]))] values.extend(dummy) row = " & ".join(values) + r"\\" rows.append(row) rows = "\n".join(rows) template = r"""\documentclass{{standalone}} \usepackage{{booktabs}} \usepackage[dvipsnames]{{xcolor}} \begin{{document}} \begin{{tabular}}{{lccccc|cccc}} \toprule & \multicolumn{{5}}{{c}}{{UESTC}} & \multicolumn{{4}}{{|c}}{{NTU-13}} \\ Loss & FID$_{{tr}}$$\downarrow$ & FID$_{{test}}$$\downarrow$ & Acc.$\uparrow$ & Div.$\rightarrow$ & Multimod.$\rightarrow$ & FID$_{{tr}}$$\downarrow$ & Acc.$\uparrow$ & Div.$\rightarrow$ & Multimod.$\rightarrow$ \\ \midrule {gtrow} \midrule {rows} \bottomrule \end{{tabular}} \end{{document}} """.format(rows=rows, gtrow=gtrow) return template if __name__ == "__main__": import argparse def parse_opts(): parser = argparse.ArgumentParser() parser.add_argument("exppath", help="name of the exp") return parser.parse_args() opt = parse_opts() exppath = opt.exppath folder = exppath tex = construct_table(folder) texpath = os.path.join(folder, "table_loss.tex") with open(texpath, "w") as ftex: ftex.write(tex) print(f"Table saved at {texpath}") ================================================ FILE: PBnet/src/evaluate/tables/latexmodela2m.py ================================================ import os import glob import math import numpy as np from .tools import load_metrics def get_gtname(mname): return mname + "_gt" def get_genname(mname): return mname + "_gen" def get_reconsname(mname): return mname + "_recons" def construct_table(folder, evaluation): evalpath = os.path.join(folder, evaluation) metrics = load_metrics(evalpath) a2m = metrics["action2motion"] keys = ["fid", "accuracy", "diversity", "multimodality"] a2m["fid_gt"] = a2m["fid_gt2"] modelname = os.path.split(folder)[1] modelname = modelname.replace("_ntu13_vibe_rot6d_glob_translation_numlayers_8_numframes_60_sampling_conseq_samplingstep_1_kl_1e-05_gelu", "") modelname = modelname.replace("_", " ") def valformat(val, power=3): p = float(pow(10, power)) # "{:<04}".format(np.round(p*val).astype(int)/p) return str(np.round(p*val).astype(int)/p).ljust(5, "0") values = [] rows = [] for model in ["gt", "gen", "recons"]: row = ["{} {}".format(modelname, model)] for key in keys: ckey = f"{key}_{model}" values = np.array([float(x) for x in a2m[ckey]]) mean = valformat(np.mean(values)) interval = valformat(1.96 * np.var(values))[1:] # string = rf"${mean:.4}^{{\pm{interval:.3}}}$" string = rf"${mean}^{{\pm{interval}}}$" row.append(string) row = " & ".join(row) + r"\\" rows.append(row) MODELS = "\n ".join(rows) template = r"""\documentclass{{standalone}} \usepackage{{booktabs}} \usepackage[dvipsnames]{{xcolor}} \begin{{document}} \begin{{tabular}}{{lcccc}} \toprule Architecture & FID$\downarrow$ & Acc.$\uparrow$ & Div.$\uparrow$ & Multimod.$\uparrow$\\ \midrule action2motion ground truth & $0.031^{{\pm.004}}$ & $0.999^{{\pm.001}}$ & $7.108^{{\pm.048}}$ & $2.194^{{\pm.025}}$ \\ action2motion lie model & $0.330^{{\pm.008}}$ & $0.949^{{\pm.001}}$ & $7.065^{{\pm.043}}$ & $2.052^{{\pm.030}}$ \\ \midrule {MODELS} \bottomrule \end{{tabular}} \end{{document}} """.format(MODELS=MODELS) return template if __name__ == "__main__": import argparse def parse_opts(): parser = argparse.ArgumentParser() parser.add_argument("evalpath", help="name of the evaluation") return parser.parse_args() opt = parse_opts() evalpath = opt.evalpath folder, evaluation = os.path.split(evalpath) tex = construct_table(folder, evaluation) texpath = os.path.join(folder, os.path.splitext(evaluation)[0] + ".tex") with open(texpath, "w") as ftex: ftex.write(tex) print(f"Table saved at {texpath}") ================================================ FILE: PBnet/src/evaluate/tables/latexmodelsa2m.py ================================================ import os import glob import math import re import numpy as np from .tools import load_metrics def valformat(val, power=3): p = float(pow(10, power)) # "{:<04}".format(np.round(p*val).astype(int)/p) return str(np.round(p*val).astype(int)/p).ljust(5, "0") def construct_table(folder): exppath = folder paths = glob.glob(f"{exppath}/**/evaluation*_all*.yaml") keys = ["fid", "accuracy", "diversity", "multimodality"] models_results = [] for i, path in enumerate(paths): metrics = load_metrics(path) a2m = metrics["action2motion"] a2m["fid_gt"] = a2m["fid_gt2"] modelname = os.path.split(os.path.split(path)[0])[1] for info in ["vibe", "rot6d", "glob", "translation", "numlayers_8", "numframes_60", "sampling_conseq", "samplingstep_1", "jointstype", "gelu", "kl_1e-05", "cvae", "ntu13"]: modelname = modelname.replace(info, "") modelname = re.sub("_{1,}", " ", modelname) # takin GT only for the first one if i == 0: gtrow = ["Our GT"] for key in keys: ckey = f"{key}_gt" values = np.array([float(x) for x in a2m[ckey]]) mean = valformat(np.mean(values)) interval = valformat(1.96 * np.var(values))[1:] # string = rf"${mean:.4}^{{\pm{interval:.3}}}$" string = rf"${mean}$" # ^{{\pm{interval}}}$" gtrow.append(string) gtrow = " & ".join(gtrow) + r"\\" rows = [] for model in ["gen"]: # ["gt", "gen", "recons"]: # row = ["{} {}".format(modelname, model)] row = [modelname] for key in keys: ckey = f"{key}_{model}" values = np.array([float(x) for x in a2m[ckey]]) mean = valformat(np.mean(values)) interval = valformat(1.96 * np.var(values))[1:] # string = rf"${mean:.4}^{{\pm{interval:.3}}}$" string = rf"${mean}$" # ^{{\pm{interval}}}$" row.append(string) row = " & ".join(row) + r"\\" rows.append(row) models_result = "\n ".join(rows) models_results.append(models_result) sorting = ["former rc kl", "former rcxyz kl", "former rc rcxyz kl", "former rc rcxyz vel kl", "former rc rcxyz velxyz kl", "former rc rcxyz vel velxyz kl"] changing = {"rc": r"$\mathcal{L}_{R}$", "rcxyz": r"$\mathcal{L}_{O}$", "vel": r"$\mathcal{L}_{\Delta R}$", "velxyz": r"$\mathcal{L}_{\Delta O}$"} changing_jointstype = {"smpl": "J", "vertices": "V"} sorted_models = [gtrow, " \\midrule\n"] for sortkey in sorting: for models_result in models_results: if sortkey in models_result: modelsname = models_result.split("&")[0].rstrip() losses = sortkey.split(" ")[1:-1] # remove former and kl wlosses = [] for loss in losses: renaming = changing[loss] jtype = modelsname.split(" ")[-1] if jtype in changing_jointstype: renaming = renaming.replace("O", changing_jointstype[jtype]) wlosses.append(renaming) models_result = models_result.replace(modelsname, " + ".join(wlosses)) sorted_models.append(models_result) # MODELS = "\n \\midrule\n".join(sorted_models) MODELS = "\n".join(sorted_models) + "\n" template = r"""\documentclass{{standalone}} \usepackage{{booktabs}} \usepackage[dvipsnames]{{xcolor}} \begin{{document}} \begin{{tabular}}{{lcccc}} \toprule Architecture & FID$\downarrow$ & Acc.$\uparrow$ & Div.$\uparrow$ & Multimod.$\uparrow$\\ \midrule action2motion ground truth & $0.031^{{\pm.004}}$ & $0.999^{{\pm.001}}$ & $7.108^{{\pm.048}}$ & $2.194^{{\pm.025}}$ \\ action2motion lie model & $0.330^{{\pm.008}}$ & $0.949^{{\pm.001}}$ & $7.065^{{\pm.043}}$ & $2.052^{{\pm.030}}$ \\ \midrule {MODELS} \bottomrule \end{{tabular}} \end{{document}} """.format(MODELS=MODELS) return template if __name__ == "__main__": import argparse def parse_opts(): parser = argparse.ArgumentParser() parser.add_argument("exppath", help="name of the exp") return parser.parse_args() opt = parse_opts() exppath = opt.exppath folder = exppath tex = construct_table(folder) texpath = os.path.join(folder, "table.tex") with open(texpath, "w") as ftex: ftex.write(tex) print(f"Table saved at {texpath}") ================================================ FILE: PBnet/src/evaluate/tables/latexmodelsstgcn.py ================================================ import os import glob import math import re import numpy as np from .tools import load_metrics def get_gtname(mname): return mname + "_gt" def get_genname(mname): return mname + "_gen" def get_reconsname(mname): return mname + "_recons" def valformat(val, power=3): p = float(pow(10, power)) # "{:<04}".format(np.round(p*val).astype(int)/p) return str(np.round(p*val).astype(int)/p).ljust(5, "0") def construct_table(folder): exppath = folder paths = glob.glob(f"{exppath}/**/evaluation*0500_all.yaml") keys = ["fid", "accuracy", "diversity", "multimodality"] models_results = [] for i, path in enumerate(paths): metrics = load_metrics(path) stgcn = metrics["stgcn"] # easy fid gt for sets in ["train", "test"]: stgcn[f"fid_gt_{sets}"] = stgcn[f"fid_gt2_{sets}"] modelname = os.path.split(os.path.split(path)[0])[1] for info in ["vibe", "rot6d", "glob", "translation", "numlayers_8", "numframes_60", "sampling_conseq", "samplingstep_1", "jointstype", "gelu", "kl_1e-05", "cvae", "uestc"]: modelname = modelname.replace(info, "") modelname = re.sub("_{1,}", " ", modelname) # takin GT only for the first one if i == 0: gtrow = ["Our GT"] for sets in ["train", "test"]: for key in keys: ckey = f"{key}_gt_{sets}" values = np.array([float(x) for x in stgcn[ckey]]) mean = valformat(np.mean(values)) interval = valformat(1.96 * np.var(values))[1:] # string = rf"${mean:.4}^{{\pm{interval:.3}}}$" string = rf"${mean}$" # ^{{\pm{interval}}}$" gtrow.append(string) gtrow.append("") gtrow = " & ".join(gtrow[:-1]) + r"\\" rows = [] for model in ["gen"]: # ["gt", "gen", "recons"]: # row = ["{} {}".format(modelname, model)] row = [modelname] for sets in ["train", "test"]: for key in keys: ckey = f"{key}_{model}_{sets}" values = np.array([float(x) for x in stgcn[ckey]]) mean = valformat(np.mean(values)) interval = valformat(1.96 * np.var(values))[1:] # string = rf"${mean:.4}^{{\pm{interval:.3}}}$" string = rf"${mean}$" # ^{{\pm{interval}}}$" row.append(string) row.append("") row = " & ".join(row[:-1]) + r"\\" rows.append(row) models_result = "\n ".join(rows) models_results.append(models_result) sorting = ["former rc kl", "former rcxyz kl", "former rc rcxyz kl", "former rc rcxyz vel kl", "former rc rcxyz velxyz kl", "former rc rcxyz vel velxyz kl"] changing = {"rc": r"$\mathcal{L}_{R}$", "rcxyz": r"$\mathcal{L}_{O}$", "vel": r"$\mathcal{L}_{\Delta R}$", "velxyz": r"$\mathcal{L}_{\Delta O}$"} changing_jointstype = {"smpl": "J", "vertices": "V"} sorted_models = [gtrow, " \\midrule\n"] for sortkey in sorting: for models_result in models_results: if sortkey in models_result: modelsname = models_result.split("&")[0].rstrip() losses = sortkey.split(" ")[1:-1] # remove former and kl wlosses = [] for loss in losses: renaming = changing[loss] jtype = modelsname.split(" ")[-1] if jtype in changing_jointstype: renaming = renaming.replace("O", changing_jointstype[jtype]) wlosses.append(renaming) models_result = models_result.replace(modelsname, " + ".join(wlosses)) sorted_models.append(models_result) # MODELS = "\n \\midrule\n".join(sorted_models) MODELS = "\n".join(sorted_models) + "\n" template = r"""\documentclass{{standalone}} \usepackage{{booktabs}} \usepackage[dvipsnames]{{xcolor}} \begin{{document}} \begin{{tabular}}{{lccccccccc}} Architecture & FID$\downarrow$ & Acc.$\uparrow$ & Div.$\uparrow$ & Multimod.$\uparrow$ & & FID$\downarrow$ & Acc.$\uparrow$ & Div.$\uparrow$ & Multimod.$\uparrow$\\ \midrule \toprule {MODELS} \bottomrule \end{{tabular}} \end{{document}} """.format(MODELS=MODELS) return template if __name__ == "__main__": import argparse def parse_opts(): parser = argparse.ArgumentParser() parser.add_argument("exppath", help="name of the exp") return parser.parse_args() opt = parse_opts() exppath = opt.exppath folder = exppath tex = construct_table(folder) texpath = os.path.join(folder, "table.tex") with open(texpath, "w") as ftex: ftex.write(tex) print(f"Table saved at {texpath}") ================================================ FILE: PBnet/src/evaluate/tables/losstable.py ================================================ import os import glob import math import re import numpy as np from .tools import load_metrics def valformat(val, power=3): p = float(pow(10, power)) # "{:<04}".format(np.round(p*val).astype(int)/p) return str(np.round(p*val).astype(int)/p).ljust(4, "0") def format_values(values, key): mean = np.mean(values) if key == "accuracy": mean = 100*mean values = 100*values smean = valformat(mean, 1) else: smean = valformat(mean, 2) interval = valformat(1.96 * np.var(values), 2) # [1:] # string = rf"${mean:.4}^{{\pm{interval:.3}}}$" # string = rf"${smean}$" # ^{{\pm{interval}}}$" string = rf"${smean}^{{\pm{interval}}}$" return string def construct_table(folder): exppath = folder paths = glob.glob(f"{exppath}/**/evaluation*_all*.yaml") keys = ["fid", "accuracy", "diversity", "multimodality"] model_metrics_dataset = {"ntu13": {}, "uestc": {}} epoch_dataset = {"ntu13": 1000, "uestc": 500} for i, path in enumerate(paths): epoch = int(path.split("evaluation_metrics_")[1].split(".")[0].split("_")[0]) modelinfo = os.path.split(os.path.split(path)[0])[1] dataset = modelinfo.split("_kl_")[1].split("_")[0] # Take the right epoch if epoch_dataset[dataset] != epoch: continue if "vel" in modelinfo: continue if "rc_rcxyz_kl" in modelinfo: if "vertices" in modelinfo: name = r"$\mathcal{L}_{P}$ + $\mathcal{L}_{V}$" else: # name = r"$\mathcal{L}_{P}$ + $\mathcal{L}_{J}$" continue elif "rc_kl" in modelinfo: name = r"$\mathcal{L}_{P}$" elif "rcxyz_kl" in modelinfo: if "vertices" in modelinfo: name = r"$\mathcal{L}_{V}$" else: name = r"$\mathcal{L}_{J}$" else: print(f"weird: {modelinfo}") metrics = load_metrics(path) model_metrics = model_metrics_dataset[dataset] if dataset == "ntu13": a2m = metrics["action2motion"] if "GT" not in model_metrics: a2m["fid_gt"] = a2m["fid_gt2"] row = [] for key in keys: ckey = f"{key}_gt" values = np.array([float(x) for x in a2m[ckey]]) string = format_values(values, key) row.append(string) model_metrics["GT"] = row row = [] for key in keys: ckey = f"{key}_gen" values = np.array([float(x) for x in a2m[ckey]]) string = format_values(values, key) row.append(string) model_metrics[name] = row elif dataset == "uestc": stgcn = metrics["stgcn"] if "GT" not in model_metrics: for sets in ["train", "test"]: stgcn[f"fid_gt_{sets}"] = stgcn[f"fid_gt2_{sets}"] stgcnkeys = ["fid_gt_train", "fid_gt_test", "accuracy_gt_train", "diversity_gt_train", "multimodality_gt_train"] row = [] for ckey in stgcnkeys: values = np.array([float(x) for x in stgcn[ckey]]) string = format_values(values, ckey.split("_")[0]) row.append(string) model_metrics["GT"] = row stgcnkeys = ["fid_gen_train", "fid_gen_test", "accuracy_gen_train", "diversity_gen_train", "multimodality_gen_train"] row = [] for ckey in stgcnkeys: values = np.array([float(x) for x in stgcn[ckey]]) string = format_values(values, ckey.split("_")[0]) row.append(string) model_metrics[name] = row lossmodels = [r"$\mathcal{L}_{J}$", r"$\mathcal{L}_{P}$", r"$\mathcal{L}_{V}$", # r"$\mathcal{L}_{P}$ + $\mathcal{L}_{J}$", r"$\mathcal{L}_{P}$ + $\mathcal{L}_{V}$"] gtvalues = ["GT"] for dataset in ["uestc", "ntu13"]: model_metrics = model_metrics_dataset[dataset] gtvalues.extend(model_metrics["GT"]) gtrow = " & ".join(gtvalues) + r"\\" rows = [] for model in lossmodels: if model == "GT": continue values = [model] for dataset in ["uestc", "ntu13"]: model_metrics = model_metrics_dataset[dataset] if model in model_metrics: values.extend(model_metrics[model]) else: dummy = ["" for _ in range(len(model_metrics["GT"]))] values.extend(dummy) row = " & ".join(values) + r"\\" rows.append(row) rows = "\n".join(rows) template = r"""\documentclass{{standalone}} \usepackage{{booktabs}} \usepackage[dvipsnames]{{xcolor}} \begin{{document}} \begin{{tabular}}{{lccccc|cccc}} \toprule & \multicolumn{{5}}{{c}}{{UESTC}} & \multicolumn{{4}}{{|c}}{{NTU-13}} \\ Loss & FID$_{{tr}}$$\downarrow$ & FID$_{{test}}$$\downarrow$ & Acc.$\uparrow$ & Div.$\rightarrow$ & Multimod.$\rightarrow$ & FID$_{{tr}}$$\downarrow$ & Acc.$\uparrow$ & Div.$\rightarrow$ & Multimod.$\rightarrow$ \\ \midrule {gtrow} \midrule {rows} \bottomrule \end{{tabular}} \end{{document}} """.format(rows=rows, gtrow=gtrow) return template if __name__ == "__main__": import argparse def parse_opts(): parser = argparse.ArgumentParser() parser.add_argument("exppath", help="name of the exp") return parser.parse_args() opt = parse_opts() exppath = opt.exppath folder = exppath tex = construct_table(folder) texpath = os.path.join(folder, "table_loss.tex") with open(texpath, "w") as ftex: ftex.write(tex) print(f"Table saved at {texpath}") ================================================ FILE: PBnet/src/evaluate/tables/maketable.py ================================================ import os import glob import math from .tools import load_metrics METRICS = {"joints": ["acceleration", "rc", "diversity", "multimodality"], "action2motion": ["accuracy", "fid", "diversity", "multimodality"]} UP = r"$\uparrow$" DOWN = r"$\downarrow$" RIGHT = r"$\rightarrow$" ARROWS = {"accuracy": UP, "acceleration": RIGHT, "rc": DOWN, "fid": DOWN, "diversity": RIGHT, "multimodality": RIGHT} POSE_ORDER = ["xyz", "rotvec", "rotquat", "rotmat", "rot6d"] for pose in POSE_ORDER: METRICS[pose] = METRICS["joints"] GROUPORDER = POSE_ORDER + ["action2motion"] GREEN = "Green" RED = "Mahogany" def bold(string): return r"\textbf{{" + string + r"}}" def colorize_template(string, color): return r"\textcolor{{" + color + r"}}{{" + string + r"}}" def colorize_bold_template(string, color): return bold(colorize_template(string, color)) def format_table(val, gtval, mname): value = float(val) try: exp = math.floor(math.log10(value)) except ValueError: exp = 0 value = 0 if mname == "rc": formatter = "{:.1e}" if value >= 1: formatter = colorize_bold_template(formatter, RED) elif mname in ["diversity", "multimodality"]: if exp < -1: formatter = "{:.1e}" else: formatter = "{:.3g}" if gtval is not None: gtval = float(gtval) if value > 0.8*gtval: formatter = colorize_bold_template(formatter, GREEN) elif value < 0.3*gtval: formatter = colorize_bold_template(formatter, RED) elif mname == "accuracy": formatter = "{:.1%}" if value > 0.65: formatter = colorize_bold_template(formatter, GREEN) elif value < 0.35: formatter = colorize_bold_template(formatter, RED) elif mname == "acceleration": formatter = "{:.1e}" if gtval is not None: gtval = float(gtval) diff = math.log10(value/gtval) # below acceleration if diff < 0.05: formatter = colorize_bold_template(formatter, GREEN) elif diff > 0.3: formatter = colorize_bold_template(formatter, RED) else: formatter = "{:.2f}" formatter = bold(formatter) return formatter.format(value).replace("%", r"\%") def get_gtname(mname): return mname + "_gt" def get_genname(mname): return mname + "_gen" def get_reconsname(mname): return mname + "_recons" def collect_tables(folder, expname, lastepoch=False, norecons=False): exppath = os.path.join(folder, expname) paths = glob.glob(f"{exppath}/**/evaluation*") if len(paths) == 0: raise ValueError("No evaluation founds.") pose_rep, *losses = expname.split("_") expname = expname.replace("_", "\\_") models_kl = {} allkls = set() models_epochs = {} for path in paths: metrics = load_metrics(path) fname = os.path.split(path)[0] modelname = fname.split("cvae_")[1].split("_rc")[0] kl_loss = float(fname.split("_kl_")[2].split("_")[0]) epoch = os.path.split(path)[1].split("evaluation_metrics_")[1].split(".")[0] if lastepoch: if modelname not in models_epochs: models_epochs[modelname] = epoch else: if models_epochs[modelname] > epoch: continue else: models_epochs[modelname] = epoch modelname = rf"{modelname}" else: modelname = rf"{modelname}\_{epoch}" if "numlayers" in fname: nlayers = int(fname.split("numlayers")[1].split("_")[1]) modelname += rf"\_nlayer\_{nlayers}" if "relu" in fname: activation = "relu" elif "gelu" in fname: activation = "gelu" else: activation = "" modelname += rf"\_{activation}" try: ablation = fname.split("abl_")[1].split("_sampling")[0] ablation = ablation.replace("_", r"\_") modelname += rf"\_{ablation}" except IndexError: modelname += r"\_noablation" if modelname not in models_kl: models_kl[modelname] = {} models_kl[modelname][kl_loss] = metrics allkls.add(kl_loss) lambdas_sorted = sorted(list(allkls), reverse=True) gtrowl = ["ground truth"] for group in GROUPORDER: if group in metrics: for mname in METRICS[group]: gtname = get_gtname(mname) if gtname in metrics[group]: val = format_table(metrics[group][gtname], None, mname) gtrowl.append(val) else: gtrowl.append("") gtrowl.append("") gtrowl.pop() gtrow = " & ".join(gtrowl) + r"\\" bodylist = [gtrow] bodylist.append(r"\midrule") modelnames = sorted(list(models_kl.keys())) # compute first rows # to add a first col firstrow = [""] for group in GROUPORDER: if group in metrics: for mname in METRICS[group]: mname = f"{mname} {ARROWS[mname]}" firstrow.append(mname) firstrow.append("") firstrow.pop() firstrow = " & ".join(firstrow) + r"\\" for lam in lambdas_sorted: for modelname in modelnames: if lam in models_kl[modelname]: metrics = models_kl[modelname][lam] row = [f"{modelname} {lam}"] for group in GROUPORDER: if group in metrics: for mname in METRICS[group]: gtname = get_gtname(mname) gtval = metrics[group][gtname] if gtname in metrics[group] else None genname = get_genname(mname) reconsname = get_reconsname(mname) if not norecons and genname in metrics[group] and reconsname in metrics[group]: genval = format_table(metrics[group][genname], gtval, mname) reconsval = format_table(metrics[group][reconsname], gtval, mname) row.append(f"{genval}/{reconsval}") elif genname in metrics[group]: genval = format_table(metrics[group][genname], gtval, mname) row.append(f"{genval}") elif reconsname in metrics[group]: reconsval = format_table(metrics[group][reconsname], gtval, mname) row.append(f"{reconsval}") else: print(f"{mname} is not present in this evaluation") row.append("") row.pop() row = " & ".join(row) + r"\\" bodylist.append(row) # bodylist.append(emptyrow) bodylist.append(r"\midrule") bodylist.append(r"\bottomrule") body = "\n".join(bodylist) ncols = len(gtrowl) title = f"Evaluation of {expname} experiment" template = r"""\documentclass{{standalone}} \usepackage{{booktabs}} \usepackage[dvipsnames]{{xcolor}} \begin{{document}} \begin{{tabular}}{{{ncolsl}}} \multicolumn{{{ncols}}}{{c}}{{{title}}} \\ \multicolumn{{{ncols}}}{{c}}{{}} \\ & \multicolumn{{{nbcolsxyz}}}{{c}}{{xyz}} & & \multicolumn{{{nbcolspose}}}{{c}}{{{pose_rep}}} & & \multicolumn{{{nbcolsa2m}}}{{c}}{{action2motion}} \\ {firstrow} \midrule {body} \end{{tabular}} \end{{document}} """.format(ncolsl="l"+"c"*(ncols-1), ncols=ncols, pose_rep=pose_rep, title=title, firstrow=firstrow, nbcolsxyz=len(METRICS["joints"]), nbcolspose=len(METRICS[pose_rep]), nbcolsa2m=len(METRICS["action2motion"]), body=body) return template if __name__ == "__main__": import argparse def parse_opts(): parser = argparse.ArgumentParser() parser.add_argument("exppath", help="name of the exp") parser.add_argument("--outpath", default="tex", help="name of the exp") parser.add_argument("--norecons", dest='norecons', action='store_true') parser.set_defaults(norecons=False) parser.add_argument("--lastepoch", dest='lastepoch', action='store_true') parser.set_defaults(lastepoch=False) return parser.parse_args() opt = parse_opts() exppath = opt.exppath norecons = opt.norecons lastepoch = opt.lastepoch folder, expname = os.path.split(exppath) template = collect_tables(folder, expname, lastepoch=lastepoch, norecons=norecons) # os.makedirs(opt.outpath, exist_ok=True) name = expname if norecons: name += "_norecons" texpath = os.path.join(exppath, name + ".tex") with open(texpath, "w") as ftex: ftex.write(template) print(f"Table saved at {texpath}") ================================================ FILE: PBnet/src/evaluate/tables/numlayertable.py ================================================ import os import glob import math import re import numpy as np from .tools import load_metrics def valformat(val, power=3): p = float(pow(10, power)) # "{:<04}".format(np.round(p*val).astype(int)/p) return str(np.round(p*val).astype(int)/p).ljust(4, "0") def format_values(values, key): mean = np.mean(values) if key == "accuracy": mean = 100*mean values = 100*values smean = valformat(mean, 1) else: smean = valformat(mean, 2) interval = valformat(1.96 * np.var(values), 2) # [1:] # string = rf"${mean:.4}^{{\pm{interval:.3}}}$" # string = rf"${smean}$" # ^{{\pm{interval}}}$" string = rf"${smean}^{{\pm{interval}}}$" return string def construct_table(folder): exppath = folder paths = glob.glob(f"{exppath}/**/evaluation*_all*.yaml") keys = ["fid", "accuracy", "diversity", "multimodality"] model_metrics_dataset = {"ntu13": {}, "uestc": {}} epoch_dataset = {"ntu13": 1000, "uestc": 500} for i, path in enumerate(paths): epoch = int(path.split("evaluation_metrics_")[1].split(".")[0].split("_")[0]) modelinfo = os.path.split(os.path.split(path)[0])[1] dataset = modelinfo.split("_kl_")[1].split("_")[0] # Take the right epoch if epoch_dataset[dataset] != epoch: continue name = "numlayers " + modelinfo.split("numlayers_")[1].split("_")[0] metrics = load_metrics(path) model_metrics = model_metrics_dataset[dataset] if dataset == "ntu13": a2m = metrics["action2motion"] if "GT" not in model_metrics: a2m["fid_gt"] = a2m["fid_gt2"] row = [] for key in keys: ckey = f"{key}_gt" values = np.array([float(x) for x in a2m[ckey]]) string = format_values(values, key) row.append(string) model_metrics["GT"] = row row = [] for key in keys: ckey = f"{key}_gen" values = np.array([float(x) for x in a2m[ckey]]) string = format_values(values, key) row.append(string) model_metrics[name] = row elif dataset == "uestc": stgcn = metrics["stgcn"] if "GT" not in model_metrics: for sets in ["train", "test"]: stgcn[f"fid_gt_{sets}"] = stgcn[f"fid_gt2_{sets}"] stgcnkeys = ["fid_gt_train", "fid_gt_test", "accuracy_gt_train", "diversity_gt_train", "multimodality_gt_train"] row = [] for ckey in stgcnkeys: values = np.array([float(x) for x in stgcn[ckey]]) string = format_values(values, ckey.split("_")[0]) row.append(string) model_metrics["GT"] = row stgcnkeys = ["fid_gen_train", "fid_gen_test", "accuracy_gen_train", "diversity_gen_train", "multimodality_gen_train"] row = [] for ckey in stgcnkeys: values = np.array([float(x) for x in stgcn[ckey]]) string = format_values(values, ckey.split("_")[0]) row.append(string) model_metrics[name] = row gtvalues = ["GT"] for dataset in ["uestc", "ntu13"]: model_metrics = model_metrics_dataset[dataset] gtvalues.extend(model_metrics["GT"]) gtrow = " & ".join(gtvalues) + r"\\" rows = [] modelnames = sorted(list(model_metrics.keys())) for model in modelnames: if model == "GT": continue values = [model] for dataset in ["uestc", "ntu13"]: model_metrics = model_metrics_dataset[dataset] if model in model_metrics: values.extend(model_metrics[model]) else: dummy = ["" for _ in range(len(model_metrics["GT"]))] values.extend(dummy) row = " & ".join(values) + r"\\" rows.append(row) rows = "\n".join(rows) template = r"""\documentclass{{standalone}} \usepackage{{booktabs}} \usepackage[dvipsnames]{{xcolor}} \begin{{document}} \begin{{tabular}}{{lccccc|cccc}} \toprule & \multicolumn{{5}}{{c}}{{UESTC}} & \multicolumn{{4}}{{|c}}{{NTU-13}} \\ Loss & FID$_{{tr}}$$\downarrow$ & FID$_{{test}}$$\downarrow$ & Acc.$\uparrow$ & Div.$\rightarrow$ & Multimod.$\rightarrow$ & FID$_{{tr}}$$\downarrow$ & Acc.$\uparrow$ & Div.$\rightarrow$ & Multimod.$\rightarrow$ \\ \midrule {gtrow} \midrule {rows} \bottomrule \end{{tabular}} \end{{document}} """.format(rows=rows, gtrow=gtrow) return template if __name__ == "__main__": import argparse def parse_opts(): parser = argparse.ArgumentParser() parser.add_argument("exppath", help="name of the exp") return parser.parse_args() opt = parse_opts() exppath = opt.exppath folder = exppath tex = construct_table(folder) texpath = os.path.join(folder, "table_loss.tex") with open(texpath, "w") as ftex: ftex.write(tex) print(f"Table saved at {texpath}") ================================================ FILE: PBnet/src/evaluate/tables/posereptable.py ================================================ import os import glob import math import re import numpy as np from .tools import load_metrics def valformat(val, power=3): p = float(pow(10, power)) # "{:<04}".format(np.round(p*val).astype(int)/p) return str(np.round(p*val).astype(int)/p).ljust(4, "0") def format_values(values, key): mean = np.mean(values) if key == "accuracy": mean = 100*mean values = 100*values smean = valformat(mean, 1) else: smean = valformat(mean, 2) interval = valformat(1.96 * np.var(values), 2) # [1:] # string = rf"${mean:.4}^{{\pm{interval:.3}}}$" # string = rf"${smean}$" # ^{{\pm{interval}}}$" string = rf"${smean}^{{\pm{interval}}}$" return string def construct_table(folder): exppath = folder paths = glob.glob(f"{exppath}/**/evaluation*_all*.yaml") keys = ["fid", "accuracy", "diversity", "multimodality"] model_metrics_dataset = {"ntu13": {}, "uestc": {}} epoch_dataset = {"ntu13": 1000, "uestc": 500} for i, path in enumerate(paths): epoch = int(path.split("evaluation_metrics_")[1].split(".")[0].split("_")[0]) modelinfo = os.path.split(os.path.split(path)[0])[1] dataset = modelinfo.split("_kl_")[1].split("_")[0] # Take the right epoch if epoch_dataset[dataset] != epoch: continue name = "Pose rep " + modelinfo.split("_vibe_")[1].split("_")[0] if "xyz" in name: continue metrics = load_metrics(path) model_metrics = model_metrics_dataset[dataset] if dataset == "ntu13": a2m = metrics["action2motion"] if "GT" not in model_metrics: a2m["fid_gt"] = a2m["fid_gt2"] row = [] for key in keys: ckey = f"{key}_gt" values = np.array([float(x) for x in a2m[ckey]]) string = format_values(values, key) row.append(string) model_metrics["GT"] = row row = [] for key in keys: ckey = f"{key}_gen" values = np.array([float(x) for x in a2m[ckey]]) string = format_values(values, key) row.append(string) model_metrics[name] = row elif dataset == "uestc": stgcn = metrics["stgcn"] if "GT" not in model_metrics: for sets in ["train", "test"]: stgcn[f"fid_gt_{sets}"] = stgcn[f"fid_gt2_{sets}"] stgcnkeys = ["fid_gt_train", "fid_gt_test", "accuracy_gt_train", "diversity_gt_train", "multimodality_gt_train"] row = [] for ckey in stgcnkeys: values = np.array([float(x) for x in stgcn[ckey]]) string = format_values(values, ckey.split("_")[0]) row.append(string) model_metrics["GT"] = row stgcnkeys = ["fid_gen_train", "fid_gen_test", "accuracy_gen_train", "diversity_gen_train", "multimodality_gen_train"] row = [] for ckey in stgcnkeys: values = np.array([float(x) for x in stgcn[ckey]]) string = format_values(values, ckey.split("_")[0]) row.append(string) model_metrics[name] = row gtvalues = ["GT"] for dataset in ["uestc", "ntu13"]: model_metrics = model_metrics_dataset[dataset] if "GT" not in model_metrics and dataset == "uestc": gtvalues.extend([" "] * (5 if dataset == "uestc" else 4)) else: gtvalues.extend(model_metrics["GT"]) gtrow = " & ".join(gtvalues) + r"\\" rows = [] modelnames = sorted(list(model_metrics_dataset["ntu13"].keys())) import ipdb; ipdb.set_trace() for model in modelnames: if model == "GT": continue values = [model] for dataset in ["uestc", "ntu13"]: model_metrics = model_metrics_dataset[dataset] if model in model_metrics: values.extend(model_metrics[model]) else: dummy = ["" for _ in range(5 if dataset == "uestc" else 4)] values.extend(dummy) row = " & ".join(values) + r"\\" rows.append(row) rows = "\n".join(rows) template = r"""\documentclass{{standalone}} \usepackage{{booktabs}} \usepackage[dvipsnames]{{xcolor}} \begin{{document}} \begin{{tabular}}{{lccccc|cccc}} \toprule & \multicolumn{{5}}{{c}}{{UESTC}} & \multicolumn{{4}}{{|c}}{{NTU-13}} \\ Loss & FID$_{{tr}}$$\downarrow$ & FID$_{{test}}$$\downarrow$ & Acc.$\uparrow$ & Div.$\rightarrow$ & Multimod.$\rightarrow$ & FID$_{{tr}}$$\downarrow$ & Acc.$\uparrow$ & Div.$\rightarrow$ & Multimod.$\rightarrow$ \\ \midrule {gtrow} \midrule {rows} \bottomrule \end{{tabular}} \end{{document}} """.format(rows=rows, gtrow=gtrow) return template if __name__ == "__main__": import argparse def parse_opts(): parser = argparse.ArgumentParser() parser.add_argument("exppath", help="name of the exp") return parser.parse_args() opt = parse_opts() exppath = opt.exppath folder = exppath tex = construct_table(folder) texpath = os.path.join(folder, "table_loss.tex") with open(texpath, "w") as ftex: ftex.write(tex) print(f"Table saved at {texpath}") ================================================ FILE: PBnet/src/evaluate/tools.py ================================================ import yaml def format_metrics(metrics, formatter="{:.6}"): newmetrics = {} for key, val in metrics.items(): newmetrics[key] = formatter.format(val) return newmetrics def save_metrics(path, metrics): with open(path, "w") as yfile: yaml.dump(metrics, yfile) def load_metrics(path): with open(path, "r") as yfile: string = yfile.read() return yaml.load(string, yaml.loader.BaseLoader) ================================================ FILE: PBnet/src/evaluate/tvae_eval.py ================================================ import torch from tqdm import tqdm from src.utils.fixseed import fixseed from torch.utils.data import DataLoader from src.utils.tensors_hdtf import collate import os import numpy as np import torch.nn.functional as F # from .tools import save_metrics, format_metrics from src.models.get_model import get_model as get_gen_model def evaluate(parameters, dataset, folder, checkpointname, epoch, niter): # num_frames = 60 device = parameters["device"] # dummy => update parameters info model = get_gen_model(parameters) print("Restore weights..") checkpointpath = os.path.join(folder, checkpointname) state_dict = torch.load(checkpointpath, map_location=device) model.load_state_dict(state_dict) model.eval() if checkpointname.split("_")[0] == 'retraincheckpoint': save_folder = os.path.join(folder, 'fintune', checkpointname.split('_')[2]+'_'+checkpointname.split('_')[4].split('.')[0]) else: save_folder = os.path.join(folder, 'nofinetune', checkpointname.split('_')[1].split('.')[0]) os.makedirs(save_folder, exist_ok=True) allseeds = list(range(niter)) try: for index, seed in enumerate(allseeds): print(f"Evaluation number: {index+1}/{niter}") fixseed(seed) save_pred_path = os.path.join(save_folder, 'eval_pred', str(seed)) save_gt_path = os.path.join(save_folder, 'eval_gt', str(seed)) os.makedirs(save_pred_path, exist_ok=True) os.makedirs(save_gt_path, exist_ok=True) dataiterator = DataLoader(dataset, batch_size=parameters["batch_size"], shuffle=False, num_workers=8, collate_fn=collate) with torch.no_grad(): for databatch in tqdm(dataiterator, desc=f"Construct dataloader: generating.."): # batch = {key: val.to(device) for key, val in databatch.items()} pose = databatch["x"] audio = databatch["y"] gendurations = databatch["lengths"] # start = databatch["start"] batch = model.generate(pose, audio, gendurations) batch = {key: val.to(device) for key, val in batch.items()} for pose_pre, pose_gt, mask, filename, start_num in zip(batch['output'], databatch['x'], databatch['mask'], databatch['videoname'], databatch['start']): # # x_ref = pose_gt[0,:].unsqueeze(dim=0).cpu() # pose_pre = (pose_pre.cpu()+x_ref - 0.5) * 180 # gtmasked = (pose_gt[mask].cpu() -0.5 ) * 180 # x_ref = pose_gt[0,:].unsqueeze(dim=0) pose_pre = pose_pre.cpu()+x_ref gtmasked = pose_gt[mask].cpu() outmasked = pose_pre[mask].cpu() pred_path = os.path.join(save_pred_path, filename+'_'+str(start_num)) gt_path = os.path.join(save_gt_path, filename+'_'+str(start_num)+'_gt') # np.save(pred_path, pose_pre.cpu()) # np.save(gt_path, pose_gt.cpu()) np.savetxt(pred_path, outmasked) np.savetxt(gt_path, gtmasked) loss = F.mse_loss(gtmasked, outmasked, reduction='mean') print(loss) except KeyboardInterrupt: string = "Saving the evaluation before exiting.." print(string) epoch = checkpointname.split("_")[1].split(".")[0] metricname = "evaluation_metrics_{}_all.yaml".format(epoch) evalpath = os.path.join(folder, metricname) print(f"Saving evaluation: {evalpath}") # save_metrics(evalpath, metrics) ================================================ FILE: PBnet/src/evaluate/tvae_eval_norm.py ================================================ import torch from tqdm import tqdm from src.utils.fixseed import fixseed from src.utils.utils import MultiEpochsDataLoader as DataLoader from src.utils.tensors_hdtf import collate_old import os import numpy as np import torch.nn.functional as F # from .tools import save_metrics, format_metrics from src.models.get_model import get_model as get_gen_model def transform(x, min_val, max_val): out = x * (max_val - min_val) + min_val return out def evaluate(parameters, dataset, folder, checkpointname, epoch, niter): # num_frames = 60 min_val = dataset.min_vals max_val = dataset.max_vals device = parameters["device"] # dummy => update parameters info model = get_gen_model(parameters) print("Restore weights..") checkpointpath = os.path.join(folder, checkpointname) state_dict = torch.load(checkpointpath, map_location=device) model.load_state_dict(state_dict) model.eval() if checkpointname.split("_")[0] == 'retraincheckpoint': save_folder = os.path.join(folder, 'fintune', checkpointname.split('_')[2]+'_'+checkpointname.split('_')[4].split('.')[0]) else: save_folder = os.path.join(folder, 'nofinetune', checkpointname.split('_')[1].split('.')[0]) os.makedirs(save_folder, exist_ok=True) allseeds = list(range(niter)) try: for index, seed in enumerate(allseeds): print(f"Evaluation number: {index+1}/{niter}") fixseed(seed) save_pred_path = os.path.join(save_folder, 'eval_pred', str(seed)) save_gt_path = os.path.join(save_folder, 'eval_gt', str(seed)) os.makedirs(save_pred_path, exist_ok=True) os.makedirs(save_gt_path, exist_ok=True) dataiterator = DataLoader(dataset, batch_size=parameters["batch_size"], shuffle=False, num_workers=8, collate_fn=collate_old) with torch.no_grad(): for databatch in tqdm(dataiterator, desc=f"Construct dataloader: generating.."): # batch = {key: val.to(device) for key, val in databatch.items()} pose = databatch["x"] audio = databatch["y"] gendurations = databatch["lengths"] # start = databatch["start"] batch = model.generate(pose, audio, gendurations, fact = 1) batch = {key: val.to(device) for key, val in batch.items()} for pose_pre, pose_gt, mask, filename, start_num in zip(batch['output'], databatch['x'], databatch['mask'], databatch['videoname'], databatch['start']): x_ref = pose_gt[0,:].unsqueeze(dim=0) pose_pre = pose_pre.cpu()+x_ref gtmasked = pose_gt[mask].cpu() outmasked = pose_pre[mask].cpu() gtmasked = transform(gtmasked, min_val, max_val) outmasked = transform(outmasked, min_val, max_val) pred_path = os.path.join(save_pred_path, filename+'_'+str(start_num)) gt_path = os.path.join(save_gt_path, filename+'_'+str(start_num)+'_gt') # np.save(pred_path, pose_pre.cpu()) # np.save(gt_path, pose_gt.cpu()) np.savetxt(pred_path, outmasked) np.savetxt(gt_path, gtmasked) loss = F.mse_loss(gtmasked, outmasked, reduction='mean') print(loss) except KeyboardInterrupt: string = "Saving the evaluation before exiting.." print(string) epoch = checkpointname.split("_")[1].split(".")[0] metricname = "evaluation_metrics_{}_all.yaml".format(epoch) evalpath = os.path.join(folder, metricname) print(f"Saving evaluation: {evalpath}") # save_metrics(evalpath, metrics) ================================================ FILE: PBnet/src/evaluate/tvae_eval_norm_all.py ================================================ import torch from tqdm import tqdm from src.utils.fixseed import fixseed from src.utils.utils import MultiEpochsDataLoader as DataLoader from src.utils.tensors_hdtf import collate_old import os import numpy as np import torch.nn.functional as F # from .tools import save_metrics, format_metrics from src.models.get_model import get_model as get_gen_model def save_images_as_npy(input_data, output_file): # save_npy = np.zeros(input_data.shape[0], 7) # save_npy[:,:, :-1] = input_data # save_npy[:, -1] = ref[:,:, -1] # images_array = np.array(images) np.save(output_file, input_data) def save_as_chunk(dir, data): if not os.path.exists(dir): os.makedirs(dir) chunks = [data[i:min(i + 25, data.shape[0])] for i in range(0, data.shape[0], 25)] for i, chunk in enumerate(chunks): output_file = os.path.join(dir, f'chunk_%04d.npy' % (i)) # chunk = np.stack(chunk, axis = 0) save_images_as_npy(chunk, output_file) def transform(x, min_val, max_val): out = x * (max_val - min_val) + min_val return out def evaluate(parameters, dataset, folder, checkpointname, epoch, niter): # num_frames = 60 min_val = dataset.min_vals max_val = dataset.max_vals device = parameters["device"] # dummy => update parameters info model = get_gen_model(parameters) print("Restore weights..") checkpointpath = os.path.join(folder, checkpointname) model_ckpt = model.state_dict() state_dict = torch.load(checkpointpath, map_location=device) for name, _ in model_ckpt.items(): if model_ckpt[name].shape == state_dict[name].shape: model_ckpt[name].copy_(state_dict[name]) model.load_state_dict(model_ckpt) # model.load_state_dict(state_dict) model.eval() if checkpointname.split("_")[0] == 'retraincheckpoint': save_folder = os.path.join(folder, 'fintune', checkpointname.split('_')[2]+'_'+checkpointname.split('_')[4].split('.')[0]) else: save_folder = os.path.join(folder, 'nofinetune', checkpointname.split('_')[1].split('.')[0]) os.makedirs(save_folder, exist_ok=True) allseeds = list(range(niter)) try: for index, seed in enumerate(allseeds): print(f"Evaluation number: {index+1}/{niter}") fixseed(seed) save_pred_path = os.path.join(save_folder, 'eval_pred', str(seed)) save_gt_path = os.path.join(save_folder, 'eval_gt', str(seed)) os.makedirs(save_pred_path, exist_ok=True) os.makedirs(save_gt_path, exist_ok=True) dataiterator = DataLoader(dataset, batch_size=parameters["batch_size"], shuffle=False, num_workers=8, collate_fn=collate_old) with torch.no_grad(): for databatch in tqdm(dataiterator, desc=f"Construct dataloader: generating.."): # batch = {key: val.to(device) for key, val in databatch.items()} pose = databatch["x"][:,:,:-1] # b, len, c ref = databatch['x'][:,:, -1] audio = databatch["y"] # b, len, c gendurations = databatch["lengths"] # start = databatch["start"] batch = model.generate(pose, audio, gendurations, fact = 1) batch = {key: val.to(device) for key, val in batch.items()} for pose_pre, pose_gt, mask, filename, start_num in zip(batch['output'], databatch['x'], databatch['mask'], databatch['videoname'], databatch['start']): x_ref = pose_gt[0,:].unsqueeze(dim=0) pose_pre = pose_pre.cpu() padding_vec = torch.zeros(pose_pre.shape[0], 1) pose_pre = torch.concat([pose_pre, padding_vec], dim = -1) pose_pre = pose_pre.cpu()+x_ref gtmasked = pose_gt[mask].cpu() outmasked = pose_pre[mask].cpu() gtmasked = transform(gtmasked, min_val, max_val) outmasked = transform(outmasked, min_val, max_val) pred_dir = os.path.join(save_pred_path, filename) save_as_chunk(pred_dir, outmasked) # np.save(pred_path, pose_pre.cpu()) # np.save(gt_path, pose_gt.cpu()) # np.savetxt(pred_path, outmasked) # np.savetxt(gt_path, gtmasked) loss = F.mse_loss(gtmasked, outmasked, reduction='mean') print(loss) except KeyboardInterrupt: string = "Saving the evaluation before exiting.." print(string) epoch = checkpointname.split("_")[1].split(".")[0] metricname = "evaluation_metrics_{}_all.yaml".format(epoch) evalpath = os.path.join(folder, metricname) print(f"Saving evaluation: {evalpath}") # save_metrics(evalpath, metrics) ================================================ FILE: PBnet/src/evaluate/tvae_eval_norm_eye_pose.py ================================================ import torch from tqdm import tqdm from src.utils.fixseed import fixseed from src.utils.utils import MultiEpochsDataLoader as DataLoader from src.utils.tensors_eye_eval import collate import os import numpy as np import torch.nn.functional as F import time # from .tools import save_metrics, format_metrics from src.models.get_model import get_model as get_gen_model def transform(x, min_val, max_val): out = x * (max_val - min_val) + min_val return out def evaluate(parameters, dataset, folder, checkpointname, epoch, niter): # num_frames = 60 min_val = dataset.min_vals max_val = dataset.max_vals device = parameters["device"] # dummy => update parameters info model = get_gen_model(parameters) print("Restore weights..") checkpointpath = os.path.join(folder, checkpointname) state_dict = torch.load(checkpointpath, map_location=device) model.load_state_dict(state_dict) model.eval() if checkpointname.split("_")[0] == 'retraincheckpoint': save_folder = os.path.join(folder, 'fintune', checkpointname.split('_')[2]+'_'+checkpointname.split('_')[4].split('.')[0]) else: save_folder = os.path.join(folder, 'nofinetune', checkpointname.split('_')[1].split('.')[0]) os.makedirs(save_folder, exist_ok=True) allseeds = list(range(niter)) try: for index, seed in enumerate(allseeds): print(f"Evaluation number: {index+1}/{niter}") fixseed(seed) save_pred_path = os.path.join(save_folder, 'eval_pred', str(seed)) save_gt_path = os.path.join(save_folder, 'eval_gt', str(seed)) os.makedirs(save_pred_path, exist_ok=True) os.makedirs(save_gt_path, exist_ok=True) dataiterator = DataLoader(dataset, batch_size=parameters["batch_size"], shuffle=False, num_workers=8, collate_fn=collate) with torch.no_grad(): for databatch in tqdm(dataiterator, desc=f"Construct dataloader: generating.."): # batch = {key: val.to(device) for key, val in databatch.items()} pose_eye = databatch["x"] audio = databatch["y"] gendurations = databatch["lengths"] # start = databatch["start"] start_time = time.time() batch = model.generate(pose_eye, audio, gendurations, fact = 1) end_time = time.time() print(f'generate audio time {end_time- start_time}') start_time = end_time # exit() batch = {key: val.to(device) for key, val in batch.items()} for pose_eye_pre, pose_eye_gt, mask, filename, start_num in zip(batch['output'], databatch['x'], databatch['mask'], databatch['videoname'], databatch['start']): x_ref = pose_eye_gt[0,:].unsqueeze(dim=0) pose_eye_pre = pose_eye_pre.cpu()+x_ref gtmasked = pose_eye_gt[mask].cpu() outmasked = pose_eye_pre[mask].cpu() gtmasked[:,:-2] = transform(gtmasked[:,:-2], min_val, max_val) outmasked[:,:-2] = transform(outmasked[:,:-2], min_val, max_val) pred_path = os.path.join(save_pred_path, filename+'_'+str(start_num)) gt_path = os.path.join(save_gt_path, filename+'_'+str(start_num)+'_gt') # np.save(pred_path, pose_pre.cpu()) # np.save(gt_path, pose_gt.cpu()) np.savetxt(pred_path, outmasked) np.savetxt(gt_path, gtmasked) loss = F.mse_loss(gtmasked[:,:3], outmasked[:,:3], reduction='mean') print(loss) except KeyboardInterrupt: string = "Saving the evaluation before exiting.." print(string) epoch = checkpointname.split("_")[1].split(".")[0] metricname = "evaluation_metrics_{}_all.yaml".format(epoch) evalpath = os.path.join(folder, metricname) print(f"Saving evaluation: {evalpath}") # save_metrics(evalpath, metrics) ================================================ FILE: PBnet/src/evaluate/tvae_eval_norm_eye_pose_seg.py ================================================ import torch from tqdm import tqdm from src.utils.fixseed import fixseed from src.utils.utils import MultiEpochsDataLoader as DataLoader from src.utils.tensors_eye_eval import collate import os import numpy as np import torch.nn.functional as F import time # from .tools import save_metrics, format_metrics from src.models.get_model import get_model as get_gen_model INF_LENGTH = 200 def transform(x, min_val, max_val): out = x * (max_val - min_val) + min_val return out def save_images_as_npy(input_data, output_file): # save_npy = np.zeros(input_data.shape[0], 7) # save_npy[:,:, :-1] = input_data # save_npy[:, -1] = ref[:,:, -1] # images_array = np.array(images) np.save(output_file, input_data) def save_as_chunk(dir, data): if not os.path.exists(dir): os.makedirs(dir) chunks = [data[i:min(i + 25, data.shape[0])] for i in range(0, data.shape[0], 25)] for i, chunk in enumerate(chunks): output_file = os.path.join(dir, f'chunk_%04d.npy' % (i)) # chunk = np.stack(chunk, axis = 0) save_images_as_npy(chunk, output_file) def evaluate(parameters, dataset, folder, checkpointname, epoch, niter): # num_frames = 60 min_val = dataset.min_vals max_val = dataset.max_vals device = parameters["device"] # dummy => update parameters info model = get_gen_model(parameters) print("Restore weights..") checkpointpath = os.path.join(folder, checkpointname) state_dict = torch.load(checkpointpath, map_location=device) model.load_state_dict(state_dict) model.eval() if checkpointname.split("_")[0] == 'retraincheckpoint': save_folder = os.path.join(folder, 'fintune', checkpointname.split('_')[2]+'_'+checkpointname.split('_')[4].split('.')[0]) else: save_folder = os.path.join(folder, 'nofinetune', checkpointname.split('_')[1].split('.')[0]) os.makedirs(save_folder, exist_ok=True) allseeds = list(range(niter)) try: for index, seed in enumerate(allseeds): print(f"Evaluation number: {index+1}/{niter}") fixseed(seed) save_pred_path_pose = os.path.join(save_folder, 'eval_pred', str(seed),'pose') save_gt_path_pose = os.path.join(save_folder, 'eval_gt', str(seed),'pose') save_pred_path_eye = os.path.join(save_folder, 'eval_pred', str(seed),'eye') save_gt_path_eye = os.path.join(save_folder, 'eval_gt', str(seed),'eye') os.makedirs(save_pred_path_pose, exist_ok=True) os.makedirs(save_gt_path_pose, exist_ok=True) os.makedirs(save_pred_path_eye, exist_ok=True) os.makedirs(save_gt_path_eye, exist_ok=True) dataiterator = DataLoader(dataset, batch_size=parameters["batch_size"], shuffle=False, num_workers=8, collate_fn=collate) with torch.no_grad(): for databatch in tqdm(dataiterator, desc=f"Construct dataloader: generating.."): # batch = {key: val.to(device) for key, val in databatch.items()} pose_eye = databatch["x"] audio = databatch["y"] gendurations = databatch["lengths"] # start = databatch["start"] # start_time = time.time() # batch = model.generate(pose_eye, audio, gendurations, fact = 1) # end_time = time.time() # print(f'generate audio time {end_time- start_time}') # start_time = end_time # # exit() # batch = {key: val.to(device) for key, val in batch.items()} output = None for i in range(0, pose_eye.shape[1], INF_LENGTH): # step 1: seg start = i end = min(pose_eye.shape[1], i + INF_LENGTH) pose_seg = pose_eye[:, start:end] audio_seg = audio[:, start:end] gendurations_seg = torch.tensor([end - start]) # step 2: predict batch = model.generate(pose_seg, audio_seg, gendurations_seg, fact = 1) # step 3: merge if output == None: output = batch['output'].detach().cpu() else: output = torch.concat([output, batch['output'].detach().cpu()], dim= 1) for pose_pre, pose_gt, mask, filename, start_num in zip(output, databatch['x'], databatch['mask'], databatch['videoname'], databatch['start']): pose_pre = pose_pre.cpu() # padding_vec = torch.zeros(pose_pre.shape[0], 1) # pose_pre = torch.concat([pose_pre], dim = -1) for i in range(0, pose_gt.shape[0], INF_LENGTH): start = i end = min(pose_gt.shape[0], i + INF_LENGTH) x_ref = pose_gt[i,:].unsqueeze(dim=0) pose_pre[start:end] = pose_pre[start:end]+x_ref gtmasked = pose_gt[mask].cpu() outmasked = pose_pre[mask].cpu() gtmasked[:,:-2] = transform(gtmasked[:,:-2], min_val, max_val) outmasked[:,:-2] = transform(outmasked[:,:-2], min_val, max_val) pred_dir_pose = os.path.join(save_pred_path_pose, filename) pred_dir_eye = os.path.join(save_pred_path_eye, filename) out_eye = outmasked[:, 6:] out_pose = outmasked[:, :6] save_as_chunk(pred_dir_pose, out_pose) save_as_chunk(pred_dir_eye, out_eye) # save_as_chunk(pred_dir, outmasked) # np.save(pred_path, pose_pre.cpu()) # np.save(gt_path, pose_gt.cpu()) # np.savetxt(pred_path, outmasked) # np.savetxt(gt_path, gtmasked) loss = F.mse_loss(gtmasked, outmasked, reduction='mean') print(loss) except KeyboardInterrupt: string = "Saving the evaluation before exiting.." print(string) epoch = checkpointname.split("_")[1].split(".")[0] metricname = "evaluation_metrics_{}_all.yaml".format(epoch) evalpath = os.path.join(folder, metricname) print(f"Saving evaluation: {evalpath}") # save_metrics(evalpath, metrics) ================================================ FILE: PBnet/src/evaluate/tvae_eval_norm_seg.py ================================================ import torch from tqdm import tqdm from src.utils.fixseed import fixseed from src.utils.utils import MultiEpochsDataLoader as DataLoader from src.utils.tensors_hdtf import collate, collate_old import os import numpy as np import torch.nn.functional as F # from .tools import save_metrics, format_metrics from src.models.get_model import get_model as get_gen_model INF_LENGTH = 600 def save_images_as_npy(input_data, output_file): # save_npy = np.zeros(input_data.shape[0], 7) # save_npy[:,:, :-1] = input_data # save_npy[:, -1] = ref[:,:, -1] # images_array = np.array(images) np.save(output_file, input_data) def save_as_chunk(dir, data): if not os.path.exists(dir): os.makedirs(dir) chunks = [data[i:min(i + 25, data.shape[0])] for i in range(0, data.shape[0], 25)] for i, chunk in enumerate(chunks): output_file = os.path.join(dir, f'chunk_%04d.npy' % (i)) # chunk = np.stack(chunk, axis = 0) save_images_as_npy(chunk, output_file) def transform(x, min_val, max_val): out = x * (max_val - min_val) + min_val return out def evaluate(parameters, dataset, folder, checkpointname, epoch, niter): # num_frames = 60 min_val = dataset.min_vals max_val = dataset.max_vals device = parameters["device"] # dummy => update parameters info model = get_gen_model(parameters) print("Restore weights..") checkpointpath = os.path.join(folder, checkpointname) model_ckpt = model.state_dict() state_dict = torch.load(checkpointpath, map_location=device) for name, _ in model_ckpt.items(): if model_ckpt[name].shape == state_dict[name].shape: model_ckpt[name].copy_(state_dict[name]) model.load_state_dict(model_ckpt) # model.load_state_dict(state_dict) model.eval() if checkpointname.split("_")[0] == 'retraincheckpoint': save_folder = os.path.join(folder, 'fintune', checkpointname.split('_')[2]+'_'+checkpointname.split('_')[4].split('.')[0]) else: save_folder = os.path.join(folder, 'nofinetune', checkpointname.split('_')[1].split('.')[0]) os.makedirs(save_folder, exist_ok=True) allseeds = list(range(niter)) try: for index, seed in enumerate(allseeds): print(f"Evaluation number: {index+1}/{niter}") fixseed(seed) save_pred_path = os.path.join(save_folder, 'eval_pred', str(seed)) save_gt_path = os.path.join(save_folder, 'eval_gt', str(seed)) os.makedirs(save_pred_path, exist_ok=True) os.makedirs(save_gt_path, exist_ok=True) dataiterator = DataLoader(dataset, batch_size=parameters["batch_size"], shuffle=False, num_workers=8, collate_fn=collate_old) with torch.no_grad(): for databatch in tqdm(dataiterator, desc=f"Construct dataloader: generating.."): # batch = {key: val.to(device) for key, val in databatch.items()} pose = databatch["x"][:,:,:-1] # b, len, c ref = databatch['x'][:,:, -1] audio = databatch["y"] # b, len, c gendurations = databatch["lengths"] # start = databatch["start"] output = None for i in range(0, pose.shape[1], INF_LENGTH): # step 1: seg start = i end = min(pose.shape[1], i + INF_LENGTH) pose_seg = pose[:, start:end] audio_seg = audio[:, start:end] gendurations_seg = torch.tensor([end - start]) # step 2: predict batch = model.generate(pose_seg, audio_seg, gendurations_seg, fact = 1) # step 3: merge if output == None: output = batch['output'].detach().cpu() else: output = torch.concat([output, batch['output'].detach().cpu()], dim= 1) # batch = model.generate(pose, audio, gendurations, fact = 1) # batch = {key: val.to(device) for key, val in batch.items()} for pose_pre, pose_gt, mask, filename, start_num in zip(output, databatch['x'], databatch['mask'], databatch['videoname'], databatch['start']): pose_pre = pose_pre.cpu() padding_vec = torch.zeros(pose_pre.shape[0], 1) pose_pre = torch.concat([pose_pre, padding_vec], dim = -1) for i in range(0, pose_gt.shape[0], INF_LENGTH): start = i end = min(pose_gt.shape[0], i + INF_LENGTH) x_ref = pose_gt[i,:].unsqueeze(dim=0) pose_pre[start:end] = pose_pre[start:end]+x_ref gtmasked = pose_gt[mask].cpu() outmasked = pose_pre[mask].cpu() gtmasked = transform(gtmasked, min_val, max_val) outmasked = transform(outmasked, min_val, max_val) pred_dir = os.path.join(save_pred_path, filename) save_as_chunk(pred_dir, outmasked) # np.save(pred_path, pose_pre.cpu()) # np.save(gt_path, pose_gt.cpu()) # np.savetxt(pred_path, outmasked) # np.savetxt(gt_path, gtmasked) loss = F.mse_loss(gtmasked, outmasked, reduction='mean') print(loss) except KeyboardInterrupt: string = "Saving the evaluation before exiting.." print(string) epoch = checkpointname.split("_")[1].split(".")[0] metricname = "evaluation_metrics_{}_all.yaml".format(epoch) evalpath = os.path.join(folder, metricname) print(f"Saving evaluation: {evalpath}") # save_metrics(evalpath, metrics) ================================================ FILE: PBnet/src/evaluate/tvae_eval_onlyeye_all_seg.py ================================================ import torch from tqdm import tqdm from src.utils.fixseed import fixseed from src.utils.utils import MultiEpochsDataLoader as DataLoader from src.utils.tensors_onlyeye import collate_eval import os import numpy as np import torch.nn.functional as F # from .tools import save_metrics, format_metrics from src.models.get_model import get_model as get_gen_model def save_images_as_npy(input_data, output_file): # save_npy = np.zeros(input_data.shape[0], 7) # save_npy[:,:, :-1] = input_data # save_npy[:, -1] = ref[:,:, -1] # images_array = np.array(images) np.save(output_file, input_data) def save_as_chunk(dir, data): if not os.path.exists(dir): os.makedirs(dir) chunks = [data[i:min(i + 25, data.shape[0])] for i in range(0, data.shape[0], 25)] for i, chunk in enumerate(chunks): output_file = os.path.join(dir, f'chunk_%04d.npy' % (i)) # chunk = np.stack(chunk, axis = 0) save_images_as_npy(chunk, output_file) # def transform(x, min_val, max_val): # out = x * (max_val - min_val) + min_val # return out def evaluate(parameters, dataset, folder, checkpointname, epoch, niter): # num_frames = 60 # min_val = dataset.min_vals # max_val = dataset.max_vals device = parameters["device"] # dummy => update parameters info model = get_gen_model(parameters) print("Restore weights..") checkpointpath = os.path.join(folder, checkpointname) model_ckpt = model.state_dict() state_dict = torch.load(checkpointpath, map_location=device) for name, _ in model_ckpt.items(): if model_ckpt[name].shape == state_dict[name].shape: model_ckpt[name].copy_(state_dict[name]) model.load_state_dict(model_ckpt) # model.load_state_dict(state_dict) model.eval() if checkpointname.split("_")[0] == 'retraincheckpoint': save_folder = os.path.join(folder, 'fintune', checkpointname.split('_')[2]+'_'+checkpointname.split('_')[4].split('.')[0]) else: save_folder = os.path.join(folder, 'nofinetune', checkpointname.split('_')[1].split('.')[0]) os.makedirs(save_folder, exist_ok=True) allseeds = list(range(niter)) try: for index, seed in enumerate(allseeds): print(f"Evaluation number: {index+1}/{niter}") fixseed(seed) save_pred_path = os.path.join(save_folder, 'seg_all', 'eval_pred', str(seed)) save_gt_path = os.path.join(save_folder, 'seg_all', 'eval_gt', str(seed)) os.makedirs(save_pred_path, exist_ok=True) os.makedirs(save_gt_path, exist_ok=True) dataiterator = DataLoader(dataset, batch_size=parameters["batch_size"], shuffle=False, num_workers=8, collate_fn=collate_eval) with torch.no_grad(): for databatch in tqdm(dataiterator, desc=f"Construct dataloader: generating.."): # batch = {key: val.to(device) for key, val in databatch.items()} pose = databatch["x"] # b, len, c audio = databatch["y"] # b, len, c gendurations = databatch["lengths"] # start = databatch["start"] output = None for i in range(0, pose.shape[1], 200): # step 1: seg start = i end = min(pose.shape[1], i + 200) pose_seg = pose[:, start:end] audio_seg = audio[:, start:end] # gendurations_seg = gendurations[:, start:end] gendurations_seg = torch.tensor([end - start]) # step 2: predict batch = model.generate(pose_seg, audio_seg, gendurations_seg, fact = 1) # step 3: merge if output == None: output = batch['output'].detach().cpu() else: output = torch.concat([output, batch['output'].detach().cpu()], dim= 1) # batch = model.generate(pose, audio, gendurations, fact = 1) # batch = {key: val.to(device) for key, val in batch.items()} for pose_pre, pose_gt, mask, filename, start_num in zip(output, databatch['x'], databatch['mask'], databatch['videoname'], databatch['start']): pose_pre = pose_pre.cpu() # padding_vec = torch.zeros(pose_pre.shape[0], 1) # pose_pre = torch.concat([pose_pre, padding_vec], dim = -1) for i in range(0, pose_gt.shape[0], 200): start = i end = min(pose_gt.shape[0], i + 200) x_ref = pose_gt[i,:].unsqueeze(dim=0) pose_pre[start:end] = pose_pre[start:end]+x_ref gtmasked = pose_gt[mask].cpu() outmasked = pose_pre[mask].cpu() # gtmasked = transform(gtmasked, min_val, max_val) # outmasked = transform(outmasked, min_val, max_val) pred_dir = os.path.join(save_pred_path, filename) save_as_chunk(pred_dir, outmasked) # np.save(pred_path, pose_pre.cpu()) # np.save(gt_path, pose_gt.cpu()) # np.savetxt(pred_path, outmasked) # np.savetxt(gt_path, gtmasked) loss = F.mse_loss(gtmasked, outmasked, reduction='mean') print(loss) except KeyboardInterrupt: string = "Saving the evaluation before exiting.." print(string) epoch = checkpointname.split("_")[1].split(".")[0] metricname = "evaluation_metrics_{}_all.yaml".format(epoch) evalpath = os.path.join(folder, metricname) print(f"Saving evaluation: {evalpath}") # save_metrics(evalpath, metrics) ================================================ FILE: PBnet/src/evaluate/tvae_eval_single.py ================================================ import torch from tqdm import tqdm import sys import os current_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(os.path.dirname(current_dir)) if parent_dir not in sys.path: sys.path.append(parent_dir) print(parent_dir) from src.utils.fixseed import fixseed from src.parser.tools import load_args import os import numpy as np import torch.nn.functional as F # from .tools import save_metrics, format_metrics from src.models.get_model import get_model as get_gen_model import argparse max_vals = torch.tensor([90, 90, 90, 1, 720, 1080]).to(torch.float32).reshape(1, 1, 6) min_vals = torch.tensor([-90, -90, -90, 0, 0, 0]).to(torch.float32).reshape(1, 1, 6) def inv_transform(x, min_val, max_val): out = x * (max_val - min_val) + min_val return out def save_images_as_npy(input_data, output_file): # save_npy = np.zeros(input_data.shape[0], 7) # save_npy[:,:, :-1] = input_data # save_npy[:, -1] = ref[:,:, -1] # images_array = np.array(images) np.save(output_file, input_data) # def transform(x, min_val, max_val): # out = x * (max_val - min_val) + min_val # return out def evaluate(parameters_pose, parameters_blink, audio_path, init_pose_path, init_blink_path, checkpoint_p_path, checkpoint_b_path, output_path): # num_frames = 60 # min_val = dataset.min_vals # max_val = dataset.max_vals device = "cuda:0" pose_dim = parameters_pose['pos_dim'] eye_dim = parameters_blink['eye_dim'] # dummy => update parameters info model_p = get_gen_model(parameters_pose) model_b = get_gen_model(parameters_blink) print("Restore weights..") # checkpointpath = os.path.join(folder, checkpointname) # model_p_ckpt = model_p.state_dict() # model_b_ckpt = model_b.state_dict() state_dict_p = torch.load(checkpoint_p_path, map_location=device) state_dict_b = torch.load(checkpoint_b_path, map_location=device) # for name, _ in model_ckpt.items(): # if model_ckpt[name].shape == state_dict[name].shape: # model_ckpt[name].copy_(state_dict[name]) # model.load_state_dict(model_ckpt) model_p.load_state_dict(state_dict_p) model_b.load_state_dict(state_dict_b) model_p.eval() model_b.eval() os.makedirs(output_path, exist_ok=True) try: init_pose = torch.from_numpy(np.load(init_pose_path))[:,:pose_dim].unsqueeze(0).to(torch.float32) init_blink = torch.from_numpy(np.load(init_blink_path))[:,:eye_dim].unsqueeze(0).to(torch.float32) audio = torch.from_numpy(np.load(audio_path)).unsqueeze(0).to(torch.float32) except Exception: # the 3ddfa fail to extract valid pose, using typical value instead init_pose = torch.from_numpy(np.array([[0, 0, 0, 4.79e-04, 5.65e+01, 6.49e+01,]]))[:,:pose_dim].unsqueeze(0).to(torch.float32) init_blink = torch.from_numpy(np.array([[0.3,0.3]]))[:,:eye_dim].unsqueeze(0).to(torch.float32) audio = torch.from_numpy(np.load(audio_path)).unsqueeze(0).to(torch.float32) init_pose = (init_pose - min_vals)/ (max_vals - min_vals) fixseed(1234) with torch.no_grad(): # batch = {key: val.to(device) for key, val in databatch.items()} # step 1: seg pose_seg = init_pose blink_seg = init_blink audio_seg = audio # gendurations_seg = gendurations[:, start:end] gendurations_seg = torch.tensor([audio.shape[1] - 0]) # step 2: predict batch_p = model_p.generate(pose_seg, audio_seg, gendurations_seg, fact = 1) batch_b = model_b.generate(blink_seg, audio_seg, gendurations_seg, fact = 1) # step 3: merge output_p = batch_p['output'].detach().cpu() output_b = batch_b['output'].detach().cpu() output_p = output_p + pose_seg output_p = inv_transform(output_p, min_vals, max_vals) output_b = output_b + blink_seg output_pose_path = os.path.join(output_path, 'dri_pose.npy') output_blink_path = os.path.join(output_path, 'dri_blink.npy') np.save(output_pose_path , output_p[0]) np.save(output_blink_path, output_b[0]) def get_arguments(): """Parse all the arguments provided from the CLI. Returns: A list of parsed arguments. """ parser = argparse.ArgumentParser(description="PBnet") parser.add_argument("--audio_path", default='/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert_interpolate/RD_Radio54_000.npy') parser.add_argument("--ckpt_pose", default='your_path/pretrain_models/pbnet_seperate/pose/checkpoint_40000.pth.tar', help="ckpt of PoseNet") parser.add_argument("--ckpt_blink", default='your_path/pretrain_models/pbnet_seperate/blink/checkpoint_95000.pth.tar', help="ckpt of BlinkNet") parser.add_argument("--init_pose_blink", default='your/path/DAWN-pytorch/ood_data/ood_test_material/cache_2', help="dir of init pose/blink") parser.add_argument("--output", default='/train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/demo_output', help="output_dir") return parser.parse_args() if __name__ == '__main__': args = get_arguments() audio_path = args.audio_path ckpt_pose = args.ckpt_pose ckpt_blink = args.ckpt_blink output_dir = args.output init_blink = os.path.join(args.init_pose_blink, 'init_eye_bbox.npy') # init_eye_bbox.npy init_pose = os.path.join(args.init_pose_blink, 'init_pose.npy') folder_p, _ = os.path.split(ckpt_pose) parameters_p = load_args(os.path.join(folder_p, "opt.yaml")) parameters_p['device'] = 'cuda:0' parameters_p["audio_dim"] = 1024 parameters_p["pos_dim"] = 6 parameters_p["eye_dim"] = 0 folder_b, _ = os.path.split(ckpt_blink) parameters_b = load_args(os.path.join(folder_b, "opt.yaml")) parameters_b['device'] = 'cuda:0' parameters_b["audio_dim"] = 1024 parameters_b["pos_dim"] = 0 parameters_b["eye_dim"] = 2 evaluate(parameters_pose = parameters_p, parameters_blink = parameters_b, audio_path = audio_path, init_pose_path = init_pose, init_blink_path = init_blink, checkpoint_p_path = ckpt_pose, checkpoint_b_path = ckpt_blink, output_path = output_dir) ================================================ FILE: PBnet/src/evaluate/tvae_eval_single_both_eye_pose.py ================================================ import torch from tqdm import tqdm import os import sys # adding path of PBnet current_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(os.path.dirname(current_dir)) if parent_dir not in sys.path: sys.path.append(parent_dir) print(parent_dir) from src.utils.fixseed import fixseed from src.parser.tools import load_args import numpy as np import torch.nn.functional as F # from .tools import save_metrics, format_metrics from src.models.get_model import get_model as get_gen_model import argparse max_vals = torch.tensor([90, 90, 90, 1, 720, 1080, 1, 1]).to(torch.float32).reshape(1, 1, 8) min_vals = torch.tensor([-90, -90, -90, 0, 0, 0, 0, 0]).to(torch.float32).reshape(1, 1, 8) def inv_transform(x, min_val, max_val): out = x * (max_val - min_val) + min_val return out def save_images_as_npy(input_data, output_file): # save_npy = np.zeros(input_data.shape[0], 7) # save_npy[:,:, :-1] = input_data # save_npy[:, -1] = ref[:,:, -1] # images_array = np.array(images) np.save(output_file, input_data) # def transform(x, min_val, max_val): # out = x * (max_val - min_val) + min_val # return out def evaluate(parameters, audio_path, init_pose_path, init_blink_path, checkpoint_path, output_path): # num_frames = 60 # min_val = dataset.min_vals # max_val = dataset.max_vals device = parameters["device"] pose_dim = parameters['pos_dim'] eye_dim = parameters['eye_dim'] # dummy => update parameters info model = get_gen_model(parameters) print("Restore weights..") # checkpointpath = os.path.join(folder, checkpointname) # model_p_ckpt = model_p.state_dict() # model_b_ckpt = model_b.state_dict() state_dict_p = torch.load(checkpoint_path, map_location=device) # for name, _ in model_ckpt.items(): # if model_ckpt[name].shape == state_dict[name].shape: # model_ckpt[name].copy_(state_dict[name]) # model.load_state_dict(model_ckpt) model.load_state_dict(state_dict_p) model.eval() os.makedirs(output_path, exist_ok=True) try: init_pose = torch.from_numpy(np.load(init_pose_path))[:,:pose_dim].unsqueeze(0).to(torch.float32) init_blink = torch.from_numpy(np.load(init_blink_path))[:,:eye_dim].unsqueeze(0).to(torch.float32) audio = torch.from_numpy(np.load(audio_path)).unsqueeze(0).to(torch.float32) except Exception: # the 3ddfa fail to extract valid pose, using typical value instead init_pose = torch.from_numpy(np.array([[0, 0, 0, 4.79e-04, 5.65e+01, 6.49e+01,]]))[:,:pose_dim].unsqueeze(0).to(torch.float32) init_blink = torch.from_numpy(np.array([[0.3,0.3]]))[:,:eye_dim].unsqueeze(0).to(torch.float32) audio = torch.from_numpy(np.load(audio_path)).unsqueeze(0).to(torch.float32) pose_seg = init_pose blink_seg = init_blink init_pose = torch.concat([pose_seg, blink_seg], dim = -1) init_pose = (init_pose - min_vals)/ (max_vals - min_vals) fixseed(1234) with torch.no_grad(): # batch = {key: val.to(device) for key, val in databatch.items()} # step 1: seg audio_seg = audio # gendurations_seg = gendurations[:, start:end] gendurations_seg = torch.tensor([audio.shape[1] - 0]) # step 2: predict batch = model.generate(init_pose, audio_seg, gendurations_seg, fact = 1) # step 3: merge output = batch['output'].detach().cpu() output = output + init_pose output = inv_transform(output, min_vals, max_vals) output_pose_path = os.path.join(output_path, 'dri_pose.npy') output_blink_path = os.path.join(output_path, 'dri_blink.npy') np.save(output_pose_path , output[0,:,:pose_dim]) np.save(output_blink_path, output[0,:,pose_dim:]) def get_arguments(): """Parse all the arguments provided from the CLI. Returns: A list of parsed arguments. """ parser = argparse.ArgumentParser(description="PBnet") parser.add_argument("--audio_path", default='/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert_interpolate/RD_Radio54_000.npy') parser.add_argument("--ckpt", default='../pretrain_models/pbnet_both/checkpoint_100000.pth.tar', help="ckpt of PoseNet") parser.add_argument("--init_pose_blink", default='your/path/DAWN-pytorch/ood_data/ood_test_material/cache_2', help="dir of init pose/blink") parser.add_argument("--output", default='/train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/demo_output', help="output_dir") return parser.parse_args() if __name__ == '__main__': args = get_arguments() audio_path = args.audio_path ckpt_pose = args.ckpt output_dir = args.output init_blink = os.path.join(args.init_pose_blink, 'init_eye_bbox.npy') # init_eye_bbox.npy init_pose = os.path.join(args.init_pose_blink, 'init_pose.npy') folder_p, _ = os.path.split(ckpt_pose) parameters = load_args(os.path.join(folder_p, "opt.yaml")) parameters['device'] = 'cuda:0' parameters["audio_dim"] = 1024 parameters["pos_dim"] = 6 parameters["eye_dim"] = 2 evaluate(parameters = parameters, audio_path = audio_path, init_pose_path = init_pose, init_blink_path = init_blink, checkpoint_path = ckpt_pose, output_path = output_dir) ================================================ FILE: PBnet/src/evaluate/tvae_eval_std.py ================================================ import torch from tqdm import tqdm from src.utils.fixseed import fixseed from src.utils.utils import MultiEpochsDataLoader as DataLoader from src.utils.tensors_hdtf import collate import os import numpy as np import torch.nn.functional as F # from .tools import save_metrics, format_metrics from src.models.get_model import get_model as get_gen_model def evaluate(parameters, dataset, folder, checkpointname, epoch, niter): # num_frames = 60 device = parameters["device"] # dummy => update parameters info model = get_gen_model(parameters) print("Restore weights..") checkpointpath = os.path.join(folder, checkpointname) model_ckpt = model.state_dict() state_dict = torch.load(checkpointpath, map_location=device) for name, _ in model_ckpt.items(): if model_ckpt[name].shape == state_dict[name].shape: model_ckpt[name].copy_(state_dict[name]) model.load_state_dict(model_ckpt) model.eval() if checkpointname.split("_")[0] == 'retraincheckpoint': save_folder = os.path.join(folder, 'fintune', checkpointname.split('_')[2]+'_'+checkpointname.split('_')[4].split('.')[0]) else: save_folder = os.path.join(folder, 'nofinetune', checkpointname.split('_')[1].split('.')[0]) os.makedirs(save_folder, exist_ok=True) allseeds = list(range(niter)) try: for index, seed in enumerate(allseeds): print(f"Evaluation number: {index+1}/{niter}") fixseed(seed) save_pred_path = os.path.join(save_folder, 'eval_pred', str(seed)) save_gt_path = os.path.join(save_folder, 'eval_gt', str(seed)) os.makedirs(save_pred_path, exist_ok=True) os.makedirs(save_gt_path, exist_ok=True) dataiterator = DataLoader(dataset, batch_size=parameters["batch_size"], shuffle=False, num_workers=8, collate_fn=collate) with torch.no_grad(): for databatch in tqdm(dataiterator, desc=f"Construct dataloader: generating.."): # batch = {key: val.to(device) for key, val in databatch.items()} pose = databatch["x"] audio = databatch["y"] gendurations = databatch["lengths"] # start = databatch["start"] batch = model.generate(pose, audio, gendurations, fact = 1) batch = {key: val.to(device) for key, val in batch.items()} for pose_pre, pose_gt, mask, filename, start_num in zip(batch['output'], databatch['x'], databatch['mask'], databatch['videoname'], databatch['start']): x_ref = pose_gt[0,:].unsqueeze(dim=0).cpu() pose_pre = (pose_pre.cpu()+x_ref - 0.5) * 180 gtmasked = (pose_gt[mask].cpu() -0.5 ) * 180 outmasked = pose_pre[mask].cpu() pred_path = os.path.join(save_pred_path, filename+'_'+str(start_num)) gt_path = os.path.join(save_gt_path, filename+'_'+str(start_num)+'_gt') # np.save(pred_path, pose_pre.cpu()) # np.save(gt_path, pose_gt.cpu()) np.savetxt(pred_path, outmasked) np.savetxt(gt_path, gtmasked) loss = F.mse_loss(gtmasked, outmasked, reduction='mean') print(loss) except KeyboardInterrupt: string = "Saving the evaluation before exiting.." print(string) epoch = checkpointname.split("_")[1].split(".")[0] metricname = "evaluation_metrics_{}_all.yaml".format(epoch) evalpath = os.path.join(folder, metricname) print(f"Saving evaluation: {evalpath}") # save_metrics(evalpath, metrics) ================================================ FILE: PBnet/src/evaluate/tvae_eval_train.py ================================================ import torch from tqdm import tqdm from src.utils.fixseed import fixseed from src.utils.utils import MultiEpochsDataLoader as DataLoader from src.utils.tensors_hdtf import collate import os import numpy as np import torch.nn.functional as F # from .tools import save_metrics, format_metrics from src.models.get_model import get_model as get_gen_model def evaluate(parameters, dataset, folder, checkpointname, epoch, niter): # num_frames = 60 device = parameters["device"] # dummy => update parameters info model = get_gen_model(parameters) print("Restore weights..") checkpointpath = os.path.join(folder, checkpointname) state_dict = torch.load(checkpointpath, map_location=device) model.load_state_dict(state_dict) model.eval() if checkpointname.split("_")[0] == 'retraincheckpoint': save_folder = os.path.join(folder, 'fintune', checkpointname.split('_')[2]+'_'+checkpointname.split('_')[4].split('.')[0]) else: save_folder = os.path.join(folder, 'nofinetune', checkpointname.split('_')[1].split('.')[0]) os.makedirs(save_folder, exist_ok=True) allseeds = list(range(niter)) try: for index, seed in enumerate(allseeds): print(f"Evaluation number: {index+1}/{niter}") fixseed(seed) save_pred_path = os.path.join(save_folder, 'eval_pred', str(seed)) save_gt_path = os.path.join(save_folder, 'eval_gt', str(seed)) os.makedirs(save_pred_path, exist_ok=True) os.makedirs(save_gt_path, exist_ok=True) dataiterator = DataLoader(dataset, batch_size=parameters["batch_size"], shuffle=False, num_workers=8, collate_fn=collate) with torch.no_grad(): for databatch in tqdm(dataiterator, desc=f"Construct dataloader: generating.."): name_list = databatch['videoname'] start_list = databatch['start'] databatch = {key: val.to(device) for key, val in databatch.items() if key!='videoname' and key!='start'} pose = databatch["x"] audio = databatch["y"] gendurations = databatch["lengths"] # start = databatch["start"] batch = model.forward(databatch) batch = {key: val.to(device) for key, val in batch.items()} for pose_pre, pose_gt, mask, filename, start_num in zip(batch['output'], databatch['x'], databatch['mask'], name_list, start_list): x_ref = pose_gt[0,:].unsqueeze(dim=0) pose_pre = pose_pre.cpu()+x_ref.cpu() gtmasked = pose_gt[mask].cpu() outmasked = pose_pre[mask].cpu() pred_path = os.path.join(save_pred_path, filename+'_'+str(start_num)) gt_path = os.path.join(save_gt_path, filename+'_'+str(start_num)+'_gt') # np.save(pred_path, pose_pre.cpu()) # np.save(gt_path, pose_gt.cpu()) np.savetxt(pred_path, outmasked) np.savetxt(gt_path, gtmasked) loss = F.mse_loss(gtmasked, outmasked, reduction='mean') print(loss) except KeyboardInterrupt: string = "Saving the evaluation before exiting.." print(string) epoch = checkpointname.split("_")[1].split(".")[0] metricname = "evaluation_metrics_{}_all.yaml".format(epoch) evalpath = os.path.join(folder, metricname) print(f"Saving evaluation: {evalpath}") # save_metrics(evalpath, metrics) ================================================ FILE: PBnet/src/evaluate/tvae_eval_train_norm.py ================================================ import torch from tqdm import tqdm from src.utils.fixseed import fixseed from src.utils.utils import MultiEpochsDataLoader as DataLoader from src.utils.tensors_hdtf import collate import os import numpy as np import torch.nn.functional as F # from .tools import save_metrics, format_metrics from src.models.get_model import get_model as get_gen_model def transform(x, min_val, max_val): out = x * (max_val - min_val) + min_val return out def evaluate(parameters, dataset, folder, checkpointname, epoch, niter): # num_frames = 60 min_val = dataset.min_vals max_val = dataset.max_vals device = parameters["device"] # dummy => update parameters info model = get_gen_model(parameters) print("Restore weights..") checkpointpath = os.path.join(folder, checkpointname) state_dict = torch.load(checkpointpath, map_location=device) model.load_state_dict(state_dict) model.eval() if checkpointname.split("_")[0] == 'retraincheckpoint': save_folder = os.path.join(folder, 'fintune', checkpointname.split('_')[2]+'_'+checkpointname.split('_')[4].split('.')[0]) else: save_folder = os.path.join(folder, 'nofinetune', checkpointname.split('_')[1].split('.')[0]) os.makedirs(save_folder, exist_ok=True) allseeds = list(range(niter)) try: for index, seed in enumerate(allseeds): print(f"Evaluation number: {index+1}/{niter}") fixseed(seed) save_pred_path = os.path.join(save_folder, 'eval_pred', str(seed)) save_gt_path = os.path.join(save_folder, 'eval_gt', str(seed)) os.makedirs(save_pred_path, exist_ok=True) os.makedirs(save_gt_path, exist_ok=True) dataiterator = DataLoader(dataset, batch_size=parameters["batch_size"], shuffle=False, num_workers=8, collate_fn=collate) with torch.no_grad(): for databatch in tqdm(dataiterator, desc=f"Construct dataloader: generating.."): # batch = {key: val.to(device) for key, val in databatch.items()} name_list = databatch['videoname'] start_list = databatch['start'] databatch = {key: val.to(device) for key, val in databatch.items() if key!='videoname' and key!='start'} pose = databatch["x"] audio = databatch["y"] gendurations = databatch["lengths"] # start = databatch["start"] batch = model.forward(databatch) batch = {key: val.to(device) for key, val in batch.items()} for pose_pre, pose_gt, mask, filename, start_num in zip(batch['output'], databatch['x'], databatch['mask'], name_list, start_list): x_ref = pose_gt[0,:].unsqueeze(dim=0).cpu() pose_pre = pose_pre.cpu()+x_ref gtmasked = pose_gt[mask].cpu() outmasked = pose_pre[mask].cpu() gtmasked = transform(gtmasked, min_val, max_val) outmasked = transform(outmasked, min_val, max_val) pred_path = os.path.join(save_pred_path, filename+'_'+str(start_num)) gt_path = os.path.join(save_gt_path, filename+'_'+str(start_num)+'_gt') # np.save(pred_path, pose_pre.cpu()) # np.save(gt_path, pose_gt.cpu()) np.savetxt(pred_path, outmasked) np.savetxt(gt_path, gtmasked) loss = F.mse_loss(gtmasked, outmasked, reduction='mean') print('all loss: ',loss) loss_f3 = F.mse_loss(gtmasked[:, :3], outmasked[:, :3], reduction='mean') print('f3 loss: ',loss_f3) loss_ls = F.mse_loss(gtmasked[:, 3:], outmasked[:, 3:], reduction='mean') print('ls loss: ',loss_ls) except KeyboardInterrupt: string = "Saving the evaluation before exiting.." print(string) epoch = checkpointname.split("_")[1].split(".")[0] metricname = "evaluation_metrics_{}_all.yaml".format(epoch) evalpath = os.path.join(folder, metricname) print(f"Saving evaluation: {evalpath}") # save_metrics(evalpath, metrics) ================================================ FILE: PBnet/src/evaluate/tvae_eval_train_std.py ================================================ import torch from tqdm import tqdm from src.utils.fixseed import fixseed from src.utils.utils import MultiEpochsDataLoader as DataLoader from src.utils.tensors_hdtf import collate import os import numpy as np import torch.nn.functional as F # from .tools import save_metrics, format_metrics from src.models.get_model import get_model as get_gen_model def evaluate(parameters, dataset, folder, checkpointname, epoch, niter): # num_frames = 60 device = parameters["device"] # dummy => update parameters info model = get_gen_model(parameters) print("Restore weights..") checkpointpath = os.path.join(folder, checkpointname) state_dict = torch.load(checkpointpath, map_location=device) model.load_state_dict(state_dict) model.eval() if checkpointname.split("_")[0] == 'retraincheckpoint': save_folder = os.path.join(folder, 'fintune_train', checkpointname.split('_')[2]+'_'+checkpointname.split('_')[4].split('.')[0]) else: save_folder = os.path.join(folder, 'nofinetune', checkpointname.split('_')[1].split('.')[0]) os.makedirs(save_folder, exist_ok=True) allseeds = list(range(niter)) try: for index, seed in enumerate(allseeds): print(f"Evaluation number: {index+1}/{niter}") fixseed(seed) save_pred_path = os.path.join(save_folder, 'eval_pred', str(seed)) save_gt_path = os.path.join(save_folder, 'eval_gt', str(seed)) os.makedirs(save_pred_path, exist_ok=True) os.makedirs(save_gt_path, exist_ok=True) dataiterator = DataLoader(dataset, batch_size=parameters["batch_size"], shuffle=False, num_workers=8, collate_fn=collate) with torch.no_grad(): for databatch in tqdm(dataiterator, desc=f"Construct dataloader: generating.."): name_list = databatch['videoname'] start_list = databatch['start'] databatch = {key: val.to(device) for key, val in databatch.items() if key!='videoname' and key!='start'} pose = databatch["x"] audio = databatch["y"] gendurations = databatch["lengths"] # start = databatch["start"] batch = model.forward(databatch) batch = {key: val.to(device) for key, val in batch.items()} for pose_pre, pose_gt, mask, filename, start_num in zip(batch['output'], databatch['x'], databatch['mask'], name_list, start_list): x_ref = pose_gt[0,:].unsqueeze(dim=0).cpu() pose_pre = (pose_pre.cpu()+x_ref - 0.5) * 180 gtmasked = (pose_gt[mask].cpu() -0.5 ) * 180 outmasked = pose_pre[mask].cpu() pred_path = os.path.join(save_pred_path, filename+'_'+str(start_num)) gt_path = os.path.join(save_gt_path, filename+'_'+str(start_num)+'_gt') # np.save(pred_path, pose_pre.cpu()) # np.save(gt_path, pose_gt.cpu()) np.savetxt(pred_path, outmasked) np.savetxt(gt_path, gtmasked) loss = F.mse_loss(gtmasked, outmasked, reduction='mean') print(loss) except KeyboardInterrupt: string = "Saving the evaluation before exiting.." print(string) epoch = checkpointname.split("_")[1].split(".")[0] metricname = "evaluation_metrics_{}_all.yaml".format(epoch) evalpath = os.path.join(folder, metricname) print(f"Saving evaluation: {evalpath}") # save_metrics(evalpath, metrics) ================================================ FILE: PBnet/src/generate/generate_sequences.py ================================================ import os import matplotlib.pyplot as plt import torch import numpy as np from src.utils.get_model_and_data import get_model_and_data from src.models.get_model import get_model from src.parser.generate import parser import src.utils.fixseed # noqa plt.switch_backend('agg') def generate_actions(beta, model, dataset, epoch, params, folder, num_frames=60, durationexp=False, vertstrans=True, onlygen=False, nspa=10, inter=False, writer=None): """ Generate & viz samples """ # visualize with joints3D model.outputxyz = True # print("remove smpl") model.param2xyz["jointstype"] = "vertices" print(f"Visualization of the epoch {epoch}") fact = params["fact_latent"] num_classes = dataset.num_classes classes = torch.arange(num_classes) if not onlygen: nspa = 1 nats = num_classes if durationexp: nspa = 4 durations = [40, 60, 80, 100] gendurations = torch.tensor([[dur for cl in classes] for dur in durations], dtype=int) else: gendurations = torch.tensor([num_frames for cl in classes], dtype=int) if not onlygen: # extract the real samples real_samples, mask_real, real_lengths = dataset.get_label_sample_batch(classes.numpy()) # to visualize directly # Visualizaion of real samples visualization = {"x": real_samples.to(model.device), "y": classes.to(model.device), "mask": mask_real.to(model.device), "lengths": real_lengths.to(model.device), "output": real_samples.to(model.device)} reconstruction = {"x": real_samples.to(model.device), "y": classes.to(model.device), "lengths": real_lengths.to(model.device), "mask": mask_real.to(model.device)} print("Computing the samples poses..") # generate the repr (joints3D/pose etc) model.eval() with torch.no_grad(): if not onlygen: # Get xyz for the real ones visualization["output_xyz"] = model.rot2xyz(visualization["output"], visualization["mask"], vertstrans=vertstrans, beta=beta) # Reconstruction of the real data reconstruction = model(reconstruction) # update reconstruction dicts noise_same_action = "random" noise_diff_action = "random" # Generate the new data generation = model.generate(classes, gendurations, nspa=nspa, noise_same_action=noise_same_action, noise_diff_action=noise_diff_action, fact=fact) generation["output_xyz"] = model.rot2xyz(generation["output"], generation["mask"], vertstrans=vertstrans, beta=beta) outxyz = model.rot2xyz(reconstruction["output"], reconstruction["mask"], vertstrans=vertstrans, beta=beta) reconstruction["output_xyz"] = outxyz else: if inter: noise_same_action = "interpolate" else: noise_same_action = "random" noise_diff_action = "random" # Generate the new data generation = model.generate(classes, gendurations, nspa=nspa, noise_same_action=noise_same_action, noise_diff_action=noise_diff_action, fact=fact) generation["output_xyz"] = model.rot2xyz(generation["output"], generation["mask"], vertstrans=vertstrans, beta=beta) output = generation["output_xyz"].reshape(nspa, nats, *generation["output_xyz"].shape[1:]).cpu().numpy() if not onlygen: output = np.stack([visualization["output_xyz"].cpu().numpy(), generation["output_xyz"].cpu().numpy(), reconstruction["output_xyz"].cpu().numpy()]) return output def main(): parameters, folder, checkpointname, epoch = parser() nspa = parameters["num_samples_per_action"] # no dataset needed if parameters["mode"] in []: # ["gen", "duration", "interpolate"]: model = get_model(parameters) else: model, datasets = get_model_and_data(parameters) dataset = datasets["train"] # same for ntu print("Restore weights..") checkpointpath = os.path.join(folder, checkpointname) state_dict = torch.load(checkpointpath, map_location=parameters["device"]) model.load_state_dict(state_dict) from src.utils.fixseed import fixseed # noqa for seed in [1]: # [0, 1, 2]: fixseed(seed) # visualize_params onlygen = True vertstrans = False inter = True and onlygen varying_beta = False if varying_beta: betas = [-2, -1, 0, 1, 2] else: betas = [0] for beta in betas: output = generate_actions(beta, model, dataset, epoch, parameters, folder, inter=inter, vertstrans=vertstrans, nspa=nspa, onlygen=onlygen) if varying_beta: filename = "generation_beta_{}.npy".format(beta) else: filename = "generation.npy" filename = os.path.join(folder, filename) np.save(filename, output) print("Saved at: " + filename) if __name__ == '__main__': main() ================================================ FILE: PBnet/src/models/__init__.py ================================================ ================================================ FILE: PBnet/src/models/architectures/__init__.py ================================================ ================================================ FILE: PBnet/src/models/architectures/autotrans.py ================================================ from .transformer import Encoder_TRANSFORMER as Encoder_AUTOTRANS # noqa import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from .tools.transformer_layers import PositionalEncoding from .tools.transformer_layers import TransformerDecoderLayer # taken from joeynmt repo def subsequent_mask(size: int): """ Mask out subsequent positions (to prevent attending to future positions) Transformer helper function. :param size: size of mask (2nd and 3rd dim) :return: Tensor with 0s and 1s of shape (1, size, size) """ mask = np.triu(np.ones((1, size, size)), k=1).astype('uint8') return torch.from_numpy(mask) == 0 def augment_x(x, y, mask, lengths, num_classes, concatenate_time): bs, nframes, njoints, nfeats = x.size() x = x.reshape(bs, nframes, njoints*nfeats) if len(y.shape) == 1: # can give on hot encoded as input y = F.one_hot(y, num_classes) y = y.to(dtype=x.dtype) y = y[:, None, :].repeat((1, nframes, 1)) if concatenate_time: # Time embedding time = mask * 1/(lengths[..., None]-1) time = (time[:, None] * torch.arange(time.shape[1], device=x.device)[None, :])[:, 0] time = time[..., None] x_augmented = torch.cat((x, y, time), 2) else: x_augmented = torch.cat((x, y), 2) return x_augmented def augment_z(z, y, mask, lengths, num_classes, concatenate_time): if len(y.shape) == 1: # can give on hot encoded as input y = F.one_hot(y, num_classes) y = y.to(dtype=z.dtype) # concatenete z and y and repeat the input z_augmented = torch.cat((z, y), 1)[:, None].repeat((1, mask.shape[1], 1)) # Time embedding if concatenate_time: time = mask * 1/(lengths[..., None]-1) time = (time[:, None] * torch.arange(time.shape[1], device=z.device)[None, :])[:, 0] z_augmented = torch.cat((z_augmented, time[..., None]), 2) return z_augmented class Decoder_AUTOTRANS(nn.Module): def __init__(self, modeltype, njoints, nfeats, num_frames, num_classes, translation, pose_rep, glob, glob_rot, concatenate_time=True, positional_encoding=True, latent_dim=256, ff_size=1024, num_layers=4, num_heads=4, dropout=0.1, emb_dropout=0.1, teacher_forcing=True, **kargs): super().__init__() self.modeltype = modeltype self.njoints = njoints self.nfeats = nfeats self.num_frames = num_frames self.num_classes = num_classes self.pose_rep = pose_rep self.glob = glob self.glob_rot = glob_rot self.translation = translation self.concatenate_time = concatenate_time self.positional_encoding = positional_encoding self.latent_dim = latent_dim self.ff_size = ff_size self.num_layers = num_layers self.num_heads = num_heads self.dropout = dropout self.emb_dropout = emb_dropout self.teacher_forcing = teacher_forcing self.input_feats = self.latent_dim + self.num_classes self.input_feats_x = self.njoints*self.nfeats + self.num_classes if self.concatenate_time: self.input_feats += 1 self.input_feats_x += 1 self.embedding = nn.Linear(self.input_feats, self.latent_dim) self.embedding_x = nn.Linear(self.input_feats_x, self.latent_dim) self.output_feats = self.njoints*self.nfeats # create num_layers decoder layers and put them in a list self.layers = nn.ModuleList([TransformerDecoderLayer(size=self.latent_dim, ff_size=self.ff_size, num_heads=self.num_heads, dropout=self.dropout) for _ in range(self.num_layers)]) self.pe = PositionalEncoding(self.latent_dim) self.layer_norm = nn.LayerNorm(self.latent_dim, eps=1e-6) self.emb_dropout = nn.Dropout(p=self.emb_dropout) self.output_layer = nn.Linear(self.latent_dim, self.output_feats, bias=False) def forward(self, batch): z, y, mask = batch["z"], batch["y"], batch["mask"] lengths = mask.sum(1) lenseqmax = mask.shape[1] bs, njoints, nfeats = len(z), self.njoints, self.nfeats z_augmented = augment_z(z, y, mask, lengths, self.num_classes, self.concatenate_time) src = self.embedding(z_augmented) src_mask = mask.unsqueeze(1) # Check if using teacher forcing or not # if it is allowed and possible teacher_forcing = self.teacher_forcing and "x" in batch # in eval mode, by default it it not unless it is "forced" teacher_forcing = teacher_forcing and (self.training or batch.get("teacher_force", False)) if teacher_forcing: x = batch["x"].permute((0, 3, 1, 2)) # shift the input x = torch.cat((x.new_zeros((x.shape[0], 1, *x.shape[2:])), x[:, :-1]), axis=1) # Embedding of the input x_augmented = augment_x(x, y, mask, lengths, self.num_classes, self.concatenate_time) trg = self.embedding_x(x_augmented) trg_mask = (mask[:, None] * subsequent_mask(lenseqmax).type_as(mask)) # shape: torch.Size([48, 183, 183]) if self.positional_encoding: trg = self.pe(trg) trg = self.emb_dropout(trg) val = trg for layer in self.layers: val = layer(val, src, src_mask=src_mask, trg_mask=trg_mask) val = self.layer_norm(val) val = self.output_layer(val) # pad the output val[~mask] = 0 val = val.reshape((bs, lenseqmax, njoints, nfeats)) batch["output"] = val.permute(0, 2, 3, 1) else: # Create the first input x/src_mask x = torch.Tensor.new_zeros(z, (bs, 1, njoints, nfeats)) for index in range(lenseqmax): # change it to speed up current_mask = mask[:, :index+1] x_augmented = augment_x(x, y, current_mask, lengths, self.num_classes, self.concatenate_time) trg = self.embedding_x(x_augmented) trg_mask = (current_mask[:, None] * subsequent_mask(index+1).type_as(mask)) if self.positional_encoding: trg = self.pe(trg) trg = self.emb_dropout(trg) val = trg for layer in self.layers: val = layer(val, src, src_mask=src_mask, trg_mask=trg_mask) val = self.layer_norm(val) val = self.output_layer(val) # pad the output val[~current_mask] = 0 val = val.reshape((bs, index+1, njoints, nfeats)) # extract the last output last_out = val[:, -1] # concatenate it to input x x = torch.cat((x, last_out[:, None]), 1) # remove the dummy first input (BOS) batch["output"] = x[:, 1:].permute(0, 2, 3, 1) return batch ================================================ FILE: PBnet/src/models/architectures/fc.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F class Encoder_FC(nn.Module): def __init__(self, modeltype, njoints, nfeats, num_frames, num_classes, translation, pose_rep, glob, glob_rot, latent_dim=256, **kargs): super().__init__() self.modeltype = modeltype self.njoints = njoints self.nfeats = nfeats self.num_frames = num_frames self.num_classes = num_classes self.translation = translation self.pose_rep = pose_rep self.glob = glob self.glob_rot = glob_rot self.latent_dim = latent_dim self.activation = nn.GELU() self.input_dim = self.njoints*self.nfeats*self.num_frames+self.num_classes self.fully_connected = nn.Sequential(nn.Linear(self.input_dim, 512), nn.GELU(), nn.Linear(512, 256), nn.GELU()) if self.modeltype == "cvae": self.mu = nn.Linear(256, self.latent_dim) self.var = nn.Linear(256, self.latent_dim) else: self.final = nn.Linear(256, self.latent_dim) def forward(self, batch): x, y = batch["x"], batch["y"] bs, njoints, feats, nframes = x.size() if (njoints * feats * nframes) != self.njoints*self.nfeats*self.num_frames: raise ValueError("This model is not adapted with this input") if len(y.shape) == 1: # can give on hot encoded as input y = F.one_hot(y, self.num_classes) y = y.to(dtype=x.dtype) x = x.reshape(bs, njoints*feats*nframes) x = torch.cat((x, y), 1) x = self.fully_connected(x) if self.modeltype == "cvae": return {"mu": self.mu(x), "logvar": self.var(x)} else: return {"z": self.final(x)} class Decoder_FC(nn.Module): def __init__(self, modeltype, njoints, nfeats, num_frames, num_classes, translation, pose_rep, glob, glob_rot, latent_dim=256, **kargs): super().__init__() self.modeltype = modeltype self.njoints = njoints self.nfeats = nfeats self.num_frames = num_frames self.num_classes = num_classes self.translation = translation self.pose_rep = pose_rep self.glob = glob self.glob_rot = glob_rot self.latent_dim = latent_dim self.input_dim = self.latent_dim + self.num_classes self.output_dim = self.njoints*self.nfeats*self.num_frames self.fully_connected = nn.Sequential(nn.Linear(self.input_dim, 256), nn.GELU(), nn.Linear(256, 512), nn.GELU(), nn.Linear(512, self.output_dim), nn.GELU()) def forward(self, batch): z, y = batch["z"], batch["y"] # z: [batch_size, latent_dim] # y: [batch_size] if len(y.shape) == 1: # can give on hot encoded as input y = F.one_hot(y, self.num_classes) y = y.to(dtype=z.dtype) # y: [batch_size, num_classes] # z: [batch_size, latent_dim+num_classes] z = torch.cat((z, y), dim=1) z = self.fully_connected(z) bs, _ = z.size() z = z.reshape(bs, self.njoints, self.nfeats, self.num_frames) batch["output"] = z return batch ================================================ FILE: PBnet/src/models/architectures/gru.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F def augment_x(x, y, mask, lengths, num_classes, concatenate_time): bs, nframes, njoints, nfeats = x.size() x = x.reshape(bs, nframes, njoints*nfeats) if len(y.shape) == 1: # can give on hot encoded as input y = F.one_hot(y, num_classes) y = y.to(dtype=x.dtype) y = y[:, None, :].repeat((1, nframes, 1)) if concatenate_time: # Time embedding time = mask * 1/(lengths[..., None]-1) time = (time[:, None] * torch.arange(time.shape[1], device=x.device)[None, :])[:, 0] time = time[..., None] x_augmented = torch.cat((x, y, time), 2) else: x_augmented = torch.cat((x, y), 2) return x_augmented def augment_z(z, y, mask, lengths, num_classes, concatenate_time): if len(y.shape) == 1: # can give on hot encoded as input y = F.one_hot(y, num_classes) y = y.to(dtype=z.dtype) # concatenete z and y and repeat the input z_augmented = torch.cat((z, y), 1)[:, None].repeat((1, mask.shape[1], 1)) # Time embedding if concatenate_time: time = mask * 1/(lengths[..., None]-1) time = (time[:, None] * torch.arange(time.shape[1], device=z.device)[None, :])[:, 0] z_augmented = torch.cat((z_augmented, time[..., None]), 2) return z_augmented class Encoder_GRU(nn.Module): def __init__(self, modeltype, njoints, nfeats, num_frames, num_classes, translation, pose_rep, glob, glob_rot, concatenate_time=True, latent_dim=256, num_layers=4, **kargs): super().__init__() self.modeltype = modeltype self.njoints = njoints self.nfeats = nfeats self.num_frames = num_frames self.num_classes = num_classes self.pose_rep = pose_rep self.glob = glob self.glob_rot = glob_rot self.translation = translation self.concatenate_time = concatenate_time self.latent_dim = latent_dim self.num_layers = num_layers # Layers self.input_feats = self.njoints*self.nfeats + self.num_classes if self.concatenate_time: self.input_feats += 1 self.feats_embedding = nn.Linear(self.input_feats, self.latent_dim) self.gru = nn.GRU(self.latent_dim, self.latent_dim, num_layers=self.num_layers, batch_first=True) if self.modeltype == "cvae": self.mu = nn.Linear(self.latent_dim, self.latent_dim) self.var = nn.Linear(self.latent_dim, self.latent_dim) else: self.final = nn.Linear(self.latent_dim, self.latent_dim) def forward(self, batch): x, y, mask, lengths = batch["x"], batch["y"], batch["mask"], batch["lengths"] bs = len(y) x = x.permute((0, 3, 1, 2)) x = augment_x(x, y, mask, lengths, self.num_classes, self.concatenate_time) # Model x = self.feats_embedding(x) x = self.gru(x)[0] # Get last valid input x = x[tuple(torch.stack((torch.arange(bs, device=x.device), lengths-1)))] if self.modeltype == "cvae": return {"mu": self.mu(x), "logvar": self.var(x)} else: return {"z": self.final(x)} class Decoder_GRU(nn.Module): def __init__(self, modeltype, njoints, nfeats, num_frames, num_classes, translation, pose_rep, glob, glob_rot, concatenate_time=True, latent_dim=256, num_layers=4, **kargs): super().__init__() self.modeltype = modeltype self.njoints = njoints self.nfeats = nfeats self.num_frames = num_frames self.num_classes = num_classes self.pose_rep = pose_rep self.glob = glob self.glob_rot = glob_rot self.translation = translation self.concatenate_time = concatenate_time self.latent_dim = latent_dim self.num_layers = num_layers # Layers self.input_feats = self.latent_dim + self.num_classes if self.concatenate_time: self.input_feats += 1 self.feats_embedding = nn.Linear(self.input_feats, self.latent_dim) self.gru = nn.GRU(self.latent_dim, self.latent_dim, num_layers=self.num_layers, batch_first=True) self.output_feats = self.njoints*self.nfeats self.final_layer = nn.Linear(self.latent_dim, self.output_feats) def forward(self, batch): z, y, mask, lengths = batch["z"], batch["y"], batch["mask"], batch["lengths"] bs, nframes = mask.shape z = augment_z(z, y, mask, lengths, self.num_classes, self.concatenate_time) # Model z = self.feats_embedding(z) z = self.gru(z)[0] z = self.final_layer(z) # Post process z = z.reshape(bs, nframes, self.njoints, self.nfeats) # 0 for padded sequences z[~mask] = 0 z = z.permute(0, 2, 3, 1) batch["output"] = z return batch ================================================ FILE: PBnet/src/models/architectures/grutrans.py ================================================ from .gru import Encoder_GRU as Encoder_GRUTRANS # noqa from .transformer import Decoder_TRANSFORMER as Decoder_GRUTRANS # noqa ================================================ FILE: PBnet/src/models/architectures/mlp.py ================================================ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class Upsample(nn.Module): def __init__(self, input_dim, output_dim, kernel, stride): super(Upsample, self).__init__() self.upsample = nn.ConvTranspose2d( input_dim, output_dim, kernel_size=kernel, stride=stride ) def forward(self, x): return self.upsample(x) class ResidualConv(nn.Module): def __init__(self, input_dim, output_dim, stride, padding): super(ResidualConv, self).__init__() self.conv_block = nn.Sequential( nn.BatchNorm2d(input_dim), nn.ReLU(), nn.Conv2d( input_dim, output_dim, kernel_size=3, stride=stride, padding=padding ), nn.BatchNorm2d(output_dim), nn.ReLU(), nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), ) self.conv_skip = nn.Sequential( nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1), nn.BatchNorm2d(output_dim), ) def forward(self, x): return self.conv_block(x) + self.conv_skip(x) class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x): # not used in the final model x = x + self.pe[:x.shape[0], :] return self.dropout(x) class RelativePositionBias(nn.Module): def __init__( self, heads=8, num_buckets=32, max_distance=128 ): super().__init__() self.num_buckets = num_buckets self.max_distance = max_distance self.relative_attention_bias = nn.Embedding(num_buckets, heads) @staticmethod def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): ret = 0 n = -relative_position num_buckets //= 2 ret += (n < 0).long() * num_buckets n = torch.abs(n) max_exact = num_buckets // 2 is_small = n < max_exact val_if_large = max_exact + ( torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).long() val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) ret += torch.where(is_small, n, val_if_large) return ret def forward(self, n, device): q_pos = torch.arange(n, dtype=torch.long, device=device) k_pos = torch.arange(n, dtype=torch.long, device=device) rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') rp_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance) values = self.relative_attention_bias(rp_bucket) return rearrange(values, 'i j h -> h i j') # only for ablation / not used in the final model class TimeEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super(TimeEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) def forward(self, x, mask, lengths): time = mask * 1/(lengths[..., None]-1) time = time[:, None] * torch.arange(time.shape[1], device=x.device)[None, :] time = time[:, 0].T # add the time encoding x = x + time[..., None] return self.dropout(x) class ResUnet(nn.Module): def __init__(self, channel=1, filters=[32, 64, 128, 256]): super(ResUnet, self).__init__() self.input_layer = nn.Sequential( nn.Conv2d(channel, filters[0], kernel_size=3, padding=1), nn.BatchNorm2d(filters[0]), nn.ReLU(), nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), ) self.input_skip = nn.Sequential( nn.Conv2d(channel, filters[0], kernel_size=3, padding=1) ) self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1) self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1) self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1) self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1)) self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1) self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1)) self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1) self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1)) self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1) self.output_layer = nn.Sequential( nn.Conv2d(filters[0], 1, 1, 1), nn.Sigmoid(), ) def forward(self, x): # Encode x1 = self.input_layer(x) + self.input_skip(x) x2 = self.residual_conv_1(x1) x3 = self.residual_conv_2(x2) # Bridge x4 = self.bridge(x3) # Decode x4 = self.upsample_1(x4) x5 = torch.cat([x4, x3], dim=1) x6 = self.up_residual_conv1(x5) x6 = self.upsample_2(x6) x7 = torch.cat([x6, x2], dim=1) x8 = self.up_residual_conv2(x7) x8 = self.upsample_3(x8) x9 = torch.cat([x8, x1], dim=1) x10 = self.up_residual_conv3(x9) output = self.output_layer(x10) return output class Encoder_MLP(nn.Module): def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=7, pose_latent_dim=64, audio_latent_dim=256, ff_size=128, num_layers=4, num_heads=4, dropout=0.1, ablation=None, activation="gelu", **kargs): super().__init__() self.modeltype = modeltype self.resunet = ResUnet() self.audio_latent_dim = audio_latent_dim # self.num_classes = num_classes self.seq_len = num_frames self.pose_latent_dim = pose_latent_dim self.MLP = nn.Sequential() layer_sizes = [pos_dim + self.seq_len * pos_dim + self.seq_len * self.audio_latent_dim, ff_size] # layer_sizes[0] = self.audio_latent_dim + self.pose_latent_dim*2 for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])): self.MLP.add_module( name="L{:d}".format(i), module=nn.Linear(in_size, out_size)) self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU()) self.linear_means = nn.Linear(layer_sizes[-1], ff_size) self.linear_logvar = nn.Linear(layer_sizes[-1], ff_size) self.linear_audio = nn.Linear(audio_dim, self.audio_latent_dim) # self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size)) def forward(self, batch): # class_id = batch['class'] pose_motion_gt_ori = batch["x"] #bs seq_len 6 ref = pose_motion_gt_ori[:,0,:] #bs 6 batch['x_delta'] = pose_motion_gt_ori - ref[:,None,:] pose_motion_gt = batch['x_delta'] bs = pose_motion_gt_ori.shape[0] audio_in = batch["y"] # bs seq_len audio_emb_in_size #pose encode pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6 pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6 #audio mapping # print(audio_in.shape) audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size audio_out = audio_out.reshape(bs, -1) # class_bias = self.classbias[class_id] #bs latent_size x_in = torch.cat([ref, pose_emb, audio_out], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size x_out = self.MLP(x_in) mu = self.linear_means(x_out) logvar = self.linear_means(x_out) #bs latent_size batch.update({'mu':mu, 'logvar':logvar}) return batch class Decoder_MLP(nn.Module): def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=7, pose_latent_dim=64, audio_latent_dim=256, ff_size=128, num_layers=4, num_heads=4, dropout=0.1, activation="gelu", ablation=None, **kargs): super().__init__() self.resunet = ResUnet() # self.num_classes = num_classes self.seq_len = num_frames self.MLP = nn.Sequential() self.audio_latent_dim = audio_latent_dim layer_sizes = [ff_size, self.seq_len * pos_dim] input_size = ff_size + self.seq_len*audio_latent_dim + pos_dim for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)): self.MLP.add_module( name="L{:d}".format(i), module=nn.Linear(in_size, out_size)) if i+1 < len(layer_sizes): self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU()) else: self.MLP.add_module(name="sigmoid", module=nn.Sigmoid()) self.pose_linear = nn.Linear(pos_dim, pos_dim) self.linear_audio = nn.Linear(audio_dim, audio_latent_dim) # self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size)) def forward(self, batch): z = batch['z'] #bs latent_size bs = z.shape[0] pose_motion_gt = batch["x"] #bs seq_len 6 ref = pose_motion_gt[:,0,:] # class_id = batch['class'] #bs 6 audio_in = batch['y'] # bs seq_len audio_emb_in_size #print('audio_in: ', audio_in[:, :, :10]) audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size #print('audio_out: ', audio_out[:, :, :10]) audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size # class_bias = self.classbias[class_id] #bs latent_size z = z # + class_bias x_in = torch.cat([ref, z, audio_out], dim=-1) x_out = self.MLP(x_in) # bs layer_sizes[-1] x_out = x_out.reshape((bs, self.seq_len, -1)) #print('x_out: ', x_out) pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6 pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6 pose_motion_pred = pose_motion_pred batch.update({'output':pose_motion_pred}) return batch ================================================ FILE: PBnet/src/models/architectures/resnet34.py ================================================ import torch import torch.nn as nn from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d from sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: raise ValueError('BasicBlock only supports groups=1 and base_width=64') if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class ResNet(nn.Module): def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None,input_channel = 3): super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: # each element in the tuple indicates if we should replace # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError("replace_stride_with_dilation should be None " "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d(input_channel, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) def _make_layer(self, block, planes, blocks, stride=1, dilate=False): norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer)) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return x def _resnet(arch, block, layers, pretrained, progress, **kwargs): model = ResNet(block, layers, **kwargs) return model def resnet34(pretrained=False, progress=True, **kwargs): r"""ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) class MyResNet34(nn.Module): def __init__(self,embedding_dim,input_channel = 3): super(MyResNet34, self).__init__() self.resnet = resnet34(norm_layer = BatchNorm2d,num_classes=embedding_dim,input_channel = input_channel) def forward(self, x): return self.resnet(x) ================================================ FILE: PBnet/src/models/architectures/tools/embeddings.py ================================================ # This file is taken from signjoey repository import math import torch from torch import nn, Tensor from ....tools.tools import freeze_params def get_activation(activation_type): if activation_type == "relu": return nn.ReLU() elif activation_type == "relu6": return nn.ReLU6() elif activation_type == "prelu": return nn.PReLU() elif activation_type == "selu": return nn.SELU() elif activation_type == "celu": return nn.CELU() elif activation_type == "gelu": return nn.GELU() elif activation_type == "sigmoid": return nn.Sigmoid() elif activation_type == "softplus": return nn.Softplus() elif activation_type == "softshrink": return nn.Softshrink() elif activation_type == "softsign": return nn.Softsign() elif activation_type == "tanh": return nn.Tanh() elif activation_type == "tanhshrink": return nn.Tanhshrink() else: raise ValueError("Unknown activation type {}".format(activation_type)) class MaskedNorm(nn.Module): """ Original Code from: https://discuss.pytorch.org/t/batchnorm-for-different-sized-samples-in-batch/44251/8 """ def __init__(self, norm_type, num_groups, num_features): super().__init__() self.norm_type = norm_type if self.norm_type == "batch": self.norm = nn.BatchNorm1d(num_features=num_features) elif self.norm_type == "group": self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=num_features) elif self.norm_type == "layer": self.norm = nn.LayerNorm(normalized_shape=num_features) else: raise ValueError("Unsupported Normalization Layer") self.num_features = num_features def forward(self, x: Tensor, mask: Tensor): if self.training: reshaped = x.reshape([-1, self.num_features]) reshaped_mask = mask.reshape([-1, 1]) > 0 selected = torch.masked_select(reshaped, reshaped_mask).reshape( [-1, self.num_features] ) batch_normed = self.norm(selected) scattered = reshaped.masked_scatter(reshaped_mask, batch_normed) return scattered.reshape([x.shape[0], -1, self.num_features]) else: reshaped = x.reshape([-1, self.num_features]) batched_normed = self.norm(reshaped) return batched_normed.reshape([x.shape[0], -1, self.num_features]) # TODO (Cihan): Spatial and Word Embeddings are pretty much the same # We might as well convert them into a single module class. # Only difference is the lut vs linear layers. class Embeddings(nn.Module): """ Simple embeddings class """ # pylint: disable=unused-argument def __init__( self, embedding_dim: int = 64, num_heads: int = 8, scale: bool = False, scale_factor: float = None, norm_type: str = None, activation_type: str = None, vocab_size: int = 0, padding_idx: int = 1, freeze: bool = False, **kwargs ): """ Create new embeddings for the vocabulary. Use scaling for the Transformer. :param embedding_dim: :param scale: :param vocab_size: :param padding_idx: :param freeze: freeze the embeddings during training """ super().__init__() self.embedding_dim = embedding_dim self.vocab_size = vocab_size self.lut = nn.Embedding(vocab_size, self.embedding_dim, padding_idx=padding_idx) self.norm_type = norm_type if self.norm_type: self.norm = MaskedNorm( norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim ) self.activation_type = activation_type if self.activation_type: self.activation = get_activation(activation_type) self.scale = scale if self.scale: if scale_factor: self.scale_factor = scale_factor else: self.scale_factor = math.sqrt(self.embedding_dim) if freeze: freeze_params(self) # pylint: disable=arguments-differ def forward(self, x: Tensor, mask: Tensor = None) -> Tensor: """ Perform lookup for input `x` in the embedding table. :param mask: token masks :param x: index in the vocabulary :return: embedded representation for `x` """ x = self.lut(x) if self.norm_type: x = self.norm(x, mask) if self.activation_type: x = self.activation(x) if self.scale: return x * self.scale_factor else: return x def __repr__(self): return "%s(embedding_dim=%d, vocab_size=%d)" % ( self.__class__.__name__, self.embedding_dim, self.vocab_size, ) class SpatialEmbeddings(nn.Module): """ Simple Linear Projection Layer (For encoder outputs to predict glosses) """ # pylint: disable=unused-argument def __init__( self, embedding_dim: int, input_size: int, num_heads: int, freeze: bool = False, norm_type: str = "batch", activation_type: str = "softsign", scale: bool = False, scale_factor: float = None, **kwargs ): """ Create new embeddings for the vocabulary. Use scaling for the Transformer. :param embedding_dim: :param input_size: :param freeze: freeze the embeddings during training """ super().__init__() self.embedding_dim = embedding_dim self.input_size = input_size self.ln = nn.Linear(self.input_size, self.embedding_dim) self.norm_type = norm_type if self.norm_type: self.norm = MaskedNorm( norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim ) self.activation_type = activation_type if self.activation_type: self.activation = get_activation(activation_type) self.scale = scale if self.scale: if scale_factor: self.scale_factor = scale_factor else: self.scale_factor = math.sqrt(self.embedding_dim) if freeze: freeze_params(self) # pylint: disable=arguments-differ def forward(self, x: Tensor, mask: Tensor) -> Tensor: """ :param mask: frame masks :param x: input frame features :return: embedded representation for `x` """ x = self.ln(x) if self.norm_type: x = self.norm(x, mask) if self.activation_type: x = self.activation(x) if self.scale: return x * self.scale_factor else: return x def __repr__(self): return "%s(embedding_dim=%d, input_size=%d)" % ( self.__class__.__name__, self.embedding_dim, self.input_size, ) ================================================ FILE: PBnet/src/models/architectures/tools/resnet.py ================================================ import torch import torch.nn as nn def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: raise ValueError('BasicBlock only supports groups=1 and base_width=64') if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class ResNet(nn.Module): def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None,input_channel = 3): super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: # each element in the tuple indicates if we should replace # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError("replace_stride_with_dilation should be None " "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d(input_channel, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) def _make_layer(self, block, planes, blocks, stride=1, dilate=False): norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer)) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return x def _resnet(arch, block, layers, pretrained, progress, **kwargs): model = ResNet(block, layers, **kwargs) return model def resnet34(pretrained=False, progress=True, **kwargs): r"""ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) ================================================ FILE: PBnet/src/models/architectures/tools/transformer_layers.py ================================================ # -*- coding: utf-8 -*- import math import torch import torch.nn as nn from torch import Tensor # Took from https://github.com/joeynmt/joeynmt/blob/fb66afcbe1beef9acd59283bcc084c4d4c1e6343/joeynmt/transformer_layers.py # pylint: disable=arguments-differ class MultiHeadedAttention(nn.Module): """ Multi-Head Attention module from "Attention is All You Need" Implementation modified from OpenNMT-py. https://github.com/OpenNMT/OpenNMT-py """ def __init__(self, num_heads: int, size: int, dropout: float = 0.1): """ Create a multi-headed attention layer. :param num_heads: the number of heads :param size: model size (must be divisible by num_heads) :param dropout: probability of dropping a unit """ super().__init__() assert size % num_heads == 0 self.head_size = head_size = size // num_heads self.model_size = size self.num_heads = num_heads self.k_layer = nn.Linear(size, num_heads * head_size) self.v_layer = nn.Linear(size, num_heads * head_size) self.q_layer = nn.Linear(size, num_heads * head_size) self.output_layer = nn.Linear(size, size) self.softmax = nn.Softmax(dim=-1) self.dropout = nn.Dropout(dropout) def forward(self, k: Tensor, v: Tensor, q: Tensor, mask: Tensor = None): """ Computes multi-headed attention. :param k: keys [B, M, D] with M being the sentence length. :param v: values [B, M, D] :param q: query [B, M, D] :param mask: optional mask [B, 1, M] or [B, M, M] :return: """ batch_size = k.size(0) num_heads = self.num_heads # project the queries (q), keys (k), and values (v) k = self.k_layer(k) v = self.v_layer(v) q = self.q_layer(q) # reshape q, k, v for our computation to [batch_size, num_heads, ..] k = k.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2) v = v.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2) q = q.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2) # compute scores q = q / math.sqrt(self.head_size) # batch x num_heads x query_len x key_len scores = torch.matmul(q, k.transpose(2, 3)) # torch.Size([48, 8, 183, 183]) # apply the mask (if we have one) # we add a dimension for the heads to it below: [B, 1, 1, M] if mask is not None: scores = scores.masked_fill(~mask.unsqueeze(1), float('-inf')) # apply attention dropout and compute context vectors. attention = self.softmax(scores) attention = self.dropout(attention) # torch.Size([48, 8, 183, 183]) [bs, nheads, time, time] (for decoding) # v: torch.Size([48, 8, 183, 32]) (32 is 256/8) # get context vector (select values with attention) and reshape # back to [B, M, D] context = torch.matmul(attention, v) # torch.Size([48, 8, 183, 32]) context = context.transpose(1, 2).contiguous().view( batch_size, -1, num_heads * self.head_size) # torch.Size([48, 183, 256]) put back to 256 (combine the heads) output = self.output_layer(context) # torch.Size([48, 183, 256]): 1 output per time step return output # pylint: disable=arguments-differ class PositionwiseFeedForward(nn.Module): """ Position-wise Feed-forward layer Projects to ff_size and then back down to input_size. """ def __init__(self, input_size, ff_size, dropout=0.1): """ Initializes position-wise feed-forward layer. :param input_size: dimensionality of the input. :param ff_size: dimensionality of intermediate representation :param dropout: """ super().__init__() self.layer_norm = nn.LayerNorm(input_size, eps=1e-6) self.pwff_layer = nn.Sequential( nn.Linear(input_size, ff_size), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ff_size, input_size), nn.Dropout(dropout), ) def forward(self, x): x_norm = self.layer_norm(x) return self.pwff_layer(x_norm) + x # pylint: disable=arguments-differ class PositionalEncoding(nn.Module): """ Pre-compute position encodings (PE). In forward pass, this adds the position-encodings to the input for as many time steps as necessary. Implementation based on OpenNMT-py. https://github.com/OpenNMT/OpenNMT-py """ def __init__(self, size: int = 0, max_len: int = 5000): """ Positional Encoding with maximum length max_len :param size: :param max_len: :param dropout: """ if size % 2 != 0: raise ValueError("Cannot use sin/cos positional encoding with " "odd dim (got dim={:d})".format(size)) pe = torch.zeros(max_len, size) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp((torch.arange(0, size, 2, dtype=torch.float) * -(math.log(10000.0) / size))) pe[:, 0::2] = torch.sin(position.float() * div_term) pe[:, 1::2] = torch.cos(position.float() * div_term) pe = pe.unsqueeze(0) # shape: [1, size, max_len] super().__init__() self.register_buffer('pe', pe) self.dim = size def forward(self, emb): """Embed inputs. Args: emb (FloatTensor): Sequence of word vectors ``(seq_len, batch_size, self.dim)`` """ # Add position encodings return emb + self.pe[:, :emb.size(1)] class TransformerEncoderLayer(nn.Module): """ One Transformer encoder layer has a Multi-head attention layer plus a position-wise feed-forward layer. """ def __init__(self, size: int = 0, ff_size: int = 0, num_heads: int = 0, dropout: float = 0.1): """ A single Transformer layer. :param size: :param ff_size: :param num_heads: :param dropout: """ super().__init__() self.layer_norm = nn.LayerNorm(size, eps=1e-6) self.src_src_att = MultiHeadedAttention(num_heads, size, dropout=dropout) self.feed_forward = PositionwiseFeedForward(size, ff_size=ff_size, dropout=dropout) self.dropout = nn.Dropout(dropout) self.size = size # pylint: disable=arguments-differ def forward(self, x: Tensor, mask: Tensor) -> Tensor: """ Forward pass for a single transformer encoder layer. First applies layer norm, then self attention, then dropout with residual connection (adding the input to the result), and then a position-wise feed-forward layer. :param x: layer input :param mask: input mask :return: output tensor """ x_norm = self.layer_norm(x) h = self.src_src_att(x_norm, x_norm, x_norm, mask) h = self.dropout(h) + x o = self.feed_forward(h) return o class TransformerDecoderLayer(nn.Module): """ Transformer decoder layer. Consists of self-attention, source-attention, and feed-forward. """ def __init__(self, size: int = 0, ff_size: int = 0, num_heads: int = 0, dropout: float = 0.1): """ Represents a single Transformer decoder layer. It attends to the source representation and the previous decoder states. :param size: model dimensionality :param ff_size: size of the feed-forward intermediate layer :param num_heads: number of heads :param dropout: dropout to apply to input """ super().__init__() self.size = size self.trg_trg_att = MultiHeadedAttention(num_heads, size, dropout=dropout) self.src_trg_att = MultiHeadedAttention(num_heads, size, dropout=dropout) self.feed_forward = PositionwiseFeedForward(size, ff_size=ff_size, dropout=dropout) self.x_layer_norm = nn.LayerNorm(size, eps=1e-6) self.dec_layer_norm = nn.LayerNorm(size, eps=1e-6) self.dropout = nn.Dropout(dropout) # pylint: disable=arguments-differ def forward(self, x: Tensor = None, memory: Tensor = None, src_mask: Tensor = None, trg_mask: Tensor = None) -> Tensor: """ Forward pass of a single Transformer decoder layer. :param x: inputs :param memory: source representations :param src_mask: source mask :param trg_mask: target mask (so as to not condition on future steps) :return: output tensor """ # decoder/target self-attention x_norm = self.x_layer_norm(x) # torch.Size([48, 183, 256]) h1 = self.trg_trg_att(x_norm, x_norm, x_norm, mask=trg_mask) h1 = self.dropout(h1) + x # source-target attention h1_norm = self.dec_layer_norm(h1) # torch.Size([48, 183, 256]) (same for memory) h2 = self.src_trg_att(memory, memory, h1_norm, mask=src_mask) # final position-wise feed-forward layer o = self.feed_forward(self.dropout(h2) + h1) return o ================================================ FILE: PBnet/src/models/architectures/tools/util.py ================================================ from torch import nn import torch.nn.functional as F import torch from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d from sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d from src.models.architectures.tools.resnet import resnet34 class MyResNet34(nn.Module): def __init__(self,embedding_dim,input_channel = 3): super(MyResNet34, self).__init__() self.resnet = resnet34(norm_layer = BatchNorm2d,num_classes=embedding_dim,input_channel = input_channel) def forward(self, x): return self.resnet(x) ================================================ FILE: PBnet/src/models/architectures/transformer.py ================================================ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x): # not used in the final model x = x + self.pe[:x.shape[0], :] return self.dropout(x) class RelativePositionBias(nn.Module): def __init__( self, heads=8, num_buckets=32, max_distance=128 ): super().__init__() self.num_buckets = num_buckets self.max_distance = max_distance self.relative_attention_bias = nn.Embedding(num_buckets, heads) @staticmethod def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): ret = 0 n = -relative_position num_buckets //= 2 ret += (n < 0).long() * num_buckets n = torch.abs(n) max_exact = num_buckets // 2 is_small = n < max_exact val_if_large = max_exact + ( torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).long() val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) ret += torch.where(is_small, n, val_if_large) return ret def forward(self, n, device): q_pos = torch.arange(n, dtype=torch.long, device=device) k_pos = torch.arange(n, dtype=torch.long, device=device) rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') rp_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance) values = self.relative_attention_bias(rp_bucket) return rearrange(values, 'i j h -> h i j') # only for ablation / not used in the final model class TimeEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super(TimeEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) def forward(self, x, mask, lengths): time = mask * 1/(lengths[..., None]-1) time = time[:, None] * torch.arange(time.shape[1], device=x.device)[None, :] time = time[:, 0].T # add the time encoding x = x + time[..., None] return self.dropout(x) class Encoder_TRANSFORMER(nn.Module): def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=7, pose_latent_dim=64, audio_latent_dim=256, ff_size=1024, num_layers=4, num_heads=4, dropout=0.1, ablation=None, activation="gelu", **kargs): super().__init__() self.modeltype = modeltype self.pos_dim = pos_dim self.num_frames = num_frames self.audio_dim = audio_dim self.pose_latent_dim = pose_latent_dim self.audio_latent_dim = audio_latent_dim self.latent_dim = self.audio_latent_dim + self.pose_latent_dim*2 self.ff_size = ff_size self.num_layers = num_layers self.num_heads = num_heads self.dropout = dropout self.ablation = ablation self.activation = activation # if self.ablation == "average_encoder": # self.mu_layer = nn.Linear(self.latent_dim, self.latent_dim) # self.sigma_layer = nn.Linear(self.latent_dim, self.latent_dim) # else: # self.muQuery = nn.Parameter(torch.randn(self.num_classes, self.latent_dim)) # self.sigmaQuery = nn.Parameter(torch.randn(self.num_classes, self.latent_dim)) # # there's no class of our dataset CREMA/HDTF, so noly dont need to use nn.parameter self.mu_layer = nn.Linear(self.latent_dim, self.audio_latent_dim) self.sigma_layer = nn.Linear(self.latent_dim, self.audio_latent_dim) self.poseEmbedding = nn.Linear(self.pos_dim, self.pose_latent_dim) #6,64 self.firstposeEmbedding = nn.Linear(self.pos_dim, self.pose_latent_dim) #6,64 self.audioEmbedding = nn.Linear(self.audio_dim, self.audio_latent_dim) #1024, 256 self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) # self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, num_layers=self.num_layers) def forward(self, batch): ''' x: 6-dim pos, (bs, max_num_frames, 6) y: 1024-dim audio embbeding, (bs, max_num_frames, 1024) ''' x, y, mask = batch["x"], batch["y"], batch["mask"] # bs, njoints, nfeats, nframes = x.shape # x = x.permute((3, 0, 1, 2)).reshape(nframes, bs, njoints*nfeats) x_ref = x[:,0,:].unsqueeze(dim=1) # The pose information of the first frame(refrence img) x = x-x_ref.repeat(1,x.size(1),1) # bs, nf, 6 Obtain the difference from the first frame batch['x_delta'] = x x_ref = x_ref.permute((1,0,2)) #1, bs, 6 x = x.permute((1, 0, 2)) #nf, bs, 6 y = y.permute((1, 0, 2)) #nf, bs, 1024 # embedding of the pose/audio x_ref = self.firstposeEmbedding(x_ref).repeat(x.size(0),1,1) #nf, bs, 64 x = self.poseEmbedding(x) #nf, bs, 64 y = self.audioEmbedding(y) #nf, bs, 256 x = torch.cat([x_ref, x, y],dim=-1) # nf, bs, 64+64+256 # only use the "average_encoder" mode # add positional encoding x = self.sequence_pos_encoder(x) # transformer layers final = self.seqTransEncoder(x, src_key_padding_mask=~mask) #nu_frames, bs, 64+64+256 # get the average of the output z = final# final.mean(axis=0) # nf, bs, 64+64+256 # extract mu and logvar mu = self.mu_layer(z) # nf, bs, 256 logvar = self.sigma_layer(z) # nf, bs, 256 # logvar = - torch.ones_like(logvar) * 1e10 return {"mu": mu, "logvar": logvar} class Decoder_TRANSFORMER(nn.Module): def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=7, pose_latent_dim=64, audio_latent_dim=256, ff_size=1024, num_layers=4, num_heads=4, dropout=0.1, activation="gelu", ablation=None, **kargs): super().__init__() self.modeltype = modeltype self.pos_dim = pos_dim self.num_frames = num_frames self.audio_dim = audio_dim self.pose_latent_dim = pose_latent_dim self.audio_latent_dim = audio_latent_dim self.latent_dim = self.audio_latent_dim + self.pose_latent_dim*2 self.ff_size = ff_size self.num_layers = num_layers self.num_heads = num_heads self.dropout = dropout self.ablation = ablation self.activation = activation self.firstposeEmbedding = nn.Linear(self.pos_dim, self.pose_latent_dim) #6,64 self.audioEmbedding = nn.Linear(self.audio_dim, self.audio_latent_dim) #1024, 256 self.ztimelinear = nn.Linear(self.audio_latent_dim*2+self.pose_latent_dim, self.pose_latent_dim) #256*2+64,64 # self.input_feats = self.njoints*self.nfeats # # only for ablation / not used in the final model # if self.ablation == "zandtime": # self.ztimelinear = nn.Linear(self.latent_dim + self.num_classes, self.latent_dim) # else: # self.actionBiases = nn.Parameter(torch.randn(1024, self.latent_dim)) # self.actionBiases = nn.Parameter(torch.randn(self.num_classes, self.latent_dim)) # # only for ablation / not used in the final model # if self.ablation == "time_encoding": # self.sequence_pos_encoder = TimeEncoding(self.dropout) # else: # self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) self.sequence_pos_encoder = PositionalEncoding(self.pose_latent_dim, self.dropout) # self.sequence_pos_encoder = TimeEncoding(self.dropout) #time_encoding seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=self.pose_latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=activation) self.seqTransDecoder = nn.TransformerDecoder(seqTransDecoderLayer, num_layers=self.num_layers) self.finallayer = nn.Linear(self.pose_latent_dim, self.pos_dim) def forward(self, batch): ''' z: bs, audio_latent_dim(256) y: bs, num_frames, 1024 mask: bs, num_frames lengths: [num_frames,...] ''' x, z, y, mask, lengths = batch["x"], batch["z"], batch["y"], batch["mask"], batch["lengths"] bs, nframes = mask.shape # first img x_ref = x[:,0,:].unsqueeze(dim=1) #bs, 1, 64 x_ref = self.firstposeEmbedding(x_ref.repeat(1, nframes, 1)) #bs, nf, 64 y = self.audioEmbedding(y) #bs, num_frames, 256 z = z.permute(1, 0, 2) #z = z.unsqueeze(dim=1).repeat(1, nframes, 1) #bs, num_frames, 256 z = torch.cat([x_ref, z, y], dim=-1) # bs, num_frames, 256*2+64 z = self.ztimelinear(z) z = z.permute((1, 0, 2)) # nf, bs, 64 pose_latent_dim = z.shape[2] # z = z[None] # sequence of size 1 # # only for ablation / not used in the final model # if self.ablation == "zandtime": # yoh = F.one_hot(y, self.num_classes) # z = torch.cat((z, yoh), axis=1) # z = self.ztimelinear(z) # z = z[None] # sequence of size 1 # else: # # only for ablation / not used in the final model # if self.ablation == "concat_bias": # # sequence of size 2 # z = torch.stack((z, self.actionBiases[y]), axis=0) # else: # # shift the latent noise vector to be the action noise # z = z + self.actionBiases[y.long()] # NEED CHECK # z = z[None] # sequence of size 1 timequeries = torch.zeros(nframes, bs, pose_latent_dim, device=z.device) timequeries = self.sequence_pos_encoder(timequeries) # timequeries = self.sequence_pos_encoder(timequeries, mask, lengths) #time_encoding # # only for ablation / not used in the final model # if self.ablation == "time_encoding": # timequeries = self.sequence_pos_encoder(timequeries, mask, lengths) # else: # timequeries = self.sequence_pos_encoder(timequeries) # num_frames, bs, 64 output = self.seqTransDecoder(tgt=timequeries, memory=z, tgt_key_padding_mask=~mask) output = self.finallayer(output).reshape(nframes, bs, self.pos_dim) # num_frames, bs, 6 # output = self.finallayer(output).reshape(nframes, bs, njoints, nfeats) # zero for padded area output[~mask.T] = 0 #nf, bs, 6 output = output.permute(1,0,2)#bs, nf, 6 batch["output"] = output return batch ================================================ FILE: PBnet/src/models/architectures/transformerdecoder.py ================================================ import copy from typing import Optional, Any, Union, Callable import torch from torch import Tensor from torch.nn.functional import dropout from torch.nn import functional as F from torch.nn.modules import Module from torch.nn.modules.container import ModuleList from torch.nn.modules.activation import MultiheadAttention as MultiheadAttention0 from torch.nn.init import xavier_uniform_ from torch.nn.modules.dropout import Dropout from torch.nn.modules.linear import Linear from torch.nn.modules.normalization import LayerNorm import torch.nn as nn class MultiheadAttention(nn.Module): def __init__(self, embed_size, heads, dropout = None, batch_first = None): super(MultiheadAttention, self).__init__() self.embed_size = embed_size self.heads = heads self.head_dim = embed_size // heads assert ( self.head_dim * heads == embed_size ), "Embedding size needs to be divisible by heads" self.values = nn.Linear(embed_size, self.head_dim * heads, bias=False) self.keys = nn.Linear(embed_size, self.head_dim * heads, bias=False) self.queries = nn.Linear(embed_size, self.head_dim * heads, bias=False) self.fc_out = nn.Linear(heads * self.head_dim, embed_size) self.dropout = nn.Dropout(dropout) def sinusoidal_position_embedding(self, batch_size, nums_head, max_len, output_dim, device): position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1) ids = torch.arange(0, output_dim // 2, dtype=torch.float) theta = torch.pow(10000, -2 * ids / output_dim) embeddings = position * theta embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape)))) embeddings = torch.reshape(embeddings, (batch_size, nums_head, max_len, output_dim)) embeddings = embeddings.to(device) return embeddings def RoPE(self, q, k): # q,k: (B, H, L, D) batch_size = q.shape[0] nums_head = q.shape[1] max_len = q.shape[2] output_dim = q.shape[-1] pos_emb = self.sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device) cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) # q,k: (B, H, L, D) q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1) q2 = q2.reshape(q.shape) q = q * cos_pos + q2 * sin_pos k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1) k2 = k2.reshape(k.shape) k = k * cos_pos + k2 * sin_pos return q, k def forward(self, q, k, v, attn_mask = None, key_padding_mask=None, need_weights = None): B = q.shape[0] use_rope = True len = q.shape[1] values = self.values(v).view(B, len, self.heads, self.head_dim) keys = self.keys(k).view(B, len, self.heads, self.head_dim) queries = self.queries(q).view(B, len, self.heads, self.head_dim) values = values.permute(0, 2, 1, 3) keys = keys.permute(0, 2, 1, 3) queries = queries.permute(0, 2, 1, 3) # [B, H, L, D] if use_rope: queries, keys = self.RoPE(queries, keys) energy = torch.matmul(queries, keys.permute(0, 1, 3, 2)) # if attn_mask is not None: # energy = energy.masked_fill(attn_mask == 1, float("-1e20")) if attn_mask is None: attn_mask = 0 attention = F.softmax(energy / (self.head_dim ** (1 / 2) + attn_mask), dim=-1) attention = self.dropout(attention) out = torch.matmul(attention, values) out = out.permute(0, 2, 1, 3).contiguous().view(B, len, self.heads * self.head_dim) out = self.fc_out(out) return out class Transformer(Module): r"""A transformer model. User is able to modify the attributes as needed. The architecture is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. Args: d_model: the number of expected features in the encoder/decoder inputs (default=512). nhead: the number of heads in the multiheadattention models (default=8). num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6). num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6). dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). activation: the activation function of encoder/decoder intermediate layer, can be a string ("relu" or "gelu") or a unary callable. Default: relu custom_encoder: custom encoder (default=None). custom_decoder: custom decoder (default=None). layer_norm_eps: the eps value in layer normalization components (default=1e-5). batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False`` (seq, batch, feature). norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before other attention and feedforward operations, otherwise after. Default: ``False`` (after). Examples:: >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12) >>> src = torch.rand((10, 32, 512)) >>> tgt = torch.rand((20, 32, 512)) >>> out = transformer_model(src, tgt) Note: A full example to apply nn.Transformer module for the word language model is available in https://github.com/pytorch/examples/tree/master/word_language_model """ def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6, num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None, layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super(Transformer, self).__init__() if custom_encoder is not None: self.encoder = custom_encoder else: encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, layer_norm_eps, batch_first, norm_first, **factory_kwargs) encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) if custom_decoder is not None: self.decoder = custom_decoder else: decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, layer_norm_eps, batch_first, norm_first, **factory_kwargs) decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm) self._reset_parameters() self.d_model = d_model self.nhead = nhead self.batch_first = batch_first def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: r"""Take in and process masked source/target sequences. Args: src: the sequence to the encoder (required). tgt: the sequence to the decoder (required). src_mask: the additive mask for the src sequence (optional). tgt_mask: the additive mask for the tgt sequence (optional). memory_mask: the additive mask for the encoder output (optional). src_key_padding_mask: the ByteTensor mask for src keys per batch (optional). tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional). memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional). Shape: - src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or `(N, S, E)` if `batch_first=True`. - tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or `(N, T, E)` if `batch_first=True`. - src_mask: :math:`(S, S)`. - tgt_mask: :math:`(T, T)`. - memory_mask: :math:`(T, S)`. - src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`. - tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`. - memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`. Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or `(N, T, E)` if `batch_first=True`. Note: Due to the multi-head attention architecture in the transformer model, the output sequence length of a transformer is same as the input sequence (i.e. target) length of the decode. where S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number Examples: >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) """ is_batched = src.dim() == 3 if not self.batch_first and src.size(1) != tgt.size(1) and is_batched: raise RuntimeError("the batch number of src and tgt must be equal") elif self.batch_first and src.size(0) != tgt.size(0) and is_batched: raise RuntimeError("the batch number of src and tgt must be equal") if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model: raise RuntimeError("the feature number of src and tgt must be equal to d_model") memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask) output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask) return output @staticmethod def generate_square_subsequent_mask(sz: int) -> Tensor: r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). """ return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1) def _reset_parameters(self): r"""Initiate parameters in the transformer model.""" for p in self.parameters(): if p.dim() > 1: xavier_uniform_(p) class TransformerEncoder(Module): r"""TransformerEncoder is a stack of N encoder layers Args: encoder_layer: an instance of the TransformerEncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). norm: the layer normalization component (optional). Examples:: >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) >>> src = torch.rand(10, 32, 512) >>> out = transformer_encoder(src) """ __constants__ = ['norm'] def __init__(self, encoder_layer, num_layers, norm=None): super(TransformerEncoder, self).__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers self.norm = norm def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: r"""Pass the input through the encoder layers in turn. Args: src: the sequence to the encoder (required). mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). Shape: see the docs in Transformer class. """ output = src for mod in self.layers: output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask) if self.norm is not None: output = self.norm(output) return output class TransformerDecoder(Module): r"""TransformerDecoder is a stack of N decoder layers Args: decoder_layer: an instance of the TransformerDecoderLayer() class (required). num_layers: the number of sub-decoder-layers in the decoder (required). norm: the layer normalization component (optional). Examples:: >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) >>> memory = torch.rand(10, 32, 512) >>> tgt = torch.rand(20, 32, 512) >>> out = transformer_decoder(tgt, memory) """ __constants__ = ['norm'] def __init__(self, decoder_layer, num_layers, norm=None): super(TransformerDecoder, self).__init__() self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers self.norm = norm def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: r"""Pass the inputs (and mask) through the decoder layer in turn. Args: tgt: the sequence to the decoder (required). memory: the sequence from the last layer of the encoder (required). tgt_mask: the mask for the tgt sequence (optional). memory_mask: the mask for the memory sequence (optional). tgt_key_padding_mask: the mask for the tgt keys per batch (optional). memory_key_padding_mask: the mask for the memory keys per batch (optional). Shape: see the docs in Transformer class. """ output = tgt for mod in self.layers: output = mod(output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask) if self.norm is not None: output = self.norm(output) return output class TransformerEncoderLayer(Module): r"""TransformerEncoderLayer is made up of self-attn and feedforward network. This standard encoder layer is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information Processing Systems, pages 6000-6010. Users may modify or implement in a different way during application. Args: d_model: the number of expected features in the input (required). nhead: the number of heads in the multiheadattention models (required). dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). activation: the activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable. Default: relu layer_norm_eps: the eps value in layer normalization components (default=1e-5). batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False`` (seq, batch, feature). norm_first: if ``True``, layer norm is done prior to attention and feedforward operations, respectivaly. Otherwise it's done after. Default: ``False`` (after). Examples:: >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) >>> src = torch.rand(10, 32, 512) >>> out = encoder_layer(src) Alternatively, when ``batch_first`` is ``True``: >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) >>> src = torch.rand(32, 10, 512) >>> out = encoder_layer(src) """ __constants__ = ['batch_first', 'norm_first'] def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super(TransformerEncoderLayer, self).__init__() self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs) # Implementation of Feedforward model self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs) self.dropout = Dropout(dropout) self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs) self.norm_first = norm_first self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.dropout1 = Dropout(dropout) self.dropout2 = Dropout(dropout) # Legacy string support for activation function. if isinstance(activation, str): self.activation = _get_activation_fn(activation) else: self.activation = activation def __setstate__(self, state): if 'activation' not in state: state['activation'] = F.relu super(TransformerEncoderLayer, self).__setstate__(state) def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). Shape: see the docs in Transformer class. """ # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf x = src if self.norm_first: x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) x = x + self._ff_block(self.norm2(x)) else: x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) x = self.norm2(x + self._ff_block(x)) return x # self-attention block def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor: x = self.self_attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)[0] return self.dropout1(x) # feed forward block def _ff_block(self, x: Tensor) -> Tensor: x = self.linear2(self.dropout(self.activation(self.linear1(x)))) return self.dropout2(x) class TransformerDecoderLayer(Module): r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. This standard decoder layer is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information Processing Systems, pages 6000-6010. Users may modify or implement in a different way during application. Args: d_model: the number of expected features in the input (required). nhead: the number of heads in the multiheadattention models (required). dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). activation: the activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable. Default: relu layer_norm_eps: the eps value in layer normalization components (default=1e-5). batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False`` (seq, batch, feature). norm_first: if ``True``, layer norm is done prior to self attention, multihead attention and feedforward operations, respectivaly. Otherwise it's done after. Default: ``False`` (after). Examples:: >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) >>> memory = torch.rand(10, 32, 512) >>> tgt = torch.rand(20, 32, 512) >>> out = decoder_layer(tgt, memory) Alternatively, when ``batch_first`` is ``True``: >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True) >>> memory = torch.rand(32, 10, 512) >>> tgt = torch.rand(32, 20, 512) >>> out = decoder_layer(tgt, memory) """ __constants__ = ['batch_first', 'norm_first'] def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super(TransformerDecoderLayer, self).__init__() self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, ) self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, ) # Implementation of Feedforward model self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs) self.dropout = Dropout(dropout) self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs) self.norm_first = norm_first self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.dropout1 = Dropout(dropout) self.dropout2 = Dropout(dropout) self.dropout3 = Dropout(dropout) # Legacy string support for activation function. if isinstance(activation, str): self.activation = _get_activation_fn(activation) else: self.activation = activation def __setstate__(self, state): if 'activation' not in state: state['activation'] = F.relu super(TransformerDecoderLayer, self).__setstate__(state) def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: r"""Pass the inputs (and mask) through the decoder layer. Args: tgt: the sequence to the decoder layer (required). memory: the sequence from the last layer of the encoder (required). tgt_mask: the mask for the tgt sequence (optional). memory_mask: the mask for the memory sequence (optional). tgt_key_padding_mask: the mask for the tgt keys per batch (optional). memory_key_padding_mask: the mask for the memory keys per batch (optional). Shape: see the docs in Transformer class. """ # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf x = tgt if self.norm_first: x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask) x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask) x = x + self._ff_block(self.norm3(x)) else: x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask)) x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask)) x = self.norm3(x + self._ff_block(x)) return x # self-attention block def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor: x = self.self_attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)[0] return self.dropout1(x) # multihead attention block def _mha_block(self, x: Tensor, mem: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor: x = self.multihead_attn(x, mem, mem, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)[0] return self.dropout2(x) # feed forward block def _ff_block(self, x: Tensor) -> Tensor: x = self.linear2(self.dropout(self.activation(self.linear1(x)))) return self.dropout3(x) def _get_clones(module, N): return ModuleList([copy.deepcopy(module) for i in range(N)]) def _get_activation_fn(activation): if activation == "relu": return F.relu elif activation == "gelu": return F.gelu raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) ================================================ FILE: PBnet/src/models/architectures/transformerdecoder4.py ================================================ import copy from typing import Optional, Any, Union, Callable import torch from torch import Tensor from torch.nn.functional import dropout from torch.nn import functional as F from torch.nn.modules import Module from torch.nn.modules.container import ModuleList from torch.nn.modules.activation import MultiheadAttention as MultiheadAttention0 from torch.nn.init import xavier_uniform_ from torch.nn.modules.dropout import Dropout from torch.nn.modules.linear import Linear from torch.nn.modules.normalization import LayerNorm import torch.nn as nn from einops import rearrange, repeat, reduce, pack, unpack from torch import einsum from einops_exts import rearrange_many from rotary_embedding_torch import RotaryEmbedding def exists(x): return x is not None class Attention(nn.Module): def __init__( self, dim, heads=4, dim_head=32, rotary_emb=None ): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads hidden_dim = dim_head * heads self.rotary_emb = rotary_emb self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False) self.to_out = nn.Linear(hidden_dim, dim, bias=False) def forward( self, x, pos_bias=None, ): # temperal: 'b (h w) f c' ; spatial : 'b f (h w) c' n, device = x.shape[-2], x.device qkv = self.to_qkv(x).chunk(3, dim=-1) # if exists(focus_present_mask) and focus_present_mask.all(): # # if all batch samples are focusing on present # # it would be equivalent to passing that token's values through to the output # values = qkv[-1] # return self.to_out(values) # split out heads q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads) # scale q = q * self.scale # rotate positions into queries and keys for time attention if exists(self.rotary_emb): q = self.rotary_emb.rotate_queries_or_keys(q) k = self.rotary_emb.rotate_queries_or_keys(k) # similarity sim = einsum('... h i d, ... h j d -> ... h i j', q, k) # relative positional bias if exists(pos_bias): sim = sim + pos_bias # if exists(focus_present_mask) and not (~focus_present_mask).all(): # attend_all_mask = torch.ones((n, n), device=device, dtype=torch.bool) # attend_self_mask = torch.eye(n, device=device, dtype=torch.bool) # mask = torch.where( # rearrange(focus_present_mask, 'b -> b 1 1 1 1'), # rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), # rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), # ) # sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) # numerical stability sim = sim - sim.amax(dim=-1, keepdim=True).detach() attn = sim.softmax(dim=-1) # aggregate values out = einsum('... h i j, ... h j d -> ... h i d', attn, v) out = rearrange(out, '... h n d -> ... n (h d)') return self.to_out(out) class Attention_2(nn.Module): def __init__( self, dim, heads=4, dim_head=32, rotary_emb=None ): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads hidden_dim = dim_head * heads self.rotary_emb = rotary_emb self.to_q = nn.Linear(dim, hidden_dim, bias=False) self.to_k = nn.Linear(dim, hidden_dim, bias=False) self.to_v = nn.Linear(dim, hidden_dim, bias=False) self.to_out = nn.Linear(hidden_dim, dim, bias=False) def forward( self, q, k, v, pos_bias=None, focus_present_mask=None ): # temperal: 'b (h w) f c' ; spatial : 'b f (h w) c' q = self.to_q(q) k = self.to_k(k) v = self.to_v(v) # split out heads q = rearrange(q, '... n (h d) -> ... h n d', h=self.heads) # b, head, fn, c k = rearrange(k, '... n (h d) -> ... h n d', h=self.heads) v = rearrange(v, '... n (h d) -> ... h n d', h=self.heads) # q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads) # scale q = q * self.scale # rotate positions into queries and keys for time attention if exists(self.rotary_emb): q = self.rotary_emb.rotate_queries_or_keys(q) k = self.rotary_emb.rotate_queries_or_keys(k) # similarity sim = einsum('... h i d, ... h j d -> ... h i j', q, k) # relative positional bias if exists(pos_bias): sim = sim + pos_bias # numerical stability sim = sim - sim.amax(dim=-1, keepdim=True).detach() attn = sim.softmax(dim=-1) # aggregate values out = einsum('... h i j, ... h j d -> ... h i d', attn, v) out = rearrange(out, '... h n d -> ... n (h d)') return self.to_out(out) class PositionwiseFeedforwardLayer(nn.Module): def __init__(self, d_model, d_ff, dropout): super(PositionwiseFeedforwardLayer, self).__init__() self.linear1 = nn.Linear(d_model, d_ff) self.linear2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) def forward(self, x): x = F.gelu(self.linear1(x)) x = self.dropout(x) x = self.linear2(x) return x class DecoderLayer(nn.Module): def __init__(self, d_model, num_heads, d_ff, dropout, rotary_emb): super(DecoderLayer, self).__init__() self.self_attn = Attention(dim = d_model, heads = num_heads, rotary_emb = rotary_emb) self.multihead_attn = Attention_2(dim = d_model, heads = num_heads, rotary_emb = rotary_emb) self.ffn = PositionwiseFeedforwardLayer(d_model, d_ff, dropout) self.layer_norm1 = nn.LayerNorm(d_model) self.layer_norm2 = nn.LayerNorm(d_model) self.layer_norm3 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, tgt, memory, tgt_mask=None, memory_mask=None): tgt = self.layer_norm1(tgt + self.dropout(self.self_attn(tgt, tgt_mask))) tgt = self.layer_norm2(tgt + self.dropout(self.multihead_attn(tgt, memory, memory, memory_mask))) tgt = self.layer_norm3(tgt + self.dropout(self.ffn(tgt))) return tgt class TransformerDecoder(nn.Module): def __init__(self, num_layers, d_model, num_heads, dim_feedforward, dropout): super(TransformerDecoder, self).__init__() self.num_layers = num_layers rotary_emb = RotaryEmbedding(min(32, num_heads)) self.decoder_layers = nn.ModuleList([DecoderLayer(d_model = d_model, num_heads = num_heads, d_ff = dim_feedforward, dropout = dropout, rotary_emb = rotary_emb) for _ in range(num_layers)]) def forward(self, tgt, memory, tgt_mask=None, memory_mask=None): output = tgt for layer in self.decoder_layers: output = layer(output, memory, tgt_mask, memory_mask) return output ================================================ FILE: PBnet/src/models/architectures/transformerdecoder5.py ================================================ import copy from typing import Optional, Any, Union, Callable import torch from torch import Tensor from torch.nn.functional import dropout from torch.nn import functional as F from torch.nn.modules import Module from torch.nn.modules.container import ModuleList from torch.nn.init import xavier_uniform_ from torch.nn.modules.dropout import Dropout from torch.nn.modules.linear import Linear from torch.nn.modules.normalization import LayerNorm import torch.nn as nn from einops import rearrange, repeat, reduce, pack, unpack from torch import einsum from einops_exts import rearrange_many from rotary_embedding_torch import RotaryEmbedding def exists(x): return x is not None class Attention(nn.Module): def __init__( self, dim, heads=4, dim_head=32, rotary_emb=None ): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads hidden_dim = dim_head * heads self.rotary_emb = rotary_emb self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False) self.to_out = nn.Linear(hidden_dim, dim, bias=False) def forward( self, x, pos_bias=None, ): # temperal: 'b (h w) f c' ; spatial : 'b f (h w) c' n, device = x.shape[-2], x.device qkv = self.to_qkv(x).chunk(3, dim=-1) # if exists(focus_present_mask) and focus_present_mask.all(): # # if all batch samples are focusing on present # # it would be equivalent to passing that token's values through to the output # values = qkv[-1] # return self.to_out(values) # split out heads q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads) # scale q = q * self.scale # rotate positions into queries and keys for time attention if exists(self.rotary_emb): q = self.rotary_emb.rotate_queries_or_keys(q) k = self.rotary_emb.rotate_queries_or_keys(k) # similarity sim = einsum('... h i d, ... h j d -> ... h i j', q, k) # relative positional bias if exists(pos_bias): sim = sim + pos_bias # if exists(focus_present_mask) and not (~focus_present_mask).all(): # attend_all_mask = torch.ones((n, n), device=device, dtype=torch.bool) # attend_self_mask = torch.eye(n, device=device, dtype=torch.bool) # mask = torch.where( # rearrange(focus_present_mask, 'b -> b 1 1 1 1'), # rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), # rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), # ) # sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) # numerical stability sim = sim - sim.amax(dim=-1, keepdim=True).detach() attn = sim.softmax(dim=-1) # aggregate values out = einsum('... h i j, ... h j d -> ... h i d', attn, v) out = rearrange(out, '... h n d -> ... n (h d)') return self.to_out(out) class Attention_2(nn.Module): def __init__( self, dim, heads=4, dim_head=32, rotary_emb=None ): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads hidden_dim = dim_head * heads self.rotary_emb = rotary_emb self.to_q = nn.Linear(dim, hidden_dim, bias=False) self.to_k = nn.Linear(dim, hidden_dim, bias=False) self.to_v = nn.Linear(dim, hidden_dim, bias=False) self.to_out = nn.Linear(hidden_dim, dim, bias=False) def forward( self, q, k, v, pos_bias=None, focus_present_mask=None ): # temperal: 'b (h w) f c' ; spatial : 'b f (h w) c' q = self.to_q(q) k = self.to_k(k) v = self.to_v(v) # split out heads q = rearrange(q, '... n (h d) -> ... h n d', h=self.heads) # b, head, fn, c k = rearrange(k, '... n (h d) -> ... h n d', h=self.heads) v = rearrange(v, '... n (h d) -> ... h n d', h=self.heads) # q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads) # scale q = q * self.scale # rotate positions into queries and keys for time attention if exists(self.rotary_emb): q = self.rotary_emb.rotate_queries_or_keys(q) k = self.rotary_emb.rotate_queries_or_keys(k) # similarity sim = einsum('... h i d, ... h j d -> ... h i j', q, k) # relative positional bias if exists(pos_bias): sim = sim + pos_bias # numerical stability sim = sim - sim.amax(dim=-1, keepdim=True).detach() attn = sim.softmax(dim=-1) # aggregate values out = einsum('... h i j, ... h j d -> ... h i d', attn, v) out = rearrange(out, '... h n d -> ... n (h d)') return self.to_out(out) class PositionwiseFeedforwardLayer(nn.Module): def __init__(self, d_model, d_ff, dropout): super(PositionwiseFeedforwardLayer, self).__init__() self.linear1 = nn.Linear(d_model, d_ff) self.linear2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) def forward(self, x): x = F.gelu(self.linear1(x)) x = self.dropout(x) x = self.linear2(x) return x class DecoderLayer(nn.Module): def __init__(self, d_model, num_heads, d_ff, dropout, rotary_emb): super(DecoderLayer, self).__init__() self.self_attn = Attention(dim = d_model, heads = num_heads, rotary_emb = rotary_emb) # , rotary_emb = rotary_emb) self.multihead_attn = Attention_2(dim = d_model, heads = num_heads, rotary_emb = rotary_emb) self.ffn = PositionwiseFeedforwardLayer(d_model, d_ff, dropout) self.layer_norm1 = nn.LayerNorm(d_model) self.layer_norm2 = nn.LayerNorm(d_model) self.layer_norm3 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) def forward(self, tgt, memory, tgt_mask=None, memory_mask=None): tgt = self.layer_norm1(tgt + self.dropout1(self.self_attn(tgt, tgt_mask))) tgt = self.layer_norm2(tgt + self.dropout2(self.multihead_attn(tgt, memory, memory, memory_mask))) tgt = self.layer_norm3(tgt + self.dropout3(self.ffn(tgt))) return tgt class TransformerDecoder(nn.Module): def __init__(self, num_layers, d_model, num_heads, dim_feedforward, dropout): super(TransformerDecoder, self).__init__() self.num_layers = num_layers rotary_emb = RotaryEmbedding(min(32, num_heads)) self.decoder_layers = nn.ModuleList([DecoderLayer(d_model = d_model, num_heads = num_heads, d_ff = dim_feedforward, dropout = dropout, rotary_emb = rotary_emb) for _ in range(num_layers)]) def forward(self, tgt, memory, tgt_mask=None, memory_mask=None): output = tgt for layer in self.decoder_layers: output = layer(output, memory, tgt_mask, memory_mask) return output ================================================ FILE: PBnet/src/models/architectures/transformerreemb.py ================================================ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat, reduce, pack, unpack from einops_exts import rearrange_many from torch import einsum from rotary_embedding_torch import RotaryEmbedding import math def exists(x): return x is not None class LayerNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.gamma = nn.Parameter(torch.ones(1, 1, dim)) def forward(self, x): var = torch.var(x, dim=-1, unbiased=False, keepdim=True) mean = torch.mean(x, dim=-1, keepdim=True) return (x - mean) / (var + self.eps).sqrt() * self.gamma class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = LayerNorm(dim) def forward(self, x, **kwargs): x = self.norm(x) return self.fn(x, **kwargs) class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, *args, **kwargs): return self.fn(x, *args, **kwargs) + x class EinopsToAndFrom(nn.Module): def __init__(self, from_einops, to_einops, fn): super().__init__() self.from_einops = from_einops self.to_einops = to_einops self.fn = fn def forward(self, x, **kwargs): shape = x.shape reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape))) x = rearrange(x, f'{self.from_einops} -> {self.to_einops}') x = self.fn(x, **kwargs) x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs) return x class Attention(nn.Module): def __init__( self, dim, heads=4, dim_head=32, rotary_emb=None ): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads hidden_dim = dim_head * heads self.rotary_emb = rotary_emb self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False) self.to_out = nn.Linear(hidden_dim, dim, bias=False) def forward( self, x, pos_bias=None, focus_present_mask=None ): # temperal: 'b (h w) f c' ; spatial : 'b f (h w) c' n, device = x.shape[-2], x.device qkv = self.to_qkv(x).chunk(3, dim=-1) if exists(focus_present_mask) and focus_present_mask.all(): # if all batch samples are focusing on present # it would be equivalent to passing that token's values through to the output values = qkv[-1] return self.to_out(values) # split out heads q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads) # scale q = q * self.scale # rotate positions into queries and keys for time attention if exists(self.rotary_emb): q = self.rotary_emb.rotate_queries_or_keys(q) k = self.rotary_emb.rotate_queries_or_keys(k) # similarity sim = einsum('... h i d, ... h j d -> ... h i j', q, k) # relative positional bias if exists(pos_bias): sim = sim + pos_bias if exists(focus_present_mask) and not (~focus_present_mask).all(): attend_all_mask = torch.ones((n, n), device=device, dtype=torch.bool) attend_self_mask = torch.eye(n, device=device, dtype=torch.bool) mask = torch.where( rearrange(focus_present_mask, 'b -> b 1 1 1 1'), rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), ) sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) # numerical stability sim = sim - sim.amax(dim=-1, keepdim=True).detach() attn = sim.softmax(dim=-1) # aggregate values out = einsum('... h i j, ... h j d -> ... h i d', attn, v) out = rearrange(out, '... h n d -> ... n (h d)') return self.to_out(out) class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=20000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x): # not used in the final model x = x + self.pe[:x.shape[0], :] return self.dropout(x) class RelativePositionBias(nn.Module): def __init__( self, heads=8, num_buckets=32, max_distance=128 ): super().__init__() self.num_buckets = num_buckets self.max_distance = max_distance self.relative_attention_bias = nn.Embedding(num_buckets, heads) @staticmethod def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): ret = 0 n = -relative_position num_buckets //= 2 ret += (n < 0).long() * num_buckets n = torch.abs(n) max_exact = num_buckets // 2 is_small = n < max_exact val_if_large = max_exact + ( torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).long() val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) ret += torch.where(is_small, n, val_if_large) return ret def forward(self, n, device, eval = False): q_pos = torch.arange(n, dtype=torch.long, device=device) k_pos = torch.arange(n, dtype=torch.long, device=device) rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') rp_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance) if True: mask = - (((rel_pos > 32) + (rel_pos < -32)) * (1e8)) values = self.relative_attention_bias(rp_bucket) return rearrange(values, 'i j h -> h i j') + mask else: values = self.relative_attention_bias(rp_bucket) return rearrange(values, 'i j h -> h i j') # only for ablation / not used in the final model class TimeEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super(TimeEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) def forward(self, x, mask, lengths): time = mask * 1/(lengths[..., None]-1) time = time[:, None] * torch.arange(time.shape[1], device=x.device)[None, :] time = time[:, 0].T # add the time encoding x = x + time[..., None] return self.dropout(x) class Encoder_TRANSFORMERREEMB(nn.Module): def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=7, pose_latent_dim=64, audio_latent_dim=256, ff_size=1024, num_layers=4, num_heads=4, dropout=0.1, ablation=None, activation="gelu", **kargs): super().__init__() self.modeltype = modeltype self.pos_dim = pos_dim self.num_frames = num_frames self.audio_dim = audio_dim self.pose_latent_dim = pose_latent_dim self.audio_latent_dim = audio_latent_dim self.latent_dim = self.audio_latent_dim + self.pose_latent_dim*2 self.ff_size = ff_size self.num_layers = num_layers self.num_heads = num_heads self.dropout = dropout self.ablation = ablation self.activation = activation # if self.ablation == "average_encoder": # self.mu_layer = nn.Linear(self.latent_dim, self.latent_dim) # self.sigma_layer = nn.Linear(self.latent_dim, self.latent_dim) # else: # self.muQuery = nn.Parameter(torch.randn(self.num_classes, self.latent_dim)) # self.sigmaQuery = nn.Parameter(torch.randn(self.num_classes, self.latent_dim)) # # there's no class of our dataset CREMA/HDTF, so noly dont need to use nn.parameter self.mu_layer = nn.Linear(self.latent_dim, self.audio_latent_dim) self.sigma_layer = nn.Linear(self.latent_dim, self.audio_latent_dim) self.poseEmbedding = nn.Linear(self.pos_dim, self.pose_latent_dim) #6,64 self.firstposeEmbedding = nn.Linear(self.pos_dim, self.pose_latent_dim) #6,64 self.audioEmbedding = nn.Linear(self.audio_dim, self.audio_latent_dim) #1024, 256 self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) # self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, num_layers=self.num_layers) def forward(self, batch): ''' x: 6-dim pos, (bs, max_num_frames, 6) y: 1024-dim audio embbeding, (bs, max_num_frames, 1024) ''' x, y, mask = batch["x"], batch["y"], batch["mask"] # bs, njoints, nfeats, nframes = x.shape # x = x.permute((3, 0, 1, 2)).reshape(nframes, bs, njoints*nfeats) x_ref = x[:,0,:].unsqueeze(dim=1) # The pose information of the first frame(refrence img) x = x-x_ref.repeat(1,x.size(1),1) # bs, nf, 6 Obtain the difference from the first frame batch['x_delta'] = x x_ref = x_ref.permute((1,0,2)) #1, bs, 6 x = x.permute((1, 0, 2)) #nf, bs, 6 y = y.permute((1, 0, 2)) #nf, bs, 1024 # embedding of the pose/audio x_ref = self.firstposeEmbedding(x_ref).repeat(x.size(0),1,1) #nf, bs, 64 x = self.poseEmbedding(x) #nf, bs, 64 y = self.audioEmbedding(y) #nf, bs, 256 x = torch.cat([x_ref, x, y],dim=-1) # nf, bs, 64+64+256 # only use the "average_encoder" mode # add positional encoding x = self.sequence_pos_encoder(x) # transformer layers final = self.seqTransEncoder(x, src_key_padding_mask=~mask) #nu_frames, bs, 64+64+256 # get the average of the output z = final# final.mean(axis=0) # nf, bs, 64+64+256 # extract mu and logvar mu = self.mu_layer(z) # nf, bs, 256 logvar = self.sigma_layer(z) # nf, bs, 256 # logvar = - torch.ones_like(logvar) * 1e10 return {"mu": mu, "logvar": logvar} class Decoder_TRANSFORMERREEMB(nn.Module): def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=7, pose_latent_dim=64, audio_latent_dim=256, ff_size=1024, num_layers=4, num_heads=4, dropout=0.1, activation="gelu", ablation=None, num_buckets = 32, max_distance = 32,**kargs): super().__init__() self.modeltype = modeltype self.pos_dim = pos_dim self.num_frames = num_frames self.audio_dim = audio_dim self.pose_latent_dim = pose_latent_dim self.audio_latent_dim = audio_latent_dim self.latent_dim = self.audio_latent_dim + self.pose_latent_dim*2 self.ff_size = ff_size self.num_layers = num_layers self.num_heads = num_heads self.dropout = dropout self.ablation = ablation self.activation = activation self.firstposeEmbedding = nn.Linear(self.pos_dim, self.pose_latent_dim) #6,64 self.audioEmbedding = nn.Linear(self.audio_dim, self.audio_latent_dim) #1024, 256 self.ztimelinear = nn.Linear(self.audio_latent_dim*2+self.pose_latent_dim, self.pose_latent_dim) #256*2+64,64 self.init_proj = nn.Linear(self.pose_latent_dim, self.pose_latent_dim) # self.input_feats = self.njoints*self.nfeats # # only for ablation / not used in the final model # if self.ablation == "zandtime": # self.ztimelinear = nn.Linear(self.latent_dim + self.num_classes, self.latent_dim) # else: # self.actionBiases = nn.Parameter(torch.randn(1024, self.latent_dim)) # self.actionBiases = nn.Parameter(torch.randn(self.num_classes, self.latent_dim)) # # only for ablation / not used in the final model # if self.ablation == "time_encoding": # self.sequence_pos_encoder = TimeEncoding(self.dropout) # else: # self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) self.sequence_pos_encoder = PositionalEncoding(self.pose_latent_dim, self.dropout) rotary_emb = RotaryEmbedding(min(32, num_heads)) self.time_rel_pos_bias = RelativePositionBias(heads=num_heads, num_buckets=num_buckets, max_distance=max_distance) temporal_attn = lambda dim: EinopsToAndFrom('l b c', 'b l c', # len, b, c Attention(dim, heads=num_heads, dim_head=32, rotary_emb=rotary_emb)) self.init_temporal_attn = Residual(PreNorm(self.pose_latent_dim, temporal_attn(self.pose_latent_dim))) # self.sequence_pos_encoder = TimeEncoding(self.dropout) #time_encoding seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=self.pose_latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=activation) self.seqTransDecoder = nn.TransformerDecoder(seqTransDecoderLayer, num_layers=self.num_layers) self.finallayer = nn.Linear(self.pose_latent_dim, self.pos_dim) def forward(self, batch): ''' z: bs, audio_latent_dim(256) y: bs, num_frames, 1024 mask: bs, num_frames lengths: [num_frames,...] ''' x, z, y, mask, lengths = batch["x"], batch["z"], batch["y"], batch["mask"], batch["lengths"] bs, nframes = mask.shape # first img x_ref = x[:,0,:].unsqueeze(dim=1) #bs, 1, 64 x_ref = self.firstposeEmbedding(x_ref.repeat(1, nframes, 1)) #bs, nf, 64 y = self.audioEmbedding(y) #bs, num_frames, 256 z = z.permute(1, 0, 2) #z = z.unsqueeze(dim=1).repeat(1, nframes, 1) #bs, num_frames, 256 z = torch.cat([x_ref, z, y], dim=-1) # bs, num_frames, 256*2+64 z = self.ztimelinear(z) z = z.permute((1, 0, 2)) # nf, bs, 64 pose_latent_dim = z.shape[2] # z = z[None] # sequence of size 1 # # only for ablation / not used in the final model # if self.ablation == "zandtime": # yoh = F.one_hot(y, self.num_classes) # z = torch.cat((z, yoh), axis=1) # z = self.ztimelinear(z) # z = z[None] # sequence of size 1 # else: # # only for ablation / not used in the final model # if self.ablation == "concat_bias": # # sequence of size 2 # z = torch.stack((z, self.actionBiases[y]), axis=0) # else: # # shift the latent noise vector to be the action noise # z = z + self.actionBiases[y.long()] # NEED CHECK # z = z[None] # sequence of size 1 timequeries = torch.zeros(nframes, bs, pose_latent_dim, device=z.device) # len, b, c timequeries = self.sequence_pos_encoder(timequeries) time_rel_pos_bias = self.time_rel_pos_bias(timequeries.shape[0], device=x.device) timequeries = self.init_proj(timequeries) timequeries = self.init_temporal_attn(timequeries, pos_bias=time_rel_pos_bias) # timequeries = self.sequence_pos_encoder(timequeries, mask, lengths) #time_encoding # # only for ablation / not used in the final model # if self.ablation == "time_encoding": # timequeries = self.sequence_pos_encoder(timequeries, mask, lengths) # else: # timequeries = self.sequence_pos_encoder(timequeries) # num_frames, bs, 64 output = self.seqTransDecoder(tgt=timequeries, memory=z, tgt_mask=time_rel_pos_bias.repeat(bs, 1, 1), tgt_key_padding_mask=~mask) output = self.finallayer(output).reshape(nframes, bs, self.pos_dim) # num_frames, bs, 6 # output = self.finallayer(output).reshape(nframes, bs, njoints, nfeats) # zero for padded area output[~mask.T] = 0 #nf, bs, 6 output = output.permute(1,0,2)#bs, nf, 6 batch["output"] = output return batch ================================================ FILE: PBnet/src/models/architectures/transformerreemb5.py ================================================ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat, reduce, pack, unpack from einops_exts import rearrange_many from torch import einsum from rotary_embedding_torch import RotaryEmbedding import math from src.models.architectures.transformerdecoder4 import * def exists(x): return x is not None class LayerNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.gamma = nn.Parameter(torch.ones(1, 1, dim)) def forward(self, x): var = torch.var(x, dim=-1, unbiased=False, keepdim=True) mean = torch.mean(x, dim=-1, keepdim=True) return (x - mean) / (var + self.eps).sqrt() * self.gamma class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = LayerNorm(dim) def forward(self, x, **kwargs): x = self.norm(x) return self.fn(x, **kwargs) class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, *args, **kwargs): return self.fn(x, *args, **kwargs) + x class EinopsToAndFrom(nn.Module): def __init__(self, from_einops, to_einops, fn): super().__init__() self.from_einops = from_einops self.to_einops = to_einops self.fn = fn def forward(self, x, **kwargs): shape = x.shape reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape))) x = rearrange(x, f'{self.from_einops} -> {self.to_einops}') x = self.fn(x, **kwargs) x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs) return x class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=20000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x): # not used in the final model x = x + self.pe[:x.shape[0], :] return self.dropout(x) class RelativePositionBias(nn.Module): def __init__( self, heads=8, num_buckets=32, max_distance=128 ): super().__init__() self.num_buckets = num_buckets self.max_distance = max_distance self.relative_attention_bias = nn.Embedding(num_buckets, heads) @staticmethod def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): ret = 0 n = -relative_position num_buckets //= 2 ret += (n < 0).long() * num_buckets n = torch.abs(n) max_exact = num_buckets // 2 is_small = n < max_exact val_if_large = max_exact + ( torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).long() val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) ret += torch.where(is_small, n, val_if_large) return ret def forward(self, n, device, eval = False): q_pos = torch.arange(n, dtype=torch.long, device=device) k_pos = torch.arange(n, dtype=torch.long, device=device) rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') rp_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance) if not self.relative_attention_bias.training: print('eval!') mask = - (((rel_pos > 200) + (rel_pos < -200)) * (1e8)) values = self.relative_attention_bias(rp_bucket) return rearrange(values, 'i j h -> h i j') + mask else: # values = self.relative_attention_bias(rp_bucket) # return rearrange(values, 'i j h -> h i j') # mask = - (((rel_pos > 100) + (rel_pos < -100)) * (1e8)) values = self.relative_attention_bias(rp_bucket) return rearrange(values, 'i j h -> h i j') # + mask # only for ablation / not used in the final model class TimeEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super(TimeEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) def forward(self, x, mask, lengths): time = mask * 1/(lengths[..., None]-1) time = time[:, None] * torch.arange(time.shape[1], device=x.device)[None, :] time = time[:, 0].T # add the time encoding x = x + time[..., None] return self.dropout(x) class Encoder_TRANSFORMERREEMB5(nn.Module): def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=6, eye_dim=2, pose_latent_dim=64, audio_latent_dim=256, ff_size=1024, num_layers=4, num_heads=4, dropout=0.1, ablation=None, activation="gelu", **kargs): super().__init__() self.modeltype = modeltype self.pos_dim = pos_dim self.eye_dim = eye_dim self.num_frames = num_frames self.audio_dim = audio_dim self.pose_latent_dim = pose_latent_dim self.audio_latent_dim = audio_latent_dim self.latent_dim = self.audio_latent_dim + self.pose_latent_dim*2 self.ff_size = ff_size self.num_layers = num_layers self.num_heads = num_heads self.dropout = dropout self.ablation = ablation self.activation = activation # if self.ablation == "average_encoder": # self.mu_layer = nn.Linear(self.latent_dim, self.latent_dim) # self.sigma_layer = nn.Linear(self.latent_dim, self.latent_dim) # else: # self.muQuery = nn.Parameter(torch.randn(self.num_classes, self.latent_dim)) # self.sigmaQuery = nn.Parameter(torch.randn(self.num_classes, self.latent_dim)) # # there's no class of our dataset CREMA/HDTF, so noly dont need to use nn.parameter self.mu_layer = nn.Linear(self.latent_dim, self.audio_latent_dim) self.sigma_layer = nn.Linear(self.latent_dim, self.audio_latent_dim) self.poseEmbedding = nn.Linear(self.pos_dim+self.eye_dim, self.pose_latent_dim) #6,64 self.firstposeEmbedding = nn.Linear(self.pos_dim+self.eye_dim, self.pose_latent_dim) #6,64 self.audioEmbedding = nn.Linear(self.audio_dim, self.audio_latent_dim) #1024, 256 self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) # self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, num_layers=self.num_layers) def forward(self, batch): ''' x: 6-dim pos, (bs, max_num_frames, 6) y: 1024-dim audio embbeding, (bs, max_num_frames, 1024) ''' x, y, mask = batch["x"], batch["y"], batch["mask"] # bs, njoints, nfeats, nframes = x.shape # x = x.permute((3, 0, 1, 2)).reshape(nframes, bs, njoints*nfeats) x_ref = x[:,0,:].unsqueeze(dim=1) # The pose information of the first frame(refrence img) x = x-x_ref.repeat(1,x.size(1),1) # bs, nf, 6 Obtain the difference from the first frame batch['x_delta'] = x x_ref = x_ref.permute((1,0,2)) #1, bs, 6 x = x.permute((1, 0, 2)) #nf, bs, 6 y = y.permute((1, 0, 2)) #nf, bs, 1024 # embedding of the pose/audio x_ref = self.firstposeEmbedding(x_ref).repeat(x.size(0),1,1) #nf, bs, 64 x = self.poseEmbedding(x) #nf, bs, 64 y = self.audioEmbedding(y) #nf, bs, 256 x = torch.cat([x_ref, x, y],dim=-1) # nf, bs, 64+64+256 # only use the "average_encoder" mode # add positional encoding x = self.sequence_pos_encoder(x) # transformer layers final = self.seqTransEncoder(x, src_key_padding_mask=~mask) #nu_frames, bs, 64+64+256 # get the average of the output z = final# final.mean(axis=0) # nf, bs, 64+64+256 # extract mu and logvar mu = self.mu_layer(z) # nf, bs, 256 logvar = self.sigma_layer(z) # nf, bs, 256 # logvar = - torch.ones_like(logvar) * 1e10 return {"mu": mu, "logvar": logvar} class Decoder_TRANSFORMERREEMB5(nn.Module): def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=6, eye_dim=2, pose_latent_dim=64, audio_latent_dim=256, ff_size=1024, num_layers=4, num_heads=4, dropout=0.1, activation="gelu", ablation=None, num_buckets = 32, max_distance = 32,**kargs): super().__init__() self.modeltype = modeltype self.pos_dim = pos_dim self.eye_dim = eye_dim self.num_frames = num_frames self.audio_dim = audio_dim self.pose_latent_dim = pose_latent_dim self.audio_latent_dim = audio_latent_dim self.latent_dim = self.audio_latent_dim + self.pose_latent_dim*2 self.ff_size = ff_size self.num_layers = num_layers self.num_heads = num_heads self.dropout = dropout self.ablation = ablation self.activation = activation self.firstposeEmbedding = nn.Linear(self.pos_dim+self.eye_dim, self.pose_latent_dim) #6,64 self.audioEmbedding = nn.Linear(self.audio_dim, self.audio_latent_dim) #1024, 256 self.ztimelinear = nn.Linear(self.audio_latent_dim*2+self.pose_latent_dim, self.pose_latent_dim) #256*2+64,64 self.init_proj = nn.Linear(self.pose_latent_dim, self.pose_latent_dim) # self.input_feats = self.njoints*self.nfeats # # only for ablation / not used in the final model # if self.ablation == "zandtime": # self.ztimelinear = nn.Linear(self.latent_dim + self.num_classes, self.latent_dim) # else: # self.actionBiases = nn.Parameter(torch.randn(1024, self.latent_dim)) # self.actionBiases = nn.Parameter(torch.randn(self.num_classes, self.latent_dim)) # # only for ablation / not used in the final model # if self.ablation == "time_encoding": # self.sequence_pos_encoder = TimeEncoding(self.dropout) # else: # self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) self.sequence_pos_encoder = PositionalEncoding(self.pose_latent_dim, self.dropout) rotary_emb = RotaryEmbedding(min(32, num_heads)) self.time_rel_pos_bias_tgt = RelativePositionBias(heads=num_heads, num_buckets=num_buckets, max_distance=max_distance) self.time_rel_pos_bias_mem = RelativePositionBias(heads=num_heads, num_buckets=num_buckets, max_distance=max_distance) temporal_attn = lambda dim: Attention(dim, heads=num_heads, dim_head=32, rotary_emb=rotary_emb) self.init_temporal_attn = Residual(PreNorm(self.pose_latent_dim, temporal_attn(self.pose_latent_dim))) # self.sequence_pos_encoder = TimeEncoding(self.dropout) #time_encoding # seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=self.pose_latent_dim, # nhead=self.num_heads, # dim_feedforward=self.ff_size, # dropout=self.dropout, # activation=activation) self.seqTransDecoder = TransformerDecoder(d_model = self.pose_latent_dim, num_heads = self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, num_layers=self.num_layers) self.finallayer = nn.Linear(self.pose_latent_dim, self.pos_dim+self.eye_dim) self.q_dropout = nn.Dropout(self.dropout) def forward(self, batch): ''' z: bs, audio_latent_dim(256) y: bs, num_frames, 1024 mask: bs, num_frames lengths: [num_frames,...] ''' x, z, y, mask, lengths = batch["x"], batch["z"], batch["y"], batch["mask"], batch["lengths"] bs, nframes = mask.shape # first img x_ref = x[:,0,:].unsqueeze(dim=1) #bs, 1, 64 x_ref = self.firstposeEmbedding(x_ref.repeat(1, nframes, 1)) #bs, nf, 64 y = self.audioEmbedding(y) #bs, num_frames, 256 z = z.permute(1, 0, 2) #z = z.unsqueeze(dim=1).repeat(1, nframes, 1) #bs, num_frames, 256 z = torch.cat([x_ref, z, y], dim=-1) # bs, num_frames, 256*2+64 z = self.ztimelinear(z) # z = z.permute((1, 0, 2)) # nf, bs, 64 pose_latent_dim = z.shape[2] # z = z[None] # sequence of size 1 # # only for ablation / not used in the final model # if self.ablation == "zandtime": # yoh = F.one_hot(y, self.num_classes) # z = torch.cat((z, yoh), axis=1) # z = self.ztimelinear(z) # z = z[None] # sequence of size 1 # else: # # only for ablation / not used in the final model # if self.ablation == "concat_bias": # # sequence of size 2 # z = torch.stack((z, self.actionBiases[y]), axis=0) # else: # # shift the latent noise vector to be the action noise # z = z + self.actionBiases[y.long()] # NEED CHECK # z = z[None] # sequence of size 1 timequeries = torch.zeros(bs, nframes, pose_latent_dim, device=z.device) # len, b, c # timequeries = self.sequence_pos_encoder(timequeries) #time_encoding time_rel_pos_bias_tgt = self.time_rel_pos_bias_tgt(nframes, device=x.device) time_rel_pos_bias_mem = self.time_rel_pos_bias_mem(nframes, device=x.device) timequeries = self.init_proj(timequeries) timequeries = self.init_temporal_attn(timequeries, pos_bias=time_rel_pos_bias_tgt.repeat(bs, 1, 1, 1)) # # only for ablation / not used in the final model # if self.ablation == "time_encoding": # timequeries = self.sequence_pos_encoder(timequeries, mask, lengths) # else: # timequeries = self.sequence_pos_encoder(timequeries) # num_frames, bs, 64 output = self.seqTransDecoder(tgt=timequeries, memory=z, tgt_mask=time_rel_pos_bias_tgt.repeat(bs, 1, 1, 1), memory_mask = time_rel_pos_bias_mem.repeat(bs, 1, 1, 1), ) output = self.finallayer(output) # .reshape(nframes, bs, self.pos_dim) # num_frames, bs, 6 # output = self.finallayer(output).reshape(nframes, bs, njoints, nfeats) # zero for padded area output[~mask] = 0 #nf, bs, 6 batch["out_pose"] = output[:,:,:6] # .permute(1,0,2)#bs, nf, 6 batch["out_eye"] = output[:,:,6:] # .permute(1,0,2)#bs, nf, 6 batch["output"] = output return batch ================================================ FILE: PBnet/src/models/architectures/transformerreemb6.py ================================================ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat, reduce, pack, unpack from einops_exts import rearrange_many from torch import einsum from rotary_embedding_torch import RotaryEmbedding import math from src.models.architectures.transformerdecoder5 import * def exists(x): return x is not None class LayerNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.gamma = nn.Parameter(torch.ones(1, 1, dim)) def forward(self, x): var = torch.var(x, dim=-1, unbiased=False, keepdim=True) mean = torch.mean(x, dim=-1, keepdim=True) return (x - mean) / (var + self.eps).sqrt() * self.gamma class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = LayerNorm(dim) def forward(self, x, **kwargs): x = self.norm(x) return self.fn(x, **kwargs) class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, *args, **kwargs): return self.fn(x, *args, **kwargs) + x class EinopsToAndFrom(nn.Module): def __init__(self, from_einops, to_einops, fn): super().__init__() self.from_einops = from_einops self.to_einops = to_einops self.fn = fn def forward(self, x, **kwargs): shape = x.shape reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape))) x = rearrange(x, f'{self.from_einops} -> {self.to_einops}') x = self.fn(x, **kwargs) x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs) return x class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=20000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x): # not used in the final model x = x + self.pe[:x.shape[0], :] return self.dropout(x) class RelativePositionBias(nn.Module): def __init__( self, heads=8, num_buckets=32, max_distance=128 ): super().__init__() self.num_buckets = num_buckets self.max_distance = max_distance self.relative_attention_bias = nn.Embedding(num_buckets, heads) @staticmethod def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): ret = 0 n = -relative_position num_buckets //= 2 ret += (n < 0).long() * num_buckets n = torch.abs(n) max_exact = num_buckets // 2 is_small = n < max_exact val_if_large = max_exact + ( torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).long() val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) ret += torch.where(is_small, n, val_if_large) return ret def forward(self, n, device, eval = False): q_pos = torch.arange(n, dtype=torch.long, device=device) k_pos = torch.arange(n, dtype=torch.long, device=device) rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') rp_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance) if not self.relative_attention_bias.training: print('eval!') mask = - (((rel_pos > 100) + (rel_pos < -100)) * (1e8)) values = self.relative_attention_bias(rp_bucket) return rearrange(values, 'i j h -> h i j') + mask else: # values = self.relative_attention_bias(rp_bucket) # return rearrange(values, 'i j h -> h i j') # mask = - (((rel_pos > 100) + (rel_pos < -100)) * (1e8)) values = self.relative_attention_bias(rp_bucket) return rearrange(values, 'i j h -> h i j') # + mask # only for ablation / not used in the final model class TimeEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super(TimeEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) def forward(self, x, mask, lengths): time = mask * 1/(lengths[..., None]-1) time = time[:, None] * torch.arange(time.shape[1], device=x.device)[None, :] time = time[:, 0].T # add the time encoding x = x + time[..., None] return self.dropout(x) class Encoder_TRANSFORMERREEMB6(nn.Module): def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=6, eye_dim=2, pose_latent_dim=64, audio_latent_dim=256, ff_size=1024, num_layers=4, num_heads=4, dropout=0.1, ablation=None, activation="gelu", **kargs): super().__init__() self.modeltype = modeltype self.pos_dim = pos_dim self.eye_dim = 0 self.num_frames = num_frames self.audio_dim = audio_dim self.pose_latent_dim = pose_latent_dim self.audio_latent_dim = audio_latent_dim self.latent_dim = self.audio_latent_dim + self.pose_latent_dim*2 self.ff_size = ff_size self.num_layers = num_layers self.num_heads = num_heads self.dropout = dropout self.ablation = ablation self.activation = activation # if self.ablation == "average_encoder": # self.mu_layer = nn.Linear(self.latent_dim, self.latent_dim) # self.sigma_layer = nn.Linear(self.latent_dim, self.latent_dim) # else: # self.muQuery = nn.Parameter(torch.randn(self.num_classes, self.latent_dim)) # self.sigmaQuery = nn.Parameter(torch.randn(self.num_classes, self.latent_dim)) # # there's no class of our dataset CREMA/HDTF, so noly dont need to use nn.parameter self.mu_layer = nn.Linear(self.latent_dim, self.audio_latent_dim) self.sigma_layer = nn.Linear(self.latent_dim, self.audio_latent_dim) self.poseEmbedding = nn.Linear(self.pos_dim+self.eye_dim, self.pose_latent_dim) #6,64 self.firstposeEmbedding = nn.Linear(self.pos_dim+self.eye_dim, self.pose_latent_dim) #6,64 self.audioEmbedding = nn.Linear(self.audio_dim, self.audio_latent_dim) #1024, 256 self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) # self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, num_layers=self.num_layers) def forward(self, batch): ''' x: 6-dim pos, (bs, max_num_frames, 6) y: 1024-dim audio embbeding, (bs, max_num_frames, 1024) ''' x, y, mask = batch["x"], batch["y"], batch["mask"] # bs, njoints, nfeats, nframes = x.shape # x = x.permute((3, 0, 1, 2)).reshape(nframes, bs, njoints*nfeats) x_ref = x[:,0,:].unsqueeze(dim=1) # The pose information of the first frame(refrence img) x = x-x_ref.repeat(1,x.size(1),1) # bs, nf, 6 Obtain the difference from the first frame batch['x_delta'] = x x_ref = x_ref.permute((1,0,2)) #1, bs, 6 x = x.permute((1, 0, 2)) #nf, bs, 6 y = y.permute((1, 0, 2)) #nf, bs, 1024 # embedding of the pose/audio x_ref = self.firstposeEmbedding(x_ref).repeat(x.size(0),1,1) #nf, bs, 64 x = self.poseEmbedding(x) #nf, bs, 64 y = self.audioEmbedding(y) #nf, bs, 256 x = torch.cat([x_ref, x, y],dim=-1) # nf, bs, 64+64+256 # only use the "average_encoder" mode # add positional encoding x = self.sequence_pos_encoder(x) # transformer layers final = self.seqTransEncoder(x, src_key_padding_mask=~mask) #nu_frames, bs, 64+64+256 # get the average of the output z = final# final.mean(axis=0) # nf, bs, 64+64+256 # extract mu and logvar mu = self.mu_layer(z) # nf, bs, 256 logvar = self.sigma_layer(z) # nf, bs, 256 # logvar = - torch.ones_like(logvar) * 1e10 return {"mu": mu, "logvar": logvar} class Decoder_TRANSFORMERREEMB6(nn.Module): def __init__(self, modeltype, num_frames, audio_dim=1024, pos_dim=6, eye_dim=2, pose_latent_dim=64, audio_latent_dim=256, ff_size=1024, num_layers=4, num_heads=4, dropout=0.1, activation="gelu", ablation=None, num_buckets = 32, max_distance = 32,**kargs): super().__init__() self.modeltype = modeltype self.pos_dim = pos_dim self.eye_dim = 0 self.num_frames = num_frames self.audio_dim = audio_dim self.pose_latent_dim = pose_latent_dim self.audio_latent_dim = audio_latent_dim self.latent_dim = self.audio_latent_dim + self.pose_latent_dim*2 self.ff_size = ff_size self.num_layers = num_layers self.num_heads = num_heads self.dropout = dropout self.ablation = ablation self.activation = activation self.firstposeEmbedding = nn.Linear(self.pos_dim+self.eye_dim, self.pose_latent_dim) #6,64 self.audioEmbedding = nn.Linear(self.audio_dim, self.audio_latent_dim) #1024, 256 self.ztimelinear = nn.Linear(self.audio_latent_dim*2+self.pose_latent_dim, self.pose_latent_dim) #256*2+64,64 self.init_proj = nn.Linear(self.pose_latent_dim, self.pose_latent_dim) # self.input_feats = self.njoints*self.nfeats # # only for ablation / not used in the final model # if self.ablation == "zandtime": # self.ztimelinear = nn.Linear(self.latent_dim + self.num_classes, self.latent_dim) # else: # self.actionBiases = nn.Parameter(torch.randn(1024, self.latent_dim)) # self.actionBiases = nn.Parameter(torch.randn(self.num_classes, self.latent_dim)) # # only for ablation / not used in the final model # if self.ablation == "time_encoding": # self.sequence_pos_encoder = TimeEncoding(self.dropout) # else: # self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) self.sequence_pos_encoder = PositionalEncoding(self.pose_latent_dim, self.dropout) rotary_emb = RotaryEmbedding(min(32, num_heads)) self.time_rel_pos_bias_tgt = RelativePositionBias(heads=num_heads, num_buckets=num_buckets, max_distance=max_distance) self.time_rel_pos_bias_mem = RelativePositionBias(heads=num_heads, num_buckets=num_buckets, max_distance=max_distance) temporal_attn = lambda dim: Attention(dim, heads=num_heads, dim_head=32, rotary_emb=rotary_emb) self.init_temporal_attn = Residual(PreNorm(self.pose_latent_dim, temporal_attn(self.pose_latent_dim))) # self.sequence_pos_encoder = TimeEncoding(self.dropout) #time_encoding # seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=self.pose_latent_dim, # nhead=self.num_heads, # dim_feedforward=self.ff_size, # dropout=self.dropout, # activation=activation) self.seqTransDecoder = TransformerDecoder(d_model = self.pose_latent_dim, num_heads = self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, num_layers=self.num_layers) self.finallayer = nn.Linear(self.pose_latent_dim, self.pos_dim+self.eye_dim) def forward(self, batch): ''' z: bs, audio_latent_dim(256) y: bs, num_frames, 1024 mask: bs, num_frames lengths: [num_frames,...] ''' x, z, y, mask, lengths = batch["x"], batch["z"], batch["y"], batch["mask"], batch["lengths"] bs, nframes = mask.shape # first img x_ref = x[:,0,:].unsqueeze(dim=1) #bs, 1, 64 x_ref = self.firstposeEmbedding(x_ref.repeat(1, nframes, 1)) #bs, nf, 64 y = self.audioEmbedding(y) #bs, num_frames, 256 z = z.permute(1, 0, 2) #z = z.unsqueeze(dim=1).repeat(1, nframes, 1) #bs, num_frames, 256 z = torch.cat([x_ref, z, y], dim=-1) # bs, num_frames, 256*2+64 z = self.ztimelinear(z) # z = z.permute((1, 0, 2)) # nf, bs, 64 pose_latent_dim = z.shape[2] # z = z[None] # sequence of size 1 # # only for ablation / not used in the final model # if self.ablation == "zandtime": # yoh = F.one_hot(y, self.num_classes) # z = torch.cat((z, yoh), axis=1) # z = self.ztimelinear(z) # z = z[None] # sequence of size 1 # else: # # only for ablation / not used in the final model # if self.ablation == "concat_bias": # # sequence of size 2 # z = torch.stack((z, self.actionBiases[y]), axis=0) # else: # # shift the latent noise vector to be the action noise # z = z + self.actionBiases[y.long()] # NEED CHECK # z = z[None] # sequence of size 1 timequeries = torch.zeros(bs, nframes, pose_latent_dim, device=z.device) # len, b, c # timequeries = self.sequence_pos_encoder(timequeries) #time_encoding time_rel_pos_bias_tgt = self.time_rel_pos_bias_tgt(nframes, device=x.device) time_rel_pos_bias_mem = self.time_rel_pos_bias_mem(nframes, device=x.device) timequeries = self.init_proj(timequeries) timequeries = self.init_temporal_attn(timequeries, pos_bias=time_rel_pos_bias_tgt.repeat(bs, 1, 1, 1)) # # only for ablation / not used in the final model # if self.ablation == "time_encoding": # timequeries = self.sequence_pos_encoder(timequeries, mask, lengths) # else: # timequeries = self.sequence_pos_encoder(timequeries) # num_frames, bs, 64 output = self.seqTransDecoder(tgt=timequeries, memory=z, tgt_mask=time_rel_pos_bias_tgt.repeat(bs, 1, 1, 1), memory_mask = time_rel_pos_bias_mem.repeat(bs, 1, 1, 1), ) output = self.finallayer(output) # .reshape(nframes, bs, self.pos_dim) # num_frames, bs, 6 # output = self.finallayer(output).reshape(nframes, bs, njoints, nfeats) # zero for padded area output[~mask] = 0 #nf, bs, 6 # batch["out_pose"] = output[:,:,:6] # .permute(1,0,2)#bs, nf, 6 # batch["out_eye"] = output[:,:,6:] # .permute(1,0,2)#bs, nf, 6 batch["output"] = output return batch ================================================ FILE: PBnet/src/models/architectures/transgru.py ================================================ from .transformer import Encoder_TRANSFORMER as Encoder_TRANSGRU # noqa from .gru import Decoder_GRU as Decoder_TRANSGRU # noqa ================================================ FILE: PBnet/src/models/get_model.py ================================================ import importlib import sys import os current_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(os.path.dirname(current_dir)) if parent_dir not in sys.path: sys.path.append(parent_dir) print(parent_dir) # JOINTSTYPES = ["a2m", "a2mpl", "smpl", "vibe", "vertices"] LOSSES = ["rc", "kl", "rcw", "ssim", "var", 'reg'] # not used: "hp", "mmd", "vel", "velxyz" MODELTYPES = ["cvae"] # not used: "cae" ARCHINAMES = ["fc", "gru", "transformer","transformerreemb5", "transformerreemb6", "transformerreemb7", "transformerreemb8","transformermel", "transgru", "grutrans", "autotrans"] def get_model(parameters): modeltype = parameters["modeltype"] archiname = parameters["archiname"] archi_module = importlib.import_module(f'.architectures.{archiname}', package="src.models") Encoder = archi_module.__getattribute__(f"Encoder_{archiname.upper()}") Decoder = archi_module.__getattribute__(f"Decoder_{archiname.upper()}") model_module = importlib.import_module(f'.modeltype.{modeltype}', package="src.models") Model = model_module.__getattribute__(f"{modeltype.upper()}") encoder = Encoder(**parameters) decoder = Decoder(**parameters) # parameters["outputxyz"] = "rcxyz" in parameters["lambdas"] return Model(encoder, decoder, **parameters).to(parameters["device"]) ================================================ FILE: PBnet/src/models/modeltype/__init__.py ================================================ ================================================ FILE: PBnet/src/models/modeltype/cae.py ================================================ import torch import torch.nn as nn from ..tools.losses import get_loss_function import torch.nn.functional as F # from ..rotation2xyz import Rotation2xyz class CAE(nn.Module): def __init__(self, encoder, decoder, device, lambdas, latent_dim, **kwargs): super().__init__() self.encoder = encoder self.decoder = decoder # self.outputxyz = outputxyz self.lambdas = lambdas self.latent_dim = latent_dim # self.pose_rep = pose_rep # self.glob = glob # self.glob_rot = glob_rot self.device = device # self.translation = translation # self.jointstype = jointstype # self.vertstrans = vertstrans self.losses = list(self.lambdas) + ["mixed"] # self.rotation2xyz = Rotation2xyz(device=self.device) # self.param2xyz = {"pose_rep": self.pose_rep, # "glob_rot": self.glob_rot, # "glob": self.glob, # "jointstype": self.jointstype, # "translation": self.translation, # "vertstrans": self.vertstrans} # def rot2xyz(self, x, mask, **kwargs): # kargs = self.param2xyz.copy() # kargs.update(kwargs) # return self.rotation2xyz(x, mask, **kargs) def forward(self, batch): # if self.outputxyz: # batch["x_xyz"] = self.rot2xyz(batch["x"], batch["mask"]) # elif self.pose_rep == "xyz": # batch["x_xyz"] = batch["x"] # encode batch.update(self.encoder(batch)) # decode batch.update(self.decoder(batch)) # # if we want to output xyz # if self.outputxyz: # batch["output_xyz"] = self.rot2xyz(batch["output"], batch["mask"]) # elif self.pose_rep == "xyz": # batch["output_xyz"] = batch["output"] return batch def compute_loss(self, batch, epoch = 0): mixed_loss = 0 losses = {} for ltype, lam in self.lambdas.items(): loss_function = get_loss_function(ltype) loss = loss_function(self, batch) if 'kl' in ltype: if epoch < 1e4 and epoch != 0: lam = 0 elif epoch != 0: lam = lam * max(epoch - 1e4, 7e4) / 7e4 mixed_loss += loss*lam losses[ltype] = loss.item() # D_loss, G_loss = self.calculate_GAN_loss(batch) # mixed_loss += G_loss * 0.7 # losses['GAN_D'] = D_loss # losses['GAN_G'] = D_loss losses["mixed"] = mixed_loss.item() return mixed_loss, losses @staticmethod def lengths_to_mask(lengths): max_len = max(lengths) if isinstance(max_len, torch.Tensor): max_len = max_len.item() index = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) mask = index < lengths.unsqueeze(1) return mask def generate_one(self, cls, duration, fact=1, xyz=False): y = torch.tensor([cls], dtype=int, device=self.device)[None] lengths = torch.tensor([duration], dtype=int, device=self.device) mask = self.lengths_to_mask(lengths) z = torch.randn(self.latent_dim, device=self.device)[None] batch = {"z": fact*z, "y": y, "mask": mask, "lengths": lengths} batch = self.decoder(batch) if not xyz: return batch["output"][0] output_xyz = self.rot2xyz(batch["output"], batch["mask"]) return output_xyz[0] def generate(self, pose, audio, durations, noise_same_action="random", noise_diff_action="random", fact=1): ''' audio: hubert embeddbing, (bs, fn, 1024) durations: different num_frames, (bs, ) ''' # if nspa is None: # nspa = 1 bs = len(audio) # y = audio.to(self.device).repeat(nspa) # (view(nspa, nats)) x = pose.to(self.device) y = audio.to(self.device) if len(durations.shape) == 1: lengths = durations.to(self.device) else: lengths = durations.to(self.device).reshape(y.shape) mask = self.lengths_to_mask(lengths) z = torch.randn(audio[0].shape[0], bs, self.latent_dim, device=self.device) # z = torch.randn(1, bs, self.latent_dim, device=self.device).repeat(audio[0].shape[0], 1, 1) # if noise_same_action == "random": # if noise_diff_action == "random": # z = torch.randn(nspa*bs, self.latent_dim, device=self.device) # elif noise_diff_action == "same": # z_same_action = torch.randn(nspa, self.latent_dim, device=self.device) # z = z_same_action.repeat_interleave(bs, axis=0) # else: # raise NotImplementedError("Noise diff action must be random or same.") # elif noise_same_action == "interpolate": # if noise_diff_action == "random": # z_diff_action = torch.randn(bs, self.latent_dim, device=self.device) # elif noise_diff_action == "same": # z_diff_action = torch.randn(1, self.latent_dim, device=self.device).repeat(bs, 1) # else: # raise NotImplementedError("Noise diff action must be random or same.") # interpolation_factors = torch.linspace(-1, 1, nspa, device=self.device) # z = torch.einsum("ij,k->kij", z_diff_action, interpolation_factors).view(nspa*bs, -1) # elif noise_same_action == "same": # if noise_diff_action == "random": # z_diff_action = torch.randn(bs, self.latent_dim, device=self.device) # elif noise_diff_action == "same": # z_diff_action = torch.randn(1, self.latent_dim, device=self.device).repeat(bs, 1) # else: # raise NotImplementedError("Noise diff action must be random or same.") # z = z_diff_action.repeat((nspa, 1)) # else: # raise NotImplementedError("Noise same action must be random, same or interpolate.") batch = {"x": x,"z": fact*z, "y": y, "mask": mask, "lengths": lengths} batch = self.decoder(batch) # if self.outputxyz: # batch["output_xyz"] = self.rot2xyz(batch["output"], batch["mask"]) # elif self.pose_rep == "xyz": # batch["output_xyz"] = batch["output"] return batch def return_latent(self, batch, seed=None): return self.encoder(batch)["z"] ================================================ FILE: PBnet/src/models/modeltype/cae_0.py ================================================ import torch import torch.nn as nn from ..tools.losses import get_loss_function # from ..rotation2xyz import Rotation2xyz class CAE(nn.Module): def __init__(self, encoder, decoder, device, lambdas, latent_dim, **kwargs): super().__init__() self.encoder = encoder self.decoder = decoder # self.outputxyz = outputxyz self.lambdas = lambdas self.latent_dim = latent_dim # self.pose_rep = pose_rep # self.glob = glob # self.glob_rot = glob_rot self.device = device # self.translation = translation # self.jointstype = jointstype # self.vertstrans = vertstrans self.losses = list(self.lambdas) + ["mixed"] # self.rotation2xyz = Rotation2xyz(device=self.device) # self.param2xyz = {"pose_rep": self.pose_rep, # "glob_rot": self.glob_rot, # "glob": self.glob, # "jointstype": self.jointstype, # "translation": self.translation, # "vertstrans": self.vertstrans} # def rot2xyz(self, x, mask, **kwargs): # kargs = self.param2xyz.copy() # kargs.update(kwargs) # return self.rotation2xyz(x, mask, **kargs) def forward(self, batch): # if self.outputxyz: # batch["x_xyz"] = self.rot2xyz(batch["x"], batch["mask"]) # elif self.pose_rep == "xyz": # batch["x_xyz"] = batch["x"] # encode batch.update(self.encoder(batch)) # decode batch.update(self.decoder(batch)) # # if we want to output xyz # if self.outputxyz: # batch["output_xyz"] = self.rot2xyz(batch["output"], batch["mask"]) # elif self.pose_rep == "xyz": # batch["output_xyz"] = batch["output"] return batch def compute_loss(self, batch): mixed_loss = 0 losses = {} for ltype, lam in self.lambdas.items(): loss_function = get_loss_function(ltype) loss = loss_function(self, batch) mixed_loss += loss*lam losses[ltype] = loss.item() losses["mixed"] = mixed_loss.item() return mixed_loss, losses @staticmethod def lengths_to_mask(lengths): max_len = max(lengths) if isinstance(max_len, torch.Tensor): max_len = max_len.item() index = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) mask = index < lengths.unsqueeze(1) return mask def generate_one(self, cls, duration, fact=1, xyz=False): y = torch.tensor([cls], dtype=int, device=self.device)[None] lengths = torch.tensor([duration], dtype=int, device=self.device) mask = self.lengths_to_mask(lengths) z = torch.randn(self.latent_dim, device=self.device)[None] batch = {"z": fact*z, "y": y, "mask": mask, "lengths": lengths} batch = self.decoder(batch) if not xyz: return batch["output"][0] output_xyz = self.rot2xyz(batch["output"], batch["mask"]) return output_xyz[0] def generate(self, pose, audio, durations, noise_same_action="random", noise_diff_action="random", fact=1): ''' audio: hubert embeddbing, (bs, fn, 1024) durations: different num_frames, (bs, ) ''' # if nspa is None: # nspa = 1 bs = len(audio) # y = audio.to(self.device).repeat(nspa) # (view(nspa, nats)) x = pose.to(self.device) y = audio.to(self.device) if len(durations.shape) == 1: lengths = durations.to(self.device) else: lengths = durations.to(self.device).reshape(y.shape) mask = self.lengths_to_mask(lengths) # z = torch.randn(bs, self.latent_dim, device=self.device) z = torch.randn(audio[0].shape[0], bs, self.latent_dim, device=self.device) # if noise_same_action == "random": # if noise_diff_action == "random": # z = torch.randn(nspa*bs, self.latent_dim, device=self.device) # elif noise_diff_action == "same": # z_same_action = torch.randn(nspa, self.latent_dim, device=self.device) # z = z_same_action.repeat_interleave(bs, axis=0) # else: # raise NotImplementedError("Noise diff action must be random or same.") # elif noise_same_action == "interpolate": # if noise_diff_action == "random": # z_diff_action = torch.randn(bs, self.latent_dim, device=self.device) # elif noise_diff_action == "same": # z_diff_action = torch.randn(1, self.latent_dim, device=self.device).repeat(bs, 1) # else: # raise NotImplementedError("Noise diff action must be random or same.") # interpolation_factors = torch.linspace(-1, 1, nspa, device=self.device) # z = torch.einsum("ij,k->kij", z_diff_action, interpolation_factors).view(nspa*bs, -1) # elif noise_same_action == "same": # if noise_diff_action == "random": # z_diff_action = torch.randn(bs, self.latent_dim, device=self.device) # elif noise_diff_action == "same": # z_diff_action = torch.randn(1, self.latent_dim, device=self.device).repeat(bs, 1) # else: # raise NotImplementedError("Noise diff action must be random or same.") # z = z_diff_action.repeat((nspa, 1)) # else: # raise NotImplementedError("Noise same action must be random, same or interpolate.") batch = {"x": x,"z": fact*z, "y": y, "mask": mask, "lengths": lengths} batch = self.decoder(batch) # if self.outputxyz: # batch["output_xyz"] = self.rot2xyz(batch["output"], batch["mask"]) # elif self.pose_rep == "xyz": # batch["output_xyz"] = batch["output"] return batch def return_latent(self, batch, seed=None): return self.encoder(batch)["z"] ================================================ FILE: PBnet/src/models/modeltype/cvae.py ================================================ import torch from .cae import CAE class CVAE(CAE): def reparameterize(self, batch, seed=None): mu, logvar = batch["mu"], batch["logvar"] std = torch.exp(logvar / 2) if seed is None: eps = std.data.new(std.size()).normal_() else: generator = torch.Generator(device=self.device) generator.manual_seed(seed) eps = std.data.new(std.size()).normal_(generator=generator) z = eps.mul(std).add_(mu) return z def forward(self, batch): # if self.outputxyz: # batch["x_xyz"] = self.rot2xyz(batch["x"], batch["mask"]) # elif self.pose_rep == "xyz": # batch["x_xyz"] = batch["x"] # encode batch.update(self.encoder(batch)) batch["z"] = self.reparameterize(batch) # decode batch.update(self.decoder(batch)) # if we want to output xyz # if self.outputxyz: # batch["output_xyz"] = self.rot2xyz(batch["output"], batch["mask"]) # elif self.pose_rep == "xyz": # batch["output_xyz"] = batch["output"] return batch def return_latent(self, batch, seed=None): distrib_param = self.encoder(batch) batch.update(distrib_param) return self.reparameterize(batch, seed=seed) ================================================ FILE: PBnet/src/models/modeltype/lstm.py ================================================ import torch import torch.nn as nn from ..tools.losses import get_loss_function import torch.nn.functional as F # from ..rotation2xyz import Rotation2xyz from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d from sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d from src.models.architectures.tools.resnet import resnet34 class MyResNet34(nn.Module): def __init__(self,embedding_dim,input_channel = 3): super(MyResNet34, self).__init__() self.resnet = resnet34(norm_layer = BatchNorm2d,num_classes=embedding_dim,input_channel = input_channel) def forward(self, x): return self.resnet(x) class LSTM(nn.Module): def __init__(self, encoder, decoder, device, lambdas, latent_dim, **kwargs): super(LSTM,self).__init__() self.em_audio = MyResNet34(256, 1) self.em_init_pose = nn.Linear(3,256) self.lstm = nn.LSTM(512,256,num_layers=2,bias=True,batch_first=True) self.output = nn.Linear(256,3) self.lambdas = lambdas self.losses = list(self.lambdas) + ["mixed"] self.device = device def compute_loss(self, batch): mixed_loss = 0 losses = {} for ltype, lam in self.lambdas.items(): loss_function = get_loss_function(ltype) loss = loss_function(self, batch) mixed_loss += loss*lam losses[ltype] = loss.item() # D_loss, G_loss = self.calculate_GAN_loss(batch) # mixed_loss += G_loss * 0.7 # losses['GAN_D'] = D_loss # losses['GAN_G'] = D_loss losses["mixed"] = mixed_loss.item() return mixed_loss, losses def forward(self,batch): x, y, mask = batch["x"], batch["y"], batch["mask"] bs = x.shape[0] x_ref = x[:,0,:].unsqueeze(dim=1) # The pose information of the first frame(refrence img) x = x-x_ref.repeat(1,x.size(1),1) # bs, nf, 6 Obtain the difference from the first frame batch['x_delta'] = x ref_pose = self.em_init_pose(batch["x"][:,0,:]) result = [] bs,seqlen,_,_ = batch["y"].shape zero_state = torch.zeros((2,bs,256),requires_grad=True).to(ref_pose.device) cur_state = (zero_state,zero_state) audio = batch["y"].reshape(-1, 1, 4, 41) audio_em = self.em_audio(audio).reshape(bs, seqlen, 256) for i in range(seqlen): ref_pose,cur_state = self.lstm(torch.cat((audio_em[:,i:i+1],ref_pose.unsqueeze(1)),dim=2),cur_state) ref_pose = ref_pose.reshape(-1, 256) result.append(self.output(ref_pose).unsqueeze(1)) res = torch.cat(result,dim=1) batch['output'] = res return batch @staticmethod def lengths_to_mask(lengths): max_len = max(lengths) if isinstance(max_len, torch.Tensor): max_len = max_len.item() index = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) mask = index < lengths.unsqueeze(1) return mask def generate(self, pose, audio, durations, noise_same_action="random", noise_diff_action="random", fact=1): x = pose.to(self.device) y = audio.to(self.device) if len(durations.shape) == 1: lengths = durations.to(self.device) else: lengths = durations.to(self.device).reshape(y.shape) mask = self.lengths_to_mask(lengths) batch = {"x": x, "y": y, "mask": mask, "lengths": lengths} batch = self.forward(batch) return batch ================================================ FILE: PBnet/src/models/rotation2xyz.py ================================================ import torch import src.utils.rotation_conversions as geometry from .smpl import SMPL, JOINTSTYPE_ROOT from .get_model import JOINTSTYPES class Rotation2xyz: def __init__(self, device): self.device = device self.smpl_model = SMPL().eval().to(device) def __call__(self, x, mask, pose_rep, translation, glob, jointstype, vertstrans, betas=None, beta=0, glob_rot=None, **kwargs): if pose_rep == "xyz": return x if mask is None: mask = torch.ones((x.shape[0], x.shape[-1]), dtype=bool, device=x.device) if not glob and glob_rot is None: raise TypeError("You must specify global rotation if glob is False") if jointstype not in JOINTSTYPES: raise NotImplementedError("This jointstype is not implemented.") if translation: x_translations = x[:, -1, :3] x_rotations = x[:, :-1] else: x_rotations = x x_rotations = x_rotations.permute(0, 3, 1, 2) nsamples, time, njoints, feats = x_rotations.shape # Compute rotations (convert only masked sequences output) if pose_rep == "rotvec": rotations = geometry.axis_angle_to_matrix(x_rotations[mask]) elif pose_rep == "rotmat": rotations = x_rotations[mask].view(-1, njoints, 3, 3) elif pose_rep == "rotquat": rotations = geometry.quaternion_to_matrix(x_rotations[mask]) elif pose_rep == "rot6d": rotations = geometry.rotation_6d_to_matrix(x_rotations[mask]) else: raise NotImplementedError("No geometry for this one.") if not glob: global_orient = torch.tensor(glob_rot, device=x.device) global_orient = geometry.axis_angle_to_matrix(global_orient).view(1, 1, 3, 3) global_orient = global_orient.repeat(len(rotations), 1, 1, 1) else: global_orient = rotations[:, 0] rotations = rotations[:, 1:] if betas is None: betas = torch.zeros([rotations.shape[0], self.smpl_model.num_betas], dtype=rotations.dtype, device=rotations.device) betas[:, 1] = beta # import ipdb; ipdb.set_trace() out = self.smpl_model(body_pose=rotations, global_orient=global_orient, betas=betas) # get the desirable joints joints = out[jointstype] x_xyz = torch.empty(nsamples, time, joints.shape[1], 3, device=x.device, dtype=x.dtype) x_xyz[~mask] = 0 x_xyz[mask] = joints x_xyz = x_xyz.permute(0, 2, 3, 1).contiguous() # the first translation root at the origin on the prediction if jointstype != "vertices": rootindex = JOINTSTYPE_ROOT[jointstype] x_xyz = x_xyz - x_xyz[:, [rootindex], :, :] if translation and vertstrans: # the first translation root at the origin x_translations = x_translations - x_translations[:, :, [0]] # add the translation to all the joints x_xyz = x_xyz + x_translations[:, None, :, :] return x_xyz ================================================ FILE: PBnet/src/models/smpl.py ================================================ import numpy as np import torch import contextlib from smplx import SMPLLayer as _SMPLLayer from smplx.lbs import vertices2joints from src.datasets.ntu13 import action2motion_joints from src.config import SMPL_MODEL_PATH, JOINT_REGRESSOR_TRAIN_EXTRA JOINTSTYPE_ROOT = {"a2m": 0, # action2motion "smpl": 0, "a2mpl": 0, # set(smpl, a2m) "vibe": 8} # 0 is the 8 position: OP MidHip below JOINT_MAP = { 'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17, 'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16, 'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0, 'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8, 'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7, 'OP REye': 25, 'OP LEye': 26, 'OP REar': 27, 'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30, 'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34, 'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45, 'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7, 'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17, 'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20, 'Neck (LSP)': 47, 'Top of Head (LSP)': 48, 'Pelvis (MPII)': 49, 'Thorax (MPII)': 50, 'Spine (H36M)': 51, 'Jaw (H36M)': 52, 'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26, 'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27 } JOINT_NAMES = [ 'OP Nose', 'OP Neck', 'OP RShoulder', 'OP RElbow', 'OP RWrist', 'OP LShoulder', 'OP LElbow', 'OP LWrist', 'OP MidHip', 'OP RHip', 'OP RKnee', 'OP RAnkle', 'OP LHip', 'OP LKnee', 'OP LAnkle', 'OP REye', 'OP LEye', 'OP REar', 'OP LEar', 'OP LBigToe', 'OP LSmallToe', 'OP LHeel', 'OP RBigToe', 'OP RSmallToe', 'OP RHeel', 'Right Ankle', 'Right Knee', 'Right Hip', 'Left Hip', 'Left Knee', 'Left Ankle', 'Right Wrist', 'Right Elbow', 'Right Shoulder', 'Left Shoulder', 'Left Elbow', 'Left Wrist', 'Neck (LSP)', 'Top of Head (LSP)', 'Pelvis (MPII)', 'Thorax (MPII)', 'Spine (H36M)', 'Jaw (H36M)', 'Head (H36M)', 'Nose', 'Left Eye', 'Right Eye', 'Left Ear', 'Right Ear' ] # adapted from VIBE/SPIN to output smpl_joints, vibe joints and action2motion joints class SMPL(_SMPLLayer): """ Extension of the official SMPL implementation to support more joints """ def __init__(self, model_path=SMPL_MODEL_PATH, **kwargs): kwargs["model_path"] = model_path # remove the verbosity for the 10-shapes beta parameters with contextlib.redirect_stdout(None): super(SMPL, self).__init__(**kwargs) J_regressor_extra = np.load(JOINT_REGRESSOR_TRAIN_EXTRA) self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32)) vibe_indexes = np.array([JOINT_MAP[i] for i in JOINT_NAMES]) a2m_indexes = vibe_indexes[action2motion_joints] smpl_indexes = np.arange(24) a2mpl_indexes = np.unique(np.r_[smpl_indexes, a2m_indexes]) self.maps = {"vibe": vibe_indexes, "a2m": a2m_indexes, "smpl": smpl_indexes, "a2mpl": a2mpl_indexes} def forward(self, *args, **kwargs): smpl_output = super(SMPL, self).forward(*args, **kwargs) extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices) all_joints = torch.cat([smpl_output.joints, extra_joints], dim=1) output = {"vertices": smpl_output.vertices} for joinstype, indexes in self.maps.items(): output[joinstype] = all_joints[:, indexes] return output ================================================ FILE: PBnet/src/models/tools/__init__.py ================================================ ================================================ FILE: PBnet/src/models/tools/graphconv.py ================================================ import math import torch from torch.nn.parameter import Parameter from torch.nn.modules.module import Module class GraphConvolution(Module): """ Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 """ def __init__(self, in_features, out_features, bias=True): super(GraphConvolution, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.FloatTensor(in_features, out_features)) if bias: self.bias = Parameter(torch.FloatTensor(out_features)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): stdv = 1. / math.sqrt(self.weight.size(1)) self.weight.data.uniform_(-stdv, stdv) if self.bias is not None: self.bias.data.uniform_(-stdv, stdv) def forward(self, input, adj): support = torch.mm(input, self.weight) output = torch.spmm(adj, support) if self.bias is not None: return output + self.bias else: return output def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ')' ================================================ FILE: PBnet/src/models/tools/hessian_penalty.py ================================================ """ ## Adapted to work with our "batches" Official PyTorch implementation of the Hessian Penalty regularization term from https://arxiv.org/pdf/2008.10599.pdf Author: Bill Peebles TensorFlow Implementation (GPU + Multi-Layer): hessian_penalty_tf.py Simple Pure NumPy Implementation: hessian_penalty_np.py Simple use case where you want to apply the Hessian Penalty to the output of net w.r.t. net_input: >>> from hessian_penalty_pytorch import hessian_penalty >>> net = MyNeuralNet() >>> net_input = sample_input() >>> loss = hessian_penalty(net, z=net_input) # Compute hessian penalty of net's output w.r.t. net_input >>> loss.backward() # Compute gradients w.r.t. net's parameters If your network takes multiple inputs, simply supply them to hessian_penalty as you do in the net's forward pass. In the following example, we assume BigGAN.forward takes a second input argument "y". Note that we always take the Hessian Penalty w.r.t. the z argument supplied to hessian_penalty: >>> from hessian_penalty_pytorch import hessian_penalty >>> net = BigGAN() >>> z_input = sample_z_vector() >>> class_label = sample_class_label() >>> loss = hessian_penalty(net, z=net_input, y=class_label) >>> loss.backward() """ import torch def hessian_penalty(G, batch, k=2, epsilon=0.1, reduction=torch.max, return_separately=False, G_z=None, **G_kwargs): """ Official PyTorch Hessian Penalty implementation. Note: If you want to regularize multiple network activations simultaneously, you need to make sure the function G you pass to hessian_penalty returns a list of those activations when it's called with G(z, **G_kwargs). Otherwise, if G returns a tensor the Hessian Penalty will only be computed for the final output of G. :param G: Function that maps input z to either a tensor or a list of tensors (activations) :param z: Input to G that the Hessian Penalty will be computed with respect to :param k: Number of Hessian directions to sample (must be >= 2) :param epsilon: Amount to blur G before estimating Hessian (must be > 0) :param reduction: Many-to-one function to reduce each pixel/neuron's individual hessian penalty into a final loss :param return_separately: If False, hessian penalties for each activation output by G are automatically summed into a final loss. If True, the hessian penalties for each layer will be returned in a list instead. If G outputs a single tensor, setting this to True will produce a length-1 list. :param G_z: [Optional small speed-up] If you have already computed G(z, **G_kwargs) for the current training iteration, then you can provide it here to reduce the number of forward passes of this method by 1 :param G_kwargs: Additional inputs to G besides the z vector. For example, in BigGAN you would pass the class label into this function via y= :return: A differentiable scalar (the hessian penalty), or a list of hessian penalties if return_separately is True """ if G_z is None: G_z = G(batch, **G_kwargs) z = batch["x"] rademacher_size = torch.Size((k, *z.size())) # (k, N, z.size()) dzs = epsilon * rademacher(rademacher_size, device=z.device) second_orders = [] for dz in dzs: # Iterate over each (N, z.size()) tensor in xs central_second_order = multi_layer_second_directional_derivative(G, batch, dz, G_z, epsilon, **G_kwargs) second_orders.append(central_second_order) # Appends a tensor with shape equal to G(z).size() loss = multi_stack_var_and_reduce(second_orders, reduction, return_separately) # (k, G(z).size()) --> scalar return loss def rademacher(shape, device='cpu'): """Creates a random tensor of size [shape] under the Rademacher distribution (P(x=1) == P(x=-1) == 0.5)""" x = torch.empty(shape, device=device) x.random_(0, 2) # Creates random tensor of 0s and 1s x[x == 0] = -1 # Turn the 0s into -1s return x def multi_layer_second_directional_derivative(G, batch, dz, G_z, epsilon, **G_kwargs): """Estimates the second directional derivative of G w.r.t. its input at z in the direction x""" batch_plus = {**batch, "x": batch["x"] + dz} batch_moins = {**batch, "x": batch["x"] - dz} G_to_x = G(batch_plus, **G_kwargs) G_from_x = G(batch_moins, **G_kwargs) G_to_x = listify(G_to_x) G_from_x = listify(G_from_x) G_z = listify(G_z) eps_sqr = epsilon ** 2 sdd = [(G2x - 2 * G_z_base + Gfx) / eps_sqr for G2x, G_z_base, Gfx in zip(G_to_x, G_z, G_from_x)] return sdd def stack_var_and_reduce(list_of_activations, reduction=torch.max): """Equation (5) from the paper.""" second_orders = torch.stack(list_of_activations) # (k, N, C, H, W) var_tensor = torch.var(second_orders, dim=0, unbiased=True) # (N, C, H, W) penalty = reduction(var_tensor) # (1,) (scalar) return penalty def multi_stack_var_and_reduce(sdds, reduction=torch.max, return_separately=False): """Iterate over all activations to be regularized, then apply Equation (5) to each.""" sum_of_penalties = 0 if not return_separately else [] for activ_n in zip(*sdds): penalty = stack_var_and_reduce(activ_n, reduction) sum_of_penalties += penalty if not return_separately else [penalty] return sum_of_penalties def listify(x): """If x is already a list, do nothing. Otherwise, wrap x in a list.""" if isinstance(x, list): return x else: return [x] def _test_hessian_penalty(): """ A simple multi-layer test to verify the implementation. Function: G(z) = [z_0 * z_1, z_0**2 * z_1] Ground Truth Hessian Penalty: [4, 16 * z_0**2] """ batch_size = 10 nz = 2 z = torch.randn(batch_size, nz) def reduction(x): return x.abs().mean() def G(z): return [z[:, 0] * z[:, 1], (z[:, 0] ** 2) * z[:, 1]] ground_truth = [4, reduction(16 * z[:, 0] ** 2).item()] # In this simple example, we use k=100 to reduce variance, but when applied to neural networks # you will probably want to use a small k (e.g., k=2) due to memory considerations. predicted = hessian_penalty(G, z, G_z=None, k=100, reduction=reduction, return_separately=True) predicted = [p.item() for p in predicted] print('Ground Truth: %s' % ground_truth) print('Approximation: %s' % predicted) # This should be close to ground_truth, but not exactly correct print('Difference: %s' % [str(100 * abs(p - gt) / gt) + '%' for p, gt in zip(predicted, ground_truth)]) if __name__ == '__main__': _test_hessian_penalty() ================================================ FILE: PBnet/src/models/tools/losses.py ================================================ import torch from einops import rearrange import torch.nn.functional as F from .hessian_penalty import hessian_penalty from .mmd import compute_mmd from .ssim_loss import ssim from .normalize_data import normalize_data def compute_rc_loss(model, batch): # x = batch["x"] #bs, nf, 6 x_delta = batch["x_delta"] output = batch["output"] #bs, nf, 6 mask = batch["mask"] #bs, nf # gtmasked = x[mask] gtmasked = x_delta[mask] outmasked = output[mask] # loss is large in the beginning loss = F.mse_loss(gtmasked, outmasked, reduction='mean') return loss def compute_reg_loss(model, batch): # x = batch["x"] #bs, nf, 6 x_delta = batch["x_delta"] mask = batch["mask"] #bs, nf x_1 = x_delta[:,:-1] x_2 = x_delta[:,1:] # gtmasked = x[mask] # loss is large in the beginning loss = F.mse_loss(x_1, x_2, reduction='mean') return loss def compute_rc_weight_loss(model, batch): x = batch["x"] #bs, nf, 6 x_delta = batch["x_delta"] output = batch["output"] #bs, nf, 6 mask = batch["mask"] #bs, nf # gtmasked = x[mask] #bs*nf, 6 gtmasked = x_delta[mask] #bs*nf, 6 outmasked = output[mask] #bs*nf, 6 if x.size(2) == 6: weights = torch.tensor([3, 3, 3, 1, 1, 1], dtype=torch.float32).cuda() elif x.size(2) == 7: weights = torch.tensor([3, 3, 3, 1, 1, 1, 0.5], dtype=torch.float32).cuda() elif x.size(2) == 8: weights = torch.tensor([3, 3, 3, 0, 0, 0, 3, 3], dtype=torch.float32).cuda() else: weights = torch.ones(x.size(2), dtype=torch.float32).cuda() weights = weights.unsqueeze(0) # loss is large in the beginning loss = F.mse_loss(gtmasked*weights, outmasked*weights, reduction='mean') return loss def compute_hp_loss(model, batch): loss = hessian_penalty(model.return_latent, batch, seed=torch.random.seed()) return loss def compute_kl_loss(model, batch): # mu, logvar: bs, 256 mu, logvar = batch["mu"], batch["logvar"] loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()) return loss def compute_ssim_loss(model, batch): x = batch["x"] #bs, nf, 6 x_ref = x[:,0,:].unsqueeze(dim=1) #bs, 1, 64 bs = x_ref.shape[0] mask = batch["mask"] #bs, nf x_delta = batch["x_delta"] output = batch["output"] loss = ssimnorm_loss(x_delta, output, mask, bs) return loss def ssimnorm_loss(x, output, mask, bs): min_vals = min(x.min(),output.min()) max_vals = max(x.max(),output.max()) x_norm = normalize_data(x, min_vals, max_vals) out_norm = normalize_data(output, min_vals, max_vals) gtmasked = x_norm[mask] #bs*nf, 6 outmasked = out_norm[mask] #bs*nf, 6 gtmasked = rearrange(gtmasked, '(b f) c -> b f c', b=bs) outmasked = rearrange(outmasked, '(b f) c -> b f c', b=bs) gtmasked = gtmasked.unsqueeze(dim=1) # b 1 f c outmasked = outmasked.unsqueeze(dim=1) # b 1 f c loss = 1-ssim(gtmasked, outmasked, val_range=1, window_size=3) return loss def ssimnorm_self_loss(x, output, mask, bs): x_norm = normalize_data(x, x.min(), x.max()) out_norm = normalize_data(output, output.min(),output.max()) gtmasked = x_norm[mask] #bs*nf, 6 outmasked = out_norm[mask] #bs*nf, 6 gtmasked = rearrange(gtmasked, '(b f) c -> b f c', b=bs) outmasked = rearrange(outmasked, '(b f) c -> b f c', b=bs) gtmasked = gtmasked.unsqueeze(dim=1) # b 1 f c outmasked = outmasked.unsqueeze(dim=1) # b 1 f c loss = 1-ssim(gtmasked, outmasked, val_range=1, window_size=5) return loss def ssim255_loss(x, output, mask, bs): gtmasked = x[mask] #bs*nf, 6 outmasked = output[mask] #bs*nf, 6 # add 128 to ensue input range is 0-255 gtmasked = rearrange(gtmasked, '(b f) c -> b f c', b=bs)+128 outmasked = rearrange(outmasked, '(b f) c -> b f c', b=bs)+128 gtmasked = gtmasked.unsqueeze(dim=1) # b 1 f c outmasked = outmasked.unsqueeze(dim=1) # b 1 f c loss = 1-ssim(gtmasked, outmasked, val_range=255, window_size=5) return loss def comput_var_loss(model, batch): output = batch["output"] #bs, nf, 6 mask = batch["mask"] #bs, nf outmasked = output[mask] #bs*nf, 6 batch_size, num_frames, dim = output.size() outmasked = rearrange(outmasked, '(b f) c -> b f c', b=batch_size) variance_loss = 0 zero_loss = torch.tensor(0) for b in range(batch_size): for d in range(dim): dimension_output = outmasked[b, :, d] # shape: (bs, nf) frame_variance = torch.var(dimension_output) variance_loss += frame_variance variance_loss /= (batch_size * dim) if 3>variance_loss>0: return variance_loss else: return zero_loss def compute_mmd_loss(model, batch): z = batch["z"] true_samples = torch.randn(z.shape, requires_grad=False, device=model.device) loss = compute_mmd(true_samples, z) return loss _matching_ = {"rc": compute_rc_loss, "rcw": compute_rc_weight_loss, "kl": compute_kl_loss, "hp": compute_hp_loss, "mmd": compute_mmd_loss, "ssim": compute_ssim_loss, "var": comput_var_loss, 'reg': compute_reg_loss} # _matching_ = {"rc": compute_rc_loss, "kl": compute_kl_loss, "hp": compute_hp_loss, # "mmd": compute_mmd_loss, "rcxyz": compute_rcxyz_loss, # "vel": compute_vel_loss, "velxyz": compute_velxyz_loss} def get_loss_function(ltype): return _matching_[ltype] def get_loss_names(): return list(_matching_.keys()) ================================================ FILE: PBnet/src/models/tools/mmd.py ================================================ import torch # from https://github.com/napsternxg/pytorch-practice/blob/master/Pytorch%20-%20MMD%20VAE.ipynb def compute_kernel(x, y): x_size = x.size(0) y_size = y.size(0) dim = x.size(1) x = x.unsqueeze(1) # (x_size, 1, dim) y = y.unsqueeze(0) # (1, y_size, dim) tiled_x = x.expand(x_size, y_size, dim) tiled_y = y.expand(x_size, y_size, dim) kernel_input = (tiled_x - tiled_y).pow(2).mean(2)/float(dim) return torch.exp(-kernel_input) # (x_size, y_size) def compute_mmd(x, y): x_kernel = compute_kernel(x, x) y_kernel = compute_kernel(y, y) xy_kernel = compute_kernel(x, y) mmd = x_kernel.mean() + y_kernel.mean() - 2*xy_kernel.mean() return mmd ================================================ FILE: PBnet/src/models/tools/msssim_loss.py ================================================ import torch import torch.nn.functional as F from math import exp import numpy as np def gaussian(window_size, sigma): gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) return gauss/gauss.sum() def create_window(window_size, channel=1): _1D_window = gaussian(window_size, 1.5).unsqueeze(1) _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() return window def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). if val_range is None: if torch.max(img1) > 128: max_val = 255 else: max_val = 1 if torch.min(img1) < -0.5: min_val = -1 else: min_val = 0 L = max_val - min_val else: L = val_range padd = 0 (_, channel, height, width) = img1.size() if window is None: real_size = min(window_size, height, width) window = create_window(real_size, channel=channel).to(img1.device) mu1 = F.conv2d(img1, window, padding=padd, groups=channel) mu2 = F.conv2d(img2, window, padding=padd, groups=channel) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1 * mu2 sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 C1 = (0.01 * L) ** 2 C2 = (0.03 * L) ** 2 v1 = 2.0 * sigma12 + C2 v2 = sigma1_sq + sigma2_sq + C2 cs = v1 / v2 # contrast sensitivity ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) if size_average: cs = cs.mean() ret = ssim_map.mean() else: cs = cs.mean(1).mean(1).mean(1) ret = ssim_map.mean(1).mean(1).mean(1) if full: return ret, cs return ret def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=None): device = img1.device weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) levels = weights.size()[0] ssims = [] mcs = [] for _ in range(levels): sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) # Relu normalize (not compliant with original definition) if normalize == "relu": ssims.append(torch.relu(sim)) mcs.append(torch.relu(cs)) else: ssims.append(sim) mcs.append(cs) img1 = F.avg_pool2d(img1, (2, 2)) img2 = F.avg_pool2d(img2, (2, 2)) ssims = torch.stack(ssims) mcs = torch.stack(mcs) # Simple normalize (not compliant with original definition) # TODO: remove support for normalize == True (kept for backward support) if normalize == "simple" or normalize == True: ssims = (ssims + 1) / 2 mcs = (mcs + 1) / 2 pow1 = mcs ** weights pow2 = ssims ** weights # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ output = torch.prod(pow1[:-1]) * pow2[-1] return output # Classes to re-use window class SSIM(torch.nn.Module): def __init__(self, window_size=11, size_average=True, val_range=None): super(SSIM, self).__init__() self.window_size = window_size self.size_average = size_average self.val_range = val_range # Assume 1 channel for SSIM self.channel = 1 self.window = create_window(window_size) def forward(self, img1, img2): (_, channel, _, _) = img1.size() if channel == self.channel and self.window.dtype == img1.dtype: window = self.window else: window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) self.window = window self.channel = channel return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) class MSSSIM(torch.nn.Module): def __init__(self, window_size=11, size_average=True, channel=3): super(MSSSIM, self).__init__() self.window_size = window_size self.size_average = size_average self.channel = channel def forward(self, img1, img2): # TODO: store window between calls if possible return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) if __name__ == "__main__": device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') m = MSSSIM() img1 = torch.rand(1, 1, 256, 256) img2 = torch.rand(1, 1, 256, 256) print(msssim(img1, img2)) print(m(img1, img2)) ================================================ FILE: PBnet/src/models/tools/normalize_data.py ================================================ import torch def normalize_data(data, min_vals, max_vals): min_vals = min_vals.unsqueeze(0).unsqueeze(0) max_vals = max_vals.unsqueeze(0).unsqueeze(0) normalized_data = (data - min_vals) / (max_vals - min_vals) return normalized_data if __name__ == "__main__": bs = 32 nf = 10 data = torch.randn((bs, nf, 6)) # means = torch.tensor([2.17239228e-02 -8.76334959e-01 1.83403242e-01 4.68812609e-04 6.09114990e+01 6.82846017e+01]) # stds = torch.tensor([3.95977561e+00 2.74141379e+00 2.70259097e+00 8.42982963e-06 1.71036724e+00 1.94872744e+00]) min_vals = torch.tensor([-1.03461033e+01, -8.08477430e+00, -7.56659334e+00, 4.33026857e-04, 5.68175623e+01, 6.36141304e+01]) max_vals = torch.tensor([1.75214498e+01, 8.44862517e+00, 7.98321722e+00, 6.12732050e-04, 6.88481830e+01, 8.21925801e+01]) normalized_data = normalize_data(data, min_vals, max_vals) print(normalized_data) ================================================ FILE: PBnet/src/models/tools/ssim_loss.py ================================================ import torch import torch.nn.functional as F from torch.autograd import Variable import numpy as np from math import exp def gaussian(window_size, sigma): gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) return gauss/gauss.sum() def create_window(window_size, channel): _1D_window = gaussian(window_size, 0.5).unsqueeze(1) _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) return window def _ssim(img1, img2, window, window_size, channel, val_range = 1, size_average = True): mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1*mu2 sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 C1 = (0.01*val_range)**2 C2 = (0.03*val_range)**2 ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) if size_average: return ssim_map.mean() else: return ssim_map.mean(1).mean(1).mean(1) class SSIM(torch.nn.Module): def __init__(self, window_size = 11, size_average = True): super(SSIM, self).__init__() self.window_size = window_size self.size_average = size_average self.channel = 1 self.window = create_window(window_size, self.channel) def forward(self, img1, img2): (_, channel, _, _) = img1.size() if channel == self.channel and self.window.data.type() == img1.data.type(): window = self.window else: window = create_window(self.window_size, channel) if img1.is_cuda: window = window.cuda(img1.get_device()) window = window.type_as(img1) self.window = window self.channel = channel return _ssim(img1, img2, window, self.window_size, channel, self.size_average) def ssim(img1, img2, window_size = 11, val_range=1, size_average = True): (_, channel, _, _) = img1.size() window = create_window(window_size, channel) if img1.is_cuda: window = window.cuda(img1.get_device()) window = window.type_as(img1) return _ssim(img1, img2, window, window_size, channel, val_range, size_average) def read_pose_from_txt(file_path): data = np.loadtxt(file_path) return data if __name__ == "__main__": pose1 = read_pose_from_txt('/train20/intern/permanent/lmlin2/ReferenceCode/ACTOR-master/exps_delta_pose/HDTF_nf40_kl1_ssim1_128_w5_1w/nofinetune/3500/eval_gt/0/RD_Radio14_000_522_gt') pose2 = read_pose_from_txt('/train20/intern/permanent/lmlin2/ReferenceCode/ACTOR-master/exps_delta_pose/HDTF_nf40_kl1_ssim1_128_w5_1w/nofinetune/3500/eval_pred/0/RD_Radio14_000_522') pose1_tensor = torch.tensor(pose1).unsqueeze(0).unsqueeze(0).float()[:,:,:,:-1] pose2_tensor = torch.tensor(pose2).unsqueeze(0).unsqueeze(0).float()[:,:,:,:-1] # pose1_tensor = torch.tensor(pose1).unsqueeze(0).unsqueeze(0).float()[:,:,:,:-1]+128 # pose2_tensor = torch.tensor(pose2).unsqueeze(0).unsqueeze(0).float()[:,:,:,:-1]+128 pose1_tensor = (pose1_tensor - pose1_tensor.min()) / (pose1_tensor.max() - pose1_tensor.min()) pose2_tensor = (pose2_tensor - pose2_tensor.min()) / (pose2_tensor.max() - pose2_tensor.min()) ssim_loss = 1-ssim(pose1_tensor, pose2_tensor, window_size=3, val_range=1) print(ssim_loss) ================================================ FILE: PBnet/src/models/tools/tools.py ================================================ import torch.nn as nn from torch.nn.modules.module import ModuleAttributeError class AutoParams(nn.Module): def __init__(self, **kargs): try: for param in self.needed_params: if param in kargs: setattr(self, param, kargs[param]) else: raise ValueError(f"{param} is needed.") except ModuleAttributeError: pass try: for param, default in self.optional_params.items(): if param in kargs and kargs[param] is not None: setattr(self, param, kargs[param]) else: setattr(self, param, default) except ModuleAttributeError: pass super().__init__() # taken from joeynmt repo def freeze_params(module: nn.Module) -> None: """ Freeze the parameters of this module, i.e. do not update them during training :param module: freeze parameters of this module """ for _, p in module.named_parameters(): p.requires_grad = False ================================================ FILE: PBnet/src/parser/base.py ================================================ from argparse import ArgumentParser # noqa def add_misc_options(parser): group = parser.add_argument_group('Miscellaneous options') group.add_argument("--expname", default="exps", help="general directory to this experiments, use it if you don't provide folder name") group.add_argument("--folder", default="exps/default_path", help="directory name to save models") def add_cuda_options(parser): group = parser.add_argument_group('Cuda options') group.add_argument("--cuda", dest='cuda', action='store_true', help="if we want to try to use gpu") group.add_argument('--cpu', dest='cuda', action='store_false', help="if we want to use cpu") group.set_defaults(cuda=True) group.add_argument("--gpu", default='0', help="choose gpu device.") def adding_cuda(parameters): import torch if (parameters["cuda"] or parameters["gpu"]) and torch.cuda.is_available(): parameters["device"] = torch.device("cuda") else: parameters["device"] = torch.device("cpu") ================================================ FILE: PBnet/src/parser/checkpoint.py ================================================ import os from .base import ArgumentParser, adding_cuda from .tools import load_args def parser(): parser = ArgumentParser() parser.add_argument("checkpointname") parser.add_argument("--num_epochs", type=int, default=5000, help="new number of epochs of training") opt = parser.parse_args() folder, checkpoint = os.path.split(opt.checkpointname) parameters = load_args(os.path.join(folder, "opt.yaml")) parameters["num_epochs"] = opt.num_epochs adding_cuda(parameters) epoch = int(checkpoint.split("_")[-1].split('.')[0]) return parameters, folder, checkpoint, epoch def construct_checkpointname(parameters, folder): implist = [parameters["modelname"], parameters["dataset"], parameters["extraction_method"], parameters["pose_rep"]] if parameters["pose_rep"] != "xyz": # [True, ""] to be compatible with generate job if "glob" in parameters: implist.append("glob" if parameters["glob"] in [True, ""] else "noglob") else: implist.append("noglob") if "translation" in parameters: implist.append("translation" if parameters["translation"] in [True, ""] else "notranslation") else: implist.append("notranslation") if "rcxyz" in parameters["modelname"]: implist.append("joinstype_{}".format(parameters["jointstype"])) if "num_layers" in parameters: implist.append("numlayers_{}".format(parameters["num_layers"])) for name in ["num_frames", "min_len", "max_len", "num_seq_max"]: pvalue = parameters[name] pname = name.replace("_", "") if pvalue != -1: implist.append(f"{pname}_{pvalue}") if "view" in parameters: if parameters["view"] == "frontview": implist.append("frontview") if "use_z" in parameters: if parameters["use_z"] != 0: implist.append("usez") else: implist.append("noz") if "vertstrans" in parameters: implist.append("vetr" if parameters["vertstrans"] else "novetr") if "ablation" in parameters: abl = parameters["ablation"] if abl not in ["", None]: implist.append(f"abl_{abl}") if parameters["num_frames"] != -1: implist.append("sampling_{}".format(parameters["sampling"])) if parameters["sampling"] == "conseq": implist.append("samplingstep_{}".format(parameters["sampling_step"])) if "lambda_kl" in parameters: implist.append("kl_{:.0e}".format(float(parameters["lambda_kl"]))) if "activation" in parameters: act = parameters["activation"] implist.append(act) implist.append("bs_{}".format(parameters["batch_size"])) implist.append("ldim_{}".format(parameters["latent_dim"])) checkpoint = "_".join(implist) return os.path.join(folder, checkpoint) ================================================ FILE: PBnet/src/parser/dataset.py ================================================ from src.datasets.dataset import POSE_REPS def add_dataset_options(parser): group = parser.add_argument_group('Dataset options') group.add_argument("--dataset", default='crema', help="Name of the dataset") group.add_argument("--num_frames", default=60, type=int, help="number of frames or -1 => whole, -2 => random between min_len and total") # group.add_argument("--sampling", default="conseq", choices=["conseq", "random_conseq", "random"], help="sampling choices") # group.add_argument("--sampling_step", default=1, type=int, help="sampling step") # group.add_argument("--pose_rep", required=True, choices=POSE_REPS, help="xyz or rotvec etc") group.add_argument("--max_len", default=-1, type=int, help="number of frames maximum per sequence or -1") group.add_argument("--min_len", default=-1, type=int, help="number of frames minimum per sequence or -1") group.add_argument("--num_seq_max", default=-1, type=int, help="number of sequences maximum to load or -1") # group.add_argument("--glob", dest='glob', action='store_true', help="if we want global rotation") # group.add_argument('--no-glob', dest='glob', action='store_false', help="if we don't want global rotation") # group.set_defaults(glob=True) # group.add_argument("--glob_rot", type=int, nargs="+", default=[3.141592653589793, 0, 0], # help="Default rotation, usefull if glob is False") # group.add_argument("--translation", dest='translation', action='store_true', # help="if we want to output translation") # group.add_argument('--no-translation', dest='translation', action='store_false', # help="if we don't want to output translation") # group.set_defaults(translation=True) # group.add_argument("--debug", dest='debug', action='store_true', help="if we are in debug mode") # group.set_defaults(debug=False) ================================================ FILE: PBnet/src/parser/evaluation.py ================================================ import argparse import os import sys sys.path.append('/train20/intern/permanent/lmlin2/ReferenceCode/ACTOR-master') from src.parser.tools import load_args from src.parser.base import add_cuda_options, adding_cuda def parser(): parser = argparse.ArgumentParser() parser.add_argument("checkpointname") parser.add_argument("--dataset", default='crema', help="name of dataset") parser.add_argument("--batch_size", type=int, default=32, help="size of the batches") parser.add_argument("--num_frames", default=60, type=int, help="number of frames, if value is bigger than gt nf, load all nf. ") parser.add_argument("--niter", default=20, type=int, help="number of iterations") parser.add_argument("--num_seq_max", default=3000, type=int, help="number of sequences maximum to load or -1") # cuda options add_cuda_options(parser) opt = parser.parse_args() newparameters = {key: val for key, val in vars(opt).items() if val is not None} folder, checkpoint = os.path.split(newparameters["checkpointname"]) parameters = load_args(os.path.join(folder, "opt.yaml")) parameters.update(newparameters) adding_cuda(parameters) if checkpoint.split("_")[0] == 'retraincheckpoint': epoch = int(checkpoint.split("_")[2])+int(checkpoint.split("_")[4].split('.')[0]) else: epoch = int(checkpoint.split("_")[1].split('.')[0]) return parameters, folder, checkpoint, epoch, opt.niter ================================================ FILE: PBnet/src/parser/finetunning.py ================================================ import os from .base import argparse, adding_cuda, load_args def parser(): parser = argparse.ArgumentParser() parser.add_argument("checkpointname") group = parser.add_argument_group('Finetunning options (what should change)') group.add_argument("--num_epochs", type=int, help="new number of epochs of training") group.add_argument("--batch_size", type=int, help="size of the batches") group.add_argument("--lr", type=float, help="AdamW: learning rate") group.add_argument("--snapshot", type=int, help="frequency of saving model/viz") group.add_argument("--num_frames", default=-2, type=int, help="number of frames or -1 => whole, -2 => random between min_len and total") group.add_argument("--min_len", default=60, type=int, help="number of frames minimum per sequence or -1") group.add_argument("--max_len", default=100, type=int, help="number of frames maximum per sequence or -1") opt = parser.parse_args() folder, checkpoint = os.path.split(opt.checkpointname) parameters = load_args(os.path.join(folder, "opt.yaml")) parameters["folder"] = folder adding_cuda(parameters) epoch = int(checkpoint.split("_")[-1].split('.')[0]) return parameters, folder, checkpoint, epoch ================================================ FILE: PBnet/src/parser/generate.py ================================================ import os from src.models.get_model import JOINTSTYPES from .base import ArgumentParser, add_cuda_options, adding_cuda from .tools import load_args def add_generation_options(parser): group = parser.add_argument_group('Generation options') group.add_argument("--num_samples_per_action", default=5, type=int, help="num samples per action") group.add_argument("--num_frames", default=60, type=int, help="The number of frames considered (overrided if duration mode is chosen)") group.add_argument("--fact_latent", default=1, type=int, help="Fact latent") group.add_argument("--jointstype", default="smpl", choices=JOINTSTYPES, help="Jointstype for training with xyz") group.add_argument('--vertstrans', dest='vertstrans', action='store_true', help="Add the vertex translations") group.add_argument('--no-vertstrans', dest='vertstrans', action='store_false', help="Do not add the vertex translations") group.set_defaults(vertstrans=False) group.add_argument("--mode", default="gen", choices=["interpolate", "gen", "duration", "reconstruction"], help="The kind of generation considered.") def parser(): parser = ArgumentParser() parser.add_argument("checkpointname") # add visualize options back add_generation_options(parser) # cuda options add_cuda_options(parser) opt = parser.parse_args() newparameters = {key: val for key, val in vars(opt).items() if val is not None} folder, checkpoint = os.path.split(newparameters["checkpointname"]) parameters = load_args(os.path.join(folder, "opt.yaml")) parameters.update(newparameters) adding_cuda(parameters) epoch = int(checkpoint.split("_")[-1].split('.')[0]) return parameters, folder, checkpoint, epoch ================================================ FILE: PBnet/src/parser/model.py ================================================ from src.models.get_model import LOSSES, MODELTYPES, ARCHINAMES def add_model_options(parser): group = parser.add_argument_group('Model options') group.add_argument("--modelname", default='cvae_transformer_rc_kl', help="Choice of the model, should be like cvae_transformer_rc_rcxyz_kl") group.add_argument("--latent_dim", default=256, type=int, help="dimensionality of the latent space") group.add_argument("--lambda_kl", default=1.0, type=float, help="weight of the kl divergence loss") group.add_argument("--lambda_rcw", default=1.0, type=float, help="weight of the rc divergence loss with weight") group.add_argument("--lambda_rc", default=1.0, type=float, help="weight of the rc divergence loss") group.add_argument("--lambda_ssim", default=1.0, type=float, help="weight of the ssim divergence loss") group.add_argument("--lambda_reg", default=0.1, type=float, help="weight of the reg loss") # group.add_argument("--lambda_var", default=-0.1, type=float, help="weight of the var divergence loss") group.add_argument("--num_layers", default=2, type=int, help="Number of layers for GRU and transformer") group.add_argument("--ff_size", default=128, type=int, help="Size of feedforward for transformer") group.add_argument("--max_distance", default=128, type=int, help="") group.add_argument("--num_buckets", default=128, type=int, help="") group.add_argument("--audio_latent_dim", default=256, type=int, help="Size of audio latent for transformer") group.add_argument("--first3", default=False, help="Dim of pose, 3 or 6") group.add_argument("--eye", default=False, help="eye information") group.add_argument("--activation", default="gelu", help="Activation for function for the transformer layers") group.add_argument("--dropout", default=0.1, type=float, help="Activation for function for the transformer layers") # # Ablations # group.add_argument("--ablation", choices=[None, "average_encoder", "zandtime", "time_encoding", "concat_bias"], # help="Ablations for the transformer architechture") def parse_modelname(modelname): modeltype, archiname, *losses = modelname.split("_") if modeltype not in MODELTYPES: raise NotImplementedError("This type of model is not implemented.") if archiname not in ARCHINAMES: raise NotImplementedError("This architechture is not implemented.") if len(losses) == 0: raise NotImplementedError("You have to specify at least one loss function.") for loss in losses: if loss not in LOSSES: raise NotImplementedError("This loss is not implemented.") return modeltype, archiname, losses ================================================ FILE: PBnet/src/parser/recognition.py ================================================ import os from .base import argparse, add_misc_options, add_cuda_options, adding_cuda from .tools import save_args from .dataset import add_dataset_options from .training import add_training_options from .checkpoint import construct_checkpointname def training_parser(): parser = argparse.ArgumentParser() # misc options add_misc_options(parser) # training options add_training_options(parser) # dataset options add_dataset_options(parser) # model options add_cuda_options(parser) opt = parser.parse_args() # remove None params, and create a dictionnary parameters = {key: val for key, val in vars(opt).items() if val is not None} parameters["modelname"] = "recognition" if "folder" not in parameters: parameters["folder"] = construct_checkpointname(parameters, parameters["expname"]) os.makedirs(parameters["folder"], exist_ok=True) save_args(parameters, folder=parameters["folder"]) adding_cuda(parameters) return parameters ================================================ FILE: PBnet/src/parser/tools.py ================================================ import os import yaml def save_args(opt, folder): os.makedirs(folder, exist_ok=True) # Save as yaml optpath = os.path.join(folder, "opt.yaml") with open(optpath, 'w') as opt_file: yaml.dump(opt, opt_file) def load_args(filename): with open(filename, "rb") as optfile: opt = yaml.load(optfile, Loader=yaml.Loader) return opt ================================================ FILE: PBnet/src/parser/training.py ================================================ import os from .base import add_misc_options, add_cuda_options, adding_cuda, ArgumentParser from .tools import save_args from .dataset import add_dataset_options from .model import add_model_options, parse_modelname from .checkpoint import construct_checkpointname def add_training_options(parser): group = parser.add_argument_group('Training options') group.add_argument("--ckpt", default='') group.add_argument("--batch_size", default=100, type=int, help="size of the batches") group.add_argument("--num_epochs", default=5000, type=int, help="number of epochs of training") group.add_argument("--lr", default=0.0004, type=float, help="AdamW: learning rate") group.add_argument("--snapshot", default=2000, type=int, help="frequency of saving model/viz") # ff_size def parser(): parser = ArgumentParser() # misc options add_misc_options(parser) # cuda options add_cuda_options(parser) # training options add_training_options(parser) # dataset options add_dataset_options(parser) # model options add_model_options(parser) opt = parser.parse_args() # remove None params, and create a dictionnary parameters = {key: val for key, val in vars(opt).items() if val is not None} # parse modelname ret = parse_modelname(parameters["modelname"]) parameters["modeltype"], parameters["archiname"], parameters["losses"] = ret # update lambdas params lambdas = {} for loss in parameters["losses"]: lambdas[loss] = opt.__getattribute__(f"lambda_{loss}") parameters["lambdas"] = lambdas if "folder" not in parameters: parameters["folder"] = construct_checkpointname(parameters, parameters["expname"]) os.makedirs(parameters["folder"], exist_ok=True) save_args(parameters, folder=parameters["folder"]) adding_cuda(parameters) return parameters ================================================ FILE: PBnet/src/parser/visualize.py ================================================ import os from src.models.get_model import JOINTSTYPES from .base import ArgumentParser, add_cuda_options, adding_cuda from .tools import load_args from .dataset import add_dataset_options def construct_figname(parameters): figname = "fig_{:03d}" return figname def add_visualize_options(parser): group = parser.add_argument_group('Visualization options') group.add_argument("--num_actions_to_sample", default=5, type=int, help="num actions to sample") group.add_argument("--num_samples_per_action", default=5, type=int, help="num samples per action") group.add_argument("--fps", default=20, type=int, help="FPS for the rendering") group.add_argument("--force_visu_joints", dest='force_visu_joints', action='store_true', help="if we want to visualize joints even if it is rotation") group.add_argument('--no-force_visu_joints', dest='force_visu_joints', action='store_false', help="if we don't want to visualize joints even if it is rotation") group.set_defaults(force_visu_joints=True) group.add_argument("--jointstype", default="smpl", choices=JOINTSTYPES, help="Jointstype for training with xyz") group.add_argument('--vertstrans', dest='vertstrans', action='store_true', help="Training with vertex translations") group.add_argument('--no-vertstrans', dest='vertstrans', action='store_false', help="Training without vertex translations") group.set_defaults(vertstrans=False) group.add_argument("--noise_same_action", default="random", choices=["interpolate", "random", "same"], help="inside one action, sample several noise or interpolate it") group.add_argument("--noise_diff_action", default="random", choices=["random", "same"], help="use the same noise or different noise for every actions") group.add_argument("--duration_mode", default="mean", choices=["mean", "interpolate"], help="use the same noise or different noise for every actions") group.add_argument("--reconstruction_mode", default="ntf", choices=["tf", "ntf", "both"], help="reconstruction: teacher forcing or not or both") group.add_argument("--decoder_test", default="new", choices=["new", "diffaction", "diffduration", "interpolate_action"], help="what is the test we want to do") group.add_argument("--fact_latent", type=int, default=1, help="factor for max latent space") def parser(checkpoint=True): parser = ArgumentParser() if checkpoint: parser.add_argument("checkpointname") else: add_dataset_options(parser) # add visualize options back add_visualize_options(parser) # cuda options add_cuda_options(parser) opt = parser.parse_args() if checkpoint: newparameters = {key: val for key, val in vars(opt).items() if val is not None} folder, checkpoint = os.path.split(newparameters["checkpointname"]) parameters = load_args(os.path.join(folder, "opt.yaml")) parameters.update(newparameters) else: parameters = {key: val for key, val in vars(opt).items() if val is not None} adding_cuda(parameters) if checkpoint: parameters["figname"] = construct_figname(parameters) epoch = int(checkpoint.split("_")[-1].split('.')[0]) return parameters, folder, checkpoint, epoch else: return parameters ================================================ FILE: PBnet/src/preprocess/humanact12_process.py ================================================ import os import numpy as np import pickle as pkl from phspdtools import CameraParams def splitname(name): subject = name[1:3] group = name[4:6] time = name[7:9] frame1 = name[10:14] frame2 = name[15:19] action = name[20:24] return subject, group, time, frame1, frame2, action def create_phpsd_name(name): subject, group, time, frame1, frame2, action = splitname(name) phpsdname = f"subject{subject}_group{int(group)}_time{int(time)}" return phpsdname def get_frames(name): subject, group, time, frame1, frame2, action = splitname(name) return int(frame1), int(frame2) def get_action(name, coarse=True): subject, group, time, frame1, frame2, action = splitname(name) if coarse: return action[:2] else: return action humanact12_coarse_action_enumerator = { 1: "warm_up", 2: "walk", 3: "run", 4: "jump", 5: "drink", 6: "lift_dumbbell", 7: "sit", 8: "eat", 9: "turn steering wheel", 10: "phone", 11: "boxing", 12: "throw", } humanact12_coarse_action_to_label = {x: x-1 for x in range(1, 13)} def process_datata(savepath, posesfolder="data/PHPSDposes", datapath="data/HumanAct12", campath="data/phspdCameras"): data_list = os.listdir(datapath) data_list.sort() camera_params = CameraParams(campath) vibestyle = {"poses": [], "oldposes": [], "joints3D": [], "y": []} for index, name in enumerate(data_list): foldername = create_phpsd_name(name) subject = foldername.split("_")[0] T = camera_params.get_extrinsic("c2", subject) frame1, frame2 = get_frames(name) # subjecta, groupa, timea, frame1a, frame2a, actiona = splitname(name) posepath = os.path.join(posesfolder, foldername, "pose.txt") smplposepath = os.path.join(posesfolder, foldername, "shape_smpl.txt") npypath = os.path.join(datapath, name) joints3D = np.load(npypath) # take this one to get same number of frames that HumanAct12 joints .npy file # Otherwise we have to much frames (the registration is not perfect) poses = [] goodframes = [] with open(posepath) as f: for line in f.readlines(): tmp = line.split(' ') frame_idx = int(tmp[0]) if frame_idx >= frame1 and frame_idx <= frame2: goodframes.append(frame_idx) pose = np.asarray([float(i) for i in tmp[1:]]).reshape([-1, 3]) poses.append(pose) poses = np.array(poses) # if joints3D.shape[0] == (frame2 - frame1 + 1): # continue smplposes = [] with open(smplposepath) as f: for line in f.readlines(): tmp = line.split(' ') frame_idx = int(tmp[0]) if frame_idx in goodframes: # pose = np.asarray([float(i) for i in tmp[1:]]).reshape([-1, 3]) # poses.append(pose) smplparam = np.asarray([float(i) for i in tmp[1:]]) smplpose = smplparam[13:85] smplposes.append(smplpose) smplposes = np.array(smplposes) oldposes = poses.copy() # rotate to the good camera poses = T.transform(poses) poses = poses - poses[0][0] + joints3D[0][0] # and verify that the pose correspond to the humanact12 data if np.linalg.norm(poses - joints3D) >= 1e-10: print("bad") continue assert np.linalg.norm(poses - joints3D) < 1e-10 rotation = T.getmat4()[:3, :3] import pytorch3d.transforms.rotation_conversions as p3d import torch # rotate the global rotation global_matrix = p3d.axis_angle_to_matrix(torch.from_numpy(smplposes[:, :3])) smplposes[:, :3] = p3d.matrix_to_axis_angle(torch.from_numpy(rotation) @ global_matrix).numpy() assert poses.shape[0] == joints3D.shape[0] assert smplposes.shape[0] == joints3D.shape[0] vibestyle["poses"].append(smplposes) vibestyle["joints3D"].append(joints3D) action = get_action(name, coarse=True) label = humanact12_coarse_action_to_label[int(action)] vibestyle["y"].append(label) pkl.dump(vibestyle, open(savepath, "wb")) if __name__ == "__main__": folder = "data/HumanAct12Poses/" os.makedirs(folder, exist_ok=True) savepath = os.path.join(folder, "humanact12poses.pkl") process_datata(savepath) ================================================ FILE: PBnet/src/preprocess/phspdtools.py ================================================ # taken and adapted from https://github.com/JimmyZou/PolarHumanPoseShape/ import pickle import numpy as np import os class Transform: def __init__(self, R=np.eye(3, dtype='float'), t=np.zeros(3, 'float'), s=np.ones(3, 'float')): self.R = R.copy() # rotation self.t = t.reshape(-1).copy() # translation self.s = s.copy() # scale def __mul__(self, other): # combine two transformation together R = np.dot(self.R, other.R) t = np.dot(self.R, other.t * self.s) + self.t if not hasattr(other, 's'): other.s = np.ones(3, 'float').copy() s = other.s.copy() return Transform(R, t, s) def inv(self): # inverse the rigid tansformation R = self.R.T t = -np.dot(self.R.T, self.t) return Transform(R, t) def transform(self, xyz): # transform 3D point if not hasattr(self, 's'): self.s = np.ones(3, 'float').copy() assert xyz.shape[-1] == 3 assert len(self.s) == 3 return np.dot(xyz * self.s, self.R.T) + self.t def getmat4(self): # homogeneous transformation matrix M = np.eye(4) M[:3, :3] = self.R * self.s M[:3, 3] = self.t return M def quat2R(quat): """ Description =========== convert vector q to matrix R Parameters ========== :param quat: (4,) array Returns ======= :return: (3,3) array """ w = quat[0] x = quat[1] y = quat[2] z = quat[3] n = w * w + x * x + y * y + z * z s = 2. / np.clip(n, 1e-7, 1e7) wx = s * w * x wy = s * w * y wz = s * w * z xx = s * x * x xy = s * x * y xz = s * x * z yy = s * y * y yz = s * y * z zz = s * z * z R = np.stack([1 - (yy + zz), xy - wz, xz + wy, xy + wz, 1 - (xx + zz), yz - wx, xz - wy, yz + wx, 1 - (xx + yy)]) return R.reshape((3, 3)) def convert_param2tranform(param, scale=1): R = quat2R(param[0:4]) t = param[4:7] s = scale * np.ones(3, 'float') return Transform(R, t, s) class CameraParams: def __init__(self, cam_folder="data/phspdCameras"): # load camera params, save intrinsic and extrinsic camera parameters as a dictionary # intrinsic ['param_p', 'param_c1', 'param_d1', 'param_c2', 'param_d2', 'param_c3', 'param_d3'] # extrinsic ['d1p', 'd2p', 'd3p', 'cd1', 'cd2', 'cd3'] self.cam_params = [] with open(os.path.join(cam_folder, "CamParams0906.pkl"), 'rb') as f: self.cam_params.append(pickle.load(f)) with open(os.path.join(cam_folder, "CamParams0909.pkl"), 'rb') as f: self.cam_params.append(pickle.load(f)) # corresponding cam params to each subject self.name_cam_params = {} # {"name": 0 or 1} for name in ['subject06', 'subject09', 'subject11', 'subject05', 'subject12', 'subject04']: self.name_cam_params[name] = 0 for name in ['subject03', 'subject01', 'subject02', 'subject10', 'subject07', 'subject08']: self.name_cam_params[name] = 1 # corresponding cam params to each subject self.name_gender = {} # {"name": 0 or 1} for name in ['subject02', 'subject03', 'subject04', 'subject05', 'subject06', 'subject08', 'subject09', 'subject11', 'subject12']: self.name_gender[name] = 0 # male for name in ['subject01', 'subject07', 'subject10']: self.name_gender[name] = 1 # female def get_intrinsic(self, cam_name, subject_no): """ 'p': polarization camera, color 'c1': color camera for the 1st Kinect 'd1': depth (ToF) camera for the 1st Kinect ... return (fx, fy, cx, cy) """ assert cam_name in ['p', 'c1', 'd1', 'c2', 'd2', 'c3', 'd3'] assert subject_no in ['subject06', 'subject09', 'subject11', 'subject05', 'subject12', 'subject04', 'subject03', 'subject01', 'subject02', 'subject10', 'subject07', 'subject08'] fx, fy, cx, cy, _, _, _ = self.cam_params[self.name_cam_params[subject_no]]['param_%s' % cam_name] intrinsic = (fx, fy, cx, cy) return intrinsic def get_extrinsic(self, cams_name, subject_no): """ The annotated poses and shapes are saved in polarization camera coordinate. 'd1p': transform from polarization camera to 1st Kinect depth image 'c1p': transform from polarization camera to 1st Kinect color image ... return transform class """ assert cams_name in ['d1', 'd2', 'd3', 'c1', 'c2', 'c3'] assert subject_no in ['subject06', 'subject09', 'subject11', 'subject05', 'subject12', 'subject04', 'subject03', 'subject01', 'subject02', 'subject10', 'subject07', 'subject08'] if cams_name in ['d1p', 'd2p', 'd3p']: T = convert_param2tranform(self.cam_params[self.name_cam_params[subject_no]][cams_name]) else: i = cams_name[1] T_dp = convert_param2tranform(self.cam_params[self.name_cam_params[subject_no]]['d%sp' % i]) T_cd = convert_param2tranform(self.cam_params[self.name_cam_params[subject_no]]['cd%s' % i]) T = T_cd * T_dp return T def get_gender(self, subject_no): return self.name_gender[subject_no] if __name__ == '__main__': # test camera_params = CameraParams(data_dir='../..//data') T = camera_params.get_extrinsic('c2', 'subject01') print(T.getmat4()) ================================================ FILE: PBnet/src/preprocess/uestc_vibe_postprocessing.py ================================================ import numpy as np import pickle as pkl import tarfile import os import scipy.io as sio from tqdm import tqdm import src.utils.rotation_conversions as geometry import torch W = 960 H = 540 def get_kinect_motion(tar, videos, index): # skeleton loading video = videos[index] skeleton_name = video.replace("color.avi", "skeleton.mat") skeleton_path = os.path.join("mat_from_skeleton", skeleton_name) ffile = tar.extractfile(skeleton_path) skeleton = sio.loadmat(ffile, variable_names=["v"])["v"] skeleton = skeleton.reshape(-1, 25, 3) return skeleton def motionto2d(motion, W=960, H=540): K = np.array(((540, 0, W / 2), (0, 540, H / 2), (0, 0, 1))) motion[..., 1] = -motion[..., 1] motion2d = np.einsum("tjk,lk->tjl", motion, K) nonzeroix = np.where(motion2d[..., 2] != 0) motion2d[nonzeroix] = motion2d[nonzeroix] / motion2d[(*nonzeroix, 2)][..., None] return motion2d[..., :2] def motionto2dvibe(motion, cam): sx, sy, tx, ty = cam return (motion[..., :2] + [tx, ty]) * [W/2*sx, H/2*sy] + [W/2, H/2] def get_kcenter(tar, videos, index): kmotion2d = motionto2d(get_kinect_motion(tar, videos, index)) kboxes = np.hstack((kmotion2d.min(1), kmotion2d.max(1))) x1, y1, x2, y2 = kboxes.T kcenter = np.stack(((x1 + x2)/2, (y1 + y2)/2)).T return kcenter def get_concat_goodtracks(allvibe, tar, videos, index): idxall = allvibe[index] kcenter = get_kcenter(tar, videos, index) tracks = np.array(list(idxall.keys())) if len(tracks) == 1: return idxall[tracks[0]], tracks remainingmask = np.ones(len(tracks), dtype=bool) currenttrack = None vibetracks = [] while remainingmask.any(): # find new track # first look at the closest new track in time candidate = np.argmin([idxall[track]["frame_ids"][0] for track in tracks[remainingmask]]) candidate_max = idxall[tracks[remainingmask][candidate]]["frame_ids"][-1] # look for other candidate which intersect with the candidate (conflict) candidates = np.where(np.array([idxall[track]["frame_ids"][0] <= candidate_max for track in tracks[remainingmask]]))[0] # if the candidate is alone, take it if len(candidates) == 1: idx = np.where(remainingmask)[0][candidate] # if there are conflit, find the closest match else: # take the closest one in distance to the last center observed if currenttrack is None: # take the kinect output lastbox = kcenter[0] else: # take the last boxe output lastbox = idxall[currenttrack]["bboxes"][-1, :2] dists = np.linalg.norm([idxall[tracks[remainingmask][candidate]]["bboxes"][0, :2] - lastbox for candidate in candidates], axis=1) idx = np.where(remainingmask)[0][candidates[np.argmin(dists)]] # compute informations currenttrack = tracks[idx] vibetracks.append(currenttrack) lastframe = idxall[currenttrack]["frame_ids"][-1] # filter overlapping frames remainingmask = np.array([idxall[track]["frame_ids"][0] > lastframe for track in tracks]) & remainingmask goodvibe = {key: [] for key in ['pred_cam', 'orig_cam', 'pose', 'betas', 'joints3d', 'bboxes', 'frame_ids']} for key in goodvibe: goodvibe[key] = np.concatenate([idxall[track][key] for track in vibetracks]) return goodvibe, vibetracks def interpolate_track(gvibe): # interpolation starting = np.where((gvibe["frame_ids"][1:] - gvibe["frame_ids"][:-1]) != 1)[0] + 1 lastend = 0 saveall = {key: [] for key in gvibe.keys() if key != "joints2d"} for start in starting: begin = start - 1 end = start lastgoodidx = gvibe["frame_ids"][begin] firstnewgoodidx = gvibe["frame_ids"][end] for key in saveall.keys(): # save the segment before the cut saveall[key].append(gvibe[key][lastend:begin+1]) # extract the last good info lastgoodinfo = gvibe[key][begin] # extract the first regood info newfirstgoodinfo = gvibe[key][end] if key == "pose": # interpolate in quaternions q0 = geometry.axis_angle_to_quaternion(torch.from_numpy(lastgoodinfo.reshape(24, 3))) q1 = geometry.axis_angle_to_quaternion(torch.from_numpy(newfirstgoodinfo.reshape(24, 3))) q2 = geometry.axis_angle_to_quaternion(-torch.from_numpy(newfirstgoodinfo.reshape(24, 3))) # Help when the interpolation is between pi and -pi # It avoid the problem of inverting people with global rotation # It is not optimal but it is better than nothing # newfirstgoodinfo = torch.where((torch.argmin(torch.stack((torch.linalg.norm(q0-q1, axis=1), # torch.linalg.norm(q0-q2, axis=1))), axis=0) == 0)[:, None], q1, q2) first = [q1[0], q2[0]][np.argmin((torch.linalg.norm(q0[0]-q1[0]), torch.linalg.norm(q0[0]-q2[0])))] newfirstgoodinfo = q1 newfirstgoodinfo[0] = first lastgoodinfo = q0 # interpolate in between interinfo = [] for x in range(lastgoodidx+1, firstnewgoodidx): # linear coeficient w2 = x - lastgoodidx w1 = firstnewgoodidx - x w1, w2 = w1/(w1+w2), w2/(w1+w2) inter = lastgoodinfo * w1 + newfirstgoodinfo * w2 if key == "pose": # interpolate in quaternions # normalize the quaternion inter = inter/torch.linalg.norm(inter, axis=1)[:, None] inter = geometry.quaternion_to_axis_angle(inter).numpy().reshape(-1) interinfo.append(inter) saveall[key].append(interinfo) lastend = end for key in saveall.keys(): saveall[key].append(gvibe[key][lastend:]) saveall[key] = np.concatenate(saveall[key]) saveall["frame_ids"] = np.round(saveall["frame_ids"]).astype(int) # make sure the interpolation was fine => looking at a whole frame_ids assert (saveall["frame_ids"] == np.arange(gvibe["frame_ids"].min(), gvibe["frame_ids"].max()+1)).all() return saveall if __name__ == "__main__": datapath = "datasets/uestc/" allpath = os.path.join(datapath, "vibe_cache_all_tracks.pkl") oldpath = os.path.join(datapath, "vibe_cache.pkl") videopath = os.path.join(datapath, 'info', 'names.txt') kinectpath = os.path.join(datapath, "mat_from_skeleton.tar") allvibe = pkl.load(open(allpath, "rb")) oldvibe = pkl.load(open(oldpath, "rb")) videos = open(videopath, 'r').read().splitlines() tar = tarfile.open(kinectpath, "r") newvibelst = [] allvtracks = [] for index in tqdm(range(len(videos))): gvibe, vtracks = get_concat_goodtracks(allvibe, tar, videos, index) allvtracks.append(vtracks) newvibelst.append(interpolate_track(gvibe)) newvibe = {key: [] for key in newvibelst[0].keys()} for nvibe in newvibelst: for key in newvibe: newvibe[key].append(nvibe[key]) pkl.dump(newvibe, open("newvibe.pkl", "wb")) ================================================ FILE: PBnet/src/recognition/compute_accuracy.py ================================================ import os import torch from torch.utils.data import DataLoader from tqdm import tqdm from src.utils.get_model_and_data import get_model_and_data from src.utils.tensors import collate from src.evaluate.tools import save_metrics from src.parser.checkpoint import parser import src.utils.fixseed # noqa def compute_accuracy(model, datasets, parameters): device = parameters["device"] iterators = {key: DataLoader(datasets[key], batch_size=parameters["batch_size"], shuffle=False, num_workers=8, collate_fn=collate) for key in datasets.keys()} model.eval() num_labels = parameters["num_classes"] accuracies = {} with torch.no_grad(): for key, iterator in iterators.items(): confusion = torch.zeros(num_labels, num_labels, dtype=torch.long) for batch in tqdm(iterator, desc=f"Computing {key} batch"): # Put everything in device batch = {key: val.to(device) for key, val in batch.items()} # forward pass batch = model(batch) yhat = batch["yhat"].max(dim=1).indices ygt = batch["y"] for label, pred in zip(ygt, yhat): confusion[label][pred] += 1 accuracy = (torch.trace(confusion)/torch.sum(confusion)).item() accuracies[key] = accuracy return accuracies def main(): # parse options parameters, folder, checkpointname, epoch = parser() model, datasets = get_model_and_data(parameters) print("Restore weights..") checkpointpath = os.path.join(folder, checkpointname) state_dict = torch.load(checkpointpath, map_location=parameters["device"]) model.load_state_dict(state_dict) accuracies = compute_accuracy(model, datasets, parameters) metricname = "recognition_accuracies_on_samedata_{}.yaml".format(epoch) evalpath = os.path.join(folder, metricname) print(f"Saving score: {evalpath}") save_metrics(evalpath, accuracies) if __name__ == '__main__': main() ================================================ FILE: PBnet/src/recognition/get_model.py ================================================ from .models.stgcn import STGCN def get_model(parameters): layout = "smpl" if parameters["glob"] else "smpl_noglobal" model = STGCN(in_channels=parameters["nfeats"], num_class=parameters["num_classes"], graph_args={"layout": layout, "strategy": "spatial"}, edge_importance_weighting=True, device=parameters["device"]) model = model.to(parameters["device"]) return model ================================================ FILE: PBnet/src/recognition/models/stgcn.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from .stgcnutils.tgcn import ConvTemporalGraphical from .stgcnutils.graph import Graph __all__ = ["STGCN"] class STGCN(nn.Module): r"""Spatial temporal graph convolutional networks. Args: in_channels (int): Number of channels in the input data num_class (int): Number of classes for the classification task graph_args (dict): The arguments for building the graph edge_importance_weighting (bool): If ``True``, adds a learnable importance weighting to the edges of the graph **kwargs (optional): Other parameters for graph convolution units Shape: - Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})` - Output: :math:`(N, num_class)` where :math:`N` is a batch size, :math:`T_{in}` is a length of input sequence, :math:`V_{in}` is the number of graph nodes, :math:`M_{in}` is the number of instance in a frame. """ def __init__(self, in_channels, num_class, graph_args, edge_importance_weighting, device, **kwargs): super().__init__() self.device = device self.num_class = num_class self.losses = ["accuracy", "cross_entropy", "mixed"] self.criterion = torch.nn.CrossEntropyLoss(reduction='mean') # load graph self.graph = Graph(**graph_args) A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False) self.register_buffer('A', A) # build networks spatial_kernel_size = A.size(0) temporal_kernel_size = 9 kernel_size = (temporal_kernel_size, spatial_kernel_size) self.data_bn = nn.BatchNorm1d(in_channels * A.size(1)) kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'} self.st_gcn_networks = nn.ModuleList(( st_gcn(in_channels, 64, kernel_size, 1, residual=False, **kwargs0), st_gcn(64, 64, kernel_size, 1, **kwargs), st_gcn(64, 64, kernel_size, 1, **kwargs), st_gcn(64, 64, kernel_size, 1, **kwargs), st_gcn(64, 128, kernel_size, 2, **kwargs), st_gcn(128, 128, kernel_size, 1, **kwargs), st_gcn(128, 128, kernel_size, 1, **kwargs), st_gcn(128, 256, kernel_size, 2, **kwargs), st_gcn(256, 256, kernel_size, 1, **kwargs), st_gcn(256, 256, kernel_size, 1, **kwargs), )) # initialize parameters for edge importance weighting if edge_importance_weighting: self.edge_importance = nn.ParameterList([ nn.Parameter(torch.ones(self.A.size())) for i in self.st_gcn_networks ]) else: self.edge_importance = [1] * len(self.st_gcn_networks) # fcn for prediction self.fcn = nn.Conv2d(256, num_class, kernel_size=1) def forward(self, batch): # TODO: use mask # Received batch["x"] as # Batch(48), Joints(23), Quat(4), Time(157 # Expecting: # Batch, Quat:4, Time, Joints, 1 x = batch["x"].permute(0, 2, 3, 1).unsqueeze(4).contiguous() # data normalization N, C, T, V, M = x.size() x = x.permute(0, 4, 3, 1, 2).contiguous() x = x.view(N * M, V * C, T) x = self.data_bn(x) x = x.view(N, M, V, C, T) x = x.permute(0, 1, 3, 4, 2).contiguous() x = x.view(N * M, C, T, V) # forward for gcn, importance in zip(self.st_gcn_networks, self.edge_importance): x, _ = gcn(x, self.A * importance) # compute feature # _, c, t, v = x.size() # features = x.view(N, M, c, t, v).permute(0, 2, 3, 4, 1) # batch["features"] = features # global pooling x = F.avg_pool2d(x, x.size()[2:]) x = x.view(N, M, -1, 1, 1).mean(dim=1) # features batch["features"] = x.squeeze() # prediction x = self.fcn(x) x = x.view(x.size(0), -1) batch["yhat"] = x return batch def compute_accuracy(self, batch): confusion = torch.zeros(self.num_class, self.num_class, dtype=int) yhat = batch["yhat"].max(dim=1).indices ygt = batch["y"] for label, pred in zip(ygt, yhat): confusion[label][pred] += 1 accuracy = torch.trace(confusion)/torch.sum(confusion) return accuracy def compute_loss(self, batch): cross_entropy = self.criterion(batch["yhat"], batch["y"]) mixed_loss = cross_entropy acc = self.compute_accuracy(batch) losses = {"cross_entropy": cross_entropy.item(), "mixed": mixed_loss.item(), "accuracy": acc.item()} return mixed_loss, losses class st_gcn(nn.Module): r"""Applies a spatial temporal graph convolution over an input graph sequence. Args: in_channels (int): Number of channels in the input sequence data out_channels (int): Number of channels produced by the convolution kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel stride (int, optional): Stride of the temporal convolution. Default: 1 dropout (int, optional): Dropout rate of the final output. Default: 0 residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True`` Shape: - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format where :math:`N` is a batch size, :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, :math:`T_{in}/T_{out}` is a length of input/output sequence, :math:`V` is the number of graph nodes. """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dropout=0, residual=True): super().__init__() assert len(kernel_size) == 2 assert kernel_size[0] % 2 == 1 padding = ((kernel_size[0] - 1) // 2, 0) self.gcn = ConvTemporalGraphical(in_channels, out_channels, kernel_size[1]) self.tcn = nn.Sequential( nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d( out_channels, out_channels, (kernel_size[0], 1), (stride, 1), padding, ), nn.BatchNorm2d(out_channels), nn.Dropout(dropout, inplace=True), ) if not residual: self.residual = lambda x: 0 elif (in_channels == out_channels) and (stride == 1): self.residual = lambda x: x else: self.residual = nn.Sequential( nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=(stride, 1)), nn.BatchNorm2d(out_channels), ) self.relu = nn.ReLU(inplace=True) def forward(self, x, A): res = self.residual(x) x, A = self.gcn(x, A) x = self.tcn(x) + res return self.relu(x), A if __name__ == "__main__": model = STGCN(in_channels=3, num_class=60, edge_importance_weighting=True, graph_args={"layout": "smpl_noglobal", "strategy": "spatial"}) # Batch, in_channels, time, vertices, M inp = torch.rand(10, 3, 16, 23, 1) out = model(inp) print(out.shape) import pdb pdb.set_trace() ================================================ FILE: PBnet/src/recognition/models/stgcnutils/graph.py ================================================ import numpy as np import pickle as pkl from src.config import SMPL_KINTREE_PATH class Graph: """ The Graph to model the skeletons extracted by the openpose Args: strategy (string): must be one of the follow candidates - uniform: Uniform Labeling - distance: Distance Partitioning - spatial: Spatial Configuration For more information, please refer to the section 'Partition Strategies' in our paper (https://arxiv.org/abs/1801.07455). layout (string): must be one of the follow candidates - openpose: Is consists of 18 joints. For more information, please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose#output - ntu-rgb+d: Is consists of 25 joints. For more information, please refer to https://github.com/shahroudy/NTURGB-D - smpl: Consists of 24/23 joints with without global rotation. max_hop (int): the maximal distance between two connected nodes dilation (int): controls the spacing between the kernel points """ def __init__(self, layout='openpose', strategy='uniform', kintree_path=SMPL_KINTREE_PATH, max_hop=1, dilation=1): self.max_hop = max_hop self.dilation = dilation self.kintree_path = kintree_path self.get_edge(layout) self.hop_dis = get_hop_distance( self.num_node, self.edge, max_hop=max_hop) self.get_adjacency(strategy) def __str__(self): return self.A def get_edge(self, layout): if layout == 'openpose': self.num_node = 18 self_link = [(i, i) for i in range(self.num_node)] neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5), (13, 12), (12, 11), (10, 9), (9, 8), (11, 5), (8, 2), (5, 1), (2, 1), (0, 1), (15, 0), (14, 0), (17, 15), (16, 14)] self.edge = self_link + neighbor_link self.center = 1 elif layout == 'smpl': self.num_node = 24 self_link = [(i, i) for i in range(self.num_node)] kt = pkl.load(open(self.kintree_path, "rb")) neighbor_link = [(k, kt[1][i + 1]) for i, k in enumerate(kt[0][1:])] self.edge = self_link + neighbor_link self.center = 0 elif layout == 'smpl_noglobal': self.num_node = 23 self_link = [(i, i) for i in range(self.num_node)] kt = pkl.load(open(self.kintree_path, "rb")) neighbor_link = [(k, kt[1][i + 1]) for i, k in enumerate(kt[0][1:])] # remove the root joint neighbor_1base = [n for n in neighbor_link if n[0] != 0 and n[1] != 0] neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] self.edge = self_link + neighbor_link self.center = 0 elif layout == 'ntu-rgb+d': self.num_node = 25 self_link = [(i, i) for i in range(self.num_node)] neighbor_1base = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21), (6, 5), (7, 6), (8, 7), (9, 21), (10, 9), (11, 10), (12, 11), (13, 1), (14, 13), (15, 14), (16, 15), (17, 1), (18, 17), (19, 18), (20, 19), (22, 23), (23, 8), (24, 25), (25, 12)] neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] self.edge = self_link + neighbor_link self.center = 21 - 1 elif layout == 'ntu_edge': self.num_node = 24 self_link = [(i, i) for i in range(self.num_node)] neighbor_1base = [(1, 2), (3, 2), (4, 3), (5, 2), (6, 5), (7, 6), (8, 7), (9, 2), (10, 9), (11, 10), (12, 11), (13, 1), (14, 13), (15, 14), (16, 15), (17, 1), (18, 17), (19, 18), (20, 19), (21, 22), (22, 8), (23, 24), (24, 12)] neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] self.edge = self_link + neighbor_link self.center = 2 # elif layout=='customer settings' # pass else: raise NotImplementedError("This Layout is not supported") def get_adjacency(self, strategy): valid_hop = range(0, self.max_hop + 1, self.dilation) adjacency = np.zeros((self.num_node, self.num_node)) for hop in valid_hop: adjacency[self.hop_dis == hop] = 1 normalize_adjacency = normalize_digraph(adjacency) if strategy == 'uniform': A = np.zeros((1, self.num_node, self.num_node)) A[0] = normalize_adjacency self.A = A elif strategy == 'distance': A = np.zeros((len(valid_hop), self.num_node, self.num_node)) for i, hop in enumerate(valid_hop): A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis == hop] self.A = A elif strategy == 'spatial': A = [] for hop in valid_hop: a_root = np.zeros((self.num_node, self.num_node)) a_close = np.zeros((self.num_node, self.num_node)) a_further = np.zeros((self.num_node, self.num_node)) for i in range(self.num_node): for j in range(self.num_node): if self.hop_dis[j, i] == hop: if self.hop_dis[j, self.center] == self.hop_dis[ i, self.center]: a_root[j, i] = normalize_adjacency[j, i] elif self.hop_dis[j, self. center] > self.hop_dis[i, self. center]: a_close[j, i] = normalize_adjacency[j, i] else: a_further[j, i] = normalize_adjacency[j, i] if hop == 0: A.append(a_root) else: A.append(a_root + a_close) A.append(a_further) A = np.stack(A) self.A = A else: raise NotImplementedError("This Strategy is not supported") def get_hop_distance(num_node, edge, max_hop=1): A = np.zeros((num_node, num_node)) for i, j in edge: A[j, i] = 1 A[i, j] = 1 # compute hop steps hop_dis = np.zeros((num_node, num_node)) + np.inf transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)] arrive_mat = (np.stack(transfer_mat) > 0) for d in range(max_hop, -1, -1): hop_dis[arrive_mat[d]] = d return hop_dis def normalize_digraph(A): Dl = np.sum(A, 0) num_node = A.shape[0] Dn = np.zeros((num_node, num_node)) for i in range(num_node): if Dl[i] > 0: Dn[i, i] = Dl[i]**(-1) AD = np.dot(A, Dn) return AD def normalize_undigraph(A): Dl = np.sum(A, 0) num_node = A.shape[0] Dn = np.zeros((num_node, num_node)) for i in range(num_node): if Dl[i] > 0: Dn[i, i] = Dl[i]**(-0.5) DAD = np.dot(np.dot(Dn, A), Dn) return DAD ================================================ FILE: PBnet/src/recognition/models/stgcnutils/tgcn.py ================================================ # The based unit of graph convolutional networks. import torch import torch.nn as nn class ConvTemporalGraphical(nn.Module): r"""The basic module for applying a graph convolution. Args: in_channels (int): Number of channels in the input sequence data out_channels (int): Number of channels produced by the convolution kernel_size (int): Size of the graph convolving kernel t_kernel_size (int): Size of the temporal convolving kernel t_stride (int, optional): Stride of the temporal convolution. Default: 1 t_padding (int, optional): Temporal zero-padding added to both sides of the input. Default: 0 t_dilation (int, optional): Spacing between temporal kernel elements. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` Shape: - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format where :math:`N` is a batch size, :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`, :math:`T_{in}/T_{out}` is a length of input/output sequence, :math:`V` is the number of graph nodes. """ def __init__(self, in_channels, out_channels, kernel_size, t_kernel_size=1, t_stride=1, t_padding=0, t_dilation=1, bias=True): super().__init__() self.kernel_size = kernel_size self.conv = nn.Conv2d( in_channels, out_channels * kernel_size, kernel_size=(t_kernel_size, 1), padding=(t_padding, 0), stride=(t_stride, 1), dilation=(t_dilation, 1), bias=bias) def forward(self, x, A): assert A.size(0) == self.kernel_size x = self.conv(x) n, kc, t, v = x.size() x = x.view(n, self.kernel_size, kc//self.kernel_size, t, v) x = torch.einsum('nkctv,kvw->nctw', (x, A)) return x.contiguous(), A ================================================ FILE: PBnet/src/render/renderer.py ================================================ """ This script is borrowed from https://github.com/mkocabas/VIBE Adhere to their licence to use this script It has been modified """ import math import trimesh import pyrender import numpy as np from pyrender.constants import RenderFlags import os os.environ['PYOPENGL_PLATFORM'] = 'egl' SMPL_MODEL_DIR = "models/smpl/" def get_smpl_faces(): return np.load(os.path.join(SMPL_MODEL_DIR, "smplfaces.npy")) class WeakPerspectiveCamera(pyrender.Camera): def __init__(self, scale, translation, znear=pyrender.camera.DEFAULT_Z_NEAR, zfar=None, name=None): super(WeakPerspectiveCamera, self).__init__( znear=znear, zfar=zfar, name=name, ) self.scale = scale self.translation = translation def get_projection_matrix(self, width=None, height=None): P = np.eye(4) P[0, 0] = self.scale[0] P[1, 1] = self.scale[1] P[0, 3] = self.translation[0] * self.scale[0] P[1, 3] = -self.translation[1] * self.scale[1] P[2, 2] = -1 return P class Renderer: def __init__(self, background=None, resolution=(224, 224), bg_color=[0, 0, 0, 0.5], orig_img=False, wireframe=False): width, height = resolution self.background = np.zeros((height, width, 3)) self.resolution = resolution self.faces = get_smpl_faces() self.orig_img = orig_img self.wireframe = wireframe self.renderer = pyrender.OffscreenRenderer( viewport_width=self.resolution[0], viewport_height=self.resolution[1], point_size=0.5 ) # set the scene self.scene = pyrender.Scene(bg_color=bg_color, ambient_light=(0.4, 0.4, 0.4)) light = pyrender.PointLight(color=[1.0, 1.0, 1.0], intensity=4) light_pose = np.eye(4) light_pose[:3, 3] = [0, -1, 1] self.scene.add(light, pose=light_pose.copy()) light_pose[:3, 3] = [0, 1, 1] self.scene.add(light, pose=light_pose.copy()) light_pose[:3, 3] = [1, 1, 2] self.scene.add(light, pose=light_pose.copy()) """ok light_pose = np.eye(4) light_pose[:3, 3] = [0, -1, 1] self.scene.add(light, pose=light_pose) light_pose[:3, 3] = [0, 1, 1] self.scene.add(light, pose=light_pose) light_pose[:3, 3] = [1, 1, 2] self.scene.add(light, pose=light_pose) """ # light_pose[:3, 3] = [0, -2, 2] # [droite, hauteur, profondeur camera] """ light_pose = np.eye(4) light_pose[:3, 3] = [0, -1, 1] self.scene.add(light, pose=light_pose) light_pose[:3, 3] = [0, 1, 1] self.scene.add(light, pose=light_pose) light_pose[:3, 3] = [1, 1, 2] self.scene.add(light, pose=light_pose) """ def render(self, img, verts, cam, angle=None, axis=None, mesh_filename=None, color=[1.0, 1.0, 0.9]): mesh = trimesh.Trimesh(vertices=verts, faces=self.faces, process=False) Rx = trimesh.transformations.rotation_matrix(math.radians(180), [1, 0, 0]) mesh.apply_transform(Rx) if mesh_filename is not None: mesh.export(mesh_filename) if angle and axis: R = trimesh.transformations.rotation_matrix(math.radians(angle), axis) mesh.apply_transform(R) sx, sy, tx, ty = cam camera = WeakPerspectiveCamera( scale=[sx, sy], translation=[tx, ty], zfar=1000. ) material = pyrender.MetallicRoughnessMaterial( metallicFactor=0.7, alphaMode='OPAQUE', baseColorFactor=(color[0], color[1], color[2], 1.0) ) mesh = pyrender.Mesh.from_trimesh(mesh, material=material) mesh_node = self.scene.add(mesh, 'mesh') camera_pose = np.eye(4) cam_node = self.scene.add(camera, pose=camera_pose) if self.wireframe: render_flags = RenderFlags.RGBA | RenderFlags.ALL_WIREFRAME else: render_flags = RenderFlags.RGBA rgb, _ = self.renderer.render(self.scene, flags=render_flags) valid_mask = (rgb[:, :, -1] > 0)[:, :, np.newaxis] output_img = rgb[:, :, :-1] * valid_mask + (1 - valid_mask) * img image = output_img.astype(np.uint8) self.scene.remove_node(mesh_node) self.scene.remove_node(cam_node) return image def get_renderer(width, height): renderer = Renderer(resolution=(width, height), bg_color=[1, 1, 1, 0.5], orig_img=False, wireframe=False) return renderer ================================================ FILE: PBnet/src/render/rendermotion.py ================================================ import numpy as np import imageio import os import argparse from tqdm import tqdm from .renderer import get_renderer def get_rotation(theta=np.pi/3): import src.utils.rotation_conversions as geometry import torch axis = torch.tensor([0, 1, 0], dtype=torch.float) axisangle = theta*axis matrix = geometry.axis_angle_to_matrix(axisangle) return matrix.numpy() def render_video(meshes, key, action, renderer, savepath, background, cam=(0.75, 0.75, 0, 0.10), color=[0.11, 0.53, 0.8]): writer = imageio.get_writer(savepath, fps=30) # center the first frame meshes = meshes - meshes[0].mean(axis=0) # matrix = get_rotation(theta=np.pi/4) # meshes = meshes[45:] # meshes = np.einsum("ij,lki->lkj", matrix, meshes) imgs = [] for mesh in tqdm(meshes, desc=f"Visualize {key}, action {action}"): img = renderer.render(background, mesh, cam, color=color) imgs.append(img) # show(img) imgs = np.array(imgs) masks = ~(imgs/255. > 0.96).all(-1) coords = np.argwhere(masks.sum(axis=0)) y1, x1 = coords.min(axis=0) y2, x2 = coords.max(axis=0) for cimg in imgs[:, y1:y2, x1:x2]: writer.append_data(cimg) writer.close() def main(): parser = argparse.ArgumentParser() parser.add_argument("filename") opt = parser.parse_args() filename = opt.filename savefolder = os.path.splitext(filename)[0] os.makedirs(savefolder, exist_ok=True) output = np.load(filename) if output.shape[0] == 3: visualization, generation, reconstruction = output output = {"visualization": visualization, "generation": generation, "reconstruction": reconstruction} else: # output = {f"generation_{key}": output[key] for key in range(2)} # len(output))} # output = {f"generation_{key}": output[key] for key in range(len(output))} output = {f"generation_{key}": output[key] for key in range(len(output))} width = 1024 height = 1024 background = np.zeros((height, width, 3)) renderer = get_renderer(width, height) # if duration mode, put back durations if output["generation_3"].shape[-1] == 100: output["generation_0"] = output["generation_0"][:, :, :, :40] output["generation_1"] = output["generation_1"][:, :, :, :60] output["generation_2"] = output["generation_2"][:, :, :, :80] output["generation_3"] = output["generation_3"][:, :, :, :100] elif output["generation_3"].shape[-1] == 160: print("160 mode") output["generation_0"] = output["generation_0"][:, :, :, :100] output["generation_1"] = output["generation_1"][:, :, :, :120] output["generation_2"] = output["generation_2"][:, :, :, :140] output["generation_3"] = output["generation_3"][:, :, :, :160] # if str(action) == str(1) and str(key) == "generation_4": for key in output: vidmeshes = output[key] for action in range(len(vidmeshes)): meshes = vidmeshes[action].transpose(2, 0, 1) path = os.path.join(savefolder, "action{}_{}.mp4".format(action, key)) render_video(meshes, key, action, renderer, path, background) if __name__ == "__main__": main() ================================================ FILE: PBnet/src/train/__init__.py ================================================ ================================================ FILE: PBnet/src/train/train_cvae_ganloss_ann_eye.py ================================================ import sys sys.path.append('your_path/PBnet') import os import torch from torch.utils.tensorboard import SummaryWriter from src.utils.utils import MultiEpochsDataLoader as DataLoader from src.utils.utils import CudaDataLoader # import torch.utils.data.dataloader as DataLoader from src.train.trainer_gan_ann import train import src.utils.fixseed # noqa from src.parser.training import parser import torch import torch.nn as nn import torch.nn.functional as F import importlib import random import numpy as np JOINTSTYPES = ["a2m", "a2mpl", "smpl", "vibe", "vertices"] LOSSES = ["rc", "kl", "rcw", "ssim"] # not used: "hp", "mmd", "vel", "velxyz" MODELTYPES = ["cvae"] # not used: "cae" ARCHINAMES = ["fc", "gru", "transformer", "transformerreemb5", "transformermel", "transgru", "grutrans", "autotrans"] class ConvNormRelu(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, padding, norm='batch', leaky=True): super(ConvNormRelu, self).__init__() layers = [] layers.append(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)) if norm == 'batch': layers.append(nn.BatchNorm1d(out_channels)) if leaky: layers.append(nn.LeakyReLU(0.2)) else: layers.append(nn.ReLU()) self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x) class D_patchgan(nn.Module): def __init__(self, n_downsampling=2, pos_dim=6, eye_dim=0, norm='batch'): super(D_patchgan, self).__init__() ndf = 64 self.eye_dim = eye_dim self.dim = pos_dim + self.eye_dim self.conv1 = nn.Conv1d(self.dim, ndf, kernel_size=4, stride=2, padding=1) self.leaky_relu = nn.LeakyReLU(0.2) layers = [] for n in range(0, n_downsampling): nf_mult = min(2**n, 8) layers.append(ConvNormRelu(ndf * nf_mult, ndf * nf_mult * 2, kernel_size=4, stride=2, padding=1, norm=norm)) nf_mult = min(2**n_downsampling, 8) layers.append(ConvNormRelu(ndf * nf_mult, ndf * nf_mult, kernel_size=4, stride=1, padding=1, norm=norm)) layers.append(nn.Conv1d(ndf * nf_mult, 1, kernel_size=4, stride=1, padding=1)) self.model = nn.Sequential(*layers) def forward(self, x): out = self.conv1(x) out = self.leaky_relu(out) out = self.model(out) return out def calculate_GAN_loss(self, batch): x = batch["x"] #bs, nf, 6 x_ref = x[:,0,:].unsqueeze(dim=1) #bs, 1, 64 output = batch["output"]+x_ref #bs, nf, 6 real_pose_score = self.forward(x.permute(0,2,1)) fake_pose_score = self.forward(output.permute(0,2,1)) D_loss = F.binary_cross_entropy_with_logits(real_pose_score, torch.ones_like(real_pose_score)) + F.binary_cross_entropy_with_logits(fake_pose_score, torch.zeros_like(fake_pose_score)) G_loss = F.binary_cross_entropy_with_logits(fake_pose_score, torch.ones_like(fake_pose_score)) return D_loss.mean(), G_loss.mean() def get_model(parameters): modeltype = parameters["modeltype"] archiname = parameters["archiname"] archi_module = importlib.import_module(f'.architectures.{archiname}', package="src.models") Encoder = archi_module.__getattribute__(f"Encoder_{archiname.upper()}") Decoder = archi_module.__getattribute__(f"Decoder_{archiname.upper()}") model_module = importlib.import_module(f'.modeltype.{modeltype}', package="src.models") Model = model_module.__getattribute__(f"{modeltype.upper()}") encoder = Encoder(**parameters) decoder = Decoder(**parameters) # parameters["outputxyz"] = "rcxyz" in parameters["lambdas"] return Model(encoder, decoder, **parameters).to(parameters["device"]) def do_epochs(model, model_d, dataset, parameters, optimizer_g, optimizer_d, scheduler_g, scheduler_d, writer): # train_iterator = DataLoader(dataset, batch_size=parameters["batch_size"], # shuffle=True, num_workers=8, pin_memory=True) train_iterator = DataLoader(dataset, batch_size=parameters["batch_size"], shuffle=True, num_workers=16, collate_fn=collate, pin_memory=True) train_iterator = CudaDataLoader(train_iterator, device = 'cuda:0') logpath = os.path.join(parameters["folder"], "training.log") with open(logpath, "w") as logfile: for epoch in range(1, parameters["num_epochs"]+1): dict_loss = train(model, model_d, optimizer_g, optimizer_d, train_iterator, model.device, epoch) for key in dict_loss.keys(): dict_loss[key] /= len(train_iterator) writer.add_scalar(f"Loss/{key}", dict_loss[key], epoch) epochlog = f"Epoch {epoch}, train losses: {dict_loss}" print(epochlog) print(epochlog, file=logfile) scheduler_g.step() scheduler_d.step() if ((epoch % parameters["snapshot"]) == 0) or (epoch == parameters["num_epochs"]): checkpoint_path = os.path.join(parameters["folder"], 'checkpoint_{:04d}.pth.tar'.format(epoch)) print('Saving checkpoint {}'.format(checkpoint_path)) torch.save(model.state_dict(), checkpoint_path) writer.flush() if __name__ == '__main__': # setup_seed(1234) # parse options parameters = parser() # logging tensorboard writer = SummaryWriter(log_dir=parameters["folder"]) os.environ["CUDA_VISIBLE_DEVICES"] = parameters["gpu"] dataset_name = parameters["dataset"] if dataset_name == 'crema': from src.datasets.datasets_crema_pos_eye_fast import CREMA from src.utils.tensors_eye import collate # data path data_dir = "/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images" # model and dataset dataset = CREMA(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'train') dataset.update_parameters(parameters) elif dataset_name == 'hdtf': data_dir = "/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz" if parameters["first3"]=='True': if parameters["eye"]=='True': from src.utils.tensors_eye import collate from src.datasets.datasets_hdtf_pos_chunk_norm_eye_first3 import HDTF else: from src.utils.tensors import collate from src.datasets.datasets_hdtf_pos_chunk_norm_2_first3 import HDTF else: if parameters["eye"]=='True': from src.utils.tensors_eye import collate from src.datasets.datasets_hdtf_pos_chunk_norm_eye_fast import HDTF else: from src.utils.tensors_eye import collate from src.datasets.datasets_hdtf_pos_chunk_norm_2 import HDTF dataset = HDTF(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'train') dataset.update_parameters(parameters) else: dataset = None print('Dataset can not be found!!') model = get_model(parameters) if parameters['eye']=='True': model_d = D_patchgan(pos_dim=parameters["pos_dim"], eye_dim=parameters["eye_dim"]).to(parameters["device"]) else: model_d = D_patchgan(pos_dim=parameters["pos_dim"]).to(parameters["device"]) # optimizer optimizer_g = torch.optim.AdamW(model.parameters(), lr=parameters["lr"]) optimizer_d = torch.optim.AdamW(model_d.parameters(), lr=parameters["lr"]) scheduler_g = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_g, T_max=parameters["num_epochs"], eta_min=2e-5) scheduler_d = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_d, T_max=parameters["num_epochs"], eta_min=2e-5) print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) print("Training model..") do_epochs(model, model_d, dataset, parameters, optimizer_g, optimizer_d, scheduler_g, scheduler_d, writer) writer.close() ================================================ FILE: PBnet/src/train/train_cvae_ganloss_ann_fast.py ================================================ import sys sys.path.append('your_path/PBnet') import os import torch from torch.utils.tensorboard import SummaryWriter from src.utils.utils import MultiEpochsDataLoader as DataLoader from src.train.trainer_gan_ann import train from src.utils.tensors import collate import src.utils.fixseed # noqa from src.parser.training import parser from src.datasets.datasets_crema_pos import CREMA from src.datasets.datasets_hdtf_pos_chunk_norm_2 import HDTF import torch import torch.nn as nn import torch.nn.functional as F import importlib JOINTSTYPES = ["a2m", "a2mpl", "smpl", "vibe", "vertices"] LOSSES = ["rc", "kl", "rcw", "ssim"] # not used: "hp", "mmd", "vel", "velxyz" MODELTYPES = ["cvae"] # not used: "cae" ARCHINAMES = ["fc", "gru", "transformer", "transgru", "grutrans", "autotrans"] class ConvNormRelu(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, padding, norm='batch', leaky=True): super(ConvNormRelu, self).__init__() layers = [] layers.append(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)) if norm == 'batch': layers.append(nn.BatchNorm1d(out_channels)) if leaky: layers.append(nn.LeakyReLU(0.2)) else: layers.append(nn.ReLU()) self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x) class D_patchgan(nn.Module): def __init__(self, n_downsampling=2, norm='batch'): super(D_patchgan, self).__init__() ndf = 64 self.conv1 = nn.Conv1d(6, ndf, kernel_size=4, stride=2, padding=1) self.leaky_relu = nn.LeakyReLU(0.2) layers = [] for n in range(0, n_downsampling): nf_mult = min(2**n, 8) layers.append(ConvNormRelu(ndf * nf_mult, ndf * nf_mult * 2, kernel_size=4, stride=2, padding=1, norm=norm)) nf_mult = min(2**n_downsampling, 8) layers.append(ConvNormRelu(ndf * nf_mult, ndf * nf_mult, kernel_size=4, stride=1, padding=1, norm=norm)) layers.append(nn.Conv1d(ndf * nf_mult, 1, kernel_size=4, stride=1, padding=1)) self.model = nn.Sequential(*layers) def forward(self, x): out = self.conv1(x) out = self.leaky_relu(out) out = self.model(out) return out def calculate_GAN_loss(self, batch): x = batch["x"] #bs, nf, 6 x_ref = x[:,0,:].unsqueeze(dim=1) #bs, 1, 64 output = batch["output"]+x_ref #bs, nf, 6 real_pose_score = self.forward(x.permute(0,2,1)) fake_pose_score = self.forward(output.permute(0,2,1)) D_loss = F.binary_cross_entropy_with_logits(real_pose_score, torch.ones_like(real_pose_score)) + F.binary_cross_entropy_with_logits(fake_pose_score, torch.zeros_like(fake_pose_score)) G_loss = F.binary_cross_entropy_with_logits(fake_pose_score, torch.ones_like(fake_pose_score)) return D_loss.mean(), G_loss.mean() def get_model(parameters): modeltype = parameters["modeltype"] archiname = parameters["archiname"] archi_module = importlib.import_module(f'.architectures.{archiname}', package="src.models") Encoder = archi_module.__getattribute__(f"Encoder_{archiname.upper()}") Decoder = archi_module.__getattribute__(f"Decoder_{archiname.upper()}") model_module = importlib.import_module(f'.modeltype.{modeltype}', package="src.models") Model = model_module.__getattribute__(f"{modeltype.upper()}") encoder = Encoder(**parameters) decoder = Decoder(**parameters) # parameters["outputxyz"] = "rcxyz" in parameters["lambdas"] return Model(encoder, decoder, **parameters).to(parameters["device"]) def do_epochs(model, model_d, dataset, parameters, optimizer_g, optimizer_d, scheduler_g, scheduler_d, writer): # train_iterator = DataLoader(dataset, batch_size=parameters["batch_size"], # shuffle=True, num_workers=8, pin_memory=True) train_iterator = DataLoader(dataset, batch_size=parameters["batch_size"], shuffle=True, num_workers=8, collate_fn=collate, pin_memory = True) logpath = os.path.join(parameters["folder"], "training.log") with open(logpath, "w") as logfile: for epoch in range(1, parameters["num_epochs"]+1): dict_loss = train(model, model_d, optimizer_g, optimizer_d, train_iterator, model.device, epoch) for key in dict_loss.keys(): dict_loss[key] /= len(train_iterator) writer.add_scalar(f"Loss/{key}", dict_loss[key], epoch) epochlog = f"Epoch {epoch}, train losses: {dict_loss}" print(epochlog) print(epochlog, file=logfile) scheduler_g.step() scheduler_d.step() if ((epoch % parameters["snapshot"]) == 0) or (epoch == parameters["num_epochs"]): checkpoint_path = os.path.join(parameters["folder"], 'checkpoint_{:04d}.pth.tar'.format(epoch)) print('Saving checkpoint {}'.format(checkpoint_path)) torch.save(model.state_dict(), checkpoint_path) writer.flush() if __name__ == '__main__': # parse options parameters = parser() # logging tensorboard writer = SummaryWriter(log_dir=parameters["folder"]) os.environ["CUDA_VISIBLE_DEVICES"] = parameters["gpu"] dataset_name = parameters["dataset"] if dataset_name == 'crema': # data path data_dir = "/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images" # model and dataset dataset = CREMA(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'train') dataset.update_parameters(parameters) elif dataset_name == 'hdtf': data_dir = "/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz" dataset = HDTF(data_dir=data_dir, max_num_frames=parameters["num_frames"], mode = 'train') dataset.update_parameters(parameters) else: dataset = None print('Dataset can not be found!!') model = get_model(parameters) model_d = D_patchgan().to(parameters["device"]) # optimizer optimizer_g = torch.optim.AdamW(model.parameters(), lr=parameters["lr"]) optimizer_d = torch.optim.AdamW(model_d.parameters(), lr=parameters["lr"]) scheduler_g = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_g, T_max=parameters["num_epochs"], eta_min=2e-5) scheduler_d = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_d, T_max=parameters["num_epochs"], eta_min=2e-5) print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) print("Training model..") do_epochs(model, model_d, dataset, parameters, optimizer_g, optimizer_d, scheduler_g, scheduler_d, writer) writer.close() ================================================ FILE: PBnet/src/train/trainer.py ================================================ import torch from tqdm import tqdm def train_or_test(model, optimizer, iterator, device, mode="train"): if mode == "train": model.train() grad_env = torch.enable_grad elif mode == "test": model.eval() grad_env = torch.no_grad else: raise ValueError("This mode is not recognized.") # loss of the epoch dict_loss = {loss: 0 for loss in model.losses} with grad_env(): for i, batch in tqdm(enumerate(iterator), desc="Computing batch"): # Put everything in device batch = {key: val.to(device) for key, val in batch.items() if key!='videoname'} if mode == "train": # update the gradients to zero optimizer.zero_grad() # forward pass batch = model(batch) mixed_loss, losses = model.compute_loss(batch) for key in dict_loss.keys(): dict_loss[key] += losses[key] if mode == "train": # backward pass mixed_loss.backward() # update the weights optimizer.step() if i % 10 == 0: print(losses) return dict_loss def train(model, optimizer, iterator, device): return train_or_test(model, optimizer, iterator, device, mode="train") def test(model, optimizer, iterator, device): return train_or_test(model, optimizer, iterator, device, mode="test") ================================================ FILE: PBnet/src/train/trainer_gan.py ================================================ import torch from tqdm import tqdm import time def train_or_test(model, model_d, optimizer_g, optimizer_d, iterator, device, mode="train", epoch = 0): if mode == "train": model.train() model_d.train() grad_env = torch.enable_grad elif mode == "test": model.eval() model_d.eval() grad_env = torch.no_grad else: raise ValueError("This mode is not recognized.") # loss of the epoch dict_loss = {loss: 0 for loss in (model.losses)} dict_loss['Dloss'] = 0 dict_loss['Gloss'] = 0 with grad_env(): start_time = time.time() # end # print(f'load time {end_time- start_time}') for i, batch in tqdm(enumerate(iterator), desc="Computing batch"): # Put everything in device # end_time = time.time() # print("load_cost: ", - start_time + end_time) # start_time = time.time() batch = {key: val.to(device) for key, val in batch.items() if key!='videoname'} if mode == "train": # update the gradients to zero optimizer_g.zero_grad() optimizer_d.zero_grad() # forward pass batch = model(batch) mixed_loss, losses = model.compute_loss(batch, epoch) D_loss, G_loss = model_d.calculate_GAN_loss(batch) end_time = time.time() print("forward: ", - start_time + end_time) start_time = time.time() for key in dict_loss.keys(): if key != 'Gloss' and key != 'Dloss': dict_loss[key] += losses[key] dict_loss['Dloss'] += D_loss.item() dict_loss['Gloss'] += G_loss.item() if mode == "train": # backward pass (mixed_loss + (G_loss + D_loss * 0.5) ).backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 2.) # update the weights optimizer_g.step() optimizer_d.step() end_time = time.time() print("back: ", - start_time + end_time) start_time = time.time() # if i % 10 == 0: # print(dict_loss) return dict_loss def train(model, model_d, optimizer_g, optimizer_d, iterator, device): return train_or_test(model, model_d, optimizer_g, optimizer_d, iterator, device, mode="train") def test(model, model_d, optimizer_g, optimizer_d, iterator, device): return train_or_test(model, model_d, optimizer_g, optimizer_d, iterator, device, mode="test") ================================================ FILE: PBnet/src/train/trainer_gan_ann.py ================================================ import torch from tqdm import tqdm import time def train_or_test(model, model_d, optimizer_g, optimizer_d, iterator, device, mode="train", epoch = 0): if mode == "train": model.train() model_d.train() grad_env = torch.enable_grad elif mode == "test": model.eval() model_d.eval() grad_env = torch.no_grad else: raise ValueError("This mode is not recognized.") # loss of the epoch dict_loss = {loss: 0 for loss in (model.losses)} dict_loss['Dloss'] = 0 dict_loss['Gloss'] = 0 with grad_env(): # start_time = time.time() # end # print(f'load time {end_time- start_time}') for i, batch in tqdm(enumerate(iterator), desc="Computing batch"): # Put everything in device # end_time = time.time() # print("load_cost: ", - start_time + end_time) # start_time = time.time() batch = {key: val.to(device) for key, val in batch.items() if key!='videoname'} # end_time = time.time() # print("tocuda_cost: ", - start_time + end_time) # start_time = time.time() if mode == "train": # update the gradients to zero optimizer_g.zero_grad() optimizer_d.zero_grad() # forward pass batch = model(batch) mixed_loss, losses = model.compute_loss(batch, epoch) D_loss, G_loss = model_d.calculate_GAN_loss(batch) # end_time = time.time() # print("forward: ", - start_time + end_time) # start_time = time.time() for key in dict_loss.keys(): if key != 'Gloss' and key != 'Dloss': dict_loss[key] += losses[key] dict_loss['Dloss'] += D_loss.item() dict_loss['Gloss'] += G_loss.item() if mode == "train": # backward pass ((mixed_loss + (G_loss + D_loss) )).backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 2.) # update the weights optimizer_g.step() optimizer_d.step() # end_time = time.time() # print("back: ", - start_time + end_time) # start_time = time.time() # if i % 10 == 0: # print(dict_loss) return dict_loss def train(model, model_d, optimizer_g, optimizer_d, iterator, device, epoch): return train_or_test(model, model_d, optimizer_g, optimizer_d, iterator, device, mode="train", epoch = epoch) def test(model, model_d, optimizer_g, optimizer_d, iterator, device): return train_or_test(model, model_d, optimizer_g, optimizer_d, iterator, device, mode="test") ================================================ FILE: PBnet/src/utils/PYTORCH3D_LICENSE ================================================ BSD License For PyTorch3D software Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: PBnet/src/utils/__init__.py ================================================ ================================================ FILE: PBnet/src/utils/fixseed.py ================================================ import numpy as np import torch import random def fixseed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) SEED = 10 EVALSEED = 0 # Provoc warning: not fully functionnal yet # torch.set_deterministic(True) torch.backends.cudnn.benchmark = False fixseed(SEED) ================================================ FILE: PBnet/src/utils/get_model_and_data.py ================================================ from ..datasets.get_dataset import get_datasets from ..recognition.get_model import get_model as get_rec_model from ..models.get_model import get_model as get_gen_model def get_model_and_data(parameters): datasets = get_datasets(parameters) if parameters["modelname"] == "recognition": model = get_rec_model(parameters) else: model = get_gen_model(parameters) return model, datasets ================================================ FILE: PBnet/src/utils/misc.py ================================================ import torch def to_numpy(tensor): if torch.is_tensor(tensor): return tensor.cpu().numpy() elif type(tensor).__module__ != 'numpy': raise ValueError("Cannot convert {} to numpy array".format( type(tensor))) return tensor def to_torch(ndarray): if type(ndarray).__module__ == 'numpy': return torch.from_numpy(ndarray) elif not torch.is_tensor(ndarray): raise ValueError("Cannot convert {} to torch tensor".format( type(ndarray))) return ndarray def cleanexit(): import sys import os try: sys.exit(0) except SystemExit: os._exit(0) ================================================ FILE: PBnet/src/utils/rotation_conversions.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Check PYTORCH3D_LICENCE before use import functools from typing import Optional import torch import torch.nn.functional as F """ The transformation matrices returned from the functions in this file assume the points on which the transformation will be applied are column vectors. i.e. the R matrix is structured as R = [ [Rxx, Rxy, Rxz], [Ryx, Ryy, Ryz], [Rzx, Rzy, Rzz], ] # (3, 3) This matrix can be applied to column vectors by post multiplication by the points e.g. points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point transformed_points = R * points To apply the same matrix to points which are row vectors, the R matrix can be transposed and pre multiplied by the points: e.g. points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point transformed_points = points * R.transpose(1, 0) """ def quaternion_to_matrix(quaternions): """ Convert rotations given as quaternions to rotation matrices. Args: quaternions: quaternions with real part first, as tensor of shape (..., 4). Returns: Rotation matrices as tensor of shape (..., 3, 3). """ r, i, j, k = torch.unbind(quaternions, -1) two_s = 2.0 / (quaternions * quaternions).sum(-1) o = torch.stack( ( 1 - two_s * (j * j + k * k), two_s * (i * j - k * r), two_s * (i * k + j * r), two_s * (i * j + k * r), 1 - two_s * (i * i + k * k), two_s * (j * k - i * r), two_s * (i * k - j * r), two_s * (j * k + i * r), 1 - two_s * (i * i + j * j), ), -1, ) return o.reshape(quaternions.shape[:-1] + (3, 3)) def _copysign(a, b): """ Return a tensor where each element has the absolute value taken from the, corresponding element of a, with sign taken from the corresponding element of b. This is like the standard copysign floating-point operation, but is not careful about negative 0 and NaN. Args: a: source tensor. b: tensor whose signs will be used, of the same shape as a. Returns: Tensor of the same shape as a with the signs of b. """ signs_differ = (a < 0) != (b < 0) return torch.where(signs_differ, -a, a) def _sqrt_positive_part(x): """ Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0. """ ret = torch.zeros_like(x) positive_mask = x > 0 ret[positive_mask] = torch.sqrt(x[positive_mask]) return ret def matrix_to_quaternion(matrix): """ Convert rotations given as rotation matrices to quaternions. Args: matrix: Rotation matrices as tensor of shape (..., 3, 3). Returns: quaternions with real part first, as tensor of shape (..., 4). """ if matrix.size(-1) != 3 or matrix.size(-2) != 3: raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") m00 = matrix[..., 0, 0] m11 = matrix[..., 1, 1] m22 = matrix[..., 2, 2] o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) return torch.stack((o0, o1, o2, o3), -1) def _axis_angle_rotation(axis: str, angle): """ Return the rotation matrices for one of the rotations about an axis of which Euler angles describe, for each value of the angle given. Args: axis: Axis label "X" or "Y or "Z". angle: any shape tensor of Euler angles in radians Returns: Rotation matrices as tensor of shape (..., 3, 3). """ cos = torch.cos(angle) sin = torch.sin(angle) one = torch.ones_like(angle) zero = torch.zeros_like(angle) if axis == "X": R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) if axis == "Y": R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) if axis == "Z": R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) def euler_angles_to_matrix(euler_angles, convention: str): """ Convert rotations given as Euler angles in radians to rotation matrices. Args: euler_angles: Euler angles in radians as tensor of shape (..., 3). convention: Convention string of three uppercase letters from {"X", "Y", and "Z"}. Returns: Rotation matrices as tensor of shape (..., 3, 3). """ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: raise ValueError("Invalid input euler angles.") if len(convention) != 3: raise ValueError("Convention must have 3 letters.") if convention[1] in (convention[0], convention[2]): raise ValueError(f"Invalid convention {convention}.") for letter in convention: if letter not in ("X", "Y", "Z"): raise ValueError(f"Invalid letter {letter} in convention string.") matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) return functools.reduce(torch.matmul, matrices) def _angle_from_tan( axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool ): """ Extract the first or third Euler angle from the two members of the matrix which are positive constant times its sine and cosine. Args: axis: Axis label "X" or "Y or "Z" for the angle we are finding. other_axis: Axis label "X" or "Y or "Z" for the middle axis in the convention. data: Rotation matrices as tensor of shape (..., 3, 3). horizontal: Whether we are looking for the angle for the third axis, which means the relevant entries are in the same row of the rotation matrix. If not, they are in the same column. tait_bryan: Whether the first and third axes in the convention differ. Returns: Euler Angles in radians for each matrix in data as a tensor of shape (...). """ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] if horizontal: i2, i1 = i1, i2 even = (axis + other_axis) in ["XY", "YZ", "ZX"] if horizontal == even: return torch.atan2(data[..., i1], data[..., i2]) if tait_bryan: return torch.atan2(-data[..., i2], data[..., i1]) return torch.atan2(data[..., i2], -data[..., i1]) def _index_from_letter(letter: str): if letter == "X": return 0 if letter == "Y": return 1 if letter == "Z": return 2 def matrix_to_euler_angles(matrix, convention: str): """ Convert rotations given as rotation matrices to Euler angles in radians. Args: matrix: Rotation matrices as tensor of shape (..., 3, 3). convention: Convention string of three uppercase letters. Returns: Euler angles in radians as tensor of shape (..., 3). """ if len(convention) != 3: raise ValueError("Convention must have 3 letters.") if convention[1] in (convention[0], convention[2]): raise ValueError(f"Invalid convention {convention}.") for letter in convention: if letter not in ("X", "Y", "Z"): raise ValueError(f"Invalid letter {letter} in convention string.") if matrix.size(-1) != 3 or matrix.size(-2) != 3: raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") i0 = _index_from_letter(convention[0]) i2 = _index_from_letter(convention[2]) tait_bryan = i0 != i2 if tait_bryan: central_angle = torch.asin( matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) ) else: central_angle = torch.acos(matrix[..., i0, i0]) o = ( _angle_from_tan( convention[0], convention[1], matrix[..., i2], False, tait_bryan ), central_angle, _angle_from_tan( convention[2], convention[1], matrix[..., i0, :], True, tait_bryan ), ) return torch.stack(o, -1) def random_quaternions( n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False ): """ Generate random quaternions representing rotations, i.e. versors with nonnegative real part. Args: n: Number of quaternions in a batch to return. dtype: Type to return. device: Desired device of returned tensor. Default: uses the current device for the default tensor type. requires_grad: Whether the resulting tensor should have the gradient flag set. Returns: Quaternions as tensor of shape (N, 4). """ o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad) s = (o * o).sum(1) o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] return o def random_rotations( n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False ): """ Generate random rotations as 3x3 rotation matrices. Args: n: Number of rotation matrices in a batch to return. dtype: Type to return. device: Device of returned tensor. Default: if None, uses the current device for the default tensor type. requires_grad: Whether the resulting tensor should have the gradient flag set. Returns: Rotation matrices as tensor of shape (n, 3, 3). """ quaternions = random_quaternions( n, dtype=dtype, device=device, requires_grad=requires_grad ) return quaternion_to_matrix(quaternions) def random_rotation( dtype: Optional[torch.dtype] = None, device=None, requires_grad=False ): """ Generate a single random 3x3 rotation matrix. Args: dtype: Type to return device: Device of returned tensor. Default: if None, uses the current device for the default tensor type requires_grad: Whether the resulting tensor should have the gradient flag set Returns: Rotation matrix as tensor of shape (3, 3). """ return random_rotations(1, dtype, device, requires_grad)[0] def standardize_quaternion(quaternions): """ Convert a unit quaternion to a standard form: one in which the real part is non negative. Args: quaternions: Quaternions with real part first, as tensor of shape (..., 4). Returns: Standardized quaternions as tensor of shape (..., 4). """ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) def quaternion_raw_multiply(a, b): """ Multiply two quaternions. Usual torch rules for broadcasting apply. Args: a: Quaternions as tensor of shape (..., 4), real part first. b: Quaternions as tensor of shape (..., 4), real part first. Returns: The product of a and b, a tensor of quaternions shape (..., 4). """ aw, ax, ay, az = torch.unbind(a, -1) bw, bx, by, bz = torch.unbind(b, -1) ow = aw * bw - ax * bx - ay * by - az * bz ox = aw * bx + ax * bw + ay * bz - az * by oy = aw * by - ax * bz + ay * bw + az * bx oz = aw * bz + ax * by - ay * bx + az * bw return torch.stack((ow, ox, oy, oz), -1) def quaternion_multiply(a, b): """ Multiply two quaternions representing rotations, returning the quaternion representing their composition, i.e. the versor with nonnegative real part. Usual torch rules for broadcasting apply. Args: a: Quaternions as tensor of shape (..., 4), real part first. b: Quaternions as tensor of shape (..., 4), real part first. Returns: The product of a and b, a tensor of quaternions of shape (..., 4). """ ab = quaternion_raw_multiply(a, b) return standardize_quaternion(ab) def quaternion_invert(quaternion): """ Given a quaternion representing rotation, get the quaternion representing its inverse. Args: quaternion: Quaternions as tensor of shape (..., 4), with real part first, which must be versors (unit quaternions). Returns: The inverse, a tensor of quaternions of shape (..., 4). """ return quaternion * quaternion.new_tensor([1, -1, -1, -1]) def quaternion_apply(quaternion, point): """ Apply the rotation given by a quaternion to a 3D point. Usual torch rules for broadcasting apply. Args: quaternion: Tensor of quaternions, real part first, of shape (..., 4). point: Tensor of 3D points of shape (..., 3). Returns: Tensor of rotated points of shape (..., 3). """ if point.size(-1) != 3: raise ValueError(f"Points are not in 3D, f{point.shape}.") real_parts = point.new_zeros(point.shape[:-1] + (1,)) point_as_quaternion = torch.cat((real_parts, point), -1) out = quaternion_raw_multiply( quaternion_raw_multiply(quaternion, point_as_quaternion), quaternion_invert(quaternion), ) return out[..., 1:] def axis_angle_to_matrix(axis_angle): """ Convert rotations given as axis/angle to rotation matrices. Args: axis_angle: Rotations given as a vector in axis angle form, as a tensor of shape (..., 3), where the magnitude is the angle turned anticlockwise in radians around the vector's direction. Returns: Rotation matrices as tensor of shape (..., 3, 3). """ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) def matrix_to_axis_angle(matrix): """ Convert rotations given as rotation matrices to axis/angle. Args: matrix: Rotation matrices as tensor of shape (..., 3, 3). Returns: Rotations given as a vector in axis angle form, as a tensor of shape (..., 3), where the magnitude is the angle turned anticlockwise in radians around the vector's direction. """ return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) def axis_angle_to_quaternion(axis_angle): """ Convert rotations given as axis/angle to quaternions. Args: axis_angle: Rotations given as a vector in axis angle form, as a tensor of shape (..., 3), where the magnitude is the angle turned anticlockwise in radians around the vector's direction. Returns: quaternions with real part first, as tensor of shape (..., 4). """ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) half_angles = 0.5 * angles eps = 1e-6 small_angles = angles.abs() < eps sin_half_angles_over_angles = torch.empty_like(angles) sin_half_angles_over_angles[~small_angles] = ( torch.sin(half_angles[~small_angles]) / angles[~small_angles] ) # for x small, sin(x/2) is about x/2 - (x/2)^3/6 # so sin(x/2)/x is about 1/2 - (x*x)/48 sin_half_angles_over_angles[small_angles] = ( 0.5 - (angles[small_angles] * angles[small_angles]) / 48 ) quaternions = torch.cat( [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 ) return quaternions def quaternion_to_axis_angle(quaternions): """ Convert rotations given as quaternions to axis/angle. Args: quaternions: quaternions with real part first, as tensor of shape (..., 4). Returns: Rotations given as a vector in axis angle form, as a tensor of shape (..., 3), where the magnitude is the angle turned anticlockwise in radians around the vector's direction. """ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) half_angles = torch.atan2(norms, quaternions[..., :1]) angles = 2 * half_angles eps = 1e-6 small_angles = angles.abs() < eps sin_half_angles_over_angles = torch.empty_like(angles) sin_half_angles_over_angles[~small_angles] = ( torch.sin(half_angles[~small_angles]) / angles[~small_angles] ) # for x small, sin(x/2) is about x/2 - (x/2)^3/6 # so sin(x/2)/x is about 1/2 - (x*x)/48 sin_half_angles_over_angles[small_angles] = ( 0.5 - (angles[small_angles] * angles[small_angles]) / 48 ) return quaternions[..., 1:] / sin_half_angles_over_angles def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: """ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix using Gram--Schmidt orthogonalisation per Section B of [1]. Args: d6: 6D rotation representation, of size (*, 6) Returns: batch of rotation matrices of size (*, 3, 3) [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. On the Continuity of Rotation Representations in Neural Networks. IEEE Conference on Computer Vision and Pattern Recognition, 2019. Retrieved from http://arxiv.org/abs/1812.07035 """ a1, a2 = d6[..., :3], d6[..., 3:] b1 = F.normalize(a1, dim=-1) b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 b2 = F.normalize(b2, dim=-1) b3 = torch.cross(b1, b2, dim=-1) return torch.stack((b1, b2, b3), dim=-2) def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: """ Converts rotation matrices to 6D rotation representation by Zhou et al. [1] by dropping the last row. Note that 6D representation is not unique. Args: matrix: batch of rotation matrices of size (*, 3, 3) Returns: 6D rotation representation, of size (*, 6) [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. On the Continuity of Rotation Representations in Neural Networks. IEEE Conference on Computer Vision and Pattern Recognition, 2019. Retrieved from http://arxiv.org/abs/1812.07035 """ return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) ================================================ FILE: PBnet/src/utils/tensors.py ================================================ import torch def lengths_to_mask(lengths): max_len = max(lengths) mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1) return mask def collate_tensors(batch): dims = batch[0].dim() max_size = [max([b.size(i) for b in batch]) for i in range(dims)] size = (len(batch),) + tuple(max_size) canvas = batch[0].new_zeros(size=size) for i, b in enumerate(batch): sub_tensor = canvas[i] for d in range(dims): sub_tensor = sub_tensor.narrow(d, 0, b.size(d)) sub_tensor.add_(b) return canvas def collate(batch): posbatch = [b[1] for b in batch] audiobatch = [b[0] for b in batch] lenbatch = [len(b[0]) for b in batch] videonamebatch=[b[2] for b in batch] posbatchTensor = collate_tensors(posbatch) audiobatchTensor = collate_tensors(audiobatch) lenbatchTensor = torch.as_tensor(lenbatch) maskbatchTensor = lengths_to_mask(lenbatchTensor) batch = {"x": posbatchTensor, "y": audiobatchTensor, "mask": maskbatchTensor, "lengths": lenbatchTensor, "videoname": videonamebatch} return batch ================================================ FILE: PBnet/src/utils/tensors_eye.py ================================================ import torch def lengths_to_mask(lengths): max_len = max(lengths) mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1) return mask def collate_tensors(batch): dims = batch[0].dim() max_size = [max([b.size(i) for b in batch]) for i in range(dims)] size = (len(batch),) + tuple(max_size) canvas = batch[0].new_zeros(size=size) for i, b in enumerate(batch): sub_tensor = canvas[i] for d in range(dims): sub_tensor = sub_tensor.narrow(d, 0, b.size(d)) sub_tensor.add_(b) return canvas def collate(batch): posbatch = [b[1] for b in batch] audiobatch = [b[0] for b in batch] # eyebatch = [b[2] for b in batch] lenbatch = [len(b[0]) for b in batch] # startbatch = [b[4] for b in batch] videonamebatch=[b[3] for b in batch] poseyebatch=[b[5] for b in batch] posbatchTensor = collate_tensors(posbatch) audiobatchTensor = collate_tensors(audiobatch) # eyebatchTensor = collate_tensors(eyebatch) poseyebatchTensor = collate_tensors(poseyebatch) # startbatchTensor = collate_tensors(startbatch) lenbatchTensor = torch.as_tensor(lenbatch) maskbatchTensor = lengths_to_mask(lenbatchTensor) batch = {"x":poseyebatchTensor,"p": posbatchTensor, "y": audiobatchTensor, "mask": maskbatchTensor, "lengths": lenbatchTensor, "videoname": videonamebatch} # return batch ================================================ FILE: PBnet/src/utils/tensors_eye_eval.py ================================================ import torch def lengths_to_mask(lengths): max_len = max(lengths) mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1) return mask def collate_tensors(batch): dims = batch[0].dim() max_size = [max([b.size(i) for b in batch]) for i in range(dims)] size = (len(batch),) + tuple(max_size) canvas = batch[0].new_zeros(size=size) for i, b in enumerate(batch): sub_tensor = canvas[i] for d in range(dims): sub_tensor = sub_tensor.narrow(d, 0, b.size(d)) sub_tensor.add_(b) return canvas def collate(batch): posbatch = [b[1] for b in batch] audiobatch = [b[0] for b in batch] eyebatch = [b[2] for b in batch] lenbatch = [len(b[0]) for b in batch] startbatch = [b[4] for b in batch] videonamebatch=[b[3] for b in batch] poseyebatch=[b[5] for b in batch] posbatchTensor = collate_tensors(posbatch) audiobatchTensor = collate_tensors(audiobatch) eyebatchTensor = collate_tensors(eyebatch) poseyebatchTensor = collate_tensors(poseyebatch) # startbatchTensor = collate_tensors(startbatch) lenbatchTensor = torch.as_tensor(lenbatch) maskbatchTensor = lengths_to_mask(lenbatchTensor) batch = {"x":poseyebatchTensor,"p": posbatchTensor, "y": audiobatchTensor, "e": eyebatchTensor, "mask": maskbatchTensor, "lengths": lenbatchTensor, "videoname": videonamebatch, "start": startbatch} return batch ================================================ FILE: PBnet/src/utils/tensors_hdtf.py ================================================ import torch def lengths_to_mask(lengths): max_len = max(lengths) mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1) return mask def collate_tensors(batch): dims = batch[0].dim() max_size = [max([b.size(i) for b in batch]) for i in range(dims)] size = (len(batch),) + tuple(max_size) canvas = batch[0].new_zeros(size=size) for i, b in enumerate(batch): sub_tensor = canvas[i] for d in range(dims): sub_tensor = sub_tensor.narrow(d, 0, b.size(d)) sub_tensor.add_(b) return canvas def collate_old(batch): posbatch = [b[1] for b in batch] audiobatch = [b[0] for b in batch] lenbatch = [len(b[0]) for b in batch] startbatch = [b[3] for b in batch] videonamebatch=[b[2] for b in batch] posbatchTensor = collate_tensors(posbatch) audiobatchTensor = collate_tensors(audiobatch) # startbatchTensor = collate_tensors(startbatch) lenbatchTensor = torch.as_tensor(lenbatch) maskbatchTensor = lengths_to_mask(lenbatchTensor) batch = {"x": posbatchTensor, "y": audiobatchTensor, "mask": maskbatchTensor, "lengths": lenbatchTensor, "videoname": videonamebatch, "start": startbatch} return batch def collate(batch): posbatch = [b[1] for b in batch] audiobatch = [b[0] for b in batch] eyebatch = [b[2] for b in batch] lenbatch = [len(b[0]) for b in batch] # startbatch = [b[4] for b in batch] videonamebatch=[b[3] for b in batch] poseyebatch=[b[5] for b in batch] posbatchTensor = collate_tensors(posbatch) audiobatchTensor = collate_tensors(audiobatch) eyebatchTensor = collate_tensors(eyebatch) poseyebatchTensor = collate_tensors(poseyebatch) # startbatchTensor = collate_tensors(startbatch) lenbatchTensor = torch.as_tensor(lenbatch) maskbatchTensor = lengths_to_mask(lenbatchTensor) batch = {"x":poseyebatchTensor,"p": posbatchTensor, "y": audiobatchTensor, "e": eyebatchTensor, "mask": maskbatchTensor, "lengths": lenbatchTensor, "videoname": videonamebatch} return batch ================================================ FILE: PBnet/src/utils/tensors_onlyeye.py ================================================ import torch def lengths_to_mask(lengths): max_len = max(lengths) mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1) return mask def collate_tensors(batch): dims = batch[0].dim() max_size = [max([b.size(i) for b in batch]) for i in range(dims)] size = (len(batch),) + tuple(max_size) canvas = batch[0].new_zeros(size=size) for i, b in enumerate(batch): sub_tensor = canvas[i] for d in range(dims): sub_tensor = sub_tensor.narrow(d, 0, b.size(d)) sub_tensor.add_(b) return canvas def collate(batch): # posbatch = [b[1] for b in batch] audiobatch = [b[0] for b in batch] eyebatch = [b[1] for b in batch] lenbatch = [len(b[0]) for b in batch] # startbatch = [b[4] for b in batch] videonamebatch=[b[2] for b in batch] # poseyebatch=[b[5] for b in batch] # posbatchTensor = collate_tensors(posbatch) audiobatchTensor = collate_tensors(audiobatch) eyebatchTensor = collate_tensors(eyebatch) # poseyebatchTensor = collate_tensors(poseyebatch) # startbatchTensor = collate_tensors(startbatch) lenbatchTensor = torch.as_tensor(lenbatch) maskbatchTensor = lengths_to_mask(lenbatchTensor) batch = {"x":eyebatchTensor, "y": audiobatchTensor, "mask": maskbatchTensor, "lengths": lenbatchTensor, "videoname": videonamebatch} return batch def collate_eval(batch): # posbatch = [b[1] for b in batch] audiobatch = [b[0] for b in batch] eyebatch = [b[1] for b in batch] lenbatch = [len(b[0]) for b in batch] startbatch = [b[3] for b in batch] videonamebatch=[b[2] for b in batch] # poseyebatch=[b[5] for b in batch] # posbatchTensor = collate_tensors(posbatch) audiobatchTensor = collate_tensors(audiobatch) eyebatchTensor = collate_tensors(eyebatch) # poseyebatchTensor = collate_tensors(poseyebatch) # startbatchTensor = collate_tensors(startbatch) lenbatchTensor = torch.as_tensor(lenbatch) maskbatchTensor = lengths_to_mask(lenbatchTensor) batch = {"x":eyebatchTensor, "y": audiobatchTensor, "mask": maskbatchTensor, "lengths": lenbatchTensor, "videoname": videonamebatch,"start": startbatch} return batch ================================================ FILE: PBnet/src/utils/utils.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from threading import Thread from queue import Queue class _RepeatSampler(object): def __init__(self, sampler): self.sampler = sampler def __iter__(self): while True: yield from iter(self.sampler) class MultiEpochsDataLoader(torch.utils.data.DataLoader): """ During multi-epoch training, the DataLoader object does not need to recreate the thread and batch_sampler objects, in order to save the initialization time for each epoch. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) self.iterator = super().__iter__() def __len__(self): return len(self.batch_sampler.sampler) def __iter__(self): for i in range(len(self)): yield next(self.iterator) class CudaDataLoader: def __init__(self, loader, device, queue_size=2): self.device = device self.queue_size = queue_size self.loader = loader self.load_stream = torch.cuda.Stream(device=device) self.queue = Queue(maxsize=self.queue_size) self.idx = 0 self.worker = Thread(target=self.load_loop) self.worker.setDaemon(True) self.worker.start() def load_loop(self): # The loop that will load into the queue in the background while True: for i, sample in enumerate(self.loader): self.queue.put(self.load_instance(sample)) def load_instance(self, sample): if torch.is_tensor(sample): with torch.cuda.stream(self.load_stream): return sample.to(self.device, non_blocking=True) elif sample is None or type(sample) in (list, str): return sample elif isinstance(sample, dict): return {k: self.load_instance(v) for k, v in sample.items()} else: return [self.load_instance(s) for s in sample] def __iter__(self): self.idx = 0 return self def __next__(self): if not self.worker.is_alive() and self.queue.empty(): self.idx = 0 self.queue.join() self.worker.join() raise StopIteration elif self.idx >= len(self.loader): self.idx = 0 raise StopIteration else: out = self.queue.get() self.queue.task_done() self.idx += 1 return out def next(self): return self.__next__() def __len__(self): return len(self.loader) @property def sampler(self): return self.loader.sampler @property def dataset(self): return self.loader.dataset ================================================ FILE: PBnet/src/utils/video.py ================================================ import numpy as np import imageio def load_video(filename): vid = imageio.get_reader(filename, 'ffmpeg') fps = vid.get_meta_data()['fps'] nframes = vid.count_frames() return vid, fps, nframes class SaveVideo: def __init__(self, outname, fps): self.outname = outname self.fps = fps def __enter__(self): self.writter = imageio.get_writer(self.outname, format='FFMPEG', fps=self.fps) return self def __exit__(self, exc_type, exc_value, exc_traceback): self.writter.close() def __iadd__(self, data): if np.max(data) <= 1: data = np.array(255*data, dtype=np.uint8) else: data = np.array(data, dtype=np.uint8) self.writter.append_data(data) return self ================================================ FILE: PBnet/src/visualize/__init__.py ================================================ ================================================ FILE: PBnet/src/visualize/anim.py ================================================ import numpy as np import torch import imageio # from action2motion # Define a kinematic tree for the skeletal struture humanact12_kinematic_chain = [[0, 1, 4, 7, 10], [0, 2, 5, 8, 11], [0, 3, 6, 9, 12, 15], [9, 13, 16, 18, 20, 22], [9, 14, 17, 19, 21, 23]] # same as smpl smpl_kinematic_chain = humanact12_kinematic_chain mocap_kinematic_chain = [[0, 1, 2, 3], [0, 12, 13, 14, 15], [0, 16, 17, 18, 19], [1, 4, 5, 6, 7], [1, 8, 9, 10, 11]] vibe_kinematic_chain = [[0, 12, 13, 14, 15], [0, 9, 10, 11, 16], [0, 1, 8, 17], [1, 5, 6, 7], [1, 2, 3, 4]] action2motion_kinematic_chain = vibe_kinematic_chain def add_shadow(img, shadow=15): img = np.copy(img) mask = img > shadow img[mask] = img[mask] - shadow img[~mask] = 0 return img def load_anim(path, timesize=None): data = np.array(imageio.mimread(path, memtest=False))[..., :3] if timesize is None: return data # take the last frame and put shadow repeat the last frame but with a little shadow lastframe = add_shadow(data[-1]) alldata = np.tile(lastframe, (timesize, 1, 1, 1)) # copy the first frames lenanim = data.shape[0] alldata[:lenanim] = data[:lenanim] return alldata def plot_3d_motion(motion, length, save_path, params, title="", interval=50): import matplotlib import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D # noqa: F401 from mpl_toolkits.mplot3d.art3d import Poly3DCollection # noqa: F401 from matplotlib.animation import FuncAnimation, writers # noqa: F401 # import mpl_toolkits.mplot3d.axes3d as p3 matplotlib.use('Agg') pose_rep = params["pose_rep"] fig = plt.figure(figsize=[2.6, 2.8]) ax = fig.add_subplot(111, projection='3d') # ax = p3.Axes3D(fig) # ax = fig.gca(projection='3d') def init(): ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_zticklabels([]) ax.set_xlim(-0.7, 0.7) ax.set_ylim(-0.7, 0.7) ax.set_zlim(-0.7, 0.7) ax.view_init(azim=-90, elev=110) # ax.set_axis_off() ax.xaxis._axinfo["grid"]['color'] = (0.5, 0.5, 0.5, 0.25) ax.yaxis._axinfo["grid"]['color'] = (0.5, 0.5, 0.5, 0.25) ax.zaxis._axinfo["grid"]['color'] = (0.5, 0.5, 0.5, 0.25) colors = ['red', 'magenta', 'black', 'green', 'blue'] if pose_rep != "xyz": raise ValueError("It should already be xyz.") if torch.is_tensor(motion): motion = motion.numpy() # invert axis motion[:, 1, :] = -motion[:, 1, :] motion[:, 2, :] = -motion[:, 2, :] """ Debug: to rotate the bodies import src.utils.rotation_conversions as geometry glob_rot = [0, 1.5707963267948966, 0] global_orient = torch.tensor(glob_rot) rotmat = geometry.axis_angle_to_matrix(global_orient) motion = np.einsum("ikj,ko->ioj", motion, rotmat) """ if motion.shape[0] == 18: kinematic_tree = action2motion_kinematic_chain elif motion.shape[0] == 24: kinematic_tree = smpl_kinematic_chain else: kinematic_tree = None def update(index): ax.lines = [] ax.collections = [] if kinematic_tree is not None: for chain, color in zip(kinematic_tree, colors): ax.plot(motion[chain, 0, index], motion[chain, 1, index], motion[chain, 2, index], linewidth=4.0, color=color) else: ax.scatter(motion[1:, 0, index], motion[1:, 1, index], motion[1:, 2, index], c="red") ax.scatter(motion[:1, 0, index], motion[:1, 1, index], motion[:1, 2, index], c="blue") ax.set_title(title) ani = FuncAnimation(fig, update, frames=length, interval=interval, repeat=False, init_func=init) plt.tight_layout() # pillow have problem droping frames ani.save(save_path, writer='ffmpeg', fps=1000/interval) plt.close() def plot_3d_motion_dico(x): motion, length, save_path, params, kargs = x plot_3d_motion(motion, length, save_path, params, **kargs) ================================================ FILE: PBnet/src/visualize/visualize.py ================================================ import os import imageio import numpy as np import torch import torch.nn.functional as F from tqdm import tqdm from .anim import plot_3d_motion_dico, load_anim def stack_images(real, real_gens, gen): nleft_cols = len(real_gens) + 1 print("Stacking frames..") allframes = np.concatenate((real[:, None, ...], *[x[:, None, ...] for x in real_gens], gen), 1) nframes, nspa, nats, h, w, pix = allframes.shape blackborder = np.zeros((w//30, h*nats, pix), dtype=allframes.dtype) frames = [] for frame_idx in tqdm(range(nframes)): columns = np.vstack(allframes[frame_idx].transpose(1, 2, 3, 4, 0)).transpose(3, 1, 0, 2) frame = np.concatenate((*columns[0:nleft_cols], blackborder, *columns[nleft_cols:]), 0).transpose(1, 0, 2) frames.append(frame) return np.stack(frames) def generate_by_video(visualization, reconstructions, generation, label_to_action_name, params, nats, nspa, tmp_path): # shape : (17, 3, 4, 480, 640, 3) # (nframes, row, column, h, w, 3) fps = params["fps"] params = params.copy() if "output_xyz" in visualization: outputkey = "output_xyz" params["pose_rep"] = "xyz" else: outputkey = "poses" keep = [outputkey, "lengths", "y"] visu = {key: visualization[key].data.cpu().numpy() for key in keep} recons = {mode: {key: reconstruction[key].data.cpu().numpy() for key in keep} for mode, reconstruction in reconstructions.items()} gener = {key: generation[key].data.cpu().numpy() for key in keep} lenmax = max(gener["lengths"].max(), visu["lengths"].max()) timesize = lenmax + 5 import multiprocessing def pool_job_with_desc(pool, iterator, desc, max_, save_path_format, isij): with tqdm(total=max_, desc=desc.format("Render")) as pbar: for _ in pool.imap_unordered(plot_3d_motion_dico, iterator): pbar.update() if isij: array = np.stack([[load_anim(save_path_format.format(i, j), timesize) for j in range(nats)] for i in tqdm(range(nspa), desc=desc.format("Load"))]) return array.transpose(2, 0, 1, 3, 4, 5) else: array = np.stack([load_anim(save_path_format.format(i), timesize) for i in tqdm(range(nats), desc=desc.format("Load"))]) return array.transpose(1, 0, 2, 3, 4) with multiprocessing.Pool() as pool: # Generated samples save_path_format = os.path.join(tmp_path, "gen_{}_{}.gif") iterator = ((gener[outputkey][i, j], gener["lengths"][i, j], save_path_format.format(i, j), params, {"title": f"gen: {label_to_action_name(gener['y'][i, j])}", "interval": 1000/fps}) for j in range(nats) for i in range(nspa)) gener["frames"] = pool_job_with_desc(pool, iterator, "{} the generated samples", nats*nspa, save_path_format, True) # Real samples save_path_format = os.path.join(tmp_path, "real_{}.gif") iterator = ((visu[outputkey][i], visu["lengths"][i], save_path_format.format(i), params, {"title": f"real: {label_to_action_name(visu['y'][i])}", "interval": 1000/fps}) for i in range(nats)) visu["frames"] = pool_job_with_desc(pool, iterator, "{} the real samples", nats, save_path_format, False) for mode, recon in recons.items(): # Reconstructed samples save_path_format = os.path.join(tmp_path, f"reconstructed_{mode}_" + "{}.gif") iterator = ((recon[outputkey][i], recon["lengths"][i], save_path_format.format(i), params, {"title": f"recons: {label_to_action_name(recon['y'][i])}", "interval": 1000/fps}) for i in range(nats)) recon["frames"] = pool_job_with_desc(pool, iterator, "{} the reconstructed samples", nats, save_path_format, False) frames = stack_images(visu["frames"], [recon["frames"] for recon in recons.values()], gener["frames"]) return frames def viz_epoch(model, dataset, epoch, params, folder, writer=None): """ Generate & viz samples """ # visualize with joints3D model.outputxyz = True print(f"Visualization of the epoch {epoch}") noise_same_action = params["noise_same_action"] noise_diff_action = params["noise_diff_action"] duration_mode = params["duration_mode"] reconstruction_mode = params["reconstruction_mode"] decoder_test = params["decoder_test"] fact = params["fact_latent"] figname = params["figname"].format(epoch) nspa = params["num_samples_per_action"] nats = params["num_actions_to_sample"] num_classes = params["num_classes"] # define some classes classes = torch.randperm(num_classes)[:nats] meandurations = torch.from_numpy(np.array([round(dataset.get_mean_length_label(cl.item())) for cl in classes])) if duration_mode == "interpolate" or decoder_test == "diffduration": points, step = np.linspace(-nspa, nspa, nspa, retstep=True) points = np.round(10*points/step).astype(int) gendurations = meandurations.repeat((nspa, 1)) + points[:, None] else: gendurations = meandurations.repeat((nspa, 1)) # extract the real samples real_samples, mask_real, real_lengths = dataset.get_label_sample_batch(classes.numpy()) # to visualize directly # Visualizaion of real samples visualization = {"x": real_samples.to(model.device), "y": classes.to(model.device), "mask": mask_real.to(model.device), "lengths": real_lengths.to(model.device), "output": real_samples.to(model.device)} # Visualizaion of real samples if reconstruction_mode == "both": reconstructions = {"tf": {"x": real_samples.to(model.device), "y": classes.to(model.device), "lengths": real_lengths.to(model.device), "mask": mask_real.to(model.device), "teacher_force": True}, "ntf": {"x": real_samples.to(model.device), "y": classes.to(model.device), "lengths": real_lengths.to(model.device), "mask": mask_real.to(model.device)}} else: reconstructions = {reconstruction_mode: {"x": real_samples.to(model.device), "y": classes.to(model.device), "lengths": real_lengths.to(model.device), "mask": mask_real.to(model.device), "teacher_force": reconstruction_mode == "tf"}} print("Computing the samples poses..") # generate the repr (joints3D/pose etc) model.eval() with torch.no_grad(): # Reconstruction of the real data for mode in reconstructions: model(reconstructions[mode]) # update reconstruction dicts reconstruction = reconstructions[list(reconstructions.keys())[0]] if decoder_test == "new": # Generate the new data generation = model.generate(classes, gendurations, nspa=nspa, noise_same_action=noise_same_action, noise_diff_action=noise_diff_action, fact=fact) elif decoder_test == "diffaction": assert nats == nspa # keep the same noise for each "sample" z = reconstruction["z"].repeat((nspa, 1)) mask = reconstruction["mask"].repeat((nspa, 1)) lengths = reconstruction["lengths"].repeat(nspa) # but use other labels y = classes.repeat_interleave(nspa).to(model.device) generation = {"z": z, "y": y, "mask": mask, "lengths": lengths} model.decoder(generation) elif decoder_test == "diffduration": z = reconstruction["z"].repeat((nspa, 1)) lengths = gendurations.reshape(-1).to(model.device) mask = model.lengths_to_mask(lengths) y = classes.repeat(nats).to(model.device) generation = {"z": z, "y": y, "mask": mask, "lengths": lengths} model.decoder(generation) elif decoder_test == "interpolate_action": assert nats == nspa # same noise for each sample z_diff_action = torch.randn(1, model.latent_dim, device=model.device).repeat(nats, 1) z = z_diff_action.repeat((nspa, 1)) # but use combination of labels and labels below y = F.one_hot(classes.to(model.device), model.num_classes).to(model.device) y_below = F.one_hot(torch.cat((classes[1:], classes[0:1])), model.num_classes).to(model.device) convex_factors = torch.linspace(0, 1, nspa, device=model.device) y_mixed = torch.einsum("nk,m->mnk", y, 1-convex_factors) + torch.einsum("nk,m->mnk", y_below, convex_factors) y_mixed = y_mixed.reshape(nspa*nats, y_mixed.shape[-1]) durations = gendurations[0].to(model.device) durations_below = torch.cat((durations[1:], durations[0:1])) gendurations = torch.einsum("l,k->kl", durations, 1-convex_factors) + torch.einsum("l,k->kl", durations_below, convex_factors) gendurations = gendurations.to(dtype=durations.dtype) lengths = gendurations.to(model.device).reshape(z.shape[0]) mask = model.lengths_to_mask(lengths) generation = {"z": z, "y": y_mixed, "mask": mask, "lengths": lengths} model.decoder(generation) # Get xyz for the real ones visualization["output_xyz"] = model.rot2xyz(visualization["output"], visualization["mask"]) for key, val in generation.items(): if len(generation[key].shape) == 1: generation[key] = val.reshape(nspa, nats) else: generation[key] = val.reshape(nspa, nats, *val.shape[1:]) finalpath = os.path.join(folder, figname + ".gif") tmp_path = os.path.join(folder, f"subfigures_{figname}") os.makedirs(tmp_path, exist_ok=True) print("Generate the videos..") frames = generate_by_video(visualization, reconstructions, generation, dataset.label_to_action_name, params, nats, nspa, tmp_path) print(f"Writing video {finalpath}..") imageio.mimsave(finalpath, frames, fps=params["fps"]) if writer is not None: writer.add_video(f"Video/Epoch {epoch}", frames.transpose(0, 3, 1, 2)[None], epoch, fps=params["fps"]) def viz_dataset(dataset, params, folder): """ Generate & viz samples """ print("Visualization of the dataset") nspa = params["num_samples_per_action"] nats = params["num_actions_to_sample"] num_classes = params["num_classes"] figname = "{}_{}_numframes_{}_sampling_{}_step_{}".format(params["dataset"], params["pose_rep"], params["num_frames"], params["sampling"], params["sampling_step"]) # define some classes classes = torch.randperm(num_classes)[:nats] allclasses = classes.repeat(nspa, 1).reshape(nspa*nats) # extract the real samples real_samples, mask_real, real_lengths = dataset.get_label_sample_batch(allclasses.numpy()) # to visualize directly # Visualizaion of real samples visualization = {"x": real_samples, "y": allclasses, "mask": mask_real, "lengths": real_lengths, "output": real_samples} from src.models.rotation2xyz import Rotation2xyz device = params["device"] rot2xyz = Rotation2xyz(device=device) rot2xyz_params = {"pose_rep": params["pose_rep"], "glob_rot": params["glob_rot"], "glob": params["glob"], "jointstype": params["jointstype"], "translation": params["translation"]} output = visualization["output"] visualization["output_xyz"] = rot2xyz(output.to(device), visualization["mask"].to(device), **rot2xyz_params) for key, val in visualization.items(): if len(visualization[key].shape) == 1: visualization[key] = val.reshape(nspa, nats) else: visualization[key] = val.reshape(nspa, nats, *val.shape[1:]) finalpath = os.path.join(folder, figname + ".gif") tmp_path = os.path.join(folder, f"subfigures_{figname}") os.makedirs(tmp_path, exist_ok=True) print("Generate the videos..") frames = generate_by_video_sequences(visualization, dataset.label_to_action_name, params, nats, nspa, tmp_path) print(f"Writing video {finalpath}..") imageio.mimsave(finalpath, frames, fps=params["fps"]) def generate_by_video_sequences(visualization, label_to_action_name, params, nats, nspa, tmp_path): # shape : (17, 3, 4, 480, 640, 3) # (nframes, row, column, h, w, 3) fps = params["fps"] if "output_xyz" in visualization: outputkey = "output_xyz" params["pose_rep"] = "xyz" else: outputkey = "poses" keep = [outputkey, "lengths", "y"] visu = {key: visualization[key].data.cpu().numpy() for key in keep} lenmax = visu["lengths"].max() timesize = lenmax + 5 import multiprocessing def pool_job_with_desc(pool, iterator, desc, max_, save_path_format): with tqdm(total=max_, desc=desc.format("Render")) as pbar: for _ in pool.imap_unordered(plot_3d_motion_dico, iterator): pbar.update() array = np.stack([[load_anim(save_path_format.format(i, j), timesize) for j in range(nats)] for i in tqdm(range(nspa), desc=desc.format("Load"))]) return array.transpose(2, 0, 1, 3, 4, 5) with multiprocessing.Pool() as pool: # Real samples save_path_format = os.path.join(tmp_path, "real_{}_{}.gif") iterator = ((visu[outputkey][i, j], visu["lengths"][i, j], save_path_format.format(i, j), params, {"title": f"real: {label_to_action_name(visu['y'][i, j])}", "interval": 1000/fps}) for j in range(nats) for i in range(nspa)) visu["frames"] = pool_job_with_desc(pool, iterator, "{} the real samples", nats, save_path_format) frames = stack_images_sequence(visu["frames"]) return frames def stack_images_sequence(visu): print("Stacking frames..") allframes = visu nframes, nspa, nats, h, w, pix = allframes.shape frames = [] for frame_idx in tqdm(range(nframes)): columns = np.vstack(allframes[frame_idx].transpose(1, 2, 3, 4, 0)).transpose(3, 1, 0, 2) frame = np.concatenate(columns).transpose(1, 0, 2) frames.append(frame) return np.stack(frames) ================================================ FILE: PBnet/src/visualize/visualize_checkpoint.py ================================================ import os import matplotlib.pyplot as plt import torch from src.utils.get_model_and_data import get_model_and_data from src.parser.visualize import parser from .visualize import viz_epoch import src.utils.fixseed # noqa plt.switch_backend('agg') def main(): # parse options parameters, folder, checkpointname, epoch = parser() model, datasets = get_model_and_data(parameters) dataset = datasets["train"] print("Restore weights..") checkpointpath = os.path.join(folder, checkpointname) state_dict = torch.load(checkpointpath, map_location=parameters["device"]) model.load_state_dict(state_dict) # visualize_params viz_epoch(model, dataset, epoch, parameters, folder=folder, writer=None) if __name__ == '__main__': main() ================================================ FILE: PBnet/src/visualize/visualize_dataset.py ================================================ import matplotlib.pyplot as plt # import torch import os from src.datasets.get_dataset import get_dataset from src.utils import optutils from src.utils.visualize import viz_dataset import src.utils.fixseed # noqa plt.switch_backend('agg') if __name__ == '__main__': # parse options parameters = optutils.visualize_dataset_parser() # get device device = parameters["device"] # get data DATA = get_dataset(name=parameters["dataset"]) dataset = DATA(split="train", **parameters) # add specific parameters from the dataset loading dataset.update_parameters(parameters) name = f"{parameters['dataset']}_{parameters['extraction_method']}" folder = os.path.join("datavisualize", name) viz_dataset(dataset, parameters, folder) ================================================ FILE: PBnet/src/visualize/visualize_latent_space.py ================================================ import matplotlib.pyplot as plt import numpy as np import os import scipy import torch import torch.nn.functional as F from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from ..utils import optutils from ..utils.visualize import viz_epoch, viz_fake, viz_real from ..models.get_model import get_model from ..datasets.get_dataset import get_dataset from ..utils.trainer import train, test # import ..utils.fixseed # noqa plt.switch_backend('agg') if __name__ == '__main__': # parse options opt, folder, checkpointname, epoch = optutils.parse_load_args() # get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # get data DATA = get_dataset(name=opt.dataname) dataset = DATA(split="train", **opt.data) test_dataset = train_dataset = dataset # update model parameters opt.model.update({"num_classes": dataset.num_classes, "nfeats": dataset.nfeats, "device": device}) # update visualize params opt.visualize.update({"num_classes": dataset.num_classes, "num_actions_to_sample": min(opt.visualize["num_actions_to_sample"], dataset.num_classes)}) # get model MODEL = get_model(opt.modelname) model = MODEL(**opt.model) model = model.to(device) print("Restore weights..") checkpointpath = os.path.join(folder, checkpointname) state_dict = torch.load(checkpointpath, map_location=device) model.load_state_dict(state_dict) nexemple = 20 latents = [] labels = [] generats = [] print("Evaluating model..") keep = {"x": [], "y": [], "di": []} num_classes = dataset.num_classes # num_classes = 1 for label in tqdm(range(num_classes)): xcp, ycp, di = dataset.get_label_sample(label, n=nexemple, return_labels=True, return_index=True) keep["x"].append(xcp) keep["y"].append(ycp) keep["di"].append(di) x = torch.from_numpy(xcp).to(device) y = torch.from_numpy(ycp).to(device) h = model.return_latent(x, y) # mu, var = model.encoder(x, y) # h = mu hy = torch.randn(nexemple, model.latent_dim, device=device) hcp = h.data.cpu().numpy() hycp = hy.data.cpu().numpy() latents.append(hcp) generats.append(hycp) labels.append(ycp) latents = np.array(latents) generats = np.array(generats) nclasses, nexemple, latent_dim = latents.shape labels = np.array(labels) all_latents = np.concatenate(latents) all_generats = np.concatenate(generats) nall_latents = len(all_latents) # import ipdb; ipdb.set_trace() print("Computing tsne..") from sklearn.manifold import TSNE all_input = np.concatenate((all_latents, all_generats)) # tsne = TSNE(n_components=2) # all_vizu_concat = tsne.fit_transform(all_input) # import ipdb; ipdb.set_trace() # feats = tuple(np.argsort(all_latents.var(0))[::-1][:2]) feats = tuple(np.argsort(all_latents.min(0)-all_latents.max(0))[::-1][:2] ) all_vizu_concat = all_input[:, feats] all_vizu_vectors = all_vizu_concat[:nall_latents] all_gen_vizu_vectors = all_vizu_concat[nall_latents:] gen_vizu_vectors = all_gen_vizu_vectors.reshape(nclasses, nexemple, 2) vizu_vectors = all_vizu_vectors.reshape(nclasses, nexemple, 2) print("Plotting..") import matplotlib.pyplot as plt import matplotlib.colors as mcolors colors = list(mcolors.TABLEAU_COLORS.values()) + list(mcolors.BASE_COLORS.values()) + list(mcolors.CSS4_COLORS.values()) for label in tqdm(range(num_classes)): color = colors[label] plt.scatter(*gen_vizu_vectors[label].T, color=color, marker="X") for label in tqdm(range(num_classes)): color = colors[label] plt.scatter(*vizu_vectors[label].T, color=color) plt.savefig("tsne_all.png") plt.close() import ipdb; ipdb.set_trace() """ mean = all_vizu_vectors.mean() farthest = np.argsort(np.linalg.norm(mean - all_vizu_vectors, axis=1))[::-1][0] cl_number, exnumber = np.argwhere(np.arange(all_vizu_vectors.shape[0]).reshape(nclasses, nexemple) == farthest)[0] outlier_vid = keep["x"][cl_number][exnumber] nframe = outlier_vid.shape[-1] from ..utils.video import SaveVideo save_path = "outlier.mp4" cl_name = dataset.label_to_action_name(cl_number) with SaveVideo(save_path, opt.visualize["fps"]) as outvideo: for frame in range(nframe): outvideo += repr_to_frame(outlier_vid[..., frame], f"{cl_name} outlier", {"pose_rep": "xyz"}) """ ================================================ FILE: PBnet/src/visualize/visualize_nturefined.py ================================================ import matplotlib.pyplot as plt import torch from src.datasets.get_dataset import get_dataset from src.utils.anim import plot_3d_motion import src.utils.fixseed # noqa plt.switch_backend('agg') def viz_ntu13(dataset, device): """ Generate & viz samples """ print("Visualization of the ntu13") from src.models.rotation2xyz import Rotation2xyz rot2xyz = Rotation2xyz(device) realsamples = [] pose18samples = [] pose24samples = [] translation = True dataset.glob = True dataset.translation = translation for i in range(1, 2): dataset.pose_rep = "xyz" x_xyz = dataset[i][0] realsamples.append(x_xyz) dataset.pose_rep = "rotvec" pose = dataset[i][0] mask = torch.ones(pose.shape[2], dtype=bool) # from src.models.smpl import SMPL # smplmodel = SMPL().eval().to(device) # import ipdb; ipdb.set_trace() pose24 = rot2xyz(pose[None], mask[None], pose_rep="rotvec", jointstype="smpl", glob=True, translation=translation)[0] pose18 = rot2xyz(pose[None], mask[None], pose_rep="rotvec", jointstype="a2m", glob=True, translation=translation)[0] translation = True dataset.glob = True dataset.translation = translation # poseT = dataset[i][0] # pose18T = rot2xyz(poseT[None], mask[None], pose_rep="rotvec", jointstype="action2motion", glob=True, translation=translation)[0] # import ipdb; ipdb.set_trace() pose18samples.append(pose18) pose24samples.append(pose24) params = {"pose_rep": "xyz"} for i in [0]: for x_xyz, title in zip([pose24samples[i], pose18samples[i], realsamples[i]], ["pose_to_24", "pose_to_18", "action2motion_18"]): save_path = title + ".gif" plot_3d_motion(x_xyz, x_xyz.shape[-1], save_path, params, title=title) print(f"saving {save_path}") if __name__ == '__main__': # get device device = torch.device('cpu') # get data DATA = get_dataset(name="ntu13") dataset = DATA(split="train") viz_ntu13(dataset, device) ================================================ FILE: PBnet/src/visualize/visualize_sequence.py ================================================ import os import matplotlib.pyplot as plt import torch import numpy as np from src.datasets.get_dataset import get_dataset from src.models.get_model import get_model from src.utils import optutils from src.utils.anim import plot_3d_motion_on_oneframe from src.utils.visualize import process_to_visualize import src.utils.fixseed # noqa plt.switch_backend('agg') if __name__ == '__main__': # parse options opt, folder, checkpointname, epoch = optutils.visualize_parser() # get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # get data DATA = get_dataset(name=opt.dataset) dataset = DATA(split="train", **opt.data) test_dataset = train_dataset = dataset # update model parameters opt.model.update({"num_classes": dataset.num_classes, "nfeats": dataset.nfeats, "device": device}) # update visualize params opt.visualize.update({"num_classes": dataset.num_classes, "num_actions_to_sample": min(opt.visualize["num_actions_to_sample"], dataset.num_classes)}) # get model MODEL = get_model(opt.modelname) model = MODEL(**opt.model) model = model.to(device) print("Restore weights..") checkpointpath = os.path.join(folder, checkpointname) state_dict = torch.load(checkpointpath, map_location=device) model.load_state_dict(state_dict) save_path = os.path.join(folder, f"fig_{epoch}") action_number = 0 actioname = dataset.action_to_action_name(action_number) label = dataset.action_to_label(action_number) print(f"Generate {actioname}..") y = torch.from_numpy(np.array([label], dtype=int)).to(device) motion = model.generate(y, fact=1) motion = process_to_visualize(motion.data.cpu().numpy(), opt.visualize)[0] print("Plot motion..") plot_3d_motion_on_oneframe(motion, "motion.png", opt.visualize, title=actioname) ================================================ FILE: README.md ================================================ # 🌅 DAWN: Dynamic Frame Avatar with Non-autoregressive Diffusion Framework for Talking Head Video Generation [![arXiv](https://img.shields.io/badge/Arxiv-2410.13726-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.13726) [![Demo Page](https://img.shields.io/badge/Demo_Page-blue)](https://hanbo-cheng.github.io/DAWN/) [![zhihu](https://img.shields.io/badge/知乎-0079FF.svg?logo=zhihu&logoColor=white)](https://zhuanlan.zhihu.com/p/2253009511) [中文文档](README_CN.md)

😊 Please give us a star ⭐ to support us for continous update 😊
## News * ```2024.10.14``` 🔥 We release the [Demo page](https://hanbo-cheng.github.io/DAWN/). * ```2024.10.18``` 🔥 We release the paper [DAWN](https://arxiv.org/abs/2410.13726). * ```2024.10.21``` 🔥 We update the Chinese introduction [](https://zhuanlan.zhihu.com/p/2253009511). * ```2024.11.7``` 🔥🔥 We realse the pretrained model on [hugging face](https://huggingface.co/Hanbo-Cheng/DAWN). * ```2024.11.9``` 🔥🔥🔥 We realse the inference code. We sincerely invite you to experience our model. 😊 * ```2025.2.16``` 🔥🔥🔥 We optimize the unified inference code. Now you can run the test pipeline with only one script. 🚀 ## TODO list: - [x] release the inference code - [x] release the pretrained model of **128*128** - [x] release the pretrained model of **256*256** - [x] release the unified test code - [ ] in progress ... ## Equipment Requirements With our VRAM-oriented optimized [code](DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test_local_opt.py), the maximum length of video that can be generated is **linearly related** to the size of the GPU VRAM. Larger VRAM produce longer videos. - To generate **128*128** video, we recommend using a GPU with **12GB** or more VRAM. This can at least generate video of approximately **400 frames**. - To generate **256*256** video, we recommend using a GPU with **24GB** or more VRAM. This can at least generate video of approximately **200 frames**. PS: Although optimized [code](DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test_local_opt.py) can improve VRAM utilization, it currently sacrifices inference speed due to incomplete optimization of local attention. We are actively working on this issue, and if you have a better solution, we welcome your PR. If you wish to achieve faster inference speeds, you can use [unoptimized code](DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test.py), but this will increase VRAM usage (O(n²) spatial complexity). ## Methodology ### The overall structure of DAWN:

framework

## Environment We highly recommend to try DAWN on linux platform. Runing on windows may produce some rubbish files need to be deleted manually and requires additional effort for the deployment of the 3DDFA repository (our `extract_init_states` folder) [comment](https://github.com/cleardusk/3DDFA_V2/issues/12#issuecomment-697479173). 1. set up the conda environment ``` conda create -n DAWN python=3.8 conda activate DAWN pip install -r requirements.txt ``` 2. Follow the [readme](extract_init_states/readme.md) and [3DDFA_V2](https://github.com/cleardusk/3DDFA_V2) to set up the 3DDFA environment. ## Inference Since our model **is trained only on the HDTF dataset** and has few parameters, in order to ensure the best driving effect, please provide examples of : - standard human photos as much as possible, try not to wear hats or large headgear - ensure a clear boundary between the background and the subject - have the face occupying the main position in the image. The preparation for inference: 1. Download the pretrain checkpoints from [hugging face](https://huggingface.co/Hanbo-Cheng/DAWN). Create the `./pretrain_models` directory and put the checkpoint files into it. Please down load the Hubert model from [facebook/hubert-large-ls960-ft](https://huggingface.co/facebook/hubert-large-ls960-ft/tree/main). ``` directory structure: pretrain_models/ ├── LFG_256_400ep.pth ├── LFG_128_1000ep.pth ├── DAWN_256.pth ├── DAWN_128.pth └── hubert-large-ls960-ft/ ├── ..... ``` 2. Run the inference script: ``` python unified_video_generator.py \ --audio_path your/audio/path \ --image_path your/image/path \ --output_path output/path \ --cache_path cache/path \ --resolution 128 \ # optional: 128 or 256 ``` ***Inference on other dataset:*** By specifying the `audio_path`, `image_path`, and `output_path` of the `VideoGenerator` class during each inference, and modifying the contents of `directory_name` and `output_video_path` in `unified_video_generator.py` Lines 310-312 and 393-394, you can control the naming logic for saving images and videos, enabling testing on any dataset. ## Citing DAWN If you wish to refer to the baseline results published here, please use the following BibTeX entries: ```BibTeX @misc{dawn2024, title={DAWN: Dynamic Frame Avatar with Non-autoregressive Diffusion Framework for Talking Head Video Generation}, author={Hanbo Cheng and Limin Lin and Chenyu Liu and Pengcheng Xia and Pengfei Hu and Jiefeng Ma and Jun Du and Jia Pan}, year={2024}, eprint={2410.13726}, archivePrefix={arXiv}, primaryClass={cs.CV}, url={https://arxiv.org/abs/2410.13726}, } ``` ## Acknowledgement [Limin Lin](https://github.com/LiminLin0) and [Hanbo Cheng](https://github.com/Hanbo-Cheng) contributed equally to the project. Thank you to the authors of [Diffused Heads](https://github.com/MStypulkowski/diffused-heads) for assisting us in reproducing their work! We also extend our gratitude to the authors of [MRAA](https://github.com/snap-research/articulated-animation), [LFDM](https://github.com/snap-research/articulated-animation), [3DDFA_V2](https://github.com/cleardusk/3DDFA_V2) and [ACTOR](https://github.com/Mathux/ACTOR) for their contributions to the open-source community. Lastly, we thank our mentors and co-authors for their continuous support in our research work! ================================================ FILE: README_CN.md ================================================ # 🌅 DAWN:Dynamic Frame Avatar with Non-autoregressive Diffusion Framework for Talking Head Video Generation [![arXiv](https://img.shields.io/badge/Arxiv-2410.13726-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.13726) [![Demo Page](https://img.shields.io/badge/Demo_Page-blue)](https://hanbo-cheng.github.io/DAWN/) [![zhihu](https://img.shields.io/badge/知乎-0079FF.svg?logo=zhihu&logoColor=white)](https://zhuanlan.zhihu.com/p/2253009511)

😊 请给我们一个star⭐支持我们的持续更新 😊 ## 新闻 * ```2024.10.14``` 🔥 我们发布了 [DEMO](https://hanbo-cheng.github.io/DAWN/)。 * ```2024.10.18``` 🔥 我们发布了论文 [DAWN](https://arxiv.org/abs/2410.13726)。 * ```2024.10.21``` 🔥 我们更新了中文介绍 [知乎](https://zhuanlan.zhihu.com/p/2253009511)。 * ```2024.11.7``` 🔥🔥 我们在 [hugging face](https://huggingface.co/Hanbo-Cheng/DAWN) 上发布了预训练模型。 * ```2024.11.9``` 🔥🔥🔥 我们发布了推理代码。我们诚挚邀请您体验我们的模型。😊 * ```2025.2.16``` 🔥🔥🔥 我们优化了统一推理代码。现在您可以仅用一个脚本运行测试流程。🚀 ## 待办事项列表: - [x] 发布推理代码 - [x] 发布 **128*128** 的预训练模型 - [x] 发布 **256*256** 的预训练模型 - [x] 发布统一测试代码 - [ ] 进行中 ... ## 设备要求 使用我们针对VRAM优化的 [代码](DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test_local_opt.py),生成的视频最大长度与GPU VRAM的大小 **成线性关系**。更大的VRAM可以生成更长的视频。 - 要生成 **128*128** 视频,我们建议使用 **12GB** 或更多VRAM的GPU。这至少可以生成大约 **400帧** 的视频。 - 要生成 **256*256** 视频,我们建议使用 **24GB** 或更多VRAM的GPU。这至少可以生成大约 **200帧** 的视频。 PS:尽管优化的 [代码](DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test_local_opt.py) 可以提高VRAM利用率,但由于局部注意力的优化尚不完整,目前牺牲了推理速度。我们正在积极解决这个问题,如果您有更好的解决方案,欢迎您提交PR。如果您希望实现更快的推理速度,可以使用 [未优化的代码](DM_3/modules/video_flow_diffusion_multiGPU_v0_crema_plus_faceemb_ca_multi_test.py),但这将增加VRAM使用(O(n²) 空间复杂度)。 ## 方法论 ### DAWN的整体结构:

framework

## 环境 我们强烈建议在Linux平台上尝试DAWN。在Windows上运行可能会产生一些需要手动删除的垃圾文件,并且需要额外的努力来部署3DDFA库(我们的 `extract_init_states` 文件夹) [评论](https://github.com/cleardusk/3DDFA_V2/issues/12#issuecomment-697479173)。 1. 设置conda环境 ``` conda create -n DAWN python=3.8 conda activate DAWN pip install -r requirements.txt ``` 2. 按照 [readme](extract_init_states/readme.md) 和 [3DDFA_V2](https://github.com/cleardusk/3DDFA_V2) 设置3DDFA环境。 ## 推理 由于我们的模型 **仅在HDTF数据集上训练**,并且参数较少,为了确保最佳的驱动效果,请尽量提供以下示例: - 尽量使用标准人像照片,避免佩戴帽子或大型头饰 - 确保背景与主体之间有清晰的边界 - 确保面部在图像中占据主要位置。 推理准备: 1. 从 [hugging face](https://huggingface.co/Hanbo-Cheng/DAWN) 下载预训练检查点。创建 `./pretrain_models` 目录并将检查点文件放入其中。请从 [facebook/hubert-large-ls960-ft](https://huggingface.co/facebook/hubert-large-ls960-ft/tree/main) 下载Hubert模型。 2. 运行推理脚本: ``` python unified_video_generator.py \ --audio_path your/audio/path \ --image_path your/image/path \ --output_path output/path \ --cache_path cache/path ``` ***在其他数据集上的推理:*** 通过在每次推理时指定 `VideoGenerator` 类的 `audio_path`、`image_path` 和 `output_path`,并修改 `unified_video_generator.py` 中第310-312行和393-394行的 `directory_name` 和 `output_video_path` 的内容,您可以控制保存图像和视频的命名逻辑,从而在任何数据集上进行测试。 ================================================ FILE: config/DAWN_128.yaml ================================================ input_size: 128 max_n_frames: 200 random_seed: 1234 mean: [0.0, 0.0, 0.0] win_width: 40 sampling_step: 20 ddim_sampling_eta: 1.0 cond_scale: 1.0 model_config: is_train: true pose_dim: 6 config_pth: './config/hdtf128.yaml' ae_pretrained_pth: './pretrain_models/LFG_128_1000ep.pth' diffusion_pretrained_pth: './pretrain_models/DAWN_128.pth' ================================================ FILE: config/DAWN_256.yaml ================================================ input_size: 256 max_n_frames: 200 random_seed: 1234 mean: [0.0, 0.0, 0.0] win_width: 40 sampling_step: 20 ddim_sampling_eta: 1.0 cond_scale: 1.0 model_config: is_train: true pose_dim: 6 config_pth: './config/hdtf256.yaml' ae_pretrained_pth: './pretrain_models/LFG_256_400ep.pth' diffusion_pretrained_pth: './pretrain_models/DAWN_256.pth' ================================================ FILE: config/hdtf128.yaml ================================================ #Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. #No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, #publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. #Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, #title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. #In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. # Dataset parameters # Each dataset should contain 2 folders train and test # Each video can be represented as: # - an image of concatenated frames # - '.mp4' or '.gif' # - folder with all frames from a specific video dataset_params: # Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames. # Folder with frames is preferred format for training, since it is the fastest. root_dir: /work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images # Shape to resize all frames to, specify null if resizing is not needed frame_shape: 128 # In case of Vox or Taichi single video can be splitted in many chunks, or the maybe several videos for single person. # In this case epoch can be a pass over different identities (if id_sampling=True) or over different chunks (if id_sampling=False) # If the name the video '12335#adsbf.mp4' the id is assumed to be 12335 id_sampling: False # List with pairs for animation, null for random pairs pairs_list: null # Augmentation parameters see augmentation.py for all possible augmentations augmentation_params: flip_param: horizontal_flip: True time_flip: True jitter_param: brightness: 0.1 contrast: 0.1 saturation: 0.1 hue: 0.1 # Defines architecture of the model model_params: # Number of regions num_regions: 10 # Number of channels, for RGB image it is always 3 num_channels: 3 # Enable estimation of affine parameters for each region, # set to False if only region centers (keypoints) need to be estimated estimate_affine: True # Svd can perform random axis swap between source and driving if singular values are close to each other # Set to True to avoid axis swap between source and driving revert_axis_swap: True # Parameters of background prediction network based on simple Unet-like encoder. bg_predictor_params: # Number of features multiplier block_expansion: 32 # Maximum allowed number of features max_features: 1024 # Number of block in the Encoder. num_blocks: 5 # Type of background movement model, select one from ['zero', 'shift', 'affine', 'perspective'] bg_type: 'affine' # Parameters of the region prediction network based on Unet region_predictor_params: # Softmax temperature for heatmaps temperature: 0.1 # Number of features multiplier block_expansion: 32 # Maximum allowed number of features max_features: 1024 # Regions is predicted on smaller images for better performance, # scale_factor=0.25 means that 256x256 image will be resized to 64x64 scale_factor: 0.25 # Number of block in Unet. Can be increased or decreased depending or resolution. num_blocks: 5 # Either to use pca_based estimation of affine parameters of regression based pca_based: True # Either to use fast_svd (https://github.com/KinglittleQ/torch-batch-svd) or standard pytorch svd # Fast svd may produce not meaningful regions if used along with revert_axis_swap fast_svd: False # Parameters of Generator, based on Jonson architecture generator_params: # Number of features multiplier block_expansion: 64 # Maximum allowed number of features max_features: 512 # Number of down-sampling blocks in Jonson architecture. # Can be increased or decreased depending or resolution. num_down_blocks: 2 # Number of ResBlocks in Jonson architecture. num_bottleneck_blocks: 6 # To use skip connections or no. skips: True # Parameters of pixelwise flow predictor based on Unet pixelwise_flow_predictor_params: # Number of features multiplier block_expansion: 64 # Maximum allowed number of features max_features: 1024 # Number of block in Unet. Can be increased or decreased depending on resolution. num_blocks: 5 # Flow predictor operates on the smaller images for better performance, # scale_factor=0.25 means that 256x256 image will be resized to 64x64 scale_factor: 0.25 # Set to True in order to use deformed source images using sparse flow use_deformed_source: True # Set to False in order to render region heatmaps with fixed covariance # True for covariance estimate using region_predictor use_covar_heatmap: True # Set to False to disable occlusion mask estimation estimate_occlusion_map: True # Parameter for animation-via-disentanglement (avd) network avd_network_params: # Bottleneck for identity branch id_bottle_size: 64 # Bottleneck for pose branch pose_bottle_size: 64 # Parameters of training (reconstruction) train_params: max_epochs: 100 # For better i/o performance when number of videos is small number of epochs can be multiplied by this number. # Thus effectively with num_repeats=100 each epoch is 100 times larger. num_repeats: 100 # Drop learning rate 10 times after this epochs epoch_milestones: [60, 90] # Initial learning rate lr: 2.0e-4 # Batch size. (14 is batch size for one V100 gpu). batch_size: 100 # Either to use sync_bn or not, enabling sync_bn will significantly slow the training time use_sync_bn: False # Dataset preprocessing cpu workers dataloader_workers: 16 print_freq: 10 save_img_freq: 100 # update checkpoint in this frequent update_ckpt_freq: 5000 # Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256, # than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32. scales: [1, 0.5, 0.25, 0.125] # Parameters of transform for equivariance loss transform_params: sigma_affine: 0.05 sigma_tps: 0.005 points_tps: 5 loss_weights: # Weights for perceptual pyramide loss. Note that here you can only specify weight across the layer, and # weights across the resolution will be the same. perceptual: [10, 10, 10, 10, 10] rec_vgg: [0, 0, 0, 0, 0] # Weights for equivariance loss. equivariance_shift: 10 equivariance_affine: 10 # Parameters of visualization visualizer_params: # Size of keypoints kp_size: 2 # Draw border between images or not draw_border: True # Colormap for regions and keypoints visualization colormap: 'gist_rainbow' # Background color for region visualization region_bg_color: [1, 1, 1] ================================================ FILE: config/hdtf128_1000ep.yaml ================================================ #Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. #No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, #publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. #Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, #title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. #In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. # Dataset parameters # Each dataset should contain 2 folders train and test # Each video can be represented as: # - an image of concatenated frames # - '.mp4' or '.gif' # - folder with all frames from a specific video dataset_params: # Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames. # Folder with frames is preferred format for training, since it is the fastest. root_dir: /yrfs2/cv2/pcxia/audiovisual/hdtf/images # Shape to resize all frames to, specify null if resizing is not needed frame_shape: 128 # In case of Vox or Taichi single video can be splitted in many chunks, or the maybe several videos for single person. # In this case epoch can be a pass over different identities (if id_sampling=True) or over different chunks (if id_sampling=False) # If the name the video '12335#adsbf.mp4' the id is assumed to be 12335 id_sampling: False # List with pairs for animation, null for random pairs pairs_list: null # Augmentation parameters see augmentation.py for all possible augmentations augmentation_params: flip_param: horizontal_flip: True time_flip: True jitter_param: brightness: 0.1 contrast: 0.1 saturation: 0.1 hue: 0.1 # Defines architecture of the model model_params: # Number of regions num_regions: 10 # Number of channels, for RGB image it is always 3 num_channels: 3 # Enable estimation of affine parameters for each region, # set to False if only region centers (keypoints) need to be estimated estimate_affine: True # Svd can perform random axis swap between source and driving if singular values are close to each other # Set to True to avoid axis swap between source and driving revert_axis_swap: True # Parameters of background prediction network based on simple Unet-like encoder. bg_predictor_params: # Number of features multiplier block_expansion: 32 # Maximum allowed number of features max_features: 1024 # Number of block in the Encoder. num_blocks: 5 # Type of background movement model, select one from ['zero', 'shift', 'affine', 'perspective'] bg_type: 'affine' # Parameters of the region prediction network based on Unet region_predictor_params: # Softmax temperature for heatmaps temperature: 0.1 # Number of features multiplier block_expansion: 32 # Maximum allowed number of features max_features: 1024 # Regions is predicted on smaller images for better performance, # scale_factor=0.25 means that 256x256 image will be resized to 64x64 scale_factor: 0.25 # Number of block in Unet. Can be increased or decreased depending or resolution. num_blocks: 5 # Either to use pca_based estimation of affine parameters of regression based pca_based: True # Either to use fast_svd (https://github.com/KinglittleQ/torch-batch-svd) or standard pytorch svd # Fast svd may produce not meaningful regions if used along with revert_axis_swap fast_svd: False # Parameters of Generator, based on Jonson architecture generator_params: # Number of features multiplier block_expansion: 64 # Maximum allowed number of features max_features: 512 # Number of down-sampling blocks in Jonson architecture. # Can be increased or decreased depending or resolution. num_down_blocks: 2 # Number of ResBlocks in Jonson architecture. num_bottleneck_blocks: 6 # To use skip connections or no. skips: True # Parameters of pixelwise flow predictor based on Unet pixelwise_flow_predictor_params: # Number of features multiplier block_expansion: 64 # Maximum allowed number of features max_features: 1024 # Number of block in Unet. Can be increased or decreased depending on resolution. num_blocks: 5 # Flow predictor operates on the smaller images for better performance, # scale_factor=0.25 means that 256x256 image will be resized to 64x64 scale_factor: 0.25 # Set to True in order to use deformed source images using sparse flow use_deformed_source: True # Set to False in order to render region heatmaps with fixed covariance # True for covariance estimate using region_predictor use_covar_heatmap: True # Set to False to disable occlusion mask estimation estimate_occlusion_map: True # Parameter for animation-via-disentanglement (avd) network avd_network_params: # Bottleneck for identity branch id_bottle_size: 64 # Bottleneck for pose branch pose_bottle_size: 64 # Parameters of training (reconstruction) train_params: max_epochs: 1000 # For better i/o performance when number of videos is small number of epochs can be multiplied by this number. # Thus effectively with num_repeats=100 each epoch is 100 times larger. num_repeats: 100 # Drop learning rate 10 times after this epochs epoch_milestones: [60, 90] # Initial learning rate lr: 4.0e-4 # Batch size. (14 is batch size for one V100 gpu). batch_size: 82 # Either to use sync_bn or not, enabling sync_bn will significantly slow the training time use_sync_bn: False # Dataset preprocessing cpu workers dataloader_workers: 8 print_freq: 10 save_img_freq: 100 # update checkpoint in this frequent update_ckpt_freq: 5000 # Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256, # than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32. scales: [1, 0.5, 0.25, 0.125] # Parameters of transform for equivariance loss transform_params: sigma_affine: 0.05 sigma_tps: 0.005 points_tps: 5 loss_weights: # Weights for perceptual pyramide loss. Note that here you can only specify weight across the layer, and # weights across the resolution will be the same. perceptual: [10, 10, 10, 10, 10] rec_vgg: [1, 1, 1, 1, 1] # Weights for equivariance loss. equivariance_shift: 10 equivariance_affine: 10 # Parameters of visualization visualizer_params: # Size of keypoints kp_size: 2 # Draw border between images or not draw_border: True # Colormap for regions and keypoints visualization colormap: 'gist_rainbow' # Background color for region visualization region_bg_color: [1, 1, 1] ================================================ FILE: config/hdtf128_1000ep_crema.yaml ================================================ #Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. #No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, #publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. #Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, #title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. #In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. # Dataset parameters # Each dataset should contain 2 folders train and test # Each video can be represented as: # - an image of concatenated frames # - '.mp4' or '.gif' # - folder with all frames from a specific video dataset_params: # Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames. # Folder with frames is preferred format for training, since it is the fastest. root_dir: /work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images # Shape to resize all frames to, specify null if resizing is not needed frame_shape: 128 # In case of Vox or Taichi single video can be splitted in many chunks, or the maybe several videos for single person. # In this case epoch can be a pass over different identities (if id_sampling=True) or over different chunks (if id_sampling=False) # If the name the video '12335#adsbf.mp4' the id is assumed to be 12335 id_sampling: False # List with pairs for animation, null for random pairs pairs_list: null # Augmentation parameters see augmentation.py for all possible augmentations augmentation_params: flip_param: horizontal_flip: True time_flip: True jitter_param: brightness: 0.1 contrast: 0.1 saturation: 0.1 hue: 0.1 # Defines architecture of the model model_params: # Number of regions num_regions: 10 # Number of channels, for RGB image it is always 3 num_channels: 3 # Enable estimation of affine parameters for each region, # set to False if only region centers (keypoints) need to be estimated estimate_affine: True # Svd can perform random axis swap between source and driving if singular values are close to each other # Set to True to avoid axis swap between source and driving revert_axis_swap: True # Parameters of background prediction network based on simple Unet-like encoder. bg_predictor_params: # Number of features multiplier block_expansion: 32 # Maximum allowed number of features max_features: 1024 # Number of block in the Encoder. num_blocks: 5 # Type of background movement model, select one from ['zero', 'shift', 'affine', 'perspective'] bg_type: 'affine' # Parameters of the region prediction network based on Unet region_predictor_params: # Softmax temperature for heatmaps temperature: 0.1 # Number of features multiplier block_expansion: 32 # Maximum allowed number of features max_features: 1024 # Regions is predicted on smaller images for better performance, # scale_factor=0.25 means that 256x256 image will be resized to 64x64 scale_factor: 0.25 # Number of block in Unet. Can be increased or decreased depending or resolution. num_blocks: 5 # Either to use pca_based estimation of affine parameters of regression based pca_based: True # Either to use fast_svd (https://github.com/KinglittleQ/torch-batch-svd) or standard pytorch svd # Fast svd may produce not meaningful regions if used along with revert_axis_swap fast_svd: False # Parameters of Generator, based on Jonson architecture generator_params: # Number of features multiplier block_expansion: 64 # Maximum allowed number of features max_features: 512 # Number of down-sampling blocks in Jonson architecture. # Can be increased or decreased depending or resolution. num_down_blocks: 2 # Number of ResBlocks in Jonson architecture. num_bottleneck_blocks: 6 # To use skip connections or no. skips: True # Parameters of pixelwise flow predictor based on Unet pixelwise_flow_predictor_params: # Number of features multiplier block_expansion: 64 # Maximum allowed number of features max_features: 1024 # Number of block in Unet. Can be increased or decreased depending on resolution. num_blocks: 5 # Flow predictor operates on the smaller images for better performance, # scale_factor=0.25 means that 256x256 image will be resized to 64x64 scale_factor: 0.25 # Set to True in order to use deformed source images using sparse flow use_deformed_source: True # Set to False in order to render region heatmaps with fixed covariance # True for covariance estimate using region_predictor use_covar_heatmap: True # Set to False to disable occlusion mask estimation estimate_occlusion_map: True # Parameter for animation-via-disentanglement (avd) network avd_network_params: # Bottleneck for identity branch id_bottle_size: 64 # Bottleneck for pose branch pose_bottle_size: 64 # Parameters of training (reconstruction) train_params: max_epochs: 600 # For better i/o performance when number of videos is small number of epochs can be multiplied by this number. # Thus effectively with num_repeats=100 each epoch is 100 times larger. num_repeats: 100 # Drop learning rate 10 times after this epochs epoch_milestones: [60, 90] # Initial learning rate lr: 4.0e-4 # Batch size. (14 is batch size for one V100 gpu). batch_size: 100 # Either to use sync_bn or not, enabling sync_bn will significantly slow the training time use_sync_bn: False # Dataset preprocessing cpu workers dataloader_workers: 8 print_freq: 10 save_img_freq: 100 # update checkpoint in this frequent update_ckpt_freq: 5000 # Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256, # than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32. scales: [1, 0.5, 0.25, 0.125] # Parameters of transform for equivariance loss transform_params: sigma_affine: 0.05 sigma_tps: 0.005 points_tps: 5 loss_weights: # Weights for perceptual pyramide loss. Note that here you can only specify weight across the layer, and # weights across the resolution will be the same. perceptual: [10, 10, 10, 10, 10] rec_vgg: [1, 1, 1, 1, 1] # Weights for equivariance loss. equivariance_shift: 10 equivariance_affine: 10 # Parameters of visualization visualizer_params: # Size of keypoints kp_size: 2 # Draw border between images or not draw_border: True # Colormap for regions and keypoints visualization colormap: 'gist_rainbow' # Background color for region visualization region_bg_color: [1, 1, 1] ================================================ FILE: config/hdtf256.yaml ================================================ #Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. #No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, #publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. #Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, #title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. #In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. # Dataset parameters # Each dataset should contain 2 folders train and test # Each video can be represented as: # - an image of concatenated frames # - '.mp4' or '.gif' # - folder with all frames from a specific video dataset_params: # Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames. # Folder with frames is preferred format for training, since it is the fastest. root_dir: /yrfs2/cv2/pcxia/audiovisual/hdtf/images # Shape to resize all frames to, specify null if resizing is not needed frame_shape: 256 # In case of Vox or Taichi single video can be splitted in many chunks, or the maybe several videos for single person. # In this case epoch can be a pass over different identities (if id_sampling=True) or over different chunks (if id_sampling=False) # If the name the video '12335#adsbf.mp4' the id is assumed to be 12335 id_sampling: False # List with pairs for animation, null for random pairs pairs_list: null # Augmentation parameters see augmentation.py for all possible augmentations augmentation_params: flip_param: horizontal_flip: True time_flip: True jitter_param: brightness: 0.1 contrast: 0.1 saturation: 0.1 hue: 0.1 # Defines architecture of the model model_params: # Number of regions num_regions: 10 # Number of channels, for RGB image it is always 3 num_channels: 3 # Enable estimation of affine parameters for each region, # set to False if only region centers (keypoints) need to be estimated estimate_affine: True # Svd can perform random axis swap between source and driving if singular values are close to each other # Set to True to avoid axis swap between source and driving revert_axis_swap: True # Parameters of background prediction network based on simple Unet-like encoder. bg_predictor_params: # Number of features multiplier block_expansion: 32 # Maximum allowed number of features max_features: 1024 # Number of block in the Encoder. num_blocks: 5 # Type of background movement model, select one from ['zero', 'shift', 'affine', 'perspective'] bg_type: 'affine' # Parameters of the region prediction network based on Unet region_predictor_params: # Softmax temperature for heatmaps temperature: 0.1 # Number of features multiplier block_expansion: 32 # Maximum allowed number of features max_features: 1024 # Regions is predicted on smaller images for better performance, # scale_factor=0.25 means that 256x256 image will be resized to 64x64 scale_factor: 0.25 # Number of block in Unet. Can be increased or decreased depending or resolution. num_blocks: 5 # Either to use pca_based estimation of affine parameters of regression based pca_based: True # Either to use fast_svd (https://github.com/KinglittleQ/torch-batch-svd) or standard pytorch svd # Fast svd may produce not meaningful regions if used along with revert_axis_swap fast_svd: False # Parameters of Generator, based on Jonson architecture generator_params: # Number of features multiplier block_expansion: 64 # Maximum allowed number of features max_features: 512 # Number of down-sampling blocks in Jonson architecture. # Can be increased or decreased depending or resolution. num_down_blocks: 2 # Number of ResBlocks in Jonson architecture. num_bottleneck_blocks: 6 # To use skip connections or no. skips: True # Parameters of pixelwise flow predictor based on Unet pixelwise_flow_predictor_params: # Number of features multiplier block_expansion: 64 # Maximum allowed number of features max_features: 1024 # Number of block in Unet. Can be increased or decreased depending or resolution. num_blocks: 5 # Flow predictor operates on the smaller images for better performance, # scale_factor=0.25 means that 256x256 image will be resized to 64x64 scale_factor: 0.25 # Set to True in order to use deformed source images using sparse flow use_deformed_source: True # Set to False in order to render region heatmaps with fixed covariance # True for covariance estimate using region_predictor use_covar_heatmap: True # Set to False to disable occlusion mask estimation estimate_occlusion_map: True # Parameter for animation-via-disentanglement (avd) network avd_network_params: # Bottleneck for identity branch id_bottle_size: 64 # Bottleneck for pose branch pose_bottle_size: 64 # Parameters of training (reconstruction) train_params: max_epochs: 100 # For better i/o performance when number of videos is small number of epochs can be multiplied by this number. # Thus effectively with num_repeats=100 each epoch is 100 times larger. num_repeats: 100 # Drop learning rate 10 times after this epochs epoch_milestones: [60, 90] # Initial learning rate lr: 2.0e-4 # Batch size. (14 is batch size for one V100 gpu). batch_size: 42 # Either to use sync_bn or not, enabling sync_bn will significantly slow the training time use_sync_bn: False # Dataset preprocessing cpu workers dataloader_workers: 12 print_freq: 10 save_img_freq: 100 # update checkpoint in this frequent update_ckpt_freq: 5000 # Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256, # than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32. scales: [1, 0.5, 0.25, 0.125] # Parameters of transform for equivariance loss transform_params: sigma_affine: 0.05 sigma_tps: 0.005 points_tps: 5 loss_weights: # Weights for perceptual pyramide loss. Note that here you can only specify weight across the layer, and # weights across the resolution will be the same. perceptual: [10, 10, 10, 10, 10] # Weights for equivariance loss. equivariance_shift: 10 equivariance_affine: 10 # Parameters of visualization visualizer_params: # Size of keypoints kp_size: 2 # Draw border between images or not draw_border: True # Colormap for regions and keypoints visualization colormap: 'gist_rainbow' # Background color for region visualization region_bg_color: [1, 1, 1] ================================================ FILE: config/hdtf256_400ep.yaml ================================================ #Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. #No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, #publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. #Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, #title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. #In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. # Dataset parameters # Each dataset should contain 2 folders train and test # Each video can be represented as: # - an image of concatenated frames # - '.mp4' or '.gif' # - folder with all frames from a specific video dataset_params: # Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames. # Folder with frames is preferred format for training, since it is the fastest. root_dir: /yrfs2/cv2/pcxia/audiovisual/hdtf/images # Shape to resize all frames to, specify null if resizing is not needed frame_shape: 256 # In case of Vox or Taichi single video can be splitted in many chunks, or the maybe several videos for single person. # In this case epoch can be a pass over different identities (if id_sampling=True) or over different chunks (if id_sampling=False) # If the name the video '12335#adsbf.mp4' the id is assumed to be 12335 id_sampling: False # List with pairs for animation, null for random pairs pairs_list: null # Augmentation parameters see augmentation.py for all possible augmentations augmentation_params: flip_param: horizontal_flip: True time_flip: True jitter_param: brightness: 0.1 contrast: 0.1 saturation: 0.1 hue: 0.1 # Defines architecture of the model model_params: # Number of regions num_regions: 10 # Number of channels, for RGB image it is always 3 num_channels: 3 # Enable estimation of affine parameters for each region, # set to False if only region centers (keypoints) need to be estimated estimate_affine: True # Svd can perform random axis swap between source and driving if singular values are close to each other # Set to True to avoid axis swap between source and driving revert_axis_swap: True # Parameters of background prediction network based on simple Unet-like encoder. bg_predictor_params: # Number of features multiplier block_expansion: 32 # Maximum allowed number of features max_features: 1024 # Number of block in the Encoder. num_blocks: 5 # Type of background movement model, select one from ['zero', 'shift', 'affine', 'perspective'] bg_type: 'affine' # Parameters of the region prediction network based on Unet region_predictor_params: # Softmax temperature for heatmaps temperature: 0.1 # Number of features multiplier block_expansion: 32 # Maximum allowed number of features max_features: 1024 # Regions is predicted on smaller images for better performance, # scale_factor=0.25 means that 256x256 image will be resized to 64x64 scale_factor: 0.25 # Number of block in Unet. Can be increased or decreased depending or resolution. num_blocks: 5 # Either to use pca_based estimation of affine parameters of regression based pca_based: True # Either to use fast_svd (https://github.com/KinglittleQ/torch-batch-svd) or standard pytorch svd # Fast svd may produce not meaningful regions if used along with revert_axis_swap fast_svd: False # Parameters of Generator, based on Jonson architecture generator_params: # Number of features multiplier block_expansion: 64 # Maximum allowed number of features max_features: 512 # Number of down-sampling blocks in Jonson architecture. # Can be increased or decreased depending or resolution. num_down_blocks: 2 # Number of ResBlocks in Jonson architecture. num_bottleneck_blocks: 6 # To use skip connections or no. skips: True # Parameters of pixelwise flow predictor based on Unet pixelwise_flow_predictor_params: # Number of features multiplier block_expansion: 64 # Maximum allowed number of features max_features: 1024 # Number of block in Unet. Can be increased or decreased depending or resolution. num_blocks: 5 # Flow predictor operates on the smaller images for better performance, # scale_factor=0.25 means that 256x256 image will be resized to 64x64 scale_factor: 0.25 # Set to True in order to use deformed source images using sparse flow use_deformed_source: True # Set to False in order to render region heatmaps with fixed covariance # True for covariance estimate using region_predictor use_covar_heatmap: True # Set to False to disable occlusion mask estimation estimate_occlusion_map: True # Parameter for animation-via-disentanglement (avd) network avd_network_params: # Bottleneck for identity branch id_bottle_size: 64 # Bottleneck for pose branch pose_bottle_size: 64 # Parameters of training (reconstruction) train_params: max_epochs: 400 # For better i/o performance when number of videos is small number of epochs can be multiplied by this number. # Thus effectively with num_repeats=100 each epoch is 100 times larger. num_repeats: 100 # Drop learning rate 10 times after this epochs epoch_milestones: [60, 90] # Initial learning rate lr: 6.0e-4 # Batch size. (14 is batch size for one V100 gpu). batch_size: 20 # Either to use sync_bn or not, enabling sync_bn will significantly slow the training time use_sync_bn: False # Dataset preprocessing cpu workers dataloader_workers: 12 print_freq: 10 save_img_freq: 100 # update checkpoint in this frequent update_ckpt_freq: 5000 # Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256, # than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32. scales: [1, 0.5, 0.25, 0.125] # Parameters of transform for equivariance loss transform_params: sigma_affine: 0.05 sigma_tps: 0.005 points_tps: 5 loss_weights: # Weights for perceptual pyramide loss. Note that here you can only specify weight across the layer, and # weights across the resolution will be the same. perceptual: [10, 10, 10, 10, 10] # Weights for equivariance loss. equivariance_shift: 10 equivariance_affine: 10 # Parameters of visualization visualizer_params: # Size of keypoints kp_size: 2 # Draw border between images or not draw_border: True # Colormap for regions and keypoints visualization colormap: 'gist_rainbow' # Background color for region visualization region_bg_color: [1, 1, 1] ================================================ FILE: extract_init_states/FaceBoxes/FaceBoxes.py ================================================ # coding: utf-8 import os.path as osp import torch import numpy as np import cv2 from .utils.prior_box import PriorBox from .utils.nms_wrapper import nms from .utils.box_utils import decode from .utils.timer import Timer from .utils.functions import check_keys, remove_prefix, load_model from .utils.config import cfg from .models.faceboxes import FaceBoxesNet import torch.backends.cudnn as cudnn # some global configs confidence_threshold = 0.05 top_k = 5000 keep_top_k = 750 nms_threshold = 0.3 vis_thres = 0.5 resize = 1 scale_flag = True HEIGHT, WIDTH = 720, 1080 make_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn) pretrained_path = make_abs_path('weights/FaceBoxesProd.pth') def viz_bbox(img, dets, wfp='out.jpg'): # show for b in dets: if b[4] < vis_thres: continue text = "{:.4f}".format(b[4]) b = list(map(int, b)) cv2.rectangle(img, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2) cx = b[0] cy = b[1] + 12 cv2.putText(img, text, (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255)) cv2.imwrite(wfp, img) print(f'Viz bbox to {wfp}') class FaceBoxes: def __init__(self, timer_flag=False): torch.set_grad_enabled(False) net = FaceBoxesNet(phase='test', size=None, num_classes=2) # initialize detector self.net = load_model(net, pretrained_path=pretrained_path, load_to_cpu=True) self.net.eval() # print('Finished loading model!') cudnn.benchmark = True self.net = self.net.cuda() self.timer_flag = timer_flag @torch.no_grad() def __call__(self, img_): img_raw = img_.copy() # scaling to speed up scale = 1 if scale_flag: h, w = img_raw.shape[:2] if h > HEIGHT: scale = HEIGHT / h if w * scale > WIDTH: scale *= WIDTH / (w * scale) # print(scale) if scale == 1: img_raw_scale = img_raw else: h_s = int(scale * h) w_s = int(scale * w) # print(h_s, w_s) img_raw_scale = cv2.resize(img_raw, dsize=(w_s, h_s)) # print(img_raw_scale.shape) img = np.float32(img_raw_scale) else: img = np.float32(img_raw) # forward _t = {'forward_pass': Timer(), 'misc': Timer()} im_height, im_width, _ = img.shape scale_bbox = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]]).cuda() img -= (104, 117, 123) img = img.transpose(2, 0, 1) img = torch.from_numpy(img).cuda().unsqueeze(0) _t['forward_pass'].tic() loc, conf = self.net(img) # forward pass _t['forward_pass'].toc() _t['misc'].tic() priorbox = PriorBox(image_size=(im_height, im_width)) priors = priorbox.forward() prior_data = priors.data.cuda() boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance']) if scale_flag: boxes = boxes * scale_bbox / scale / resize else: boxes = boxes * scale_bbox / resize boxes = boxes.cpu().numpy() scores = conf.squeeze(0).data.cpu().numpy()[:, 1] # ignore low scores inds = np.where(scores > confidence_threshold)[0] boxes = boxes[inds] scores = scores[inds] # keep top-K before NMS order = scores.argsort()[::-1][:top_k] boxes = boxes[order] scores = scores[order] # do NMS dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) keep = nms(dets, nms_threshold) dets = dets[keep, :] # keep top-K faster NMS dets = dets[:keep_top_k, :] _t['misc'].toc() if self.timer_flag: print('Detection: {:d}/{:d} forward_pass_time: {:.4f}s misc: {:.4f}s'.format(1, 1, _t[ 'forward_pass'].average_time, _t['misc'].average_time)) # filter using vis_thres det_bboxes = [] for b in dets: if b[4] > vis_thres: xmin, ymin, xmax, ymax, score = b[0], b[1], b[2], b[3], b[4] bbox = [xmin, ymin, xmax, ymax, score] det_bboxes.append(bbox) return det_bboxes def main(): face_boxes = FaceBoxes(timer_flag=True) fn = 'trump_hillary.jpg' img_fp = f'../examples/inputs/{fn}' img = cv2.imread(img_fp) print(f'input shape: {img.shape}') dets = face_boxes(img) # xmin, ymin, w, h # print(dets) # repeating inference for `n` times n = 10 for i in range(n): dets = face_boxes(img) wfn = fn.replace('.jpg', '_det.jpg') wfp = osp.join('../examples/results', wfn) viz_bbox(img, dets, wfp) if __name__ == '__main__': main() ================================================ FILE: extract_init_states/FaceBoxes/FaceBoxes_ONNX.py ================================================ # coding: utf-8 import os.path as osp import torch import numpy as np import cv2 from .utils.prior_box import PriorBox from .utils.nms_wrapper import nms from .utils.box_utils import decode from .utils.timer import Timer from .utils.config import cfg from .onnx import convert_to_onnx import onnxruntime # some global configs confidence_threshold = 0.05 top_k = 5000 keep_top_k = 750 nms_threshold = 0.3 vis_thres = 0.2 resize = 1 scale_flag = True HEIGHT, WIDTH = 720, 1080 make_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn) onnx_path = make_abs_path('weights/FaceBoxesProd.onnx') def viz_bbox(img, dets, wfp='out.jpg'): # show for b in dets: if b[4] < vis_thres: continue text = "{:.4f}".format(b[4]) b = list(map(int, b)) cv2.rectangle(img, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2) cx = b[0] cy = b[1] + 12 cv2.putText(img, text, (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255)) cv2.imwrite(wfp, img) print(f'Viz bbox to {wfp}') class FaceBoxes_ONNX(object): def __init__(self, timer_flag=False): if not osp.exists(onnx_path): convert_to_onnx(onnx_path) self.session = onnxruntime.InferenceSession(onnx_path, providers=['CUDAExecutionProvider']) self.timer_flag = timer_flag def __call__(self, img_): img_raw = img_.copy() # scaling to speed up scale = 1 if scale_flag: h, w = img_raw.shape[:2] if h > HEIGHT: scale = HEIGHT / h if w * scale > WIDTH: scale *= WIDTH / (w * scale) # print(scale) if scale == 1: img_raw_scale = img_raw else: h_s = int(scale * h) w_s = int(scale * w) # print(h_s, w_s) img_raw_scale = cv2.resize(img_raw, dsize=(w_s, h_s)) # print(img_raw_scale.shape) img = np.float32(img_raw_scale) else: img = np.float32(img_raw) # forward _t = {'forward_pass': Timer(), 'misc': Timer()} im_height, im_width, _ = img.shape scale_bbox = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) img -= (104, 117, 123) img = img.transpose(2, 0, 1) # img = torch.from_numpy(img).unsqueeze(0) img = img[np.newaxis, ...] _t['forward_pass'].tic() # loc, conf = self.net(img) # forward pass out = self.session.run(None, {'input': img}) loc, conf = out[0], out[1] # for compatibility, may need to optimize loc = torch.from_numpy(loc) _t['forward_pass'].toc() _t['misc'].tic() priorbox = PriorBox(image_size=(im_height, im_width)) priors = priorbox.forward() prior_data = priors.data boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance']) if scale_flag: boxes = boxes * scale_bbox / scale / resize else: boxes = boxes * scale_bbox / resize boxes = boxes.cpu().numpy() scores = conf[0][:, 1] # scores = conf.squeeze(0).data.cpu().numpy()[:, 1] # ignore low scores inds = np.where(scores > confidence_threshold)[0] boxes = boxes[inds] scores = scores[inds] # keep top-K before NMS order = scores.argsort()[::-1][:top_k] boxes = boxes[order] scores = scores[order] # do NMS dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) keep = nms(dets, nms_threshold) dets = dets[keep, :] # keep top-K faster NMS dets = dets[:keep_top_k, :] _t['misc'].toc() if self.timer_flag: print('Detection: {:d}/{:d} forward_pass_time: {:.4f}s misc: {:.4f}s'.format(1, 1, _t[ 'forward_pass'].average_time, _t['misc'].average_time)) # filter using vis_thres det_bboxes = [] for b in dets: if b[4] > vis_thres: xmin, ymin, xmax, ymax, score = b[0], b[1], b[2], b[3], b[4] bbox = [xmin, ymin, xmax, ymax, score] det_bboxes.append(bbox) return det_bboxes def main(): face_boxes = FaceBoxes_ONNX(timer_flag=True) fn = 'trump_hillary.jpg' img_fp = f'../examples/inputs/{fn}' img = cv2.imread(img_fp) print(f'input shape: {img.shape}') dets = face_boxes(img) # xmin, ymin, w, h # print(dets) # repeating inference for `n` times n = 10 for i in range(n): dets = face_boxes(img) wfn = fn.replace('.jpg', '_det.jpg') wfp = osp.join('../examples/results', wfn) viz_bbox(img, dets, wfp) if __name__ == '__main__': main() ================================================ FILE: extract_init_states/FaceBoxes/__init__.py ================================================ from .FaceBoxes import FaceBoxes ================================================ FILE: extract_init_states/FaceBoxes/build_cpu_nms.sh ================================================ cd utils python3 build.py build_ext --inplace cd .. ================================================ FILE: extract_init_states/FaceBoxes/models/__init__.py ================================================ ================================================ FILE: extract_init_states/FaceBoxes/models/faceboxes.py ================================================ # coding: utf-8 import torch import torch.nn as nn import torch.nn.functional as F class BasicConv2d(nn.Module): def __init__(self, in_channels, out_channels, **kwargs): super(BasicConv2d, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) self.bn = nn.BatchNorm2d(out_channels, eps=1e-5) def forward(self, x): x = self.conv(x) x = self.bn(x) return F.relu(x, inplace=True) class Inception(nn.Module): def __init__(self): super(Inception, self).__init__() self.branch1x1 = BasicConv2d(128, 32, kernel_size=1, padding=0) self.branch1x1_2 = BasicConv2d(128, 32, kernel_size=1, padding=0) self.branch3x3_reduce = BasicConv2d(128, 24, kernel_size=1, padding=0) self.branch3x3 = BasicConv2d(24, 32, kernel_size=3, padding=1) self.branch3x3_reduce_2 = BasicConv2d(128, 24, kernel_size=1, padding=0) self.branch3x3_2 = BasicConv2d(24, 32, kernel_size=3, padding=1) self.branch3x3_3 = BasicConv2d(32, 32, kernel_size=3, padding=1) def forward(self, x): branch1x1 = self.branch1x1(x) branch1x1_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) branch1x1_2 = self.branch1x1_2(branch1x1_pool) branch3x3_reduce = self.branch3x3_reduce(x) branch3x3 = self.branch3x3(branch3x3_reduce) branch3x3_reduce_2 = self.branch3x3_reduce_2(x) branch3x3_2 = self.branch3x3_2(branch3x3_reduce_2) branch3x3_3 = self.branch3x3_3(branch3x3_2) outputs = [branch1x1, branch1x1_2, branch3x3, branch3x3_3] return torch.cat(outputs, 1) class CRelu(nn.Module): def __init__(self, in_channels, out_channels, **kwargs): super(CRelu, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) self.bn = nn.BatchNorm2d(out_channels, eps=1e-5) def forward(self, x): x = self.conv(x) x = self.bn(x) x = torch.cat([x, -x], 1) x = F.relu(x, inplace=True) return x class FaceBoxesNet(nn.Module): def __init__(self, phase, size, num_classes): super(FaceBoxesNet, self).__init__() self.phase = phase self.num_classes = num_classes self.size = size self.conv1 = CRelu(3, 24, kernel_size=7, stride=4, padding=3) self.conv2 = CRelu(48, 64, kernel_size=5, stride=2, padding=2) self.inception1 = Inception() self.inception2 = Inception() self.inception3 = Inception() self.conv3_1 = BasicConv2d(128, 128, kernel_size=1, stride=1, padding=0) self.conv3_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1) self.conv4_1 = BasicConv2d(256, 128, kernel_size=1, stride=1, padding=0) self.conv4_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1) self.loc, self.conf = self.multibox(self.num_classes) if self.phase == 'test': self.softmax = nn.Softmax(dim=-1) if self.phase == 'train': for m in self.modules(): if isinstance(m, nn.Conv2d): if m.bias is not None: nn.init.xavier_normal_(m.weight.data) m.bias.data.fill_(0.02) else: m.weight.data.normal_(0, 0.01) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def multibox(self, num_classes): loc_layers = [] conf_layers = [] loc_layers += [nn.Conv2d(128, 21 * 4, kernel_size=3, padding=1)] conf_layers += [nn.Conv2d(128, 21 * num_classes, kernel_size=3, padding=1)] loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)] conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)] loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)] conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)] return nn.Sequential(*loc_layers), nn.Sequential(*conf_layers) def forward(self, x): detection_sources = list() loc = list() conf = list() x = self.conv1(x) x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) x = self.conv2(x) x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) x = self.inception1(x) x = self.inception2(x) x = self.inception3(x) detection_sources.append(x) x = self.conv3_1(x) x = self.conv3_2(x) detection_sources.append(x) x = self.conv4_1(x) x = self.conv4_2(x) detection_sources.append(x) for (x, l, c) in zip(detection_sources, self.loc, self.conf): loc.append(l(x).permute(0, 2, 3, 1).contiguous()) conf.append(c(x).permute(0, 2, 3, 1).contiguous()) loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) if self.phase == "test": output = (loc.view(loc.size(0), -1, 4), self.softmax(conf.view(conf.size(0), -1, self.num_classes))) else: output = (loc.view(loc.size(0), -1, 4), conf.view(conf.size(0), -1, self.num_classes)) return output ================================================ FILE: extract_init_states/FaceBoxes/onnx.py ================================================ # coding: utf-8 __author__ = 'cleardusk' import torch from .models.faceboxes import FaceBoxesNet from .utils.functions import load_model def convert_to_onnx(onnx_path): pretrained_path = onnx_path.replace('.onnx', '.pth') # 1. load model torch.set_grad_enabled(False) net = FaceBoxesNet(phase='test', size=None, num_classes=2) # initialize detector net = load_model(net, pretrained_path=pretrained_path, load_to_cpu=True) net.eval() # 2. convert batch_size = 1 dummy_input = torch.randn(batch_size, 3, 720, 1080) # export with dynamic axes for various input sizes torch.onnx.export( net, (dummy_input,), onnx_path, input_names=['input'], output_names=['output'], dynamic_axes={ 'input': [0, 2, 3], 'output': [0] }, do_constant_folding=True ) print(f'Convert {pretrained_path} to {onnx_path} done.') ================================================ FILE: extract_init_states/FaceBoxes/readme.md ================================================ ## How to fun FaceBoxes ### Build the cpu version of NMS ```shell script cd utils python3 build.py build_ext --inplace ``` or just run ```shell script sh ./build_cpu_nms.sh ``` ### Run the demo of face detection ```shell script python3 FaceBoxes.py ``` ================================================ FILE: extract_init_states/FaceBoxes/utils/.gitignore ================================================ utils/build utils/nms/*.so utils/*.c build/ ================================================ FILE: extract_init_states/FaceBoxes/utils/__init__.py ================================================ ================================================ FILE: extract_init_states/FaceBoxes/utils/box_utils.py ================================================ # coding: utf-8 import torch import numpy as np def point_form(boxes): """ Convert prior_boxes to (xmin, ymin, xmax, ymax) representation for comparison to point form ground truth data. Args: boxes: (tensor) center-size default boxes from priorbox layers. Return: boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. """ return torch.cat((boxes[:, :2] - boxes[:, 2:] / 2, # xmin, ymin boxes[:, :2] + boxes[:, 2:] / 2), 1) # xmax, ymax def center_size(boxes): """ Convert prior_boxes to (cx, cy, w, h) representation for comparison to center-size form ground truth data. Args: boxes: (tensor) point_form boxes Return: boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. """ return torch.cat((boxes[:, 2:] + boxes[:, :2]) / 2, # cx, cy boxes[:, 2:] - boxes[:, :2], 1) # w, h def intersect(box_a, box_b): """ We resize both tensors to [A,B,2] without new malloc: [A,2] -> [A,1,2] -> [A,B,2] [B,2] -> [1,B,2] -> [A,B,2] Then we compute the area of intersect between box_a and box_b. Args: box_a: (tensor) bounding boxes, Shape: [A,4]. box_b: (tensor) bounding boxes, Shape: [B,4]. Return: (tensor) intersection area, Shape: [A,B]. """ A = box_a.size(0) B = box_b.size(0) max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2)) inter = torch.clamp((max_xy - min_xy), min=0) return inter[:, :, 0] * inter[:, :, 1] def jaccard(box_a, box_b): """Compute the jaccard overlap of two sets of boxes. The jaccard overlap is simply the intersection over union of two boxes. Here we operate on ground truth boxes and default boxes. E.g.: A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) Args: box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] Return: jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] """ inter = intersect(box_a, box_b) area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] union = area_a + area_b - inter return inter / union # [A,B] def matrix_iou(a, b): """ return iou of a and b, numpy version for data augenmentation """ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) return area_i / (area_a[:, np.newaxis] + area_b - area_i) def matrix_iof(a, b): """ return iof of a and b, numpy version for data augenmentation """ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) return area_i / np.maximum(area_a[:, np.newaxis], 1) def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx): """Match each prior box with the ground truth box of the highest jaccard overlap, encode the bounding boxes, then return the matched indices corresponding to both confidence and location preds. Args: threshold: (float) The overlap threshold used when mathing boxes. truths: (tensor) Ground truth boxes, Shape: [num_obj, num_priors]. priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. variances: (tensor) Variances corresponding to each prior coord, Shape: [num_priors, 4]. labels: (tensor) All the class labels for the image, Shape: [num_obj]. loc_t: (tensor) Tensor to be filled w/ endcoded location targets. conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. idx: (int) current batch index Return: The matched indices corresponding to 1)location and 2)confidence preds. """ # jaccard index overlaps = jaccard( truths, point_form(priors) ) # (Bipartite Matching) # [1,num_objects] best prior for each ground truth best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) # ignore hard gt valid_gt_idx = best_prior_overlap[:, 0] >= 0.2 best_prior_idx_filter = best_prior_idx[valid_gt_idx, :] if best_prior_idx_filter.shape[0] <= 0: loc_t[idx] = 0 conf_t[idx] = 0 return # [1,num_priors] best ground truth for each prior best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) best_truth_idx.squeeze_(0) best_truth_overlap.squeeze_(0) best_prior_idx.squeeze_(1) best_prior_idx_filter.squeeze_(1) best_prior_overlap.squeeze_(1) best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior # TODO refactor: index best_prior_idx with long tensor # ensure every gt matches with its prior of max overlap for j in range(best_prior_idx.size(0)): best_truth_idx[best_prior_idx[j]] = j matches = truths[best_truth_idx] # Shape: [num_priors,4] conf = labels[best_truth_idx] # Shape: [num_priors] conf[best_truth_overlap < threshold] = 0 # label as background loc = encode(matches, priors, variances) loc_t[idx] = loc # [num_priors,4] encoded offsets to learn conf_t[idx] = conf # [num_priors] top class label for each prior def encode(matched, priors, variances): """Encode the variances from the priorbox layers into the ground truth boxes we have matched (based on jaccard overlap) with the prior boxes. Args: matched: (tensor) Coords of ground truth for each prior in point-form Shape: [num_priors, 4]. priors: (tensor) Prior boxes in center-offset form Shape: [num_priors,4]. variances: (list[float]) Variances of priorboxes Return: encoded boxes (tensor), Shape: [num_priors, 4] """ # dist b/t match center and prior's center g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] # encode variance g_cxcy /= (variances[0] * priors[:, 2:]) # match wh / prior wh g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] g_wh = torch.log(g_wh) / variances[1] # return target for smooth_l1_loss return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] # Adapted from https://github.com/Hakuyume/chainer-ssd def decode(loc, priors, variances): """Decode locations from predictions using priors to undo the encoding we did for offset regression at train time. Args: loc (tensor): location predictions for loc layers, Shape: [num_priors,4] priors (tensor): Prior boxes in center-offset form. Shape: [num_priors,4]. variances: (list[float]) Variances of priorboxes Return: decoded bounding box predictions """ boxes = torch.cat(( priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) boxes[:, :2] -= boxes[:, 2:] / 2 boxes[:, 2:] += boxes[:, :2] return boxes def log_sum_exp(x): """Utility function for computing log_sum_exp while determining This will be used to determine unaveraged confidence loss across all examples in a batch. Args: x (Variable(tensor)): conf_preds from conf layers """ x_max = x.data.max() return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max # Original author: Francisco Massa: # https://github.com/fmassa/object-detection.torch # Ported to PyTorch by Max deGroot (02/01/2017) def nms(boxes, scores, overlap=0.5, top_k=200): """Apply non-maximum suppression at test time to avoid detecting too many overlapping bounding boxes for a given object. Args: boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. scores: (tensor) The class predscores for the img, Shape:[num_priors]. overlap: (float) The overlap thresh for suppressing unnecessary boxes. top_k: (int) The Maximum number of box preds to consider. Return: The indices of the kept boxes with respect to num_priors. """ keep = torch.Tensor(scores.size(0)).fill_(0).long() if boxes.numel() == 0: return keep x1 = boxes[:, 0] y1 = boxes[:, 1] x2 = boxes[:, 2] y2 = boxes[:, 3] area = torch.mul(x2 - x1, y2 - y1) v, idx = scores.sort(0) # sort in ascending order # I = I[v >= 0.01] idx = idx[-top_k:] # indices of the top-k largest vals xx1 = boxes.new() yy1 = boxes.new() xx2 = boxes.new() yy2 = boxes.new() w = boxes.new() h = boxes.new() # keep = torch.Tensor() count = 0 while idx.numel() > 0: i = idx[-1] # index of current largest val # keep.append(i) keep[count] = i count += 1 if idx.size(0) == 1: break idx = idx[:-1] # remove kept element from view # load bboxes of next highest vals torch.index_select(x1, 0, idx, out=xx1) torch.index_select(y1, 0, idx, out=yy1) torch.index_select(x2, 0, idx, out=xx2) torch.index_select(y2, 0, idx, out=yy2) # store element-wise max with next highest score xx1 = torch.clamp(xx1, min=x1[i]) yy1 = torch.clamp(yy1, min=y1[i]) xx2 = torch.clamp(xx2, max=x2[i]) yy2 = torch.clamp(yy2, max=y2[i]) w.resize_as_(xx2) h.resize_as_(yy2) w = xx2 - xx1 h = yy2 - yy1 # check sizes of xx1 and xx2.. after each iteration w = torch.clamp(w, min=0.0) h = torch.clamp(h, min=0.0) inter = w * h # IoU = i / (area(a) + area(b) - i) rem_areas = torch.index_select(area, 0, idx) # load remaining areas) union = (rem_areas - inter) + area[i] IoU = inter / union # store result in iou # keep only elements with an IoU <= overlap idx = idx[IoU.le(overlap)] return keep, count ================================================ FILE: extract_init_states/FaceBoxes/utils/build.py ================================================ # coding: utf-8 # -------------------------------------------------------- # Fast R-CNN # Copyright (c) 2015 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ross Girshick # -------------------------------------------------------- import os from os.path import join as pjoin import numpy as np from distutils.core import setup from distutils.extension import Extension from Cython.Distutils import build_ext def find_in_path(name, path): "Find a file in a search path" # adapted fom http://code.activestate.com/recipes/52224-find-a-file-given-a-search-path/ for dir in path.split(os.pathsep): binpath = pjoin(dir, name) if os.path.exists(binpath): return os.path.abspath(binpath) return None # Obtain the numpy include directory. This logic works across numpy versions. try: numpy_include = np.get_include() except AttributeError: numpy_include = np.get_numpy_include() # run the customize_compiler class custom_build_ext(build_ext): def build_extensions(self): # customize_compiler_for_nvcc(self.compiler) build_ext.build_extensions(self) ext_modules = [ Extension( "nms.cpu_nms", ["nms/cpu_nms.pyx"], # extra_compile_args={'gcc': ["-Wno-cpp", "-Wno-unused-function"]}, # extra_compile_args=["-Wno-cpp", "-Wno-unused-function"], # !!! if you are on windows platform, you need to comment this line include_dirs=[numpy_include] ) ] setup( name='mot_utils', ext_modules=ext_modules, # inject our custom trigger cmdclass={'build_ext': custom_build_ext}, ) ================================================ FILE: extract_init_states/FaceBoxes/utils/config.py ================================================ # coding: utf-8 cfg = { 'name': 'FaceBoxes', 'min_sizes': [[32, 64, 128], [256], [512]], 'steps': [32, 64, 128], 'variance': [0.1, 0.2], 'clip': False } ================================================ FILE: extract_init_states/FaceBoxes/utils/functions.py ================================================ # coding: utf-8 import sys import os.path as osp import torch def check_keys(model, pretrained_state_dict): ckpt_keys = set(pretrained_state_dict.keys()) model_keys = set(model.state_dict().keys()) used_pretrained_keys = model_keys & ckpt_keys unused_pretrained_keys = ckpt_keys - model_keys missing_keys = model_keys - ckpt_keys # print('Missing keys:{}'.format(len(missing_keys))) # print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys))) # print('Used keys:{}'.format(len(used_pretrained_keys))) assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' return True def remove_prefix(state_dict, prefix): ''' Old style model is stored with all names of parameters sharing common prefix 'module.' ''' # print('remove prefix \'{}\''.format(prefix)) f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x return {f(key): value for key, value in state_dict.items()} def load_model(model, pretrained_path, load_to_cpu): if not osp.isfile(pretrained_path): print(f'The pre-trained FaceBoxes model {pretrained_path} does not exist') sys.exit('-1') # print('Loading pretrained model from {}'.format(pretrained_path)) if load_to_cpu: pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage) else: device = torch.cuda.current_device() pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device)) if "state_dict" in pretrained_dict.keys(): pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.') else: pretrained_dict = remove_prefix(pretrained_dict, 'module.') check_keys(model, pretrained_dict) model.load_state_dict(pretrained_dict, strict=False) return model ================================================ FILE: extract_init_states/FaceBoxes/utils/nms/.gitignore ================================================ *.c *.so ================================================ FILE: extract_init_states/FaceBoxes/utils/nms/__init__.py ================================================ ================================================ FILE: extract_init_states/FaceBoxes/utils/nms/cpu_nms.pyx ================================================ # -------------------------------------------------------- # Fast R-CNN # Copyright (c) 2015 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ross Girshick # -------------------------------------------------------- import numpy as np cimport numpy as np cdef inline np.float32_t max(np.float32_t a, np.float32_t b): return a if a >= b else b cdef inline np.float32_t min(np.float32_t a, np.float32_t b): return a if a <= b else b def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh): cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0] cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1] cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2] cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3] cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4] cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1) cdef np.ndarray[np.int64_t, ndim=1] order = scores.argsort()[::-1] cdef int ndets = dets.shape[0] cdef np.ndarray[np.int64_t, ndim=1] suppressed = \ np.zeros((ndets), dtype=np.int64) # nominal indices cdef int _i, _j # sorted indices cdef int i, j # temp variables for box i's (the box currently under consideration) cdef np.float32_t ix1, iy1, ix2, iy2, iarea # variables for computing overlap with box j (lower scoring box) cdef np.float32_t xx1, yy1, xx2, yy2 cdef np.float32_t w, h cdef np.float32_t inter, ovr keep = [] for _i in range(ndets): i = order[_i] if suppressed[i] == 1: continue keep.append(i) ix1 = x1[i] iy1 = y1[i] ix2 = x2[i] iy2 = y2[i] iarea = areas[i] for _j in range(_i + 1, ndets): j = order[_j] if suppressed[j] == 1: continue xx1 = max(ix1, x1[j]) yy1 = max(iy1, y1[j]) xx2 = min(ix2, x2[j]) yy2 = min(iy2, y2[j]) w = max(0.0, xx2 - xx1 + 1) h = max(0.0, yy2 - yy1 + 1) inter = w * h ovr = inter / (iarea + areas[j] - inter) if ovr >= thresh: suppressed[j] = 1 return keep def cpu_soft_nms(np.ndarray[float, ndim=2] boxes, float sigma=0.5, float Nt=0.3, float threshold=0.001, unsigned int method=0): cdef unsigned int N = boxes.shape[0] cdef float iw, ih, box_area cdef float ua cdef int pos = 0 cdef float maxscore = 0 cdef int maxpos = 0 cdef float x1,x2,y1,y2,tx1,tx2,ty1,ty2,ts,area,weight,ov for i in range(N): maxscore = boxes[i, 4] maxpos = i tx1 = boxes[i,0] ty1 = boxes[i,1] tx2 = boxes[i,2] ty2 = boxes[i,3] ts = boxes[i,4] pos = i + 1 # get max box while pos < N: if maxscore < boxes[pos, 4]: maxscore = boxes[pos, 4] maxpos = pos pos = pos + 1 # add max box as a detection boxes[i,0] = boxes[maxpos,0] boxes[i,1] = boxes[maxpos,1] boxes[i,2] = boxes[maxpos,2] boxes[i,3] = boxes[maxpos,3] boxes[i,4] = boxes[maxpos,4] # swap ith box with position of max box boxes[maxpos,0] = tx1 boxes[maxpos,1] = ty1 boxes[maxpos,2] = tx2 boxes[maxpos,3] = ty2 boxes[maxpos,4] = ts tx1 = boxes[i,0] ty1 = boxes[i,1] tx2 = boxes[i,2] ty2 = boxes[i,3] ts = boxes[i,4] pos = i + 1 # NMS iterations, note that N changes if detection boxes fall below threshold while pos < N: x1 = boxes[pos, 0] y1 = boxes[pos, 1] x2 = boxes[pos, 2] y2 = boxes[pos, 3] s = boxes[pos, 4] area = (x2 - x1 + 1) * (y2 - y1 + 1) iw = (min(tx2, x2) - max(tx1, x1) + 1) if iw > 0: ih = (min(ty2, y2) - max(ty1, y1) + 1) if ih > 0: ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih) ov = iw * ih / ua #iou between max box and detection box if method == 1: # linear if ov > Nt: weight = 1 - ov else: weight = 1 elif method == 2: # gaussian weight = np.exp(-(ov * ov)/sigma) else: # original NMS if ov > Nt: weight = 0 else: weight = 1 boxes[pos, 4] = weight*boxes[pos, 4] # if box score falls below threshold, discard the box by swapping with last box # update N if boxes[pos, 4] < threshold: boxes[pos,0] = boxes[N-1, 0] boxes[pos,1] = boxes[N-1, 1] boxes[pos,2] = boxes[N-1, 2] boxes[pos,3] = boxes[N-1, 3] boxes[pos,4] = boxes[N-1, 4] N = N - 1 pos = pos - 1 pos = pos + 1 keep = [i for i in range(N)] return keep ================================================ FILE: extract_init_states/FaceBoxes/utils/nms/py_cpu_nms.py ================================================ # -------------------------------------------------------- # Fast R-CNN # Copyright (c) 2015 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ross Girshick # -------------------------------------------------------- import numpy as np def py_cpu_nms(dets, thresh): """Pure Python NMS baseline.""" x1 = dets[:, 0] y1 = dets[:, 1] x2 = dets[:, 2] y2 = dets[:, 3] scores = dets[:, 4] areas = (x2 - x1 + 1) * (y2 - y1 + 1) order = scores.argsort()[::-1] keep = [] while order.size > 0: i = order[0] keep.append(i) xx1 = np.maximum(x1[i], x1[order[1:]]) yy1 = np.maximum(y1[i], y1[order[1:]]) xx2 = np.minimum(x2[i], x2[order[1:]]) yy2 = np.minimum(y2[i], y2[order[1:]]) w = np.maximum(0.0, xx2 - xx1 + 1) h = np.maximum(0.0, yy2 - yy1 + 1) inter = w * h ovr = inter / (areas[i] + areas[order[1:]] - inter) inds = np.where(ovr <= thresh)[0] order = order[inds + 1] return keep ================================================ FILE: extract_init_states/FaceBoxes/utils/nms_wrapper.py ================================================ # coding: utf-8 # -------------------------------------------------------- # Fast R-CNN # Copyright (c) 2015 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ross Girshick # -------------------------------------------------------- from .nms.cpu_nms import cpu_nms, cpu_soft_nms def nms(dets, thresh): """Dispatch to either CPU or GPU NMS implementations.""" if dets.shape[0] == 0: return [] return cpu_nms(dets, thresh) # return gpu_nms(dets, thresh) ================================================ FILE: extract_init_states/FaceBoxes/utils/prior_box.py ================================================ # coding: utf-8 from .config import cfg import torch from itertools import product as product from math import ceil class PriorBox(object): def __init__(self, image_size=None): super(PriorBox, self).__init__() # self.aspect_ratios = cfg['aspect_ratios'] self.min_sizes = cfg['min_sizes'] self.steps = cfg['steps'] self.clip = cfg['clip'] self.image_size = image_size self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps] def forward(self): anchors = [] for k, f in enumerate(self.feature_maps): min_sizes = self.min_sizes[k] for i, j in product(range(f[0]), range(f[1])): for min_size in min_sizes: s_kx = min_size / self.image_size[1] s_ky = min_size / self.image_size[0] if min_size == 32: dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0, j + 0.25, j + 0.5, j + 0.75]] dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0, i + 0.25, i + 0.5, i + 0.75]] for cy, cx in product(dense_cy, dense_cx): anchors += [cx, cy, s_kx, s_ky] elif min_size == 64: dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0, j + 0.5]] dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0, i + 0.5]] for cy, cx in product(dense_cy, dense_cx): anchors += [cx, cy, s_kx, s_ky] else: cx = (j + 0.5) * self.steps[k] / self.image_size[1] cy = (i + 0.5) * self.steps[k] / self.image_size[0] anchors += [cx, cy, s_kx, s_ky] # back to torch land output = torch.Tensor(anchors).view(-1, 4) if self.clip: output.clamp_(max=1, min=0) return output ================================================ FILE: extract_init_states/FaceBoxes/utils/timer.py ================================================ # coding: utf-8 # -------------------------------------------------------- # Fast R-CNN # Copyright (c) 2015 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ross Girshick # -------------------------------------------------------- import time class Timer(object): """A simple timer.""" def __init__(self): self.total_time = 0. self.calls = 0 self.start_time = 0. self.diff = 0. self.average_time = 0. def tic(self): # using time.time instead of time.clock because time time.clock # does not normalize for multithreading self.start_time = time.time() def toc(self, average=True): self.diff = time.time() - self.start_time self.total_time += self.diff self.calls += 1 self.average_time = self.total_time / self.calls if average: return self.average_time else: return self.diff def clear(self): self.total_time = 0. self.calls = 0 self.start_time = 0. self.diff = 0. self.average_time = 0. ================================================ FILE: extract_init_states/FaceBoxes/weights/.gitignore ================================================ *.onnx ================================================ FILE: extract_init_states/FaceBoxes/weights/readme.md ================================================ The pre-trained model `FaceBoxesProd.pth` is downloaded from [Google Drive](https://drive.google.com/file/d/1tRVwOlu0QtjvADQ2H7vqrRwsWEmaqioI). The converted `FaceBoxesProd.onnx`: [Google Drive](https://drive.google.com/file/d/1pccQOvYqKh3iCEHc5tSWx2-1fhgxs6rh/view?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1TJS2wFRLSoWZPR4l9E7G7w) (Password: 9hph) ================================================ FILE: extract_init_states/TDDFA_ONNX.py ================================================ # coding: utf-8 __author__ = 'cleardusk' import os import sys current_dir = os.path.dirname(os.path.abspath(__file__)) if current_dir not in sys.path: sys.path.append(current_dir) print(current_dir) import os.path as osp import numpy as np import cv2 import onnxruntime from utils.onnx import convert_to_onnx from utils.io import _load from utils.functions import ( crop_img, parse_roi_box_from_bbox, parse_roi_box_from_landmark, ) from utils.tddfa_util import _parse_param, similar_transform from bfm.bfm import BFMModel from bfm.bfm_onnx import convert_bfm_to_onnx make_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn) class TDDFA_ONNX(object): """TDDFA_ONNX: the ONNX version of Three-D Dense Face Alignment (TDDFA)""" def __init__(self, **kvs): # torch.set_grad_enabled(False) # load onnx version of BFM bfm_fp = make_abs_path(kvs.get('bfm_fp', 'configs/bfm_noneck_v3.pkl')) bfm_onnx_fp = bfm_fp.replace('.pkl', '.onnx') if not osp.exists(bfm_onnx_fp): convert_bfm_to_onnx( bfm_onnx_fp, shape_dim=kvs.get('shape_dim', 40), exp_dim=kvs.get('exp_dim', 10) ) self.bfm_session = onnxruntime.InferenceSession(bfm_onnx_fp, providers=['CUDAExecutionProvider']) # load for optimization bfm = BFMModel(bfm_fp, shape_dim=kvs.get('shape_dim', 40), exp_dim=kvs.get('exp_dim', 10)) self.tri = bfm.tri self.u_base, self.w_shp_base, self.w_exp_base = bfm.u_base, bfm.w_shp_base, bfm.w_exp_base # config self.gpu_mode = kvs.get('gpu_mode', True) self.gpu_id = kvs.get('gpu_id', 0) self.size = kvs.get('size', 120) param_mean_std_fp = make_abs_path(kvs.get( 'param_mean_std_fp', f'configs/param_mean_std_62d_{self.size}x{self.size}.pkl') ) onnx_fp = make_abs_path(kvs.get('onnx_fp', kvs.get('checkpoint_fp').replace('.pth', '.onnx'))) # convert to onnx online if not existed if onnx_fp is None or not osp.exists(onnx_fp): print(f'{onnx_fp} does not exist, try to convert the `.pth` version to `.onnx` online') onnx_fp = convert_to_onnx(**kvs) self.session = onnxruntime.InferenceSession(onnx_fp, providers=['CUDAExecutionProvider']) # params normalization config r = _load(param_mean_std_fp) self.param_mean = r.get('mean') self.param_std = r.get('std') def __call__(self, img_ori, objs, **kvs): # Crop image, forward to get the param param_lst = [] roi_box_lst = [] crop_policy = kvs.get('crop_policy', 'box') for obj in objs: if crop_policy == 'box': # by face box roi_box = parse_roi_box_from_bbox(obj) elif crop_policy == 'landmark': # by landmarks roi_box = parse_roi_box_from_landmark(obj) else: raise ValueError(f'Unknown crop policy {crop_policy}') roi_box_lst.append(roi_box) img = crop_img(img_ori, roi_box) img = cv2.resize(img, dsize=(self.size, self.size), interpolation=cv2.INTER_LINEAR) img = img.astype(np.float32).transpose(2, 0, 1)[np.newaxis, ...] img = (img - 127.5) / 128. inp_dct = {'input': img} param = self.session.run(None, inp_dct)[0] param = param.flatten().astype(np.float32) param = param * self.param_std + self.param_mean # re-scale param_lst.append(param) return param_lst, roi_box_lst def recon_vers(self, param_lst, roi_box_lst, **kvs): dense_flag = kvs.get('dense_flag', False) size = self.size ver_lst = [] for param, roi_box in zip(param_lst, roi_box_lst): R, offset, alpha_shp, alpha_exp = _parse_param(param) if dense_flag: inp_dct = { 'R': R, 'offset': offset, 'alpha_shp': alpha_shp, 'alpha_exp': alpha_exp } pts3d = self.bfm_session.run(None, inp_dct)[0] pts3d = similar_transform(pts3d, roi_box, size) else: pts3d = R @ (self.u_base + self.w_shp_base @ alpha_shp + self.w_exp_base @ alpha_exp). \ reshape(3, -1, order='F') + offset pts3d = similar_transform(pts3d, roi_box, size) ver_lst.append(pts3d) return ver_lst ================================================ FILE: extract_init_states/bfm/.gitignore ================================================ *.ply ================================================ FILE: extract_init_states/bfm/__init__.py ================================================ from .bfm import BFMModel ================================================ FILE: extract_init_states/bfm/bfm.py ================================================ # coding: utf-8 __author__ = 'cleardusk' import sys sys.path.append('..') import os.path as osp import numpy as np from utils.io import _load make_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn) def _to_ctype(arr): if not arr.flags.c_contiguous: return arr.copy(order='C') return arr class BFMModel(object): def __init__(self, bfm_fp, shape_dim=40, exp_dim=10): bfm = _load(bfm_fp) self.u = bfm.get('u').astype(np.float32) # fix bug self.w_shp = bfm.get('w_shp').astype(np.float32)[..., :shape_dim] self.w_exp = bfm.get('w_exp').astype(np.float32)[..., :exp_dim] if osp.split(bfm_fp)[-1] == 'bfm_noneck_v3.pkl': self.tri = _load(make_abs_path('../configs/tri.pkl')) # this tri/face is re-built for bfm_noneck_v3 else: self.tri = bfm.get('tri') self.tri = _to_ctype(self.tri.T).astype(np.int32) self.keypoints = bfm.get('keypoints').astype(np.int64) # fix bug w = np.concatenate((self.w_shp, self.w_exp), axis=1) self.w_norm = np.linalg.norm(w, axis=0) self.u_base = self.u[self.keypoints].reshape(-1, 1) self.w_shp_base = self.w_shp[self.keypoints] self.w_exp_base = self.w_exp[self.keypoints] ================================================ FILE: extract_init_states/bfm/bfm_onnx.py ================================================ # coding: utf-8 __author__ = 'cleardusk' import sys sys.path.append('..') import os.path as osp import numpy as np import torch import torch.nn as nn from utils.io import _load, _numpy_to_cuda, _numpy_to_tensor make_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn) def _to_ctype(arr): if not arr.flags.c_contiguous: return arr.copy(order='C') return arr def _load_tri(bfm_fp): if osp.split(bfm_fp)[-1] == 'bfm_noneck_v3.pkl': tri = _load(make_abs_path('../configs/tri.pkl')) # this tri/face is re-built for bfm_noneck_v3 else: tri = _load(bfm_fp).get('tri') tri = _to_ctype(tri.T).astype(np.int32) return tri class BFMModel_ONNX(nn.Module): """BFM serves as a decoder""" def __init__(self, bfm_fp, shape_dim=40, exp_dim=10): super(BFMModel_ONNX, self).__init__() _to_tensor = _numpy_to_tensor # load bfm bfm = _load(bfm_fp) u = _to_tensor(bfm.get('u').astype(np.float32)) self.u = u.view(-1, 3).transpose(1, 0) w_shp = _to_tensor(bfm.get('w_shp').astype(np.float32)[..., :shape_dim]) w_exp = _to_tensor(bfm.get('w_exp').astype(np.float32)[..., :exp_dim]) w = torch.cat((w_shp, w_exp), dim=1) self.w = w.view(-1, 3, w.shape[-1]).contiguous().permute(1, 0, 2) # self.u = _to_tensor(bfm.get('u').astype(np.float32)) # fix bug # w_shp = _to_tensor(bfm.get('w_shp').astype(np.float32)[..., :shape_dim]) # w_exp = _to_tensor(bfm.get('w_exp').astype(np.float32)[..., :exp_dim]) # self.w = torch.cat((w_shp, w_exp), dim=1) # self.keypoints = bfm.get('keypoints').astype(np.long) # fix bug # self.u_base = self.u[self.keypoints].reshape(-1, 1) # self.w_shp_base = self.w_shp[self.keypoints] # self.w_exp_base = self.w_exp[self.keypoints] def forward(self, *inps): R, offset, alpha_shp, alpha_exp = inps alpha = torch.cat((alpha_shp, alpha_exp)) # pts3d = R @ (self.u + self.w_shp.matmul(alpha_shp) + self.w_exp.matmul(alpha_exp)). \ # view(-1, 3).transpose(1, 0) + offset # pts3d = R @ (self.u + self.w.matmul(alpha)).view(-1, 3).transpose(1, 0) + offset pts3d = R @ (self.u + self.w.matmul(alpha).squeeze()) + offset return pts3d def convert_bfm_to_onnx(bfm_onnx_fp, shape_dim=40, exp_dim=10): # print(shape_dim, exp_dim) bfm_fp = bfm_onnx_fp.replace('.onnx', '.pkl') bfm_decoder = BFMModel_ONNX(bfm_fp=bfm_fp, shape_dim=shape_dim, exp_dim=exp_dim) bfm_decoder.eval() # dummy_input = torch.randn(12 + shape_dim + exp_dim) dummy_input = torch.randn(3, 3), torch.randn(3, 1), torch.randn(shape_dim, 1), torch.randn(exp_dim, 1) R, offset, alpha_shp, alpha_exp = dummy_input torch.onnx.export( bfm_decoder, (R, offset, alpha_shp, alpha_exp), bfm_onnx_fp, input_names=['R', 'offset', 'alpha_shp', 'alpha_exp'], output_names=['output'], dynamic_axes={ 'alpha_shp': [0], 'alpha_exp': [0], }, do_constant_folding=True ) print(f'Convert {bfm_fp} to {bfm_onnx_fp} done.') if __name__ == '__main__': convert_bfm_to_onnx('../configs/bfm_noneck_v3.onnx') ================================================ FILE: extract_init_states/bfm/readme.md ================================================ ## Statement The modified BFM2009 face model in `../configs/bfm_noneck_v3.pkl` is only for academic use. For commercial use, you need to apply for the commercial license, some refs are below: [1] https://faces.dmi.unibas.ch/bfm/?nav=1-0&id=basel_face_model [2] https://faces.dmi.unibas.ch/bfm/bfm2019.html If your work benefits from this repo, please cite @PROCEEDINGS{bfm09, title={A 3D Face Model for Pose and Illumination Invariant Face Recognition}, author={P. Paysan and R. Knothe and B. Amberg and S. Romdhani and T. Vetter}, journal={Proceedings of the 6th IEEE International Conference on Advanced Video and Signal based Surveillance (AVSS) for Security, Safety and Monitoring in Smart Environments}, organization={IEEE}, year={2009}, address = {Genova, Italy}, } ================================================ FILE: extract_init_states/build.sh ================================================ cd FaceBoxes sh ./build_cpu_nms.sh cd .. # cd Sim3DR # sh ./build_sim3dr.sh # cd .. cd utils/asset gcc -shared -Wall -O3 render.c -o render.so -fPIC cd ../.. ================================================ FILE: extract_init_states/configs/.gitignore ================================================ # *.pkl # *.yml # *.onnx ================================================ FILE: extract_init_states/configs/bfm_noneck_v3.onnx ================================================ [File too large to display: 22.4 MB] ================================================ FILE: extract_init_states/configs/bfm_noneck_v3.pkl ================================================ [File too large to display: 23.3 MB] ================================================ FILE: extract_init_states/configs/mb05_120x120.yml ================================================ arch: mobilenet # MobileNet V1 widen_factor: 0.5 checkpoint_fp: weights/mb05_120x120.pth bfm_fp: configs/bfm_noneck_v3.pkl # or configs/bfm_noneck_v3_slim.pkl size: 120 num_params: 62 ================================================ FILE: extract_init_states/configs/mb1_120x120.yml ================================================ arch: mobilenet # MobileNet V1 widen_factor: 1.0 checkpoint_fp: weights/mb1_120x120.pth bfm_fp: configs/bfm_noneck_v3.pkl # or configs/bfm_noneck_v3_slim.pkl size: 120 num_params: 62 ================================================ FILE: extract_init_states/configs/readme.md ================================================ ## The simplified version of BFM `bfm_noneck_v3_slim.pkl`: [Google Drive](https://drive.google.com/file/d/1iK5lD49E_gCn9voUjWDPj2ItGKvM10GI/view?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1C_SzYBOG3swZA_EjxpXlAw) (Password: p803) ================================================ FILE: extract_init_states/configs/resnet_120x120.yml ================================================ # before using this config, go through readme.md to find the onnx links to download `resnet22.onnx` arch: resnet22 checkpoint_fp: weights/resnet22.pth bfm_fp: configs/bfm_noneck_v3.pkl size: 120 num_params: 62 ================================================ FILE: extract_init_states/demo_pose_extract_2d_lmk_img.py ================================================ # coding: utf-8 # based on 3DDFA __author__ = 'cleardusk' import sys import os current_dir = os.path.dirname(os.path.abspath(__file__)) print(current_dir) if current_dir not in sys.path: sys.path.append(current_dir) print(current_dir) import argparse import cv2 import yaml import time from yaml import safe_dump from FaceBoxes import FaceBoxes import numpy as np from tqdm import tqdm import copy import time from utils.pose import viz_pose, get_pose from utils.serialization import ser_to_ply, ser_to_obj from utils.functions import draw_landmarks, get_suffix, calculate_eye, calculate_bbox from utils.tddfa_util import str2bool import concurrent.futures from multiprocessing import Pool def main(args,img, save_path, pose_path): # begin = time.time() # face_boxes.eval() # Given a still image path and load to BGR channel # img = cv2.imread(img_path) #args.img_fp # Detect faces, get 3DMM params and roi boxes # start_time = time.time() boxes = face_boxes(img) # end_time = time.time() # execution_time = end_time - start_time # print(f'box time: {execution_time}') n = len(boxes) if n == 0: print(f'No face detected, exit') # sys.exit(-1) return None # print(f'Detect {n} faces') # start_time = time.time() param_lst, roi_box_lst = tddfa(img, boxes) # end_time = time.time() # execution_time = end_time - start_time # print(f'tddfa time: {execution_time}') #detection time # detect_time = time.time()-begin # print('detection time: '+str(detect_time), file=open('/mnt/lustre/jixinya/Home/3DDFA_V2/pose.txt', 'a')) # Visualization and serialization dense_flag = args.opt in ('2d_dense', '3d', 'depth', 'pncc', 'uv_tex', 'ply', 'obj') # old_suffix = get_suffix(img_path) old_suffix = 'png' new_suffix = f'.{args.opt}' if args.opt in ('ply', 'obj') else '.jpg' wfp = f'examples/results/{args.img_fp.split("/")[-1].replace(old_suffix, "")}_{args.opt}' + new_suffix # start_time = time.time() ver_lst = tddfa.recon_vers(param_lst, roi_box_lst, dense_flag=dense_flag) # end_time = time.time() # execution_time = end_time - start_time # print(f'tddfa.recon_vers time: {execution_time}') # start_time = time.time() all_pose = get_pose(img, param_lst, ver_lst, show_flag=args.show_flag, wfp=save_path, wnp = pose_path) end_time = time.time() return all_pose, ver_lst if __name__ == '__main__': parser = argparse.ArgumentParser(description='The demo of still image of 3DDFA_V2') parser.add_argument('-c', '--config', type=str, default=f'{current_dir}/configs/mb1_120x120.yml') parser.add_argument('-f', '--img_fp', type=str, default='/disk2/pfhu/DAWN-pytorch/images/image/anime_female2.jpeg') parser.add_argument('-m', '--mode', type=str, default='gpu', help='gpu or cpu mode') parser.add_argument('-o', '--opt', type=str, default='pose', choices=['2d_sparse', '2d_dense', '3d', 'depth', 'pncc', 'uv_tex', 'pose', 'ply', 'obj']) parser.add_argument('--show_flag', type=str2bool, default='False', help='whether to show the visualization result') parser.add_argument('--onnx', action='store_true', default=True) parser.add_argument('-p', '--part', type=int, default=1) parser.add_argument('-a', '--all', type=int, default=1) parser.add_argument('-i', '--input', type=str) parser.add_argument('-t', '--output', type=str) args = parser.parse_args() part = args.part all_part = args.all filepath = args.input save_path = args.output if not os.path.exists(save_path): os.makedirs(save_path) start_point = 30 #int((part - 1) *duration) cfg = yaml.load(open(args.config), Loader=yaml.SafeLoader) # Init FaceBoxes and TDDFA, recommend using onnx flag if args.onnx: import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' # os.environ['OMP_WAIT_POLICY'] = 'PASSIVE' os.environ['OMP_NUM_THREADS'] = '8' from FaceBoxes.FaceBoxes_ONNX import FaceBoxes_ONNX from TDDFA_ONNX import TDDFA_ONNX face_boxes = FaceBoxes_ONNX() tddfa = TDDFA_ONNX(**cfg) else: gpu_mode = args.mode == 'gpu' tddfa = TDDFA(gpu_mode=gpu_mode, **cfg) # tddfa.eval() face_boxes = FaceBoxes() # save_path_pose = os.path.join(save_path, 'tmp.npy') image= cv2.imread(filepath) if image.shape[2] == 4: image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) pose, lmk = main(args,image, save_path = None, pose_path = None) lmk = lmk[0] eye_bbox_result = np.zeros(8) bbox = calculate_bbox(image, lmk) left_ratio, right_ratio = calculate_eye(lmk) eye_bbox_result[0] = left_ratio.sum() eye_bbox_result[1] = right_ratio.sum() eye_bbox_result[2:] = np.array(bbox) pose = pose.reshape(1,7) eye_bbox_result = eye_bbox_result.reshape(1, -1) eye_bbox_path = os.path.join(save_path, 'init_eye_bbox.npy') pose_path = os.path.join(save_path, 'init_pose.npy') np.save(eye_bbox_path, eye_bbox_result) np.save(pose_path, pose) ================================================ FILE: extract_init_states/functions.py ================================================ # coding: utf-8 __author__ = 'cleardusk' import numpy as np import cv2 from math import sqrt import matplotlib.pyplot as plt RED = (0, 0, 255) GREEN = (0, 255, 0) BLUE = (255, 0, 0) def get_suffix(filename): """a.jpg -> jpg""" pos = filename.rfind('.') if pos == -1: return '' return filename[pos:] def crop_img(img, roi_box): h, w = img.shape[:2] sx, sy, ex, ey = [int(round(_)) for _ in roi_box] dh, dw = ey - sy, ex - sx if len(img.shape) == 3: res = np.zeros((dh, dw, 3), dtype=np.uint8) else: res = np.zeros((dh, dw), dtype=np.uint8) if sx < 0: sx, dsx = 0, -sx else: dsx = 0 if ex > w: ex, dex = w, dw - (ex - w) else: dex = dw if sy < 0: sy, dsy = 0, -sy else: dsy = 0 if ey > h: ey, dey = h, dh - (ey - h) else: dey = dh res[dsy:dey, dsx:dex] = img[sy:ey, sx:ex] return res def calc_hypotenuse(pts): bbox = [min(pts[0, :]), min(pts[1, :]), max(pts[0, :]), max(pts[1, :])] center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] radius = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 bbox = [center[0] - radius, center[1] - radius, center[0] + radius, center[1] + radius] llength = sqrt((bbox[2] - bbox[0]) ** 2 + (bbox[3] - bbox[1]) ** 2) return llength / 3 def parse_roi_box_from_landmark(pts): """calc roi box from landmark""" bbox = [min(pts[0, :]), min(pts[1, :]), max(pts[0, :]), max(pts[1, :])] center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] radius = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 bbox = [center[0] - radius, center[1] - radius, center[0] + radius, center[1] + radius] llength = sqrt((bbox[2] - bbox[0]) ** 2 + (bbox[3] - bbox[1]) ** 2) center_x = (bbox[2] + bbox[0]) / 2 center_y = (bbox[3] + bbox[1]) / 2 roi_box = [0] * 4 roi_box[0] = center_x - llength / 2 roi_box[1] = center_y - llength / 2 roi_box[2] = roi_box[0] + llength roi_box[3] = roi_box[1] + llength return roi_box def parse_roi_box_from_bbox(bbox): left, top, right, bottom = bbox[:4] old_size = (right - left + bottom - top) / 2 center_x = right - (right - left) / 2.0 center_y = bottom - (bottom - top) / 2.0 + old_size * 0.14 size = int(old_size * 1.58) roi_box = [0] * 4 roi_box[0] = center_x - size / 2 roi_box[1] = center_y - size / 2 roi_box[2] = roi_box[0] + size roi_box[3] = roi_box[1] + size return roi_box def plot_image(img): height, width = img.shape[:2] plt.figure(figsize=(12, height / width * 12)) plt.subplots_adjust(left=0, right=1, top=1, bottom=0) plt.axis('off') plt.imshow(img[..., ::-1]) plt.show() def draw_landmarks(img, pts, style='fancy', wfp=None, show_flag=False, **kwargs): """Draw landmarks using matplotlib""" height, width = img.shape[:2] plt.figure(figsize=(12, height / width * 12)) plt.imshow(img[..., ::-1]) plt.subplots_adjust(left=0, right=1, top=1, bottom=0) plt.axis('off') dense_flag = kwargs.get('dense_flag') if not type(pts) in [tuple, list]: pts = [pts] for i in range(len(pts)): if dense_flag: plt.plot(pts[i][0, ::6], pts[i][1, ::6], 'o', markersize=0.4, color='c', alpha=0.7) else: alpha = 0.8 markersize = 4 lw = 1.5 color = kwargs.get('color', 'w') markeredgecolor = kwargs.get('markeredgecolor', 'black') nums = [0, 17, 22, 27, 31, 36, 42, 48, 60, 68] # close eyes and mouths plot_close = lambda i1, i2: plt.plot([pts[i][0, i1], pts[i][0, i2]], [pts[i][1, i1], pts[i][1, i2]], color=color, lw=lw, alpha=alpha - 0.1) plot_close(41, 36) plot_close(47, 42) plot_close(59, 48) plot_close(67, 60) for ind in range(len(nums) - 1): l, r = nums[ind], nums[ind + 1] plt.plot(pts[i][0, l:r], pts[i][1, l:r], color=color, lw=lw, alpha=alpha - 0.1) plt.plot(pts[i][0, l:r], pts[i][1, l:r], marker='o', linestyle='None', markersize=markersize, color=color, markeredgecolor=markeredgecolor, alpha=alpha) if wfp is not None: plt.savefig(wfp, dpi=150) print(f'Save visualization result to {wfp}') if show_flag: plt.show() def cv_draw_landmark(img_ori, pts, box=None, color=GREEN, size=1): img = img_ori.copy() n = pts.shape[1] if n <= 106: for i in range(n): cv2.circle(img, (int(round(pts[0, i])), int(round(pts[1, i]))), size, color, -1) else: sep = 1 for i in range(0, n, sep): cv2.circle(img, (int(round(pts[0, i])), int(round(pts[1, i]))), size, color, 1) if box is not None: left, top, right, bottom = np.round(box).astype(np.int32) left_top = (left, top) right_top = (right, top) right_bottom = (right, bottom) left_bottom = (left, bottom) cv2.line(img, left_top, right_top, BLUE, 1, cv2.LINE_AA) cv2.line(img, right_top, right_bottom, BLUE, 1, cv2.LINE_AA) cv2.line(img, right_bottom, left_bottom, BLUE, 1, cv2.LINE_AA) cv2.line(img, left_bottom, left_top, BLUE, 1, cv2.LINE_AA) return img def calculate_bbox(img, lmk): lmk = lmk.transpose(1,0) # point_3d_homo = np.hstack((lmk, np.ones([lmk.shape[0], 1]))) # n x 4 # point_2d = point_3d_homo.dot(P.T)[:, :2] # point_2d[:, 1] = - point_2d[:, 1] # point_2d[:, :2] = point_2d[:, :2] - np.mean(point_2d, 0) + np.mean(lmk[:27,:2], 0) # lmk 0-27 point_2d = lmk[:, :2] point_2d = np.int32(point_2d.reshape(-1, 2)) H = img.shape[0] W = img.shape[1] x_min, x_max = point_2d[:, 0].min(), point_2d[:, 0].max() y_min, y_max = point_2d[:, 1].min(), point_2d[:, 1].max() # cv2.polylines(img, [point_2d], True, (40, 255, 0), 2, cv2.LINE_AA) # points_list = [(p[0], p[1]) for p in point_2d] # for p in points_list: # cv2.circle(img, p, 1, (40, 255, 0), -1) # return img return [x_min, x_max, y_min, y_max, H, W] def calculate_eye(lmk): ''' left right obj ''' lmk = lmk.transpose(1,0) leye_upper = lmk[43] leye_lower = lmk[47] leye_left = lmk[45] leye_right = lmk[42] reye_upper = lmk[37] reye_lower = lmk[41] reye_left = lmk[39] reye_right = lmk[36] left_ratio = np.linalg.norm(leye_upper - leye_lower, 2) / np.linalg.norm(leye_left - leye_right, 2) right_ratio = np.linalg.norm(reye_upper - reye_lower, 2) / np.linalg.norm(reye_left - reye_right, 2) return left_ratio, right_ratio ================================================ FILE: extract_init_states/models/__init__.py ================================================ from .mobilenet_v1 import * from .mobilenet_v3 import * from .resnet import * ================================================ FILE: extract_init_states/models/mobilenet_v1.py ================================================ # coding: utf-8 from __future__ import division """ Creates a MobileNet Model as defined in: Andrew G. Howard Menglong Zhu Bo Chen, et.al. (2017). MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications. Copyright (c) Yang Lu, 2017 Modified By cleardusk """ import math import torch.nn as nn __all__ = ['MobileNet', 'mobilenet'] # __all__ = ['mobilenet_2', 'mobilenet_1', 'mobilenet_075', 'mobilenet_05', 'mobilenet_025'] class DepthWiseBlock(nn.Module): def __init__(self, inplanes, planes, stride=1, prelu=False): super(DepthWiseBlock, self).__init__() inplanes, planes = int(inplanes), int(planes) self.conv_dw = nn.Conv2d(inplanes, inplanes, kernel_size=3, padding=1, stride=stride, groups=inplanes, bias=False) self.bn_dw = nn.BatchNorm2d(inplanes) self.conv_sep = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False) self.bn_sep = nn.BatchNorm2d(planes) if prelu: self.relu = nn.PReLU() else: self.relu = nn.ReLU(inplace=True) def forward(self, x): out = self.conv_dw(x) out = self.bn_dw(out) out = self.relu(out) out = self.conv_sep(out) out = self.bn_sep(out) out = self.relu(out) return out class MobileNet(nn.Module): def __init__(self, widen_factor=1.0, num_classes=1000, prelu=False, input_channel=3): """ Constructor Args: widen_factor: config of widen_factor num_classes: number of classes """ super(MobileNet, self).__init__() block = DepthWiseBlock self.conv1 = nn.Conv2d(input_channel, int(32 * widen_factor), kernel_size=3, stride=2, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(int(32 * widen_factor)) if prelu: self.relu = nn.PReLU() else: self.relu = nn.ReLU(inplace=True) self.dw2_1 = block(32 * widen_factor, 64 * widen_factor, prelu=prelu) self.dw2_2 = block(64 * widen_factor, 128 * widen_factor, stride=2, prelu=prelu) self.dw3_1 = block(128 * widen_factor, 128 * widen_factor, prelu=prelu) self.dw3_2 = block(128 * widen_factor, 256 * widen_factor, stride=2, prelu=prelu) self.dw4_1 = block(256 * widen_factor, 256 * widen_factor, prelu=prelu) self.dw4_2 = block(256 * widen_factor, 512 * widen_factor, stride=2, prelu=prelu) self.dw5_1 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) self.dw5_2 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) self.dw5_3 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) self.dw5_4 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) self.dw5_5 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) self.dw5_6 = block(512 * widen_factor, 1024 * widen_factor, stride=2, prelu=prelu) self.dw6 = block(1024 * widen_factor, 1024 * widen_factor, prelu=prelu) self.avgpool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(int(1024 * widen_factor), num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.dw2_1(x) x = self.dw2_2(x) x = self.dw3_1(x) x = self.dw3_2(x) x = self.dw4_1(x) x = self.dw4_2(x) x = self.dw5_1(x) x = self.dw5_2(x) x = self.dw5_3(x) x = self.dw5_4(x) x = self.dw5_5(x) x = self.dw5_6(x) x = self.dw6(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc(x) return x def mobilenet(**kwargs): """ Construct MobileNet. widen_factor=1.0 for mobilenet_1 widen_factor=0.75 for mobilenet_075 widen_factor=0.5 for mobilenet_05 widen_factor=0.25 for mobilenet_025 """ # widen_factor = 1.0, num_classes = 1000 # model = MobileNet(widen_factor=widen_factor, num_classes=num_classes) # return model model = MobileNet( widen_factor=kwargs.get('widen_factor', 1.0), num_classes=kwargs.get('num_classes', 62) ) return model def mobilenet_2(num_classes=62, input_channel=3): model = MobileNet(widen_factor=2.0, num_classes=num_classes, input_channel=input_channel) return model def mobilenet_1(num_classes=62, input_channel=3): model = MobileNet(widen_factor=1.0, num_classes=num_classes, input_channel=input_channel) return model def mobilenet_075(num_classes=62, input_channel=3): model = MobileNet(widen_factor=0.75, num_classes=num_classes, input_channel=input_channel) return model def mobilenet_05(num_classes=62, input_channel=3): model = MobileNet(widen_factor=0.5, num_classes=num_classes, input_channel=input_channel) return model def mobilenet_025(num_classes=62, input_channel=3): model = MobileNet(widen_factor=0.25, num_classes=num_classes, input_channel=input_channel) return model ================================================ FILE: extract_init_states/models/mobilenet_v3.py ================================================ # coding: utf-8 import torch.nn as nn import torch.nn.functional as F __all__ = ['MobileNetV3', 'mobilenet_v3'] def conv_bn(inp, oup, stride, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, nlin_layer=nn.ReLU): return nn.Sequential( conv_layer(inp, oup, 3, stride, 1, bias=False), norm_layer(oup), nlin_layer(inplace=True) ) def conv_1x1_bn(inp, oup, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, nlin_layer=nn.ReLU): return nn.Sequential( conv_layer(inp, oup, 1, 1, 0, bias=False), norm_layer(oup), nlin_layer(inplace=True) ) class Hswish(nn.Module): def __init__(self, inplace=True): super(Hswish, self).__init__() self.inplace = inplace def forward(self, x): return x * F.relu6(x + 3., inplace=self.inplace) / 6. class Hsigmoid(nn.Module): def __init__(self, inplace=True): super(Hsigmoid, self).__init__() self.inplace = inplace def forward(self, x): return F.relu6(x + 3., inplace=self.inplace) / 6. class SEModule(nn.Module): def __init__(self, channel, reduction=4): super(SEModule, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channel, channel // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(channel // reduction, channel, bias=False), Hsigmoid() # nn.Sigmoid() ) def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y.expand_as(x) class Identity(nn.Module): def __init__(self, channel): super(Identity, self).__init__() def forward(self, x): return x def make_divisible(x, divisible_by=8): import numpy as np return int(np.ceil(x * 1. / divisible_by) * divisible_by) class MobileBottleneck(nn.Module): def __init__(self, inp, oup, kernel, stride, exp, se=False, nl='RE'): super(MobileBottleneck, self).__init__() assert stride in [1, 2] assert kernel in [3, 5] padding = (kernel - 1) // 2 self.use_res_connect = stride == 1 and inp == oup conv_layer = nn.Conv2d norm_layer = nn.BatchNorm2d if nl == 'RE': nlin_layer = nn.ReLU # or ReLU6 elif nl == 'HS': nlin_layer = Hswish else: raise NotImplementedError if se: SELayer = SEModule else: SELayer = Identity self.conv = nn.Sequential( # pw conv_layer(inp, exp, 1, 1, 0, bias=False), norm_layer(exp), nlin_layer(inplace=True), # dw conv_layer(exp, exp, kernel, stride, padding, groups=exp, bias=False), norm_layer(exp), SELayer(exp), nlin_layer(inplace=True), # pw-linear conv_layer(exp, oup, 1, 1, 0, bias=False), norm_layer(oup), ) def forward(self, x): if self.use_res_connect: return x + self.conv(x) else: return self.conv(x) class MobileNetV3(nn.Module): def __init__(self, widen_factor=1.0, num_classes=141, num_landmarks=136, input_size=120, mode='small'): super(MobileNetV3, self).__init__() input_channel = 16 last_channel = 1280 if mode == 'large': # refer to Table 1 in paper mobile_setting = [ # k, exp, c, se, nl, s, [3, 16, 16, False, 'RE', 1], [3, 64, 24, False, 'RE', 2], [3, 72, 24, False, 'RE', 1], [5, 72, 40, True, 'RE', 2], [5, 120, 40, True, 'RE', 1], [5, 120, 40, True, 'RE', 1], [3, 240, 80, False, 'HS', 2], [3, 200, 80, False, 'HS', 1], [3, 184, 80, False, 'HS', 1], [3, 184, 80, False, 'HS', 1], [3, 480, 112, True, 'HS', 1], [3, 672, 112, True, 'HS', 1], [5, 672, 160, True, 'HS', 2], [5, 960, 160, True, 'HS', 1], [5, 960, 160, True, 'HS', 1], ] elif mode == 'small': # refer to Table 2 in paper mobile_setting = [ # k, exp, c, se, nl, s, [3, 16, 16, True, 'RE', 2], [3, 72, 24, False, 'RE', 2], [3, 88, 24, False, 'RE', 1], [5, 96, 40, True, 'HS', 2], [5, 240, 40, True, 'HS', 1], [5, 240, 40, True, 'HS', 1], [5, 120, 48, True, 'HS', 1], [5, 144, 48, True, 'HS', 1], [5, 288, 96, True, 'HS', 2], [5, 576, 96, True, 'HS', 1], [5, 576, 96, True, 'HS', 1], ] else: raise NotImplementedError # building first layer assert input_size % 32 == 0 last_channel = make_divisible(last_channel * widen_factor) if widen_factor > 1.0 else last_channel self.features = [conv_bn(3, input_channel, 2, nlin_layer=Hswish)] # self.classifier = [] # building mobile blocks for k, exp, c, se, nl, s in mobile_setting: output_channel = make_divisible(c * widen_factor) exp_channel = make_divisible(exp * widen_factor) self.features.append(MobileBottleneck(input_channel, output_channel, k, s, exp_channel, se, nl)) input_channel = output_channel # building last several layers if mode == 'large': last_conv = make_divisible(960 * widen_factor) self.features.append(conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish)) self.features.append(nn.AdaptiveAvgPool2d(1)) self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0)) self.features.append(Hswish(inplace=True)) elif mode == 'small': last_conv = make_divisible(576 * widen_factor) self.features.append(conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish)) # self.features.append(SEModule(last_conv)) # refer to paper Table2, but I think this is a mistake self.features.append(nn.AdaptiveAvgPool2d(1)) self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0)) self.features.append(Hswish(inplace=True)) else: raise NotImplementedError # make it nn.Sequential self.features = nn.Sequential(*self.features) # self.fc_param = nn.Linear(int(last_channel), num_classes) self.fc = nn.Linear(int(last_channel), num_classes) # self.fc_lm = nn.Linear(int(last_channel), num_landmarks) # building classifier # self.classifier = nn.Sequential( # nn.Dropout(p=dropout), # refer to paper section 6 # nn.Linear(last_channel, n_class), # ) self._initialize_weights() def forward(self, x): x = self.features(x) x_share = x.mean(3).mean(2) # x = self.classifier(x) # print(x_share.shape) # xp = self.fc_param(x_share) # param # xl = self.fc_lm(x_share) # lm xp = self.fc(x_share) # param return xp # , xl def _initialize_weights(self): # weight initialization for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out') if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm2d): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) if m.bias is not None: nn.init.zeros_(m.bias) def mobilenet_v3(**kwargs): model = MobileNetV3( widen_factor=kwargs.get('widen_factor', 1.0), num_classes=kwargs.get('num_classes', 62), num_landmarks=kwargs.get('num_landmarks', 136), input_size=kwargs.get('size', 128), mode=kwargs.get('mode', 'small') ) return model ================================================ FILE: extract_init_states/models/resnet.py ================================================ #!/usr/bin/env python3 # coding: utf-8 import torch.nn as nn __all__ = ['ResNet', 'resnet22'] def conv3x3(in_planes, out_planes, stride=1): "3x3 convolution with padding" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class ResNet(nn.Module): """Another Strucutre used in caffe-resnet25""" def __init__(self, block, layers, num_classes=62, num_landmarks=136, input_channel=3, fc_flg=False): self.inplanes = 64 super(ResNet, self).__init__() self.conv1 = nn.Conv2d(input_channel, 32, kernel_size=5, stride=2, padding=2, bias=False) self.bn1 = nn.BatchNorm2d(32) # 32 is input channels number self.relu1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(64) self.relu2 = nn.ReLU(inplace=True) # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 128, layers[0], stride=2) self.layer2 = self._make_layer(block, 256, layers[1], stride=2) self.layer3 = self._make_layer(block, 512, layers[2], stride=2) self.conv_param = nn.Conv2d(512, num_classes, 1) # self.conv_lm = nn.Conv2d(512, num_landmarks, 1) self.avgpool = nn.AdaptiveAvgPool2d(1) # self.fc = nn.Linear(512 * block.expansion, num_classes) self.fc_flg = fc_flg # parameter initialization for m in self.modules(): if isinstance(m, nn.Conv2d): # 1. # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels # m.weight.data.normal_(0, math.sqrt(2. / n)) # 2. kaiming normal nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu1(x) x = self.conv2(x) x = self.bn2(x) x = self.relu2(x) # x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) # if self.fc_flg: # x = self.avgpool(x) # x = x.view(x.size(0), -1) # x = self.fc(x) # else: xp = self.conv_param(x) xp = self.avgpool(xp) xp = xp.view(xp.size(0), -1) # xl = self.conv_lm(x) # xl = self.avgpool(xl) # xl = xl.view(xl.size(0), -1) return xp # , xl def resnet22(**kwargs): model = ResNet( BasicBlock, [3, 4, 3], num_landmarks=kwargs.get('num_landmarks', 136), input_channel=kwargs.get('input_channel', 3), fc_flg=False ) return model def main(): pass if __name__ == '__main__': main() ================================================ FILE: extract_init_states/pose.py ================================================ # coding: utf-8 """ Reference: https://github.com/YadiraF/PRNet/blob/master/utils/estimate_pose.py Calculating pose from the output 3DMM parameters, you can also try to use solvePnP to perform estimation """ __author__ = 'cleardusk' import cv2 import numpy as np from math import cos, sin, atan2, asin, sqrt from .functions import calc_hypotenuse, plot_image def P2sRt(P): """ decompositing camera matrix P. Args: P: (3, 4). Affine Camera Matrix. Returns: s: scale factor. R: (3, 3). rotation matrix. t2d: (2,). 2d translation. """ t3d = P[:, 3] # R1 = P[0:1, :3] R2 = P[1:2, :3] s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2.0 # r1 = R1 / np.linalg.norm(R1) r2 = R2 / np.linalg.norm(R2) r3 = np.cross(r1, r2) # r1r2,r3 R = np.concatenate((r1, r2, r3), 0) # r 1-3R () return s, R, t3d def matrix2angle(R): """ compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf refined by: https://stackoverflow.com/questions/43364900/rotation-matrix-to-euler-angles-with-opencv todo: check and debug Args: R: (3,3). rotation matrix Returns: x: yaw y: pitch z: roll """ if R[2, 0] > 0.998: z = 0 x = np.pi / 2 y = z + atan2(-R[0, 1], -R[0, 2]) elif R[2, 0] < -0.998: z = 0 x = -np.pi / 2 y = -z + atan2(R[0, 1], R[0, 2]) else: x = asin(R[2, 0]) y = atan2(R[2, 1] / cos(x), R[2, 2] / cos(x)) z = atan2(R[1, 0] / cos(x), R[0, 0] / cos(x)) return x, y, z def angle2matrix(theta): """ compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf refined by: https://stackoverflow.com/questions/43364900/rotation-matrix-to-euler-angles-with-opencv todo: check and debug Args: R: (3,3). rotation matrix Returns: x: yaw y: pitch z: roll """ R_x = np.array([[1, 0, 0 ], [0, cos(theta[1]), -sin(theta[1]) ], [0, sin(theta[1]), cos(theta[1]) ] ]) R_y = np.array([[cos(theta[0]), 0, sin(-theta[0]) ], [0, 1, 0 ], [-sin(-theta[0]), 0, cos(theta[0]) ] ]) R_z = np.array([[cos(theta[2]), -sin(theta[2]), 0], [sin(theta[2]), cos(theta[2]), 0], [0, 0, 1] ]) R = np.dot(R_z, np.dot( R_y, R_x )) return R def angle2matrix_3ddfa(angles): ''' get rotation matrix from three rotation angles(radian). The same as in 3DDFA. Args: angles: [3,]. x, y, z angles x: pitch. y: yaw. z: roll. Returns: R: 3x3. rotation matrix. ''' # x, y, z = np.deg2rad(angles[0]), np.deg2rad(angles[1]), np.deg2rad(angles[2]) x, y, z = angles[1], angles[0], angles[2] # x Rx=np.array([[1, 0, 0], [0, cos(x), sin(x)], [0, -sin(x), cos(x)]]) # y Ry=np.array([[ cos(y), 0, -sin(y)], [ 0, 1, 0], [sin(y), 0, cos(y)]]) # z Rz=np.array([[cos(z), sin(z), 0], [-sin(z), cos(z), 0], [ 0, 0, 1]]) R = Rx.dot(Ry).dot(Rz) return R.astype(np.float32) def calc_pose(param): P = param[:12].reshape(3, -1) # camera matrix s, R, t3d = P2sRt(P) P = np.concatenate((R, t3d.reshape(3, -1)), axis=1) # without scale pose = matrix2angle(R) pose = [p * 180 / np.pi for p in pose] return P, pose def build_camera_box(rear_size=90): point_3d = [] rear_depth = 0 point_3d.append((-rear_size, -rear_size, rear_depth)) point_3d.append((-rear_size, rear_size, rear_depth)) point_3d.append((rear_size, rear_size, rear_depth)) point_3d.append((rear_size, -rear_size, rear_depth)) point_3d.append((-rear_size, -rear_size, rear_depth)) front_size = int(4 / 3 * rear_size) front_depth = int(4 / 3 * rear_size) point_3d.append((-front_size, -front_size, front_depth)) point_3d.append((-front_size, front_size, front_depth)) point_3d.append((front_size, front_size, front_depth)) point_3d.append((front_size, -front_size, front_depth)) point_3d.append((-front_size, -front_size, front_depth)) point_3d = np.array(point_3d, dtype=np.float32).reshape(-1, 3) return point_3d def plot_pose_box(img, P, ver, color=(40, 255, 0), line_width=2): """ Draw a 3D box as annotation of pose. Ref:https://github.com/yinguobing/head-pose-estimation/blob/master/pose_estimator.py Args: img: the input image P: (3, 4). Affine Camera Matrix. kpt: (2, 68) or (3, 68) """ llength = calc_hypotenuse(ver) point_3d = build_camera_box(llength) # Map to 2d image points point_3d_homo = np.hstack((point_3d, np.ones([point_3d.shape[0], 1]))) # n x 4 point_2d = point_3d_homo.dot(P.T)[:, :2] point_2d[:, 1] = - point_2d[:, 1] point_2d[:, :2] = point_2d[:, :2] - np.mean(point_2d[:4, :2], 0) + np.mean(ver[:2, :27], 1) # lmk 0-27 point_2d = np.int32(point_2d.reshape(-1, 2)) # Draw all the lines cv2.polylines(img, [point_2d], True, color, line_width, cv2.LINE_AA) cv2.line(img, tuple(point_2d[1]), tuple( point_2d[6]), color, line_width, cv2.LINE_AA) cv2.line(img, tuple(point_2d[2]), tuple( point_2d[7]), color, line_width, cv2.LINE_AA) cv2.line(img, tuple(point_2d[3]), tuple( point_2d[8]), color, line_width, cv2.LINE_AA) return img def viz_pose(img, param_lst, ver_lst, show_flag=False, wfp=None): for param, ver in zip(param_lst, ver_lst): P, pose = calc_pose(param) img = plot_pose_box(img, P, ver) # print(P[:, :3]) # print(f'yaw: {pose[0]:.1f}, pitch: {pose[1]:.1f}, roll: {pose[2]:.1f}') if wfp is not None: cv2.imwrite(wfp, img) print(f'Save visualization result to {wfp}') if show_flag: plot_image(img) return img def pose_6(param): P = param[:12].reshape(3, -1) # camera matrix s, R, t3d = P2sRt(P) P = np.concatenate((R, t3d.reshape(3, -1)), axis=1) # without scale pose = matrix2angle(R) # R,pose # print(t3d) R1 = angle2matrix(pose) # print(R) # print(R1) pose = [p * 180 / np.pi for p in pose] return s, pose, t3d, P # s()、R()、t3d() def smooth_pose(img, param_lst, ver_lst, pose_new, show_flag=False, wfp=None, wnp = None): for param, ver in zip(param_lst, ver_lst): t3d = np.array([pose_new[4],pose_new[5],pose_new[6]]) theta = np.array([pose_new[0],pose_new[1],pose_new[2]]) theta = [p * np.pi / 180 for p in theta] R = angle2matrix(theta) P = np.concatenate((R, t3d.reshape(3, -1)), axis=1) img = plot_pose_box(img, P, ver) # print(P,P.shape,t3d) # print(P,pose_new) # print(f'yaw: {theta[0]:.1f}, pitch: {theta[1]:.1f}, roll: {theta[2]:.1f}') all_pose = [0] all_pose = np.array(all_pose) if wfp is not None: cv2.imwrite(wfp, img) print(f'Save visualization result to {wfp}') if wnp is not None: np.save(wnp, all_pose) print(f'Save visualization result to {wfp}') if show_flag: plot_image(img) return img def get_pose(img, param_lst, ver_lst, show_flag=False, wfp=None, wnp = None): for param, ver in zip(param_lst, ver_lst): # s, pose, t3d, P = pose_6(param) img_1 = plot_pose_box(img.copy(), P, ver) # print(P,P.shape,t3d) # print(f'yaw: {pose[0]:.1f}, pitch: {pose[1]:.1f}, roll: {pose[2]:.1f}') all_pose = [pose[0],pose[1],pose[2],s,t3d[0],t3d[1],t3d[2]] all_pose = np.array(all_pose) # if wfp is not None: # cv2.imwrite(wfp, img_1) # print(f'Save visualization result to {wfp}') # if wnp is not None: # np.save(wnp, all_pose) # print(f'Save visualization result to {wfp}') if show_flag: plot_image(img) return all_pose ================================================ FILE: extract_init_states/readme.md ================================================ # README The `extract_init_state` is mainly from [3DDFA_v2](https://github.com/cleardusk/3DDFA_V2) with minor revision. We remove the `Sim3DR` in original repo. We add or revise the file of `extract_init_states\demo_pose_extract_2d_lmk_img.py`, `extract_init_states\utils\pose.py`. ## Linux Linux user can follow the installation process on [3DDFA_v2](https://github.com/cleardusk/3DDFA_V2) ## Win For Windows user, be aware to these tips: 1. Installing gcc 2. In `extract_init_states\FaceBoxes\utils\build.py`, you need comment line 47 3. Revise the `extract_init_states\FaceBoxes\utils\nms\cpu_nms.pyx` following [comment](https://github.com/cleardusk/3DDFA_V2/issues/12#issuecomment-697479173). 4. Run the command in sh script line by line manually ================================================ FILE: extract_init_states/utils/__init__.py ================================================ ================================================ FILE: extract_init_states/utils/asset/.gitignore ================================================ *.so ================================================ FILE: extract_init_states/utils/asset/build_render_ctypes.sh ================================================ gcc -shared -Wall -O3 render.c -o render.so -fPIC ================================================ FILE: extract_init_states/utils/asset/render.c ================================================ #include #include #include #define max(x, y) (((x) > (y)) ? (x) : (y)) #define min(x, y) (((x) < (y)) ? (x) : (y)) #define clip(_x, _min, _max) min(max(_x, _min), _max) struct Tuple3D { float x; float y; float z; }; void _render(const int *triangles, const int ntri, const float *light, const float *directional, const float *ambient, const float *vertices, const int nver, unsigned char *image, const int h, const int w) { int tri_p0_ind, tri_p1_ind, tri_p2_ind; int color_index; float dot00, dot01, dot11, dot02, dot12; float cos_sum, det; struct Tuple3D p0, p1, p2; struct Tuple3D v0, v1, v2; struct Tuple3D p, start, end; struct Tuple3D ver_max = {-1.0e8, -1.0e8, -1.0e8}; struct Tuple3D ver_min = {1.0e8, 1.0e8, 1.0e8}; struct Tuple3D ver_mean = {0.0, 0.0, 0.0}; float *ver_normal = (float *)calloc(3 * nver, sizeof(float)); float *colors = (float *)malloc(3 * nver * sizeof(float)); float *depth_buffer = (float *)calloc(h * w, sizeof(float)); for (int i = 0; i < ntri; i++) { tri_p0_ind = triangles[3 * i]; tri_p1_ind = triangles[3 * i + 1]; tri_p2_ind = triangles[3 * i + 2]; // counter clockwise order start.x = vertices[tri_p1_ind] - vertices[tri_p0_ind]; start.y = vertices[tri_p1_ind + 1] - vertices[tri_p0_ind + 1]; start.z = vertices[tri_p1_ind + 2] - vertices[tri_p0_ind + 2]; end.x = vertices[tri_p2_ind] - vertices[tri_p0_ind]; end.y = vertices[tri_p2_ind + 1] - vertices[tri_p0_ind + 1]; end.z = vertices[tri_p2_ind + 2] - vertices[tri_p0_ind + 2]; p.x = start.y * end.z - start.z * end.y; p.y = start.z * end.x - start.x * end.z; p.z = start.x * end.y - start.y * end.x; ver_normal[tri_p0_ind] += p.x; ver_normal[tri_p1_ind] += p.x; ver_normal[tri_p2_ind] += p.x; ver_normal[tri_p0_ind + 1] += p.y; ver_normal[tri_p1_ind + 1] += p.y; ver_normal[tri_p2_ind + 1] += p.y; ver_normal[tri_p0_ind + 2] += p.z; ver_normal[tri_p1_ind + 2] += p.z; ver_normal[tri_p2_ind + 2] += p.z; } for (int i = 0; i < nver; ++i) { p.x = ver_normal[3 * i]; p.y = ver_normal[3 * i + 1]; p.z = ver_normal[3 * i + 2]; det = sqrt(p.x * p.x + p.y * p.y + p.z * p.z); if (det <= 0) det = 1e-6; ver_normal[3 * i] /= det; ver_normal[3 * i + 1] /= det; ver_normal[3 * i + 2] /= det; ver_mean.x += p.x; ver_mean.y += p.y; ver_mean.z += p.z; ver_max.x = max(ver_max.x, p.x); ver_max.y = max(ver_max.y, p.y); ver_max.z = max(ver_max.z, p.z); ver_min.x = min(ver_min.x, p.x); ver_min.y = min(ver_min.y, p.y); ver_min.z = min(ver_min.z, p.z); } ver_mean.x /= nver; ver_mean.y /= nver; ver_mean.z /= nver; for (int i = 0; i < nver; ++i) { colors[3 * i] = vertices[3 * i]; colors[3 * i + 1] = vertices[3 * i + 1]; colors[3 * i + 2] = vertices[3 * i + 2]; colors[3 * i] -= ver_mean.x; colors[3 * i] /= ver_max.x - ver_min.x; colors[3 * i + 1] -= ver_mean.y; colors[3 * i + 1] /= ver_max.y - ver_min.y; colors[3 * i + 2] -= ver_mean.z; colors[3 * i + 2] /= ver_max.z - ver_min.z; p.x = light[0] - colors[3 * i]; p.y = light[1] - colors[3 * i + 1]; p.z = light[2] - colors[3 * i + 2]; det = sqrt(p.x * p.x + p.y * p.y + p.z * p.z); if (det <= 0) det = 1e-6; colors[3 * i] = p.x / det; colors[3 * i + 1] = p.y / det; colors[3 * i + 2] = p.z / det; colors[3 * i] *= ver_normal[3 * i]; colors[3 * i + 1] *= ver_normal[3 * i + 1]; colors[3 * i + 2] *= ver_normal[3 * i + 2]; cos_sum = colors[3 * i] + colors[3 * i + 1] + colors[3 * i + 2]; colors[3 * i] = clip(cos_sum * directional[0] + ambient[0], 0, 1); colors[3 * i + 1] = clip(cos_sum * directional[1] + ambient[1], 0, 1); colors[3 * i + 2] = clip(cos_sum * directional[2] + ambient[2], 0, 1); } for (int i = 0; i < ntri; ++i) { tri_p0_ind = triangles[3 * i]; tri_p1_ind = triangles[3 * i + 1]; tri_p2_ind = triangles[3 * i + 2]; p0.x = vertices[tri_p0_ind]; p0.y = vertices[tri_p0_ind + 1]; p0.z = vertices[tri_p0_ind + 2]; p1.x = vertices[tri_p1_ind]; p1.y = vertices[tri_p1_ind + 1]; p1.z = vertices[tri_p1_ind + 2]; p2.x = vertices[tri_p2_ind]; p2.y = vertices[tri_p2_ind + 1]; p2.z = vertices[tri_p2_ind + 2]; start.x = max(ceil(min(p0.x, min(p1.x, p2.x))), 0); end.x = min(floor(max(p0.x, max(p1.x, p2.x))), w - 1); start.y = max(ceil(min(p0.y, min(p1.y, p2.y))), 0); end.y = min(floor(max(p0.y, max(p1.y, p2.y))), h - 1); if (end.x < start.x || end.y < start.y) continue; v0.x = p2.x - p0.x; v0.y = p2.y - p0.y; v1.x = p1.x - p0.x; v1.y = p1.y - p0.y; // dot products np.dot(v0.T, v0) dot00 = v0.x * v0.x + v0.y * v0.y; dot01 = v0.x * v1.x + v0.y * v1.y; dot11 = v1.x * v1.x + v1.y * v1.y; // barycentric coordinates start.z = dot00 * dot11 - dot01 * dot01; if (start.z != 0) start.z = 1 / start.z; for (p.y = start.y; p.y <= end.y; p.y += 1.0) { for (p.x = start.x; p.x <= end.x; p.x += 1.0) { v2.x = p.x - p0.x; v2.y = p.y - p0.y; dot02 = v0.x * v2.x + v0.y * v2.y; dot12 = v1.x * v2.x + v1.y * v2.y; v2.z = (dot11 * dot02 - dot01 * dot12) * start.z; v1.z = (dot00 * dot12 - dot01 * dot02) * start.z; v0.z = 1 - v2.z - v1.z; // judge is_point_in_tri by below line of code if (v2.z > 0 && v1.z > 0 && v0.z > 0) { p.z = v0.z * p0.z + v1.z * p1.z + v2.z * p2.z; color_index = p.y * w + p.x; if (p.z > depth_buffer[color_index]) { end.z = v0.z * colors[tri_p0_ind]; end.z += v1.z * colors[tri_p1_ind]; end.z += v2.z * colors[tri_p2_ind]; image[3 * color_index] = end.z * 255; end.z = v0.z * colors[tri_p0_ind + 1]; end.z += v1.z * colors[tri_p1_ind + 1]; end.z += v2.z * colors[tri_p2_ind + 1]; image[3 * color_index + 1] = end.z * 255; end.z = v0.z * colors[tri_p0_ind + 2]; end.z += v1.z * colors[tri_p1_ind + 2]; end.z += v2.z * colors[tri_p2_ind + 2]; image[3 * color_index + 2] = end.z * 255; depth_buffer[color_index] = p.z; } } } } } free(depth_buffer); free(colors); free(ver_normal); } ================================================ FILE: extract_init_states/utils/depth.py ================================================ # coding: utf-8 __author__ = 'cleardusk' import sys sys.path.append('..') import cv2 import numpy as np from Sim3DR import rasterize from utils.functions import plot_image from .tddfa_util import _to_ctype def depth(img, ver_lst, tri, show_flag=False, wfp=None, with_bg_flag=True): if with_bg_flag: overlap = img.copy() else: overlap = np.zeros_like(img) for ver_ in ver_lst: ver = _to_ctype(ver_.T) # transpose z = ver[:, 2] z_min, z_max = min(z), max(z) z = (z - z_min) / (z_max - z_min) # expand z = np.repeat(z[:, np.newaxis], 3, axis=1) overlap = rasterize(ver, tri, z, bg=overlap) if wfp is not None: cv2.imwrite(wfp, overlap) print(f'Save visualization result to {wfp}') if show_flag: plot_image(overlap) return overlap ================================================ FILE: extract_init_states/utils/functions.py ================================================ # coding: utf-8 __author__ = 'cleardusk' import numpy as np import cv2 from math import sqrt import matplotlib.pyplot as plt RED = (0, 0, 255) GREEN = (0, 255, 0) BLUE = (255, 0, 0) def get_suffix(filename): """a.jpg -> jpg""" pos = filename.rfind('.') if pos == -1: return '' return filename[pos:] def crop_img(img, roi_box): h, w = img.shape[:2] sx, sy, ex, ey = [int(round(_)) for _ in roi_box] dh, dw = ey - sy, ex - sx if len(img.shape) == 3: res = np.zeros((dh, dw, 3), dtype=np.uint8) else: res = np.zeros((dh, dw), dtype=np.uint8) if sx < 0: sx, dsx = 0, -sx else: dsx = 0 if ex > w: ex, dex = w, dw - (ex - w) else: dex = dw if sy < 0: sy, dsy = 0, -sy else: dsy = 0 if ey > h: ey, dey = h, dh - (ey - h) else: dey = dh res[dsy:dey, dsx:dex] = img[sy:ey, sx:ex] return res def calc_hypotenuse(pts): bbox = [min(pts[0, :]), min(pts[1, :]), max(pts[0, :]), max(pts[1, :])] center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] radius = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 bbox = [center[0] - radius, center[1] - radius, center[0] + radius, center[1] + radius] llength = sqrt((bbox[2] - bbox[0]) ** 2 + (bbox[3] - bbox[1]) ** 2) return llength / 3 def parse_roi_box_from_landmark(pts): """calc roi box from landmark""" bbox = [min(pts[0, :]), min(pts[1, :]), max(pts[0, :]), max(pts[1, :])] center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] radius = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 bbox = [center[0] - radius, center[1] - radius, center[0] + radius, center[1] + radius] llength = sqrt((bbox[2] - bbox[0]) ** 2 + (bbox[3] - bbox[1]) ** 2) center_x = (bbox[2] + bbox[0]) / 2 center_y = (bbox[3] + bbox[1]) / 2 roi_box = [0] * 4 roi_box[0] = center_x - llength / 2 roi_box[1] = center_y - llength / 2 roi_box[2] = roi_box[0] + llength roi_box[3] = roi_box[1] + llength return roi_box def parse_roi_box_from_bbox(bbox): left, top, right, bottom = bbox[:4] old_size = (right - left + bottom - top) / 2 center_x = right - (right - left) / 2.0 center_y = bottom - (bottom - top) / 2.0 + old_size * 0.14 size = int(old_size * 1.58) roi_box = [0] * 4 roi_box[0] = center_x - size / 2 roi_box[1] = center_y - size / 2 roi_box[2] = roi_box[0] + size roi_box[3] = roi_box[1] + size return roi_box def plot_image(img): height, width = img.shape[:2] plt.figure(figsize=(12, height / width * 12)) plt.subplots_adjust(left=0, right=1, top=1, bottom=0) plt.axis('off') plt.imshow(img[..., ::-1]) plt.show() def draw_landmarks(img, pts, style='fancy', wfp=None, show_flag=False, **kwargs): """Draw landmarks using matplotlib""" height, width = img.shape[:2] plt.figure(figsize=(12, height / width * 12)) plt.imshow(img[..., ::-1]) plt.subplots_adjust(left=0, right=1, top=1, bottom=0) plt.axis('off') dense_flag = kwargs.get('dense_flag') if not type(pts) in [tuple, list]: pts = [pts] for i in range(len(pts)): if dense_flag: plt.plot(pts[i][0, ::6], pts[i][1, ::6], 'o', markersize=0.4, color='c', alpha=0.7) else: alpha = 0.8 markersize = 4 lw = 1.5 color = kwargs.get('color', 'w') markeredgecolor = kwargs.get('markeredgecolor', 'black') nums = [0, 17, 22, 27, 31, 36, 42, 48, 60, 68] # close eyes and mouths plot_close = lambda i1, i2: plt.plot([pts[i][0, i1], pts[i][0, i2]], [pts[i][1, i1], pts[i][1, i2]], color=color, lw=lw, alpha=alpha - 0.1) plot_close(41, 36) plot_close(47, 42) plot_close(59, 48) plot_close(67, 60) for ind in range(len(nums) - 1): l, r = nums[ind], nums[ind + 1] plt.plot(pts[i][0, l:r], pts[i][1, l:r], color=color, lw=lw, alpha=alpha - 0.1) plt.plot(pts[i][0, l:r], pts[i][1, l:r], marker='o', linestyle='None', markersize=markersize, color=color, markeredgecolor=markeredgecolor, alpha=alpha) if wfp is not None: plt.savefig(wfp, dpi=150) print(f'Save visualization result to {wfp}') if show_flag: plt.show() def cv_draw_landmark(img_ori, pts, box=None, color=GREEN, size=1): img = img_ori.copy() n = pts.shape[1] if n <= 106: for i in range(n): cv2.circle(img, (int(round(pts[0, i])), int(round(pts[1, i]))), size, color, -1) else: sep = 1 for i in range(0, n, sep): cv2.circle(img, (int(round(pts[0, i])), int(round(pts[1, i]))), size, color, 1) if box is not None: left, top, right, bottom = np.round(box).astype(np.int32) left_top = (left, top) right_top = (right, top) right_bottom = (right, bottom) left_bottom = (left, bottom) cv2.line(img, left_top, right_top, BLUE, 1, cv2.LINE_AA) cv2.line(img, right_top, right_bottom, BLUE, 1, cv2.LINE_AA) cv2.line(img, right_bottom, left_bottom, BLUE, 1, cv2.LINE_AA) cv2.line(img, left_bottom, left_top, BLUE, 1, cv2.LINE_AA) return img def calculate_bbox(img, lmk): lmk = lmk.transpose(1,0) # point_3d_homo = np.hstack((lmk, np.ones([lmk.shape[0], 1]))) # n x 4 # point_2d = point_3d_homo.dot(P.T)[:, :2] # point_2d[:, 1] = - point_2d[:, 1] # point_2d[:, :2] = point_2d[:, :2] - np.mean(point_2d, 0) + np.mean(lmk[:27,:2], 0) # lmk 0-27 face contour point_2d = lmk[:, :2] point_2d = np.int32(point_2d.reshape(-1, 2)) H = img.shape[0] W = img.shape[1] x_min, x_max = point_2d[:, 0].min(), point_2d[:, 0].max() y_min, y_max = point_2d[:, 1].min(), point_2d[:, 1].max() # cv2.polylines(img, [point_2d], True, (40, 255, 0), 2, cv2.LINE_AA) # points_list = [(p[0], p[1]) for p in point_2d] # for p in points_list: # cv2.circle(img, p, 1, (40, 255, 0), -1) # return img return [x_min, x_max, y_min, y_max, H, W] def calculate_eye(lmk): lmk = lmk.transpose(1,0) leye_upper = lmk[43] leye_lower = lmk[47] leye_left = lmk[45] leye_right = lmk[42] reye_upper = lmk[37] reye_lower = lmk[41] reye_left = lmk[39] reye_right = lmk[36] left_ratio = np.linalg.norm(leye_upper - leye_lower, 2) / np.linalg.norm(leye_left - leye_right, 2) right_ratio = np.linalg.norm(reye_upper - reye_lower, 2) / np.linalg.norm(reye_left - reye_right, 2) return left_ratio, right_ratio ================================================ FILE: extract_init_states/utils/io.py ================================================ # coding: utf-8 __author__ = 'cleardusk' import os import numpy as np import torch import pickle def mkdir(d): os.makedirs(d, exist_ok=True) def _get_suffix(filename): """a.jpg -> jpg""" pos = filename.rfind('.') if pos == -1: return '' return filename[pos + 1:] def _load(fp): suffix = _get_suffix(fp) if suffix == 'npy': return np.load(fp) elif suffix == 'pkl': return pickle.load(open(fp, 'rb')) def _dump(wfp, obj): suffix = _get_suffix(wfp) if suffix == 'npy': np.save(wfp, obj) elif suffix == 'pkl': pickle.dump(obj, open(wfp, 'wb')) else: raise Exception('Unknown Type: {}'.format(suffix)) def _load_tensor(fp, mode='cpu'): if mode.lower() == 'cpu': return torch.from_numpy(_load(fp)) elif mode.lower() == 'gpu': return torch.from_numpy(_load(fp)).cuda() def _tensor_to_cuda(x): if x.is_cuda: return x else: return x.cuda() def _load_gpu(fp): return torch.from_numpy(_load(fp)).cuda() _load_cpu = _load _numpy_to_tensor = lambda x: torch.from_numpy(x) _tensor_to_numpy = lambda x: x.numpy() _numpy_to_cuda = lambda x: _tensor_to_cuda(torch.from_numpy(x)) _cuda_to_tensor = lambda x: x.cpu() _cuda_to_numpy = lambda x: x.cpu().numpy() ================================================ FILE: extract_init_states/utils/onnx.py ================================================ # coding: utf-8 __author__ = 'cleardusk' import sys sys.path.append('..') import torch import models from utils.tddfa_util import load_model def convert_to_onnx(**kvs): # 1. load model size = kvs.get('size', 120) model = getattr(models, kvs.get('arch'))( num_classes=kvs.get('num_params', 62), widen_factor=kvs.get('widen_factor', 1), size=size, mode=kvs.get('mode', 'small') ) checkpoint_fp = kvs.get('checkpoint_fp') model = load_model(model, checkpoint_fp) model.eval() # 2. convert batch_size = 1 dummy_input = torch.randn(batch_size, 3, size, size) wfp = checkpoint_fp.replace('.pth', '.onnx') torch.onnx.export( model, (dummy_input, ), wfp, input_names=['input'], output_names=['output'], do_constant_folding=True ) print(f'Convert {checkpoint_fp} to {wfp} done.') return wfp ================================================ FILE: extract_init_states/utils/pncc.py ================================================ # coding: utf-8 __author__ = 'cleardusk' import sys sys.path.append('..') import cv2 import numpy as np import os.path as osp from Sim3DR import rasterize from utils.functions import plot_image from utils.io import _load, _dump from utils.tddfa_util import _to_ctype make_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn) def calc_ncc_code(): from bfm import bfm # formula: ncc_d = ( u_d - min(u_d) ) / ( max(u_d) - min(u_d) ), d = {r, g, b} u = bfm.u u = u.reshape(3, -1, order='F') for i in range(3): u[i] = (u[i] - u[i].min()) / (u[i].max() - u[i].min()) _dump('../configs/ncc_code.npy', u) def pncc(img, ver_lst, tri, show_flag=False, wfp=None, with_bg_flag=True): ncc_code = _load(make_abs_path('../configs/ncc_code.npy')) if with_bg_flag: overlap = img.copy() else: overlap = np.zeros_like(img) # rendering pncc for ver_ in ver_lst: ver = _to_ctype(ver_.T) # transpose overlap = rasterize(ver, tri, ncc_code.T, bg=overlap) # m x 3 if wfp is not None: cv2.imwrite(wfp, overlap) print(f'Save visualization result to {wfp}') if show_flag: plot_image(overlap) return overlap def main(): # `configs/ncc_code.npy` is generated by `calc_nnc_code` function # calc_ncc_code() pass if __name__ == '__main__': main() ================================================ FILE: extract_init_states/utils/pose.py ================================================ # coding: utf-8 """ Reference: https://github.com/YadiraF/PRNet/blob/master/utils/estimate_pose.py Calculating pose from the output 3DMM parameters, you can also try to use solvePnP to perform estimation """ __author__ = 'cleardusk' import cv2 import numpy as np from math import cos, sin, atan2, asin, sqrt from .functions import calc_hypotenuse, plot_image def P2sRt(P): """ decompositing camera matrix P. Args: P: (3, 4). Affine Camera Matrix. Returns: s: scale factor. R: (3, 3). rotation matrix. t2d: (2,). 2d translation. """ t3d = P[:, 3] # shift R1 = P[0:1, :3] R2 = P[1:2, :3] s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2.0 # r1 = R1 / np.linalg.norm(R1) r2 = R2 / np.linalg.norm(R2) r3 = np.cross(r1, r2) # r1r2,r3 R = np.concatenate((r1, r2, r3), 0) # r 1-3R () return s, R, t3d def matrix2angle(R): """ compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf refined by: https://stackoverflow.com/questions/43364900/rotation-matrix-to-euler-angles-with-opencv todo: check and debug Args: R: (3,3). rotation matrix Returns: x: yaw y: pitch z: roll """ if R[2, 0] > 0.998: z = 0 x = np.pi / 2 y = z + atan2(-R[0, 1], -R[0, 2]) elif R[2, 0] < -0.998: z = 0 x = -np.pi / 2 y = -z + atan2(R[0, 1], R[0, 2]) else: x = asin(R[2, 0]) y = atan2(R[2, 1] / cos(x), R[2, 2] / cos(x)) z = atan2(R[1, 0] / cos(x), R[0, 0] / cos(x)) return x, y, z def angle2matrix(theta): """ compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf refined by: https://stackoverflow.com/questions/43364900/rotation-matrix-to-euler-angles-with-opencv todo: check and debug Args: R: (3,3). rotation matrix Returns: x: yaw y: pitch z: roll """ R_x = np.array([[1, 0, 0 ], [0, cos(theta[1]), -sin(theta[1]) ], [0, sin(theta[1]), cos(theta[1]) ] ]) R_y = np.array([[cos(theta[0]), 0, sin(-theta[0]) ], [0, 1, 0 ], [-sin(-theta[0]), 0, cos(theta[0]) ] ]) R_z = np.array([[cos(theta[2]), -sin(theta[2]), 0], [sin(theta[2]), cos(theta[2]), 0], [0, 0, 1] ]) R = np.dot(R_z, np.dot( R_y, R_x )) return R def angle2matrix_3ddfa(angles): ''' get rotation matrix from three rotation angles(radian). The same as in 3DDFA. Args: angles: [3,]. x, y, z angles x: pitch. y: yaw. z: roll. Returns: R: 3x3. rotation matrix. ''' # x, y, z = np.deg2rad(angles[0]), np.deg2rad(angles[1]), np.deg2rad(angles[2]) x, y, z = angles[1], angles[0], angles[2] # x Rx=np.array([[1, 0, 0], [0, cos(x), sin(x)], [0, -sin(x), cos(x)]]) # y Ry=np.array([[ cos(y), 0, -sin(y)], [ 0, 1, 0], [sin(y), 0, cos(y)]]) # z Rz=np.array([[cos(z), sin(z), 0], [-sin(z), cos(z), 0], [ 0, 0, 1]]) R = Rx.dot(Ry).dot(Rz) return R.astype(np.float32) def calc_pose(param): P = param[:12].reshape(3, -1) # camera matrix s, R, t3d = P2sRt(P) P = np.concatenate((R, t3d.reshape(3, -1)), axis=1) # without scale pose = matrix2angle(R) pose = [p * 180 / np.pi for p in pose] return P, pose def build_camera_box(rear_size=90): point_3d = [] rear_depth = 0 point_3d.append((-rear_size, -rear_size, rear_depth)) point_3d.append((-rear_size, rear_size, rear_depth)) point_3d.append((rear_size, rear_size, rear_depth)) point_3d.append((rear_size, -rear_size, rear_depth)) point_3d.append((-rear_size, -rear_size, rear_depth)) front_size = int(4 / 3 * rear_size) front_depth = int(4 / 3 * rear_size) point_3d.append((-front_size, -front_size, front_depth)) point_3d.append((-front_size, front_size, front_depth)) point_3d.append((front_size, front_size, front_depth)) point_3d.append((front_size, -front_size, front_depth)) point_3d.append((-front_size, -front_size, front_depth)) point_3d = np.array(point_3d, dtype=np.float32).reshape(-1, 3) return point_3d def plot_pose_box(img, P, ver, color=(40, 255, 0), line_width=2): """ Draw a 3D box as annotation of pose. Ref:https://github.com/yinguobing/head-pose-estimation/blob/master/pose_estimator.py Args: img: the input image P: (3, 4). Affine Camera Matrix. kpt: (2, 68) or (3, 68) """ llength = calc_hypotenuse(ver) point_3d = build_camera_box(llength) # Map to 2d image points point_3d_homo = np.hstack((point_3d, np.ones([point_3d.shape[0], 1]))) # n x 4 point_2d = point_3d_homo.dot(P.T)[:, :2] point_2d[:, 1] = - point_2d[:, 1] point_2d[:, :2] = point_2d[:, :2] - np.mean(point_2d[:4, :2], 0) + np.mean(ver[:2, :27], 1) # lmk 0-27 point_2d = np.int32(point_2d.reshape(-1, 2)) # Draw all the lines cv2.polylines(img, [point_2d], True, color, line_width, cv2.LINE_AA) cv2.line(img, tuple(point_2d[1]), tuple( point_2d[6]), color, line_width, cv2.LINE_AA) cv2.line(img, tuple(point_2d[2]), tuple( point_2d[7]), color, line_width, cv2.LINE_AA) cv2.line(img, tuple(point_2d[3]), tuple( point_2d[8]), color, line_width, cv2.LINE_AA) return img def viz_pose(img, param_lst, ver_lst, show_flag=False, wfp=None): for param, ver in zip(param_lst, ver_lst): P, pose = calc_pose(param) img = plot_pose_box(img, P, ver) # print(P[:, :3]) # print(f'yaw: {pose[0]:.1f}, pitch: {pose[1]:.1f}, roll: {pose[2]:.1f}') if wfp is not None: cv2.imwrite(wfp, img) print(f'Save visualization result to {wfp}') if show_flag: plot_image(img) return img def pose_6(param): P = param[:12].reshape(3, -1) # camera matrix s, R, t3d = P2sRt(P) P = np.concatenate((R, t3d.reshape(3, -1)), axis=1) # without scale pose = matrix2angle(R) # Convert the rotation matrix R to Euler angle form to obtain the pose. # print(t3d) R1 = angle2matrix(pose) # print(R) # print(R1) pose = [p * 180 / np.pi for p in pose] return s, pose, t3d, P # s(scale)、R(roate)、t3d(shift) def smooth_pose(img, param_lst, ver_lst, pose_new, show_flag=False, wfp=None, wnp = None): for param, ver in zip(param_lst, ver_lst): t3d = np.array([pose_new[4],pose_new[5],pose_new[6]]) theta = np.array([pose_new[0],pose_new[1],pose_new[2]]) theta = [p * np.pi / 180 for p in theta] R = angle2matrix(theta) P = np.concatenate((R, t3d.reshape(3, -1)), axis=1) img = plot_pose_box(img, P, ver) # print(P,P.shape,t3d) # print(P,pose_new) # print(f'yaw: {theta[0]:.1f}, pitch: {theta[1]:.1f}, roll: {theta[2]:.1f}') all_pose = [0] all_pose = np.array(all_pose) if wfp is not None: cv2.imwrite(wfp, img) print(f'Save visualization result to {wfp}') if wnp is not None: np.save(wnp, all_pose) print(f'Save visualization result to {wfp}') if show_flag: plot_image(img) return img def get_pose(img, param_lst, ver_lst, show_flag=False, wfp=None, wnp = None): for param, ver in zip(param_lst, ver_lst): # only one loop s, pose, t3d, P = pose_6(param) img_1 = plot_pose_box(img.copy(), P, ver) # print(P,P.shape,t3d) # print(f'yaw: {pose[0]:.1f}, pitch: {pose[1]:.1f}, roll: {pose[2]:.1f}') all_pose = [pose[0],pose[1],pose[2],s,t3d[0],t3d[1],t3d[2]] all_pose = np.array(all_pose) # if wfp is not None: # cv2.imwrite(wfp, img_1) # print(f'Save visualization result to {wfp}') # if wnp is not None: # np.save(wnp, all_pose) # print(f'Save visualization result to {wfp}') if show_flag: plot_image(img) return all_pose ================================================ FILE: extract_init_states/utils/render.py ================================================ # coding: utf-8 __author__ = 'cleardusk' import sys sys.path.append('..') import cv2 import numpy as np from Sim3DR import RenderPipeline from utils.functions import plot_image from .tddfa_util import _to_ctype cfg = { 'intensity_ambient': 0.3, 'color_ambient': (1, 1, 1), 'intensity_directional': 0.6, 'color_directional': (1, 1, 1), 'intensity_specular': 0.1, 'specular_exp': 5, 'light_pos': (0, 0, 5), 'view_pos': (0, 0, 5) } render_app = RenderPipeline(**cfg) def render(img, ver_lst, tri, alpha=0.6, show_flag=False, wfp=None, with_bg_flag=True): if with_bg_flag: overlap = img.copy() else: overlap = np.zeros_like(img) for ver_ in ver_lst: ver = _to_ctype(ver_.T) # transpose overlap = render_app(ver, tri, overlap) if with_bg_flag: res = cv2.addWeighted(img, 1 - alpha, overlap, alpha, 0) else: res = overlap if wfp is not None: cv2.imwrite(wfp, res) print(f'Save visualization result to {wfp}') if show_flag: plot_image(res) return res ================================================ FILE: extract_init_states/utils/render_ctypes.py ================================================ # coding: utf-8 """ Borrowed from https://github.com/1996scarlet/Dense-Head-Pose-Estimation/blob/main/service/CtypesMeshRender.py To use this render, you should build the clib first: ``` cd utils/asset gcc -shared -Wall -O3 render.c -o render.so -fPIC cd ../.. ``` """ import sys sys.path.append('..') import os.path as osp import cv2 import numpy as np import ctypes from utils.functions import plot_image make_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn) class TrianglesMeshRender(object): def __init__( self, clibs, light=(0, 0, 5), direction=(0.6, 0.6, 0.6), ambient=(0.3, 0.3, 0.3) ): if not osp.exists(clibs): raise Exception(f'{clibs} not found, please build it first, by run ' f'"gcc -shared -Wall -O3 render.c -o render.so -fPIC" in utils/asset directory') self._clibs = ctypes.CDLL(clibs) self._light = np.array(light, dtype=np.float32) self._light = np.ctypeslib.as_ctypes(self._light) self._direction = np.array(direction, dtype=np.float32) self._direction = np.ctypeslib.as_ctypes(self._direction) self._ambient = np.array(ambient, dtype=np.float32) self._ambient = np.ctypeslib.as_ctypes(self._ambient) def __call__(self, vertices, triangles, bg): self.triangles = np.ctypeslib.as_ctypes(3 * triangles) # Attention self.tri_nums = triangles.shape[0] self._clibs._render( self.triangles, self.tri_nums, self._light, self._direction, self._ambient, np.ctypeslib.as_ctypes(vertices), vertices.shape[0], np.ctypeslib.as_ctypes(bg), bg.shape[0], bg.shape[1] ) render_app = TrianglesMeshRender(clibs=make_abs_path('asset/render.so')) def render(img, ver_lst, tri, alpha=0.6, show_flag=False, wfp=None, with_bg_flag=True): if with_bg_flag: overlap = img.copy() else: overlap = np.zeros_like(img) for ver_ in ver_lst: ver = np.ascontiguousarray(ver_.T) # transpose render_app(ver, tri, bg=overlap) if with_bg_flag: res = cv2.addWeighted(img, 1 - alpha, overlap, alpha, 0) else: res = overlap if wfp is not None: cv2.imwrite(wfp, res) print(f'Save visualization result to {wfp}') if show_flag: plot_image(res) return res ================================================ FILE: extract_init_states/utils/serialization.py ================================================ # coding: utf-8 __author__ = 'cleardusk' import numpy as np from .tddfa_util import _to_ctype from .functions import get_suffix header_temp = """ply format ascii 1.0 element vertex {} property float x property float y property float z element face {} property list uchar int vertex_indices end_header """ def ser_to_ply_single(ver_lst, tri, height, wfp, reverse=True): suffix = get_suffix(wfp) for i, ver in enumerate(ver_lst): wfp_new = wfp.replace(suffix, f'_{i + 1}{suffix}') n_vertex = ver.shape[1] n_face = tri.shape[0] header = header_temp.format(n_vertex, n_face) with open(wfp_new, 'w') as f: f.write(header + '\n') for i in range(n_vertex): x, y, z = ver[:, i] if reverse: f.write(f'{x:.2f} {height-y:.2f} {z:.2f}\n') else: f.write(f'{x:.2f} {y:.2f} {z:.2f}\n') for i in range(n_face): idx1, idx2, idx3 = tri[i] # m x 3 if reverse: f.write(f'3 {idx3} {idx2} {idx1}\n') else: f.write(f'3 {idx1} {idx2} {idx3}\n') print(f'Dump tp {wfp_new}') def ser_to_ply_multiple(ver_lst, tri, height, wfp, reverse=True): n_ply = len(ver_lst) # count ply if n_ply <= 0: return n_vertex = ver_lst[0].shape[1] n_face = tri.shape[0] header = header_temp.format(n_vertex * n_ply, n_face * n_ply) with open(wfp, 'w') as f: f.write(header + '\n') for i in range(n_ply): ver = ver_lst[i] for j in range(n_vertex): x, y, z = ver[:, j] if reverse: f.write(f'{x:.2f} {height - y:.2f} {z:.2f}\n') else: f.write(f'{x:.2f} {y:.2f} {z:.2f}\n') for i in range(n_ply): offset = i * n_vertex for j in range(n_face): idx1, idx2, idx3 = tri[j] # m x 3 if reverse: f.write(f'3 {idx3 + offset} {idx2 + offset} {idx1 + offset}\n') else: f.write(f'3 {idx1 + offset} {idx2 + offset} {idx3 + offset}\n') print(f'Dump tp {wfp}') def get_colors(img, ver): h, w, _ = img.shape ver[0, :] = np.minimum(np.maximum(ver[0, :], 0), w - 1) # x ver[1, :] = np.minimum(np.maximum(ver[1, :], 0), h - 1) # y ind = np.round(ver).astype(np.int32) colors = img[ind[1, :], ind[0, :], :] / 255. # n x 3 return colors.copy() def ser_to_obj_single(img, ver_lst, tri, height, wfp): suffix = get_suffix(wfp) n_face = tri.shape[0] for i, ver in enumerate(ver_lst): colors = get_colors(img, ver) n_vertex = ver.shape[1] wfp_new = wfp.replace(suffix, f'_{i + 1}{suffix}') with open(wfp_new, 'w') as f: for i in range(n_vertex): x, y, z = ver[:, i] f.write( f'v {x:.2f} {height - y:.2f} {z:.2f} {colors[i, 2]:.2f} {colors[i, 1]:.2f} {colors[i, 0]:.2f}\n') for i in range(n_face): idx1, idx2, idx3 = tri[i] # m x 3 f.write(f'f {idx3 + 1} {idx2 + 1} {idx1 + 1}\n') print(f'Dump tp {wfp_new}') def ser_to_obj_multiple(img, ver_lst, tri, height, wfp): n_obj = len(ver_lst) # count obj if n_obj <= 0: return n_vertex = ver_lst[0].shape[1] n_face = tri.shape[0] with open(wfp, 'w') as f: for i in range(n_obj): ver = ver_lst[i] colors = get_colors(img, ver) for j in range(n_vertex): x, y, z = ver[:, j] f.write( f'v {x:.2f} {height - y:.2f} {z:.2f} {colors[j, 2]:.2f} {colors[j, 1]:.2f} {colors[j, 0]:.2f}\n') for i in range(n_obj): offset = i * n_vertex for j in range(n_face): idx1, idx2, idx3 = tri[j] # m x 3 f.write(f'f {idx3 + 1 + offset} {idx2 + 1 + offset} {idx1 + 1 + offset}\n') print(f'Dump tp {wfp}') ser_to_ply = ser_to_ply_multiple # ser_to_ply_single ser_to_obj = ser_to_obj_multiple # ser_to_obj_multiple ================================================ FILE: extract_init_states/utils/tddfa_util.py ================================================ # coding: utf-8 __author__ = 'cleardusk' import sys sys.path.append('..') import argparse import numpy as np import torch def _to_ctype(arr): if not arr.flags.c_contiguous: return arr.copy(order='C') return arr def str2bool(v): if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Boolean value expected') def load_model(model, checkpoint_fp): checkpoint = torch.load(checkpoint_fp, map_location=lambda storage, loc: storage)['state_dict'] model_dict = model.state_dict() # because the model is trained by multiple gpus, prefix module should be removed for k in checkpoint.keys(): kc = k.replace('module.', '') if kc in model_dict.keys(): model_dict[kc] = checkpoint[k] if kc in ['fc_param.bias', 'fc_param.weight']: model_dict[kc.replace('_param', '')] = checkpoint[k] model.load_state_dict(model_dict) return model class ToTensorGjz(object): def __call__(self, pic): if isinstance(pic, np.ndarray): img = torch.from_numpy(pic.transpose((2, 0, 1))) return img.float() def __repr__(self): return self.__class__.__name__ + '()' class NormalizeGjz(object): def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, tensor): tensor.sub_(self.mean).div_(self.std) return tensor def similar_transform(pts3d, roi_box, size): pts3d[0, :] -= 1 # for Python compatibility pts3d[2, :] -= 1 pts3d[1, :] = size - pts3d[1, :] sx, sy, ex, ey = roi_box scale_x = (ex - sx) / size scale_y = (ey - sy) / size pts3d[0, :] = pts3d[0, :] * scale_x + sx pts3d[1, :] = pts3d[1, :] * scale_y + sy s = (scale_x + scale_y) / 2 pts3d[2, :] *= s pts3d[2, :] -= np.min(pts3d[2, :]) return np.array(pts3d, dtype=np.float32) def _parse_param(param): """matrix pose form param: shape=(trans_dim+shape_dim+exp_dim,), i.e., 62 = 12 + 40 + 10 """ # pre-defined templates for parameter n = param.shape[0] if n == 62: trans_dim, shape_dim, exp_dim = 12, 40, 10 elif n == 72: trans_dim, shape_dim, exp_dim = 12, 40, 20 elif n == 141: trans_dim, shape_dim, exp_dim = 12, 100, 29 else: raise Exception(f'Undefined templated param parsing rule') R_ = param[:trans_dim].reshape(3, -1) R = R_[:, :3] offset = R_[:, -1].reshape(3, 1) alpha_shp = param[trans_dim:trans_dim + shape_dim].reshape(-1, 1) alpha_exp = param[trans_dim + shape_dim:].reshape(-1, 1) return R, offset, alpha_shp, alpha_exp ================================================ FILE: extract_init_states/utils/uv.py ================================================ # coding: utf-8 __author__ = 'cleardusk' import sys sys.path.append('..') import cv2 import numpy as np import os.path as osp import scipy.io as sio from Sim3DR import rasterize from utils.functions import plot_image from utils.io import _load from utils.tddfa_util import _to_ctype make_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn) def load_uv_coords(fp): C = sio.loadmat(fp) uv_coords = C['UV'].copy(order='C').astype(np.float32) return uv_coords def process_uv(uv_coords, uv_h=256, uv_w=256): uv_coords[:, 0] = uv_coords[:, 0] * (uv_w - 1) uv_coords[:, 1] = uv_coords[:, 1] * (uv_h - 1) uv_coords[:, 1] = uv_h - uv_coords[:, 1] - 1 uv_coords = np.hstack((uv_coords, np.zeros((uv_coords.shape[0], 1), dtype=np.float32))) # add z return uv_coords g_uv_coords = load_uv_coords(make_abs_path('../configs/BFM_UV.mat')) indices = _load(make_abs_path('../configs/indices.npy')) # todo: handle bfm_slim g_uv_coords = g_uv_coords[indices, :] def get_colors(img, ver): # nearest-neighbor sampling [h, w, _] = img.shape ver[0, :] = np.minimum(np.maximum(ver[0, :], 0), w - 1) # x ver[1, :] = np.minimum(np.maximum(ver[1, :], 0), h - 1) # y ind = np.round(ver).astype(np.int32) colors = img[ind[1, :], ind[0, :], :] # n x 3 return colors def bilinear_interpolate(img, x, y): """ https://stackoverflow.com/questions/12729228/simple-efficient-bilinear-interpolation-of-images-in-numpy-and-python """ x0 = np.floor(x).astype(np.int32) x1 = x0 + 1 y0 = np.floor(y).astype(np.int32) y1 = y0 + 1 x0 = np.clip(x0, 0, img.shape[1] - 1) x1 = np.clip(x1, 0, img.shape[1] - 1) y0 = np.clip(y0, 0, img.shape[0] - 1) y1 = np.clip(y1, 0, img.shape[0] - 1) i_a = img[y0, x0] i_b = img[y1, x0] i_c = img[y0, x1] i_d = img[y1, x1] wa = (x1 - x) * (y1 - y) wb = (x1 - x) * (y - y0) wc = (x - x0) * (y1 - y) wd = (x - x0) * (y - y0) return wa[..., np.newaxis] * i_a + wb[..., np.newaxis] * i_b + wc[..., np.newaxis] * i_c + wd[..., np.newaxis] * i_d def uv_tex(img, ver_lst, tri, uv_h=256, uv_w=256, uv_c=3, show_flag=False, wfp=None): uv_coords = process_uv(g_uv_coords.copy(), uv_h=uv_h, uv_w=uv_w) res_lst = [] for ver_ in ver_lst: ver = _to_ctype(ver_.T) # transpose to m x 3 colors = bilinear_interpolate(img, ver[:, 0], ver[:, 1]) / 255. # `rasterize` here serves as texture sampling, may need to optimization res = rasterize(uv_coords, tri, colors, height=uv_h, width=uv_w, channel=uv_c) res_lst.append(res) # concat if there more than one image res = np.concatenate(res_lst, axis=1) if len(res_lst) > 1 else res_lst[0] if wfp is not None: cv2.imwrite(wfp, res) print(f'Save visualization result to {wfp}') if show_flag: plot_image(res) return res ================================================ FILE: extract_init_states/weights/.gitignore ================================================ # checkpoints/ # *.pth # *.onnx ================================================ FILE: extract_init_states/weights/mb1_120x120.onnx ================================================ [File too large to display: 12.4 MB] ================================================ FILE: extract_init_states/weights/mb1_120x120.pth ================================================ [File too large to display: 13.1 MB] ================================================ FILE: extract_init_states/weights/readme.md ================================================ ## Pre-converted onnx model | Model | Link | | :-: | :-: | | `mb1_120x120.onnx` | [Google Drive](https://drive.google.com/file/d/1YpO1KfXvJHRmCBkErNa62dHm-CUjsoIk/view?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1qpQBd5KOS0-5lD6jZKXZ-Q) (Password: cqbx) | | `mb05_120x120.onnx` | [Google Drive](https://drive.google.com/file/d/1orJFiZPshmp7jmCx_D0tvIEtPYtnFvHS/view?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1sRaBOA5wHu6PFS1Qd-TBFA) (Password: 8qst) | | `resnet22.onnx` | [Google Drive](https://drive.google.com/file/d/1rRyrd7Ar-QYTi1hRHOYHspT8PTyXQ5ds/view?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1Nzkw7Ie_5trKvi1JYxymJA) (Password: 1op6) | | `resnet22.pth` | [Google Drive](https://drive.google.com/file/d/1dh7JZgkj1IaO4ZcSuBOBZl2suT9EPedV/view?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1IS7ncVxhw0f955ySg67Y4A) (Password: lv1a) | ================================================ FILE: filter_fourier.py ================================================ import torch import torch.fft import torchvision.transforms as transforms import cv2 import numpy as np import matplotlib.pyplot as plt import cv2 import numpy as np # Filtering function: Input optical flow field to filter out high-frequency noise. def gaussian_pdf(x, mean, std): return (1 / (std * torch.sqrt(2 * torch.tensor(3.141592653589793))) * torch.exp(-((x - mean) ** 2) / (2 * std ** 2))) def gaussian_density(length = 20, amplitude = 2, mean = 19, sigma = 3): x = torch.arange(0, length, 1.0) gaussian = amplitude * torch.exp(-(x - mean)**2 / (2 * sigma**2)) gaussian = torch.clip(gaussian, max = 1, min = 0) return gaussian.cuda() def fourier_filter(fea): L, C , H , W = fea.shape mean = 0 std = 3 _x = torch.linspace(-10, 10, H) # Define 128 values within the range of -5 to 5. X, Y = torch.meshgrid(_x, _x) # Generate grid coordinates. gaussian_map = (gaussian_pdf(X, mean, std).cuda()) * (gaussian_pdf(Y, mean, std).cuda()) gaussian_map = gaussian_map.unsqueeze(0).repeat(1, C, 1, 1) gaussian_map = torch.clip((gaussian_map)/gaussian_map.max() * 3 , min = 0, max = 1) # lowpass_filter = torch.zeros(H,H).cuda() # for i in range(H): # for j in range(H): # if np.sqrt((i - H//2)**2 + (j - H//2)**2) <= 10: # lowpass_filter[i, j] = 1 x = torch.fft.fft2(fea, dim=(-2, -1)) x_shifted = torch.fft.fftshift(x) # 1,3,128,128 x_shifted = x_shifted * gaussian_map# lowpass_filter # * gaussian_map reconstructed_x = torch.fft.ifftshift(x_shifted) reconstructed_x = torch.fft.ifft2(reconstructed_x, dim=(-2, -1)) reconstructed_x = torch.real(reconstructed_x) return reconstructed_x def fourier_filter_1D(fea, dim): # idex = freq * L / 25 L, C , H , W = fea.shape mean = 0 std = 3 fft_result = torch.fft.rfft(fea, dim=dim) # 低通滤波 cutoff_freq = 10 # 保留前 10 个频率 # mask = gaussian_density(length = L, mean = 0, sigma = 5, amplitude = 2)[:, None, None, None] # fft_result = mask * fft_result fft_result[L//4:] = 0 # 设置高频部分为 0 # 对 H 维度进行逆傅里叶变换 filtered_tensor = torch.fft.irfft(fft_result,n= L, dim=dim) filtered_tensor = torch.real(filtered_tensor) return filtered_tensor def hf_loss(fea, mask, dim): mask = 1- mask # gaussian_density(length = L, mean = 0, sigma = 12, amplitude = 2) fft_result = torch.fft.rfft(fea, dim=dim) fft_result = fft_result * mask fft_result = fft_result.abs() return fft_result def hf_loss_2(fea_x, fea_y, dim): ''' 与GT计算频域损失 ''' fft_result_x = torch.fft.rfft(fea_x, dim=dim) fft_result_y = torch.fft.rfft(fea_y, dim=dim) # fft_result = fft_result.abs() loss = (fft_result_y - fft_result_x).abs() return loss class KalmanFilter1D: def __init__(self, A, H, Q, R, x_init, P_init): self.A = torch.tensor(A, requires_grad=False) self.H = torch.tensor(H, requires_grad=False) self.Q = torch.tensor(Q, requires_grad=False) self.R = torch.tensor(R, requires_grad=False) self.x = torch.tensor(x_init, requires_grad=True) self.P = torch.tensor(P_init, requires_grad=True) def update(self, z): # 预测步骤 x_pred = self.A * self.x P_pred = self.A * self.P * self.A + self.Q # 更新步骤 K = P_pred * self.H / (self.H * P_pred * self.H + self.R) self.x = x_pred + K * (z - self.H * x_pred) self.P = (1 - K * self.H) * P_pred return self.x def kalman_filter(observations, dim): kf = KalmanFilter1D(A=1., H=1., Q=0.01, R=0.1, x_init=0., P_init=1.) filtered_values = torch.zeros_like(observations) for idx in range(observations.size(dim)): obs_slice = tuple(slice(None) if i != dim else idx for i in range(len(observations.size()))) obs = observations[obs_slice] filtered_value = kf.update(obs) filtered_values[obs_slice] = filtered_value return filtered_values def naive_filter(fea): L, C , H , W = fea.shape fea_mask = fea.abs()>(1/64) fea = fea*fea_mask return fea # def fourier_filter(x): # L, C , H , W = x.shape # mean = 0 # std = 3 # _x = torch.linspace(-5, 5, H) # 定义一个范围为-5到5的128个值 # X, Y = torch.meshgrid(_x, _x) # Generate grid coordinates. # gaussian_map = (gaussian_pdf(X, mean, std).cuda()) * (gaussian_pdf(Y, mean, std).cuda()) # gaussian_map = gaussian_map.unsqueeze(0).repeat(1, C, 1, 1) # gaussian_map = (gaussian_map)/gaussian_map.max() # x = torch.fft.fft2(x, dim=(-2, -1)) # x_shifted = torch.fft.fftshift(x) # 1,3,128,128 # x_shifted = x_shifted # * gaussian_map # reconstructed_x = torch.fft.ifftshift(x_shifted) # reconstructed_x = torch.fft.ifft2(reconstructed_x, dim=(-2, -1)) # reconstructed_x = torch.abs(reconstructed_x) # return reconstructed_x if __name__ == '__main__': # 读取视频 gd = gaussian_density(length = 20, mean = 0, sigma = 5, amplitude = 2) print(gd) print(gd[:10]) # cap = cv2.VideoCapture('your_path/demo/s2_20w_newae_crema_s1_10_s2_11-j-sl-vr-of-tr-rmm-ddim0200_1.00/7_s76_1076_ITH_FEA_XX.mp4') # # 生成均值为0,标准差为3的高斯概率密度分布张量 # mean = 0 # std = 3 # x = torch.linspace(-5, 5, 128) # 定义一个范围为-5到5的128个值 # X, Y = torch.meshgrid(x, x) # Generate grid coordinates. # gaussian_map = gaussian_pdf(X, mean, std) * gaussian_pdf(Y, mean, std) # gaussian_map = gaussian_map.unsqueeze(0).repeat(1, 3, 1, 1) # gaussian_map = ( gaussian_map)/gaussian_map.max() # # 输入数据,假设frames是一个包含L帧RGB图像的numpy数组,形状为(L, 3, H, W) # frames = np.random.randint(0, 255, (100, 3, 256, 256)).astype(np.uint8) # # 设置输出视频的名称、帧率和分辨率 # def generate_video(frames): # video_name = 'output_video.avi' # fps = 25 # resolution = (128, 128) # # 创建视频写入对象 # fourcc = cv2.VideoWriter_fourcc(*'XVID') # video = cv2.VideoWriter(video_name, fourcc, fps, resolution) # # 逐帧将图像写入视频 # for i in range(frames.shape[0]): # frame = frames[i][:,:,:].transpose(1, 2, 0).astype(np.uint8) # 调整通道顺序(H, W, 3) # video.write(frame) # # 释放资源并保存视频 # video.release() # # 存储还原后的图像帧 # reconstructed_frames = [] # # 循环遍历视频的每一帧 # while(cap.isOpened()): # ret, frame = cap.read() # if not ret: # break # # 将当前帧转换为 PyTorch 张量 # frame = torch.tensor(frame) # frame = frame.permute(2, 0, 1).unsqueeze(0).float() # # 对当前帧进行 2D 傅里叶变换 # fft_frame = torch.fft.fft2(frame, dim=(-2, -1)) # fft_frame_shifted = torch.fft.fftshift(fft_frame) # 1,3,128,128 # # 将频域展开形式还原回图像 # # fft_frame_shifted = fft_frame_shifted * gaussian_map # reconstructed_frame = torch.fft.ifftshift(fft_frame_shifted) # reconstructed_frame = torch.fft.ifft2(reconstructed_frame, dim=(-2, -1)) # reconstructed_frame = torch.abs(reconstructed_frame) # # 将还原后的图像帧添加到列表中 # reconstructed_frames.append(reconstructed_frame) # # 将还原后的图像帧转换为数组 # reconstructed_frames = torch.cat(reconstructed_frames, dim=0) # # 将还原后的图像帧转换为 numpy 数组 # reconstructed_frames = (reconstructed_frames).to(torch.int32) # reconstructed_frames = reconstructed_frames.squeeze(1).numpy() # # 显示还原后的视频 # generate_video(reconstructed_frames) ================================================ FILE: hubert_extract/data_gen/process_lrs3/binarizer.py ================================================ import os import numpy as np from scipy.misc import face import torch from tqdm import trange import pickle from copy import deepcopy from data_util.face3d_helper import Face3DHelper from utils.commons.indexed_datasets import IndexedDataset, IndexedDatasetBuilder def load_video_npy(fn): assert fn.endswith(".npy") ret_dict = np.load(fn,allow_pickle=True).item() video_dict = { 'coeff': ret_dict['coeff'], # [T, h] 'lm68': ret_dict['lm68'], # [T, 68, 2] 'lm5': ret_dict['lm5'], # [T, 5, 2] } return video_dict def cal_lm3d_in_video_dict(video_dict, face3d_helper): coeff = torch.from_numpy(video_dict['coeff']).float() identity = coeff[:, 0:80] exp = coeff[:, 80:144] idexp_lm3d = face3d_helper.reconstruct_idexp_lm3d(identity, exp).cpu().numpy() video_dict['idexp_lm3d'] = idexp_lm3d def load_audio_npy(fn): assert fn.endswith(".npy") ret_dict = np.load(fn,allow_pickle=True).item() audio_dict = { "mel": ret_dict['mel'], # [T, 80] "f0": ret_dict['f0'], # [T,1] } return audio_dict if __name__ == '__main__': face3d_helper = Face3DHelper(use_gpu=False) import glob,tqdm prefixs = ['val', 'train'] binarized_ds_path = "data/binary/lrs3" os.makedirs(binarized_ds_path, exist_ok=True) for prefix in prefixs: databuilder = IndexedDatasetBuilder(os.path.join(binarized_ds_path, prefix), gzip=False) raw_base_dir = '/home/yezhenhui/datasets/raw/lrs3_raw' spk_ids = sorted([dir_name.split("/")[-1] for dir_name in glob.glob(raw_base_dir + "/*")]) spk_id2spk_idx = {spk_id : i for i,spk_id in enumerate(spk_ids) } np.save(os.path.join(binarized_ds_path, "spk_id2spk_idx.npy"), spk_id2spk_idx, allow_pickle=True) mp4_names = glob.glob(raw_base_dir + "/*/*.mp4") cnt = 0 for i, mp4_name in tqdm.tqdm(enumerate(mp4_names), total=len(mp4_names)): if prefix == 'train': if i % 100 == 0: continue else: if i % 100 != 0: continue lst = mp4_name.split("/") spk_id = lst[-2] clip_id = lst[-1][:-4] audio_npy_name = os.path.join(raw_base_dir, spk_id, clip_id+"_audio.npy") hubert_npy_name = os.path.join(raw_base_dir, spk_id, clip_id+"_hubert.npy") video_npy_name = os.path.join(raw_base_dir, spk_id, clip_id+"_coeff_pt.npy") if (not os.path.exists(audio_npy_name)) or (not os.path.exists(video_npy_name)): print(f"Skip item for not found.") continue if (not os.path.exists(hubert_npy_name)): print(f"Skip item for hubert_npy not found.") continue audio_dict = load_audio_npy(audio_npy_name) hubert = np.load(hubert_npy_name) video_dict = load_video_npy(video_npy_name) cal_lm3d_in_video_dict(video_dict, face3d_helper) mel = audio_dict['mel'] if mel.shape[0] < 64: # the video is shorter than 0.6s print(f"Skip item for too short.") continue audio_dict.update(video_dict) audio_dict['spk_id'] = spk_id audio_dict['spk_idx'] = spk_id2spk_idx[spk_id] audio_dict['item_id'] = spk_id + "_" + clip_id audio_dict['hubert'] = hubert # [T_x, hid=1024] databuilder.add_item(audio_dict) cnt += 1 databuilder.finalize() print(f"{prefix} set has {cnt} samples!") ================================================ FILE: hubert_extract/data_gen/process_lrs3/process_audio_hubert.py ================================================ from genericpath import exists from transformers import Wav2Vec2Processor, HubertModel import soundfile as sf import numpy as np import torch from tqdm import tqdm import fairseq print("Loading the Wav2Vec2 Processor...") # wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") wav2vec2_processor = Wav2Vec2Processor.from_pretrained("/train20/intern/permanent/hbcheng2/lmlin2/Flow/GeneFace-main/checkpoints/hubert_ckp") print("Loading the HuBERT Model...") hubert_model = HubertModel.from_pretrained("/train20/intern/permanent/hbcheng2/lmlin2/Flow/GeneFace-main/checkpoints/hubert_ckp", from_tf = True) def get_hubert_from_16k_wav(wav_16k_name): speech_16k, _ = sf.read(wav_16k_name) hubert = get_hubert_from_16k_speech(speech_16k) return hubert @torch.no_grad() def get_hubert_from_16k_speech(speech, device="cuda:1"): global hubert_model hubert_model = hubert_model.to(device) if speech.ndim ==2: speech = speech[:, 0] # [T, 2] ==> [T,] input_values_all = wav2vec2_processor(speech, return_tensors="pt", sampling_rate=16000).input_values # [1, T] input_values_all = input_values_all.to(device) # For long audio sequence, due to the memory limitation, we cannot process them in one run # HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320 # Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step. # So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320 # We have the equation to calculate out time step: T = floor((t-k)/s) # To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip # The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N kernel = 400 stride = 320 clip_length = stride * 1000 num_iter = input_values_all.shape[1] // clip_length expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride res_lst = [] for i in range(num_iter): if i == 0: start_idx = 0 end_idx = clip_length - stride + kernel else: start_idx = clip_length * i end_idx = start_idx + (clip_length - stride + kernel) input_values = input_values_all[:, start_idx: end_idx] hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024] res_lst.append(hidden_states[0]) if num_iter > 0: input_values = input_values_all[:, clip_length * num_iter:] else: input_values = input_values_all # if input_values.shape[1] != 0: if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024] res_lst.append(hidden_states[0]) ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024] # assert ret.shape[0] == expected_T assert abs(ret.shape[0] - expected_T) <= 1 if ret.shape[0] < expected_T: ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0])) else: ret = ret[:expected_T] return ret if __name__ == '__main__': ## Process Single Long Audio for NeRF dataset # person_id = 'May' import os # demo test # wav_16k_name = f"/train20/intern/permanent/lmlin2/data/audio_lesson_01-j-w.wav" # npy_name = 'demo_test-j-w' # demo_npy_name = f"/train20/intern/permanent/lmlin2/data/{npy_name}.npy" # speech_16k, _ = sf.read(wav_16k_name) # hubert_hidden = get_hubert_from_16k_speech(speech_16k) # np.save(demo_npy_name, hubert_hidden.detach().numpy()) # hdtf dataset image_path = '/train20/intern/permanent/lmlin2/data/hdtf_image_50hz' image_path = '/yrfs2/cv2/pcxia/audiovisual/hdtf/images_25hz' wav_path = '/yrfs2/cv2/pcxia/audiovisual/hdtf/image_audio' for wavfile in os.listdir(wav_path): frames = os.listdir(os.path.join(image_path, wavfile[0:-4])) frames.sort() num_frames = len(frames) wav_16k_name = os.path.join(wav_path, wavfile) # wav_16k_name = f"/yrfs2/cv2/pcxia/audiovisual/hdtf/image_audio/RD_Radio1_000.wav" #(3749, 1024) # wav_16k_name = f"data/processed/videos/{person_id}/aud.wav" # wav_16k_name = f"/train20/intern/permanent/lmlin2/Flow/GeneFace-main/data/raw/val_wavs/zozo.wav" # 543 1024 # wav_16k_name = f"/train20/intern/permanent/lmlin2/data/audio_lesson_01.wav" npy_name = wavfile[0:-4] hubert_npy_name = f"/train20/intern/permanent/lmlin2/data/hdtf_wav_hubert/{npy_name}.npy" speech_16k, _ = sf.read(wav_16k_name) hubert_hidden = get_hubert_from_16k_speech(speech_16k) print(f'fnum{num_frames},hubersize{hubert_hidden.shape[0]}') np.save(hubert_npy_name, hubert_hidden.detach().numpy()) # crema dataset # image_path = '/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images' # wav_path = '/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/audio' # save_path = '/train20/intern/permanent/lmlin2/data/crema_wav_hubert' # for id_name in tqdm(os.listdir(wav_path)): # for wavfile in os.listdir(os.path.join(wav_path, id_name)): # frame_dir = os.path.join(image_path, id_name, wavfile[0:-4]) # if not exists(frame_dir): # print(f'{frame_dir} does not exist!') # continue # frames= os.listdir(frame_dir) # frames.sort() # num_frames = len(frames) # wav_16k_name = os.path.join(wav_path, id_name, wavfile) # npy_name = wavfile[0:-4] # save_dir = os.path.join(save_path,id_name) # hubert_npy_name = os.path.join(save_dir,npy_name+'.npy') # if exists(hubert_npy_name): # print(f'{hubert_npy_name} exists!') # continue # if not exists(save_dir): # os.makedirs(save_dir) # speech_16k, _ = sf.read(wav_16k_name) # hubert_hidden = get_hubert_from_16k_speech(speech_16k) # print(f'fnum{num_frames},hubersize{hubert_hidden.shape[0]}') # np.save(hubert_npy_name, hubert_hidden.detach().numpy()) # ## Process short audio clips for LRS3 dataset # import glob, os, tqdm # lrs3_dir = '/home/yezhenhui/datasets/raw/lrs3_raw/' # wav_16k_names = glob.glob(os.path.join(lrs3_dir, '*/*.wav')) # for wav_16k_name in tqdm.tqdm(wav_16k_names, total=len(wav_16k_names)): # spk_id = wav_16k_name.split("/")[-2] # clip_id = wav_16k_name.split("/")[-1][:-4] # out_name = os.path.join(lrs3_dir, spk_id, clip_id+'_hubert.npy') # if os.path.exists(out_name): # continue # speech_16k, _ = sf.read(wav_16k_name) # hubert_hidden = get_hubert_from_16k_speech(speech_16k) # np.save(out_name, hubert_hidden.detach().numpy()) ================================================ FILE: hubert_extract/data_gen/process_lrs3/process_audio_hubert_interpolate.py ================================================ from genericpath import exists from transformers import Wav2Vec2Processor, HubertModel import soundfile as sf import numpy as np import torch from tqdm import tqdm import fairseq import decord from scipy.interpolate import interp1d print("Loading the Wav2Vec2 Processor...") # wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") wav2vec2_processor = Wav2Vec2Processor.from_pretrained("/train20/intern/permanent/hbcheng2/lmlin2/Flow/GeneFace-main/checkpoints/hubert_ckp") print("Loading the HuBERT Model...") hubert_model = HubertModel.from_pretrained("/train20/intern/permanent/hbcheng2/lmlin2/Flow/GeneFace-main/checkpoints/hubert_ckp", from_tf = True) def get_hubert_from_16k_wav(wav_16k_name): speech_16k, _ = sf.read(wav_16k_name) hubert = get_hubert_from_16k_speech(speech_16k) return hubert @torch.no_grad() def get_hubert_from_16k_speech(speech, device="cuda:1"): global hubert_model hubert_model = hubert_model.to(device) if speech.ndim ==2: speech = speech[:, 0] # [T, 2] ==> [T,] input_values_all = wav2vec2_processor(speech, return_tensors="pt", sampling_rate=16000).input_values # [1, T] input_values_all = input_values_all.to(device) # For long audio sequence, due to the memory limitation, we cannot process them in one run # HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320 # Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step. # So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320 # We have the equation to calculate out time step: T = floor((t-k)/s) # To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip # The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N kernel = 400 stride = 320 clip_length = stride * 1000 num_iter = input_values_all.shape[1] // clip_length expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride res_lst = [] for i in range(num_iter): if i == 0: start_idx = 0 end_idx = clip_length - stride + kernel else: start_idx = clip_length * i end_idx = start_idx + (clip_length - stride + kernel) input_values = input_values_all[:, start_idx: end_idx] hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024] res_lst.append(hidden_states[0]) if num_iter > 0: input_values = input_values_all[:, clip_length * num_iter:] else: input_values = input_values_all # if input_values.shape[1] != 0: if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024] res_lst.append(hidden_states[0]) ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024] # assert ret.shape[0] == expected_T assert abs(ret.shape[0] - expected_T) <= 1 if ret.shape[0] < expected_T: ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0])) else: ret = ret[:expected_T] return ret # decord.bridge.set_bridge('torch') from tqdm import tqdm if __name__ == '__main__': ## Process Single Long Audio for NeRF dataset # person_id = 'May' import os # demo test # wav_16k_name = f"/train20/intern/permanent/lmlin2/data/audio_lesson_01-j-w.wav" # npy_name = 'demo_test-j-w' # demo_npy_name = f"/train20/intern/permanent/lmlin2/data/{npy_name}.npy" # speech_16k, _ = sf.read(wav_16k_name) # hubert_hidden = get_hubert_from_16k_speech(speech_16k) # np.save(demo_npy_name, hubert_hidden.detach().numpy()) # hdtf dataset # image_path = '/train20/intern/permanent/lmlin2/data/hdtf_image_50hz' # video_path = '/train20/intern/permanent/hbcheng2/data/HDTF/video_25hz' # wav_path = '/yrfs2/cv2/pcxia/audiovisual/hdtf/image_audio' # save_path = '/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert' # interpolate_path = "/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert_interpolate" # if not os.path.exists(interpolate_path): # os.makedirs(interpolate_path) # if not os.path.exists(save_path): # os.makedirs(save_path) # for wavfile in tqdm(os.listdir(wav_path)): # video = os.path.join(video_path, wavfile[0:-4] + '.mp4') # vr = decord.VideoReader(video) # num_frames = len(vr) # wav_16k_name = os.path.join(wav_path, wavfile) # # wav_16k_name = f"/yrfs2/cv2/pcxia/audiovisual/hdtf/image_audio/RD_Radio1_000.wav" #(3749, 1024) # # wav_16k_name = f"data/processed/videos/{person_id}/aud.wav" # # wav_16k_name = f"/train20/intern/permanent/lmlin2/Flow/GeneFace-main/data/raw/val_wavs/zozo.wav" # 543 1024 # # wav_16k_name = f"/train20/intern/permanent/lmlin2/data/audio_lesson_01.wav" # npy_name = wavfile[0:-4] # hubert_npy_name = f"{save_path}/{npy_name}.npy" # hubert_npy_name_interpolate = f"{interpolate_path}/{npy_name}.npy" # if os.path.exists(hubert_npy_name) and os.path.exists(hubert_npy_name_interpolate): # continue # speech_16k, _ = sf.read(wav_16k_name) # hubert_hidden = get_hubert_from_16k_speech(speech_16k) # print(f'fnum{num_frames},hubersize{hubert_hidden.shape[0]}') # hubert_hidden = hubert_hidden.detach().numpy() # interp_func = interp1d(np.arange(hubert_hidden.shape[0]), hubert_hidden, kind='linear', axis=0) # hubert_feature_interpolated = interp_func(np.linspace(0, hubert_hidden.shape[0] - 1, num_frames)).astype(np.float32) # # torch.nn.functional.interpolate(hubert_hidden.unsqueeze(0).permute(0,2,1).cuda(), size=num_frames, mode='linear', align_corners=False).squeeze(0).permute(1, 0).cpu() # np.save(hubert_npy_name, hubert_hidden) # np.save(hubert_npy_name_interpolate, hubert_feature_interpolated) # crema dataset image_path = '/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images' wav_path = '/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/audio' save_path = '/train20/intern/permanent/hbcheng2/data/crema/hubert_25hz' for id_name in tqdm(os.listdir(wav_path)): for wavfile in os.listdir(os.path.join(wav_path, id_name)): frame_dir = os.path.join(image_path, id_name, wavfile[0:-4]) if not exists(frame_dir): print(f'{frame_dir} does not exist!') continue frames= os.listdir(frame_dir) frames.sort() num_frames = len(frames) wav_16k_name = os.path.join(wav_path, id_name, wavfile) npy_name = wavfile[0:-4] save_dir = os.path.join(save_path,id_name) hubert_npy_name = os.path.join(save_dir,npy_name+'.npy') if exists(hubert_npy_name): print(f'{hubert_npy_name} exists!') continue if not exists(save_dir): os.makedirs(save_dir) speech_16k, _ = sf.read(wav_16k_name) hubert_hidden = get_hubert_from_16k_speech(speech_16k) hubert_hidden = hubert_hidden.detach().numpy() interp_func = interp1d(np.arange(hubert_hidden.shape[0]), hubert_hidden, kind='linear', axis=0) hubert_feature_interpolated = interp_func(np.linspace(0, hubert_hidden.shape[0] - 1, num_frames)).astype(np.float32) print(f'fnum{num_frames},hubersize{hubert_hidden.shape[0]}') np.save(hubert_npy_name, hubert_feature_interpolated) # ## Process short audio clips for LRS3 dataset # import glob, os, tqdm # lrs3_dir = '/home/yezhenhui/datasets/raw/lrs3_raw/' # wav_16k_names = glob.glob(os.path.join(lrs3_dir, '*/*.wav')) # for wav_16k_name in tqdm.tqdm(wav_16k_names, total=len(wav_16k_names)): # spk_id = wav_16k_name.split("/")[-2] # clip_id = wav_16k_name.split("/")[-1][:-4] # out_name = os.path.join(lrs3_dir, spk_id, clip_id+'_hubert.npy') # if os.path.exists(out_name): # continue # speech_16k, _ = sf.read(wav_16k_name) # hubert_hidden = get_hubert_from_16k_speech(speech_16k) # np.save(out_name, hubert_hidden.detach().numpy()) ================================================ FILE: hubert_extract/data_gen/process_lrs3/process_audio_hubert_interpolate_batch.py ================================================ from genericpath import exists from transformers import Wav2Vec2Processor, HubertModel import soundfile as sf import numpy as np import torch from tqdm import tqdm import fairseq import decord from scipy.interpolate import interp1d print("Loading the Wav2Vec2 Processor...") # wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") wav2vec2_processor = Wav2Vec2Processor.from_pretrained("/train20/intern/permanent/hbcheng2/lmlin2/Flow/GeneFace-main/checkpoints/hubert_ckp") print("Loading the HuBERT Model...") hubert_model = HubertModel.from_pretrained("/train20/intern/permanent/hbcheng2/lmlin2/Flow/GeneFace-main/checkpoints/hubert_ckp", from_tf = True) def get_hubert_from_16k_wav(wav_16k_name): speech_16k, _ = sf.read(wav_16k_name) hubert = get_hubert_from_16k_speech(speech_16k) return hubert @torch.no_grad() def get_hubert_from_16k_speech(speech, device="cuda:3"): global hubert_model hubert_model = hubert_model.to(device) if speech.ndim ==2: speech = speech[:, 0] # [T, 2] ==> [T,] input_values_all = wav2vec2_processor(speech, return_tensors="pt", sampling_rate=16000).input_values # [1, T] input_values_all = input_values_all.to(device) # For long audio sequence, due to the memory limitation, we cannot process them in one run # HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320 # Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step. # So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320 # We have the equation to calculate out time step: T = floor((t-k)/s) # To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip # The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N kernel = 400 stride = 320 clip_length = stride * 1000 num_iter = input_values_all.shape[1] // clip_length expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride res_lst = [] for i in range(num_iter): if i == 0: start_idx = 0 end_idx = clip_length - stride + kernel else: start_idx = clip_length * i end_idx = start_idx + (clip_length - stride + kernel) input_values = input_values_all[:, start_idx: end_idx] hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024] res_lst.append(hidden_states[0]) if num_iter > 0: input_values = input_values_all[:, clip_length * num_iter:] else: input_values = input_values_all # if input_values.shape[1] != 0: if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024] res_lst.append(hidden_states[0]) ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024] # assert ret.shape[0] == expected_T assert abs(ret.shape[0] - expected_T) <= 1 if ret.shape[0] < expected_T: ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0])) else: ret = ret[:expected_T] return ret # decord.bridge.set_bridge('torch') from tqdm import tqdm if __name__ == '__main__': ## Process Single Long Audio for NeRF dataset # person_id = 'May' import os # demo test # wav_16k_name = f"/train20/intern/permanent/lmlin2/data/audio_lesson_01-j-w.wav" # npy_name = 'demo_test-j-w' # demo_npy_name = f"/train20/intern/permanent/lmlin2/data/{npy_name}.npy" # speech_16k, _ = sf.read(wav_16k_name) # hubert_hidden = get_hubert_from_16k_speech(speech_16k) # np.save(demo_npy_name, hubert_hidden.detach().numpy()) # hdtf dataset # image_path = '/train20/intern/permanent/lmlin2/data/hdtf_image_50hz' # video_path = '/train20/intern/permanent/hbcheng2/data/HDTF/video_25hz' # wav_path = '/yrfs2/cv2/pcxia/audiovisual/hdtf/image_audio' # save_path = '/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert' # interpolate_path = "/train20/intern/permanent/hbcheng2/data/HDTF/hdtf_wav_hubert_interpolate" # if not os.path.exists(interpolate_path): # os.makedirs(interpolate_path) # if not os.path.exists(save_path): # os.makedirs(save_path) # for wavfile in tqdm(os.listdir(wav_path)): # video = os.path.join(video_path, wavfile[0:-4] + '.mp4') # vr = decord.VideoReader(video) # num_frames = len(vr) # wav_16k_name = os.path.join(wav_path, wavfile) # # wav_16k_name = f"/yrfs2/cv2/pcxia/audiovisual/hdtf/image_audio/RD_Radio1_000.wav" #(3749, 1024) # # wav_16k_name = f"data/processed/videos/{person_id}/aud.wav" # # wav_16k_name = f"/train20/intern/permanent/lmlin2/Flow/GeneFace-main/data/raw/val_wavs/zozo.wav" # 543 1024 # # wav_16k_name = f"/train20/intern/permanent/lmlin2/data/audio_lesson_01.wav" # npy_name = wavfile[0:-4] # hubert_npy_name = f"{save_path}/{npy_name}.npy" # hubert_npy_name_interpolate = f"{interpolate_path}/{npy_name}.npy" # if os.path.exists(hubert_npy_name) and os.path.exists(hubert_npy_name_interpolate): # continue # speech_16k, _ = sf.read(wav_16k_name) # hubert_hidden = get_hubert_from_16k_speech(speech_16k) # print(f'fnum{num_frames},hubersize{hubert_hidden.shape[0]}') # hubert_hidden = hubert_hidden.detach().numpy() # interp_func = interp1d(np.arange(hubert_hidden.shape[0]), hubert_hidden, kind='linear', axis=0) # hubert_feature_interpolated = interp_func(np.linspace(0, hubert_hidden.shape[0] - 1, num_frames)).astype(np.float32) # # torch.nn.functional.interpolate(hubert_hidden.unsqueeze(0).permute(0,2,1).cuda(), size=num_frames, mode='linear', align_corners=False).squeeze(0).permute(1, 0).cpu() # np.save(hubert_npy_name, hubert_hidden) # np.save(hubert_npy_name_interpolate, hubert_feature_interpolated) # crema dataset # image_path = '/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/images' # wav_path = '/work1/cv2/pcxia/diffusion_v1/diffused-heads-colab-main/datasets/audio' # save_path = '/train20/intern/permanent/hbcheng2/data/crema/hubert_25hz' wav_path = '/train20/intern/permanent/hbcheng2/data/ood_video/audio_clip_2' save_path = '/train20/intern/permanent/hbcheng2/data/ood_video/audio_clip_hubert' # for id_name in tqdm(os.listdir(wav_path)): for wavfile in os.listdir(wav_path): # frame_dir = os.path.join(image_path, id_name, wavfile[0:-4]) # if not exists(frame_dir): # print(f'{frame_dir} does not exist!') # continue # frames= os.listdir(frame_dir) # frames.sort() # num_frames = len(frames) wav_16k_name = os.path.join(wav_path, wavfile) npy_name = wavfile[0:-4] save_dir = os.path.join(save_path) hubert_npy_name = os.path.join(save_dir,npy_name+'.npy') # if exists(hubert_npy_name): # print(f'{hubert_npy_name} exists!') # continue if not exists(save_dir): os.makedirs(save_dir) speech_16k, _ = sf.read(wav_16k_name) num_frames = int((speech_16k.shape[0] / 16000) * 25) hubert_hidden = get_hubert_from_16k_speech(speech_16k) hubert_hidden = hubert_hidden.detach().numpy() interp_func = interp1d(np.arange(hubert_hidden.shape[0]), hubert_hidden, kind='linear', axis=0) hubert_feature_interpolated = interp_func(np.linspace(0, hubert_hidden.shape[0] - 1, num_frames)).astype(np.float32) # print(f'fnum{num_frames},hubersize{hubert_hidden.shape[0]}') np.save(hubert_npy_name, hubert_feature_interpolated) # ## Process short audio clips for LRS3 dataset # import glob, os, tqdm # lrs3_dir = '/home/yezhenhui/datasets/raw/lrs3_raw/' # wav_16k_names = glob.glob(os.path.join(lrs3_dir, '*/*.wav')) # for wav_16k_name in tqdm.tqdm(wav_16k_names, total=len(wav_16k_names)): # spk_id = wav_16k_name.split("/")[-2] # clip_id = wav_16k_name.split("/")[-1][:-4] # out_name = os.path.join(lrs3_dir, spk_id, clip_id+'_hubert.npy') # if os.path.exists(out_name): # continue # speech_16k, _ = sf.read(wav_16k_name) # hubert_hidden = get_hubert_from_16k_speech(speech_16k) # np.save(out_name, hubert_hidden.detach().numpy()) ================================================ FILE: hubert_extract/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py ================================================ import os import sys os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" # adding path of PBnet current_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(current_dir))) if parent_dir not in sys.path: sys.path.append(parent_dir) print(parent_dir) from genericpath import exists from transformers import Wav2Vec2Processor, HubertModel import soundfile as sf import numpy as np import torch from scipy.interpolate import interp1d import subprocess import os from tqdm import tqdm import tempfile print("Loading the Wav2Vec2 Processor...") # wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") wav2vec2_processor = Wav2Vec2Processor.from_pretrained("./pretrain_models/hubert_ckp") print("Loading the HuBERT Model...") hubert_model = HubertModel.from_pretrained("./pretrain_models/hubert_ckp", from_tf = True) def get_hubert_from_16k_wav(wav_16k_name): speech_16k, _ = sf.read(wav_16k_name) hubert = get_hubert_from_16k_speech(speech_16k) return hubert @torch.no_grad() def get_hubert_from_16k_speech(speech, device="cuda:0"): global hubert_model print(f"当前显存占用: {torch.cuda.memory_allocated()} 字节") print(f"显存缓存占用: {torch.cuda.memory_reserved()} 字节") torch.cuda.empty_cache() # 强制重置 PyTorch 的 CUDA 分配器 torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() # 可选:手动设置较大的初始缓存大小 torch.cuda.set_per_process_memory_fraction(0.9) # 允许使用90%的显存 # 在加载模型前先检查显存状态 print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**2:.2f} MB") print(f"Allocated memory: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") print(f"Cached memory: {torch.cuda.memory_reserved() / 1024**2:.2f} MB") print(torch.cuda.memory_summary()) hubert_model = hubert_model.to(device) if speech.ndim ==2: speech = speech[:, 0] # [T, 2] ==> [T,] input_values_all = wav2vec2_processor(speech, return_tensors="pt", sampling_rate=16000).input_values # [1, T] input_values_all = input_values_all.to(device) # For long audio sequence, due to the memory limitation, we cannot process them in one run # HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320 # Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step. # So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320 # We have the equation to calculate out time step: T = floor((t-k)/s) # To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip # The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N kernel = 400 stride = 320 clip_length = stride * 1000 num_iter = input_values_all.shape[1] // clip_length expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride res_lst = [] for i in range(num_iter): if i == 0: start_idx = 0 end_idx = clip_length - stride + kernel else: start_idx = clip_length * i end_idx = start_idx + (clip_length - stride + kernel) input_values = input_values_all[:, start_idx: end_idx] hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024] res_lst.append(hidden_states[0]) if num_iter > 0: input_values = input_values_all[:, clip_length * num_iter:] else: input_values = input_values_all # if input_values.shape[1] != 0: if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024] res_lst.append(hidden_states[0]) ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024] # assert ret.shape[0] == expected_T assert abs(ret.shape[0] - expected_T) <= 1 if ret.shape[0] < expected_T: ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0])) else: ret = ret[:expected_T] return ret import argparse def get_arguments(): """Parse all the arguments provided from the CLI. Returns: A list of parsed arguments. """ parser = argparse.ArgumentParser(description="Flow Diffusion") parser.add_argument("--src_audio_path", default='/train20/intern/permanent/hbcheng2/data/test_speed/target_audio.wav') parser.add_argument("--save_path", default='your/path/DAWN-pytorch/ood_data', help="") return parser.parse_args() def convert_wav_to_16k(input_file, output_file): command = [ 'ffmpeg', '-i', input_file, '-ar', '16000', output_file ] subprocess.run(command) def delete_file(file_path): try: os.remove(file_path) print(f"File {file_path} has been deleted successfully.") except FileNotFoundError: print(f"File {file_path} not found.") except PermissionError: print(f"Permission denied: Unable to delete {file_path}.") except Exception as e: print(f"Error occurred while deleting {file_path}: {e}") if __name__ == '__main__': args = get_arguments() wav_path = args.src_audio_path wav_16k_name = wav_path npy_name = args.save_path output_wav_path = tempfile.NamedTemporaryFile('w', suffix='.wav', dir='./') convert_wav_to_16k(wav_path, output_wav_path.name) speech_16k, _ = sf.read(output_wav_path.name) delete_file(output_wav_path.name) # speech_16k, _ = sf.read(wav_path) num_frames = int((speech_16k.shape[0] / 16000) * 25) hubert_hidden = get_hubert_from_16k_speech(speech_16k, device = 'cuda:0') hubert_hidden = hubert_hidden.detach().numpy() interp_func = interp1d(np.arange(hubert_hidden.shape[0]), hubert_hidden, kind='linear', axis=0) hubert_feature_interpolated = interp_func(np.linspace(0, hubert_hidden.shape[0] - 1, num_frames)).astype(np.float32) print(f'fnum{num_frames},hubersize{hubert_hidden.shape[0]}') np.save(npy_name, hubert_feature_interpolated) ================================================ FILE: hubert_extract/data_gen/process_lrs3/process_audio_hubert_interpolate_single.py ================================================ from genericpath import exists from transformers import Wav2Vec2Processor, HubertModel import soundfile as sf import numpy as np import torch from tqdm import tqdm import fairseq import decord from scipy.interpolate import interp1d print("Loading the Wav2Vec2 Processor...") # wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") wav2vec2_processor = Wav2Vec2Processor.from_pretrained("/train20/intern/permanent/hbcheng2/lmlin2/Flow/GeneFace-main/checkpoints/hubert_ckp") print("Loading the HuBERT Model...") hubert_model = HubertModel.from_pretrained("/train20/intern/permanent/hbcheng2/lmlin2/Flow/GeneFace-main/checkpoints/hubert_ckp", from_tf = True) def get_hubert_from_16k_wav(wav_16k_name): speech_16k, _ = sf.read(wav_16k_name) hubert = get_hubert_from_16k_speech(speech_16k) return hubert @torch.no_grad() def get_hubert_from_16k_speech(speech, device="cuda:1"): global hubert_model hubert_model = hubert_model.to(device) if speech.ndim ==2: speech = speech[:, 0] # [T, 2] ==> [T,] input_values_all = wav2vec2_processor(speech, return_tensors="pt", sampling_rate=16000).input_values # [1, T] input_values_all = input_values_all.to(device) # For long audio sequence, due to the memory limitation, we cannot process them in one run # HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320 # Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step. # So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320 # We have the equation to calculate out time step: T = floor((t-k)/s) # To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip # The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N kernel = 400 stride = 320 clip_length = stride * 1000 num_iter = input_values_all.shape[1] // clip_length expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride res_lst = [] for i in range(num_iter): if i == 0: start_idx = 0 end_idx = clip_length - stride + kernel else: start_idx = clip_length * i end_idx = start_idx + (clip_length - stride + kernel) input_values = input_values_all[:, start_idx: end_idx] hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024] res_lst.append(hidden_states[0]) if num_iter > 0: input_values = input_values_all[:, clip_length * num_iter:] else: input_values = input_values_all # if input_values.shape[1] != 0: if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024] res_lst.append(hidden_states[0]) ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024] # assert ret.shape[0] == expected_T assert abs(ret.shape[0] - expected_T) <= 1 if ret.shape[0] < expected_T: ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0])) else: ret = ret[:expected_T] return ret from tqdm import tqdm if __name__ == '__main__': wav_path = '/train20/intern/permanent/hbcheng2/data/ood_video/audio_clip/minanyu-xuelingsun_2.wav' wav_16k_name = wav_path hubert_npy_name = '/train20/intern/permanent/hbcheng2/data/ood_video/audio_clip_hubert/minanyu-xuelingsun_2.npy'# os.path.join(wav_path,npy_name+'.npy') speech_16k, _ = sf.read(wav_16k_name) num_frames = int((speech_16k.shape[0] / 16000) * 25) hubert_hidden = get_hubert_from_16k_speech(speech_16k) hubert_hidden = hubert_hidden.detach().numpy() interp_func = interp1d(np.arange(hubert_hidden.shape[0]), hubert_hidden, kind='linear', axis=0) hubert_feature_interpolated = interp_func(np.linspace(0, hubert_hidden.shape[0] - 1, num_frames)).astype(np.float32) print(f'fnum{num_frames},hubersize{hubert_hidden.shape[0]}') np.save(hubert_npy_name, hubert_feature_interpolated) ================================================ FILE: hubert_extract/data_gen/process_lrs3/process_audio_mel_f0.py ================================================ import numpy as np import torch import glob import os import tqdm import librosa import parselmouth from utils.commons.pitch_utils import f0_to_coarse from utils.commons.multiprocess_utils import multiprocess_run_tqdm def librosa_pad_lr(x, fsize, fshift, pad_sides=1): '''compute right padding (final frame) or both sides padding (first and final frames) ''' assert pad_sides in (1, 2) # return int(fsize // 2) pad = (x.shape[0] // fshift + 1) * fshift - x.shape[0] if pad_sides == 1: return 0, pad else: return pad // 2, pad // 2 + pad % 2 def extract_mel_from_fname(wav_path, fft_size=512, hop_size=320, win_length=512, window="hann", num_mels=80, fmin=80, fmax=7600, eps=1e-6, sample_rate=16000, min_level_db=-100): if isinstance(wav_path, str): wav, _ = librosa.core.load(wav_path, sr=sample_rate) else: wav = wav_path # get amplitude spectrogram x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size, win_length=win_length, window=window, center=False) spc = np.abs(x_stft) # (n_bins, T) # get mel basis fmin = 0 if fmin == -1 else fmin fmax = sample_rate / 2 if fmax == -1 else fmax mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax) mel = mel_basis @ spc mel = np.log10(np.maximum(eps, mel)) # (n_mel_bins, T) mel = mel.T l_pad, r_pad = librosa_pad_lr(wav, fft_size, hop_size, 1) wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0) return wav.T, mel def extract_f0_from_wav_and_mel(wav, mel, hop_size=320, audio_sample_rate=16000, ): time_step = hop_size / audio_sample_rate * 1000 f0_min = 80 f0_max = 750 f0 = parselmouth.Sound(wav, audio_sample_rate).to_pitch_ac( time_step=time_step / 1000, voicing_threshold=0.6, pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency'] delta_l = len(mel) - len(f0) assert np.abs(delta_l) <= 8 if delta_l > 0: f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0) f0 = f0[:len(mel)] pitch_coarse = f0_to_coarse(f0) return f0, pitch_coarse def extract_mel_f0_from_fname(fname, out_name=None): assert fname.endswith(".wav") if out_name is None: out_name = fname[:-4] + '_audio.npy' wav, mel = extract_mel_from_fname(fname) f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel) out_dict = { "mel": mel, # [T, 80] "f0": f0, } np.save(out_name, out_dict) return True if __name__ == '__main__': import os, glob lrs3_dir = "/home/yezhenhui/datasets/raw/lrs3_raw" wav_name_pattern = os.path.join(lrs3_dir, "*/*.wav") wav_names = glob.glob(wav_name_pattern) wav_names = sorted(wav_names) for _ in multiprocess_run_tqdm(extract_mel_f0_from_fname, args=wav_names, num_workers=32,desc='extracting Mel and f0'): pass ================================================ FILE: misc.py ================================================ import cv2 import os import requests import torch import torch.nn.functional as F import torch.distributed as dist import sys import matplotlib.pyplot as plt from matplotlib.collections import LineCollection import numpy as np import flow_vis import cv2 def fig2data(fig): """ @brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it @param fig a matplotlib figure @return a numpy 3D array of RGBA values """ # draw the renderer fig.canvas.draw() # Get the RGBA buffer from the figure w, h = fig.canvas.get_width_height() buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8) buf.shape = (w, h, 4) # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode buf = np.roll(buf, 3, axis=2) return buf def plot_grid(x, y, ax=None, **kwargs): ax = ax or plt.gca() segs1 = np.stack((x, y), axis=2) segs2 = segs1.transpose(1, 0, 2) ax.add_collection(LineCollection(segs1, **kwargs)) ax.add_collection(LineCollection(segs2, **kwargs)) ax.autoscale() def grid2fig(warped_grid, grid_size=32, img_size=256): dpi = 1000 # plt.ioff() h_range = torch.linspace(-1, 1, grid_size) w_range = torch.linspace(-1, 1, grid_size) grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).flip(2) flow_uv = grid.cpu().data.numpy() fig, ax = plt.subplots() grid_x, grid_y = warped_grid[..., 0], warped_grid[..., 1] plot_grid(flow_uv[..., 0], flow_uv[..., 1], ax=ax, color="lightgrey") plot_grid(grid_x, grid_y, ax=ax, color="C0") plt.axis("off") plt.tight_layout(pad=0) fig.set_size_inches(img_size/100, img_size/100) fig.set_dpi(100) out = fig2data(fig)[:, :, :3] out = np.flipud(out) out = np.fliplr(out) plt.close() plt.cla() plt.clf() return out def flow2fig(warped_grid, id_grid, grid_size=32, img_size=128): # h_range = torch.linspace(-1, 1, grid_size) # w_range = torch.linspace(-1, 1, grid_size) # id_grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).flip(2) warped_flow = warped_grid - id_grid img = flow_vis.flow_to_color(warped_flow) img = cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_AREA) return img def conf2fig(conf, img_size=128): conf = F.interpolate(conf.unsqueeze(dim=0), size=img_size).data.cpu().numpy() conf = np.transpose(conf, [0, 2, 3, 1]) conf = np.array(conf[0, :, :, 0]*255, dtype=np.uint8) return conf class Logger(object): def __init__(self, filename='default.log', stream=sys.stdout): self.terminal = stream self.log = open(filename, 'w') def write(self, message): self.terminal.write(message) self.log.write(message) def flush(self): pass def resize(im, desired_size, interpolation): old_size = im.shape[:2] ratio = float(desired_size)/max(old_size) new_size = tuple(int(x*ratio) for x in old_size) im = cv2.resize(im, (new_size[1], new_size[0]), interpolation=interpolation) delta_w = desired_size - new_size[1] delta_h = desired_size - new_size[0] top, bottom = delta_h//2, delta_h-(delta_h//2) left, right = delta_w//2, delta_w-(delta_w//2) color = [0, 0, 0] new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) return new_im def resample(image, flow): r"""Resamples an image using the provided flow. Args: image (NxCxHxW tensor) : Image to resample. flow (Nx2xHxW tensor) : Optical flow to resample the image. Returns: output (NxCxHxW tensor) : Resampled image. """ assert flow.shape[1] == 2 b, c, h, w = image.size() grid = get_grid(b, (h, w)) flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1.0) / 2.0), flow[:, 1:2, :, :] / ((h - 1.0) / 2.0)], dim=1) final_grid = (grid + flow).permute(0, 2, 3, 1) try: output = F.grid_sample(image, final_grid, mode='bilinear', padding_mode='border', align_corners=True) except Exception: output = F.grid_sample(image, final_grid, mode='bilinear', padding_mode='border') return output def get_grid(batchsize, size, minval=-1.0, maxval=1.0): r"""Get a grid ranging [-1, 1] of 2D/3D coordinates. Args: batchsize (int) : Batch size. size (tuple) : (height, width) or (depth, height, width). minval (float) : minimum value in returned grid. maxval (float) : maximum value in returned grid. Returns: t_grid (4D tensor) : Grid of coordinates. """ if len(size) == 2: rows, cols = size elif len(size) == 3: deps, rows, cols = size else: raise ValueError('Dimension can only be 2 or 3.') x = torch.linspace(minval, maxval, cols) x = x.view(1, 1, 1, cols) x = x.expand(batchsize, 1, rows, cols) y = torch.linspace(minval, maxval, rows) y = y.view(1, 1, rows, 1) y = y.expand(batchsize, 1, rows, cols) t_grid = torch.cat([x, y], dim=1) if len(size) == 3: z = torch.linspace(minval, maxval, deps) z = z.view(1, 1, deps, 1, 1) z = z.expand(batchsize, 1, deps, rows, cols) t_grid = t_grid.unsqueeze(2).expand(batchsize, 2, deps, rows, cols) t_grid = torch.cat([t_grid, z], dim=1) t_grid.requires_grad = False return t_grid.to('cuda') def get_checkpoint(checkpoint_path, url=''): r"""Get the checkpoint path. If it does not exist yet, download it from the url. Args: checkpoint_path (str): Checkpoint path. url (str): URL to download checkpoint. Returns: (str): Full checkpoint path. """ if 'TORCH_HOME' not in os.environ: os.environ['TORCH_HOME'] = os.getcwd() save_dir = os.path.join(os.environ['TORCH_HOME'], 'checkpoints') os.makedirs(save_dir, exist_ok=True) full_checkpoint_path = os.path.join(save_dir, checkpoint_path) if not os.path.exists(full_checkpoint_path): os.makedirs(os.path.dirname(full_checkpoint_path), exist_ok=True) if is_master(): print('Download {}'.format(url)) download_file_from_google_drive(url, full_checkpoint_path) if dist.is_available() and dist.is_initialized(): dist.barrier() return full_checkpoint_path def download_file_from_google_drive(file_id, destination): r"""Download a file from the google drive by using the file ID. Args: file_id: Google drive file ID destination: Path to save the file. Returns: """ URL = "https://docs.google.com/uc?export=download" session = requests.Session() response = session.get(URL, params={'id': file_id}, stream=True) token = get_confirm_token(response) if token: params = {'id': file_id, 'confirm': token} response = session.get(URL, params=params, stream=True) save_response_content(response, destination) def get_confirm_token(response): r"""Get confirm token Args: response: Check if the file exists. Returns: """ for key, value in response.cookies.items(): if key.startswith('download_warning'): return value return None def save_response_content(response, destination): r"""Save response content Args: response: destination: Path to save the file. Returns: """ chunk_size = 32768 with open(destination, "wb") as f: for chunk in response.iter_content(chunk_size): if chunk: f.write(chunk) def get_rank(): r"""Get rank of the thread.""" rank = 0 if dist.is_available(): if dist.is_initialized(): rank = dist.get_rank() return rank def is_master(): r"""check if current process is the master""" return get_rank() == 0 ================================================ FILE: requirements.txt ================================================ absl-py==2.0.0 accelerate==1.0.1 aiofiles==23.2.1 albumentations==1.3.1 annotated-types==0.7.0 antlr4-python3-runtime==4.8 anyio==4.5.2 astunparse==1.6.3 audioread==3.0.1 av==11.0.0 beautifulsoup4==4.12.3 bitarray==2.8.2 boto3==1.28.78 botocore==1.31.78 cachetools==4.2.4 certifi==2023.7.22 cffi==1.16.0 charset-normalizer==3.2.0 click==8.1.7 cmake==3.30.1 colorama==0.4.6 coloredlogs==15.0.1 contourpy==1.1.1 cycler==0.12.1 Cython==3.0.5 decorator==4.4.2 decord==0.6.0 einops==0.7.0 einops-exts==0.0.4 exceptiongroup==1.2.2 fairseq==0.12.2 fastapi==0.115.8 ffmpeg==1.4 ffmpeg-python==0.2.0 ffmpy==0.5.0 filelock==3.13.1 flatbuffers==23.5.26 flow-vis==0.1 fonttools==4.44.0 fsspec==2023.10.0 future==1.0.0 gast==0.4.0 gdown==5.1.0 google-auth==2.32.0 google-auth-oauthlib==0.4.6 google-pasta==0.2.0 gradio==4.44.1 gradio_client==1.3.0 grpcio==1.59.2 h11==0.14.0 h5py==3.10.0 httpcore==1.0.7 httpx==0.28.1 huggingface-hub==0.28.1 humanfriendly==10.0 hydra-core==1.0.7 idna==3.4 imageio==2.31.5 imageio-ffmpeg==0.4.9 importlib-metadata==6.8.0 importlib-resources==6.1.0 Jinja2==3.1.4 jmespath==1.0.1 joblib==1.3.2 json-tricks==3.17.3 keras==2.11.0 kiwisolver==1.4.5 lazy_loader==0.3 libclang==16.0.6 librosa==0.7.1 lit==18.1.8 llvmlite==0.41.0 lpips==0.1.4 lxml==4.9.3 Markdown==3.5.1 markdown-it-py==3.0.0 MarkupSafe==2.1.3 matplotlib==3.7.3 mdurl==0.1.2 moviepy==1.0.3 mpmath==1.3.0 msgpack==1.0.7 natsort==8.4.0 networkx==3.1 numba==0.58.0 numpy==1.24.3 nvidia-cublas-cu11==11.10.3.66 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-nvrtc-cu11==11.7.99 nvidia-cuda-runtime-cu11==11.7.99 nvidia-cudnn-cu11==8.5.0.96 nvidia-nvjitlink-cu12==12.8.61 oauthlib==3.2.2 omegaconf==2.0.5 onnx==1.17.0 onnxruntime==1.19.2 opencv-contrib-python==4.8.0.76 opencv-python==4.7.0.72 opencv-python-headless==4.8.1.78 opt-einsum==3.3.0 orjson==3.10.15 packaging==23.2 pandas==2.0.3 Pillow==10.0.1 platformdirs==3.11.0 pooch==1.7.0 portalocker==2.8.2 proglog==0.1.10 protobuf==3.20.2 psutil==6.1.1 pyasn1==0.5.0 pyasn1-modules==0.3.0 pycparser==2.21 pydantic==2.10.6 pydantic_core==2.27.2 pydub==0.25.1 Pygments==2.19.1 pyparsing==3.1.1 PySocks==1.7.1 pyspng==0.1.1 python-dateutil==2.8.2 python-multipart==0.0.20 python_speech_features==0.6 pytz==2023.3.post1 PyWavelets==1.4.1 PyYAML==6.0.1 qudida==0.0.4 regex==2023.10.3 requests==2.31.0 requests-oauthlib==1.3.1 resampy==0.4.2 rich==13.9.4 rotary-embedding-torch==0.3.5 rsa==4.9 ruff==0.9.6 s3transfer==0.7.0 sacrebleu==2.3.1 sacremoses==0.1.1 safetensors==0.5.2 scenedetect==0.5.1 scikit-image==0.21.0 scikit-learn==1.3.1 scipy==1.9.1 semantic-version==2.10.0 sentencepiece==0.1.99 shellingham==1.5.4 six==1.16.0 sniffio==1.3.1 soundfile==0.12.1 soupsieve==2.5 soxr==0.3.7 starlette==0.44.0 sympy==1.13.1 sync-batchnorm==0.0.1 tabulate==0.9.0 tensorboard==2.11.2 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.1 tensorboardX==2.6.2.2 tensorflow==2.11.1 tensorflow-estimator==2.11.0 tensorflow-io-gcs-filesystem==0.34.0 termcolor==2.3.0 threadpoolctl==3.2.0 tifffile==2023.7.10 tokenizers==0.20.3 tomlkit==0.12.0 torch==1.13.0 torchaudio==0.13.0 torchvision==0.14.0 tqdm==4.66.1 transformers==4.46.3 triton==2.0.0 typer==0.15.1 typing_extensions==4.12.2 tzdata==2023.3 urllib3==2.2.3 uvicorn==0.33.0 visualize==0.5.1 websockets==12.0 Werkzeug==3.0.1 wrapt==1.15.0 zipp==3.17.0 ================================================ FILE: run_ood_test/run_DM_v0_df_test_128_both_pose_blink.sh ================================================ test_name=ood_test_1009 # $(date +"%Y-%m-%d_%H-%M-%S") time_tag=tmp1009 # $(date +"%Y-%m-%d_%H-%M-%S") audio_path=WRA_MarcoRubio_000.wav image_path=real_female_1.jpeg cache_path=cache/$time_tag audio_emb_path=cache/target_audio.npy video_output_path=cache/ conda activate 3DDFA cd extract_init_states python demo_pose_extract_2d_lmk_img.py \ --input $image_path \ --output $cache_path cd .. conda activate DAWN python ./hubert_extract/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \ --src_audio_path $audio_path \ --save_path $audio_emb_path python ./PBnet/src/evaluate/tvae_eval_single_both_eye_pose.py \ --audio_path $audio_emb_path \ --init_pose_blink $cache_path \ --ckpt './pretrain_models/pbnet_both/checkpoint_100000.pth.tar' \ --output $cache_path python ./DM_3/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_128_2.py --gpu 0 \ --source_img_path $image_path \ --init_state_path $cache_path \ --drive_blink_path $cache_path/dri_blink.npy \ --drive_pose_path $cache_path/dri_pose.npy \ --audio_emb_path $audio_emb_path \ --save_path $video_output_path/$test_name \ --src_audio_path $audio_path ================================================ FILE: run_ood_test/run_DM_v0_df_test_128_separate_pose_blink.sh ================================================ test_name=ood_test_1009 # $(date +"%Y-%m-%d_%H-%M-%S") time_tag=tmp1009 # $(date +"%Y-%m-%d_%H-%M-%S") audio_path=WRA_MarcoRubio_000.wav image_path=real_female_1.jpeg cache_path=cache/$time_tag audio_emb_path=cache/target_audio.npy video_output_path=cache/ source activate conda activate 3DDFA cd extract_init_states python demo_pose_extract_2d_lmk_img.py \ --input ../$image_path \ --output ../$cache_path cd .. conda activate DAWN python ./hubert_extract/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \ --src_audio_path $audio_path \ --save_path $audio_emb_path # conda activate LFDM_a40 python ./PBnet/src/evaluate/tvae_eval_single.py \ --audio_path $audio_emb_path \ --init_pose_blink $cache_path \ --output $cache_path \ --ckpt_pose ./pretrain_models/pbnet_seperate/pose/checkpoint_40000.pth.tar \ --ckpt_blink ./pretrain_models/pbnet_seperate/blink/checkpoint_95000.pth.tar python your_path/DM_3/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_128_2.py --gpu 0 \ --source_img_path $image_path \ --init_state_path $cache_path \ --drive_blink_path $cache_path/dri_blink.npy \ --drive_pose_path $cache_path/dri_pose.npy \ --audio_emb_path $audio_emb_path \ --save_path $video_output_path/$test_name \ --src_audio_path $audio_path ================================================ FILE: run_ood_test/run_DM_v0_df_test_256.sh ================================================ source /home4/intern/hbcheng2/.bashrc test_name=ood_test_1006 # $(date +"%Y-%m-%d_%H-%M-%S") time_tag=tmp #$(date +"%Y-%m-%d_%H-%M-%S") audio_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_clip_vocal_origin/Taylor-Swift-You-Belong-With-Me-vocal_clip7.wav image_path=your/path/DAWN-pytorch/ood_data/ood_select_3/test4.jpeg cache_path=your/path/DAWN-pytorch/ood_data_3/$time_tag audio_emb_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_embedding_vocal/Taylor-Swift-You-Belong-With-Me-vocal_clip7.npy conda activate 3DDFA cd /train20/intern/permanent/hbcheng2/AIGC_related/3DDFA_V2-master python /train20/intern/permanent/hbcheng2/AIGC_related/3DDFA_V2-master/demo_pose_extract_2d_lmk_img.py \ --input $image_path \ --output $cache_path # conda activate LFDM_chb # cd /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main # python /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \ # --src_audio_path $audio_path \ # --save_path $audio_emb_path conda activate LFDM_chb cd /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master python /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/src/evaluate/tvae_eval_signal.py \ --audio_path $audio_emb_path \ --init_pose_blink $cache_path \ --output $cache_path cd your/path/DAWN-pytorch # source /home4/intern/hbcheng2/.bashrc # echo 'finish extracting init state' python your/path/DAWN-pytorch/DM_1/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_256_2.py --gpu 0 \ --source_img_path $image_path \ --init_state_path $cache_path \ --drive_blink_path $cache_path/dri_blink.npy \ --drive_pose_path $cache_path/dri_pose.npy \ --audio_emb_path $audio_emb_path \ --save_path /train20/intern/permanent/hbcheng2/data/ood_test_3/$test_name \ --src_audio_path $audio_path # audio_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_clip_vocal_origin/Taylor-Swift-You-Belong-With-Me-vocal_clip1.wav # # image_path=your/path/DAWN-pytorch/ood_data/ood_select/images/draw_female_test1.png # # cache_path=your/path/DAWN-pytorch/ood_data_3/$test_name # audio_emb_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_embedding_vocal/Taylor-Swift-You-Belong-With-Me-vocal_clip1.npy # # conda activate LFDM_chb # # cd /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main # # python /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \ # # --src_audio_path $audio_path \ # # --save_path $audio_emb_path # cd /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master # python /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/src/evaluate/tvae_eval_signal.py \ # --audio_path $audio_emb_path \ # --init_pose_blink $cache_path \ # --output $cache_path # cd your/path/DAWN-pytorch # # source /home4/intern/hbcheng2/.bashrc # # conda activate LFDM_a40 # # echo 'finish extracting init state' # python your/path/DAWN-pytorch/DM_1/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_256_2.py --gpu 0 \ # --source_img_path $image_path \ # --init_state_path $cache_path \ # --drive_blink_path $cache_path/dri_blink.npy \ # --drive_pose_path $cache_path/dri_pose.npy \ # --audio_emb_path $audio_emb_path \ # --save_path /train20/intern/permanent/hbcheng2/data/ood_test_3/$test_name \ # --src_audio_path $audio_path # audio_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_clip_vocal_origin/Taylor-Swift-You-Belong-With-Me-vocal_clip2.wav # # image_path=your/path/DAWN-pytorch/ood_data/ood_select/images/draw_female_test1.png # # cache_path=your/path/DAWN-pytorch/ood_data_3/$test_name # audio_emb_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_embedding_vocal/Taylor-Swift-You-Belong-With-Me-vocal_clip2.npy # # conda activate LFDM_chb # # cd /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main # # python /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \ # # --src_audio_path $audio_path \ # # --save_path $audio_emb_path # cd /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master # python /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/src/evaluate/tvae_eval_signal.py \ # --audio_path $audio_emb_path \ # --init_pose_blink $cache_path \ # --output $cache_path # cd your/path/DAWN-pytorch # # source /home4/intern/hbcheng2/.bashrc # # conda activate LFDM_a40 # # echo 'finish extracting init state' # python your/path/DAWN-pytorch/DM_1/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_256_2.py --gpu 0 \ # --source_img_path $image_path \ # --init_state_path $cache_path \ # --drive_blink_path $cache_path/dri_blink.npy \ # --drive_pose_path $cache_path/dri_pose.npy \ # --audio_emb_path $audio_emb_path \ # --save_path /train20/intern/permanent/hbcheng2/data/ood_test_3/$test_name \ # --src_audio_path $audio_path # audio_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_clip_vocal_origin/Taylor-Swift-You-Belong-With-Me-vocal_clip3.wav # # image_path=your/path/DAWN-pytorch/ood_data/ood_select/images/draw_female_test1.png # # cache_path=your/path/DAWN-pytorch/ood_data_3/$test_name # audio_emb_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_embedding_vocal/Taylor-Swift-You-Belong-With-Me-vocal_clip3.npy # # conda activate LFDM_chb # # cd /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main # # python /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \ # # --src_audio_path $audio_path \ # # --save_path $audio_emb_path # cd /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master # python /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/src/evaluate/tvae_eval_signal.py \ # --audio_path $audio_emb_path \ # --init_pose_blink $cache_path \ # --output $cache_path # cd your/path/DAWN-pytorch # # source /home4/intern/hbcheng2/.bashrc # # conda activate LFDM_a40 # # echo 'finish extracting init state' # python your/path/DAWN-pytorch/DM_1/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_256_2.py --gpu 0 \ # --source_img_path $image_path \ # --init_state_path $cache_path \ # --drive_blink_path $cache_path/dri_blink.npy \ # --drive_pose_path $cache_path/dri_pose.npy \ # --audio_emb_path $audio_emb_path \ # --save_path /train20/intern/permanent/hbcheng2/data/ood_test_3/$test_name \ # --src_audio_path $audio_path # audio_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_clip_vocal_origin/Taylor-Swift-You-Belong-With-Me-vocal_clip4.wav # # image_path=your/path/DAWN-pytorch/ood_data/ood_select/images/draw_female_test1.png # # cache_path=your/path/DAWN-pytorch/ood_data_3/$test_name # audio_emb_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_embedding_vocal/Taylor-Swift-You-Belong-With-Me-vocal_clip4.npy # # conda activate LFDM_chb # # cd /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main # # python /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \ # # --src_audio_path $audio_path \ # # --save_path $audio_emb_path # cd /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master # python /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/src/evaluate/tvae_eval_signal.py \ # --audio_path $audio_emb_path \ # --init_pose_blink $cache_path \ # --output $cache_path # cd your/path/DAWN-pytorch # # source /home4/intern/hbcheng2/.bashrc # # conda activate LFDM_a40 # # echo 'finish extracting init state' # python your/path/DAWN-pytorch/DM_1/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_256_2.py --gpu 0 \ # --source_img_path $image_path \ # --init_state_path $cache_path \ # --drive_blink_path $cache_path/dri_blink.npy \ # --drive_pose_path $cache_path/dri_pose.npy \ # --audio_emb_path $audio_emb_path \ # --save_path /train20/intern/permanent/hbcheng2/data/ood_test_3/$test_name \ # --src_audio_path $audio_path # audio_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_clip_vocal_origin/Taylor-Swift-You-Belong-With-Me-vocal_clip5.wav # # image_path=your/path/DAWN-pytorch/ood_data/ood_select/images/draw_female_test1.png # # cache_path=your/path/DAWN-pytorch/ood_data_3/$test_name # audio_emb_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_embedding_vocal/Taylor-Swift-You-Belong-With-Me-vocal_clip5.npy # # conda activate LFDM_chb # # cd /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main # # python /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \ # # --src_audio_path $audio_path \ # # --save_path $audio_emb_path # cd /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master # python /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/src/evaluate/tvae_eval_signal.py \ # --audio_path $audio_emb_path \ # --init_pose_blink $cache_path \ # --output $cache_path # cd your/path/DAWN-pytorch # # source /home4/intern/hbcheng2/.bashrc # # conda activate LFDM_a40 # # echo 'finish extracting init state' # python your/path/DAWN-pytorch/DM_1/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_256_2.py --gpu 0 \ # --source_img_path $image_path \ # --init_state_path $cache_path \ # --drive_blink_path $cache_path/dri_blink.npy \ # --drive_pose_path $cache_path/dri_pose.npy \ # --audio_emb_path $audio_emb_path \ # --save_path /train20/intern/permanent/hbcheng2/data/ood_test_3/$test_name \ # --src_audio_path $audio_path # audio_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_clip_vocal_origin/Taylor-Swift-You-Belong-With-Me-vocal_clip6.wav # # image_path=your/path/DAWN-pytorch/ood_data/ood_select/images/draw_female_test1.png # # cache_path=your/path/DAWN-pytorch/ood_data_3/$test_name # audio_emb_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_embedding_vocal/Taylor-Swift-You-Belong-With-Me-vocal_clip6.npy # # conda activate LFDM_chb # # cd /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main # # python /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \ # # --src_audio_path $audio_path \ # # --save_path $audio_emb_path # cd /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master # python /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/src/evaluate/tvae_eval_signal.py \ # --audio_path $audio_emb_path \ # --init_pose_blink $cache_path \ # --output $cache_path # cd your/path/DAWN-pytorch # # source /home4/intern/hbcheng2/.bashrc # # conda activate LFDM_a40 # # echo 'finish extracting init state' # python your/path/DAWN-pytorch/DM_1/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_256_2.py --gpu 0 \ # --source_img_path $image_path \ # --init_state_path $cache_path \ # --drive_blink_path $cache_path/dri_blink.npy \ # --drive_pose_path $cache_path/dri_pose.npy \ # --audio_emb_path $audio_emb_path \ # --save_path /train20/intern/permanent/hbcheng2/data/ood_test_3/$test_name \ # --src_audio_path $audio_path # audio_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_clip_vocal_origin/Taylor-Swift-You-Belong-With-Me-vocal_clip0.wav # # image_path=your/path/DAWN-pytorch/ood_data/ood_select/images/draw_female_test1.png # # cache_path=your/path/DAWN-pytorch/ood_data_3/$test_name # audio_emb_path=your/path/DAWN-pytorch/ood_data/ood_select/audio_embedding_vocal/Taylor-Swift-You-Belong-With-Me-vocal_clip0.npy # # conda activate LFDM_chb # # cd /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main # # python /train20/intern/permanent/hbcheng2/AIGC_related/GeneFace-main/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \ # # --src_audio_path $audio_path \ # # --save_path $audio_emb_path # cd /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master # python /train20/intern/permanent/hbcheng2/AIGC_related/ACTOR-master/src/evaluate/tvae_eval_signal.py \ # --audio_path $audio_emb_path \ # --init_pose_blink $cache_path \ # --output $cache_path # cd your/path/DAWN-pytorch # # source /home4/intern/hbcheng2/.bashrc # # conda activate LFDM_a40 # # echo 'finish extracting init state' # python your/path/DAWN-pytorch/DM_1/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_256_2.py --gpu 0 \ # --source_img_path $image_path \ # --init_state_path $cache_path \ # --drive_blink_path $cache_path/dri_blink.npy \ # --drive_pose_path $cache_path/dri_pose.npy \ # --audio_emb_path $audio_emb_path \ # --save_path /train20/intern/permanent/hbcheng2/data/ood_test_3/$test_name \ # --src_audio_path $audio_path ================================================ FILE: run_ood_test/run_DM_v0_df_test_256_1.sh ================================================ test_name=ood_test_1009 # $(date +"%Y-%m-%d_%H-%M-%S") time_tag=tmp1009 # $(date +"%Y-%m-%d_%H-%M-%S") audio_path=WRA_MarcoRubio_000.wav image_path=real_female_1.jpeg cache_path=cache/$time_tag audio_emb_path=cache/target_audio.npy video_output_path=cache/ conda activate 3DDFA cd extract_init_states python demo_pose_extract_2d_lmk_img.py \ --input ../$image_path \ --output ../$cache_path cd .. conda activate DAWN python ./hubert_extract/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \ --src_audio_path $audio_path \ --save_path $audio_emb_path python ./PBnet/src/evaluate/tvae_eval_single_both_eye_pose.py \ --audio_path $audio_emb_path \ --init_pose_blink $cache_path \ --ckpt './pretrain_models/pbnet_both/checkpoint_100000.pth.tar' \ --output $cache_path python ./DM_3/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_256_2.py --gpu 0 \ --source_img_path $image_path \ --init_state_path $cache_path \ --drive_blink_path $cache_path/dri_blink.npy \ --drive_pose_path $cache_path/dri_pose.npy \ --audio_emb_path $audio_emb_path \ --save_path $video_output_path/$test_name \ --src_audio_path $audio_path ================================================ FILE: run_ood_test/run_DM_v0_df_test_256_1_separate_pose_blink.sh ================================================ test_name=ood_test_1009 # $(date +"%Y-%m-%d_%H-%M-%S") time_tag=tmp1009 # $(date +"%Y-%m-%d_%H-%M-%S") audio_path=WRA_MarcoRubio_000.wav image_path=real_female_1.jpeg cache_path=cache/$time_tag audio_emb_path=cache/target_audio.npy video_output_path=cache/ source activate # conda activate 3DDFA # cd extract_init_states # python demo_pose_extract_2d_lmk_img.py \ # --input ../$image_path \ # --output ../$cache_path # cd .. conda activate DAWN python ./hubert_extract/data_gen/process_lrs3/process_audio_hubert_interpolate_demo.py \ --src_audio_path $audio_path \ --save_path $audio_emb_path # python ./PBnet/src/evaluate/tvae_eval_single.py \ # --audio_path $audio_emb_path \ # --init_pose_blink $cache_path \ # --output $cache_path \ # --ckpt_pose ./pretrain_models/pbnet_seperate/pose/checkpoint_40000.pth.tar \ # --ckpt_blink ./pretrain_models/pbnet_seperate/blink/checkpoint_95000.pth.tar # python ./DM_3/test_demo/test_VIDEO_hdtf_df_wpose_face_cond_init_ca_newae_ood_256_2.py --gpu 0 \ # --source_img_path $image_path \ # --init_state_path $cache_path \ # --drive_blink_path $cache_path/dri_blink.npy \ # --drive_pose_path $cache_path/dri_pose.npy \ # --audio_emb_path $audio_emb_path \ # --save_path $video_output_path/$test_name \ # --src_audio_path $audio_path ================================================ FILE: sync_batchnorm/__init__.py ================================================ # -*- coding: utf-8 -*- # File : __init__.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d from .replicate import DataParallelWithCallback, patch_replication_callback ================================================ FILE: sync_batchnorm/batchnorm.py ================================================ # -*- coding: utf-8 -*- # File : batchnorm.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import collections import torch import torch.nn.functional as F from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast from .comm import SyncMaster __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] def _sum_ft(tensor): """sum over the first and last dimention""" return tensor.sum(dim=0).sum(dim=-1) def _unsqueeze_ft(tensor): """add new dementions at the front and the tail""" return tensor.unsqueeze(0).unsqueeze(-1) _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) class _SynchronizedBatchNorm(_BatchNorm): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) self._sync_master = SyncMaster(self._data_parallel_master) self._is_parallel = False self._parallel_id = None self._slave_pipe = None def forward(self, input): # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. if not (self._is_parallel and self.training): return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps) # Resize the input to (B, C, -1). input_shape = input.size() input = input.view(input.size(0), self.num_features, -1) # Compute the sum and square-sum. sum_size = input.size(0) * input.size(2) input_sum = _sum_ft(input) input_ssum = _sum_ft(input ** 2) # Reduce-and-broadcast the statistics. if self._parallel_id == 0: mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) else: mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) # Compute the output. if self.affine: # MJY:: Fuse the multiplication for speed. output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) else: output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) # Reshape it. return output.view(input_shape) def __data_parallel_replicate__(self, ctx, copy_id): self._is_parallel = True self._parallel_id = copy_id # parallel_id == 0 means master device. if self._parallel_id == 0: ctx.sync_master = self._sync_master else: self._slave_pipe = ctx.sync_master.register_slave(copy_id) def _data_parallel_master(self, intermediates): """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" # Always using same "device order" makes the ReduceAdd operation faster. # Thanks to:: Tete Xiao (http://tetexiao.com/) intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) to_reduce = [i[1][:2] for i in intermediates] to_reduce = [j for i in to_reduce for j in i] # flatten target_gpus = [i[1].sum.get_device() for i in intermediates] sum_size = sum([i[1].sum_size for i in intermediates]) sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) broadcasted = Broadcast.apply(target_gpus, mean, inv_std) outputs = [] for i, rec in enumerate(intermediates): outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) return outputs def _compute_mean_std(self, sum_, ssum, size): """Compute the mean and standard-deviation with sum and square-sum. This method also maintains the moving average on the master device.""" assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' mean = sum_ / size sumvar = ssum - sum_ * mean unbias_var = sumvar / (size - 1) bias_var = sumvar / size self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data return mean, bias_var.clamp(self.eps) ** -0.5 class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a mini-batch. .. math:: y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta This module differs from the built-in PyTorch BatchNorm1d as the mean and standard-deviation are reduced across all devices during training. For example, when one uses `nn.DataParallel` to wrap the network during training, PyTorch's implementation normalize the tensor on each device using the statistics only on that device, which accelerated the computation and is also easy to implement, but the statistics might be inaccurate. Instead, in this synchronized version, the statistics will be computed over all training samples distributed on multiple devices. Note that, for one-GPU or CPU-only case, this module behaves exactly same as the built-in PyTorch implementation. The mean and standard-deviation are calculated per-dimension over the mini-batches and gamma and beta are learnable parameter vectors of size C (where C is the input size). During training, this layer keeps a running estimate of its computed mean and variance. The running sum is kept with a default momentum of 0.1. During evaluation, this running mean/variance is used for normalization. Because the BatchNorm is done over the `C` dimension, computing statistics on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm Args: num_features: num_features from an expected input of size `batch_size x num_features [x width]` eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, gives the layer learnable affine parameters. Default: ``True`` Shape: - Input: :math:`(N, C)` or :math:`(N, C, L)` - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) Examples: >>> # With Learnable Parameters >>> m = SynchronizedBatchNorm1d(100) >>> # Without Learnable Parameters >>> m = SynchronizedBatchNorm1d(100, affine=False) >>> input = torch.autograd.Variable(torch.randn(20, 100)) >>> output = m(input) """ def _check_input_dim(self, input): if input.dim() != 2 and input.dim() != 3: raise ValueError('expected 2D or 3D input (got {}D input)' .format(input.dim())) super(SynchronizedBatchNorm1d, self)._check_input_dim(input) class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch of 3d inputs .. math:: y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta This module differs from the built-in PyTorch BatchNorm2d as the mean and standard-deviation are reduced across all devices during training. For example, when one uses `nn.DataParallel` to wrap the network during training, PyTorch's implementation normalize the tensor on each device using the statistics only on that device, which accelerated the computation and is also easy to implement, but the statistics might be inaccurate. Instead, in this synchronized version, the statistics will be computed over all training samples distributed on multiple devices. Note that, for one-GPU or CPU-only case, this module behaves exactly same as the built-in PyTorch implementation. The mean and standard-deviation are calculated per-dimension over the mini-batches and gamma and beta are learnable parameter vectors of size C (where C is the input size). During training, this layer keeps a running estimate of its computed mean and variance. The running sum is kept with a default momentum of 0.1. During evaluation, this running mean/variance is used for normalization. Because the BatchNorm is done over the `C` dimension, computing statistics on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm Args: num_features: num_features from an expected input of size batch_size x num_features x height x width eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, gives the layer learnable affine parameters. Default: ``True`` Shape: - Input: :math:`(N, C, H, W)` - Output: :math:`(N, C, H, W)` (same shape as input) Examples: >>> # With Learnable Parameters >>> m = SynchronizedBatchNorm2d(100) >>> # Without Learnable Parameters >>> m = SynchronizedBatchNorm2d(100, affine=False) >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) >>> output = m(input) """ def _check_input_dim(self, input): if input.dim() != 4: raise ValueError('expected 4D input (got {}D input)' .format(input.dim())) super(SynchronizedBatchNorm2d, self)._check_input_dim(input) class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch of 4d inputs .. math:: y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta This module differs from the built-in PyTorch BatchNorm3d as the mean and standard-deviation are reduced across all devices during training. For example, when one uses `nn.DataParallel` to wrap the network during training, PyTorch's implementation normalize the tensor on each device using the statistics only on that device, which accelerated the computation and is also easy to implement, but the statistics might be inaccurate. Instead, in this synchronized version, the statistics will be computed over all training samples distributed on multiple devices. Note that, for one-GPU or CPU-only case, this module behaves exactly same as the built-in PyTorch implementation. The mean and standard-deviation are calculated per-dimension over the mini-batches and gamma and beta are learnable parameter vectors of size C (where C is the input size). During training, this layer keeps a running estimate of its computed mean and variance. The running sum is kept with a default momentum of 0.1. During evaluation, this running mean/variance is used for normalization. Because the BatchNorm is done over the `C` dimension, computing statistics on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm or Spatio-temporal BatchNorm Args: num_features: num_features from an expected input of size batch_size x num_features x depth x height x width eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, gives the layer learnable affine parameters. Default: ``True`` Shape: - Input: :math:`(N, C, D, H, W)` - Output: :math:`(N, C, D, H, W)` (same shape as input) Examples: >>> # With Learnable Parameters >>> m = SynchronizedBatchNorm3d(100) >>> # Without Learnable Parameters >>> m = SynchronizedBatchNorm3d(100, affine=False) >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) >>> output = m(input) """ def _check_input_dim(self, input): if input.dim() != 5: raise ValueError('expected 5D input (got {}D input)' .format(input.dim())) super(SynchronizedBatchNorm3d, self)._check_input_dim(input) ================================================ FILE: sync_batchnorm/comm.py ================================================ # -*- coding: utf-8 -*- # File : comm.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import queue import collections import threading __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] class FutureResult(object): """A thread-safe future implementation. Used only as one-to-one pipe.""" def __init__(self): self._result = None self._lock = threading.Lock() self._cond = threading.Condition(self._lock) def put(self, result): with self._lock: assert self._result is None, 'Previous result has\'t been fetched.' self._result = result self._cond.notify() def get(self): with self._lock: if self._result is None: self._cond.wait() res = self._result self._result = None return res _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) class SlavePipe(_SlavePipeBase): """Pipe for master-slave communication.""" def run_slave(self, msg): self.queue.put((self.identifier, msg)) ret = self.result.get() self.queue.put(True) return ret class SyncMaster(object): """An abstract `SyncMaster` object. - During the replication, as the data parallel will trigger an callback of each module, all slave devices should call `register(id)` and obtain an `SlavePipe` to communicate with the master. - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, and passed to a registered callback. - After receiving the messages, the master device should gather the information and determine to message passed back to each slave devices. """ def __init__(self, master_callback): """ Args: master_callback: a callback to be invoked after having collected messages from slave devices. """ self._master_callback = master_callback self._queue = queue.Queue() self._registry = collections.OrderedDict() self._activated = False def __getstate__(self): return {'master_callback': self._master_callback} def __setstate__(self, state): self.__init__(state['master_callback']) def register_slave(self, identifier): """ Register an slave device. Args: identifier: an identifier, usually is the device id. Returns: a `SlavePipe` object which can be used to communicate with the master device. """ if self._activated: assert self._queue.empty(), 'Queue is not clean before next initialization.' self._activated = False self._registry.clear() future = FutureResult() self._registry[identifier] = _MasterRegistry(future) return SlavePipe(identifier, self._queue, future) def run_master(self, master_msg): """ Main entry for the master device in each forward pass. The messages were first collected from each devices (including the master device), and then an callback will be invoked to compute the message to be sent back to each devices (including the master device). Args: master_msg: the message that the master want to send to itself. This will be placed as the first message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. Returns: the message to be sent back to the master device. """ self._activated = True intermediates = [(0, master_msg)] for i in range(self.nr_slaves): intermediates.append(self._queue.get()) results = self._master_callback(intermediates) assert results[0][0] == 0, 'The first result should belongs to the master.' for i, res in results: if i == 0: continue self._registry[i].result.put(res) for i in range(self.nr_slaves): assert self._queue.get() is True return results[0][1] @property def nr_slaves(self): return len(self._registry) ================================================ FILE: sync_batchnorm/replicate.py ================================================ # -*- coding: utf-8 -*- # File : replicate.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import functools from torch.nn.parallel.data_parallel import DataParallel __all__ = [ 'CallbackContext', 'execute_replication_callbacks', 'DataParallelWithCallback', 'patch_replication_callback' ] class CallbackContext(object): pass def execute_replication_callbacks(modules): """ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` Note that, as all modules are isomorphism, we assign each sub-module with a context (shared among multiple copies of this module on different devices). Through this context, different copies can share some information. We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback of any slave copies. """ master_copy = modules[0] nr_modules = len(list(master_copy.modules())) ctxs = [CallbackContext() for _ in range(nr_modules)] for i, module in enumerate(modules): for j, m in enumerate(module.modules()): if hasattr(m, '__data_parallel_replicate__'): m.__data_parallel_replicate__(ctxs[j], i) class DataParallelWithCallback(DataParallel): """ Data Parallel with a replication callback. An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by original `replicate` function. The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` Examples: > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) # sync_bn.__data_parallel_replicate__ will be invoked. """ def replicate(self, module, device_ids): modules = super(DataParallelWithCallback, self).replicate(module, device_ids) execute_replication_callbacks(modules) return modules def update_num_frames(self, new_num_frames): self.unet.update_num_frames(new_num_frames) self.gaussian_diffusion.update_num_frames(new_num_frames) def patch_replication_callback(data_parallel): """ Monkey-patch an existing `DataParallel` object. Add the replication callback. Useful when you have customized `DataParallel` implementation. Examples: > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) > patch_replication_callback(sync_bn) # this is equivalent to > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) """ assert isinstance(data_parallel, DataParallel) old_replicate = data_parallel.replicate @functools.wraps(old_replicate) def new_replicate(module, device_ids): modules = old_replicate(module, device_ids) execute_replication_callbacks(modules) return modules data_parallel.replicate = new_replicate ================================================ FILE: sync_batchnorm/replicate_ddp.py ================================================ # -*- coding: utf-8 -*- # File : replicate.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import functools from torch.nn.parallel.data_parallel import DataParallel from torch.nn.parallel import DistributedDataParallel __all__ = [ 'CallbackContext', 'execute_replication_callbacks', 'DataParallelWithCallback_ddp', 'patch_replication_callback_ddp' ] class CallbackContext(object): pass def execute_replication_callbacks(modules): """ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` Note that, as all modules are isomorphism, we assign each sub-module with a context (shared among multiple copies of this module on different devices). Through this context, different copies can share some information. We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback of any slave copies. """ master_copy = modules[0] nr_modules = len(list(master_copy.modules())) ctxs = [CallbackContext() for _ in range(nr_modules)] for i, module in enumerate(modules): for j, m in enumerate(module.modules()): if hasattr(m, '__data_parallel_replicate__'): m.__data_parallel_replicate__(ctxs[j], i) class DataParallelWithCallback_ddp(DistributedDataParallel): """ Data Parallel with a replication callback. An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by original `replicate` function. The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` Examples: > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) # sync_bn.__data_parallel_replicate__ will be invoked. """ def replicate(self, module, device_ids): modules = super(DataParallelWithCallback_ddp, self).replicate(module, device_ids) execute_replication_callbacks(modules) return modules def update_num_frames(self, new_num_frames): self.unet.update_num_frames(new_num_frames) self.gaussian_diffusion.update_num_frames(new_num_frames) def patch_replication_callback_ddp(data_parallel): """ Monkey-patch an existing `DataParallel` object. Add the replication callback. Useful when you have customized `DataParallel` implementation. Examples: > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) > patch_replication_callback(sync_bn) # this is equivalent to > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) """ assert isinstance(data_parallel, DistributedDataParallel) old_replicate = data_parallel.replicate @functools.wraps(old_replicate) def new_replicate(module, device_ids): modules = old_replicate(module, device_ids) execute_replication_callbacks(modules) return modules data_parallel.replicate = new_replicate ================================================ FILE: sync_batchnorm/unittest.py ================================================ # -*- coding: utf-8 -*- # File : unittest.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import unittest import numpy as np from torch.autograd import Variable def as_numpy(v): if isinstance(v, Variable): v = v.data return v.cpu().numpy() class TorchTestCase(unittest.TestCase): def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): npa, npb = as_numpy(a), as_numpy(b) self.assertTrue( np.allclose(npa, npb, atol=atol), 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) ) ================================================ FILE: unified_video_generator.py ================================================ import os import os.path as osp import argparse from pathlib import Path import subprocess import sys sys.path.append('.') import os import cv2 import yaml import tempfile import numpy as np import torch import soundfile as sf from scipy.interpolate import interp1d from extract_init_states.FaceBoxes.FaceBoxes_ONNX import FaceBoxes_ONNX from extract_init_states.TDDFA_ONNX import TDDFA_ONNX from extract_init_states.utils.pose import get_pose from extract_init_states.utils.functions import calculate_eye, calculate_bbox from transformers import AutoProcessor, HubertModel from PBnet.src.models.get_model import get_model as get_gen_model from PIL import Image from torchvision import transforms from pydub import AudioSegment def inv_transform(x, min_vals, max_vals): return x * (max_vals - min_vals) + min_vals def load_args(filename): with open(filename, "rb") as optfile: opt = yaml.load(optfile, Loader=yaml.Loader) return opt class VideoGenerator: def __init__(self, args): self.audio_path = args.audio_path self.image_path = args.image_path self.output_path = args.output_path self.cache_path = args.cache_path self.resolution = args.resolution # Ensure output directories exist os.makedirs(self.cache_path, exist_ok=True) os.makedirs(self.output_path, exist_ok=True) # Set intermediate file paths self.audio_emb_path = os.path.join(self.cache_path, 'target_audio.npy') # Set ONNX runtime environment for 3DDFA os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' os.environ['OMP_NUM_THREADS'] = '8' # Initialize configuration self.config_path = './extract_init_states/configs/mb1_120x120.yml' self.cfg = yaml.load(open(self.config_path), Loader=yaml.SafeLoader) # Initialize models self.face_boxes = FaceBoxes_ONNX() self.tddfa = TDDFA_ONNX(**self.cfg) # HuBERT model configuration print("Loading the Wav2Vec2 Processor...") self.wav2vec2_processor = AutoProcessor.from_pretrained("./pretrain_models/hubert-large-ls960-ft") print("Loading the HuBERT Model...") self.hubert_model = HubertModel.from_pretrained("./pretrain_models/hubert-large-ls960-ft") self.hubert_model.eval() # PBnet related configuration self.pbnet_pose_ckpt = './pretrain_models/pbnet_seperate/pose/checkpoint_40000.pth.tar' self.pbnet_blink_ckpt = './pretrain_models/pbnet_seperate/blink/checkpoint_95000.pth.tar' self.device = 'cuda:0' # PBnet model parameters folder_p, _ = os.path.split(self.pbnet_pose_ckpt) self.pose_params = load_args(os.path.join(folder_p, "opt.yaml")) self.pose_params['device'] = self.device self.pose_params['audio_dim'] = 1024 self.pose_params['pos_dim'] = 6 self.pose_params['eye_dim'] = 0 folder_b, _ = os.path.split(self.pbnet_blink_ckpt) self.blink_params = load_args(os.path.join(folder_b, "opt.yaml")) self.blink_params['device'] = self.device self.blink_params['audio_dim'] = 1024 self.blink_params['pos_dim'] = 0 self.blink_params['eye_dim'] = 2 # Add normalization parameters self.max_vals = torch.tensor([90, 90, 90, 1, 720, 1080]).to(torch.float32).reshape(1, 1, 6) self.min_vals = torch.tensor([-90, -90, -90, 0, 0, 0]).to(torch.float32).reshape(1, 1, 6) # Load models model_p = get_gen_model(self.pose_params) model_b = get_gen_model(self.blink_params) # Load pretrained weights state_dict_p = torch.load(self.pbnet_pose_ckpt, map_location=self.device) state_dict_b = torch.load(self.pbnet_blink_ckpt, map_location=self.device) model_p.load_state_dict(state_dict_p) model_b.load_state_dict(state_dict_b) model_p.eval() model_b.eval() self.model_p = model_p self.model_b = model_b # Add default video generation configuration current_dir = osp.dirname(osp.abspath(__file__)) # Load configuration file config_path = osp.join(current_dir, 'config', f'DAWN_{int(self.resolution)}.yaml') with open(config_path, 'r') as f: self.video_config = yaml.safe_load(f) # Initialize video generation model as None for lazy loading self.video_model = self._init_video_model(self.video_config['model_config']) # def switch_conda_env(self, env_name): # """切换 conda 环境的函数""" # # 这里需要使用 subprocess 来执行 conda 命令 # subprocess.run(f"conda activate {env_name}", shell=True) def extract_pose(self): """Extract facial pose and landmark information from input image. This function uses 3DDFA-V2 model for face detection and pose estimation. Main steps include: 1. Load and initialize face detection and pose estimation models 2. Process input image 3. Extract facial pose and landmark information 4. Save results to specified paths Output files: - init_pose.npy: Numpy array file containing pose information - init_eye_bbox.npy: Numpy array file containing eye and bounding box information """ # Set ONNX runtime environment os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' os.environ['OMP_NUM_THREADS'] = '8' # Initialize configuration # config_path = 'configs/mb1_120x120.yml' # Make sure path is correct cfg = yaml.load(open(self.config_path), Loader=yaml.SafeLoader) # Initialize models face_boxes = FaceBoxes_ONNX() tddfa = TDDFA_ONNX(**cfg) # Read input image image = cv2.imread(self.image_path) if image.shape[2] == 4: # Handle RGBA images image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) # Face detection boxes = face_boxes(image) if len(boxes) == 0: raise ValueError(f'No face detected in image: {self.image_path}') return None # Get 3DMM parameters and ROI boxes param_lst, roi_box_lst = tddfa(image, boxes) # Reconstruct vertices dense_flag = True # For generating dense landmarks ver_lst = tddfa.recon_vers(param_lst, roi_box_lst, dense_flag=dense_flag) # Get pose information pose = get_pose(image, param_lst, ver_lst, show_flag=False, wfp=None, wnp=None) # Calculate eye and bounding box information lmk = ver_lst[0] eye_bbox_result = np.zeros(8) bbox = calculate_bbox(image, lmk) left_ratio, right_ratio = calculate_eye(lmk) # Organize result data eye_bbox_result[0] = left_ratio.sum() eye_bbox_result[1] = right_ratio.sum() eye_bbox_result[2:] = np.array(bbox) # Reshape arrays pose = pose.reshape(1, 7) eye_bbox_result = eye_bbox_result.reshape(1, -1) # Set save paths eye_bbox_path = os.path.join(self.cache_path, 'init_eye_bbox.npy') pose_path = os.path.join(self.cache_path, 'init_pose.npy') # Save results np.save(eye_bbox_path, eye_bbox_result) np.save(pose_path, pose) def process_audio(self): """Process audio file and extract HuBERT features. This method performs the following steps: 1. Convert input audio to 16kHz sampling rate 2. Extract audio features using HuBERT model 3. Interpolate features to match video frame rate 4. Save processed features Output files: - target_audio.npy: Numpy array containing interpolated HuBERT features Raises: RuntimeError: If audio processing fails """ # self.switch_conda_env("DAWN") try: # Create temp file for 16kHz audio with tempfile.NamedTemporaryFile('w', suffix='.wav', dir='./') as temp_wav: # Convert audio sampling rate to 16kHz self._convert_wav_to_16k(self.audio_path, temp_wav.name) # Read 16kHz audio speech_16k, _ = sf.read(temp_wav.name) # Calculate target frame count (based on 25fps video) num_frames = int((speech_16k.shape[0] / 16000) * 25) # Extract HuBERT features hubert_hidden = self._get_hubert_from_16k_speech(speech_16k, device=self.device) hubert_hidden = hubert_hidden.detach().numpy() # Linear interpolation of features interp_func = interp1d(np.arange(hubert_hidden.shape[0]), hubert_hidden, kind='linear', axis=0) hubert_feature_interpolated = interp_func( np.linspace(0, hubert_hidden.shape[0] - 1, num_frames) ).astype(np.float32) print(f'Frame count: {num_frames}, HuBERT size: {hubert_hidden.shape[0]}') # Save processed features np.save(self.audio_emb_path, hubert_feature_interpolated) except Exception as e: raise RuntimeError(f"Audio processing failed: {str(e)}") def generate_pose_blink(self): """Generate pose and blink data. This function uses the PBnet model to generate driving pose and blink data. Main steps include: 1. Load pretrained pose and blink models 2. Process input data (audio features, initial pose, initial blink) 3. Generate driving data 4. Save results Output files: - dri_pose.npy: Generated pose data - dri_blink.npy: Generated blink data """ # Set input paths init_pose_path = os.path.join(self.cache_path, 'init_pose.npy') init_blink_path = os.path.join(self.cache_path, 'init_eye_bbox.npy') try: # Load input data init_pose = torch.from_numpy(np.load(init_pose_path))[:,:self.pose_params['pos_dim']].unsqueeze(0).to(torch.float32) init_blink = torch.from_numpy(np.load(init_blink_path))[:,:self.blink_params['eye_dim']].unsqueeze(0).to(torch.float32) audio = torch.from_numpy(np.load(self.audio_emb_path)).unsqueeze(0).to(torch.float32) except Exception: # Use default values when 3DDFA extraction fails init_pose = torch.from_numpy(np.array([[0, 0, 0, 4.79e-04, 5.65e+01, 6.49e+01,]]))[:,:self.pose_params['pos_dim']].unsqueeze(0).to(torch.float32) init_blink = torch.from_numpy(np.array([[0.3,0.3]]))[:,:self.blink_params['eye_dim']].unsqueeze(0).to(torch.float32) audio = torch.from_numpy(np.load(self.audio_emb_path)).unsqueeze(0).to(torch.float32) # normalize init_pose = (init_pose - self.min_vals) / (self.max_vals - self.min_vals) with torch.no_grad(): # 生成驱动数据 gendurations_seg = torch.tensor([audio.shape[1] - 0]) batch_p = self.model_p.generate(init_pose, audio, gendurations_seg, fact=1) batch_b = self.model_b.generate(init_blink, audio, gendurations_seg, fact=1) # process the output output_p = batch_p['output'].detach().cpu() output_b = batch_b['output'].detach().cpu() output_p = output_p + init_pose output_p = inv_transform(output_p, self.min_vals, self.max_vals) output_b = output_b + init_blink # save results output_pose_path = os.path.join(self.cache_path, 'dri_pose.npy') output_blink_path = os.path.join(self.cache_path, 'dri_blink.npy') np.save(output_pose_path, output_p[0]) np.save(output_blink_path, output_b[0]) def generate_final_video(self): """Generate the final video. Args: Raises: RuntimeError: If an error occurs during video generation """ try: # prepare the output dir directory_name = os.path.splitext(os.path.basename(self.image_path))[0] video_dir = os.path.join(self.output_path, directory_name, 'video') img_dir = os.path.join(self.output_path, directory_name, 'img') os.makedirs(video_dir, exist_ok=True) os.makedirs(img_dir, exist_ok=True) # prepare input image = Image.open(self.image_path).convert("RGB") transform = transforms.Compose([ transforms.Resize((self.video_config['input_size'], self.video_config['input_size'])), transforms.ToTensor() ]) image_tensor = transform(image) * 255 # load the audio emb and condition (pose blink) hubert_npy = np.load(self.audio_emb_path) max_frames = min(self.video_config['max_n_frames'], hubert_npy.shape[0]) ref_hubert = torch.from_numpy(hubert_npy[:max_frames]).to(torch.float32) drive_poses = torch.from_numpy(np.load(os.path.join(self.cache_path, 'dri_pose.npy'))[:max_frames]).to(torch.float32) drive_blink = torch.from_numpy(np.load(os.path.join(self.cache_path, 'dri_blink.npy'))[:max_frames]).to(torch.float32) try: real_poses = torch.from_numpy(np.load(os.path.join(self.cache_path, 'init_pose.npy'))).to(torch.float32) real_blink_bbox = torch.from_numpy(np.load(os.path.join(self.cache_path, 'init_eye_bbox.npy'))).to(torch.float32) except Exception: # default value real_poses = torch.zeros(1, 7) real_blink_bbox = torch.tensor([[0.3, 0.3, 64, 64, 192, 192, 256, 256]]).reshape(1, -1).to(torch.float32) # prepare init state init_pose = real_poses[0].unsqueeze(0) init_blink = real_blink_bbox[0,:2].unsqueeze(0) # process drive_poses = drive_poses.permute(1,0) drive_blink = drive_blink.permute(1,0) real_blink_bbox = real_blink_bbox.permute(1,0) # temp file with tempfile.NamedTemporaryFile('w', suffix='.wav') as temp_wav, \ tempfile.NamedTemporaryFile('w', suffix='.mp4') as temp_video: # extract the audio seg self._extract_audio_segment(self.audio_path, 0, max_frames, 25, temp_wav.name) # video writer fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_writer = cv2.VideoWriter( temp_video.name, fourcc, 25, (self.video_config['input_size'], self.video_config['input_size']) ) # ddim generation with torch.no_grad(): self.video_model.update_num_frames(max_frames) sample_output = self.video_model.sample_one_video( sample_img=image_tensor.unsqueeze(dim=0).cuda()/255., sample_audio_hubert=ref_hubert.unsqueeze(dim=0).cuda(), sample_pose=drive_poses.unsqueeze(0).cuda(), sample_eye=drive_blink[:2].unsqueeze(0).cuda(), sample_bbox=real_blink_bbox[2:].unsqueeze(0).cuda(), init_pose=init_pose.cuda(), init_eye=init_blink.cuda(), cond_scale=self.video_config['cond_scale'] ) # write the frame for frame_idx in range(max_frames): frame = self._process_output_frame( sample_output["sample_out_vid"][:, :, frame_idx], mean=self.video_config['mean'] ) video_writer.write(frame) # save frames as png frame_name = f"{frame_idx:03d}.png" frame_path = os.path.join(img_dir, frame_name) cv2.imwrite(frame_path, frame) video_writer.release() # save the final video output_video_path = os.path.join(video_dir, f"{directory_name}.mp4") self._combine_video_audio(temp_wav.name, temp_video.name, output_video_path) except Exception as e: raise RuntimeError(f"! Video generation failed: {str(e)}") def run(self): """Execute the complete generation pipeline""" print("1. Extracting pose information...") self.extract_pose() print("2. Processing audio...") self.process_audio() print("3. Generating pose and blink data...") self.generate_pose_blink() print("4. Generating final video...") self.generate_final_video() @staticmethod def _convert_wav_to_16k(input_file, output_file): """Convert audio file to 16kHz sampling rate. Args: input_file (str): Path to input audio file output_file (str): Path to output audio file """ command = [ 'ffmpeg', '-i', input_file, '-ar', '16000', '-y', # Add -y parameter to automatically overwrite existing files output_file ] subprocess.run(command) @torch.no_grad() def _get_hubert_from_16k_speech(self, speech, device="cuda:0"): """Extract HuBERT features from 16kHz audio. Args: speech (numpy.ndarray): Input audio data device (str): Computing device, defaults to "cuda:0" Returns: torch.Tensor: HuBERT feature tensor Notes: HuBERT model uses multi-layer CNN for processing: - Total stride is 320 (5*2*2*2*2*2) - Kernel size is 400 - Process long audio in segments to avoid memory issues """ self.hubert_model = self.hubert_model.to(device) if speech.ndim == 2: speech = speech[:, 0] # [T, 2] ==> [T,] input_values_all = self.wav2vec2_processor( speech, return_tensors="pt", sampling_rate=16000 ).input_values.to(device) # Set parameters for segment processing kernel = 400 stride = 320 clip_length = stride * 1000 num_iter = input_values_all.shape[1] // clip_length expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride # Process audio in segments res_lst = [] for i in range(num_iter): if i == 0: start_idx = 0 end_idx = clip_length - stride + kernel else: start_idx = clip_length * i end_idx = start_idx + (clip_length - stride + kernel) input_values = input_values_all[:, start_idx: end_idx] hidden_states = self.hubert_model(input_values).last_hidden_state res_lst.append(hidden_states[0]) # the last seg if num_iter > 0: input_values = input_values_all[:, clip_length * num_iter:] else: input_values = input_values_all if input_values.shape[1] >= kernel: hidden_states = self.hubert_model(input_values).last_hidden_state res_lst.append(hidden_states[0]) # concat the feature ret = torch.cat(res_lst, dim=0).cpu() # check length assert abs(ret.shape[0] - expected_T) <= 1 if ret.shape[0] < expected_T: ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0])) else: ret = ret[:expected_T] return ret def _init_video_model(self, model_config): """Initialize the video generation model. Args: model_config (dict): Model configuration dictionary Returns: FlowDiffusion: Initialized video generation model """ from DM_3.modules.video_flow_diffusion_model_multiGPU_v0_crema_vgg_floss_plus_faceemb_flow_fast_init_cond_test import FlowDiffusion model = FlowDiffusion( is_train=model_config['is_train'], sampling_timesteps=self.video_config['sampling_step'], ddim_sampling_eta=self.video_config['ddim_sampling_eta'], pose_dim=model_config['pose_dim'], config_pth=model_config['config_pth'], pretrained_pth=model_config['ae_pretrained_pth'], win_width=self.video_config['win_width'] ) model.cuda() # load model checkpoint = torch.load(model_config['diffusion_pretrained_pth']) model.diffusion.load_state_dict(checkpoint['diffusion']) model.eval() return model def _process_output_frame(self, frame_batch, mean=(0.0, 0.0, 0.0), index=0): """Process the output frame data from the model. Args: frame_batch (torch.Tensor): Batch of frame data mean (tuple): Mean values index (int): Batch index Returns: numpy.ndarray: Processed frame in BGR format """ frame = frame_batch[index].permute(1, 2, 0).data.cpu().numpy().copy() frame += np.array(mean)/255.0 frame = np.clip(frame, 0, 1) frame = (frame * 255).astype(np.uint8) return cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) def _extract_audio_segment(self, input_wav, start_frame, num_frames, fps, output_wav): """Extract audio segment. Args: input_wav (str): Input audio path start_frame (int): Start frame num_frames (int): Number of frames fps (int): Frames per second output_wav (str): Output audio path """ audio = AudioSegment.from_wav(input_wav) frame_duration = 1000 / fps start_time = start_frame * frame_duration end_time = (start_frame + num_frames) * frame_duration audio[start_time:end_time].export(output_wav, format="wav") def _combine_video_audio(self, audio_path, video_path, output_path): """Combine video and audio. Args: audio_path (str): Path to audio file video_path (str): Path to video file output_path (str): Path to output file """ cmd = [ 'ffmpeg', '-y', '-i', audio_path, '-i', video_path, '-vcodec', 'copy', '-ac', '2', '-channel_layout', 'stereo', '-pix_fmt', 'yuv420p', output_path, '-shortest' ] subprocess.run(cmd) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--audio_path', type=str, default= 'WRA_MarcoRubio_000.wav', help='Input audio path') parser.add_argument('--image_path', type=str, default= 'real_female_1.jpeg', help='Input image path') parser.add_argument('--output_path', type=str, default= 'output', help='Output video path') parser.add_argument('--cache_path', type=str, default='cache/tmp', help='Cache file path') parser.add_argument('--resolution', type=int, default=128, help='resolution') return parser.parse_args() def main(): args = parse_args() generator = VideoGenerator(args) generator.run() if __name__ == "__main__": main()