Showing preview only (361K chars total). Download the full file or copy to clipboard to get everything.
Repository: jixinya/EAMM
Branch: main
Commit: e176a79865f7
Files: 53
Total size: 345.8 KB
Directory structure:
gitextract_7pp944hx/
├── 3DDFA_V2/
│ ├── demo.py
│ └── utils/
│ └── pose.py
├── LICENSE
├── M003_template.npy
├── README.md
├── augmentation.py
├── config/
│ ├── MEAD_emo_video_aug_delta_4_crop_random_crop.yaml
│ ├── train_part1.yaml
│ ├── train_part1_fine_tune.yaml
│ └── train_part2.yaml
├── dataset/
│ ├── LRW/
│ │ ├── MFCC/
│ │ │ └── ABOUT/
│ │ │ └── ABOUT_00001.npy
│ │ └── Pose/
│ │ └── ABOUT/
│ │ └── ABOUT_00001.npy
│ └── MEAD/
│ └── list/
│ └── MEAD_fomm_neu_dic_crop.npy
├── demo.py
├── filter1.py
├── frames_dataset.py
├── logger.py
├── modules/
│ ├── dense_motion.py
│ ├── discriminator.py
│ ├── function.py
│ ├── generator.py
│ ├── keypoint_detector.py
│ ├── model.py
│ ├── model_delta_map.py
│ ├── model_gen.py
│ ├── ops.py
│ ├── stylegan2.py
│ └── util.py
├── ops.py
├── process_data.py
├── requirements.txt
├── run.py
├── sync_batchnorm/
│ ├── __init__.py
│ ├── batchnorm.py
│ ├── comm.py
│ ├── replicate.py
│ └── unittest.py
├── test/
│ ├── pose/
│ │ ├── 14.npy
│ │ ├── 21.npy
│ │ ├── 60.npy
│ │ ├── 7.npy
│ │ ├── anne.npy
│ │ ├── brade2.npy
│ │ ├── dune_1.npy
│ │ ├── dune_2.npy
│ │ ├── jake4.npy
│ │ ├── mona.npy
│ │ ├── paint1.npy
│ │ └── scarlett.npy
│ └── pose_long/
│ ├── 0zn70Ak8lRc_Daniel_Auteuil_0zn70Ak8lRc_0002.npy
│ ├── 1hEr7qKRKL4_Daniel_Dae_Kim_1hEr7qKRKL4_0004.npy
│ └── 50IAfJCypFI_Alex_Kingston_50IAfJCypFI_0001.npy
└── train.py
================================================
FILE CONTENTS
================================================
================================================
FILE: 3DDFA_V2/demo.py
================================================
# coding: utf-8
__author__ = 'cleardusk'
import sys
import argparse
import cv2
import yaml
import os
import time
from FaceBoxes import FaceBoxes
from TDDFA import TDDFA
from utils.render import render
#from utils.render_ctypes import render # faster
from utils.depth import depth
from utils.pncc import pncc
from utils.uv import uv_tex
from utils.pose import viz_pose, get_pose
from utils.serialization import ser_to_ply, ser_to_obj
from utils.functions import draw_landmarks, get_suffix
from utils.tddfa_util import str2bool
import numpy as np
from tqdm import tqdm
import copy
import concurrent.futures
from multiprocessing import Pool
def main(args,img, save_path, pose_path):
# begin = time.time()
cfg = yaml.load(open(args.config), Loader=yaml.SafeLoader)
# Init FaceBoxes and TDDFA, recommend using onnx flag
if args.onnx:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
os.environ['OMP_NUM_THREADS'] = '4'
from FaceBoxes.FaceBoxes_ONNX import FaceBoxes_ONNX
from TDDFA_ONNX import TDDFA_ONNX
face_boxes = FaceBoxes_ONNX()
tddfa = TDDFA_ONNX(**cfg)
else:
gpu_mode = args.mode == 'gpu'
tddfa = TDDFA(gpu_mode=gpu_mode, **cfg)
face_boxes = FaceBoxes()
# Given a still image path and load to BGR channel
# img = cv2.imread(img_path) #args.img_fp
# Detect faces, get 3DMM params and roi boxes
boxes = face_boxes(img)
n = len(boxes)
if n == 0:
print(f'No face detected, exit')
# sys.exit(-1)
return None
print(f'Detect {n} faces')
param_lst, roi_box_lst = tddfa(img, boxes)
#detection time
# detect_time = time.time()-begin
# print('detection time: '+str(detect_time), file=open('/mnt/lustre/jixinya/Home/3DDFA_V2/pose.txt', 'a'))
# Visualization and serialization
dense_flag = args.opt in ('2d_dense', '3d', 'depth', 'pncc', 'uv_tex', 'ply', 'obj')
# old_suffix = get_suffix(img_path)
old_suffix = 'png'
new_suffix = f'.{args.opt}' if args.opt in ('ply', 'obj') else '.jpg'
wfp = f'examples/results/{args.img_fp.split("/")[-1].replace(old_suffix, "")}_{args.opt}' + new_suffix
ver_lst = tddfa.recon_vers(param_lst, roi_box_lst, dense_flag=dense_flag)
if args.opt == '2d_sparse':
draw_landmarks(img, ver_lst, show_flag=args.show_flag, dense_flag=dense_flag, wfp=wfp)
elif args.opt == '2d_dense':
draw_landmarks(img, ver_lst, show_flag=args.show_flag, dense_flag=dense_flag, wfp=wfp)
elif args.opt == '3d':
render(img, ver_lst, tddfa.tri, alpha=0.6, show_flag=args.show_flag, wfp=wfp)
elif args.opt == 'depth':
# if `with_bf_flag` is False, the background is black
depth(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp, with_bg_flag=True)
elif args.opt == 'pncc':
pncc(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp, with_bg_flag=True)
elif args.opt == 'uv_tex':
uv_tex(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp)
elif args.opt == 'pose':
all_pose = get_pose(img, param_lst, ver_lst, show_flag=args.show_flag, wfp=save_path, wnp = pose_path)
elif args.opt == 'ply':
ser_to_ply(ver_lst, tddfa.tri, height=img.shape[0], wfp=wfp)
elif args.opt == 'obj':
ser_to_obj(img, ver_lst, tddfa.tri, height=img.shape[0], wfp=wfp)
else:
raise ValueError(f'Unknown opt {args.opt}')
return all_pose
def process_word(i):
path = '/media/xinya/Backup Plus/sense_shixi_data/new_crop/MEAD_fomm_video_6/'
save = '/media/xinya/Backup Plus/sense_shixi_data/new_crop/MEAD_fomm_pose_im/'
pose = '/media/xinya/Backup Plus/sense_shixi_data/new_crop/MEAD_fomm_pose/'
start = time.time()
Dir = os.listdir(path)
Dir.sort()
word = Dir[i]
wpath = os.path.join(path, word)
print(wpath)
pathDir = os.listdir(wpath)
pose_file = os.path.join(pose,word)
if not os.path.exists(pose_file):
os.makedirs(pose_file)
for j in range(len(pathDir)):
name = pathDir[j]
# save_file = os.path.join(save,word,name)
# if not os.path.exists(save_file):
# os.makedirs(save_file)
fpath = os.path.join(wpath,name)
image_all = []
videoCapture = cv2.VideoCapture(fpath)
success, frame = videoCapture.read()
n = 0
while success :
image_all.append(frame)
n = n + 1
success, frame = videoCapture.read()
# fDir = os.listdir(fpath)
pose_all = np.zeros((len(image_all),7))
for k in range(len(image_all)):
# index = fDir[k].split('.')[0]
# img_path = os.path.join(fpath,str(k)+'.png')
# pose_all[k] = main(args,image_all[k], os.path.join(save_file,str(k)+'.jpg'), None)
pose_all[k] = main(args,image_all[k], None, None)
np.save(os.path.join(pose,word,name.split('.')[0]+'.npy'),pose_all)
st = time.time()-start
print(str(i)+' '+word+' '+str(j)+' '+name+' '+str(k)+'time: '+str(st), file=open('/media/thea/Backup Plus/sense_shixi_data/new_crop/pose_mead6.txt', 'a'))
print(i,word,j,name,k)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='The demo of still image of 3DDFA_V2')
parser.add_argument('-c', '--config', type=str, default='configs/mb1_120x120.yml')
parser.add_argument('-f', '--img_fp', type=str, default='examples/inputs/0.png')
parser.add_argument('-m', '--mode', type=str, default='cpu', help='gpu or cpu mode')
parser.add_argument('-o', '--opt', type=str, default='pose',
choices=['2d_sparse', '2d_dense', '3d', 'depth', 'pncc', 'uv_tex', 'pose', 'ply', 'obj'])
parser.add_argument('--show_flag', type=str2bool, default='False', help='whether to show the visualization result')
parser.add_argument('--onnx', action='store_true', default=False)
args = parser.parse_args()
filepath = 'test/image/'
pathDir = os.listdir(filepath)
for i in range(len(pathDir)):
image= cv2.imread(os.path.join(filepath,pathDir[i]))
pose = main(args,image, None, None).reshape(1,7)
np.save('test/pose/'+pathDir[i].split('.')[0]+'.npy',pose)
print(i,pathDir[i])
'''
def main(args):
cfg = yaml.load(open(args.config), Loader=yaml.SafeLoader)
# Init FaceBoxes and TDDFA, recommend using onnx flag
if args.onnx:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
os.environ['OMP_NUM_THREADS'] = '4'
from FaceBoxes.FaceBoxes_ONNX import FaceBoxes_ONNX
from TDDFA_ONNX import TDDFA_ONNX
face_boxes = FaceBoxes_ONNX()
tddfa = TDDFA_ONNX(**cfg)
else:
gpu_mode = args.mode == 'gpu'
tddfa = TDDFA(gpu_mode=gpu_mode, **cfg)
face_boxes = FaceBoxes()
# Given a still image path and load to BGR channel
img = cv2.imread(args.img_fp)
# Detect faces, get 3DMM params and roi boxes
boxes = face_boxes(img)
n = len(boxes)
if n == 0:
print(f'No face detected, exit')
sys.exit(-1)
print(f'Detect {n} faces')
param_lst, roi_box_lst = tddfa(img, boxes)
# Visualization and serialization
dense_flag = args.opt in ('2d_dense', '3d', 'depth', 'pncc', 'uv_tex', 'ply', 'obj')
old_suffix = get_suffix(args.img_fp)
new_suffix = f'.{args.opt}' if args.opt in ('ply', 'obj') else '.jpg'
wfp = f'examples/results/{args.img_fp.split("/")[-1].replace(old_suffix, "")}_{args.opt}' + new_suffix
ver_lst = tddfa.recon_vers(param_lst, roi_box_lst, dense_flag=dense_flag)
if args.opt == '2d_sparse':
draw_landmarks(img, ver_lst, show_flag=args.show_flag, dense_flag=dense_flag, wfp=wfp)
elif args.opt == '2d_dense':
draw_landmarks(img, ver_lst, show_flag=args.show_flag, dense_flag=dense_flag, wfp=wfp)
elif args.opt == '3d':
render(img, ver_lst, tddfa.tri, alpha=0.6, show_flag=args.show_flag, wfp=wfp)
elif args.opt == 'depth':
# if `with_bf_flag` is False, the background is black
depth(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp, with_bg_flag=True)
elif args.opt == 'pncc':
pncc(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp, with_bg_flag=True)
elif args.opt == 'uv_tex':
uv_tex(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp)
elif args.opt == 'pose':
viz_pose(img, param_lst, ver_lst, show_flag=args.show_flag, wfp=wfp)
elif args.opt == 'ply':
ser_to_ply(ver_lst, tddfa.tri, height=img.shape[0], wfp=wfp)
elif args.opt == 'obj':
ser_to_obj(img, ver_lst, tddfa.tri, height=img.shape[0], wfp=wfp)
else:
raise ValueError(f'Unknown opt {args.opt}')
'''
================================================
FILE: 3DDFA_V2/utils/pose.py
================================================
# coding: utf-8
"""
Reference: https://github.com/YadiraF/PRNet/blob/master/utils/estimate_pose.py
Calculating pose from the output 3DMM parameters, you can also try to use solvePnP to perform estimation
"""
__author__ = 'cleardusk'
import cv2
import numpy as np
from math import cos, sin, atan2, asin, sqrt
from .functions import calc_hypotenuse, plot_image
def P2sRt(P):
""" decompositing camera matrix P.
Args:
P: (3, 4). Affine Camera Matrix.
Returns:
s: scale factor.
R: (3, 3). rotation matrix.
t2d: (2,). 2d translation.
"""
t3d = P[:, 3]
R1 = P[0:1, :3]
R2 = P[1:2, :3]
s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2.0
r1 = R1 / np.linalg.norm(R1)
r2 = R2 / np.linalg.norm(R2)
r3 = np.cross(r1, r2)
R = np.concatenate((r1, r2, r3), 0)
return s, R, t3d
def matrix2angle(R):
""" compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf
refined by: https://stackoverflow.com/questions/43364900/rotation-matrix-to-euler-angles-with-opencv
todo: check and debug
Args:
R: (3,3). rotation matrix
Returns:
x: yaw
y: pitch
z: roll
"""
if R[2, 0] > 0.998:
z = 0
x = np.pi / 2
y = z + atan2(-R[0, 1], -R[0, 2])
elif R[2, 0] < -0.998:
z = 0
x = -np.pi / 2
y = -z + atan2(R[0, 1], R[0, 2])
else:
x = asin(R[2, 0])
y = atan2(R[2, 1] / cos(x), R[2, 2] / cos(x))
z = atan2(R[1, 0] / cos(x), R[0, 0] / cos(x))
return x, y, z
def angle2matrix(theta):
""" compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf
refined by: https://stackoverflow.com/questions/43364900/rotation-matrix-to-euler-angles-with-opencv
todo: check and debug
Args:
R: (3,3). rotation matrix
Returns:
x: yaw
y: pitch
z: roll
"""
R_x = np.array([[1, 0, 0 ],
[0, cos(theta[1]), -sin(theta[1]) ],
[0, sin(theta[1]), cos(theta[1]) ]
])
R_y = np.array([[cos(theta[0]), 0, sin(-theta[0]) ],
[0, 1, 0 ],
[-sin(-theta[0]), 0, cos(theta[0]) ]
])
R_z = np.array([[cos(theta[2]), -sin(theta[2]), 0],
[sin(theta[2]), cos(theta[2]), 0],
[0, 0, 1]
])
R = np.dot(R_z, np.dot( R_y, R_x ))
return R
def angle2matrix_3ddfa(angles):
''' get rotation matrix from three rotation angles(radian). The same as in 3DDFA.
Args:
angles: [3,]. x, y, z angles
x: pitch.
y: yaw.
z: roll.
Returns:
R: 3x3. rotation matrix.
'''
# x, y, z = np.deg2rad(angles[0]), np.deg2rad(angles[1]), np.deg2rad(angles[2])
x, y, z = angles[1], angles[0], angles[2]
# x
Rx=np.array([[1, 0, 0],
[0, cos(x), sin(x)],
[0, -sin(x), cos(x)]])
# y
Ry=np.array([[ cos(y), 0, -sin(y)],
[ 0, 1, 0],
[sin(y), 0, cos(y)]])
# z
Rz=np.array([[cos(z), sin(z), 0],
[-sin(z), cos(z), 0],
[ 0, 0, 1]])
R = Rx.dot(Ry).dot(Rz)
return R.astype(np.float32)
def calc_pose(param):
P = param[:12].reshape(3, -1) # camera matrix
s, R, t3d = P2sRt(P)
P = np.concatenate((R, t3d.reshape(3, -1)), axis=1) # without scale
pose = matrix2angle(R)
pose = [p * 180 / np.pi for p in pose]
return P, pose
def build_camera_box(rear_size=90):
point_3d = []
rear_depth = 0
point_3d.append((-rear_size, -rear_size, rear_depth))
point_3d.append((-rear_size, rear_size, rear_depth))
point_3d.append((rear_size, rear_size, rear_depth))
point_3d.append((rear_size, -rear_size, rear_depth))
point_3d.append((-rear_size, -rear_size, rear_depth))
front_size = int(4 / 3 * rear_size)
front_depth = int(4 / 3 * rear_size)
point_3d.append((-front_size, -front_size, front_depth))
point_3d.append((-front_size, front_size, front_depth))
point_3d.append((front_size, front_size, front_depth))
point_3d.append((front_size, -front_size, front_depth))
point_3d.append((-front_size, -front_size, front_depth))
point_3d = np.array(point_3d, dtype=np.float32).reshape(-1, 3)
return point_3d
def plot_pose_box(img, P, ver, color=(40, 255, 0), line_width=2):
""" Draw a 3D box as annotation of pose.
Ref:https://github.com/yinguobing/head-pose-estimation/blob/master/pose_estimator.py
Args:
img: the input image
P: (3, 4). Affine Camera Matrix.
kpt: (2, 68) or (3, 68)
"""
llength = calc_hypotenuse(ver)
point_3d = build_camera_box(llength)
# Map to 2d image points
point_3d_homo = np.hstack((point_3d, np.ones([point_3d.shape[0], 1]))) # n x 4
point_2d = point_3d_homo.dot(P.T)[:, :2]
point_2d[:, 1] = - point_2d[:, 1]
point_2d[:, :2] = point_2d[:, :2] - np.mean(point_2d[:4, :2], 0) + np.mean(ver[:2, :27], 1)
point_2d = np.int32(point_2d.reshape(-1, 2))
# Draw all the lines
cv2.polylines(img, [point_2d], True, color, line_width, cv2.LINE_AA)
cv2.line(img, tuple(point_2d[1]), tuple(
point_2d[6]), color, line_width, cv2.LINE_AA)
cv2.line(img, tuple(point_2d[2]), tuple(
point_2d[7]), color, line_width, cv2.LINE_AA)
cv2.line(img, tuple(point_2d[3]), tuple(
point_2d[8]), color, line_width, cv2.LINE_AA)
return img
def viz_pose(img, param_lst, ver_lst, show_flag=False, wfp=None):
for param, ver in zip(param_lst, ver_lst):
P, pose = calc_pose(param)
img = plot_pose_box(img, P, ver)
# print(P[:, :3])
print(f'yaw: {pose[0]:.1f}, pitch: {pose[1]:.1f}, roll: {pose[2]:.1f}')
if wfp is not None:
cv2.imwrite(wfp, img)
print(f'Save visualization result to {wfp}')
if show_flag:
plot_image(img)
return img
def pose_6(param):
P = param[:12].reshape(3, -1) # camera matrix
s, R, t3d = P2sRt(P)
P = np.concatenate((R, t3d.reshape(3, -1)), axis=1) # without scale
pose = matrix2angle(R)
print(t3d)
R1 = angle2matrix(pose)
print(R)
print(R1)
pose = [p * 180 / np.pi for p in pose]
return s, pose, t3d, P
def smooth_pose(img, param_lst, ver_lst, pose_new, show_flag=False, wfp=None, wnp = None):
for param, ver in zip(param_lst, ver_lst):
t3d = np.array([pose_new[4],pose_new[5],pose_new[6]])
theta = np.array([pose_new[0],pose_new[1],pose_new[2]])
theta = [p * np.pi / 180 for p in theta]
R = angle2matrix(theta)
P = np.concatenate((R, t3d.reshape(3, -1)), axis=1)
img = plot_pose_box(img, P, ver)
# print(P,P.shape,t3d)
print(P,pose_new)
print(f'yaw: {theta[0]:.1f}, pitch: {theta[1]:.1f}, roll: {theta[2]:.1f}')
all_pose = [0]
all_pose = np.array(all_pose)
if wfp is not None:
cv2.imwrite(wfp, img)
print(f'Save visualization result to {wfp}')
if wnp is not None:
np.save(wnp, all_pose)
print(f'Save visualization result to {wfp}')
if show_flag:
plot_image(img)
return img
def get_pose(img, param_lst, ver_lst, show_flag=False, wfp=None, wnp = None):
for param, ver in zip(param_lst, ver_lst):
s, pose, t3d, P = pose_6(param)
img = plot_pose_box(img, P, ver)
# print(P,P.shape,t3d)
print(f'yaw: {pose[0]:.1f}, pitch: {pose[1]:.1f}, roll: {pose[2]:.1f}')
all_pose = [pose[0],pose[1],pose[2],s,t3d[0],t3d[1],t3d[2]]
all_pose = np.array(all_pose)
if wfp is not None:
cv2.imwrite(wfp, img)
print(f'Save visualization result to {wfp}')
if wnp is not None:
np.save(wnp, all_pose)
print(f'Save visualization result to {wfp}')
if show_flag:
plot_image(img)
return all_pose
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2022 jixinya
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# EAMM: One-Shot Emotional Talking Face via Audio-Based Emotion-Aware Motion Model [SIGGRAPH 2022 Conference]
Xinya Ji, [Hang Zhou](https://hangz-nju-cuhk.github.io/), Kaisiyuan Wang, [Qianyi Wu](https://wuqianyi.top/), [Wayne Wu](http://wywu.github.io/), [Feng Xu](http://xufeng.site/), [Xun Cao](https://cite.nju.edu.cn/People/Faculty/20190621/i5054.html)
[[Project]](https://jixinya.github.io/projects/EAMM/) [[Paper]](https://arxiv.org/abs/2205.15278)

Given a single portrait image, we can synthesize emotional talking faces, where mouth movements match the input audio and facial emotion dynamics follow the emotion source video.
## Installation
We train and test based on Python3.6 and Pytorch. To install the dependencies run:
```
pip install -r requirements.txt
```
## Testing
- Download the pre-trained models and data under the following link: [google-drive](https://drive.google.com/file/d/1IL9LjH3JegyMqJABqMxrX3StAq_v8Gtp/view?usp=sharing) and put the file in corresponding places.
- Run the demo:
`python demo.py --source_image path/to/image --driving_video path/to/emotion_video --pose_file path/to/pose --in_file path/to/audio --emotion emotion_type`
- Prepare testing data:
prepare source_image -- crop_image in process_data.py
prepare driving_video -- crop_image_tem in process_data.py
prepare pose -- detect pose using [3DDFA_V2](https://github.com/cleardusk/3DDFA_V2)
## Training
- Training data structure:
```
./data/<dataset_name>
├──fomm_crop
│ ├──id/file_name # cropped images
│ │ ├──0.png
│ │ ├──...
├──fomm_pose_crop
│ ├──id
│ │ ├──file_name.npy # pose of the cropped images
│ │ ├──...
├──MFCC
│ ├──id
│ │ ├──file_name.npy # MFCC of the audio
│ │ ├──...
*The cropped images are generated by 'crop_image_tem' in process_data.py
*The pose of the cropped video are generated by 3DDFA_V2/demo.py
*The MFCC of the audio are generated by 'audio2mfcc' in process_data.py
```
- Step 1 : Train the Audio2Facial-Dynamics Module using LRW dataset
`python run.py --config config/train_part1.yaml --mode train_part1 --checkpoint log/124_52000.pth.tar `
- Step 2 : Fine-tune the Audio2Facial-Dynamics Module after getting stable results from step1
`python run.py --config config/train_part1_fine_tune.yaml --mode train_part1_fine_tune --checkpoint log/124_52000.pth.tar --audio_chechpoint checkpoint/from/step_1`
- Setp 3 : Train the Implicit Emotion Displacement Learner
`python run.py --config config/train_part2.yaml --mode train_part2 --checkpoint log/124_52000.pth.tar --audio_chechpoint checkpoint/from/step_2`
## Citation
```
@inproceedings{10.1145/3528233.3530745,
author = {Ji, Xinya and Zhou, Hang and Wang, Kaisiyuan and Wu, Qianyi and Wu, Wayne and Xu, Feng and Cao, Xun},
title = {EAMM: One-Shot Emotional Talking Face via Audio-Based Emotion-Aware Motion Model},
year = {2022},
isbn = {9781450393379},
url = {https://doi.org/10.1145/3528233.3530745},
doi = {10.1145/3528233.3530745},
booktitle = {ACM SIGGRAPH 2022 Conference Proceedings},
series = {SIGGRAPH '22}
}
```
================================================
FILE: augmentation.py
================================================
"""
Code from https://github.com/hassony2/torch_videovision
"""
import numbers
import math
import random
import numpy as np
import PIL
import cv2
from skimage.transform import resize, rotate, AffineTransform, warp
from skimage.util import pad
import torchvision
import warnings
from skimage import img_as_ubyte, img_as_float
def crop_clip(clip, min_h, min_w, h, w):
if isinstance(clip[0], np.ndarray):
cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
elif isinstance(clip[0], PIL.Image.Image):
cropped = [
img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
]
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
return cropped
def pad_clip(clip, h, w):
im_h, im_w = clip[0].shape[:2]
pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2)
pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2)
return pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge')
def resize_clip(clip, size, interpolation='bilinear'):
if isinstance(clip[0], np.ndarray):
if isinstance(size, numbers.Number):
im_h, im_w, im_c = clip[0].shape
# Min spatial dim already matches minimal size
if (im_w <= im_h and im_w == size) or (im_h <= im_w
and im_h == size):
return clip
new_h, new_w = get_resize_sizes(im_h, im_w, size)
size = (new_w, new_h)
else:
size = size[1], size[0]
scaled = [
resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True,
mode='constant', anti_aliasing=True) for img in clip
]
elif isinstance(clip[0], PIL.Image.Image):
if isinstance(size, numbers.Number):
im_w, im_h = clip[0].size
# Min spatial dim already matches minimal size
if (im_w <= im_h and im_w == size) or (im_h <= im_w
and im_h == size):
return clip
new_h, new_w = get_resize_sizes(im_h, im_w, size)
size = (new_w, new_h)
else:
size = size[1], size[0]
if interpolation == 'bilinear':
pil_inter = PIL.Image.NEAREST
else:
pil_inter = PIL.Image.BILINEAR
scaled = [img.resize(size, pil_inter) for img in clip]
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
return scaled
def get_resize_sizes(im_h, im_w, size):
if im_w < im_h:
ow = size
oh = int(size * im_h / im_w)
else:
oh = size
ow = int(size * im_w / im_h)
return oh, ow
class RandomFlip(object):
def __init__(self, time_flip=False, horizontal_flip=False):
self.time_flip = time_flip
self.horizontal_flip = horizontal_flip
def __call__(self, clip):
if random.random() < 0.5 and self.time_flip:
return clip[::-1]
if random.random() < 0.5 and self.horizontal_flip:
return [np.fliplr(img) for img in clip]
return clip
class RandomResize(object):
"""Resizes a list of (H x W x C) numpy.ndarray to the final size
The larger the original image is, the more times it takes to
interpolate
Args:
interpolation (str): Can be one of 'nearest', 'bilinear'
defaults to nearest
size (tuple): (widht, height)
"""
def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
self.ratio = ratio
self.interpolation = interpolation
def __call__(self, clip):
scaling_factor = random.uniform(self.ratio[0], self.ratio[1])
if isinstance(clip[0], np.ndarray):
im_h, im_w, im_c = clip[0].shape
elif isinstance(clip[0], PIL.Image.Image):
im_w, im_h = clip[0].size
new_w = int(im_w * scaling_factor)
new_h = int(im_h * scaling_factor)
new_size = (new_w, new_h)
resized = resize_clip(
clip, new_size, interpolation=self.interpolation)
return resized
class RandomCrop(object):
"""Extract random crop at the same location for a list of videos
Args:
size (sequence or int): Desired output size for the
crop in format (h, w)
"""
def __init__(self, size):
if isinstance(size, numbers.Number):
size = (size, size)
self.size = size
def __call__(self, clip):
"""
Args:
img (PIL.Image or numpy.ndarray): List of videos to be cropped
in format (h, w, c) in numpy.ndarray
Returns:
PIL.Image or numpy.ndarray: Cropped list of videos
"""
h, w = self.size
if isinstance(clip[0], np.ndarray):
im_h, im_w, im_c = clip[0].shape
elif isinstance(clip[0], PIL.Image.Image):
im_w, im_h = clip[0].size
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
clip = pad_clip(clip, h, w)
im_h, im_w = clip.shape[1:3]
x1 = 0 if h == im_h else random.randint(0, im_w - w)
y1 = 0 if w == im_w else random.randint(0, im_h - h)
cropped = crop_clip(clip, y1, x1, h, w)
return cropped
class MouthCrop(object):
"""Extract random crop at the same location for a list of videos
Args:
size (sequence or int): Desired output size for the
crop in format (h, w)
"""
def __init__(self, center_x, center_y, mask_width, mask_height):
self.center_x = center_x
self.center_y = center_y
self.mask_width = mask_width
self.mask_height = mask_height
def __call__(self, clip):
"""
Args:
img (PIL.Image or numpy.ndarray): List of videos to be cropped
in format (h, w, c) in numpy.ndarray
Returns:
PIL.Image or numpy.ndarray: Cropped list of videos
"""
start_x = self.center_x - int(self.mask_width/2)
start_y = self.center_y - int(self.mask_height/2)
end_x = start_x + self.mask_width
end_y = start_y + self.mask_height
# mask is all white
# mask = 255*np.ones((mask_height, mask_width, 3), dtype=np.uint8)
# mask is uniform noise
cropped = []
for i in range(len(clip)):
mask = np.random.rand(self.mask_height, self.mask_width, 3)
img = clip[i].copy()
img[start_y:end_y, start_x:end_x, :] = mask
cropped.append(img)
cropped = np.array(cropped)
return cropped
class RandomRotation(object):
"""Rotate entire clip randomly by a random angle within
given bounds
Args:
degrees (sequence or int): Range of degrees to select from
If degrees is a number instead of sequence like (min, max),
the range of degrees, will be (-degrees, +degrees).
"""
def __init__(self, degrees):
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError('If degrees is a single number,'
'must be positive')
degrees = (-degrees, degrees)
else:
if len(degrees) != 2:
raise ValueError('If degrees is a sequence,'
'it must be of len 2.')
self.degrees = degrees
def __call__(self, clip):
"""
Args:
img (PIL.Image or numpy.ndarray): List of videos to be cropped
in format (h, w, c) in numpy.ndarray
Returns:
PIL.Image or numpy.ndarray: Cropped list of videos
"""
angle = random.uniform(self.degrees[0], self.degrees[1])
if isinstance(clip[0], np.ndarray):
rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip]
elif isinstance(clip[0], PIL.Image.Image):
rotated = [img.rotate(angle) for img in clip]
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
return rotated
class RandomPerspective(object):
"""Rotate entire clip randomly by a random angle within
given bounds
Args:
degrees (sequence or int): Range of degrees to select from
If degrees is a number instead of sequence like (min, max),
the range of degrees, will be (-degrees, +degrees).
"""
def __init__(self, pers_num, enlarge_num):
self.pers_num = pers_num
self.enlarge_num = enlarge_num
def __call__(self, clip):
"""
Args:
img (PIL.Image or numpy.ndarray): List of videos to be cropped
in format (h, w, c) in numpy.ndarray
Returns:
PIL.Image or numpy.ndarray: Cropped list of videos
"""
out = clip
for i in range(len(clip)):
self.pers_size = np.random.randint(20, self.pers_num) * pow(-1, np.random.randint(2))
self.enlarge_size = np.random.randint(20, self.enlarge_num) * pow(-1, np.random.randint(2))
h, w, c = clip[i].shape
crop_size=256
dst = np.array([
[-self.enlarge_size, -self.enlarge_size],
[-self.enlarge_size + self.pers_size, w + self.enlarge_size],
[h + self.enlarge_size, -self.enlarge_size],
[h + self.enlarge_size - self.pers_size, w + self.enlarge_size],], dtype=np.float32)
src = np.array([[-self.enlarge_size, -self.enlarge_size], [-self.enlarge_size, w + self.enlarge_size],
[h + self.enlarge_size, -self.enlarge_size], [h + self.enlarge_size, w + self.enlarge_size]]).astype(np.float32())
M = cv2.getPerspectiveTransform(src, dst)
warped = cv2.warpPerspective(clip[i], M, (crop_size, crop_size), borderMode=cv2.BORDER_REPLICATE)
out[i] = warped
return out
class ColorJitter(object):
"""Randomly change the brightness, contrast and saturation and hue of the clip
Args:
brightness (float): How much to jitter brightness. brightness_factor
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
contrast (float): How much to jitter contrast. contrast_factor
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
saturation (float): How much to jitter saturation. saturation_factor
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
[-hue, hue]. Should be >=0 and <= 0.5.
"""
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
def get_params(self, brightness, contrast, saturation, hue):
if brightness > 0:
brightness_factor = random.uniform(
max(0, 1 - brightness), 1 + brightness)
else:
brightness_factor = None
if contrast > 0:
contrast_factor = random.uniform(
max(0, 1 - contrast), 1 + contrast)
else:
contrast_factor = None
if saturation > 0:
saturation_factor = random.uniform(
max(0, 1 - saturation), 1 + saturation)
else:
saturation_factor = None
if hue > 0:
hue_factor = random.uniform(-hue, hue)
else:
hue_factor = None
return brightness_factor, contrast_factor, saturation_factor, hue_factor
def __call__(self, clip):
"""
Args:
clip (list): list of PIL.Image
Returns:
list PIL.Image : list of transformed PIL.Image
"""
if isinstance(clip[0], np.ndarray):
brightness, contrast, saturation, hue = self.get_params(
self.brightness, self.contrast, self.saturation, self.hue)
# Create img transform function sequence
img_transforms = []
if brightness is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
if saturation is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
if hue is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
if contrast is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
random.shuffle(img_transforms)
img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array,
img_as_float]
with warnings.catch_warnings():
warnings.simplefilter("ignore")
jittered_clip = []
for img in clip:
jittered_img = img
for func in img_transforms:
jittered_img = func(jittered_img)
jittered_clip.append(jittered_img.astype('float32'))
elif isinstance(clip[0], PIL.Image.Image):
brightness, contrast, saturation, hue = self.get_params(
self.brightness, self.contrast, self.saturation, self.hue)
# Create img transform function sequence
img_transforms = []
if brightness is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
if saturation is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
if hue is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
if contrast is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
random.shuffle(img_transforms)
# Apply to all videos
jittered_clip = []
for img in clip:
for func in img_transforms:
jittered_img = func(img)
jittered_clip.append(jittered_img)
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
return jittered_clip
class AllAugmentationTransform:
def __init__(self, crop_mouth_param = None, resize_param=None, rotation_param=None, perspective_param=None, flip_param=None, crop_param=None, jitter_param=None):
self.transforms = []
if crop_mouth_param is not None:
self.transforms.append(MouthCrop(**crop_mouth_param))
if flip_param is not None:
self.transforms.append(RandomFlip(**flip_param))
if rotation_param is not None:
self.transforms.append(RandomRotation(**rotation_param))
if perspective_param is not None:
self.transforms.append(RandomPerspective(**perspective_param))
if resize_param is not None:
self.transforms.append(RandomResize(**resize_param))
if crop_param is not None:
self.transforms.append(RandomCrop(**crop_param))
if jitter_param is not None:
self.transforms.append(ColorJitter(**jitter_param))
def __call__(self, clip):
for t in self.transforms:
clip = t(clip)
return clip
================================================
FILE: config/MEAD_emo_video_aug_delta_4_crop_random_crop.yaml
================================================
dataset_params:
root_dir: /mnt/lustre/share_data/jixinya/MEAD/
frame_shape: [256, 256, 3]
id_sampling: False
pairs_list: Random_choice
augmentation_params:
crop_mouth_param:
center_x: 135
center_y: 190
mask_width: 100
mask_height: 60
rotation_param:
degrees: 30
perspective_param:
pers_num: 30
enlarge_num: 40
flip_param:
horizontal_flip: True
time_flip: False
jitter_param:
brightness: 0
contrast: 0
saturation: 0
hue: 0
model_params:
common_params:
num_kp: 10
num_channels: 3
estimate_jacobian: True
audio_params:
num_kp: 10
num_channels : 3
num_channels_a : 3
estimate_jacobian: True
kp_detector_params:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25
num_blocks: 5
generator_params:
block_expansion: 64
max_features: 512
num_down_blocks: 2
num_bottleneck_blocks: 6
estimate_occlusion_map: True
dense_motion_params:
block_expansion: 64
max_features: 1024
num_blocks: 5
scale_factor: 0.25
discriminator_params:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
sn: True
train_params:
type: linear_4
smooth: False
jaco_net: cnn
ldmark: fake
generator: not
train_generator: False
num_epochs: 300
num_repeats: 1
epoch_milestones: [60, 90]
lr_generator: 2.0e-4
lr_discriminator: 2.0e-4
lr_kp_detector: 2.0e-4
lr_audio_feature: 2.0e-4
batch_size: 16
scales: [1, 0.5, 0.25, 0.125]
checkpoint_freq: 1
transform_params:
sigma_affine: 0.05
sigma_tps: 0.005
points_tps: 5
loss_weights:
generator_gan: 0
discriminator_gan: 1
feature_matching: [10, 10, 10, 10]
perceptual: [10, 10, 10, 10, 10]
equivariance_value: 0
equivariance_jacobian: 0
emo: 10
reconstruction_params:
num_videos: 1000
format: '.mp4'
animate_params:
num_pairs: 50
format: '.mp4'
normalization_params:
adapt_movement_scale: False
use_relative_movement: True
use_relative_jacobian: True
visualizer_params:
kp_size: 5
draw_border: True
colormap: 'gist_rainbow'
================================================
FILE: config/train_part1.yaml
================================================
dataset_params:
name: Vox
root_dir: dataset/LRW/
frame_shape: [256, 256, 3]
id_sampling: False
augmentation_params:
flip_param:
horizontal_flip: False
time_flip: False
jitter_param:
brightness: 0.1
contrast: 0.1
saturation: 0.1
hue: 0.1
model_params:
common_params:
num_kp: 10
num_channels: 3
estimate_jacobian: True
audio_params:
num_kp: 10
num_channels : 3
num_channels_a : 3
estimate_jacobian: True
kp_detector_params:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25
num_blocks: 5
generator_params:
block_expansion: 64
max_features: 512
num_down_blocks: 2
num_bottleneck_blocks: 6
estimate_occlusion_map: True
dense_motion_params:
block_expansion: 64
max_features: 1024
num_blocks: 5
scale_factor: 0.25
discriminator_params:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
sn: True
train_params:
jaco_net: cnn
ldmark: fake
generator: not
num_epochs: 300
num_repeats: 1
epoch_milestones: [60, 90]
lr_generator: 2.0e-4
lr_discriminator: 2.0e-4
lr_kp_detector: 2.0e-4
lr_audio_feature: 2.0e-4
batch_size: 8
scales: [1, 0.5, 0.25, 0.125]
checkpoint_freq: 1
transform_params:
sigma_affine: 0.05
sigma_tps: 0.005
points_tps: 5
loss_weights:
generator_gan: 0
discriminator_gan: 0
feature_matching: [10, 10, 10, 10]
perceptual: [10, 10, 10, 10, 10]
equivariance_value: 0
equivariance_jacobian: 0
audio: 10
visualizer_params:
kp_size: 5
draw_border: True
colormap: 'gist_rainbow'
================================================
FILE: config/train_part1_fine_tune.yaml
================================================
dataset_params:
name: LRW
root_dir: dataset/LRW/
frame_shape: [256, 256, 3]
id_sampling: False
augmentation_params:
flip_param:
horizontal_flip: False
time_flip: False
jitter_param:
brightness: 0.1
contrast: 0.1
saturation: 0.1
hue: 0.1
model_params:
common_params:
num_kp: 10
num_channels: 3
estimate_jacobian: True
audio_params:
num_kp: 10
num_channels : 3
num_channels_a : 3
estimate_jacobian: True
kp_detector_params:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25
num_blocks: 5
generator_params:
block_expansion: 64
max_features: 512
num_down_blocks: 2
num_bottleneck_blocks: 6
estimate_occlusion_map: True
dense_motion_params:
block_expansion: 64
max_features: 1024
num_blocks: 5
scale_factor: 0.25
discriminator_params:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
sn: True
train_params:
jaco_net: cnn
ldmark: fake
generator: audio
num_epochs: 300
num_repeats: 1
epoch_milestones: [60, 90]
lr_generator: 2.0e-4
lr_discriminator: 2.0e-4
lr_kp_detector: 2.0e-4
lr_audio_feature: 2.0e-4
batch_size: 6
scales: [1, 0.5, 0.25, 0.125]
checkpoint_freq: 1
transform_params:
sigma_affine: 0.05
sigma_tps: 0.005
points_tps: 5
loss_weights:
generator_gan: 0
discriminator_gan: 0
feature_matching: [10, 10, 10, 10]
perceptual: [0.1, 0.1, 0.1, 0.1, 0.1]
equivariance_value: 0
equivariance_jacobian: 0
audio: 10
visualizer_params:
kp_size: 5
draw_border: True
colormap: 'gist_rainbow'
================================================
FILE: config/train_part2.yaml
================================================
dataset_params:
name: MEAD
root_dir: dataset/MEAD/
frame_shape: [256, 256, 3]
id_sampling: False
augmentation_params:
crop_mouth_param:
center_x: 135
center_y: 190
mask_width: 100
mask_height: 60
rotation_param:
degrees: 30
perspective_param:
pers_num: 30
enlarge_num: 40
flip_param:
horizontal_flip: True
time_flip: False
jitter_param:
brightness: 0
contrast: 0
saturation: 0
hue: 0
model_params:
common_params:
num_kp: 10
num_channels: 3
estimate_jacobian: True
audio_params:
num_kp: 10
num_channels : 3
num_channels_a : 3
estimate_jacobian: True
kp_detector_params:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25
num_blocks: 5
generator_params:
block_expansion: 64
max_features: 512
num_down_blocks: 2
num_bottleneck_blocks: 6
estimate_occlusion_map: True
dense_motion_params:
block_expansion: 64
max_features: 1024
num_blocks: 5
scale_factor: 0.25
discriminator_params:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
sn: True
train_params:
type: linear_4
smooth: False
jaco_net: cnn
ldmark: fake
generator: not
num_epochs: 300
num_repeats: 1
epoch_milestones: [60, 90]
lr_generator: 2.0e-4
lr_discriminator: 2.0e-4
lr_kp_detector: 2.0e-4
lr_audio_feature: 2.0e-4
batch_size: 16
scales: [1, 0.5, 0.25, 0.125]
checkpoint_freq: 1
transform_params:
sigma_affine: 0.05
sigma_tps: 0.005
points_tps: 5
loss_weights:
generator_gan: 0
discriminator_gan: 0
feature_matching: [10, 10, 10, 10]
perceptual: [10, 10, 10, 10, 10]
equivariance_value: 0
equivariance_jacobian: 0
emo: 10
visualizer_params:
kp_size: 5
draw_border: True
colormap: 'gist_rainbow'
================================================
FILE: demo.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Oct 6 20:57:27 2021
@author: thea
"""
import matplotlib
matplotlib.use('Agg')
import os,sys
import yaml
from argparse import ArgumentParser
from tqdm import tqdm
from skimage import io, img_as_float32
import imageio
import numpy as np
from skimage.transform import resize
from skimage import img_as_ubyte
import torch
from filter1 import OneEuroFilter
import torch.utils
from torch.autograd import Variable
from modules.generator import OcclusionAwareGenerator
from modules.keypoint_detector import KPDetector, KPDetector_a
from modules.util import AT_net, Emotion_k, Emotion_map, AT_net2
from augmentation import AllAugmentationTransform
from scipy.spatial import ConvexHull
import python_speech_features
from pathlib import Path
import dlib
import cv2
import librosa
from skimage import transform as tf
#from audiolm.models import AT_emoiton
#from audiolm.utils import plot_flmarks
if sys.version_info[0] < 3:
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.6")
detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor('./shape_predictor_68_face_landmarks.dat')
def load_checkpoints(opt, checkpoint_path, audio_checkpoint_path, emo_checkpoint_path, cpu=False):
with open(opt.config) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
**config['model_params']['common_params'])
if not cpu:
generator.cuda()
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
**config['model_params']['common_params'])
if not cpu:
kp_detector.cuda()
kp_detector_a = KPDetector_a(**config['model_params']['kp_detector_params'],
**config['model_params']['audio_params'])
audio_feature = AT_net2()
if opt.type.startswith('linear'):
emo_detector = Emotion_k(block_expansion=32, num_channels=3, max_features=1024,
num_blocks=5, scale_factor=0.25, num_classes=8)
elif opt.type.startswith('map'):
emo_detector = Emotion_map(block_expansion=32, num_channels=3, max_features=1024,
num_blocks=5, scale_factor=0.25, num_classes=8)
if not cpu:
kp_detector_a.cuda()
audio_feature.cuda()
emo_detector.cuda()
if cpu:
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
audio_checkpoint = torch.load(audio_checkpoint_path, map_location=torch.device('cpu'))
emo_checkpoint = torch.load(emo_checkpoint_path, map_location=torch.device('cpu'))
else:
checkpoint = torch.load(checkpoint_path)
audio_checkpoint = torch.load(audio_checkpoint_path)
emo_checkpoint = torch.load(emo_checkpoint_path)
generator.load_state_dict(checkpoint['generator'])
kp_detector.load_state_dict(checkpoint['kp_detector'])
audio_feature.load_state_dict(audio_checkpoint['audio_feature'])
kp_detector_a.load_state_dict(audio_checkpoint['kp_detector_a'])
emo_detector.load_state_dict(emo_checkpoint['emo_detector'])
if not cpu:
generator = generator.cuda()
kp_detector = kp_detector.cuda()
audio_feature = audio_feature.cuda()
kp_detector_a = kp_detector_a.cuda()
emo_detector = emo_detector.cuda()
generator.eval()
kp_detector.eval()
audio_feature.eval()
kp_detector_a.eval()
emo_detector.eval()
return generator, kp_detector, kp_detector_a, audio_feature, emo_detector
def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
use_relative_movement=False, use_relative_jacobian=False):
if adapt_movement_scale:
source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
else:
adapt_movement_scale = 1
kp_new = {k: v for k, v in kp_driving.items()}
if use_relative_movement:
kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
kp_value_diff *= adapt_movement_scale
kp_new['value'] = kp_value_diff + kp_source['value']
if use_relative_jacobian:
jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
return kp_new
def shape_to_np(shape, dtype="int"):
# initialize the list of (x, y)-coordinates
coords = np.zeros((shape.num_parts, 2), dtype=dtype)
# loop over all facial landmarks and convert them
# to a 2-tuple of (x, y)-coordinates
for i in range(0, shape.num_parts):
coords[i] = (shape.part(i).x, shape.part(i).y)
# return the list of (x, y)-coordinates
return coords
def get_aligned_image(driving_video, opt):
aligned_array = []
video_array = np.array(driving_video)
source_image=video_array[0]
# aligned_array.append(source_image)
source_image = np.array(source_image * 255, dtype=np.uint8)
gray = cv2.cvtColor(source_image, cv2.COLOR_BGR2GRAY)
rects = detector(gray, 1) #detect human face
for (i, rect) in enumerate(rects):
template = predictor(gray, rect) #detect 68 points
template = shape_to_np(template)
if opt.emotion == 'surprised' or opt.emotion == 'fear':
template = template-[0,10]
for i in range(len(video_array)):
image=np.array(video_array[i] * 255, dtype=np.uint8)
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
rects = detector(gray, 1) #detect human face
for (j, rect) in enumerate(rects):
shape = predictor(gray, rect) #detect 68 points
shape = shape_to_np(shape)
pts2 = np.float32(template[:35,:])
pts1 = np.float32(shape[:35,:]) #eye and nose
# pts2 = np.float32(np.concatenate((template[:16,:],template[27:36,:]),axis = 0))
# pts1 = np.float32(np.concatenate((shape[:16,:],shape[27:36,:]),axis = 0)) #eye and nose
# pts1 = np.float32(landmark[17:35,:])
tform = tf.SimilarityTransform()
tform.estimate( pts2, pts1) #Set the transformation matrix with the explicit parameters.
dst = tf.warp(image, tform, output_shape=(256, 256))
dst = np.array(dst, dtype=np.float32)
aligned_array.append(dst)
return aligned_array
def get_transformed_image(driving_video, opt):
video_array = np.array(driving_video)
with open(opt.config) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
transformations = AllAugmentationTransform(**config['dataset_params']['augmentation_params'])
transformed_array = transformations(video_array)
return transformed_array
def make_animation_smooth(source_image, driving_video, transformed_video, deco_out, kp_loss, generator, kp_detector, kp_detector_a, emo_detector, opt, relative=True, adapt_movement_scale=True, cpu=False):
with torch.no_grad():
predictions = []
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
if not cpu:
source = source.cuda()
driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
transformed_driving = torch.tensor(np.array(transformed_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
kp_source = kp_detector(source)
kp_driving_initial = kp_detector_a(deco_out[:,0])
emo_driving_all = []
features = []
kp_driving_all = []
for frame_idx in tqdm(range(len(deco_out[0]))):
driving_frame = driving[:, :, frame_idx]
transformed_frame = transformed_driving[:, :, frame_idx]
if not cpu:
driving_frame = driving_frame.cuda()
transformed_frame = transformed_frame.cuda()
kp_driving = kp_detector_a(deco_out[:,frame_idx])
kp_driving_all.append(kp_driving)
if opt.add_emo:
value = kp_driving['value']
jacobian = kp_driving['jacobian']
if opt.type == 'linear_3':
emo_driving,_ = emo_detector(transformed_frame,value,jacobian)
features.append(emo_detector.feature(transformed_frame).data.cpu().numpy())
emo_driving_all.append(emo_driving)
features = np.array(features)
if opt.add_emo:
one_euro_filter_v = OneEuroFilter(mincutoff=1, beta=0.2, dcutoff=1.0, freq=100)#1 0.4
one_euro_filter_j = OneEuroFilter(mincutoff=1, beta=0.2, dcutoff=1.0, freq=100)#1 0.4
for j in range(len(emo_driving_all)):
emo_driving_all[j]['value']=one_euro_filter_v.process(emo_driving_all[j]['value'].cpu()*100)/100
emo_driving_all[j]['value'] = emo_driving_all[j]['value'].cuda()
emo_driving_all[j]['jacobian']=one_euro_filter_j.process(emo_driving_all[j]['jacobian'].cpu()*100)/100
emo_driving_all[j]['jacobian'] = emo_driving_all[j]['jacobian'].cuda()
one_euro_filter_v = OneEuroFilter(mincutoff=0.05, beta=8, dcutoff=1.0, freq=100)
one_euro_filter_j = OneEuroFilter(mincutoff=0.05, beta=8, dcutoff=1.0, freq=100)
for j in range(len(kp_driving_all)):
kp_driving_all[j]['value']=one_euro_filter_v.process(kp_driving_all[j]['value'].cpu()*10)/10
kp_driving_all[j]['value'] = kp_driving_all[j]['value'].cuda()
kp_driving_all[j]['jacobian']=one_euro_filter_j.process(kp_driving_all[j]['jacobian'].cpu()*10)/10
kp_driving_all[j]['jacobian'] = kp_driving_all[j]['jacobian'].cuda()
for frame_idx in tqdm(range(len(deco_out[0]))):
if opt.check_add:
kp_driving = kp_detector_a(deco_out[:,0])
else:
kp_driving = kp_driving_all[frame_idx]
# kp_driving_real = kp_detector(driving_frame)
# kp_driving['value'] = (1-opt.weight)*kp_driving['value'] + opt.weight*kp_driving_real['value']
# kp_driving['jacobian'] = (1-opt.weight)*kp_driving['jacobian'] + opt.weight*kp_driving_real['jacobian']
if opt.add_emo:
emo_driving = emo_driving_all[frame_idx]
if opt.type == 'linear_3':
kp_driving['value'][:,1] = kp_driving['value'][:,1] + emo_driving['value'][:,0]*0.2
kp_driving['jacobian'][:,1] = kp_driving['jacobian'][:,1] + emo_driving['jacobian'][:,0]*0.2
kp_driving['value'][:,4] = kp_driving['value'][:,4] + emo_driving['value'][:,1]
kp_driving['jacobian'][:,4] = kp_driving['jacobian'][:,4] + emo_driving['jacobian'][:,1]
kp_driving['value'][:,6] = kp_driving['value'][:,6] + emo_driving['value'][:,2]
kp_driving['jacobian'][:,6] = kp_driving['jacobian'][:,6] + emo_driving['jacobian'][:,2]
# kp_driving['value'][:,8] = kp_driving['value'][:,8] + emo_driving['value'][:,3]
# kp_driving['jacobian'][:,8] = kp_driving['jacobian'][:,8] + emo_driving['jacobian'][:,3]
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
return predictions, features
def test_auido(example_image, audio_feature, all_pose, opt):
with open(opt.config) as f:
para = yaml.load(f, Loader=yaml.FullLoader)
# encoder = audio_feature()
if not opt.cpu:
audio_feature = audio_feature.cuda()
audio_feature.eval()
# decoder.eval()
test_file = opt.in_file
pose = all_pose[:,:6]
if len(pose) == 1:
pose = np.repeat(pose,100,0)
elif opt.smooth_pose:
one_euro_filter = OneEuroFilter(mincutoff=0.004, beta=0.7, dcutoff=1.0, freq=100)
for j in range(len(pose)):
pose[j]=one_euro_filter.process(pose[j])
# pose[j]=pose[0]
example_image = np.array(example_image, dtype='float32').transpose((2, 0, 1))
speech, sr = librosa.load(test_file, sr=16000)
# mfcc = python_speech_features.mfcc(speech ,16000,winstep=0.01)
speech = np.insert(speech, 0, np.zeros(1920))
speech = np.append(speech, np.zeros(1920))
mfcc = python_speech_features.mfcc(speech,16000,winstep=0.01)
print ('=======================================')
print ('Start to generate images')
ind = 3
with torch.no_grad():
fake_lmark = []
input_mfcc = []
while ind <= int(mfcc.shape[0]/4) - 4:
t_mfcc =mfcc[( ind - 3)*4: (ind + 4)*4, 1:]
t_mfcc = torch.FloatTensor(t_mfcc).cuda()
input_mfcc.append(t_mfcc)
ind += 1
input_mfcc = torch.stack(input_mfcc,dim = 0)
if (len(pose)<len(input_mfcc)):
gap = len(input_mfcc)-len(pose)
n = int((gap/len(pose)/2)) +2
pose = np.concatenate((pose,pose[::-1,:]),axis = 0)
pose = np.tile(pose, (n,1))
if(len(pose)>len(input_mfcc)):
pose = pose[:len(input_mfcc),:]
if not opt.cpu:
example_image = Variable(torch.FloatTensor(example_image.astype(float)) ).cuda()
example_image = torch.unsqueeze(example_image,0)
pose = Variable(torch.FloatTensor(pose.astype(float)) ).cuda()
pose = pose.unsqueeze(0)
input_mfcc = input_mfcc.unsqueeze(0)
deco_out = audio_feature(example_image,input_mfcc,pose,para['train_params']['jaco_net'],1.6)
return deco_out
def save(path, frames, format):
if format == '.png':
if not os.path.exists(path):
os.makedirs(path)
for j, frame in enumerate(frames):
imageio.imsave(path+'/'+str(j)+'.png',frame)
# imageio.imsave(os.path.join(path, str(j) + '.png'), frames[j])
else:
print ("Unknown format %s" % format)
exit()
class VideoWriter(object):
def __init__(self, path, width, height, fps):
fourcc = cv2.VideoWriter_fourcc(*'XVID')
self.path = path
self.out = cv2.VideoWriter(self.path, fourcc, fps, (width, height))
def write_frame(self, frame):
self.out.write(frame)
def end(self):
self.out.release()
def concatenate(number, imgs, save_path):
width, height = imgs.shape[-3:-1]
imgs = imgs.reshape(number,-1,width,height,3)
if number == 2:
left = imgs[0]
right = imgs[1]
im_all = []
for i in range(len(left)):
im = np.concatenate((left[i],right[i]),axis = 1)
im_all.append(im)
if number == 3:
left = imgs[0]
middle = imgs[1]
right = imgs[2]
im_all = []
for i in range(len(left)):
im = np.concatenate((left[i],middle[i],right[i]),axis = 1)
im_all.append(im)
if number == 4:
left = imgs[0]
left2 = imgs[1]
right = imgs[2]
right2 = imgs[3]
im_all = []
for i in range(len(left)):
im = np.concatenate((left[i],left2[i],right[i],right2[i]),axis = 1)
im_all.append(im)
if number == 5:
left = imgs[0]
left2 = imgs[1]
middle = imgs[2]
right = imgs[3]
right2 = imgs[4]
im_all = []
for i in range(len(left)):
im = np.concatenate((left[i],left2[i],middle[i],right[i],right2[i]),axis = 1)
im_all.append(im)
imageio.mimsave(save_path, [img_as_ubyte(frame) for frame in im_all], fps=25)
def add_audio(video_name=None, audio_dir = None):
command = 'ffmpeg -i ' + video_name + ' -i ' + audio_dir + ' -vcodec copy -acodec copy -y ' + video_name.replace('.mp4','.mov')
print (command)
os.system(command)
def crop_image(source_image):
template = np.load('./M003_template.npy')
image= cv2.imread(source_image)
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
rects = detector(gray, 1) #detect human face
if len(rects) != 1:
return 0
for (j, rect) in enumerate(rects):
shape = predictor(gray, rect) #detect 68 points
shape = shape_to_np(shape)
pts2 = np.float32(template[:47,:])
pts1 = np.float32(shape[:47,:]) #eye and nose
# pts1 = np.float32(landmark[17:35,:])
tform = tf.SimilarityTransform()
tform.estimate( pts2, pts1) #Set the transformation matrix with the explicit parameters.
dst = tf.warp(image, tform, output_shape=(256, 256))
dst = np.array(dst * 255, dtype=np.uint8)
return dst
def smooth_pose(pose_file, pose_long):
start = np.load(pose_file)
video_pose = np.load(pose_long)
delta = video_pose - video_pose[0,:]
print(len(delta))
pose = np.repeat(start,len(delta),axis = 0)
all_pose = pose + delta
return all_pose
def test(opt, name):
all_pose = np.load(opt.pose_file).reshape(-1,7)
if opt.pose_long:
all_pose = smooth_pose(opt.pose_file,opt.pose_given)
# source_image = img_as_float32(io.imread(opt.source_image))
source_image = img_as_float32(crop_image(opt.source_image))
source_image = resize(source_image, (256, 256))[..., :3]
reader = imageio.get_reader(opt.driving_video)
fps = reader.get_meta_data()['fps']
driving_video = []
try:
for im in reader:
driving_video.append(im)
except RuntimeError:
pass
reader.close()
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
driving_video = get_aligned_image(driving_video, opt)
transformed_video = get_transformed_image(driving_video, opt)
transformed_video = np.array(transformed_video)
generator, kp_detector,kp_detector_a, audio_feature, emo_detector = load_checkpoints(opt=opt, checkpoint_path=opt.checkpoint, audio_checkpoint_path=opt.audio_checkpoint, emo_checkpoint_path = opt.emo_checkpoint, cpu=opt.cpu)
deco_out = test_auido(source_image, audio_feature, all_pose, opt)
if len(driving_video) < len(deco_out[0]):
driving_video = np.resize(driving_video,(len(deco_out[0]),256,256,3))
transformed_video = np.resize(transformed_video,(len(deco_out[0]),256,256,3))
else:
driving_video = driving_video[:len(deco_out[0])]
opt.add_emo = False
predictions, _ = make_animation_smooth(source_image, driving_video, transformed_video, deco_out, opt.kp_loss, generator, kp_detector, kp_detector_a, emo_detector, opt, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
imageio.mimsave(os.path.join(opt.result_path,'neutral.mp4'), [img_as_ubyte(frame) for frame in predictions], fps=fps)
predictions = np.array(predictions)
opt.add_emo = True
predictions1,_ = make_animation_smooth(source_image, driving_video, transformed_video, deco_out, opt.kp_loss, generator, kp_detector, kp_detector_a, emo_detector, opt, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
imageio.mimsave(os.path.join(opt.result_path,'emotion.mp4'), [img_as_ubyte(frame) for frame in predictions1], fps=fps)
add_audio(os.path.join(opt.result_path,'emotion.mp4'),opt.in_file)
predictions1 = np.array(predictions1)
all_imgs = np.concatenate((driving_video,predictions,predictions1),axis = 0)
save_path = os.path.join(opt.result_path, 'all.mp4')
concatenate(3, all_imgs, save_path)
add_audio(save_path,opt.in_file)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--config", default ='config/MEAD_emo_video_aug_delta_4_crop_random_crop.yaml', help="path to config")#required=True default ='config/vox-256.yaml'
parser.add_argument("--audio_checkpoint", default='log/1-6000.pth.tar', help="path to checkpoint to restore")
parser.add_argument("--checkpoint", default='log/124_52000.pth.tar', help="path to checkpoint to restore")
# parser.add_argument("--emo_checkpoint", default='ablation/ablation/ten/10-6000.pth.tar', help="path to checkpoint to restore")
parser.add_argument("--emo_checkpoint", default='log/5-3000.pth.tar', help="path to checkpoint to restore")
parser.add_argument("--source_image", default='test/image/21.png', help="path to source image")
parser.add_argument("--driving_video", default='test/video/disgusted.mp4', help="path to driving video")#data/M030/video/M030_angry_
parser.add_argument('--in_file', type=str, default='test/audio/sample1.mov')
parser.add_argument('--pose_file', type=str, default='test/pose/21.npy')
parser.add_argument('--pose_given', type=str, default='test/pose_long/0zn70Ak8lRc_Daniel_Auteuil_0zn70Ak8lRc_0002.npy')
parser.add_argument("--result_path", default='result/', help="path to output")#'/media/thea/新加卷/fomm/Exp/'+emotion+'.mp4'
parser.add_argument("--relative", dest="relative", action="store_true", help="use relative or absolute keypoint coordinates")
parser.add_argument("--adapt_scale", dest="adapt_scale", action="store_true", help="adapt movement scale based on convex hull of keypoints")
parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
parser.add_argument("--kp_loss", default=0, help="keypoint loss.")
parser.add_argument("--smooth_pose", default=True, help="cpu mode.")
parser.add_argument("--pose_long", default=False, help="use given long poses.")
parser.add_argument("--weight", default=0, help="cpu mode.")
parser.add_argument("--add_emo", default=False, help="add emotion.")
parser.add_argument("--check_add", default=False, help="check emotion displacement.")
parser.add_argument("--type", default='linear_3', help="add emotion type.")
parser.add_argument("--emotion", default='disgusted', help="emotion category, 'angry', 'contempt','disgusted','fear','happy','neutral','sad','surprised'.")
parser.set_defaults(relative=False)
parser.set_defaults(adapt_scale=False)
opt = parser.parse_args()
# opt.cpu = True
test(opt,'test')
================================================
FILE: filter1.py
================================================
import cv2
#import pickle
import time
import numpy as np
import copy
from matplotlib import pyplot as plt
from tqdm import tqdm
class LowPassFilter:
def __init__(self):
self.prev_raw_value = None
self.prev_filtered_value = None
def process(self, value, alpha):
if self.prev_raw_value is None:
s = value
else:
s = alpha * value + (1.0 - alpha) * self.prev_filtered_value
self.prev_raw_value = value
self.prev_filtered_value = s
return s
class OneEuroFilter:
def __init__(self, mincutoff=1.0, beta=0.0, dcutoff=1.0, freq=30):
self.freq = freq
self.mincutoff = mincutoff
self.beta = beta
self.dcutoff = dcutoff
self.x_filter = LowPassFilter()
self.dx_filter = LowPassFilter()
def compute_alpha(self, cutoff):
te = 1.0 / self.freq
tau = 1.0 / (2 * np.pi * cutoff)
return 1.0 / (1.0 + tau / te)
def process(self, x):
prev_x = self.x_filter.prev_raw_value
dx = 0.0 if prev_x is None else (x - prev_x) * self.freq
edx = self.dx_filter.process(dx, self.compute_alpha(self.dcutoff))
cutoff = self.mincutoff + self.beta * np.abs(edx)
return self.x_filter.process(x, self.compute_alpha(cutoff))
================================================
FILE: frames_dataset.py
================================================
import os
from skimage import io, img_as_float32, transform
from skimage.color import gray2rgb
from sklearn.model_selection import train_test_split
from imageio import mimread
import numpy as np
from torch.utils.data import Dataset
import pandas as pd
from augmentation import AllAugmentationTransform
import glob
import pickle
import random
from filter1 import OneEuroFilter
def read_video(name, frame_shape):
"""
Read video which can be:
- an image of concatenated frames
- '.mp4' and'.gif'
- folder with videos
"""
if os.path.isdir(name):
frames = sorted(os.listdir(name))
num_frames = len(frames)
video_array = np.array(
[img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)])
elif name.lower().endswith('.png') or name.lower().endswith('.jpg'):
image = io.imread(name)
if len(image.shape) == 2 or image.shape[2] == 1:
image = gray2rgb(image)
if image.shape[2] == 4:
image = image[..., :3]
image = img_as_float32(image)
video_array = np.moveaxis(image, 1, 0)
video_array = video_array.reshape((-1,) + frame_shape)
video_array = np.moveaxis(video_array, 1, 2)
elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'):
video = np.array(mimread(name))
if len(video.shape) == 3:
video = np.array([gray2rgb(frame) for frame in video])
if video.shape[-1] == 4:
video = video[..., :3]
video_array = img_as_float32(video)
else:
raise Exception("Unknown file extensions %s" % name)
return video_array
def get_list(ipath,base_name):
#ipath = '/mnt/lustre/share/jixinya/LRW/pose/train_fo/'
ipath = os.path.join(ipath,base_name)
name_list = os.listdir(ipath)
image_path = os.path.join('/mnt/lustre/share/jixinya/LRW/Image/',base_name)
all = []
for k in range(len(name_list)):
name = name_list[k]
path_ = os.path.join(ipath,name)
Dir = os.listdir(path_)
for i in range(len(Dir)):
word = Dir[i]
path = os.path.join(path_, word)
if os.path.exists(os.path.join(image_path,name,word.split('.')[0])):
all.append(name+'/'+word.split('.')[0])
#print(k,name,i,word)
print('get list '+os.path.basename(ipath))
return all
class AudioDataset(Dataset):
"""
Dataset of videos, each video can be represented as:
- an image of concatenated frames
- '.mp4' or '.gif'
- folder with all frames
"""
def __init__(self, name, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,
random_seed=0, augmentation_params=None):
self.root_dir = root_dir
self.audio_dir = os.path.join(root_dir,'MFCC')
self.image_dir = os.path.join(root_dir,'Image')
self.pose_dir = os.path.join(root_dir,'pose')
# assert len(os.listdir(self.audio_dir)) == len(os.listdir(self.image_dir)), 'audio and image length not equal'
# self.videos=np.load('../LRW/list/train_fo.npy')
# self.videos = os.listdir(self.landmark_dir)
self.frame_shape = tuple(frame_shape)
self.id_sampling = id_sampling
if os.path.exists(os.path.join(self.pose_dir, 'train_fo')):
assert os.path.exists(os.path.join(self.pose_dir, 'test_fo'))
print("Use predefined train-test split.")
if id_sampling:
train_videos = {os.path.basename(video).split('#')[0] for video in
os.listdir(os.path.join(self.image_dir, 'train'))}
train_videos = list(train_videos)
else:
train_videos = np.load('../LRW/list/train_fo.npy')# get_list(self.pose_dir, 'train_fo')
# df=open('../LRW/list/test_fo.txt','rb')
test_videos=np.load('../LRW/list/test_fo.npy')
# df.close()
# test_videos = np.load('../LRW/list/train_fo.npy')
#get_list(self.pose_dir, 'test_fo')
# self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')
self.image_dir = os.path.join(self.image_dir, 'train_fo' if is_train else 'test_fo')
self.audio_dir = os.path.join(self.audio_dir, 'train' if is_train else 'test')
self.pose_dir = os.path.join(self.pose_dir, 'train_fo' if is_train else 'test_fo')
else:
print("Use random train-test split.")
train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)
if is_train:
self.videos = train_videos
else:
self.videos = test_videos
self.is_train = is_train
if self.is_train:
self.transform = AllAugmentationTransform(**augmentation_params)
else:
self.transform = None
def __len__(self):
return len(self.videos)
def __getitem__(self, idx):
if self.is_train and self.id_sampling:
name = self.videos[idx].split('.')[0]
path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))
else:
name = self.videos[idx].split('.')[0]
audio_path = os.path.join(self.audio_dir, name)
pose_path = os.path.join(self.pose_dir,name)
path = os.path.join(self.image_dir, name)
video_name = os.path.basename(path)
if os.path.isdir(path):
# if self.is_train and os.path.isdir(path):
# mfcc loading
r = random.choice([x for x in range(3, 8)])
example_image = img_as_float32(io.imread(os.path.join(path, str(r)+'.png')))
mfccs = []
for ind in range(1, 17):
# t_mfcc = mfcc[(r + ind - 3) * 4: (r + ind + 4) * 4, 1:]
t_mfcc = np.load(os.path.join(audio_path,str(r + ind)+'.npy'),allow_pickle=True)[:, 1:]
mfccs.append(t_mfcc)
mfccs = np.array(mfccs)
poses = []
video_array = []
for ind in range(1, 17):
t_pose = np.load(os.path.join(self.pose_dir,name+'.npy'))[r+ind,:-1]
poses.append(t_pose)
image = img_as_float32(io.imread(os.path.join(path, str(r + ind)+'.png')))
video_array.append(image)
poses = np.array(poses)
video_array = np.array(video_array)
else:
print('Wrong, data path not an existing file.')
if self.transform is not None:
video_array = self.transform(video_array)
out = {}
driving = np.array(video_array, dtype='float32')
spatial_size = np.array(driving.shape[1:3][::-1])[np.newaxis]
driving_pose = np.array(poses, dtype='float32')
example_image = np.array(example_image, dtype='float32')
out['example_image'] = example_image.transpose((2, 0, 1))
out['driving_pose'] = driving_pose
out['driving'] = driving.transpose((0, 3, 1, 2))
out['driving_audio'] = np.array(mfccs, dtype='float32')
# out['name'] = video_name
return out
class VoxDataset(Dataset):
"""
Dataset of videos, each video can be represented as:
- an image of concatenated frames
- '.mp4' or '.gif'
- folder with all frames
"""
def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,
random_seed=0, pairs_list=None, augmentation_params=None):
self.root_dir = root_dir
self.audio_dir = os.path.join(root_dir,'MFCC')
self.image_dir = os.path.join(root_dir,'align_img')
self.pose_dir = os.path.join(root_dir,'align_pose')
# assert len(os.listdir(self.audio_dir)) == len(os.listdir(self.image_dir)), 'audio and image length not equal'
# df=open('../LRW/list/test_fo.txt','rb')
# self.videos=pickle.load(df)
# df.close()
self.videos=np.load('/mnt/lustre/share_data/jixinya/VoxCeleb1_Cut/right.npy')
# self.videos = os.listdir(self.landmark_dir)
self.frame_shape = tuple(frame_shape)
self.pairs_list = pairs_list
self.id_sampling = id_sampling
if os.path.exists(os.path.join(self.pose_dir, 'train_fo')):
assert os.path.exists(os.path.join(self.pose_dir, 'test_fo'))
print("Use predefined train-test split.")
if id_sampling:
train_videos = {os.path.basename(video).split('#')[0] for video in
os.listdir(os.path.join(self.image_dir, 'train'))}
train_videos = list(train_videos)
else:
train_videos = np.load('/mnt/lustre/share_data/jixinya/VoxCeleb1_Cut/right.npy')# get_list(self.pose_dir, 'train_fo')
self.image_dir = os.path.join(self.image_dir, 'train_fo' if is_train else 'test_fo')
self.audio_dir = os.path.join(self.audio_dir, 'train' if is_train else 'test')
self.pose_dir = os.path.join(self.pose_dir, 'train_fo' if is_train else 'test_fo')
else:
print("Use random train-test split.")
train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)
if is_train:
self.videos = train_videos
else:
self.videos = test_videos
self.is_train = is_train
if self.is_train:
self.transform = AllAugmentationTransform(**augmentation_params)
else:
self.transform = None
def __len__(self):
return len(self.videos)
def __getitem__(self, idx):
if self.is_train and self.id_sampling:
name = self.videos[idx].split('.')[0]
path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))
else:
name = self.videos[idx].split('.')[0]
audio_path = os.path.join(self.audio_dir, name+'.npy')
pose_path = os.path.join(self.pose_dir,name+'.npy')
path = os.path.join(self.image_dir, name)
video_name = os.path.basename(path)
if os.path.isdir(path):
# if self.is_train and os.path.isdir(path):
frames = os.listdir(path)
num_frames = len(frames)
frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2))
video_array = [img_as_float32(io.imread(os.path.join(path, frames[idx]))) for idx in frame_idx]
mfcc = np.load(audio_path)
pose = np.load(pose_path)
# print(audio_path,pose_path,len(mfcc))
try:
len(mfcc) > 16
except:
print('wrongmfcc len:',audio_path)
if 16 < len(mfcc) < 24 :
r = 0
else:
r = random.choice([x for x in range(3, len(mfcc)-20)])
mfccs = []
poses = []
video_array = []
for ind in range(1, 17):
t_mfcc = mfcc[r+ind][:, 1:]
mfccs.append(t_mfcc)
t_pose = pose[r+ind,:-1]
poses.append(t_pose)
image = img_as_float32(io.imread(os.path.join(path, str(r + ind)+'.png')))
video_array.append(image)
mfccs = np.array(mfccs)
poses = np.array(poses)
video_array = np.array(video_array)
example_image = img_as_float32(io.imread(os.path.join(path, str(r)+'.png')))
else:
print('Wrong, data path not an existing file.')
if self.transform is not None:
video_array = self.transform(video_array)
out = {}
driving = np.array(video_array, dtype='float32')
spatial_size = np.array(driving.shape[1:3][::-1])[np.newaxis]
driving_pose = np.array(poses, dtype='float32')
example_image = np.array(example_image, dtype='float32')
out['example_image'] = example_image.transpose((2, 0, 1))
out['driving_pose'] = driving_pose
out['driving'] = driving.transpose((0, 3, 1, 2))
out['driving_audio'] = np.array(mfccs, dtype='float32')
# out['name'] = video_name
return out
class MeadDataset(Dataset):
"""
Dataset of videos, each video can be represented as:
- an image of concatenated frames
- '.mp4' or '.gif'
- folder with all frames
"""
def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,
random_seed=0, augmentation_params=None):
self.root_dir = root_dir
self.audio_dir = os.path.join(root_dir,'MEAD_MFCC')
self.image_dir = os.path.join(root_dir,'MEAD_fomm_crop')
self.pose_dir = os.path.join(root_dir,'MEAD_fomm_pose_crop')
self.videos = np.load('/mnt/lustre/share_data/jixinya/MEAD/MEAD_fomm_audio_less_crop.npy')
self.dict = np.load('/mnt/lustre/share_data/jixinya/MEAD/MEAD_fomm_neu_dic_crop.npy',allow_pickle=True).item()
# self.videos = os.listdir(root_dir)
self.frame_shape = tuple(frame_shape)
self.id_sampling = id_sampling
if os.path.exists(os.path.join(root_dir, 'train')):
assert os.path.exists(os.path.join(root_dir, 'test'))
print("Use predefined train-test split.")
if id_sampling:
train_videos = {os.path.basename(video).split('#')[0] for video in
os.listdir(os.path.join(root_dir, 'train'))}
train_videos = list(train_videos)
else:
train_videos = os.listdir(os.path.join(root_dir, 'train'))
test_videos = os.listdir(os.path.join(root_dir, 'test'))
self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')
else:
print("Use random train-test split.")
train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)
if is_train:
self.videos = train_videos
else:
self.videos = test_videos
self.is_train = is_train
if self.is_train:
self.transform = AllAugmentationTransform(**augmentation_params)
else:
self.transform = None
def __len__(self):
return len(self.videos)
def __getitem__(self, idx):
if self.is_train and self.id_sampling:
name = self.videos[idx]
path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))
else:
name = self.videos[idx]
path = os.path.join(self.image_dir, name)
video_name = os.path.basename(path)
id_name = path.split('/')[-2]
neu_list = self.dict[id_name]
neu_path = os.path.join(self.image_dir, np.random.choice(neu_list))
audio_path = os.path.join(self.audio_dir, name+'.npy')
pose_path = os.path.join(self.pose_dir,name+'.npy')
if self.is_train and os.path.isdir(path):
mfcc = np.load(audio_path)
pose_raw = np.load(pose_path)
one_euro_filter = OneEuroFilter(mincutoff=0.01, beta=0.7, dcutoff=1.0, freq=100)
pose = np.zeros((len(pose_raw),7))
for j in range(len(pose_raw)):
pose[j]=one_euro_filter.process(pose_raw[j])
# print(audio_path,pose_path,len(mfcc))
neu_frames = os.listdir(neu_path)
num_neu_frames = len(neu_frames)
frame_idx = np.random.choice(num_neu_frames)
example_image = img_as_float32(io.imread(os.path.join(neu_path, neu_frames[frame_idx])))
try:
len(mfcc) > 16
except:
print('wrongmfcc len:',audio_path)
if 16 < len(mfcc) < 24 :
r = 0
else:
r = random.choice([x for x in range(3, len(mfcc)-20)])
mfccs = []
poses = []
video_array = []
for ind in range(1, 17):
t_mfcc = mfcc[r+ind][:, 1:]
mfccs.append(t_mfcc)
t_pose = pose[r+ind,:-1]
poses.append(t_pose)
image = img_as_float32(io.imread(os.path.join(path, str(r + ind)+'.png')))
video_array.append(image)
mfccs = np.array(mfccs)
poses = np.array(poses)
video_array = np.array(video_array)
else:
print('Wrong, data path not an existing file.')
if self.transform is not None:
video_array = self.transform(video_array)
out = {}
if self.is_train:
driving = np.array(video_array, dtype='float32')
driving_pose = np.array(poses, dtype='float32')
example_image = np.array(example_image, dtype='float32')
out['example_image'] = example_image.transpose((2, 0, 1))
out['driving_pose'] = driving_pose
out['driving'] = driving.transpose((0, 3, 1, 2))
out['driving_audio'] = np.array(mfccs, dtype='float32')
# out['name'] = id_name+'/'+video_name
return out
class DatasetRepeater(Dataset):
"""
Pass several times over the same dataset for better i/o performance
"""
def __init__(self, dataset, num_repeats=100):
self.dataset = dataset
# self.dataset2 = dataset2
self.num_repeats = num_repeats
def __len__(self):
return self.num_repeats * self.dataset.__len__()
def __getitem__(self, idx):
# if idx % 5 == 0:
# return self.dataset2[idx % self.dataset2.__len__()]#% self.dataset.__len__()
# else:
# return self.dataset[idx % self.dataset.__len__()]
return self.dataset[idx % self.dataset.__len__()]
class TestsetRepeater(Dataset):
"""
Pass several times over the same dataset for better i/o performance
"""
def __init__(self, dataset, num_repeats=100):
self.dataset = dataset
self.num_repeats = num_repeats
def __len__(self):
return self.num_repeats * self.dataset.__len__()
def __getitem__(self, idx):
return self.dataset[idx % self.dataset.__len__()]#% self.dataset.__len__()
class PairedDataset(Dataset):
"""
Dataset of pairs for animation.
"""
def __init__(self, initial_dataset, number_of_pairs, seed=0):
self.initial_dataset = initial_dataset
pairs_list = self.initial_dataset.pairs_list
np.random.seed(seed)
if pairs_list is None:
max_idx = min(number_of_pairs, len(initial_dataset))
nx, ny = max_idx, max_idx
xy = np.mgrid[:nx, :ny].reshape(2, -1).T
number_of_pairs = min(xy.shape[0], number_of_pairs)
self.pairs = xy.take(np.random.choice(xy.shape[0], number_of_pairs, replace=False), axis=0)
else:
videos = self.initial_dataset.videos
name_to_index = {name: index for index, name in enumerate(videos)}
pairs = pd.read_csv(pairs_list)
pairs = pairs[np.logical_and(pairs['source'].isin(videos), pairs['driving'].isin(videos))]
number_of_pairs = min(pairs.shape[0], number_of_pairs)
self.pairs = []
self.start_frames = []
for ind in range(number_of_pairs):
self.pairs.append(
(name_to_index[pairs['driving'].iloc[ind]], name_to_index[pairs['source'].iloc[ind]]))
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx):
pair = self.pairs[idx]
first = self.initial_dataset[pair[0]]
second = self.initial_dataset[pair[1]]
first = {'driving_' + key: value for key, value in first.items()}
second = {'source_' + key: value for key, value in second.items()}
return {**first, **second}
================================================
FILE: logger.py
================================================
import numpy as np
import torch
import torch.nn.functional as F
import imageio
import os
from skimage.draw import circle
import matplotlib.pyplot as plt
import collections
class Logger:
def __init__(self, log_dir, checkpoint_freq=100, visualizer_params=None, zfill_num=8, log_file_name='log.txt'):
self.loss_list = []
self.cpk_dir = log_dir
self.visualizations_dir = os.path.join(log_dir, 'train-vis')
if not os.path.exists(self.visualizations_dir):
os.makedirs(self.visualizations_dir)
self.log_file = open(os.path.join(log_dir, log_file_name), 'a')
self.zfill_num = zfill_num
self.visualizer = Visualizer(**visualizer_params)
self.checkpoint_freq = checkpoint_freq
self.epoch = 0
self.best_loss = float('inf')
self.names = None
def log_scores(self, loss_names):
loss_mean = np.array(self.loss_list).mean(axis=0)
loss_string = "; ".join(["%s - %.5f" % (name, value) for name, value in zip(loss_names, loss_mean)])
loss_string = str(str(self.epoch)+str(self.step).zfill(self.zfill_num)) + ") " + loss_string
print(loss_string, file=self.log_file)
self.loss_list = []
self.log_file.flush()
def visualize_rec(self, inp, out):
# image = self.visualizer.visualize(inp['driving'], inp['source'], out)
image = self.visualizer.visualize(inp['driving'][:,-1], inp['transformed_driving'][:,-1], inp['example_image'], out)
imageio.imsave(os.path.join(self.visualizations_dir, "%s-%s-rec.png" % (str(self.epoch),str(self.step).zfill(self.zfill_num))), image)
def save_cpk(self, emergent=False):
cpk = {k: v.state_dict() for k, v in self.models.items()}
cpk['epoch'] = self.epoch
cpk['step'] = self.step
cpk_path = os.path.join(self.cpk_dir, '%s-%s-checkpoint.pth.tar' % (str(self.epoch),str(self.step).zfill(self.zfill_num)))
if not (os.path.exists(cpk_path) and emergent):
torch.save(cpk, cpk_path)
@staticmethod
def load_cpk(checkpoint_path, generator=None, discriminator=None, kp_detector=None, audio_feature=None,
optimizer_generator=None, optimizer_discriminator=None, optimizer_kp_detector=None, optimizer_audio_feature = None):
checkpoint = torch.load(checkpoint_path)
if generator is not None:
generator.load_state_dict(checkpoint['generator'])
if kp_detector is not None:
kp_detector.load_state_dict(checkpoint['kp_detector'])
if discriminator is not None:
try:
discriminator.load_state_dict(checkpoint['discriminator'])
except:
print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
# if audio_feature is not None:
# audio_feature.load_state_dict(checkpoint['audio_feature'])
if optimizer_generator is not None:
optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
if optimizer_discriminator is not None:
try:
optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
except RuntimeError as e:
print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
if optimizer_kp_detector is not None:
optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])
# if optimizer_audio_feature is not None:
# a = checkpoint['optimizer_kp_detector']['param_groups']
# a[0].pop('params')
# optimizer_audio_feature.load_state_dict(checkpoint['optimizer_audio_feature'])
return checkpoint['epoch']
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if 'models' in self.__dict__:
self.save_cpk()
self.log_file.close()
def log_iter(self, losses):
losses = collections.OrderedDict(losses.items())
if self.names is None:
self.names = list(losses.keys())
self.loss_list.append(list(losses.values()))
def log_epoch(self, epoch, step, models, inp, out):
self.epoch = epoch
self.step = step
self.models = models
if (self.epoch + 1) % self.checkpoint_freq == 0:
self.save_cpk()
self.log_scores(self.names)
self.visualize_rec(inp, out)
class Visualizer:
def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbow'):
self.kp_size = kp_size
self.draw_border = draw_border
self.colormap = plt.get_cmap(colormap)
def draw_image_with_kp(self, image, kp_array):
image = np.copy(image)
spatial_size = np.array(image.shape[:2][::-1])[np.newaxis]
kp_array = spatial_size * (kp_array + 1) / 2
num_kp = kp_array.shape[0]
for kp_ind, kp in enumerate(kp_array):
rr, cc = circle(kp[1], kp[0], self.kp_size, shape=image.shape[:2])
image[rr, cc] = np.array(self.colormap(kp_ind / num_kp))[:3]
return image
def create_image_column_with_kp(self, images, kp):
image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)])
return self.create_image_column(image_array)
def create_image_column(self, images):
if self.draw_border:
images = np.copy(images)
images[:, :, [0, -1]] = (1, 1, 1)
images[:, :, [0, -1]] = (1, 1, 1)
return np.concatenate(list(images), axis=0)
def create_image_grid(self, *args):
out = []
for arg in args:
if type(arg) == tuple:
out.append(self.create_image_column_with_kp(arg[0], arg[1]))
else:
out.append(self.create_image_column(arg))
return np.concatenate(out, axis=1)
def visualize(self, driving, transformed_driving, source, out):
images = []
# Source image with keypoints
source = source.data.cpu()
kp_source = out['kp_source']['value'].data.cpu().numpy()
source = np.transpose(source, [0, 2, 3, 1])
images.append((source, kp_source))
# Equivariance visualization
if 'transformed_frame' in out:
transformed = out['transformed_frame'].data.cpu().numpy()
transformed = np.transpose(transformed, [0, 2, 3, 1])
transformed_kp = out['transformed_kp']['value'].data.cpu().numpy()
images.append((transformed, transformed_kp))
# Equivariance visualization
transformed_driving = transformed_driving.data.cpu().numpy()
transformed_driving = np.transpose(transformed_driving, [0, 2, 3, 1])
images.append(transformed_driving)
# Driving image with keypoints
kp_driving = out['kp_driving'][-1]['value'].data.cpu().numpy() #[-1]['value']
driving = driving.data.cpu().numpy()
driving = np.transpose(driving, [0, 2, 3, 1])
images.append((driving, kp_driving))
# Deformed image
if 'deformed' in out:
deformed = out['deformed'].data.cpu().numpy()
deformed = np.transpose(deformed, [0, 2, 3, 1])
images.append(deformed)
# Result with and without keypoints
prediction = out['prediction'].data.cpu().numpy()
prediction = np.transpose(prediction, [0, 2, 3, 1])
if 'kp_norm' in out:
kp_norm = out['kp_norm']['value'].data.cpu().numpy()
images.append((prediction, kp_norm))
images.append(prediction)
## Occlusion map
if 'occlusion_map' in out:
occlusion_map = out['occlusion_map'].data.cpu().repeat(1, 3, 1, 1)
occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy()
occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1])
images.append(occlusion_map)
# Deformed images according to each individual transform
if 'sparse_deformed' in out:
full_mask = []
for i in range(out['sparse_deformed'].shape[1]):
image = out['sparse_deformed'][:, i].data.cpu()
image = F.interpolate(image, size=source.shape[1:3])
mask = out['mask'][:, i:(i+1)].data.cpu().repeat(1, 3, 1, 1)
mask = F.interpolate(mask, size=source.shape[1:3])
image = np.transpose(image.numpy(), (0, 2, 3, 1))
mask = np.transpose(mask.numpy(), (0, 2, 3, 1))
if i != 0:
color = np.array(self.colormap((i - 1) / (out['sparse_deformed'].shape[1] - 1)))[:3]
else:
color = np.array((0, 0, 0))
color = color.reshape((1, 1, 1, 3))
images.append(image)
if i != 0:
images.append(mask * color)
else:
images.append(mask)
full_mask.append(mask * color)
images.append(sum(full_mask))
image = self.create_image_grid(*images)
image = (255 * image).astype(np.uint8)
return image
================================================
FILE: modules/dense_motion.py
================================================
from torch import nn
import torch.nn.functional as F
import torch
from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian
class DenseMotionNetwork(nn.Module):
"""
Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
"""
def __init__(self, block_expansion, num_blocks, max_features, num_kp, num_channels, estimate_occlusion_map=False,
scale_factor=1, kp_variance=0.01):
super(DenseMotionNetwork, self).__init__()
self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp + 1) * (num_channels + 1),
max_features=max_features, num_blocks=num_blocks)
self.mask = nn.Conv2d(self.hourglass.out_filters, num_kp + 1, kernel_size=(7, 7), padding=(3, 3))
if estimate_occlusion_map:
self.occlusion = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3))
else:
self.occlusion = None
self.num_kp = num_kp
self.scale_factor = scale_factor
self.kp_variance = kp_variance
if self.scale_factor != 1:
self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
def create_heatmap_representations(self, source_image, kp_driving, kp_source):
"""
Eq 6. in the paper H_k(z)
"""
spatial_size = source_image.shape[2:]
gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance)
gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance)
heatmap = gaussian_driving - gaussian_source #[4,10,H,W]
#adding background feature
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type())
heatmap = torch.cat([zeros, heatmap], dim=1)
heatmap = heatmap.unsqueeze(2) #[4,11,1,h,w]
return heatmap
def create_sparse_motions(self, source_image, kp_driving, kp_source):
"""
Eq 4. in the paper T_{s<-d}(z)
"""
bs, _, h, w = source_image.shape
identity_grid = make_coordinate_grid((h, w), type=kp_source['value'].type())
identity_grid = identity_grid.view(1, 1, h, w, 2)
coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 2) #[4,10,64,64,2]
if 'jacobian' in kp_driving:
jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian']))
jacobian = jacobian.unsqueeze(-3).unsqueeze(-3)
jacobian = jacobian.repeat(1, 1, h, w, 1, 1)
coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1))
coordinate_grid = coordinate_grid.squeeze(-1)
driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 2)
#adding background feature
identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)
sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1)
return sparse_motions
def create_deformed_source_image(self, source_image, sparse_motions):
"""
Eq 7. in the paper \hat{T}_{s<-d}(z)
"""
bs, _, h, w = source_image.shape
source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp + 1, 1, 1, 1, 1)
source_repeat = source_repeat.view(bs * (self.num_kp + 1), -1, h, w)
sparse_motions = sparse_motions.view((bs * (self.num_kp + 1), h, w, -1))
sparse_deformed = F.grid_sample(source_repeat, sparse_motions)
sparse_deformed = sparse_deformed.view((bs, self.num_kp + 1, -1, h, w))
return sparse_deformed
def forward(self, source_image, kp_driving, kp_source):
if self.scale_factor != 1:
source_image = self.down(source_image) #[4,3,H*scale,W*scale]
bs, _, h, w = source_image.shape
out_dict = dict()
heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source) #[4,11,1,64,64]
sparse_motion = self.create_sparse_motions(source_image, kp_driving, kp_source) #[4,11,64,64,2]
deformed_source = self.create_deformed_source_image(source_image, sparse_motion) #[4,11,3,64,64]
out_dict['sparse_deformed'] = deformed_source
input = torch.cat([heatmap_representation, deformed_source], dim=2)
input = input.view(bs, -1, h, w) #[4,11*4,64,64]
prediction = self.hourglass(input) #[4,108,64,64]
mask = self.mask(prediction)
mask = F.softmax(mask, dim=1) #[4,11,64,64]
out_dict['mask'] = mask
mask = mask.unsqueeze(2)
sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3)
deformation = (sparse_motion * mask).sum(dim=1)
deformation = deformation.permute(0, 2, 3, 1) #[4,64,64,2]
out_dict['deformation'] = deformation
# Sec. 3.2 in the paper
if self.occlusion:
occlusion_map = torch.sigmoid(self.occlusion(prediction))
out_dict['occlusion_map'] = occlusion_map #[4,1,64,64]
return out_dict
================================================
FILE: modules/discriminator.py
================================================
from torch import nn
import torch.nn.functional as F
from modules.util import kp2gaussian
import torch
class DownBlock2d(nn.Module):
"""
Simple block for processing video (encoder).
"""
def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
super(DownBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
if sn:
self.conv = nn.utils.spectral_norm(self.conv)
if norm:
self.norm = nn.InstanceNorm2d(out_features, affine=True)
else:
self.norm = None
self.pool = pool
def forward(self, x):
out = x
out = self.conv(out)
if self.norm:
out = self.norm(out)
out = F.leaky_relu(out, 0.2)
if self.pool:
out = F.avg_pool2d(out, (2, 2))
return out
class Discriminator(nn.Module):
"""
Discriminator similar to Pix2Pix
"""
def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
sn=False, use_kp=False, num_kp=10, kp_variance=0.01, **kwargs):
super(Discriminator, self).__init__()
down_blocks = []
for i in range(num_blocks):
down_blocks.append(
DownBlock2d(num_channels + num_kp * use_kp if i == 0 else min(max_features, block_expansion * (2 ** i)),
min(max_features, block_expansion * (2 ** (i + 1))),
norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
self.down_blocks = nn.ModuleList(down_blocks)
self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
if sn:
self.conv = nn.utils.spectral_norm(self.conv)
self.use_kp = use_kp
self.kp_variance = kp_variance
def forward(self, x, kp=None):
feature_maps = []
out = x
if self.use_kp:
heatmap = kp2gaussian(kp, x.shape[2:], self.kp_variance)
out = torch.cat([out, heatmap], dim=1)
for down_block in self.down_blocks:
feature_maps.append(down_block(out))
out = feature_maps[-1]
prediction_map = self.conv(out)
return feature_maps, prediction_map
class MultiScaleDiscriminator(nn.Module):
"""
Multi-scale (scale) discriminator
"""
def __init__(self, scales=(), **kwargs):
super(MultiScaleDiscriminator, self).__init__()
self.scales = scales
discs = {}
for scale in scales:
discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
self.discs = nn.ModuleDict(discs)
def forward(self, x, kp=None):
out_dict = {}
for scale, disc in self.discs.items():
scale = str(scale).replace('-', '.')
key = 'prediction_' + scale
feature_maps, prediction_map = disc(x[key], kp)
out_dict['feature_maps_' + scale] = feature_maps
out_dict['prediction_map_' + scale] = prediction_map
return out_dict
================================================
FILE: modules/function.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 30 17:45:24 2021
@author: SENSETIME\jixinya1
"""
import torch
def calc_mean_std(feat, eps=1e-5):
# eps is a small value added to the variance to avoid divide-by-zero.
size = feat.size()
assert (len(size) == 4)
N, C = size[:2]
feat_var = feat.view(N, C, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(N, C, 1, 1)
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat, style_feat):
assert (content_feat.size()[:2] == style_feat.size()[:2])
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(
size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
def _calc_feat_flatten_mean_std(feat):
# takes 3D feat (C, H, W), return mean and std of array within channels
assert (feat.size()[0] == 3)
assert (isinstance(feat, torch.FloatTensor))
feat_flatten = feat.view(3, -1)
mean = feat_flatten.mean(dim=-1, keepdim=True)
std = feat_flatten.std(dim=-1, keepdim=True)
return feat_flatten, mean, std
def _mat_sqrt(x):
U, D, V = torch.svd(x)
return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t())
def coral(source, target):
# assume both source and target are 3D array (C, H, W)
# Note: flatten -> f
source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
source_f_norm = (source_f - source_f_mean.expand_as(
source_f)) / source_f_std.expand_as(source_f)
source_f_cov_eye = \
torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)
target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
target_f_norm = (target_f - target_f_mean.expand_as(
target_f)) / target_f_std.expand_as(target_f)
target_f_cov_eye = \
torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)
source_f_norm_transfer = torch.mm(
_mat_sqrt(target_f_cov_eye),
torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),
source_f_norm)
)
source_f_transfer = source_f_norm_transfer * \
target_f_std.expand_as(source_f_norm) + \
target_f_mean.expand_as(source_f_norm)
return source_f_transfer.view(source.size())
================================================
FILE: modules/generator.py
================================================
import torch
from torch import nn
import torch.nn.functional as F
from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d
from modules.dense_motion import DenseMotionNetwork
class OcclusionAwareGenerator(nn.Module):
"""
Generator that given source image and and keypoints try to transform image according to movement trajectories
induced by keypoints. Generator follows Johnson architecture.
"""
def __init__(self, num_channels, num_kp, block_expansion, max_features, num_down_blocks,
num_bottleneck_blocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
super(OcclusionAwareGenerator, self).__init__()
if dense_motion_params is not None:
self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, num_channels=num_channels,
estimate_occlusion_map=estimate_occlusion_map,
**dense_motion_params)
else:
self.dense_motion_network = None
self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))
down_blocks = []
for i in range(num_down_blocks):
in_features = min(max_features, block_expansion * (2 ** i))
out_features = min(max_features, block_expansion * (2 ** (i + 1)))
down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
self.down_blocks = nn.ModuleList(down_blocks)
up_blocks = []
for i in range(num_down_blocks):
in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i)))
out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1)))
up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
self.up_blocks = nn.ModuleList(up_blocks)
self.bottleneck = torch.nn.Sequential()
in_features = min(max_features, block_expansion * (2 ** num_down_blocks))
for i in range(num_bottleneck_blocks):
self.bottleneck.add_module('r' + str(i), ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1)))
self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))
self.estimate_occlusion_map = estimate_occlusion_map
self.num_channels = num_channels
def deform_input(self, inp, deformation):
_, h_old, w_old, _ = deformation.shape
_, _, h, w = inp.shape
if h_old != h or w_old != w:
deformation = deformation.permute(0, 3, 1, 2)
deformation = F.interpolate(deformation, size=(h, w), mode='bilinear')
deformation = deformation.permute(0, 2, 3, 1)
return F.grid_sample(inp, deformation)
def forward(self, source_image, kp_driving, kp_source):
# Encoding (downsampling) part
out = self.first(source_image) #[4,64,H,W]
for i in range(len(self.down_blocks)):
out = self.down_blocks[i](out) #[4,256,H/4,W/4]
# Transforming feature representation according to deformation and occlusion
output_dict = {}
if self.dense_motion_network is not None:
dense_motion = self.dense_motion_network(source_image=source_image, kp_driving=kp_driving,
kp_source=kp_source)
output_dict['mask'] = dense_motion['mask']
output_dict['sparse_deformed'] = dense_motion['sparse_deformed']
if 'occlusion_map' in dense_motion:
occlusion_map = dense_motion['occlusion_map']
output_dict['occlusion_map'] = occlusion_map
else:
occlusion_map = None
deformation = dense_motion['deformation']
out = self.deform_input(out, deformation)
if occlusion_map is not None:
if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
out = out * occlusion_map
output_dict["deformed"] = self.deform_input(source_image, deformation)
# Decoding part
out = self.bottleneck(out) #[4,256,64,64]
for i in range(len(self.up_blocks)):
out = self.up_blocks[i](out)
out = self.final(out)
out = torch.sigmoid(out) #[4,3,256,256]
output_dict["prediction"] = out
return output_dict
================================================
FILE: modules/keypoint_detector.py
================================================
from torch import nn
import torch
import torch.nn.functional as F
from modules.util import Hourglass, make_coordinate_grid, AntiAliasInterpolation2d, Ct_encoder, EmotionNet, AF2F, AF2F_s, draw_heatmap
class KPDetector(nn.Module):
"""
Detecting a keypoints. Return keypoint position and jacobian near each keypoint.
"""
def __init__(self, block_expansion, num_kp, num_channels, max_features,
num_blocks, temperature, estimate_jacobian=False, scale_factor=1,
single_jacobian_map=False, pad=0):
super(KPDetector, self).__init__()
self.predictor = Hourglass(block_expansion, in_features=num_channels,
max_features=max_features, num_blocks=num_blocks)
self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7),
padding=pad)
if estimate_jacobian:
self.num_jacobian_maps = 1 if single_jacobian_map else num_kp
self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters,
out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad)
self.jacobian.weight.data.zero_()
self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))
else:
self.jacobian = None
self.temperature = temperature
self.scale_factor = scale_factor
if self.scale_factor != 1:
self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
def gaussian2kp(self, heatmap):
"""
Extract the mean and from a heatmap
"""
shape = heatmap.shape
heatmap = heatmap.unsqueeze(-1) #[4,10,58,58,1]
grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) #[1,1,58,58,2]
value = (heatmap * grid).sum(dim=(2, 3)) #[4,10,2]
kp = {'value': value}
return kp
def audio_feature(self, x, heatmap):
# prediction = self.kp(x) #[4,10,H/4-6, W/4-6]
# final_shape = prediction.shape
# heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58]
# heatmap = F.softmax(heatmap / self.temperature, dim=2)
# heatmap = heatmap.view(*final_shape) #[4,10,58,58]
# out = self.gaussian2kp(heatmap)
final_shape = heatmap.squeeze(2).shape
if self.jacobian is not None:
jacobian_map = self.jacobian(x) ##[4,40,H/4-6, W/4-6]
jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],
final_shape[3])
heatmap = heatmap.unsqueeze(2)
jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6]
jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)
jacobian = jacobian.sum(dim=-1) #[4,10,4]
jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2]
return jacobian
def forward(self, x): #torch.Size([4, 3, H, W])
if self.scale_factor != 1:
x = self.down(x) # 0.25 [4, 3, H/4, W/4]
feature_map = self.predictor(x) #[4,3+32,H/4, W/4]
prediction = self.kp(feature_map) #[4,10,H/4-6, W/4-6]
final_shape = prediction.shape
heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58]
heatmap = F.softmax(heatmap / self.temperature, dim=2)
heatmap = heatmap.view(*final_shape) #[4,10,58,58]
out = self.gaussian2kp(heatmap)
out['heatmap'] = heatmap
if self.jacobian is not None:
jacobian_map = self.jacobian(feature_map) ##[4,40,H/4-6, W/4-6]
jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],
final_shape[3])
heatmap = heatmap.unsqueeze(2)
jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6]
jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)
jacobian = jacobian.sum(dim=-1) #[4,10,4]
jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2]
out['jacobian'] = jacobian
return out
class KPDetector_a(nn.Module):
"""
Detecting a keypoints. Return keypoint position and jacobian near each keypoint.
"""
def __init__(self, block_expansion, num_kp, num_channels,num_channels_a, max_features,
num_blocks, temperature, estimate_jacobian=False, scale_factor=1,
single_jacobian_map=False, pad=0):
super(KPDetector_a, self).__init__()
self.predictor = Hourglass(block_expansion, in_features=num_channels_a,
max_features=max_features, num_blocks=num_blocks)
self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7),
padding=pad)
if estimate_jacobian:
self.num_jacobian_maps = 1 if single_jacobian_map else num_kp
self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters,
out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad)
self.jacobian.weight.data.zero_()
self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))
else:
self.jacobian = None
self.temperature = temperature
self.scale_factor = scale_factor
if self.scale_factor != 1:
self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
def gaussian2kp(self, heatmap):
"""
Extract the mean and from a heatmap
"""
shape = heatmap.shape
heatmap = heatmap.unsqueeze(-1) #[4,10,58,58,1]
grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) #[1,1,58,58,2]
value = (heatmap * grid).sum(dim=(2, 3)) #[4,10,2]
kp = {'value': value}
return kp
def audio_feature(self, x, heatmap):
# prediction = self.kp(x) #[4,10,H/4-6, W/4-6]
# final_shape = prediction.shape
# heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58]
# heatmap = F.softmax(heatmap / self.temperature, dim=2)
# heatmap = heatmap.view(*final_shape) #[4,10,58,58]
# out = self.gaussian2kp(heatmap)
final_shape = heatmap.squeeze(2).shape
if self.jacobian is not None:
jacobian_map = self.jacobian(x) ##[4,40,H/4-6, W/4-6]
jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],
final_shape[3])
heatmap = heatmap.unsqueeze(2)
jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6]
jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)
jacobian = jacobian.sum(dim=-1) #[4,10,4]
jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2]
return jacobian
def forward(self, feature_map): #torch.Size([4, 3, H, W])
prediction = self.kp(feature_map) #[4,10,H/4-6, W/4-6]
final_shape = prediction.shape
heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58]
heatmap = F.softmax(heatmap / self.temperature, dim=2)
heatmap = heatmap.view(*final_shape) #[4,10,58,58]
out = self.gaussian2kp(heatmap)
out['heatmap'] = heatmap
if self.jacobian is not None:
jacobian_map = self.jacobian(feature_map) ##[4,40,H/4-6, W/4-6]
jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],
final_shape[3])
heatmap = heatmap.unsqueeze(2)
jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6]
jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)
jacobian = jacobian.sum(dim=-1) #[4,10,4]
jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2]
out['jacobian'] = jacobian
return out
class Audio_Feature(nn.Module):
def __init__(self):
super(Audio_Feature, self).__init__()
self.con_encoder = Ct_encoder()
self.emo_encoder = EmotionNet()
self.decoder = AF2F_s()
def forward(self, x):
x = x.unsqueeze(1)
c = self.con_encoder(x)
e = self.emo_encoder(x)
# d = torch.cat([c, e], dim=1)
d = self.decoder(c)
return d
'''
def forward(self, x, cube, audio): #torch.Size([4, 3, H, W])
if self.scale_factor != 1:
x = self.down(x) # 0.25 [4, 3, H/4, W/4]
cube = cube.unsqueeze(1)
feature = torch.cat([x,cube,audio],dim=1)
feature_map = self.predictor(feature) #[4,3+32,H/4, W/4]
prediction = self.kp(feature_map) #[4,10,H/4-6, W/4-6]
final_shape = prediction.shape
heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58]
heatmap = F.softmax(heatmap / self.temperature, dim=2)
heatmap = heatmap.view(*final_shape) #[4,10,58,58]
out = self.gaussian2kp(heatmap)
out['heatmap'] = heatmap
if self.jacobian is not None:
jacobian_map = self.jacobian(feature_map) ##[4,40,H/4-6, W/4-6]
jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],
final_shape[3])
heatmap = heatmap.unsqueeze(2)
jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6]
jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)
jacobian = jacobian.sum(dim=-1) #[4,10,4]
jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2]
out['jacobian'] = jacobian
return out
'''
================================================
FILE: modules/model.py
================================================
from torch import nn
import torch
import torch.nn.functional as F
from modules.util import AntiAliasInterpolation2d, make_coordinate_grid
from torchvision import models
import numpy as np
from torch.autograd import grad
class Vgg19(torch.nn.Module):
"""
Vgg19 network for perceptual loss. See Sec 3.3.
"""
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
requires_grad=False)
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
requires_grad=False)
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
X = (X - self.mean) / self.std
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
class ImagePyramide(torch.nn.Module):
"""
Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
"""
def __init__(self, scales, num_channels):
super(ImagePyramide, self).__init__()
downs = {}
for scale in scales:
downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
self.downs = nn.ModuleDict(downs)
def forward(self, x):
out_dict = {}
for scale, down_module in self.downs.items():
out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
return out_dict
class Transform:
"""
Random tps transformation for equivariance constraints. See Sec 3.3
"""
def __init__(self, bs, **kwargs):
noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))
self.theta = noise + torch.eye(2, 3).view(1, 2, 3)
self.bs = bs
if ('sigma_tps' in kwargs) and ('points_tps' in kwargs):
self.tps = True
self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())
self.control_points = self.control_points.unsqueeze(0)
self.control_params = torch.normal(mean=0,
std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))
else:
self.tps = False
def transform_frame(self, frame):
grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2]
grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)
return F.grid_sample(frame, grid, padding_mode="reflection")
def inverse_transform_frame(self, frame):
grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2]
grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
grid = self.inverse_warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)
return F.grid_sample(frame, grid, padding_mode="reflection")
def warp_coordinates(self, coordinates):
theta = self.theta.type(coordinates.type())
theta = theta.unsqueeze(1)
transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]
transformed = transformed.squeeze(-1)
if self.tps:
control_points = self.control_points.type(coordinates.type())
control_params = self.control_params.type(coordinates.type())
distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
distances = torch.abs(distances).sum(-1)
result = distances ** 2
result = result * torch.log(distances + 1e-6)
result = result * control_params
result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
transformed = transformed + result
return transformed
def inverse_warp_coordinates(self, coordinates):
theta = self.theta.type(coordinates.type())
theta = theta.unsqueeze(1)
a = torch.FloatTensor([[[[0,0,1]]]]).repeat([self.bs,1,1,1]).cuda()
c = torch.cat((theta,a),2)
d = c.inverse()[:,:,:2,:]
d = d.type(coordinates.type())
transformed = torch.matmul(d[:, :, :, :2], coordinates.unsqueeze(-1)) + d[:, :, :, 2:]
transformed = transformed.squeeze(-1)
if self.tps:
control_points = self.control_points.type(coordinates.type())
control_params = self.control_params.type(coordinates.type())
distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
distances = torch.abs(distances).sum(-1)
result = distances ** 2
result = result * torch.log(distances + 1e-6)
result = result * control_params
result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
transformed = transformed + result
return transformed
def jacobian(self, coordinates):
coordinates.requires_grad=True
new_coordinates = self.warp_coordinates(coordinates)#[4,10,2]
grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True)
grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True)
jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2)
return jacobian
def detach_kp(kp):
return {key: value.detach() for key, value in kp.items()}
class TrainPart1Model(torch.nn.Module):
"""
Merge all generator related updates into single model for better multi-gpu usage
"""
def __init__(self, kp_extractor, kp_extractor_a, audio_feature, generator, discriminator, train_params, device_ids):
super(TrainFullModel, self).__init__()
self.kp_extractor = kp_extractor
self.kp_extractor_a = kp_extractor_a
self.audio_feature = audio_feature
self.generator = generator
self.discriminator = discriminator
self.train_params = train_params
self.scales = train_params['scales']
self.disc_scales = self.discriminator.scales
self.pyramid = ImagePyramide(self.scales, generator.num_channels)
if torch.cuda.is_available():
self.pyramid = self.pyramid.cuda()
self.loss_weights = train_params['loss_weights']
if sum(self.loss_weights['perceptual']) != 0:
self.vgg = Vgg19()
if torch.cuda.is_available():
self.vgg = self.vgg.cuda()
self.mse_loss_fn = nn.MSELoss().cuda()
def forward(self, x):
kp_source = self.kp_extractor(x['example_image'])
kp_driving = []
for i in range(16):
kp_driving.append(self.kp_extractor(x['driving'][:,i]))
kp_driving_a = [] #x['example_image'],
deco_out = self.audio_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net'])
loss_values = {}
if self.loss_weights['audio'] != 0:
kp_driving_a = []
for i in range(16):
kp_driving_a.append(self.kp_extractor_a(deco_out[:,i]))#
loss_value = 0
loss_heatmap = 0
loss_jacobian = 0
loss_perceptual = 0
for i in range(len(kp_driving)):
loss_jacobian += (torch.abs(kp_driving[i]['jacobian'] - kp_driving_a[i]['jacobian']).mean())*self.loss_weights['audio']
# loss_jacobian = loss_jacobian*self.loss_weights['audio']
loss_heatmap += (torch.abs(kp_driving[i]['heatmap'] - kp_driving_a[i]['heatmap']).mean())*self.loss_weights['audio']*100
loss_value += (torch.abs(kp_driving[i]['value'].detach() - kp_driving_a[i]['value']).mean())*self.loss_weights['audio']
loss_values['loss_value'] = loss_value/len(kp_driving)
loss_values['loss_heatmap'] = loss_heatmap/len(kp_driving)
loss_values['loss_jacobian'] = loss_jacobian/len(kp_driving)
if self.train_params['generator'] == 'not':
# loss_values['perceptual'] = self.mse_loss_fn(deco_out,deco_out)
for i in range(1): #0,len(kp_driving),4
generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving_a[i])
generated.update({'kp_source': kp_source, 'kp_driving': kp_driving_a})
elif self.train_params['generator'] == 'visual':
for i in range(0,len(kp_driving),4): #0,len(kp_driving),4
generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving[i])
generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
pyramide_real = self.pyramid(x['driving'][:,i])
pyramide_generated = self.pyramid(generated['prediction'])
if sum(self.loss_weights['perceptual']) != 0:
value_total = 0
for scale in self.scales:
x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
for i, weight in enumerate(self.loss_weights['perceptual']):
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
value_total += self.loss_weights['perceptual'][i] * value
loss_perceptual += value_total
length = int((len(kp_driving)-1)/4)+1
loss_values['perceptual'] = loss_perceptual/length
elif self.train_params['generator'] == 'audio':
for i in range(0,len(kp_driving),4): #0,len(kp_driving),4
generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving_a[i])
generated.update({'kp_source': kp_source, 'kp_driving': kp_driving_a})
pyramide_real = self.pyramid(x['driving'][:,i])
pyramide_generated = self.pyramid(generated['prediction'])
if sum(self.loss_weights['perceptual']) != 0:
value_total = 0
for scale in self.scales:
x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
for i, weight in enumerate(self.loss_weights['perceptual']):
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
value_total += self.loss_weights['perceptual'][i] * value
loss_perceptual += value_total
length = int((len(kp_driving)-1)/4)+1
loss_values['perceptual'] = loss_perceptual/length
else:
print('wrong train_params: ', self.train_params['generator'])
return loss_values,generated
class TrainPart2Model(torch.nn.Module):
"""
Merge all generator related updates into single model for better multi-gpu usage
"""
def __init__(self, kp_extractor, emo_feature, kp_extractor_a, audio_feature, generator, discriminator, train_params, device_ids):
super(TrainFullModel, self).__init__()
self.kp_extractor = kp_extractor
self.kp_extractor_a = kp_extractor_a
self.audio_feature = audio_feature
self.emo_feature = emo_feature
self.generator = generator
self.discriminator = discriminator
self.train_params = train_params
self.scales = train_params['scales']
self.disc_scales = self.discriminator.scales
self.pyramid = ImagePyramide(self.scales, generator.num_channels)
if torch.cuda.is_available():
self.pyramid = self.pyramid.cuda()
self.loss_weights = train_params['loss_weights']
if sum(self.loss_weights['perceptual']) != 0:
self.vgg = Vgg19()
if torch.cuda.is_available():
self.vgg = self.vgg.cuda()
self.mse_loss_fn = nn.MSELoss().cuda()
self.CroEn_loss = nn.CrossEntropyLoss().cuda()
def forward(self, x):
kp_source = self.kp_extractor(x['example_image'])
kp_driving = []
kp_emo = []
for i in range(16):
kp_driving.append(self.kp_extractor(x['driving'][:,i]))
# kp_emo.append(self.emo_detector(x['driving'][:,i]))
kp_driving_a = [] #x['example_image'],
deco_out = self.audio_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net'])
# emo_out = self.emo_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net'])
loss_values = {}
if self.loss_weights['emo'] != 0:
kp_driving_a = []
fakes = []
for i in range(16):
kp_driving_a.append(self.kp_extractor_a(deco_out[:,i]))#
value = self.kp_extractor_a(deco_out[:,i])['value']
jacobian = self.kp_extractor_a(deco_out[:,i])['jacobian']
if self.train_params['type'] == 'linear_4' :
out, fake = self.emo_feature(x['transformed_driving'][:,i],value,jacobian)
kp_emo.append(out)
fakes.append(fake)
# kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian))
elif self.train_params['type'] == 'linear_10':
# kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))
out, fake = self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)
kp_emo.append(out)
fakes.append(fake)
elif self.train_params['type'] == 'linear_4_new':
# kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))
out, fake = self.emo_feature.linear_4(x['transformed_driving'][:,i],value,jacobian)
kp_emo.append(out)
fakes.append(fake)
elif self.train_params['type'] == 'linear_np_4':
# kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))
out, fake = self.emo_feature.linear_np_4(x['transformed_driving'][:,i],value,jacobian)
kp_emo.append(out)
fakes.append(fake)
elif self.train_params['type'] == 'linear_np_10':
# kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))
out, fake = self.emo_feature.linear_np_10(x['transformed_driving'][:,i],value,jacobian)
kp_emo.append(out)
fakes.append(fake)
loss_value = 0
loss_jacobian = 0
loss_classify = 0
kp_all = kp_driving_a
for i in range(len(kp_driving)):
if self.train_params['type'] == 'linear_4' or self.train_params['type'] == 'linear_4_new' or self.train_params['type'] == 'linear_np_4':
loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,1] - kp_driving_a[i]['jacobian'][:,1] -kp_emo[i]['jacobian'][:,0]).mean())*self.loss_weights['emo']
loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,4] - kp_driving_a[i]['jacobian'][:,4] -kp_emo[i]['jacobian'][:,1]).mean())*self.loss_weights['emo']
loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,6] - kp_driving_a[i]['jacobian'][:,6] -kp_emo[i]['jacobian'][:,2]).mean())*self.loss_weights['emo']
loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,8] - kp_driving_a[i]['jacobian'][:,8] -kp_emo[i]['jacobian'][:,3]).mean())*self.loss_weights['emo']
loss_classify += self.CroEn_loss(fakes[i],x['emotion'])
loss_value += (torch.abs(kp_driving[i]['value'][:,1] .detach() - kp_driving_a[i]['value'][:,1] - kp_emo[i]['value'][:,0] ).mean())*self.loss_weights['emo']
loss_value += (torch.abs(kp_driving[i]['value'][:,4] .detach() - kp_driving_a[i]['value'][:,4] - kp_emo[i]['value'][:,1] ).mean())*self.loss_weights['emo']
loss_value += (torch.abs(kp_driving[i]['value'][:,6] .detach() - kp_driving_a[i]['value'][:,6] - kp_emo[i]['value'][:,2] ).mean())*self.loss_weights['emo']
loss_value += (torch.abs(kp_driving[i]['value'][:,8] .detach() - kp_driving_a[i]['value'][:,8] - kp_emo[i]['value'][:,3] ).mean())*self.loss_weights['emo']
kp_all[i]['jacobian'][:,1] = kp_emo[i]['jacobian'][:,0] + kp_driving_a[i]['jacobian'][:,1]
kp_all[i]['jacobian'][:,4] = kp_emo[i]['jacobian'][:,1] + kp_driving_a[i]['jacobian'][:,4]
kp_all[i]['jacobian'][:,6] = kp_emo[i]['jacobian'][:,2] + kp_driving_a[i]['jacobian'][:,6]
kp_all[i]['jacobian'][:,8] = kp_emo[i]['jacobian'][:,3] + kp_driving_a[i]['jacobian'][:,8]
kp_all[i]['value'][:,1] = kp_emo[i]['value'][:,0] + kp_driving_a[i]['value'][:,1]
kp_all[i]['value'][:,4] = kp_emo[i]['value'][:,1] + kp_driving_a[i]['value'][:,4]
kp_all[i]['value'][:,6] = kp_emo[i]['value'][:,2] + kp_driving_a[i]['value'][:,6]
kp_all[i]['value'][:,8] = kp_emo[i]['value'][:,3] + kp_driving_a[i]['value'][:,8]
elif self.train_params['type'] == 'linear_10' or self.train_params['type'] == 'linear_np_10':
loss_jacobian += (torch.abs(kp_driving[i]['jacobian'] - kp_driving_a[i]['jacobian'] -kp_emo[i]['jacobian']).mean())*self.loss_weights['emo']
loss_classify += self.CroEn_loss(fakes[i],x['emotion'])
loss_value += (torch.abs(kp_driving[i]['value'].detach() - kp_driving_a[i]['value'] - kp_emo[i]['value'] ).mean())*self.loss_weights['emo']
# kp_all[i]['value'] = kp_emo[i]['value'] + kp_driving_a[i]['value']
loss_values['loss_value'] = loss_value/len(kp_driving)
# loss_values['loss_heatmap'] = loss_heatmap/len(kp_driving)
loss_values['loss_jacobian'] = loss_jacobian/len(kp_driving)
if self.train_params['classify'] == True:
loss_values['loss_classify'] = loss_classify/len(kp_driving)
else:
loss_values['loss_classify'] = torch.tensor(0, device = loss_values['loss_value'].device)
return loss_values,generated
class GeneratorFullModel(torch.nn.Module):
"""
Merge all generator related updates into single model for better multi-gpu usage
"""
def __init__(self, kp_extractor, kp_extractor_a, audio_feature, generator, discriminator, train_params):
super(GeneratorFullModel, self).__init__()
self.kp_extractor = kp_extractor
self.kp_extractor_a = kp_extractor_a
# self.content_encoder = content_encoder
# self.emotion_encoder = emotion_encoder
self.audio_feature = audio_feature
self.generator = generator
self.discriminator = discriminator
self.train_params = train_params
self.scales = train_params['scales']
self.disc_scales = self.discriminator.scales
self.pyramid = ImagePyramide(self.scales, generator.num_channels)
if torch.cuda.is_available():
self.pyramid = self.pyramid.cuda()
self.loss_weights = train_params['loss_weights']
if sum(self.loss_weights['perceptual']) != 0:
self.vgg = Vgg19()
if torch.cuda.is_available():
self.vgg = self.vgg.cuda()
self.pca = torch.FloatTensor(np.load('.../LRW/list/U_106.npy'))[:, :16].cuda()
self.mean = torch.FloatTensor(np.load('.../LRW/list/mean_106.npy')).cuda()
def forward(self, x):
# source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[])
# source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1)))
# kp_source = self.kp_extractor(x['source'])
# kp_source_a = self.kp_extractor_a(x['source'], x['source_cube'], source_a_f)
# driving_a_f = self.audio_feature(self.content_encoder(x['driving_audio'].unsqueeze(1)), self.emotion_encoder(x['driving_audio'].unsqueeze(1)))
# driving_a_f = self.audio_feature(x['driving_audio'])
# kp_driving = self.kp_extractor(x['driving'])
# kp_driving_a = self.kp_extractor_a(x['driving'], x['driving_cube'], driving_a_f)
kp_driving = []
for i in range(16):
kp_driving.append(self.kp_extractor(x['driving'][:,i],x['driving_landmark'][:,i],self.loss_weights['equivariance_value']))
kp_driving_a = []
fc_out, deco_out = self.audio_feature(x['example_landmark'], x['driving_audio'], x['driving_pose'])
fake_lmark=fc_out + x['example_landmark'].expand_as(fc_out)
fake_lmark = torch.mm( fake_lmark, self.pca.t() )
fake_lmark = fake_lmark + self.mean.expand_as(fake_lmark)
fake_lmark = fake_lmark.unsqueeze(0)
# for i in range(16):
# kp_driving_a.append()
# generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving)
# generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
loss_values = {}
pyramide_real = self.pyramid(x['driving'])
pyramide_generated = self.pyramid(generated['prediction'])
if self.loss_weights['audio'] != 0:
value = torch.abs(kp_source['jacobian'].detach() - kp_source_a['jacobian'].detach()).mean() + torch.abs(kp_driving['jacobian'].detach() - kp_driving_a['jacobian']).mean()
value = value/2
loss_values['jacobian'] = value*self.loss_weights['audio']
value = torch.abs(kp_source['heatmap'].detach() - kp_source_a['heatmap'].detach()).mean() + torch.abs(kp_driving['heatmap'].detach() - kp_driving_a['heatmap']).mean()
value = value/2
loss_values['heatmap'] = value*self.loss_weights['audio']
value = torch.abs(kp_source['value'].detach() - kp_source_a['value'].detach()).mean() + torch.abs(kp_driving['value'].detach() - kp_driving_a['value']).mean()
value = value/2
loss_values['value'] = value*self.loss_weights['audio']
if sum(self.loss_weights['perceptual']) != 0:
value_total = 0
for scale in self.scales:
x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
for i, weight in enumerate(self.loss_weights['perceptual']):
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
value_total += self.loss_weights['perceptual'][i] * value
loss_values['perceptual'] = value_total
if self.loss_weights['generator_gan'] != 0:
discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
value_total = 0
for scale in self.disc_scales:
key = 'prediction_map_%s' % scale
value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
value_total += self.loss_weights['generator_gan'] * value
loss_values['gen_gan'] = value_total
if sum(self.loss_weights['feature_matching']) != 0:
value_total = 0
for scale in self.disc_scales:
key = 'feature_maps_%s' % scale
for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):
if self.loss_weights['feature_matching'][i] == 0:
continue
value = torch.abs(a - b).mean()
value_total += self.loss_weights['feature_matching'][i] * value
loss_values['feature_matching'] = value_total
if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0:
transform = Transform(x['driving'].shape[0], **self.train_params['transform_params'])
transformed_frame = transform.transform_frame(x['driving'])
transformed_landmark = transform.inverse_warp_coordinates(x['driving_landmark'])
transformed_kp = self.kp_extractor(transformed_frame)
generated['transformed_frame'] = transformed_frame
generated['transformed_kp'] = transformed_kp
## Value loss part
if self.loss_weights['equivariance_value'] != 0:
value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean()
loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value
## jacobian loss part
if self.loss_weights['equivariance_jacobian'] != 0:
jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']),
transformed_kp['jacobian'])
normed_driving = torch.inverse(kp_driving['jacobian'])
normed_transformed = jacobian_transformed
value = torch.matmul(normed_driving, normed_transformed)
eye = torch.eye(2).view(1, 1, 2, 2).type(value.type())
value = torch.abs(eye - value).mean()
loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value
return loss_values, generated
class DiscriminatorFullModel(torch.nn.Module):
"""
Merge all discriminator related updates into single model for better multi-gpu usage
"""
def __init__(self, kp_extractor, generator, discriminator, train_params):
super(DiscriminatorFullModel, self).__init__()
self.kp_extractor = kp_extractor
self.generator = generator
self.discriminator = discriminator
self.train_params = train_params
self.scales = self.discriminator.scales
self.pyramid = ImagePyramide(self.scales, generator.num_channels)
if torch.cuda.is_available():
self.pyramid = self.pyramid.cuda()
self.loss_weights = train_params['loss_weights']
def forward(self, x, generated):
pyramide_real = self.pyramid(x['driving'])
pyramide_generated = self.pyramid(generated['prediction'].detach())
kp_driving = generated['kp_driving']
discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
loss_values = {}
value_total = 0
for scale in self.scales:
key = 'prediction_map_%s' % scale
value = (1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2
value_total += self.loss_weights['discriminator_gan'] * value.mean()
loss_values['disc_gan'] = value_total
return loss_values
================================================
FILE: modules/model_delta_map.py
================================================
from torch import nn
import torch
import torch.nn.functional as F
from modules.util import AntiAliasInterpolation2d, make_coordinate_grid
from torchvision import models
import numpy as np
from torch.autograd import grad
class Vgg19(torch.nn.Module):
"""
Vgg19 network for perceptual loss. See Sec 3.3.
"""
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
requires_grad=False)
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
requires_grad=False)
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
X = (X - self.mean) / self.std
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
class ImagePyramide(torch.nn.Module):
"""
Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
"""
def __init__(self, scales, num_channels):
super(ImagePyramide, self).__init__()
downs = {}
for scale in scales:
downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
self.downs = nn.ModuleDict(downs)
def forward(self, x):
out_dict = {}
for scale, down_module in self.downs.items():
out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
return out_dict
class Transform:
"""
Random tps transformation for equivariance constraints. See Sec 3.3
"""
def __init__(self, bs, **kwargs):
noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))
self.theta = noise + torch.eye(2, 3).view(1, 2, 3)
self.bs = bs
if ('sigma_tps' in kwargs) and ('points_tps' in kwargs):
self.tps = True
self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())
self.control_points = self.control_points.unsqueeze(0)
self.control_params = torch.normal(mean=0,
std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))
else:
self.tps = False
def transform_frame(self, frame):
grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2]
grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)
return F.grid_sample(frame, grid, padding_mode="reflection")
def inverse_transform_frame(self, frame):
grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2]
grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
grid = self.inverse_warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)
return F.grid_sample(frame, grid, padding_mode="reflection")
def warp_coordinates(self, coordinates):
theta = self.theta.type(coordinates.type())
theta = theta.unsqueeze(1)
transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]
transformed = transformed.squeeze(-1)
if self.tps:
control_points = self.control_points.type(coordinates.type())
control_params = self.control_params.type(coordinates.type())
distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
distances = torch.abs(distances).sum(-1)
result = distances ** 2
result = result * torch.log(distances + 1e-6)
result = result * control_params
result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
transformed = transformed + result
return transformed
def inverse_warp_coordinates(self, coordinates):
theta = self.theta.type(coordinates.type())
theta = theta.unsqueeze(1)
a = torch.FloatTensor([[[[0,0,1]]]]).repeat([self.bs,1,1,1]).cuda()
c = torch.cat((theta,a),2)
d = c.inverse()[:,:,:2,:]
d = d.type(coordinates.type())
transformed = torch.matmul(d[:, :, :, :2], coordinates.unsqueeze(-1)) + d[:, :, :, 2:]
transformed = transformed.squeeze(-1)
if self.tps:
control_points = self.control_points.type(coordinates.type())
control_params = self.control_params.type(coordinates.type())
distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
distances = torch.abs(distances).sum(-1)
result = distances ** 2
result = result * torch.log(distances + 1e-6)
result = result * control_params
result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
transformed = transformed + result
return transformed
def jacobian(self, coordinates):
coordinates.requires_grad=True
new_coordinates = self.warp_coordinates(coordinates)#[4,10,2]
grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True)
grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True)
jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2)
return jacobian
def detach_kp(kp):
return {key: value.detach() for key, value in kp.items()}
class TrainFullModel(torch.nn.Module):
"""
Merge all generator related updates into single model for better multi-gpu usage
"""
def __init__(self, kp_extractor, emo_feature, kp_extractor_a, audio_feature, generator, discriminator, train_params, device_ids):
super(TrainFullModel, self).__init__()
self.kp_extractor = kp_extractor
self.kp_extractor_a = kp_extractor_a
# self.emo_detector = emo_detector
# self.content_encoder = content_encoder
# self.emotion_encoder = emotion_encoder
self.audio_feature = audio_feature
self.emo_feature = emo_feature
self.generator = generator
self.discriminator = discriminator
self.train_params = train_params
self.scales = train_params['scales']
self.disc_scales = self.discriminator.scales
self.pyramid = ImagePyramide(self.scales, generator.num_channels)
if torch.cuda.is_available():
self.pyramid = self.pyramid.cuda()
self.loss_weights = train_params['loss_weights']
if sum(self.loss_weights['perceptual']) != 0:
self.vgg = Vgg19()
if torch.cuda.is_available():
self.vgg = self.vgg.cuda()
# self.pca = torch.FloatTensor(np.load('/mnt/lustre/jixinya/Home/LRW/list/U_106.npy'))[:, :16].to(device_ids[0])
# self.mean = torch.FloatTensor(np.load('/mnt/lustre/jixinya/Home/LRW/list/mean_106.npy')).to(device_ids[0])
self.mse_loss_fn = nn.MSELoss().cuda()
self.CroEn_loss = nn.CrossEntropyLoss().cuda()
def forward(self, x):
# source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[])
# source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1)))
kp_source = self.kp_extractor(x['example_image'])
kp_driving = []
kp_emo = []
for i in range(16):
kp_driving.append(self.kp_extractor(x['driving'][:,i]))
# kp_emo.append(self.emo_detector(x['driving'][:,i]))
# print('KP_driving ', file=open('/mnt/lustre/jixinya/Home/fomm_audio/log/LRW_test.txt', 'a'))
kp_driving_a = [] #x['example_image'],
deco_out = self.audio_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net'])
# emo_out = self.emo_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net'])
loss_values = {}
if self.loss_weights['emo'] != 0:
kp_driving_a = []
fakes = []
for i in range(16):
kp_driving_a.append(self.kp_extractor_a(deco_out[:,i]))#
value = self.kp_extractor_a(deco_out[:,i])['value']
jacobian = self.kp_extractor_a(deco_out[:,i])['jacobian']
if self.train_params['type'] == 'map_4':
out, fake = self.emo_feature.map_4(x['transformed_driving'][:,i],value,jacobian)
kp_emo.append(out)
fakes.append(fake)
# kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian))
elif self.train_params['type'] == 'map_10':
# kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))
out, fake = self.emo_feature(x['transformed_driving'][:,i],value,jacobian)
kp_emo.append(out)
fakes.append(fake)
# kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian))
# print('Kp_audio_driving ', file=open('/mnt/lustre/jixinya/Home/fomm_audio/log/LRW_test.txt', 'a'))
loss_value = 0
# loss_heatmap = 0
loss_jacobian = 0
loss_perceptual = 0
loss_classify = 0
kp_all = kp_driving_a
for i in range(len(kp_driving)):
if self.train_params['type'] == 'map_4':
loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,1] - kp_driving_a[i]['jacobian'][:,1] -kp_emo[i]['jacobian'][:,0]).mean())*self.loss_weights['emo']
loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,4] - kp_driving_a[i]['jacobian'][:,4] -kp_emo[i]['jacobian'][:,1]).mean())*self.loss_weights['emo']
loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,6] - kp_driving_a[i]['jacobian'][:,6] -kp_emo[i]['jacobian'][:,2]).mean())*self.loss_weights['emo']
loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,8] - kp_driving_a[i]['jacobian'][:,8] -kp_emo[i]['jacobian'][:,3]).mean())*self.loss_weights['emo']
loss_classify += self.CroEn_loss(fakes[i],x['emotion'])
loss_value += (torch.abs(kp_driving[i]['value'][:,1] .detach() - kp_driving_a[i]['value'][:,1] - kp_emo[i]['value'][:,0] ).mean())*self.loss_weights['emo']
loss_value += (torch.abs(kp_driving[i]['value'][:,4] .detach() - kp_driving_a[i]['value'][:,4] - kp_emo[i]['value'][:,1] ).mean())*self.loss_weights['emo']
loss_value += (torch.abs(kp_driving[i]['value'][:,6] .detach() - kp_driving_a[i]['value'][:,6] - kp_emo[i]['value'][:,2] ).mean())*self.loss_weights['emo']
loss_value += (torch.abs(kp_driving[i]['value'][:,8] .detach() - kp_driving_a[i]['value'][:,8] - kp_emo[i]['value'][:,3] ).mean())*self.loss_weights['emo']
kp_all[i]['jacobian'][:,1] = kp_emo[i]['jacobian'][:,0] + kp_driving_a[i]['jacobian'][:,1]
kp_all[i]['jacobian'][:,4] = kp_emo[i]['jacobian'][:,1] + kp_driving_a[i]['jacobian'][:,4]
kp_all[i]['jacobian'][:,6] = kp_emo[i]['jacobian'][:,2] + kp_driving_a[i]['jacobian'][:,6]
kp_all[i]['jacobian'][:,8] = kp_emo[i]['jacobian'][:,3] + kp_driving_a[i]['jacobian'][:,8]
kp_all[i]['value'][:,1] = kp_emo[i]['value'][:,0] + kp_driving_a[i]['value'][:,1]
kp_all[i]['value'][:,4] = kp_emo[i]['value'][:,1] + kp_driving_a[i]['value'][:,4]
kp_all[i]['value'][:,6] = kp_emo[i]['value'][:,2] + kp_driving_a[i]['value'][:,6]
kp_all[i]['value'][:,8] = kp_emo[i]['value'][:,3] + kp_driving_a[i]['value'][:,8]
elif self.train_params['type'] == 'map_10':
loss_jacobian += (torch.abs(kp_driving[i]['jacobian'] - kp_driving_a[i]['jacobian'] -kp_emo[i]['jacobian']).mean())*self.loss_weights['emo']
loss_classify += self.CroEn_loss(fakes[i],x['emotion'])
loss_value += (torch.abs(kp_driving[i]['value'].detach() - kp_driving_a[i]['value'] - kp_emo[i]['value'] ).mean())*self.loss_weights['emo']
# kp_all[i]['value'] = kp_emo[i]['value'] + kp_driving_a[i]['value']
loss_values['loss_value'] = loss_value/len(kp_driving)
# loss_values['loss_heatmap'] = loss_heatmap/len(kp_driving)
loss_values['loss_jacobian'] = loss_jacobian/len(kp_driving)
loss_values['loss_classify'] = loss_classify/len(kp_driving)
if self.train_params['generator'] == 'not':
loss_values['perceptual'] = self.mse_loss_fn(deco_out,deco_out)
for i in range(1): #0,len(kp_driving),4
generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_all[i])
generated.update({'kp_source': kp_source, 'kp_driving': kp_all})
elif self.train_params['generator'] == 'visual':
for i in range(0,len(kp_driving),4): #0,len(kp_driving),4
generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving[i])
generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
pyramide_real = self.pyramid(x['driving'][:,i])
pyramide_generated = self.pyramid(generated['prediction'])
if sum(self.loss_weights['perceptual']) != 0:
value_total = 0
for scale in self.scales:
x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
for i, weight in enumerate(self.loss_weights['perceptual']):
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
value_total += self.loss_weights['perceptual'][i] * value
loss_perceptual += value_total
length = int((len(kp_driving)-1)/4)+1
loss_values['perceptual'] = loss_perceptual/length
elif self.train_params['generator'] == 'audio':
for i in range(0,len(kp_driving),4): #0,len(kp_driving),4
generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving_a[i])
generated.update({'kp_source': kp_source, 'kp_driving': kp_driving_a})
pyramide_real = self.pyramid(x['driving'][:,i])
pyramide_generated = self.pyramid(generated['prediction'])
if sum(self.loss_weights['perceptual']) != 0:
value_total = 0
for scale in self.scales:
x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
for i, weight in enumerate(self.loss_weights['perceptual']):
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
value_total += self.loss_weights['perceptual'][i] * value
loss_perceptual += value_total
length = int((len(kp_driving)-1)/4)+1
loss_values['perceptual'] = loss_perceptual/length
else:
print('wrong train_params: ', self.train_params['generator'])
return loss_values,generated
class GeneratorFullModel(torch.nn.Module):
"""
Merge all generator related updates into single model for better multi-gpu usage
"""
def __init__(self, kp_extractor, kp_extractor_a, audio_feature, generator, discriminator, train_params):
super(GeneratorFullModel, self).__init__()
self.kp_extractor = kp_extractor
self.kp_extractor_a = kp_extractor_a
# self.content_encoder = content_encoder
# self.emotion_encoder = emotion_encoder
self.audio_feature = audio_feature
self.generator = generator
self.discriminator = discriminator
self.train_params = train_params
self.scales = train_params['scales']
self.disc_scales = self.discriminator.scales
self.pyramid = ImagePyramide(self.scales, generator.num_channels)
if torch.cuda.is_available():
self.pyramid = self.pyramid.cuda()
self.loss_weights = train_params['loss_weights']
if sum(self.loss_weights['perceptual']) != 0:
self.vgg = Vgg19()
if torch.cuda.is_available():
self.vgg = self.vgg.cuda()
self.pca = torch.FloatTensor(np.load('.../LRW/list/U_106.npy'))[:, :16].cuda()
self.mean = torch.FloatTensor(np.load('.../LRW/list/mean_106.npy')).cuda()
def forward(self, x):
# source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[])
# source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1)))
# kp_source = self.kp_extractor(x['source'])
# kp_source_a = self.kp_extractor_a(x['source'], x['source_cube'], source_a_f)
# driving_a_f = self.audio_feature(self.content_encoder(x['driving_audio'].unsqueeze(1)), self.emotion_encoder(x['driving_audio'].unsqueeze(1)))
# driving_a_f = self.audio_feature(x['driving_audio'])
# kp_driving = self.kp_extractor(x['driving'])
# kp_driving_a = self.kp_extractor_a(x['driving'], x['driving_cube'], driving_a_f)
kp_driving = []
for i in range(16):
kp_driving.append(self.kp_extractor(x['driving'][:,i],x['driving_landmark'][:,i],self.loss_weights['equivariance_value']))
kp_driving_a = []
fc_out, deco_out = self.audio_feature(x['example_landmark'], x['driving_audio'], x['driving_pose'])
fake_lmark=fc_out + x['example_landmark'].expand_as(fc_out)
fake_lmark = torch.mm( fake_lmark, self.pca.t() )
fake_lmark = fake_lmark + self.mean.expand_as(fake_lmark)
fake_lmark = fake_lmark.unsqueeze(0)
# for i in range(16):
# kp_driving_a.append()
# generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving)
# generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
loss_values = {}
pyramide_real = self.pyramid(x['driving'])
pyramide_generated = self.pyramid(generated['prediction'])
if self.loss_weights['audio'] != 0:
value = torch.abs(kp_source['jacobian'].detach() - kp_source_a['jacobian'].detach()).mean() + torch.abs(kp_driving['jacobian'].detach() - kp_driving_a['jacobian']).mean()
value = value/2
loss_values['jacobian'] = value*self.loss_weights['audio']
value = torch.abs(kp_source['heatmap'].detach() - kp_source_a['heatmap'].detach()).mean() + torch.abs(kp_driving['heatmap'].detach() - kp_driving_a['heatmap']).mean()
value = value/2
loss_values['heatmap'] = value*self.loss_weights['audio']
value = torch.abs(kp_source['value'].detach() - kp_source_a['value'].detach()).mean() + torch.abs(kp_driving['value'].detach() - kp_driving_a['value']).mean()
value = value/2
loss_values['value'] = value*self.loss_weights['audio']
if sum(self.loss_weights['perceptual']) != 0:
value_total = 0
for scale in self.scales:
x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
for i, weight in enumerate(self.loss_weights['perceptual']):
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
value_total += self.loss_weights['perceptual'][i] * value
loss_values['perceptual'] = value_total
if self.loss_weights['generator_gan'] != 0:
discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
value_total = 0
for scale in self.disc_scales:
key = 'prediction_map_%s' % scale
value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
value_total += self.loss_weights['generator_gan'] * value
loss_values['gen_gan'] = value_total
if sum(self.loss_weights['feature_matching']) != 0:
value_total = 0
for scale in self.disc_scales:
key = 'feature_maps_%s' % scale
for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):
if self.loss_weights['feature_matching'][i] == 0:
continue
value = torch.abs(a - b).mean()
value_total += self.loss_weights['feature_matching'][i] * value
loss_values['feature_matching'] = value_total
if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0:
transform = Transform(x['driving'].shape[0], **self.train_params['transform_params'])
transformed_frame = transform.transform_frame(x['driving'])
transformed_landmark = transform.inverse_warp_coordinates(x['driving_landmark'])
transformed_kp = self.kp_extractor(transformed_frame)
generated['transformed_frame'] = transformed_frame
generated['transformed_kp'] = transformed_kp
## Value loss part
if self.loss_weights['equivariance_value'] != 0:
value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean()
loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value
## jacobian loss part
if self.loss_weights['equivariance_jacobian'] != 0:
jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']),
transformed_kp['jacobian'])
normed_driving = torch.inverse(kp_driving['jacobian'])
normed_transformed = jacobian_transformed
value = torch.matmul(normed_driving, normed_transformed)
eye = torch.eye(2).view(1, 1, 2, 2).type(value.type())
value = torch.abs(eye - value).mean()
loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value
return loss_values, generated
class DiscriminatorFullModel(torch.nn.Module):
"""
Merge all discriminator related updates into single model for better multi-gpu usage
"""
def __init__(self, kp_extractor, generator, discriminator, train_params):
super(DiscriminatorFullModel, self).__init__()
self.kp_extractor = kp_extractor
self.generator = generator
self.discriminator = discriminator
self.train_params = train_params
self.scales = self.discriminator.scales
self.pyramid = ImagePyramide(self.scales, generator.num_channels)
if torch.cuda.is_available():
self.pyramid = self.pyramid.cuda()
self.loss_weights = train_params['loss_weights']
def forward(self, x, generated):
pyramide_real = self.pyramid(x['driving'])
pyramide_generated = self.pyramid(generated['prediction'].detach())
kp_driving = generated['kp_driving']
discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
loss_values = {}
value_total = 0
for scale in self.scales:
key = 'prediction_map_%s' % scale
value = (1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2
value_total += self.loss_weights['discriminator_gan'] * value.mean()
loss_values['disc_gan'] = value_total
return loss_values
================================================
FILE: modules/model_gen.py
================================================
from torch import nn
import torch
import torch.nn.functional as F
from modules.util import AntiAliasInterpolation2d, make_coordinate_grid
from torchvision import models
import numpy as np
from torch.autograd import grad
class Vgg19(torch.nn.Module):
"""
Vgg19 network for perceptual loss. See Sec 3.3.
"""
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
requires_grad=False)
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
requires_grad=False)
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
X = (X - self.mean) / self.std
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
class ImagePyramide(torch.nn.Module):
"""
Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
"""
def __init__(self, scales, num_channels):
super(ImagePyramide, self).__init__()
downs = {}
for scale in scales:
downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
self.downs = nn.ModuleDict(downs)
def forward(self, x):
out_dict = {}
for scale, down_module in self.downs.items():
out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
return out_dict
class Transform:
"""
Random tps transformation for equivariance constraints. See Sec 3.3
"""
def __init__(self, bs, **kwargs):
noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))
self.theta = noise + torch.eye(2, 3).view(1, 2, 3)
self.bs = bs
if ('sigma_tps' in kwargs) and ('points_tps' in kwargs):
self.tps = True
self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())
self.control_points = self.control_points.unsqueeze(0)
self.control_params = torch.normal(mean=0,
std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))
else:
self.tps = False
def transform_frame(self, frame):
grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2]
grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)
return F.grid_sample(frame, grid, padding_mode="reflection")
def inverse_transform_frame(self, frame):
grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2]
grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
grid = self.inverse_warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)
return F.grid_sample(frame, grid, padding_mode="reflection")
def warp_coordinates(self, coordinates):
theta = self.theta.type(coordinates.type())
theta = theta.unsqueeze(1)
transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]
transformed = transformed.squeeze(-1)
if self.tps:
control_points = self.control_points.type(coordinates.type())
control_params = self.control_params.type(coordinates.type())
distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
distances = torch.abs(distances).sum(-1)
result = distances ** 2
result = result * torch.log(distances + 1e-6)
result = result * control_params
result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
transformed = transformed + result
return transformed
def inverse_warp_coordinates(self, coordinates):
theta = self.theta.type(coordinates.type())
theta = theta.unsqueeze(1)
a = torch.FloatTensor([[[[0,0,1]]]]).repeat([self.bs,1,1,1]).cuda()
c = torch.cat((theta,a),2)
d = c.inverse()[:,:,:2,:]
d = d.type(coordinates.type())
transformed = torch.matmul(d[:, :, :, :2], coordinates.unsqueeze(-1)) + d[:, :, :, 2:]
transformed = transformed.squeeze(-1)
if self.tps:
control_points = self.control_points.type(coordinates.type())
control_params = self.control_params.type(coordinates.type())
distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
distances = torch.abs(distances).sum(-1)
result = distances ** 2
result = result * torch.log(distances + 1e-6)
result = result * control_params
result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
transformed = transformed + result
return transformed
def jacobian(self, coordinates):
coordinates.requires_grad=True
new_coordinates = self.warp_coordinates(coordinates)#[4,10,2]
grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True)
grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True)
jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2)
return jacobian
def detach_kp(kp):
return {key: value.detach() for key, value in kp.items()}
class TrainFullModel(torch.nn.Module):
"""
Merge all generator related updates into single model for better multi-gpu usage
"""
def __init__(self, kp_extractor, emo_feature, kp_extractor_a, audio_feature, generator, discriminator, train_params, device_ids):
super(TrainFullModel, self).__init__()
self.kp_extractor = kp_extractor
self.kp_extractor_a = kp_extractor_a
# self.emo_detector = emo_detector
# self.content_encoder = content_encoder
# self.emotion_encoder = emotion_encoder
self.audio_feature = audio_feature
self.emo_feature = emo_feature
self.generator = generator
self.discriminator = discriminator
self.train_params = train_params
self.scales = train_params['scales']
self.disc_scales = self.discriminator.scales
self.pyramid = ImagePyramide(self.scales, generator.num_channels)
if torch.cuda.is_available():
self.pyramid = self.pyramid.cuda()
self.loss_weights = train_params['loss_weights']
if sum(self.loss_weights['perceptual']) != 0:
self.vgg = Vgg19()
if torch.cuda.is_available():
self.vgg = self.vgg.cuda()
# self.pca = torch.FloatTensor(np.load('/mnt/lustre/jixinya/Home/LRW/list/U_106.npy'))[:, :16].to(device_ids[0])
# self.mean = torch.FloatTensor(np.load('/mnt/lustre/jixinya/Home/LRW/list/mean_106.npy')).to(device_ids[0])
self.mse_loss_fn = nn.MSELoss().cuda()
self.CroEn_loss = nn.CrossEntropyLoss().cuda()
def forward(self, x):
# source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[])
# source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1)))
kp_source = self.kp_extractor(x['example_image'])
# print(x['name'],len(x['name']))
kp_driving = []
kp_emo = []
for i in range(16):
kp_driving.append(self.kp_extractor(x['driving'][:,i]))
# kp_emo.append(self.emo_detector(x['driving'][:,i]))
# print('KP_driving ', file=open('/mnt/lustre/jixinya/Home/fomm_audio/log/LRW_test.txt', 'a'))
kp_driving_a = [] #x['example_image'],
deco_out = self.audio_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net'])
# emo_out = self.emo_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net'])
loss_values = {}
if self.loss_weights['emo'] != 0:
kp_driving_a = []
fakes = []
for i in range(16):
kp_driving_a.append(self.kp_extractor_a(deco_out[:,i]))#
value = self.kp_extractor_a(deco_out[:,i])['value']
jacobian = self.kp_extractor_a(deco_out[:,i])['jacobian']
if self.train_params['type'] == 'linear_4' and x['name'][0] == 0:
out, fake = self.emo_feature(x['transformed_driving'][:,i],value,jacobian)
kp_emo.append(out)
fakes.append(fake)
# kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian))
elif self.train_params['type'] == 'linear_10' and x['name'][0] == 0:
# kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))
out, fake = self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)
kp_emo.append(out)
fakes.append(fake)
elif self.train_params['type'] == 'linear_4_new' and x['name'][0] == 0:
# kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))
out, fake = self.emo_feature.linear_4(x['transformed_driving'][:,i],value,jacobian)
kp_emo.append(out)
fakes.append(fake)
elif self.train_params['type'] == 'linear_np_4':
# kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))
out, fake = self.emo_feature.linear_np_4(x['transformed_driving'][:,i],value,jacobian)
kp_emo.append(out)
fakes.append(fake)
elif self.train_params['type'] == 'linear_np_10':
# kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))
out, fake = self.emo_feature.linear_np_10(x['transformed_driving'][:,i],value,jacobian)
kp_emo.append(out)
fakes.append(fake)
# kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian))
# print('Kp_audio_driving ', file=open('/mnt/lustre/jixinya/Home/fomm_audio/log/LRW_test.txt', 'a'))
loss_perceptual = 0
kp_all = kp_driving_a
if self.train_params['smooth'] == True:
value_all = torch.randn(len(kp_driving),out['value'].shape[0],out['value'].shape[1],out['value'].shape[2]).cuda()
jacobian_all = torch.randn(len(kp_driving),out['jacobian'].shape[0],out['jacobian'].shape[1],2,2).cuda()
print(len(kp_driving))
for i in range(len(kp_driving)):
# if x['name'][i] == 'LRW':
# loss_jacobian += (torch.abs(kp_driving[i]['jacobian'] - kp_driving_a[i]['jacobian']).mean())*self.loss_weights['emo']
# loss_value += (torch.abs(kp_driving[i]['value'].detach() - kp_driving_a[i]['value']).mean())*self.loss_weights['emo']
# loss_classify += self.mse_loss_fn(deco_out,deco_out)
if self.train_params['type'] == 'linear_4' and x['name'][0] == 0:
kp_all[i]['jacobian'][:,1] = kp_emo[i]['jacobian'][:,0] + kp_driving_a[i]['jacobian'][:,1]
kp_all[i]['jacobian'][:,4] = kp_emo[i]['jacobian'][:,1] + kp_driving_a[i]['jacobian'][:,4]
kp_all[i]['jacobian'][:,6] = kp_emo[i]['jacobian'][:,2] + kp_driving_a[i]['jacobian'][:,6]
kp_all[i]['jacobian'][:,8] = kp_emo[i]['jacobian'][:,3] + kp_driving_a[i]['jacobian'][:,8]
kp_all[i]['value'][:,1] = kp_emo[i]['value'][:,0] + kp_driving_a[i]['value'][:,1]
kp_all[i]['value'][:,4] = kp_emo[i]['value'][:,1] + kp_driving_a[i]['value'][:,4]
kp_all[i]['value'][:,6] = kp_emo[i]['value'][:,2] + kp_driving_a[i]['value'][:,6]
kp_all[i]['value'][:,8] = kp_emo[i]['value'][:,3] + kp_driving_a[i]['value'][:,8]
# kp_all[i]['value'] = kp_emo[i]['value'] + kp_driving_a[i]['value']
if self.train_params['smooth'] == True:
loss_smooth = 0
loss_smooth += (torch.abs(value_all[2:,:,:,:] + value_all[:-2,:,:,:].detach() -2*value_all[1:-1,:,:,:].detach()).mean())*self.loss_weights['emo'] *100
loss_smooth += (torch.abs(jacobian_all[2:,:,:,:] + jacobian_all[:-2,:,:,:].detach() -2*jacobian_all[1:-1,:,:,:].detach()).mean())*self.loss_weights['emo'] *100
loss_values['loss_smooth'] = loss_smooth/len(kp_driving)
else:
loss_values['loss_smooth'] = self.mse_loss_fn(deco_out,deco_out)
if self.train_params['generator'] == 'not':
loss_values['perceptual'] = self.mse_loss_fn(deco_out,deco_out)
for i in range(1): #0,len(kp_driving),4
generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_all[i])
generated.update({'kp_source': kp_source, 'kp_driving': kp_all})
elif self.train_params['generator'] == 'visual':
for i in range(0,len(kp_driving),4): #0,len(kp_driving),4
generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving[i])
generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
pyramide_real = self.pyramid(x['driving'][:,i])
pyramide_generated = self.pyramid(generated['prediction'])
if sum(self.loss_weights['perceptual']) != 0:
value_total = 0
for scale in self.scales:
x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
for i, weight in enumerate(self.loss_weights['perceptual']):
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
value_total += self.loss_weights['perceptual'][i] * value
loss_perceptual += value_total
length = int((len(kp_driving)-1)/4)+1
loss_values['perceptual'] = loss_perceptual/length
elif self.train_params['generator'] == 'audio':
for i in range(0,len(kp_driving),4): #0,len(kp_driving),4
generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_all[i])
generated.update({'kp_source': kp_source, 'kp_driving': kp_all})
pyramide_real = self.pyramid(x['driving'][:,i])
pyramide_generated = self.pyramid(generated['prediction'])
# loss_mse = nn.MSELoss(generated['prediction'],x['driving'][:,i])
if sum(self.loss_weights['perceptual']) != 0:
value_total = 0
for scale in self.scales:
x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
for i, weight in enumerate(self.loss_weights['perceptual']):
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
value_total += self.loss_weights['perceptual'][i] * value
loss_perceptual += value_total
length = int((len(kp_driving)-1)/4)+1
loss_values['perceptual'] = loss_perceptual/length
# loss_values['mse'] = loss_mse/length
else:
print('wrong train_params: ', self.train_params['generator'])
return loss_values,generated
class GeneratorFullModel(torch.nn.Module):
"""
Merge all generator related updates into single model for better multi-gpu usage
"""
def __init__(self, kp_extractor, kp_extractor_a, audio_feature, generator, discriminator, train_params):
super(GeneratorFullModel, self).__init__()
self.kp_extractor = kp_extractor
self.kp_extractor_a = kp_extractor_a
# self.content_encoder = content_encoder
# self.emotion_encoder = emotion_encoder
self.audio_feature = audio_feature
self.generator = generator
self.discriminator = discriminator
self.train_params = train_params
self.scales = train_params['scales']
self.disc_scales = self.discriminator.scales
self.pyramid = ImagePyramide(self.scales, generator.num_channels)
if torch.cuda.is_available():
self.pyramid = self.pyramid.cuda()
self.loss_weights = train_params['loss_weights']
if sum(self.loss_weights['perceptual']) != 0:
self.vgg = Vgg19()
if torch.cuda.is_available():
self.vgg = self.vgg.cuda()
self.pca = torch.FloatTensor(np.load('.../LRW/list/U_106.npy'))[:, :16].cuda()
self.mean = torch.FloatTensor(np.load('.../LRW/list/mean_106.npy')).cuda()
def forward(self, x):
# source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[])
# source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1)))
#
gitextract_7pp944hx/ ├── 3DDFA_V2/ │ ├── demo.py │ └── utils/ │ └── pose.py ├── LICENSE ├── M003_template.npy ├── README.md ├── augmentation.py ├── config/ │ ├── MEAD_emo_video_aug_delta_4_crop_random_crop.yaml │ ├── train_part1.yaml │ ├── train_part1_fine_tune.yaml │ └── train_part2.yaml ├── dataset/ │ ├── LRW/ │ │ ├── MFCC/ │ │ │ └── ABOUT/ │ │ │ └── ABOUT_00001.npy │ │ └── Pose/ │ │ └── ABOUT/ │ │ └── ABOUT_00001.npy │ └── MEAD/ │ └── list/ │ └── MEAD_fomm_neu_dic_crop.npy ├── demo.py ├── filter1.py ├── frames_dataset.py ├── logger.py ├── modules/ │ ├── dense_motion.py │ ├── discriminator.py │ ├── function.py │ ├── generator.py │ ├── keypoint_detector.py │ ├── model.py │ ├── model_delta_map.py │ ├── model_gen.py │ ├── ops.py │ ├── stylegan2.py │ └── util.py ├── ops.py ├── process_data.py ├── requirements.txt ├── run.py ├── sync_batchnorm/ │ ├── __init__.py │ ├── batchnorm.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py ├── test/ │ ├── pose/ │ │ ├── 14.npy │ │ ├── 21.npy │ │ ├── 60.npy │ │ ├── 7.npy │ │ ├── anne.npy │ │ ├── brade2.npy │ │ ├── dune_1.npy │ │ ├── dune_2.npy │ │ ├── jake4.npy │ │ ├── mona.npy │ │ ├── paint1.npy │ │ └── scarlett.npy │ └── pose_long/ │ ├── 0zn70Ak8lRc_Daniel_Auteuil_0zn70Ak8lRc_0002.npy │ ├── 1hEr7qKRKL4_Daniel_Dae_Kim_1hEr7qKRKL4_0004.npy │ └── 50IAfJCypFI_Alex_Kingston_50IAfJCypFI_0001.npy └── train.py
SYMBOL INDEX (451 symbols across 25 files)
FILE: 3DDFA_V2/demo.py
function main (line 29) | def main(args,img, save_path, pose_path):
function process_word (line 102) | def process_word(i):
FILE: 3DDFA_V2/utils/pose.py
function P2sRt (line 18) | def P2sRt(P):
function matrix2angle (line 39) | def matrix2angle(R):
function angle2matrix (line 65) | def angle2matrix(theta):
function angle2matrix_3ddfa (line 112) | def angle2matrix_3ddfa(angles):
function calc_pose (line 140) | def calc_pose(param):
function build_camera_box (line 150) | def build_camera_box(rear_size=90):
function plot_pose_box (line 171) | def plot_pose_box(img, P, ver, color=(40, 255, 0), line_width=2):
function viz_pose (line 201) | def viz_pose(img, param_lst, ver_lst, show_flag=False, wfp=None):
function pose_6 (line 217) | def pose_6(param):
function smooth_pose (line 231) | def smooth_pose(img, param_lst, ver_lst, pose_new, show_flag=False, wfp=...
function get_pose (line 263) | def get_pose(img, param_lst, ver_lst, show_flag=False, wfp=None, wnp = N...
FILE: augmentation.py
function crop_clip (line 20) | def crop_clip(clip, min_h, min_w, h, w):
function pad_clip (line 34) | def pad_clip(clip, h, w):
function resize_clip (line 42) | def resize_clip(clip, size, interpolation='bilinear'):
function get_resize_sizes (line 81) | def get_resize_sizes(im_h, im_w, size):
class RandomFlip (line 91) | class RandomFlip(object):
method __init__ (line 92) | def __init__(self, time_flip=False, horizontal_flip=False):
method __call__ (line 96) | def __call__(self, clip):
class RandomResize (line 105) | class RandomResize(object):
method __init__ (line 115) | def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
method __call__ (line 119) | def __call__(self, clip):
class RandomCrop (line 136) | class RandomCrop(object):
method __init__ (line 143) | def __init__(self, size):
method __call__ (line 149) | def __call__(self, clip):
class MouthCrop (line 175) | class MouthCrop(object):
method __init__ (line 182) | def __init__(self, center_x, center_y, mask_width, mask_height):
method __call__ (line 190) | def __call__(self, clip):
class RandomRotation (line 215) | class RandomRotation(object):
method __init__ (line 224) | def __init__(self, degrees):
method __call__ (line 237) | def __call__(self, clip):
class RandomPerspective (line 256) | class RandomPerspective(object):
method __init__ (line 265) | def __init__(self, pers_num, enlarge_num):
method __call__ (line 269) | def __call__(self, clip):
class ColorJitter (line 297) | class ColorJitter(object):
method __init__ (line 310) | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
method get_params (line 316) | def get_params(self, brightness, contrast, saturation, hue):
method __call__ (line 341) | def __call__(self, clip):
class AllAugmentationTransform (line 403) | class AllAugmentationTransform:
method __init__ (line 404) | def __init__(self, crop_mouth_param = None, resize_param=None, rotatio...
method __call__ (line 427) | def __call__(self, clip):
FILE: demo.py
function load_checkpoints (line 49) | def load_checkpoints(opt, checkpoint_path, audio_checkpoint_path, emo_ch...
function normalize_kp (line 112) | def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_moveme...
function shape_to_np (line 134) | def shape_to_np(shape, dtype="int"):
function get_aligned_image (line 146) | def get_aligned_image(driving_video, opt):
function get_transformed_image (line 184) | def get_transformed_image(driving_video, opt):
function make_animation_smooth (line 194) | def make_animation_smooth(source_image, driving_video, transformed_video...
function test_auido (line 286) | def test_auido(example_image, audio_feature, all_pose, opt):
function save (line 357) | def save(path, frames, format):
class VideoWriter (line 370) | class VideoWriter(object):
method __init__ (line 371) | def __init__(self, path, width, height, fps):
method write_frame (line 376) | def write_frame(self, frame):
method end (line 379) | def end(self):
function concatenate (line 382) | def concatenate(number, imgs, save_path):
function add_audio (line 427) | def add_audio(video_name=None, audio_dir = None):
function crop_image (line 433) | def crop_image(source_image):
function smooth_pose (line 456) | def smooth_pose(pose_file, pose_long):
function test (line 467) | def test(opt, name):
FILE: filter1.py
class LowPassFilter (line 13) | class LowPassFilter:
method __init__ (line 14) | def __init__(self):
method process (line 18) | def process(self, value, alpha):
class OneEuroFilter (line 28) | class OneEuroFilter:
method __init__ (line 29) | def __init__(self, mincutoff=1.0, beta=0.0, dcutoff=1.0, freq=30):
method compute_alpha (line 37) | def compute_alpha(self, cutoff):
method process (line 42) | def process(self, x):
FILE: frames_dataset.py
function read_video (line 15) | def read_video(name, frame_shape):
function get_list (line 55) | def get_list(ipath,base_name):
class AudioDataset (line 75) | class AudioDataset(Dataset):
method __init__ (line 83) | def __init__(self, name, root_dir, frame_shape=(256, 256, 3), id_sampl...
method __len__ (line 132) | def __len__(self):
method __getitem__ (line 135) | def __getitem__(self, idx):
class VoxDataset (line 196) | class VoxDataset(Dataset):
method __init__ (line 204) | def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=Fa...
method __len__ (line 252) | def __len__(self):
method __getitem__ (line 255) | def __getitem__(self, idx):
class MeadDataset (line 328) | class MeadDataset(Dataset):
method __init__ (line 336) | def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=Fa...
method __len__ (line 378) | def __len__(self):
method __getitem__ (line 381) | def __getitem__(self, idx):
class DatasetRepeater (line 461) | class DatasetRepeater(Dataset):
method __init__ (line 466) | def __init__(self, dataset, num_repeats=100):
method __len__ (line 471) | def __len__(self):
method __getitem__ (line 474) | def __getitem__(self, idx):
class TestsetRepeater (line 481) | class TestsetRepeater(Dataset):
method __init__ (line 486) | def __init__(self, dataset, num_repeats=100):
method __len__ (line 491) | def __len__(self):
method __getitem__ (line 494) | def __getitem__(self, idx):
class PairedDataset (line 499) | class PairedDataset(Dataset):
method __init__ (line 504) | def __init__(self, initial_dataset, number_of_pairs, seed=0):
method __len__ (line 529) | def __len__(self):
method __getitem__ (line 532) | def __getitem__(self, idx):
FILE: logger.py
class Logger (line 13) | class Logger:
method __init__ (line 14) | def __init__(self, log_dir, checkpoint_freq=100, visualizer_params=Non...
method log_scores (line 29) | def log_scores(self, loss_names):
method visualize_rec (line 39) | def visualize_rec(self, inp, out):
method save_cpk (line 44) | def save_cpk(self, emergent=False):
method load_cpk (line 53) | def load_cpk(checkpoint_path, generator=None, discriminator=None, kp_d...
method __enter__ (line 83) | def __enter__(self):
method __exit__ (line 86) | def __exit__(self, exc_type, exc_val, exc_tb):
method log_iter (line 91) | def log_iter(self, losses):
method log_epoch (line 97) | def log_epoch(self, epoch, step, models, inp, out):
class Visualizer (line 107) | class Visualizer:
method __init__ (line 108) | def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbo...
method draw_image_with_kp (line 113) | def draw_image_with_kp(self, image, kp_array):
method create_image_column_with_kp (line 123) | def create_image_column_with_kp(self, images, kp):
method create_image_column (line 127) | def create_image_column(self, images):
method create_image_grid (line 134) | def create_image_grid(self, *args):
method visualize (line 143) | def visualize(self, driving, transformed_driving, source, out):
FILE: modules/dense_motion.py
class DenseMotionNetwork (line 7) | class DenseMotionNetwork(nn.Module):
method __init__ (line 12) | def __init__(self, block_expansion, num_blocks, max_features, num_kp, ...
method create_heatmap_representations (line 32) | def create_heatmap_representations(self, source_image, kp_driving, kp_...
method create_sparse_motions (line 47) | def create_sparse_motions(self, source_image, kp_driving, kp_source):
method create_deformed_source_image (line 69) | def create_deformed_source_image(self, source_image, sparse_motions):
method forward (line 81) | def forward(self, source_image, kp_driving, kp_source):
FILE: modules/discriminator.py
class DownBlock2d (line 7) | class DownBlock2d(nn.Module):
method __init__ (line 12) | def __init__(self, in_features, out_features, norm=False, kernel_size=...
method forward (line 25) | def forward(self, x):
class Discriminator (line 36) | class Discriminator(nn.Module):
method __init__ (line 41) | def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, m...
method forward (line 59) | def forward(self, x, kp=None):
class MultiScaleDiscriminator (line 74) | class MultiScaleDiscriminator(nn.Module):
method __init__ (line 79) | def __init__(self, scales=(), **kwargs):
method forward (line 87) | def forward(self, x, kp=None):
FILE: modules/function.py
function calc_mean_std (line 12) | def calc_mean_std(feat, eps=1e-5):
function adaptive_instance_normalization (line 23) | def adaptive_instance_normalization(content_feat, style_feat):
function _calc_feat_flatten_mean_std (line 34) | def _calc_feat_flatten_mean_std(feat):
function _mat_sqrt (line 44) | def _mat_sqrt(x):
function coral (line 49) | def coral(source, target):
FILE: modules/generator.py
class OcclusionAwareGenerator (line 8) | class OcclusionAwareGenerator(nn.Module):
method __init__ (line 14) | def __init__(self, num_channels, num_kp, block_expansion, max_features...
method deform_input (line 50) | def deform_input(self, inp, deformation):
method forward (line 59) | def forward(self, source_image, kp_driving, kp_source):
FILE: modules/keypoint_detector.py
class KPDetector (line 7) | class KPDetector(nn.Module):
method __init__ (line 12) | def __init__(self, block_expansion, num_kp, num_channels, max_features,
method gaussian2kp (line 40) | def gaussian2kp(self, heatmap):
method audio_feature (line 52) | def audio_feature(self, x, heatmap):
method forward (line 77) | def forward(self, x): #torch.Size([4, 3, H, W])
class KPDetector_a (line 110) | class KPDetector_a(nn.Module):
method __init__ (line 115) | def __init__(self, block_expansion, num_kp, num_channels,num_channels_...
method gaussian2kp (line 143) | def gaussian2kp(self, heatmap):
method audio_feature (line 155) | def audio_feature(self, x, heatmap):
method forward (line 180) | def forward(self, feature_map): #torch.Size([4, 3, H, W])
class Audio_Feature (line 208) | class Audio_Feature(nn.Module):
method __init__ (line 209) | def __init__(self):
method forward (line 218) | def forward(self, x):
FILE: modules/model.py
class Vgg19 (line 10) | class Vgg19(torch.nn.Module):
method __init__ (line 14) | def __init__(self, requires_grad=False):
method forward (line 42) | def forward(self, X):
class ImagePyramide (line 53) | class ImagePyramide(torch.nn.Module):
method __init__ (line 57) | def __init__(self, scales, num_channels):
method forward (line 64) | def forward(self, x):
class Transform (line 71) | class Transform:
method __init__ (line 75) | def __init__(self, bs, **kwargs):
method transform_frame (line 89) | def transform_frame(self, frame):
method inverse_transform_frame (line 95) | def inverse_transform_frame(self, frame):
method warp_coordinates (line 101) | def warp_coordinates(self, coordinates):
method inverse_warp_coordinates (line 121) | def inverse_warp_coordinates(self, coordinates):
method jacobian (line 146) | def jacobian(self, coordinates):
function detach_kp (line 155) | def detach_kp(kp):
class TrainPart1Model (line 158) | class TrainPart1Model(torch.nn.Module):
method __init__ (line 163) | def __init__(self, kp_extractor, kp_extractor_a, audio_feature, genera...
method forward (line 187) | def forward(self, x):
class TrainPart2Model (line 282) | class TrainPart2Model(torch.nn.Module):
method __init__ (line 287) | def __init__(self, kp_extractor, emo_feature, kp_extractor_a, audio_fe...
method forward (line 312) | def forward(self, x):
class GeneratorFullModel (line 416) | class GeneratorFullModel(torch.nn.Module):
method __init__ (line 421) | def __init__(self, kp_extractor, kp_extractor_a, audio_feature, genera...
method forward (line 447) | def forward(self, x):
class DiscriminatorFullModel (line 557) | class DiscriminatorFullModel(torch.nn.Module):
method __init__ (line 562) | def __init__(self, kp_extractor, generator, discriminator, train_params):
method forward (line 575) | def forward(self, x, generated):
FILE: modules/model_delta_map.py
class Vgg19 (line 10) | class Vgg19(torch.nn.Module):
method __init__ (line 14) | def __init__(self, requires_grad=False):
method forward (line 42) | def forward(self, X):
class ImagePyramide (line 53) | class ImagePyramide(torch.nn.Module):
method __init__ (line 57) | def __init__(self, scales, num_channels):
method forward (line 64) | def forward(self, x):
class Transform (line 71) | class Transform:
method __init__ (line 75) | def __init__(self, bs, **kwargs):
method transform_frame (line 89) | def transform_frame(self, frame):
method inverse_transform_frame (line 95) | def inverse_transform_frame(self, frame):
method warp_coordinates (line 101) | def warp_coordinates(self, coordinates):
method inverse_warp_coordinates (line 121) | def inverse_warp_coordinates(self, coordinates):
method jacobian (line 146) | def jacobian(self, coordinates):
function detach_kp (line 155) | def detach_kp(kp):
class TrainFullModel (line 158) | class TrainFullModel(torch.nn.Module):
method __init__ (line 163) | def __init__(self, kp_extractor, emo_feature, kp_extractor_a, audio_fe...
method forward (line 192) | def forward(self, x):
class GeneratorFullModel (line 325) | class GeneratorFullModel(torch.nn.Module):
method __init__ (line 330) | def __init__(self, kp_extractor, kp_extractor_a, audio_feature, genera...
method forward (line 356) | def forward(self, x):
class DiscriminatorFullModel (line 466) | class DiscriminatorFullModel(torch.nn.Module):
method __init__ (line 471) | def __init__(self, kp_extractor, generator, discriminator, train_params):
method forward (line 484) | def forward(self, x, generated):
FILE: modules/model_gen.py
class Vgg19 (line 10) | class Vgg19(torch.nn.Module):
method __init__ (line 14) | def __init__(self, requires_grad=False):
method forward (line 42) | def forward(self, X):
class ImagePyramide (line 53) | class ImagePyramide(torch.nn.Module):
method __init__ (line 57) | def __init__(self, scales, num_channels):
method forward (line 64) | def forward(self, x):
class Transform (line 71) | class Transform:
method __init__ (line 75) | def __init__(self, bs, **kwargs):
method transform_frame (line 89) | def transform_frame(self, frame):
method inverse_transform_frame (line 95) | def inverse_transform_frame(self, frame):
method warp_coordinates (line 101) | def warp_coordinates(self, coordinates):
method inverse_warp_coordinates (line 121) | def inverse_warp_coordinates(self, coordinates):
method jacobian (line 146) | def jacobian(self, coordinates):
function detach_kp (line 155) | def detach_kp(kp):
class TrainFullModel (line 158) | class TrainFullModel(torch.nn.Module):
method __init__ (line 163) | def __init__(self, kp_extractor, emo_feature, kp_extractor_a, audio_fe...
method forward (line 192) | def forward(self, x):
class GeneratorFullModel (line 341) | class GeneratorFullModel(torch.nn.Module):
method __init__ (line 346) | def __init__(self, kp_extractor, kp_extractor_a, audio_feature, genera...
method forward (line 372) | def forward(self, x):
class DiscriminatorFullModel (line 482) | class DiscriminatorFullModel(torch.nn.Module):
method __init__ (line 487) | def __init__(self, kp_extractor, generator, discriminator, train_params):
method forward (line 500) | def forward(self, x, generated):
FILE: modules/ops.py
function linear (line 8) | def linear(channel_in, channel_out,
function conv2d (line 21) | def conv2d(channel_in, channel_out,
function conv_transpose2d (line 37) | def conv_transpose2d(channel_in, channel_out,
function nn_conv2d (line 53) | def nn_conv2d(channel_in, channel_out,
function _apply (line 71) | def _apply(layer, activation, normalizer, channel_out=None):
FILE: modules/stylegan2.py
function fused_leaky_relu (line 25) | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
class FusedLeakyReLU (line 29) | class FusedLeakyReLU(nn.Module):
method __init__ (line 30) | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
method forward (line 36) | def forward(self, input):
function upfirdn2d_native (line 45) | def upfirdn2d_native(
function upfirdn2d (line 82) | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
class PixelNorm (line 86) | class PixelNorm(nn.Module):
method __init__ (line 87) | def __init__(self):
method forward (line 90) | def forward(self, input):
function make_kernel (line 94) | def make_kernel(k):
class Upsample (line 105) | class Upsample(nn.Module):
method __init__ (line 106) | def __init__(self, kernel, factor=2):
method forward (line 120) | def forward(self, input):
class Downsample (line 126) | class Downsample(nn.Module):
method __init__ (line 127) | def __init__(self, kernel, factor=2):
method forward (line 141) | def forward(self, input):
class Blur (line 147) | class Blur(nn.Module):
method __init__ (line 148) | def __init__(self, kernel, pad, upsample_factor=1):
method forward (line 160) | def forward(self, input):
class EqualConv2d (line 166) | class EqualConv2d(nn.Module):
method __init__ (line 167) | def __init__(
method forward (line 186) | def forward(self, input):
method __repr__ (line 199) | def __repr__(self):
class EqualLinear (line 206) | class EqualLinear(nn.Module):
method __init__ (line 207) | def __init__(
method forward (line 225) | def forward(self, input):
method __repr__ (line 237) | def __repr__(self):
class ScaledLeakyReLU (line 243) | class ScaledLeakyReLU(nn.Module):
method __init__ (line 244) | def __init__(self, negative_slope=0.2):
method forward (line 249) | def forward(self, input):
class ModulatedConv2d (line 255) | class ModulatedConv2d(nn.Module):
method __init__ (line 256) | def __init__(
method __repr__ (line 305) | def __repr__(self):
method forward (line 311) | def forward(self, input, style):
class NoiseInjection (line 358) | class NoiseInjection(nn.Module):
method __init__ (line 359) | def __init__(self):
method forward (line 364) | def forward(self, image, noise=None):
class ConstantInput (line 372) | class ConstantInput(nn.Module):
method __init__ (line 373) | def __init__(self, channel, size=4):
method forward (line 378) | def forward(self, input):
class StyledConv (line 385) | class StyledConv(nn.Module):
method __init__ (line 386) | def __init__(
method forward (line 415) | def forward(self, input, style=None, noise=None):
class ToRGB (line 425) | class ToRGB(nn.Module):
method __init__ (line 426) | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[...
method forward (line 435) | def forward(self, input, style, skip=None):
class Generator (line 447) | class Generator(nn.Module):
method __init__ (line 448) | def __init__(
method make_noise (line 533) | def make_noise(self):
method mean_latent (line 544) | def mean_latent(self, n_latent):
method get_latent (line 552) | def get_latent(self, input):
method forward (line 555) | def forward(
class ConvLayer (line 630) | class ConvLayer(nn.Sequential):
method __init__ (line 631) | def __init__(
class ResBlock (line 679) | class ResBlock(nn.Module):
method __init__ (line 680) | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], ...
method forward (line 694) | def forward(self, input):
class StyleGAN2Discriminator (line 704) | class StyleGAN2Discriminator(nn.Module):
method __init__ (line 705) | def __init__(self, input_nc, ndf=64, n_layers=3, no_antialias=False, s...
method forward (line 761) | def forward(self, input, get_minibatch_features=False):
class TileStyleGAN2Discriminator (line 795) | class TileStyleGAN2Discriminator(StyleGAN2Discriminator):
method forward (line 796) | def forward(self, input):
class StyleGAN2Encoder (line 806) | class StyleGAN2Encoder(nn.Module):
method __init__ (line 807) | def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_b...
method forward (line 843) | def forward(self, input, layers=[], get_features=False):
class StyleGAN2Decoder (line 860) | class StyleGAN2Decoder(nn.Module):
method __init__ (line 861) | def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_b...
method forward (line 902) | def forward(self, input):
class StyleGAN2Generator (line 906) | class StyleGAN2Generator(nn.Module):
method __init__ (line 907) | def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_b...
method forward (line 913) | def forward(self, input, layers=[], encode_only=False):
FILE: modules/util.py
class InstanceNorm (line 26) | class InstanceNorm(nn.Module):
method __init__ (line 27) | def __init__(self, epsilon=1e-8):
method forward (line 35) | def forward(self, x):
class ApplyStyle (line 41) | class ApplyStyle(nn.Module):
method __init__ (line 45) | def __init__(self, latent_size, channels, use_wscale):
method forward (line 52) | def forward(self, x, latent):
class FC (line 60) | class FC(nn.Module):
method __init__ (line 61) | def __init__(self,
method forward (line 87) | def forward(self, x):
class Embedder (line 97) | class Embedder:
method __init__ (line 98) | def __init__(self, **kwargs):
method create_embedding_fn (line 102) | def create_embedding_fn(self):
method embed (line 126) | def embed(self, inputs):
function get_embedder (line 130) | def get_embedder(multires, i=0):
function draw_heatmap (line 148) | def draw_heatmap(landmark, width, height):
class NA_net (line 175) | class NA_net(nn.Module):
method __init__ (line 176) | def __init__(self):
method forward (line 195) | def forward(self, neutral):
class AT_net (line 203) | class AT_net(nn.Module):
method __init__ (line 204) | def __init__(self):
method forward (line 270) | def forward(self, example_image, audio, pose, jaco_net):
class Classify (line 306) | class Classify(nn.Module):
method __init__ (line 307) | def __init__(self):
method forward (line 314) | def forward(self, feature):
class TF_net (line 321) | class TF_net(nn.Module):
method __init__ (line 322) | def __init__(self):
method adain_forward (line 391) | def adain_forward(self, example_image, audio, pose, jaco_net, emo_feat...
method adain_feature2 (line 434) | def adain_feature2(self, example_image, audio, pose, jaco_net, emo_fea...
method forward (line 477) | def forward(self, example_image, audio, pose, jaco_net, emo_features):
class AT_net2 (line 514) | class AT_net2(nn.Module):
method __init__ (line 515) | def __init__(self):
method forward (line 580) | def forward(self, example_image, audio, pose, jaco_net, weight):
class Ct_encoder (line 618) | class Ct_encoder(nn.Module):
method __init__ (line 619) | def __init__(self):
method forward (line 638) | def forward(self, audio):
class EmotionNet (line 647) | class EmotionNet(nn.Module):
method __init__ (line 648) | def __init__(self):
method forward (line 697) | def forward(self, mfcc):
class AF2F (line 715) | class AF2F(nn.Module):
method __init__ (line 716) | def __init__(self):
method forward (line 736) | def forward(self, content,emotion):
class AF2F_s (line 745) | class AF2F_s(nn.Module):
method __init__ (line 746) | def __init__(self):
method forward (line 766) | def forward(self, content):
class A2I (line 776) | class A2I(nn.Module):
method __init__ (line 777) | def __init__(self):
method forward (line 804) | def forward(self, mfcc):
function kp2gaussian (line 815) | def kp2gaussian(kp, spatial_size, kp_variance):
function make_coordinate_grid (line 839) | def make_coordinate_grid(spatial_size, type):
class ResBlock2d (line 858) | class ResBlock2d(nn.Module):
method __init__ (line 863) | def __init__(self, in_features, kernel_size, padding):
method forward (line 872) | def forward(self, x):
class UpBlock2d (line 883) | class UpBlock2d(nn.Module):
method __init__ (line 888) | def __init__(self, in_features, out_features, kernel_size=3, padding=1...
method forward (line 895) | def forward(self, x):
class DownBlock2d (line 903) | class DownBlock2d(nn.Module):
method __init__ (line 908) | def __init__(self, in_features, out_features, kernel_size=3, padding=1...
method forward (line 915) | def forward(self, x):
class SameBlock2d (line 923) | class SameBlock2d(nn.Module):
method __init__ (line 928) | def __init__(self, in_features, out_features, groups=1, kernel_size=3,...
method forward (line 934) | def forward(self, x):
class Encoder (line 941) | class Encoder(nn.Module):
method __init__ (line 946) | def __init__(self, block_expansion, in_features, num_blocks=3, max_fea...
method forward (line 956) | def forward(self, x):
class Decoder (line 963) | class Decoder(nn.Module):
method __init__ (line 968) | def __init__(self, block_expansion, in_features, num_blocks=3, max_fea...
method forward (line 981) | def forward(self, x):
class Hourglass (line 990) | class Hourglass(nn.Module):
method __init__ (line 995) | def __init__(self, block_expansion, in_features, num_blocks=3, max_fea...
method forward (line 1001) | def forward(self, x):
class AntiAliasInterpolation2d (line 1005) | class AntiAliasInterpolation2d(nn.Module):
method __init__ (line 1009) | def __init__(self, channels, scale):
method forward (line 1044) | def forward(self, input):
function sigmoid (line 1054) | def sigmoid(x):
function norm_angle (line 1058) | def norm_angle(angle):
function conv3x3 (line 1063) | def conv3x3(in_planes, out_planes, stride=1):
class BasicBlock (line 1069) | class BasicBlock(nn.Module):
method __init__ (line 1072) | def __init__(self, inplanes, planes, stride=1, downsample=None):
method forward (line 1082) | def forward(self, x):
class Bottleneck (line 1101) | class Bottleneck(nn.Module):
method __init__ (line 1104) | def __init__(self, inplanes, planes, stride=1, downsample=None):
method forward (line 1117) | def forward(self, x):
class EmDetector (line 1139) | class EmDetector(nn.Module):
method __init__ (line 1144) | def __init__(self, block_expansion, num_channels, max_features,
method _make_layer (line 1170) | def _make_layer(self, block, planes, blocks, stride=1):
method adain_feature (line 1187) | def adain_feature(self, x): #torch.Size([4, 3, H, W])
method forward (line 1197) | def forward(self, x): #torch.Size([4, 3, H, W])
class Emotion_k (line 1223) | class Emotion_k(nn.Module):
method __init__ (line 1228) | def __init__(self, block_expansion, num_channels, max_features,
method _make_layer (line 1316) | def _make_layer(self, block, planes, blocks, stride=1):
method linear_10 (line 1333) | def linear_10(self, x, value, jacobian): #torch.Size([4, 3, H, W])
method linear_4 (line 1364) | def linear_4(self, x, value, jacobian): #torch.Size([4, 3, H, W])
method linear_np_10 (line 1396) | def linear_np_10(self, x, value, jacobian): #torch.Size([4, 3, H, W])
method linear_np_4 (line 1427) | def linear_np_4(self, x, value, jacobian): #torch.Size([4, 3, H, W])
method emotion_feature (line 1459) | def emotion_feature(self, feature, value, jacobian): #torch.Size([4, 3...
method feature (line 1477) | def feature(self, x): #torch.Size([4, 3, H, W])
method forward (line 1498) | def forward(self, x, value, jacobian): #torch.Size([4, 3, H, W])
class Emotion_map (line 1529) | class Emotion_map(nn.Module):
method __init__ (line 1534) | def __init__(self, block_expansion, num_channels, max_features,
method _make_layer (line 1607) | def _make_layer(self, block, planes, blocks, stride=1):
method gaussian2kp (line 1624) | def gaussian2kp(self, heatmap):
method map_4 (line 1636) | def map_4(self, x, value, jacobian): #torch.Size([4, 3, H, W])
method forward (line 1687) | def forward(self, x, value, jacobian): #torch.Size([4, 3, H, W])
function conv2d (line 1740) | def conv2d(channel_in, channel_out,
function _apply (line 1755) | def _apply(layer, activation, normalizer, channel_out=None):
FILE: ops.py
class ResidualBlock (line 8) | class ResidualBlock(nn.Module):
method __init__ (line 9) | def __init__(self, channel_in, channel_out):
method forward (line 19) | def forward(self, x):
function linear (line 27) | def linear(channel_in, channel_out,
function conv2d (line 40) | def conv2d(channel_in, channel_out,
function conv_transpose2d (line 56) | def conv_transpose2d(channel_in, channel_out,
function nn_conv2d (line 72) | def nn_conv2d(channel_in, channel_out,
function _apply (line 90) | def _apply(layer, activation, normalizer, channel_out=None):
FILE: process_data.py
function save (line 29) | def save(path, frames, format):
function crop_image (line 44) | def crop_image(image_path, out_path):
function shape_to_np (line 70) | def shape_to_np(shape, dtype="int"):
function crop_image_tem (line 85) | def crop_image_tem(video_path, out_path):
function proc_audio (line 124) | def proc_audio(src_mouth_path, dst_audio_path):
function audio2mfcc (line 130) | def audio2mfcc(audio_file, save, name):
FILE: sync_batchnorm/batchnorm.py
function _sum_ft (line 24) | def _sum_ft(tensor):
function _unsqueeze_ft (line 29) | def _unsqueeze_ft(tensor):
class _SynchronizedBatchNorm (line 38) | class _SynchronizedBatchNorm(_BatchNorm):
method __init__ (line 39) | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
method forward (line 48) | def forward(self, input):
method __data_parallel_replicate__ (line 80) | def __data_parallel_replicate__(self, ctx, copy_id):
method _data_parallel_master (line 90) | def _data_parallel_master(self, intermediates):
method _compute_mean_std (line 113) | def _compute_mean_std(self, sum_, ssum, size):
class SynchronizedBatchNorm1d (line 128) | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
method _check_input_dim (line 184) | def _check_input_dim(self, input):
class SynchronizedBatchNorm2d (line 191) | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
method _check_input_dim (line 247) | def _check_input_dim(self, input):
class SynchronizedBatchNorm3d (line 254) | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
method _check_input_dim (line 311) | def _check_input_dim(self, input):
FILE: sync_batchnorm/comm.py
class FutureResult (line 18) | class FutureResult(object):
method __init__ (line 21) | def __init__(self):
method put (line 26) | def put(self, result):
method get (line 32) | def get(self):
class SlavePipe (line 46) | class SlavePipe(_SlavePipeBase):
method run_slave (line 49) | def run_slave(self, msg):
class SyncMaster (line 56) | class SyncMaster(object):
method __init__ (line 67) | def __init__(self, master_callback):
method __getstate__ (line 78) | def __getstate__(self):
method __setstate__ (line 81) | def __setstate__(self, state):
method register_slave (line 84) | def register_slave(self, identifier):
method run_master (line 102) | def run_master(self, master_msg):
method nr_slaves (line 136) | def nr_slaves(self):
FILE: sync_batchnorm/replicate.py
class CallbackContext (line 23) | class CallbackContext(object):
function execute_replication_callbacks (line 27) | def execute_replication_callbacks(modules):
class DataParallelWithCallback (line 50) | class DataParallelWithCallback(DataParallel):
method replicate (line 64) | def replicate(self, module, device_ids):
function patch_replication_callback (line 70) | def patch_replication_callback(data_parallel):
FILE: sync_batchnorm/unittest.py
function as_numpy (line 17) | def as_numpy(v):
class TorchTestCase (line 23) | class TorchTestCase(unittest.TestCase):
method assertTensorClose (line 24) | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
FILE: train.py
function train_part1 (line 18) | def train_part1(config, generator, discriminator, kp_detector, kp_detect...
function train_part1_fine_tune (line 133) | def train_part1_fine_tune(config, generator, discriminator, kp_detector,...
function train_part2 (line 273) | def train_part2(config, generator, discriminator, kp_detector, emo_detec...
Condensed preview — 53 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (366K chars).
[
{
"path": "3DDFA_V2/demo.py",
"chars": 8816,
"preview": "# coding: utf-8\n\n__author__ = 'cleardusk'\n\nimport sys\nimport argparse\nimport cv2\nimport yaml\nimport os\nimport time\nfrom "
},
{
"path": "3DDFA_V2/utils/pose.py",
"chars": 8314,
"preview": "# coding: utf-8\n\n\"\"\"\nReference: https://github.com/YadiraF/PRNet/blob/master/utils/estimate_pose.py\n\nCalculating pose fr"
},
{
"path": "LICENSE",
"chars": 1064,
"preview": "MIT License\n\nCopyright (c) 2022 jixinya\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof"
},
{
"path": "README.md",
"chars": 3250,
"preview": "# EAMM: One-Shot Emotional Talking Face via Audio-Based Emotion-Aware Motion Model [SIGGRAPH 2022 Conference]\r\n\r\nXinya "
},
{
"path": "augmentation.py",
"chars": 16022,
"preview": "\"\"\"\nCode from https://github.com/hassony2/torch_videovision\n\"\"\"\n\nimport numbers\nimport math\nimport random\nimport numpy a"
},
{
"path": "config/MEAD_emo_video_aug_delta_4_crop_random_crop.yaml",
"chars": 2210,
"preview": "dataset_params:\n root_dir: /mnt/lustre/share_data/jixinya/MEAD/\n frame_shape: [256, 256, 3]\n id_sampling: False\n pai"
},
{
"path": "config/train_part1.yaml",
"chars": 1684,
"preview": "dataset_params:\n name: Vox\n root_dir: dataset/LRW/\n frame_shape: [256, 256, 3]\n id_sampling: False\n augmentation_pa"
},
{
"path": "config/train_part1_fine_tune.yaml",
"chars": 1689,
"preview": "dataset_params:\n name: LRW\n root_dir: dataset/LRW/\n frame_shape: [256, 256, 3]\n id_sampling: False\n augmentation_pa"
},
{
"path": "config/train_part2.yaml",
"chars": 1919,
"preview": "dataset_params:\n name: MEAD\n root_dir: dataset/MEAD/\n frame_shape: [256, 256, 3]\n id_sampling: False\n augmentation_"
},
{
"path": "demo.py",
"chars": 22742,
"preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nCreated on Wed Oct 6 20:57:27 2021\n@author: thea\n\"\"\"\n\nimport matplot"
},
{
"path": "filter1.py",
"chars": 1201,
"preview": "import cv2\n#import pickle\nimport time\nimport numpy as np\nimport copy\n\nfrom matplotlib import pyplot as plt\nfrom tqdm imp"
},
{
"path": "frames_dataset.py",
"chars": 20138,
"preview": "import os\nfrom skimage import io, img_as_float32, transform\nfrom skimage.color import gray2rgb\nfrom sklearn.model_select"
},
{
"path": "logger.py",
"chars": 9203,
"preview": "import numpy as np\nimport torch\nimport torch.nn.functional as F\nimport imageio\n\nimport os\nfrom skimage.draw import circl"
},
{
"path": "modules/dense_motion.py",
"chars": 5189,
"preview": "from torch import nn\nimport torch.nn.functional as F\nimport torch\nfrom modules.util import Hourglass, AntiAliasInterpola"
},
{
"path": "modules/discriminator.py",
"chars": 3156,
"preview": "from torch import nn\nimport torch.nn.functional as F\nfrom modules.util import kp2gaussian\nimport torch\n\n\nclass DownBlock"
},
{
"path": "modules/function.py",
"chars": 2525,
"preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nCreated on Thu Sep 30 17:45:24 2021\n\n@author: SENSETIME\\jixinya1\n\"\"\"\n"
},
{
"path": "modules/generator.py",
"chars": 4627,
"preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom modules.util import ResBlock2d, SameBlock2d, UpBl"
},
{
"path": "modules/keypoint_detector.py",
"chars": 10448,
"preview": "from torch import nn\nimport torch\nimport torch.nn.functional as F\nfrom modules.util import Hourglass, make_coordinate_gr"
},
{
"path": "modules/model.py",
"chars": 28798,
"preview": "from torch import nn\nimport torch\nimport torch.nn.functional as F\nfrom modules.util import AntiAliasInterpolation2d, mak"
},
{
"path": "modules/model_delta_map.py",
"chars": 25745,
"preview": "from torch import nn\nimport torch\nimport torch.nn.functional as F\nfrom modules.util import AntiAliasInterpolation2d, mak"
},
{
"path": "modules/model_gen.py",
"chars": 25806,
"preview": "from torch import nn\nimport torch\nimport torch.nn.functional as F\nfrom modules.util import AntiAliasInterpolation2d, mak"
},
{
"path": "modules/ops.py",
"chars": 2305,
"preview": "import torch\nimport torchvision\nimport torch.nn as nn\nimport torch.nn.init as init\nfrom torch.autograd import Variable\n\n"
},
{
"path": "modules/stylegan2.py",
"chars": 28081,
"preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nCreated on Thu Jul 8 01:03:50 2021\n\n@author: thea\n\"\"\"\n\n\"\"\"\nThe netwo"
},
{
"path": "modules/util.py",
"chars": 63282,
"preview": "from torch import nn\n\nimport torch.nn.functional as F\nimport torch\nimport numpy as np\nimport cv2\nfrom sync_batchnorm imp"
},
{
"path": "ops.py",
"chars": 2797,
"preview": "import torch\nimport torchvision\nimport torch.nn as nn\nimport torch.nn.init as init\nfrom torch.autograd import Variable\n\n"
},
{
"path": "process_data.py",
"chars": 5768,
"preview": "# -*- coding: utf-8 -*-\n\"\"\"\nCreated on Thu Jun 24 11:36:01 2021\n\n@author: Xinya\n\"\"\"\n\nimport os\nimport glob\nimport time\ni"
},
{
"path": "requirements.txt",
"chars": 167,
"preview": "torch==1.10.1\ntorchvision==0.11.2\nnumpy\nlibrosa\nopencv-python\npython_speech_features\npickle-mixin\nmatplotlib\nscikit-imag"
},
{
"path": "run.py",
"chars": 5496,
"preview": "import matplotlib\n\nmatplotlib.use('Agg')\n\nimport os, sys\nimport yaml\nfrom argparse import ArgumentParser\nfrom time impor"
},
{
"path": "sync_batchnorm/__init__.py",
"chars": 449,
"preview": "# -*- coding: utf-8 -*-\n# File : __init__.py\n# Author : Jiayuan Mao\n# Email : maojiayuan@gmail.com\n# Date : 27/01/2"
},
{
"path": "sync_batchnorm/batchnorm.py",
"chars": 12973,
"preview": "# -*- coding: utf-8 -*-\n# File : batchnorm.py\n# Author : Jiayuan Mao\n# Email : maojiayuan@gmail.com\n# Date : 27/01/"
},
{
"path": "sync_batchnorm/comm.py",
"chars": 4449,
"preview": "# -*- coding: utf-8 -*-\n# File : comm.py\n# Author : Jiayuan Mao\n# Email : maojiayuan@gmail.com\n# Date : 27/01/2018\n"
},
{
"path": "sync_batchnorm/replicate.py",
"chars": 3226,
"preview": "# -*- coding: utf-8 -*-\n# File : replicate.py\n# Author : Jiayuan Mao\n# Email : maojiayuan@gmail.com\n# Date : 27/01/"
},
{
"path": "sync_batchnorm/unittest.py",
"chars": 835,
"preview": "# -*- coding: utf-8 -*-\n# File : unittest.py\n# Author : Jiayuan Mao\n# Email : maojiayuan@gmail.com\n# Date : 27/01/2"
},
{
"path": "train.py",
"chars": 19729,
"preview": "from tqdm import trange\nimport torch\nimport torch.nn as nn\nfrom torch.utils.data import DataLoader\n\nfrom logger import L"
}
]
// ... and 19 more files (download for full content)
About this extraction
This page contains the full source code of the jixinya/EAMM GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 53 files (345.8 KB), approximately 92.0k tokens, and a symbol index with 451 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.