Full Code of jixinya/EAMM for AI

main e176a79865f7 cached
53 files
345.8 KB
92.0k tokens
451 symbols
1 requests
Download .txt
Showing preview only (361K chars total). Download the full file or copy to clipboard to get everything.
Repository: jixinya/EAMM
Branch: main
Commit: e176a79865f7
Files: 53
Total size: 345.8 KB

Directory structure:
gitextract_7pp944hx/

├── 3DDFA_V2/
│   ├── demo.py
│   └── utils/
│       └── pose.py
├── LICENSE
├── M003_template.npy
├── README.md
├── augmentation.py
├── config/
│   ├── MEAD_emo_video_aug_delta_4_crop_random_crop.yaml
│   ├── train_part1.yaml
│   ├── train_part1_fine_tune.yaml
│   └── train_part2.yaml
├── dataset/
│   ├── LRW/
│   │   ├── MFCC/
│   │   │   └── ABOUT/
│   │   │       └── ABOUT_00001.npy
│   │   └── Pose/
│   │       └── ABOUT/
│   │           └── ABOUT_00001.npy
│   └── MEAD/
│       └── list/
│           └── MEAD_fomm_neu_dic_crop.npy
├── demo.py
├── filter1.py
├── frames_dataset.py
├── logger.py
├── modules/
│   ├── dense_motion.py
│   ├── discriminator.py
│   ├── function.py
│   ├── generator.py
│   ├── keypoint_detector.py
│   ├── model.py
│   ├── model_delta_map.py
│   ├── model_gen.py
│   ├── ops.py
│   ├── stylegan2.py
│   └── util.py
├── ops.py
├── process_data.py
├── requirements.txt
├── run.py
├── sync_batchnorm/
│   ├── __init__.py
│   ├── batchnorm.py
│   ├── comm.py
│   ├── replicate.py
│   └── unittest.py
├── test/
│   ├── pose/
│   │   ├── 14.npy
│   │   ├── 21.npy
│   │   ├── 60.npy
│   │   ├── 7.npy
│   │   ├── anne.npy
│   │   ├── brade2.npy
│   │   ├── dune_1.npy
│   │   ├── dune_2.npy
│   │   ├── jake4.npy
│   │   ├── mona.npy
│   │   ├── paint1.npy
│   │   └── scarlett.npy
│   └── pose_long/
│       ├── 0zn70Ak8lRc_Daniel_Auteuil_0zn70Ak8lRc_0002.npy
│       ├── 1hEr7qKRKL4_Daniel_Dae_Kim_1hEr7qKRKL4_0004.npy
│       └── 50IAfJCypFI_Alex_Kingston_50IAfJCypFI_0001.npy
└── train.py

================================================
FILE CONTENTS
================================================

================================================
FILE: 3DDFA_V2/demo.py
================================================
# coding: utf-8

__author__ = 'cleardusk'

import sys
import argparse
import cv2
import yaml
import os
import time
from FaceBoxes import FaceBoxes
from TDDFA import TDDFA
from utils.render import render
#from utils.render_ctypes import render  # faster
from utils.depth import depth
from utils.pncc import pncc
from utils.uv import uv_tex
from utils.pose import viz_pose, get_pose
from utils.serialization import ser_to_ply, ser_to_obj
from utils.functions import draw_landmarks, get_suffix
from utils.tddfa_util import str2bool
import numpy as np
from tqdm import tqdm
import copy

import concurrent.futures
from multiprocessing import Pool

def main(args,img, save_path, pose_path):
 #   begin = time.time()
    cfg = yaml.load(open(args.config), Loader=yaml.SafeLoader)

    # Init FaceBoxes and TDDFA, recommend using onnx flag
    if args.onnx:
        import os
        os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
        os.environ['OMP_NUM_THREADS'] = '4'

        from FaceBoxes.FaceBoxes_ONNX import FaceBoxes_ONNX
        from TDDFA_ONNX import TDDFA_ONNX

        face_boxes = FaceBoxes_ONNX()
        tddfa = TDDFA_ONNX(**cfg)
    else:
        gpu_mode = args.mode == 'gpu'
        tddfa = TDDFA(gpu_mode=gpu_mode, **cfg)
        face_boxes = FaceBoxes()

    # Given a still image path and load to BGR channel
  #  img = cv2.imread(img_path) #args.img_fp

    # Detect faces, get 3DMM params and roi boxes
    boxes = face_boxes(img)
    n = len(boxes)
    if n == 0:
        print(f'No face detected, exit')
      #  sys.exit(-1)
        return None
    print(f'Detect {n} faces')

    param_lst, roi_box_lst = tddfa(img, boxes)
    #detection time
  #  detect_time = time.time()-begin
 #   print('detection time: '+str(detect_time), file=open('/mnt/lustre/jixinya/Home/3DDFA_V2/pose.txt', 'a'))
    # Visualization and serialization
    dense_flag = args.opt in ('2d_dense', '3d', 'depth', 'pncc', 'uv_tex', 'ply', 'obj')
  #  old_suffix = get_suffix(img_path)
    old_suffix = 'png'
    new_suffix = f'.{args.opt}' if args.opt in ('ply', 'obj') else '.jpg'

    wfp = f'examples/results/{args.img_fp.split("/")[-1].replace(old_suffix, "")}_{args.opt}' + new_suffix

    ver_lst = tddfa.recon_vers(param_lst, roi_box_lst, dense_flag=dense_flag)

    if args.opt == '2d_sparse':
        draw_landmarks(img, ver_lst, show_flag=args.show_flag, dense_flag=dense_flag, wfp=wfp)
    elif args.opt == '2d_dense':
        draw_landmarks(img, ver_lst, show_flag=args.show_flag, dense_flag=dense_flag, wfp=wfp)
    elif args.opt == '3d':
        render(img, ver_lst, tddfa.tri, alpha=0.6, show_flag=args.show_flag, wfp=wfp)
    elif args.opt == 'depth':

        # if `with_bf_flag` is False, the background is black
        depth(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp, with_bg_flag=True)
    elif args.opt == 'pncc':
        pncc(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp, with_bg_flag=True)
    elif args.opt == 'uv_tex':
        uv_tex(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp)
    elif args.opt == 'pose':
        all_pose = get_pose(img, param_lst, ver_lst, show_flag=args.show_flag, wfp=save_path, wnp = pose_path)
    elif args.opt == 'ply':
        ser_to_ply(ver_lst, tddfa.tri, height=img.shape[0], wfp=wfp)
    elif args.opt == 'obj':
        ser_to_obj(img, ver_lst, tddfa.tri, height=img.shape[0], wfp=wfp)
    else:
        raise ValueError(f'Unknown opt {args.opt}')

    return all_pose



def process_word(i):
    path = '/media/xinya/Backup Plus/sense_shixi_data/new_crop/MEAD_fomm_video_6/'
    save = '/media/xinya/Backup Plus/sense_shixi_data/new_crop/MEAD_fomm_pose_im/'
    pose = '/media/xinya/Backup Plus/sense_shixi_data/new_crop/MEAD_fomm_pose/'
    start = time.time()
    Dir = os.listdir(path)
    Dir.sort()
    word = Dir[i]
    wpath = os.path.join(path, word)
    print(wpath)
    pathDir = os.listdir(wpath)
    pose_file = os.path.join(pose,word)
    if not os.path.exists(pose_file):
        os.makedirs(pose_file)

    for j in range(len(pathDir)):
        name = pathDir[j]
     #   save_file = os.path.join(save,word,name)
     #   if not os.path.exists(save_file):
     #       os.makedirs(save_file)
        fpath = os.path.join(wpath,name)
        image_all = []
        videoCapture = cv2.VideoCapture(fpath)

        success, frame = videoCapture.read()

        n = 0
        while success :
            image_all.append(frame)
            n = n + 1
            success, frame = videoCapture.read()

     #   fDir = os.listdir(fpath)
        pose_all = np.zeros((len(image_all),7))
        for k in range(len(image_all)):
    #        index = fDir[k].split('.')[0]
    #        img_path = os.path.join(fpath,str(k)+'.png')

     #       pose_all[k] = main(args,image_all[k], os.path.join(save_file,str(k)+'.jpg'), None)
            pose_all[k] = main(args,image_all[k], None, None)
        np.save(os.path.join(pose,word,name.split('.')[0]+'.npy'),pose_all)
        st = time.time()-start
        print(str(i)+' '+word+' '+str(j)+' '+name+' '+str(k)+'time: '+str(st), file=open('/media/thea/Backup Plus/sense_shixi_data/new_crop/pose_mead6.txt', 'a'))
        print(i,word,j,name,k)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='The demo of still image of 3DDFA_V2')
    parser.add_argument('-c', '--config', type=str, default='configs/mb1_120x120.yml')
    parser.add_argument('-f', '--img_fp', type=str, default='examples/inputs/0.png')
    parser.add_argument('-m', '--mode', type=str, default='cpu', help='gpu or cpu mode')
    parser.add_argument('-o', '--opt', type=str, default='pose',
                        choices=['2d_sparse', '2d_dense', '3d', 'depth', 'pncc', 'uv_tex', 'pose', 'ply', 'obj'])
    parser.add_argument('--show_flag', type=str2bool, default='False', help='whether to show the visualization result')
    parser.add_argument('--onnx', action='store_true', default=False)

    args = parser.parse_args()


    
    filepath = 'test/image/'
    pathDir = os.listdir(filepath)
    for i in range(len(pathDir)):
        image= cv2.imread(os.path.join(filepath,pathDir[i]))
        pose = main(args,image, None, None).reshape(1,7)

        np.save('test/pose/'+pathDir[i].split('.')[0]+'.npy',pose)
        print(i,pathDir[i])
        
        
'''





def main(args):
    cfg = yaml.load(open(args.config), Loader=yaml.SafeLoader)

    # Init FaceBoxes and TDDFA, recommend using onnx flag
    if args.onnx:
        import os
        os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
        os.environ['OMP_NUM_THREADS'] = '4'

        from FaceBoxes.FaceBoxes_ONNX import FaceBoxes_ONNX
        from TDDFA_ONNX import TDDFA_ONNX

        face_boxes = FaceBoxes_ONNX()
        tddfa = TDDFA_ONNX(**cfg)
    else:
        gpu_mode = args.mode == 'gpu'
        tddfa = TDDFA(gpu_mode=gpu_mode, **cfg)
        face_boxes = FaceBoxes()

    # Given a still image path and load to BGR channel
    img = cv2.imread(args.img_fp)

    # Detect faces, get 3DMM params and roi boxes
    boxes = face_boxes(img)
    n = len(boxes)
    if n == 0:
        print(f'No face detected, exit')
        sys.exit(-1)
    print(f'Detect {n} faces')

    param_lst, roi_box_lst = tddfa(img, boxes)

    # Visualization and serialization
    dense_flag = args.opt in ('2d_dense', '3d', 'depth', 'pncc', 'uv_tex', 'ply', 'obj')
    old_suffix = get_suffix(args.img_fp)
    new_suffix = f'.{args.opt}' if args.opt in ('ply', 'obj') else '.jpg'

    wfp = f'examples/results/{args.img_fp.split("/")[-1].replace(old_suffix, "")}_{args.opt}' + new_suffix

    ver_lst = tddfa.recon_vers(param_lst, roi_box_lst, dense_flag=dense_flag)

    if args.opt == '2d_sparse':
        draw_landmarks(img, ver_lst, show_flag=args.show_flag, dense_flag=dense_flag, wfp=wfp)
    elif args.opt == '2d_dense':
        draw_landmarks(img, ver_lst, show_flag=args.show_flag, dense_flag=dense_flag, wfp=wfp)
    elif args.opt == '3d':
        render(img, ver_lst, tddfa.tri, alpha=0.6, show_flag=args.show_flag, wfp=wfp)
    elif args.opt == 'depth':
        # if `with_bf_flag` is False, the background is black
        depth(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp, with_bg_flag=True)
    elif args.opt == 'pncc':
        pncc(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp, with_bg_flag=True)
    elif args.opt == 'uv_tex':
        uv_tex(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp)
    elif args.opt == 'pose':
        viz_pose(img, param_lst, ver_lst, show_flag=args.show_flag, wfp=wfp)
    elif args.opt == 'ply':
        ser_to_ply(ver_lst, tddfa.tri, height=img.shape[0], wfp=wfp)
    elif args.opt == 'obj':
        ser_to_obj(img, ver_lst, tddfa.tri, height=img.shape[0], wfp=wfp)
    else:
        raise ValueError(f'Unknown opt {args.opt}')
'''

================================================
FILE: 3DDFA_V2/utils/pose.py
================================================
# coding: utf-8

"""
Reference: https://github.com/YadiraF/PRNet/blob/master/utils/estimate_pose.py

Calculating pose from the output 3DMM parameters, you can also try to use solvePnP to perform estimation
"""

__author__ = 'cleardusk'

import cv2
import numpy as np
from math import cos, sin, atan2, asin, sqrt

from .functions import calc_hypotenuse, plot_image


def P2sRt(P):
    """ decompositing camera matrix P.
    Args:
        P: (3, 4). Affine Camera Matrix.
    Returns:
        s: scale factor.
        R: (3, 3). rotation matrix.
        t2d: (2,). 2d translation.
    """
    t3d = P[:, 3]
    R1 = P[0:1, :3]
    R2 = P[1:2, :3]
    s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2.0
    r1 = R1 / np.linalg.norm(R1)
    r2 = R2 / np.linalg.norm(R2)
    r3 = np.cross(r1, r2)

    R = np.concatenate((r1, r2, r3), 0)
    return s, R, t3d


def matrix2angle(R):
    """ compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf
    refined by: https://stackoverflow.com/questions/43364900/rotation-matrix-to-euler-angles-with-opencv
    todo: check and debug
     Args:
         R: (3,3). rotation matrix
     Returns:
         x: yaw
         y: pitch
         z: roll
     """
    if R[2, 0] > 0.998:
        z = 0
        x = np.pi / 2
        y = z + atan2(-R[0, 1], -R[0, 2])
    elif R[2, 0] < -0.998:
        z = 0
        x = -np.pi / 2
        y = -z + atan2(R[0, 1], R[0, 2])
    else:
        x = asin(R[2, 0])
        y = atan2(R[2, 1] / cos(x), R[2, 2] / cos(x))
        z = atan2(R[1, 0] / cos(x), R[0, 0] / cos(x))

    return x, y, z

def angle2matrix(theta):
    """ compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf
    refined by: https://stackoverflow.com/questions/43364900/rotation-matrix-to-euler-angles-with-opencv
    todo: check and debug
     Args:
         R: (3,3). rotation matrix
     Returns:
         x: yaw
         y: pitch
         z: roll
     """
    R_x = np.array([[1,         0,                  0         ],

                    [0,         cos(theta[1]), -sin(theta[1]) ],

                    [0,         sin(theta[1]), cos(theta[1])  ]

                    ])

 

    R_y = np.array([[cos(theta[0]),    0,      sin(-theta[0])  ],

                    [0,                     1,      0         ],

                    [-sin(-theta[0]),   0,      cos(theta[0])  ]

                    ])

 

    R_z = np.array([[cos(theta[2]),    -sin(theta[2]),    0],

                    [sin(theta[2]),    cos(theta[2]),     0],

                    [0,                     0,            1]

                    ])

 

    R = np.dot(R_z, np.dot( R_y, R_x ))

 

    return R

def angle2matrix_3ddfa(angles):
    ''' get rotation matrix from three rotation angles(radian). The same as in 3DDFA.
    Args:
        angles: [3,]. x, y, z angles
        x: pitch.
        y: yaw. 
        z: roll. 
    Returns:
        R: 3x3. rotation matrix.
    '''
    # x, y, z = np.deg2rad(angles[0]), np.deg2rad(angles[1]), np.deg2rad(angles[2])
    x, y, z = angles[1], angles[0], angles[2]
    
    # x
    Rx=np.array([[1,      0,       0],
                 [0, cos(x),  sin(x)],
                 [0, -sin(x),   cos(x)]])
    # y
    Ry=np.array([[ cos(y), 0, -sin(y)],
                 [      0, 1,      0],
                 [sin(y), 0, cos(y)]])
    # z
    Rz=np.array([[cos(z), sin(z), 0],
                 [-sin(z),  cos(z), 0],
                 [     0,       0, 1]])
    R = Rx.dot(Ry).dot(Rz)
    return R.astype(np.float32)

def calc_pose(param):
    P = param[:12].reshape(3, -1)  # camera matrix
    s, R, t3d = P2sRt(P)
    P = np.concatenate((R, t3d.reshape(3, -1)), axis=1)  # without scale
    pose = matrix2angle(R)
    pose = [p * 180 / np.pi for p in pose]

    return P, pose


def build_camera_box(rear_size=90):
    point_3d = []
    rear_depth = 0
    point_3d.append((-rear_size, -rear_size, rear_depth))
    point_3d.append((-rear_size, rear_size, rear_depth))
    point_3d.append((rear_size, rear_size, rear_depth))
    point_3d.append((rear_size, -rear_size, rear_depth))
    point_3d.append((-rear_size, -rear_size, rear_depth))

    front_size = int(4 / 3 * rear_size)
    front_depth = int(4 / 3 * rear_size)
    point_3d.append((-front_size, -front_size, front_depth))
    point_3d.append((-front_size, front_size, front_depth))
    point_3d.append((front_size, front_size, front_depth))
    point_3d.append((front_size, -front_size, front_depth))
    point_3d.append((-front_size, -front_size, front_depth))
    point_3d = np.array(point_3d, dtype=np.float32).reshape(-1, 3)

    return point_3d


def plot_pose_box(img, P, ver, color=(40, 255, 0), line_width=2):
    """ Draw a 3D box as annotation of pose.
    Ref:https://github.com/yinguobing/head-pose-estimation/blob/master/pose_estimator.py
    Args:
        img: the input image
        P: (3, 4). Affine Camera Matrix.
        kpt: (2, 68) or (3, 68)
    """
    llength = calc_hypotenuse(ver)
    point_3d = build_camera_box(llength)
    # Map to 2d image points
    point_3d_homo = np.hstack((point_3d, np.ones([point_3d.shape[0], 1])))  # n x 4
    point_2d = point_3d_homo.dot(P.T)[:, :2]

    point_2d[:, 1] = - point_2d[:, 1]
    point_2d[:, :2] = point_2d[:, :2] - np.mean(point_2d[:4, :2], 0) + np.mean(ver[:2, :27], 1)
    point_2d = np.int32(point_2d.reshape(-1, 2))

    # Draw all the lines
    cv2.polylines(img, [point_2d], True, color, line_width, cv2.LINE_AA)
    cv2.line(img, tuple(point_2d[1]), tuple(
        point_2d[6]), color, line_width, cv2.LINE_AA)
    cv2.line(img, tuple(point_2d[2]), tuple(
        point_2d[7]), color, line_width, cv2.LINE_AA)
    cv2.line(img, tuple(point_2d[3]), tuple(
        point_2d[8]), color, line_width, cv2.LINE_AA)

    return img


def viz_pose(img, param_lst, ver_lst, show_flag=False, wfp=None):
    for param, ver in zip(param_lst, ver_lst):
        P, pose = calc_pose(param)
        img = plot_pose_box(img, P, ver)
        # print(P[:, :3])
        print(f'yaw: {pose[0]:.1f}, pitch: {pose[1]:.1f}, roll: {pose[2]:.1f}')

    if wfp is not None:
        cv2.imwrite(wfp, img)
        print(f'Save visualization result to {wfp}')

    if show_flag:
        plot_image(img)

    return img

def pose_6(param):
    P = param[:12].reshape(3, -1)  # camera matrix
    s, R, t3d = P2sRt(P)
    P = np.concatenate((R, t3d.reshape(3, -1)), axis=1)  # without scale
    pose = matrix2angle(R)
    print(t3d)
    R1 = angle2matrix(pose)
    print(R)
    print(R1)
    pose = [p * 180 / np.pi for p in pose]
    
    return s, pose, t3d, P


def smooth_pose(img, param_lst, ver_lst, pose_new, show_flag=False, wfp=None, wnp = None):
    for param, ver in zip(param_lst, ver_lst):
        t3d = np.array([pose_new[4],pose_new[5],pose_new[6]])
        
        theta = np.array([pose_new[0],pose_new[1],pose_new[2]])
        theta = [p * np.pi / 180 for p in theta]
        R = angle2matrix(theta)
        P = np.concatenate((R, t3d.reshape(3, -1)), axis=1) 
        img = plot_pose_box(img, P, ver)
    #    print(P,P.shape,t3d)
        print(P,pose_new)
        print(f'yaw: {theta[0]:.1f}, pitch: {theta[1]:.1f}, roll: {theta[2]:.1f}')
        all_pose = [0]
        all_pose = np.array(all_pose)

    if wfp is not None:
        cv2.imwrite(wfp, img)
        print(f'Save visualization result to {wfp}')
        
    if wnp is not None:
        np.save(wnp, all_pose)
        print(f'Save visualization result to {wfp}')
        
    if show_flag:
        plot_image(img)

    return img

    
    
    

def get_pose(img, param_lst, ver_lst, show_flag=False, wfp=None, wnp = None):
    for param, ver in zip(param_lst, ver_lst):
        s, pose, t3d, P = pose_6(param)
        img = plot_pose_box(img, P, ver)
    #    print(P,P.shape,t3d)
        print(f'yaw: {pose[0]:.1f}, pitch: {pose[1]:.1f}, roll: {pose[2]:.1f}')
        all_pose = [pose[0],pose[1],pose[2],s,t3d[0],t3d[1],t3d[2]]
        all_pose = np.array(all_pose)

    if wfp is not None:
        cv2.imwrite(wfp, img)
        print(f'Save visualization result to {wfp}')
        
    if wnp is not None:
        np.save(wnp, all_pose)
        print(f'Save visualization result to {wfp}')
        
    if show_flag:
        plot_image(img)

    return all_pose



================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2022 jixinya

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# EAMM:  One-Shot Emotional Talking Face via Audio-Based Emotion-Aware Motion Model [SIGGRAPH 2022 Conference]

Xinya Ji, [Hang Zhou](https://hangz-nju-cuhk.github.io/), Kaisiyuan Wang, [Qianyi Wu](https://wuqianyi.top/), [Wayne Wu](http://wywu.github.io/), [Feng Xu](http://xufeng.site/), [Xun Cao](https://cite.nju.edu.cn/People/Faculty/20190621/i5054.html)

[[Project]](https://jixinya.github.io/projects/EAMM/)  [[Paper]](https://arxiv.org/abs/2205.15278)    

![visualization](demo/teaser-1.png)

Given a single portrait image, we can synthesize emotional talking faces, where mouth movements match the input audio and facial emotion dynamics follow the emotion source video.

## Installation

We train and test based on Python3.6 and Pytorch. To install the dependencies run:

```
pip install -r requirements.txt
```

## Testing

- Download the pre-trained models and data under the following link: [google-drive](https://drive.google.com/file/d/1IL9LjH3JegyMqJABqMxrX3StAq_v8Gtp/view?usp=sharing) and put the file in corresponding places.

- Run the demo:
  
  `python demo.py --source_image path/to/image --driving_video path/to/emotion_video --pose_file path/to/pose --in_file path/to/audio --emotion emotion_type`
  
- Prepare testing data:

  prepare source_image -- crop_image in process_data.py

  prepare driving_video -- crop_image_tem in process_data.py

  prepare pose -- detect pose using [3DDFA_V2](https://github.com/cleardusk/3DDFA_V2)

## Training

- Training data structure:

  ```
  ./data/<dataset_name>
  ├──fomm_crop
  │  ├──id/file_name   # cropped images
  │  │  ├──0.png
  │  │  ├──...
  ├──fomm_pose_crop
  │  ├──id   
  │  │  ├──file_name.npy  # pose of the cropped images
  │  │  ├──...
  ├──MFCC
  │  ├──id   
  │  │  ├──file_name.npy  # MFCC of the audio
  │  │  ├──...
  
  
  *The cropped images are generated by 'crop_image_tem' in process_data.py
  *The pose of the cropped video are generated by 3DDFA_V2/demo.py
  *The MFCC of the audio are generated by 'audio2mfcc' in process_data.py
  ```

    

- Step 1 : Train the Audio2Facial-Dynamics Module using LRW dataset

  `python run.py --config config/train_part1.yaml --mode train_part1 --checkpoint log/124_52000.pth.tar `

- Step 2 : Fine-tune the Audio2Facial-Dynamics Module after getting stable results from step1

  `python run.py --config config/train_part1_fine_tune.yaml --mode train_part1_fine_tune --checkpoint log/124_52000.pth.tar --audio_chechpoint  checkpoint/from/step_1`

- Setp 3 : Train the Implicit Emotion Displacement Learner

  `python run.py --config config/train_part2.yaml --mode train_part2 --checkpoint log/124_52000.pth.tar --audio_chechpoint  checkpoint/from/step_2`

## Citation

```
@inproceedings{10.1145/3528233.3530745,
author = {Ji, Xinya and Zhou, Hang and Wang, Kaisiyuan and Wu, Qianyi and Wu, Wayne and Xu, Feng and Cao, Xun},
title = {EAMM: One-Shot Emotional Talking Face via Audio-Based Emotion-Aware Motion Model},
year = {2022},
isbn = {9781450393379},
url = {https://doi.org/10.1145/3528233.3530745},
doi = {10.1145/3528233.3530745},
booktitle = {ACM SIGGRAPH 2022 Conference Proceedings},
series = {SIGGRAPH '22}
}


```



================================================
FILE: augmentation.py
================================================
"""
Code from https://github.com/hassony2/torch_videovision
"""

import numbers
import math
import random
import numpy as np
import PIL
import cv2
from skimage.transform import resize, rotate, AffineTransform, warp
from skimage.util import pad
import torchvision

import warnings

from skimage import img_as_ubyte, img_as_float


def crop_clip(clip, min_h, min_w, h, w):
    if isinstance(clip[0], np.ndarray):
        cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]

    elif isinstance(clip[0], PIL.Image.Image):
        cropped = [
            img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
            ]
    else:
        raise TypeError('Expected numpy.ndarray or PIL.Image' +
                        'but got list of {0}'.format(type(clip[0])))
    return cropped


def pad_clip(clip, h, w):
    im_h, im_w = clip[0].shape[:2]
    pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2)
    pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2)

    return pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge')


def resize_clip(clip, size, interpolation='bilinear'):
    if isinstance(clip[0], np.ndarray):
        if isinstance(size, numbers.Number):
            im_h, im_w, im_c = clip[0].shape
            # Min spatial dim already matches minimal size
            if (im_w <= im_h and im_w == size) or (im_h <= im_w
                                                   and im_h == size):
                return clip
            new_h, new_w = get_resize_sizes(im_h, im_w, size)
            size = (new_w, new_h)
        else:
            size = size[1], size[0]

        scaled = [
            resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True,
                   mode='constant', anti_aliasing=True) for img in clip
            ]
    elif isinstance(clip[0], PIL.Image.Image):
        if isinstance(size, numbers.Number):
            im_w, im_h = clip[0].size
            # Min spatial dim already matches minimal size
            if (im_w <= im_h and im_w == size) or (im_h <= im_w
                                                   and im_h == size):
                return clip
            new_h, new_w = get_resize_sizes(im_h, im_w, size)
            size = (new_w, new_h)
        else:
            size = size[1], size[0]
        if interpolation == 'bilinear':
            pil_inter = PIL.Image.NEAREST
        else:
            pil_inter = PIL.Image.BILINEAR
        scaled = [img.resize(size, pil_inter) for img in clip]
    else:
        raise TypeError('Expected numpy.ndarray or PIL.Image' +
                        'but got list of {0}'.format(type(clip[0])))
    return scaled


def get_resize_sizes(im_h, im_w, size):
    if im_w < im_h:
        ow = size
        oh = int(size * im_h / im_w)
    else:
        oh = size
        ow = int(size * im_w / im_h)
    return oh, ow


class RandomFlip(object):
    def __init__(self, time_flip=False, horizontal_flip=False):
        self.time_flip = time_flip
        self.horizontal_flip = horizontal_flip

    def __call__(self, clip):
        if random.random() < 0.5 and self.time_flip:
            return clip[::-1]
        if random.random() < 0.5 and self.horizontal_flip:
            return [np.fliplr(img) for img in clip]

        return clip


class RandomResize(object):
    """Resizes a list of (H x W x C) numpy.ndarray to the final size
    The larger the original image is, the more times it takes to
    interpolate
    Args:
    interpolation (str): Can be one of 'nearest', 'bilinear'
    defaults to nearest
    size (tuple): (widht, height)
    """

    def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
        self.ratio = ratio
        self.interpolation = interpolation

    def __call__(self, clip):
        scaling_factor = random.uniform(self.ratio[0], self.ratio[1])

        if isinstance(clip[0], np.ndarray):
            im_h, im_w, im_c = clip[0].shape
        elif isinstance(clip[0], PIL.Image.Image):
            im_w, im_h = clip[0].size

        new_w = int(im_w * scaling_factor)
        new_h = int(im_h * scaling_factor)
        new_size = (new_w, new_h)
        resized = resize_clip(
            clip, new_size, interpolation=self.interpolation)

        return resized


class RandomCrop(object):
    """Extract random crop at the same location for a list of videos
    Args:
    size (sequence or int): Desired output size for the
    crop in format (h, w)
    """

    def __init__(self, size):
        if isinstance(size, numbers.Number):
            size = (size, size)

        self.size = size

    def __call__(self, clip):
        """
        Args:
        img (PIL.Image or numpy.ndarray): List of videos to be cropped
        in format (h, w, c) in numpy.ndarray
        Returns:
        PIL.Image or numpy.ndarray: Cropped list of videos
        """
        h, w = self.size
        if isinstance(clip[0], np.ndarray):
            im_h, im_w, im_c = clip[0].shape
        elif isinstance(clip[0], PIL.Image.Image):
            im_w, im_h = clip[0].size
        else:
            raise TypeError('Expected numpy.ndarray or PIL.Image' +
                            'but got list of {0}'.format(type(clip[0])))

        clip = pad_clip(clip, h, w)
        im_h, im_w = clip.shape[1:3]
        x1 = 0 if h == im_h else random.randint(0, im_w - w)
        y1 = 0 if w == im_w else random.randint(0, im_h - h)
        cropped = crop_clip(clip, y1, x1, h, w)

        return cropped


class MouthCrop(object):
    """Extract random crop at the same location for a list of videos
    Args:
    size (sequence or int): Desired output size for the
    crop in format (h, w)
    """

    def __init__(self, center_x, center_y, mask_width, mask_height):
        

        self.center_x = center_x
        self.center_y = center_y
        self.mask_width = mask_width
        self.mask_height = mask_height

    def __call__(self, clip):
        """
        Args:
        img (PIL.Image or numpy.ndarray): List of videos to be cropped
        in format (h, w, c) in numpy.ndarray
        Returns:
        PIL.Image or numpy.ndarray: Cropped list of videos
        """
        start_x = self.center_x - int(self.mask_width/2)
        start_y = self.center_y - int(self.mask_height/2) 
        end_x = start_x + self.mask_width
        end_y = start_y + self.mask_height
        # mask is all white
        # mask = 255*np.ones((mask_height, mask_width, 3), dtype=np.uint8)
        # mask is uniform noise
        cropped = []
        for i in range(len(clip)):
            mask = np.random.rand(self.mask_height, self.mask_width, 3)
            img = clip[i].copy()
            img[start_y:end_y, start_x:end_x, :] = mask
        
            cropped.append(img)
        cropped = np.array(cropped)
        return cropped

class RandomRotation(object):
    """Rotate entire clip randomly by a random angle within
    given bounds
    Args:
    degrees (sequence or int): Range of degrees to select from
    If degrees is a number instead of sequence like (min, max),
    the range of degrees, will be (-degrees, +degrees).
    """

    def __init__(self, degrees):
        if isinstance(degrees, numbers.Number):
            if degrees < 0:
                raise ValueError('If degrees is a single number,'
                                 'must be positive')
            degrees = (-degrees, degrees)
        else:
            if len(degrees) != 2:
                raise ValueError('If degrees is a sequence,'
                                 'it must be of len 2.')

        self.degrees = degrees

    def __call__(self, clip):
        """
        Args:
        img (PIL.Image or numpy.ndarray): List of videos to be cropped
        in format (h, w, c) in numpy.ndarray
        Returns:
        PIL.Image or numpy.ndarray: Cropped list of videos
        """
        angle = random.uniform(self.degrees[0], self.degrees[1])
        if isinstance(clip[0], np.ndarray):
            rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip]
        elif isinstance(clip[0], PIL.Image.Image):
            rotated = [img.rotate(angle) for img in clip]
        else:
            raise TypeError('Expected numpy.ndarray or PIL.Image' +
                            'but got list of {0}'.format(type(clip[0])))

        return rotated

class RandomPerspective(object):
    """Rotate entire clip randomly by a random angle within
    given bounds
    Args:
    degrees (sequence or int): Range of degrees to select from
    If degrees is a number instead of sequence like (min, max),
    the range of degrees, will be (-degrees, +degrees).
    """

    def __init__(self, pers_num, enlarge_num):
        self.pers_num = pers_num
        self.enlarge_num = enlarge_num

    def __call__(self, clip):
        """
        Args:
        img (PIL.Image or numpy.ndarray): List of videos to be cropped
        in format (h, w, c) in numpy.ndarray
        Returns:
        PIL.Image or numpy.ndarray: Cropped list of videos
        """
        out = clip
        for i in range(len(clip)):
            self.pers_size = np.random.randint(20, self.pers_num) * pow(-1, np.random.randint(2))
            self.enlarge_size = np.random.randint(20, self.enlarge_num) * pow(-1, np.random.randint(2))
            h, w, c = clip[i].shape
            crop_size=256
            dst = np.array([
                [-self.enlarge_size, -self.enlarge_size],
                [-self.enlarge_size + self.pers_size, w + self.enlarge_size],
                [h + self.enlarge_size, -self.enlarge_size],
                [h + self.enlarge_size - self.pers_size, w + self.enlarge_size],], dtype=np.float32)
            src = np.array([[-self.enlarge_size, -self.enlarge_size], [-self.enlarge_size, w + self.enlarge_size],
                        [h + self.enlarge_size, -self.enlarge_size], [h + self.enlarge_size, w + self.enlarge_size]]).astype(np.float32())
            M = cv2.getPerspectiveTransform(src, dst)
            warped = cv2.warpPerspective(clip[i], M, (crop_size, crop_size), borderMode=cv2.BORDER_REPLICATE)
            out[i] = warped

        return out


class ColorJitter(object):
    """Randomly change the brightness, contrast and saturation and hue of the clip
    Args:
    brightness (float): How much to jitter brightness. brightness_factor
    is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
    contrast (float): How much to jitter contrast. contrast_factor
    is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
    saturation (float): How much to jitter saturation. saturation_factor
    is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
    hue(float): How much to jitter hue. hue_factor is chosen uniformly from
    [-hue, hue]. Should be >=0 and <= 0.5.
    """

    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
        self.brightness = brightness
        self.contrast = contrast
        self.saturation = saturation
        self.hue = hue

    def get_params(self, brightness, contrast, saturation, hue):
        if brightness > 0:
            brightness_factor = random.uniform(
                max(0, 1 - brightness), 1 + brightness)
        else:
            brightness_factor = None

        if contrast > 0:
            contrast_factor = random.uniform(
                max(0, 1 - contrast), 1 + contrast)
        else:
            contrast_factor = None

        if saturation > 0:
            saturation_factor = random.uniform(
                max(0, 1 - saturation), 1 + saturation)
        else:
            saturation_factor = None

        if hue > 0:
            hue_factor = random.uniform(-hue, hue)
        else:
            hue_factor = None
        return brightness_factor, contrast_factor, saturation_factor, hue_factor

    def __call__(self, clip):
        """
        Args:
        clip (list): list of PIL.Image
        Returns:
        list PIL.Image : list of transformed PIL.Image
        """
        if isinstance(clip[0], np.ndarray):
            brightness, contrast, saturation, hue = self.get_params(
                self.brightness, self.contrast, self.saturation, self.hue)

            # Create img transform function sequence
            img_transforms = []
            if brightness is not None:
                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
            if saturation is not None:
                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
            if hue is not None:
                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
            if contrast is not None:
                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
            random.shuffle(img_transforms)
            img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array,
                                                                                                     img_as_float]

            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                jittered_clip = []
                for img in clip:
                    jittered_img = img
                    for func in img_transforms:
                        jittered_img = func(jittered_img)
                    jittered_clip.append(jittered_img.astype('float32'))
        elif isinstance(clip[0], PIL.Image.Image):
            brightness, contrast, saturation, hue = self.get_params(
                self.brightness, self.contrast, self.saturation, self.hue)

            # Create img transform function sequence
            img_transforms = []
            if brightness is not None:
                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
            if saturation is not None:
                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
            if hue is not None:
                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
            if contrast is not None:
                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
            random.shuffle(img_transforms)

            # Apply to all videos
            jittered_clip = []
            for img in clip:
                for func in img_transforms:
                    jittered_img = func(img)
                jittered_clip.append(jittered_img)

        else:
            raise TypeError('Expected numpy.ndarray or PIL.Image' +
                            'but got list of {0}'.format(type(clip[0])))
        return jittered_clip


class AllAugmentationTransform:
    def __init__(self, crop_mouth_param = None, resize_param=None, rotation_param=None, perspective_param=None, flip_param=None, crop_param=None, jitter_param=None):
        self.transforms = []
        if crop_mouth_param is not None:
            self.transforms.append(MouthCrop(**crop_mouth_param))
        
        if flip_param is not None:
            self.transforms.append(RandomFlip(**flip_param))

        if rotation_param is not None:
            self.transforms.append(RandomRotation(**rotation_param))
        
        if perspective_param is not None:
            self.transforms.append(RandomPerspective(**perspective_param))

        if resize_param is not None:
            self.transforms.append(RandomResize(**resize_param))

        if crop_param is not None:
            self.transforms.append(RandomCrop(**crop_param))

        if jitter_param is not None:
            self.transforms.append(ColorJitter(**jitter_param))
      
    def __call__(self, clip):
        for t in self.transforms:
            clip = t(clip)
        return clip


================================================
FILE: config/MEAD_emo_video_aug_delta_4_crop_random_crop.yaml
================================================
dataset_params:
  root_dir: /mnt/lustre/share_data/jixinya/MEAD/
  frame_shape: [256, 256, 3]
  id_sampling: False
  pairs_list: Random_choice
  augmentation_params:
    crop_mouth_param: 
      center_x: 135
      center_y: 190
      mask_width: 100
      mask_height: 60
    rotation_param: 
      degrees: 30
    perspective_param: 
      pers_num: 30
      enlarge_num: 40
    flip_param:
      horizontal_flip: True
      time_flip: False
    jitter_param:
      brightness: 0
      contrast: 0
      saturation: 0
      hue: 0

model_params:
  common_params:
    num_kp: 10
    num_channels: 3
    estimate_jacobian: True
  audio_params:
    num_kp: 10
    num_channels : 3
    num_channels_a : 3
    estimate_jacobian: True
  kp_detector_params:
     temperature: 0.1
     block_expansion: 32
     max_features: 1024
     scale_factor: 0.25
     num_blocks: 5
  generator_params:
    block_expansion: 64
    max_features: 512
    num_down_blocks: 2
    num_bottleneck_blocks: 6
    estimate_occlusion_map: True
    dense_motion_params:
      block_expansion: 64
      max_features: 1024
      num_blocks: 5
      scale_factor: 0.25
  discriminator_params:
    scales: [1]
    block_expansion: 32
    max_features: 512
    num_blocks: 4
    sn: True

train_params:
  type: linear_4
  smooth: False
  jaco_net: cnn
  ldmark: fake
  generator: not
  train_generator: False
  num_epochs: 300
  num_repeats: 1
  epoch_milestones: [60, 90]
  lr_generator: 2.0e-4
  lr_discriminator: 2.0e-4
  lr_kp_detector: 2.0e-4
  lr_audio_feature: 2.0e-4
  batch_size: 16
  scales: [1, 0.5, 0.25, 0.125]
  checkpoint_freq: 1
  transform_params:
    sigma_affine: 0.05
    sigma_tps: 0.005
    points_tps: 5
  loss_weights:
    generator_gan: 0
    discriminator_gan: 1
    feature_matching: [10, 10, 10, 10]
    perceptual: [10, 10, 10, 10, 10]
    equivariance_value: 0
    equivariance_jacobian: 0
    emo: 10

reconstruction_params:
  num_videos: 1000
  format: '.mp4'

animate_params:
  num_pairs: 50
  format: '.mp4'
  normalization_params:
    adapt_movement_scale: False
    use_relative_movement: True
    use_relative_jacobian: True

visualizer_params:
  kp_size: 5
  draw_border: True
  colormap: 'gist_rainbow'


================================================
FILE: config/train_part1.yaml
================================================
dataset_params:
  name: Vox
  root_dir: dataset/LRW/
  frame_shape: [256, 256, 3]
  id_sampling: False
  augmentation_params:
    flip_param:
      horizontal_flip: False
      time_flip: False
    jitter_param:
      brightness: 0.1
      contrast: 0.1
      saturation: 0.1
      hue: 0.1


model_params:
  common_params:
    num_kp: 10
    num_channels: 3
    estimate_jacobian: True
  audio_params:
    num_kp: 10
    num_channels : 3
    num_channels_a : 3
    estimate_jacobian: True
  kp_detector_params:
     temperature: 0.1
     block_expansion: 32
     max_features: 1024
     scale_factor: 0.25
     num_blocks: 5
  generator_params:
    block_expansion: 64
    max_features: 512
    num_down_blocks: 2
    num_bottleneck_blocks: 6
    estimate_occlusion_map: True
    dense_motion_params:
      block_expansion: 64
      max_features: 1024
      num_blocks: 5
      scale_factor: 0.25
  discriminator_params:
    scales: [1]
    block_expansion: 32
    max_features: 512
    num_blocks: 4
    sn: True

train_params:
  jaco_net: cnn
  ldmark: fake
  generator: not
  num_epochs: 300
  num_repeats: 1
  epoch_milestones: [60, 90]
  lr_generator: 2.0e-4
  lr_discriminator: 2.0e-4
  lr_kp_detector: 2.0e-4
  lr_audio_feature: 2.0e-4
  batch_size: 8
  scales: [1, 0.5, 0.25, 0.125]
  checkpoint_freq: 1
  transform_params:
    sigma_affine: 0.05
    sigma_tps: 0.005
    points_tps: 5
  loss_weights:
    generator_gan: 0
    discriminator_gan: 0
    feature_matching: [10, 10, 10, 10]
    perceptual: [10, 10, 10, 10, 10]
    equivariance_value: 0
    equivariance_jacobian: 0
    audio: 10



visualizer_params:
  kp_size: 5
  draw_border: True
  colormap: 'gist_rainbow'


================================================
FILE: config/train_part1_fine_tune.yaml
================================================
dataset_params:
  name: LRW
  root_dir: dataset/LRW/
  frame_shape: [256, 256, 3]
  id_sampling: False
  augmentation_params:
    flip_param:
      horizontal_flip: False
      time_flip: False
    jitter_param:
      brightness: 0.1
      contrast: 0.1
      saturation: 0.1
      hue: 0.1


model_params:
  common_params:
    num_kp: 10
    num_channels: 3
    estimate_jacobian: True
  audio_params:
    num_kp: 10
    num_channels : 3
    num_channels_a : 3
    estimate_jacobian: True
  kp_detector_params:
     temperature: 0.1
     block_expansion: 32
     max_features: 1024
     scale_factor: 0.25
     num_blocks: 5
  generator_params:
    block_expansion: 64
    max_features: 512
    num_down_blocks: 2
    num_bottleneck_blocks: 6
    estimate_occlusion_map: True
    dense_motion_params:
      block_expansion: 64
      max_features: 1024
      num_blocks: 5
      scale_factor: 0.25
  discriminator_params:
    scales: [1]
    block_expansion: 32
    max_features: 512
    num_blocks: 4
    sn: True

train_params:
  jaco_net: cnn
  ldmark: fake
  generator: audio
  num_epochs: 300
  num_repeats: 1
  epoch_milestones: [60, 90]
  lr_generator: 2.0e-4
  lr_discriminator: 2.0e-4
  lr_kp_detector: 2.0e-4
  lr_audio_feature: 2.0e-4
  batch_size: 6
  scales: [1, 0.5, 0.25, 0.125]
  checkpoint_freq: 1
  transform_params:
    sigma_affine: 0.05
    sigma_tps: 0.005
    points_tps: 5
  loss_weights:
    generator_gan: 0
    discriminator_gan: 0
    feature_matching: [10, 10, 10, 10]
    perceptual: [0.1, 0.1, 0.1, 0.1, 0.1]
    equivariance_value: 0
    equivariance_jacobian: 0
    audio: 10

visualizer_params:
  kp_size: 5
  draw_border: True
  colormap: 'gist_rainbow'


================================================
FILE: config/train_part2.yaml
================================================
dataset_params:
  name: MEAD
  root_dir: dataset/MEAD/
  frame_shape: [256, 256, 3]
  id_sampling: False
  augmentation_params:
    crop_mouth_param: 
      center_x: 135
      center_y: 190
      mask_width: 100
      mask_height: 60
    rotation_param: 
      degrees: 30
    perspective_param: 
      pers_num: 30
      enlarge_num: 40
    flip_param:
      horizontal_flip: True
      time_flip: False
    jitter_param: 
      brightness: 0
      contrast: 0
      saturation: 0
      hue: 0

model_params:
  common_params:
    num_kp: 10
    num_channels: 3
    estimate_jacobian: True
  audio_params:
    num_kp: 10
    num_channels : 3
    num_channels_a : 3
    estimate_jacobian: True
  kp_detector_params:
     temperature: 0.1
     block_expansion: 32
     max_features: 1024
     scale_factor: 0.25
     num_blocks: 5
  generator_params:
    block_expansion: 64
    max_features: 512
    num_down_blocks: 2
    num_bottleneck_blocks: 6
    estimate_occlusion_map: True
    dense_motion_params:
      block_expansion: 64
      max_features: 1024
      num_blocks: 5
      scale_factor: 0.25
  discriminator_params:
    scales: [1]
    block_expansion: 32
    max_features: 512
    num_blocks: 4
    sn: True

train_params:
  type: linear_4
  smooth: False
  jaco_net: cnn
  ldmark: fake
  generator: not
  num_epochs: 300
  num_repeats: 1
  epoch_milestones: [60, 90]
  lr_generator: 2.0e-4
  lr_discriminator: 2.0e-4
  lr_kp_detector: 2.0e-4
  lr_audio_feature: 2.0e-4
  batch_size: 16
  scales: [1, 0.5, 0.25, 0.125]
  checkpoint_freq: 1
  transform_params:
    sigma_affine: 0.05
    sigma_tps: 0.005
    points_tps: 5
  loss_weights:
    generator_gan: 0
    discriminator_gan: 0
    feature_matching: [10, 10, 10, 10]
    perceptual: [10, 10, 10, 10, 10]
    equivariance_value: 0
    equivariance_jacobian: 0
    emo: 10


visualizer_params:
  kp_size: 5
  draw_border: True
  colormap: 'gist_rainbow'


================================================
FILE: demo.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Oct  6 20:57:27 2021
@author: thea
"""

import matplotlib
matplotlib.use('Agg')
import os,sys
import yaml
from argparse import ArgumentParser
from tqdm import tqdm
from skimage import io, img_as_float32
import imageio
import numpy as np
from skimage.transform import resize
from skimage import img_as_ubyte
import torch
from filter1 import OneEuroFilter
import torch.utils

from torch.autograd import Variable
from modules.generator import OcclusionAwareGenerator
from modules.keypoint_detector import KPDetector, KPDetector_a
from modules.util import AT_net, Emotion_k, Emotion_map, AT_net2
from augmentation import AllAugmentationTransform

from scipy.spatial import ConvexHull

import python_speech_features
from pathlib import Path
import dlib
import cv2
import librosa
from skimage import transform as tf
#from audiolm.models import AT_emoiton
#from audiolm.utils import plot_flmarks
if sys.version_info[0] < 3:
    raise Exception("You must use Python 3 or higher. Recommended version is Python 3.6")


detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor('./shape_predictor_68_face_landmarks.dat')




def load_checkpoints(opt, checkpoint_path, audio_checkpoint_path, emo_checkpoint_path, cpu=False):

    with open(opt.config) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
                                        **config['model_params']['common_params'])
    if not cpu:
        generator.cuda()

    kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                             **config['model_params']['common_params'])
    if not cpu:
        kp_detector.cuda()

    kp_detector_a = KPDetector_a(**config['model_params']['kp_detector_params'],
                             **config['model_params']['audio_params'])

    audio_feature = AT_net2()
    if opt.type.startswith('linear'):
        emo_detector = Emotion_k(block_expansion=32, num_channels=3, max_features=1024,
                 num_blocks=5, scale_factor=0.25, num_classes=8)
    elif opt.type.startswith('map'):
        emo_detector = Emotion_map(block_expansion=32, num_channels=3, max_features=1024,
                 num_blocks=5, scale_factor=0.25, num_classes=8)
    if not cpu:
        kp_detector_a.cuda()
        audio_feature.cuda()
        emo_detector.cuda()




    if cpu:
        checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
        audio_checkpoint = torch.load(audio_checkpoint_path, map_location=torch.device('cpu'))
        emo_checkpoint = torch.load(emo_checkpoint_path, map_location=torch.device('cpu'))
    else:
        checkpoint = torch.load(checkpoint_path)
        audio_checkpoint = torch.load(audio_checkpoint_path)
        emo_checkpoint = torch.load(emo_checkpoint_path)

    generator.load_state_dict(checkpoint['generator'])
    kp_detector.load_state_dict(checkpoint['kp_detector'])
    audio_feature.load_state_dict(audio_checkpoint['audio_feature'])
    kp_detector_a.load_state_dict(audio_checkpoint['kp_detector_a'])
    emo_detector.load_state_dict(emo_checkpoint['emo_detector'])
    

    if not cpu:
        generator = generator.cuda()
        kp_detector = kp_detector.cuda()
        audio_feature = audio_feature.cuda()
        kp_detector_a = kp_detector_a.cuda()
        emo_detector = emo_detector.cuda()

    generator.eval()
    kp_detector.eval()
    audio_feature.eval()
    kp_detector_a.eval()
    emo_detector.eval()
    return generator, kp_detector, kp_detector_a, audio_feature, emo_detector

def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
                 use_relative_movement=False, use_relative_jacobian=False):
    if adapt_movement_scale:
        source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
        driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
        adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
    else:
        adapt_movement_scale = 1

    kp_new = {k: v for k, v in kp_driving.items()}

    if use_relative_movement:
        kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
        kp_value_diff *= adapt_movement_scale
        kp_new['value'] = kp_value_diff + kp_source['value']

        if use_relative_jacobian:
            jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
            kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])

    return kp_new

def shape_to_np(shape, dtype="int"):
    # initialize the list of (x, y)-coordinates
    coords = np.zeros((shape.num_parts, 2), dtype=dtype)

    # loop over all facial landmarks and convert them
    # to a 2-tuple of (x, y)-coordinates
    for i in range(0, shape.num_parts):
        coords[i] = (shape.part(i).x, shape.part(i).y)

    # return the list of (x, y)-coordinates
    return coords

def get_aligned_image(driving_video, opt):
    aligned_array = []

    video_array = np.array(driving_video)
    source_image=video_array[0]
   # aligned_array.append(source_image)
    source_image = np.array(source_image * 255, dtype=np.uint8)
    gray = cv2.cvtColor(source_image, cv2.COLOR_BGR2GRAY)
    rects = detector(gray, 1)  #detect human face
    for (i, rect) in enumerate(rects):
        template = predictor(gray, rect) #detect 68 points
        template = shape_to_np(template)

    if opt.emotion == 'surprised' or opt.emotion == 'fear':
        template = template-[0,10]
    for i in range(len(video_array)):
        image=np.array(video_array[i] * 255, dtype=np.uint8)
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        rects = detector(gray, 1)  #detect human face
        for (j, rect) in enumerate(rects):
            shape = predictor(gray, rect) #detect 68 points
            shape = shape_to_np(shape)

        pts2 = np.float32(template[:35,:])
        pts1 = np.float32(shape[:35,:]) #eye and nose

    #    pts2 = np.float32(np.concatenate((template[:16,:],template[27:36,:]),axis = 0))
    #    pts1 = np.float32(np.concatenate((shape[:16,:],shape[27:36,:]),axis = 0)) #eye and nose
        # pts1 = np.float32(landmark[17:35,:])
        tform = tf.SimilarityTransform()
        tform.estimate( pts2, pts1) #Set the transformation matrix with the explicit parameters.
        dst = tf.warp(image, tform, output_shape=(256, 256))

        dst = np.array(dst, dtype=np.float32)
        aligned_array.append(dst)

    return aligned_array

def get_transformed_image(driving_video, opt):
    video_array = np.array(driving_video)
    with open(opt.config) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    transformations = AllAugmentationTransform(**config['dataset_params']['augmentation_params'])
    transformed_array = transformations(video_array)
    return transformed_array



def make_animation_smooth(source_image, driving_video, transformed_video, deco_out, kp_loss, generator, kp_detector, kp_detector_a, emo_detector, opt, relative=True, adapt_movement_scale=True, cpu=False):
    with torch.no_grad():
        predictions = []

        source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)

        if not cpu:
            source = source.cuda()

        driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
        transformed_driving = torch.tensor(np.array(transformed_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)

        kp_source = kp_detector(source)
        kp_driving_initial = kp_detector_a(deco_out[:,0])

        emo_driving_all = []
        features = []
        kp_driving_all = []
        for frame_idx in tqdm(range(len(deco_out[0]))):

            driving_frame = driving[:, :, frame_idx]
            transformed_frame = transformed_driving[:, :, frame_idx]
            if not cpu:
                driving_frame = driving_frame.cuda()
                transformed_frame = transformed_frame.cuda()
            kp_driving = kp_detector_a(deco_out[:,frame_idx])
            kp_driving_all.append(kp_driving)
            if opt.add_emo:
                value = kp_driving['value']
                jacobian = kp_driving['jacobian']
                if opt.type == 'linear_3':
                    emo_driving,_ = emo_detector(transformed_frame,value,jacobian)
                    features.append(emo_detector.feature(transformed_frame).data.cpu().numpy())
            
                emo_driving_all.append(emo_driving)
        features = np.array(features)
        if opt.add_emo:        
            one_euro_filter_v = OneEuroFilter(mincutoff=1, beta=0.2, dcutoff=1.0, freq=100)#1 0.4
            one_euro_filter_j = OneEuroFilter(mincutoff=1, beta=0.2, dcutoff=1.0, freq=100)#1 0.4

            for j in range(len(emo_driving_all)):
                emo_driving_all[j]['value']=one_euro_filter_v.process(emo_driving_all[j]['value'].cpu()*100)/100
                emo_driving_all[j]['value'] = emo_driving_all[j]['value'].cuda()
                emo_driving_all[j]['jacobian']=one_euro_filter_j.process(emo_driving_all[j]['jacobian'].cpu()*100)/100
                emo_driving_all[j]['jacobian'] = emo_driving_all[j]['jacobian'].cuda()


        one_euro_filter_v = OneEuroFilter(mincutoff=0.05, beta=8, dcutoff=1.0, freq=100)
        one_euro_filter_j = OneEuroFilter(mincutoff=0.05, beta=8, dcutoff=1.0, freq=100)

        for j in range(len(kp_driving_all)):
            kp_driving_all[j]['value']=one_euro_filter_v.process(kp_driving_all[j]['value'].cpu()*10)/10
            kp_driving_all[j]['value'] = kp_driving_all[j]['value'].cuda()
            kp_driving_all[j]['jacobian']=one_euro_filter_j.process(kp_driving_all[j]['jacobian'].cpu()*10)/10
            kp_driving_all[j]['jacobian'] = kp_driving_all[j]['jacobian'].cuda()


        for frame_idx in tqdm(range(len(deco_out[0]))):
            
            if opt.check_add:
                kp_driving = kp_detector_a(deco_out[:,0])
            else:
                kp_driving = kp_driving_all[frame_idx]

       #     kp_driving_real = kp_detector(driving_frame)

       #     kp_driving['value'] = (1-opt.weight)*kp_driving['value'] + opt.weight*kp_driving_real['value']
       #     kp_driving['jacobian'] = (1-opt.weight)*kp_driving['jacobian'] + opt.weight*kp_driving_real['jacobian']

            if opt.add_emo:
                emo_driving = emo_driving_all[frame_idx]
                if opt.type == 'linear_3':
                    kp_driving['value'][:,1] = kp_driving['value'][:,1] + emo_driving['value'][:,0]*0.2
                    kp_driving['jacobian'][:,1] = kp_driving['jacobian'][:,1] + emo_driving['jacobian'][:,0]*0.2
                    kp_driving['value'][:,4] = kp_driving['value'][:,4] + emo_driving['value'][:,1]
                    kp_driving['jacobian'][:,4] = kp_driving['jacobian'][:,4] + emo_driving['jacobian'][:,1]
                    kp_driving['value'][:,6] = kp_driving['value'][:,6] + emo_driving['value'][:,2]
                    kp_driving['jacobian'][:,6] = kp_driving['jacobian'][:,6] + emo_driving['jacobian'][:,2]
                   # kp_driving['value'][:,8] = kp_driving['value'][:,8] + emo_driving['value'][:,3]
                   # kp_driving['jacobian'][:,8] = kp_driving['jacobian'][:,8] + emo_driving['jacobian'][:,3]
               
         
            kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
                                   kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
                                   use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
            out = generator(source, kp_source=kp_source, kp_driving=kp_norm)

            predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
    return predictions, features



def test_auido(example_image, audio_feature, all_pose, opt):
    with open(opt.config) as f:
        para = yaml.load(f, Loader=yaml.FullLoader)

  #  encoder = audio_feature()
    if not opt.cpu:
        audio_feature = audio_feature.cuda()

    audio_feature.eval()
 #   decoder.eval()
    test_file = opt.in_file
    pose = all_pose[:,:6]
    if len(pose) == 1:
        pose = np.repeat(pose,100,0)

    elif opt.smooth_pose:
        one_euro_filter = OneEuroFilter(mincutoff=0.004, beta=0.7, dcutoff=1.0, freq=100)


        for j in range(len(pose)):
            pose[j]=one_euro_filter.process(pose[j])
      #      pose[j]=pose[0]

    example_image = np.array(example_image, dtype='float32').transpose((2, 0, 1))

    


    speech, sr = librosa.load(test_file, sr=16000)
  #  mfcc = python_speech_features.mfcc(speech ,16000,winstep=0.01)
    speech = np.insert(speech, 0, np.zeros(1920))
    speech = np.append(speech, np.zeros(1920))
    mfcc = python_speech_features.mfcc(speech,16000,winstep=0.01)


    print ('=======================================')
    print ('Start to generate images')

    ind = 3
    with torch.no_grad():
        fake_lmark = []
        input_mfcc = []
        while ind <= int(mfcc.shape[0]/4) - 4:
            t_mfcc =mfcc[( ind - 3)*4: (ind + 4)*4, 1:]
            t_mfcc = torch.FloatTensor(t_mfcc).cuda()
            input_mfcc.append(t_mfcc)
            ind += 1
        input_mfcc = torch.stack(input_mfcc,dim = 0)

        if (len(pose)<len(input_mfcc)):
            gap = len(input_mfcc)-len(pose)
            n = int((gap/len(pose)/2)) +2
            pose = np.concatenate((pose,pose[::-1,:]),axis = 0)
            pose = np.tile(pose, (n,1))
        if(len(pose)>len(input_mfcc)):
            pose = pose[:len(input_mfcc),:]
        
        if not opt.cpu:
            example_image = Variable(torch.FloatTensor(example_image.astype(float)) ).cuda()
            example_image = torch.unsqueeze(example_image,0)
            pose = Variable(torch.FloatTensor(pose.astype(float)) ).cuda()
        
        pose = pose.unsqueeze(0)

        input_mfcc = input_mfcc.unsqueeze(0)

        deco_out = audio_feature(example_image,input_mfcc,pose,para['train_params']['jaco_net'],1.6)

        return deco_out


def save(path, frames, format):

    if format == '.png':
        if not os.path.exists(path):

            os.makedirs(path)
        for j, frame in enumerate(frames):
            imageio.imsave(path+'/'+str(j)+'.png',frame)
    #        imageio.imsave(os.path.join(path, str(j) + '.png'), frames[j])
    else:
        print ("Unknown format %s" % format)
        exit()

class VideoWriter(object):
    def __init__(self, path, width, height, fps):
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        self.path = path
        self.out = cv2.VideoWriter(self.path, fourcc, fps, (width, height))

    def write_frame(self, frame):
        self.out.write(frame)

    def end(self):
        self.out.release()

def concatenate(number, imgs, save_path):
    width, height = imgs.shape[-3:-1]
    imgs = imgs.reshape(number,-1,width,height,3)
    if number == 2:
        left = imgs[0]
        right = imgs[1]

        im_all = []
        for i in range(len(left)):
            im = np.concatenate((left[i],right[i]),axis = 1)
            im_all.append(im)
    if number == 3:
        left = imgs[0]
        middle = imgs[1]
        right = imgs[2]

        im_all = []
        for i in range(len(left)):
            im = np.concatenate((left[i],middle[i],right[i]),axis = 1)
            im_all.append(im)
    if number == 4:
        left = imgs[0]
        left2 = imgs[1]
        right = imgs[2]
        right2 = imgs[3]

        im_all = []
        for i in range(len(left)):
            im = np.concatenate((left[i],left2[i],right[i],right2[i]),axis = 1)
            im_all.append(im)
    if number == 5:
        left = imgs[0]
        left2 = imgs[1]
        middle = imgs[2]
        right = imgs[3]
        right2 = imgs[4]

        im_all = []
        for i in range(len(left)):
            im = np.concatenate((left[i],left2[i],middle[i],right[i],right2[i]),axis = 1)
            im_all.append(im)


    imageio.mimsave(save_path, [img_as_ubyte(frame) for frame in im_all], fps=25)

def add_audio(video_name=None, audio_dir = None):

    command = 'ffmpeg -i ' + video_name  + ' -i ' + audio_dir + ' -vcodec copy  -acodec copy -y  ' + video_name.replace('.mp4','.mov')
    print (command)
    os.system(command)

def crop_image(source_image):
    
    template = np.load('./M003_template.npy')
    image= cv2.imread(source_image)
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    rects = detector(gray, 1)  #detect human face
    if len(rects) != 1:
        return 0
    for (j, rect) in enumerate(rects):
        shape = predictor(gray, rect) #detect 68 points
        shape = shape_to_np(shape)

    pts2 = np.float32(template[:47,:])
    pts1 = np.float32(shape[:47,:]) #eye and nose
    # pts1 = np.float32(landmark[17:35,:])
    tform = tf.SimilarityTransform()
    tform.estimate( pts2, pts1) #Set the transformation matrix with the explicit parameters.
  
    dst = tf.warp(image, tform, output_shape=(256, 256))

    dst = np.array(dst * 255, dtype=np.uint8)
    return dst 

def smooth_pose(pose_file, pose_long):
    start = np.load(pose_file)
    video_pose = np.load(pose_long)
    delta = video_pose - video_pose[0,:]
    print(len(delta))
    
    pose = np.repeat(start,len(delta),axis = 0)
    all_pose =  pose + delta

    return all_pose

def test(opt, name):

    all_pose = np.load(opt.pose_file).reshape(-1,7)
    if opt.pose_long:

        all_pose = smooth_pose(opt.pose_file,opt.pose_given)

    
   # source_image = img_as_float32(io.imread(opt.source_image))
    source_image = img_as_float32(crop_image(opt.source_image))
    source_image = resize(source_image, (256, 256))[..., :3]
  
    reader = imageio.get_reader(opt.driving_video)
    fps = reader.get_meta_data()['fps']
    driving_video = []
    try:
        for im in reader:
            driving_video.append(im)
    except RuntimeError:
        pass
    reader.close()

   
    driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
    driving_video = get_aligned_image(driving_video, opt)
    transformed_video = get_transformed_image(driving_video, opt)
    transformed_video = np.array(transformed_video)

    generator, kp_detector,kp_detector_a, audio_feature, emo_detector = load_checkpoints(opt=opt, checkpoint_path=opt.checkpoint, audio_checkpoint_path=opt.audio_checkpoint, emo_checkpoint_path = opt.emo_checkpoint, cpu=opt.cpu)
 
    deco_out = test_auido(source_image, audio_feature, all_pose, opt)
    if len(driving_video) < len(deco_out[0]):
        driving_video = np.resize(driving_video,(len(deco_out[0]),256,256,3))
        transformed_video = np.resize(transformed_video,(len(deco_out[0]),256,256,3))

    else:
        driving_video = driving_video[:len(deco_out[0])]
    opt.add_emo = False
    predictions, _ = make_animation_smooth(source_image, driving_video, transformed_video, deco_out, opt.kp_loss, generator, kp_detector, kp_detector_a, emo_detector, opt, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
  
    imageio.mimsave(os.path.join(opt.result_path,'neutral.mp4'), [img_as_ubyte(frame) for frame in predictions], fps=fps)
    predictions = np.array(predictions)
     
    opt.add_emo = True
  
    predictions1,_ = make_animation_smooth(source_image, driving_video, transformed_video, deco_out, opt.kp_loss, generator, kp_detector, kp_detector_a, emo_detector, opt, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
  
    imageio.mimsave(os.path.join(opt.result_path,'emotion.mp4'), [img_as_ubyte(frame) for frame in predictions1], fps=fps)
    add_audio(os.path.join(opt.result_path,'emotion.mp4'),opt.in_file)
    predictions1 = np.array(predictions1)
    all_imgs = np.concatenate((driving_video,predictions,predictions1),axis = 0)
    save_path = os.path.join(opt.result_path, 'all.mp4')
    concatenate(3, all_imgs, save_path)
    add_audio(save_path,opt.in_file)



if __name__ == "__main__":
   
    
   
    parser = ArgumentParser()
    parser.add_argument("--config", default ='config/MEAD_emo_video_aug_delta_4_crop_random_crop.yaml', help="path to config")#required=True default ='config/vox-256.yaml'
 
    parser.add_argument("--audio_checkpoint", default='log/1-6000.pth.tar', help="path to checkpoint to restore")
    parser.add_argument("--checkpoint", default='log/124_52000.pth.tar', help="path to checkpoint to restore")
   # parser.add_argument("--emo_checkpoint", default='ablation/ablation/ten/10-6000.pth.tar', help="path to checkpoint to restore")
    parser.add_argument("--emo_checkpoint", default='log/5-3000.pth.tar', help="path to checkpoint to restore")

    parser.add_argument("--source_image", default='test/image/21.png', help="path to source image")
 
    parser.add_argument("--driving_video", default='test/video/disgusted.mp4', help="path to driving video")#data/M030/video/M030_angry_
    parser.add_argument('--in_file', type=str, default='test/audio/sample1.mov')
    parser.add_argument('--pose_file', type=str, default='test/pose/21.npy')
    parser.add_argument('--pose_given', type=str, default='test/pose_long/0zn70Ak8lRc_Daniel_Auteuil_0zn70Ak8lRc_0002.npy')

    parser.add_argument("--result_path", default='result/', help="path to output")#'/media/thea/新加卷/fomm/Exp/'+emotion+'.mp4'

    parser.add_argument("--relative", dest="relative", action="store_true", help="use relative or absolute keypoint coordinates")
    parser.add_argument("--adapt_scale", dest="adapt_scale", action="store_true", help="adapt movement scale based on convex hull of keypoints")

    parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
    parser.add_argument("--kp_loss", default=0, help="keypoint loss.")

    parser.add_argument("--smooth_pose",  default=True, help="cpu mode.")
    parser.add_argument("--pose_long",  default=False, help="use given long poses.")
    parser.add_argument("--weight",  default=0, help="cpu mode.")
    parser.add_argument("--add_emo",  default=False, help="add emotion.")
    parser.add_argument("--check_add",  default=False, help="check emotion displacement.")
    parser.add_argument("--type",  default='linear_3', help="add emotion type.")
    parser.add_argument("--emotion",  default='disgusted', help="emotion category, 'angry', 'contempt','disgusted','fear','happy','neutral','sad','surprised'.")
    parser.set_defaults(relative=False)
    parser.set_defaults(adapt_scale=False)

    opt = parser.parse_args()
 #   opt.cpu = True
   
    test(opt,'test')
         
    

================================================
FILE: filter1.py
================================================
import cv2
#import pickle
import time
import numpy as np
import copy

from matplotlib import pyplot as plt
from tqdm import tqdm




class LowPassFilter:
  def __init__(self):
    self.prev_raw_value = None
    self.prev_filtered_value = None

  def process(self, value, alpha):
    if self.prev_raw_value is None:
      s = value
    else:
      s = alpha * value + (1.0 - alpha) * self.prev_filtered_value
    self.prev_raw_value = value
    self.prev_filtered_value = s
    return s


class OneEuroFilter:
  def __init__(self, mincutoff=1.0, beta=0.0, dcutoff=1.0, freq=30):
    self.freq = freq
    self.mincutoff = mincutoff
    self.beta = beta
    self.dcutoff = dcutoff
    self.x_filter = LowPassFilter()
    self.dx_filter = LowPassFilter()

  def compute_alpha(self, cutoff):
    te = 1.0 / self.freq
    tau = 1.0 / (2 * np.pi * cutoff)
    return 1.0 / (1.0 + tau / te)

  def process(self, x):
    prev_x = self.x_filter.prev_raw_value
    dx = 0.0 if prev_x is None else (x - prev_x) * self.freq
    edx = self.dx_filter.process(dx, self.compute_alpha(self.dcutoff))
    cutoff = self.mincutoff + self.beta * np.abs(edx)
    return self.x_filter.process(x, self.compute_alpha(cutoff))



================================================
FILE: frames_dataset.py
================================================
import os
from skimage import io, img_as_float32, transform
from skimage.color import gray2rgb
from sklearn.model_selection import train_test_split
from imageio import mimread

import numpy as np
from torch.utils.data import Dataset
import pandas as pd
from augmentation import AllAugmentationTransform
import glob
import pickle
import random
from filter1 import OneEuroFilter
def read_video(name, frame_shape):
    """
    Read video which can be:
      - an image of concatenated frames
      - '.mp4' and'.gif'
      - folder with videos
    """

    if os.path.isdir(name):
        frames = sorted(os.listdir(name))
        num_frames = len(frames)
        video_array = np.array(
            [img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)])
    elif name.lower().endswith('.png') or name.lower().endswith('.jpg'):
        image = io.imread(name)

        if len(image.shape) == 2 or image.shape[2] == 1:
            image = gray2rgb(image)

        if image.shape[2] == 4:
            image = image[..., :3]

        image = img_as_float32(image)

        video_array = np.moveaxis(image, 1, 0)

        video_array = video_array.reshape((-1,) + frame_shape)
        video_array = np.moveaxis(video_array, 1, 2)
    elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'):
        video = np.array(mimread(name))
        if len(video.shape) == 3:
            video = np.array([gray2rgb(frame) for frame in video])
        if video.shape[-1] == 4:
            video = video[..., :3]
        video_array = img_as_float32(video)
    else:
        raise Exception("Unknown file extensions  %s" % name)

    return video_array

def get_list(ipath,base_name):
#ipath = '/mnt/lustre/share/jixinya/LRW/pose/train_fo/'
    ipath = os.path.join(ipath,base_name)
    name_list = os.listdir(ipath)
    image_path = os.path.join('/mnt/lustre/share/jixinya/LRW/Image/',base_name)
    all = []
    for k in range(len(name_list)):
        name = name_list[k]
        path_ = os.path.join(ipath,name)
        Dir = os.listdir(path_)
        for i in range(len(Dir)):
            word = Dir[i]
            path = os.path.join(path_, word)
            if os.path.exists(os.path.join(image_path,name,word.split('.')[0])):
                all.append(name+'/'+word.split('.')[0])
            #print(k,name,i,word)
    print('get list '+os.path.basename(ipath))
    return all


class AudioDataset(Dataset):
    """
    Dataset of videos, each video can be represented as:
      - an image of concatenated frames
      - '.mp4' or '.gif'
      - folder with all frames
    """

    def __init__(self, name, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,
                 random_seed=0, augmentation_params=None):
        self.root_dir = root_dir
        self.audio_dir = os.path.join(root_dir,'MFCC')
        self.image_dir = os.path.join(root_dir,'Image')
        self.pose_dir = os.path.join(root_dir,'pose')
      #  assert len(os.listdir(self.audio_dir)) == len(os.listdir(self.image_dir)), 'audio and image length not equal'

      #  self.videos=np.load('../LRW/list/train_fo.npy')
      #  self.videos = os.listdir(self.landmark_dir)
        self.frame_shape = tuple(frame_shape)
       
        self.id_sampling = id_sampling

        if os.path.exists(os.path.join(self.pose_dir, 'train_fo')):
            assert os.path.exists(os.path.join(self.pose_dir, 'test_fo'))
            print("Use predefined train-test split.")
            if id_sampling:
                train_videos = {os.path.basename(video).split('#')[0] for video in
                                os.listdir(os.path.join(self.image_dir, 'train'))}
                train_videos = list(train_videos)
            else:
                train_videos =  np.load('../LRW/list/train_fo.npy')# get_list(self.pose_dir, 'train_fo')
         #   df=open('../LRW/list/test_fo.txt','rb')
            test_videos=np.load('../LRW/list/test_fo.npy')
         #   df.close()
         #   test_videos = np.load('../LRW/list/train_fo.npy')
            #get_list(self.pose_dir, 'test_fo')
        #    self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')
           
            self.image_dir = os.path.join(self.image_dir, 'train_fo' if is_train else 'test_fo')
            self.audio_dir = os.path.join(self.audio_dir, 'train' if is_train else 'test')
            self.pose_dir = os.path.join(self.pose_dir, 'train_fo' if is_train else 'test_fo')
        else:
            print("Use random train-test split.")
            train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)

        if is_train:
            self.videos = train_videos
        else:
            self.videos = test_videos

        self.is_train = is_train

        if self.is_train:
            self.transform = AllAugmentationTransform(**augmentation_params)
        else:
            self.transform = None

    def __len__(self):
        return len(self.videos)

    def __getitem__(self, idx):
        if self.is_train and self.id_sampling:
            name = self.videos[idx].split('.')[0]
            path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))
        else:
            name = self.videos[idx].split('.')[0]
           
            audio_path = os.path.join(self.audio_dir, name)
            pose_path = os.path.join(self.pose_dir,name)
            path = os.path.join(self.image_dir, name)

        video_name = os.path.basename(path)

        if  os.path.isdir(path):
     #   if self.is_train and os.path.isdir(path):
         
            # mfcc loading
            r = random.choice([x for x in range(3, 8)])

            example_image = img_as_float32(io.imread(os.path.join(path, str(r)+'.png')))

            mfccs = []
            for ind in range(1, 17):
              #  t_mfcc = mfcc[(r + ind - 3) * 4: (r + ind + 4) * 4, 1:]
                t_mfcc = np.load(os.path.join(audio_path,str(r + ind)+'.npy'),allow_pickle=True)[:, 1:]
                mfccs.append(t_mfcc)
            mfccs = np.array(mfccs)
            
            poses = []
            video_array = []
            for ind in range(1, 17):
              
                t_pose = np.load(os.path.join(self.pose_dir,name+'.npy'))[r+ind,:-1]
                
                poses.append(t_pose)
                image = img_as_float32(io.imread(os.path.join(path, str(r + ind)+'.png')))
                video_array.append(image)
            poses = np.array(poses)
            video_array = np.array(video_array)

        else:
            print('Wrong, data path not an existing file.')

        if self.transform is not None:
            video_array = self.transform(video_array)

        out = {}
     
        driving = np.array(video_array, dtype='float32')
        spatial_size = np.array(driving.shape[1:3][::-1])[np.newaxis]
        driving_pose = np.array(poses, dtype='float32')
        example_image = np.array(example_image, dtype='float32')

        out['example_image'] = example_image.transpose((2, 0, 1))
        out['driving_pose'] = driving_pose
        out['driving'] = driving.transpose((0, 3, 1, 2))
        out['driving_audio'] = np.array(mfccs, dtype='float32')
    #    out['name'] = video_name

        return out

class VoxDataset(Dataset):
    """
    Dataset of videos, each video can be represented as:
      - an image of concatenated frames
      - '.mp4' or '.gif'
      - folder with all frames
    """

    def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,
                 random_seed=0, pairs_list=None, augmentation_params=None):
        self.root_dir = root_dir
        self.audio_dir = os.path.join(root_dir,'MFCC')
        self.image_dir = os.path.join(root_dir,'align_img')

        self.pose_dir = os.path.join(root_dir,'align_pose')
      #  assert len(os.listdir(self.audio_dir)) == len(os.listdir(self.image_dir)), 'audio and image length not equal'


     #   df=open('../LRW/list/test_fo.txt','rb')
     #  self.videos=pickle.load(df)
     #   df.close()
        self.videos=np.load('/mnt/lustre/share_data/jixinya/VoxCeleb1_Cut/right.npy')
      #  self.videos = os.listdir(self.landmark_dir)
        self.frame_shape = tuple(frame_shape)
        self.pairs_list = pairs_list
        self.id_sampling = id_sampling

        if os.path.exists(os.path.join(self.pose_dir, 'train_fo')):
            assert os.path.exists(os.path.join(self.pose_dir, 'test_fo'))
            print("Use predefined train-test split.")
            if id_sampling:
                train_videos = {os.path.basename(video).split('#')[0] for video in
                                os.listdir(os.path.join(self.image_dir, 'train'))}
                train_videos = list(train_videos)
            else:
                train_videos = np.load('/mnt/lustre/share_data/jixinya/VoxCeleb1_Cut/right.npy')# get_list(self.pose_dir, 'train_fo')
      
            self.image_dir = os.path.join(self.image_dir, 'train_fo' if is_train else 'test_fo')
            self.audio_dir = os.path.join(self.audio_dir, 'train' if is_train else 'test')
            self.pose_dir = os.path.join(self.pose_dir, 'train_fo' if is_train else 'test_fo')
        else:
            print("Use random train-test split.")
            train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)

        if is_train:
            self.videos = train_videos
        else:
            self.videos = test_videos

        self.is_train = is_train

        if self.is_train:
            self.transform = AllAugmentationTransform(**augmentation_params)
        else:
            self.transform = None

    def __len__(self):
        return len(self.videos)

    def __getitem__(self, idx):
        if self.is_train and self.id_sampling:
            name = self.videos[idx].split('.')[0]
            path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))
        else:
            name = self.videos[idx].split('.')[0]

            audio_path = os.path.join(self.audio_dir, name+'.npy')
            pose_path = os.path.join(self.pose_dir,name+'.npy')
            path = os.path.join(self.image_dir, name)

        video_name = os.path.basename(path)

        if  os.path.isdir(path):
     #   if self.is_train and os.path.isdir(path):
            frames = os.listdir(path)
            num_frames = len(frames)
            frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2))
            video_array = [img_as_float32(io.imread(os.path.join(path, frames[idx]))) for idx in frame_idx]
            mfcc = np.load(audio_path)
            pose = np.load(pose_path)

          #  print(audio_path,pose_path,len(mfcc))

            try:
                len(mfcc) > 16
            except:
                print('wrongmfcc len:',audio_path)
            if 16 < len(mfcc) < 24 :
                r = 0
            else:

                r = random.choice([x for x in range(3, len(mfcc)-20)])

            mfccs = []
            poses = []
            video_array = []
            for ind in range(1, 17):
                t_mfcc = mfcc[r+ind][:, 1:]
                mfccs.append(t_mfcc)
                t_pose = pose[r+ind,:-1]
                poses.append(t_pose)
                image = img_as_float32(io.imread(os.path.join(path, str(r + ind)+'.png')))
                video_array.append(image)
            mfccs = np.array(mfccs)
            poses = np.array(poses)
            video_array = np.array(video_array)

            example_image = img_as_float32(io.imread(os.path.join(path, str(r)+'.png')))


        else:
            print('Wrong, data path not an existing file.')

        if self.transform is not None:
            video_array = self.transform(video_array)

        out = {}

        driving = np.array(video_array, dtype='float32')

        spatial_size = np.array(driving.shape[1:3][::-1])[np.newaxis]
        driving_pose = np.array(poses, dtype='float32')
        example_image = np.array(example_image, dtype='float32')
        out['example_image'] = example_image.transpose((2, 0, 1))
        out['driving_pose'] = driving_pose
        out['driving'] = driving.transpose((0, 3, 1, 2))

        out['driving_audio'] = np.array(mfccs, dtype='float32')
    #    out['name'] = video_name

        return out

class MeadDataset(Dataset):
    """
    Dataset of videos, each video can be represented as:
      - an image of concatenated frames
      - '.mp4' or '.gif'
      - folder with all frames
    """

    def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,
                 random_seed=0, augmentation_params=None):
        self.root_dir = root_dir

        self.audio_dir = os.path.join(root_dir,'MEAD_MFCC')
        self.image_dir = os.path.join(root_dir,'MEAD_fomm_crop')

        self.pose_dir = os.path.join(root_dir,'MEAD_fomm_pose_crop')

        self.videos = np.load('/mnt/lustre/share_data/jixinya/MEAD/MEAD_fomm_audio_less_crop.npy')
        self.dict = np.load('/mnt/lustre/share_data/jixinya/MEAD/MEAD_fomm_neu_dic_crop.npy',allow_pickle=True).item()
       # self.videos = os.listdir(root_dir)
        self.frame_shape = tuple(frame_shape)

        self.id_sampling = id_sampling
        if os.path.exists(os.path.join(root_dir, 'train')):
            assert os.path.exists(os.path.join(root_dir, 'test'))
            print("Use predefined train-test split.")
            if id_sampling:
                train_videos = {os.path.basename(video).split('#')[0] for video in
                                os.listdir(os.path.join(root_dir, 'train'))}
                train_videos = list(train_videos)
            else:
                train_videos = os.listdir(os.path.join(root_dir, 'train'))
            test_videos = os.listdir(os.path.join(root_dir, 'test'))
            self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')
        else:
            print("Use random train-test split.")
            train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)

        if is_train:
            self.videos = train_videos
        else:
            self.videos = test_videos

        self.is_train = is_train

        if self.is_train:
            self.transform = AllAugmentationTransform(**augmentation_params)
        else:
            self.transform = None

    def __len__(self):
        return len(self.videos)

    def __getitem__(self, idx):
        if self.is_train and self.id_sampling:
            name = self.videos[idx]
            path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))
        else:
            name = self.videos[idx]
            path = os.path.join(self.image_dir, name)

            video_name = os.path.basename(path)
            id_name = path.split('/')[-2]
            neu_list = self.dict[id_name]
            neu_path = os.path.join(self.image_dir, np.random.choice(neu_list))

            audio_path = os.path.join(self.audio_dir, name+'.npy')
            pose_path = os.path.join(self.pose_dir,name+'.npy')


        if self.is_train and os.path.isdir(path):

            mfcc = np.load(audio_path)
            pose_raw = np.load(pose_path)
            one_euro_filter = OneEuroFilter(mincutoff=0.01, beta=0.7, dcutoff=1.0, freq=100)
            pose = np.zeros((len(pose_raw),7))

            for j in range(len(pose_raw)):
                pose[j]=one_euro_filter.process(pose_raw[j])
          #  print(audio_path,pose_path,len(mfcc))

            neu_frames = os.listdir(neu_path)
            num_neu_frames = len(neu_frames)
            frame_idx = np.random.choice(num_neu_frames)
            example_image = img_as_float32(io.imread(os.path.join(neu_path, neu_frames[frame_idx])))
            try:
                len(mfcc) > 16
            except:
                print('wrongmfcc len:',audio_path)
            if 16 < len(mfcc) < 24 :
                r = 0
            else:

                r = random.choice([x for x in range(3, len(mfcc)-20)])

            mfccs = []
            poses = []
            video_array = []
            for ind in range(1, 17):
                t_mfcc = mfcc[r+ind][:, 1:]
                mfccs.append(t_mfcc)
                t_pose = pose[r+ind,:-1]
                poses.append(t_pose)
                image = img_as_float32(io.imread(os.path.join(path, str(r + ind)+'.png')))
                video_array.append(image)
            mfccs = np.array(mfccs)
            poses = np.array(poses)
            video_array = np.array(video_array)

        else:
            print('Wrong, data path not an existing file.')

        if self.transform is not None:
            video_array = self.transform(video_array)

        out = {}
        if self.is_train:
      
            driving = np.array(video_array, dtype='float32')
            driving_pose = np.array(poses, dtype='float32')
            example_image = np.array(example_image, dtype='float32')


            out['example_image'] = example_image.transpose((2, 0, 1))
            out['driving_pose'] = driving_pose
            out['driving'] = driving.transpose((0, 3, 1, 2))
            out['driving_audio'] = np.array(mfccs, dtype='float32')

      #  out['name'] = id_name+'/'+video_name

        return out


class DatasetRepeater(Dataset):
    """
    Pass several times over the same dataset for better i/o performance
    """

    def __init__(self, dataset, num_repeats=100):
        self.dataset = dataset
    #    self.dataset2 = dataset2
        self.num_repeats = num_repeats

    def __len__(self):
        return self.num_repeats * self.dataset.__len__()

    def __getitem__(self, idx):
     #   if idx % 5 == 0:
     #       return self.dataset2[idx % self.dataset2.__len__()]#% self.dataset.__len__()
     #   else:
     #       return self.dataset[idx % self.dataset.__len__()]
        return self.dataset[idx % self.dataset.__len__()]

class TestsetRepeater(Dataset):
    """
    Pass several times over the same dataset for better i/o performance
    """

    def __init__(self, dataset, num_repeats=100):
        self.dataset = dataset

        self.num_repeats = num_repeats

    def __len__(self):
        return self.num_repeats * self.dataset.__len__()

    def __getitem__(self, idx):

        return self.dataset[idx % self.dataset.__len__()]#% self.dataset.__len__()


class PairedDataset(Dataset):
    """
    Dataset of pairs for animation.
    """

    def __init__(self, initial_dataset, number_of_pairs, seed=0):
        self.initial_dataset = initial_dataset
        pairs_list = self.initial_dataset.pairs_list

        np.random.seed(seed)

        if pairs_list is None:
            max_idx = min(number_of_pairs, len(initial_dataset))
            nx, ny = max_idx, max_idx
            xy = np.mgrid[:nx, :ny].reshape(2, -1).T
            number_of_pairs = min(xy.shape[0], number_of_pairs)
            self.pairs = xy.take(np.random.choice(xy.shape[0], number_of_pairs, replace=False), axis=0)
        else:
            videos = self.initial_dataset.videos
            name_to_index = {name: index for index, name in enumerate(videos)}
            pairs = pd.read_csv(pairs_list)
            pairs = pairs[np.logical_and(pairs['source'].isin(videos), pairs['driving'].isin(videos))]

            number_of_pairs = min(pairs.shape[0], number_of_pairs)
            self.pairs = []
            self.start_frames = []
            for ind in range(number_of_pairs):
                self.pairs.append(
                    (name_to_index[pairs['driving'].iloc[ind]], name_to_index[pairs['source'].iloc[ind]]))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        pair = self.pairs[idx]
        first = self.initial_dataset[pair[0]]
        second = self.initial_dataset[pair[1]]
        first = {'driving_' + key: value for key, value in first.items()}
        second = {'source_' + key: value for key, value in second.items()}

        return {**first, **second}


================================================
FILE: logger.py
================================================
import numpy as np
import torch
import torch.nn.functional as F
import imageio

import os
from skimage.draw import circle

import matplotlib.pyplot as plt
import collections


class Logger:
    def __init__(self, log_dir, checkpoint_freq=100, visualizer_params=None, zfill_num=8, log_file_name='log.txt'):

        self.loss_list = []
        self.cpk_dir = log_dir
        self.visualizations_dir = os.path.join(log_dir, 'train-vis')
        if not os.path.exists(self.visualizations_dir):
            os.makedirs(self.visualizations_dir)
        self.log_file = open(os.path.join(log_dir, log_file_name), 'a')
        self.zfill_num = zfill_num
        self.visualizer = Visualizer(**visualizer_params)
        self.checkpoint_freq = checkpoint_freq
        self.epoch = 0
        self.best_loss = float('inf')
        self.names = None

    def log_scores(self, loss_names):
        loss_mean = np.array(self.loss_list).mean(axis=0)

        loss_string = "; ".join(["%s - %.5f" % (name, value) for name, value in zip(loss_names, loss_mean)])
        loss_string = str(str(self.epoch)+str(self.step).zfill(self.zfill_num)) + ") " + loss_string

        print(loss_string, file=self.log_file)
        self.loss_list = []
        self.log_file.flush()

    def visualize_rec(self, inp, out):
      #  image = self.visualizer.visualize(inp['driving'], inp['source'], out)
        image = self.visualizer.visualize(inp['driving'][:,-1], inp['transformed_driving'][:,-1], inp['example_image'], out)
        imageio.imsave(os.path.join(self.visualizations_dir, "%s-%s-rec.png" % (str(self.epoch),str(self.step).zfill(self.zfill_num))), image)

    def save_cpk(self, emergent=False):
        cpk = {k: v.state_dict() for k, v in self.models.items()}
        cpk['epoch'] = self.epoch
        cpk['step'] = self.step
        cpk_path = os.path.join(self.cpk_dir, '%s-%s-checkpoint.pth.tar' % (str(self.epoch),str(self.step).zfill(self.zfill_num)))
        if not (os.path.exists(cpk_path) and emergent):
            torch.save(cpk, cpk_path)

    @staticmethod
    def load_cpk(checkpoint_path, generator=None, discriminator=None, kp_detector=None, audio_feature=None,
                 optimizer_generator=None, optimizer_discriminator=None, optimizer_kp_detector=None, optimizer_audio_feature = None):
        checkpoint = torch.load(checkpoint_path)
        if generator is not None:
            generator.load_state_dict(checkpoint['generator'])
        if kp_detector is not None:
            kp_detector.load_state_dict(checkpoint['kp_detector'])
        if discriminator is not None:
            try:
               discriminator.load_state_dict(checkpoint['discriminator'])
            except:
               print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
    #    if audio_feature is not None:
    #        audio_feature.load_state_dict(checkpoint['audio_feature'])
        if optimizer_generator is not None:
            optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
        if optimizer_discriminator is not None:
            try:
                optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
            except RuntimeError as e:
                print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
        if optimizer_kp_detector is not None:
            optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])
  #      if optimizer_audio_feature is not None:
  #          a = checkpoint['optimizer_kp_detector']['param_groups']
  #          a[0].pop('params')
  #          optimizer_audio_feature.load_state_dict(checkpoint['optimizer_audio_feature'])

        return checkpoint['epoch']

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if 'models' in self.__dict__:
            self.save_cpk()
        self.log_file.close()

    def log_iter(self, losses):
        losses = collections.OrderedDict(losses.items())
        if self.names is None:
            self.names = list(losses.keys())
        self.loss_list.append(list(losses.values()))

    def log_epoch(self, epoch, step, models, inp, out):
        self.epoch = epoch
        self.step = step
        self.models = models
        if (self.epoch + 1) % self.checkpoint_freq == 0:
            self.save_cpk()
        self.log_scores(self.names)
        self.visualize_rec(inp, out)


class Visualizer:
    def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbow'):
        self.kp_size = kp_size
        self.draw_border = draw_border
        self.colormap = plt.get_cmap(colormap)

    def draw_image_with_kp(self, image, kp_array):
        image = np.copy(image)
        spatial_size = np.array(image.shape[:2][::-1])[np.newaxis]
        kp_array = spatial_size * (kp_array + 1) / 2
        num_kp = kp_array.shape[0]
        for kp_ind, kp in enumerate(kp_array):
            rr, cc = circle(kp[1], kp[0], self.kp_size, shape=image.shape[:2])
            image[rr, cc] = np.array(self.colormap(kp_ind / num_kp))[:3]
        return image

    def create_image_column_with_kp(self, images, kp):
        image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)])
        return self.create_image_column(image_array)

    def create_image_column(self, images):
        if self.draw_border:
            images = np.copy(images)
            images[:, :, [0, -1]] = (1, 1, 1)
            images[:, :, [0, -1]] = (1, 1, 1)
        return np.concatenate(list(images), axis=0)

    def create_image_grid(self, *args):
        out = []
        for arg in args:
            if type(arg) == tuple:
                out.append(self.create_image_column_with_kp(arg[0], arg[1]))
            else:
                out.append(self.create_image_column(arg))
        return np.concatenate(out, axis=1)

    def visualize(self, driving, transformed_driving, source, out):
        images = []

        # Source image with keypoints
        source = source.data.cpu()
        kp_source = out['kp_source']['value'].data.cpu().numpy()
        source = np.transpose(source, [0, 2, 3, 1])
        images.append((source, kp_source))

        # Equivariance visualization
        if 'transformed_frame' in out:
            transformed = out['transformed_frame'].data.cpu().numpy()
            transformed = np.transpose(transformed, [0, 2, 3, 1])
            transformed_kp = out['transformed_kp']['value'].data.cpu().numpy()
            images.append((transformed, transformed_kp))

        # Equivariance visualization
        transformed_driving = transformed_driving.data.cpu().numpy()
        transformed_driving = np.transpose(transformed_driving, [0, 2, 3, 1])
        images.append(transformed_driving)

        # Driving image with keypoints
        kp_driving = out['kp_driving'][-1]['value'].data.cpu().numpy() #[-1]['value']
        driving = driving.data.cpu().numpy()
        driving = np.transpose(driving, [0, 2, 3, 1])
        images.append((driving, kp_driving))

        # Deformed image
        if 'deformed' in out:
            deformed = out['deformed'].data.cpu().numpy()
            deformed = np.transpose(deformed, [0, 2, 3, 1])
            images.append(deformed)

        # Result with and without keypoints
        prediction = out['prediction'].data.cpu().numpy()
        prediction = np.transpose(prediction, [0, 2, 3, 1])
        if 'kp_norm' in out:
            kp_norm = out['kp_norm']['value'].data.cpu().numpy()
            images.append((prediction, kp_norm))
        images.append(prediction)


        ## Occlusion map
        if 'occlusion_map' in out:
            occlusion_map = out['occlusion_map'].data.cpu().repeat(1, 3, 1, 1)
            occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy()
            occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1])
            images.append(occlusion_map)

        # Deformed images according to each individual transform
        if 'sparse_deformed' in out:
            full_mask = []
            for i in range(out['sparse_deformed'].shape[1]):
                image = out['sparse_deformed'][:, i].data.cpu()
                image = F.interpolate(image, size=source.shape[1:3])
                mask = out['mask'][:, i:(i+1)].data.cpu().repeat(1, 3, 1, 1)
                mask = F.interpolate(mask, size=source.shape[1:3])
                image = np.transpose(image.numpy(), (0, 2, 3, 1))
                mask = np.transpose(mask.numpy(), (0, 2, 3, 1))

                if i != 0:
                    color = np.array(self.colormap((i - 1) / (out['sparse_deformed'].shape[1] - 1)))[:3]
                else:
                    color = np.array((0, 0, 0))

                color = color.reshape((1, 1, 1, 3))

                images.append(image)
                if i != 0:
                    images.append(mask * color)
                else:
                    images.append(mask)

                full_mask.append(mask * color)

            images.append(sum(full_mask))

        image = self.create_image_grid(*images)
        image = (255 * image).astype(np.uint8)
        return image


================================================
FILE: modules/dense_motion.py
================================================
from torch import nn
import torch.nn.functional as F
import torch
from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian


class DenseMotionNetwork(nn.Module):
    """
    Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
    """

    def __init__(self, block_expansion, num_blocks, max_features, num_kp, num_channels, estimate_occlusion_map=False,
                 scale_factor=1, kp_variance=0.01):
        super(DenseMotionNetwork, self).__init__()
        self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp + 1) * (num_channels + 1),
                                   max_features=max_features, num_blocks=num_blocks)

        self.mask = nn.Conv2d(self.hourglass.out_filters, num_kp + 1, kernel_size=(7, 7), padding=(3, 3))

        if estimate_occlusion_map:
            self.occlusion = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3))
        else:
            self.occlusion = None

        self.num_kp = num_kp
        self.scale_factor = scale_factor
        self.kp_variance = kp_variance

        if self.scale_factor != 1:
            self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)

    def create_heatmap_representations(self, source_image, kp_driving, kp_source):
        """
        Eq 6. in the paper H_k(z)
        """
        spatial_size = source_image.shape[2:]
        gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance)
        gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance)
        heatmap = gaussian_driving - gaussian_source #[4,10,H,W]

        #adding background feature
        zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type())
        heatmap = torch.cat([zeros, heatmap], dim=1)
        heatmap = heatmap.unsqueeze(2) #[4,11,1,h,w]
        return heatmap

    def create_sparse_motions(self, source_image, kp_driving, kp_source):
        """
        Eq 4. in the paper T_{s<-d}(z)
        """
        bs, _, h, w = source_image.shape
        identity_grid = make_coordinate_grid((h, w), type=kp_source['value'].type())
        identity_grid = identity_grid.view(1, 1, h, w, 2)
        coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 2) #[4,10,64,64,2]
        if 'jacobian' in kp_driving:
            jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian']))
            jacobian = jacobian.unsqueeze(-3).unsqueeze(-3)
            jacobian = jacobian.repeat(1, 1, h, w, 1, 1)
            coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1))
            coordinate_grid = coordinate_grid.squeeze(-1)

        driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 2)

        #adding background feature
        identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)
        sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1)
        return sparse_motions

    def create_deformed_source_image(self, source_image, sparse_motions):
        """
        Eq 7. in the paper \hat{T}_{s<-d}(z)
        """
        bs, _, h, w = source_image.shape
        source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp + 1, 1, 1, 1, 1)
        source_repeat = source_repeat.view(bs * (self.num_kp + 1), -1, h, w)
        sparse_motions = sparse_motions.view((bs * (self.num_kp + 1), h, w, -1))
        sparse_deformed = F.grid_sample(source_repeat, sparse_motions)
        sparse_deformed = sparse_deformed.view((bs, self.num_kp + 1, -1, h, w))
        return sparse_deformed

    def forward(self, source_image, kp_driving, kp_source):
        if self.scale_factor != 1:
            source_image = self.down(source_image) #[4,3,H*scale,W*scale]

        bs, _, h, w = source_image.shape

        out_dict = dict()
        heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source) #[4,11,1,64,64]
        sparse_motion = self.create_sparse_motions(source_image, kp_driving, kp_source) #[4,11,64,64,2]
        deformed_source = self.create_deformed_source_image(source_image, sparse_motion) #[4,11,3,64,64]
        out_dict['sparse_deformed'] = deformed_source

        input = torch.cat([heatmap_representation, deformed_source], dim=2)
        input = input.view(bs, -1, h, w) #[4,11*4,64,64]

        prediction = self.hourglass(input) #[4,108,64,64]

        mask = self.mask(prediction)
        mask = F.softmax(mask, dim=1) #[4,11,64,64]
        out_dict['mask'] = mask
        mask = mask.unsqueeze(2)
        sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3)
        deformation = (sparse_motion * mask).sum(dim=1)
        deformation = deformation.permute(0, 2, 3, 1) #[4,64,64,2]

        out_dict['deformation'] = deformation

        # Sec. 3.2 in the paper
        if self.occlusion:
            occlusion_map = torch.sigmoid(self.occlusion(prediction))
            out_dict['occlusion_map'] = occlusion_map #[4,1,64,64]

        return out_dict


================================================
FILE: modules/discriminator.py
================================================
from torch import nn
import torch.nn.functional as F
from modules.util import kp2gaussian
import torch


class DownBlock2d(nn.Module):
    """
    Simple block for processing video (encoder).
    """

    def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
        super(DownBlock2d, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)

        if sn:
            self.conv = nn.utils.spectral_norm(self.conv)

        if norm:
            self.norm = nn.InstanceNorm2d(out_features, affine=True)
        else:
            self.norm = None
        self.pool = pool

    def forward(self, x):
        out = x
        out = self.conv(out)
        if self.norm:
            out = self.norm(out)
        out = F.leaky_relu(out, 0.2)
        if self.pool:
            out = F.avg_pool2d(out, (2, 2))
        return out


class Discriminator(nn.Module):
    """
    Discriminator similar to Pix2Pix
    """

    def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
                 sn=False, use_kp=False, num_kp=10, kp_variance=0.01, **kwargs):
        super(Discriminator, self).__init__()

        down_blocks = []
        for i in range(num_blocks):
            down_blocks.append(
                DownBlock2d(num_channels + num_kp * use_kp if i == 0 else min(max_features, block_expansion * (2 ** i)),
                            min(max_features, block_expansion * (2 ** (i + 1))),
                            norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))

        self.down_blocks = nn.ModuleList(down_blocks)
        self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
        if sn:
            self.conv = nn.utils.spectral_norm(self.conv)
        self.use_kp = use_kp
        self.kp_variance = kp_variance

    def forward(self, x, kp=None):
        feature_maps = []
        out = x
        if self.use_kp:
            heatmap = kp2gaussian(kp, x.shape[2:], self.kp_variance)
            out = torch.cat([out, heatmap], dim=1)

        for down_block in self.down_blocks:
            feature_maps.append(down_block(out))
            out = feature_maps[-1]
        prediction_map = self.conv(out)

        return feature_maps, prediction_map


class MultiScaleDiscriminator(nn.Module):
    """
    Multi-scale (scale) discriminator
    """

    def __init__(self, scales=(), **kwargs):
        super(MultiScaleDiscriminator, self).__init__()
        self.scales = scales
        discs = {}
        for scale in scales:
            discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
        self.discs = nn.ModuleDict(discs)

    def forward(self, x, kp=None):
        out_dict = {}
        for scale, disc in self.discs.items():
            scale = str(scale).replace('-', '.')
            key = 'prediction_' + scale
            feature_maps, prediction_map = disc(x[key], kp)
            out_dict['feature_maps_' + scale] = feature_maps
            out_dict['prediction_map_' + scale] = prediction_map
        return out_dict


================================================
FILE: modules/function.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 30 17:45:24 2021

@author: SENSETIME\jixinya1
"""

import torch


def calc_mean_std(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std


def adaptive_instance_normalization(content_feat, style_feat):
    assert (content_feat.size()[:2] == style_feat.size()[:2])
    size = content_feat.size()
    style_mean, style_std = calc_mean_std(style_feat)
    content_mean, content_std = calc_mean_std(content_feat)

    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)
    return normalized_feat * style_std.expand(size) + style_mean.expand(size)


def _calc_feat_flatten_mean_std(feat):
    # takes 3D feat (C, H, W), return mean and std of array within channels
    assert (feat.size()[0] == 3)
    assert (isinstance(feat, torch.FloatTensor))
    feat_flatten = feat.view(3, -1)
    mean = feat_flatten.mean(dim=-1, keepdim=True)
    std = feat_flatten.std(dim=-1, keepdim=True)
    return feat_flatten, mean, std


def _mat_sqrt(x):
    U, D, V = torch.svd(x)
    return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t())


def coral(source, target):
    # assume both source and target are 3D array (C, H, W)
    # Note: flatten -> f

    source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
    source_f_norm = (source_f - source_f_mean.expand_as(
        source_f)) / source_f_std.expand_as(source_f)
    source_f_cov_eye = \
        torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)

    target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
    target_f_norm = (target_f - target_f_mean.expand_as(
        target_f)) / target_f_std.expand_as(target_f)
    target_f_cov_eye = \
        torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)

    source_f_norm_transfer = torch.mm(
        _mat_sqrt(target_f_cov_eye),
        torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),
                 source_f_norm)
    )

    source_f_transfer = source_f_norm_transfer * \
                        target_f_std.expand_as(source_f_norm) + \
                        target_f_mean.expand_as(source_f_norm)

    return source_f_transfer.view(source.size())

================================================
FILE: modules/generator.py
================================================
import torch
from torch import nn
import torch.nn.functional as F
from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d
from modules.dense_motion import DenseMotionNetwork


class OcclusionAwareGenerator(nn.Module):
    """
    Generator that given source image and and keypoints try to transform image according to movement trajectories
    induced by keypoints. Generator follows Johnson architecture.
    """

    def __init__(self, num_channels, num_kp, block_expansion, max_features, num_down_blocks,
                 num_bottleneck_blocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
        super(OcclusionAwareGenerator, self).__init__()

        if dense_motion_params is not None:
            self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, num_channels=num_channels,
                                                           estimate_occlusion_map=estimate_occlusion_map,
                                                           **dense_motion_params)
        else:
            self.dense_motion_network = None

        self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))

        down_blocks = []
        for i in range(num_down_blocks):
            in_features = min(max_features, block_expansion * (2 ** i))
            out_features = min(max_features, block_expansion * (2 ** (i + 1)))
            down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
        self.down_blocks = nn.ModuleList(down_blocks)

        up_blocks = []
        for i in range(num_down_blocks):
            in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i)))
            out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1)))
            up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
        self.up_blocks = nn.ModuleList(up_blocks)

        self.bottleneck = torch.nn.Sequential()
        in_features = min(max_features, block_expansion * (2 ** num_down_blocks))
        for i in range(num_bottleneck_blocks):
            self.bottleneck.add_module('r' + str(i), ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1)))

        self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))
        self.estimate_occlusion_map = estimate_occlusion_map
        self.num_channels = num_channels

    def deform_input(self, inp, deformation):
        _, h_old, w_old, _ = deformation.shape
        _, _, h, w = inp.shape
        if h_old != h or w_old != w:
            deformation = deformation.permute(0, 3, 1, 2)
            deformation = F.interpolate(deformation, size=(h, w), mode='bilinear')
            deformation = deformation.permute(0, 2, 3, 1)
        return F.grid_sample(inp, deformation)

    def forward(self, source_image, kp_driving, kp_source):
        # Encoding (downsampling) part
        out = self.first(source_image) #[4,64,H,W]
        for i in range(len(self.down_blocks)):
            out = self.down_blocks[i](out) #[4,256,H/4,W/4]

        # Transforming feature representation according to deformation and occlusion
        output_dict = {}
        if self.dense_motion_network is not None:
            dense_motion = self.dense_motion_network(source_image=source_image, kp_driving=kp_driving,
                                                     kp_source=kp_source)
            output_dict['mask'] = dense_motion['mask']
            output_dict['sparse_deformed'] = dense_motion['sparse_deformed']

            if 'occlusion_map' in dense_motion:
                occlusion_map = dense_motion['occlusion_map']
                output_dict['occlusion_map'] = occlusion_map
            else:
                occlusion_map = None
            deformation = dense_motion['deformation']
            out = self.deform_input(out, deformation)

            if occlusion_map is not None:
                if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
                    occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
                out = out * occlusion_map

            output_dict["deformed"] = self.deform_input(source_image, deformation)

        # Decoding part
        out = self.bottleneck(out) #[4,256,64,64]
        for i in range(len(self.up_blocks)):
            out = self.up_blocks[i](out)
        out = self.final(out)
        out = torch.sigmoid(out) #[4,3,256,256]

        output_dict["prediction"] = out

        return output_dict


================================================
FILE: modules/keypoint_detector.py
================================================
from torch import nn
import torch
import torch.nn.functional as F
from modules.util import Hourglass, make_coordinate_grid, AntiAliasInterpolation2d, Ct_encoder, EmotionNet, AF2F, AF2F_s, draw_heatmap


class KPDetector(nn.Module):
    """
    Detecting a keypoints. Return keypoint position and jacobian near each keypoint.
    """

    def __init__(self, block_expansion, num_kp, num_channels, max_features,
                 num_blocks, temperature, estimate_jacobian=False, scale_factor=1,
                 single_jacobian_map=False, pad=0):
        super(KPDetector, self).__init__()

        self.predictor = Hourglass(block_expansion, in_features=num_channels,
                                   max_features=max_features, num_blocks=num_blocks)

        self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7),
                            padding=pad)

        if estimate_jacobian:
            self.num_jacobian_maps = 1 if single_jacobian_map else num_kp
            self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters,
                                      out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad)
            self.jacobian.weight.data.zero_()
            self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))
        else:
            self.jacobian = None

        self.temperature = temperature
        self.scale_factor = scale_factor
        if self.scale_factor != 1:
            self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
        
        
        
        
    def gaussian2kp(self, heatmap):
        """
        Extract the mean and from a heatmap
        """
        shape = heatmap.shape
        heatmap = heatmap.unsqueeze(-1) #[4,10,58,58,1]
        grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) #[1,1,58,58,2]
        value = (heatmap * grid).sum(dim=(2, 3)) #[4,10,2]
        kp = {'value': value}

        return kp
    
    def audio_feature(self, x, heatmap):
        
      #  prediction = self.kp(x) #[4,10,H/4-6, W/4-6]

      #  final_shape = prediction.shape
      #  heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58]
     #   heatmap = F.softmax(heatmap / self.temperature, dim=2)
     #   heatmap = heatmap.view(*final_shape) #[4,10,58,58]

     #   out = self.gaussian2kp(heatmap)
        final_shape = heatmap.squeeze(2).shape   
     
        if self.jacobian is not None:
            jacobian_map = self.jacobian(x) ##[4,40,H/4-6, W/4-6]
            jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],
                                                final_shape[3])
            heatmap = heatmap.unsqueeze(2)

            jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6]
            jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)
            jacobian = jacobian.sum(dim=-1) #[4,10,4]
            jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2]
            
        return jacobian
    
    def forward(self, x): #torch.Size([4, 3, H, W])
        if self.scale_factor != 1:
            x = self.down(x) # 0.25 [4, 3, H/4, W/4]

        feature_map = self.predictor(x) #[4,3+32,H/4, W/4]
        prediction = self.kp(feature_map) #[4,10,H/4-6, W/4-6]

        final_shape = prediction.shape
        
        heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58]
        heatmap = F.softmax(heatmap / self.temperature, dim=2)
        heatmap = heatmap.view(*final_shape) #[4,10,58,58]
        
        out = self.gaussian2kp(heatmap)
        out['heatmap'] = heatmap
        
        if self.jacobian is not None:
            jacobian_map = self.jacobian(feature_map) ##[4,40,H/4-6, W/4-6]
            jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],
                                                final_shape[3])
            heatmap = heatmap.unsqueeze(2)

            jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6]
            jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)
            jacobian = jacobian.sum(dim=-1) #[4,10,4]
            jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2]
            out['jacobian'] = jacobian

        return out
    
    


class KPDetector_a(nn.Module):
    """
    Detecting a keypoints. Return keypoint position and jacobian near each keypoint.
    """

    def __init__(self, block_expansion, num_kp, num_channels,num_channels_a, max_features,
                 num_blocks, temperature, estimate_jacobian=False, scale_factor=1,
                 single_jacobian_map=False, pad=0):
        super(KPDetector_a, self).__init__()

        self.predictor = Hourglass(block_expansion, in_features=num_channels_a,
                                   max_features=max_features, num_blocks=num_blocks)

        self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7),
                            padding=pad)

        if estimate_jacobian:
            self.num_jacobian_maps = 1 if single_jacobian_map else num_kp
            self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters,
                                      out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad)
            self.jacobian.weight.data.zero_()
            self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))
        else:
            self.jacobian = None

        self.temperature = temperature
        self.scale_factor = scale_factor
        if self.scale_factor != 1:
            self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
        
        
        
        
    def gaussian2kp(self, heatmap):
        """
        Extract the mean and from a heatmap
        """
        shape = heatmap.shape
        heatmap = heatmap.unsqueeze(-1) #[4,10,58,58,1]
        grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) #[1,1,58,58,2]
        value = (heatmap * grid).sum(dim=(2, 3)) #[4,10,2]
        kp = {'value': value}

        return kp
    
    def audio_feature(self, x, heatmap):
        
      #  prediction = self.kp(x) #[4,10,H/4-6, W/4-6]

      #  final_shape = prediction.shape
      #  heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58]
     #   heatmap = F.softmax(heatmap / self.temperature, dim=2)
     #   heatmap = heatmap.view(*final_shape) #[4,10,58,58]

     #   out = self.gaussian2kp(heatmap)
        final_shape = heatmap.squeeze(2).shape   
     
        if self.jacobian is not None:
            jacobian_map = self.jacobian(x) ##[4,40,H/4-6, W/4-6]
            jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],
                                                final_shape[3])
            heatmap = heatmap.unsqueeze(2)

            jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6]
            jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)
            jacobian = jacobian.sum(dim=-1) #[4,10,4]
            jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2]
            
        return jacobian
    
    def forward(self,  feature_map): #torch.Size([4, 3, H, W])
       
        prediction = self.kp(feature_map) #[4,10,H/4-6, W/4-6]

        final_shape = prediction.shape
        
        heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58]
        heatmap = F.softmax(heatmap / self.temperature, dim=2)
        heatmap = heatmap.view(*final_shape) #[4,10,58,58]
        
        out = self.gaussian2kp(heatmap)
        out['heatmap'] = heatmap
        
        if self.jacobian is not None:
            jacobian_map = self.jacobian(feature_map) ##[4,40,H/4-6, W/4-6]
            jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],
                                                final_shape[3])
            heatmap = heatmap.unsqueeze(2)

            jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6]
            jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)
            jacobian = jacobian.sum(dim=-1) #[4,10,4]
            jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2]
            out['jacobian'] = jacobian

        return out
  
    
class Audio_Feature(nn.Module):
    def __init__(self):
        super(Audio_Feature, self).__init__()
        
        self.con_encoder = Ct_encoder()
        self.emo_encoder = EmotionNet()
        self.decoder = AF2F_s()

    
    
    def forward(self, x):
        x = x.unsqueeze(1)
      
        c = self.con_encoder(x)
        e = self.emo_encoder(x)
        
     #   d = torch.cat([c, e], dim=1)
        d = self.decoder(c)
        
        
        return d
'''
def forward(self, x, cube, audio): #torch.Size([4, 3, H, W])
        if self.scale_factor != 1:
            x = self.down(x) # 0.25 [4, 3, H/4, W/4]
        
        cube = cube.unsqueeze(1)
        feature = torch.cat([x,cube,audio],dim=1)
        feature_map = self.predictor(feature) #[4,3+32,H/4, W/4]
        prediction = self.kp(feature_map) #[4,10,H/4-6, W/4-6]

        final_shape = prediction.shape
        heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58]
        heatmap = F.softmax(heatmap / self.temperature, dim=2)
        heatmap = heatmap.view(*final_shape) #[4,10,58,58]

        out = self.gaussian2kp(heatmap)
        out['heatmap'] = heatmap
        if self.jacobian is not None:
            jacobian_map = self.jacobian(feature_map) ##[4,40,H/4-6, W/4-6]
            jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],
                                                final_shape[3])
            heatmap = heatmap.unsqueeze(2)

            jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6]
            jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)
            jacobian = jacobian.sum(dim=-1) #[4,10,4]
            jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2]
            out['jacobian'] = jacobian

        return out
'''


================================================
FILE: modules/model.py
================================================
from torch import nn
import torch
import torch.nn.functional as F
from modules.util import AntiAliasInterpolation2d, make_coordinate_grid
from torchvision import models
import numpy as np
from torch.autograd import grad


class Vgg19(torch.nn.Module):
    """
    Vgg19 network for perceptual loss. See Sec 3.3.
    """
    def __init__(self, requires_grad=False):
        super(Vgg19, self).__init__()
        vgg_pretrained_features = models.vgg19(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        for x in range(2):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(21, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])

        self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
                                       requires_grad=False)
        self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
                                      requires_grad=False)

        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        X = (X - self.mean) / self.std
        h_relu1 = self.slice1(X)
        h_relu2 = self.slice2(h_relu1)
        h_relu3 = self.slice3(h_relu2)
        h_relu4 = self.slice4(h_relu3)
        h_relu5 = self.slice5(h_relu4)
        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
        return out


class ImagePyramide(torch.nn.Module):
    """
    Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
    """
    def __init__(self, scales, num_channels):
        super(ImagePyramide, self).__init__()
        downs = {}
        for scale in scales:
            downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
        self.downs = nn.ModuleDict(downs)

    def forward(self, x):
        out_dict = {}
        for scale, down_module in self.downs.items():
            out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
        return out_dict


class Transform:
    """
    Random tps transformation for equivariance constraints. See Sec 3.3
    """
    def __init__(self, bs, **kwargs):
        noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))
        self.theta = noise + torch.eye(2, 3).view(1, 2, 3)
        self.bs = bs

        if ('sigma_tps' in kwargs) and ('points_tps' in kwargs):
            self.tps = True
            self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())
            self.control_points = self.control_points.unsqueeze(0)
            self.control_params = torch.normal(mean=0,
                                               std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))
        else:
            self.tps = False

    def transform_frame(self, frame):
        grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2]
        grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
        grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)
        return F.grid_sample(frame, grid, padding_mode="reflection")
    
    def inverse_transform_frame(self, frame):
        grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2]
        grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
        grid = self.inverse_warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)
        return F.grid_sample(frame, grid, padding_mode="reflection")
    
    def warp_coordinates(self, coordinates):
        theta = self.theta.type(coordinates.type())
        theta = theta.unsqueeze(1)
        transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]
        transformed = transformed.squeeze(-1)

        if self.tps:
            control_points = self.control_points.type(coordinates.type())
            control_params = self.control_params.type(coordinates.type())
            distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
            distances = torch.abs(distances).sum(-1)

            result = distances ** 2
            result = result * torch.log(distances + 1e-6)
            result = result * control_params
            result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
            transformed = transformed + result

        return transformed

    def inverse_warp_coordinates(self, coordinates):
        theta = self.theta.type(coordinates.type())
        theta = theta.unsqueeze(1)
        a = torch.FloatTensor([[[[0,0,1]]]]).repeat([self.bs,1,1,1]).cuda()
        c = torch.cat((theta,a),2)
        d = c.inverse()[:,:,:2,:]
        d = d.type(coordinates.type())
        transformed = torch.matmul(d[:, :, :, :2], coordinates.unsqueeze(-1)) + d[:, :, :, 2:]
        transformed = transformed.squeeze(-1)
        
        if self.tps:
            control_points = self.control_points.type(coordinates.type())
            control_params = self.control_params.type(coordinates.type())
            distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
            distances = torch.abs(distances).sum(-1)

            result = distances ** 2
            result = result * torch.log(distances + 1e-6)
            result = result * control_params
            result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
            transformed = transformed + result
        
        
        return transformed

    def jacobian(self, coordinates):
        coordinates.requires_grad=True
        new_coordinates = self.warp_coordinates(coordinates)#[4,10,2]
        grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True)
        grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True)
        jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2)
        return jacobian


def detach_kp(kp):
    return {key: value.detach() for key, value in kp.items()}

class TrainPart1Model(torch.nn.Module):
    """
    Merge all generator related updates into single model for better multi-gpu usage
    """

    def __init__(self, kp_extractor, kp_extractor_a, audio_feature, generator, discriminator, train_params, device_ids):
        super(TrainFullModel, self).__init__()
        self.kp_extractor = kp_extractor
        self.kp_extractor_a = kp_extractor_a

        self.audio_feature = audio_feature
        self.generator = generator
        self.discriminator = discriminator
        self.train_params = train_params
        self.scales = train_params['scales']
        self.disc_scales = self.discriminator.scales
        self.pyramid = ImagePyramide(self.scales, generator.num_channels)
        if torch.cuda.is_available():
            self.pyramid = self.pyramid.cuda()

        self.loss_weights = train_params['loss_weights']

        if sum(self.loss_weights['perceptual']) != 0:
            self.vgg = Vgg19()
            if torch.cuda.is_available():
                self.vgg = self.vgg.cuda()
        
     
        self.mse_loss_fn   =  nn.MSELoss().cuda()
    def forward(self, x):
 
        kp_source = self.kp_extractor(x['example_image'])

        kp_driving = []
        for i in range(16):
            kp_driving.append(self.kp_extractor(x['driving'][:,i]))

        kp_driving_a = [] #x['example_image'],
        deco_out = self.audio_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net'])
        loss_values = {}
        
        if self.loss_weights['audio'] != 0:
            
            kp_driving_a = []
            for i in range(16):
                kp_driving_a.append(self.kp_extractor_a(deco_out[:,i]))#
       
   
        loss_value = 0
        loss_heatmap = 0
        loss_jacobian = 0
        loss_perceptual = 0
        for i in range(len(kp_driving)):
            loss_jacobian += (torch.abs(kp_driving[i]['jacobian'] - kp_driving_a[i]['jacobian']).mean())*self.loss_weights['audio']
            
         #   loss_jacobian = loss_jacobian*self.loss_weights['audio']
            loss_heatmap += (torch.abs(kp_driving[i]['heatmap'] - kp_driving_a[i]['heatmap']).mean())*self.loss_weights['audio']*100
           
            
            loss_value += (torch.abs(kp_driving[i]['value'].detach() - kp_driving_a[i]['value']).mean())*self.loss_weights['audio']
           
        loss_values['loss_value'] = loss_value/len(kp_driving)
        loss_values['loss_heatmap'] = loss_heatmap/len(kp_driving)
        loss_values['loss_jacobian'] = loss_jacobian/len(kp_driving)

   
        if self.train_params['generator'] == 'not':
     #       loss_values['perceptual'] = self.mse_loss_fn(deco_out,deco_out)
            for i in range(1): #0,len(kp_driving),4
 
                generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving_a[i])
                generated.update({'kp_source': kp_source, 'kp_driving': kp_driving_a})
        elif self.train_params['generator'] == 'visual':
            for i in range(0,len(kp_driving),4): #0,len(kp_driving),4
 
                generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving[i])
                generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
                
                pyramide_real = self.pyramid(x['driving'][:,i])
                pyramide_generated = self.pyramid(generated['prediction'])
        
                if sum(self.loss_weights['perceptual']) != 0:
                    value_total = 0
                    for scale in self.scales:
                        x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
                        y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])

                        for i, weight in enumerate(self.loss_weights['perceptual']):
                            value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
                            value_total += self.loss_weights['perceptual'][i] * value
                    loss_perceptual += value_total
        
            length = int((len(kp_driving)-1)/4)+1
            loss_values['perceptual'] = loss_perceptual/length
        elif self.train_params['generator'] == 'audio':
            for i in range(0,len(kp_driving),4): #0,len(kp_driving),4
 
                generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving_a[i])
                generated.update({'kp_source': kp_source, 'kp_driving': kp_driving_a})
                
                pyramide_real = self.pyramid(x['driving'][:,i])
                pyramide_generated = self.pyramid(generated['prediction'])
        
                if sum(self.loss_weights['perceptual']) != 0:
                    value_total = 0
                    for scale in self.scales:
                        x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
                        y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])

                        for i, weight in enumerate(self.loss_weights['perceptual']):
                            value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
                            value_total += self.loss_weights['perceptual'][i] * value
                    loss_perceptual += value_total
        
            length = int((len(kp_driving)-1)/4)+1
            loss_values['perceptual'] = loss_perceptual/length
        else:
            print('wrong train_params: ', self.train_params['generator'])
      
        
      
        return loss_values,generated


class TrainPart2Model(torch.nn.Module):
    """
    Merge all generator related updates into single model for better multi-gpu usage
    """

    def __init__(self, kp_extractor, emo_feature, kp_extractor_a, audio_feature, generator, discriminator, train_params, device_ids):
        super(TrainFullModel, self).__init__()
        self.kp_extractor = kp_extractor
        self.kp_extractor_a = kp_extractor_a

        self.audio_feature = audio_feature
        self.emo_feature = emo_feature
        self.generator = generator
        self.discriminator = discriminator
        self.train_params = train_params
        self.scales = train_params['scales']
        self.disc_scales = self.discriminator.scales
        self.pyramid = ImagePyramide(self.scales, generator.num_channels)
        if torch.cuda.is_available():
            self.pyramid = self.pyramid.cuda()

        self.loss_weights = train_params['loss_weights']

        if sum(self.loss_weights['perceptual']) != 0:
            self.vgg = Vgg19()
            if torch.cuda.is_available():
                self.vgg = self.vgg.cuda()

        self.mse_loss_fn   =  nn.MSELoss().cuda()
        self.CroEn_loss =  nn.CrossEntropyLoss().cuda()
    def forward(self, x):
 
        kp_source = self.kp_extractor(x['example_image'])

        kp_driving = []
        kp_emo = []
        for i in range(16):
            kp_driving.append(self.kp_extractor(x['driving'][:,i]))
    #        kp_emo.append(self.emo_detector(x['driving'][:,i]))

        kp_driving_a = [] #x['example_image'],
        deco_out = self.audio_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net'])
    #    emo_out = self.emo_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net'])
        loss_values = {}

        if self.loss_weights['emo'] != 0:

            kp_driving_a = []
            fakes = []
            for i in range(16):
                kp_driving_a.append(self.kp_extractor_a(deco_out[:,i]))#
                value = self.kp_extractor_a(deco_out[:,i])['value']
                jacobian = self.kp_extractor_a(deco_out[:,i])['jacobian']
                if self.train_params['type'] == 'linear_4' :
                    out, fake = self.emo_feature(x['transformed_driving'][:,i],value,jacobian)
                    kp_emo.append(out)
                    fakes.append(fake)
                 #   kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian))
                elif self.train_params['type'] == 'linear_10':
                 #   kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))

                    out, fake = self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)
                    kp_emo.append(out)
                    fakes.append(fake)
                elif self.train_params['type'] == 'linear_4_new':
                 #   kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))

                    out, fake = self.emo_feature.linear_4(x['transformed_driving'][:,i],value,jacobian)
                    kp_emo.append(out)
                    fakes.append(fake)
                elif self.train_params['type'] == 'linear_np_4':
                 #   kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))

                    out, fake = self.emo_feature.linear_np_4(x['transformed_driving'][:,i],value,jacobian)
                    kp_emo.append(out)
                    fakes.append(fake)
                elif self.train_params['type'] == 'linear_np_10':
                 #   kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))

                    out, fake = self.emo_feature.linear_np_10(x['transformed_driving'][:,i],value,jacobian)
                    kp_emo.append(out)
                    fakes.append(fake)
          
        loss_value = 0

        loss_jacobian = 0

        loss_classify = 0
        kp_all = kp_driving_a
     
        for i in range(len(kp_driving)):
       
            if self.train_params['type'] == 'linear_4' or self.train_params['type'] == 'linear_4_new' or self.train_params['type'] == 'linear_np_4':
                loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,1] - kp_driving_a[i]['jacobian'][:,1] -kp_emo[i]['jacobian'][:,0]).mean())*self.loss_weights['emo']
                loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,4] - kp_driving_a[i]['jacobian'][:,4] -kp_emo[i]['jacobian'][:,1]).mean())*self.loss_weights['emo']
                loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,6] - kp_driving_a[i]['jacobian'][:,6] -kp_emo[i]['jacobian'][:,2]).mean())*self.loss_weights['emo']
                loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,8] - kp_driving_a[i]['jacobian'][:,8] -kp_emo[i]['jacobian'][:,3]).mean())*self.loss_weights['emo']

                loss_classify += self.CroEn_loss(fakes[i],x['emotion'])
                loss_value += (torch.abs(kp_driving[i]['value'][:,1] .detach() - kp_driving_a[i]['value'][:,1]  - kp_emo[i]['value'][:,0] ).mean())*self.loss_weights['emo']
                loss_value += (torch.abs(kp_driving[i]['value'][:,4] .detach() - kp_driving_a[i]['value'][:,4]  - kp_emo[i]['value'][:,1] ).mean())*self.loss_weights['emo']
                loss_value += (torch.abs(kp_driving[i]['value'][:,6] .detach() - kp_driving_a[i]['value'][:,6]  - kp_emo[i]['value'][:,2] ).mean())*self.loss_weights['emo']
                loss_value += (torch.abs(kp_driving[i]['value'][:,8] .detach() - kp_driving_a[i]['value'][:,8]  - kp_emo[i]['value'][:,3] ).mean())*self.loss_weights['emo']
                kp_all[i]['jacobian'][:,1] = kp_emo[i]['jacobian'][:,0] + kp_driving_a[i]['jacobian'][:,1]
                kp_all[i]['jacobian'][:,4] = kp_emo[i]['jacobian'][:,1] + kp_driving_a[i]['jacobian'][:,4]
                kp_all[i]['jacobian'][:,6] = kp_emo[i]['jacobian'][:,2] + kp_driving_a[i]['jacobian'][:,6]
                kp_all[i]['jacobian'][:,8] = kp_emo[i]['jacobian'][:,3] + kp_driving_a[i]['jacobian'][:,8]
                kp_all[i]['value'][:,1] = kp_emo[i]['value'][:,0] + kp_driving_a[i]['value'][:,1]
                kp_all[i]['value'][:,4] = kp_emo[i]['value'][:,1] + kp_driving_a[i]['value'][:,4]
                kp_all[i]['value'][:,6] = kp_emo[i]['value'][:,2] + kp_driving_a[i]['value'][:,6]
                kp_all[i]['value'][:,8] = kp_emo[i]['value'][:,3] + kp_driving_a[i]['value'][:,8]
            elif self.train_params['type'] == 'linear_10' or self.train_params['type'] == 'linear_np_10':
                loss_jacobian += (torch.abs(kp_driving[i]['jacobian'] - kp_driving_a[i]['jacobian'] -kp_emo[i]['jacobian']).mean())*self.loss_weights['emo']

                loss_classify += self.CroEn_loss(fakes[i],x['emotion'])
                loss_value += (torch.abs(kp_driving[i]['value'].detach() - kp_driving_a[i]['value']  - kp_emo[i]['value'] ).mean())*self.loss_weights['emo']

        #    kp_all[i]['value'] = kp_emo[i]['value'] + kp_driving_a[i]['value']

        loss_values['loss_value'] = loss_value/len(kp_driving)
  #      loss_values['loss_heatmap'] = loss_heatmap/len(kp_driving)
        loss_values['loss_jacobian'] = loss_jacobian/len(kp_driving)
        if self.train_params['classify'] == True:
            loss_values['loss_classify'] = loss_classify/len(kp_driving)
        else:
            loss_values['loss_classify'] = torch.tensor(0, device = loss_values['loss_value'].device)
        
        



        return loss_values,generated


class GeneratorFullModel(torch.nn.Module):
    """
    Merge all generator related updates into single model for better multi-gpu usage
    """

    def __init__(self, kp_extractor, kp_extractor_a, audio_feature, generator, discriminator, train_params):
        super(GeneratorFullModel, self).__init__()
        self.kp_extractor = kp_extractor
        self.kp_extractor_a = kp_extractor_a
    #    self.content_encoder = content_encoder
    #    self.emotion_encoder = emotion_encoder
        self.audio_feature = audio_feature
        self.generator = generator
        self.discriminator = discriminator
        self.train_params = train_params
        self.scales = train_params['scales']
        self.disc_scales = self.discriminator.scales
        self.pyramid = ImagePyramide(self.scales, generator.num_channels)
        if torch.cuda.is_available():
            self.pyramid = self.pyramid.cuda()

        self.loss_weights = train_params['loss_weights']

        if sum(self.loss_weights['perceptual']) != 0:
            self.vgg = Vgg19()
            if torch.cuda.is_available():
                self.vgg = self.vgg.cuda()
        
        self.pca = torch.FloatTensor(np.load('.../LRW/list/U_106.npy'))[:, :16].cuda()
        self.mean = torch.FloatTensor(np.load('.../LRW/list/mean_106.npy')).cuda()
        
    def forward(self, x):
   #     source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[])
      #  source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1)))
   #     kp_source = self.kp_extractor(x['source'])
   #     kp_source_a = self.kp_extractor_a(x['source'], x['source_cube'], source_a_f)
      #  driving_a_f = self.audio_feature(self.content_encoder(x['driving_audio'].unsqueeze(1)), self.emotion_encoder(x['driving_audio'].unsqueeze(1)))
      #  driving_a_f = self.audio_feature(x['driving_audio'])
      #  kp_driving = self.kp_extractor(x['driving'])
   #     kp_driving_a = self.kp_extractor_a(x['driving'], x['driving_cube'], driving_a_f)
       
        kp_driving = []
        for i in range(16):
            kp_driving.append(self.kp_extractor(x['driving'][:,i],x['driving_landmark'][:,i],self.loss_weights['equivariance_value']))
        
        kp_driving_a = []
        fc_out, deco_out = self.audio_feature(x['example_landmark'], x['driving_audio'], x['driving_pose'])
        fake_lmark=fc_out + x['example_landmark'].expand_as(fc_out)
        
      
        fake_lmark = torch.mm( fake_lmark, self.pca.t() )
        fake_lmark = fake_lmark + self.mean.expand_as(fake_lmark)
    

        fake_lmark = fake_lmark.unsqueeze(0) 

    #    for i in range(16):
    #        kp_driving_a.append()
        
   #     generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving)
   #     generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})

        loss_values = {}

        pyramide_real = self.pyramid(x['driving'])
        pyramide_generated = self.pyramid(generated['prediction'])
        
        if self.loss_weights['audio'] != 0:
            value = torch.abs(kp_source['jacobian'].detach() - kp_source_a['jacobian'].detach()).mean() + torch.abs(kp_driving['jacobian'].detach() - kp_driving_a['jacobian']).mean()
            value = value/2
            loss_values['jacobian'] = value*self.loss_weights['audio']
            value = torch.abs(kp_source['heatmap'].detach() - kp_source_a['heatmap'].detach()).mean() + torch.abs(kp_driving['heatmap'].detach() - kp_driving_a['heatmap']).mean()
            value = value/2
            loss_values['heatmap'] = value*self.loss_weights['audio']
            value = torch.abs(kp_source['value'].detach() - kp_source_a['value'].detach()).mean() + torch.abs(kp_driving['value'].detach() - kp_driving_a['value']).mean()
            value = value/2
            loss_values['value'] = value*self.loss_weights['audio']
            
        if sum(self.loss_weights['perceptual']) != 0:
            value_total = 0
            for scale in self.scales:
                x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
                y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])

                for i, weight in enumerate(self.loss_weights['perceptual']):
                    value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
                    value_total += self.loss_weights['perceptual'][i] * value
                loss_values['perceptual'] = value_total

        if self.loss_weights['generator_gan'] != 0:
            discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
            discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
            value_total = 0
            for scale in self.disc_scales:
                key = 'prediction_map_%s' % scale
                value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
                value_total += self.loss_weights['generator_gan'] * value
            loss_values['gen_gan'] = value_total

            if sum(self.loss_weights['feature_matching']) != 0:
                value_total = 0
                for scale in self.disc_scales:
                    key = 'feature_maps_%s' % scale
                    for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):
                        if self.loss_weights['feature_matching'][i] == 0:
                            continue
                        value = torch.abs(a - b).mean()
                        value_total += self.loss_weights['feature_matching'][i] * value
                    loss_values['feature_matching'] = value_total

        if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0:
            transform = Transform(x['driving'].shape[0], **self.train_params['transform_params'])
            transformed_frame = transform.transform_frame(x['driving'])
            transformed_landmark =  transform.inverse_warp_coordinates(x['driving_landmark'])
            transformed_kp = self.kp_extractor(transformed_frame)

            generated['transformed_frame'] = transformed_frame
            generated['transformed_kp'] = transformed_kp
            
            ## Value loss part
            if self.loss_weights['equivariance_value'] != 0:
                value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean()
                loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value

            ## jacobian loss part
            if self.loss_weights['equivariance_jacobian'] != 0:
                jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']),
                                                    transformed_kp['jacobian'])

                normed_driving = torch.inverse(kp_driving['jacobian'])
                normed_transformed = jacobian_transformed
                value = torch.matmul(normed_driving, normed_transformed)

                eye = torch.eye(2).view(1, 1, 2, 2).type(value.type())

                value = torch.abs(eye - value).mean()
                loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value

        return loss_values, generated


class DiscriminatorFullModel(torch.nn.Module):
    """
    Merge all discriminator related updates into single model for better multi-gpu usage
    """

    def __init__(self, kp_extractor, generator, discriminator, train_params):
        super(DiscriminatorFullModel, self).__init__()
        self.kp_extractor = kp_extractor
        self.generator = generator
        self.discriminator = discriminator
        self.train_params = train_params
        self.scales = self.discriminator.scales
        self.pyramid = ImagePyramide(self.scales, generator.num_channels)
        if torch.cuda.is_available():
            self.pyramid = self.pyramid.cuda()

        self.loss_weights = train_params['loss_weights']

    def forward(self, x, generated):
        pyramide_real = self.pyramid(x['driving'])
        pyramide_generated = self.pyramid(generated['prediction'].detach())

        kp_driving = generated['kp_driving']
        discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
        discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))

        loss_values = {}
        value_total = 0
        for scale in self.scales:
            key = 'prediction_map_%s' % scale
            value = (1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2
            value_total += self.loss_weights['discriminator_gan'] * value.mean()
        loss_values['disc_gan'] = value_total

        return loss_values


================================================
FILE: modules/model_delta_map.py
================================================
from torch import nn
import torch
import torch.nn.functional as F
from modules.util import AntiAliasInterpolation2d, make_coordinate_grid
from torchvision import models
import numpy as np
from torch.autograd import grad


class Vgg19(torch.nn.Module):
    """
    Vgg19 network for perceptual loss. See Sec 3.3.
    """
    def __init__(self, requires_grad=False):
        super(Vgg19, self).__init__()
        vgg_pretrained_features = models.vgg19(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        for x in range(2):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(21, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])

        self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
                                       requires_grad=False)
        self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
                                      requires_grad=False)

        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        X = (X - self.mean) / self.std
        h_relu1 = self.slice1(X)
        h_relu2 = self.slice2(h_relu1)
        h_relu3 = self.slice3(h_relu2)
        h_relu4 = self.slice4(h_relu3)
        h_relu5 = self.slice5(h_relu4)
        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
        return out


class ImagePyramide(torch.nn.Module):
    """
    Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
    """
    def __init__(self, scales, num_channels):
        super(ImagePyramide, self).__init__()
        downs = {}
        for scale in scales:
            downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
        self.downs = nn.ModuleDict(downs)

    def forward(self, x):
        out_dict = {}
        for scale, down_module in self.downs.items():
            out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
        return out_dict


class Transform:
    """
    Random tps transformation for equivariance constraints. See Sec 3.3
    """
    def __init__(self, bs, **kwargs):
        noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))
        self.theta = noise + torch.eye(2, 3).view(1, 2, 3)
        self.bs = bs

        if ('sigma_tps' in kwargs) and ('points_tps' in kwargs):
            self.tps = True
            self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())
            self.control_points = self.control_points.unsqueeze(0)
            self.control_params = torch.normal(mean=0,
                                               std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))
        else:
            self.tps = False

    def transform_frame(self, frame):
        grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2]
        grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
        grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)
        return F.grid_sample(frame, grid, padding_mode="reflection")
    
    def inverse_transform_frame(self, frame):
        grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2]
        grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
        grid = self.inverse_warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)
        return F.grid_sample(frame, grid, padding_mode="reflection")
    
    def warp_coordinates(self, coordinates):
        theta = self.theta.type(coordinates.type())
        theta = theta.unsqueeze(1)
        transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]
        transformed = transformed.squeeze(-1)

        if self.tps:
            control_points = self.control_points.type(coordinates.type())
            control_params = self.control_params.type(coordinates.type())
            distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
            distances = torch.abs(distances).sum(-1)

            result = distances ** 2
            result = result * torch.log(distances + 1e-6)
            result = result * control_params
            result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
            transformed = transformed + result

        return transformed

    def inverse_warp_coordinates(self, coordinates):
        theta = self.theta.type(coordinates.type())
        theta = theta.unsqueeze(1)
        a = torch.FloatTensor([[[[0,0,1]]]]).repeat([self.bs,1,1,1]).cuda()
        c = torch.cat((theta,a),2)
        d = c.inverse()[:,:,:2,:]
        d = d.type(coordinates.type())
        transformed = torch.matmul(d[:, :, :, :2], coordinates.unsqueeze(-1)) + d[:, :, :, 2:]
        transformed = transformed.squeeze(-1)
        
        if self.tps:
            control_points = self.control_points.type(coordinates.type())
            control_params = self.control_params.type(coordinates.type())
            distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
            distances = torch.abs(distances).sum(-1)

            result = distances ** 2
            result = result * torch.log(distances + 1e-6)
            result = result * control_params
            result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
            transformed = transformed + result
        
        
        return transformed

    def jacobian(self, coordinates):
        coordinates.requires_grad=True
        new_coordinates = self.warp_coordinates(coordinates)#[4,10,2]
        grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True)
        grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True)
        jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2)
        return jacobian


def detach_kp(kp):
    return {key: value.detach() for key, value in kp.items()}

class TrainFullModel(torch.nn.Module):
    """
    Merge all generator related updates into single model for better multi-gpu usage
    """

    def __init__(self, kp_extractor, emo_feature, kp_extractor_a, audio_feature, generator, discriminator, train_params, device_ids):
        super(TrainFullModel, self).__init__()
        self.kp_extractor = kp_extractor
        self.kp_extractor_a = kp_extractor_a
    #    self.emo_detector = emo_detector
    #    self.content_encoder = content_encoder
    #    self.emotion_encoder = emotion_encoder
        self.audio_feature = audio_feature
        self.emo_feature = emo_feature
        self.generator = generator
        self.discriminator = discriminator
        self.train_params = train_params
        self.scales = train_params['scales']
        self.disc_scales = self.discriminator.scales
        self.pyramid = ImagePyramide(self.scales, generator.num_channels)
        if torch.cuda.is_available():
            self.pyramid = self.pyramid.cuda()

        self.loss_weights = train_params['loss_weights']

        if sum(self.loss_weights['perceptual']) != 0:
            self.vgg = Vgg19()
            if torch.cuda.is_available():
                self.vgg = self.vgg.cuda()
        
       # self.pca = torch.FloatTensor(np.load('/mnt/lustre/jixinya/Home/LRW/list/U_106.npy'))[:, :16].to(device_ids[0])
      #  self.mean = torch.FloatTensor(np.load('/mnt/lustre/jixinya/Home/LRW/list/mean_106.npy')).to(device_ids[0])
        self.mse_loss_fn   =  nn.MSELoss().cuda()
        self.CroEn_loss =  nn.CrossEntropyLoss().cuda()
    def forward(self, x):
   #     source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[])
      #  source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1)))
        kp_source = self.kp_extractor(x['example_image'])

        kp_driving = []
        kp_emo = []
        for i in range(16):
            kp_driving.append(self.kp_extractor(x['driving'][:,i]))
    #        kp_emo.append(self.emo_detector(x['driving'][:,i]))
    #    print('KP_driving ', file=open('/mnt/lustre/jixinya/Home/fomm_audio/log/LRW_test.txt', 'a'))
        kp_driving_a = [] #x['example_image'],
        deco_out = self.audio_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net'])
    #    emo_out = self.emo_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net'])
        loss_values = {}
        
        if self.loss_weights['emo'] != 0:
            
            kp_driving_a = []
            fakes = []
            for i in range(16):
                kp_driving_a.append(self.kp_extractor_a(deco_out[:,i]))#
                value = self.kp_extractor_a(deco_out[:,i])['value']
                jacobian = self.kp_extractor_a(deco_out[:,i])['jacobian']
                if self.train_params['type'] == 'map_4':
                    out, fake = self.emo_feature.map_4(x['transformed_driving'][:,i],value,jacobian)   
                    kp_emo.append(out)
                    fakes.append(fake)
                 #   kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian))
                elif self.train_params['type'] == 'map_10':
                 #   kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))
                
                    out, fake = self.emo_feature(x['transformed_driving'][:,i],value,jacobian)   
                    kp_emo.append(out)
                    fakes.append(fake)
            #    kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian))
    #    print('Kp_audio_driving ', file=open('/mnt/lustre/jixinya/Home/fomm_audio/log/LRW_test.txt', 'a'))
        loss_value = 0
    #    loss_heatmap = 0
        loss_jacobian = 0
        loss_perceptual = 0
        loss_classify = 0
        kp_all = kp_driving_a
        for i in range(len(kp_driving)):
            if self.train_params['type'] == 'map_4':
                loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,1] - kp_driving_a[i]['jacobian'][:,1] -kp_emo[i]['jacobian'][:,0]).mean())*self.loss_weights['emo']
                loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,4] - kp_driving_a[i]['jacobian'][:,4] -kp_emo[i]['jacobian'][:,1]).mean())*self.loss_weights['emo']
                loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,6] - kp_driving_a[i]['jacobian'][:,6] -kp_emo[i]['jacobian'][:,2]).mean())*self.loss_weights['emo']
                loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,8] - kp_driving_a[i]['jacobian'][:,8] -kp_emo[i]['jacobian'][:,3]).mean())*self.loss_weights['emo']
        
                loss_classify += self.CroEn_loss(fakes[i],x['emotion'])
                loss_value += (torch.abs(kp_driving[i]['value'][:,1] .detach() - kp_driving_a[i]['value'][:,1]  - kp_emo[i]['value'][:,0] ).mean())*self.loss_weights['emo']
                loss_value += (torch.abs(kp_driving[i]['value'][:,4] .detach() - kp_driving_a[i]['value'][:,4]  - kp_emo[i]['value'][:,1] ).mean())*self.loss_weights['emo']
                loss_value += (torch.abs(kp_driving[i]['value'][:,6] .detach() - kp_driving_a[i]['value'][:,6]  - kp_emo[i]['value'][:,2] ).mean())*self.loss_weights['emo']
                loss_value += (torch.abs(kp_driving[i]['value'][:,8] .detach() - kp_driving_a[i]['value'][:,8]  - kp_emo[i]['value'][:,3] ).mean())*self.loss_weights['emo']
                kp_all[i]['jacobian'][:,1] = kp_emo[i]['jacobian'][:,0] + kp_driving_a[i]['jacobian'][:,1]
                kp_all[i]['jacobian'][:,4] = kp_emo[i]['jacobian'][:,1] + kp_driving_a[i]['jacobian'][:,4]
                kp_all[i]['jacobian'][:,6] = kp_emo[i]['jacobian'][:,2] + kp_driving_a[i]['jacobian'][:,6]
                kp_all[i]['jacobian'][:,8] = kp_emo[i]['jacobian'][:,3] + kp_driving_a[i]['jacobian'][:,8]
                kp_all[i]['value'][:,1] = kp_emo[i]['value'][:,0] + kp_driving_a[i]['value'][:,1]
                kp_all[i]['value'][:,4] = kp_emo[i]['value'][:,1] + kp_driving_a[i]['value'][:,4]
                kp_all[i]['value'][:,6] = kp_emo[i]['value'][:,2] + kp_driving_a[i]['value'][:,6]
                kp_all[i]['value'][:,8] = kp_emo[i]['value'][:,3] + kp_driving_a[i]['value'][:,8]
            elif self.train_params['type'] == 'map_10':
                loss_jacobian += (torch.abs(kp_driving[i]['jacobian'] - kp_driving_a[i]['jacobian'] -kp_emo[i]['jacobian']).mean())*self.loss_weights['emo']
        
                loss_classify += self.CroEn_loss(fakes[i],x['emotion'])
                loss_value += (torch.abs(kp_driving[i]['value'].detach() - kp_driving_a[i]['value']  - kp_emo[i]['value'] ).mean())*self.loss_weights['emo']
            
        #    kp_all[i]['value'] = kp_emo[i]['value'] + kp_driving_a[i]['value']
            
        loss_values['loss_value'] = loss_value/len(kp_driving)
  #      loss_values['loss_heatmap'] = loss_heatmap/len(kp_driving)
        loss_values['loss_jacobian'] = loss_jacobian/len(kp_driving)
        loss_values['loss_classify'] = loss_classify/len(kp_driving)
   
        if self.train_params['generator'] == 'not':
            loss_values['perceptual'] = self.mse_loss_fn(deco_out,deco_out)
            for i in range(1): #0,len(kp_driving),4
 
                generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_all[i])
                generated.update({'kp_source': kp_source, 'kp_driving': kp_all})
        elif self.train_params['generator'] == 'visual':
            for i in range(0,len(kp_driving),4): #0,len(kp_driving),4
 
                generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving[i])
                generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
                
                pyramide_real = self.pyramid(x['driving'][:,i])
                pyramide_generated = self.pyramid(generated['prediction'])
        
                if sum(self.loss_weights['perceptual']) != 0:
                    value_total = 0
                    for scale in self.scales:
                        x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
                        y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])

                        for i, weight in enumerate(self.loss_weights['perceptual']):
                            value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
                            value_total += self.loss_weights['perceptual'][i] * value
                    loss_perceptual += value_total
        
            length = int((len(kp_driving)-1)/4)+1
            loss_values['perceptual'] = loss_perceptual/length
        elif self.train_params['generator'] == 'audio':
            for i in range(0,len(kp_driving),4): #0,len(kp_driving),4
 
                generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving_a[i])
                generated.update({'kp_source': kp_source, 'kp_driving': kp_driving_a})
                
                pyramide_real = self.pyramid(x['driving'][:,i])
                pyramide_generated = self.pyramid(generated['prediction'])
        
                if sum(self.loss_weights['perceptual']) != 0:
                    value_total = 0
                    for scale in self.scales:
                        x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
                        y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])

                        for i, weight in enumerate(self.loss_weights['perceptual']):
                            value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
                            value_total += self.loss_weights['perceptual'][i] * value
                    loss_perceptual += value_total
        
            length = int((len(kp_driving)-1)/4)+1
            loss_values['perceptual'] = loss_perceptual/length
        else:
            print('wrong train_params: ', self.train_params['generator'])
      
        
      
        return loss_values,generated

class GeneratorFullModel(torch.nn.Module):
    """
    Merge all generator related updates into single model for better multi-gpu usage
    """

    def __init__(self, kp_extractor, kp_extractor_a, audio_feature, generator, discriminator, train_params):
        super(GeneratorFullModel, self).__init__()
        self.kp_extractor = kp_extractor
        self.kp_extractor_a = kp_extractor_a
    #    self.content_encoder = content_encoder
    #    self.emotion_encoder = emotion_encoder
        self.audio_feature = audio_feature
        self.generator = generator
        self.discriminator = discriminator
        self.train_params = train_params
        self.scales = train_params['scales']
        self.disc_scales = self.discriminator.scales
        self.pyramid = ImagePyramide(self.scales, generator.num_channels)
        if torch.cuda.is_available():
            self.pyramid = self.pyramid.cuda()

        self.loss_weights = train_params['loss_weights']

        if sum(self.loss_weights['perceptual']) != 0:
            self.vgg = Vgg19()
            if torch.cuda.is_available():
                self.vgg = self.vgg.cuda()
        
        self.pca = torch.FloatTensor(np.load('.../LRW/list/U_106.npy'))[:, :16].cuda()
        self.mean = torch.FloatTensor(np.load('.../LRW/list/mean_106.npy')).cuda()
        
    def forward(self, x):
   #     source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[])
      #  source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1)))
   #     kp_source = self.kp_extractor(x['source'])
   #     kp_source_a = self.kp_extractor_a(x['source'], x['source_cube'], source_a_f)
      #  driving_a_f = self.audio_feature(self.content_encoder(x['driving_audio'].unsqueeze(1)), self.emotion_encoder(x['driving_audio'].unsqueeze(1)))
      #  driving_a_f = self.audio_feature(x['driving_audio'])
      #  kp_driving = self.kp_extractor(x['driving'])
   #     kp_driving_a = self.kp_extractor_a(x['driving'], x['driving_cube'], driving_a_f)
       
        kp_driving = []
        for i in range(16):
            kp_driving.append(self.kp_extractor(x['driving'][:,i],x['driving_landmark'][:,i],self.loss_weights['equivariance_value']))
        
        kp_driving_a = []
        fc_out, deco_out = self.audio_feature(x['example_landmark'], x['driving_audio'], x['driving_pose'])
        fake_lmark=fc_out + x['example_landmark'].expand_as(fc_out)
        
      
        fake_lmark = torch.mm( fake_lmark, self.pca.t() )
        fake_lmark = fake_lmark + self.mean.expand_as(fake_lmark)
    

        fake_lmark = fake_lmark.unsqueeze(0) 

    #    for i in range(16):
    #        kp_driving_a.append()
        
   #     generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving)
   #     generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})

        loss_values = {}

        pyramide_real = self.pyramid(x['driving'])
        pyramide_generated = self.pyramid(generated['prediction'])
        
        if self.loss_weights['audio'] != 0:
            value = torch.abs(kp_source['jacobian'].detach() - kp_source_a['jacobian'].detach()).mean() + torch.abs(kp_driving['jacobian'].detach() - kp_driving_a['jacobian']).mean()
            value = value/2
            loss_values['jacobian'] = value*self.loss_weights['audio']
            value = torch.abs(kp_source['heatmap'].detach() - kp_source_a['heatmap'].detach()).mean() + torch.abs(kp_driving['heatmap'].detach() - kp_driving_a['heatmap']).mean()
            value = value/2
            loss_values['heatmap'] = value*self.loss_weights['audio']
            value = torch.abs(kp_source['value'].detach() - kp_source_a['value'].detach()).mean() + torch.abs(kp_driving['value'].detach() - kp_driving_a['value']).mean()
            value = value/2
            loss_values['value'] = value*self.loss_weights['audio']
            
        if sum(self.loss_weights['perceptual']) != 0:
            value_total = 0
            for scale in self.scales:
                x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
                y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])

                for i, weight in enumerate(self.loss_weights['perceptual']):
                    value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
                    value_total += self.loss_weights['perceptual'][i] * value
                loss_values['perceptual'] = value_total

        if self.loss_weights['generator_gan'] != 0:
            discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
            discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
            value_total = 0
            for scale in self.disc_scales:
                key = 'prediction_map_%s' % scale
                value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
                value_total += self.loss_weights['generator_gan'] * value
            loss_values['gen_gan'] = value_total

            if sum(self.loss_weights['feature_matching']) != 0:
                value_total = 0
                for scale in self.disc_scales:
                    key = 'feature_maps_%s' % scale
                    for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):
                        if self.loss_weights['feature_matching'][i] == 0:
                            continue
                        value = torch.abs(a - b).mean()
                        value_total += self.loss_weights['feature_matching'][i] * value
                    loss_values['feature_matching'] = value_total

        if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0:
            transform = Transform(x['driving'].shape[0], **self.train_params['transform_params'])
            transformed_frame = transform.transform_frame(x['driving'])
            transformed_landmark =  transform.inverse_warp_coordinates(x['driving_landmark'])
            transformed_kp = self.kp_extractor(transformed_frame)

            generated['transformed_frame'] = transformed_frame
            generated['transformed_kp'] = transformed_kp
            
            ## Value loss part
            if self.loss_weights['equivariance_value'] != 0:
                value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean()
                loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value

            ## jacobian loss part
            if self.loss_weights['equivariance_jacobian'] != 0:
                jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']),
                                                    transformed_kp['jacobian'])

                normed_driving = torch.inverse(kp_driving['jacobian'])
                normed_transformed = jacobian_transformed
                value = torch.matmul(normed_driving, normed_transformed)

                eye = torch.eye(2).view(1, 1, 2, 2).type(value.type())

                value = torch.abs(eye - value).mean()
                loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value

        return loss_values, generated


class DiscriminatorFullModel(torch.nn.Module):
    """
    Merge all discriminator related updates into single model for better multi-gpu usage
    """

    def __init__(self, kp_extractor, generator, discriminator, train_params):
        super(DiscriminatorFullModel, self).__init__()
        self.kp_extractor = kp_extractor
        self.generator = generator
        self.discriminator = discriminator
        self.train_params = train_params
        self.scales = self.discriminator.scales
        self.pyramid = ImagePyramide(self.scales, generator.num_channels)
        if torch.cuda.is_available():
            self.pyramid = self.pyramid.cuda()

        self.loss_weights = train_params['loss_weights']

    def forward(self, x, generated):
        pyramide_real = self.pyramid(x['driving'])
        pyramide_generated = self.pyramid(generated['prediction'].detach())

        kp_driving = generated['kp_driving']
        discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
        discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))

        loss_values = {}
        value_total = 0
        for scale in self.scales:
            key = 'prediction_map_%s' % scale
            value = (1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2
            value_total += self.loss_weights['discriminator_gan'] * value.mean()
        loss_values['disc_gan'] = value_total

        return loss_values


================================================
FILE: modules/model_gen.py
================================================
from torch import nn
import torch
import torch.nn.functional as F
from modules.util import AntiAliasInterpolation2d, make_coordinate_grid
from torchvision import models
import numpy as np
from torch.autograd import grad


class Vgg19(torch.nn.Module):
    """
    Vgg19 network for perceptual loss. See Sec 3.3.
    """
    def __init__(self, requires_grad=False):
        super(Vgg19, self).__init__()
        vgg_pretrained_features = models.vgg19(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        for x in range(2):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(21, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])

        self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
                                       requires_grad=False)
        self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
                                      requires_grad=False)

        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        X = (X - self.mean) / self.std
        h_relu1 = self.slice1(X)
        h_relu2 = self.slice2(h_relu1)
        h_relu3 = self.slice3(h_relu2)
        h_relu4 = self.slice4(h_relu3)
        h_relu5 = self.slice5(h_relu4)
        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
        return out


class ImagePyramide(torch.nn.Module):
    """
    Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
    """
    def __init__(self, scales, num_channels):
        super(ImagePyramide, self).__init__()
        downs = {}
        for scale in scales:
            downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
        self.downs = nn.ModuleDict(downs)

    def forward(self, x):
        out_dict = {}
        for scale, down_module in self.downs.items():
            out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
        return out_dict


class Transform:
    """
    Random tps transformation for equivariance constraints. See Sec 3.3
    """
    def __init__(self, bs, **kwargs):
        noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))
        self.theta = noise + torch.eye(2, 3).view(1, 2, 3)
        self.bs = bs

        if ('sigma_tps' in kwargs) and ('points_tps' in kwargs):
            self.tps = True
            self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())
            self.control_points = self.control_points.unsqueeze(0)
            self.control_params = torch.normal(mean=0,
                                               std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))
        else:
            self.tps = False

    def transform_frame(self, frame):
        grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2]
        grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
        grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)
        return F.grid_sample(frame, grid, padding_mode="reflection")

    def inverse_transform_frame(self, frame):
        grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2]
        grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
        grid = self.inverse_warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)
        return F.grid_sample(frame, grid, padding_mode="reflection")

    def warp_coordinates(self, coordinates):
        theta = self.theta.type(coordinates.type())
        theta = theta.unsqueeze(1)
        transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]
        transformed = transformed.squeeze(-1)

        if self.tps:
            control_points = self.control_points.type(coordinates.type())
            control_params = self.control_params.type(coordinates.type())
            distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
            distances = torch.abs(distances).sum(-1)

            result = distances ** 2
            result = result * torch.log(distances + 1e-6)
            result = result * control_params
            result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
            transformed = transformed + result

        return transformed

    def inverse_warp_coordinates(self, coordinates):
        theta = self.theta.type(coordinates.type())
        theta = theta.unsqueeze(1)
        a = torch.FloatTensor([[[[0,0,1]]]]).repeat([self.bs,1,1,1]).cuda()
        c = torch.cat((theta,a),2)
        d = c.inverse()[:,:,:2,:]
        d = d.type(coordinates.type())
        transformed = torch.matmul(d[:, :, :, :2], coordinates.unsqueeze(-1)) + d[:, :, :, 2:]
        transformed = transformed.squeeze(-1)

        if self.tps:
            control_points = self.control_points.type(coordinates.type())
            control_params = self.control_params.type(coordinates.type())
            distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
            distances = torch.abs(distances).sum(-1)

            result = distances ** 2
            result = result * torch.log(distances + 1e-6)
            result = result * control_params
            result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
            transformed = transformed + result


        return transformed

    def jacobian(self, coordinates):
        coordinates.requires_grad=True
        new_coordinates = self.warp_coordinates(coordinates)#[4,10,2]
        grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True)
        grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True)
        jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2)
        return jacobian


def detach_kp(kp):
    return {key: value.detach() for key, value in kp.items()}

class TrainFullModel(torch.nn.Module):
    """
    Merge all generator related updates into single model for better multi-gpu usage
    """

    def __init__(self, kp_extractor, emo_feature, kp_extractor_a, audio_feature, generator, discriminator, train_params, device_ids):
        super(TrainFullModel, self).__init__()
        self.kp_extractor = kp_extractor
        self.kp_extractor_a = kp_extractor_a
    #    self.emo_detector = emo_detector
    #    self.content_encoder = content_encoder
    #    self.emotion_encoder = emotion_encoder
        self.audio_feature = audio_feature
        self.emo_feature = emo_feature
        self.generator = generator
        self.discriminator = discriminator
        self.train_params = train_params
        self.scales = train_params['scales']
        self.disc_scales = self.discriminator.scales
        self.pyramid = ImagePyramide(self.scales, generator.num_channels)
        if torch.cuda.is_available():
            self.pyramid = self.pyramid.cuda()

        self.loss_weights = train_params['loss_weights']

        if sum(self.loss_weights['perceptual']) != 0:
            self.vgg = Vgg19()
            if torch.cuda.is_available():
                self.vgg = self.vgg.cuda()

       # self.pca = torch.FloatTensor(np.load('/mnt/lustre/jixinya/Home/LRW/list/U_106.npy'))[:, :16].to(device_ids[0])
      #  self.mean = torch.FloatTensor(np.load('/mnt/lustre/jixinya/Home/LRW/list/mean_106.npy')).to(device_ids[0])
        self.mse_loss_fn   =  nn.MSELoss().cuda()
        self.CroEn_loss =  nn.CrossEntropyLoss().cuda()
    def forward(self, x):
   #     source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[])
      #  source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1)))
        kp_source = self.kp_extractor(x['example_image'])
      #  print(x['name'],len(x['name']))
        kp_driving = []
        kp_emo = []
        for i in range(16):
            kp_driving.append(self.kp_extractor(x['driving'][:,i]))
    #        kp_emo.append(self.emo_detector(x['driving'][:,i]))
    #    print('KP_driving ', file=open('/mnt/lustre/jixinya/Home/fomm_audio/log/LRW_test.txt', 'a'))
        kp_driving_a = [] #x['example_image'],
        deco_out = self.audio_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net'])
    #    emo_out = self.emo_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net'])
        loss_values = {}

        if self.loss_weights['emo'] != 0:

            kp_driving_a = []
            fakes = []
            for i in range(16):
                kp_driving_a.append(self.kp_extractor_a(deco_out[:,i]))#
                value = self.kp_extractor_a(deco_out[:,i])['value']
                jacobian = self.kp_extractor_a(deco_out[:,i])['jacobian']
                if self.train_params['type'] == 'linear_4' and x['name'][0] == 0:
                    out, fake = self.emo_feature(x['transformed_driving'][:,i],value,jacobian)
                    kp_emo.append(out)
                    fakes.append(fake)
                 #   kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian))
                elif self.train_params['type'] == 'linear_10' and x['name'][0] == 0:
                 #   kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))

                    out, fake = self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)
                    kp_emo.append(out)
                    fakes.append(fake)
                elif self.train_params['type'] == 'linear_4_new' and x['name'][0] == 0:
                 #   kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))

                    out, fake = self.emo_feature.linear_4(x['transformed_driving'][:,i],value,jacobian)
                    kp_emo.append(out)
                    fakes.append(fake)
                elif self.train_params['type'] == 'linear_np_4':
                 #   kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))

                    out, fake = self.emo_feature.linear_np_4(x['transformed_driving'][:,i],value,jacobian)
                    kp_emo.append(out)
                    fakes.append(fake)
                elif self.train_params['type'] == 'linear_np_10':
                 #   kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))

                    out, fake = self.emo_feature.linear_np_10(x['transformed_driving'][:,i],value,jacobian)
                    kp_emo.append(out)
                    fakes.append(fake)
            #    kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian))
    #    print('Kp_audio_driving ', file=open('/mnt/lustre/jixinya/Home/fomm_audio/log/LRW_test.txt', 'a'))

        loss_perceptual = 0

        kp_all = kp_driving_a
        if self.train_params['smooth'] == True:
            value_all = torch.randn(len(kp_driving),out['value'].shape[0],out['value'].shape[1],out['value'].shape[2]).cuda()
            jacobian_all = torch.randn(len(kp_driving),out['jacobian'].shape[0],out['jacobian'].shape[1],2,2).cuda()
        print(len(kp_driving))
        for i in range(len(kp_driving)):
          #  if x['name'][i] == 'LRW':
          #      loss_jacobian += (torch.abs(kp_driving[i]['jacobian'] - kp_driving_a[i]['jacobian']).mean())*self.loss_weights['emo']

          #      loss_value += (torch.abs(kp_driving[i]['value'].detach() - kp_driving_a[i]['value']).mean())*self.loss_weights['emo']
          #      loss_classify += self.mse_loss_fn(deco_out,deco_out)
            if self.train_params['type'] == 'linear_4' and x['name'][0] == 0:

                kp_all[i]['jacobian'][:,1] = kp_emo[i]['jacobian'][:,0] + kp_driving_a[i]['jacobian'][:,1]
                kp_all[i]['jacobian'][:,4] = kp_emo[i]['jacobian'][:,1] + kp_driving_a[i]['jacobian'][:,4]
                kp_all[i]['jacobian'][:,6] = kp_emo[i]['jacobian'][:,2] + kp_driving_a[i]['jacobian'][:,6]
                kp_all[i]['jacobian'][:,8] = kp_emo[i]['jacobian'][:,3] + kp_driving_a[i]['jacobian'][:,8]
                kp_all[i]['value'][:,1] = kp_emo[i]['value'][:,0] + kp_driving_a[i]['value'][:,1]
                kp_all[i]['value'][:,4] = kp_emo[i]['value'][:,1] + kp_driving_a[i]['value'][:,4]
                kp_all[i]['value'][:,6] = kp_emo[i]['value'][:,2] + kp_driving_a[i]['value'][:,6]
                kp_all[i]['value'][:,8] = kp_emo[i]['value'][:,3] + kp_driving_a[i]['value'][:,8]

        #    kp_all[i]['value'] = kp_emo[i]['value'] + kp_driving_a[i]['value']


        if self.train_params['smooth'] == True:
            loss_smooth = 0
            loss_smooth += (torch.abs(value_all[2:,:,:,:] + value_all[:-2,:,:,:].detach() -2*value_all[1:-1,:,:,:].detach()).mean())*self.loss_weights['emo'] *100
            loss_smooth += (torch.abs(jacobian_all[2:,:,:,:] + jacobian_all[:-2,:,:,:].detach() -2*jacobian_all[1:-1,:,:,:].detach()).mean())*self.loss_weights['emo'] *100
            loss_values['loss_smooth'] = loss_smooth/len(kp_driving)
        else:
            loss_values['loss_smooth'] = self.mse_loss_fn(deco_out,deco_out)
        if self.train_params['generator'] == 'not':
            loss_values['perceptual'] = self.mse_loss_fn(deco_out,deco_out)
            for i in range(1): #0,len(kp_driving),4

                generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_all[i])
                generated.update({'kp_source': kp_source, 'kp_driving': kp_all})
        elif self.train_params['generator'] == 'visual':
            for i in range(0,len(kp_driving),4): #0,len(kp_driving),4

                generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving[i])
                generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})

                pyramide_real = self.pyramid(x['driving'][:,i])
                pyramide_generated = self.pyramid(generated['prediction'])

                if sum(self.loss_weights['perceptual']) != 0:
                    value_total = 0
                    for scale in self.scales:
                        x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
                        y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])

                        for i, weight in enumerate(self.loss_weights['perceptual']):
                            value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
                            value_total += self.loss_weights['perceptual'][i] * value
                    loss_perceptual += value_total

            length = int((len(kp_driving)-1)/4)+1
            loss_values['perceptual'] = loss_perceptual/length
        elif self.train_params['generator'] == 'audio':
            for i in range(0,len(kp_driving),4): #0,len(kp_driving),4

                generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_all[i])
                generated.update({'kp_source': kp_source, 'kp_driving': kp_all})

                pyramide_real = self.pyramid(x['driving'][:,i])
                pyramide_generated = self.pyramid(generated['prediction'])
            #    loss_mse = nn.MSELoss(generated['prediction'],x['driving'][:,i])
                if sum(self.loss_weights['perceptual']) != 0:
                    value_total = 0
                    for scale in self.scales:
                        x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
                        y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])

                        for i, weight in enumerate(self.loss_weights['perceptual']):
                            value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
                            value_total += self.loss_weights['perceptual'][i] * value
                    loss_perceptual += value_total

            length = int((len(kp_driving)-1)/4)+1
            loss_values['perceptual'] = loss_perceptual/length
      #      loss_values['mse'] = loss_mse/length
            
        else:
            print('wrong train_params: ', self.train_params['generator'])



        return loss_values,generated

class GeneratorFullModel(torch.nn.Module):
    """
    Merge all generator related updates into single model for better multi-gpu usage
    """

    def __init__(self, kp_extractor, kp_extractor_a, audio_feature, generator, discriminator, train_params):
        super(GeneratorFullModel, self).__init__()
        self.kp_extractor = kp_extractor
        self.kp_extractor_a = kp_extractor_a
    #    self.content_encoder = content_encoder
    #    self.emotion_encoder = emotion_encoder
        self.audio_feature = audio_feature
        self.generator = generator
        self.discriminator = discriminator
        self.train_params = train_params
        self.scales = train_params['scales']
        self.disc_scales = self.discriminator.scales
        self.pyramid = ImagePyramide(self.scales, generator.num_channels)
        if torch.cuda.is_available():
            self.pyramid = self.pyramid.cuda()

        self.loss_weights = train_params['loss_weights']

        if sum(self.loss_weights['perceptual']) != 0:
            self.vgg = Vgg19()
            if torch.cuda.is_available():
                self.vgg = self.vgg.cuda()

        self.pca = torch.FloatTensor(np.load('.../LRW/list/U_106.npy'))[:, :16].cuda()
        self.mean = torch.FloatTensor(np.load('.../LRW/list/mean_106.npy')).cuda()

    def forward(self, x):
   #     source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[])
      #  source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1)))
   # 
Download .txt
gitextract_7pp944hx/

├── 3DDFA_V2/
│   ├── demo.py
│   └── utils/
│       └── pose.py
├── LICENSE
├── M003_template.npy
├── README.md
├── augmentation.py
├── config/
│   ├── MEAD_emo_video_aug_delta_4_crop_random_crop.yaml
│   ├── train_part1.yaml
│   ├── train_part1_fine_tune.yaml
│   └── train_part2.yaml
├── dataset/
│   ├── LRW/
│   │   ├── MFCC/
│   │   │   └── ABOUT/
│   │   │       └── ABOUT_00001.npy
│   │   └── Pose/
│   │       └── ABOUT/
│   │           └── ABOUT_00001.npy
│   └── MEAD/
│       └── list/
│           └── MEAD_fomm_neu_dic_crop.npy
├── demo.py
├── filter1.py
├── frames_dataset.py
├── logger.py
├── modules/
│   ├── dense_motion.py
│   ├── discriminator.py
│   ├── function.py
│   ├── generator.py
│   ├── keypoint_detector.py
│   ├── model.py
│   ├── model_delta_map.py
│   ├── model_gen.py
│   ├── ops.py
│   ├── stylegan2.py
│   └── util.py
├── ops.py
├── process_data.py
├── requirements.txt
├── run.py
├── sync_batchnorm/
│   ├── __init__.py
│   ├── batchnorm.py
│   ├── comm.py
│   ├── replicate.py
│   └── unittest.py
├── test/
│   ├── pose/
│   │   ├── 14.npy
│   │   ├── 21.npy
│   │   ├── 60.npy
│   │   ├── 7.npy
│   │   ├── anne.npy
│   │   ├── brade2.npy
│   │   ├── dune_1.npy
│   │   ├── dune_2.npy
│   │   ├── jake4.npy
│   │   ├── mona.npy
│   │   ├── paint1.npy
│   │   └── scarlett.npy
│   └── pose_long/
│       ├── 0zn70Ak8lRc_Daniel_Auteuil_0zn70Ak8lRc_0002.npy
│       ├── 1hEr7qKRKL4_Daniel_Dae_Kim_1hEr7qKRKL4_0004.npy
│       └── 50IAfJCypFI_Alex_Kingston_50IAfJCypFI_0001.npy
└── train.py
Download .txt
SYMBOL INDEX (451 symbols across 25 files)

FILE: 3DDFA_V2/demo.py
  function main (line 29) | def main(args,img, save_path, pose_path):
  function process_word (line 102) | def process_word(i):

FILE: 3DDFA_V2/utils/pose.py
  function P2sRt (line 18) | def P2sRt(P):
  function matrix2angle (line 39) | def matrix2angle(R):
  function angle2matrix (line 65) | def angle2matrix(theta):
  function angle2matrix_3ddfa (line 112) | def angle2matrix_3ddfa(angles):
  function calc_pose (line 140) | def calc_pose(param):
  function build_camera_box (line 150) | def build_camera_box(rear_size=90):
  function plot_pose_box (line 171) | def plot_pose_box(img, P, ver, color=(40, 255, 0), line_width=2):
  function viz_pose (line 201) | def viz_pose(img, param_lst, ver_lst, show_flag=False, wfp=None):
  function pose_6 (line 217) | def pose_6(param):
  function smooth_pose (line 231) | def smooth_pose(img, param_lst, ver_lst, pose_new, show_flag=False, wfp=...
  function get_pose (line 263) | def get_pose(img, param_lst, ver_lst, show_flag=False, wfp=None, wnp = N...

FILE: augmentation.py
  function crop_clip (line 20) | def crop_clip(clip, min_h, min_w, h, w):
  function pad_clip (line 34) | def pad_clip(clip, h, w):
  function resize_clip (line 42) | def resize_clip(clip, size, interpolation='bilinear'):
  function get_resize_sizes (line 81) | def get_resize_sizes(im_h, im_w, size):
  class RandomFlip (line 91) | class RandomFlip(object):
    method __init__ (line 92) | def __init__(self, time_flip=False, horizontal_flip=False):
    method __call__ (line 96) | def __call__(self, clip):
  class RandomResize (line 105) | class RandomResize(object):
    method __init__ (line 115) | def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
    method __call__ (line 119) | def __call__(self, clip):
  class RandomCrop (line 136) | class RandomCrop(object):
    method __init__ (line 143) | def __init__(self, size):
    method __call__ (line 149) | def __call__(self, clip):
  class MouthCrop (line 175) | class MouthCrop(object):
    method __init__ (line 182) | def __init__(self, center_x, center_y, mask_width, mask_height):
    method __call__ (line 190) | def __call__(self, clip):
  class RandomRotation (line 215) | class RandomRotation(object):
    method __init__ (line 224) | def __init__(self, degrees):
    method __call__ (line 237) | def __call__(self, clip):
  class RandomPerspective (line 256) | class RandomPerspective(object):
    method __init__ (line 265) | def __init__(self, pers_num, enlarge_num):
    method __call__ (line 269) | def __call__(self, clip):
  class ColorJitter (line 297) | class ColorJitter(object):
    method __init__ (line 310) | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
    method get_params (line 316) | def get_params(self, brightness, contrast, saturation, hue):
    method __call__ (line 341) | def __call__(self, clip):
  class AllAugmentationTransform (line 403) | class AllAugmentationTransform:
    method __init__ (line 404) | def __init__(self, crop_mouth_param = None, resize_param=None, rotatio...
    method __call__ (line 427) | def __call__(self, clip):

FILE: demo.py
  function load_checkpoints (line 49) | def load_checkpoints(opt, checkpoint_path, audio_checkpoint_path, emo_ch...
  function normalize_kp (line 112) | def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_moveme...
  function shape_to_np (line 134) | def shape_to_np(shape, dtype="int"):
  function get_aligned_image (line 146) | def get_aligned_image(driving_video, opt):
  function get_transformed_image (line 184) | def get_transformed_image(driving_video, opt):
  function make_animation_smooth (line 194) | def make_animation_smooth(source_image, driving_video, transformed_video...
  function test_auido (line 286) | def test_auido(example_image, audio_feature, all_pose, opt):
  function save (line 357) | def save(path, frames, format):
  class VideoWriter (line 370) | class VideoWriter(object):
    method __init__ (line 371) | def __init__(self, path, width, height, fps):
    method write_frame (line 376) | def write_frame(self, frame):
    method end (line 379) | def end(self):
  function concatenate (line 382) | def concatenate(number, imgs, save_path):
  function add_audio (line 427) | def add_audio(video_name=None, audio_dir = None):
  function crop_image (line 433) | def crop_image(source_image):
  function smooth_pose (line 456) | def smooth_pose(pose_file, pose_long):
  function test (line 467) | def test(opt, name):

FILE: filter1.py
  class LowPassFilter (line 13) | class LowPassFilter:
    method __init__ (line 14) | def __init__(self):
    method process (line 18) | def process(self, value, alpha):
  class OneEuroFilter (line 28) | class OneEuroFilter:
    method __init__ (line 29) | def __init__(self, mincutoff=1.0, beta=0.0, dcutoff=1.0, freq=30):
    method compute_alpha (line 37) | def compute_alpha(self, cutoff):
    method process (line 42) | def process(self, x):

FILE: frames_dataset.py
  function read_video (line 15) | def read_video(name, frame_shape):
  function get_list (line 55) | def get_list(ipath,base_name):
  class AudioDataset (line 75) | class AudioDataset(Dataset):
    method __init__ (line 83) | def __init__(self, name, root_dir, frame_shape=(256, 256, 3), id_sampl...
    method __len__ (line 132) | def __len__(self):
    method __getitem__ (line 135) | def __getitem__(self, idx):
  class VoxDataset (line 196) | class VoxDataset(Dataset):
    method __init__ (line 204) | def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=Fa...
    method __len__ (line 252) | def __len__(self):
    method __getitem__ (line 255) | def __getitem__(self, idx):
  class MeadDataset (line 328) | class MeadDataset(Dataset):
    method __init__ (line 336) | def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=Fa...
    method __len__ (line 378) | def __len__(self):
    method __getitem__ (line 381) | def __getitem__(self, idx):
  class DatasetRepeater (line 461) | class DatasetRepeater(Dataset):
    method __init__ (line 466) | def __init__(self, dataset, num_repeats=100):
    method __len__ (line 471) | def __len__(self):
    method __getitem__ (line 474) | def __getitem__(self, idx):
  class TestsetRepeater (line 481) | class TestsetRepeater(Dataset):
    method __init__ (line 486) | def __init__(self, dataset, num_repeats=100):
    method __len__ (line 491) | def __len__(self):
    method __getitem__ (line 494) | def __getitem__(self, idx):
  class PairedDataset (line 499) | class PairedDataset(Dataset):
    method __init__ (line 504) | def __init__(self, initial_dataset, number_of_pairs, seed=0):
    method __len__ (line 529) | def __len__(self):
    method __getitem__ (line 532) | def __getitem__(self, idx):

FILE: logger.py
  class Logger (line 13) | class Logger:
    method __init__ (line 14) | def __init__(self, log_dir, checkpoint_freq=100, visualizer_params=Non...
    method log_scores (line 29) | def log_scores(self, loss_names):
    method visualize_rec (line 39) | def visualize_rec(self, inp, out):
    method save_cpk (line 44) | def save_cpk(self, emergent=False):
    method load_cpk (line 53) | def load_cpk(checkpoint_path, generator=None, discriminator=None, kp_d...
    method __enter__ (line 83) | def __enter__(self):
    method __exit__ (line 86) | def __exit__(self, exc_type, exc_val, exc_tb):
    method log_iter (line 91) | def log_iter(self, losses):
    method log_epoch (line 97) | def log_epoch(self, epoch, step, models, inp, out):
  class Visualizer (line 107) | class Visualizer:
    method __init__ (line 108) | def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbo...
    method draw_image_with_kp (line 113) | def draw_image_with_kp(self, image, kp_array):
    method create_image_column_with_kp (line 123) | def create_image_column_with_kp(self, images, kp):
    method create_image_column (line 127) | def create_image_column(self, images):
    method create_image_grid (line 134) | def create_image_grid(self, *args):
    method visualize (line 143) | def visualize(self, driving, transformed_driving, source, out):

FILE: modules/dense_motion.py
  class DenseMotionNetwork (line 7) | class DenseMotionNetwork(nn.Module):
    method __init__ (line 12) | def __init__(self, block_expansion, num_blocks, max_features, num_kp, ...
    method create_heatmap_representations (line 32) | def create_heatmap_representations(self, source_image, kp_driving, kp_...
    method create_sparse_motions (line 47) | def create_sparse_motions(self, source_image, kp_driving, kp_source):
    method create_deformed_source_image (line 69) | def create_deformed_source_image(self, source_image, sparse_motions):
    method forward (line 81) | def forward(self, source_image, kp_driving, kp_source):

FILE: modules/discriminator.py
  class DownBlock2d (line 7) | class DownBlock2d(nn.Module):
    method __init__ (line 12) | def __init__(self, in_features, out_features, norm=False, kernel_size=...
    method forward (line 25) | def forward(self, x):
  class Discriminator (line 36) | class Discriminator(nn.Module):
    method __init__ (line 41) | def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, m...
    method forward (line 59) | def forward(self, x, kp=None):
  class MultiScaleDiscriminator (line 74) | class MultiScaleDiscriminator(nn.Module):
    method __init__ (line 79) | def __init__(self, scales=(), **kwargs):
    method forward (line 87) | def forward(self, x, kp=None):

FILE: modules/function.py
  function calc_mean_std (line 12) | def calc_mean_std(feat, eps=1e-5):
  function adaptive_instance_normalization (line 23) | def adaptive_instance_normalization(content_feat, style_feat):
  function _calc_feat_flatten_mean_std (line 34) | def _calc_feat_flatten_mean_std(feat):
  function _mat_sqrt (line 44) | def _mat_sqrt(x):
  function coral (line 49) | def coral(source, target):

FILE: modules/generator.py
  class OcclusionAwareGenerator (line 8) | class OcclusionAwareGenerator(nn.Module):
    method __init__ (line 14) | def __init__(self, num_channels, num_kp, block_expansion, max_features...
    method deform_input (line 50) | def deform_input(self, inp, deformation):
    method forward (line 59) | def forward(self, source_image, kp_driving, kp_source):

FILE: modules/keypoint_detector.py
  class KPDetector (line 7) | class KPDetector(nn.Module):
    method __init__ (line 12) | def __init__(self, block_expansion, num_kp, num_channels, max_features,
    method gaussian2kp (line 40) | def gaussian2kp(self, heatmap):
    method audio_feature (line 52) | def audio_feature(self, x, heatmap):
    method forward (line 77) | def forward(self, x): #torch.Size([4, 3, H, W])
  class KPDetector_a (line 110) | class KPDetector_a(nn.Module):
    method __init__ (line 115) | def __init__(self, block_expansion, num_kp, num_channels,num_channels_...
    method gaussian2kp (line 143) | def gaussian2kp(self, heatmap):
    method audio_feature (line 155) | def audio_feature(self, x, heatmap):
    method forward (line 180) | def forward(self,  feature_map): #torch.Size([4, 3, H, W])
  class Audio_Feature (line 208) | class Audio_Feature(nn.Module):
    method __init__ (line 209) | def __init__(self):
    method forward (line 218) | def forward(self, x):

FILE: modules/model.py
  class Vgg19 (line 10) | class Vgg19(torch.nn.Module):
    method __init__ (line 14) | def __init__(self, requires_grad=False):
    method forward (line 42) | def forward(self, X):
  class ImagePyramide (line 53) | class ImagePyramide(torch.nn.Module):
    method __init__ (line 57) | def __init__(self, scales, num_channels):
    method forward (line 64) | def forward(self, x):
  class Transform (line 71) | class Transform:
    method __init__ (line 75) | def __init__(self, bs, **kwargs):
    method transform_frame (line 89) | def transform_frame(self, frame):
    method inverse_transform_frame (line 95) | def inverse_transform_frame(self, frame):
    method warp_coordinates (line 101) | def warp_coordinates(self, coordinates):
    method inverse_warp_coordinates (line 121) | def inverse_warp_coordinates(self, coordinates):
    method jacobian (line 146) | def jacobian(self, coordinates):
  function detach_kp (line 155) | def detach_kp(kp):
  class TrainPart1Model (line 158) | class TrainPart1Model(torch.nn.Module):
    method __init__ (line 163) | def __init__(self, kp_extractor, kp_extractor_a, audio_feature, genera...
    method forward (line 187) | def forward(self, x):
  class TrainPart2Model (line 282) | class TrainPart2Model(torch.nn.Module):
    method __init__ (line 287) | def __init__(self, kp_extractor, emo_feature, kp_extractor_a, audio_fe...
    method forward (line 312) | def forward(self, x):
  class GeneratorFullModel (line 416) | class GeneratorFullModel(torch.nn.Module):
    method __init__ (line 421) | def __init__(self, kp_extractor, kp_extractor_a, audio_feature, genera...
    method forward (line 447) | def forward(self, x):
  class DiscriminatorFullModel (line 557) | class DiscriminatorFullModel(torch.nn.Module):
    method __init__ (line 562) | def __init__(self, kp_extractor, generator, discriminator, train_params):
    method forward (line 575) | def forward(self, x, generated):

FILE: modules/model_delta_map.py
  class Vgg19 (line 10) | class Vgg19(torch.nn.Module):
    method __init__ (line 14) | def __init__(self, requires_grad=False):
    method forward (line 42) | def forward(self, X):
  class ImagePyramide (line 53) | class ImagePyramide(torch.nn.Module):
    method __init__ (line 57) | def __init__(self, scales, num_channels):
    method forward (line 64) | def forward(self, x):
  class Transform (line 71) | class Transform:
    method __init__ (line 75) | def __init__(self, bs, **kwargs):
    method transform_frame (line 89) | def transform_frame(self, frame):
    method inverse_transform_frame (line 95) | def inverse_transform_frame(self, frame):
    method warp_coordinates (line 101) | def warp_coordinates(self, coordinates):
    method inverse_warp_coordinates (line 121) | def inverse_warp_coordinates(self, coordinates):
    method jacobian (line 146) | def jacobian(self, coordinates):
  function detach_kp (line 155) | def detach_kp(kp):
  class TrainFullModel (line 158) | class TrainFullModel(torch.nn.Module):
    method __init__ (line 163) | def __init__(self, kp_extractor, emo_feature, kp_extractor_a, audio_fe...
    method forward (line 192) | def forward(self, x):
  class GeneratorFullModel (line 325) | class GeneratorFullModel(torch.nn.Module):
    method __init__ (line 330) | def __init__(self, kp_extractor, kp_extractor_a, audio_feature, genera...
    method forward (line 356) | def forward(self, x):
  class DiscriminatorFullModel (line 466) | class DiscriminatorFullModel(torch.nn.Module):
    method __init__ (line 471) | def __init__(self, kp_extractor, generator, discriminator, train_params):
    method forward (line 484) | def forward(self, x, generated):

FILE: modules/model_gen.py
  class Vgg19 (line 10) | class Vgg19(torch.nn.Module):
    method __init__ (line 14) | def __init__(self, requires_grad=False):
    method forward (line 42) | def forward(self, X):
  class ImagePyramide (line 53) | class ImagePyramide(torch.nn.Module):
    method __init__ (line 57) | def __init__(self, scales, num_channels):
    method forward (line 64) | def forward(self, x):
  class Transform (line 71) | class Transform:
    method __init__ (line 75) | def __init__(self, bs, **kwargs):
    method transform_frame (line 89) | def transform_frame(self, frame):
    method inverse_transform_frame (line 95) | def inverse_transform_frame(self, frame):
    method warp_coordinates (line 101) | def warp_coordinates(self, coordinates):
    method inverse_warp_coordinates (line 121) | def inverse_warp_coordinates(self, coordinates):
    method jacobian (line 146) | def jacobian(self, coordinates):
  function detach_kp (line 155) | def detach_kp(kp):
  class TrainFullModel (line 158) | class TrainFullModel(torch.nn.Module):
    method __init__ (line 163) | def __init__(self, kp_extractor, emo_feature, kp_extractor_a, audio_fe...
    method forward (line 192) | def forward(self, x):
  class GeneratorFullModel (line 341) | class GeneratorFullModel(torch.nn.Module):
    method __init__ (line 346) | def __init__(self, kp_extractor, kp_extractor_a, audio_feature, genera...
    method forward (line 372) | def forward(self, x):
  class DiscriminatorFullModel (line 482) | class DiscriminatorFullModel(torch.nn.Module):
    method __init__ (line 487) | def __init__(self, kp_extractor, generator, discriminator, train_params):
    method forward (line 500) | def forward(self, x, generated):

FILE: modules/ops.py
  function linear (line 8) | def linear(channel_in, channel_out,
  function conv2d (line 21) | def conv2d(channel_in, channel_out,
  function conv_transpose2d (line 37) | def conv_transpose2d(channel_in, channel_out,
  function nn_conv2d (line 53) | def nn_conv2d(channel_in, channel_out,
  function _apply (line 71) | def _apply(layer, activation, normalizer, channel_out=None):

FILE: modules/stylegan2.py
  function fused_leaky_relu (line 25) | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
  class FusedLeakyReLU (line 29) | class FusedLeakyReLU(nn.Module):
    method __init__ (line 30) | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
    method forward (line 36) | def forward(self, input):
  function upfirdn2d_native (line 45) | def upfirdn2d_native(
  function upfirdn2d (line 82) | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
  class PixelNorm (line 86) | class PixelNorm(nn.Module):
    method __init__ (line 87) | def __init__(self):
    method forward (line 90) | def forward(self, input):
  function make_kernel (line 94) | def make_kernel(k):
  class Upsample (line 105) | class Upsample(nn.Module):
    method __init__ (line 106) | def __init__(self, kernel, factor=2):
    method forward (line 120) | def forward(self, input):
  class Downsample (line 126) | class Downsample(nn.Module):
    method __init__ (line 127) | def __init__(self, kernel, factor=2):
    method forward (line 141) | def forward(self, input):
  class Blur (line 147) | class Blur(nn.Module):
    method __init__ (line 148) | def __init__(self, kernel, pad, upsample_factor=1):
    method forward (line 160) | def forward(self, input):
  class EqualConv2d (line 166) | class EqualConv2d(nn.Module):
    method __init__ (line 167) | def __init__(
    method forward (line 186) | def forward(self, input):
    method __repr__ (line 199) | def __repr__(self):
  class EqualLinear (line 206) | class EqualLinear(nn.Module):
    method __init__ (line 207) | def __init__(
    method forward (line 225) | def forward(self, input):
    method __repr__ (line 237) | def __repr__(self):
  class ScaledLeakyReLU (line 243) | class ScaledLeakyReLU(nn.Module):
    method __init__ (line 244) | def __init__(self, negative_slope=0.2):
    method forward (line 249) | def forward(self, input):
  class ModulatedConv2d (line 255) | class ModulatedConv2d(nn.Module):
    method __init__ (line 256) | def __init__(
    method __repr__ (line 305) | def __repr__(self):
    method forward (line 311) | def forward(self, input, style):
  class NoiseInjection (line 358) | class NoiseInjection(nn.Module):
    method __init__ (line 359) | def __init__(self):
    method forward (line 364) | def forward(self, image, noise=None):
  class ConstantInput (line 372) | class ConstantInput(nn.Module):
    method __init__ (line 373) | def __init__(self, channel, size=4):
    method forward (line 378) | def forward(self, input):
  class StyledConv (line 385) | class StyledConv(nn.Module):
    method __init__ (line 386) | def __init__(
    method forward (line 415) | def forward(self, input, style=None, noise=None):
  class ToRGB (line 425) | class ToRGB(nn.Module):
    method __init__ (line 426) | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[...
    method forward (line 435) | def forward(self, input, style, skip=None):
  class Generator (line 447) | class Generator(nn.Module):
    method __init__ (line 448) | def __init__(
    method make_noise (line 533) | def make_noise(self):
    method mean_latent (line 544) | def mean_latent(self, n_latent):
    method get_latent (line 552) | def get_latent(self, input):
    method forward (line 555) | def forward(
  class ConvLayer (line 630) | class ConvLayer(nn.Sequential):
    method __init__ (line 631) | def __init__(
  class ResBlock (line 679) | class ResBlock(nn.Module):
    method __init__ (line 680) | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], ...
    method forward (line 694) | def forward(self, input):
  class StyleGAN2Discriminator (line 704) | class StyleGAN2Discriminator(nn.Module):
    method __init__ (line 705) | def __init__(self, input_nc, ndf=64, n_layers=3, no_antialias=False, s...
    method forward (line 761) | def forward(self, input, get_minibatch_features=False):
  class TileStyleGAN2Discriminator (line 795) | class TileStyleGAN2Discriminator(StyleGAN2Discriminator):
    method forward (line 796) | def forward(self, input):
  class StyleGAN2Encoder (line 806) | class StyleGAN2Encoder(nn.Module):
    method __init__ (line 807) | def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_b...
    method forward (line 843) | def forward(self, input, layers=[], get_features=False):
  class StyleGAN2Decoder (line 860) | class StyleGAN2Decoder(nn.Module):
    method __init__ (line 861) | def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_b...
    method forward (line 902) | def forward(self, input):
  class StyleGAN2Generator (line 906) | class StyleGAN2Generator(nn.Module):
    method __init__ (line 907) | def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_b...
    method forward (line 913) | def forward(self, input, layers=[], encode_only=False):

FILE: modules/util.py
  class InstanceNorm (line 26) | class InstanceNorm(nn.Module):
    method __init__ (line 27) | def __init__(self, epsilon=1e-8):
    method forward (line 35) | def forward(self, x):
  class ApplyStyle (line 41) | class ApplyStyle(nn.Module):
    method __init__ (line 45) | def __init__(self, latent_size, channels, use_wscale):
    method forward (line 52) | def forward(self, x, latent):
  class FC (line 60) | class FC(nn.Module):
    method __init__ (line 61) | def __init__(self,
    method forward (line 87) | def forward(self, x):
  class Embedder (line 97) | class Embedder:
    method __init__ (line 98) | def __init__(self, **kwargs):
    method create_embedding_fn (line 102) | def create_embedding_fn(self):
    method embed (line 126) | def embed(self, inputs):
  function get_embedder (line 130) | def get_embedder(multires, i=0):
  function draw_heatmap (line 148) | def draw_heatmap(landmark, width, height):
  class NA_net (line 175) | class NA_net(nn.Module):
    method __init__ (line 176) | def __init__(self):
    method forward (line 195) | def forward(self, neutral):
  class AT_net (line 203) | class AT_net(nn.Module):
    method __init__ (line 204) | def __init__(self):
    method forward (line 270) | def forward(self, example_image, audio, pose, jaco_net):
  class Classify (line 306) | class Classify(nn.Module):
    method __init__ (line 307) | def __init__(self):
    method forward (line 314) | def forward(self, feature):
  class TF_net (line 321) | class TF_net(nn.Module):
    method __init__ (line 322) | def __init__(self):
    method adain_forward (line 391) | def adain_forward(self, example_image, audio, pose, jaco_net, emo_feat...
    method adain_feature2 (line 434) | def adain_feature2(self, example_image, audio, pose, jaco_net, emo_fea...
    method forward (line 477) | def forward(self, example_image, audio, pose, jaco_net, emo_features):
  class AT_net2 (line 514) | class AT_net2(nn.Module):
    method __init__ (line 515) | def __init__(self):
    method forward (line 580) | def forward(self, example_image, audio, pose, jaco_net, weight):
  class Ct_encoder (line 618) | class Ct_encoder(nn.Module):
    method __init__ (line 619) | def __init__(self):
    method forward (line 638) | def forward(self, audio):
  class EmotionNet (line 647) | class EmotionNet(nn.Module):
    method __init__ (line 648) | def __init__(self):
    method forward (line 697) | def forward(self, mfcc):
  class AF2F (line 715) | class AF2F(nn.Module):
    method __init__ (line 716) | def __init__(self):
    method forward (line 736) | def forward(self, content,emotion):
  class AF2F_s (line 745) | class AF2F_s(nn.Module):
    method __init__ (line 746) | def __init__(self):
    method forward (line 766) | def forward(self, content):
  class A2I (line 776) | class A2I(nn.Module):
    method __init__ (line 777) | def __init__(self):
    method forward (line 804) | def forward(self, mfcc):
  function kp2gaussian (line 815) | def kp2gaussian(kp, spatial_size, kp_variance):
  function make_coordinate_grid (line 839) | def make_coordinate_grid(spatial_size, type):
  class ResBlock2d (line 858) | class ResBlock2d(nn.Module):
    method __init__ (line 863) | def __init__(self, in_features, kernel_size, padding):
    method forward (line 872) | def forward(self, x):
  class UpBlock2d (line 883) | class UpBlock2d(nn.Module):
    method __init__ (line 888) | def __init__(self, in_features, out_features, kernel_size=3, padding=1...
    method forward (line 895) | def forward(self, x):
  class DownBlock2d (line 903) | class DownBlock2d(nn.Module):
    method __init__ (line 908) | def __init__(self, in_features, out_features, kernel_size=3, padding=1...
    method forward (line 915) | def forward(self, x):
  class SameBlock2d (line 923) | class SameBlock2d(nn.Module):
    method __init__ (line 928) | def __init__(self, in_features, out_features, groups=1, kernel_size=3,...
    method forward (line 934) | def forward(self, x):
  class Encoder (line 941) | class Encoder(nn.Module):
    method __init__ (line 946) | def __init__(self, block_expansion, in_features, num_blocks=3, max_fea...
    method forward (line 956) | def forward(self, x):
  class Decoder (line 963) | class Decoder(nn.Module):
    method __init__ (line 968) | def __init__(self, block_expansion, in_features, num_blocks=3, max_fea...
    method forward (line 981) | def forward(self, x):
  class Hourglass (line 990) | class Hourglass(nn.Module):
    method __init__ (line 995) | def __init__(self, block_expansion, in_features, num_blocks=3, max_fea...
    method forward (line 1001) | def forward(self, x):
  class AntiAliasInterpolation2d (line 1005) | class AntiAliasInterpolation2d(nn.Module):
    method __init__ (line 1009) | def __init__(self, channels, scale):
    method forward (line 1044) | def forward(self, input):
  function sigmoid (line 1054) | def sigmoid(x):
  function norm_angle (line 1058) | def norm_angle(angle):
  function conv3x3 (line 1063) | def conv3x3(in_planes, out_planes, stride=1):
  class BasicBlock (line 1069) | class BasicBlock(nn.Module):
    method __init__ (line 1072) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 1082) | def forward(self, x):
  class Bottleneck (line 1101) | class Bottleneck(nn.Module):
    method __init__ (line 1104) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 1117) | def forward(self, x):
  class EmDetector (line 1139) | class EmDetector(nn.Module):
    method __init__ (line 1144) | def __init__(self, block_expansion,  num_channels, max_features,
    method _make_layer (line 1170) | def _make_layer(self, block, planes, blocks, stride=1):
    method adain_feature (line 1187) | def adain_feature(self, x): #torch.Size([4, 3, H, W])
    method forward (line 1197) | def forward(self, x): #torch.Size([4, 3, H, W])
  class Emotion_k (line 1223) | class Emotion_k(nn.Module):
    method __init__ (line 1228) | def __init__(self, block_expansion,  num_channels, max_features,
    method _make_layer (line 1316) | def _make_layer(self, block, planes, blocks, stride=1):
    method linear_10 (line 1333) | def linear_10(self, x, value, jacobian): #torch.Size([4, 3, H, W])
    method linear_4 (line 1364) | def linear_4(self, x, value, jacobian): #torch.Size([4, 3, H, W])
    method linear_np_10 (line 1396) | def linear_np_10(self, x, value, jacobian): #torch.Size([4, 3, H, W])
    method linear_np_4 (line 1427) | def linear_np_4(self, x, value, jacobian): #torch.Size([4, 3, H, W])
    method emotion_feature (line 1459) | def emotion_feature(self, feature, value, jacobian): #torch.Size([4, 3...
    method feature (line 1477) | def feature(self, x): #torch.Size([4, 3, H, W])
    method forward (line 1498) | def forward(self, x, value, jacobian): #torch.Size([4, 3, H, W])
  class Emotion_map (line 1529) | class Emotion_map(nn.Module):
    method __init__ (line 1534) | def __init__(self, block_expansion,  num_channels, max_features,
    method _make_layer (line 1607) | def _make_layer(self, block, planes, blocks, stride=1):
    method gaussian2kp (line 1624) | def gaussian2kp(self, heatmap):
    method map_4 (line 1636) | def map_4(self, x, value, jacobian): #torch.Size([4, 3, H, W])
    method forward (line 1687) | def forward(self, x, value, jacobian): #torch.Size([4, 3, H, W])
  function conv2d (line 1740) | def conv2d(channel_in, channel_out,
  function _apply (line 1755) | def _apply(layer, activation, normalizer, channel_out=None):

FILE: ops.py
  class ResidualBlock (line 8) | class ResidualBlock(nn.Module):
    method __init__ (line 9) | def __init__(self, channel_in, channel_out):
    method forward (line 19) | def forward(self, x):
  function linear (line 27) | def linear(channel_in, channel_out,
  function conv2d (line 40) | def conv2d(channel_in, channel_out,
  function conv_transpose2d (line 56) | def conv_transpose2d(channel_in, channel_out,
  function nn_conv2d (line 72) | def nn_conv2d(channel_in, channel_out,
  function _apply (line 90) | def _apply(layer, activation, normalizer, channel_out=None):

FILE: process_data.py
  function save (line 29) | def save(path, frames, format):
  function crop_image (line 44) | def crop_image(image_path, out_path):
  function shape_to_np (line 70) | def shape_to_np(shape, dtype="int"):
  function crop_image_tem (line 85) | def crop_image_tem(video_path, out_path):
  function proc_audio (line 124) | def proc_audio(src_mouth_path, dst_audio_path):
  function audio2mfcc (line 130) | def audio2mfcc(audio_file, save, name):

FILE: sync_batchnorm/batchnorm.py
  function _sum_ft (line 24) | def _sum_ft(tensor):
  function _unsqueeze_ft (line 29) | def _unsqueeze_ft(tensor):
  class _SynchronizedBatchNorm (line 38) | class _SynchronizedBatchNorm(_BatchNorm):
    method __init__ (line 39) | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
    method forward (line 48) | def forward(self, input):
    method __data_parallel_replicate__ (line 80) | def __data_parallel_replicate__(self, ctx, copy_id):
    method _data_parallel_master (line 90) | def _data_parallel_master(self, intermediates):
    method _compute_mean_std (line 113) | def _compute_mean_std(self, sum_, ssum, size):
  class SynchronizedBatchNorm1d (line 128) | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
    method _check_input_dim (line 184) | def _check_input_dim(self, input):
  class SynchronizedBatchNorm2d (line 191) | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
    method _check_input_dim (line 247) | def _check_input_dim(self, input):
  class SynchronizedBatchNorm3d (line 254) | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
    method _check_input_dim (line 311) | def _check_input_dim(self, input):

FILE: sync_batchnorm/comm.py
  class FutureResult (line 18) | class FutureResult(object):
    method __init__ (line 21) | def __init__(self):
    method put (line 26) | def put(self, result):
    method get (line 32) | def get(self):
  class SlavePipe (line 46) | class SlavePipe(_SlavePipeBase):
    method run_slave (line 49) | def run_slave(self, msg):
  class SyncMaster (line 56) | class SyncMaster(object):
    method __init__ (line 67) | def __init__(self, master_callback):
    method __getstate__ (line 78) | def __getstate__(self):
    method __setstate__ (line 81) | def __setstate__(self, state):
    method register_slave (line 84) | def register_slave(self, identifier):
    method run_master (line 102) | def run_master(self, master_msg):
    method nr_slaves (line 136) | def nr_slaves(self):

FILE: sync_batchnorm/replicate.py
  class CallbackContext (line 23) | class CallbackContext(object):
  function execute_replication_callbacks (line 27) | def execute_replication_callbacks(modules):
  class DataParallelWithCallback (line 50) | class DataParallelWithCallback(DataParallel):
    method replicate (line 64) | def replicate(self, module, device_ids):
  function patch_replication_callback (line 70) | def patch_replication_callback(data_parallel):

FILE: sync_batchnorm/unittest.py
  function as_numpy (line 17) | def as_numpy(v):
  class TorchTestCase (line 23) | class TorchTestCase(unittest.TestCase):
    method assertTensorClose (line 24) | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):

FILE: train.py
  function train_part1 (line 18) | def train_part1(config, generator, discriminator, kp_detector, kp_detect...
  function train_part1_fine_tune (line 133) | def train_part1_fine_tune(config, generator, discriminator, kp_detector,...
  function train_part2 (line 273) | def train_part2(config, generator, discriminator, kp_detector, emo_detec...
Condensed preview — 53 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (366K chars).
[
  {
    "path": "3DDFA_V2/demo.py",
    "chars": 8816,
    "preview": "# coding: utf-8\n\n__author__ = 'cleardusk'\n\nimport sys\nimport argparse\nimport cv2\nimport yaml\nimport os\nimport time\nfrom "
  },
  {
    "path": "3DDFA_V2/utils/pose.py",
    "chars": 8314,
    "preview": "# coding: utf-8\n\n\"\"\"\nReference: https://github.com/YadiraF/PRNet/blob/master/utils/estimate_pose.py\n\nCalculating pose fr"
  },
  {
    "path": "LICENSE",
    "chars": 1064,
    "preview": "MIT License\n\nCopyright (c) 2022 jixinya\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof"
  },
  {
    "path": "README.md",
    "chars": 3250,
    "preview": "# EAMM:  One-Shot Emotional Talking Face via Audio-Based Emotion-Aware Motion Model [SIGGRAPH 2022 Conference]\r\n\r\nXinya "
  },
  {
    "path": "augmentation.py",
    "chars": 16022,
    "preview": "\"\"\"\nCode from https://github.com/hassony2/torch_videovision\n\"\"\"\n\nimport numbers\nimport math\nimport random\nimport numpy a"
  },
  {
    "path": "config/MEAD_emo_video_aug_delta_4_crop_random_crop.yaml",
    "chars": 2210,
    "preview": "dataset_params:\n  root_dir: /mnt/lustre/share_data/jixinya/MEAD/\n  frame_shape: [256, 256, 3]\n  id_sampling: False\n  pai"
  },
  {
    "path": "config/train_part1.yaml",
    "chars": 1684,
    "preview": "dataset_params:\n  name: Vox\n  root_dir: dataset/LRW/\n  frame_shape: [256, 256, 3]\n  id_sampling: False\n  augmentation_pa"
  },
  {
    "path": "config/train_part1_fine_tune.yaml",
    "chars": 1689,
    "preview": "dataset_params:\n  name: LRW\n  root_dir: dataset/LRW/\n  frame_shape: [256, 256, 3]\n  id_sampling: False\n  augmentation_pa"
  },
  {
    "path": "config/train_part2.yaml",
    "chars": 1919,
    "preview": "dataset_params:\n  name: MEAD\n  root_dir: dataset/MEAD/\n  frame_shape: [256, 256, 3]\n  id_sampling: False\n  augmentation_"
  },
  {
    "path": "demo.py",
    "chars": 22742,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nCreated on Wed Oct  6 20:57:27 2021\n@author: thea\n\"\"\"\n\nimport matplot"
  },
  {
    "path": "filter1.py",
    "chars": 1201,
    "preview": "import cv2\n#import pickle\nimport time\nimport numpy as np\nimport copy\n\nfrom matplotlib import pyplot as plt\nfrom tqdm imp"
  },
  {
    "path": "frames_dataset.py",
    "chars": 20138,
    "preview": "import os\nfrom skimage import io, img_as_float32, transform\nfrom skimage.color import gray2rgb\nfrom sklearn.model_select"
  },
  {
    "path": "logger.py",
    "chars": 9203,
    "preview": "import numpy as np\nimport torch\nimport torch.nn.functional as F\nimport imageio\n\nimport os\nfrom skimage.draw import circl"
  },
  {
    "path": "modules/dense_motion.py",
    "chars": 5189,
    "preview": "from torch import nn\nimport torch.nn.functional as F\nimport torch\nfrom modules.util import Hourglass, AntiAliasInterpola"
  },
  {
    "path": "modules/discriminator.py",
    "chars": 3156,
    "preview": "from torch import nn\nimport torch.nn.functional as F\nfrom modules.util import kp2gaussian\nimport torch\n\n\nclass DownBlock"
  },
  {
    "path": "modules/function.py",
    "chars": 2525,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nCreated on Thu Sep 30 17:45:24 2021\n\n@author: SENSETIME\\jixinya1\n\"\"\"\n"
  },
  {
    "path": "modules/generator.py",
    "chars": 4627,
    "preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom modules.util import ResBlock2d, SameBlock2d, UpBl"
  },
  {
    "path": "modules/keypoint_detector.py",
    "chars": 10448,
    "preview": "from torch import nn\nimport torch\nimport torch.nn.functional as F\nfrom modules.util import Hourglass, make_coordinate_gr"
  },
  {
    "path": "modules/model.py",
    "chars": 28798,
    "preview": "from torch import nn\nimport torch\nimport torch.nn.functional as F\nfrom modules.util import AntiAliasInterpolation2d, mak"
  },
  {
    "path": "modules/model_delta_map.py",
    "chars": 25745,
    "preview": "from torch import nn\nimport torch\nimport torch.nn.functional as F\nfrom modules.util import AntiAliasInterpolation2d, mak"
  },
  {
    "path": "modules/model_gen.py",
    "chars": 25806,
    "preview": "from torch import nn\nimport torch\nimport torch.nn.functional as F\nfrom modules.util import AntiAliasInterpolation2d, mak"
  },
  {
    "path": "modules/ops.py",
    "chars": 2305,
    "preview": "import torch\nimport torchvision\nimport torch.nn as nn\nimport torch.nn.init as init\nfrom torch.autograd import Variable\n\n"
  },
  {
    "path": "modules/stylegan2.py",
    "chars": 28081,
    "preview": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nCreated on Thu Jul  8 01:03:50 2021\n\n@author: thea\n\"\"\"\n\n\"\"\"\nThe netwo"
  },
  {
    "path": "modules/util.py",
    "chars": 63282,
    "preview": "from torch import nn\n\nimport torch.nn.functional as F\nimport torch\nimport numpy as np\nimport cv2\nfrom sync_batchnorm imp"
  },
  {
    "path": "ops.py",
    "chars": 2797,
    "preview": "import torch\nimport torchvision\nimport torch.nn as nn\nimport torch.nn.init as init\nfrom torch.autograd import Variable\n\n"
  },
  {
    "path": "process_data.py",
    "chars": 5768,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"\nCreated on Thu Jun 24 11:36:01 2021\n\n@author: Xinya\n\"\"\"\n\nimport os\nimport glob\nimport time\ni"
  },
  {
    "path": "requirements.txt",
    "chars": 167,
    "preview": "torch==1.10.1\ntorchvision==0.11.2\nnumpy\nlibrosa\nopencv-python\npython_speech_features\npickle-mixin\nmatplotlib\nscikit-imag"
  },
  {
    "path": "run.py",
    "chars": 5496,
    "preview": "import matplotlib\n\nmatplotlib.use('Agg')\n\nimport os, sys\nimport yaml\nfrom argparse import ArgumentParser\nfrom time impor"
  },
  {
    "path": "sync_batchnorm/__init__.py",
    "chars": 449,
    "preview": "# -*- coding: utf-8 -*-\n# File   : __init__.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2"
  },
  {
    "path": "sync_batchnorm/batchnorm.py",
    "chars": 12973,
    "preview": "# -*- coding: utf-8 -*-\n# File   : batchnorm.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/"
  },
  {
    "path": "sync_batchnorm/comm.py",
    "chars": 4449,
    "preview": "# -*- coding: utf-8 -*-\n# File   : comm.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n"
  },
  {
    "path": "sync_batchnorm/replicate.py",
    "chars": 3226,
    "preview": "# -*- coding: utf-8 -*-\n# File   : replicate.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/"
  },
  {
    "path": "sync_batchnorm/unittest.py",
    "chars": 835,
    "preview": "# -*- coding: utf-8 -*-\n# File   : unittest.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2"
  },
  {
    "path": "train.py",
    "chars": 19729,
    "preview": "from tqdm import trange\nimport torch\nimport torch.nn as nn\nfrom torch.utils.data import DataLoader\n\nfrom logger import L"
  }
]

// ... and 19 more files (download for full content)

About this extraction

This page contains the full source code of the jixinya/EAMM GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 53 files (345.8 KB), approximately 92.0k tokens, and a symbol index with 451 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!