Repository: jixinya/EAMM Branch: main Commit: e176a79865f7 Files: 53 Total size: 345.8 KB Directory structure: gitextract_7pp944hx/ ├── 3DDFA_V2/ │ ├── demo.py │ └── utils/ │ └── pose.py ├── LICENSE ├── M003_template.npy ├── README.md ├── augmentation.py ├── config/ │ ├── MEAD_emo_video_aug_delta_4_crop_random_crop.yaml │ ├── train_part1.yaml │ ├── train_part1_fine_tune.yaml │ └── train_part2.yaml ├── dataset/ │ ├── LRW/ │ │ ├── MFCC/ │ │ │ └── ABOUT/ │ │ │ └── ABOUT_00001.npy │ │ └── Pose/ │ │ └── ABOUT/ │ │ └── ABOUT_00001.npy │ └── MEAD/ │ └── list/ │ └── MEAD_fomm_neu_dic_crop.npy ├── demo.py ├── filter1.py ├── frames_dataset.py ├── logger.py ├── modules/ │ ├── dense_motion.py │ ├── discriminator.py │ ├── function.py │ ├── generator.py │ ├── keypoint_detector.py │ ├── model.py │ ├── model_delta_map.py │ ├── model_gen.py │ ├── ops.py │ ├── stylegan2.py │ └── util.py ├── ops.py ├── process_data.py ├── requirements.txt ├── run.py ├── sync_batchnorm/ │ ├── __init__.py │ ├── batchnorm.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py ├── test/ │ ├── pose/ │ │ ├── 14.npy │ │ ├── 21.npy │ │ ├── 60.npy │ │ ├── 7.npy │ │ ├── anne.npy │ │ ├── brade2.npy │ │ ├── dune_1.npy │ │ ├── dune_2.npy │ │ ├── jake4.npy │ │ ├── mona.npy │ │ ├── paint1.npy │ │ └── scarlett.npy │ └── pose_long/ │ ├── 0zn70Ak8lRc_Daniel_Auteuil_0zn70Ak8lRc_0002.npy │ ├── 1hEr7qKRKL4_Daniel_Dae_Kim_1hEr7qKRKL4_0004.npy │ └── 50IAfJCypFI_Alex_Kingston_50IAfJCypFI_0001.npy └── train.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: 3DDFA_V2/demo.py ================================================ # coding: utf-8 __author__ = 'cleardusk' import sys import argparse import cv2 import yaml import os import time from FaceBoxes import FaceBoxes from TDDFA import TDDFA from utils.render import render #from utils.render_ctypes import render # faster from utils.depth import depth from utils.pncc import pncc from utils.uv import uv_tex from utils.pose import viz_pose, get_pose from utils.serialization import ser_to_ply, ser_to_obj from utils.functions import draw_landmarks, get_suffix from utils.tddfa_util import str2bool import numpy as np from tqdm import tqdm import copy import concurrent.futures from multiprocessing import Pool def main(args,img, save_path, pose_path): # begin = time.time() cfg = yaml.load(open(args.config), Loader=yaml.SafeLoader) # Init FaceBoxes and TDDFA, recommend using onnx flag if args.onnx: import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' os.environ['OMP_NUM_THREADS'] = '4' from FaceBoxes.FaceBoxes_ONNX import FaceBoxes_ONNX from TDDFA_ONNX import TDDFA_ONNX face_boxes = FaceBoxes_ONNX() tddfa = TDDFA_ONNX(**cfg) else: gpu_mode = args.mode == 'gpu' tddfa = TDDFA(gpu_mode=gpu_mode, **cfg) face_boxes = FaceBoxes() # Given a still image path and load to BGR channel # img = cv2.imread(img_path) #args.img_fp # Detect faces, get 3DMM params and roi boxes boxes = face_boxes(img) n = len(boxes) if n == 0: print(f'No face detected, exit') # sys.exit(-1) return None print(f'Detect {n} faces') param_lst, roi_box_lst = tddfa(img, boxes) #detection time # detect_time = time.time()-begin # print('detection time: '+str(detect_time), file=open('/mnt/lustre/jixinya/Home/3DDFA_V2/pose.txt', 'a')) # Visualization and serialization dense_flag = args.opt in ('2d_dense', '3d', 'depth', 'pncc', 'uv_tex', 'ply', 'obj') # old_suffix = get_suffix(img_path) old_suffix = 'png' new_suffix = f'.{args.opt}' if args.opt in ('ply', 'obj') else '.jpg' wfp = f'examples/results/{args.img_fp.split("/")[-1].replace(old_suffix, "")}_{args.opt}' + new_suffix ver_lst = tddfa.recon_vers(param_lst, roi_box_lst, dense_flag=dense_flag) if args.opt == '2d_sparse': draw_landmarks(img, ver_lst, show_flag=args.show_flag, dense_flag=dense_flag, wfp=wfp) elif args.opt == '2d_dense': draw_landmarks(img, ver_lst, show_flag=args.show_flag, dense_flag=dense_flag, wfp=wfp) elif args.opt == '3d': render(img, ver_lst, tddfa.tri, alpha=0.6, show_flag=args.show_flag, wfp=wfp) elif args.opt == 'depth': # if `with_bf_flag` is False, the background is black depth(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp, with_bg_flag=True) elif args.opt == 'pncc': pncc(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp, with_bg_flag=True) elif args.opt == 'uv_tex': uv_tex(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp) elif args.opt == 'pose': all_pose = get_pose(img, param_lst, ver_lst, show_flag=args.show_flag, wfp=save_path, wnp = pose_path) elif args.opt == 'ply': ser_to_ply(ver_lst, tddfa.tri, height=img.shape[0], wfp=wfp) elif args.opt == 'obj': ser_to_obj(img, ver_lst, tddfa.tri, height=img.shape[0], wfp=wfp) else: raise ValueError(f'Unknown opt {args.opt}') return all_pose def process_word(i): path = '/media/xinya/Backup Plus/sense_shixi_data/new_crop/MEAD_fomm_video_6/' save = '/media/xinya/Backup Plus/sense_shixi_data/new_crop/MEAD_fomm_pose_im/' pose = '/media/xinya/Backup Plus/sense_shixi_data/new_crop/MEAD_fomm_pose/' start = time.time() Dir = os.listdir(path) Dir.sort() word = Dir[i] wpath = os.path.join(path, word) print(wpath) pathDir = os.listdir(wpath) pose_file = os.path.join(pose,word) if not os.path.exists(pose_file): os.makedirs(pose_file) for j in range(len(pathDir)): name = pathDir[j] # save_file = os.path.join(save,word,name) # if not os.path.exists(save_file): # os.makedirs(save_file) fpath = os.path.join(wpath,name) image_all = [] videoCapture = cv2.VideoCapture(fpath) success, frame = videoCapture.read() n = 0 while success : image_all.append(frame) n = n + 1 success, frame = videoCapture.read() # fDir = os.listdir(fpath) pose_all = np.zeros((len(image_all),7)) for k in range(len(image_all)): # index = fDir[k].split('.')[0] # img_path = os.path.join(fpath,str(k)+'.png') # pose_all[k] = main(args,image_all[k], os.path.join(save_file,str(k)+'.jpg'), None) pose_all[k] = main(args,image_all[k], None, None) np.save(os.path.join(pose,word,name.split('.')[0]+'.npy'),pose_all) st = time.time()-start print(str(i)+' '+word+' '+str(j)+' '+name+' '+str(k)+'time: '+str(st), file=open('/media/thea/Backup Plus/sense_shixi_data/new_crop/pose_mead6.txt', 'a')) print(i,word,j,name,k) if __name__ == '__main__': parser = argparse.ArgumentParser(description='The demo of still image of 3DDFA_V2') parser.add_argument('-c', '--config', type=str, default='configs/mb1_120x120.yml') parser.add_argument('-f', '--img_fp', type=str, default='examples/inputs/0.png') parser.add_argument('-m', '--mode', type=str, default='cpu', help='gpu or cpu mode') parser.add_argument('-o', '--opt', type=str, default='pose', choices=['2d_sparse', '2d_dense', '3d', 'depth', 'pncc', 'uv_tex', 'pose', 'ply', 'obj']) parser.add_argument('--show_flag', type=str2bool, default='False', help='whether to show the visualization result') parser.add_argument('--onnx', action='store_true', default=False) args = parser.parse_args() filepath = 'test/image/' pathDir = os.listdir(filepath) for i in range(len(pathDir)): image= cv2.imread(os.path.join(filepath,pathDir[i])) pose = main(args,image, None, None).reshape(1,7) np.save('test/pose/'+pathDir[i].split('.')[0]+'.npy',pose) print(i,pathDir[i]) ''' def main(args): cfg = yaml.load(open(args.config), Loader=yaml.SafeLoader) # Init FaceBoxes and TDDFA, recommend using onnx flag if args.onnx: import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' os.environ['OMP_NUM_THREADS'] = '4' from FaceBoxes.FaceBoxes_ONNX import FaceBoxes_ONNX from TDDFA_ONNX import TDDFA_ONNX face_boxes = FaceBoxes_ONNX() tddfa = TDDFA_ONNX(**cfg) else: gpu_mode = args.mode == 'gpu' tddfa = TDDFA(gpu_mode=gpu_mode, **cfg) face_boxes = FaceBoxes() # Given a still image path and load to BGR channel img = cv2.imread(args.img_fp) # Detect faces, get 3DMM params and roi boxes boxes = face_boxes(img) n = len(boxes) if n == 0: print(f'No face detected, exit') sys.exit(-1) print(f'Detect {n} faces') param_lst, roi_box_lst = tddfa(img, boxes) # Visualization and serialization dense_flag = args.opt in ('2d_dense', '3d', 'depth', 'pncc', 'uv_tex', 'ply', 'obj') old_suffix = get_suffix(args.img_fp) new_suffix = f'.{args.opt}' if args.opt in ('ply', 'obj') else '.jpg' wfp = f'examples/results/{args.img_fp.split("/")[-1].replace(old_suffix, "")}_{args.opt}' + new_suffix ver_lst = tddfa.recon_vers(param_lst, roi_box_lst, dense_flag=dense_flag) if args.opt == '2d_sparse': draw_landmarks(img, ver_lst, show_flag=args.show_flag, dense_flag=dense_flag, wfp=wfp) elif args.opt == '2d_dense': draw_landmarks(img, ver_lst, show_flag=args.show_flag, dense_flag=dense_flag, wfp=wfp) elif args.opt == '3d': render(img, ver_lst, tddfa.tri, alpha=0.6, show_flag=args.show_flag, wfp=wfp) elif args.opt == 'depth': # if `with_bf_flag` is False, the background is black depth(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp, with_bg_flag=True) elif args.opt == 'pncc': pncc(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp, with_bg_flag=True) elif args.opt == 'uv_tex': uv_tex(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp) elif args.opt == 'pose': viz_pose(img, param_lst, ver_lst, show_flag=args.show_flag, wfp=wfp) elif args.opt == 'ply': ser_to_ply(ver_lst, tddfa.tri, height=img.shape[0], wfp=wfp) elif args.opt == 'obj': ser_to_obj(img, ver_lst, tddfa.tri, height=img.shape[0], wfp=wfp) else: raise ValueError(f'Unknown opt {args.opt}') ''' ================================================ FILE: 3DDFA_V2/utils/pose.py ================================================ # coding: utf-8 """ Reference: https://github.com/YadiraF/PRNet/blob/master/utils/estimate_pose.py Calculating pose from the output 3DMM parameters, you can also try to use solvePnP to perform estimation """ __author__ = 'cleardusk' import cv2 import numpy as np from math import cos, sin, atan2, asin, sqrt from .functions import calc_hypotenuse, plot_image def P2sRt(P): """ decompositing camera matrix P. Args: P: (3, 4). Affine Camera Matrix. Returns: s: scale factor. R: (3, 3). rotation matrix. t2d: (2,). 2d translation. """ t3d = P[:, 3] R1 = P[0:1, :3] R2 = P[1:2, :3] s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2.0 r1 = R1 / np.linalg.norm(R1) r2 = R2 / np.linalg.norm(R2) r3 = np.cross(r1, r2) R = np.concatenate((r1, r2, r3), 0) return s, R, t3d def matrix2angle(R): """ compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf refined by: https://stackoverflow.com/questions/43364900/rotation-matrix-to-euler-angles-with-opencv todo: check and debug Args: R: (3,3). rotation matrix Returns: x: yaw y: pitch z: roll """ if R[2, 0] > 0.998: z = 0 x = np.pi / 2 y = z + atan2(-R[0, 1], -R[0, 2]) elif R[2, 0] < -0.998: z = 0 x = -np.pi / 2 y = -z + atan2(R[0, 1], R[0, 2]) else: x = asin(R[2, 0]) y = atan2(R[2, 1] / cos(x), R[2, 2] / cos(x)) z = atan2(R[1, 0] / cos(x), R[0, 0] / cos(x)) return x, y, z def angle2matrix(theta): """ compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf refined by: https://stackoverflow.com/questions/43364900/rotation-matrix-to-euler-angles-with-opencv todo: check and debug Args: R: (3,3). rotation matrix Returns: x: yaw y: pitch z: roll """ R_x = np.array([[1, 0, 0 ], [0, cos(theta[1]), -sin(theta[1]) ], [0, sin(theta[1]), cos(theta[1]) ] ]) R_y = np.array([[cos(theta[0]), 0, sin(-theta[0]) ], [0, 1, 0 ], [-sin(-theta[0]), 0, cos(theta[0]) ] ]) R_z = np.array([[cos(theta[2]), -sin(theta[2]), 0], [sin(theta[2]), cos(theta[2]), 0], [0, 0, 1] ]) R = np.dot(R_z, np.dot( R_y, R_x )) return R def angle2matrix_3ddfa(angles): ''' get rotation matrix from three rotation angles(radian). The same as in 3DDFA. Args: angles: [3,]. x, y, z angles x: pitch. y: yaw. z: roll. Returns: R: 3x3. rotation matrix. ''' # x, y, z = np.deg2rad(angles[0]), np.deg2rad(angles[1]), np.deg2rad(angles[2]) x, y, z = angles[1], angles[0], angles[2] # x Rx=np.array([[1, 0, 0], [0, cos(x), sin(x)], [0, -sin(x), cos(x)]]) # y Ry=np.array([[ cos(y), 0, -sin(y)], [ 0, 1, 0], [sin(y), 0, cos(y)]]) # z Rz=np.array([[cos(z), sin(z), 0], [-sin(z), cos(z), 0], [ 0, 0, 1]]) R = Rx.dot(Ry).dot(Rz) return R.astype(np.float32) def calc_pose(param): P = param[:12].reshape(3, -1) # camera matrix s, R, t3d = P2sRt(P) P = np.concatenate((R, t3d.reshape(3, -1)), axis=1) # without scale pose = matrix2angle(R) pose = [p * 180 / np.pi for p in pose] return P, pose def build_camera_box(rear_size=90): point_3d = [] rear_depth = 0 point_3d.append((-rear_size, -rear_size, rear_depth)) point_3d.append((-rear_size, rear_size, rear_depth)) point_3d.append((rear_size, rear_size, rear_depth)) point_3d.append((rear_size, -rear_size, rear_depth)) point_3d.append((-rear_size, -rear_size, rear_depth)) front_size = int(4 / 3 * rear_size) front_depth = int(4 / 3 * rear_size) point_3d.append((-front_size, -front_size, front_depth)) point_3d.append((-front_size, front_size, front_depth)) point_3d.append((front_size, front_size, front_depth)) point_3d.append((front_size, -front_size, front_depth)) point_3d.append((-front_size, -front_size, front_depth)) point_3d = np.array(point_3d, dtype=np.float32).reshape(-1, 3) return point_3d def plot_pose_box(img, P, ver, color=(40, 255, 0), line_width=2): """ Draw a 3D box as annotation of pose. Ref:https://github.com/yinguobing/head-pose-estimation/blob/master/pose_estimator.py Args: img: the input image P: (3, 4). Affine Camera Matrix. kpt: (2, 68) or (3, 68) """ llength = calc_hypotenuse(ver) point_3d = build_camera_box(llength) # Map to 2d image points point_3d_homo = np.hstack((point_3d, np.ones([point_3d.shape[0], 1]))) # n x 4 point_2d = point_3d_homo.dot(P.T)[:, :2] point_2d[:, 1] = - point_2d[:, 1] point_2d[:, :2] = point_2d[:, :2] - np.mean(point_2d[:4, :2], 0) + np.mean(ver[:2, :27], 1) point_2d = np.int32(point_2d.reshape(-1, 2)) # Draw all the lines cv2.polylines(img, [point_2d], True, color, line_width, cv2.LINE_AA) cv2.line(img, tuple(point_2d[1]), tuple( point_2d[6]), color, line_width, cv2.LINE_AA) cv2.line(img, tuple(point_2d[2]), tuple( point_2d[7]), color, line_width, cv2.LINE_AA) cv2.line(img, tuple(point_2d[3]), tuple( point_2d[8]), color, line_width, cv2.LINE_AA) return img def viz_pose(img, param_lst, ver_lst, show_flag=False, wfp=None): for param, ver in zip(param_lst, ver_lst): P, pose = calc_pose(param) img = plot_pose_box(img, P, ver) # print(P[:, :3]) print(f'yaw: {pose[0]:.1f}, pitch: {pose[1]:.1f}, roll: {pose[2]:.1f}') if wfp is not None: cv2.imwrite(wfp, img) print(f'Save visualization result to {wfp}') if show_flag: plot_image(img) return img def pose_6(param): P = param[:12].reshape(3, -1) # camera matrix s, R, t3d = P2sRt(P) P = np.concatenate((R, t3d.reshape(3, -1)), axis=1) # without scale pose = matrix2angle(R) print(t3d) R1 = angle2matrix(pose) print(R) print(R1) pose = [p * 180 / np.pi for p in pose] return s, pose, t3d, P def smooth_pose(img, param_lst, ver_lst, pose_new, show_flag=False, wfp=None, wnp = None): for param, ver in zip(param_lst, ver_lst): t3d = np.array([pose_new[4],pose_new[5],pose_new[6]]) theta = np.array([pose_new[0],pose_new[1],pose_new[2]]) theta = [p * np.pi / 180 for p in theta] R = angle2matrix(theta) P = np.concatenate((R, t3d.reshape(3, -1)), axis=1) img = plot_pose_box(img, P, ver) # print(P,P.shape,t3d) print(P,pose_new) print(f'yaw: {theta[0]:.1f}, pitch: {theta[1]:.1f}, roll: {theta[2]:.1f}') all_pose = [0] all_pose = np.array(all_pose) if wfp is not None: cv2.imwrite(wfp, img) print(f'Save visualization result to {wfp}') if wnp is not None: np.save(wnp, all_pose) print(f'Save visualization result to {wfp}') if show_flag: plot_image(img) return img def get_pose(img, param_lst, ver_lst, show_flag=False, wfp=None, wnp = None): for param, ver in zip(param_lst, ver_lst): s, pose, t3d, P = pose_6(param) img = plot_pose_box(img, P, ver) # print(P,P.shape,t3d) print(f'yaw: {pose[0]:.1f}, pitch: {pose[1]:.1f}, roll: {pose[2]:.1f}') all_pose = [pose[0],pose[1],pose[2],s,t3d[0],t3d[1],t3d[2]] all_pose = np.array(all_pose) if wfp is not None: cv2.imwrite(wfp, img) print(f'Save visualization result to {wfp}') if wnp is not None: np.save(wnp, all_pose) print(f'Save visualization result to {wfp}') if show_flag: plot_image(img) return all_pose ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2022 jixinya Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # EAMM: One-Shot Emotional Talking Face via Audio-Based Emotion-Aware Motion Model [SIGGRAPH 2022 Conference] Xinya Ji, [Hang Zhou](https://hangz-nju-cuhk.github.io/), Kaisiyuan Wang, [Qianyi Wu](https://wuqianyi.top/), [Wayne Wu](http://wywu.github.io/), [Feng Xu](http://xufeng.site/), [Xun Cao](https://cite.nju.edu.cn/People/Faculty/20190621/i5054.html) [[Project]](https://jixinya.github.io/projects/EAMM/) [[Paper]](https://arxiv.org/abs/2205.15278) ![visualization](demo/teaser-1.png) Given a single portrait image, we can synthesize emotional talking faces, where mouth movements match the input audio and facial emotion dynamics follow the emotion source video. ## Installation We train and test based on Python3.6 and Pytorch. To install the dependencies run: ``` pip install -r requirements.txt ``` ## Testing - Download the pre-trained models and data under the following link: [google-drive](https://drive.google.com/file/d/1IL9LjH3JegyMqJABqMxrX3StAq_v8Gtp/view?usp=sharing) and put the file in corresponding places. - Run the demo: `python demo.py --source_image path/to/image --driving_video path/to/emotion_video --pose_file path/to/pose --in_file path/to/audio --emotion emotion_type` - Prepare testing data: prepare source_image -- crop_image in process_data.py prepare driving_video -- crop_image_tem in process_data.py prepare pose -- detect pose using [3DDFA_V2](https://github.com/cleardusk/3DDFA_V2) ## Training - Training data structure: ``` ./data/ ├──fomm_crop │ ├──id/file_name # cropped images │ │ ├──0.png │ │ ├──... ├──fomm_pose_crop │ ├──id │ │ ├──file_name.npy # pose of the cropped images │ │ ├──... ├──MFCC │ ├──id │ │ ├──file_name.npy # MFCC of the audio │ │ ├──... *The cropped images are generated by 'crop_image_tem' in process_data.py *The pose of the cropped video are generated by 3DDFA_V2/demo.py *The MFCC of the audio are generated by 'audio2mfcc' in process_data.py ``` - Step 1 : Train the Audio2Facial-Dynamics Module using LRW dataset `python run.py --config config/train_part1.yaml --mode train_part1 --checkpoint log/124_52000.pth.tar ` - Step 2 : Fine-tune the Audio2Facial-Dynamics Module after getting stable results from step1 `python run.py --config config/train_part1_fine_tune.yaml --mode train_part1_fine_tune --checkpoint log/124_52000.pth.tar --audio_chechpoint checkpoint/from/step_1` - Setp 3 : Train the Implicit Emotion Displacement Learner `python run.py --config config/train_part2.yaml --mode train_part2 --checkpoint log/124_52000.pth.tar --audio_chechpoint checkpoint/from/step_2` ## Citation ``` @inproceedings{10.1145/3528233.3530745, author = {Ji, Xinya and Zhou, Hang and Wang, Kaisiyuan and Wu, Qianyi and Wu, Wayne and Xu, Feng and Cao, Xun}, title = {EAMM: One-Shot Emotional Talking Face via Audio-Based Emotion-Aware Motion Model}, year = {2022}, isbn = {9781450393379}, url = {https://doi.org/10.1145/3528233.3530745}, doi = {10.1145/3528233.3530745}, booktitle = {ACM SIGGRAPH 2022 Conference Proceedings}, series = {SIGGRAPH '22} } ``` ================================================ FILE: augmentation.py ================================================ """ Code from https://github.com/hassony2/torch_videovision """ import numbers import math import random import numpy as np import PIL import cv2 from skimage.transform import resize, rotate, AffineTransform, warp from skimage.util import pad import torchvision import warnings from skimage import img_as_ubyte, img_as_float def crop_clip(clip, min_h, min_w, h, w): if isinstance(clip[0], np.ndarray): cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] elif isinstance(clip[0], PIL.Image.Image): cropped = [ img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip ] else: raise TypeError('Expected numpy.ndarray or PIL.Image' + 'but got list of {0}'.format(type(clip[0]))) return cropped def pad_clip(clip, h, w): im_h, im_w = clip[0].shape[:2] pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2) pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2) return pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge') def resize_clip(clip, size, interpolation='bilinear'): if isinstance(clip[0], np.ndarray): if isinstance(size, numbers.Number): im_h, im_w, im_c = clip[0].shape # Min spatial dim already matches minimal size if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size): return clip new_h, new_w = get_resize_sizes(im_h, im_w, size) size = (new_w, new_h) else: size = size[1], size[0] scaled = [ resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True, mode='constant', anti_aliasing=True) for img in clip ] elif isinstance(clip[0], PIL.Image.Image): if isinstance(size, numbers.Number): im_w, im_h = clip[0].size # Min spatial dim already matches minimal size if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size): return clip new_h, new_w = get_resize_sizes(im_h, im_w, size) size = (new_w, new_h) else: size = size[1], size[0] if interpolation == 'bilinear': pil_inter = PIL.Image.NEAREST else: pil_inter = PIL.Image.BILINEAR scaled = [img.resize(size, pil_inter) for img in clip] else: raise TypeError('Expected numpy.ndarray or PIL.Image' + 'but got list of {0}'.format(type(clip[0]))) return scaled def get_resize_sizes(im_h, im_w, size): if im_w < im_h: ow = size oh = int(size * im_h / im_w) else: oh = size ow = int(size * im_w / im_h) return oh, ow class RandomFlip(object): def __init__(self, time_flip=False, horizontal_flip=False): self.time_flip = time_flip self.horizontal_flip = horizontal_flip def __call__(self, clip): if random.random() < 0.5 and self.time_flip: return clip[::-1] if random.random() < 0.5 and self.horizontal_flip: return [np.fliplr(img) for img in clip] return clip class RandomResize(object): """Resizes a list of (H x W x C) numpy.ndarray to the final size The larger the original image is, the more times it takes to interpolate Args: interpolation (str): Can be one of 'nearest', 'bilinear' defaults to nearest size (tuple): (widht, height) """ def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'): self.ratio = ratio self.interpolation = interpolation def __call__(self, clip): scaling_factor = random.uniform(self.ratio[0], self.ratio[1]) if isinstance(clip[0], np.ndarray): im_h, im_w, im_c = clip[0].shape elif isinstance(clip[0], PIL.Image.Image): im_w, im_h = clip[0].size new_w = int(im_w * scaling_factor) new_h = int(im_h * scaling_factor) new_size = (new_w, new_h) resized = resize_clip( clip, new_size, interpolation=self.interpolation) return resized class RandomCrop(object): """Extract random crop at the same location for a list of videos Args: size (sequence or int): Desired output size for the crop in format (h, w) """ def __init__(self, size): if isinstance(size, numbers.Number): size = (size, size) self.size = size def __call__(self, clip): """ Args: img (PIL.Image or numpy.ndarray): List of videos to be cropped in format (h, w, c) in numpy.ndarray Returns: PIL.Image or numpy.ndarray: Cropped list of videos """ h, w = self.size if isinstance(clip[0], np.ndarray): im_h, im_w, im_c = clip[0].shape elif isinstance(clip[0], PIL.Image.Image): im_w, im_h = clip[0].size else: raise TypeError('Expected numpy.ndarray or PIL.Image' + 'but got list of {0}'.format(type(clip[0]))) clip = pad_clip(clip, h, w) im_h, im_w = clip.shape[1:3] x1 = 0 if h == im_h else random.randint(0, im_w - w) y1 = 0 if w == im_w else random.randint(0, im_h - h) cropped = crop_clip(clip, y1, x1, h, w) return cropped class MouthCrop(object): """Extract random crop at the same location for a list of videos Args: size (sequence or int): Desired output size for the crop in format (h, w) """ def __init__(self, center_x, center_y, mask_width, mask_height): self.center_x = center_x self.center_y = center_y self.mask_width = mask_width self.mask_height = mask_height def __call__(self, clip): """ Args: img (PIL.Image or numpy.ndarray): List of videos to be cropped in format (h, w, c) in numpy.ndarray Returns: PIL.Image or numpy.ndarray: Cropped list of videos """ start_x = self.center_x - int(self.mask_width/2) start_y = self.center_y - int(self.mask_height/2) end_x = start_x + self.mask_width end_y = start_y + self.mask_height # mask is all white # mask = 255*np.ones((mask_height, mask_width, 3), dtype=np.uint8) # mask is uniform noise cropped = [] for i in range(len(clip)): mask = np.random.rand(self.mask_height, self.mask_width, 3) img = clip[i].copy() img[start_y:end_y, start_x:end_x, :] = mask cropped.append(img) cropped = np.array(cropped) return cropped class RandomRotation(object): """Rotate entire clip randomly by a random angle within given bounds Args: degrees (sequence or int): Range of degrees to select from If degrees is a number instead of sequence like (min, max), the range of degrees, will be (-degrees, +degrees). """ def __init__(self, degrees): if isinstance(degrees, numbers.Number): if degrees < 0: raise ValueError('If degrees is a single number,' 'must be positive') degrees = (-degrees, degrees) else: if len(degrees) != 2: raise ValueError('If degrees is a sequence,' 'it must be of len 2.') self.degrees = degrees def __call__(self, clip): """ Args: img (PIL.Image or numpy.ndarray): List of videos to be cropped in format (h, w, c) in numpy.ndarray Returns: PIL.Image or numpy.ndarray: Cropped list of videos """ angle = random.uniform(self.degrees[0], self.degrees[1]) if isinstance(clip[0], np.ndarray): rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip] elif isinstance(clip[0], PIL.Image.Image): rotated = [img.rotate(angle) for img in clip] else: raise TypeError('Expected numpy.ndarray or PIL.Image' + 'but got list of {0}'.format(type(clip[0]))) return rotated class RandomPerspective(object): """Rotate entire clip randomly by a random angle within given bounds Args: degrees (sequence or int): Range of degrees to select from If degrees is a number instead of sequence like (min, max), the range of degrees, will be (-degrees, +degrees). """ def __init__(self, pers_num, enlarge_num): self.pers_num = pers_num self.enlarge_num = enlarge_num def __call__(self, clip): """ Args: img (PIL.Image or numpy.ndarray): List of videos to be cropped in format (h, w, c) in numpy.ndarray Returns: PIL.Image or numpy.ndarray: Cropped list of videos """ out = clip for i in range(len(clip)): self.pers_size = np.random.randint(20, self.pers_num) * pow(-1, np.random.randint(2)) self.enlarge_size = np.random.randint(20, self.enlarge_num) * pow(-1, np.random.randint(2)) h, w, c = clip[i].shape crop_size=256 dst = np.array([ [-self.enlarge_size, -self.enlarge_size], [-self.enlarge_size + self.pers_size, w + self.enlarge_size], [h + self.enlarge_size, -self.enlarge_size], [h + self.enlarge_size - self.pers_size, w + self.enlarge_size],], dtype=np.float32) src = np.array([[-self.enlarge_size, -self.enlarge_size], [-self.enlarge_size, w + self.enlarge_size], [h + self.enlarge_size, -self.enlarge_size], [h + self.enlarge_size, w + self.enlarge_size]]).astype(np.float32()) M = cv2.getPerspectiveTransform(src, dst) warped = cv2.warpPerspective(clip[i], M, (crop_size, crop_size), borderMode=cv2.BORDER_REPLICATE) out[i] = warped return out class ColorJitter(object): """Randomly change the brightness, contrast and saturation and hue of the clip Args: brightness (float): How much to jitter brightness. brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. contrast (float): How much to jitter contrast. contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. saturation (float): How much to jitter saturation. saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. hue(float): How much to jitter hue. hue_factor is chosen uniformly from [-hue, hue]. Should be >=0 and <= 0.5. """ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): self.brightness = brightness self.contrast = contrast self.saturation = saturation self.hue = hue def get_params(self, brightness, contrast, saturation, hue): if brightness > 0: brightness_factor = random.uniform( max(0, 1 - brightness), 1 + brightness) else: brightness_factor = None if contrast > 0: contrast_factor = random.uniform( max(0, 1 - contrast), 1 + contrast) else: contrast_factor = None if saturation > 0: saturation_factor = random.uniform( max(0, 1 - saturation), 1 + saturation) else: saturation_factor = None if hue > 0: hue_factor = random.uniform(-hue, hue) else: hue_factor = None return brightness_factor, contrast_factor, saturation_factor, hue_factor def __call__(self, clip): """ Args: clip (list): list of PIL.Image Returns: list PIL.Image : list of transformed PIL.Image """ if isinstance(clip[0], np.ndarray): brightness, contrast, saturation, hue = self.get_params( self.brightness, self.contrast, self.saturation, self.hue) # Create img transform function sequence img_transforms = [] if brightness is not None: img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) if saturation is not None: img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) if hue is not None: img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) if contrast is not None: img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) random.shuffle(img_transforms) img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array, img_as_float] with warnings.catch_warnings(): warnings.simplefilter("ignore") jittered_clip = [] for img in clip: jittered_img = img for func in img_transforms: jittered_img = func(jittered_img) jittered_clip.append(jittered_img.astype('float32')) elif isinstance(clip[0], PIL.Image.Image): brightness, contrast, saturation, hue = self.get_params( self.brightness, self.contrast, self.saturation, self.hue) # Create img transform function sequence img_transforms = [] if brightness is not None: img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) if saturation is not None: img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) if hue is not None: img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) if contrast is not None: img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) random.shuffle(img_transforms) # Apply to all videos jittered_clip = [] for img in clip: for func in img_transforms: jittered_img = func(img) jittered_clip.append(jittered_img) else: raise TypeError('Expected numpy.ndarray or PIL.Image' + 'but got list of {0}'.format(type(clip[0]))) return jittered_clip class AllAugmentationTransform: def __init__(self, crop_mouth_param = None, resize_param=None, rotation_param=None, perspective_param=None, flip_param=None, crop_param=None, jitter_param=None): self.transforms = [] if crop_mouth_param is not None: self.transforms.append(MouthCrop(**crop_mouth_param)) if flip_param is not None: self.transforms.append(RandomFlip(**flip_param)) if rotation_param is not None: self.transforms.append(RandomRotation(**rotation_param)) if perspective_param is not None: self.transforms.append(RandomPerspective(**perspective_param)) if resize_param is not None: self.transforms.append(RandomResize(**resize_param)) if crop_param is not None: self.transforms.append(RandomCrop(**crop_param)) if jitter_param is not None: self.transforms.append(ColorJitter(**jitter_param)) def __call__(self, clip): for t in self.transforms: clip = t(clip) return clip ================================================ FILE: config/MEAD_emo_video_aug_delta_4_crop_random_crop.yaml ================================================ dataset_params: root_dir: /mnt/lustre/share_data/jixinya/MEAD/ frame_shape: [256, 256, 3] id_sampling: False pairs_list: Random_choice augmentation_params: crop_mouth_param: center_x: 135 center_y: 190 mask_width: 100 mask_height: 60 rotation_param: degrees: 30 perspective_param: pers_num: 30 enlarge_num: 40 flip_param: horizontal_flip: True time_flip: False jitter_param: brightness: 0 contrast: 0 saturation: 0 hue: 0 model_params: common_params: num_kp: 10 num_channels: 3 estimate_jacobian: True audio_params: num_kp: 10 num_channels : 3 num_channels_a : 3 estimate_jacobian: True kp_detector_params: temperature: 0.1 block_expansion: 32 max_features: 1024 scale_factor: 0.25 num_blocks: 5 generator_params: block_expansion: 64 max_features: 512 num_down_blocks: 2 num_bottleneck_blocks: 6 estimate_occlusion_map: True dense_motion_params: block_expansion: 64 max_features: 1024 num_blocks: 5 scale_factor: 0.25 discriminator_params: scales: [1] block_expansion: 32 max_features: 512 num_blocks: 4 sn: True train_params: type: linear_4 smooth: False jaco_net: cnn ldmark: fake generator: not train_generator: False num_epochs: 300 num_repeats: 1 epoch_milestones: [60, 90] lr_generator: 2.0e-4 lr_discriminator: 2.0e-4 lr_kp_detector: 2.0e-4 lr_audio_feature: 2.0e-4 batch_size: 16 scales: [1, 0.5, 0.25, 0.125] checkpoint_freq: 1 transform_params: sigma_affine: 0.05 sigma_tps: 0.005 points_tps: 5 loss_weights: generator_gan: 0 discriminator_gan: 1 feature_matching: [10, 10, 10, 10] perceptual: [10, 10, 10, 10, 10] equivariance_value: 0 equivariance_jacobian: 0 emo: 10 reconstruction_params: num_videos: 1000 format: '.mp4' animate_params: num_pairs: 50 format: '.mp4' normalization_params: adapt_movement_scale: False use_relative_movement: True use_relative_jacobian: True visualizer_params: kp_size: 5 draw_border: True colormap: 'gist_rainbow' ================================================ FILE: config/train_part1.yaml ================================================ dataset_params: name: Vox root_dir: dataset/LRW/ frame_shape: [256, 256, 3] id_sampling: False augmentation_params: flip_param: horizontal_flip: False time_flip: False jitter_param: brightness: 0.1 contrast: 0.1 saturation: 0.1 hue: 0.1 model_params: common_params: num_kp: 10 num_channels: 3 estimate_jacobian: True audio_params: num_kp: 10 num_channels : 3 num_channels_a : 3 estimate_jacobian: True kp_detector_params: temperature: 0.1 block_expansion: 32 max_features: 1024 scale_factor: 0.25 num_blocks: 5 generator_params: block_expansion: 64 max_features: 512 num_down_blocks: 2 num_bottleneck_blocks: 6 estimate_occlusion_map: True dense_motion_params: block_expansion: 64 max_features: 1024 num_blocks: 5 scale_factor: 0.25 discriminator_params: scales: [1] block_expansion: 32 max_features: 512 num_blocks: 4 sn: True train_params: jaco_net: cnn ldmark: fake generator: not num_epochs: 300 num_repeats: 1 epoch_milestones: [60, 90] lr_generator: 2.0e-4 lr_discriminator: 2.0e-4 lr_kp_detector: 2.0e-4 lr_audio_feature: 2.0e-4 batch_size: 8 scales: [1, 0.5, 0.25, 0.125] checkpoint_freq: 1 transform_params: sigma_affine: 0.05 sigma_tps: 0.005 points_tps: 5 loss_weights: generator_gan: 0 discriminator_gan: 0 feature_matching: [10, 10, 10, 10] perceptual: [10, 10, 10, 10, 10] equivariance_value: 0 equivariance_jacobian: 0 audio: 10 visualizer_params: kp_size: 5 draw_border: True colormap: 'gist_rainbow' ================================================ FILE: config/train_part1_fine_tune.yaml ================================================ dataset_params: name: LRW root_dir: dataset/LRW/ frame_shape: [256, 256, 3] id_sampling: False augmentation_params: flip_param: horizontal_flip: False time_flip: False jitter_param: brightness: 0.1 contrast: 0.1 saturation: 0.1 hue: 0.1 model_params: common_params: num_kp: 10 num_channels: 3 estimate_jacobian: True audio_params: num_kp: 10 num_channels : 3 num_channels_a : 3 estimate_jacobian: True kp_detector_params: temperature: 0.1 block_expansion: 32 max_features: 1024 scale_factor: 0.25 num_blocks: 5 generator_params: block_expansion: 64 max_features: 512 num_down_blocks: 2 num_bottleneck_blocks: 6 estimate_occlusion_map: True dense_motion_params: block_expansion: 64 max_features: 1024 num_blocks: 5 scale_factor: 0.25 discriminator_params: scales: [1] block_expansion: 32 max_features: 512 num_blocks: 4 sn: True train_params: jaco_net: cnn ldmark: fake generator: audio num_epochs: 300 num_repeats: 1 epoch_milestones: [60, 90] lr_generator: 2.0e-4 lr_discriminator: 2.0e-4 lr_kp_detector: 2.0e-4 lr_audio_feature: 2.0e-4 batch_size: 6 scales: [1, 0.5, 0.25, 0.125] checkpoint_freq: 1 transform_params: sigma_affine: 0.05 sigma_tps: 0.005 points_tps: 5 loss_weights: generator_gan: 0 discriminator_gan: 0 feature_matching: [10, 10, 10, 10] perceptual: [0.1, 0.1, 0.1, 0.1, 0.1] equivariance_value: 0 equivariance_jacobian: 0 audio: 10 visualizer_params: kp_size: 5 draw_border: True colormap: 'gist_rainbow' ================================================ FILE: config/train_part2.yaml ================================================ dataset_params: name: MEAD root_dir: dataset/MEAD/ frame_shape: [256, 256, 3] id_sampling: False augmentation_params: crop_mouth_param: center_x: 135 center_y: 190 mask_width: 100 mask_height: 60 rotation_param: degrees: 30 perspective_param: pers_num: 30 enlarge_num: 40 flip_param: horizontal_flip: True time_flip: False jitter_param: brightness: 0 contrast: 0 saturation: 0 hue: 0 model_params: common_params: num_kp: 10 num_channels: 3 estimate_jacobian: True audio_params: num_kp: 10 num_channels : 3 num_channels_a : 3 estimate_jacobian: True kp_detector_params: temperature: 0.1 block_expansion: 32 max_features: 1024 scale_factor: 0.25 num_blocks: 5 generator_params: block_expansion: 64 max_features: 512 num_down_blocks: 2 num_bottleneck_blocks: 6 estimate_occlusion_map: True dense_motion_params: block_expansion: 64 max_features: 1024 num_blocks: 5 scale_factor: 0.25 discriminator_params: scales: [1] block_expansion: 32 max_features: 512 num_blocks: 4 sn: True train_params: type: linear_4 smooth: False jaco_net: cnn ldmark: fake generator: not num_epochs: 300 num_repeats: 1 epoch_milestones: [60, 90] lr_generator: 2.0e-4 lr_discriminator: 2.0e-4 lr_kp_detector: 2.0e-4 lr_audio_feature: 2.0e-4 batch_size: 16 scales: [1, 0.5, 0.25, 0.125] checkpoint_freq: 1 transform_params: sigma_affine: 0.05 sigma_tps: 0.005 points_tps: 5 loss_weights: generator_gan: 0 discriminator_gan: 0 feature_matching: [10, 10, 10, 10] perceptual: [10, 10, 10, 10, 10] equivariance_value: 0 equivariance_jacobian: 0 emo: 10 visualizer_params: kp_size: 5 draw_border: True colormap: 'gist_rainbow' ================================================ FILE: demo.py ================================================ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Wed Oct 6 20:57:27 2021 @author: thea """ import matplotlib matplotlib.use('Agg') import os,sys import yaml from argparse import ArgumentParser from tqdm import tqdm from skimage import io, img_as_float32 import imageio import numpy as np from skimage.transform import resize from skimage import img_as_ubyte import torch from filter1 import OneEuroFilter import torch.utils from torch.autograd import Variable from modules.generator import OcclusionAwareGenerator from modules.keypoint_detector import KPDetector, KPDetector_a from modules.util import AT_net, Emotion_k, Emotion_map, AT_net2 from augmentation import AllAugmentationTransform from scipy.spatial import ConvexHull import python_speech_features from pathlib import Path import dlib import cv2 import librosa from skimage import transform as tf #from audiolm.models import AT_emoiton #from audiolm.utils import plot_flmarks if sys.version_info[0] < 3: raise Exception("You must use Python 3 or higher. Recommended version is Python 3.6") detector = dlib.get_frontal_face_detector() predictor = dlib.shape_predictor('./shape_predictor_68_face_landmarks.dat') def load_checkpoints(opt, checkpoint_path, audio_checkpoint_path, emo_checkpoint_path, cpu=False): with open(opt.config) as f: config = yaml.load(f, Loader=yaml.FullLoader) generator = OcclusionAwareGenerator(**config['model_params']['generator_params'], **config['model_params']['common_params']) if not cpu: generator.cuda() kp_detector = KPDetector(**config['model_params']['kp_detector_params'], **config['model_params']['common_params']) if not cpu: kp_detector.cuda() kp_detector_a = KPDetector_a(**config['model_params']['kp_detector_params'], **config['model_params']['audio_params']) audio_feature = AT_net2() if opt.type.startswith('linear'): emo_detector = Emotion_k(block_expansion=32, num_channels=3, max_features=1024, num_blocks=5, scale_factor=0.25, num_classes=8) elif opt.type.startswith('map'): emo_detector = Emotion_map(block_expansion=32, num_channels=3, max_features=1024, num_blocks=5, scale_factor=0.25, num_classes=8) if not cpu: kp_detector_a.cuda() audio_feature.cuda() emo_detector.cuda() if cpu: checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) audio_checkpoint = torch.load(audio_checkpoint_path, map_location=torch.device('cpu')) emo_checkpoint = torch.load(emo_checkpoint_path, map_location=torch.device('cpu')) else: checkpoint = torch.load(checkpoint_path) audio_checkpoint = torch.load(audio_checkpoint_path) emo_checkpoint = torch.load(emo_checkpoint_path) generator.load_state_dict(checkpoint['generator']) kp_detector.load_state_dict(checkpoint['kp_detector']) audio_feature.load_state_dict(audio_checkpoint['audio_feature']) kp_detector_a.load_state_dict(audio_checkpoint['kp_detector_a']) emo_detector.load_state_dict(emo_checkpoint['emo_detector']) if not cpu: generator = generator.cuda() kp_detector = kp_detector.cuda() audio_feature = audio_feature.cuda() kp_detector_a = kp_detector_a.cuda() emo_detector = emo_detector.cuda() generator.eval() kp_detector.eval() audio_feature.eval() kp_detector_a.eval() emo_detector.eval() return generator, kp_detector, kp_detector_a, audio_feature, emo_detector def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False, use_relative_movement=False, use_relative_jacobian=False): if adapt_movement_scale: source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) else: adapt_movement_scale = 1 kp_new = {k: v for k, v in kp_driving.items()} if use_relative_movement: kp_value_diff = (kp_driving['value'] - kp_driving_initial['value']) kp_value_diff *= adapt_movement_scale kp_new['value'] = kp_value_diff + kp_source['value'] if use_relative_jacobian: jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian'])) kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian']) return kp_new def shape_to_np(shape, dtype="int"): # initialize the list of (x, y)-coordinates coords = np.zeros((shape.num_parts, 2), dtype=dtype) # loop over all facial landmarks and convert them # to a 2-tuple of (x, y)-coordinates for i in range(0, shape.num_parts): coords[i] = (shape.part(i).x, shape.part(i).y) # return the list of (x, y)-coordinates return coords def get_aligned_image(driving_video, opt): aligned_array = [] video_array = np.array(driving_video) source_image=video_array[0] # aligned_array.append(source_image) source_image = np.array(source_image * 255, dtype=np.uint8) gray = cv2.cvtColor(source_image, cv2.COLOR_BGR2GRAY) rects = detector(gray, 1) #detect human face for (i, rect) in enumerate(rects): template = predictor(gray, rect) #detect 68 points template = shape_to_np(template) if opt.emotion == 'surprised' or opt.emotion == 'fear': template = template-[0,10] for i in range(len(video_array)): image=np.array(video_array[i] * 255, dtype=np.uint8) gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) rects = detector(gray, 1) #detect human face for (j, rect) in enumerate(rects): shape = predictor(gray, rect) #detect 68 points shape = shape_to_np(shape) pts2 = np.float32(template[:35,:]) pts1 = np.float32(shape[:35,:]) #eye and nose # pts2 = np.float32(np.concatenate((template[:16,:],template[27:36,:]),axis = 0)) # pts1 = np.float32(np.concatenate((shape[:16,:],shape[27:36,:]),axis = 0)) #eye and nose # pts1 = np.float32(landmark[17:35,:]) tform = tf.SimilarityTransform() tform.estimate( pts2, pts1) #Set the transformation matrix with the explicit parameters. dst = tf.warp(image, tform, output_shape=(256, 256)) dst = np.array(dst, dtype=np.float32) aligned_array.append(dst) return aligned_array def get_transformed_image(driving_video, opt): video_array = np.array(driving_video) with open(opt.config) as f: config = yaml.load(f, Loader=yaml.FullLoader) transformations = AllAugmentationTransform(**config['dataset_params']['augmentation_params']) transformed_array = transformations(video_array) return transformed_array def make_animation_smooth(source_image, driving_video, transformed_video, deco_out, kp_loss, generator, kp_detector, kp_detector_a, emo_detector, opt, relative=True, adapt_movement_scale=True, cpu=False): with torch.no_grad(): predictions = [] source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2) if not cpu: source = source.cuda() driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3) transformed_driving = torch.tensor(np.array(transformed_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3) kp_source = kp_detector(source) kp_driving_initial = kp_detector_a(deco_out[:,0]) emo_driving_all = [] features = [] kp_driving_all = [] for frame_idx in tqdm(range(len(deco_out[0]))): driving_frame = driving[:, :, frame_idx] transformed_frame = transformed_driving[:, :, frame_idx] if not cpu: driving_frame = driving_frame.cuda() transformed_frame = transformed_frame.cuda() kp_driving = kp_detector_a(deco_out[:,frame_idx]) kp_driving_all.append(kp_driving) if opt.add_emo: value = kp_driving['value'] jacobian = kp_driving['jacobian'] if opt.type == 'linear_3': emo_driving,_ = emo_detector(transformed_frame,value,jacobian) features.append(emo_detector.feature(transformed_frame).data.cpu().numpy()) emo_driving_all.append(emo_driving) features = np.array(features) if opt.add_emo: one_euro_filter_v = OneEuroFilter(mincutoff=1, beta=0.2, dcutoff=1.0, freq=100)#1 0.4 one_euro_filter_j = OneEuroFilter(mincutoff=1, beta=0.2, dcutoff=1.0, freq=100)#1 0.4 for j in range(len(emo_driving_all)): emo_driving_all[j]['value']=one_euro_filter_v.process(emo_driving_all[j]['value'].cpu()*100)/100 emo_driving_all[j]['value'] = emo_driving_all[j]['value'].cuda() emo_driving_all[j]['jacobian']=one_euro_filter_j.process(emo_driving_all[j]['jacobian'].cpu()*100)/100 emo_driving_all[j]['jacobian'] = emo_driving_all[j]['jacobian'].cuda() one_euro_filter_v = OneEuroFilter(mincutoff=0.05, beta=8, dcutoff=1.0, freq=100) one_euro_filter_j = OneEuroFilter(mincutoff=0.05, beta=8, dcutoff=1.0, freq=100) for j in range(len(kp_driving_all)): kp_driving_all[j]['value']=one_euro_filter_v.process(kp_driving_all[j]['value'].cpu()*10)/10 kp_driving_all[j]['value'] = kp_driving_all[j]['value'].cuda() kp_driving_all[j]['jacobian']=one_euro_filter_j.process(kp_driving_all[j]['jacobian'].cpu()*10)/10 kp_driving_all[j]['jacobian'] = kp_driving_all[j]['jacobian'].cuda() for frame_idx in tqdm(range(len(deco_out[0]))): if opt.check_add: kp_driving = kp_detector_a(deco_out[:,0]) else: kp_driving = kp_driving_all[frame_idx] # kp_driving_real = kp_detector(driving_frame) # kp_driving['value'] = (1-opt.weight)*kp_driving['value'] + opt.weight*kp_driving_real['value'] # kp_driving['jacobian'] = (1-opt.weight)*kp_driving['jacobian'] + opt.weight*kp_driving_real['jacobian'] if opt.add_emo: emo_driving = emo_driving_all[frame_idx] if opt.type == 'linear_3': kp_driving['value'][:,1] = kp_driving['value'][:,1] + emo_driving['value'][:,0]*0.2 kp_driving['jacobian'][:,1] = kp_driving['jacobian'][:,1] + emo_driving['jacobian'][:,0]*0.2 kp_driving['value'][:,4] = kp_driving['value'][:,4] + emo_driving['value'][:,1] kp_driving['jacobian'][:,4] = kp_driving['jacobian'][:,4] + emo_driving['jacobian'][:,1] kp_driving['value'][:,6] = kp_driving['value'][:,6] + emo_driving['value'][:,2] kp_driving['jacobian'][:,6] = kp_driving['jacobian'][:,6] + emo_driving['jacobian'][:,2] # kp_driving['value'][:,8] = kp_driving['value'][:,8] + emo_driving['value'][:,3] # kp_driving['jacobian'][:,8] = kp_driving['jacobian'][:,8] + emo_driving['jacobian'][:,3] kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving, kp_driving_initial=kp_driving_initial, use_relative_movement=relative, use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale) out = generator(source, kp_source=kp_source, kp_driving=kp_norm) predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) return predictions, features def test_auido(example_image, audio_feature, all_pose, opt): with open(opt.config) as f: para = yaml.load(f, Loader=yaml.FullLoader) # encoder = audio_feature() if not opt.cpu: audio_feature = audio_feature.cuda() audio_feature.eval() # decoder.eval() test_file = opt.in_file pose = all_pose[:,:6] if len(pose) == 1: pose = np.repeat(pose,100,0) elif opt.smooth_pose: one_euro_filter = OneEuroFilter(mincutoff=0.004, beta=0.7, dcutoff=1.0, freq=100) for j in range(len(pose)): pose[j]=one_euro_filter.process(pose[j]) # pose[j]=pose[0] example_image = np.array(example_image, dtype='float32').transpose((2, 0, 1)) speech, sr = librosa.load(test_file, sr=16000) # mfcc = python_speech_features.mfcc(speech ,16000,winstep=0.01) speech = np.insert(speech, 0, np.zeros(1920)) speech = np.append(speech, np.zeros(1920)) mfcc = python_speech_features.mfcc(speech,16000,winstep=0.01) print ('=======================================') print ('Start to generate images') ind = 3 with torch.no_grad(): fake_lmark = [] input_mfcc = [] while ind <= int(mfcc.shape[0]/4) - 4: t_mfcc =mfcc[( ind - 3)*4: (ind + 4)*4, 1:] t_mfcc = torch.FloatTensor(t_mfcc).cuda() input_mfcc.append(t_mfcc) ind += 1 input_mfcc = torch.stack(input_mfcc,dim = 0) if (len(pose)len(input_mfcc)): pose = pose[:len(input_mfcc),:] if not opt.cpu: example_image = Variable(torch.FloatTensor(example_image.astype(float)) ).cuda() example_image = torch.unsqueeze(example_image,0) pose = Variable(torch.FloatTensor(pose.astype(float)) ).cuda() pose = pose.unsqueeze(0) input_mfcc = input_mfcc.unsqueeze(0) deco_out = audio_feature(example_image,input_mfcc,pose,para['train_params']['jaco_net'],1.6) return deco_out def save(path, frames, format): if format == '.png': if not os.path.exists(path): os.makedirs(path) for j, frame in enumerate(frames): imageio.imsave(path+'/'+str(j)+'.png',frame) # imageio.imsave(os.path.join(path, str(j) + '.png'), frames[j]) else: print ("Unknown format %s" % format) exit() class VideoWriter(object): def __init__(self, path, width, height, fps): fourcc = cv2.VideoWriter_fourcc(*'XVID') self.path = path self.out = cv2.VideoWriter(self.path, fourcc, fps, (width, height)) def write_frame(self, frame): self.out.write(frame) def end(self): self.out.release() def concatenate(number, imgs, save_path): width, height = imgs.shape[-3:-1] imgs = imgs.reshape(number,-1,width,height,3) if number == 2: left = imgs[0] right = imgs[1] im_all = [] for i in range(len(left)): im = np.concatenate((left[i],right[i]),axis = 1) im_all.append(im) if number == 3: left = imgs[0] middle = imgs[1] right = imgs[2] im_all = [] for i in range(len(left)): im = np.concatenate((left[i],middle[i],right[i]),axis = 1) im_all.append(im) if number == 4: left = imgs[0] left2 = imgs[1] right = imgs[2] right2 = imgs[3] im_all = [] for i in range(len(left)): im = np.concatenate((left[i],left2[i],right[i],right2[i]),axis = 1) im_all.append(im) if number == 5: left = imgs[0] left2 = imgs[1] middle = imgs[2] right = imgs[3] right2 = imgs[4] im_all = [] for i in range(len(left)): im = np.concatenate((left[i],left2[i],middle[i],right[i],right2[i]),axis = 1) im_all.append(im) imageio.mimsave(save_path, [img_as_ubyte(frame) for frame in im_all], fps=25) def add_audio(video_name=None, audio_dir = None): command = 'ffmpeg -i ' + video_name + ' -i ' + audio_dir + ' -vcodec copy -acodec copy -y ' + video_name.replace('.mp4','.mov') print (command) os.system(command) def crop_image(source_image): template = np.load('./M003_template.npy') image= cv2.imread(source_image) gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) rects = detector(gray, 1) #detect human face if len(rects) != 1: return 0 for (j, rect) in enumerate(rects): shape = predictor(gray, rect) #detect 68 points shape = shape_to_np(shape) pts2 = np.float32(template[:47,:]) pts1 = np.float32(shape[:47,:]) #eye and nose # pts1 = np.float32(landmark[17:35,:]) tform = tf.SimilarityTransform() tform.estimate( pts2, pts1) #Set the transformation matrix with the explicit parameters. dst = tf.warp(image, tform, output_shape=(256, 256)) dst = np.array(dst * 255, dtype=np.uint8) return dst def smooth_pose(pose_file, pose_long): start = np.load(pose_file) video_pose = np.load(pose_long) delta = video_pose - video_pose[0,:] print(len(delta)) pose = np.repeat(start,len(delta),axis = 0) all_pose = pose + delta return all_pose def test(opt, name): all_pose = np.load(opt.pose_file).reshape(-1,7) if opt.pose_long: all_pose = smooth_pose(opt.pose_file,opt.pose_given) # source_image = img_as_float32(io.imread(opt.source_image)) source_image = img_as_float32(crop_image(opt.source_image)) source_image = resize(source_image, (256, 256))[..., :3] reader = imageio.get_reader(opt.driving_video) fps = reader.get_meta_data()['fps'] driving_video = [] try: for im in reader: driving_video.append(im) except RuntimeError: pass reader.close() driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video] driving_video = get_aligned_image(driving_video, opt) transformed_video = get_transformed_image(driving_video, opt) transformed_video = np.array(transformed_video) generator, kp_detector,kp_detector_a, audio_feature, emo_detector = load_checkpoints(opt=opt, checkpoint_path=opt.checkpoint, audio_checkpoint_path=opt.audio_checkpoint, emo_checkpoint_path = opt.emo_checkpoint, cpu=opt.cpu) deco_out = test_auido(source_image, audio_feature, all_pose, opt) if len(driving_video) < len(deco_out[0]): driving_video = np.resize(driving_video,(len(deco_out[0]),256,256,3)) transformed_video = np.resize(transformed_video,(len(deco_out[0]),256,256,3)) else: driving_video = driving_video[:len(deco_out[0])] opt.add_emo = False predictions, _ = make_animation_smooth(source_image, driving_video, transformed_video, deco_out, opt.kp_loss, generator, kp_detector, kp_detector_a, emo_detector, opt, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu) imageio.mimsave(os.path.join(opt.result_path,'neutral.mp4'), [img_as_ubyte(frame) for frame in predictions], fps=fps) predictions = np.array(predictions) opt.add_emo = True predictions1,_ = make_animation_smooth(source_image, driving_video, transformed_video, deco_out, opt.kp_loss, generator, kp_detector, kp_detector_a, emo_detector, opt, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu) imageio.mimsave(os.path.join(opt.result_path,'emotion.mp4'), [img_as_ubyte(frame) for frame in predictions1], fps=fps) add_audio(os.path.join(opt.result_path,'emotion.mp4'),opt.in_file) predictions1 = np.array(predictions1) all_imgs = np.concatenate((driving_video,predictions,predictions1),axis = 0) save_path = os.path.join(opt.result_path, 'all.mp4') concatenate(3, all_imgs, save_path) add_audio(save_path,opt.in_file) if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--config", default ='config/MEAD_emo_video_aug_delta_4_crop_random_crop.yaml', help="path to config")#required=True default ='config/vox-256.yaml' parser.add_argument("--audio_checkpoint", default='log/1-6000.pth.tar', help="path to checkpoint to restore") parser.add_argument("--checkpoint", default='log/124_52000.pth.tar', help="path to checkpoint to restore") # parser.add_argument("--emo_checkpoint", default='ablation/ablation/ten/10-6000.pth.tar', help="path to checkpoint to restore") parser.add_argument("--emo_checkpoint", default='log/5-3000.pth.tar', help="path to checkpoint to restore") parser.add_argument("--source_image", default='test/image/21.png', help="path to source image") parser.add_argument("--driving_video", default='test/video/disgusted.mp4', help="path to driving video")#data/M030/video/M030_angry_ parser.add_argument('--in_file', type=str, default='test/audio/sample1.mov') parser.add_argument('--pose_file', type=str, default='test/pose/21.npy') parser.add_argument('--pose_given', type=str, default='test/pose_long/0zn70Ak8lRc_Daniel_Auteuil_0zn70Ak8lRc_0002.npy') parser.add_argument("--result_path", default='result/', help="path to output")#'/media/thea/新加卷/fomm/Exp/'+emotion+'.mp4' parser.add_argument("--relative", dest="relative", action="store_true", help="use relative or absolute keypoint coordinates") parser.add_argument("--adapt_scale", dest="adapt_scale", action="store_true", help="adapt movement scale based on convex hull of keypoints") parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.") parser.add_argument("--kp_loss", default=0, help="keypoint loss.") parser.add_argument("--smooth_pose", default=True, help="cpu mode.") parser.add_argument("--pose_long", default=False, help="use given long poses.") parser.add_argument("--weight", default=0, help="cpu mode.") parser.add_argument("--add_emo", default=False, help="add emotion.") parser.add_argument("--check_add", default=False, help="check emotion displacement.") parser.add_argument("--type", default='linear_3', help="add emotion type.") parser.add_argument("--emotion", default='disgusted', help="emotion category, 'angry', 'contempt','disgusted','fear','happy','neutral','sad','surprised'.") parser.set_defaults(relative=False) parser.set_defaults(adapt_scale=False) opt = parser.parse_args() # opt.cpu = True test(opt,'test') ================================================ FILE: filter1.py ================================================ import cv2 #import pickle import time import numpy as np import copy from matplotlib import pyplot as plt from tqdm import tqdm class LowPassFilter: def __init__(self): self.prev_raw_value = None self.prev_filtered_value = None def process(self, value, alpha): if self.prev_raw_value is None: s = value else: s = alpha * value + (1.0 - alpha) * self.prev_filtered_value self.prev_raw_value = value self.prev_filtered_value = s return s class OneEuroFilter: def __init__(self, mincutoff=1.0, beta=0.0, dcutoff=1.0, freq=30): self.freq = freq self.mincutoff = mincutoff self.beta = beta self.dcutoff = dcutoff self.x_filter = LowPassFilter() self.dx_filter = LowPassFilter() def compute_alpha(self, cutoff): te = 1.0 / self.freq tau = 1.0 / (2 * np.pi * cutoff) return 1.0 / (1.0 + tau / te) def process(self, x): prev_x = self.x_filter.prev_raw_value dx = 0.0 if prev_x is None else (x - prev_x) * self.freq edx = self.dx_filter.process(dx, self.compute_alpha(self.dcutoff)) cutoff = self.mincutoff + self.beta * np.abs(edx) return self.x_filter.process(x, self.compute_alpha(cutoff)) ================================================ FILE: frames_dataset.py ================================================ import os from skimage import io, img_as_float32, transform from skimage.color import gray2rgb from sklearn.model_selection import train_test_split from imageio import mimread import numpy as np from torch.utils.data import Dataset import pandas as pd from augmentation import AllAugmentationTransform import glob import pickle import random from filter1 import OneEuroFilter def read_video(name, frame_shape): """ Read video which can be: - an image of concatenated frames - '.mp4' and'.gif' - folder with videos """ if os.path.isdir(name): frames = sorted(os.listdir(name)) num_frames = len(frames) video_array = np.array( [img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)]) elif name.lower().endswith('.png') or name.lower().endswith('.jpg'): image = io.imread(name) if len(image.shape) == 2 or image.shape[2] == 1: image = gray2rgb(image) if image.shape[2] == 4: image = image[..., :3] image = img_as_float32(image) video_array = np.moveaxis(image, 1, 0) video_array = video_array.reshape((-1,) + frame_shape) video_array = np.moveaxis(video_array, 1, 2) elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'): video = np.array(mimread(name)) if len(video.shape) == 3: video = np.array([gray2rgb(frame) for frame in video]) if video.shape[-1] == 4: video = video[..., :3] video_array = img_as_float32(video) else: raise Exception("Unknown file extensions %s" % name) return video_array def get_list(ipath,base_name): #ipath = '/mnt/lustre/share/jixinya/LRW/pose/train_fo/' ipath = os.path.join(ipath,base_name) name_list = os.listdir(ipath) image_path = os.path.join('/mnt/lustre/share/jixinya/LRW/Image/',base_name) all = [] for k in range(len(name_list)): name = name_list[k] path_ = os.path.join(ipath,name) Dir = os.listdir(path_) for i in range(len(Dir)): word = Dir[i] path = os.path.join(path_, word) if os.path.exists(os.path.join(image_path,name,word.split('.')[0])): all.append(name+'/'+word.split('.')[0]) #print(k,name,i,word) print('get list '+os.path.basename(ipath)) return all class AudioDataset(Dataset): """ Dataset of videos, each video can be represented as: - an image of concatenated frames - '.mp4' or '.gif' - folder with all frames """ def __init__(self, name, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True, random_seed=0, augmentation_params=None): self.root_dir = root_dir self.audio_dir = os.path.join(root_dir,'MFCC') self.image_dir = os.path.join(root_dir,'Image') self.pose_dir = os.path.join(root_dir,'pose') # assert len(os.listdir(self.audio_dir)) == len(os.listdir(self.image_dir)), 'audio and image length not equal' # self.videos=np.load('../LRW/list/train_fo.npy') # self.videos = os.listdir(self.landmark_dir) self.frame_shape = tuple(frame_shape) self.id_sampling = id_sampling if os.path.exists(os.path.join(self.pose_dir, 'train_fo')): assert os.path.exists(os.path.join(self.pose_dir, 'test_fo')) print("Use predefined train-test split.") if id_sampling: train_videos = {os.path.basename(video).split('#')[0] for video in os.listdir(os.path.join(self.image_dir, 'train'))} train_videos = list(train_videos) else: train_videos = np.load('../LRW/list/train_fo.npy')# get_list(self.pose_dir, 'train_fo') # df=open('../LRW/list/test_fo.txt','rb') test_videos=np.load('../LRW/list/test_fo.npy') # df.close() # test_videos = np.load('../LRW/list/train_fo.npy') #get_list(self.pose_dir, 'test_fo') # self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test') self.image_dir = os.path.join(self.image_dir, 'train_fo' if is_train else 'test_fo') self.audio_dir = os.path.join(self.audio_dir, 'train' if is_train else 'test') self.pose_dir = os.path.join(self.pose_dir, 'train_fo' if is_train else 'test_fo') else: print("Use random train-test split.") train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2) if is_train: self.videos = train_videos else: self.videos = test_videos self.is_train = is_train if self.is_train: self.transform = AllAugmentationTransform(**augmentation_params) else: self.transform = None def __len__(self): return len(self.videos) def __getitem__(self, idx): if self.is_train and self.id_sampling: name = self.videos[idx].split('.')[0] path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4'))) else: name = self.videos[idx].split('.')[0] audio_path = os.path.join(self.audio_dir, name) pose_path = os.path.join(self.pose_dir,name) path = os.path.join(self.image_dir, name) video_name = os.path.basename(path) if os.path.isdir(path): # if self.is_train and os.path.isdir(path): # mfcc loading r = random.choice([x for x in range(3, 8)]) example_image = img_as_float32(io.imread(os.path.join(path, str(r)+'.png'))) mfccs = [] for ind in range(1, 17): # t_mfcc = mfcc[(r + ind - 3) * 4: (r + ind + 4) * 4, 1:] t_mfcc = np.load(os.path.join(audio_path,str(r + ind)+'.npy'),allow_pickle=True)[:, 1:] mfccs.append(t_mfcc) mfccs = np.array(mfccs) poses = [] video_array = [] for ind in range(1, 17): t_pose = np.load(os.path.join(self.pose_dir,name+'.npy'))[r+ind,:-1] poses.append(t_pose) image = img_as_float32(io.imread(os.path.join(path, str(r + ind)+'.png'))) video_array.append(image) poses = np.array(poses) video_array = np.array(video_array) else: print('Wrong, data path not an existing file.') if self.transform is not None: video_array = self.transform(video_array) out = {} driving = np.array(video_array, dtype='float32') spatial_size = np.array(driving.shape[1:3][::-1])[np.newaxis] driving_pose = np.array(poses, dtype='float32') example_image = np.array(example_image, dtype='float32') out['example_image'] = example_image.transpose((2, 0, 1)) out['driving_pose'] = driving_pose out['driving'] = driving.transpose((0, 3, 1, 2)) out['driving_audio'] = np.array(mfccs, dtype='float32') # out['name'] = video_name return out class VoxDataset(Dataset): """ Dataset of videos, each video can be represented as: - an image of concatenated frames - '.mp4' or '.gif' - folder with all frames """ def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True, random_seed=0, pairs_list=None, augmentation_params=None): self.root_dir = root_dir self.audio_dir = os.path.join(root_dir,'MFCC') self.image_dir = os.path.join(root_dir,'align_img') self.pose_dir = os.path.join(root_dir,'align_pose') # assert len(os.listdir(self.audio_dir)) == len(os.listdir(self.image_dir)), 'audio and image length not equal' # df=open('../LRW/list/test_fo.txt','rb') # self.videos=pickle.load(df) # df.close() self.videos=np.load('/mnt/lustre/share_data/jixinya/VoxCeleb1_Cut/right.npy') # self.videos = os.listdir(self.landmark_dir) self.frame_shape = tuple(frame_shape) self.pairs_list = pairs_list self.id_sampling = id_sampling if os.path.exists(os.path.join(self.pose_dir, 'train_fo')): assert os.path.exists(os.path.join(self.pose_dir, 'test_fo')) print("Use predefined train-test split.") if id_sampling: train_videos = {os.path.basename(video).split('#')[0] for video in os.listdir(os.path.join(self.image_dir, 'train'))} train_videos = list(train_videos) else: train_videos = np.load('/mnt/lustre/share_data/jixinya/VoxCeleb1_Cut/right.npy')# get_list(self.pose_dir, 'train_fo') self.image_dir = os.path.join(self.image_dir, 'train_fo' if is_train else 'test_fo') self.audio_dir = os.path.join(self.audio_dir, 'train' if is_train else 'test') self.pose_dir = os.path.join(self.pose_dir, 'train_fo' if is_train else 'test_fo') else: print("Use random train-test split.") train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2) if is_train: self.videos = train_videos else: self.videos = test_videos self.is_train = is_train if self.is_train: self.transform = AllAugmentationTransform(**augmentation_params) else: self.transform = None def __len__(self): return len(self.videos) def __getitem__(self, idx): if self.is_train and self.id_sampling: name = self.videos[idx].split('.')[0] path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4'))) else: name = self.videos[idx].split('.')[0] audio_path = os.path.join(self.audio_dir, name+'.npy') pose_path = os.path.join(self.pose_dir,name+'.npy') path = os.path.join(self.image_dir, name) video_name = os.path.basename(path) if os.path.isdir(path): # if self.is_train and os.path.isdir(path): frames = os.listdir(path) num_frames = len(frames) frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) video_array = [img_as_float32(io.imread(os.path.join(path, frames[idx]))) for idx in frame_idx] mfcc = np.load(audio_path) pose = np.load(pose_path) # print(audio_path,pose_path,len(mfcc)) try: len(mfcc) > 16 except: print('wrongmfcc len:',audio_path) if 16 < len(mfcc) < 24 : r = 0 else: r = random.choice([x for x in range(3, len(mfcc)-20)]) mfccs = [] poses = [] video_array = [] for ind in range(1, 17): t_mfcc = mfcc[r+ind][:, 1:] mfccs.append(t_mfcc) t_pose = pose[r+ind,:-1] poses.append(t_pose) image = img_as_float32(io.imread(os.path.join(path, str(r + ind)+'.png'))) video_array.append(image) mfccs = np.array(mfccs) poses = np.array(poses) video_array = np.array(video_array) example_image = img_as_float32(io.imread(os.path.join(path, str(r)+'.png'))) else: print('Wrong, data path not an existing file.') if self.transform is not None: video_array = self.transform(video_array) out = {} driving = np.array(video_array, dtype='float32') spatial_size = np.array(driving.shape[1:3][::-1])[np.newaxis] driving_pose = np.array(poses, dtype='float32') example_image = np.array(example_image, dtype='float32') out['example_image'] = example_image.transpose((2, 0, 1)) out['driving_pose'] = driving_pose out['driving'] = driving.transpose((0, 3, 1, 2)) out['driving_audio'] = np.array(mfccs, dtype='float32') # out['name'] = video_name return out class MeadDataset(Dataset): """ Dataset of videos, each video can be represented as: - an image of concatenated frames - '.mp4' or '.gif' - folder with all frames """ def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True, random_seed=0, augmentation_params=None): self.root_dir = root_dir self.audio_dir = os.path.join(root_dir,'MEAD_MFCC') self.image_dir = os.path.join(root_dir,'MEAD_fomm_crop') self.pose_dir = os.path.join(root_dir,'MEAD_fomm_pose_crop') self.videos = np.load('/mnt/lustre/share_data/jixinya/MEAD/MEAD_fomm_audio_less_crop.npy') self.dict = np.load('/mnt/lustre/share_data/jixinya/MEAD/MEAD_fomm_neu_dic_crop.npy',allow_pickle=True).item() # self.videos = os.listdir(root_dir) self.frame_shape = tuple(frame_shape) self.id_sampling = id_sampling if os.path.exists(os.path.join(root_dir, 'train')): assert os.path.exists(os.path.join(root_dir, 'test')) print("Use predefined train-test split.") if id_sampling: train_videos = {os.path.basename(video).split('#')[0] for video in os.listdir(os.path.join(root_dir, 'train'))} train_videos = list(train_videos) else: train_videos = os.listdir(os.path.join(root_dir, 'train')) test_videos = os.listdir(os.path.join(root_dir, 'test')) self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test') else: print("Use random train-test split.") train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2) if is_train: self.videos = train_videos else: self.videos = test_videos self.is_train = is_train if self.is_train: self.transform = AllAugmentationTransform(**augmentation_params) else: self.transform = None def __len__(self): return len(self.videos) def __getitem__(self, idx): if self.is_train and self.id_sampling: name = self.videos[idx] path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4'))) else: name = self.videos[idx] path = os.path.join(self.image_dir, name) video_name = os.path.basename(path) id_name = path.split('/')[-2] neu_list = self.dict[id_name] neu_path = os.path.join(self.image_dir, np.random.choice(neu_list)) audio_path = os.path.join(self.audio_dir, name+'.npy') pose_path = os.path.join(self.pose_dir,name+'.npy') if self.is_train and os.path.isdir(path): mfcc = np.load(audio_path) pose_raw = np.load(pose_path) one_euro_filter = OneEuroFilter(mincutoff=0.01, beta=0.7, dcutoff=1.0, freq=100) pose = np.zeros((len(pose_raw),7)) for j in range(len(pose_raw)): pose[j]=one_euro_filter.process(pose_raw[j]) # print(audio_path,pose_path,len(mfcc)) neu_frames = os.listdir(neu_path) num_neu_frames = len(neu_frames) frame_idx = np.random.choice(num_neu_frames) example_image = img_as_float32(io.imread(os.path.join(neu_path, neu_frames[frame_idx]))) try: len(mfcc) > 16 except: print('wrongmfcc len:',audio_path) if 16 < len(mfcc) < 24 : r = 0 else: r = random.choice([x for x in range(3, len(mfcc)-20)]) mfccs = [] poses = [] video_array = [] for ind in range(1, 17): t_mfcc = mfcc[r+ind][:, 1:] mfccs.append(t_mfcc) t_pose = pose[r+ind,:-1] poses.append(t_pose) image = img_as_float32(io.imread(os.path.join(path, str(r + ind)+'.png'))) video_array.append(image) mfccs = np.array(mfccs) poses = np.array(poses) video_array = np.array(video_array) else: print('Wrong, data path not an existing file.') if self.transform is not None: video_array = self.transform(video_array) out = {} if self.is_train: driving = np.array(video_array, dtype='float32') driving_pose = np.array(poses, dtype='float32') example_image = np.array(example_image, dtype='float32') out['example_image'] = example_image.transpose((2, 0, 1)) out['driving_pose'] = driving_pose out['driving'] = driving.transpose((0, 3, 1, 2)) out['driving_audio'] = np.array(mfccs, dtype='float32') # out['name'] = id_name+'/'+video_name return out class DatasetRepeater(Dataset): """ Pass several times over the same dataset for better i/o performance """ def __init__(self, dataset, num_repeats=100): self.dataset = dataset # self.dataset2 = dataset2 self.num_repeats = num_repeats def __len__(self): return self.num_repeats * self.dataset.__len__() def __getitem__(self, idx): # if idx % 5 == 0: # return self.dataset2[idx % self.dataset2.__len__()]#% self.dataset.__len__() # else: # return self.dataset[idx % self.dataset.__len__()] return self.dataset[idx % self.dataset.__len__()] class TestsetRepeater(Dataset): """ Pass several times over the same dataset for better i/o performance """ def __init__(self, dataset, num_repeats=100): self.dataset = dataset self.num_repeats = num_repeats def __len__(self): return self.num_repeats * self.dataset.__len__() def __getitem__(self, idx): return self.dataset[idx % self.dataset.__len__()]#% self.dataset.__len__() class PairedDataset(Dataset): """ Dataset of pairs for animation. """ def __init__(self, initial_dataset, number_of_pairs, seed=0): self.initial_dataset = initial_dataset pairs_list = self.initial_dataset.pairs_list np.random.seed(seed) if pairs_list is None: max_idx = min(number_of_pairs, len(initial_dataset)) nx, ny = max_idx, max_idx xy = np.mgrid[:nx, :ny].reshape(2, -1).T number_of_pairs = min(xy.shape[0], number_of_pairs) self.pairs = xy.take(np.random.choice(xy.shape[0], number_of_pairs, replace=False), axis=0) else: videos = self.initial_dataset.videos name_to_index = {name: index for index, name in enumerate(videos)} pairs = pd.read_csv(pairs_list) pairs = pairs[np.logical_and(pairs['source'].isin(videos), pairs['driving'].isin(videos))] number_of_pairs = min(pairs.shape[0], number_of_pairs) self.pairs = [] self.start_frames = [] for ind in range(number_of_pairs): self.pairs.append( (name_to_index[pairs['driving'].iloc[ind]], name_to_index[pairs['source'].iloc[ind]])) def __len__(self): return len(self.pairs) def __getitem__(self, idx): pair = self.pairs[idx] first = self.initial_dataset[pair[0]] second = self.initial_dataset[pair[1]] first = {'driving_' + key: value for key, value in first.items()} second = {'source_' + key: value for key, value in second.items()} return {**first, **second} ================================================ FILE: logger.py ================================================ import numpy as np import torch import torch.nn.functional as F import imageio import os from skimage.draw import circle import matplotlib.pyplot as plt import collections class Logger: def __init__(self, log_dir, checkpoint_freq=100, visualizer_params=None, zfill_num=8, log_file_name='log.txt'): self.loss_list = [] self.cpk_dir = log_dir self.visualizations_dir = os.path.join(log_dir, 'train-vis') if not os.path.exists(self.visualizations_dir): os.makedirs(self.visualizations_dir) self.log_file = open(os.path.join(log_dir, log_file_name), 'a') self.zfill_num = zfill_num self.visualizer = Visualizer(**visualizer_params) self.checkpoint_freq = checkpoint_freq self.epoch = 0 self.best_loss = float('inf') self.names = None def log_scores(self, loss_names): loss_mean = np.array(self.loss_list).mean(axis=0) loss_string = "; ".join(["%s - %.5f" % (name, value) for name, value in zip(loss_names, loss_mean)]) loss_string = str(str(self.epoch)+str(self.step).zfill(self.zfill_num)) + ") " + loss_string print(loss_string, file=self.log_file) self.loss_list = [] self.log_file.flush() def visualize_rec(self, inp, out): # image = self.visualizer.visualize(inp['driving'], inp['source'], out) image = self.visualizer.visualize(inp['driving'][:,-1], inp['transformed_driving'][:,-1], inp['example_image'], out) imageio.imsave(os.path.join(self.visualizations_dir, "%s-%s-rec.png" % (str(self.epoch),str(self.step).zfill(self.zfill_num))), image) def save_cpk(self, emergent=False): cpk = {k: v.state_dict() for k, v in self.models.items()} cpk['epoch'] = self.epoch cpk['step'] = self.step cpk_path = os.path.join(self.cpk_dir, '%s-%s-checkpoint.pth.tar' % (str(self.epoch),str(self.step).zfill(self.zfill_num))) if not (os.path.exists(cpk_path) and emergent): torch.save(cpk, cpk_path) @staticmethod def load_cpk(checkpoint_path, generator=None, discriminator=None, kp_detector=None, audio_feature=None, optimizer_generator=None, optimizer_discriminator=None, optimizer_kp_detector=None, optimizer_audio_feature = None): checkpoint = torch.load(checkpoint_path) if generator is not None: generator.load_state_dict(checkpoint['generator']) if kp_detector is not None: kp_detector.load_state_dict(checkpoint['kp_detector']) if discriminator is not None: try: discriminator.load_state_dict(checkpoint['discriminator']) except: print ('No discriminator in the state-dict. Dicriminator will be randomly initialized') # if audio_feature is not None: # audio_feature.load_state_dict(checkpoint['audio_feature']) if optimizer_generator is not None: optimizer_generator.load_state_dict(checkpoint['optimizer_generator']) if optimizer_discriminator is not None: try: optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) except RuntimeError as e: print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized') if optimizer_kp_detector is not None: optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector']) # if optimizer_audio_feature is not None: # a = checkpoint['optimizer_kp_detector']['param_groups'] # a[0].pop('params') # optimizer_audio_feature.load_state_dict(checkpoint['optimizer_audio_feature']) return checkpoint['epoch'] def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): if 'models' in self.__dict__: self.save_cpk() self.log_file.close() def log_iter(self, losses): losses = collections.OrderedDict(losses.items()) if self.names is None: self.names = list(losses.keys()) self.loss_list.append(list(losses.values())) def log_epoch(self, epoch, step, models, inp, out): self.epoch = epoch self.step = step self.models = models if (self.epoch + 1) % self.checkpoint_freq == 0: self.save_cpk() self.log_scores(self.names) self.visualize_rec(inp, out) class Visualizer: def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbow'): self.kp_size = kp_size self.draw_border = draw_border self.colormap = plt.get_cmap(colormap) def draw_image_with_kp(self, image, kp_array): image = np.copy(image) spatial_size = np.array(image.shape[:2][::-1])[np.newaxis] kp_array = spatial_size * (kp_array + 1) / 2 num_kp = kp_array.shape[0] for kp_ind, kp in enumerate(kp_array): rr, cc = circle(kp[1], kp[0], self.kp_size, shape=image.shape[:2]) image[rr, cc] = np.array(self.colormap(kp_ind / num_kp))[:3] return image def create_image_column_with_kp(self, images, kp): image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)]) return self.create_image_column(image_array) def create_image_column(self, images): if self.draw_border: images = np.copy(images) images[:, :, [0, -1]] = (1, 1, 1) images[:, :, [0, -1]] = (1, 1, 1) return np.concatenate(list(images), axis=0) def create_image_grid(self, *args): out = [] for arg in args: if type(arg) == tuple: out.append(self.create_image_column_with_kp(arg[0], arg[1])) else: out.append(self.create_image_column(arg)) return np.concatenate(out, axis=1) def visualize(self, driving, transformed_driving, source, out): images = [] # Source image with keypoints source = source.data.cpu() kp_source = out['kp_source']['value'].data.cpu().numpy() source = np.transpose(source, [0, 2, 3, 1]) images.append((source, kp_source)) # Equivariance visualization if 'transformed_frame' in out: transformed = out['transformed_frame'].data.cpu().numpy() transformed = np.transpose(transformed, [0, 2, 3, 1]) transformed_kp = out['transformed_kp']['value'].data.cpu().numpy() images.append((transformed, transformed_kp)) # Equivariance visualization transformed_driving = transformed_driving.data.cpu().numpy() transformed_driving = np.transpose(transformed_driving, [0, 2, 3, 1]) images.append(transformed_driving) # Driving image with keypoints kp_driving = out['kp_driving'][-1]['value'].data.cpu().numpy() #[-1]['value'] driving = driving.data.cpu().numpy() driving = np.transpose(driving, [0, 2, 3, 1]) images.append((driving, kp_driving)) # Deformed image if 'deformed' in out: deformed = out['deformed'].data.cpu().numpy() deformed = np.transpose(deformed, [0, 2, 3, 1]) images.append(deformed) # Result with and without keypoints prediction = out['prediction'].data.cpu().numpy() prediction = np.transpose(prediction, [0, 2, 3, 1]) if 'kp_norm' in out: kp_norm = out['kp_norm']['value'].data.cpu().numpy() images.append((prediction, kp_norm)) images.append(prediction) ## Occlusion map if 'occlusion_map' in out: occlusion_map = out['occlusion_map'].data.cpu().repeat(1, 3, 1, 1) occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy() occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1]) images.append(occlusion_map) # Deformed images according to each individual transform if 'sparse_deformed' in out: full_mask = [] for i in range(out['sparse_deformed'].shape[1]): image = out['sparse_deformed'][:, i].data.cpu() image = F.interpolate(image, size=source.shape[1:3]) mask = out['mask'][:, i:(i+1)].data.cpu().repeat(1, 3, 1, 1) mask = F.interpolate(mask, size=source.shape[1:3]) image = np.transpose(image.numpy(), (0, 2, 3, 1)) mask = np.transpose(mask.numpy(), (0, 2, 3, 1)) if i != 0: color = np.array(self.colormap((i - 1) / (out['sparse_deformed'].shape[1] - 1)))[:3] else: color = np.array((0, 0, 0)) color = color.reshape((1, 1, 1, 3)) images.append(image) if i != 0: images.append(mask * color) else: images.append(mask) full_mask.append(mask * color) images.append(sum(full_mask)) image = self.create_image_grid(*images) image = (255 * image).astype(np.uint8) return image ================================================ FILE: modules/dense_motion.py ================================================ from torch import nn import torch.nn.functional as F import torch from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian class DenseMotionNetwork(nn.Module): """ Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving """ def __init__(self, block_expansion, num_blocks, max_features, num_kp, num_channels, estimate_occlusion_map=False, scale_factor=1, kp_variance=0.01): super(DenseMotionNetwork, self).__init__() self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp + 1) * (num_channels + 1), max_features=max_features, num_blocks=num_blocks) self.mask = nn.Conv2d(self.hourglass.out_filters, num_kp + 1, kernel_size=(7, 7), padding=(3, 3)) if estimate_occlusion_map: self.occlusion = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3)) else: self.occlusion = None self.num_kp = num_kp self.scale_factor = scale_factor self.kp_variance = kp_variance if self.scale_factor != 1: self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) def create_heatmap_representations(self, source_image, kp_driving, kp_source): """ Eq 6. in the paper H_k(z) """ spatial_size = source_image.shape[2:] gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance) gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance) heatmap = gaussian_driving - gaussian_source #[4,10,H,W] #adding background feature zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type()) heatmap = torch.cat([zeros, heatmap], dim=1) heatmap = heatmap.unsqueeze(2) #[4,11,1,h,w] return heatmap def create_sparse_motions(self, source_image, kp_driving, kp_source): """ Eq 4. in the paper T_{s<-d}(z) """ bs, _, h, w = source_image.shape identity_grid = make_coordinate_grid((h, w), type=kp_source['value'].type()) identity_grid = identity_grid.view(1, 1, h, w, 2) coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 2) #[4,10,64,64,2] if 'jacobian' in kp_driving: jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian'])) jacobian = jacobian.unsqueeze(-3).unsqueeze(-3) jacobian = jacobian.repeat(1, 1, h, w, 1, 1) coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1)) coordinate_grid = coordinate_grid.squeeze(-1) driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 2) #adding background feature identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1) sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) return sparse_motions def create_deformed_source_image(self, source_image, sparse_motions): """ Eq 7. in the paper \hat{T}_{s<-d}(z) """ bs, _, h, w = source_image.shape source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp + 1, 1, 1, 1, 1) source_repeat = source_repeat.view(bs * (self.num_kp + 1), -1, h, w) sparse_motions = sparse_motions.view((bs * (self.num_kp + 1), h, w, -1)) sparse_deformed = F.grid_sample(source_repeat, sparse_motions) sparse_deformed = sparse_deformed.view((bs, self.num_kp + 1, -1, h, w)) return sparse_deformed def forward(self, source_image, kp_driving, kp_source): if self.scale_factor != 1: source_image = self.down(source_image) #[4,3,H*scale,W*scale] bs, _, h, w = source_image.shape out_dict = dict() heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source) #[4,11,1,64,64] sparse_motion = self.create_sparse_motions(source_image, kp_driving, kp_source) #[4,11,64,64,2] deformed_source = self.create_deformed_source_image(source_image, sparse_motion) #[4,11,3,64,64] out_dict['sparse_deformed'] = deformed_source input = torch.cat([heatmap_representation, deformed_source], dim=2) input = input.view(bs, -1, h, w) #[4,11*4,64,64] prediction = self.hourglass(input) #[4,108,64,64] mask = self.mask(prediction) mask = F.softmax(mask, dim=1) #[4,11,64,64] out_dict['mask'] = mask mask = mask.unsqueeze(2) sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3) deformation = (sparse_motion * mask).sum(dim=1) deformation = deformation.permute(0, 2, 3, 1) #[4,64,64,2] out_dict['deformation'] = deformation # Sec. 3.2 in the paper if self.occlusion: occlusion_map = torch.sigmoid(self.occlusion(prediction)) out_dict['occlusion_map'] = occlusion_map #[4,1,64,64] return out_dict ================================================ FILE: modules/discriminator.py ================================================ from torch import nn import torch.nn.functional as F from modules.util import kp2gaussian import torch class DownBlock2d(nn.Module): """ Simple block for processing video (encoder). """ def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False): super(DownBlock2d, self).__init__() self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size) if sn: self.conv = nn.utils.spectral_norm(self.conv) if norm: self.norm = nn.InstanceNorm2d(out_features, affine=True) else: self.norm = None self.pool = pool def forward(self, x): out = x out = self.conv(out) if self.norm: out = self.norm(out) out = F.leaky_relu(out, 0.2) if self.pool: out = F.avg_pool2d(out, (2, 2)) return out class Discriminator(nn.Module): """ Discriminator similar to Pix2Pix """ def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512, sn=False, use_kp=False, num_kp=10, kp_variance=0.01, **kwargs): super(Discriminator, self).__init__() down_blocks = [] for i in range(num_blocks): down_blocks.append( DownBlock2d(num_channels + num_kp * use_kp if i == 0 else min(max_features, block_expansion * (2 ** i)), min(max_features, block_expansion * (2 ** (i + 1))), norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn)) self.down_blocks = nn.ModuleList(down_blocks) self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1) if sn: self.conv = nn.utils.spectral_norm(self.conv) self.use_kp = use_kp self.kp_variance = kp_variance def forward(self, x, kp=None): feature_maps = [] out = x if self.use_kp: heatmap = kp2gaussian(kp, x.shape[2:], self.kp_variance) out = torch.cat([out, heatmap], dim=1) for down_block in self.down_blocks: feature_maps.append(down_block(out)) out = feature_maps[-1] prediction_map = self.conv(out) return feature_maps, prediction_map class MultiScaleDiscriminator(nn.Module): """ Multi-scale (scale) discriminator """ def __init__(self, scales=(), **kwargs): super(MultiScaleDiscriminator, self).__init__() self.scales = scales discs = {} for scale in scales: discs[str(scale).replace('.', '-')] = Discriminator(**kwargs) self.discs = nn.ModuleDict(discs) def forward(self, x, kp=None): out_dict = {} for scale, disc in self.discs.items(): scale = str(scale).replace('-', '.') key = 'prediction_' + scale feature_maps, prediction_map = disc(x[key], kp) out_dict['feature_maps_' + scale] = feature_maps out_dict['prediction_map_' + scale] = prediction_map return out_dict ================================================ FILE: modules/function.py ================================================ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Thu Sep 30 17:45:24 2021 @author: SENSETIME\jixinya1 """ import torch def calc_mean_std(feat, eps=1e-5): # eps is a small value added to the variance to avoid divide-by-zero. size = feat.size() assert (len(size) == 4) N, C = size[:2] feat_var = feat.view(N, C, -1).var(dim=2) + eps feat_std = feat_var.sqrt().view(N, C, 1, 1) feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) return feat_mean, feat_std def adaptive_instance_normalization(content_feat, style_feat): assert (content_feat.size()[:2] == style_feat.size()[:2]) size = content_feat.size() style_mean, style_std = calc_mean_std(style_feat) content_mean, content_std = calc_mean_std(content_feat) normalized_feat = (content_feat - content_mean.expand( size)) / content_std.expand(size) return normalized_feat * style_std.expand(size) + style_mean.expand(size) def _calc_feat_flatten_mean_std(feat): # takes 3D feat (C, H, W), return mean and std of array within channels assert (feat.size()[0] == 3) assert (isinstance(feat, torch.FloatTensor)) feat_flatten = feat.view(3, -1) mean = feat_flatten.mean(dim=-1, keepdim=True) std = feat_flatten.std(dim=-1, keepdim=True) return feat_flatten, mean, std def _mat_sqrt(x): U, D, V = torch.svd(x) return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t()) def coral(source, target): # assume both source and target are 3D array (C, H, W) # Note: flatten -> f source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source) source_f_norm = (source_f - source_f_mean.expand_as( source_f)) / source_f_std.expand_as(source_f) source_f_cov_eye = \ torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3) target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target) target_f_norm = (target_f - target_f_mean.expand_as( target_f)) / target_f_std.expand_as(target_f) target_f_cov_eye = \ torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3) source_f_norm_transfer = torch.mm( _mat_sqrt(target_f_cov_eye), torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)), source_f_norm) ) source_f_transfer = source_f_norm_transfer * \ target_f_std.expand_as(source_f_norm) + \ target_f_mean.expand_as(source_f_norm) return source_f_transfer.view(source.size()) ================================================ FILE: modules/generator.py ================================================ import torch from torch import nn import torch.nn.functional as F from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d from modules.dense_motion import DenseMotionNetwork class OcclusionAwareGenerator(nn.Module): """ Generator that given source image and and keypoints try to transform image according to movement trajectories induced by keypoints. Generator follows Johnson architecture. """ def __init__(self, num_channels, num_kp, block_expansion, max_features, num_down_blocks, num_bottleneck_blocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): super(OcclusionAwareGenerator, self).__init__() if dense_motion_params is not None: self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, num_channels=num_channels, estimate_occlusion_map=estimate_occlusion_map, **dense_motion_params) else: self.dense_motion_network = None self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3)) down_blocks = [] for i in range(num_down_blocks): in_features = min(max_features, block_expansion * (2 ** i)) out_features = min(max_features, block_expansion * (2 ** (i + 1))) down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) self.down_blocks = nn.ModuleList(down_blocks) up_blocks = [] for i in range(num_down_blocks): in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i))) out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1))) up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) self.up_blocks = nn.ModuleList(up_blocks) self.bottleneck = torch.nn.Sequential() in_features = min(max_features, block_expansion * (2 ** num_down_blocks)) for i in range(num_bottleneck_blocks): self.bottleneck.add_module('r' + str(i), ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1))) self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3)) self.estimate_occlusion_map = estimate_occlusion_map self.num_channels = num_channels def deform_input(self, inp, deformation): _, h_old, w_old, _ = deformation.shape _, _, h, w = inp.shape if h_old != h or w_old != w: deformation = deformation.permute(0, 3, 1, 2) deformation = F.interpolate(deformation, size=(h, w), mode='bilinear') deformation = deformation.permute(0, 2, 3, 1) return F.grid_sample(inp, deformation) def forward(self, source_image, kp_driving, kp_source): # Encoding (downsampling) part out = self.first(source_image) #[4,64,H,W] for i in range(len(self.down_blocks)): out = self.down_blocks[i](out) #[4,256,H/4,W/4] # Transforming feature representation according to deformation and occlusion output_dict = {} if self.dense_motion_network is not None: dense_motion = self.dense_motion_network(source_image=source_image, kp_driving=kp_driving, kp_source=kp_source) output_dict['mask'] = dense_motion['mask'] output_dict['sparse_deformed'] = dense_motion['sparse_deformed'] if 'occlusion_map' in dense_motion: occlusion_map = dense_motion['occlusion_map'] output_dict['occlusion_map'] = occlusion_map else: occlusion_map = None deformation = dense_motion['deformation'] out = self.deform_input(out, deformation) if occlusion_map is not None: if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') out = out * occlusion_map output_dict["deformed"] = self.deform_input(source_image, deformation) # Decoding part out = self.bottleneck(out) #[4,256,64,64] for i in range(len(self.up_blocks)): out = self.up_blocks[i](out) out = self.final(out) out = torch.sigmoid(out) #[4,3,256,256] output_dict["prediction"] = out return output_dict ================================================ FILE: modules/keypoint_detector.py ================================================ from torch import nn import torch import torch.nn.functional as F from modules.util import Hourglass, make_coordinate_grid, AntiAliasInterpolation2d, Ct_encoder, EmotionNet, AF2F, AF2F_s, draw_heatmap class KPDetector(nn.Module): """ Detecting a keypoints. Return keypoint position and jacobian near each keypoint. """ def __init__(self, block_expansion, num_kp, num_channels, max_features, num_blocks, temperature, estimate_jacobian=False, scale_factor=1, single_jacobian_map=False, pad=0): super(KPDetector, self).__init__() self.predictor = Hourglass(block_expansion, in_features=num_channels, max_features=max_features, num_blocks=num_blocks) self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7), padding=pad) if estimate_jacobian: self.num_jacobian_maps = 1 if single_jacobian_map else num_kp self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad) self.jacobian.weight.data.zero_() self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float)) else: self.jacobian = None self.temperature = temperature self.scale_factor = scale_factor if self.scale_factor != 1: self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) def gaussian2kp(self, heatmap): """ Extract the mean and from a heatmap """ shape = heatmap.shape heatmap = heatmap.unsqueeze(-1) #[4,10,58,58,1] grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) #[1,1,58,58,2] value = (heatmap * grid).sum(dim=(2, 3)) #[4,10,2] kp = {'value': value} return kp def audio_feature(self, x, heatmap): # prediction = self.kp(x) #[4,10,H/4-6, W/4-6] # final_shape = prediction.shape # heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58] # heatmap = F.softmax(heatmap / self.temperature, dim=2) # heatmap = heatmap.view(*final_shape) #[4,10,58,58] # out = self.gaussian2kp(heatmap) final_shape = heatmap.squeeze(2).shape if self.jacobian is not None: jacobian_map = self.jacobian(x) ##[4,40,H/4-6, W/4-6] jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2], final_shape[3]) heatmap = heatmap.unsqueeze(2) jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6] jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) jacobian = jacobian.sum(dim=-1) #[4,10,4] jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2] return jacobian def forward(self, x): #torch.Size([4, 3, H, W]) if self.scale_factor != 1: x = self.down(x) # 0.25 [4, 3, H/4, W/4] feature_map = self.predictor(x) #[4,3+32,H/4, W/4] prediction = self.kp(feature_map) #[4,10,H/4-6, W/4-6] final_shape = prediction.shape heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58] heatmap = F.softmax(heatmap / self.temperature, dim=2) heatmap = heatmap.view(*final_shape) #[4,10,58,58] out = self.gaussian2kp(heatmap) out['heatmap'] = heatmap if self.jacobian is not None: jacobian_map = self.jacobian(feature_map) ##[4,40,H/4-6, W/4-6] jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2], final_shape[3]) heatmap = heatmap.unsqueeze(2) jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6] jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) jacobian = jacobian.sum(dim=-1) #[4,10,4] jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2] out['jacobian'] = jacobian return out class KPDetector_a(nn.Module): """ Detecting a keypoints. Return keypoint position and jacobian near each keypoint. """ def __init__(self, block_expansion, num_kp, num_channels,num_channels_a, max_features, num_blocks, temperature, estimate_jacobian=False, scale_factor=1, single_jacobian_map=False, pad=0): super(KPDetector_a, self).__init__() self.predictor = Hourglass(block_expansion, in_features=num_channels_a, max_features=max_features, num_blocks=num_blocks) self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7), padding=pad) if estimate_jacobian: self.num_jacobian_maps = 1 if single_jacobian_map else num_kp self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad) self.jacobian.weight.data.zero_() self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float)) else: self.jacobian = None self.temperature = temperature self.scale_factor = scale_factor if self.scale_factor != 1: self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) def gaussian2kp(self, heatmap): """ Extract the mean and from a heatmap """ shape = heatmap.shape heatmap = heatmap.unsqueeze(-1) #[4,10,58,58,1] grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) #[1,1,58,58,2] value = (heatmap * grid).sum(dim=(2, 3)) #[4,10,2] kp = {'value': value} return kp def audio_feature(self, x, heatmap): # prediction = self.kp(x) #[4,10,H/4-6, W/4-6] # final_shape = prediction.shape # heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58] # heatmap = F.softmax(heatmap / self.temperature, dim=2) # heatmap = heatmap.view(*final_shape) #[4,10,58,58] # out = self.gaussian2kp(heatmap) final_shape = heatmap.squeeze(2).shape if self.jacobian is not None: jacobian_map = self.jacobian(x) ##[4,40,H/4-6, W/4-6] jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2], final_shape[3]) heatmap = heatmap.unsqueeze(2) jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6] jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) jacobian = jacobian.sum(dim=-1) #[4,10,4] jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2] return jacobian def forward(self, feature_map): #torch.Size([4, 3, H, W]) prediction = self.kp(feature_map) #[4,10,H/4-6, W/4-6] final_shape = prediction.shape heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58] heatmap = F.softmax(heatmap / self.temperature, dim=2) heatmap = heatmap.view(*final_shape) #[4,10,58,58] out = self.gaussian2kp(heatmap) out['heatmap'] = heatmap if self.jacobian is not None: jacobian_map = self.jacobian(feature_map) ##[4,40,H/4-6, W/4-6] jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2], final_shape[3]) heatmap = heatmap.unsqueeze(2) jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6] jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) jacobian = jacobian.sum(dim=-1) #[4,10,4] jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2] out['jacobian'] = jacobian return out class Audio_Feature(nn.Module): def __init__(self): super(Audio_Feature, self).__init__() self.con_encoder = Ct_encoder() self.emo_encoder = EmotionNet() self.decoder = AF2F_s() def forward(self, x): x = x.unsqueeze(1) c = self.con_encoder(x) e = self.emo_encoder(x) # d = torch.cat([c, e], dim=1) d = self.decoder(c) return d ''' def forward(self, x, cube, audio): #torch.Size([4, 3, H, W]) if self.scale_factor != 1: x = self.down(x) # 0.25 [4, 3, H/4, W/4] cube = cube.unsqueeze(1) feature = torch.cat([x,cube,audio],dim=1) feature_map = self.predictor(feature) #[4,3+32,H/4, W/4] prediction = self.kp(feature_map) #[4,10,H/4-6, W/4-6] final_shape = prediction.shape heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58] heatmap = F.softmax(heatmap / self.temperature, dim=2) heatmap = heatmap.view(*final_shape) #[4,10,58,58] out = self.gaussian2kp(heatmap) out['heatmap'] = heatmap if self.jacobian is not None: jacobian_map = self.jacobian(feature_map) ##[4,40,H/4-6, W/4-6] jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2], final_shape[3]) heatmap = heatmap.unsqueeze(2) jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6] jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) jacobian = jacobian.sum(dim=-1) #[4,10,4] jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2] out['jacobian'] = jacobian return out ''' ================================================ FILE: modules/model.py ================================================ from torch import nn import torch import torch.nn.functional as F from modules.util import AntiAliasInterpolation2d, make_coordinate_grid from torchvision import models import numpy as np from torch.autograd import grad class Vgg19(torch.nn.Module): """ Vgg19 network for perceptual loss. See Sec 3.3. """ def __init__(self, requires_grad=False): super(Vgg19, self).__init__() vgg_pretrained_features = models.vgg19(pretrained=True).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() for x in range(2): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(2, 7): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(7, 12): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(12, 21): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(21, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), requires_grad=False) self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), requires_grad=False) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): X = (X - self.mean) / self.std h_relu1 = self.slice1(X) h_relu2 = self.slice2(h_relu1) h_relu3 = self.slice3(h_relu2) h_relu4 = self.slice4(h_relu3) h_relu5 = self.slice5(h_relu4) out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] return out class ImagePyramide(torch.nn.Module): """ Create image pyramide for computing pyramide perceptual loss. See Sec 3.3 """ def __init__(self, scales, num_channels): super(ImagePyramide, self).__init__() downs = {} for scale in scales: downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale) self.downs = nn.ModuleDict(downs) def forward(self, x): out_dict = {} for scale, down_module in self.downs.items(): out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x) return out_dict class Transform: """ Random tps transformation for equivariance constraints. See Sec 3.3 """ def __init__(self, bs, **kwargs): noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3])) self.theta = noise + torch.eye(2, 3).view(1, 2, 3) self.bs = bs if ('sigma_tps' in kwargs) and ('points_tps' in kwargs): self.tps = True self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type()) self.control_points = self.control_points.unsqueeze(0) self.control_params = torch.normal(mean=0, std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2])) else: self.tps = False def transform_frame(self, frame): grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2] grid = grid.view(1, frame.shape[2] * frame.shape[3], 2) grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2) return F.grid_sample(frame, grid, padding_mode="reflection") def inverse_transform_frame(self, frame): grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2] grid = grid.view(1, frame.shape[2] * frame.shape[3], 2) grid = self.inverse_warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2) return F.grid_sample(frame, grid, padding_mode="reflection") def warp_coordinates(self, coordinates): theta = self.theta.type(coordinates.type()) theta = theta.unsqueeze(1) transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:] transformed = transformed.squeeze(-1) if self.tps: control_points = self.control_points.type(coordinates.type()) control_params = self.control_params.type(coordinates.type()) distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2) distances = torch.abs(distances).sum(-1) result = distances ** 2 result = result * torch.log(distances + 1e-6) result = result * control_params result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1) transformed = transformed + result return transformed def inverse_warp_coordinates(self, coordinates): theta = self.theta.type(coordinates.type()) theta = theta.unsqueeze(1) a = torch.FloatTensor([[[[0,0,1]]]]).repeat([self.bs,1,1,1]).cuda() c = torch.cat((theta,a),2) d = c.inverse()[:,:,:2,:] d = d.type(coordinates.type()) transformed = torch.matmul(d[:, :, :, :2], coordinates.unsqueeze(-1)) + d[:, :, :, 2:] transformed = transformed.squeeze(-1) if self.tps: control_points = self.control_points.type(coordinates.type()) control_params = self.control_params.type(coordinates.type()) distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2) distances = torch.abs(distances).sum(-1) result = distances ** 2 result = result * torch.log(distances + 1e-6) result = result * control_params result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1) transformed = transformed + result return transformed def jacobian(self, coordinates): coordinates.requires_grad=True new_coordinates = self.warp_coordinates(coordinates)#[4,10,2] grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True) grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True) jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2) return jacobian def detach_kp(kp): return {key: value.detach() for key, value in kp.items()} class TrainPart1Model(torch.nn.Module): """ Merge all generator related updates into single model for better multi-gpu usage """ def __init__(self, kp_extractor, kp_extractor_a, audio_feature, generator, discriminator, train_params, device_ids): super(TrainFullModel, self).__init__() self.kp_extractor = kp_extractor self.kp_extractor_a = kp_extractor_a self.audio_feature = audio_feature self.generator = generator self.discriminator = discriminator self.train_params = train_params self.scales = train_params['scales'] self.disc_scales = self.discriminator.scales self.pyramid = ImagePyramide(self.scales, generator.num_channels) if torch.cuda.is_available(): self.pyramid = self.pyramid.cuda() self.loss_weights = train_params['loss_weights'] if sum(self.loss_weights['perceptual']) != 0: self.vgg = Vgg19() if torch.cuda.is_available(): self.vgg = self.vgg.cuda() self.mse_loss_fn = nn.MSELoss().cuda() def forward(self, x): kp_source = self.kp_extractor(x['example_image']) kp_driving = [] for i in range(16): kp_driving.append(self.kp_extractor(x['driving'][:,i])) kp_driving_a = [] #x['example_image'], deco_out = self.audio_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net']) loss_values = {} if self.loss_weights['audio'] != 0: kp_driving_a = [] for i in range(16): kp_driving_a.append(self.kp_extractor_a(deco_out[:,i]))# loss_value = 0 loss_heatmap = 0 loss_jacobian = 0 loss_perceptual = 0 for i in range(len(kp_driving)): loss_jacobian += (torch.abs(kp_driving[i]['jacobian'] - kp_driving_a[i]['jacobian']).mean())*self.loss_weights['audio'] # loss_jacobian = loss_jacobian*self.loss_weights['audio'] loss_heatmap += (torch.abs(kp_driving[i]['heatmap'] - kp_driving_a[i]['heatmap']).mean())*self.loss_weights['audio']*100 loss_value += (torch.abs(kp_driving[i]['value'].detach() - kp_driving_a[i]['value']).mean())*self.loss_weights['audio'] loss_values['loss_value'] = loss_value/len(kp_driving) loss_values['loss_heatmap'] = loss_heatmap/len(kp_driving) loss_values['loss_jacobian'] = loss_jacobian/len(kp_driving) if self.train_params['generator'] == 'not': # loss_values['perceptual'] = self.mse_loss_fn(deco_out,deco_out) for i in range(1): #0,len(kp_driving),4 generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving_a[i]) generated.update({'kp_source': kp_source, 'kp_driving': kp_driving_a}) elif self.train_params['generator'] == 'visual': for i in range(0,len(kp_driving),4): #0,len(kp_driving),4 generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving[i]) generated.update({'kp_source': kp_source, 'kp_driving': kp_driving}) pyramide_real = self.pyramid(x['driving'][:,i]) pyramide_generated = self.pyramid(generated['prediction']) if sum(self.loss_weights['perceptual']) != 0: value_total = 0 for scale in self.scales: x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) for i, weight in enumerate(self.loss_weights['perceptual']): value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() value_total += self.loss_weights['perceptual'][i] * value loss_perceptual += value_total length = int((len(kp_driving)-1)/4)+1 loss_values['perceptual'] = loss_perceptual/length elif self.train_params['generator'] == 'audio': for i in range(0,len(kp_driving),4): #0,len(kp_driving),4 generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving_a[i]) generated.update({'kp_source': kp_source, 'kp_driving': kp_driving_a}) pyramide_real = self.pyramid(x['driving'][:,i]) pyramide_generated = self.pyramid(generated['prediction']) if sum(self.loss_weights['perceptual']) != 0: value_total = 0 for scale in self.scales: x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) for i, weight in enumerate(self.loss_weights['perceptual']): value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() value_total += self.loss_weights['perceptual'][i] * value loss_perceptual += value_total length = int((len(kp_driving)-1)/4)+1 loss_values['perceptual'] = loss_perceptual/length else: print('wrong train_params: ', self.train_params['generator']) return loss_values,generated class TrainPart2Model(torch.nn.Module): """ Merge all generator related updates into single model for better multi-gpu usage """ def __init__(self, kp_extractor, emo_feature, kp_extractor_a, audio_feature, generator, discriminator, train_params, device_ids): super(TrainFullModel, self).__init__() self.kp_extractor = kp_extractor self.kp_extractor_a = kp_extractor_a self.audio_feature = audio_feature self.emo_feature = emo_feature self.generator = generator self.discriminator = discriminator self.train_params = train_params self.scales = train_params['scales'] self.disc_scales = self.discriminator.scales self.pyramid = ImagePyramide(self.scales, generator.num_channels) if torch.cuda.is_available(): self.pyramid = self.pyramid.cuda() self.loss_weights = train_params['loss_weights'] if sum(self.loss_weights['perceptual']) != 0: self.vgg = Vgg19() if torch.cuda.is_available(): self.vgg = self.vgg.cuda() self.mse_loss_fn = nn.MSELoss().cuda() self.CroEn_loss = nn.CrossEntropyLoss().cuda() def forward(self, x): kp_source = self.kp_extractor(x['example_image']) kp_driving = [] kp_emo = [] for i in range(16): kp_driving.append(self.kp_extractor(x['driving'][:,i])) # kp_emo.append(self.emo_detector(x['driving'][:,i])) kp_driving_a = [] #x['example_image'], deco_out = self.audio_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net']) # emo_out = self.emo_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net']) loss_values = {} if self.loss_weights['emo'] != 0: kp_driving_a = [] fakes = [] for i in range(16): kp_driving_a.append(self.kp_extractor_a(deco_out[:,i]))# value = self.kp_extractor_a(deco_out[:,i])['value'] jacobian = self.kp_extractor_a(deco_out[:,i])['jacobian'] if self.train_params['type'] == 'linear_4' : out, fake = self.emo_feature(x['transformed_driving'][:,i],value,jacobian) kp_emo.append(out) fakes.append(fake) # kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian)) elif self.train_params['type'] == 'linear_10': # kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)) out, fake = self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian) kp_emo.append(out) fakes.append(fake) elif self.train_params['type'] == 'linear_4_new': # kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)) out, fake = self.emo_feature.linear_4(x['transformed_driving'][:,i],value,jacobian) kp_emo.append(out) fakes.append(fake) elif self.train_params['type'] == 'linear_np_4': # kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)) out, fake = self.emo_feature.linear_np_4(x['transformed_driving'][:,i],value,jacobian) kp_emo.append(out) fakes.append(fake) elif self.train_params['type'] == 'linear_np_10': # kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)) out, fake = self.emo_feature.linear_np_10(x['transformed_driving'][:,i],value,jacobian) kp_emo.append(out) fakes.append(fake) loss_value = 0 loss_jacobian = 0 loss_classify = 0 kp_all = kp_driving_a for i in range(len(kp_driving)): if self.train_params['type'] == 'linear_4' or self.train_params['type'] == 'linear_4_new' or self.train_params['type'] == 'linear_np_4': loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,1] - kp_driving_a[i]['jacobian'][:,1] -kp_emo[i]['jacobian'][:,0]).mean())*self.loss_weights['emo'] loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,4] - kp_driving_a[i]['jacobian'][:,4] -kp_emo[i]['jacobian'][:,1]).mean())*self.loss_weights['emo'] loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,6] - kp_driving_a[i]['jacobian'][:,6] -kp_emo[i]['jacobian'][:,2]).mean())*self.loss_weights['emo'] loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,8] - kp_driving_a[i]['jacobian'][:,8] -kp_emo[i]['jacobian'][:,3]).mean())*self.loss_weights['emo'] loss_classify += self.CroEn_loss(fakes[i],x['emotion']) loss_value += (torch.abs(kp_driving[i]['value'][:,1] .detach() - kp_driving_a[i]['value'][:,1] - kp_emo[i]['value'][:,0] ).mean())*self.loss_weights['emo'] loss_value += (torch.abs(kp_driving[i]['value'][:,4] .detach() - kp_driving_a[i]['value'][:,4] - kp_emo[i]['value'][:,1] ).mean())*self.loss_weights['emo'] loss_value += (torch.abs(kp_driving[i]['value'][:,6] .detach() - kp_driving_a[i]['value'][:,6] - kp_emo[i]['value'][:,2] ).mean())*self.loss_weights['emo'] loss_value += (torch.abs(kp_driving[i]['value'][:,8] .detach() - kp_driving_a[i]['value'][:,8] - kp_emo[i]['value'][:,3] ).mean())*self.loss_weights['emo'] kp_all[i]['jacobian'][:,1] = kp_emo[i]['jacobian'][:,0] + kp_driving_a[i]['jacobian'][:,1] kp_all[i]['jacobian'][:,4] = kp_emo[i]['jacobian'][:,1] + kp_driving_a[i]['jacobian'][:,4] kp_all[i]['jacobian'][:,6] = kp_emo[i]['jacobian'][:,2] + kp_driving_a[i]['jacobian'][:,6] kp_all[i]['jacobian'][:,8] = kp_emo[i]['jacobian'][:,3] + kp_driving_a[i]['jacobian'][:,8] kp_all[i]['value'][:,1] = kp_emo[i]['value'][:,0] + kp_driving_a[i]['value'][:,1] kp_all[i]['value'][:,4] = kp_emo[i]['value'][:,1] + kp_driving_a[i]['value'][:,4] kp_all[i]['value'][:,6] = kp_emo[i]['value'][:,2] + kp_driving_a[i]['value'][:,6] kp_all[i]['value'][:,8] = kp_emo[i]['value'][:,3] + kp_driving_a[i]['value'][:,8] elif self.train_params['type'] == 'linear_10' or self.train_params['type'] == 'linear_np_10': loss_jacobian += (torch.abs(kp_driving[i]['jacobian'] - kp_driving_a[i]['jacobian'] -kp_emo[i]['jacobian']).mean())*self.loss_weights['emo'] loss_classify += self.CroEn_loss(fakes[i],x['emotion']) loss_value += (torch.abs(kp_driving[i]['value'].detach() - kp_driving_a[i]['value'] - kp_emo[i]['value'] ).mean())*self.loss_weights['emo'] # kp_all[i]['value'] = kp_emo[i]['value'] + kp_driving_a[i]['value'] loss_values['loss_value'] = loss_value/len(kp_driving) # loss_values['loss_heatmap'] = loss_heatmap/len(kp_driving) loss_values['loss_jacobian'] = loss_jacobian/len(kp_driving) if self.train_params['classify'] == True: loss_values['loss_classify'] = loss_classify/len(kp_driving) else: loss_values['loss_classify'] = torch.tensor(0, device = loss_values['loss_value'].device) return loss_values,generated class GeneratorFullModel(torch.nn.Module): """ Merge all generator related updates into single model for better multi-gpu usage """ def __init__(self, kp_extractor, kp_extractor_a, audio_feature, generator, discriminator, train_params): super(GeneratorFullModel, self).__init__() self.kp_extractor = kp_extractor self.kp_extractor_a = kp_extractor_a # self.content_encoder = content_encoder # self.emotion_encoder = emotion_encoder self.audio_feature = audio_feature self.generator = generator self.discriminator = discriminator self.train_params = train_params self.scales = train_params['scales'] self.disc_scales = self.discriminator.scales self.pyramid = ImagePyramide(self.scales, generator.num_channels) if torch.cuda.is_available(): self.pyramid = self.pyramid.cuda() self.loss_weights = train_params['loss_weights'] if sum(self.loss_weights['perceptual']) != 0: self.vgg = Vgg19() if torch.cuda.is_available(): self.vgg = self.vgg.cuda() self.pca = torch.FloatTensor(np.load('.../LRW/list/U_106.npy'))[:, :16].cuda() self.mean = torch.FloatTensor(np.load('.../LRW/list/mean_106.npy')).cuda() def forward(self, x): # source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[]) # source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1))) # kp_source = self.kp_extractor(x['source']) # kp_source_a = self.kp_extractor_a(x['source'], x['source_cube'], source_a_f) # driving_a_f = self.audio_feature(self.content_encoder(x['driving_audio'].unsqueeze(1)), self.emotion_encoder(x['driving_audio'].unsqueeze(1))) # driving_a_f = self.audio_feature(x['driving_audio']) # kp_driving = self.kp_extractor(x['driving']) # kp_driving_a = self.kp_extractor_a(x['driving'], x['driving_cube'], driving_a_f) kp_driving = [] for i in range(16): kp_driving.append(self.kp_extractor(x['driving'][:,i],x['driving_landmark'][:,i],self.loss_weights['equivariance_value'])) kp_driving_a = [] fc_out, deco_out = self.audio_feature(x['example_landmark'], x['driving_audio'], x['driving_pose']) fake_lmark=fc_out + x['example_landmark'].expand_as(fc_out) fake_lmark = torch.mm( fake_lmark, self.pca.t() ) fake_lmark = fake_lmark + self.mean.expand_as(fake_lmark) fake_lmark = fake_lmark.unsqueeze(0) # for i in range(16): # kp_driving_a.append() # generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving) # generated.update({'kp_source': kp_source, 'kp_driving': kp_driving}) loss_values = {} pyramide_real = self.pyramid(x['driving']) pyramide_generated = self.pyramid(generated['prediction']) if self.loss_weights['audio'] != 0: value = torch.abs(kp_source['jacobian'].detach() - kp_source_a['jacobian'].detach()).mean() + torch.abs(kp_driving['jacobian'].detach() - kp_driving_a['jacobian']).mean() value = value/2 loss_values['jacobian'] = value*self.loss_weights['audio'] value = torch.abs(kp_source['heatmap'].detach() - kp_source_a['heatmap'].detach()).mean() + torch.abs(kp_driving['heatmap'].detach() - kp_driving_a['heatmap']).mean() value = value/2 loss_values['heatmap'] = value*self.loss_weights['audio'] value = torch.abs(kp_source['value'].detach() - kp_source_a['value'].detach()).mean() + torch.abs(kp_driving['value'].detach() - kp_driving_a['value']).mean() value = value/2 loss_values['value'] = value*self.loss_weights['audio'] if sum(self.loss_weights['perceptual']) != 0: value_total = 0 for scale in self.scales: x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) for i, weight in enumerate(self.loss_weights['perceptual']): value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() value_total += self.loss_weights['perceptual'][i] * value loss_values['perceptual'] = value_total if self.loss_weights['generator_gan'] != 0: discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving)) discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving)) value_total = 0 for scale in self.disc_scales: key = 'prediction_map_%s' % scale value = ((1 - discriminator_maps_generated[key]) ** 2).mean() value_total += self.loss_weights['generator_gan'] * value loss_values['gen_gan'] = value_total if sum(self.loss_weights['feature_matching']) != 0: value_total = 0 for scale in self.disc_scales: key = 'feature_maps_%s' % scale for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])): if self.loss_weights['feature_matching'][i] == 0: continue value = torch.abs(a - b).mean() value_total += self.loss_weights['feature_matching'][i] * value loss_values['feature_matching'] = value_total if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0: transform = Transform(x['driving'].shape[0], **self.train_params['transform_params']) transformed_frame = transform.transform_frame(x['driving']) transformed_landmark = transform.inverse_warp_coordinates(x['driving_landmark']) transformed_kp = self.kp_extractor(transformed_frame) generated['transformed_frame'] = transformed_frame generated['transformed_kp'] = transformed_kp ## Value loss part if self.loss_weights['equivariance_value'] != 0: value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean() loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value ## jacobian loss part if self.loss_weights['equivariance_jacobian'] != 0: jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']), transformed_kp['jacobian']) normed_driving = torch.inverse(kp_driving['jacobian']) normed_transformed = jacobian_transformed value = torch.matmul(normed_driving, normed_transformed) eye = torch.eye(2).view(1, 1, 2, 2).type(value.type()) value = torch.abs(eye - value).mean() loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value return loss_values, generated class DiscriminatorFullModel(torch.nn.Module): """ Merge all discriminator related updates into single model for better multi-gpu usage """ def __init__(self, kp_extractor, generator, discriminator, train_params): super(DiscriminatorFullModel, self).__init__() self.kp_extractor = kp_extractor self.generator = generator self.discriminator = discriminator self.train_params = train_params self.scales = self.discriminator.scales self.pyramid = ImagePyramide(self.scales, generator.num_channels) if torch.cuda.is_available(): self.pyramid = self.pyramid.cuda() self.loss_weights = train_params['loss_weights'] def forward(self, x, generated): pyramide_real = self.pyramid(x['driving']) pyramide_generated = self.pyramid(generated['prediction'].detach()) kp_driving = generated['kp_driving'] discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving)) discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving)) loss_values = {} value_total = 0 for scale in self.scales: key = 'prediction_map_%s' % scale value = (1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2 value_total += self.loss_weights['discriminator_gan'] * value.mean() loss_values['disc_gan'] = value_total return loss_values ================================================ FILE: modules/model_delta_map.py ================================================ from torch import nn import torch import torch.nn.functional as F from modules.util import AntiAliasInterpolation2d, make_coordinate_grid from torchvision import models import numpy as np from torch.autograd import grad class Vgg19(torch.nn.Module): """ Vgg19 network for perceptual loss. See Sec 3.3. """ def __init__(self, requires_grad=False): super(Vgg19, self).__init__() vgg_pretrained_features = models.vgg19(pretrained=True).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() for x in range(2): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(2, 7): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(7, 12): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(12, 21): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(21, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), requires_grad=False) self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), requires_grad=False) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): X = (X - self.mean) / self.std h_relu1 = self.slice1(X) h_relu2 = self.slice2(h_relu1) h_relu3 = self.slice3(h_relu2) h_relu4 = self.slice4(h_relu3) h_relu5 = self.slice5(h_relu4) out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] return out class ImagePyramide(torch.nn.Module): """ Create image pyramide for computing pyramide perceptual loss. See Sec 3.3 """ def __init__(self, scales, num_channels): super(ImagePyramide, self).__init__() downs = {} for scale in scales: downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale) self.downs = nn.ModuleDict(downs) def forward(self, x): out_dict = {} for scale, down_module in self.downs.items(): out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x) return out_dict class Transform: """ Random tps transformation for equivariance constraints. See Sec 3.3 """ def __init__(self, bs, **kwargs): noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3])) self.theta = noise + torch.eye(2, 3).view(1, 2, 3) self.bs = bs if ('sigma_tps' in kwargs) and ('points_tps' in kwargs): self.tps = True self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type()) self.control_points = self.control_points.unsqueeze(0) self.control_params = torch.normal(mean=0, std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2])) else: self.tps = False def transform_frame(self, frame): grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2] grid = grid.view(1, frame.shape[2] * frame.shape[3], 2) grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2) return F.grid_sample(frame, grid, padding_mode="reflection") def inverse_transform_frame(self, frame): grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2] grid = grid.view(1, frame.shape[2] * frame.shape[3], 2) grid = self.inverse_warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2) return F.grid_sample(frame, grid, padding_mode="reflection") def warp_coordinates(self, coordinates): theta = self.theta.type(coordinates.type()) theta = theta.unsqueeze(1) transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:] transformed = transformed.squeeze(-1) if self.tps: control_points = self.control_points.type(coordinates.type()) control_params = self.control_params.type(coordinates.type()) distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2) distances = torch.abs(distances).sum(-1) result = distances ** 2 result = result * torch.log(distances + 1e-6) result = result * control_params result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1) transformed = transformed + result return transformed def inverse_warp_coordinates(self, coordinates): theta = self.theta.type(coordinates.type()) theta = theta.unsqueeze(1) a = torch.FloatTensor([[[[0,0,1]]]]).repeat([self.bs,1,1,1]).cuda() c = torch.cat((theta,a),2) d = c.inverse()[:,:,:2,:] d = d.type(coordinates.type()) transformed = torch.matmul(d[:, :, :, :2], coordinates.unsqueeze(-1)) + d[:, :, :, 2:] transformed = transformed.squeeze(-1) if self.tps: control_points = self.control_points.type(coordinates.type()) control_params = self.control_params.type(coordinates.type()) distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2) distances = torch.abs(distances).sum(-1) result = distances ** 2 result = result * torch.log(distances + 1e-6) result = result * control_params result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1) transformed = transformed + result return transformed def jacobian(self, coordinates): coordinates.requires_grad=True new_coordinates = self.warp_coordinates(coordinates)#[4,10,2] grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True) grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True) jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2) return jacobian def detach_kp(kp): return {key: value.detach() for key, value in kp.items()} class TrainFullModel(torch.nn.Module): """ Merge all generator related updates into single model for better multi-gpu usage """ def __init__(self, kp_extractor, emo_feature, kp_extractor_a, audio_feature, generator, discriminator, train_params, device_ids): super(TrainFullModel, self).__init__() self.kp_extractor = kp_extractor self.kp_extractor_a = kp_extractor_a # self.emo_detector = emo_detector # self.content_encoder = content_encoder # self.emotion_encoder = emotion_encoder self.audio_feature = audio_feature self.emo_feature = emo_feature self.generator = generator self.discriminator = discriminator self.train_params = train_params self.scales = train_params['scales'] self.disc_scales = self.discriminator.scales self.pyramid = ImagePyramide(self.scales, generator.num_channels) if torch.cuda.is_available(): self.pyramid = self.pyramid.cuda() self.loss_weights = train_params['loss_weights'] if sum(self.loss_weights['perceptual']) != 0: self.vgg = Vgg19() if torch.cuda.is_available(): self.vgg = self.vgg.cuda() # self.pca = torch.FloatTensor(np.load('/mnt/lustre/jixinya/Home/LRW/list/U_106.npy'))[:, :16].to(device_ids[0]) # self.mean = torch.FloatTensor(np.load('/mnt/lustre/jixinya/Home/LRW/list/mean_106.npy')).to(device_ids[0]) self.mse_loss_fn = nn.MSELoss().cuda() self.CroEn_loss = nn.CrossEntropyLoss().cuda() def forward(self, x): # source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[]) # source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1))) kp_source = self.kp_extractor(x['example_image']) kp_driving = [] kp_emo = [] for i in range(16): kp_driving.append(self.kp_extractor(x['driving'][:,i])) # kp_emo.append(self.emo_detector(x['driving'][:,i])) # print('KP_driving ', file=open('/mnt/lustre/jixinya/Home/fomm_audio/log/LRW_test.txt', 'a')) kp_driving_a = [] #x['example_image'], deco_out = self.audio_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net']) # emo_out = self.emo_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net']) loss_values = {} if self.loss_weights['emo'] != 0: kp_driving_a = [] fakes = [] for i in range(16): kp_driving_a.append(self.kp_extractor_a(deco_out[:,i]))# value = self.kp_extractor_a(deco_out[:,i])['value'] jacobian = self.kp_extractor_a(deco_out[:,i])['jacobian'] if self.train_params['type'] == 'map_4': out, fake = self.emo_feature.map_4(x['transformed_driving'][:,i],value,jacobian) kp_emo.append(out) fakes.append(fake) # kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian)) elif self.train_params['type'] == 'map_10': # kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)) out, fake = self.emo_feature(x['transformed_driving'][:,i],value,jacobian) kp_emo.append(out) fakes.append(fake) # kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian)) # print('Kp_audio_driving ', file=open('/mnt/lustre/jixinya/Home/fomm_audio/log/LRW_test.txt', 'a')) loss_value = 0 # loss_heatmap = 0 loss_jacobian = 0 loss_perceptual = 0 loss_classify = 0 kp_all = kp_driving_a for i in range(len(kp_driving)): if self.train_params['type'] == 'map_4': loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,1] - kp_driving_a[i]['jacobian'][:,1] -kp_emo[i]['jacobian'][:,0]).mean())*self.loss_weights['emo'] loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,4] - kp_driving_a[i]['jacobian'][:,4] -kp_emo[i]['jacobian'][:,1]).mean())*self.loss_weights['emo'] loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,6] - kp_driving_a[i]['jacobian'][:,6] -kp_emo[i]['jacobian'][:,2]).mean())*self.loss_weights['emo'] loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,8] - kp_driving_a[i]['jacobian'][:,8] -kp_emo[i]['jacobian'][:,3]).mean())*self.loss_weights['emo'] loss_classify += self.CroEn_loss(fakes[i],x['emotion']) loss_value += (torch.abs(kp_driving[i]['value'][:,1] .detach() - kp_driving_a[i]['value'][:,1] - kp_emo[i]['value'][:,0] ).mean())*self.loss_weights['emo'] loss_value += (torch.abs(kp_driving[i]['value'][:,4] .detach() - kp_driving_a[i]['value'][:,4] - kp_emo[i]['value'][:,1] ).mean())*self.loss_weights['emo'] loss_value += (torch.abs(kp_driving[i]['value'][:,6] .detach() - kp_driving_a[i]['value'][:,6] - kp_emo[i]['value'][:,2] ).mean())*self.loss_weights['emo'] loss_value += (torch.abs(kp_driving[i]['value'][:,8] .detach() - kp_driving_a[i]['value'][:,8] - kp_emo[i]['value'][:,3] ).mean())*self.loss_weights['emo'] kp_all[i]['jacobian'][:,1] = kp_emo[i]['jacobian'][:,0] + kp_driving_a[i]['jacobian'][:,1] kp_all[i]['jacobian'][:,4] = kp_emo[i]['jacobian'][:,1] + kp_driving_a[i]['jacobian'][:,4] kp_all[i]['jacobian'][:,6] = kp_emo[i]['jacobian'][:,2] + kp_driving_a[i]['jacobian'][:,6] kp_all[i]['jacobian'][:,8] = kp_emo[i]['jacobian'][:,3] + kp_driving_a[i]['jacobian'][:,8] kp_all[i]['value'][:,1] = kp_emo[i]['value'][:,0] + kp_driving_a[i]['value'][:,1] kp_all[i]['value'][:,4] = kp_emo[i]['value'][:,1] + kp_driving_a[i]['value'][:,4] kp_all[i]['value'][:,6] = kp_emo[i]['value'][:,2] + kp_driving_a[i]['value'][:,6] kp_all[i]['value'][:,8] = kp_emo[i]['value'][:,3] + kp_driving_a[i]['value'][:,8] elif self.train_params['type'] == 'map_10': loss_jacobian += (torch.abs(kp_driving[i]['jacobian'] - kp_driving_a[i]['jacobian'] -kp_emo[i]['jacobian']).mean())*self.loss_weights['emo'] loss_classify += self.CroEn_loss(fakes[i],x['emotion']) loss_value += (torch.abs(kp_driving[i]['value'].detach() - kp_driving_a[i]['value'] - kp_emo[i]['value'] ).mean())*self.loss_weights['emo'] # kp_all[i]['value'] = kp_emo[i]['value'] + kp_driving_a[i]['value'] loss_values['loss_value'] = loss_value/len(kp_driving) # loss_values['loss_heatmap'] = loss_heatmap/len(kp_driving) loss_values['loss_jacobian'] = loss_jacobian/len(kp_driving) loss_values['loss_classify'] = loss_classify/len(kp_driving) if self.train_params['generator'] == 'not': loss_values['perceptual'] = self.mse_loss_fn(deco_out,deco_out) for i in range(1): #0,len(kp_driving),4 generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_all[i]) generated.update({'kp_source': kp_source, 'kp_driving': kp_all}) elif self.train_params['generator'] == 'visual': for i in range(0,len(kp_driving),4): #0,len(kp_driving),4 generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving[i]) generated.update({'kp_source': kp_source, 'kp_driving': kp_driving}) pyramide_real = self.pyramid(x['driving'][:,i]) pyramide_generated = self.pyramid(generated['prediction']) if sum(self.loss_weights['perceptual']) != 0: value_total = 0 for scale in self.scales: x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) for i, weight in enumerate(self.loss_weights['perceptual']): value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() value_total += self.loss_weights['perceptual'][i] * value loss_perceptual += value_total length = int((len(kp_driving)-1)/4)+1 loss_values['perceptual'] = loss_perceptual/length elif self.train_params['generator'] == 'audio': for i in range(0,len(kp_driving),4): #0,len(kp_driving),4 generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving_a[i]) generated.update({'kp_source': kp_source, 'kp_driving': kp_driving_a}) pyramide_real = self.pyramid(x['driving'][:,i]) pyramide_generated = self.pyramid(generated['prediction']) if sum(self.loss_weights['perceptual']) != 0: value_total = 0 for scale in self.scales: x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) for i, weight in enumerate(self.loss_weights['perceptual']): value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() value_total += self.loss_weights['perceptual'][i] * value loss_perceptual += value_total length = int((len(kp_driving)-1)/4)+1 loss_values['perceptual'] = loss_perceptual/length else: print('wrong train_params: ', self.train_params['generator']) return loss_values,generated class GeneratorFullModel(torch.nn.Module): """ Merge all generator related updates into single model for better multi-gpu usage """ def __init__(self, kp_extractor, kp_extractor_a, audio_feature, generator, discriminator, train_params): super(GeneratorFullModel, self).__init__() self.kp_extractor = kp_extractor self.kp_extractor_a = kp_extractor_a # self.content_encoder = content_encoder # self.emotion_encoder = emotion_encoder self.audio_feature = audio_feature self.generator = generator self.discriminator = discriminator self.train_params = train_params self.scales = train_params['scales'] self.disc_scales = self.discriminator.scales self.pyramid = ImagePyramide(self.scales, generator.num_channels) if torch.cuda.is_available(): self.pyramid = self.pyramid.cuda() self.loss_weights = train_params['loss_weights'] if sum(self.loss_weights['perceptual']) != 0: self.vgg = Vgg19() if torch.cuda.is_available(): self.vgg = self.vgg.cuda() self.pca = torch.FloatTensor(np.load('.../LRW/list/U_106.npy'))[:, :16].cuda() self.mean = torch.FloatTensor(np.load('.../LRW/list/mean_106.npy')).cuda() def forward(self, x): # source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[]) # source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1))) # kp_source = self.kp_extractor(x['source']) # kp_source_a = self.kp_extractor_a(x['source'], x['source_cube'], source_a_f) # driving_a_f = self.audio_feature(self.content_encoder(x['driving_audio'].unsqueeze(1)), self.emotion_encoder(x['driving_audio'].unsqueeze(1))) # driving_a_f = self.audio_feature(x['driving_audio']) # kp_driving = self.kp_extractor(x['driving']) # kp_driving_a = self.kp_extractor_a(x['driving'], x['driving_cube'], driving_a_f) kp_driving = [] for i in range(16): kp_driving.append(self.kp_extractor(x['driving'][:,i],x['driving_landmark'][:,i],self.loss_weights['equivariance_value'])) kp_driving_a = [] fc_out, deco_out = self.audio_feature(x['example_landmark'], x['driving_audio'], x['driving_pose']) fake_lmark=fc_out + x['example_landmark'].expand_as(fc_out) fake_lmark = torch.mm( fake_lmark, self.pca.t() ) fake_lmark = fake_lmark + self.mean.expand_as(fake_lmark) fake_lmark = fake_lmark.unsqueeze(0) # for i in range(16): # kp_driving_a.append() # generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving) # generated.update({'kp_source': kp_source, 'kp_driving': kp_driving}) loss_values = {} pyramide_real = self.pyramid(x['driving']) pyramide_generated = self.pyramid(generated['prediction']) if self.loss_weights['audio'] != 0: value = torch.abs(kp_source['jacobian'].detach() - kp_source_a['jacobian'].detach()).mean() + torch.abs(kp_driving['jacobian'].detach() - kp_driving_a['jacobian']).mean() value = value/2 loss_values['jacobian'] = value*self.loss_weights['audio'] value = torch.abs(kp_source['heatmap'].detach() - kp_source_a['heatmap'].detach()).mean() + torch.abs(kp_driving['heatmap'].detach() - kp_driving_a['heatmap']).mean() value = value/2 loss_values['heatmap'] = value*self.loss_weights['audio'] value = torch.abs(kp_source['value'].detach() - kp_source_a['value'].detach()).mean() + torch.abs(kp_driving['value'].detach() - kp_driving_a['value']).mean() value = value/2 loss_values['value'] = value*self.loss_weights['audio'] if sum(self.loss_weights['perceptual']) != 0: value_total = 0 for scale in self.scales: x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) for i, weight in enumerate(self.loss_weights['perceptual']): value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() value_total += self.loss_weights['perceptual'][i] * value loss_values['perceptual'] = value_total if self.loss_weights['generator_gan'] != 0: discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving)) discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving)) value_total = 0 for scale in self.disc_scales: key = 'prediction_map_%s' % scale value = ((1 - discriminator_maps_generated[key]) ** 2).mean() value_total += self.loss_weights['generator_gan'] * value loss_values['gen_gan'] = value_total if sum(self.loss_weights['feature_matching']) != 0: value_total = 0 for scale in self.disc_scales: key = 'feature_maps_%s' % scale for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])): if self.loss_weights['feature_matching'][i] == 0: continue value = torch.abs(a - b).mean() value_total += self.loss_weights['feature_matching'][i] * value loss_values['feature_matching'] = value_total if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0: transform = Transform(x['driving'].shape[0], **self.train_params['transform_params']) transformed_frame = transform.transform_frame(x['driving']) transformed_landmark = transform.inverse_warp_coordinates(x['driving_landmark']) transformed_kp = self.kp_extractor(transformed_frame) generated['transformed_frame'] = transformed_frame generated['transformed_kp'] = transformed_kp ## Value loss part if self.loss_weights['equivariance_value'] != 0: value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean() loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value ## jacobian loss part if self.loss_weights['equivariance_jacobian'] != 0: jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']), transformed_kp['jacobian']) normed_driving = torch.inverse(kp_driving['jacobian']) normed_transformed = jacobian_transformed value = torch.matmul(normed_driving, normed_transformed) eye = torch.eye(2).view(1, 1, 2, 2).type(value.type()) value = torch.abs(eye - value).mean() loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value return loss_values, generated class DiscriminatorFullModel(torch.nn.Module): """ Merge all discriminator related updates into single model for better multi-gpu usage """ def __init__(self, kp_extractor, generator, discriminator, train_params): super(DiscriminatorFullModel, self).__init__() self.kp_extractor = kp_extractor self.generator = generator self.discriminator = discriminator self.train_params = train_params self.scales = self.discriminator.scales self.pyramid = ImagePyramide(self.scales, generator.num_channels) if torch.cuda.is_available(): self.pyramid = self.pyramid.cuda() self.loss_weights = train_params['loss_weights'] def forward(self, x, generated): pyramide_real = self.pyramid(x['driving']) pyramide_generated = self.pyramid(generated['prediction'].detach()) kp_driving = generated['kp_driving'] discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving)) discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving)) loss_values = {} value_total = 0 for scale in self.scales: key = 'prediction_map_%s' % scale value = (1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2 value_total += self.loss_weights['discriminator_gan'] * value.mean() loss_values['disc_gan'] = value_total return loss_values ================================================ FILE: modules/model_gen.py ================================================ from torch import nn import torch import torch.nn.functional as F from modules.util import AntiAliasInterpolation2d, make_coordinate_grid from torchvision import models import numpy as np from torch.autograd import grad class Vgg19(torch.nn.Module): """ Vgg19 network for perceptual loss. See Sec 3.3. """ def __init__(self, requires_grad=False): super(Vgg19, self).__init__() vgg_pretrained_features = models.vgg19(pretrained=True).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() for x in range(2): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(2, 7): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(7, 12): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(12, 21): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(21, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), requires_grad=False) self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), requires_grad=False) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): X = (X - self.mean) / self.std h_relu1 = self.slice1(X) h_relu2 = self.slice2(h_relu1) h_relu3 = self.slice3(h_relu2) h_relu4 = self.slice4(h_relu3) h_relu5 = self.slice5(h_relu4) out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] return out class ImagePyramide(torch.nn.Module): """ Create image pyramide for computing pyramide perceptual loss. See Sec 3.3 """ def __init__(self, scales, num_channels): super(ImagePyramide, self).__init__() downs = {} for scale in scales: downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale) self.downs = nn.ModuleDict(downs) def forward(self, x): out_dict = {} for scale, down_module in self.downs.items(): out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x) return out_dict class Transform: """ Random tps transformation for equivariance constraints. See Sec 3.3 """ def __init__(self, bs, **kwargs): noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3])) self.theta = noise + torch.eye(2, 3).view(1, 2, 3) self.bs = bs if ('sigma_tps' in kwargs) and ('points_tps' in kwargs): self.tps = True self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type()) self.control_points = self.control_points.unsqueeze(0) self.control_params = torch.normal(mean=0, std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2])) else: self.tps = False def transform_frame(self, frame): grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2] grid = grid.view(1, frame.shape[2] * frame.shape[3], 2) grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2) return F.grid_sample(frame, grid, padding_mode="reflection") def inverse_transform_frame(self, frame): grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2] grid = grid.view(1, frame.shape[2] * frame.shape[3], 2) grid = self.inverse_warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2) return F.grid_sample(frame, grid, padding_mode="reflection") def warp_coordinates(self, coordinates): theta = self.theta.type(coordinates.type()) theta = theta.unsqueeze(1) transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:] transformed = transformed.squeeze(-1) if self.tps: control_points = self.control_points.type(coordinates.type()) control_params = self.control_params.type(coordinates.type()) distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2) distances = torch.abs(distances).sum(-1) result = distances ** 2 result = result * torch.log(distances + 1e-6) result = result * control_params result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1) transformed = transformed + result return transformed def inverse_warp_coordinates(self, coordinates): theta = self.theta.type(coordinates.type()) theta = theta.unsqueeze(1) a = torch.FloatTensor([[[[0,0,1]]]]).repeat([self.bs,1,1,1]).cuda() c = torch.cat((theta,a),2) d = c.inverse()[:,:,:2,:] d = d.type(coordinates.type()) transformed = torch.matmul(d[:, :, :, :2], coordinates.unsqueeze(-1)) + d[:, :, :, 2:] transformed = transformed.squeeze(-1) if self.tps: control_points = self.control_points.type(coordinates.type()) control_params = self.control_params.type(coordinates.type()) distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2) distances = torch.abs(distances).sum(-1) result = distances ** 2 result = result * torch.log(distances + 1e-6) result = result * control_params result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1) transformed = transformed + result return transformed def jacobian(self, coordinates): coordinates.requires_grad=True new_coordinates = self.warp_coordinates(coordinates)#[4,10,2] grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True) grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True) jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2) return jacobian def detach_kp(kp): return {key: value.detach() for key, value in kp.items()} class TrainFullModel(torch.nn.Module): """ Merge all generator related updates into single model for better multi-gpu usage """ def __init__(self, kp_extractor, emo_feature, kp_extractor_a, audio_feature, generator, discriminator, train_params, device_ids): super(TrainFullModel, self).__init__() self.kp_extractor = kp_extractor self.kp_extractor_a = kp_extractor_a # self.emo_detector = emo_detector # self.content_encoder = content_encoder # self.emotion_encoder = emotion_encoder self.audio_feature = audio_feature self.emo_feature = emo_feature self.generator = generator self.discriminator = discriminator self.train_params = train_params self.scales = train_params['scales'] self.disc_scales = self.discriminator.scales self.pyramid = ImagePyramide(self.scales, generator.num_channels) if torch.cuda.is_available(): self.pyramid = self.pyramid.cuda() self.loss_weights = train_params['loss_weights'] if sum(self.loss_weights['perceptual']) != 0: self.vgg = Vgg19() if torch.cuda.is_available(): self.vgg = self.vgg.cuda() # self.pca = torch.FloatTensor(np.load('/mnt/lustre/jixinya/Home/LRW/list/U_106.npy'))[:, :16].to(device_ids[0]) # self.mean = torch.FloatTensor(np.load('/mnt/lustre/jixinya/Home/LRW/list/mean_106.npy')).to(device_ids[0]) self.mse_loss_fn = nn.MSELoss().cuda() self.CroEn_loss = nn.CrossEntropyLoss().cuda() def forward(self, x): # source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[]) # source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1))) kp_source = self.kp_extractor(x['example_image']) # print(x['name'],len(x['name'])) kp_driving = [] kp_emo = [] for i in range(16): kp_driving.append(self.kp_extractor(x['driving'][:,i])) # kp_emo.append(self.emo_detector(x['driving'][:,i])) # print('KP_driving ', file=open('/mnt/lustre/jixinya/Home/fomm_audio/log/LRW_test.txt', 'a')) kp_driving_a = [] #x['example_image'], deco_out = self.audio_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net']) # emo_out = self.emo_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net']) loss_values = {} if self.loss_weights['emo'] != 0: kp_driving_a = [] fakes = [] for i in range(16): kp_driving_a.append(self.kp_extractor_a(deco_out[:,i]))# value = self.kp_extractor_a(deco_out[:,i])['value'] jacobian = self.kp_extractor_a(deco_out[:,i])['jacobian'] if self.train_params['type'] == 'linear_4' and x['name'][0] == 0: out, fake = self.emo_feature(x['transformed_driving'][:,i],value,jacobian) kp_emo.append(out) fakes.append(fake) # kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian)) elif self.train_params['type'] == 'linear_10' and x['name'][0] == 0: # kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)) out, fake = self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian) kp_emo.append(out) fakes.append(fake) elif self.train_params['type'] == 'linear_4_new' and x['name'][0] == 0: # kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)) out, fake = self.emo_feature.linear_4(x['transformed_driving'][:,i],value,jacobian) kp_emo.append(out) fakes.append(fake) elif self.train_params['type'] == 'linear_np_4': # kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)) out, fake = self.emo_feature.linear_np_4(x['transformed_driving'][:,i],value,jacobian) kp_emo.append(out) fakes.append(fake) elif self.train_params['type'] == 'linear_np_10': # kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)) out, fake = self.emo_feature.linear_np_10(x['transformed_driving'][:,i],value,jacobian) kp_emo.append(out) fakes.append(fake) # kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian)) # print('Kp_audio_driving ', file=open('/mnt/lustre/jixinya/Home/fomm_audio/log/LRW_test.txt', 'a')) loss_perceptual = 0 kp_all = kp_driving_a if self.train_params['smooth'] == True: value_all = torch.randn(len(kp_driving),out['value'].shape[0],out['value'].shape[1],out['value'].shape[2]).cuda() jacobian_all = torch.randn(len(kp_driving),out['jacobian'].shape[0],out['jacobian'].shape[1],2,2).cuda() print(len(kp_driving)) for i in range(len(kp_driving)): # if x['name'][i] == 'LRW': # loss_jacobian += (torch.abs(kp_driving[i]['jacobian'] - kp_driving_a[i]['jacobian']).mean())*self.loss_weights['emo'] # loss_value += (torch.abs(kp_driving[i]['value'].detach() - kp_driving_a[i]['value']).mean())*self.loss_weights['emo'] # loss_classify += self.mse_loss_fn(deco_out,deco_out) if self.train_params['type'] == 'linear_4' and x['name'][0] == 0: kp_all[i]['jacobian'][:,1] = kp_emo[i]['jacobian'][:,0] + kp_driving_a[i]['jacobian'][:,1] kp_all[i]['jacobian'][:,4] = kp_emo[i]['jacobian'][:,1] + kp_driving_a[i]['jacobian'][:,4] kp_all[i]['jacobian'][:,6] = kp_emo[i]['jacobian'][:,2] + kp_driving_a[i]['jacobian'][:,6] kp_all[i]['jacobian'][:,8] = kp_emo[i]['jacobian'][:,3] + kp_driving_a[i]['jacobian'][:,8] kp_all[i]['value'][:,1] = kp_emo[i]['value'][:,0] + kp_driving_a[i]['value'][:,1] kp_all[i]['value'][:,4] = kp_emo[i]['value'][:,1] + kp_driving_a[i]['value'][:,4] kp_all[i]['value'][:,6] = kp_emo[i]['value'][:,2] + kp_driving_a[i]['value'][:,6] kp_all[i]['value'][:,8] = kp_emo[i]['value'][:,3] + kp_driving_a[i]['value'][:,8] # kp_all[i]['value'] = kp_emo[i]['value'] + kp_driving_a[i]['value'] if self.train_params['smooth'] == True: loss_smooth = 0 loss_smooth += (torch.abs(value_all[2:,:,:,:] + value_all[:-2,:,:,:].detach() -2*value_all[1:-1,:,:,:].detach()).mean())*self.loss_weights['emo'] *100 loss_smooth += (torch.abs(jacobian_all[2:,:,:,:] + jacobian_all[:-2,:,:,:].detach() -2*jacobian_all[1:-1,:,:,:].detach()).mean())*self.loss_weights['emo'] *100 loss_values['loss_smooth'] = loss_smooth/len(kp_driving) else: loss_values['loss_smooth'] = self.mse_loss_fn(deco_out,deco_out) if self.train_params['generator'] == 'not': loss_values['perceptual'] = self.mse_loss_fn(deco_out,deco_out) for i in range(1): #0,len(kp_driving),4 generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_all[i]) generated.update({'kp_source': kp_source, 'kp_driving': kp_all}) elif self.train_params['generator'] == 'visual': for i in range(0,len(kp_driving),4): #0,len(kp_driving),4 generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving[i]) generated.update({'kp_source': kp_source, 'kp_driving': kp_driving}) pyramide_real = self.pyramid(x['driving'][:,i]) pyramide_generated = self.pyramid(generated['prediction']) if sum(self.loss_weights['perceptual']) != 0: value_total = 0 for scale in self.scales: x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) for i, weight in enumerate(self.loss_weights['perceptual']): value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() value_total += self.loss_weights['perceptual'][i] * value loss_perceptual += value_total length = int((len(kp_driving)-1)/4)+1 loss_values['perceptual'] = loss_perceptual/length elif self.train_params['generator'] == 'audio': for i in range(0,len(kp_driving),4): #0,len(kp_driving),4 generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_all[i]) generated.update({'kp_source': kp_source, 'kp_driving': kp_all}) pyramide_real = self.pyramid(x['driving'][:,i]) pyramide_generated = self.pyramid(generated['prediction']) # loss_mse = nn.MSELoss(generated['prediction'],x['driving'][:,i]) if sum(self.loss_weights['perceptual']) != 0: value_total = 0 for scale in self.scales: x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) for i, weight in enumerate(self.loss_weights['perceptual']): value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() value_total += self.loss_weights['perceptual'][i] * value loss_perceptual += value_total length = int((len(kp_driving)-1)/4)+1 loss_values['perceptual'] = loss_perceptual/length # loss_values['mse'] = loss_mse/length else: print('wrong train_params: ', self.train_params['generator']) return loss_values,generated class GeneratorFullModel(torch.nn.Module): """ Merge all generator related updates into single model for better multi-gpu usage """ def __init__(self, kp_extractor, kp_extractor_a, audio_feature, generator, discriminator, train_params): super(GeneratorFullModel, self).__init__() self.kp_extractor = kp_extractor self.kp_extractor_a = kp_extractor_a # self.content_encoder = content_encoder # self.emotion_encoder = emotion_encoder self.audio_feature = audio_feature self.generator = generator self.discriminator = discriminator self.train_params = train_params self.scales = train_params['scales'] self.disc_scales = self.discriminator.scales self.pyramid = ImagePyramide(self.scales, generator.num_channels) if torch.cuda.is_available(): self.pyramid = self.pyramid.cuda() self.loss_weights = train_params['loss_weights'] if sum(self.loss_weights['perceptual']) != 0: self.vgg = Vgg19() if torch.cuda.is_available(): self.vgg = self.vgg.cuda() self.pca = torch.FloatTensor(np.load('.../LRW/list/U_106.npy'))[:, :16].cuda() self.mean = torch.FloatTensor(np.load('.../LRW/list/mean_106.npy')).cuda() def forward(self, x): # source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[]) # source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1))) # kp_source = self.kp_extractor(x['source']) # kp_source_a = self.kp_extractor_a(x['source'], x['source_cube'], source_a_f) # driving_a_f = self.audio_feature(self.content_encoder(x['driving_audio'].unsqueeze(1)), self.emotion_encoder(x['driving_audio'].unsqueeze(1))) # driving_a_f = self.audio_feature(x['driving_audio']) # kp_driving = self.kp_extractor(x['driving']) # kp_driving_a = self.kp_extractor_a(x['driving'], x['driving_cube'], driving_a_f) kp_driving = [] for i in range(16): kp_driving.append(self.kp_extractor(x['driving'][:,i],x['driving_landmark'][:,i],self.loss_weights['equivariance_value'])) kp_driving_a = [] fc_out, deco_out = self.audio_feature(x['example_landmark'], x['driving_audio'], x['driving_pose']) fake_lmark=fc_out + x['example_landmark'].expand_as(fc_out) fake_lmark = torch.mm( fake_lmark, self.pca.t() ) fake_lmark = fake_lmark + self.mean.expand_as(fake_lmark) fake_lmark = fake_lmark.unsqueeze(0) # for i in range(16): # kp_driving_a.append() # generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving) # generated.update({'kp_source': kp_source, 'kp_driving': kp_driving}) loss_values = {} pyramide_real = self.pyramid(x['driving']) pyramide_generated = self.pyramid(generated['prediction']) if self.loss_weights['audio'] != 0: value = torch.abs(kp_source['jacobian'].detach() - kp_source_a['jacobian'].detach()).mean() + torch.abs(kp_driving['jacobian'].detach() - kp_driving_a['jacobian']).mean() value = value/2 loss_values['jacobian'] = value*self.loss_weights['audio'] value = torch.abs(kp_source['heatmap'].detach() - kp_source_a['heatmap'].detach()).mean() + torch.abs(kp_driving['heatmap'].detach() - kp_driving_a['heatmap']).mean() value = value/2 loss_values['heatmap'] = value*self.loss_weights['audio'] value = torch.abs(kp_source['value'].detach() - kp_source_a['value'].detach()).mean() + torch.abs(kp_driving['value'].detach() - kp_driving_a['value']).mean() value = value/2 loss_values['value'] = value*self.loss_weights['audio'] if sum(self.loss_weights['perceptual']) != 0: value_total = 0 for scale in self.scales: x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) for i, weight in enumerate(self.loss_weights['perceptual']): value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() value_total += self.loss_weights['perceptual'][i] * value loss_values['perceptual'] = value_total if self.loss_weights['generator_gan'] != 0: discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving)) discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving)) value_total = 0 for scale in self.disc_scales: key = 'prediction_map_%s' % scale value = ((1 - discriminator_maps_generated[key]) ** 2).mean() value_total += self.loss_weights['generator_gan'] * value loss_values['gen_gan'] = value_total if sum(self.loss_weights['feature_matching']) != 0: value_total = 0 for scale in self.disc_scales: key = 'feature_maps_%s' % scale for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])): if self.loss_weights['feature_matching'][i] == 0: continue value = torch.abs(a - b).mean() value_total += self.loss_weights['feature_matching'][i] * value loss_values['feature_matching'] = value_total if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0: transform = Transform(x['driving'].shape[0], **self.train_params['transform_params']) transformed_frame = transform.transform_frame(x['driving']) transformed_landmark = transform.inverse_warp_coordinates(x['driving_landmark']) transformed_kp = self.kp_extractor(transformed_frame) generated['transformed_frame'] = transformed_frame generated['transformed_kp'] = transformed_kp ## Value loss part if self.loss_weights['equivariance_value'] != 0: value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean() loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value ## jacobian loss part if self.loss_weights['equivariance_jacobian'] != 0: jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']), transformed_kp['jacobian']) normed_driving = torch.inverse(kp_driving['jacobian']) normed_transformed = jacobian_transformed value = torch.matmul(normed_driving, normed_transformed) eye = torch.eye(2).view(1, 1, 2, 2).type(value.type()) value = torch.abs(eye - value).mean() loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value return loss_values, generated class DiscriminatorFullModel(torch.nn.Module): """ Merge all discriminator related updates into single model for better multi-gpu usage """ def __init__(self, kp_extractor, generator, discriminator, train_params): super(DiscriminatorFullModel, self).__init__() self.kp_extractor = kp_extractor self.generator = generator self.discriminator = discriminator self.train_params = train_params self.scales = self.discriminator.scales self.pyramid = ImagePyramide(self.scales, generator.num_channels) if torch.cuda.is_available(): self.pyramid = self.pyramid.cuda() self.loss_weights = train_params['loss_weights'] def forward(self, x, generated): pyramide_real = self.pyramid(x['driving']) pyramide_generated = self.pyramid(generated['prediction'].detach()) kp_driving = generated['kp_driving'] discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving)) discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving)) loss_values = {} value_total = 0 for scale in self.scales: key = 'prediction_map_%s' % scale value = (1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2 value_total += self.loss_weights['discriminator_gan'] * value.mean() loss_values['disc_gan'] = value_total return loss_values ================================================ FILE: modules/ops.py ================================================ import torch import torchvision import torch.nn as nn import torch.nn.init as init from torch.autograd import Variable def linear(channel_in, channel_out, activation=nn.ReLU, normalizer=nn.BatchNorm1d): layer = list() bias = True if not normalizer else False layer.append(nn.Linear(channel_in, channel_out, bias=bias)) _apply(layer, activation, normalizer, channel_out) # init.kaiming_normal(layer[0].weight) return nn.Sequential(*layer) def conv2d(channel_in, channel_out, ksize=3, stride=1, padding=1, activation=nn.ReLU, normalizer=nn.BatchNorm2d): layer = list() bias = True if not normalizer else False layer.append(nn.Conv2d(channel_in, channel_out, ksize, stride, padding, bias=bias)) _apply(layer, activation, normalizer, channel_out) # init.kaiming_normal(layer[0].weight) return nn.Sequential(*layer) def conv_transpose2d(channel_in, channel_out, ksize=4, stride=2, padding=1, activation=nn.ReLU, normalizer=nn.BatchNorm2d): layer = list() bias = True if not normalizer else False layer.append(nn.ConvTranspose2d(channel_in, channel_out, ksize, stride, padding, bias=bias)) _apply(layer, activation, normalizer, channel_out) # init.kaiming_normal(layer[0].weight) return nn.Sequential(*layer) def nn_conv2d(channel_in, channel_out, ksize=3, stride=1, padding=1, scale_factor=2, activation=nn.ReLU, normalizer=nn.BatchNorm2d): layer = list() bias = True if not normalizer else False layer.append(nn.UpsamplingNearest2d(scale_factor=scale_factor)) layer.append(nn.Conv2d(channel_in, channel_out, ksize, stride, padding, bias=bias)) _apply(layer, activation, normalizer, channel_out) # init.kaiming_normal(layer[1].weight) return nn.Sequential(*layer) def _apply(layer, activation, normalizer, channel_out=None): if normalizer: layer.append(normalizer(channel_out)) if activation: layer.append(activation()) return layer ================================================ FILE: modules/stylegan2.py ================================================ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Thu Jul 8 01:03:50 2021 @author: thea """ """ The network architectures is based on PyTorch implemenation of StyleGAN2Encoder. Original PyTorch repo: https://github.com/rosinality/style-based-gan-pytorch Origianl StyelGAN2 paper: https://github.com/NVlabs/stylegan2 We use the network architeture for our single-image traning setting. """ import math import numpy as np import random import torch from torch import nn from torch.nn import functional as F def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): return F.leaky_relu(input + bias, negative_slope) * scale class FusedLeakyReLU(nn.Module): def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): super().__init__() self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1)) self.negative_slope = negative_slope self.scale = scale def forward(self, input): # print("FusedLeakyReLU: ", input.abs().mean()) out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) # print("FusedLeakyReLU: ", out.abs().mean()) return out def upfirdn2d_native( input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 ): _, minor, in_h, in_w = input.shape kernel_h, kernel_w = kernel.shape out = input.view(-1, minor, in_h, 1, in_w, 1) out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0]) out = out.view(-1, minor, in_h * up_y, in_w * up_x) out = F.pad( out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] ) out = out[ :, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0), max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ] # out = out.permute(0, 3, 1, 2) out = out.reshape( [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] ) w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) out = F.conv2d(out, w) out = out.reshape( -1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, ) # out = out.permute(0, 2, 3, 1) return out[:, :, ::down_y, ::down_x] def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) class PixelNorm(nn.Module): def __init__(self): super().__init__() def forward(self, input): return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) def make_kernel(k): k = torch.tensor(k, dtype=torch.float32) if len(k.shape) == 1: k = k[None, :] * k[:, None] k /= k.sum() return k class Upsample(nn.Module): def __init__(self, kernel, factor=2): super().__init__() self.factor = factor kernel = make_kernel(kernel) * (factor ** 2) self.register_buffer('kernel', kernel) p = kernel.shape[0] - factor pad0 = (p + 1) // 2 + factor - 1 pad1 = p // 2 self.pad = (pad0, pad1) def forward(self, input): out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) return out class Downsample(nn.Module): def __init__(self, kernel, factor=2): super().__init__() self.factor = factor kernel = make_kernel(kernel) self.register_buffer('kernel', kernel) p = kernel.shape[0] - factor pad0 = (p + 1) // 2 pad1 = p // 2 self.pad = (pad0, pad1) def forward(self, input): out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) return out class Blur(nn.Module): def __init__(self, kernel, pad, upsample_factor=1): super().__init__() kernel = make_kernel(kernel) if upsample_factor > 1: kernel = kernel * (upsample_factor ** 2) self.register_buffer('kernel', kernel) self.pad = pad def forward(self, input): out = upfirdn2d(input, self.kernel, pad=self.pad) return out class EqualConv2d(nn.Module): def __init__( self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True ): super().__init__() self.weight = nn.Parameter( torch.randn(out_channel, in_channel, kernel_size, kernel_size) ) self.scale = math.sqrt(1) / math.sqrt(in_channel * (kernel_size ** 2)) self.stride = stride self.padding = padding if bias: self.bias = nn.Parameter(torch.zeros(out_channel)) else: self.bias = None def forward(self, input): # print("Before EqualConv2d: ", input.abs().mean()) out = F.conv2d( input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding, ) # print("After EqualConv2d: ", out.abs().mean(), (self.weight * self.scale).abs().mean()) return out def __repr__(self): return ( f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' ) class EqualLinear(nn.Module): def __init__( self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None ): super().__init__() self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) if bias: self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) else: self.bias = None self.activation = activation self.scale = (math.sqrt(1) / math.sqrt(in_dim)) * lr_mul self.lr_mul = lr_mul def forward(self, input): if self.activation: out = F.linear(input, self.weight * self.scale) out = fused_leaky_relu(out, self.bias * self.lr_mul) else: out = F.linear( input, self.weight * self.scale, bias=self.bias * self.lr_mul ) return out def __repr__(self): return ( f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' ) class ScaledLeakyReLU(nn.Module): def __init__(self, negative_slope=0.2): super().__init__() self.negative_slope = negative_slope def forward(self, input): out = F.leaky_relu(input, negative_slope=self.negative_slope) return out * math.sqrt(2) class ModulatedConv2d(nn.Module): def __init__( self, in_channel, out_channel, kernel_size, style_dim, demodulate=True, upsample=False, downsample=False, blur_kernel=[1, 3, 3, 1], ): super().__init__() self.eps = 1e-8 self.kernel_size = kernel_size self.in_channel = in_channel self.out_channel = out_channel self.upsample = upsample self.downsample = downsample if upsample: factor = 2 p = (len(blur_kernel) - factor) - (kernel_size - 1) pad0 = (p + 1) // 2 + factor - 1 pad1 = p // 2 + 1 self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) if downsample: factor = 2 p = (len(blur_kernel) - factor) + (kernel_size - 1) pad0 = (p + 1) // 2 pad1 = p // 2 self.blur = Blur(blur_kernel, pad=(pad0, pad1)) fan_in = in_channel * kernel_size ** 2 self.scale = math.sqrt(1) / math.sqrt(fan_in) self.padding = kernel_size // 2 self.weight = nn.Parameter( torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) ) if style_dim is not None and style_dim > 0: self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) self.demodulate = demodulate def __repr__(self): return ( f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' f'upsample={self.upsample}, downsample={self.downsample})' ) def forward(self, input, style): batch, in_channel, height, width = input.shape if style is not None: style = self.modulation(style).view(batch, 1, in_channel, 1, 1) else: style = torch.ones(batch, 1, in_channel, 1, 1).cuda() weight = self.scale * self.weight * style if self.demodulate: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) weight = weight.view( batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size ) if self.upsample: input = input.view(1, batch * in_channel, height, width) weight = weight.view( batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size ) weight = weight.transpose(1, 2).reshape( batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size ) out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) out = self.blur(out) elif self.downsample: input = self.blur(input) _, _, height, width = input.shape input = input.view(1, batch * in_channel, height, width) out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) else: input = input.view(1, batch * in_channel, height, width) out = F.conv2d(input, weight, padding=self.padding, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) return out class NoiseInjection(nn.Module): def __init__(self): super().__init__() self.weight = nn.Parameter(torch.zeros(1)) def forward(self, image, noise=None): if noise is None: batch, _, height, width = image.shape noise = image.new_empty(batch, 1, height, width).normal_() return image + self.weight * noise class ConstantInput(nn.Module): def __init__(self, channel, size=4): super().__init__() self.input = nn.Parameter(torch.randn(1, channel, size, size)) def forward(self, input): batch = input.shape[0] out = self.input.repeat(batch, 1, 1, 1) return out class StyledConv(nn.Module): def __init__( self, in_channel, out_channel, kernel_size, style_dim=None, upsample=False, blur_kernel=[1, 3, 3, 1], demodulate=True, inject_noise=False, #True ): super().__init__() self.inject_noise = inject_noise self.conv = ModulatedConv2d( in_channel, out_channel, kernel_size, style_dim, upsample=upsample, blur_kernel=blur_kernel, demodulate=demodulate, ) self.noise = NoiseInjection() # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) # self.activate = ScaledLeakyReLU(0.2) self.activate = FusedLeakyReLU(out_channel) def forward(self, input, style=None, noise=None): out = self.conv(input, style) if self.inject_noise: out = self.noise(out, noise=noise) # out = out + self.bias out = self.activate(out) return out class ToRGB(nn.Module): def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): super().__init__() if upsample: self.upsample = Upsample(blur_kernel) self.conv = ModulatedConv2d(in_channel, 3+32, 1, style_dim, demodulate=False) self.bias = nn.Parameter(torch.zeros(1, 3+32, 1, 1)) def forward(self, input, style, skip=None): out = self.conv(input, style) out = out + self.bias if skip is not None: skip = self.upsample(skip) out = out + skip return out class Generator(nn.Module): def __init__( self, size, style_dim, n_mlp, channel_multiplier=1, blur_kernel=[1, 3, 3, 1], lr_mlp=0.01, ): super().__init__() self.size = size self.style_dim = style_dim layers = [PixelNorm()] for i in range(n_mlp): layers.append( EqualLinear( style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' ) ) self.style = nn.Sequential(*layers) self.channels = { 4: 256, 8: 256, 16: 128, 32: 64, 64: 32 * channel_multiplier, 128: 16 * channel_multiplier, 256: 8 * channel_multiplier, 512: 4 * channel_multiplier, 1024: 2 * channel_multiplier, } self.input = ConstantInput(self.channels[4]) self.conv1 = StyledConv( self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel ) self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) self.log_size = int(math.log(size, 2)) self.num_layers = (self.log_size - 2) * 2 + 1 self.convs = nn.ModuleList() self.upsamples = nn.ModuleList() self.to_rgbs = nn.ModuleList() self.noises = nn.Module() in_channel = self.channels[4] for layer_idx in range(self.num_layers): res = (layer_idx + 5) // 2 shape = [1, 1, 2 ** res, 2 ** res] self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape)) for i in range(3, self.log_size + 1): out_channel = self.channels[2 ** i] self.convs.append( StyledConv( in_channel, out_channel, 3, style_dim, upsample=True, blur_kernel=blur_kernel, ) ) self.convs.append( StyledConv( out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel ) ) self.to_rgbs.append(ToRGB(out_channel, style_dim)) in_channel = out_channel self.n_latent = self.log_size * 2 - 2 def make_noise(self): device = self.input.input.device noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] for i in range(3, self.log_size + 1): for _ in range(2): noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) return noises def mean_latent(self, n_latent): latent_in = torch.randn( n_latent, self.style_dim, device=self.input.input.device ) latent = self.style(latent_in).mean(0, keepdim=True) return latent def get_latent(self, input): return self.style(input) def forward( self, styles, return_latents=False, inject_index=None, truncation=1, truncation_latent=None, input_is_latent=False, noise=None, randomize_noise=True, ): if not input_is_latent: styles = [self.style(s) for s in styles] if noise is None: if randomize_noise: noise = [None] * self.num_layers else: noise = [ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) ] if truncation < 1: style_t = [] for style in styles: style_t.append( truncation_latent + truncation * (style - truncation_latent) ) styles = style_t if len(styles) < 2: inject_index = self.n_latent if len(styles[0].shape) < 3: latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) else: latent = styles[0] else: if inject_index is None: inject_index = random.randint(1, self.n_latent - 1) latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) latent = torch.cat([latent, latent2], 1) # out = self.input(latent) out = styles[0].unsqueeze(-1).unsqueeze(-1).repeat(1,1,4,4) out = self.conv1(out, latent[:, 0], noise=noise[0]) skip = self.to_rgb1(out, latent[:, 1]) i = 1 for conv1, conv2, noise1, noise2, to_rgb in zip( self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs ): out = conv1(out, latent[:, i], noise=noise1) out = conv2(out, latent[:, i + 1], noise=noise2) skip = to_rgb(out, latent[:, i + 2], skip) i += 2 image = skip if return_latents: return image, latent else: return image, None class ConvLayer(nn.Sequential): def __init__( self, in_channel, out_channel, kernel_size, downsample=False, blur_kernel=[1, 3, 3, 1], bias=True, activate=True, ): layers = [] if downsample: factor = 2 p = (len(blur_kernel) - factor) + (kernel_size - 1) pad0 = (p + 1) // 2 pad1 = p // 2 layers.append(Blur(blur_kernel, pad=(pad0, pad1))) stride = 2 self.padding = 0 else: stride = 1 self.padding = kernel_size // 2 layers.append( EqualConv2d( in_channel, out_channel, kernel_size, padding=self.padding, stride=stride, bias=bias and not activate, ) ) if activate: if bias: layers.append(FusedLeakyReLU(out_channel)) else: layers.append(ScaledLeakyReLU(0.2)) super().__init__(*layers) class ResBlock(nn.Module): def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], downsample=True, skip_gain=1.0): super().__init__() self.skip_gain = skip_gain self.conv1 = ConvLayer(in_channel, in_channel, 3) self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=downsample, blur_kernel=blur_kernel) if in_channel != out_channel or downsample: self.skip = ConvLayer( in_channel, out_channel, 1, downsample=downsample, activate=False, bias=False ) else: self.skip = nn.Identity() def forward(self, input): out = self.conv1(input) out = self.conv2(out) skip = self.skip(input) out = (out * self.skip_gain + skip) / math.sqrt(self.skip_gain ** 2 + 1.0) return out class StyleGAN2Discriminator(nn.Module): def __init__(self, input_nc, ndf=64, n_layers=3, no_antialias=False, size=None, opt=None): super().__init__() self.opt = opt self.stddev_group = 16 if size is None: size = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size))))) if "patch" in self.opt.netD and self.opt.D_patch_size is not None: size = 2 ** int(np.log2(self.opt.D_patch_size)) blur_kernel = [1, 3, 3, 1] channel_multiplier = ndf / 64 channels = { 4: min(384, int(4096 * channel_multiplier)), 8: min(384, int(2048 * channel_multiplier)), 16: min(384, int(1024 * channel_multiplier)), 32: min(384, int(512 * channel_multiplier)), 64: int(256 * channel_multiplier), 128: int(128 * channel_multiplier), 256: int(64 * channel_multiplier), 512: int(32 * channel_multiplier), 1024: int(16 * channel_multiplier), } convs = [ConvLayer(3, channels[size], 1)] log_size = int(math.log(size, 2)) in_channel = channels[size] if "smallpatch" in self.opt.netD: final_res_log2 = 4 elif "patch" in self.opt.netD: final_res_log2 = 3 else: final_res_log2 = 2 for i in range(log_size, final_res_log2, -1): out_channel = channels[2 ** (i - 1)] convs.append(ResBlock(in_channel, out_channel, blur_kernel)) in_channel = out_channel self.convs = nn.Sequential(*convs) if False and "tile" in self.opt.netD: in_channel += 1 self.final_conv = ConvLayer(in_channel, channels[4], 3) if "patch" in self.opt.netD: self.final_linear = ConvLayer(channels[4], 1, 3, bias=False, activate=False) else: self.final_linear = nn.Sequential( EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), EqualLinear(channels[4], 1), ) def forward(self, input, get_minibatch_features=False): if "patch" in self.opt.netD and self.opt.D_patch_size is not None: h, w = input.size(2), input.size(3) y = torch.randint(h - self.opt.D_patch_size, ()) x = torch.randint(w - self.opt.D_patch_size, ()) input = input[:, :, y:y + self.opt.D_patch_size, x:x + self.opt.D_patch_size] out = input for i, conv in enumerate(self.convs): out = conv(out) # print(i, out.abs().mean()) # out = self.convs(input) batch, channel, height, width = out.shape if False and "tile" in self.opt.netD: group = min(batch, self.stddev_group) stddev = out.view( group, -1, 1, channel // 1, height, width ) stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) stddev = stddev.mean([2, 3, 4], keepdim=True).squeeze(2) stddev = stddev.repeat(group, 1, height, width) out = torch.cat([out, stddev], 1) out = self.final_conv(out) # print(out.abs().mean()) if "patch" not in self.opt.netD: out = out.view(batch, -1) out = self.final_linear(out) return out class TileStyleGAN2Discriminator(StyleGAN2Discriminator): def forward(self, input): B, C, H, W = input.size(0), input.size(1), input.size(2), input.size(3) size = self.opt.D_patch_size Y = H // size X = W // size input = input.view(B, C, Y, size, X, size) input = input.permute(0, 2, 4, 1, 3, 5).contiguous().view(B * Y * X, C, size, size) return super().forward(input) class StyleGAN2Encoder(nn.Module): def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None): super().__init__() assert opt is not None self.opt = opt channel_multiplier = ngf / 32 channels = { 4: min(512, int(round(4096 * channel_multiplier))), 8: min(512, int(round(2048 * channel_multiplier))), 16: min(512, int(round(1024 * channel_multiplier))), 32: min(512, int(round(512 * channel_multiplier))), 64: int(round(256 * channel_multiplier)), 128: int(round(128 * channel_multiplier)), 256: int(round(64 * channel_multiplier)), 512: int(round(32 * channel_multiplier)), 1024: int(round(16 * channel_multiplier)), } blur_kernel = [1, 3, 3, 1] cur_res = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size))))) convs = [nn.Identity(), ConvLayer(3, channels[cur_res], 1)] num_downsampling = self.opt.stylegan2_G_num_downsampling for i in range(num_downsampling): in_channel = channels[cur_res] out_channel = channels[cur_res // 2] convs.append(ResBlock(in_channel, out_channel, blur_kernel, downsample=True)) cur_res = cur_res // 2 for i in range(n_blocks // 2): n_channel = channels[cur_res] convs.append(ResBlock(n_channel, n_channel, downsample=False)) self.convs = nn.Sequential(*convs) def forward(self, input, layers=[], get_features=False): feat = input feats = [] if -1 in layers: layers.append(len(self.convs) - 1) for layer_id, layer in enumerate(self.convs): feat = layer(feat) # print(layer_id, " features ", feat.abs().mean()) if layer_id in layers: feats.append(feat) if get_features: return feat, feats else: return feat class StyleGAN2Decoder(nn.Module): def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None): super().__init__() assert opt is not None self.opt = opt blur_kernel = [1, 3, 3, 1] channel_multiplier = ngf / 32 channels = { 4: min(512, int(round(4096 * channel_multiplier))), 8: min(512, int(round(2048 * channel_multiplier))), 16: min(512, int(round(1024 * channel_multiplier))), 32: min(512, int(round(512 * channel_multiplier))), 64: int(round(256 * channel_multiplier)), 128: int(round(128 * channel_multiplier)), 256: int(round(64 * channel_multiplier)), 512: int(round(32 * channel_multiplier)), 1024: int(round(16 * channel_multiplier)), } num_downsampling = self.opt.stylegan2_G_num_downsampling cur_res = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size))))) // (2 ** num_downsampling) convs = [] for i in range(n_blocks // 2): n_channel = channels[cur_res] convs.append(ResBlock(n_channel, n_channel, downsample=False)) for i in range(num_downsampling): in_channel = channels[cur_res] out_channel = channels[cur_res * 2] inject_noise = "small" not in self.opt.netG convs.append( StyledConv(in_channel, out_channel, 3, upsample=True, blur_kernel=blur_kernel, inject_noise=inject_noise) ) cur_res = cur_res * 2 convs.append(ConvLayer(channels[cur_res], 3, 1)) self.convs = nn.Sequential(*convs) def forward(self, input): return self.convs(input) class StyleGAN2Generator(nn.Module): def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None): super().__init__() self.opt = opt self.encoder = StyleGAN2Encoder(input_nc, output_nc, ngf, use_dropout, n_blocks, padding_type, no_antialias, opt) self.decoder = StyleGAN2Decoder(input_nc, output_nc, ngf, use_dropout, n_blocks, padding_type, no_antialias, opt) def forward(self, input, layers=[], encode_only=False): feat, feats = self.encoder(input, layers, True) if encode_only: return feats else: fake = self.decoder(feat) if len(layers) > 0: return fake, feats else: return fake ================================================ FILE: modules/util.py ================================================ from torch import nn import torch.nn.functional as F import torch import numpy as np import cv2 from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d from modules.stylegan2 import Generator import torch.nn as nn import math import torch.utils.model_zoo as model_zoo from modules.function import adaptive_instance_normalization as adain import pdb # Misc img2mse = lambda x, y : torch.mean((x - y) ** 2) mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) class InstanceNorm(nn.Module): def __init__(self, epsilon=1e-8): """ @notice: avoid in-place ops. https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3 """ super(InstanceNorm, self).__init__() self.epsilon = epsilon def forward(self, x): x = x - torch.mean(x, (2, 3), True) tmp = torch.mul(x, x) # or x ** 2 tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon) return x * tmp class ApplyStyle(nn.Module): """ @ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb """ def __init__(self, latent_size, channels, use_wscale): super(ApplyStyle, self).__init__() self.linear = FC(latent_size, channels * 2, gain=1.0, use_wscale=use_wscale) def forward(self, x, latent): style = self.linear(latent) # style => [batch_size, n_channels*2] shape = [-1, 2, x.size(1), 1, 1] style = style.view(shape) # [batch_size, 2, n_channels, ...] x = x * (style[:, 0] + 1.) + style[:, 1] return x class FC(nn.Module): def __init__(self, in_channels, out_channels, gain=2**(0.5), use_wscale=False, lrmul=1.0, bias=True): """ The complete conversion of Dense/FC/Linear Layer of original Tensorflow version. """ super(FC, self).__init__() he_std = gain * in_channels ** (-0.5) # He init if use_wscale: init_std = 1.0 / lrmul self.w_lrmul = he_std * lrmul else: init_std = he_std / lrmul self.w_lrmul = lrmul self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels) * init_std) if bias: self.bias = torch.nn.Parameter(torch.zeros(out_channels)) self.b_lrmul = lrmul else: self.bias = None def forward(self, x): if self.bias is not None: out = F.linear(x, self.weight * self.w_lrmul, self.bias * self.b_lrmul) else: out = F.linear(x, self.weight * self.w_lrmul) out = F.leaky_relu(out, 0.2, inplace=True) return out # Positional encoding (section 5.1) class Embedder: def __init__(self, **kwargs): self.kwargs = kwargs self.create_embedding_fn() def create_embedding_fn(self): embed_fns = [] d = self.kwargs['input_dims'] out_dim = 0 if self.kwargs['include_input']: embed_fns.append(lambda x : x) out_dim += d max_freq = self.kwargs['max_freq_log2'] N_freqs = self.kwargs['num_freqs'] if self.kwargs['log_sampling']: freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) else: freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) for freq in freq_bands: for p_fn in self.kwargs['periodic_fns']: embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq)) out_dim += d self.embed_fns = embed_fns self.out_dim = out_dim def embed(self, inputs): return torch.cat([fn(inputs) for fn in self.embed_fns], -1) def get_embedder(multires, i=0): if i == -1: return nn.Identity(), 6 embed_kwargs = { 'include_input' : True, 'input_dims' : 6, 'max_freq_log2' : multires-1, 'num_freqs' : multires, 'log_sampling' : True, 'periodic_fns' : [torch.sin, torch.cos], } embedder_obj = Embedder(**embed_kwargs) embed = lambda x, eo=embedder_obj : eo.embed(x) return embed, embedder_obj.out_dim def draw_heatmap(landmark, width, height): batch = landmark.shape[0] number = landmark.shape[1] heatmap = np.zeros((batch, number,width, height), dtype=np.float32) # draw mouth from mouth landmarks, landmarks: mouth landmark points, format: x1, y1, x2, y2, ..., x20, landmark = (landmark+1)*29 for i in range(batch): for pts_idx in range(number): if int(landmark[i,pts_idx,0])<0: landmark[i,pts_idx,0] = 0 if int(landmark[i,pts_idx,1])<0: landmark[i,pts_idx,1] = 0 if int(landmark[i,pts_idx,0])>57: landmark[i,pts_idx,0] = 57 if int(landmark[i,pts_idx,1])>57: landmark[i,pts_idx,1] = 57 heatmap[i,pts_idx, int(landmark[i,pts_idx,1]), int(landmark[i,pts_idx,0])]=1 if heatmap[i,pts_idx].sum()== 1 : heatmap[i,pts_idx] = cv2.GaussianBlur(heatmap[i,pts_idx], ksize=(3, 3), sigmaX=1, sigmaY=1) heatmap = torch.tensor(heatmap).cuda() return heatmap class NA_net(nn.Module): def __init__(self): super(NA_net, self).__init__() self.decon = nn.Sequential( nn.ConvTranspose2d(1, 16, kernel_size=(2,3), stride=2, padding=(2,1), bias=True),#16,16 nn.BatchNorm2d(16), nn.ReLU(True), nn.ConvTranspose2d(16, 32, kernel_size=4, stride=2, padding=1, bias=True),#8,8 nn.BatchNorm2d(32), nn.ReLU(True), nn.ConvTranspose2d(32, 32+3, kernel_size=4, stride=2, padding=1, bias=True)#16,16 ) def forward(self, neutral): feature = neutral.unsqueeze(1) current_feature = self.decon(feature) return current_feature class AT_net(nn.Module): def __init__(self): super(AT_net, self).__init__() down_blocks = [] for i in range(8): down_blocks.append(DownBlock2d(3 if i == 0 else 2 * (2 ** i), 2 * (2 ** (i + 1)), kernel_size=3, padding=1)) self.down_blocks = nn.ModuleList(down_blocks) # self.lmark_encoder = nn.Sequential( # nn.Linear(16,256), # nn.ReLU(True), # nn.Linear(256,512), # nn.ReLU(True), # ) self.pose_encoder = nn.Sequential( nn.Linear(6,128), nn.ReLU(True), nn.Linear(128,256), nn.ReLU(True), ) self.audio_eocder = nn.Sequential( conv2d(1,64,3,1,1), conv2d(64,128,3,1,1), nn.MaxPool2d(3, stride=(1,2)), conv2d(128,256,3,1,1), conv2d(256,256,3,1,1), conv2d(256,512,3,1,1), nn.MaxPool2d(3, stride=(2,2)) ) self.audio_eocder_fc = nn.Sequential( nn.Linear(1024 *12,2048), nn.ReLU(True), nn.Linear(2048,256), nn.ReLU(True), ) self.lstm = nn.LSTM(256*4,256,3,batch_first = True) # self.lstm_fc = nn.Sequential( # nn.Linear(256,16), # ) self.decon = nn.Sequential( nn.ConvTranspose2d(256, 256, kernel_size=6, stride=2, padding=1, bias=True),#4,4 nn.BatchNorm2d(256), nn.ReLU(True), nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),#8,8 nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True), #16,16 nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),#32,32 nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 32+3, kernel_size=4, stride=2, padding=1, bias=True),#64,64 # nn.ConvTranspose2d(128, 32*4, kernel_size=2, stride=2, padding=3, bias=True),#64,64 ) self.generator = Generator(64,256,8) def forward(self, example_image, audio, pose, jaco_net): hidden = ( torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda()), torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda())) outs = example_image for down_block in self.down_blocks: outs = down_block(outs) image_feature = outs image_feature = image_feature.view(image_feature.shape[0], -1) lstm_input = [] for step_t in range(audio.size(1)): current_audio = audio[ : ,step_t , :, :].unsqueeze(1) current_feature = self.audio_eocder(current_audio) current_feature = current_feature.view(current_feature.size(0), -1) current_feature = self.audio_eocder_fc(current_feature) pose_f = self.pose_encoder(pose[:,step_t]) features = torch.cat([image_feature, current_feature, pose_f], 1) lstm_input.append(features) lstm_input = torch.stack(lstm_input, dim = 1) lstm_out, _ = self.lstm(lstm_input, hidden) fc_out = [] deco_out = [] for step_t in range(audio.size(1)): fc_in = lstm_out[:,step_t,:] # fc_out.append(self.lstm_fc(fc_in)) if jaco_net == 'cnn': fc_feature = torch.unsqueeze(fc_in,2) fc_feature = torch.unsqueeze(fc_feature,3) deco_out.append(self.decon(fc_feature)) elif jaco_net == 'gan': result,_ = self.generator([fc_in]) deco_out.append(result) else: raise Exception("jaco_net type wrong") return torch.stack(deco_out,dim=1) class Classify(nn.Module): def __init__(self): super(Classify, self).__init__() self.last_fc = nn.Linear(512,8) def forward(self, feature): # mfcc= torch.unsqueeze(mfcc, 1) x = self.last_fc(feature) return x class TF_net(nn.Module): def __init__(self): super(TF_net, self).__init__() down_blocks = [] for i in range(8): down_blocks.append(DownBlock2d(3 if i == 0 else 2 * (2 ** i), 2 * (2 ** (i + 1)), kernel_size=3, padding=1)) self.down_blocks = nn.ModuleList(down_blocks) # self.lmark_encoder = nn.Sequential( # nn.Linear(16,256), # nn.ReLU(True), # nn.Linear(256,512), # nn.ReLU(True), # ) self.pose_encoder = nn.Sequential( nn.Linear(6,128), nn.ReLU(True), nn.Linear(128,256), nn.ReLU(True), ) self.audio_eocder = nn.Sequential( conv2d(1,64,3,1,1), conv2d(64,128,3,1,1), nn.MaxPool2d(3, stride=(1,2)), conv2d(128,256,3,1,1), conv2d(256,256,3,1,1), conv2d(256,512,3,1,1), nn.MaxPool2d(3, stride=(2,2)) ) self.audio_eocder_fc = nn.Sequential( nn.Linear(1024 *12,2048), nn.ReLU(True), nn.Linear(2048,256), nn.ReLU(True), ) self.lstm = nn.LSTM(256*4,256,3,batch_first = True) self.lstm_two = nn.LSTM(256*6,256,3,batch_first = True) # self.lstm_fc = nn.Sequential( # nn.Linear(256,16), # ) self.decon = nn.Sequential( nn.ConvTranspose2d(256, 256, kernel_size=6, stride=2, padding=1, bias=True),#4,4 nn.BatchNorm2d(256), nn.ReLU(True), nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),#8,8 nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True), #16,16 nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),#32,32 nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 32+3, kernel_size=4, stride=2, padding=1, bias=True),#64,64 # nn.ConvTranspose2d(128, 32*4, kernel_size=2, stride=2, padding=3, bias=True),#64,64 ) self.generator = Generator(64,256,8) self.instance_norm = InstanceNorm() self.style_mod = ApplyStyle(512, 1024, use_wscale=True) self.style_mod1 = ApplyStyle(512, 35, use_wscale=True) def adain_forward(self, example_image, audio, pose, jaco_net, emo_features): hidden = ( torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda()), torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda())) outs = example_image for down_block in self.down_blocks: outs = down_block(outs) image_feature = outs image_feature = image_feature.view(image_feature.shape[0], -1) lstm_input = [] for step_t in range(audio.size(1)): current_audio = audio[ : ,step_t , :, :].unsqueeze(1) current_feature = self.audio_eocder(current_audio) current_feature = current_feature.view(current_feature.size(0), -1) current_feature = self.audio_eocder_fc(current_feature) #256 pose_f = self.pose_encoder(pose[:,step_t]) #256 features = torch.cat([image_feature, current_feature, pose_f], 1) features = torch.unsqueeze(torch.unsqueeze(features,-1),-1) features = self.instance_norm(features) x = self.style_mod(features, emo_features[step_t]) # t = adain(torch.unsqueeze(torch.unsqueeze(features,-1),-1), torch.unsqueeze(torch.unsqueeze(emo_features[step_t],1),2)) lstm_input.append(torch.squeeze(torch.squeeze(x,-1),-1)) lstm_input = torch.stack(lstm_input, dim = 1) lstm_out, _ = self.lstm(lstm_input, hidden) # fc_out = [] deco_out = [] for step_t in range(audio.size(1)): fc_in = lstm_out[:,step_t,:] # fc_out.append(self.lstm_fc(fc_in)) if jaco_net == 'cnn': fc_feature = torch.unsqueeze(fc_in,2) fc_feature = torch.unsqueeze(fc_feature,3) deco_out.append(self.decon(fc_feature)) elif jaco_net == 'gan': result,_ = self.generator([fc_in]) deco_out.append(result) else: raise Exception("jaco_net type wrong") return torch.stack(deco_out,dim=1) def adain_feature2(self, example_image, audio, pose, jaco_net, emo_features): hidden = ( torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda()), torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda())) outs = example_image for down_block in self.down_blocks: outs = down_block(outs) image_feature = outs image_feature = image_feature.view(image_feature.shape[0], -1) lstm_input = [] for step_t in range(audio.size(1)): current_audio = audio[ : ,step_t , :, :].unsqueeze(1) current_feature = self.audio_eocder(current_audio) current_feature = current_feature.view(current_feature.size(0), -1) current_feature = self.audio_eocder_fc(current_feature) #256 pose_f = self.pose_encoder(pose[:,step_t]) #256 features = torch.cat([image_feature, current_feature, pose_f], 1) lstm_input.append(features) lstm_input = torch.stack(lstm_input, dim = 1) lstm_out, _ = self.lstm(lstm_input, hidden) # fc_out = [] deco_out = [] for step_t in range(audio.size(1)): fc_in = lstm_out[:,step_t,:] # fc_out.append(self.lstm_fc(fc_in)) if jaco_net == 'cnn': fc_feature = torch.unsqueeze(fc_in,2) fc_feature = torch.unsqueeze(fc_feature,3) fc_feature = self.decon(fc_feature) fc_feature = self.instance_norm(fc_feature) t = self.style_mod1(fc_feature, emo_features[step_t]) # emo_feature = torch.unsqueeze(torch.unsqueeze(emo_features[step_t],-1),-1) # emo_feature = emo_feature.repeat(1,fc_feature.shape[1],1,1) # t = adain(fc_feature, emo_feature) deco_out.append(t) elif jaco_net == 'gan': result,_ = self.generator([fc_in]) deco_out.append(result) else: raise Exception("jaco_net type wrong") return torch.stack(deco_out,dim=1) def forward(self, example_image, audio, pose, jaco_net, emo_features): hidden = ( torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda()), torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda())) outs = example_image for down_block in self.down_blocks: outs = down_block(outs) image_feature = outs image_feature = image_feature.view(image_feature.shape[0], -1) lstm_input = [] for step_t in range(audio.size(1)): current_audio = audio[ : ,step_t , :, :].unsqueeze(1) current_feature = self.audio_eocder(current_audio) current_feature = current_feature.view(current_feature.size(0), -1) current_feature = self.audio_eocder_fc(current_feature) #256 pose_f = self.pose_encoder(pose[:,step_t]) #256 features = torch.cat([image_feature, current_feature, pose_f, emo_features[step_t]], 1) lstm_input.append(features) lstm_input = torch.stack(lstm_input, dim = 1) lstm_out, _ = self.lstm_two(lstm_input, hidden) fc_out = [] deco_out = [] for step_t in range(audio.size(1)): fc_in = lstm_out[:,step_t,:] # fc_out.append(self.lstm_fc(fc_in)) if jaco_net == 'cnn': fc_feature = torch.unsqueeze(fc_in,2) fc_feature = torch.unsqueeze(fc_feature,3) deco_out.append(self.decon(fc_feature)) elif jaco_net == 'gan': result,_ = self.generator([fc_in]) deco_out.append(result) else: raise Exception("jaco_net type wrong") return torch.stack(deco_out,dim=1) class AT_net2(nn.Module): def __init__(self): super(AT_net2, self).__init__() down_blocks = [] for i in range(8): down_blocks.append(DownBlock2d(3 if i == 0 else 2 * (2 ** i), 2 * (2 ** (i + 1)), kernel_size=3, padding=1)) self.down_blocks = nn.ModuleList(down_blocks) # self.lmark_encoder = nn.Sequential( # nn.Linear(16,256), # nn.ReLU(True), # nn.Linear(256,512), # nn.ReLU(True), # ) self.pose_encoder = nn.Sequential( nn.Linear(6,128), nn.ReLU(True), nn.Linear(128,256), nn.ReLU(True), ) self.audio_eocder = nn.Sequential( conv2d(1,64,3,1,1), conv2d(64,128,3,1,1), nn.MaxPool2d(3, stride=(1,2)), conv2d(128,256,3,1,1), conv2d(256,256,3,1,1), conv2d(256,512,3,1,1), nn.MaxPool2d(3, stride=(2,2)) ) self.audio_eocder_fc = nn.Sequential( nn.Linear(1024 *12,2048), nn.ReLU(True), nn.Linear(2048,256), nn.ReLU(True), ) self.lstm = nn.LSTM(256*4,256,3,batch_first = True) # self.lstm_fc = nn.Sequential( # nn.Linear(256,16), # ) self.decon = nn.Sequential( nn.ConvTranspose2d(256, 256, kernel_size=6, stride=2, padding=1, bias=True),#4,4 nn.BatchNorm2d(256), nn.ReLU(True), nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),#8,8 nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True), #16,16 nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),#32,32 nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 32+3, kernel_size=4, stride=2, padding=1, bias=True),#64,64 # nn.ConvTranspose2d(128, 32*4, kernel_size=2, stride=2, padding=3, bias=True),#64,64 ) self.generator = Generator(64,256,8) def forward(self, example_image, audio, pose, jaco_net, weight): hidden = ( torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda()), torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda())) outs = example_image for down_block in self.down_blocks: outs = down_block(outs) image_feature = outs image_feature = image_feature.view(image_feature.shape[0], -1) lstm_input = [] for step_t in range(audio.size(1)): current_audio = audio[ : ,step_t , :, :].unsqueeze(1) current_feature = self.audio_eocder(current_audio) current_feature = current_feature.view(current_feature.size(0), -1) current_feature = self.audio_eocder_fc(current_feature)*weight pose_f = self.pose_encoder(pose[:,step_t]) features = torch.cat([image_feature, current_feature, pose_f], 1) lstm_input.append(features) lstm_input = torch.stack(lstm_input, dim = 1) lstm_out, _ = self.lstm(lstm_input, hidden) fc_out = [] deco_out = [] for step_t in range(audio.size(1)): fc_in = lstm_out[:,step_t,:] # fc_out.append(self.lstm_fc(fc_in)) if jaco_net == 'cnn': fc_feature = torch.unsqueeze(fc_in,2) fc_feature = torch.unsqueeze(fc_feature,3) deco_out.append(self.decon(fc_feature)) elif jaco_net == 'gan': result,_ = self.generator([fc_in]) deco_out.append(result) else: raise Exception("jaco_net type wrong") return torch.stack(deco_out,dim=1) class Ct_encoder(nn.Module): def __init__(self): super(Ct_encoder, self).__init__() self.audio_eocder = nn.Sequential( conv2d(1,64,3,1,1), conv2d(64,128,3,1,1), nn.MaxPool2d(3, stride=(1,2)), conv2d(128,256,3,1,1), conv2d(256,256,3,1,1), conv2d(256,512,3,1,1), nn.MaxPool2d(3, stride=(2,2)) ) self.audio_eocder_fc = nn.Sequential( nn.Linear(1024 *12,2048), nn.ReLU(True), nn.Linear(2048,256), nn.ReLU(True), ) def forward(self, audio): feature = self.audio_eocder(audio) feature = feature.view(feature.size(0),-1) x = self.audio_eocder_fc(feature) return x class EmotionNet(nn.Module): def __init__(self): super(EmotionNet, self).__init__() self.emotion_eocder = nn.Sequential( conv2d(1,64,3,1,1), nn.MaxPool2d((1,3), stride=(1,2)), #[1, 64, 12, 12] conv2d(64,128,3,1,1), conv2d(128,256,3,1,1), nn.MaxPool2d((12,1), stride=(12,1)), #[1, 256, 1, 12] conv2d(256,512,3,1,1), nn.MaxPool2d((1,2), stride=(1,2)) #[1, 512, 1, 6] ) self.emotion_eocder_fc = nn.Sequential( nn.Linear(512 *6,2048), nn.ReLU(True), nn.Linear(2048,128), nn.ReLU(True), ) self.last_fc = nn.Linear(128,8) self.re_id = nn.Sequential( conv2d(512,1024,3,1,1), nn.MaxPool2d((1,2), stride=(1,2)), #[1, 1024, 1, 3] conv2d(1024,1024,3,1,1), conv2d(1024,2048,3,1,1), nn.MaxPool2d((1,2), stride=(1,2)) #[1, 2048, 1, 1] ) self.re_id_fc = nn.Sequential( nn.Linear(2048,512), nn.ReLU(True), nn.Linear(512,128), nn.ReLU(True), ) def forward(self, mfcc): # mfcc= torch.unsqueeze(mfcc, 1) mfcc=torch.transpose(mfcc,2,3) feature = self.emotion_eocder(mfcc) # id_feature = feature.detach() feature = feature.view(feature.size(0),-1) x = self.emotion_eocder_fc(feature) # remove_feature = self.re_id(id_feature) # remove_feature = remove_feature.view(remove_feature.size(0),-1) # y = self.re_id_fc(remove_feature) return x class AF2F(nn.Module): def __init__(self): super(AF2F, self).__init__() self.decon = nn.Sequential( nn.ConvTranspose2d(384, 256, kernel_size=6, stride=2, padding=1, bias=True),#4,4 nn.BatchNorm2d(256), nn.ReLU(True), nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),#8,8 nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=True), #16,16 nn.BatchNorm2d(64), nn.ReLU(True), nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True),#32,32 nn.BatchNorm2d(64), nn.ReLU(True), nn.ConvTranspose2d(64, 32+3, kernel_size=4, stride=2, padding=1, bias=True),#64,64 ) def forward(self, content,emotion): features = torch.cat([content, emotion], 1) #connect tensors inputs and dimension features = torch.unsqueeze(features,2) features = torch.unsqueeze(features,3) x = self.decon(features) return x class AF2F_s(nn.Module): def __init__(self): super(AF2F_s, self).__init__() self.decon = nn.Sequential( nn.ConvTranspose2d(256, 256, kernel_size=6, stride=2, padding=1, bias=True),#4,4 nn.BatchNorm2d(256), nn.ReLU(True), nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),#8,8 nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=True), #16,16 nn.BatchNorm2d(64), nn.ReLU(True), nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True),#32,32 nn.BatchNorm2d(64), nn.ReLU(True), nn.ConvTranspose2d(64, 32+3, kernel_size=4, stride=2, padding=1, bias=True),#64,64 nn.ReLU(), ) def forward(self, content): # features = torch.cat([content, emotion], 1) #connect tensors inputs and dimension features = torch.unsqueeze(content,2) features = torch.unsqueeze(features,3) x = self.decon(features) return x class A2I(nn.Module): def __init__(self): super(A2I, self).__init__() self.audio_eocder = nn.Sequential( conv2d(1,64,3,1,1), conv2d(64,128,3,1,1), nn.MaxPool2d((1,5), stride=(1,2)), conv2d(128,256,3,1,1), conv2d(256,256,3,1,1), nn.MaxPool2d((5,5), stride=(2,2)) ) self.decon = nn.Sequential( nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),#8,8 nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=True), #16,16 nn.BatchNorm2d(64), nn.ReLU(True), nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, bias=True),#32,32 nn.BatchNorm2d(32), nn.ReLU(True), nn.ConvTranspose2d(32, 2, kernel_size=4, stride=2, padding=1, bias=True),#64,64 nn.ReLU(), ) def forward(self, mfcc): mfcc= torch.unsqueeze(mfcc, 1) mfcc=torch.transpose(mfcc,2,3) feature = self.audio_eocder(mfcc) # id_feature = feature.detach() x = self.decon(feature) return x def kp2gaussian(kp, spatial_size, kp_variance): """ Transform a keypoint into gaussian like representation """ mean = kp['value'] #[4,10,2] coordinate_grid = make_coordinate_grid(spatial_size, mean.type()) #[h,w,2] number_of_leading_dimensions = len(mean.shape) - 1 shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape #5 coordinate_grid = coordinate_grid.view(*shape) #[1,1,h,w,2] repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1) coordinate_grid = coordinate_grid.repeat(*repeats) #[4,10,h,w,2] # Preprocess kp shape shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 2) mean = mean.view(*shape) #[4,10,1,1,2] mean_sub = (coordinate_grid - mean) out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance) return out def make_coordinate_grid(spatial_size, type): """ Create a meshgrid [-1,1] x [-1,1] of given spatial_size. """ h, w = spatial_size x = torch.arange(w).type(type) y = torch.arange(h).type(type) x = (2 * (x / (w - 1)) - 1) y = (2 * (y / (h - 1)) - 1) yy = y.view(-1, 1).repeat(1, w) xx = x.view(1, -1).repeat(h, 1) meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) return meshed class ResBlock2d(nn.Module): """ Res block, preserve spatial resolution. """ def __init__(self, in_features, kernel_size, padding): super(ResBlock2d, self).__init__() self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding) self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding) self.norm1 = BatchNorm2d(in_features, affine=True) self.norm2 = BatchNorm2d(in_features, affine=True) def forward(self, x): out = self.norm1(x) out = F.relu(out) out = self.conv1(out) out = self.norm2(out) out = F.relu(out) out = self.conv2(out) out += x return out class UpBlock2d(nn.Module): """ Upsampling block for use in decoder. """ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): super(UpBlock2d, self).__init__() self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) self.norm = BatchNorm2d(out_features, affine=True) def forward(self, x): out = F.interpolate(x, scale_factor=2) out = self.conv(out) out = self.norm(out) out = F.relu(out) return out class DownBlock2d(nn.Module): """ Downsampling block for use in encoder. """ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): super(DownBlock2d, self).__init__() self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) self.norm = BatchNorm2d(out_features, affine=True) self.pool = nn.AvgPool2d(kernel_size=(2, 2)) def forward(self, x): out = self.conv(x) out = self.norm(out) out = F.relu(out) out = self.pool(out) return out class SameBlock2d(nn.Module): """ Simple block, preserve spatial resolution. """ def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1): super(SameBlock2d, self).__init__() self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) self.norm = BatchNorm2d(out_features, affine=True) def forward(self, x): out = self.conv(x) out = self.norm(out) out = F.relu(out) return out class Encoder(nn.Module): """ Hourglass Encoder """ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): super(Encoder, self).__init__() down_blocks = [] for i in range(num_blocks): down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), min(max_features, block_expansion * (2 ** (i + 1))), kernel_size=3, padding=1)) self.down_blocks = nn.ModuleList(down_blocks) def forward(self, x): outs = [x] for down_block in self.down_blocks: outs.append(down_block(outs[-1])) return outs class Decoder(nn.Module): """ Hourglass Decoder """ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): super(Decoder, self).__init__() up_blocks = [] for i in range(num_blocks)[::-1]: in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1))) out_filters = min(max_features, block_expansion * (2 ** i)) up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1)) self.up_blocks = nn.ModuleList(up_blocks) self.out_filters = block_expansion + in_features def forward(self, x): out = x.pop() for up_block in self.up_blocks: out = up_block(out) skip = x.pop() out = torch.cat([out, skip], dim=1) return out class Hourglass(nn.Module): """ Hourglass architecture. """ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): super(Hourglass, self).__init__() self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) self.out_filters = self.decoder.out_filters def forward(self, x): return self.decoder(self.encoder(x)) class AntiAliasInterpolation2d(nn.Module): """ Band-limited downsampling, for better preservation of the input signal. """ def __init__(self, channels, scale): super(AntiAliasInterpolation2d, self).__init__() # sigma = (1 / scale - 1) / 2 sigma = 1.5 kernel_size = 2 * round(sigma * 4) + 1 self.ka = kernel_size // 2 self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka kernel_size = [kernel_size, kernel_size] sigma = [sigma, sigma] # The gaussian kernel is the product of the # gaussian function of each dimension. kernel = 1 meshgrids = torch.meshgrid( [ torch.arange(size, dtype=torch.float32) for size in kernel_size ] ) for size, std, mgrid in zip(kernel_size, sigma, meshgrids): mean = (size - 1) / 2 kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2)) # Make sure sum of values in gaussian kernel equals 1. kernel = kernel / torch.sum(kernel) # Reshape to depthwise convolutional weight kernel = kernel.view(1, 1, *kernel.size()) kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) self.register_buffer('weight', kernel) self.groups = channels self.scale = scale inv_scale = 1 / scale self.int_inv_scale = int(inv_scale) def forward(self, input): if self.scale == 1.0: return input out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) out = F.conv2d(out, weight=self.weight, groups=self.groups) out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale] return out def sigmoid(x): return 1 / (1 + math.exp(-x)) def norm_angle(angle): norm_angle = sigmoid(10 * (abs(angle) / 0.7853975 - 1)) return norm_angle def conv3x3(in_planes, out_planes, stride=1): "3x3 convolution with padding" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU() self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) self.relu = nn.ReLU() self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out = out + residual out = self.relu(out) return out class EmDetector(nn.Module): """ Detecting a keypoints. Return keypoint position and jacobian near each keypoint. """ def __init__(self, block_expansion, num_channels, max_features, num_blocks, scale_factor=1, num_classes=8): super(EmDetector, self).__init__() self.inplanes = 64 self.predictor = Hourglass(block_expansion, in_features=num_channels, max_features=max_features, num_blocks=num_blocks) self.scale_factor = scale_factor if self.scale_factor != 1: self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) self.conv1 = nn.Conv2d(self.predictor.out_filters, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU() self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) layers = [2,2,2,2] self.layer1 = self._make_layer(BasicBlock, 64, layers[0]) self.layer2 = self._make_layer(BasicBlock, 128, layers[1], stride=2) self.layer3 = self._make_layer(BasicBlock, 256, layers[2], stride=2) self.layer4 = self._make_layer(BasicBlock, 512, layers[3], stride=2) self.avgpool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(512 * BasicBlock.expansion, num_classes) self.classify = Classify() def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) def adain_feature(self, x): #torch.Size([4, 3, H, W]) if self.scale_factor != 1: x = self.down(x) # 0.25 [4, 3, H/4, W/4] feature_map = self.predictor(x) #[4,3+32,H/4, W/4] # out = self.fc(out) return feature_map def forward(self, x): #torch.Size([4, 3, H, W]) if self.scale_factor != 1: x = self.down(x) # 0.25 [4, 3, H/4, W/4] feature_map = self.predictor(x) #[4,3+32,H/4, W/4] f = self.conv1(feature_map) #[16,64,64,64] f = self.bn1(f) #torch.Size([16, 64, 64, 64]) f = self.relu(f) f = self.maxpool(f) #[16, 64, 32, 32] f = self.layer1(f) #[16, 64, 32, 32] f = self.layer2(f) #[16, 128, 16, 16]) f = self.layer3(f) #[16, 256, 8, 8] f = self.layer4(f) #[16, 512, 4, 4] f = self.avgpool(f) #[16, 512, 1, 1] out = f.squeeze(3).squeeze(2) fake = self.classify(out) # out = self.fc(out) return out, fake class Emotion_k(nn.Module): """ Detecting a keypoints. Return keypoint position and jacobian near each keypoint. """ def __init__(self, block_expansion, num_channels, max_features, num_blocks, scale_factor=1, num_classes=8): super(Emotion_k, self).__init__() self.inplanes = 64 self.predictor = Hourglass(block_expansion, in_features=num_channels, max_features=max_features, num_blocks=num_blocks) self.scale_factor = scale_factor if self.scale_factor != 1: self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) self.conv1 = nn.Conv2d(self.predictor.out_filters, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU() self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) layers = [2,2,2,2] self.layer1 = self._make_layer(BasicBlock, 64, layers[0]) self.layer2 = self._make_layer(BasicBlock, 128, layers[1], stride=2) self.layer3 = self._make_layer(BasicBlock, 256, layers[2], stride=2) self.layer4 = self._make_layer(BasicBlock, 512, layers[3], stride=2) self.avgpool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(512 * BasicBlock.expansion, num_classes) self.embed_fn, self.input_ch = get_embedder(10, 0) self.fc_p = nn.Sequential( nn.Linear(10 * 126,1024), nn.ReLU(True), nn.Linear(1024,512), nn.ReLU(True), ) self.fc_n = nn.Sequential( nn.Linear(10 * 6,128), nn.ReLU(True), nn.Linear(128,512), nn.ReLU(True), ) self.fc_all = nn.Sequential( nn.Linear(1024,512), nn.ReLU(True), nn.Linear(512,256), nn.ReLU(True), nn.Linear(256,64), nn.ReLU(True), ) # self.fc_single = nn.Sequential( # nn.Linear(512,256), # nn.ReLU(True), # nn.Linear(256,64), # nn.ReLU(True), # ) self.final = nn.Sequential( nn.Conv1d(1,2,4,2,1), nn.MaxPool1d(2,stride=2), nn.ReLU(True), nn.Conv1d(2,4,4,2,1), nn.ReLU(True), nn.Conv1d(4,4,3), ) self.final_4 = nn.Sequential( nn.Conv1d(4,4,3,1,1), nn.MaxPool1d(2,stride=2), nn.ReLU(True), nn.Conv1d(4,4,3,1) ) self.final_10 = nn.Sequential( nn.Conv1d(4,8,3,1,1), #[B,8,16] nn.MaxPool1d(2,stride=2), #[B,8,8] nn.ReLU(True), nn.Conv1d(8,10,3,1), #[B,10,6] ) self.classify = Classify() def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) def linear_10(self, x, value, jacobian): #torch.Size([4, 3, H, W]) if self.scale_factor != 1: x = self.down(x) # 0.25 [4, 3, H/4, W/4] feature_map = self.predictor(x) #[4,3+32,H/4, W/4] f = self.conv1(feature_map) #[16,64,64,64] f = self.bn1(f) #torch.Size([16, 64, 64, 64]) f = self.relu(f) f = self.maxpool(f) #[16, 64, 32, 32] f = self.layer1(f) #[16, 64, 32, 32] f = self.layer2(f) #[16, 128, 16, 16]) f = self.layer3(f) #[16, 256, 8, 8] f = self.layer4(f) #[16, 512, 4, 4] f = self.avgpool(f) #[16, 512, 1, 1] out = f.squeeze(3).squeeze(2) fake = self.classify(out) jacobian = jacobian.reshape(jacobian.shape[0],jacobian.shape[1],4) neu_input = torch.cat((value,jacobian),2) posi_input = self.embed_fn(neu_input) posi_input =posi_input.reshape(posi_input.shape[0],-1) ner_feature = self.fc_p(posi_input) all_fc = self.fc_all(torch.cat((out,ner_feature),1)).reshape(-1,4,16) result = self.final_10(all_fc) e_value = result[:,:,:2] e_jacobian = result[:,:,2:].reshape(result.shape[0],10,2,2) kp = {'value': e_value,'jacobian': e_jacobian} return kp, fake def linear_4(self, x, value, jacobian): #torch.Size([4, 3, H, W]) if self.scale_factor != 1: x = self.down(x) # 0.25 [4, 3, H/4, W/4] feature_map = self.predictor(x) #[4,3+32,H/4, W/4] f = self.conv1(feature_map) #[16,64,64,64] f = self.bn1(f) #torch.Size([16, 64, 64, 64]) f = self.relu(f) f = self.maxpool(f) #[16, 64, 32, 32] f = self.layer1(f) #[16, 64, 32, 32] f = self.layer2(f) #[16, 128, 16, 16]) f = self.layer3(f) #[16, 256, 8, 8] f = self.layer4(f) #[16, 512, 4, 4] f = self.avgpool(f) #[16, 512, 1, 1] out = f.squeeze(3).squeeze(2) fake = self.classify(out) # jacobian = jacobian.reshape(jacobian.shape[0],jacobian.shape[1],4) # neu_input = torch.cat((value,jacobian),2) # posi_input = self.embed_fn(neu_input) # posi_input =posi_input.reshape(posi_input.shape[0],-1) # ner_feature = self.fc_p(posi_input) # all_fc = self.fc_all(torch.cat((out,ner_feature),1)).reshape(-1,4,16) all_fc = torch.unsqueeze(self.fc_single(out),1) result = self.final(all_fc) e_value = result[:,:,:2] e_jacobian = result[:,:,2:].reshape(result.shape[0],4,2,2) kp = {'value': e_value,'jacobian': e_jacobian} # out = self.fc(out) return kp, fake def linear_np_10(self, x, value, jacobian): #torch.Size([4, 3, H, W]) if self.scale_factor != 1: x = self.down(x) # 0.25 [4, 3, H/4, W/4] feature_map = self.predictor(x) #[4,3+32,H/4, W/4] f = self.conv1(feature_map) #[16,64,64,64] f = self.bn1(f) #torch.Size([16, 64, 64, 64]) f = self.relu(f) f = self.maxpool(f) #[16, 64, 32, 32] f = self.layer1(f) #[16, 64, 32, 32] f = self.layer2(f) #[16, 128, 16, 16]) f = self.layer3(f) #[16, 256, 8, 8] f = self.layer4(f) #[16, 512, 4, 4] f = self.avgpool(f) #[16, 512, 1, 1] out = f.squeeze(3).squeeze(2) fake = self.classify(out) jacobian = jacobian.reshape(jacobian.shape[0],jacobian.shape[1],4) neu_input = torch.cat((value,jacobian),2) posi_input =neu_input.reshape(neu_input.shape[0],-1) ner_feature = self.fc_n(posi_input) all_fc = self.fc_all(torch.cat((out,ner_feature),1)).reshape(-1,4,16) result = self.final_10(all_fc) e_value = result[:,:,:2] e_jacobian = result[:,:,2:].reshape(result.shape[0],10,2,2) kp = {'value': e_value,'jacobian': e_jacobian} # out = self.fc(out) return kp, fake def linear_np_4(self, x, value, jacobian): #torch.Size([4, 3, H, W]) if self.scale_factor != 1: x = self.down(x) # 0.25 [4, 3, H/4, W/4] feature_map = self.predictor(x) #[4,3+32,H/4, W/4] f = self.conv1(feature_map) #[16,64,64,64] f = self.bn1(f) #torch.Size([16, 64, 64, 64]) f = self.relu(f) f = self.maxpool(f) #[16, 64, 32, 32] f = self.layer1(f) #[16, 64, 32, 32] f = self.layer2(f) #[16, 128, 16, 16]) f = self.layer3(f) #[16, 256, 8, 8] f = self.layer4(f) #[16, 512, 4, 4] f = self.avgpool(f) #[16, 512, 1, 1] out = f.squeeze(3).squeeze(2) fake = self.classify(out) jacobian = jacobian.reshape(jacobian.shape[0],jacobian.shape[1],4) neu_input = torch.cat((value,jacobian),2) posi_input =neu_input.reshape(neu_input.shape[0],-1) ner_feature = self.fc_n(posi_input) all_fc = torch.unsqueeze(self.fc_all(torch.cat((out,ner_feature),1)),1) result = self.final(all_fc) e_value = result[:,:,:2] e_jacobian = result[:,:,2:].reshape(result.shape[0],4,2,2) kp = {'value': e_value,'jacobian': e_jacobian} # out = self.fc(out) return kp, fake def emotion_feature(self, feature, value, jacobian): #torch.Size([4, 3, H, W]) out = feature fake = self.classify(out) jacobian = jacobian.reshape(jacobian.shape[0],jacobian.shape[1],4) neu_input = torch.cat((value,jacobian),2) posi_input = self.embed_fn(neu_input) posi_input =posi_input.reshape(posi_input.shape[0],-1) ner_feature = self.fc_p(posi_input) all_fc = torch.unsqueeze(self.fc_all(torch.cat((out,ner_feature),1)),1) result = self.final(all_fc) e_value = result[:,:,:2] e_jacobian = result[:,:,2:].reshape(result.shape[0],4,2,2) kp = {'value': e_value,'jacobian': e_jacobian} # out = self.fc(out) return kp, fake def feature(self, x): #torch.Size([4, 3, H, W]) if self.scale_factor != 1: x = self.down(x) # 0.25 [4, 3, H/4, W/4] feature_map = self.predictor(x) #[4,3+32,H/4, W/4] f = self.conv1(feature_map) #[16,64,64,64] f = self.bn1(f) #torch.Size([16, 64, 64, 64]) f = self.relu(f) f = self.maxpool(f) #[16, 64, 32, 32] f = self.layer1(f) #[16, 64, 32, 32] f = self.layer2(f) #[16, 128, 16, 16]) f = self.layer3(f) #[16, 256, 8, 8] f = self.layer4(f) #[16, 512, 4, 4] f = self.avgpool(f) #[16, 512, 1, 1] out = f.squeeze(3).squeeze(2) # out = self.fc(out) return out def forward(self, x, value, jacobian): #torch.Size([4, 3, H, W]) if self.scale_factor != 1: x = self.down(x) # 0.25 [4, 3, H/4, W/4] feature_map = self.predictor(x) #[4,3+32,H/4, W/4] f = self.conv1(feature_map) #[16,64,64,64] f = self.bn1(f) #torch.Size([16, 64, 64, 64]) f = self.relu(f) f = self.maxpool(f) #[16, 64, 32, 32] f = self.layer1(f) #[16, 64, 32, 32] f = self.layer2(f) #[16, 128, 16, 16]) f = self.layer3(f) #[16, 256, 8, 8] f = self.layer4(f) #[16, 512, 4, 4] f = self.avgpool(f) #[16, 512, 1, 1] out = f.squeeze(3).squeeze(2) fake = self.classify(out) jacobian = jacobian.reshape(jacobian.shape[0],jacobian.shape[1],4) neu_input = torch.cat((value,jacobian),2) posi_input = self.embed_fn(neu_input) posi_input =posi_input.reshape(posi_input.shape[0],-1) ner_feature = self.fc_p(posi_input) all_fc = torch.unsqueeze(self.fc_all(torch.cat((out,ner_feature),1)),1) result = self.final(all_fc) e_value = result[:,:,:2] e_jacobian = result[:,:,2:].reshape(result.shape[0],4,2,2) kp = {'value': e_value,'jacobian': e_jacobian} # out = self.fc(out) return kp, fake class Emotion_map(nn.Module): """ Detecting a keypoints. Return keypoint position and jacobian near each keypoint. """ def __init__(self, block_expansion, num_channels, max_features, num_blocks, scale_factor=1, num_classes=8): super(Emotion_map, self).__init__() self.inplanes = 64 self.predictor = Hourglass(block_expansion, in_features=num_channels, max_features=max_features, num_blocks=num_blocks) self.scale_factor = scale_factor if self.scale_factor != 1: self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) self.conv1 = nn.Conv2d(self.predictor.out_filters, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU() self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) layers = [2,2,2,2] self.layer1 = self._make_layer(BasicBlock, 64, layers[0]) self.layer2 = self._make_layer(BasicBlock, 128, layers[1], stride=2) self.layer3 = self._make_layer(BasicBlock, 256, layers[2], stride=2) self.layer4 = self._make_layer(BasicBlock, 512, layers[3], stride=2) self.avgpool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(512 * BasicBlock.expansion, num_classes) self.embed_fn, self.input_ch = get_embedder(10, 0) self.fc_p = nn.Sequential( nn.Linear(10 * 126,1024), nn.ReLU(True), nn.Linear(1024,512), nn.ReLU(True), ) self.fc_all = nn.Sequential( nn.Linear(1024,2048), nn.ReLU(True) ) self.final = nn.Sequential( nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),#8,8 nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=True), #16,16 nn.BatchNorm2d(64), nn.ReLU(True), nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True),#32,32 nn.BatchNorm2d(64), nn.ReLU(True), nn.ConvTranspose2d(64, 32+3, kernel_size=4, stride=2, padding=1, bias=True),#64,64 ) self.classify = Classify() self.kp = nn.Conv2d(in_channels=35, out_channels=10, kernel_size=(7, 7), padding=0) self.jacobian = nn.Conv2d(in_channels=35, out_channels=4 * 10, kernel_size=(7, 7), padding=0) self.jacobian.weight.data.zero_() self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * 10, dtype=torch.float)) self.temperature = 0.1 self.kp_4 = nn.Conv2d(in_channels=35, out_channels=4, kernel_size=(7, 7), padding=0) self.jacobian_4 = nn.Conv2d(in_channels=35, out_channels=4 * 4, kernel_size=(7, 7), padding=0) self.jacobian_4.weight.data.zero_() self.jacobian_4.bias.data.copy_(torch.tensor([1, 0, 0, 1] * 4, dtype=torch.float)) def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) def gaussian2kp(self, heatmap): """ Extract the mean and from a heatmap """ shape = heatmap.shape heatmap = heatmap.unsqueeze(-1) #[4,10,58,58,1] grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) #[1,1,58,58,2] value = (heatmap * grid).sum(dim=(2, 3)) #[4,10,2] kp = {'value': value} return kp def map_4(self, x, value, jacobian): #torch.Size([4, 3, H, W]) if self.scale_factor != 1: x = self.down(x) # 0.25 [4, 3, H/4, W/4] feature_map = self.predictor(x) #[4,3+32,H/4, W/4] f = self.conv1(feature_map) #[16,64,64,64] f = self.bn1(f) #torch.Size([16, 64, 64, 64]) f = self.relu(f) f = self.maxpool(f) #[16, 64, 32, 32] f = self.layer1(f) #[16, 64, 32, 32] f = self.layer2(f) #[16, 128, 16, 16]) f = self.layer3(f) #[16, 256, 8, 8] f = self.layer4(f) #[16, 512, 4, 4] f = self.avgpool(f) #[16, 512, 1, 1] out = f.squeeze(3).squeeze(2) fake = self.classify(out) jacobian = jacobian.reshape(jacobian.shape[0],jacobian.shape[1],4) neu_input = torch.cat((value,jacobian),2) posi_input = self.embed_fn(neu_input) posi_input =posi_input.reshape(posi_input.shape[0],-1) ner_feature = self.fc_p(posi_input) all_fc = self.fc_all(torch.cat((out,ner_feature),1)).reshape(-1,128,4,4) feature_map = self.final(all_fc) prediction = self.kp_4(feature_map) #[4,10,H/4-6, W/4-6] final_shape = prediction.shape heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58] heatmap = F.softmax(heatmap / self.temperature, dim=2) heatmap = heatmap.view(*final_shape) #[4,10,58,58] out = self.gaussian2kp(heatmap) out['heatmap'] = heatmap if self.jacobian is not None: jacobian_map = self.jacobian_4(feature_map) ##[4,40,H/4-6, W/4-6] jacobian_map = jacobian_map.reshape(final_shape[0], 4, 4, final_shape[2], final_shape[3]) heatmap = heatmap.unsqueeze(2) jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6] jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) jacobian = jacobian.sum(dim=-1) #[4,10,4] jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2] out['jacobian'] = jacobian return out, fake def forward(self, x, value, jacobian): #torch.Size([4, 3, H, W]) if self.scale_factor != 1: x = self.down(x) # 0.25 [4, 3, H/4, W/4] feature_map = self.predictor(x) #[4,3+32,H/4, W/4] f = self.conv1(feature_map) #[16,64,64,64] f = self.bn1(f) #torch.Size([16, 64, 64, 64]) f = self.relu(f) f = self.maxpool(f) #[16, 64, 32, 32] f = self.layer1(f) #[16, 64, 32, 32] f = self.layer2(f) #[16, 128, 16, 16]) f = self.layer3(f) #[16, 256, 8, 8] f = self.layer4(f) #[16, 512, 4, 4] f = self.avgpool(f) #[16, 512, 1, 1] out = f.squeeze(3).squeeze(2) fake = self.classify(out) jacobian = jacobian.reshape(jacobian.shape[0],jacobian.shape[1],4) neu_input = torch.cat((value,jacobian),2) posi_input = self.embed_fn(neu_input) posi_input =posi_input.reshape(posi_input.shape[0],-1) ner_feature = self.fc_p(posi_input) all_fc = self.fc_all(torch.cat((out,ner_feature),1)).reshape(-1,128,4,4) feature_map = self.final(all_fc) prediction = self.kp(feature_map) #[4,10,H/4-6, W/4-6] final_shape = prediction.shape heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58] heatmap = F.softmax(heatmap / self.temperature, dim=2) heatmap = heatmap.view(*final_shape) #[4,10,58,58] out = self.gaussian2kp(heatmap) out['heatmap'] = heatmap if self.jacobian is not None: jacobian_map = self.jacobian(feature_map) ##[4,40,H/4-6, W/4-6] jacobian_map = jacobian_map.reshape(final_shape[0], 10, 4, final_shape[2], final_shape[3]) heatmap = heatmap.unsqueeze(2) jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6] jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) jacobian = jacobian.sum(dim=-1) #[4,10,4] jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2] out['jacobian'] = jacobian return out, fake def conv2d(channel_in, channel_out, ksize=3, stride=1, padding=1, activation=nn.ReLU, normalizer=nn.BatchNorm2d): layer = list() bias = True if not normalizer else False layer.append(nn.Conv2d(channel_in, channel_out, ksize, stride, padding, bias=bias)) _apply(layer, activation, normalizer, channel_out) # init.kaiming_normal(layer[0].weight) return nn.Sequential(*layer) def _apply(layer, activation, normalizer, channel_out=None): if normalizer: layer.append(normalizer(channel_out)) if activation: layer.append(activation()) return layer ================================================ FILE: ops.py ================================================ import torch import torchvision import torch.nn as nn import torch.nn.init as init from torch.autograd import Variable class ResidualBlock(nn.Module): def __init__(self, channel_in, channel_out): super(ResidualBlock, self).__init__() self.block = nn.Sequential( conv3d(channel_in, channel_out, 3, 1, 1), conv3d(channel_out, channel_out, 3, 1, 1, activation=None) ) self.lrelu = nn.ReLU(0.2) def forward(self, x): residual = x out = self.block(x) out += residual out = self.lrelu(out) return out def linear(channel_in, channel_out, activation=nn.ReLU, normalizer=nn.BatchNorm1d): layer = list() bias = True if not normalizer else False layer.append(nn.Linear(channel_in, channel_out, bias=bias)) _apply(layer, activation, normalizer, channel_out) # init.kaiming_normal(layer[0].weight) return nn.Sequential(*layer) def conv2d(channel_in, channel_out, ksize=3, stride=1, padding=1, activation=nn.ReLU, normalizer=nn.BatchNorm2d): layer = list() bias = True if not normalizer else False layer.append(nn.Conv2d(channel_in, channel_out, ksize, stride, padding, bias=bias)) _apply(layer, activation, normalizer, channel_out) # init.kaiming_normal(layer[0].weight) return nn.Sequential(*layer) def conv_transpose2d(channel_in, channel_out, ksize=4, stride=2, padding=1, activation=nn.ReLU, normalizer=nn.BatchNorm2d): layer = list() bias = True if not normalizer else False layer.append(nn.ConvTranspose2d(channel_in, channel_out, ksize, stride, padding, bias=bias)) _apply(layer, activation, normalizer, channel_out) # init.kaiming_normal(layer[0].weight) return nn.Sequential(*layer) def nn_conv2d(channel_in, channel_out, ksize=3, stride=1, padding=1, scale_factor=2, activation=nn.ReLU, normalizer=nn.BatchNorm2d): layer = list() bias = True if not normalizer else False layer.append(nn.UpsamplingNearest2d(scale_factor=scale_factor)) layer.append(nn.Conv2d(channel_in, channel_out, ksize, stride, padding, bias=bias)) _apply(layer, activation, normalizer, channel_out) # init.kaiming_normal(layer[1].weight) return nn.Sequential(*layer) def _apply(layer, activation, normalizer, channel_out=None): if normalizer: layer.append(normalizer(channel_out)) if activation: layer.append(activation()) return layer ================================================ FILE: process_data.py ================================================ # -*- coding: utf-8 -*- """ Created on Thu Jun 24 11:36:01 2021 @author: Xinya """ import os import glob import time import numpy as np import csv import cv2 import dlib from skimage import transform as tf import librosa import python_speech_features detector = dlib.get_frontal_face_detector() predictor = dlib.shape_predictor('./shape_predictor_68_face_landmarks.dat') import imageio def save(path, frames, format): if format == '.mp4': imageio.mimsave(path, frames) elif format == '.png': if not os.path.exists(path): os.makedirs(path) for j, frame in enumerate(frames): cv2.imwrite(path+'/'+str(j)+'.png',frame) # imageio.imsave(os.path.join(path, str(j) + '.png'), frames[j]) else: print ("Unknown format %s" % format) exit() def crop_image(image_path, out_path): template = np.load('./M003_template.npy') image = cv2.imread(image_path) gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) rects = detector(gray, 1) #detect human face if len(rects) != 1: return 0 for (j, rect) in enumerate(rects): shape = predictor(gray, rect) #detect 68 points shape = shape_to_np(shape) pts2 = np.float32(template[:47,:]) # pts2 = np.float32(template[17:35,:]) # pts1 = np.vstack((landmark[27:36,:], landmark[39,:],landmark[42,:],landmark[45,:])) pts1 = np.float32(shape[:47,:]) #eye and nose # pts1 = np.float32(landmark[17:35,:]) tform = tf.SimilarityTransform() tform.estimate( pts2, pts1) #Set the transformation matrix with the explicit parameters. dst = tf.warp(image, tform, output_shape=(256, 256)) dst = np.array(dst * 255, dtype=np.uint8) cv2.imwrite(out_path,dst) def shape_to_np(shape, dtype="int"): # initialize the list of (x, y)-coordinates coords = np.zeros((shape.num_parts, 2), dtype=dtype) # loop over all facial landmarks and convert them # to a 2-tuple of (x, y)-coordinates for i in range(0, shape.num_parts): coords[i] = (shape.part(i).x, shape.part(i).y) # return the list of (x, y)-coordinates return coords def crop_image_tem(video_path, out_path): image_all = [] videoCapture = cv2.VideoCapture(video_path) success, frame = videoCapture.read() n = 0 while success : image_all.append(frame) n = n + 1 success, frame = videoCapture.read() if len(image_all)!=0 : template = np.load('./M003_template.npy') image=image_all[0] gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) rects = detector(gray, 1) #detect human face if len(rects) != 1: return 0 for (j, rect) in enumerate(rects): shape = predictor(gray, rect) #detect 68 points shape = shape_to_np(shape) pts2 = np.float32(template[:47,:]) # pts2 = np.float32(template[17:35,:]) # pts1 = np.vstack((landmark[27:36,:], landmark[39,:],landmark[42,:],landmark[45,:])) pts1 = np.float32(shape[:47,:]) #eye and nose # pts1 = np.float32(landmark[17:35,:]) tform = tf.SimilarityTransform() tform.estimate( pts2, pts1) #Set the transformation matrix with the explicit parameters. out = [] for i in range(len(image_all)): image = image_all[i] dst = tf.warp(image, tform, output_shape=(256, 256)) dst = np.array(dst * 255, dtype=np.uint8) out.append(dst) if not os.path.exists(out_path): os.makedirs(out_path) save(out_path,out,'.png') def proc_audio(src_mouth_path, dst_audio_path): audio_command = 'ffmpeg -i \"{}\" -loglevel error -y -f wav -acodec pcm_s16le ' \ '-ar 16000 \"{}\"'.format(src_mouth_path, dst_audio_path) os.system(audio_command) def audio2mfcc(audio_file, save, name): speech, sr = librosa.load(audio_file, sr=16000) # mfcc = python_speech_features.mfcc(speech ,16000,winstep=0.01) speech = np.insert(speech, 0, np.zeros(1920)) speech = np.append(speech, np.zeros(1920)) mfcc = python_speech_features.mfcc(speech,16000,winstep=0.01) if not os.path.exists(save): os.makedirs(save) time_len = mfcc.shape[0] mfcc_all = [] for input_idx in range(int((time_len-28)/4)+1): # target_idx = input_idx + sample_delay #14 input_feat = mfcc[4*input_idx:4*input_idx+28,:] mfcc_all.append(input_feat) np.save(os.path.join(save,name+'.npy'), mfcc_all) print(input_idx) if __name__ == "__main__": #video alignment video_path = './test/crop/M030_sad_3_001.mp4' out_path = './test/crop/M030_sad_3_001' crop_image_tem(video_path, out_path) #image alignment image_path = './test/raw_image/brade2.jpg' out_path = './test/image/brade2.jpg' crop_image(image_path, out_path) #change_audio_sample_rate src_mouth_path = './test/audio/00015.mp3' dst_audio_path = './test/audio/00015.mov' proc_audio(src_mouth_path, dst_audio_path) #audio2mfcc #mead path = './dataset/MEAD/audio/' pathDir = os.listdir(path) for i in range(len(pathDir)):#len(pathDir) name = pathDir[i] filepath = os.path.join(path,name) if os.path.exists(filepath): Dir = os.listdir(filepath) save_path = './dataset/MEAD/MEAD_MFCC/'+name os.makedirs(save_path,exist_ok=True) for j in range(len(Dir)): index = Dir[j].split('.')[0] audio_path = os.path.join(filepath,Dir[j]) audio2mfcc(audio_path, save_path,index) print(i,name,j,index) else: print('not exist ',filepath) ================================================ FILE: requirements.txt ================================================ torch==1.10.1 torchvision==0.11.2 numpy librosa opencv-python python_speech_features pickle-mixin matplotlib scikit-image Pillow tqdm dlib scipy pyyaml imageio pandas ================================================ FILE: run.py ================================================ import matplotlib matplotlib.use('Agg') import os, sys import yaml from argparse import ArgumentParser from time import gmtime, strftime from shutil import copy from frames_dataset import MeadDataset, AudioDataset, VoxDataset from modules.generator import OcclusionAwareGenerator from modules.discriminator import MultiScaleDiscriminator from modules.keypoint_detector import KPDetector, Audio_Feature, KPDetector_a from modules.util import AT_net,Emotion_k,get_logger import torch from train import train_part1, train_part1_fine_tune, train_part2 from reconstruction import reconstruction from animate import animate if __name__ == "__main__": if sys.version_info[0] < 3: raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7") parser = ArgumentParser() parser.add_argument("--config", default="config/train_part1.yaml", help="path to config")# required=True parser.add_argument("--mode", default="train_part1", choices=["train_part1", "train_part1_fine_tune", "train_part2"]) parser.add_argument("--log_dir", default='log', help="path to log into") parser.add_argument("--checkpoint", default='124_52000.pth.tar', help="path to checkpoint to restore") parser.add_argument("--audio_checkpoint", default=None, help="path to audio_checkpoint to restore") parser.add_argument("--emo_checkpoint", default=None, help="path to audio_checkpoint to restore") parser.add_argument("--device_ids", default="0", type=lambda x: list(map(int, x.split(','))), help="Names of the devices comma separated.") parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture") parser.set_defaults(verbose=False) opt = parser.parse_args() with open(opt.config) as f: config = yaml.load(f) name = os.path.basename(opt.config).split('.')[0] if opt.checkpoint is not None: log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0]) log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime()) else: log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0]) log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime()) if not os.path.exists(log_dir): os.makedirs(log_dir) if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))): copy(opt.config, log_dir) # logger = get_logger(os.path.join(log_dir, "log.txt")) generator = OcclusionAwareGenerator(**config['model_params']['generator_params'], **config['model_params']['common_params']) if torch.cuda.is_available(): generator.to(opt.device_ids[0]) if opt.verbose: print(generator) discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'], **config['model_params']['common_params']) if torch.cuda.is_available(): discriminator.to(opt.device_ids[0]) if opt.verbose: print(discriminator) kp_detector = KPDetector(**config['model_params']['kp_detector_params'], **config['model_params']['common_params']) kp_detector_a = KPDetector_a(**config['model_params']['kp_detector_params'], **config['model_params']['audio_params']) if torch.cuda.is_available(): kp_detector.to(opt.device_ids[0]) kp_detector_a.to(opt.device_ids[0]) audio_feature = AT_net() emo_feature = Emotion_k(block_expansion=32, num_channels=3, max_features=1024, num_blocks=5, scale_factor=0.25, num_classes=8) if torch.cuda.is_available(): audio_feature.to(opt.device_ids[0]) emo_feature.to(opt.device_ids[0]) if opt.verbose: print(kp_detector) print(kp_detector_a) print(audio_feature) print(emo_feature) # logger.info("Successfully load models.") if config['dataset_params']['name'] == 'Vox': dataset = VoxDataset(is_train=True, **config['dataset_params']) test_dataset = VoxDataset(is_train=False, **config['dataset_params']) elif config['dataset_params']['name'] == 'Lrw': dataset = AudioDataset(is_train=True, **config['dataset_params']) test_dataset = AudioDataset(is_train=False, **config['dataset_params']) elif config['dataset_params']['name'] == 'MEAD': dataset = MeadDataset(is_train=True, **config['dataset_params']) test_dataset = MeadDataset(is_train=False, **config['dataset_params']) if opt.mode == 'train_part1': print("Training part1...") train_part1(config, generator, discriminator, kp_detector, kp_detector_a,audio_feature, opt.checkpoint, opt.audio_checkpoint, log_dir, dataset, test_dataset,opt.device_ids, name) elif opt.mode == 'train_part1_fine_tune': print("Finetune part1...") train_part1_fine_tune(config, generator, discriminator, kp_detector, kp_detector_a,audio_feature, opt.checkpoint, opt.audio_checkpoint, log_dir, dataset, test_dataset,opt.device_ids, name) elif opt.mode == 'train_part2': print("Training part2...") train_part2(config, generator, discriminator, kp_detector, emo_feature,kp_detector_a,audio_feature, opt.checkpoint, opt.audio_checkpoint, opt.emo_checkpoint, log_dir, dataset,test_dataset,opt.device_ids, name) ================================================ FILE: sync_batchnorm/__init__.py ================================================ # -*- coding: utf-8 -*- # File : __init__.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d from .replicate import DataParallelWithCallback, patch_replication_callback ================================================ FILE: sync_batchnorm/batchnorm.py ================================================ # -*- coding: utf-8 -*- # File : batchnorm.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import collections import torch import torch.nn.functional as F from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast from .comm import SyncMaster __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] def _sum_ft(tensor): """sum over the first and last dimention""" return tensor.sum(dim=0).sum(dim=-1) def _unsqueeze_ft(tensor): """add new dementions at the front and the tail""" return tensor.unsqueeze(0).unsqueeze(-1) _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) class _SynchronizedBatchNorm(_BatchNorm): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) self._sync_master = SyncMaster(self._data_parallel_master) self._is_parallel = False self._parallel_id = None self._slave_pipe = None def forward(self, input): # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. if not (self._is_parallel and self.training): return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps) # Resize the input to (B, C, -1). input_shape = input.size() input = input.view(input.size(0), self.num_features, -1) # Compute the sum and square-sum. sum_size = input.size(0) * input.size(2) input_sum = _sum_ft(input) input_ssum = _sum_ft(input ** 2) # Reduce-and-broadcast the statistics. if self._parallel_id == 0: mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) else: mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) # Compute the output. if self.affine: # MJY:: Fuse the multiplication for speed. output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) else: output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) # Reshape it. return output.view(input_shape) def __data_parallel_replicate__(self, ctx, copy_id): self._is_parallel = True self._parallel_id = copy_id # parallel_id == 0 means master device. if self._parallel_id == 0: ctx.sync_master = self._sync_master else: self._slave_pipe = ctx.sync_master.register_slave(copy_id) def _data_parallel_master(self, intermediates): """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" # Always using same "device order" makes the ReduceAdd operation faster. # Thanks to:: Tete Xiao (http://tetexiao.com/) intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) to_reduce = [i[1][:2] for i in intermediates] to_reduce = [j for i in to_reduce for j in i] # flatten target_gpus = [i[1].sum.get_device() for i in intermediates] sum_size = sum([i[1].sum_size for i in intermediates]) sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) broadcasted = Broadcast.apply(target_gpus, mean, inv_std) outputs = [] for i, rec in enumerate(intermediates): outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) return outputs def _compute_mean_std(self, sum_, ssum, size): """Compute the mean and standard-deviation with sum and square-sum. This method also maintains the moving average on the master device.""" assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' mean = sum_ / size sumvar = ssum - sum_ * mean unbias_var = sumvar / (size - 1) bias_var = sumvar / size self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data return mean, bias_var.clamp(self.eps) ** -0.5 class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a mini-batch. .. math:: y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta This module differs from the built-in PyTorch BatchNorm1d as the mean and standard-deviation are reduced across all devices during training. For example, when one uses `nn.DataParallel` to wrap the network during training, PyTorch's implementation normalize the tensor on each device using the statistics only on that device, which accelerated the computation and is also easy to implement, but the statistics might be inaccurate. Instead, in this synchronized version, the statistics will be computed over all training samples distributed on multiple devices. Note that, for one-GPU or CPU-only case, this module behaves exactly same as the built-in PyTorch implementation. The mean and standard-deviation are calculated per-dimension over the mini-batches and gamma and beta are learnable parameter vectors of size C (where C is the input size). During training, this layer keeps a running estimate of its computed mean and variance. The running sum is kept with a default momentum of 0.1. During evaluation, this running mean/variance is used for normalization. Because the BatchNorm is done over the `C` dimension, computing statistics on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm Args: num_features: num_features from an expected input of size `batch_size x num_features [x width]` eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, gives the layer learnable affine parameters. Default: ``True`` Shape: - Input: :math:`(N, C)` or :math:`(N, C, L)` - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) Examples: >>> # With Learnable Parameters >>> m = SynchronizedBatchNorm1d(100) >>> # Without Learnable Parameters >>> m = SynchronizedBatchNorm1d(100, affine=False) >>> input = torch.autograd.Variable(torch.randn(20, 100)) >>> output = m(input) """ def _check_input_dim(self, input): if input.dim() != 2 and input.dim() != 3: raise ValueError('expected 2D or 3D input (got {}D input)' .format(input.dim())) super(SynchronizedBatchNorm1d, self)._check_input_dim(input) class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch of 3d inputs .. math:: y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta This module differs from the built-in PyTorch BatchNorm2d as the mean and standard-deviation are reduced across all devices during training. For example, when one uses `nn.DataParallel` to wrap the network during training, PyTorch's implementation normalize the tensor on each device using the statistics only on that device, which accelerated the computation and is also easy to implement, but the statistics might be inaccurate. Instead, in this synchronized version, the statistics will be computed over all training samples distributed on multiple devices. Note that, for one-GPU or CPU-only case, this module behaves exactly same as the built-in PyTorch implementation. The mean and standard-deviation are calculated per-dimension over the mini-batches and gamma and beta are learnable parameter vectors of size C (where C is the input size). During training, this layer keeps a running estimate of its computed mean and variance. The running sum is kept with a default momentum of 0.1. During evaluation, this running mean/variance is used for normalization. Because the BatchNorm is done over the `C` dimension, computing statistics on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm Args: num_features: num_features from an expected input of size batch_size x num_features x height x width eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, gives the layer learnable affine parameters. Default: ``True`` Shape: - Input: :math:`(N, C, H, W)` - Output: :math:`(N, C, H, W)` (same shape as input) Examples: >>> # With Learnable Parameters >>> m = SynchronizedBatchNorm2d(100) >>> # Without Learnable Parameters >>> m = SynchronizedBatchNorm2d(100, affine=False) >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) >>> output = m(input) """ def _check_input_dim(self, input): if input.dim() != 4: raise ValueError('expected 4D input (got {}D input)' .format(input.dim())) super(SynchronizedBatchNorm2d, self)._check_input_dim(input) class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch of 4d inputs .. math:: y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta This module differs from the built-in PyTorch BatchNorm3d as the mean and standard-deviation are reduced across all devices during training. For example, when one uses `nn.DataParallel` to wrap the network during training, PyTorch's implementation normalize the tensor on each device using the statistics only on that device, which accelerated the computation and is also easy to implement, but the statistics might be inaccurate. Instead, in this synchronized version, the statistics will be computed over all training samples distributed on multiple devices. Note that, for one-GPU or CPU-only case, this module behaves exactly same as the built-in PyTorch implementation. The mean and standard-deviation are calculated per-dimension over the mini-batches and gamma and beta are learnable parameter vectors of size C (where C is the input size). During training, this layer keeps a running estimate of its computed mean and variance. The running sum is kept with a default momentum of 0.1. During evaluation, this running mean/variance is used for normalization. Because the BatchNorm is done over the `C` dimension, computing statistics on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm or Spatio-temporal BatchNorm Args: num_features: num_features from an expected input of size batch_size x num_features x depth x height x width eps: a value added to the denominator for numerical stability. Default: 1e-5 momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, gives the layer learnable affine parameters. Default: ``True`` Shape: - Input: :math:`(N, C, D, H, W)` - Output: :math:`(N, C, D, H, W)` (same shape as input) Examples: >>> # With Learnable Parameters >>> m = SynchronizedBatchNorm3d(100) >>> # Without Learnable Parameters >>> m = SynchronizedBatchNorm3d(100, affine=False) >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) >>> output = m(input) """ def _check_input_dim(self, input): if input.dim() != 5: raise ValueError('expected 5D input (got {}D input)' .format(input.dim())) super(SynchronizedBatchNorm3d, self)._check_input_dim(input) ================================================ FILE: sync_batchnorm/comm.py ================================================ # -*- coding: utf-8 -*- # File : comm.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import queue import collections import threading __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] class FutureResult(object): """A thread-safe future implementation. Used only as one-to-one pipe.""" def __init__(self): self._result = None self._lock = threading.Lock() self._cond = threading.Condition(self._lock) def put(self, result): with self._lock: assert self._result is None, 'Previous result has\'t been fetched.' self._result = result self._cond.notify() def get(self): with self._lock: if self._result is None: self._cond.wait() res = self._result self._result = None return res _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) class SlavePipe(_SlavePipeBase): """Pipe for master-slave communication.""" def run_slave(self, msg): self.queue.put((self.identifier, msg)) ret = self.result.get() self.queue.put(True) return ret class SyncMaster(object): """An abstract `SyncMaster` object. - During the replication, as the data parallel will trigger an callback of each module, all slave devices should call `register(id)` and obtain an `SlavePipe` to communicate with the master. - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, and passed to a registered callback. - After receiving the messages, the master device should gather the information and determine to message passed back to each slave devices. """ def __init__(self, master_callback): """ Args: master_callback: a callback to be invoked after having collected messages from slave devices. """ self._master_callback = master_callback self._queue = queue.Queue() self._registry = collections.OrderedDict() self._activated = False def __getstate__(self): return {'master_callback': self._master_callback} def __setstate__(self, state): self.__init__(state['master_callback']) def register_slave(self, identifier): """ Register an slave device. Args: identifier: an identifier, usually is the device id. Returns: a `SlavePipe` object which can be used to communicate with the master device. """ if self._activated: assert self._queue.empty(), 'Queue is not clean before next initialization.' self._activated = False self._registry.clear() future = FutureResult() self._registry[identifier] = _MasterRegistry(future) return SlavePipe(identifier, self._queue, future) def run_master(self, master_msg): """ Main entry for the master device in each forward pass. The messages were first collected from each devices (including the master device), and then an callback will be invoked to compute the message to be sent back to each devices (including the master device). Args: master_msg: the message that the master want to send to itself. This will be placed as the first message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. Returns: the message to be sent back to the master device. """ self._activated = True intermediates = [(0, master_msg)] for i in range(self.nr_slaves): intermediates.append(self._queue.get()) results = self._master_callback(intermediates) assert results[0][0] == 0, 'The first result should belongs to the master.' for i, res in results: if i == 0: continue self._registry[i].result.put(res) for i in range(self.nr_slaves): assert self._queue.get() is True return results[0][1] @property def nr_slaves(self): return len(self._registry) ================================================ FILE: sync_batchnorm/replicate.py ================================================ # -*- coding: utf-8 -*- # File : replicate.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import functools from torch.nn.parallel.data_parallel import DataParallel __all__ = [ 'CallbackContext', 'execute_replication_callbacks', 'DataParallelWithCallback', 'patch_replication_callback' ] class CallbackContext(object): pass def execute_replication_callbacks(modules): """ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` Note that, as all modules are isomorphism, we assign each sub-module with a context (shared among multiple copies of this module on different devices). Through this context, different copies can share some information. We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback of any slave copies. """ master_copy = modules[0] nr_modules = len(list(master_copy.modules())) ctxs = [CallbackContext() for _ in range(nr_modules)] for i, module in enumerate(modules): for j, m in enumerate(module.modules()): if hasattr(m, '__data_parallel_replicate__'): m.__data_parallel_replicate__(ctxs[j], i) class DataParallelWithCallback(DataParallel): """ Data Parallel with a replication callback. An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by original `replicate` function. The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` Examples: > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) # sync_bn.__data_parallel_replicate__ will be invoked. """ def replicate(self, module, device_ids): modules = super(DataParallelWithCallback, self).replicate(module, device_ids) execute_replication_callbacks(modules) return modules def patch_replication_callback(data_parallel): """ Monkey-patch an existing `DataParallel` object. Add the replication callback. Useful when you have customized `DataParallel` implementation. Examples: > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) > patch_replication_callback(sync_bn) # this is equivalent to > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) """ assert isinstance(data_parallel, DataParallel) old_replicate = data_parallel.replicate @functools.wraps(old_replicate) def new_replicate(module, device_ids): modules = old_replicate(module, device_ids) execute_replication_callbacks(modules) return modules data_parallel.replicate = new_replicate ================================================ FILE: sync_batchnorm/unittest.py ================================================ # -*- coding: utf-8 -*- # File : unittest.py # Author : Jiayuan Mao # Email : maojiayuan@gmail.com # Date : 27/01/2018 # # This file is part of Synchronized-BatchNorm-PyTorch. # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch # Distributed under MIT License. import unittest import numpy as np from torch.autograd import Variable def as_numpy(v): if isinstance(v, Variable): v = v.data return v.cpu().numpy() class TorchTestCase(unittest.TestCase): def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): npa, npb = as_numpy(a), as_numpy(b) self.assertTrue( np.allclose(npa, npb, atol=atol), 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) ) ================================================ FILE: train.py ================================================ from tqdm import trange import torch import torch.nn as nn from torch.utils.data import DataLoader from logger import Logger from modules.model import DiscriminatorFullModel, TrainPart1Model, TrainPart2Model import itertools from torch.optim.lr_scheduler import MultiStepLR from sync_batchnorm import DataParallelWithCallback from frames_dataset import DatasetRepeater,TestsetRepeater import time from tensorboardX import SummaryWriter def train_part1(config, generator, discriminator, kp_detector, kp_detector_a,audio_feature, checkpoint, audio_checkpoint, log_dir, dataset, test_dataset, device_ids, name): train_params = config['train_params'] optimizer_audio_feature = torch.optim.Adam(itertools.chain(audio_feature.parameters(),kp_detector_a.parameters()), lr=train_params['lr_audio_feature'], betas=(0.5, 0.999)) if checkpoint is not None: start_epoch = Logger.load_cpk(checkpoint, generator, discriminator, kp_detector, audio_feature, optimizer_generator, optimizer_discriminator, None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector, None if train_params['lr_audio_feature'] == 0 else optimizer_audio_feature) if audio_checkpoint is not None: pretrain = torch.load(audio_checkpoint) kp_detector_a.load_state_dict(pretrain['kp_detector_a']) audio_feature.load_state_dict(pretrain['audio_feature']) optimizer_audio_feature.load_state_dict(pretrain['optimizer_audio_feature']) start_epoch = pretrain['epoch'] else: start_epoch = 0 scheduler_audio_feature = MultiStepLR(optimizer_audio_feature, train_params['epoch_milestones'], gamma=0.1, last_epoch=-1 + start_epoch * (train_params['lr_audio_feature'] != 0)) if 'num_repeats' in train_params or train_params['num_repeats'] != 1: dataset = DatasetRepeater(dataset, train_params['num_repeats']) test_dataset = TestsetRepeater(test_dataset, train_params['num_repeats']) dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=0, drop_last=True)#6 test_dataloader = DataLoader(test_dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=0, drop_last=True)#6 num_steps_per_epoch = len(dataloader) num_steps_test_epoch = len(test_dataloader) generator_full = TrainPart1Model(kp_detector, kp_detector_a, audio_feature, generator, discriminator, train_params,device_ids) discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params) if len(device_ids)>1: generator_full=torch.nn.DataParallel(generator_full) discriminator_full=torch.nn.DataParallel(discriminator_full) if torch.cuda.is_available(): if len(device_ids) == 1: generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids) discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids) elif len(device_ids)>1: generator_full = generator_full.to(device_ids[0]) discriminator_full = discriminator_full.to(device_ids[0]) step = 0 t0 = time.time() writer=SummaryWriter(comment=name) train_itr=0 test_itr=0 with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger: for epoch in trange(start_epoch, train_params['num_epochs']): for x in dataloader: losses_generator, generated = generator_full(x) loss_values = [val.mean() for val in losses_generator.values()] loss = sum(loss_values) writer.add_scalar('Train',loss,train_itr) writer.add_scalar('Train_value',loss_values[0],train_itr) writer.add_scalar('Train_heatmap',loss_values[1],train_itr) writer.add_scalar('Train_jacobian',loss_values[2],train_itr) train_itr+=1 loss.backward() optimizer_audio_feature.step() optimizer_audio_feature.zero_grad() d = time.time() # if train_params['loss_weights']['generator_gan'] != 0: # optimizer_discriminator.zero_grad() # else: # losses_discriminator = {} # losses_generator.update(losses_discriminator) losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()} logger.log_iter(losses=losses) e = time.time() step += 1 if(step % 500 == 0): logger.log_epoch(epoch,step, {'audio_feature': audio_feature, 'kp_detector_a':kp_detector_a, 'optimizer_audio_feature': optimizer_audio_feature}, inp=x, out=generated) scheduler_audio_feature.step() for x in test_dataloader: with torch.no_grad(): losses_generator, generated = generator_full(x) loss_values = [val.mean() for val in losses_generator.values()] loss = sum(loss_values) writer.add_scalar('Test',loss,test_itr) writer.add_scalar('Test_value',loss_values[0],test_itr) writer.add_scalar('Test_heatmap',loss_values[1],test_itr) writer.add_scalar('Test_jacobian',loss_values[2],test_itr) test_itr+=1 def train_part1_fine_tune(config, generator, discriminator, kp_detector, kp_detector_a,audio_feature, checkpoint, audio_checkpoint, log_dir, dataset, dataset2, test_dataset, device_ids, name): train_params = config['train_params'] optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999)) optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params['lr_discriminator'], betas=(0.5, 0.999)) optimizer_audio_feature = torch.optim.Adam(itertools.chain(audio_feature.parameters(),kp_detector_a.parameters()), lr=train_params['lr_audio_feature'], betas=(0.5, 0.999)) # optimizer_kp_detector_a = torch.optim.Adam(kp_detector_a.parameters(), lr=train_params['lr_audio_feature'], betas=(0.5, 0.999)) if checkpoint is not None: start_epoch = Logger.load_cpk(checkpoint, generator, discriminator, kp_detector, audio_feature, optimizer_generator, optimizer_discriminator, None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector, None if train_params['lr_audio_feature'] == 0 else optimizer_audio_feature) if audio_checkpoint is not None: pretrain = torch.load(audio_checkpoint) kp_detector_a.load_state_dict(pretrain['kp_detector_a']) audio_feature.load_state_dict(pretrain['audio_feature']) # optimizer_kp_detector_a.load_state_dict(pretrain['optimizer_kp_detector_a']) optimizer_audio_feature.load_state_dict(pretrain['optimizer_audio_feature']) start_epoch = pretrain['epoch'] else: start_epoch = 0 scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1, last_epoch=start_epoch - 1) scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1, last_epoch=start_epoch - 1) scheduler_audio_feature = MultiStepLR(optimizer_audio_feature, train_params['epoch_milestones'], gamma=0.1, last_epoch=-1 + start_epoch * (train_params['lr_audio_feature'] != 0)) if 'num_repeats' in train_params or train_params['num_repeats'] != 1: dataset = DatasetRepeater(dataset, train_params['num_repeats']) test_dataset = TestsetRepeater(test_dataset, train_params['num_repeats']) dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=0, drop_last=True)#6 test_dataloader = DataLoader(test_dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=0, drop_last=True)#6 num_steps_per_epoch = len(dataloader) num_steps_test_epoch = len(test_dataloader) generator_full = TrainFullModel(kp_detector, kp_detector_a, audio_feature, generator, discriminator, train_params,device_ids) discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params) print('End dataload ', file=open('log/MEAD_LRW_test_a.txt', 'a')) if len(device_ids)>1: generator_full=torch.nn.DataParallel(generator_full) discriminator_full=torch.nn.DataParallel(discriminator_full) if torch.cuda.is_available(): if len(device_ids) == 1: generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids) discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids) elif len(device_ids)>1: generator_full = generator_full.to(device_ids[0]) discriminator_full = discriminator_full.to(device_ids[0]) step = 0 t0 = time.time() writer=SummaryWriter(comment=name) train_itr=0 test_itr=0 with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger: for epoch in trange(start_epoch, train_params['num_epochs']): for x in dataloader: losses_generator, generated = generator_full(x) loss_values = [val.mean() for val in losses_generator.values()] loss = sum(loss_values) writer.add_scalar('Train',loss,train_itr) writer.add_scalar('Train_value',loss_values[0],train_itr) writer.add_scalar('Train_heatmap',loss_values[1],train_itr) writer.add_scalar('Train_jacobian',loss_values[2],train_itr) writer.add_scalar('Train_perceptual',loss_values[3],train_itr) train_itr+=1 loss.backward() optimizer_audio_feature.step() optimizer_audio_feature.zero_grad() optimizer_generator.step() optimizer_generator.zero_grad() # optimizer_kp_detector_a.step() # optimizer_kp_detector_a.zero_grad() if train_params['loss_weights']['discriminator_gan'] != 0: optimizer_discriminator.zero_grad() # losses_discriminator = discriminator_full(x, generated) # loss_values = [val.mean() for val in losses_discriminator.values()] # loss = sum(loss_values) # loss.backward() # optimizer_discriminator.step() # optimizer_discriminator.zero_grad() else: losses_discriminator = {} losses_generator.update(losses_discriminator) losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()} logger.log_iter(losses=losses) step += 1 if(step % 500 == 0): logger.log_epoch(epoch,step, {'audio_feature': audio_feature, 'kp_detector_a':kp_detector_a, 'generator': generator, 'optimizer_generator':optimizer_generator, 'optimizer_audio_feature': optimizer_audio_feature}, inp=x, out=generated) scheduler_generator.step() scheduler_discriminator.step() scheduler_audio_feature.step() for x in test_dataloader: with torch.no_grad(): losses_generator, generated = generator_full(x) loss_values = [val.mean() for val in losses_generator.values()] loss = sum(loss_values) writer.add_scalar('Test',loss,test_itr) writer.add_scalar('Test_value',loss_values[0],test_itr) writer.add_scalar('Test_heatmap',loss_values[1],test_itr) writer.add_scalar('Test_jacobian',loss_values[2],test_itr) writer.add_scalar('Test_perceptual',loss_values[3],test_itr) test_itr+=1 def train_part2(config, generator, discriminator, kp_detector, emo_detector, kp_detector_a,audio_feature, checkpoint, audio_checkpoint, emo_checkpoint, log_dir, dataset, test_dataset, device_ids, exp_name): train_params = config['train_params'] optimizer_emo_detector = torch.optim.Adam(emo_detector.parameters(), lr=train_params['lr_audio_feature'], betas=(0.5, 0.999)) if checkpoint is not None: start_epoch = Logger.load_cpk(checkpoint, generator, discriminator, kp_detector, audio_feature, optimizer_generator, optimizer_discriminator, None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector, None if train_params['lr_audio_feature'] == 0 else optimizer_audio_feature) if emo_checkpoint is not None: pretrain = torch.load(emo_checkpoint) tgt_state = emo_detector.state_dict() strip = 'module.' if 'emo_detector' in pretrain: emo_detector.load_state_dict(pretrain['emo_detector']) optimizer_emo_detector.load_state_dict(pretrain['optimizer_emo_detector']) print('emo_detector in pretrain + load', file=open('log/'+exp_name+'.txt', 'a')) for name, param in pretrain.items(): if isinstance(param, nn.Parameter): param = param.data if strip is not None and name.startswith(strip): name = name[len(strip):] if name not in tgt_state: continue tgt_state[name].copy_(param) print(name) if audio_checkpoint is not None: pretrain = torch.load(audio_checkpoint) kp_detector_a.load_state_dict(pretrain['kp_detector_a']) audio_feature.load_state_dict(pretrain['audio_feature']) optimizer_audio_feature.load_state_dict(pretrain['optimizer_audio_feature']) if 'emo_detector' in pretrain: emo_detector.load_state_dict(pretrain['emo_detector']) optimizer_emo_detector.load_state_dict(pretrain['optimizer_emo_detector']) start_epoch = pretrain['epoch'] else: start_epoch = 0 scheduler_emo_detector = MultiStepLR(optimizer_emo_detector, train_params['epoch_milestones'], gamma=0.1, last_epoch=-1 + start_epoch * (train_params['lr_audio_feature'] != 0)) if 'num_repeats' in train_params or train_params['num_repeats'] != 1: dataset = DatasetRepeater(dataset, train_params['num_repeats']) test_dataset = TestsetRepeater(test_dataset, train_params['num_repeats']) dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=0, drop_last=True)#6 test_dataloader = DataLoader(test_dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=0, drop_last=True)#6 num_steps_per_epoch = len(dataloader) num_steps_test_epoch = len(test_dataloader) generator_full = TrainPart2Model(kp_detector, emo_detector,kp_detector_a, audio_feature,generator, discriminator, train_params,device_ids) discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params) if len(device_ids)>1: generator_full=torch.nn.DataParallel(generator_full) discriminator_full=torch.nn.DataParallel(discriminator_full) if torch.cuda.is_available(): if len(device_ids) == 1: generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids) discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids) elif len(device_ids)>1: generator_full = generator_full.to(device_ids[0]) discriminator_full = discriminator_full.to(device_ids[0]) step = 0 t0 = time.time() writer=SummaryWriter(comment=exp_name) train_itr=0 test_itr=0 with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger: for epoch in trange(start_epoch, train_params['num_epochs']): for x in dataloader: losses_generator, generated = generator_full(x) loss_values = [val.mean() for val in losses_generator.values()] loss = sum(loss_values) writer.add_scalar('Train',loss,train_itr) writer.add_scalar('Train_value',loss_values[0],train_itr) # writer.add_scalar('Train_heatmap',loss_values[1],train_itr) writer.add_scalar('Train_jacobian',loss_values[1],train_itr) writer.add_scalar('Train_classify',loss_values[2],train_itr) train_itr+=1 loss.backward() optimizer_emo_detector.step() optimizer_emo_detector.zero_grad() losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()} logger.log_iter(losses=losses) step += 1 if(step % 1000 == 0): logger.log_epoch(epoch,step, {'audio_feature': audio_feature, 'kp_detector_a':kp_detector_a, 'emo_detector':emo_detector, 'optimizer_emo_detector': optimizer_emo_detector, # 'optimizer_kp_detector_a':optimizer_kp_detector_a, 'optimizer_audio_feature': optimizer_audio_feature}, inp=x, out=generated) scheduler_emo_detector.step() for x in test_dataloader: with torch.no_grad(): losses_generator, generated = generator_full(x) loss_values = [val.mean() for val in losses_generator.values()] loss = sum(loss_values) writer.add_scalar('Test',loss,test_itr) writer.add_scalar('Test_value',loss_values[0],test_itr) # writer.add_scalar('Test_heatmap',loss_values[1],test_itr) writer.add_scalar('Test_jacobian',loss_values[1],test_itr) writer.add_scalar('Test_classify',loss_values[2],test_itr) test_itr+=1