[
  {
    "path": "3DDFA_V2/demo.py",
    "content": "# coding: utf-8\n\n__author__ = 'cleardusk'\n\nimport sys\nimport argparse\nimport cv2\nimport yaml\nimport os\nimport time\nfrom FaceBoxes import FaceBoxes\nfrom TDDFA import TDDFA\nfrom utils.render import render\n#from utils.render_ctypes import render  # faster\nfrom utils.depth import depth\nfrom utils.pncc import pncc\nfrom utils.uv import uv_tex\nfrom utils.pose import viz_pose, get_pose\nfrom utils.serialization import ser_to_ply, ser_to_obj\nfrom utils.functions import draw_landmarks, get_suffix\nfrom utils.tddfa_util import str2bool\nimport numpy as np\nfrom tqdm import tqdm\nimport copy\n\nimport concurrent.futures\nfrom multiprocessing import Pool\n\ndef main(args,img, save_path, pose_path):\n #   begin = time.time()\n    cfg = yaml.load(open(args.config), Loader=yaml.SafeLoader)\n\n    # Init FaceBoxes and TDDFA, recommend using onnx flag\n    if args.onnx:\n        import os\n        os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'\n        os.environ['OMP_NUM_THREADS'] = '4'\n\n        from FaceBoxes.FaceBoxes_ONNX import FaceBoxes_ONNX\n        from TDDFA_ONNX import TDDFA_ONNX\n\n        face_boxes = FaceBoxes_ONNX()\n        tddfa = TDDFA_ONNX(**cfg)\n    else:\n        gpu_mode = args.mode == 'gpu'\n        tddfa = TDDFA(gpu_mode=gpu_mode, **cfg)\n        face_boxes = FaceBoxes()\n\n    # Given a still image path and load to BGR channel\n  #  img = cv2.imread(img_path) #args.img_fp\n\n    # Detect faces, get 3DMM params and roi boxes\n    boxes = face_boxes(img)\n    n = len(boxes)\n    if n == 0:\n        print(f'No face detected, exit')\n      #  sys.exit(-1)\n        return None\n    print(f'Detect {n} faces')\n\n    param_lst, roi_box_lst = tddfa(img, boxes)\n    #detection time\n  #  detect_time = time.time()-begin\n #   print('detection time: '+str(detect_time), file=open('/mnt/lustre/jixinya/Home/3DDFA_V2/pose.txt', 'a'))\n    # Visualization and serialization\n    dense_flag = args.opt in ('2d_dense', '3d', 'depth', 'pncc', 'uv_tex', 'ply', 'obj')\n  #  old_suffix = get_suffix(img_path)\n    old_suffix = 'png'\n    new_suffix = f'.{args.opt}' if args.opt in ('ply', 'obj') else '.jpg'\n\n    wfp = f'examples/results/{args.img_fp.split(\"/\")[-1].replace(old_suffix, \"\")}_{args.opt}' + new_suffix\n\n    ver_lst = tddfa.recon_vers(param_lst, roi_box_lst, dense_flag=dense_flag)\n\n    if args.opt == '2d_sparse':\n        draw_landmarks(img, ver_lst, show_flag=args.show_flag, dense_flag=dense_flag, wfp=wfp)\n    elif args.opt == '2d_dense':\n        draw_landmarks(img, ver_lst, show_flag=args.show_flag, dense_flag=dense_flag, wfp=wfp)\n    elif args.opt == '3d':\n        render(img, ver_lst, tddfa.tri, alpha=0.6, show_flag=args.show_flag, wfp=wfp)\n    elif args.opt == 'depth':\n\n        # if `with_bf_flag` is False, the background is black\n        depth(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp, with_bg_flag=True)\n    elif args.opt == 'pncc':\n        pncc(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp, with_bg_flag=True)\n    elif args.opt == 'uv_tex':\n        uv_tex(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp)\n    elif args.opt == 'pose':\n        all_pose = get_pose(img, param_lst, ver_lst, show_flag=args.show_flag, wfp=save_path, wnp = pose_path)\n    elif args.opt == 'ply':\n        ser_to_ply(ver_lst, tddfa.tri, height=img.shape[0], wfp=wfp)\n    elif args.opt == 'obj':\n        ser_to_obj(img, ver_lst, tddfa.tri, height=img.shape[0], wfp=wfp)\n    else:\n        raise ValueError(f'Unknown opt {args.opt}')\n\n    return all_pose\n\n\n\ndef process_word(i):\n    path = '/media/xinya/Backup Plus/sense_shixi_data/new_crop/MEAD_fomm_video_6/'\n    save = '/media/xinya/Backup Plus/sense_shixi_data/new_crop/MEAD_fomm_pose_im/'\n    pose = '/media/xinya/Backup Plus/sense_shixi_data/new_crop/MEAD_fomm_pose/'\n    start = time.time()\n    Dir = os.listdir(path)\n    Dir.sort()\n    word = Dir[i]\n    wpath = os.path.join(path, word)\n    print(wpath)\n    pathDir = os.listdir(wpath)\n    pose_file = os.path.join(pose,word)\n    if not os.path.exists(pose_file):\n        os.makedirs(pose_file)\n\n    for j in range(len(pathDir)):\n        name = pathDir[j]\n     #   save_file = os.path.join(save,word,name)\n     #   if not os.path.exists(save_file):\n     #       os.makedirs(save_file)\n        fpath = os.path.join(wpath,name)\n        image_all = []\n        videoCapture = cv2.VideoCapture(fpath)\n\n        success, frame = videoCapture.read()\n\n        n = 0\n        while success :\n            image_all.append(frame)\n            n = n + 1\n            success, frame = videoCapture.read()\n\n     #   fDir = os.listdir(fpath)\n        pose_all = np.zeros((len(image_all),7))\n        for k in range(len(image_all)):\n    #        index = fDir[k].split('.')[0]\n    #        img_path = os.path.join(fpath,str(k)+'.png')\n\n     #       pose_all[k] = main(args,image_all[k], os.path.join(save_file,str(k)+'.jpg'), None)\n            pose_all[k] = main(args,image_all[k], None, None)\n        np.save(os.path.join(pose,word,name.split('.')[0]+'.npy'),pose_all)\n        st = time.time()-start\n        print(str(i)+' '+word+' '+str(j)+' '+name+' '+str(k)+'time: '+str(st), file=open('/media/thea/Backup Plus/sense_shixi_data/new_crop/pose_mead6.txt', 'a'))\n        print(i,word,j,name,k)\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='The demo of still image of 3DDFA_V2')\n    parser.add_argument('-c', '--config', type=str, default='configs/mb1_120x120.yml')\n    parser.add_argument('-f', '--img_fp', type=str, default='examples/inputs/0.png')\n    parser.add_argument('-m', '--mode', type=str, default='cpu', help='gpu or cpu mode')\n    parser.add_argument('-o', '--opt', type=str, default='pose',\n                        choices=['2d_sparse', '2d_dense', '3d', 'depth', 'pncc', 'uv_tex', 'pose', 'ply', 'obj'])\n    parser.add_argument('--show_flag', type=str2bool, default='False', help='whether to show the visualization result')\n    parser.add_argument('--onnx', action='store_true', default=False)\n\n    args = parser.parse_args()\n\n\n    \n    filepath = 'test/image/'\n    pathDir = os.listdir(filepath)\n    for i in range(len(pathDir)):\n        image= cv2.imread(os.path.join(filepath,pathDir[i]))\n        pose = main(args,image, None, None).reshape(1,7)\n\n        np.save('test/pose/'+pathDir[i].split('.')[0]+'.npy',pose)\n        print(i,pathDir[i])\n        \n        \n'''\n\n\n\n\n\ndef main(args):\n    cfg = yaml.load(open(args.config), Loader=yaml.SafeLoader)\n\n    # Init FaceBoxes and TDDFA, recommend using onnx flag\n    if args.onnx:\n        import os\n        os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'\n        os.environ['OMP_NUM_THREADS'] = '4'\n\n        from FaceBoxes.FaceBoxes_ONNX import FaceBoxes_ONNX\n        from TDDFA_ONNX import TDDFA_ONNX\n\n        face_boxes = FaceBoxes_ONNX()\n        tddfa = TDDFA_ONNX(**cfg)\n    else:\n        gpu_mode = args.mode == 'gpu'\n        tddfa = TDDFA(gpu_mode=gpu_mode, **cfg)\n        face_boxes = FaceBoxes()\n\n    # Given a still image path and load to BGR channel\n    img = cv2.imread(args.img_fp)\n\n    # Detect faces, get 3DMM params and roi boxes\n    boxes = face_boxes(img)\n    n = len(boxes)\n    if n == 0:\n        print(f'No face detected, exit')\n        sys.exit(-1)\n    print(f'Detect {n} faces')\n\n    param_lst, roi_box_lst = tddfa(img, boxes)\n\n    # Visualization and serialization\n    dense_flag = args.opt in ('2d_dense', '3d', 'depth', 'pncc', 'uv_tex', 'ply', 'obj')\n    old_suffix = get_suffix(args.img_fp)\n    new_suffix = f'.{args.opt}' if args.opt in ('ply', 'obj') else '.jpg'\n\n    wfp = f'examples/results/{args.img_fp.split(\"/\")[-1].replace(old_suffix, \"\")}_{args.opt}' + new_suffix\n\n    ver_lst = tddfa.recon_vers(param_lst, roi_box_lst, dense_flag=dense_flag)\n\n    if args.opt == '2d_sparse':\n        draw_landmarks(img, ver_lst, show_flag=args.show_flag, dense_flag=dense_flag, wfp=wfp)\n    elif args.opt == '2d_dense':\n        draw_landmarks(img, ver_lst, show_flag=args.show_flag, dense_flag=dense_flag, wfp=wfp)\n    elif args.opt == '3d':\n        render(img, ver_lst, tddfa.tri, alpha=0.6, show_flag=args.show_flag, wfp=wfp)\n    elif args.opt == 'depth':\n        # if `with_bf_flag` is False, the background is black\n        depth(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp, with_bg_flag=True)\n    elif args.opt == 'pncc':\n        pncc(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp, with_bg_flag=True)\n    elif args.opt == 'uv_tex':\n        uv_tex(img, ver_lst, tddfa.tri, show_flag=args.show_flag, wfp=wfp)\n    elif args.opt == 'pose':\n        viz_pose(img, param_lst, ver_lst, show_flag=args.show_flag, wfp=wfp)\n    elif args.opt == 'ply':\n        ser_to_ply(ver_lst, tddfa.tri, height=img.shape[0], wfp=wfp)\n    elif args.opt == 'obj':\n        ser_to_obj(img, ver_lst, tddfa.tri, height=img.shape[0], wfp=wfp)\n    else:\n        raise ValueError(f'Unknown opt {args.opt}')\n'''"
  },
  {
    "path": "3DDFA_V2/utils/pose.py",
    "content": "# coding: utf-8\n\n\"\"\"\nReference: https://github.com/YadiraF/PRNet/blob/master/utils/estimate_pose.py\n\nCalculating pose from the output 3DMM parameters, you can also try to use solvePnP to perform estimation\n\"\"\"\n\n__author__ = 'cleardusk'\n\nimport cv2\nimport numpy as np\nfrom math import cos, sin, atan2, asin, sqrt\n\nfrom .functions import calc_hypotenuse, plot_image\n\n\ndef P2sRt(P):\n    \"\"\" decompositing camera matrix P.\n    Args:\n        P: (3, 4). Affine Camera Matrix.\n    Returns:\n        s: scale factor.\n        R: (3, 3). rotation matrix.\n        t2d: (2,). 2d translation.\n    \"\"\"\n    t3d = P[:, 3]\n    R1 = P[0:1, :3]\n    R2 = P[1:2, :3]\n    s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2.0\n    r1 = R1 / np.linalg.norm(R1)\n    r2 = R2 / np.linalg.norm(R2)\n    r3 = np.cross(r1, r2)\n\n    R = np.concatenate((r1, r2, r3), 0)\n    return s, R, t3d\n\n\ndef matrix2angle(R):\n    \"\"\" compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf\n    refined by: https://stackoverflow.com/questions/43364900/rotation-matrix-to-euler-angles-with-opencv\n    todo: check and debug\n     Args:\n         R: (3,3). rotation matrix\n     Returns:\n         x: yaw\n         y: pitch\n         z: roll\n     \"\"\"\n    if R[2, 0] > 0.998:\n        z = 0\n        x = np.pi / 2\n        y = z + atan2(-R[0, 1], -R[0, 2])\n    elif R[2, 0] < -0.998:\n        z = 0\n        x = -np.pi / 2\n        y = -z + atan2(R[0, 1], R[0, 2])\n    else:\n        x = asin(R[2, 0])\n        y = atan2(R[2, 1] / cos(x), R[2, 2] / cos(x))\n        z = atan2(R[1, 0] / cos(x), R[0, 0] / cos(x))\n\n    return x, y, z\n\ndef angle2matrix(theta):\n    \"\"\" compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf\n    refined by: https://stackoverflow.com/questions/43364900/rotation-matrix-to-euler-angles-with-opencv\n    todo: check and debug\n     Args:\n         R: (3,3). rotation matrix\n     Returns:\n         x: yaw\n         y: pitch\n         z: roll\n     \"\"\"\n    R_x = np.array([[1,         0,                  0         ],\n\n                    [0,         cos(theta[1]), -sin(theta[1]) ],\n\n                    [0,         sin(theta[1]), cos(theta[1])  ]\n\n                    ])\n\n \n\n    R_y = np.array([[cos(theta[0]),    0,      sin(-theta[0])  ],\n\n                    [0,                     1,      0         ],\n\n                    [-sin(-theta[0]),   0,      cos(theta[0])  ]\n\n                    ])\n\n \n\n    R_z = np.array([[cos(theta[2]),    -sin(theta[2]),    0],\n\n                    [sin(theta[2]),    cos(theta[2]),     0],\n\n                    [0,                     0,            1]\n\n                    ])\n\n \n\n    R = np.dot(R_z, np.dot( R_y, R_x ))\n\n \n\n    return R\n\ndef angle2matrix_3ddfa(angles):\n    ''' get rotation matrix from three rotation angles(radian). The same as in 3DDFA.\n    Args:\n        angles: [3,]. x, y, z angles\n        x: pitch.\n        y: yaw. \n        z: roll. \n    Returns:\n        R: 3x3. rotation matrix.\n    '''\n    # x, y, z = np.deg2rad(angles[0]), np.deg2rad(angles[1]), np.deg2rad(angles[2])\n    x, y, z = angles[1], angles[0], angles[2]\n    \n    # x\n    Rx=np.array([[1,      0,       0],\n                 [0, cos(x),  sin(x)],\n                 [0, -sin(x),   cos(x)]])\n    # y\n    Ry=np.array([[ cos(y), 0, -sin(y)],\n                 [      0, 1,      0],\n                 [sin(y), 0, cos(y)]])\n    # z\n    Rz=np.array([[cos(z), sin(z), 0],\n                 [-sin(z),  cos(z), 0],\n                 [     0,       0, 1]])\n    R = Rx.dot(Ry).dot(Rz)\n    return R.astype(np.float32)\n\ndef calc_pose(param):\n    P = param[:12].reshape(3, -1)  # camera matrix\n    s, R, t3d = P2sRt(P)\n    P = np.concatenate((R, t3d.reshape(3, -1)), axis=1)  # without scale\n    pose = matrix2angle(R)\n    pose = [p * 180 / np.pi for p in pose]\n\n    return P, pose\n\n\ndef build_camera_box(rear_size=90):\n    point_3d = []\n    rear_depth = 0\n    point_3d.append((-rear_size, -rear_size, rear_depth))\n    point_3d.append((-rear_size, rear_size, rear_depth))\n    point_3d.append((rear_size, rear_size, rear_depth))\n    point_3d.append((rear_size, -rear_size, rear_depth))\n    point_3d.append((-rear_size, -rear_size, rear_depth))\n\n    front_size = int(4 / 3 * rear_size)\n    front_depth = int(4 / 3 * rear_size)\n    point_3d.append((-front_size, -front_size, front_depth))\n    point_3d.append((-front_size, front_size, front_depth))\n    point_3d.append((front_size, front_size, front_depth))\n    point_3d.append((front_size, -front_size, front_depth))\n    point_3d.append((-front_size, -front_size, front_depth))\n    point_3d = np.array(point_3d, dtype=np.float32).reshape(-1, 3)\n\n    return point_3d\n\n\ndef plot_pose_box(img, P, ver, color=(40, 255, 0), line_width=2):\n    \"\"\" Draw a 3D box as annotation of pose.\n    Ref:https://github.com/yinguobing/head-pose-estimation/blob/master/pose_estimator.py\n    Args:\n        img: the input image\n        P: (3, 4). Affine Camera Matrix.\n        kpt: (2, 68) or (3, 68)\n    \"\"\"\n    llength = calc_hypotenuse(ver)\n    point_3d = build_camera_box(llength)\n    # Map to 2d image points\n    point_3d_homo = np.hstack((point_3d, np.ones([point_3d.shape[0], 1])))  # n x 4\n    point_2d = point_3d_homo.dot(P.T)[:, :2]\n\n    point_2d[:, 1] = - point_2d[:, 1]\n    point_2d[:, :2] = point_2d[:, :2] - np.mean(point_2d[:4, :2], 0) + np.mean(ver[:2, :27], 1)\n    point_2d = np.int32(point_2d.reshape(-1, 2))\n\n    # Draw all the lines\n    cv2.polylines(img, [point_2d], True, color, line_width, cv2.LINE_AA)\n    cv2.line(img, tuple(point_2d[1]), tuple(\n        point_2d[6]), color, line_width, cv2.LINE_AA)\n    cv2.line(img, tuple(point_2d[2]), tuple(\n        point_2d[7]), color, line_width, cv2.LINE_AA)\n    cv2.line(img, tuple(point_2d[3]), tuple(\n        point_2d[8]), color, line_width, cv2.LINE_AA)\n\n    return img\n\n\ndef viz_pose(img, param_lst, ver_lst, show_flag=False, wfp=None):\n    for param, ver in zip(param_lst, ver_lst):\n        P, pose = calc_pose(param)\n        img = plot_pose_box(img, P, ver)\n        # print(P[:, :3])\n        print(f'yaw: {pose[0]:.1f}, pitch: {pose[1]:.1f}, roll: {pose[2]:.1f}')\n\n    if wfp is not None:\n        cv2.imwrite(wfp, img)\n        print(f'Save visualization result to {wfp}')\n\n    if show_flag:\n        plot_image(img)\n\n    return img\n\ndef pose_6(param):\n    P = param[:12].reshape(3, -1)  # camera matrix\n    s, R, t3d = P2sRt(P)\n    P = np.concatenate((R, t3d.reshape(3, -1)), axis=1)  # without scale\n    pose = matrix2angle(R)\n    print(t3d)\n    R1 = angle2matrix(pose)\n    print(R)\n    print(R1)\n    pose = [p * 180 / np.pi for p in pose]\n    \n    return s, pose, t3d, P\n\n\ndef smooth_pose(img, param_lst, ver_lst, pose_new, show_flag=False, wfp=None, wnp = None):\n    for param, ver in zip(param_lst, ver_lst):\n        t3d = np.array([pose_new[4],pose_new[5],pose_new[6]])\n        \n        theta = np.array([pose_new[0],pose_new[1],pose_new[2]])\n        theta = [p * np.pi / 180 for p in theta]\n        R = angle2matrix(theta)\n        P = np.concatenate((R, t3d.reshape(3, -1)), axis=1) \n        img = plot_pose_box(img, P, ver)\n    #    print(P,P.shape,t3d)\n        print(P,pose_new)\n        print(f'yaw: {theta[0]:.1f}, pitch: {theta[1]:.1f}, roll: {theta[2]:.1f}')\n        all_pose = [0]\n        all_pose = np.array(all_pose)\n\n    if wfp is not None:\n        cv2.imwrite(wfp, img)\n        print(f'Save visualization result to {wfp}')\n        \n    if wnp is not None:\n        np.save(wnp, all_pose)\n        print(f'Save visualization result to {wfp}')\n        \n    if show_flag:\n        plot_image(img)\n\n    return img\n\n    \n    \n    \n\ndef get_pose(img, param_lst, ver_lst, show_flag=False, wfp=None, wnp = None):\n    for param, ver in zip(param_lst, ver_lst):\n        s, pose, t3d, P = pose_6(param)\n        img = plot_pose_box(img, P, ver)\n    #    print(P,P.shape,t3d)\n        print(f'yaw: {pose[0]:.1f}, pitch: {pose[1]:.1f}, roll: {pose[2]:.1f}')\n        all_pose = [pose[0],pose[1],pose[2],s,t3d[0],t3d[1],t3d[2]]\n        all_pose = np.array(all_pose)\n\n    if wfp is not None:\n        cv2.imwrite(wfp, img)\n        print(f'Save visualization result to {wfp}')\n        \n    if wnp is not None:\n        np.save(wnp, all_pose)\n        print(f'Save visualization result to {wfp}')\n        \n    if show_flag:\n        plot_image(img)\n\n    return all_pose\n\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2022 jixinya\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# EAMM:  One-Shot Emotional Talking Face via Audio-Based Emotion-Aware Motion Model [SIGGRAPH 2022 Conference]\r\n\r\nXinya Ji, [Hang Zhou](https://hangz-nju-cuhk.github.io/), Kaisiyuan Wang, [Qianyi Wu](https://wuqianyi.top/), [Wayne Wu](http://wywu.github.io/), [Feng Xu](http://xufeng.site/), [Xun Cao](https://cite.nju.edu.cn/People/Faculty/20190621/i5054.html)\r\n\r\n[[Project]](https://jixinya.github.io/projects/EAMM/)  [[Paper]](https://arxiv.org/abs/2205.15278)    \r\n\r\n![visualization](demo/teaser-1.png)\r\n\r\nGiven a single portrait image, we can synthesize emotional talking faces, where mouth movements match the input audio and facial emotion dynamics follow the emotion source video.\r\n\r\n## Installation\r\n\r\nWe train and test based on Python3.6 and Pytorch. To install the dependencies run:\r\n\r\n```\r\npip install -r requirements.txt\r\n```\r\n\r\n## Testing\r\n\r\n- Download the pre-trained models and data under the following link: [google-drive](https://drive.google.com/file/d/1IL9LjH3JegyMqJABqMxrX3StAq_v8Gtp/view?usp=sharing) and put the file in corresponding places.\r\n\r\n- Run the demo：\r\n  \r\n  `python demo.py --source_image path/to/image --driving_video path/to/emotion_video --pose_file path/to/pose --in_file path/to/audio --emotion emotion_type`\r\n  \r\n- Prepare testing data：\r\n\r\n  prepare source_image -- crop_image in process_data.py\r\n\r\n  prepare driving_video -- crop_image_tem in process_data.py\r\n\r\n  prepare pose -- detect pose using [3DDFA_V2](https://github.com/cleardusk/3DDFA_V2)\r\n\r\n## Training\r\n\r\n- Training data structure:\r\n\r\n  ```\r\n  ./data/<dataset_name>\r\n  ├──fomm_crop\r\n  │  ├──id/file_name   # cropped images\r\n  │  │  ├──0.png\r\n  │  │  ├──...\r\n  ├──fomm_pose_crop\r\n  │  ├──id   \r\n  │  │  ├──file_name.npy  # pose of the cropped images\r\n  │  │  ├──...\r\n  ├──MFCC\r\n  │  ├──id   \r\n  │  │  ├──file_name.npy  # MFCC of the audio\r\n  │  │  ├──...\r\n  \r\n  \r\n  *The cropped images are generated by 'crop_image_tem' in process_data.py\r\n  *The pose of the cropped video are generated by 3DDFA_V2/demo.py\r\n  *The MFCC of the audio are generated by 'audio2mfcc' in process_data.py\r\n  ```\r\n\r\n    \r\n\r\n- Step 1 : Train the Audio2Facial-Dynamics Module using LRW dataset\r\n\r\n  `python run.py --config config/train_part1.yaml --mode train_part1 --checkpoint log/124_52000.pth.tar `\r\n\r\n- Step 2 : Fine-tune the Audio2Facial-Dynamics Module after getting stable results from step1\r\n\r\n  `python run.py --config config/train_part1_fine_tune.yaml --mode train_part1_fine_tune --checkpoint log/124_52000.pth.tar --audio_chechpoint  checkpoint/from/step_1`\r\n\r\n- Setp 3 : Train the Implicit Emotion Displacement Learner\r\n\r\n  `python run.py --config config/train_part2.yaml --mode train_part2 --checkpoint log/124_52000.pth.tar --audio_chechpoint  checkpoint/from/step_2`\r\n\r\n## Citation\r\n\r\n```\r\n@inproceedings{10.1145/3528233.3530745,\r\nauthor = {Ji, Xinya and Zhou, Hang and Wang, Kaisiyuan and Wu, Qianyi and Wu, Wayne and Xu, Feng and Cao, Xun},\r\ntitle = {EAMM: One-Shot Emotional Talking Face via Audio-Based Emotion-Aware Motion Model},\r\nyear = {2022},\r\nisbn = {9781450393379},\r\nurl = {https://doi.org/10.1145/3528233.3530745},\r\ndoi = {10.1145/3528233.3530745},\r\nbooktitle = {ACM SIGGRAPH 2022 Conference Proceedings},\r\nseries = {SIGGRAPH '22}\r\n}\r\n\r\n\r\n```\r\n\r\n"
  },
  {
    "path": "augmentation.py",
    "content": "\"\"\"\nCode from https://github.com/hassony2/torch_videovision\n\"\"\"\n\nimport numbers\nimport math\nimport random\nimport numpy as np\nimport PIL\nimport cv2\nfrom skimage.transform import resize, rotate, AffineTransform, warp\nfrom skimage.util import pad\nimport torchvision\n\nimport warnings\n\nfrom skimage import img_as_ubyte, img_as_float\n\n\ndef crop_clip(clip, min_h, min_w, h, w):\n    if isinstance(clip[0], np.ndarray):\n        cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]\n\n    elif isinstance(clip[0], PIL.Image.Image):\n        cropped = [\n            img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip\n            ]\n    else:\n        raise TypeError('Expected numpy.ndarray or PIL.Image' +\n                        'but got list of {0}'.format(type(clip[0])))\n    return cropped\n\n\ndef pad_clip(clip, h, w):\n    im_h, im_w = clip[0].shape[:2]\n    pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2)\n    pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2)\n\n    return pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge')\n\n\ndef resize_clip(clip, size, interpolation='bilinear'):\n    if isinstance(clip[0], np.ndarray):\n        if isinstance(size, numbers.Number):\n            im_h, im_w, im_c = clip[0].shape\n            # Min spatial dim already matches minimal size\n            if (im_w <= im_h and im_w == size) or (im_h <= im_w\n                                                   and im_h == size):\n                return clip\n            new_h, new_w = get_resize_sizes(im_h, im_w, size)\n            size = (new_w, new_h)\n        else:\n            size = size[1], size[0]\n\n        scaled = [\n            resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True,\n                   mode='constant', anti_aliasing=True) for img in clip\n            ]\n    elif isinstance(clip[0], PIL.Image.Image):\n        if isinstance(size, numbers.Number):\n            im_w, im_h = clip[0].size\n            # Min spatial dim already matches minimal size\n            if (im_w <= im_h and im_w == size) or (im_h <= im_w\n                                                   and im_h == size):\n                return clip\n            new_h, new_w = get_resize_sizes(im_h, im_w, size)\n            size = (new_w, new_h)\n        else:\n            size = size[1], size[0]\n        if interpolation == 'bilinear':\n            pil_inter = PIL.Image.NEAREST\n        else:\n            pil_inter = PIL.Image.BILINEAR\n        scaled = [img.resize(size, pil_inter) for img in clip]\n    else:\n        raise TypeError('Expected numpy.ndarray or PIL.Image' +\n                        'but got list of {0}'.format(type(clip[0])))\n    return scaled\n\n\ndef get_resize_sizes(im_h, im_w, size):\n    if im_w < im_h:\n        ow = size\n        oh = int(size * im_h / im_w)\n    else:\n        oh = size\n        ow = int(size * im_w / im_h)\n    return oh, ow\n\n\nclass RandomFlip(object):\n    def __init__(self, time_flip=False, horizontal_flip=False):\n        self.time_flip = time_flip\n        self.horizontal_flip = horizontal_flip\n\n    def __call__(self, clip):\n        if random.random() < 0.5 and self.time_flip:\n            return clip[::-1]\n        if random.random() < 0.5 and self.horizontal_flip:\n            return [np.fliplr(img) for img in clip]\n\n        return clip\n\n\nclass RandomResize(object):\n    \"\"\"Resizes a list of (H x W x C) numpy.ndarray to the final size\n    The larger the original image is, the more times it takes to\n    interpolate\n    Args:\n    interpolation (str): Can be one of 'nearest', 'bilinear'\n    defaults to nearest\n    size (tuple): (widht, height)\n    \"\"\"\n\n    def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):\n        self.ratio = ratio\n        self.interpolation = interpolation\n\n    def __call__(self, clip):\n        scaling_factor = random.uniform(self.ratio[0], self.ratio[1])\n\n        if isinstance(clip[0], np.ndarray):\n            im_h, im_w, im_c = clip[0].shape\n        elif isinstance(clip[0], PIL.Image.Image):\n            im_w, im_h = clip[0].size\n\n        new_w = int(im_w * scaling_factor)\n        new_h = int(im_h * scaling_factor)\n        new_size = (new_w, new_h)\n        resized = resize_clip(\n            clip, new_size, interpolation=self.interpolation)\n\n        return resized\n\n\nclass RandomCrop(object):\n    \"\"\"Extract random crop at the same location for a list of videos\n    Args:\n    size (sequence or int): Desired output size for the\n    crop in format (h, w)\n    \"\"\"\n\n    def __init__(self, size):\n        if isinstance(size, numbers.Number):\n            size = (size, size)\n\n        self.size = size\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n        img (PIL.Image or numpy.ndarray): List of videos to be cropped\n        in format (h, w, c) in numpy.ndarray\n        Returns:\n        PIL.Image or numpy.ndarray: Cropped list of videos\n        \"\"\"\n        h, w = self.size\n        if isinstance(clip[0], np.ndarray):\n            im_h, im_w, im_c = clip[0].shape\n        elif isinstance(clip[0], PIL.Image.Image):\n            im_w, im_h = clip[0].size\n        else:\n            raise TypeError('Expected numpy.ndarray or PIL.Image' +\n                            'but got list of {0}'.format(type(clip[0])))\n\n        clip = pad_clip(clip, h, w)\n        im_h, im_w = clip.shape[1:3]\n        x1 = 0 if h == im_h else random.randint(0, im_w - w)\n        y1 = 0 if w == im_w else random.randint(0, im_h - h)\n        cropped = crop_clip(clip, y1, x1, h, w)\n\n        return cropped\n\n\nclass MouthCrop(object):\n    \"\"\"Extract random crop at the same location for a list of videos\n    Args:\n    size (sequence or int): Desired output size for the\n    crop in format (h, w)\n    \"\"\"\n\n    def __init__(self, center_x, center_y, mask_width, mask_height):\n        \n\n        self.center_x = center_x\n        self.center_y = center_y\n        self.mask_width = mask_width\n        self.mask_height = mask_height\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n        img (PIL.Image or numpy.ndarray): List of videos to be cropped\n        in format (h, w, c) in numpy.ndarray\n        Returns:\n        PIL.Image or numpy.ndarray: Cropped list of videos\n        \"\"\"\n        start_x = self.center_x - int(self.mask_width/2)\n        start_y = self.center_y - int(self.mask_height/2) \n        end_x = start_x + self.mask_width\n        end_y = start_y + self.mask_height\n        # mask is all white\n        # mask = 255*np.ones((mask_height, mask_width, 3), dtype=np.uint8)\n        # mask is uniform noise\n        cropped = []\n        for i in range(len(clip)):\n            mask = np.random.rand(self.mask_height, self.mask_width, 3)\n            img = clip[i].copy()\n            img[start_y:end_y, start_x:end_x, :] = mask\n        \n            cropped.append(img)\n        cropped = np.array(cropped)\n        return cropped\n\nclass RandomRotation(object):\n    \"\"\"Rotate entire clip randomly by a random angle within\n    given bounds\n    Args:\n    degrees (sequence or int): Range of degrees to select from\n    If degrees is a number instead of sequence like (min, max),\n    the range of degrees, will be (-degrees, +degrees).\n    \"\"\"\n\n    def __init__(self, degrees):\n        if isinstance(degrees, numbers.Number):\n            if degrees < 0:\n                raise ValueError('If degrees is a single number,'\n                                 'must be positive')\n            degrees = (-degrees, degrees)\n        else:\n            if len(degrees) != 2:\n                raise ValueError('If degrees is a sequence,'\n                                 'it must be of len 2.')\n\n        self.degrees = degrees\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n        img (PIL.Image or numpy.ndarray): List of videos to be cropped\n        in format (h, w, c) in numpy.ndarray\n        Returns:\n        PIL.Image or numpy.ndarray: Cropped list of videos\n        \"\"\"\n        angle = random.uniform(self.degrees[0], self.degrees[1])\n        if isinstance(clip[0], np.ndarray):\n            rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip]\n        elif isinstance(clip[0], PIL.Image.Image):\n            rotated = [img.rotate(angle) for img in clip]\n        else:\n            raise TypeError('Expected numpy.ndarray or PIL.Image' +\n                            'but got list of {0}'.format(type(clip[0])))\n\n        return rotated\n\nclass RandomPerspective(object):\n    \"\"\"Rotate entire clip randomly by a random angle within\n    given bounds\n    Args:\n    degrees (sequence or int): Range of degrees to select from\n    If degrees is a number instead of sequence like (min, max),\n    the range of degrees, will be (-degrees, +degrees).\n    \"\"\"\n\n    def __init__(self, pers_num, enlarge_num):\n        self.pers_num = pers_num\n        self.enlarge_num = enlarge_num\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n        img (PIL.Image or numpy.ndarray): List of videos to be cropped\n        in format (h, w, c) in numpy.ndarray\n        Returns:\n        PIL.Image or numpy.ndarray: Cropped list of videos\n        \"\"\"\n        out = clip\n        for i in range(len(clip)):\n            self.pers_size = np.random.randint(20, self.pers_num) * pow(-1, np.random.randint(2))\n            self.enlarge_size = np.random.randint(20, self.enlarge_num) * pow(-1, np.random.randint(2))\n            h, w, c = clip[i].shape\n            crop_size=256\n            dst = np.array([\n                [-self.enlarge_size, -self.enlarge_size],\n                [-self.enlarge_size + self.pers_size, w + self.enlarge_size],\n                [h + self.enlarge_size, -self.enlarge_size],\n                [h + self.enlarge_size - self.pers_size, w + self.enlarge_size],], dtype=np.float32)\n            src = np.array([[-self.enlarge_size, -self.enlarge_size], [-self.enlarge_size, w + self.enlarge_size],\n                        [h + self.enlarge_size, -self.enlarge_size], [h + self.enlarge_size, w + self.enlarge_size]]).astype(np.float32())\n            M = cv2.getPerspectiveTransform(src, dst)\n            warped = cv2.warpPerspective(clip[i], M, (crop_size, crop_size), borderMode=cv2.BORDER_REPLICATE)\n            out[i] = warped\n\n        return out\n\n\nclass ColorJitter(object):\n    \"\"\"Randomly change the brightness, contrast and saturation and hue of the clip\n    Args:\n    brightness (float): How much to jitter brightness. brightness_factor\n    is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].\n    contrast (float): How much to jitter contrast. contrast_factor\n    is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].\n    saturation (float): How much to jitter saturation. saturation_factor\n    is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].\n    hue(float): How much to jitter hue. hue_factor is chosen uniformly from\n    [-hue, hue]. Should be >=0 and <= 0.5.\n    \"\"\"\n\n    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):\n        self.brightness = brightness\n        self.contrast = contrast\n        self.saturation = saturation\n        self.hue = hue\n\n    def get_params(self, brightness, contrast, saturation, hue):\n        if brightness > 0:\n            brightness_factor = random.uniform(\n                max(0, 1 - brightness), 1 + brightness)\n        else:\n            brightness_factor = None\n\n        if contrast > 0:\n            contrast_factor = random.uniform(\n                max(0, 1 - contrast), 1 + contrast)\n        else:\n            contrast_factor = None\n\n        if saturation > 0:\n            saturation_factor = random.uniform(\n                max(0, 1 - saturation), 1 + saturation)\n        else:\n            saturation_factor = None\n\n        if hue > 0:\n            hue_factor = random.uniform(-hue, hue)\n        else:\n            hue_factor = None\n        return brightness_factor, contrast_factor, saturation_factor, hue_factor\n\n    def __call__(self, clip):\n        \"\"\"\n        Args:\n        clip (list): list of PIL.Image\n        Returns:\n        list PIL.Image : list of transformed PIL.Image\n        \"\"\"\n        if isinstance(clip[0], np.ndarray):\n            brightness, contrast, saturation, hue = self.get_params(\n                self.brightness, self.contrast, self.saturation, self.hue)\n\n            # Create img transform function sequence\n            img_transforms = []\n            if brightness is not None:\n                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))\n            if saturation is not None:\n                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))\n            if hue is not None:\n                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))\n            if contrast is not None:\n                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))\n            random.shuffle(img_transforms)\n            img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array,\n                                                                                                     img_as_float]\n\n            with warnings.catch_warnings():\n                warnings.simplefilter(\"ignore\")\n                jittered_clip = []\n                for img in clip:\n                    jittered_img = img\n                    for func in img_transforms:\n                        jittered_img = func(jittered_img)\n                    jittered_clip.append(jittered_img.astype('float32'))\n        elif isinstance(clip[0], PIL.Image.Image):\n            brightness, contrast, saturation, hue = self.get_params(\n                self.brightness, self.contrast, self.saturation, self.hue)\n\n            # Create img transform function sequence\n            img_transforms = []\n            if brightness is not None:\n                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))\n            if saturation is not None:\n                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))\n            if hue is not None:\n                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))\n            if contrast is not None:\n                img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))\n            random.shuffle(img_transforms)\n\n            # Apply to all videos\n            jittered_clip = []\n            for img in clip:\n                for func in img_transforms:\n                    jittered_img = func(img)\n                jittered_clip.append(jittered_img)\n\n        else:\n            raise TypeError('Expected numpy.ndarray or PIL.Image' +\n                            'but got list of {0}'.format(type(clip[0])))\n        return jittered_clip\n\n\nclass AllAugmentationTransform:\n    def __init__(self, crop_mouth_param = None, resize_param=None, rotation_param=None, perspective_param=None, flip_param=None, crop_param=None, jitter_param=None):\n        self.transforms = []\n        if crop_mouth_param is not None:\n            self.transforms.append(MouthCrop(**crop_mouth_param))\n        \n        if flip_param is not None:\n            self.transforms.append(RandomFlip(**flip_param))\n\n        if rotation_param is not None:\n            self.transforms.append(RandomRotation(**rotation_param))\n        \n        if perspective_param is not None:\n            self.transforms.append(RandomPerspective(**perspective_param))\n\n        if resize_param is not None:\n            self.transforms.append(RandomResize(**resize_param))\n\n        if crop_param is not None:\n            self.transforms.append(RandomCrop(**crop_param))\n\n        if jitter_param is not None:\n            self.transforms.append(ColorJitter(**jitter_param))\n      \n    def __call__(self, clip):\n        for t in self.transforms:\n            clip = t(clip)\n        return clip\n"
  },
  {
    "path": "config/MEAD_emo_video_aug_delta_4_crop_random_crop.yaml",
    "content": "dataset_params:\n  root_dir: /mnt/lustre/share_data/jixinya/MEAD/\n  frame_shape: [256, 256, 3]\n  id_sampling: False\n  pairs_list: Random_choice\n  augmentation_params:\n    crop_mouth_param: \n      center_x: 135\n      center_y: 190\n      mask_width: 100\n      mask_height: 60\n    rotation_param: \n      degrees: 30\n    perspective_param: \n      pers_num: 30\n      enlarge_num: 40\n    flip_param:\n      horizontal_flip: True\n      time_flip: False\n    jitter_param:\n      brightness: 0\n      contrast: 0\n      saturation: 0\n      hue: 0\n\nmodel_params:\n  common_params:\n    num_kp: 10\n    num_channels: 3\n    estimate_jacobian: True\n  audio_params:\n    num_kp: 10\n    num_channels : 3\n    num_channels_a : 3\n    estimate_jacobian: True\n  kp_detector_params:\n     temperature: 0.1\n     block_expansion: 32\n     max_features: 1024\n     scale_factor: 0.25\n     num_blocks: 5\n  generator_params:\n    block_expansion: 64\n    max_features: 512\n    num_down_blocks: 2\n    num_bottleneck_blocks: 6\n    estimate_occlusion_map: True\n    dense_motion_params:\n      block_expansion: 64\n      max_features: 1024\n      num_blocks: 5\n      scale_factor: 0.25\n  discriminator_params:\n    scales: [1]\n    block_expansion: 32\n    max_features: 512\n    num_blocks: 4\n    sn: True\n\ntrain_params:\n  type: linear_4\n  smooth: False\n  jaco_net: cnn\n  ldmark: fake\n  generator: not\n  train_generator: False\n  num_epochs: 300\n  num_repeats: 1\n  epoch_milestones: [60, 90]\n  lr_generator: 2.0e-4\n  lr_discriminator: 2.0e-4\n  lr_kp_detector: 2.0e-4\n  lr_audio_feature: 2.0e-4\n  batch_size: 16\n  scales: [1, 0.5, 0.25, 0.125]\n  checkpoint_freq: 1\n  transform_params:\n    sigma_affine: 0.05\n    sigma_tps: 0.005\n    points_tps: 5\n  loss_weights:\n    generator_gan: 0\n    discriminator_gan: 1\n    feature_matching: [10, 10, 10, 10]\n    perceptual: [10, 10, 10, 10, 10]\n    equivariance_value: 0\n    equivariance_jacobian: 0\n    emo: 10\n\nreconstruction_params:\n  num_videos: 1000\n  format: '.mp4'\n\nanimate_params:\n  num_pairs: 50\n  format: '.mp4'\n  normalization_params:\n    adapt_movement_scale: False\n    use_relative_movement: True\n    use_relative_jacobian: True\n\nvisualizer_params:\n  kp_size: 5\n  draw_border: True\n  colormap: 'gist_rainbow'\n"
  },
  {
    "path": "config/train_part1.yaml",
    "content": "dataset_params:\n  name: Vox\n  root_dir: dataset/LRW/\n  frame_shape: [256, 256, 3]\n  id_sampling: False\n  augmentation_params:\n    flip_param:\n      horizontal_flip: False\n      time_flip: False\n    jitter_param:\n      brightness: 0.1\n      contrast: 0.1\n      saturation: 0.1\n      hue: 0.1\n\n\nmodel_params:\n  common_params:\n    num_kp: 10\n    num_channels: 3\n    estimate_jacobian: True\n  audio_params:\n    num_kp: 10\n    num_channels : 3\n    num_channels_a : 3\n    estimate_jacobian: True\n  kp_detector_params:\n     temperature: 0.1\n     block_expansion: 32\n     max_features: 1024\n     scale_factor: 0.25\n     num_blocks: 5\n  generator_params:\n    block_expansion: 64\n    max_features: 512\n    num_down_blocks: 2\n    num_bottleneck_blocks: 6\n    estimate_occlusion_map: True\n    dense_motion_params:\n      block_expansion: 64\n      max_features: 1024\n      num_blocks: 5\n      scale_factor: 0.25\n  discriminator_params:\n    scales: [1]\n    block_expansion: 32\n    max_features: 512\n    num_blocks: 4\n    sn: True\n\ntrain_params:\n  jaco_net: cnn\n  ldmark: fake\n  generator: not\n  num_epochs: 300\n  num_repeats: 1\n  epoch_milestones: [60, 90]\n  lr_generator: 2.0e-4\n  lr_discriminator: 2.0e-4\n  lr_kp_detector: 2.0e-4\n  lr_audio_feature: 2.0e-4\n  batch_size: 8\n  scales: [1, 0.5, 0.25, 0.125]\n  checkpoint_freq: 1\n  transform_params:\n    sigma_affine: 0.05\n    sigma_tps: 0.005\n    points_tps: 5\n  loss_weights:\n    generator_gan: 0\n    discriminator_gan: 0\n    feature_matching: [10, 10, 10, 10]\n    perceptual: [10, 10, 10, 10, 10]\n    equivariance_value: 0\n    equivariance_jacobian: 0\n    audio: 10\n\n\n\nvisualizer_params:\n  kp_size: 5\n  draw_border: True\n  colormap: 'gist_rainbow'\n"
  },
  {
    "path": "config/train_part1_fine_tune.yaml",
    "content": "dataset_params:\n  name: LRW\n  root_dir: dataset/LRW/\n  frame_shape: [256, 256, 3]\n  id_sampling: False\n  augmentation_params:\n    flip_param:\n      horizontal_flip: False\n      time_flip: False\n    jitter_param:\n      brightness: 0.1\n      contrast: 0.1\n      saturation: 0.1\n      hue: 0.1\n\n\nmodel_params:\n  common_params:\n    num_kp: 10\n    num_channels: 3\n    estimate_jacobian: True\n  audio_params:\n    num_kp: 10\n    num_channels : 3\n    num_channels_a : 3\n    estimate_jacobian: True\n  kp_detector_params:\n     temperature: 0.1\n     block_expansion: 32\n     max_features: 1024\n     scale_factor: 0.25\n     num_blocks: 5\n  generator_params:\n    block_expansion: 64\n    max_features: 512\n    num_down_blocks: 2\n    num_bottleneck_blocks: 6\n    estimate_occlusion_map: True\n    dense_motion_params:\n      block_expansion: 64\n      max_features: 1024\n      num_blocks: 5\n      scale_factor: 0.25\n  discriminator_params:\n    scales: [1]\n    block_expansion: 32\n    max_features: 512\n    num_blocks: 4\n    sn: True\n\ntrain_params:\n  jaco_net: cnn\n  ldmark: fake\n  generator: audio\n  num_epochs: 300\n  num_repeats: 1\n  epoch_milestones: [60, 90]\n  lr_generator: 2.0e-4\n  lr_discriminator: 2.0e-4\n  lr_kp_detector: 2.0e-4\n  lr_audio_feature: 2.0e-4\n  batch_size: 6\n  scales: [1, 0.5, 0.25, 0.125]\n  checkpoint_freq: 1\n  transform_params:\n    sigma_affine: 0.05\n    sigma_tps: 0.005\n    points_tps: 5\n  loss_weights:\n    generator_gan: 0\n    discriminator_gan: 0\n    feature_matching: [10, 10, 10, 10]\n    perceptual: [0.1, 0.1, 0.1, 0.1, 0.1]\n    equivariance_value: 0\n    equivariance_jacobian: 0\n    audio: 10\n\nvisualizer_params:\n  kp_size: 5\n  draw_border: True\n  colormap: 'gist_rainbow'\n"
  },
  {
    "path": "config/train_part2.yaml",
    "content": "dataset_params:\n  name: MEAD\n  root_dir: dataset/MEAD/\n  frame_shape: [256, 256, 3]\n  id_sampling: False\n  augmentation_params:\n    crop_mouth_param: \n      center_x: 135\n      center_y: 190\n      mask_width: 100\n      mask_height: 60\n    rotation_param: \n      degrees: 30\n    perspective_param: \n      pers_num: 30\n      enlarge_num: 40\n    flip_param:\n      horizontal_flip: True\n      time_flip: False\n    jitter_param: \n      brightness: 0\n      contrast: 0\n      saturation: 0\n      hue: 0\n\nmodel_params:\n  common_params:\n    num_kp: 10\n    num_channels: 3\n    estimate_jacobian: True\n  audio_params:\n    num_kp: 10\n    num_channels : 3\n    num_channels_a : 3\n    estimate_jacobian: True\n  kp_detector_params:\n     temperature: 0.1\n     block_expansion: 32\n     max_features: 1024\n     scale_factor: 0.25\n     num_blocks: 5\n  generator_params:\n    block_expansion: 64\n    max_features: 512\n    num_down_blocks: 2\n    num_bottleneck_blocks: 6\n    estimate_occlusion_map: True\n    dense_motion_params:\n      block_expansion: 64\n      max_features: 1024\n      num_blocks: 5\n      scale_factor: 0.25\n  discriminator_params:\n    scales: [1]\n    block_expansion: 32\n    max_features: 512\n    num_blocks: 4\n    sn: True\n\ntrain_params:\n  type: linear_4\n  smooth: False\n  jaco_net: cnn\n  ldmark: fake\n  generator: not\n  num_epochs: 300\n  num_repeats: 1\n  epoch_milestones: [60, 90]\n  lr_generator: 2.0e-4\n  lr_discriminator: 2.0e-4\n  lr_kp_detector: 2.0e-4\n  lr_audio_feature: 2.0e-4\n  batch_size: 16\n  scales: [1, 0.5, 0.25, 0.125]\n  checkpoint_freq: 1\n  transform_params:\n    sigma_affine: 0.05\n    sigma_tps: 0.005\n    points_tps: 5\n  loss_weights:\n    generator_gan: 0\n    discriminator_gan: 0\n    feature_matching: [10, 10, 10, 10]\n    perceptual: [10, 10, 10, 10, 10]\n    equivariance_value: 0\n    equivariance_jacobian: 0\n    emo: 10\n\n\nvisualizer_params:\n  kp_size: 5\n  draw_border: True\n  colormap: 'gist_rainbow'\n"
  },
  {
    "path": "demo.py",
    "content": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nCreated on Wed Oct  6 20:57:27 2021\n@author: thea\n\"\"\"\n\nimport matplotlib\nmatplotlib.use('Agg')\nimport os,sys\nimport yaml\nfrom argparse import ArgumentParser\nfrom tqdm import tqdm\nfrom skimage import io, img_as_float32\nimport imageio\nimport numpy as np\nfrom skimage.transform import resize\nfrom skimage import img_as_ubyte\nimport torch\nfrom filter1 import OneEuroFilter\nimport torch.utils\n\nfrom torch.autograd import Variable\nfrom modules.generator import OcclusionAwareGenerator\nfrom modules.keypoint_detector import KPDetector, KPDetector_a\nfrom modules.util import AT_net, Emotion_k, Emotion_map, AT_net2\nfrom augmentation import AllAugmentationTransform\n\nfrom scipy.spatial import ConvexHull\n\nimport python_speech_features\nfrom pathlib import Path\nimport dlib\nimport cv2\nimport librosa\nfrom skimage import transform as tf\n#from audiolm.models import AT_emoiton\n#from audiolm.utils import plot_flmarks\nif sys.version_info[0] < 3:\n    raise Exception(\"You must use Python 3 or higher. Recommended version is Python 3.6\")\n\n\ndetector = dlib.get_frontal_face_detector()\npredictor = dlib.shape_predictor('./shape_predictor_68_face_landmarks.dat')\n\n\n\n\ndef load_checkpoints(opt, checkpoint_path, audio_checkpoint_path, emo_checkpoint_path, cpu=False):\n\n    with open(opt.config) as f:\n        config = yaml.load(f, Loader=yaml.FullLoader)\n\n    generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],\n                                        **config['model_params']['common_params'])\n    if not cpu:\n        generator.cuda()\n\n    kp_detector = KPDetector(**config['model_params']['kp_detector_params'],\n                             **config['model_params']['common_params'])\n    if not cpu:\n        kp_detector.cuda()\n\n    kp_detector_a = KPDetector_a(**config['model_params']['kp_detector_params'],\n                             **config['model_params']['audio_params'])\n\n    audio_feature = AT_net2()\n    if opt.type.startswith('linear'):\n        emo_detector = Emotion_k(block_expansion=32, num_channels=3, max_features=1024,\n                 num_blocks=5, scale_factor=0.25, num_classes=8)\n    elif opt.type.startswith('map'):\n        emo_detector = Emotion_map(block_expansion=32, num_channels=3, max_features=1024,\n                 num_blocks=5, scale_factor=0.25, num_classes=8)\n    if not cpu:\n        kp_detector_a.cuda()\n        audio_feature.cuda()\n        emo_detector.cuda()\n\n\n\n\n    if cpu:\n        checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))\n        audio_checkpoint = torch.load(audio_checkpoint_path, map_location=torch.device('cpu'))\n        emo_checkpoint = torch.load(emo_checkpoint_path, map_location=torch.device('cpu'))\n    else:\n        checkpoint = torch.load(checkpoint_path)\n        audio_checkpoint = torch.load(audio_checkpoint_path)\n        emo_checkpoint = torch.load(emo_checkpoint_path)\n\n    generator.load_state_dict(checkpoint['generator'])\n    kp_detector.load_state_dict(checkpoint['kp_detector'])\n    audio_feature.load_state_dict(audio_checkpoint['audio_feature'])\n    kp_detector_a.load_state_dict(audio_checkpoint['kp_detector_a'])\n    emo_detector.load_state_dict(emo_checkpoint['emo_detector'])\n    \n\n    if not cpu:\n        generator = generator.cuda()\n        kp_detector = kp_detector.cuda()\n        audio_feature = audio_feature.cuda()\n        kp_detector_a = kp_detector_a.cuda()\n        emo_detector = emo_detector.cuda()\n\n    generator.eval()\n    kp_detector.eval()\n    audio_feature.eval()\n    kp_detector_a.eval()\n    emo_detector.eval()\n    return generator, kp_detector, kp_detector_a, audio_feature, emo_detector\n\ndef normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,\n                 use_relative_movement=False, use_relative_jacobian=False):\n    if adapt_movement_scale:\n        source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume\n        driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume\n        adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)\n    else:\n        adapt_movement_scale = 1\n\n    kp_new = {k: v for k, v in kp_driving.items()}\n\n    if use_relative_movement:\n        kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])\n        kp_value_diff *= adapt_movement_scale\n        kp_new['value'] = kp_value_diff + kp_source['value']\n\n        if use_relative_jacobian:\n            jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))\n            kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])\n\n    return kp_new\n\ndef shape_to_np(shape, dtype=\"int\"):\n    # initialize the list of (x, y)-coordinates\n    coords = np.zeros((shape.num_parts, 2), dtype=dtype)\n\n    # loop over all facial landmarks and convert them\n    # to a 2-tuple of (x, y)-coordinates\n    for i in range(0, shape.num_parts):\n        coords[i] = (shape.part(i).x, shape.part(i).y)\n\n    # return the list of (x, y)-coordinates\n    return coords\n\ndef get_aligned_image(driving_video, opt):\n    aligned_array = []\n\n    video_array = np.array(driving_video)\n    source_image=video_array[0]\n   # aligned_array.append(source_image)\n    source_image = np.array(source_image * 255, dtype=np.uint8)\n    gray = cv2.cvtColor(source_image, cv2.COLOR_BGR2GRAY)\n    rects = detector(gray, 1)  #detect human face\n    for (i, rect) in enumerate(rects):\n        template = predictor(gray, rect) #detect 68 points\n        template = shape_to_np(template)\n\n    if opt.emotion == 'surprised' or opt.emotion == 'fear':\n        template = template-[0,10]\n    for i in range(len(video_array)):\n        image=np.array(video_array[i] * 255, dtype=np.uint8)\n        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)\n        rects = detector(gray, 1)  #detect human face\n        for (j, rect) in enumerate(rects):\n            shape = predictor(gray, rect) #detect 68 points\n            shape = shape_to_np(shape)\n\n        pts2 = np.float32(template[:35,:])\n        pts1 = np.float32(shape[:35,:]) #eye and nose\n\n    #    pts2 = np.float32(np.concatenate((template[:16,:],template[27:36,:]),axis = 0))\n    #    pts1 = np.float32(np.concatenate((shape[:16,:],shape[27:36,:]),axis = 0)) #eye and nose\n        # pts1 = np.float32(landmark[17:35,:])\n        tform = tf.SimilarityTransform()\n        tform.estimate( pts2, pts1) #Set the transformation matrix with the explicit parameters.\n        dst = tf.warp(image, tform, output_shape=(256, 256))\n\n        dst = np.array(dst, dtype=np.float32)\n        aligned_array.append(dst)\n\n    return aligned_array\n\ndef get_transformed_image(driving_video, opt):\n    video_array = np.array(driving_video)\n    with open(opt.config) as f:\n        config = yaml.load(f, Loader=yaml.FullLoader)\n    transformations = AllAugmentationTransform(**config['dataset_params']['augmentation_params'])\n    transformed_array = transformations(video_array)\n    return transformed_array\n\n\n\ndef make_animation_smooth(source_image, driving_video, transformed_video, deco_out, kp_loss, generator, kp_detector, kp_detector_a, emo_detector, opt, relative=True, adapt_movement_scale=True, cpu=False):\n    with torch.no_grad():\n        predictions = []\n\n        source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)\n\n        if not cpu:\n            source = source.cuda()\n\n        driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)\n        transformed_driving = torch.tensor(np.array(transformed_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)\n\n        kp_source = kp_detector(source)\n        kp_driving_initial = kp_detector_a(deco_out[:,0])\n\n        emo_driving_all = []\n        features = []\n        kp_driving_all = []\n        for frame_idx in tqdm(range(len(deco_out[0]))):\n\n            driving_frame = driving[:, :, frame_idx]\n            transformed_frame = transformed_driving[:, :, frame_idx]\n            if not cpu:\n                driving_frame = driving_frame.cuda()\n                transformed_frame = transformed_frame.cuda()\n            kp_driving = kp_detector_a(deco_out[:,frame_idx])\n            kp_driving_all.append(kp_driving)\n            if opt.add_emo:\n                value = kp_driving['value']\n                jacobian = kp_driving['jacobian']\n                if opt.type == 'linear_3':\n                    emo_driving,_ = emo_detector(transformed_frame,value,jacobian)\n                    features.append(emo_detector.feature(transformed_frame).data.cpu().numpy())\n            \n                emo_driving_all.append(emo_driving)\n        features = np.array(features)\n        if opt.add_emo:        \n            one_euro_filter_v = OneEuroFilter(mincutoff=1, beta=0.2, dcutoff=1.0, freq=100)#1 0.4\n            one_euro_filter_j = OneEuroFilter(mincutoff=1, beta=0.2, dcutoff=1.0, freq=100)#1 0.4\n\n            for j in range(len(emo_driving_all)):\n                emo_driving_all[j]['value']=one_euro_filter_v.process(emo_driving_all[j]['value'].cpu()*100)/100\n                emo_driving_all[j]['value'] = emo_driving_all[j]['value'].cuda()\n                emo_driving_all[j]['jacobian']=one_euro_filter_j.process(emo_driving_all[j]['jacobian'].cpu()*100)/100\n                emo_driving_all[j]['jacobian'] = emo_driving_all[j]['jacobian'].cuda()\n\n\n        one_euro_filter_v = OneEuroFilter(mincutoff=0.05, beta=8, dcutoff=1.0, freq=100)\n        one_euro_filter_j = OneEuroFilter(mincutoff=0.05, beta=8, dcutoff=1.0, freq=100)\n\n        for j in range(len(kp_driving_all)):\n            kp_driving_all[j]['value']=one_euro_filter_v.process(kp_driving_all[j]['value'].cpu()*10)/10\n            kp_driving_all[j]['value'] = kp_driving_all[j]['value'].cuda()\n            kp_driving_all[j]['jacobian']=one_euro_filter_j.process(kp_driving_all[j]['jacobian'].cpu()*10)/10\n            kp_driving_all[j]['jacobian'] = kp_driving_all[j]['jacobian'].cuda()\n\n\n        for frame_idx in tqdm(range(len(deco_out[0]))):\n            \n            if opt.check_add:\n                kp_driving = kp_detector_a(deco_out[:,0])\n            else:\n                kp_driving = kp_driving_all[frame_idx]\n\n       #     kp_driving_real = kp_detector(driving_frame)\n\n       #     kp_driving['value'] = (1-opt.weight)*kp_driving['value'] + opt.weight*kp_driving_real['value']\n       #     kp_driving['jacobian'] = (1-opt.weight)*kp_driving['jacobian'] + opt.weight*kp_driving_real['jacobian']\n\n            if opt.add_emo:\n                emo_driving = emo_driving_all[frame_idx]\n                if opt.type == 'linear_3':\n                    kp_driving['value'][:,1] = kp_driving['value'][:,1] + emo_driving['value'][:,0]*0.2\n                    kp_driving['jacobian'][:,1] = kp_driving['jacobian'][:,1] + emo_driving['jacobian'][:,0]*0.2\n                    kp_driving['value'][:,4] = kp_driving['value'][:,4] + emo_driving['value'][:,1]\n                    kp_driving['jacobian'][:,4] = kp_driving['jacobian'][:,4] + emo_driving['jacobian'][:,1]\n                    kp_driving['value'][:,6] = kp_driving['value'][:,6] + emo_driving['value'][:,2]\n                    kp_driving['jacobian'][:,6] = kp_driving['jacobian'][:,6] + emo_driving['jacobian'][:,2]\n                   # kp_driving['value'][:,8] = kp_driving['value'][:,8] + emo_driving['value'][:,3]\n                   # kp_driving['jacobian'][:,8] = kp_driving['jacobian'][:,8] + emo_driving['jacobian'][:,3]\n               \n         \n            kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,\n                                   kp_driving_initial=kp_driving_initial, use_relative_movement=relative,\n                                   use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)\n            out = generator(source, kp_source=kp_source, kp_driving=kp_norm)\n\n            predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])\n    return predictions, features\n\n\n\ndef test_auido(example_image, audio_feature, all_pose, opt):\n    with open(opt.config) as f:\n        para = yaml.load(f, Loader=yaml.FullLoader)\n\n  #  encoder = audio_feature()\n    if not opt.cpu:\n        audio_feature = audio_feature.cuda()\n\n    audio_feature.eval()\n #   decoder.eval()\n    test_file = opt.in_file\n    pose = all_pose[:,:6]\n    if len(pose) == 1:\n        pose = np.repeat(pose,100,0)\n\n    elif opt.smooth_pose:\n        one_euro_filter = OneEuroFilter(mincutoff=0.004, beta=0.7, dcutoff=1.0, freq=100)\n\n\n        for j in range(len(pose)):\n            pose[j]=one_euro_filter.process(pose[j])\n      #      pose[j]=pose[0]\n\n    example_image = np.array(example_image, dtype='float32').transpose((2, 0, 1))\n\n    \n\n\n    speech, sr = librosa.load(test_file, sr=16000)\n  #  mfcc = python_speech_features.mfcc(speech ,16000,winstep=0.01)\n    speech = np.insert(speech, 0, np.zeros(1920))\n    speech = np.append(speech, np.zeros(1920))\n    mfcc = python_speech_features.mfcc(speech,16000,winstep=0.01)\n\n\n    print ('=======================================')\n    print ('Start to generate images')\n\n    ind = 3\n    with torch.no_grad():\n        fake_lmark = []\n        input_mfcc = []\n        while ind <= int(mfcc.shape[0]/4) - 4:\n            t_mfcc =mfcc[( ind - 3)*4: (ind + 4)*4, 1:]\n            t_mfcc = torch.FloatTensor(t_mfcc).cuda()\n            input_mfcc.append(t_mfcc)\n            ind += 1\n        input_mfcc = torch.stack(input_mfcc,dim = 0)\n\n        if (len(pose)<len(input_mfcc)):\n            gap = len(input_mfcc)-len(pose)\n            n = int((gap/len(pose)/2)) +2\n            pose = np.concatenate((pose,pose[::-1,:]),axis = 0)\n            pose = np.tile(pose, (n,1))\n        if(len(pose)>len(input_mfcc)):\n            pose = pose[:len(input_mfcc),:]\n        \n        if not opt.cpu:\n            example_image = Variable(torch.FloatTensor(example_image.astype(float)) ).cuda()\n            example_image = torch.unsqueeze(example_image,0)\n            pose = Variable(torch.FloatTensor(pose.astype(float)) ).cuda()\n        \n        pose = pose.unsqueeze(0)\n\n        input_mfcc = input_mfcc.unsqueeze(0)\n\n        deco_out = audio_feature(example_image,input_mfcc,pose,para['train_params']['jaco_net'],1.6)\n\n        return deco_out\n\n\ndef save(path, frames, format):\n\n    if format == '.png':\n        if not os.path.exists(path):\n\n            os.makedirs(path)\n        for j, frame in enumerate(frames):\n            imageio.imsave(path+'/'+str(j)+'.png',frame)\n    #        imageio.imsave(os.path.join(path, str(j) + '.png'), frames[j])\n    else:\n        print (\"Unknown format %s\" % format)\n        exit()\n\nclass VideoWriter(object):\n    def __init__(self, path, width, height, fps):\n        fourcc = cv2.VideoWriter_fourcc(*'XVID')\n        self.path = path\n        self.out = cv2.VideoWriter(self.path, fourcc, fps, (width, height))\n\n    def write_frame(self, frame):\n        self.out.write(frame)\n\n    def end(self):\n        self.out.release()\n\ndef concatenate(number, imgs, save_path):\n    width, height = imgs.shape[-3:-1]\n    imgs = imgs.reshape(number,-1,width,height,3)\n    if number == 2:\n        left = imgs[0]\n        right = imgs[1]\n\n        im_all = []\n        for i in range(len(left)):\n            im = np.concatenate((left[i],right[i]),axis = 1)\n            im_all.append(im)\n    if number == 3:\n        left = imgs[0]\n        middle = imgs[1]\n        right = imgs[2]\n\n        im_all = []\n        for i in range(len(left)):\n            im = np.concatenate((left[i],middle[i],right[i]),axis = 1)\n            im_all.append(im)\n    if number == 4:\n        left = imgs[0]\n        left2 = imgs[1]\n        right = imgs[2]\n        right2 = imgs[3]\n\n        im_all = []\n        for i in range(len(left)):\n            im = np.concatenate((left[i],left2[i],right[i],right2[i]),axis = 1)\n            im_all.append(im)\n    if number == 5:\n        left = imgs[0]\n        left2 = imgs[1]\n        middle = imgs[2]\n        right = imgs[3]\n        right2 = imgs[4]\n\n        im_all = []\n        for i in range(len(left)):\n            im = np.concatenate((left[i],left2[i],middle[i],right[i],right2[i]),axis = 1)\n            im_all.append(im)\n\n\n    imageio.mimsave(save_path, [img_as_ubyte(frame) for frame in im_all], fps=25)\n\ndef add_audio(video_name=None, audio_dir = None):\n\n    command = 'ffmpeg -i ' + video_name  + ' -i ' + audio_dir + ' -vcodec copy  -acodec copy -y  ' + video_name.replace('.mp4','.mov')\n    print (command)\n    os.system(command)\n\ndef crop_image(source_image):\n    \n    template = np.load('./M003_template.npy')\n    image= cv2.imread(source_image)\n    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)\n    rects = detector(gray, 1)  #detect human face\n    if len(rects) != 1:\n        return 0\n    for (j, rect) in enumerate(rects):\n        shape = predictor(gray, rect) #detect 68 points\n        shape = shape_to_np(shape)\n\n    pts2 = np.float32(template[:47,:])\n    pts1 = np.float32(shape[:47,:]) #eye and nose\n    # pts1 = np.float32(landmark[17:35,:])\n    tform = tf.SimilarityTransform()\n    tform.estimate( pts2, pts1) #Set the transformation matrix with the explicit parameters.\n  \n    dst = tf.warp(image, tform, output_shape=(256, 256))\n\n    dst = np.array(dst * 255, dtype=np.uint8)\n    return dst \n\ndef smooth_pose(pose_file, pose_long):\n    start = np.load(pose_file)\n    video_pose = np.load(pose_long)\n    delta = video_pose - video_pose[0,:]\n    print(len(delta))\n    \n    pose = np.repeat(start,len(delta),axis = 0)\n    all_pose =  pose + delta\n\n    return all_pose\n\ndef test(opt, name):\n\n    all_pose = np.load(opt.pose_file).reshape(-1,7)\n    if opt.pose_long:\n\n        all_pose = smooth_pose(opt.pose_file,opt.pose_given)\n\n    \n   # source_image = img_as_float32(io.imread(opt.source_image))\n    source_image = img_as_float32(crop_image(opt.source_image))\n    source_image = resize(source_image, (256, 256))[..., :3]\n  \n    reader = imageio.get_reader(opt.driving_video)\n    fps = reader.get_meta_data()['fps']\n    driving_video = []\n    try:\n        for im in reader:\n            driving_video.append(im)\n    except RuntimeError:\n        pass\n    reader.close()\n\n   \n    driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]\n    driving_video = get_aligned_image(driving_video, opt)\n    transformed_video = get_transformed_image(driving_video, opt)\n    transformed_video = np.array(transformed_video)\n\n    generator, kp_detector,kp_detector_a, audio_feature, emo_detector = load_checkpoints(opt=opt, checkpoint_path=opt.checkpoint, audio_checkpoint_path=opt.audio_checkpoint, emo_checkpoint_path = opt.emo_checkpoint, cpu=opt.cpu)\n \n    deco_out = test_auido(source_image, audio_feature, all_pose, opt)\n    if len(driving_video) < len(deco_out[0]):\n        driving_video = np.resize(driving_video,(len(deco_out[0]),256,256,3))\n        transformed_video = np.resize(transformed_video,(len(deco_out[0]),256,256,3))\n\n    else:\n        driving_video = driving_video[:len(deco_out[0])]\n    opt.add_emo = False\n    predictions, _ = make_animation_smooth(source_image, driving_video, transformed_video, deco_out, opt.kp_loss, generator, kp_detector, kp_detector_a, emo_detector, opt, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)\n  \n    imageio.mimsave(os.path.join(opt.result_path,'neutral.mp4'), [img_as_ubyte(frame) for frame in predictions], fps=fps)\n    predictions = np.array(predictions)\n     \n    opt.add_emo = True\n  \n    predictions1,_ = make_animation_smooth(source_image, driving_video, transformed_video, deco_out, opt.kp_loss, generator, kp_detector, kp_detector_a, emo_detector, opt, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)\n  \n    imageio.mimsave(os.path.join(opt.result_path,'emotion.mp4'), [img_as_ubyte(frame) for frame in predictions1], fps=fps)\n    add_audio(os.path.join(opt.result_path,'emotion.mp4'),opt.in_file)\n    predictions1 = np.array(predictions1)\n    all_imgs = np.concatenate((driving_video,predictions,predictions1),axis = 0)\n    save_path = os.path.join(opt.result_path, 'all.mp4')\n    concatenate(3, all_imgs, save_path)\n    add_audio(save_path,opt.in_file)\n\n\n\nif __name__ == \"__main__\":\n   \n    \n   \n    parser = ArgumentParser()\n    parser.add_argument(\"--config\", default ='config/MEAD_emo_video_aug_delta_4_crop_random_crop.yaml', help=\"path to config\")#required=True default ='config/vox-256.yaml'\n \n    parser.add_argument(\"--audio_checkpoint\", default='log/1-6000.pth.tar', help=\"path to checkpoint to restore\")\n    parser.add_argument(\"--checkpoint\", default='log/124_52000.pth.tar', help=\"path to checkpoint to restore\")\n   # parser.add_argument(\"--emo_checkpoint\", default='ablation/ablation/ten/10-6000.pth.tar', help=\"path to checkpoint to restore\")\n    parser.add_argument(\"--emo_checkpoint\", default='log/5-3000.pth.tar', help=\"path to checkpoint to restore\")\n\n    parser.add_argument(\"--source_image\", default='test/image/21.png', help=\"path to source image\")\n \n    parser.add_argument(\"--driving_video\", default='test/video/disgusted.mp4', help=\"path to driving video\")#data/M030/video/M030_angry_\n    parser.add_argument('--in_file', type=str, default='test/audio/sample1.mov')\n    parser.add_argument('--pose_file', type=str, default='test/pose/21.npy')\n    parser.add_argument('--pose_given', type=str, default='test/pose_long/0zn70Ak8lRc_Daniel_Auteuil_0zn70Ak8lRc_0002.npy')\n\n    parser.add_argument(\"--result_path\", default='result/', help=\"path to output\")#'/media/thea/新加卷/fomm/Exp/'+emotion+'.mp4'\n\n    parser.add_argument(\"--relative\", dest=\"relative\", action=\"store_true\", help=\"use relative or absolute keypoint coordinates\")\n    parser.add_argument(\"--adapt_scale\", dest=\"adapt_scale\", action=\"store_true\", help=\"adapt movement scale based on convex hull of keypoints\")\n\n    parser.add_argument(\"--cpu\", dest=\"cpu\", action=\"store_true\", help=\"cpu mode.\")\n    parser.add_argument(\"--kp_loss\", default=0, help=\"keypoint loss.\")\n\n    parser.add_argument(\"--smooth_pose\",  default=True, help=\"cpu mode.\")\n    parser.add_argument(\"--pose_long\",  default=False, help=\"use given long poses.\")\n    parser.add_argument(\"--weight\",  default=0, help=\"cpu mode.\")\n    parser.add_argument(\"--add_emo\",  default=False, help=\"add emotion.\")\n    parser.add_argument(\"--check_add\",  default=False, help=\"check emotion displacement.\")\n    parser.add_argument(\"--type\",  default='linear_3', help=\"add emotion type.\")\n    parser.add_argument(\"--emotion\",  default='disgusted', help=\"emotion category, 'angry', 'contempt','disgusted','fear','happy','neutral','sad','surprised'.\")\n    parser.set_defaults(relative=False)\n    parser.set_defaults(adapt_scale=False)\n\n    opt = parser.parse_args()\n #   opt.cpu = True\n   \n    test(opt,'test')\n         \n    "
  },
  {
    "path": "filter1.py",
    "content": "import cv2\n#import pickle\nimport time\nimport numpy as np\nimport copy\n\nfrom matplotlib import pyplot as plt\nfrom tqdm import tqdm\n\n\n\n\nclass LowPassFilter:\n  def __init__(self):\n    self.prev_raw_value = None\n    self.prev_filtered_value = None\n\n  def process(self, value, alpha):\n    if self.prev_raw_value is None:\n      s = value\n    else:\n      s = alpha * value + (1.0 - alpha) * self.prev_filtered_value\n    self.prev_raw_value = value\n    self.prev_filtered_value = s\n    return s\n\n\nclass OneEuroFilter:\n  def __init__(self, mincutoff=1.0, beta=0.0, dcutoff=1.0, freq=30):\n    self.freq = freq\n    self.mincutoff = mincutoff\n    self.beta = beta\n    self.dcutoff = dcutoff\n    self.x_filter = LowPassFilter()\n    self.dx_filter = LowPassFilter()\n\n  def compute_alpha(self, cutoff):\n    te = 1.0 / self.freq\n    tau = 1.0 / (2 * np.pi * cutoff)\n    return 1.0 / (1.0 + tau / te)\n\n  def process(self, x):\n    prev_x = self.x_filter.prev_raw_value\n    dx = 0.0 if prev_x is None else (x - prev_x) * self.freq\n    edx = self.dx_filter.process(dx, self.compute_alpha(self.dcutoff))\n    cutoff = self.mincutoff + self.beta * np.abs(edx)\n    return self.x_filter.process(x, self.compute_alpha(cutoff))\n\n"
  },
  {
    "path": "frames_dataset.py",
    "content": "import os\nfrom skimage import io, img_as_float32, transform\nfrom skimage.color import gray2rgb\nfrom sklearn.model_selection import train_test_split\nfrom imageio import mimread\n\nimport numpy as np\nfrom torch.utils.data import Dataset\nimport pandas as pd\nfrom augmentation import AllAugmentationTransform\nimport glob\nimport pickle\nimport random\nfrom filter1 import OneEuroFilter\ndef read_video(name, frame_shape):\n    \"\"\"\n    Read video which can be:\n      - an image of concatenated frames\n      - '.mp4' and'.gif'\n      - folder with videos\n    \"\"\"\n\n    if os.path.isdir(name):\n        frames = sorted(os.listdir(name))\n        num_frames = len(frames)\n        video_array = np.array(\n            [img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)])\n    elif name.lower().endswith('.png') or name.lower().endswith('.jpg'):\n        image = io.imread(name)\n\n        if len(image.shape) == 2 or image.shape[2] == 1:\n            image = gray2rgb(image)\n\n        if image.shape[2] == 4:\n            image = image[..., :3]\n\n        image = img_as_float32(image)\n\n        video_array = np.moveaxis(image, 1, 0)\n\n        video_array = video_array.reshape((-1,) + frame_shape)\n        video_array = np.moveaxis(video_array, 1, 2)\n    elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'):\n        video = np.array(mimread(name))\n        if len(video.shape) == 3:\n            video = np.array([gray2rgb(frame) for frame in video])\n        if video.shape[-1] == 4:\n            video = video[..., :3]\n        video_array = img_as_float32(video)\n    else:\n        raise Exception(\"Unknown file extensions  %s\" % name)\n\n    return video_array\n\ndef get_list(ipath,base_name):\n#ipath = '/mnt/lustre/share/jixinya/LRW/pose/train_fo/'\n    ipath = os.path.join(ipath,base_name)\n    name_list = os.listdir(ipath)\n    image_path = os.path.join('/mnt/lustre/share/jixinya/LRW/Image/',base_name)\n    all = []\n    for k in range(len(name_list)):\n        name = name_list[k]\n        path_ = os.path.join(ipath,name)\n        Dir = os.listdir(path_)\n        for i in range(len(Dir)):\n            word = Dir[i]\n            path = os.path.join(path_, word)\n            if os.path.exists(os.path.join(image_path,name,word.split('.')[0])):\n                all.append(name+'/'+word.split('.')[0])\n            #print(k,name,i,word)\n    print('get list '+os.path.basename(ipath))\n    return all\n\n\nclass AudioDataset(Dataset):\n    \"\"\"\n    Dataset of videos, each video can be represented as:\n      - an image of concatenated frames\n      - '.mp4' or '.gif'\n      - folder with all frames\n    \"\"\"\n\n    def __init__(self, name, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,\n                 random_seed=0, augmentation_params=None):\n        self.root_dir = root_dir\n        self.audio_dir = os.path.join(root_dir,'MFCC')\n        self.image_dir = os.path.join(root_dir,'Image')\n        self.pose_dir = os.path.join(root_dir,'pose')\n      #  assert len(os.listdir(self.audio_dir)) == len(os.listdir(self.image_dir)), 'audio and image length not equal'\n\n      #  self.videos=np.load('../LRW/list/train_fo.npy')\n      #  self.videos = os.listdir(self.landmark_dir)\n        self.frame_shape = tuple(frame_shape)\n       \n        self.id_sampling = id_sampling\n\n        if os.path.exists(os.path.join(self.pose_dir, 'train_fo')):\n            assert os.path.exists(os.path.join(self.pose_dir, 'test_fo'))\n            print(\"Use predefined train-test split.\")\n            if id_sampling:\n                train_videos = {os.path.basename(video).split('#')[0] for video in\n                                os.listdir(os.path.join(self.image_dir, 'train'))}\n                train_videos = list(train_videos)\n            else:\n                train_videos =  np.load('../LRW/list/train_fo.npy')# get_list(self.pose_dir, 'train_fo')\n         #   df=open('../LRW/list/test_fo.txt','rb')\n            test_videos=np.load('../LRW/list/test_fo.npy')\n         #   df.close()\n         #   test_videos = np.load('../LRW/list/train_fo.npy')\n            #get_list(self.pose_dir, 'test_fo')\n        #    self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')\n           \n            self.image_dir = os.path.join(self.image_dir, 'train_fo' if is_train else 'test_fo')\n            self.audio_dir = os.path.join(self.audio_dir, 'train' if is_train else 'test')\n            self.pose_dir = os.path.join(self.pose_dir, 'train_fo' if is_train else 'test_fo')\n        else:\n            print(\"Use random train-test split.\")\n            train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)\n\n        if is_train:\n            self.videos = train_videos\n        else:\n            self.videos = test_videos\n\n        self.is_train = is_train\n\n        if self.is_train:\n            self.transform = AllAugmentationTransform(**augmentation_params)\n        else:\n            self.transform = None\n\n    def __len__(self):\n        return len(self.videos)\n\n    def __getitem__(self, idx):\n        if self.is_train and self.id_sampling:\n            name = self.videos[idx].split('.')[0]\n            path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))\n        else:\n            name = self.videos[idx].split('.')[0]\n           \n            audio_path = os.path.join(self.audio_dir, name)\n            pose_path = os.path.join(self.pose_dir,name)\n            path = os.path.join(self.image_dir, name)\n\n        video_name = os.path.basename(path)\n\n        if  os.path.isdir(path):\n     #   if self.is_train and os.path.isdir(path):\n         \n            # mfcc loading\n            r = random.choice([x for x in range(3, 8)])\n\n            example_image = img_as_float32(io.imread(os.path.join(path, str(r)+'.png')))\n\n            mfccs = []\n            for ind in range(1, 17):\n              #  t_mfcc = mfcc[(r + ind - 3) * 4: (r + ind + 4) * 4, 1:]\n                t_mfcc = np.load(os.path.join(audio_path,str(r + ind)+'.npy'),allow_pickle=True)[:, 1:]\n                mfccs.append(t_mfcc)\n            mfccs = np.array(mfccs)\n            \n            poses = []\n            video_array = []\n            for ind in range(1, 17):\n              \n                t_pose = np.load(os.path.join(self.pose_dir,name+'.npy'))[r+ind,:-1]\n                \n                poses.append(t_pose)\n                image = img_as_float32(io.imread(os.path.join(path, str(r + ind)+'.png')))\n                video_array.append(image)\n            poses = np.array(poses)\n            video_array = np.array(video_array)\n\n        else:\n            print('Wrong, data path not an existing file.')\n\n        if self.transform is not None:\n            video_array = self.transform(video_array)\n\n        out = {}\n     \n        driving = np.array(video_array, dtype='float32')\n        spatial_size = np.array(driving.shape[1:3][::-1])[np.newaxis]\n        driving_pose = np.array(poses, dtype='float32')\n        example_image = np.array(example_image, dtype='float32')\n\n        out['example_image'] = example_image.transpose((2, 0, 1))\n        out['driving_pose'] = driving_pose\n        out['driving'] = driving.transpose((0, 3, 1, 2))\n        out['driving_audio'] = np.array(mfccs, dtype='float32')\n    #    out['name'] = video_name\n\n        return out\n\nclass VoxDataset(Dataset):\n    \"\"\"\n    Dataset of videos, each video can be represented as:\n      - an image of concatenated frames\n      - '.mp4' or '.gif'\n      - folder with all frames\n    \"\"\"\n\n    def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,\n                 random_seed=0, pairs_list=None, augmentation_params=None):\n        self.root_dir = root_dir\n        self.audio_dir = os.path.join(root_dir,'MFCC')\n        self.image_dir = os.path.join(root_dir,'align_img')\n\n        self.pose_dir = os.path.join(root_dir,'align_pose')\n      #  assert len(os.listdir(self.audio_dir)) == len(os.listdir(self.image_dir)), 'audio and image length not equal'\n\n\n     #   df=open('../LRW/list/test_fo.txt','rb')\n     #  self.videos=pickle.load(df)\n     #   df.close()\n        self.videos=np.load('/mnt/lustre/share_data/jixinya/VoxCeleb1_Cut/right.npy')\n      #  self.videos = os.listdir(self.landmark_dir)\n        self.frame_shape = tuple(frame_shape)\n        self.pairs_list = pairs_list\n        self.id_sampling = id_sampling\n\n        if os.path.exists(os.path.join(self.pose_dir, 'train_fo')):\n            assert os.path.exists(os.path.join(self.pose_dir, 'test_fo'))\n            print(\"Use predefined train-test split.\")\n            if id_sampling:\n                train_videos = {os.path.basename(video).split('#')[0] for video in\n                                os.listdir(os.path.join(self.image_dir, 'train'))}\n                train_videos = list(train_videos)\n            else:\n                train_videos = np.load('/mnt/lustre/share_data/jixinya/VoxCeleb1_Cut/right.npy')# get_list(self.pose_dir, 'train_fo')\n      \n            self.image_dir = os.path.join(self.image_dir, 'train_fo' if is_train else 'test_fo')\n            self.audio_dir = os.path.join(self.audio_dir, 'train' if is_train else 'test')\n            self.pose_dir = os.path.join(self.pose_dir, 'train_fo' if is_train else 'test_fo')\n        else:\n            print(\"Use random train-test split.\")\n            train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)\n\n        if is_train:\n            self.videos = train_videos\n        else:\n            self.videos = test_videos\n\n        self.is_train = is_train\n\n        if self.is_train:\n            self.transform = AllAugmentationTransform(**augmentation_params)\n        else:\n            self.transform = None\n\n    def __len__(self):\n        return len(self.videos)\n\n    def __getitem__(self, idx):\n        if self.is_train and self.id_sampling:\n            name = self.videos[idx].split('.')[0]\n            path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))\n        else:\n            name = self.videos[idx].split('.')[0]\n\n            audio_path = os.path.join(self.audio_dir, name+'.npy')\n            pose_path = os.path.join(self.pose_dir,name+'.npy')\n            path = os.path.join(self.image_dir, name)\n\n        video_name = os.path.basename(path)\n\n        if  os.path.isdir(path):\n     #   if self.is_train and os.path.isdir(path):\n            frames = os.listdir(path)\n            num_frames = len(frames)\n            frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2))\n            video_array = [img_as_float32(io.imread(os.path.join(path, frames[idx]))) for idx in frame_idx]\n            mfcc = np.load(audio_path)\n            pose = np.load(pose_path)\n\n          #  print(audio_path,pose_path,len(mfcc))\n\n            try:\n                len(mfcc) > 16\n            except:\n                print('wrongmfcc len:',audio_path)\n            if 16 < len(mfcc) < 24 :\n                r = 0\n            else:\n\n                r = random.choice([x for x in range(3, len(mfcc)-20)])\n\n            mfccs = []\n            poses = []\n            video_array = []\n            for ind in range(1, 17):\n                t_mfcc = mfcc[r+ind][:, 1:]\n                mfccs.append(t_mfcc)\n                t_pose = pose[r+ind,:-1]\n                poses.append(t_pose)\n                image = img_as_float32(io.imread(os.path.join(path, str(r + ind)+'.png')))\n                video_array.append(image)\n            mfccs = np.array(mfccs)\n            poses = np.array(poses)\n            video_array = np.array(video_array)\n\n            example_image = img_as_float32(io.imread(os.path.join(path, str(r)+'.png')))\n\n\n        else:\n            print('Wrong, data path not an existing file.')\n\n        if self.transform is not None:\n            video_array = self.transform(video_array)\n\n        out = {}\n\n        driving = np.array(video_array, dtype='float32')\n\n        spatial_size = np.array(driving.shape[1:3][::-1])[np.newaxis]\n        driving_pose = np.array(poses, dtype='float32')\n        example_image = np.array(example_image, dtype='float32')\n        out['example_image'] = example_image.transpose((2, 0, 1))\n        out['driving_pose'] = driving_pose\n        out['driving'] = driving.transpose((0, 3, 1, 2))\n\n        out['driving_audio'] = np.array(mfccs, dtype='float32')\n    #    out['name'] = video_name\n\n        return out\n\nclass MeadDataset(Dataset):\n    \"\"\"\n    Dataset of videos, each video can be represented as:\n      - an image of concatenated frames\n      - '.mp4' or '.gif'\n      - folder with all frames\n    \"\"\"\n\n    def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,\n                 random_seed=0, augmentation_params=None):\n        self.root_dir = root_dir\n\n        self.audio_dir = os.path.join(root_dir,'MEAD_MFCC')\n        self.image_dir = os.path.join(root_dir,'MEAD_fomm_crop')\n\n        self.pose_dir = os.path.join(root_dir,'MEAD_fomm_pose_crop')\n\n        self.videos = np.load('/mnt/lustre/share_data/jixinya/MEAD/MEAD_fomm_audio_less_crop.npy')\n        self.dict = np.load('/mnt/lustre/share_data/jixinya/MEAD/MEAD_fomm_neu_dic_crop.npy',allow_pickle=True).item()\n       # self.videos = os.listdir(root_dir)\n        self.frame_shape = tuple(frame_shape)\n\n        self.id_sampling = id_sampling\n        if os.path.exists(os.path.join(root_dir, 'train')):\n            assert os.path.exists(os.path.join(root_dir, 'test'))\n            print(\"Use predefined train-test split.\")\n            if id_sampling:\n                train_videos = {os.path.basename(video).split('#')[0] for video in\n                                os.listdir(os.path.join(root_dir, 'train'))}\n                train_videos = list(train_videos)\n            else:\n                train_videos = os.listdir(os.path.join(root_dir, 'train'))\n            test_videos = os.listdir(os.path.join(root_dir, 'test'))\n            self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')\n        else:\n            print(\"Use random train-test split.\")\n            train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)\n\n        if is_train:\n            self.videos = train_videos\n        else:\n            self.videos = test_videos\n\n        self.is_train = is_train\n\n        if self.is_train:\n            self.transform = AllAugmentationTransform(**augmentation_params)\n        else:\n            self.transform = None\n\n    def __len__(self):\n        return len(self.videos)\n\n    def __getitem__(self, idx):\n        if self.is_train and self.id_sampling:\n            name = self.videos[idx]\n            path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))\n        else:\n            name = self.videos[idx]\n            path = os.path.join(self.image_dir, name)\n\n            video_name = os.path.basename(path)\n            id_name = path.split('/')[-2]\n            neu_list = self.dict[id_name]\n            neu_path = os.path.join(self.image_dir, np.random.choice(neu_list))\n\n            audio_path = os.path.join(self.audio_dir, name+'.npy')\n            pose_path = os.path.join(self.pose_dir,name+'.npy')\n\n\n        if self.is_train and os.path.isdir(path):\n\n            mfcc = np.load(audio_path)\n            pose_raw = np.load(pose_path)\n            one_euro_filter = OneEuroFilter(mincutoff=0.01, beta=0.7, dcutoff=1.0, freq=100)\n            pose = np.zeros((len(pose_raw),7))\n\n            for j in range(len(pose_raw)):\n                pose[j]=one_euro_filter.process(pose_raw[j])\n          #  print(audio_path,pose_path,len(mfcc))\n\n            neu_frames = os.listdir(neu_path)\n            num_neu_frames = len(neu_frames)\n            frame_idx = np.random.choice(num_neu_frames)\n            example_image = img_as_float32(io.imread(os.path.join(neu_path, neu_frames[frame_idx])))\n            try:\n                len(mfcc) > 16\n            except:\n                print('wrongmfcc len:',audio_path)\n            if 16 < len(mfcc) < 24 :\n                r = 0\n            else:\n\n                r = random.choice([x for x in range(3, len(mfcc)-20)])\n\n            mfccs = []\n            poses = []\n            video_array = []\n            for ind in range(1, 17):\n                t_mfcc = mfcc[r+ind][:, 1:]\n                mfccs.append(t_mfcc)\n                t_pose = pose[r+ind,:-1]\n                poses.append(t_pose)\n                image = img_as_float32(io.imread(os.path.join(path, str(r + ind)+'.png')))\n                video_array.append(image)\n            mfccs = np.array(mfccs)\n            poses = np.array(poses)\n            video_array = np.array(video_array)\n\n        else:\n            print('Wrong, data path not an existing file.')\n\n        if self.transform is not None:\n            video_array = self.transform(video_array)\n\n        out = {}\n        if self.is_train:\n      \n            driving = np.array(video_array, dtype='float32')\n            driving_pose = np.array(poses, dtype='float32')\n            example_image = np.array(example_image, dtype='float32')\n\n\n            out['example_image'] = example_image.transpose((2, 0, 1))\n            out['driving_pose'] = driving_pose\n            out['driving'] = driving.transpose((0, 3, 1, 2))\n            out['driving_audio'] = np.array(mfccs, dtype='float32')\n\n      #  out['name'] = id_name+'/'+video_name\n\n        return out\n\n\nclass DatasetRepeater(Dataset):\n    \"\"\"\n    Pass several times over the same dataset for better i/o performance\n    \"\"\"\n\n    def __init__(self, dataset, num_repeats=100):\n        self.dataset = dataset\n    #    self.dataset2 = dataset2\n        self.num_repeats = num_repeats\n\n    def __len__(self):\n        return self.num_repeats * self.dataset.__len__()\n\n    def __getitem__(self, idx):\n     #   if idx % 5 == 0:\n     #       return self.dataset2[idx % self.dataset2.__len__()]#% self.dataset.__len__()\n     #   else:\n     #       return self.dataset[idx % self.dataset.__len__()]\n        return self.dataset[idx % self.dataset.__len__()]\n\nclass TestsetRepeater(Dataset):\n    \"\"\"\n    Pass several times over the same dataset for better i/o performance\n    \"\"\"\n\n    def __init__(self, dataset, num_repeats=100):\n        self.dataset = dataset\n\n        self.num_repeats = num_repeats\n\n    def __len__(self):\n        return self.num_repeats * self.dataset.__len__()\n\n    def __getitem__(self, idx):\n\n        return self.dataset[idx % self.dataset.__len__()]#% self.dataset.__len__()\n\n\nclass PairedDataset(Dataset):\n    \"\"\"\n    Dataset of pairs for animation.\n    \"\"\"\n\n    def __init__(self, initial_dataset, number_of_pairs, seed=0):\n        self.initial_dataset = initial_dataset\n        pairs_list = self.initial_dataset.pairs_list\n\n        np.random.seed(seed)\n\n        if pairs_list is None:\n            max_idx = min(number_of_pairs, len(initial_dataset))\n            nx, ny = max_idx, max_idx\n            xy = np.mgrid[:nx, :ny].reshape(2, -1).T\n            number_of_pairs = min(xy.shape[0], number_of_pairs)\n            self.pairs = xy.take(np.random.choice(xy.shape[0], number_of_pairs, replace=False), axis=0)\n        else:\n            videos = self.initial_dataset.videos\n            name_to_index = {name: index for index, name in enumerate(videos)}\n            pairs = pd.read_csv(pairs_list)\n            pairs = pairs[np.logical_and(pairs['source'].isin(videos), pairs['driving'].isin(videos))]\n\n            number_of_pairs = min(pairs.shape[0], number_of_pairs)\n            self.pairs = []\n            self.start_frames = []\n            for ind in range(number_of_pairs):\n                self.pairs.append(\n                    (name_to_index[pairs['driving'].iloc[ind]], name_to_index[pairs['source'].iloc[ind]]))\n\n    def __len__(self):\n        return len(self.pairs)\n\n    def __getitem__(self, idx):\n        pair = self.pairs[idx]\n        first = self.initial_dataset[pair[0]]\n        second = self.initial_dataset[pair[1]]\n        first = {'driving_' + key: value for key, value in first.items()}\n        second = {'source_' + key: value for key, value in second.items()}\n\n        return {**first, **second}\n"
  },
  {
    "path": "logger.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn.functional as F\nimport imageio\n\nimport os\nfrom skimage.draw import circle\n\nimport matplotlib.pyplot as plt\nimport collections\n\n\nclass Logger:\n    def __init__(self, log_dir, checkpoint_freq=100, visualizer_params=None, zfill_num=8, log_file_name='log.txt'):\n\n        self.loss_list = []\n        self.cpk_dir = log_dir\n        self.visualizations_dir = os.path.join(log_dir, 'train-vis')\n        if not os.path.exists(self.visualizations_dir):\n            os.makedirs(self.visualizations_dir)\n        self.log_file = open(os.path.join(log_dir, log_file_name), 'a')\n        self.zfill_num = zfill_num\n        self.visualizer = Visualizer(**visualizer_params)\n        self.checkpoint_freq = checkpoint_freq\n        self.epoch = 0\n        self.best_loss = float('inf')\n        self.names = None\n\n    def log_scores(self, loss_names):\n        loss_mean = np.array(self.loss_list).mean(axis=0)\n\n        loss_string = \"; \".join([\"%s - %.5f\" % (name, value) for name, value in zip(loss_names, loss_mean)])\n        loss_string = str(str(self.epoch)+str(self.step).zfill(self.zfill_num)) + \") \" + loss_string\n\n        print(loss_string, file=self.log_file)\n        self.loss_list = []\n        self.log_file.flush()\n\n    def visualize_rec(self, inp, out):\n      #  image = self.visualizer.visualize(inp['driving'], inp['source'], out)\n        image = self.visualizer.visualize(inp['driving'][:,-1], inp['transformed_driving'][:,-1], inp['example_image'], out)\n        imageio.imsave(os.path.join(self.visualizations_dir, \"%s-%s-rec.png\" % (str(self.epoch),str(self.step).zfill(self.zfill_num))), image)\n\n    def save_cpk(self, emergent=False):\n        cpk = {k: v.state_dict() for k, v in self.models.items()}\n        cpk['epoch'] = self.epoch\n        cpk['step'] = self.step\n        cpk_path = os.path.join(self.cpk_dir, '%s-%s-checkpoint.pth.tar' % (str(self.epoch),str(self.step).zfill(self.zfill_num)))\n        if not (os.path.exists(cpk_path) and emergent):\n            torch.save(cpk, cpk_path)\n\n    @staticmethod\n    def load_cpk(checkpoint_path, generator=None, discriminator=None, kp_detector=None, audio_feature=None,\n                 optimizer_generator=None, optimizer_discriminator=None, optimizer_kp_detector=None, optimizer_audio_feature = None):\n        checkpoint = torch.load(checkpoint_path)\n        if generator is not None:\n            generator.load_state_dict(checkpoint['generator'])\n        if kp_detector is not None:\n            kp_detector.load_state_dict(checkpoint['kp_detector'])\n        if discriminator is not None:\n            try:\n               discriminator.load_state_dict(checkpoint['discriminator'])\n            except:\n               print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')\n    #    if audio_feature is not None:\n    #        audio_feature.load_state_dict(checkpoint['audio_feature'])\n        if optimizer_generator is not None:\n            optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])\n        if optimizer_discriminator is not None:\n            try:\n                optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])\n            except RuntimeError as e:\n                print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')\n        if optimizer_kp_detector is not None:\n            optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])\n  #      if optimizer_audio_feature is not None:\n  #          a = checkpoint['optimizer_kp_detector']['param_groups']\n  #          a[0].pop('params')\n  #          optimizer_audio_feature.load_state_dict(checkpoint['optimizer_audio_feature'])\n\n        return checkpoint['epoch']\n\n    def __enter__(self):\n        return self\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        if 'models' in self.__dict__:\n            self.save_cpk()\n        self.log_file.close()\n\n    def log_iter(self, losses):\n        losses = collections.OrderedDict(losses.items())\n        if self.names is None:\n            self.names = list(losses.keys())\n        self.loss_list.append(list(losses.values()))\n\n    def log_epoch(self, epoch, step, models, inp, out):\n        self.epoch = epoch\n        self.step = step\n        self.models = models\n        if (self.epoch + 1) % self.checkpoint_freq == 0:\n            self.save_cpk()\n        self.log_scores(self.names)\n        self.visualize_rec(inp, out)\n\n\nclass Visualizer:\n    def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbow'):\n        self.kp_size = kp_size\n        self.draw_border = draw_border\n        self.colormap = plt.get_cmap(colormap)\n\n    def draw_image_with_kp(self, image, kp_array):\n        image = np.copy(image)\n        spatial_size = np.array(image.shape[:2][::-1])[np.newaxis]\n        kp_array = spatial_size * (kp_array + 1) / 2\n        num_kp = kp_array.shape[0]\n        for kp_ind, kp in enumerate(kp_array):\n            rr, cc = circle(kp[1], kp[0], self.kp_size, shape=image.shape[:2])\n            image[rr, cc] = np.array(self.colormap(kp_ind / num_kp))[:3]\n        return image\n\n    def create_image_column_with_kp(self, images, kp):\n        image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)])\n        return self.create_image_column(image_array)\n\n    def create_image_column(self, images):\n        if self.draw_border:\n            images = np.copy(images)\n            images[:, :, [0, -1]] = (1, 1, 1)\n            images[:, :, [0, -1]] = (1, 1, 1)\n        return np.concatenate(list(images), axis=0)\n\n    def create_image_grid(self, *args):\n        out = []\n        for arg in args:\n            if type(arg) == tuple:\n                out.append(self.create_image_column_with_kp(arg[0], arg[1]))\n            else:\n                out.append(self.create_image_column(arg))\n        return np.concatenate(out, axis=1)\n\n    def visualize(self, driving, transformed_driving, source, out):\n        images = []\n\n        # Source image with keypoints\n        source = source.data.cpu()\n        kp_source = out['kp_source']['value'].data.cpu().numpy()\n        source = np.transpose(source, [0, 2, 3, 1])\n        images.append((source, kp_source))\n\n        # Equivariance visualization\n        if 'transformed_frame' in out:\n            transformed = out['transformed_frame'].data.cpu().numpy()\n            transformed = np.transpose(transformed, [0, 2, 3, 1])\n            transformed_kp = out['transformed_kp']['value'].data.cpu().numpy()\n            images.append((transformed, transformed_kp))\n\n        # Equivariance visualization\n        transformed_driving = transformed_driving.data.cpu().numpy()\n        transformed_driving = np.transpose(transformed_driving, [0, 2, 3, 1])\n        images.append(transformed_driving)\n\n        # Driving image with keypoints\n        kp_driving = out['kp_driving'][-1]['value'].data.cpu().numpy() #[-1]['value']\n        driving = driving.data.cpu().numpy()\n        driving = np.transpose(driving, [0, 2, 3, 1])\n        images.append((driving, kp_driving))\n\n        # Deformed image\n        if 'deformed' in out:\n            deformed = out['deformed'].data.cpu().numpy()\n            deformed = np.transpose(deformed, [0, 2, 3, 1])\n            images.append(deformed)\n\n        # Result with and without keypoints\n        prediction = out['prediction'].data.cpu().numpy()\n        prediction = np.transpose(prediction, [0, 2, 3, 1])\n        if 'kp_norm' in out:\n            kp_norm = out['kp_norm']['value'].data.cpu().numpy()\n            images.append((prediction, kp_norm))\n        images.append(prediction)\n\n\n        ## Occlusion map\n        if 'occlusion_map' in out:\n            occlusion_map = out['occlusion_map'].data.cpu().repeat(1, 3, 1, 1)\n            occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy()\n            occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1])\n            images.append(occlusion_map)\n\n        # Deformed images according to each individual transform\n        if 'sparse_deformed' in out:\n            full_mask = []\n            for i in range(out['sparse_deformed'].shape[1]):\n                image = out['sparse_deformed'][:, i].data.cpu()\n                image = F.interpolate(image, size=source.shape[1:3])\n                mask = out['mask'][:, i:(i+1)].data.cpu().repeat(1, 3, 1, 1)\n                mask = F.interpolate(mask, size=source.shape[1:3])\n                image = np.transpose(image.numpy(), (0, 2, 3, 1))\n                mask = np.transpose(mask.numpy(), (0, 2, 3, 1))\n\n                if i != 0:\n                    color = np.array(self.colormap((i - 1) / (out['sparse_deformed'].shape[1] - 1)))[:3]\n                else:\n                    color = np.array((0, 0, 0))\n\n                color = color.reshape((1, 1, 1, 3))\n\n                images.append(image)\n                if i != 0:\n                    images.append(mask * color)\n                else:\n                    images.append(mask)\n\n                full_mask.append(mask * color)\n\n            images.append(sum(full_mask))\n\n        image = self.create_image_grid(*images)\n        image = (255 * image).astype(np.uint8)\n        return image\n"
  },
  {
    "path": "modules/dense_motion.py",
    "content": "from torch import nn\nimport torch.nn.functional as F\nimport torch\nfrom modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian\n\n\nclass DenseMotionNetwork(nn.Module):\n    \"\"\"\n    Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving\n    \"\"\"\n\n    def __init__(self, block_expansion, num_blocks, max_features, num_kp, num_channels, estimate_occlusion_map=False,\n                 scale_factor=1, kp_variance=0.01):\n        super(DenseMotionNetwork, self).__init__()\n        self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp + 1) * (num_channels + 1),\n                                   max_features=max_features, num_blocks=num_blocks)\n\n        self.mask = nn.Conv2d(self.hourglass.out_filters, num_kp + 1, kernel_size=(7, 7), padding=(3, 3))\n\n        if estimate_occlusion_map:\n            self.occlusion = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3))\n        else:\n            self.occlusion = None\n\n        self.num_kp = num_kp\n        self.scale_factor = scale_factor\n        self.kp_variance = kp_variance\n\n        if self.scale_factor != 1:\n            self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)\n\n    def create_heatmap_representations(self, source_image, kp_driving, kp_source):\n        \"\"\"\n        Eq 6. in the paper H_k(z)\n        \"\"\"\n        spatial_size = source_image.shape[2:]\n        gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance)\n        gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance)\n        heatmap = gaussian_driving - gaussian_source #[4,10,H,W]\n\n        #adding background feature\n        zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type())\n        heatmap = torch.cat([zeros, heatmap], dim=1)\n        heatmap = heatmap.unsqueeze(2) #[4,11,1,h,w]\n        return heatmap\n\n    def create_sparse_motions(self, source_image, kp_driving, kp_source):\n        \"\"\"\n        Eq 4. in the paper T_{s<-d}(z)\n        \"\"\"\n        bs, _, h, w = source_image.shape\n        identity_grid = make_coordinate_grid((h, w), type=kp_source['value'].type())\n        identity_grid = identity_grid.view(1, 1, h, w, 2)\n        coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 2) #[4,10,64,64,2]\n        if 'jacobian' in kp_driving:\n            jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian']))\n            jacobian = jacobian.unsqueeze(-3).unsqueeze(-3)\n            jacobian = jacobian.repeat(1, 1, h, w, 1, 1)\n            coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1))\n            coordinate_grid = coordinate_grid.squeeze(-1)\n\n        driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 2)\n\n        #adding background feature\n        identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)\n        sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1)\n        return sparse_motions\n\n    def create_deformed_source_image(self, source_image, sparse_motions):\n        \"\"\"\n        Eq 7. in the paper \\hat{T}_{s<-d}(z)\n        \"\"\"\n        bs, _, h, w = source_image.shape\n        source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp + 1, 1, 1, 1, 1)\n        source_repeat = source_repeat.view(bs * (self.num_kp + 1), -1, h, w)\n        sparse_motions = sparse_motions.view((bs * (self.num_kp + 1), h, w, -1))\n        sparse_deformed = F.grid_sample(source_repeat, sparse_motions)\n        sparse_deformed = sparse_deformed.view((bs, self.num_kp + 1, -1, h, w))\n        return sparse_deformed\n\n    def forward(self, source_image, kp_driving, kp_source):\n        if self.scale_factor != 1:\n            source_image = self.down(source_image) #[4,3,H*scale,W*scale]\n\n        bs, _, h, w = source_image.shape\n\n        out_dict = dict()\n        heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source) #[4,11,1,64,64]\n        sparse_motion = self.create_sparse_motions(source_image, kp_driving, kp_source) #[4,11,64,64,2]\n        deformed_source = self.create_deformed_source_image(source_image, sparse_motion) #[4,11,3,64,64]\n        out_dict['sparse_deformed'] = deformed_source\n\n        input = torch.cat([heatmap_representation, deformed_source], dim=2)\n        input = input.view(bs, -1, h, w) #[4,11*4,64,64]\n\n        prediction = self.hourglass(input) #[4,108,64,64]\n\n        mask = self.mask(prediction)\n        mask = F.softmax(mask, dim=1) #[4,11,64,64]\n        out_dict['mask'] = mask\n        mask = mask.unsqueeze(2)\n        sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3)\n        deformation = (sparse_motion * mask).sum(dim=1)\n        deformation = deformation.permute(0, 2, 3, 1) #[4,64,64,2]\n\n        out_dict['deformation'] = deformation\n\n        # Sec. 3.2 in the paper\n        if self.occlusion:\n            occlusion_map = torch.sigmoid(self.occlusion(prediction))\n            out_dict['occlusion_map'] = occlusion_map #[4,1,64,64]\n\n        return out_dict\n"
  },
  {
    "path": "modules/discriminator.py",
    "content": "from torch import nn\nimport torch.nn.functional as F\nfrom modules.util import kp2gaussian\nimport torch\n\n\nclass DownBlock2d(nn.Module):\n    \"\"\"\n    Simple block for processing video (encoder).\n    \"\"\"\n\n    def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):\n        super(DownBlock2d, self).__init__()\n        self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)\n\n        if sn:\n            self.conv = nn.utils.spectral_norm(self.conv)\n\n        if norm:\n            self.norm = nn.InstanceNorm2d(out_features, affine=True)\n        else:\n            self.norm = None\n        self.pool = pool\n\n    def forward(self, x):\n        out = x\n        out = self.conv(out)\n        if self.norm:\n            out = self.norm(out)\n        out = F.leaky_relu(out, 0.2)\n        if self.pool:\n            out = F.avg_pool2d(out, (2, 2))\n        return out\n\n\nclass Discriminator(nn.Module):\n    \"\"\"\n    Discriminator similar to Pix2Pix\n    \"\"\"\n\n    def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,\n                 sn=False, use_kp=False, num_kp=10, kp_variance=0.01, **kwargs):\n        super(Discriminator, self).__init__()\n\n        down_blocks = []\n        for i in range(num_blocks):\n            down_blocks.append(\n                DownBlock2d(num_channels + num_kp * use_kp if i == 0 else min(max_features, block_expansion * (2 ** i)),\n                            min(max_features, block_expansion * (2 ** (i + 1))),\n                            norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))\n\n        self.down_blocks = nn.ModuleList(down_blocks)\n        self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)\n        if sn:\n            self.conv = nn.utils.spectral_norm(self.conv)\n        self.use_kp = use_kp\n        self.kp_variance = kp_variance\n\n    def forward(self, x, kp=None):\n        feature_maps = []\n        out = x\n        if self.use_kp:\n            heatmap = kp2gaussian(kp, x.shape[2:], self.kp_variance)\n            out = torch.cat([out, heatmap], dim=1)\n\n        for down_block in self.down_blocks:\n            feature_maps.append(down_block(out))\n            out = feature_maps[-1]\n        prediction_map = self.conv(out)\n\n        return feature_maps, prediction_map\n\n\nclass MultiScaleDiscriminator(nn.Module):\n    \"\"\"\n    Multi-scale (scale) discriminator\n    \"\"\"\n\n    def __init__(self, scales=(), **kwargs):\n        super(MultiScaleDiscriminator, self).__init__()\n        self.scales = scales\n        discs = {}\n        for scale in scales:\n            discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)\n        self.discs = nn.ModuleDict(discs)\n\n    def forward(self, x, kp=None):\n        out_dict = {}\n        for scale, disc in self.discs.items():\n            scale = str(scale).replace('-', '.')\n            key = 'prediction_' + scale\n            feature_maps, prediction_map = disc(x[key], kp)\n            out_dict['feature_maps_' + scale] = feature_maps\n            out_dict['prediction_map_' + scale] = prediction_map\n        return out_dict\n"
  },
  {
    "path": "modules/function.py",
    "content": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nCreated on Thu Sep 30 17:45:24 2021\n\n@author: SENSETIME\\jixinya1\n\"\"\"\n\nimport torch\n\n\ndef calc_mean_std(feat, eps=1e-5):\n    # eps is a small value added to the variance to avoid divide-by-zero.\n    size = feat.size()\n    assert (len(size) == 4)\n    N, C = size[:2]\n    feat_var = feat.view(N, C, -1).var(dim=2) + eps\n    feat_std = feat_var.sqrt().view(N, C, 1, 1)\n    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)\n    return feat_mean, feat_std\n\n\ndef adaptive_instance_normalization(content_feat, style_feat):\n    assert (content_feat.size()[:2] == style_feat.size()[:2])\n    size = content_feat.size()\n    style_mean, style_std = calc_mean_std(style_feat)\n    content_mean, content_std = calc_mean_std(content_feat)\n\n    normalized_feat = (content_feat - content_mean.expand(\n        size)) / content_std.expand(size)\n    return normalized_feat * style_std.expand(size) + style_mean.expand(size)\n\n\ndef _calc_feat_flatten_mean_std(feat):\n    # takes 3D feat (C, H, W), return mean and std of array within channels\n    assert (feat.size()[0] == 3)\n    assert (isinstance(feat, torch.FloatTensor))\n    feat_flatten = feat.view(3, -1)\n    mean = feat_flatten.mean(dim=-1, keepdim=True)\n    std = feat_flatten.std(dim=-1, keepdim=True)\n    return feat_flatten, mean, std\n\n\ndef _mat_sqrt(x):\n    U, D, V = torch.svd(x)\n    return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t())\n\n\ndef coral(source, target):\n    # assume both source and target are 3D array (C, H, W)\n    # Note: flatten -> f\n\n    source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)\n    source_f_norm = (source_f - source_f_mean.expand_as(\n        source_f)) / source_f_std.expand_as(source_f)\n    source_f_cov_eye = \\\n        torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)\n\n    target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)\n    target_f_norm = (target_f - target_f_mean.expand_as(\n        target_f)) / target_f_std.expand_as(target_f)\n    target_f_cov_eye = \\\n        torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)\n\n    source_f_norm_transfer = torch.mm(\n        _mat_sqrt(target_f_cov_eye),\n        torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),\n                 source_f_norm)\n    )\n\n    source_f_transfer = source_f_norm_transfer * \\\n                        target_f_std.expand_as(source_f_norm) + \\\n                        target_f_mean.expand_as(source_f_norm)\n\n    return source_f_transfer.view(source.size())"
  },
  {
    "path": "modules/generator.py",
    "content": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d\nfrom modules.dense_motion import DenseMotionNetwork\n\n\nclass OcclusionAwareGenerator(nn.Module):\n    \"\"\"\n    Generator that given source image and and keypoints try to transform image according to movement trajectories\n    induced by keypoints. Generator follows Johnson architecture.\n    \"\"\"\n\n    def __init__(self, num_channels, num_kp, block_expansion, max_features, num_down_blocks,\n                 num_bottleneck_blocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):\n        super(OcclusionAwareGenerator, self).__init__()\n\n        if dense_motion_params is not None:\n            self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, num_channels=num_channels,\n                                                           estimate_occlusion_map=estimate_occlusion_map,\n                                                           **dense_motion_params)\n        else:\n            self.dense_motion_network = None\n\n        self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))\n\n        down_blocks = []\n        for i in range(num_down_blocks):\n            in_features = min(max_features, block_expansion * (2 ** i))\n            out_features = min(max_features, block_expansion * (2 ** (i + 1)))\n            down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))\n        self.down_blocks = nn.ModuleList(down_blocks)\n\n        up_blocks = []\n        for i in range(num_down_blocks):\n            in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i)))\n            out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1)))\n            up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))\n        self.up_blocks = nn.ModuleList(up_blocks)\n\n        self.bottleneck = torch.nn.Sequential()\n        in_features = min(max_features, block_expansion * (2 ** num_down_blocks))\n        for i in range(num_bottleneck_blocks):\n            self.bottleneck.add_module('r' + str(i), ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1)))\n\n        self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))\n        self.estimate_occlusion_map = estimate_occlusion_map\n        self.num_channels = num_channels\n\n    def deform_input(self, inp, deformation):\n        _, h_old, w_old, _ = deformation.shape\n        _, _, h, w = inp.shape\n        if h_old != h or w_old != w:\n            deformation = deformation.permute(0, 3, 1, 2)\n            deformation = F.interpolate(deformation, size=(h, w), mode='bilinear')\n            deformation = deformation.permute(0, 2, 3, 1)\n        return F.grid_sample(inp, deformation)\n\n    def forward(self, source_image, kp_driving, kp_source):\n        # Encoding (downsampling) part\n        out = self.first(source_image) #[4,64,H,W]\n        for i in range(len(self.down_blocks)):\n            out = self.down_blocks[i](out) #[4,256,H/4,W/4]\n\n        # Transforming feature representation according to deformation and occlusion\n        output_dict = {}\n        if self.dense_motion_network is not None:\n            dense_motion = self.dense_motion_network(source_image=source_image, kp_driving=kp_driving,\n                                                     kp_source=kp_source)\n            output_dict['mask'] = dense_motion['mask']\n            output_dict['sparse_deformed'] = dense_motion['sparse_deformed']\n\n            if 'occlusion_map' in dense_motion:\n                occlusion_map = dense_motion['occlusion_map']\n                output_dict['occlusion_map'] = occlusion_map\n            else:\n                occlusion_map = None\n            deformation = dense_motion['deformation']\n            out = self.deform_input(out, deformation)\n\n            if occlusion_map is not None:\n                if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:\n                    occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')\n                out = out * occlusion_map\n\n            output_dict[\"deformed\"] = self.deform_input(source_image, deformation)\n\n        # Decoding part\n        out = self.bottleneck(out) #[4,256,64,64]\n        for i in range(len(self.up_blocks)):\n            out = self.up_blocks[i](out)\n        out = self.final(out)\n        out = torch.sigmoid(out) #[4,3,256,256]\n\n        output_dict[\"prediction\"] = out\n\n        return output_dict\n"
  },
  {
    "path": "modules/keypoint_detector.py",
    "content": "from torch import nn\nimport torch\nimport torch.nn.functional as F\nfrom modules.util import Hourglass, make_coordinate_grid, AntiAliasInterpolation2d, Ct_encoder, EmotionNet, AF2F, AF2F_s, draw_heatmap\n\n\nclass KPDetector(nn.Module):\n    \"\"\"\n    Detecting a keypoints. Return keypoint position and jacobian near each keypoint.\n    \"\"\"\n\n    def __init__(self, block_expansion, num_kp, num_channels, max_features,\n                 num_blocks, temperature, estimate_jacobian=False, scale_factor=1,\n                 single_jacobian_map=False, pad=0):\n        super(KPDetector, self).__init__()\n\n        self.predictor = Hourglass(block_expansion, in_features=num_channels,\n                                   max_features=max_features, num_blocks=num_blocks)\n\n        self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7),\n                            padding=pad)\n\n        if estimate_jacobian:\n            self.num_jacobian_maps = 1 if single_jacobian_map else num_kp\n            self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters,\n                                      out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad)\n            self.jacobian.weight.data.zero_()\n            self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))\n        else:\n            self.jacobian = None\n\n        self.temperature = temperature\n        self.scale_factor = scale_factor\n        if self.scale_factor != 1:\n            self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)\n        \n        \n        \n        \n    def gaussian2kp(self, heatmap):\n        \"\"\"\n        Extract the mean and from a heatmap\n        \"\"\"\n        shape = heatmap.shape\n        heatmap = heatmap.unsqueeze(-1) #[4,10,58,58,1]\n        grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) #[1,1,58,58,2]\n        value = (heatmap * grid).sum(dim=(2, 3)) #[4,10,2]\n        kp = {'value': value}\n\n        return kp\n    \n    def audio_feature(self, x, heatmap):\n        \n      #  prediction = self.kp(x) #[4,10,H/4-6, W/4-6]\n\n      #  final_shape = prediction.shape\n      #  heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58]\n     #   heatmap = F.softmax(heatmap / self.temperature, dim=2)\n     #   heatmap = heatmap.view(*final_shape) #[4,10,58,58]\n\n     #   out = self.gaussian2kp(heatmap)\n        final_shape = heatmap.squeeze(2).shape   \n     \n        if self.jacobian is not None:\n            jacobian_map = self.jacobian(x) ##[4,40,H/4-6, W/4-6]\n            jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],\n                                                final_shape[3])\n            heatmap = heatmap.unsqueeze(2)\n\n            jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6]\n            jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)\n            jacobian = jacobian.sum(dim=-1) #[4,10,4]\n            jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2]\n            \n        return jacobian\n    \n    def forward(self, x): #torch.Size([4, 3, H, W])\n        if self.scale_factor != 1:\n            x = self.down(x) # 0.25 [4, 3, H/4, W/4]\n\n        feature_map = self.predictor(x) #[4,3+32,H/4, W/4]\n        prediction = self.kp(feature_map) #[4,10,H/4-6, W/4-6]\n\n        final_shape = prediction.shape\n        \n        heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58]\n        heatmap = F.softmax(heatmap / self.temperature, dim=2)\n        heatmap = heatmap.view(*final_shape) #[4,10,58,58]\n        \n        out = self.gaussian2kp(heatmap)\n        out['heatmap'] = heatmap\n        \n        if self.jacobian is not None:\n            jacobian_map = self.jacobian(feature_map) ##[4,40,H/4-6, W/4-6]\n            jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],\n                                                final_shape[3])\n            heatmap = heatmap.unsqueeze(2)\n\n            jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6]\n            jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)\n            jacobian = jacobian.sum(dim=-1) #[4,10,4]\n            jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2]\n            out['jacobian'] = jacobian\n\n        return out\n    \n    \n\n\nclass KPDetector_a(nn.Module):\n    \"\"\"\n    Detecting a keypoints. Return keypoint position and jacobian near each keypoint.\n    \"\"\"\n\n    def __init__(self, block_expansion, num_kp, num_channels,num_channels_a, max_features,\n                 num_blocks, temperature, estimate_jacobian=False, scale_factor=1,\n                 single_jacobian_map=False, pad=0):\n        super(KPDetector_a, self).__init__()\n\n        self.predictor = Hourglass(block_expansion, in_features=num_channels_a,\n                                   max_features=max_features, num_blocks=num_blocks)\n\n        self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7),\n                            padding=pad)\n\n        if estimate_jacobian:\n            self.num_jacobian_maps = 1 if single_jacobian_map else num_kp\n            self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters,\n                                      out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad)\n            self.jacobian.weight.data.zero_()\n            self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))\n        else:\n            self.jacobian = None\n\n        self.temperature = temperature\n        self.scale_factor = scale_factor\n        if self.scale_factor != 1:\n            self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)\n        \n        \n        \n        \n    def gaussian2kp(self, heatmap):\n        \"\"\"\n        Extract the mean and from a heatmap\n        \"\"\"\n        shape = heatmap.shape\n        heatmap = heatmap.unsqueeze(-1) #[4,10,58,58,1]\n        grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) #[1,1,58,58,2]\n        value = (heatmap * grid).sum(dim=(2, 3)) #[4,10,2]\n        kp = {'value': value}\n\n        return kp\n    \n    def audio_feature(self, x, heatmap):\n        \n      #  prediction = self.kp(x) #[4,10,H/4-6, W/4-6]\n\n      #  final_shape = prediction.shape\n      #  heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58]\n     #   heatmap = F.softmax(heatmap / self.temperature, dim=2)\n     #   heatmap = heatmap.view(*final_shape) #[4,10,58,58]\n\n     #   out = self.gaussian2kp(heatmap)\n        final_shape = heatmap.squeeze(2).shape   \n     \n        if self.jacobian is not None:\n            jacobian_map = self.jacobian(x) ##[4,40,H/4-6, W/4-6]\n            jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],\n                                                final_shape[3])\n            heatmap = heatmap.unsqueeze(2)\n\n            jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6]\n            jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)\n            jacobian = jacobian.sum(dim=-1) #[4,10,4]\n            jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2]\n            \n        return jacobian\n    \n    def forward(self,  feature_map): #torch.Size([4, 3, H, W])\n       \n        prediction = self.kp(feature_map) #[4,10,H/4-6, W/4-6]\n\n        final_shape = prediction.shape\n        \n        heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58]\n        heatmap = F.softmax(heatmap / self.temperature, dim=2)\n        heatmap = heatmap.view(*final_shape) #[4,10,58,58]\n        \n        out = self.gaussian2kp(heatmap)\n        out['heatmap'] = heatmap\n        \n        if self.jacobian is not None:\n            jacobian_map = self.jacobian(feature_map) ##[4,40,H/4-6, W/4-6]\n            jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],\n                                                final_shape[3])\n            heatmap = heatmap.unsqueeze(2)\n\n            jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6]\n            jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)\n            jacobian = jacobian.sum(dim=-1) #[4,10,4]\n            jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2]\n            out['jacobian'] = jacobian\n\n        return out\n  \n    \nclass Audio_Feature(nn.Module):\n    def __init__(self):\n        super(Audio_Feature, self).__init__()\n        \n        self.con_encoder = Ct_encoder()\n        self.emo_encoder = EmotionNet()\n        self.decoder = AF2F_s()\n\n    \n    \n    def forward(self, x):\n        x = x.unsqueeze(1)\n      \n        c = self.con_encoder(x)\n        e = self.emo_encoder(x)\n        \n     #   d = torch.cat([c, e], dim=1)\n        d = self.decoder(c)\n        \n        \n        return d\n'''\ndef forward(self, x, cube, audio): #torch.Size([4, 3, H, W])\n        if self.scale_factor != 1:\n            x = self.down(x) # 0.25 [4, 3, H/4, W/4]\n        \n        cube = cube.unsqueeze(1)\n        feature = torch.cat([x,cube,audio],dim=1)\n        feature_map = self.predictor(feature) #[4,3+32,H/4, W/4]\n        prediction = self.kp(feature_map) #[4,10,H/4-6, W/4-6]\n\n        final_shape = prediction.shape\n        heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58]\n        heatmap = F.softmax(heatmap / self.temperature, dim=2)\n        heatmap = heatmap.view(*final_shape) #[4,10,58,58]\n\n        out = self.gaussian2kp(heatmap)\n        out['heatmap'] = heatmap\n        if self.jacobian is not None:\n            jacobian_map = self.jacobian(feature_map) ##[4,40,H/4-6, W/4-6]\n            jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],\n                                                final_shape[3])\n            heatmap = heatmap.unsqueeze(2)\n\n            jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6]\n            jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)\n            jacobian = jacobian.sum(dim=-1) #[4,10,4]\n            jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2]\n            out['jacobian'] = jacobian\n\n        return out\n'''\n"
  },
  {
    "path": "modules/model.py",
    "content": "from torch import nn\nimport torch\nimport torch.nn.functional as F\nfrom modules.util import AntiAliasInterpolation2d, make_coordinate_grid\nfrom torchvision import models\nimport numpy as np\nfrom torch.autograd import grad\n\n\nclass Vgg19(torch.nn.Module):\n    \"\"\"\n    Vgg19 network for perceptual loss. See Sec 3.3.\n    \"\"\"\n    def __init__(self, requires_grad=False):\n        super(Vgg19, self).__init__()\n        vgg_pretrained_features = models.vgg19(pretrained=True).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        for x in range(2):\n            self.slice1.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(2, 7):\n            self.slice2.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(7, 12):\n            self.slice3.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(12, 21):\n            self.slice4.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(21, 30):\n            self.slice5.add_module(str(x), vgg_pretrained_features[x])\n\n        self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),\n                                       requires_grad=False)\n        self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),\n                                      requires_grad=False)\n\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        X = (X - self.mean) / self.std\n        h_relu1 = self.slice1(X)\n        h_relu2 = self.slice2(h_relu1)\n        h_relu3 = self.slice3(h_relu2)\n        h_relu4 = self.slice4(h_relu3)\n        h_relu5 = self.slice5(h_relu4)\n        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]\n        return out\n\n\nclass ImagePyramide(torch.nn.Module):\n    \"\"\"\n    Create image pyramide for computing pyramide perceptual loss. See Sec 3.3\n    \"\"\"\n    def __init__(self, scales, num_channels):\n        super(ImagePyramide, self).__init__()\n        downs = {}\n        for scale in scales:\n            downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)\n        self.downs = nn.ModuleDict(downs)\n\n    def forward(self, x):\n        out_dict = {}\n        for scale, down_module in self.downs.items():\n            out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)\n        return out_dict\n\n\nclass Transform:\n    \"\"\"\n    Random tps transformation for equivariance constraints. See Sec 3.3\n    \"\"\"\n    def __init__(self, bs, **kwargs):\n        noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))\n        self.theta = noise + torch.eye(2, 3).view(1, 2, 3)\n        self.bs = bs\n\n        if ('sigma_tps' in kwargs) and ('points_tps' in kwargs):\n            self.tps = True\n            self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())\n            self.control_points = self.control_points.unsqueeze(0)\n            self.control_params = torch.normal(mean=0,\n                                               std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))\n        else:\n            self.tps = False\n\n    def transform_frame(self, frame):\n        grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2]\n        grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)\n        grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)\n        return F.grid_sample(frame, grid, padding_mode=\"reflection\")\n    \n    def inverse_transform_frame(self, frame):\n        grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2]\n        grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)\n        grid = self.inverse_warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)\n        return F.grid_sample(frame, grid, padding_mode=\"reflection\")\n    \n    def warp_coordinates(self, coordinates):\n        theta = self.theta.type(coordinates.type())\n        theta = theta.unsqueeze(1)\n        transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]\n        transformed = transformed.squeeze(-1)\n\n        if self.tps:\n            control_points = self.control_points.type(coordinates.type())\n            control_params = self.control_params.type(coordinates.type())\n            distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)\n            distances = torch.abs(distances).sum(-1)\n\n            result = distances ** 2\n            result = result * torch.log(distances + 1e-6)\n            result = result * control_params\n            result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)\n            transformed = transformed + result\n\n        return transformed\n\n    def inverse_warp_coordinates(self, coordinates):\n        theta = self.theta.type(coordinates.type())\n        theta = theta.unsqueeze(1)\n        a = torch.FloatTensor([[[[0,0,1]]]]).repeat([self.bs,1,1,1]).cuda()\n        c = torch.cat((theta,a),2)\n        d = c.inverse()[:,:,:2,:]\n        d = d.type(coordinates.type())\n        transformed = torch.matmul(d[:, :, :, :2], coordinates.unsqueeze(-1)) + d[:, :, :, 2:]\n        transformed = transformed.squeeze(-1)\n        \n        if self.tps:\n            control_points = self.control_points.type(coordinates.type())\n            control_params = self.control_params.type(coordinates.type())\n            distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)\n            distances = torch.abs(distances).sum(-1)\n\n            result = distances ** 2\n            result = result * torch.log(distances + 1e-6)\n            result = result * control_params\n            result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)\n            transformed = transformed + result\n        \n        \n        return transformed\n\n    def jacobian(self, coordinates):\n        coordinates.requires_grad=True\n        new_coordinates = self.warp_coordinates(coordinates)#[4,10,2]\n        grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True)\n        grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True)\n        jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2)\n        return jacobian\n\n\ndef detach_kp(kp):\n    return {key: value.detach() for key, value in kp.items()}\n\nclass TrainPart1Model(torch.nn.Module):\n    \"\"\"\n    Merge all generator related updates into single model for better multi-gpu usage\n    \"\"\"\n\n    def __init__(self, kp_extractor, kp_extractor_a, audio_feature, generator, discriminator, train_params, device_ids):\n        super(TrainFullModel, self).__init__()\n        self.kp_extractor = kp_extractor\n        self.kp_extractor_a = kp_extractor_a\n\n        self.audio_feature = audio_feature\n        self.generator = generator\n        self.discriminator = discriminator\n        self.train_params = train_params\n        self.scales = train_params['scales']\n        self.disc_scales = self.discriminator.scales\n        self.pyramid = ImagePyramide(self.scales, generator.num_channels)\n        if torch.cuda.is_available():\n            self.pyramid = self.pyramid.cuda()\n\n        self.loss_weights = train_params['loss_weights']\n\n        if sum(self.loss_weights['perceptual']) != 0:\n            self.vgg = Vgg19()\n            if torch.cuda.is_available():\n                self.vgg = self.vgg.cuda()\n        \n     \n        self.mse_loss_fn   =  nn.MSELoss().cuda()\n    def forward(self, x):\n \n        kp_source = self.kp_extractor(x['example_image'])\n\n        kp_driving = []\n        for i in range(16):\n            kp_driving.append(self.kp_extractor(x['driving'][:,i]))\n\n        kp_driving_a = [] #x['example_image'],\n        deco_out = self.audio_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net'])\n        loss_values = {}\n        \n        if self.loss_weights['audio'] != 0:\n            \n            kp_driving_a = []\n            for i in range(16):\n                kp_driving_a.append(self.kp_extractor_a(deco_out[:,i]))#\n       \n   \n        loss_value = 0\n        loss_heatmap = 0\n        loss_jacobian = 0\n        loss_perceptual = 0\n        for i in range(len(kp_driving)):\n            loss_jacobian += (torch.abs(kp_driving[i]['jacobian'] - kp_driving_a[i]['jacobian']).mean())*self.loss_weights['audio']\n            \n         #   loss_jacobian = loss_jacobian*self.loss_weights['audio']\n            loss_heatmap += (torch.abs(kp_driving[i]['heatmap'] - kp_driving_a[i]['heatmap']).mean())*self.loss_weights['audio']*100\n           \n            \n            loss_value += (torch.abs(kp_driving[i]['value'].detach() - kp_driving_a[i]['value']).mean())*self.loss_weights['audio']\n           \n        loss_values['loss_value'] = loss_value/len(kp_driving)\n        loss_values['loss_heatmap'] = loss_heatmap/len(kp_driving)\n        loss_values['loss_jacobian'] = loss_jacobian/len(kp_driving)\n\n   \n        if self.train_params['generator'] == 'not':\n     #       loss_values['perceptual'] = self.mse_loss_fn(deco_out,deco_out)\n            for i in range(1): #0,len(kp_driving),4\n \n                generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving_a[i])\n                generated.update({'kp_source': kp_source, 'kp_driving': kp_driving_a})\n        elif self.train_params['generator'] == 'visual':\n            for i in range(0,len(kp_driving),4): #0,len(kp_driving),4\n \n                generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving[i])\n                generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})\n                \n                pyramide_real = self.pyramid(x['driving'][:,i])\n                pyramide_generated = self.pyramid(generated['prediction'])\n        \n                if sum(self.loss_weights['perceptual']) != 0:\n                    value_total = 0\n                    for scale in self.scales:\n                        x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])\n                        y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])\n\n                        for i, weight in enumerate(self.loss_weights['perceptual']):\n                            value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()\n                            value_total += self.loss_weights['perceptual'][i] * value\n                    loss_perceptual += value_total\n        \n            length = int((len(kp_driving)-1)/4)+1\n            loss_values['perceptual'] = loss_perceptual/length\n        elif self.train_params['generator'] == 'audio':\n            for i in range(0,len(kp_driving),4): #0,len(kp_driving),4\n \n                generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving_a[i])\n                generated.update({'kp_source': kp_source, 'kp_driving': kp_driving_a})\n                \n                pyramide_real = self.pyramid(x['driving'][:,i])\n                pyramide_generated = self.pyramid(generated['prediction'])\n        \n                if sum(self.loss_weights['perceptual']) != 0:\n                    value_total = 0\n                    for scale in self.scales:\n                        x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])\n                        y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])\n\n                        for i, weight in enumerate(self.loss_weights['perceptual']):\n                            value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()\n                            value_total += self.loss_weights['perceptual'][i] * value\n                    loss_perceptual += value_total\n        \n            length = int((len(kp_driving)-1)/4)+1\n            loss_values['perceptual'] = loss_perceptual/length\n        else:\n            print('wrong train_params: ', self.train_params['generator'])\n      \n        \n      \n        return loss_values,generated\n\n\nclass TrainPart2Model(torch.nn.Module):\n    \"\"\"\n    Merge all generator related updates into single model for better multi-gpu usage\n    \"\"\"\n\n    def __init__(self, kp_extractor, emo_feature, kp_extractor_a, audio_feature, generator, discriminator, train_params, device_ids):\n        super(TrainFullModel, self).__init__()\n        self.kp_extractor = kp_extractor\n        self.kp_extractor_a = kp_extractor_a\n\n        self.audio_feature = audio_feature\n        self.emo_feature = emo_feature\n        self.generator = generator\n        self.discriminator = discriminator\n        self.train_params = train_params\n        self.scales = train_params['scales']\n        self.disc_scales = self.discriminator.scales\n        self.pyramid = ImagePyramide(self.scales, generator.num_channels)\n        if torch.cuda.is_available():\n            self.pyramid = self.pyramid.cuda()\n\n        self.loss_weights = train_params['loss_weights']\n\n        if sum(self.loss_weights['perceptual']) != 0:\n            self.vgg = Vgg19()\n            if torch.cuda.is_available():\n                self.vgg = self.vgg.cuda()\n\n        self.mse_loss_fn   =  nn.MSELoss().cuda()\n        self.CroEn_loss =  nn.CrossEntropyLoss().cuda()\n    def forward(self, x):\n \n        kp_source = self.kp_extractor(x['example_image'])\n\n        kp_driving = []\n        kp_emo = []\n        for i in range(16):\n            kp_driving.append(self.kp_extractor(x['driving'][:,i]))\n    #        kp_emo.append(self.emo_detector(x['driving'][:,i]))\n\n        kp_driving_a = [] #x['example_image'],\n        deco_out = self.audio_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net'])\n    #    emo_out = self.emo_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net'])\n        loss_values = {}\n\n        if self.loss_weights['emo'] != 0:\n\n            kp_driving_a = []\n            fakes = []\n            for i in range(16):\n                kp_driving_a.append(self.kp_extractor_a(deco_out[:,i]))#\n                value = self.kp_extractor_a(deco_out[:,i])['value']\n                jacobian = self.kp_extractor_a(deco_out[:,i])['jacobian']\n                if self.train_params['type'] == 'linear_4' :\n                    out, fake = self.emo_feature(x['transformed_driving'][:,i],value,jacobian)\n                    kp_emo.append(out)\n                    fakes.append(fake)\n                 #   kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian))\n                elif self.train_params['type'] == 'linear_10':\n                 #   kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))\n\n                    out, fake = self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)\n                    kp_emo.append(out)\n                    fakes.append(fake)\n                elif self.train_params['type'] == 'linear_4_new':\n                 #   kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))\n\n                    out, fake = self.emo_feature.linear_4(x['transformed_driving'][:,i],value,jacobian)\n                    kp_emo.append(out)\n                    fakes.append(fake)\n                elif self.train_params['type'] == 'linear_np_4':\n                 #   kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))\n\n                    out, fake = self.emo_feature.linear_np_4(x['transformed_driving'][:,i],value,jacobian)\n                    kp_emo.append(out)\n                    fakes.append(fake)\n                elif self.train_params['type'] == 'linear_np_10':\n                 #   kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))\n\n                    out, fake = self.emo_feature.linear_np_10(x['transformed_driving'][:,i],value,jacobian)\n                    kp_emo.append(out)\n                    fakes.append(fake)\n          \n        loss_value = 0\n\n        loss_jacobian = 0\n\n        loss_classify = 0\n        kp_all = kp_driving_a\n     \n        for i in range(len(kp_driving)):\n       \n            if self.train_params['type'] == 'linear_4' or self.train_params['type'] == 'linear_4_new' or self.train_params['type'] == 'linear_np_4':\n                loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,1] - kp_driving_a[i]['jacobian'][:,1] -kp_emo[i]['jacobian'][:,0]).mean())*self.loss_weights['emo']\n                loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,4] - kp_driving_a[i]['jacobian'][:,4] -kp_emo[i]['jacobian'][:,1]).mean())*self.loss_weights['emo']\n                loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,6] - kp_driving_a[i]['jacobian'][:,6] -kp_emo[i]['jacobian'][:,2]).mean())*self.loss_weights['emo']\n                loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,8] - kp_driving_a[i]['jacobian'][:,8] -kp_emo[i]['jacobian'][:,3]).mean())*self.loss_weights['emo']\n\n                loss_classify += self.CroEn_loss(fakes[i],x['emotion'])\n                loss_value += (torch.abs(kp_driving[i]['value'][:,1] .detach() - kp_driving_a[i]['value'][:,1]  - kp_emo[i]['value'][:,0] ).mean())*self.loss_weights['emo']\n                loss_value += (torch.abs(kp_driving[i]['value'][:,4] .detach() - kp_driving_a[i]['value'][:,4]  - kp_emo[i]['value'][:,1] ).mean())*self.loss_weights['emo']\n                loss_value += (torch.abs(kp_driving[i]['value'][:,6] .detach() - kp_driving_a[i]['value'][:,6]  - kp_emo[i]['value'][:,2] ).mean())*self.loss_weights['emo']\n                loss_value += (torch.abs(kp_driving[i]['value'][:,8] .detach() - kp_driving_a[i]['value'][:,8]  - kp_emo[i]['value'][:,3] ).mean())*self.loss_weights['emo']\n                kp_all[i]['jacobian'][:,1] = kp_emo[i]['jacobian'][:,0] + kp_driving_a[i]['jacobian'][:,1]\n                kp_all[i]['jacobian'][:,4] = kp_emo[i]['jacobian'][:,1] + kp_driving_a[i]['jacobian'][:,4]\n                kp_all[i]['jacobian'][:,6] = kp_emo[i]['jacobian'][:,2] + kp_driving_a[i]['jacobian'][:,6]\n                kp_all[i]['jacobian'][:,8] = kp_emo[i]['jacobian'][:,3] + kp_driving_a[i]['jacobian'][:,8]\n                kp_all[i]['value'][:,1] = kp_emo[i]['value'][:,0] + kp_driving_a[i]['value'][:,1]\n                kp_all[i]['value'][:,4] = kp_emo[i]['value'][:,1] + kp_driving_a[i]['value'][:,4]\n                kp_all[i]['value'][:,6] = kp_emo[i]['value'][:,2] + kp_driving_a[i]['value'][:,6]\n                kp_all[i]['value'][:,8] = kp_emo[i]['value'][:,3] + kp_driving_a[i]['value'][:,8]\n            elif self.train_params['type'] == 'linear_10' or self.train_params['type'] == 'linear_np_10':\n                loss_jacobian += (torch.abs(kp_driving[i]['jacobian'] - kp_driving_a[i]['jacobian'] -kp_emo[i]['jacobian']).mean())*self.loss_weights['emo']\n\n                loss_classify += self.CroEn_loss(fakes[i],x['emotion'])\n                loss_value += (torch.abs(kp_driving[i]['value'].detach() - kp_driving_a[i]['value']  - kp_emo[i]['value'] ).mean())*self.loss_weights['emo']\n\n        #    kp_all[i]['value'] = kp_emo[i]['value'] + kp_driving_a[i]['value']\n\n        loss_values['loss_value'] = loss_value/len(kp_driving)\n  #      loss_values['loss_heatmap'] = loss_heatmap/len(kp_driving)\n        loss_values['loss_jacobian'] = loss_jacobian/len(kp_driving)\n        if self.train_params['classify'] == True:\n            loss_values['loss_classify'] = loss_classify/len(kp_driving)\n        else:\n            loss_values['loss_classify'] = torch.tensor(0, device = loss_values['loss_value'].device)\n        \n        \n\n\n\n        return loss_values,generated\n\n\nclass GeneratorFullModel(torch.nn.Module):\n    \"\"\"\n    Merge all generator related updates into single model for better multi-gpu usage\n    \"\"\"\n\n    def __init__(self, kp_extractor, kp_extractor_a, audio_feature, generator, discriminator, train_params):\n        super(GeneratorFullModel, self).__init__()\n        self.kp_extractor = kp_extractor\n        self.kp_extractor_a = kp_extractor_a\n    #    self.content_encoder = content_encoder\n    #    self.emotion_encoder = emotion_encoder\n        self.audio_feature = audio_feature\n        self.generator = generator\n        self.discriminator = discriminator\n        self.train_params = train_params\n        self.scales = train_params['scales']\n        self.disc_scales = self.discriminator.scales\n        self.pyramid = ImagePyramide(self.scales, generator.num_channels)\n        if torch.cuda.is_available():\n            self.pyramid = self.pyramid.cuda()\n\n        self.loss_weights = train_params['loss_weights']\n\n        if sum(self.loss_weights['perceptual']) != 0:\n            self.vgg = Vgg19()\n            if torch.cuda.is_available():\n                self.vgg = self.vgg.cuda()\n        \n        self.pca = torch.FloatTensor(np.load('.../LRW/list/U_106.npy'))[:, :16].cuda()\n        self.mean = torch.FloatTensor(np.load('.../LRW/list/mean_106.npy')).cuda()\n        \n    def forward(self, x):\n   #     source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[])\n      #  source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1)))\n   #     kp_source = self.kp_extractor(x['source'])\n   #     kp_source_a = self.kp_extractor_a(x['source'], x['source_cube'], source_a_f)\n      #  driving_a_f = self.audio_feature(self.content_encoder(x['driving_audio'].unsqueeze(1)), self.emotion_encoder(x['driving_audio'].unsqueeze(1)))\n      #  driving_a_f = self.audio_feature(x['driving_audio'])\n      #  kp_driving = self.kp_extractor(x['driving'])\n   #     kp_driving_a = self.kp_extractor_a(x['driving'], x['driving_cube'], driving_a_f)\n       \n        kp_driving = []\n        for i in range(16):\n            kp_driving.append(self.kp_extractor(x['driving'][:,i],x['driving_landmark'][:,i],self.loss_weights['equivariance_value']))\n        \n        kp_driving_a = []\n        fc_out, deco_out = self.audio_feature(x['example_landmark'], x['driving_audio'], x['driving_pose'])\n        fake_lmark=fc_out + x['example_landmark'].expand_as(fc_out)\n        \n      \n        fake_lmark = torch.mm( fake_lmark, self.pca.t() )\n        fake_lmark = fake_lmark + self.mean.expand_as(fake_lmark)\n    \n\n        fake_lmark = fake_lmark.unsqueeze(0) \n\n    #    for i in range(16):\n    #        kp_driving_a.append()\n        \n   #     generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving)\n   #     generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})\n\n        loss_values = {}\n\n        pyramide_real = self.pyramid(x['driving'])\n        pyramide_generated = self.pyramid(generated['prediction'])\n        \n        if self.loss_weights['audio'] != 0:\n            value = torch.abs(kp_source['jacobian'].detach() - kp_source_a['jacobian'].detach()).mean() + torch.abs(kp_driving['jacobian'].detach() - kp_driving_a['jacobian']).mean()\n            value = value/2\n            loss_values['jacobian'] = value*self.loss_weights['audio']\n            value = torch.abs(kp_source['heatmap'].detach() - kp_source_a['heatmap'].detach()).mean() + torch.abs(kp_driving['heatmap'].detach() - kp_driving_a['heatmap']).mean()\n            value = value/2\n            loss_values['heatmap'] = value*self.loss_weights['audio']\n            value = torch.abs(kp_source['value'].detach() - kp_source_a['value'].detach()).mean() + torch.abs(kp_driving['value'].detach() - kp_driving_a['value']).mean()\n            value = value/2\n            loss_values['value'] = value*self.loss_weights['audio']\n            \n        if sum(self.loss_weights['perceptual']) != 0:\n            value_total = 0\n            for scale in self.scales:\n                x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])\n                y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])\n\n                for i, weight in enumerate(self.loss_weights['perceptual']):\n                    value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()\n                    value_total += self.loss_weights['perceptual'][i] * value\n                loss_values['perceptual'] = value_total\n\n        if self.loss_weights['generator_gan'] != 0:\n            discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))\n            discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))\n            value_total = 0\n            for scale in self.disc_scales:\n                key = 'prediction_map_%s' % scale\n                value = ((1 - discriminator_maps_generated[key]) ** 2).mean()\n                value_total += self.loss_weights['generator_gan'] * value\n            loss_values['gen_gan'] = value_total\n\n            if sum(self.loss_weights['feature_matching']) != 0:\n                value_total = 0\n                for scale in self.disc_scales:\n                    key = 'feature_maps_%s' % scale\n                    for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):\n                        if self.loss_weights['feature_matching'][i] == 0:\n                            continue\n                        value = torch.abs(a - b).mean()\n                        value_total += self.loss_weights['feature_matching'][i] * value\n                    loss_values['feature_matching'] = value_total\n\n        if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0:\n            transform = Transform(x['driving'].shape[0], **self.train_params['transform_params'])\n            transformed_frame = transform.transform_frame(x['driving'])\n            transformed_landmark =  transform.inverse_warp_coordinates(x['driving_landmark'])\n            transformed_kp = self.kp_extractor(transformed_frame)\n\n            generated['transformed_frame'] = transformed_frame\n            generated['transformed_kp'] = transformed_kp\n            \n            ## Value loss part\n            if self.loss_weights['equivariance_value'] != 0:\n                value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean()\n                loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value\n\n            ## jacobian loss part\n            if self.loss_weights['equivariance_jacobian'] != 0:\n                jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']),\n                                                    transformed_kp['jacobian'])\n\n                normed_driving = torch.inverse(kp_driving['jacobian'])\n                normed_transformed = jacobian_transformed\n                value = torch.matmul(normed_driving, normed_transformed)\n\n                eye = torch.eye(2).view(1, 1, 2, 2).type(value.type())\n\n                value = torch.abs(eye - value).mean()\n                loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value\n\n        return loss_values, generated\n\n\nclass DiscriminatorFullModel(torch.nn.Module):\n    \"\"\"\n    Merge all discriminator related updates into single model for better multi-gpu usage\n    \"\"\"\n\n    def __init__(self, kp_extractor, generator, discriminator, train_params):\n        super(DiscriminatorFullModel, self).__init__()\n        self.kp_extractor = kp_extractor\n        self.generator = generator\n        self.discriminator = discriminator\n        self.train_params = train_params\n        self.scales = self.discriminator.scales\n        self.pyramid = ImagePyramide(self.scales, generator.num_channels)\n        if torch.cuda.is_available():\n            self.pyramid = self.pyramid.cuda()\n\n        self.loss_weights = train_params['loss_weights']\n\n    def forward(self, x, generated):\n        pyramide_real = self.pyramid(x['driving'])\n        pyramide_generated = self.pyramid(generated['prediction'].detach())\n\n        kp_driving = generated['kp_driving']\n        discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))\n        discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))\n\n        loss_values = {}\n        value_total = 0\n        for scale in self.scales:\n            key = 'prediction_map_%s' % scale\n            value = (1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2\n            value_total += self.loss_weights['discriminator_gan'] * value.mean()\n        loss_values['disc_gan'] = value_total\n\n        return loss_values\n"
  },
  {
    "path": "modules/model_delta_map.py",
    "content": "from torch import nn\nimport torch\nimport torch.nn.functional as F\nfrom modules.util import AntiAliasInterpolation2d, make_coordinate_grid\nfrom torchvision import models\nimport numpy as np\nfrom torch.autograd import grad\n\n\nclass Vgg19(torch.nn.Module):\n    \"\"\"\n    Vgg19 network for perceptual loss. See Sec 3.3.\n    \"\"\"\n    def __init__(self, requires_grad=False):\n        super(Vgg19, self).__init__()\n        vgg_pretrained_features = models.vgg19(pretrained=True).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        for x in range(2):\n            self.slice1.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(2, 7):\n            self.slice2.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(7, 12):\n            self.slice3.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(12, 21):\n            self.slice4.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(21, 30):\n            self.slice5.add_module(str(x), vgg_pretrained_features[x])\n\n        self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),\n                                       requires_grad=False)\n        self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),\n                                      requires_grad=False)\n\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        X = (X - self.mean) / self.std\n        h_relu1 = self.slice1(X)\n        h_relu2 = self.slice2(h_relu1)\n        h_relu3 = self.slice3(h_relu2)\n        h_relu4 = self.slice4(h_relu3)\n        h_relu5 = self.slice5(h_relu4)\n        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]\n        return out\n\n\nclass ImagePyramide(torch.nn.Module):\n    \"\"\"\n    Create image pyramide for computing pyramide perceptual loss. See Sec 3.3\n    \"\"\"\n    def __init__(self, scales, num_channels):\n        super(ImagePyramide, self).__init__()\n        downs = {}\n        for scale in scales:\n            downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)\n        self.downs = nn.ModuleDict(downs)\n\n    def forward(self, x):\n        out_dict = {}\n        for scale, down_module in self.downs.items():\n            out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)\n        return out_dict\n\n\nclass Transform:\n    \"\"\"\n    Random tps transformation for equivariance constraints. See Sec 3.3\n    \"\"\"\n    def __init__(self, bs, **kwargs):\n        noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))\n        self.theta = noise + torch.eye(2, 3).view(1, 2, 3)\n        self.bs = bs\n\n        if ('sigma_tps' in kwargs) and ('points_tps' in kwargs):\n            self.tps = True\n            self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())\n            self.control_points = self.control_points.unsqueeze(0)\n            self.control_params = torch.normal(mean=0,\n                                               std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))\n        else:\n            self.tps = False\n\n    def transform_frame(self, frame):\n        grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2]\n        grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)\n        grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)\n        return F.grid_sample(frame, grid, padding_mode=\"reflection\")\n    \n    def inverse_transform_frame(self, frame):\n        grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2]\n        grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)\n        grid = self.inverse_warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)\n        return F.grid_sample(frame, grid, padding_mode=\"reflection\")\n    \n    def warp_coordinates(self, coordinates):\n        theta = self.theta.type(coordinates.type())\n        theta = theta.unsqueeze(1)\n        transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]\n        transformed = transformed.squeeze(-1)\n\n        if self.tps:\n            control_points = self.control_points.type(coordinates.type())\n            control_params = self.control_params.type(coordinates.type())\n            distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)\n            distances = torch.abs(distances).sum(-1)\n\n            result = distances ** 2\n            result = result * torch.log(distances + 1e-6)\n            result = result * control_params\n            result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)\n            transformed = transformed + result\n\n        return transformed\n\n    def inverse_warp_coordinates(self, coordinates):\n        theta = self.theta.type(coordinates.type())\n        theta = theta.unsqueeze(1)\n        a = torch.FloatTensor([[[[0,0,1]]]]).repeat([self.bs,1,1,1]).cuda()\n        c = torch.cat((theta,a),2)\n        d = c.inverse()[:,:,:2,:]\n        d = d.type(coordinates.type())\n        transformed = torch.matmul(d[:, :, :, :2], coordinates.unsqueeze(-1)) + d[:, :, :, 2:]\n        transformed = transformed.squeeze(-1)\n        \n        if self.tps:\n            control_points = self.control_points.type(coordinates.type())\n            control_params = self.control_params.type(coordinates.type())\n            distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)\n            distances = torch.abs(distances).sum(-1)\n\n            result = distances ** 2\n            result = result * torch.log(distances + 1e-6)\n            result = result * control_params\n            result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)\n            transformed = transformed + result\n        \n        \n        return transformed\n\n    def jacobian(self, coordinates):\n        coordinates.requires_grad=True\n        new_coordinates = self.warp_coordinates(coordinates)#[4,10,2]\n        grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True)\n        grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True)\n        jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2)\n        return jacobian\n\n\ndef detach_kp(kp):\n    return {key: value.detach() for key, value in kp.items()}\n\nclass TrainFullModel(torch.nn.Module):\n    \"\"\"\n    Merge all generator related updates into single model for better multi-gpu usage\n    \"\"\"\n\n    def __init__(self, kp_extractor, emo_feature, kp_extractor_a, audio_feature, generator, discriminator, train_params, device_ids):\n        super(TrainFullModel, self).__init__()\n        self.kp_extractor = kp_extractor\n        self.kp_extractor_a = kp_extractor_a\n    #    self.emo_detector = emo_detector\n    #    self.content_encoder = content_encoder\n    #    self.emotion_encoder = emotion_encoder\n        self.audio_feature = audio_feature\n        self.emo_feature = emo_feature\n        self.generator = generator\n        self.discriminator = discriminator\n        self.train_params = train_params\n        self.scales = train_params['scales']\n        self.disc_scales = self.discriminator.scales\n        self.pyramid = ImagePyramide(self.scales, generator.num_channels)\n        if torch.cuda.is_available():\n            self.pyramid = self.pyramid.cuda()\n\n        self.loss_weights = train_params['loss_weights']\n\n        if sum(self.loss_weights['perceptual']) != 0:\n            self.vgg = Vgg19()\n            if torch.cuda.is_available():\n                self.vgg = self.vgg.cuda()\n        \n       # self.pca = torch.FloatTensor(np.load('/mnt/lustre/jixinya/Home/LRW/list/U_106.npy'))[:, :16].to(device_ids[0])\n      #  self.mean = torch.FloatTensor(np.load('/mnt/lustre/jixinya/Home/LRW/list/mean_106.npy')).to(device_ids[0])\n        self.mse_loss_fn   =  nn.MSELoss().cuda()\n        self.CroEn_loss =  nn.CrossEntropyLoss().cuda()\n    def forward(self, x):\n   #     source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[])\n      #  source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1)))\n        kp_source = self.kp_extractor(x['example_image'])\n\n        kp_driving = []\n        kp_emo = []\n        for i in range(16):\n            kp_driving.append(self.kp_extractor(x['driving'][:,i]))\n    #        kp_emo.append(self.emo_detector(x['driving'][:,i]))\n    #    print('KP_driving ', file=open('/mnt/lustre/jixinya/Home/fomm_audio/log/LRW_test.txt', 'a'))\n        kp_driving_a = [] #x['example_image'],\n        deco_out = self.audio_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net'])\n    #    emo_out = self.emo_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net'])\n        loss_values = {}\n        \n        if self.loss_weights['emo'] != 0:\n            \n            kp_driving_a = []\n            fakes = []\n            for i in range(16):\n                kp_driving_a.append(self.kp_extractor_a(deco_out[:,i]))#\n                value = self.kp_extractor_a(deco_out[:,i])['value']\n                jacobian = self.kp_extractor_a(deco_out[:,i])['jacobian']\n                if self.train_params['type'] == 'map_4':\n                    out, fake = self.emo_feature.map_4(x['transformed_driving'][:,i],value,jacobian)   \n                    kp_emo.append(out)\n                    fakes.append(fake)\n                 #   kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian))\n                elif self.train_params['type'] == 'map_10':\n                 #   kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))\n                \n                    out, fake = self.emo_feature(x['transformed_driving'][:,i],value,jacobian)   \n                    kp_emo.append(out)\n                    fakes.append(fake)\n            #    kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian))\n    #    print('Kp_audio_driving ', file=open('/mnt/lustre/jixinya/Home/fomm_audio/log/LRW_test.txt', 'a'))\n        loss_value = 0\n    #    loss_heatmap = 0\n        loss_jacobian = 0\n        loss_perceptual = 0\n        loss_classify = 0\n        kp_all = kp_driving_a\n        for i in range(len(kp_driving)):\n            if self.train_params['type'] == 'map_4':\n                loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,1] - kp_driving_a[i]['jacobian'][:,1] -kp_emo[i]['jacobian'][:,0]).mean())*self.loss_weights['emo']\n                loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,4] - kp_driving_a[i]['jacobian'][:,4] -kp_emo[i]['jacobian'][:,1]).mean())*self.loss_weights['emo']\n                loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,6] - kp_driving_a[i]['jacobian'][:,6] -kp_emo[i]['jacobian'][:,2]).mean())*self.loss_weights['emo']\n                loss_jacobian += (torch.abs(kp_driving[i]['jacobian'][:,8] - kp_driving_a[i]['jacobian'][:,8] -kp_emo[i]['jacobian'][:,3]).mean())*self.loss_weights['emo']\n        \n                loss_classify += self.CroEn_loss(fakes[i],x['emotion'])\n                loss_value += (torch.abs(kp_driving[i]['value'][:,1] .detach() - kp_driving_a[i]['value'][:,1]  - kp_emo[i]['value'][:,0] ).mean())*self.loss_weights['emo']\n                loss_value += (torch.abs(kp_driving[i]['value'][:,4] .detach() - kp_driving_a[i]['value'][:,4]  - kp_emo[i]['value'][:,1] ).mean())*self.loss_weights['emo']\n                loss_value += (torch.abs(kp_driving[i]['value'][:,6] .detach() - kp_driving_a[i]['value'][:,6]  - kp_emo[i]['value'][:,2] ).mean())*self.loss_weights['emo']\n                loss_value += (torch.abs(kp_driving[i]['value'][:,8] .detach() - kp_driving_a[i]['value'][:,8]  - kp_emo[i]['value'][:,3] ).mean())*self.loss_weights['emo']\n                kp_all[i]['jacobian'][:,1] = kp_emo[i]['jacobian'][:,0] + kp_driving_a[i]['jacobian'][:,1]\n                kp_all[i]['jacobian'][:,4] = kp_emo[i]['jacobian'][:,1] + kp_driving_a[i]['jacobian'][:,4]\n                kp_all[i]['jacobian'][:,6] = kp_emo[i]['jacobian'][:,2] + kp_driving_a[i]['jacobian'][:,6]\n                kp_all[i]['jacobian'][:,8] = kp_emo[i]['jacobian'][:,3] + kp_driving_a[i]['jacobian'][:,8]\n                kp_all[i]['value'][:,1] = kp_emo[i]['value'][:,0] + kp_driving_a[i]['value'][:,1]\n                kp_all[i]['value'][:,4] = kp_emo[i]['value'][:,1] + kp_driving_a[i]['value'][:,4]\n                kp_all[i]['value'][:,6] = kp_emo[i]['value'][:,2] + kp_driving_a[i]['value'][:,6]\n                kp_all[i]['value'][:,8] = kp_emo[i]['value'][:,3] + kp_driving_a[i]['value'][:,8]\n            elif self.train_params['type'] == 'map_10':\n                loss_jacobian += (torch.abs(kp_driving[i]['jacobian'] - kp_driving_a[i]['jacobian'] -kp_emo[i]['jacobian']).mean())*self.loss_weights['emo']\n        \n                loss_classify += self.CroEn_loss(fakes[i],x['emotion'])\n                loss_value += (torch.abs(kp_driving[i]['value'].detach() - kp_driving_a[i]['value']  - kp_emo[i]['value'] ).mean())*self.loss_weights['emo']\n            \n        #    kp_all[i]['value'] = kp_emo[i]['value'] + kp_driving_a[i]['value']\n            \n        loss_values['loss_value'] = loss_value/len(kp_driving)\n  #      loss_values['loss_heatmap'] = loss_heatmap/len(kp_driving)\n        loss_values['loss_jacobian'] = loss_jacobian/len(kp_driving)\n        loss_values['loss_classify'] = loss_classify/len(kp_driving)\n   \n        if self.train_params['generator'] == 'not':\n            loss_values['perceptual'] = self.mse_loss_fn(deco_out,deco_out)\n            for i in range(1): #0,len(kp_driving),4\n \n                generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_all[i])\n                generated.update({'kp_source': kp_source, 'kp_driving': kp_all})\n        elif self.train_params['generator'] == 'visual':\n            for i in range(0,len(kp_driving),4): #0,len(kp_driving),4\n \n                generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving[i])\n                generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})\n                \n                pyramide_real = self.pyramid(x['driving'][:,i])\n                pyramide_generated = self.pyramid(generated['prediction'])\n        \n                if sum(self.loss_weights['perceptual']) != 0:\n                    value_total = 0\n                    for scale in self.scales:\n                        x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])\n                        y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])\n\n                        for i, weight in enumerate(self.loss_weights['perceptual']):\n                            value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()\n                            value_total += self.loss_weights['perceptual'][i] * value\n                    loss_perceptual += value_total\n        \n            length = int((len(kp_driving)-1)/4)+1\n            loss_values['perceptual'] = loss_perceptual/length\n        elif self.train_params['generator'] == 'audio':\n            for i in range(0,len(kp_driving),4): #0,len(kp_driving),4\n \n                generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving_a[i])\n                generated.update({'kp_source': kp_source, 'kp_driving': kp_driving_a})\n                \n                pyramide_real = self.pyramid(x['driving'][:,i])\n                pyramide_generated = self.pyramid(generated['prediction'])\n        \n                if sum(self.loss_weights['perceptual']) != 0:\n                    value_total = 0\n                    for scale in self.scales:\n                        x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])\n                        y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])\n\n                        for i, weight in enumerate(self.loss_weights['perceptual']):\n                            value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()\n                            value_total += self.loss_weights['perceptual'][i] * value\n                    loss_perceptual += value_total\n        \n            length = int((len(kp_driving)-1)/4)+1\n            loss_values['perceptual'] = loss_perceptual/length\n        else:\n            print('wrong train_params: ', self.train_params['generator'])\n      \n        \n      \n        return loss_values,generated\n\nclass GeneratorFullModel(torch.nn.Module):\n    \"\"\"\n    Merge all generator related updates into single model for better multi-gpu usage\n    \"\"\"\n\n    def __init__(self, kp_extractor, kp_extractor_a, audio_feature, generator, discriminator, train_params):\n        super(GeneratorFullModel, self).__init__()\n        self.kp_extractor = kp_extractor\n        self.kp_extractor_a = kp_extractor_a\n    #    self.content_encoder = content_encoder\n    #    self.emotion_encoder = emotion_encoder\n        self.audio_feature = audio_feature\n        self.generator = generator\n        self.discriminator = discriminator\n        self.train_params = train_params\n        self.scales = train_params['scales']\n        self.disc_scales = self.discriminator.scales\n        self.pyramid = ImagePyramide(self.scales, generator.num_channels)\n        if torch.cuda.is_available():\n            self.pyramid = self.pyramid.cuda()\n\n        self.loss_weights = train_params['loss_weights']\n\n        if sum(self.loss_weights['perceptual']) != 0:\n            self.vgg = Vgg19()\n            if torch.cuda.is_available():\n                self.vgg = self.vgg.cuda()\n        \n        self.pca = torch.FloatTensor(np.load('.../LRW/list/U_106.npy'))[:, :16].cuda()\n        self.mean = torch.FloatTensor(np.load('.../LRW/list/mean_106.npy')).cuda()\n        \n    def forward(self, x):\n   #     source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[])\n      #  source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1)))\n   #     kp_source = self.kp_extractor(x['source'])\n   #     kp_source_a = self.kp_extractor_a(x['source'], x['source_cube'], source_a_f)\n      #  driving_a_f = self.audio_feature(self.content_encoder(x['driving_audio'].unsqueeze(1)), self.emotion_encoder(x['driving_audio'].unsqueeze(1)))\n      #  driving_a_f = self.audio_feature(x['driving_audio'])\n      #  kp_driving = self.kp_extractor(x['driving'])\n   #     kp_driving_a = self.kp_extractor_a(x['driving'], x['driving_cube'], driving_a_f)\n       \n        kp_driving = []\n        for i in range(16):\n            kp_driving.append(self.kp_extractor(x['driving'][:,i],x['driving_landmark'][:,i],self.loss_weights['equivariance_value']))\n        \n        kp_driving_a = []\n        fc_out, deco_out = self.audio_feature(x['example_landmark'], x['driving_audio'], x['driving_pose'])\n        fake_lmark=fc_out + x['example_landmark'].expand_as(fc_out)\n        \n      \n        fake_lmark = torch.mm( fake_lmark, self.pca.t() )\n        fake_lmark = fake_lmark + self.mean.expand_as(fake_lmark)\n    \n\n        fake_lmark = fake_lmark.unsqueeze(0) \n\n    #    for i in range(16):\n    #        kp_driving_a.append()\n        \n   #     generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving)\n   #     generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})\n\n        loss_values = {}\n\n        pyramide_real = self.pyramid(x['driving'])\n        pyramide_generated = self.pyramid(generated['prediction'])\n        \n        if self.loss_weights['audio'] != 0:\n            value = torch.abs(kp_source['jacobian'].detach() - kp_source_a['jacobian'].detach()).mean() + torch.abs(kp_driving['jacobian'].detach() - kp_driving_a['jacobian']).mean()\n            value = value/2\n            loss_values['jacobian'] = value*self.loss_weights['audio']\n            value = torch.abs(kp_source['heatmap'].detach() - kp_source_a['heatmap'].detach()).mean() + torch.abs(kp_driving['heatmap'].detach() - kp_driving_a['heatmap']).mean()\n            value = value/2\n            loss_values['heatmap'] = value*self.loss_weights['audio']\n            value = torch.abs(kp_source['value'].detach() - kp_source_a['value'].detach()).mean() + torch.abs(kp_driving['value'].detach() - kp_driving_a['value']).mean()\n            value = value/2\n            loss_values['value'] = value*self.loss_weights['audio']\n            \n        if sum(self.loss_weights['perceptual']) != 0:\n            value_total = 0\n            for scale in self.scales:\n                x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])\n                y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])\n\n                for i, weight in enumerate(self.loss_weights['perceptual']):\n                    value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()\n                    value_total += self.loss_weights['perceptual'][i] * value\n                loss_values['perceptual'] = value_total\n\n        if self.loss_weights['generator_gan'] != 0:\n            discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))\n            discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))\n            value_total = 0\n            for scale in self.disc_scales:\n                key = 'prediction_map_%s' % scale\n                value = ((1 - discriminator_maps_generated[key]) ** 2).mean()\n                value_total += self.loss_weights['generator_gan'] * value\n            loss_values['gen_gan'] = value_total\n\n            if sum(self.loss_weights['feature_matching']) != 0:\n                value_total = 0\n                for scale in self.disc_scales:\n                    key = 'feature_maps_%s' % scale\n                    for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):\n                        if self.loss_weights['feature_matching'][i] == 0:\n                            continue\n                        value = torch.abs(a - b).mean()\n                        value_total += self.loss_weights['feature_matching'][i] * value\n                    loss_values['feature_matching'] = value_total\n\n        if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0:\n            transform = Transform(x['driving'].shape[0], **self.train_params['transform_params'])\n            transformed_frame = transform.transform_frame(x['driving'])\n            transformed_landmark =  transform.inverse_warp_coordinates(x['driving_landmark'])\n            transformed_kp = self.kp_extractor(transformed_frame)\n\n            generated['transformed_frame'] = transformed_frame\n            generated['transformed_kp'] = transformed_kp\n            \n            ## Value loss part\n            if self.loss_weights['equivariance_value'] != 0:\n                value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean()\n                loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value\n\n            ## jacobian loss part\n            if self.loss_weights['equivariance_jacobian'] != 0:\n                jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']),\n                                                    transformed_kp['jacobian'])\n\n                normed_driving = torch.inverse(kp_driving['jacobian'])\n                normed_transformed = jacobian_transformed\n                value = torch.matmul(normed_driving, normed_transformed)\n\n                eye = torch.eye(2).view(1, 1, 2, 2).type(value.type())\n\n                value = torch.abs(eye - value).mean()\n                loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value\n\n        return loss_values, generated\n\n\nclass DiscriminatorFullModel(torch.nn.Module):\n    \"\"\"\n    Merge all discriminator related updates into single model for better multi-gpu usage\n    \"\"\"\n\n    def __init__(self, kp_extractor, generator, discriminator, train_params):\n        super(DiscriminatorFullModel, self).__init__()\n        self.kp_extractor = kp_extractor\n        self.generator = generator\n        self.discriminator = discriminator\n        self.train_params = train_params\n        self.scales = self.discriminator.scales\n        self.pyramid = ImagePyramide(self.scales, generator.num_channels)\n        if torch.cuda.is_available():\n            self.pyramid = self.pyramid.cuda()\n\n        self.loss_weights = train_params['loss_weights']\n\n    def forward(self, x, generated):\n        pyramide_real = self.pyramid(x['driving'])\n        pyramide_generated = self.pyramid(generated['prediction'].detach())\n\n        kp_driving = generated['kp_driving']\n        discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))\n        discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))\n\n        loss_values = {}\n        value_total = 0\n        for scale in self.scales:\n            key = 'prediction_map_%s' % scale\n            value = (1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2\n            value_total += self.loss_weights['discriminator_gan'] * value.mean()\n        loss_values['disc_gan'] = value_total\n\n        return loss_values\n"
  },
  {
    "path": "modules/model_gen.py",
    "content": "from torch import nn\nimport torch\nimport torch.nn.functional as F\nfrom modules.util import AntiAliasInterpolation2d, make_coordinate_grid\nfrom torchvision import models\nimport numpy as np\nfrom torch.autograd import grad\n\n\nclass Vgg19(torch.nn.Module):\n    \"\"\"\n    Vgg19 network for perceptual loss. See Sec 3.3.\n    \"\"\"\n    def __init__(self, requires_grad=False):\n        super(Vgg19, self).__init__()\n        vgg_pretrained_features = models.vgg19(pretrained=True).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        for x in range(2):\n            self.slice1.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(2, 7):\n            self.slice2.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(7, 12):\n            self.slice3.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(12, 21):\n            self.slice4.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(21, 30):\n            self.slice5.add_module(str(x), vgg_pretrained_features[x])\n\n        self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),\n                                       requires_grad=False)\n        self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),\n                                      requires_grad=False)\n\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        X = (X - self.mean) / self.std\n        h_relu1 = self.slice1(X)\n        h_relu2 = self.slice2(h_relu1)\n        h_relu3 = self.slice3(h_relu2)\n        h_relu4 = self.slice4(h_relu3)\n        h_relu5 = self.slice5(h_relu4)\n        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]\n        return out\n\n\nclass ImagePyramide(torch.nn.Module):\n    \"\"\"\n    Create image pyramide for computing pyramide perceptual loss. See Sec 3.3\n    \"\"\"\n    def __init__(self, scales, num_channels):\n        super(ImagePyramide, self).__init__()\n        downs = {}\n        for scale in scales:\n            downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)\n        self.downs = nn.ModuleDict(downs)\n\n    def forward(self, x):\n        out_dict = {}\n        for scale, down_module in self.downs.items():\n            out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)\n        return out_dict\n\n\nclass Transform:\n    \"\"\"\n    Random tps transformation for equivariance constraints. See Sec 3.3\n    \"\"\"\n    def __init__(self, bs, **kwargs):\n        noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))\n        self.theta = noise + torch.eye(2, 3).view(1, 2, 3)\n        self.bs = bs\n\n        if ('sigma_tps' in kwargs) and ('points_tps' in kwargs):\n            self.tps = True\n            self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())\n            self.control_points = self.control_points.unsqueeze(0)\n            self.control_params = torch.normal(mean=0,\n                                               std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))\n        else:\n            self.tps = False\n\n    def transform_frame(self, frame):\n        grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2]\n        grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)\n        grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)\n        return F.grid_sample(frame, grid, padding_mode=\"reflection\")\n\n    def inverse_transform_frame(self, frame):\n        grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) #[1,256,256,2]\n        grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)\n        grid = self.inverse_warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)\n        return F.grid_sample(frame, grid, padding_mode=\"reflection\")\n\n    def warp_coordinates(self, coordinates):\n        theta = self.theta.type(coordinates.type())\n        theta = theta.unsqueeze(1)\n        transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]\n        transformed = transformed.squeeze(-1)\n\n        if self.tps:\n            control_points = self.control_points.type(coordinates.type())\n            control_params = self.control_params.type(coordinates.type())\n            distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)\n            distances = torch.abs(distances).sum(-1)\n\n            result = distances ** 2\n            result = result * torch.log(distances + 1e-6)\n            result = result * control_params\n            result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)\n            transformed = transformed + result\n\n        return transformed\n\n    def inverse_warp_coordinates(self, coordinates):\n        theta = self.theta.type(coordinates.type())\n        theta = theta.unsqueeze(1)\n        a = torch.FloatTensor([[[[0,0,1]]]]).repeat([self.bs,1,1,1]).cuda()\n        c = torch.cat((theta,a),2)\n        d = c.inverse()[:,:,:2,:]\n        d = d.type(coordinates.type())\n        transformed = torch.matmul(d[:, :, :, :2], coordinates.unsqueeze(-1)) + d[:, :, :, 2:]\n        transformed = transformed.squeeze(-1)\n\n        if self.tps:\n            control_points = self.control_points.type(coordinates.type())\n            control_params = self.control_params.type(coordinates.type())\n            distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)\n            distances = torch.abs(distances).sum(-1)\n\n            result = distances ** 2\n            result = result * torch.log(distances + 1e-6)\n            result = result * control_params\n            result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)\n            transformed = transformed + result\n\n\n        return transformed\n\n    def jacobian(self, coordinates):\n        coordinates.requires_grad=True\n        new_coordinates = self.warp_coordinates(coordinates)#[4,10,2]\n        grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True)\n        grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True)\n        jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2)\n        return jacobian\n\n\ndef detach_kp(kp):\n    return {key: value.detach() for key, value in kp.items()}\n\nclass TrainFullModel(torch.nn.Module):\n    \"\"\"\n    Merge all generator related updates into single model for better multi-gpu usage\n    \"\"\"\n\n    def __init__(self, kp_extractor, emo_feature, kp_extractor_a, audio_feature, generator, discriminator, train_params, device_ids):\n        super(TrainFullModel, self).__init__()\n        self.kp_extractor = kp_extractor\n        self.kp_extractor_a = kp_extractor_a\n    #    self.emo_detector = emo_detector\n    #    self.content_encoder = content_encoder\n    #    self.emotion_encoder = emotion_encoder\n        self.audio_feature = audio_feature\n        self.emo_feature = emo_feature\n        self.generator = generator\n        self.discriminator = discriminator\n        self.train_params = train_params\n        self.scales = train_params['scales']\n        self.disc_scales = self.discriminator.scales\n        self.pyramid = ImagePyramide(self.scales, generator.num_channels)\n        if torch.cuda.is_available():\n            self.pyramid = self.pyramid.cuda()\n\n        self.loss_weights = train_params['loss_weights']\n\n        if sum(self.loss_weights['perceptual']) != 0:\n            self.vgg = Vgg19()\n            if torch.cuda.is_available():\n                self.vgg = self.vgg.cuda()\n\n       # self.pca = torch.FloatTensor(np.load('/mnt/lustre/jixinya/Home/LRW/list/U_106.npy'))[:, :16].to(device_ids[0])\n      #  self.mean = torch.FloatTensor(np.load('/mnt/lustre/jixinya/Home/LRW/list/mean_106.npy')).to(device_ids[0])\n        self.mse_loss_fn   =  nn.MSELoss().cuda()\n        self.CroEn_loss =  nn.CrossEntropyLoss().cuda()\n    def forward(self, x):\n   #     source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[])\n      #  source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1)))\n        kp_source = self.kp_extractor(x['example_image'])\n      #  print(x['name'],len(x['name']))\n        kp_driving = []\n        kp_emo = []\n        for i in range(16):\n            kp_driving.append(self.kp_extractor(x['driving'][:,i]))\n    #        kp_emo.append(self.emo_detector(x['driving'][:,i]))\n    #    print('KP_driving ', file=open('/mnt/lustre/jixinya/Home/fomm_audio/log/LRW_test.txt', 'a'))\n        kp_driving_a = [] #x['example_image'],\n        deco_out = self.audio_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net'])\n    #    emo_out = self.emo_feature(x['example_image'], x['driving_audio'], x['driving_pose'], self.train_params['jaco_net'])\n        loss_values = {}\n\n        if self.loss_weights['emo'] != 0:\n\n            kp_driving_a = []\n            fakes = []\n            for i in range(16):\n                kp_driving_a.append(self.kp_extractor_a(deco_out[:,i]))#\n                value = self.kp_extractor_a(deco_out[:,i])['value']\n                jacobian = self.kp_extractor_a(deco_out[:,i])['jacobian']\n                if self.train_params['type'] == 'linear_4' and x['name'][0] == 0:\n                    out, fake = self.emo_feature(x['transformed_driving'][:,i],value,jacobian)\n                    kp_emo.append(out)\n                    fakes.append(fake)\n                 #   kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian))\n                elif self.train_params['type'] == 'linear_10' and x['name'][0] == 0:\n                 #   kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))\n\n                    out, fake = self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian)\n                    kp_emo.append(out)\n                    fakes.append(fake)\n                elif self.train_params['type'] == 'linear_4_new' and x['name'][0] == 0:\n                 #   kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))\n\n                    out, fake = self.emo_feature.linear_4(x['transformed_driving'][:,i],value,jacobian)\n                    kp_emo.append(out)\n                    fakes.append(fake)\n                elif self.train_params['type'] == 'linear_np_4':\n                 #   kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))\n\n                    out, fake = self.emo_feature.linear_np_4(x['transformed_driving'][:,i],value,jacobian)\n                    kp_emo.append(out)\n                    fakes.append(fake)\n                elif self.train_params['type'] == 'linear_np_10':\n                 #   kp_emo.append(self.emo_feature.linear_10(x['transformed_driving'][:,i],value,jacobian))\n\n                    out, fake = self.emo_feature.linear_np_10(x['transformed_driving'][:,i],value,jacobian)\n                    kp_emo.append(out)\n                    fakes.append(fake)\n            #    kp_emo.append(self.emo_feature(x['transformed_driving'][:,i],value,jacobian))\n    #    print('Kp_audio_driving ', file=open('/mnt/lustre/jixinya/Home/fomm_audio/log/LRW_test.txt', 'a'))\n\n        loss_perceptual = 0\n\n        kp_all = kp_driving_a\n        if self.train_params['smooth'] == True:\n            value_all = torch.randn(len(kp_driving),out['value'].shape[0],out['value'].shape[1],out['value'].shape[2]).cuda()\n            jacobian_all = torch.randn(len(kp_driving),out['jacobian'].shape[0],out['jacobian'].shape[1],2,2).cuda()\n        print(len(kp_driving))\n        for i in range(len(kp_driving)):\n          #  if x['name'][i] == 'LRW':\n          #      loss_jacobian += (torch.abs(kp_driving[i]['jacobian'] - kp_driving_a[i]['jacobian']).mean())*self.loss_weights['emo']\n\n          #      loss_value += (torch.abs(kp_driving[i]['value'].detach() - kp_driving_a[i]['value']).mean())*self.loss_weights['emo']\n          #      loss_classify += self.mse_loss_fn(deco_out,deco_out)\n            if self.train_params['type'] == 'linear_4' and x['name'][0] == 0:\n\n                kp_all[i]['jacobian'][:,1] = kp_emo[i]['jacobian'][:,0] + kp_driving_a[i]['jacobian'][:,1]\n                kp_all[i]['jacobian'][:,4] = kp_emo[i]['jacobian'][:,1] + kp_driving_a[i]['jacobian'][:,4]\n                kp_all[i]['jacobian'][:,6] = kp_emo[i]['jacobian'][:,2] + kp_driving_a[i]['jacobian'][:,6]\n                kp_all[i]['jacobian'][:,8] = kp_emo[i]['jacobian'][:,3] + kp_driving_a[i]['jacobian'][:,8]\n                kp_all[i]['value'][:,1] = kp_emo[i]['value'][:,0] + kp_driving_a[i]['value'][:,1]\n                kp_all[i]['value'][:,4] = kp_emo[i]['value'][:,1] + kp_driving_a[i]['value'][:,4]\n                kp_all[i]['value'][:,6] = kp_emo[i]['value'][:,2] + kp_driving_a[i]['value'][:,6]\n                kp_all[i]['value'][:,8] = kp_emo[i]['value'][:,3] + kp_driving_a[i]['value'][:,8]\n\n        #    kp_all[i]['value'] = kp_emo[i]['value'] + kp_driving_a[i]['value']\n\n\n        if self.train_params['smooth'] == True:\n            loss_smooth = 0\n            loss_smooth += (torch.abs(value_all[2:,:,:,:] + value_all[:-2,:,:,:].detach() -2*value_all[1:-1,:,:,:].detach()).mean())*self.loss_weights['emo'] *100\n            loss_smooth += (torch.abs(jacobian_all[2:,:,:,:] + jacobian_all[:-2,:,:,:].detach() -2*jacobian_all[1:-1,:,:,:].detach()).mean())*self.loss_weights['emo'] *100\n            loss_values['loss_smooth'] = loss_smooth/len(kp_driving)\n        else:\n            loss_values['loss_smooth'] = self.mse_loss_fn(deco_out,deco_out)\n        if self.train_params['generator'] == 'not':\n            loss_values['perceptual'] = self.mse_loss_fn(deco_out,deco_out)\n            for i in range(1): #0,len(kp_driving),4\n\n                generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_all[i])\n                generated.update({'kp_source': kp_source, 'kp_driving': kp_all})\n        elif self.train_params['generator'] == 'visual':\n            for i in range(0,len(kp_driving),4): #0,len(kp_driving),4\n\n                generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_driving[i])\n                generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})\n\n                pyramide_real = self.pyramid(x['driving'][:,i])\n                pyramide_generated = self.pyramid(generated['prediction'])\n\n                if sum(self.loss_weights['perceptual']) != 0:\n                    value_total = 0\n                    for scale in self.scales:\n                        x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])\n                        y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])\n\n                        for i, weight in enumerate(self.loss_weights['perceptual']):\n                            value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()\n                            value_total += self.loss_weights['perceptual'][i] * value\n                    loss_perceptual += value_total\n\n            length = int((len(kp_driving)-1)/4)+1\n            loss_values['perceptual'] = loss_perceptual/length\n        elif self.train_params['generator'] == 'audio':\n            for i in range(0,len(kp_driving),4): #0,len(kp_driving),4\n\n                generated = self.generator(x['example_image'], kp_source=kp_source, kp_driving=kp_all[i])\n                generated.update({'kp_source': kp_source, 'kp_driving': kp_all})\n\n                pyramide_real = self.pyramid(x['driving'][:,i])\n                pyramide_generated = self.pyramid(generated['prediction'])\n            #    loss_mse = nn.MSELoss(generated['prediction'],x['driving'][:,i])\n                if sum(self.loss_weights['perceptual']) != 0:\n                    value_total = 0\n                    for scale in self.scales:\n                        x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])\n                        y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])\n\n                        for i, weight in enumerate(self.loss_weights['perceptual']):\n                            value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()\n                            value_total += self.loss_weights['perceptual'][i] * value\n                    loss_perceptual += value_total\n\n            length = int((len(kp_driving)-1)/4)+1\n            loss_values['perceptual'] = loss_perceptual/length\n      #      loss_values['mse'] = loss_mse/length\n            \n        else:\n            print('wrong train_params: ', self.train_params['generator'])\n\n\n\n        return loss_values,generated\n\nclass GeneratorFullModel(torch.nn.Module):\n    \"\"\"\n    Merge all generator related updates into single model for better multi-gpu usage\n    \"\"\"\n\n    def __init__(self, kp_extractor, kp_extractor_a, audio_feature, generator, discriminator, train_params):\n        super(GeneratorFullModel, self).__init__()\n        self.kp_extractor = kp_extractor\n        self.kp_extractor_a = kp_extractor_a\n    #    self.content_encoder = content_encoder\n    #    self.emotion_encoder = emotion_encoder\n        self.audio_feature = audio_feature\n        self.generator = generator\n        self.discriminator = discriminator\n        self.train_params = train_params\n        self.scales = train_params['scales']\n        self.disc_scales = self.discriminator.scales\n        self.pyramid = ImagePyramide(self.scales, generator.num_channels)\n        if torch.cuda.is_available():\n            self.pyramid = self.pyramid.cuda()\n\n        self.loss_weights = train_params['loss_weights']\n\n        if sum(self.loss_weights['perceptual']) != 0:\n            self.vgg = Vgg19()\n            if torch.cuda.is_available():\n                self.vgg = self.vgg.cuda()\n\n        self.pca = torch.FloatTensor(np.load('.../LRW/list/U_106.npy'))[:, :16].cuda()\n        self.mean = torch.FloatTensor(np.load('.../LRW/list/mean_106.npy')).cuda()\n\n    def forward(self, x):\n   #     source_a_f = self.audio_feature(x['source_audio'],x['source_lm'],x[])\n      #  source_a_f = self.audio_feature(self.content_encoder(x['source_audio'].unsqueeze(1)), self.emotion_encoder(x['source_audio'].unsqueeze(1)))\n   #     kp_source = self.kp_extractor(x['source'])\n   #     kp_source_a = self.kp_extractor_a(x['source'], x['source_cube'], source_a_f)\n      #  driving_a_f = self.audio_feature(self.content_encoder(x['driving_audio'].unsqueeze(1)), self.emotion_encoder(x['driving_audio'].unsqueeze(1)))\n      #  driving_a_f = self.audio_feature(x['driving_audio'])\n      #  kp_driving = self.kp_extractor(x['driving'])\n   #     kp_driving_a = self.kp_extractor_a(x['driving'], x['driving_cube'], driving_a_f)\n\n        kp_driving = []\n        for i in range(16):\n            kp_driving.append(self.kp_extractor(x['driving'][:,i],x['driving_landmark'][:,i],self.loss_weights['equivariance_value']))\n\n        kp_driving_a = []\n        fc_out, deco_out = self.audio_feature(x['example_landmark'], x['driving_audio'], x['driving_pose'])\n        fake_lmark=fc_out + x['example_landmark'].expand_as(fc_out)\n\n\n        fake_lmark = torch.mm( fake_lmark, self.pca.t() )\n        fake_lmark = fake_lmark + self.mean.expand_as(fake_lmark)\n\n\n        fake_lmark = fake_lmark.unsqueeze(0)\n\n    #    for i in range(16):\n    #        kp_driving_a.append()\n\n   #     generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving)\n   #     generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})\n\n        loss_values = {}\n\n        pyramide_real = self.pyramid(x['driving'])\n        pyramide_generated = self.pyramid(generated['prediction'])\n\n        if self.loss_weights['audio'] != 0:\n            value = torch.abs(kp_source['jacobian'].detach() - kp_source_a['jacobian'].detach()).mean() + torch.abs(kp_driving['jacobian'].detach() - kp_driving_a['jacobian']).mean()\n            value = value/2\n            loss_values['jacobian'] = value*self.loss_weights['audio']\n            value = torch.abs(kp_source['heatmap'].detach() - kp_source_a['heatmap'].detach()).mean() + torch.abs(kp_driving['heatmap'].detach() - kp_driving_a['heatmap']).mean()\n            value = value/2\n            loss_values['heatmap'] = value*self.loss_weights['audio']\n            value = torch.abs(kp_source['value'].detach() - kp_source_a['value'].detach()).mean() + torch.abs(kp_driving['value'].detach() - kp_driving_a['value']).mean()\n            value = value/2\n            loss_values['value'] = value*self.loss_weights['audio']\n\n        if sum(self.loss_weights['perceptual']) != 0:\n            value_total = 0\n            for scale in self.scales:\n                x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])\n                y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])\n\n                for i, weight in enumerate(self.loss_weights['perceptual']):\n                    value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()\n                    value_total += self.loss_weights['perceptual'][i] * value\n                loss_values['perceptual'] = value_total\n\n        if self.loss_weights['generator_gan'] != 0:\n            discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))\n            discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))\n            value_total = 0\n            for scale in self.disc_scales:\n                key = 'prediction_map_%s' % scale\n                value = ((1 - discriminator_maps_generated[key]) ** 2).mean()\n                value_total += self.loss_weights['generator_gan'] * value\n            loss_values['gen_gan'] = value_total\n\n            if sum(self.loss_weights['feature_matching']) != 0:\n                value_total = 0\n                for scale in self.disc_scales:\n                    key = 'feature_maps_%s' % scale\n                    for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):\n                        if self.loss_weights['feature_matching'][i] == 0:\n                            continue\n                        value = torch.abs(a - b).mean()\n                        value_total += self.loss_weights['feature_matching'][i] * value\n                    loss_values['feature_matching'] = value_total\n\n        if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0:\n            transform = Transform(x['driving'].shape[0], **self.train_params['transform_params'])\n            transformed_frame = transform.transform_frame(x['driving'])\n            transformed_landmark =  transform.inverse_warp_coordinates(x['driving_landmark'])\n            transformed_kp = self.kp_extractor(transformed_frame)\n\n            generated['transformed_frame'] = transformed_frame\n            generated['transformed_kp'] = transformed_kp\n\n            ## Value loss part\n            if self.loss_weights['equivariance_value'] != 0:\n                value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean()\n                loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value\n\n            ## jacobian loss part\n            if self.loss_weights['equivariance_jacobian'] != 0:\n                jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']),\n                                                    transformed_kp['jacobian'])\n\n                normed_driving = torch.inverse(kp_driving['jacobian'])\n                normed_transformed = jacobian_transformed\n                value = torch.matmul(normed_driving, normed_transformed)\n\n                eye = torch.eye(2).view(1, 1, 2, 2).type(value.type())\n\n                value = torch.abs(eye - value).mean()\n                loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value\n\n        return loss_values, generated\n\n\nclass DiscriminatorFullModel(torch.nn.Module):\n    \"\"\"\n    Merge all discriminator related updates into single model for better multi-gpu usage\n    \"\"\"\n\n    def __init__(self, kp_extractor, generator, discriminator, train_params):\n        super(DiscriminatorFullModel, self).__init__()\n        self.kp_extractor = kp_extractor\n        self.generator = generator\n        self.discriminator = discriminator\n        self.train_params = train_params\n        self.scales = self.discriminator.scales\n        self.pyramid = ImagePyramide(self.scales, generator.num_channels)\n        if torch.cuda.is_available():\n            self.pyramid = self.pyramid.cuda()\n\n        self.loss_weights = train_params['loss_weights']\n\n    def forward(self, x, generated):\n        pyramide_real = self.pyramid(x['driving'])\n        pyramide_generated = self.pyramid(generated['prediction'].detach())\n\n        kp_driving = generated['kp_driving']\n        discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))\n        discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))\n\n        loss_values = {}\n        value_total = 0\n        for scale in self.scales:\n            key = 'prediction_map_%s' % scale\n            value = (1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2\n            value_total += self.loss_weights['discriminator_gan'] * value.mean()\n        loss_values['disc_gan'] = value_total\n\n        return loss_values\n"
  },
  {
    "path": "modules/ops.py",
    "content": "import torch\nimport torchvision\nimport torch.nn as nn\nimport torch.nn.init as init\nfrom torch.autograd import Variable\n\n\ndef linear(channel_in, channel_out,\n           activation=nn.ReLU,\n           normalizer=nn.BatchNorm1d):\n    layer = list()\n    bias = True if not normalizer else False\n\n    layer.append(nn.Linear(channel_in, channel_out, bias=bias))\n    _apply(layer, activation, normalizer, channel_out)\n    # init.kaiming_normal(layer[0].weight)\n\n    return nn.Sequential(*layer)\n\n\ndef conv2d(channel_in, channel_out,\n           ksize=3, stride=1, padding=1,\n           activation=nn.ReLU,\n           normalizer=nn.BatchNorm2d):\n    layer = list()\n    bias = True if not normalizer else False\n\n    layer.append(nn.Conv2d(channel_in, channel_out,\n                     ksize, stride, padding,\n                     bias=bias))\n    _apply(layer, activation, normalizer, channel_out)\n    # init.kaiming_normal(layer[0].weight)\n\n    return nn.Sequential(*layer)\n\n\ndef conv_transpose2d(channel_in, channel_out,\n                     ksize=4, stride=2, padding=1,\n                     activation=nn.ReLU,\n                     normalizer=nn.BatchNorm2d):\n    layer = list()\n    bias = True if not normalizer else False\n\n    layer.append(nn.ConvTranspose2d(channel_in, channel_out,\n                              ksize, stride, padding,\n                              bias=bias))\n    _apply(layer, activation, normalizer, channel_out)\n    # init.kaiming_normal(layer[0].weight)\n\n    return nn.Sequential(*layer)\n\n\ndef nn_conv2d(channel_in, channel_out,\n              ksize=3, stride=1, padding=1,\n              scale_factor=2,\n              activation=nn.ReLU,\n              normalizer=nn.BatchNorm2d):\n    layer = list()\n    bias = True if not normalizer else False\n\n    layer.append(nn.UpsamplingNearest2d(scale_factor=scale_factor))\n    layer.append(nn.Conv2d(channel_in, channel_out,\n                           ksize, stride, padding,\n                           bias=bias))\n    _apply(layer, activation, normalizer, channel_out)\n    # init.kaiming_normal(layer[1].weight)\n\n    return nn.Sequential(*layer)\n\n\ndef _apply(layer, activation, normalizer, channel_out=None):\n    if normalizer:\n        layer.append(normalizer(channel_out))\n    if activation:\n        layer.append(activation())\n    return layer\n\n"
  },
  {
    "path": "modules/stylegan2.py",
    "content": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nCreated on Thu Jul  8 01:03:50 2021\n\n@author: thea\n\"\"\"\n\n\"\"\"\nThe network architectures is based on PyTorch implemenation of StyleGAN2Encoder.\nOriginal PyTorch repo: https://github.com/rosinality/style-based-gan-pytorch\nOrigianl StyelGAN2 paper: https://github.com/NVlabs/stylegan2\nWe　use the network architeture for our single-image traning setting.\n\"\"\"\n\nimport math\nimport numpy as np\nimport random\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\ndef fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):\n    return F.leaky_relu(input + bias, negative_slope) * scale\n\n\nclass FusedLeakyReLU(nn.Module):\n    def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):\n        super().__init__()\n        self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))\n        self.negative_slope = negative_slope\n        self.scale = scale\n\n    def forward(self, input):\n        # print(\"FusedLeakyReLU: \", input.abs().mean())\n        out = fused_leaky_relu(input, self.bias,\n                               self.negative_slope,\n                               self.scale)\n        # print(\"FusedLeakyReLU: \", out.abs().mean())\n        return out\n\n\ndef upfirdn2d_native(\n    input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1\n):\n    _, minor, in_h, in_w = input.shape\n    kernel_h, kernel_w = kernel.shape\n\n    out = input.view(-1, minor, in_h, 1, in_w, 1)\n    out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])\n    out = out.view(-1, minor, in_h * up_y, in_w * up_x)\n\n    out = F.pad(\n        out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]\n    )\n    out = out[\n        :,\n        :,\n        max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),\n        max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0),\n    ]\n\n    # out = out.permute(0, 3, 1, 2)\n    out = out.reshape(\n        [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]\n    )\n    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)\n    out = F.conv2d(out, w)\n    out = out.reshape(\n        -1,\n        minor,\n        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,\n        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,\n    )\n    # out = out.permute(0, 2, 3, 1)\n\n    return out[:, :, ::down_y, ::down_x]\n\n\ndef upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):\n    return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])\n\n\nclass PixelNorm(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, input):\n        return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)\n\n\ndef make_kernel(k):\n    k = torch.tensor(k, dtype=torch.float32)\n\n    if len(k.shape) == 1:\n        k = k[None, :] * k[:, None]\n\n    k /= k.sum()\n\n    return k\n\n\nclass Upsample(nn.Module):\n    def __init__(self, kernel, factor=2):\n        super().__init__()\n\n        self.factor = factor\n        kernel = make_kernel(kernel) * (factor ** 2)\n        self.register_buffer('kernel', kernel)\n\n        p = kernel.shape[0] - factor\n\n        pad0 = (p + 1) // 2 + factor - 1\n        pad1 = p // 2\n\n        self.pad = (pad0, pad1)\n\n    def forward(self, input):\n        out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)\n\n        return out\n\n\nclass Downsample(nn.Module):\n    def __init__(self, kernel, factor=2):\n        super().__init__()\n\n        self.factor = factor\n        kernel = make_kernel(kernel)\n        self.register_buffer('kernel', kernel)\n\n        p = kernel.shape[0] - factor\n\n        pad0 = (p + 1) // 2\n        pad1 = p // 2\n\n        self.pad = (pad0, pad1)\n\n    def forward(self, input):\n        out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)\n\n        return out\n\n\nclass Blur(nn.Module):\n    def __init__(self, kernel, pad, upsample_factor=1):\n        super().__init__()\n\n        kernel = make_kernel(kernel)\n\n        if upsample_factor > 1:\n            kernel = kernel * (upsample_factor ** 2)\n\n        self.register_buffer('kernel', kernel)\n\n        self.pad = pad\n\n    def forward(self, input):\n        out = upfirdn2d(input, self.kernel, pad=self.pad)\n\n        return out\n\n\nclass EqualConv2d(nn.Module):\n    def __init__(\n        self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True\n    ):\n        super().__init__()\n\n        self.weight = nn.Parameter(\n            torch.randn(out_channel, in_channel, kernel_size, kernel_size)\n        )\n        self.scale = math.sqrt(1) / math.sqrt(in_channel * (kernel_size ** 2))\n\n        self.stride = stride\n        self.padding = padding\n\n        if bias:\n            self.bias = nn.Parameter(torch.zeros(out_channel))\n\n        else:\n            self.bias = None\n\n    def forward(self, input):\n        # print(\"Before EqualConv2d: \", input.abs().mean())\n        out = F.conv2d(\n            input,\n            self.weight * self.scale,\n            bias=self.bias,\n            stride=self.stride,\n            padding=self.padding,\n        )\n        # print(\"After EqualConv2d: \", out.abs().mean(), (self.weight * self.scale).abs().mean())\n\n        return out\n\n    def __repr__(self):\n        return (\n            f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'\n            f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'\n        )\n\n\nclass EqualLinear(nn.Module):\n    def __init__(\n        self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None\n    ):\n        super().__init__()\n\n        self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))\n\n        if bias:\n            self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))\n\n        else:\n            self.bias = None\n\n        self.activation = activation\n\n        self.scale = (math.sqrt(1) / math.sqrt(in_dim)) * lr_mul\n        self.lr_mul = lr_mul\n\n    def forward(self, input):\n        if self.activation:\n            out = F.linear(input, self.weight * self.scale)\n            out = fused_leaky_relu(out, self.bias * self.lr_mul)\n\n        else:\n            out = F.linear(\n                input, self.weight * self.scale, bias=self.bias * self.lr_mul\n            )\n\n        return out\n\n    def __repr__(self):\n        return (\n            f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'\n        )\n\n\nclass ScaledLeakyReLU(nn.Module):\n    def __init__(self, negative_slope=0.2):\n        super().__init__()\n\n        self.negative_slope = negative_slope\n\n    def forward(self, input):\n        out = F.leaky_relu(input, negative_slope=self.negative_slope)\n\n        return out * math.sqrt(2)\n\n\nclass ModulatedConv2d(nn.Module):\n    def __init__(\n        self,\n        in_channel,\n        out_channel,\n        kernel_size,\n        style_dim,\n        demodulate=True,\n        upsample=False,\n        downsample=False,\n        blur_kernel=[1, 3, 3, 1],\n    ):\n        super().__init__()\n\n        self.eps = 1e-8\n        self.kernel_size = kernel_size\n        self.in_channel = in_channel\n        self.out_channel = out_channel\n        self.upsample = upsample\n        self.downsample = downsample\n\n        if upsample:\n            factor = 2\n            p = (len(blur_kernel) - factor) - (kernel_size - 1)\n            pad0 = (p + 1) // 2 + factor - 1\n            pad1 = p // 2 + 1\n\n            self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)\n\n        if downsample:\n            factor = 2\n            p = (len(blur_kernel) - factor) + (kernel_size - 1)\n            pad0 = (p + 1) // 2\n            pad1 = p // 2\n\n            self.blur = Blur(blur_kernel, pad=(pad0, pad1))\n\n        fan_in = in_channel * kernel_size ** 2\n        self.scale = math.sqrt(1) / math.sqrt(fan_in)\n        self.padding = kernel_size // 2\n\n        self.weight = nn.Parameter(\n            torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)\n        )\n\n        if style_dim is not None and style_dim > 0:\n            self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)\n\n        self.demodulate = demodulate\n\n    def __repr__(self):\n        return (\n            f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '\n            f'upsample={self.upsample}, downsample={self.downsample})'\n        )\n\n    def forward(self, input, style):\n        batch, in_channel, height, width = input.shape\n\n        if style is not None:\n            style = self.modulation(style).view(batch, 1, in_channel, 1, 1)\n        else:\n            style = torch.ones(batch, 1, in_channel, 1, 1).cuda()\n        weight = self.scale * self.weight * style\n\n        if self.demodulate:\n            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)\n            weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)\n\n        weight = weight.view(\n            batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size\n        )\n\n        if self.upsample:\n            input = input.view(1, batch * in_channel, height, width)\n            weight = weight.view(\n                batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size\n            )\n            weight = weight.transpose(1, 2).reshape(\n                batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size\n            )\n            out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)\n            _, _, height, width = out.shape\n            out = out.view(batch, self.out_channel, height, width)\n            out = self.blur(out)\n\n        elif self.downsample:\n            input = self.blur(input)\n            _, _, height, width = input.shape\n            input = input.view(1, batch * in_channel, height, width)\n            out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)\n            _, _, height, width = out.shape\n            out = out.view(batch, self.out_channel, height, width)\n\n        else:\n            input = input.view(1, batch * in_channel, height, width)\n            out = F.conv2d(input, weight, padding=self.padding, groups=batch)\n            _, _, height, width = out.shape\n            out = out.view(batch, self.out_channel, height, width)\n\n        return out\n\n\nclass NoiseInjection(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n        self.weight = nn.Parameter(torch.zeros(1))\n\n    def forward(self, image, noise=None):\n        if noise is None:\n            batch, _, height, width = image.shape\n            noise = image.new_empty(batch, 1, height, width).normal_()\n\n        return image + self.weight * noise\n\n\nclass ConstantInput(nn.Module):\n    def __init__(self, channel, size=4):\n        super().__init__()\n\n        self.input = nn.Parameter(torch.randn(1, channel, size, size))\n\n    def forward(self, input):\n        batch = input.shape[0]\n        out = self.input.repeat(batch, 1, 1, 1)\n\n        return out\n\n\nclass StyledConv(nn.Module):\n    def __init__(\n        self,\n        in_channel,\n        out_channel,\n        kernel_size,\n        style_dim=None,\n        upsample=False,\n        blur_kernel=[1, 3, 3, 1],\n        demodulate=True,\n        inject_noise=False, #True\n    ):\n        super().__init__()\n\n        self.inject_noise = inject_noise\n        self.conv = ModulatedConv2d(\n            in_channel,\n            out_channel,\n            kernel_size,\n            style_dim,\n            upsample=upsample,\n            blur_kernel=blur_kernel,\n            demodulate=demodulate,\n        )\n\n        self.noise = NoiseInjection()\n        # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))\n        # self.activate = ScaledLeakyReLU(0.2)\n        self.activate = FusedLeakyReLU(out_channel)\n\n    def forward(self, input, style=None, noise=None):\n        out = self.conv(input, style)\n        if self.inject_noise:\n            out = self.noise(out, noise=noise)\n        # out = out + self.bias\n        out = self.activate(out)\n\n        return out\n\n\nclass ToRGB(nn.Module):\n    def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):\n        super().__init__()\n\n        if upsample:\n            self.upsample = Upsample(blur_kernel)\n\n        self.conv = ModulatedConv2d(in_channel, 3+32, 1, style_dim, demodulate=False)\n        self.bias = nn.Parameter(torch.zeros(1, 3+32, 1, 1))\n\n    def forward(self, input, style, skip=None):\n        out = self.conv(input, style)\n        out = out + self.bias\n\n        if skip is not None:\n            skip = self.upsample(skip)\n\n            out = out + skip\n\n        return out\n\n\nclass Generator(nn.Module):\n    def __init__(\n        self,\n        size,\n        style_dim,\n        n_mlp,\n        channel_multiplier=1,\n        blur_kernel=[1, 3, 3, 1],\n        lr_mlp=0.01,\n    ):\n        super().__init__()\n\n        self.size = size\n\n        self.style_dim = style_dim\n\n        layers = [PixelNorm()]\n\n        for i in range(n_mlp):\n            layers.append(\n                EqualLinear(\n                    style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'\n                )\n            )\n\n        self.style = nn.Sequential(*layers)\n\n        self.channels = {\n            4: 256,\n            8: 256,\n            16: 128,\n            32: 64,\n            64: 32 * channel_multiplier,\n            128: 16 * channel_multiplier,\n            256: 8 * channel_multiplier,\n            512: 4 * channel_multiplier,\n            1024: 2 * channel_multiplier,\n        }\n\n        self.input = ConstantInput(self.channels[4])\n        self.conv1 = StyledConv(\n            self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel\n        )\n        self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)\n\n        self.log_size = int(math.log(size, 2))\n        self.num_layers = (self.log_size - 2) * 2 + 1\n\n        self.convs = nn.ModuleList()\n        self.upsamples = nn.ModuleList()\n        self.to_rgbs = nn.ModuleList()\n        self.noises = nn.Module()\n\n        in_channel = self.channels[4]\n\n        for layer_idx in range(self.num_layers):\n            res = (layer_idx + 5) // 2\n            shape = [1, 1, 2 ** res, 2 ** res]\n            self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))\n\n        for i in range(3, self.log_size + 1):\n            out_channel = self.channels[2 ** i]\n\n            self.convs.append(\n                StyledConv(\n                    in_channel,\n                    out_channel,\n                    3,\n                    style_dim,\n                    upsample=True,\n                    blur_kernel=blur_kernel,\n                )\n            )\n\n            self.convs.append(\n                StyledConv(\n                    out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel\n                )\n            )\n\n            self.to_rgbs.append(ToRGB(out_channel, style_dim))\n\n            in_channel = out_channel\n\n        self.n_latent = self.log_size * 2 - 2\n\n    def make_noise(self):\n        device = self.input.input.device\n\n        noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]\n\n        for i in range(3, self.log_size + 1):\n            for _ in range(2):\n                noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))\n\n        return noises\n\n    def mean_latent(self, n_latent):\n        latent_in = torch.randn(\n            n_latent, self.style_dim, device=self.input.input.device\n        )\n        latent = self.style(latent_in).mean(0, keepdim=True)\n\n        return latent\n\n    def get_latent(self, input):\n        return self.style(input)\n\n    def forward(\n        self,\n        styles,\n        return_latents=False,\n        inject_index=None,\n        truncation=1,\n        truncation_latent=None,\n        input_is_latent=False,\n        noise=None,\n        randomize_noise=True,\n    ):\n        if not input_is_latent:\n            styles = [self.style(s) for s in styles]\n\n        if noise is None:\n            if randomize_noise:\n                noise = [None] * self.num_layers\n            else:\n                noise = [\n                    getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)\n                ]\n\n        if truncation < 1:\n            style_t = []\n\n            for style in styles:\n                style_t.append(\n                    truncation_latent + truncation * (style - truncation_latent)\n                )\n\n            styles = style_t\n\n        if len(styles) < 2:\n            inject_index = self.n_latent\n\n            if len(styles[0].shape) < 3:\n                latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)\n\n            else:\n                latent = styles[0]\n\n        else:\n            if inject_index is None:\n                inject_index = random.randint(1, self.n_latent - 1)\n\n            latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)\n            latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)\n\n            latent = torch.cat([latent, latent2], 1)\n\n      #  out = self.input(latent)\n        out = styles[0].unsqueeze(-1).unsqueeze(-1).repeat(1,1,4,4)\n        out = self.conv1(out, latent[:, 0], noise=noise[0])\n\n        skip = self.to_rgb1(out, latent[:, 1])\n\n        i = 1\n        for conv1, conv2, noise1, noise2, to_rgb in zip(\n            self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs\n        ):\n            out = conv1(out, latent[:, i], noise=noise1)\n            out = conv2(out, latent[:, i + 1], noise=noise2)\n            skip = to_rgb(out, latent[:, i + 2], skip)\n\n            i += 2\n\n        image = skip\n\n        if return_latents:\n            return image, latent\n\n        else:\n            return image, None\n\n\nclass ConvLayer(nn.Sequential):\n    def __init__(\n        self,\n        in_channel,\n        out_channel,\n        kernel_size,\n        downsample=False,\n        blur_kernel=[1, 3, 3, 1],\n        bias=True,\n        activate=True,\n    ):\n        layers = []\n\n        if downsample:\n            factor = 2\n            p = (len(blur_kernel) - factor) + (kernel_size - 1)\n            pad0 = (p + 1) // 2\n            pad1 = p // 2\n\n            layers.append(Blur(blur_kernel, pad=(pad0, pad1)))\n\n            stride = 2\n            self.padding = 0\n\n        else:\n            stride = 1\n            self.padding = kernel_size // 2\n\n        layers.append(\n            EqualConv2d(\n                in_channel,\n                out_channel,\n                kernel_size,\n                padding=self.padding,\n                stride=stride,\n                bias=bias and not activate,\n            )\n        )\n\n        if activate:\n            if bias:\n                layers.append(FusedLeakyReLU(out_channel))\n\n            else:\n                layers.append(ScaledLeakyReLU(0.2))\n\n        super().__init__(*layers)\n\n\nclass ResBlock(nn.Module):\n    def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], downsample=True, skip_gain=1.0):\n        super().__init__()\n\n        self.skip_gain = skip_gain\n        self.conv1 = ConvLayer(in_channel, in_channel, 3)\n        self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=downsample, blur_kernel=blur_kernel)\n\n        if in_channel != out_channel or downsample:\n            self.skip = ConvLayer(\n                in_channel, out_channel, 1, downsample=downsample, activate=False, bias=False\n            )\n        else:\n            self.skip = nn.Identity()\n\n    def forward(self, input):\n        out = self.conv1(input)\n        out = self.conv2(out)\n\n        skip = self.skip(input)\n        out = (out * self.skip_gain + skip) / math.sqrt(self.skip_gain ** 2 + 1.0)\n\n        return out\n\n\nclass StyleGAN2Discriminator(nn.Module):\n    def __init__(self, input_nc, ndf=64, n_layers=3, no_antialias=False, size=None, opt=None):\n        super().__init__()\n        self.opt = opt\n        self.stddev_group = 16\n        if size is None:\n            size = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size)))))\n            if \"patch\" in self.opt.netD and self.opt.D_patch_size is not None:\n                size = 2 ** int(np.log2(self.opt.D_patch_size))\n\n        blur_kernel = [1, 3, 3, 1]\n        channel_multiplier = ndf / 64\n        channels = {\n            4: min(384, int(4096 * channel_multiplier)),\n            8: min(384, int(2048 * channel_multiplier)),\n            16: min(384, int(1024 * channel_multiplier)),\n            32: min(384, int(512 * channel_multiplier)),\n            64: int(256 * channel_multiplier),\n            128: int(128 * channel_multiplier),\n            256: int(64 * channel_multiplier),\n            512: int(32 * channel_multiplier),\n            1024: int(16 * channel_multiplier),\n        }\n\n        convs = [ConvLayer(3, channels[size], 1)]\n\n        log_size = int(math.log(size, 2))\n\n        in_channel = channels[size]\n\n        if \"smallpatch\" in self.opt.netD:\n            final_res_log2 = 4\n        elif \"patch\" in self.opt.netD:\n            final_res_log2 = 3\n        else:\n            final_res_log2 = 2\n\n        for i in range(log_size, final_res_log2, -1):\n            out_channel = channels[2 ** (i - 1)]\n\n            convs.append(ResBlock(in_channel, out_channel, blur_kernel))\n\n            in_channel = out_channel\n\n        self.convs = nn.Sequential(*convs)\n\n        if False and \"tile\" in self.opt.netD:\n            in_channel += 1\n        self.final_conv = ConvLayer(in_channel, channels[4], 3)\n        if \"patch\" in self.opt.netD:\n            self.final_linear = ConvLayer(channels[4], 1, 3, bias=False, activate=False)\n        else:\n            self.final_linear = nn.Sequential(\n                EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),\n                EqualLinear(channels[4], 1),\n            )\n\n    def forward(self, input, get_minibatch_features=False):\n        if \"patch\" in self.opt.netD and self.opt.D_patch_size is not None:\n            h, w = input.size(2), input.size(3)\n            y = torch.randint(h - self.opt.D_patch_size, ())\n            x = torch.randint(w - self.opt.D_patch_size, ())\n            input = input[:, :, y:y + self.opt.D_patch_size, x:x + self.opt.D_patch_size]\n        out = input\n        for i, conv in enumerate(self.convs):\n            out = conv(out)\n            # print(i, out.abs().mean())\n        # out = self.convs(input)\n\n        batch, channel, height, width = out.shape\n\n        if False and \"tile\" in self.opt.netD:\n            group = min(batch, self.stddev_group)\n            stddev = out.view(\n                group, -1, 1, channel // 1, height, width\n            )\n            stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)\n            stddev = stddev.mean([2, 3, 4], keepdim=True).squeeze(2)\n            stddev = stddev.repeat(group, 1, height, width)\n            out = torch.cat([out, stddev], 1)\n\n        out = self.final_conv(out)\n        # print(out.abs().mean())\n\n        if \"patch\" not in self.opt.netD:\n            out = out.view(batch, -1)\n        out = self.final_linear(out)\n\n        return out\n\n\nclass TileStyleGAN2Discriminator(StyleGAN2Discriminator):\n    def forward(self, input):\n        B, C, H, W = input.size(0), input.size(1), input.size(2), input.size(3)\n        size = self.opt.D_patch_size\n        Y = H // size\n        X = W // size\n        input = input.view(B, C, Y, size, X, size)\n        input = input.permute(0, 2, 4, 1, 3, 5).contiguous().view(B * Y * X, C, size, size)\n        return super().forward(input)\n\n\nclass StyleGAN2Encoder(nn.Module):\n    def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):\n        super().__init__()\n        assert opt is not None\n        self.opt = opt\n        channel_multiplier = ngf / 32\n        channels = {\n            4: min(512, int(round(4096 * channel_multiplier))),\n            8: min(512, int(round(2048 * channel_multiplier))),\n            16: min(512, int(round(1024 * channel_multiplier))),\n            32: min(512, int(round(512 * channel_multiplier))),\n            64: int(round(256 * channel_multiplier)),\n            128: int(round(128 * channel_multiplier)),\n            256: int(round(64 * channel_multiplier)),\n            512: int(round(32 * channel_multiplier)),\n            1024: int(round(16 * channel_multiplier)),\n        }\n\n        blur_kernel = [1, 3, 3, 1]\n\n        cur_res = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size)))))\n        convs = [nn.Identity(),\n                 ConvLayer(3, channels[cur_res], 1)]\n\n        num_downsampling = self.opt.stylegan2_G_num_downsampling\n        for i in range(num_downsampling):\n            in_channel = channels[cur_res]\n            out_channel = channels[cur_res // 2]\n            convs.append(ResBlock(in_channel, out_channel, blur_kernel, downsample=True))\n            cur_res = cur_res // 2\n\n        for i in range(n_blocks // 2):\n            n_channel = channels[cur_res]\n            convs.append(ResBlock(n_channel, n_channel, downsample=False))\n\n        self.convs = nn.Sequential(*convs)\n\n    def forward(self, input, layers=[], get_features=False):\n        feat = input\n        feats = []\n        if -1 in layers:\n            layers.append(len(self.convs) - 1)\n        for layer_id, layer in enumerate(self.convs):\n            feat = layer(feat)\n            # print(layer_id, \" features \", feat.abs().mean())\n            if layer_id in layers:\n                feats.append(feat)\n\n        if get_features:\n            return feat, feats\n        else:\n            return feat\n\n\nclass StyleGAN2Decoder(nn.Module):\n    def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):\n        super().__init__()\n        assert opt is not None\n        self.opt = opt\n\n        blur_kernel = [1, 3, 3, 1]\n\n        channel_multiplier = ngf / 32\n        channels = {\n            4: min(512, int(round(4096 * channel_multiplier))),\n            8: min(512, int(round(2048 * channel_multiplier))),\n            16: min(512, int(round(1024 * channel_multiplier))),\n            32: min(512, int(round(512 * channel_multiplier))),\n            64: int(round(256 * channel_multiplier)),\n            128: int(round(128 * channel_multiplier)),\n            256: int(round(64 * channel_multiplier)),\n            512: int(round(32 * channel_multiplier)),\n            1024: int(round(16 * channel_multiplier)),\n        }\n\n        num_downsampling = self.opt.stylegan2_G_num_downsampling\n        cur_res = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size))))) // (2 ** num_downsampling)\n        convs = []\n\n        for i in range(n_blocks // 2):\n            n_channel = channels[cur_res]\n            convs.append(ResBlock(n_channel, n_channel, downsample=False))\n\n        for i in range(num_downsampling):\n            in_channel = channels[cur_res]\n            out_channel = channels[cur_res * 2]\n            inject_noise = \"small\" not in self.opt.netG\n            convs.append(\n                StyledConv(in_channel, out_channel, 3, upsample=True, blur_kernel=blur_kernel, inject_noise=inject_noise)\n            )\n            cur_res = cur_res * 2\n\n        convs.append(ConvLayer(channels[cur_res], 3, 1))\n\n        self.convs = nn.Sequential(*convs)\n\n    def forward(self, input):\n        return self.convs(input)\n\n\nclass StyleGAN2Generator(nn.Module):\n    def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):\n        super().__init__()\n        self.opt = opt\n        self.encoder = StyleGAN2Encoder(input_nc, output_nc, ngf, use_dropout, n_blocks, padding_type, no_antialias, opt)\n        self.decoder = StyleGAN2Decoder(input_nc, output_nc, ngf, use_dropout, n_blocks, padding_type, no_antialias, opt)\n\n    def forward(self, input, layers=[], encode_only=False):\n        feat, feats = self.encoder(input, layers, True)\n        if encode_only:\n            return feats\n        else:\n            fake = self.decoder(feat)\n\n            if len(layers) > 0:\n                return fake, feats\n            else:\n                return fake"
  },
  {
    "path": "modules/util.py",
    "content": "from torch import nn\n\nimport torch.nn.functional as F\nimport torch\nimport numpy as np\nimport cv2\nfrom sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d\n\nfrom modules.stylegan2 import Generator\n\nimport torch.nn as nn\nimport math\nimport torch.utils.model_zoo as model_zoo\nfrom modules.function import adaptive_instance_normalization as adain\n\nimport pdb\n\n\n\n# Misc\nimg2mse = lambda x, y : torch.mean((x - y) ** 2)\nmse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))\nto8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)\n\n\nclass InstanceNorm(nn.Module):\n    def __init__(self, epsilon=1e-8):\n        \"\"\"\n            @notice: avoid in-place ops.\n            https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3\n        \"\"\"\n        super(InstanceNorm, self).__init__()\n        self.epsilon = epsilon\n\n    def forward(self, x):\n        x   = x - torch.mean(x, (2, 3), True)\n        tmp = torch.mul(x, x) # or x ** 2\n        tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)\n        return x * tmp\n\nclass ApplyStyle(nn.Module):\n    \"\"\"\n        @ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb\n    \"\"\"\n    def __init__(self, latent_size, channels, use_wscale):\n        super(ApplyStyle, self).__init__()\n        self.linear = FC(latent_size,\n                      channels * 2,\n                      gain=1.0,\n                      use_wscale=use_wscale)\n\n    def forward(self, x, latent):\n        style = self.linear(latent)  # style => [batch_size, n_channels*2]\n        shape = [-1, 2, x.size(1), 1, 1]\n        style = style.view(shape)    # [batch_size, 2, n_channels, ...]\n        x = x * (style[:, 0] + 1.) + style[:, 1]\n        return x\n\n\nclass FC(nn.Module):\n    def __init__(self,\n                 in_channels,\n                 out_channels,\n                 gain=2**(0.5),\n                 use_wscale=False,\n                 lrmul=1.0,\n                 bias=True):\n        \"\"\"\n            The complete conversion of Dense/FC/Linear Layer of original Tensorflow version.\n        \"\"\"\n        super(FC, self).__init__()\n        he_std = gain * in_channels ** (-0.5)  # He init\n        if use_wscale:\n            init_std = 1.0 / lrmul\n            self.w_lrmul = he_std * lrmul\n        else:\n            init_std = he_std / lrmul\n            self.w_lrmul = lrmul\n\n        self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels) * init_std)\n        if bias:\n            self.bias = torch.nn.Parameter(torch.zeros(out_channels))\n            self.b_lrmul = lrmul\n        else:\n            self.bias = None\n\n    def forward(self, x):\n        if self.bias is not None:\n            out = F.linear(x, self.weight * self.w_lrmul, self.bias * self.b_lrmul)\n        else:\n            out = F.linear(x, self.weight * self.w_lrmul)\n        out = F.leaky_relu(out, 0.2, inplace=True)\n        return out\n\n\n# Positional encoding (section 5.1)\nclass Embedder:\n    def __init__(self, **kwargs):\n        self.kwargs = kwargs\n        self.create_embedding_fn()\n\n    def create_embedding_fn(self):\n        embed_fns = []\n        d = self.kwargs['input_dims']\n        out_dim = 0\n        if self.kwargs['include_input']:\n            embed_fns.append(lambda x : x)\n            out_dim += d\n\n        max_freq = self.kwargs['max_freq_log2']\n        N_freqs = self.kwargs['num_freqs']\n\n        if self.kwargs['log_sampling']:\n            freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)\n        else:\n            freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)\n\n        for freq in freq_bands:\n            for p_fn in self.kwargs['periodic_fns']:\n                embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))\n                out_dim += d\n\n        self.embed_fns = embed_fns\n        self.out_dim = out_dim\n\n    def embed(self, inputs):\n        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)\n\n\ndef get_embedder(multires, i=0):\n    if i == -1:\n        return nn.Identity(), 6\n\n    embed_kwargs = {\n                'include_input' : True,\n                'input_dims' : 6,\n                'max_freq_log2' : multires-1,\n                'num_freqs' : multires,\n                'log_sampling' : True,\n                'periodic_fns' : [torch.sin, torch.cos],\n    }\n\n    embedder_obj = Embedder(**embed_kwargs)\n    embed = lambda x, eo=embedder_obj : eo.embed(x)\n    return embed, embedder_obj.out_dim\n\n\ndef draw_heatmap(landmark, width, height):\n    batch = landmark.shape[0]\n    number = landmark.shape[1]\n    heatmap = np.zeros((batch, number,width, height), dtype=np.float32)\n    # draw mouth from mouth landmarks, landmarks: mouth landmark points, format: x1, y1, x2, y2, ..., x20,\n\n\n    landmark = (landmark+1)*29\n    for i in range(batch):\n        for pts_idx in range(number):\n            if int(landmark[i,pts_idx,0])<0:\n                landmark[i,pts_idx,0] = 0\n            if int(landmark[i,pts_idx,1])<0:\n                landmark[i,pts_idx,1] = 0\n            if int(landmark[i,pts_idx,0])>57:\n                landmark[i,pts_idx,0] = 57\n            if int(landmark[i,pts_idx,1])>57:\n                landmark[i,pts_idx,1] = 57\n            heatmap[i,pts_idx, int(landmark[i,pts_idx,1]), int(landmark[i,pts_idx,0])]=1\n            if heatmap[i,pts_idx].sum()== 1 :\n\n                heatmap[i,pts_idx] = cv2.GaussianBlur(heatmap[i,pts_idx], ksize=(3, 3), sigmaX=1, sigmaY=1)\n\n\n    heatmap = torch.tensor(heatmap).cuda()\n    return heatmap\n\nclass NA_net(nn.Module):\n    def __init__(self):\n        super(NA_net, self).__init__()\n\n\n\n        self.decon = nn.Sequential(\n                nn.ConvTranspose2d(1, 16, kernel_size=(2,3), stride=2, padding=(2,1), bias=True),#16,16\n                nn.BatchNorm2d(16),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(16, 32, kernel_size=4, stride=2, padding=1, bias=True),#8,8\n                nn.BatchNorm2d(32),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(32, 32+3, kernel_size=4, stride=2, padding=1, bias=True)#16,16\n\n\n                )\n\n\n\n    def forward(self, neutral):\n\n        feature = neutral.unsqueeze(1)\n        current_feature = self.decon(feature)\n\n\n        return current_feature\n\nclass AT_net(nn.Module):\n    def __init__(self):\n        super(AT_net, self).__init__()\n\n        down_blocks = []\n        for i in range(8):\n            down_blocks.append(DownBlock2d(3 if i == 0 else  2 * (2 ** i),\n                                            2 * (2 ** (i + 1)),\n                                           kernel_size=3, padding=1))\n        self.down_blocks = nn.ModuleList(down_blocks)\n\n\n     #   self.lmark_encoder = nn.Sequential(\n     #       nn.Linear(16,256),\n     #       nn.ReLU(True),\n     #       nn.Linear(256,512),\n     #       nn.ReLU(True),\n     #       )\n        self.pose_encoder = nn.Sequential(\n            nn.Linear(6,128),\n            nn.ReLU(True),\n            nn.Linear(128,256),\n            nn.ReLU(True),\n\n            )\n        self.audio_eocder = nn.Sequential(\n            conv2d(1,64,3,1,1),\n            conv2d(64,128,3,1,1),\n            nn.MaxPool2d(3, stride=(1,2)),\n            conv2d(128,256,3,1,1),\n            conv2d(256,256,3,1,1),\n            conv2d(256,512,3,1,1),\n            nn.MaxPool2d(3, stride=(2,2))\n            )\n        self.audio_eocder_fc = nn.Sequential(\n            nn.Linear(1024 *12,2048),\n            nn.ReLU(True),\n            nn.Linear(2048,256),\n            nn.ReLU(True),\n\n            )\n        self.lstm = nn.LSTM(256*4,256,3,batch_first = True)\n    #    self.lstm_fc = nn.Sequential(\n    #        nn.Linear(256,16),\n    #        )\n        self.decon = nn.Sequential(\n                nn.ConvTranspose2d(256, 256, kernel_size=6, stride=2, padding=1, bias=True),#4,4\n                nn.BatchNorm2d(256),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),#8,8\n                nn.BatchNorm2d(128),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True), #16,16\n                nn.BatchNorm2d(128),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),#32,32\n                nn.BatchNorm2d(128),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(128, 32+3, kernel_size=4, stride=2, padding=1, bias=True),#64,64\n            #    nn.ConvTranspose2d(128, 32*4, kernel_size=2, stride=2, padding=3, bias=True),#64,64\n\n\n                )\n        self.generator = Generator(64,256,8)\n\n\n\n    def forward(self, example_image, audio, pose, jaco_net):\n        hidden = ( torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda()),\n                      torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda()))\n        outs = example_image\n        for down_block in self.down_blocks:\n            outs = down_block(outs)\n            image_feature = outs\n        image_feature = image_feature.view(image_feature.shape[0], -1)\n        lstm_input = []\n        for step_t in range(audio.size(1)):\n            current_audio = audio[ : ,step_t , :, :].unsqueeze(1)\n            current_feature = self.audio_eocder(current_audio)\n            current_feature = current_feature.view(current_feature.size(0), -1)\n            current_feature = self.audio_eocder_fc(current_feature)\n            pose_f = self.pose_encoder(pose[:,step_t])\n            features = torch.cat([image_feature,  current_feature, pose_f], 1)\n            lstm_input.append(features)\n        lstm_input = torch.stack(lstm_input, dim = 1)\n        lstm_out, _ = self.lstm(lstm_input, hidden)\n        fc_out   = []\n        deco_out = []\n        for step_t in range(audio.size(1)):\n            fc_in = lstm_out[:,step_t,:]\n    #        fc_out.append(self.lstm_fc(fc_in))\n            if jaco_net == 'cnn':\n                fc_feature = torch.unsqueeze(fc_in,2)\n                fc_feature = torch.unsqueeze(fc_feature,3)\n                deco_out.append(self.decon(fc_feature))\n            elif jaco_net == 'gan':\n                result,_ = self.generator([fc_in])\n                deco_out.append(result)\n            else:\n                raise Exception(\"jaco_net type wrong\")\n\n        return torch.stack(deco_out,dim=1)\n\nclass Classify(nn.Module):\n    def __init__(self):\n        super(Classify, self).__init__()\n\n\n\n        self.last_fc = nn.Linear(512,8)\n\n    def forward(self, feature):\n       # mfcc= torch.unsqueeze(mfcc, 1)\n\n        x = self.last_fc(feature)\n\n        return x\n\nclass TF_net(nn.Module):\n    def __init__(self):\n        super(TF_net, self).__init__()\n\n        down_blocks = []\n        for i in range(8):\n            down_blocks.append(DownBlock2d(3 if i == 0 else  2 * (2 ** i),\n                                            2 * (2 ** (i + 1)),\n                                           kernel_size=3, padding=1))\n        self.down_blocks = nn.ModuleList(down_blocks)\n\n\n     #   self.lmark_encoder = nn.Sequential(\n     #       nn.Linear(16,256),\n     #       nn.ReLU(True),\n     #       nn.Linear(256,512),\n     #       nn.ReLU(True),\n     #       )\n        self.pose_encoder = nn.Sequential(\n            nn.Linear(6,128),\n            nn.ReLU(True),\n            nn.Linear(128,256),\n            nn.ReLU(True),\n\n            )\n        self.audio_eocder = nn.Sequential(\n            conv2d(1,64,3,1,1),\n            conv2d(64,128,3,1,1),\n            nn.MaxPool2d(3, stride=(1,2)),\n            conv2d(128,256,3,1,1),\n            conv2d(256,256,3,1,1),\n            conv2d(256,512,3,1,1),\n            nn.MaxPool2d(3, stride=(2,2))\n            )\n        self.audio_eocder_fc = nn.Sequential(\n            nn.Linear(1024 *12,2048),\n            nn.ReLU(True),\n            nn.Linear(2048,256),\n            nn.ReLU(True),\n\n            )\n        self.lstm = nn.LSTM(256*4,256,3,batch_first = True)\n        self.lstm_two = nn.LSTM(256*6,256,3,batch_first = True)\n    #    self.lstm_fc = nn.Sequential(\n    #        nn.Linear(256,16),\n    #        )\n        self.decon = nn.Sequential(\n                nn.ConvTranspose2d(256, 256, kernel_size=6, stride=2, padding=1, bias=True),#4,4\n                nn.BatchNorm2d(256),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),#8,8\n                nn.BatchNorm2d(128),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True), #16,16\n                nn.BatchNorm2d(128),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),#32,32\n                nn.BatchNorm2d(128),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(128, 32+3, kernel_size=4, stride=2, padding=1, bias=True),#64,64\n            #    nn.ConvTranspose2d(128, 32*4, kernel_size=2, stride=2, padding=3, bias=True),#64,64\n\n\n                )\n        self.generator = Generator(64,256,8)\n        self.instance_norm = InstanceNorm()\n        self.style_mod = ApplyStyle(512, 1024, use_wscale=True)\n        self.style_mod1 = ApplyStyle(512, 35, use_wscale=True)\n\n\n    def adain_forward(self, example_image, audio, pose, jaco_net, emo_features):\n        hidden = ( torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda()),\n                      torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda()))\n        outs = example_image\n        for down_block in self.down_blocks:\n            outs = down_block(outs)\n            image_feature = outs\n        image_feature = image_feature.view(image_feature.shape[0], -1)\n        lstm_input = []\n        for step_t in range(audio.size(1)):\n            current_audio = audio[ : ,step_t , :, :].unsqueeze(1)\n            current_feature = self.audio_eocder(current_audio)\n            current_feature = current_feature.view(current_feature.size(0), -1)\n            current_feature = self.audio_eocder_fc(current_feature) #256\n            pose_f = self.pose_encoder(pose[:,step_t]) #256\n            features = torch.cat([image_feature,  current_feature, pose_f], 1)\n            features = torch.unsqueeze(torch.unsqueeze(features,-1),-1)\n            features = self.instance_norm(features)\n            x = self.style_mod(features, emo_features[step_t])\n          #  t = adain(torch.unsqueeze(torch.unsqueeze(features,-1),-1), torch.unsqueeze(torch.unsqueeze(emo_features[step_t],1),2))\n\n            lstm_input.append(torch.squeeze(torch.squeeze(x,-1),-1))\n        lstm_input = torch.stack(lstm_input, dim = 1)\n        lstm_out, _ = self.lstm(lstm_input, hidden)\n    #    fc_out   = []\n        deco_out = []\n        for step_t in range(audio.size(1)):\n            fc_in = lstm_out[:,step_t,:]\n    #        fc_out.append(self.lstm_fc(fc_in))\n            if jaco_net == 'cnn':\n                fc_feature = torch.unsqueeze(fc_in,2)\n                fc_feature = torch.unsqueeze(fc_feature,3)\n                deco_out.append(self.decon(fc_feature))\n            elif jaco_net == 'gan':\n                result,_ = self.generator([fc_in])\n                deco_out.append(result)\n            else:\n                raise Exception(\"jaco_net type wrong\")\n\n        return torch.stack(deco_out,dim=1)\n\n\n\n    def adain_feature2(self, example_image, audio, pose, jaco_net, emo_features):\n        hidden = ( torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda()),\n                      torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda()))\n        outs = example_image\n        for down_block in self.down_blocks:\n            outs = down_block(outs)\n            image_feature = outs\n        image_feature = image_feature.view(image_feature.shape[0], -1)\n        lstm_input = []\n        for step_t in range(audio.size(1)):\n            current_audio = audio[ : ,step_t , :, :].unsqueeze(1)\n            current_feature = self.audio_eocder(current_audio)\n            current_feature = current_feature.view(current_feature.size(0), -1)\n            current_feature = self.audio_eocder_fc(current_feature) #256\n            pose_f = self.pose_encoder(pose[:,step_t]) #256\n            features = torch.cat([image_feature,  current_feature, pose_f], 1)\n\n            lstm_input.append(features)\n        lstm_input = torch.stack(lstm_input, dim = 1)\n        lstm_out, _ = self.lstm(lstm_input, hidden)\n    #    fc_out   = []\n        deco_out = []\n        for step_t in range(audio.size(1)):\n            fc_in = lstm_out[:,step_t,:]\n    #        fc_out.append(self.lstm_fc(fc_in))\n            if jaco_net == 'cnn':\n                fc_feature = torch.unsqueeze(fc_in,2)\n                fc_feature = torch.unsqueeze(fc_feature,3)\n                fc_feature = self.decon(fc_feature)\n                fc_feature = self.instance_norm(fc_feature)\n                t = self.style_mod1(fc_feature, emo_features[step_t])\n             #   emo_feature = torch.unsqueeze(torch.unsqueeze(emo_features[step_t],-1),-1)\n             #   emo_feature = emo_feature.repeat(1,fc_feature.shape[1],1,1)\n             #   t = adain(fc_feature, emo_feature)\n                deco_out.append(t)\n            elif jaco_net == 'gan':\n                result,_ = self.generator([fc_in])\n                deco_out.append(result)\n            else:\n                raise Exception(\"jaco_net type wrong\")\n\n        return torch.stack(deco_out,dim=1)\n\n    def forward(self, example_image, audio, pose, jaco_net, emo_features):\n        hidden = ( torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda()),\n                      torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda()))\n        outs = example_image\n        for down_block in self.down_blocks:\n            outs = down_block(outs)\n            image_feature = outs\n        image_feature = image_feature.view(image_feature.shape[0], -1)\n        lstm_input = []\n        for step_t in range(audio.size(1)):\n            current_audio = audio[ : ,step_t , :, :].unsqueeze(1)\n            current_feature = self.audio_eocder(current_audio)\n            current_feature = current_feature.view(current_feature.size(0), -1)\n            current_feature = self.audio_eocder_fc(current_feature) #256\n            pose_f = self.pose_encoder(pose[:,step_t]) #256\n            features = torch.cat([image_feature,  current_feature, pose_f, emo_features[step_t]], 1)\n            lstm_input.append(features)\n        lstm_input = torch.stack(lstm_input, dim = 1)\n        lstm_out, _ = self.lstm_two(lstm_input, hidden)\n        fc_out   = []\n        deco_out = []\n        for step_t in range(audio.size(1)):\n            fc_in = lstm_out[:,step_t,:]\n    #        fc_out.append(self.lstm_fc(fc_in))\n            if jaco_net == 'cnn':\n                fc_feature = torch.unsqueeze(fc_in,2)\n                fc_feature = torch.unsqueeze(fc_feature,3)\n                deco_out.append(self.decon(fc_feature))\n            elif jaco_net == 'gan':\n                result,_ = self.generator([fc_in])\n                deco_out.append(result)\n            else:\n                raise Exception(\"jaco_net type wrong\")\n\n        return torch.stack(deco_out,dim=1)\n\n\nclass AT_net2(nn.Module):\n    def __init__(self):\n        super(AT_net2, self).__init__()\n\n        down_blocks = []\n        for i in range(8):\n            down_blocks.append(DownBlock2d(3 if i == 0 else  2 * (2 ** i),\n                                            2 * (2 ** (i + 1)),\n                                           kernel_size=3, padding=1))\n        self.down_blocks = nn.ModuleList(down_blocks)\n\n\n     #   self.lmark_encoder = nn.Sequential(\n     #       nn.Linear(16,256),\n     #       nn.ReLU(True),\n     #       nn.Linear(256,512),\n     #       nn.ReLU(True),\n     #       )\n        self.pose_encoder = nn.Sequential(\n            nn.Linear(6,128),\n            nn.ReLU(True),\n            nn.Linear(128,256),\n            nn.ReLU(True),\n\n            )\n        self.audio_eocder = nn.Sequential(\n            conv2d(1,64,3,1,1),\n            conv2d(64,128,3,1,1),\n            nn.MaxPool2d(3, stride=(1,2)),\n            conv2d(128,256,3,1,1),\n            conv2d(256,256,3,1,1),\n            conv2d(256,512,3,1,1),\n            nn.MaxPool2d(3, stride=(2,2))\n            )\n        self.audio_eocder_fc = nn.Sequential(\n            nn.Linear(1024 *12,2048),\n            nn.ReLU(True),\n            nn.Linear(2048,256),\n            nn.ReLU(True),\n\n            )\n        self.lstm = nn.LSTM(256*4,256,3,batch_first = True)\n    #    self.lstm_fc = nn.Sequential(\n    #        nn.Linear(256,16),\n    #        )\n        self.decon = nn.Sequential(\n                nn.ConvTranspose2d(256, 256, kernel_size=6, stride=2, padding=1, bias=True),#4,4\n                nn.BatchNorm2d(256),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),#8,8\n                nn.BatchNorm2d(128),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True), #16,16\n                nn.BatchNorm2d(128),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),#32,32\n                nn.BatchNorm2d(128),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(128, 32+3, kernel_size=4, stride=2, padding=1, bias=True),#64,64\n            #    nn.ConvTranspose2d(128, 32*4, kernel_size=2, stride=2, padding=3, bias=True),#64,64\n\n\n                )\n        self.generator = Generator(64,256,8)\n\n\n    def forward(self, example_image, audio, pose, jaco_net, weight):\n        hidden = ( torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda()),\n                      torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda()))\n        outs = example_image\n        for down_block in self.down_blocks:\n            outs = down_block(outs)\n            image_feature = outs\n        image_feature = image_feature.view(image_feature.shape[0], -1)\n        lstm_input = []\n        for step_t in range(audio.size(1)):\n            current_audio = audio[ : ,step_t , :, :].unsqueeze(1)\n            current_feature = self.audio_eocder(current_audio)\n            current_feature = current_feature.view(current_feature.size(0), -1)\n            current_feature = self.audio_eocder_fc(current_feature)*weight\n            pose_f = self.pose_encoder(pose[:,step_t])\n            features = torch.cat([image_feature,  current_feature, pose_f], 1)\n            lstm_input.append(features)\n        lstm_input = torch.stack(lstm_input, dim = 1)\n        lstm_out, _ = self.lstm(lstm_input, hidden)\n        fc_out   = []\n        deco_out = []\n        for step_t in range(audio.size(1)):\n            fc_in = lstm_out[:,step_t,:]\n    #        fc_out.append(self.lstm_fc(fc_in))\n            if jaco_net == 'cnn':\n                fc_feature = torch.unsqueeze(fc_in,2)\n                fc_feature = torch.unsqueeze(fc_feature,3)\n                deco_out.append(self.decon(fc_feature))\n            elif jaco_net == 'gan':\n                result,_ = self.generator([fc_in])\n                deco_out.append(result)\n            else:\n                raise Exception(\"jaco_net type wrong\")\n\n        return torch.stack(deco_out,dim=1)\n\n\n\nclass Ct_encoder(nn.Module):\n    def __init__(self):\n        super(Ct_encoder, self).__init__()\n        self.audio_eocder = nn.Sequential(\n            conv2d(1,64,3,1,1),\n            conv2d(64,128,3,1,1),\n            nn.MaxPool2d(3, stride=(1,2)),\n            conv2d(128,256,3,1,1),\n            conv2d(256,256,3,1,1),\n            conv2d(256,512,3,1,1),\n            nn.MaxPool2d(3, stride=(2,2))\n            )\n        self.audio_eocder_fc = nn.Sequential(\n            nn.Linear(1024 *12,2048),\n            nn.ReLU(True),\n            nn.Linear(2048,256),\n            nn.ReLU(True),\n\n            )\n\n    def forward(self, audio):\n\n        feature = self.audio_eocder(audio)\n        feature = feature.view(feature.size(0),-1)\n        x = self.audio_eocder_fc(feature)\n\n        return x\n\n\nclass EmotionNet(nn.Module):\n    def __init__(self):\n        super(EmotionNet, self).__init__()\n\n        self.emotion_eocder = nn.Sequential(\n            conv2d(1,64,3,1,1),\n\n            nn.MaxPool2d((1,3), stride=(1,2)), #[1, 64, 12, 12]\n            conv2d(64,128,3,1,1),\n\n            conv2d(128,256,3,1,1),\n\n            nn.MaxPool2d((12,1), stride=(12,1)), #[1, 256, 1, 12]\n\n            conv2d(256,512,3,1,1),\n\n            nn.MaxPool2d((1,2), stride=(1,2)) #[1, 512, 1, 6]\n\n            )\n        self.emotion_eocder_fc = nn.Sequential(\n            nn.Linear(512 *6,2048),\n            nn.ReLU(True),\n            nn.Linear(2048,128),\n            nn.ReLU(True),\n\n            )\n\n        self.last_fc = nn.Linear(128,8)\n\n        self.re_id = nn.Sequential(\n            conv2d(512,1024,3,1,1),\n\n            nn.MaxPool2d((1,2), stride=(1,2)), #[1, 1024, 1, 3]\n            conv2d(1024,1024,3,1,1),\n\n            conv2d(1024,2048,3,1,1),\n\n            nn.MaxPool2d((1,2), stride=(1,2)) #[1, 2048, 1, 1]\n\n\n            )\n        self.re_id_fc = nn.Sequential(\n\n            nn.Linear(2048,512),\n            nn.ReLU(True),\n            nn.Linear(512,128),\n            nn.ReLU(True),\n            )\n\n\n    def forward(self, mfcc):\n       # mfcc= torch.unsqueeze(mfcc, 1)\n        mfcc=torch.transpose(mfcc,2,3)\n        feature = self.emotion_eocder(mfcc)\n\n   #     id_feature = feature.detach()\n\n        feature = feature.view(feature.size(0),-1)\n        x = self.emotion_eocder_fc(feature)\n\n\n  #      remove_feature = self.re_id(id_feature)\n  #      remove_feature = remove_feature.view(remove_feature.size(0),-1)\n  #      y = self.re_id_fc(remove_feature)\n\n        return x\n\n\nclass AF2F(nn.Module):\n    def __init__(self):\n        super(AF2F, self).__init__()\n        self.decon = nn.Sequential(\n                nn.ConvTranspose2d(384, 256, kernel_size=6, stride=2, padding=1, bias=True),#4,4\n                nn.BatchNorm2d(256),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),#8,8\n                nn.BatchNorm2d(128),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=True), #16,16\n                nn.BatchNorm2d(64),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True),#32,32\n                nn.BatchNorm2d(64),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(64, 32+3, kernel_size=4, stride=2, padding=1, bias=True),#64,64\n\n\n                )\n\n    def forward(self, content,emotion):\n        features = torch.cat([content,  emotion], 1) #connect tensors inputs and dimension\n        features = torch.unsqueeze(features,2)\n        features = torch.unsqueeze(features,3)\n        x = self.decon(features)\n\n\n        return x\n\nclass AF2F_s(nn.Module):\n    def __init__(self):\n        super(AF2F_s, self).__init__()\n        self.decon = nn.Sequential(\n                nn.ConvTranspose2d(256, 256, kernel_size=6, stride=2, padding=1, bias=True),#4,4\n                nn.BatchNorm2d(256),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),#8,8\n                nn.BatchNorm2d(128),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=True), #16,16\n                nn.BatchNorm2d(64),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True),#32,32\n                nn.BatchNorm2d(64),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(64, 32+3, kernel_size=4, stride=2, padding=1, bias=True),#64,64\n\n                nn.ReLU(),\n                )\n\n    def forward(self, content):\n       # features = torch.cat([content,  emotion], 1) #connect tensors inputs and dimension\n        features = torch.unsqueeze(content,2)\n        features = torch.unsqueeze(features,3)\n        x = self.decon(features)\n\n\n        return x\n\n\nclass A2I(nn.Module):\n    def __init__(self):\n        super(A2I, self).__init__()\n        self.audio_eocder = nn.Sequential(\n            conv2d(1,64,3,1,1),\n            conv2d(64,128,3,1,1),\n            nn.MaxPool2d((1,5), stride=(1,2)),\n            conv2d(128,256,3,1,1),\n            conv2d(256,256,3,1,1),\n\n            nn.MaxPool2d((5,5), stride=(2,2))\n            )\n        self.decon = nn.Sequential(\n\n                nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),#8,8\n                nn.BatchNorm2d(128),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=True), #16,16\n                nn.BatchNorm2d(64),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, bias=True),#32,32\n                nn.BatchNorm2d(32),\n                nn.ReLU(True),\n                nn.ConvTranspose2d(32, 2, kernel_size=4, stride=2, padding=1, bias=True),#64,64\n\n                nn.ReLU(),\n                )\n\n    def forward(self, mfcc):\n        mfcc= torch.unsqueeze(mfcc, 1)\n        mfcc=torch.transpose(mfcc,2,3)\n        feature = self.audio_eocder(mfcc)\n\n   #     id_feature = feature.detach()\n\n        x = self.decon(feature)\n\n        return x\n\ndef kp2gaussian(kp, spatial_size, kp_variance):\n    \"\"\"\n    Transform a keypoint into gaussian like representation\n    \"\"\"\n    mean = kp['value'] #[4,10,2]\n\n    coordinate_grid = make_coordinate_grid(spatial_size, mean.type()) #[h,w,2]\n    number_of_leading_dimensions = len(mean.shape) - 1\n    shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape #5\n    coordinate_grid = coordinate_grid.view(*shape) #[1,1,h,w,2]\n    repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1)\n    coordinate_grid = coordinate_grid.repeat(*repeats) #[4,10,h,w,2]\n\n    # Preprocess kp shape\n    shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 2)\n    mean = mean.view(*shape) #[4,10,1,1,2]\n\n    mean_sub = (coordinate_grid - mean)\n\n    out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)\n\n    return out\n\n\ndef make_coordinate_grid(spatial_size, type):\n    \"\"\"\n    Create a meshgrid [-1,1] x [-1,1] of given spatial_size.\n    \"\"\"\n    h, w = spatial_size\n    x = torch.arange(w).type(type)\n    y = torch.arange(h).type(type)\n\n    x = (2 * (x / (w - 1)) - 1)\n    y = (2 * (y / (h - 1)) - 1)\n\n    yy = y.view(-1, 1).repeat(1, w)\n    xx = x.view(1, -1).repeat(h, 1)\n\n    meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)\n\n    return meshed\n\n\nclass ResBlock2d(nn.Module):\n    \"\"\"\n    Res block, preserve spatial resolution.\n    \"\"\"\n\n    def __init__(self, in_features, kernel_size, padding):\n        super(ResBlock2d, self).__init__()\n        self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,\n                               padding=padding)\n        self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,\n                               padding=padding)\n        self.norm1 = BatchNorm2d(in_features, affine=True)\n        self.norm2 = BatchNorm2d(in_features, affine=True)\n\n    def forward(self, x):\n        out = self.norm1(x)\n        out = F.relu(out)\n        out = self.conv1(out)\n        out = self.norm2(out)\n        out = F.relu(out)\n        out = self.conv2(out)\n        out += x\n        return out\n\n\nclass UpBlock2d(nn.Module):\n    \"\"\"\n    Upsampling block for use in decoder.\n    \"\"\"\n\n    def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):\n        super(UpBlock2d, self).__init__()\n\n        self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,\n                              padding=padding, groups=groups)\n        self.norm = BatchNorm2d(out_features, affine=True)\n\n    def forward(self, x):\n        out = F.interpolate(x, scale_factor=2)\n        out = self.conv(out)\n        out = self.norm(out)\n        out = F.relu(out)\n        return out\n\n\nclass DownBlock2d(nn.Module):\n    \"\"\"\n    Downsampling block for use in encoder.\n    \"\"\"\n\n    def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):\n        super(DownBlock2d, self).__init__()\n        self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,\n                              padding=padding, groups=groups)\n        self.norm = BatchNorm2d(out_features, affine=True)\n        self.pool = nn.AvgPool2d(kernel_size=(2, 2))\n\n    def forward(self, x):\n        out = self.conv(x)\n        out = self.norm(out)\n        out = F.relu(out)\n        out = self.pool(out)\n        return out\n\n\nclass SameBlock2d(nn.Module):\n    \"\"\"\n    Simple block, preserve spatial resolution.\n    \"\"\"\n\n    def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1):\n        super(SameBlock2d, self).__init__()\n        self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,\n                              kernel_size=kernel_size, padding=padding, groups=groups)\n        self.norm = BatchNorm2d(out_features, affine=True)\n\n    def forward(self, x):\n        out = self.conv(x)\n        out = self.norm(out)\n        out = F.relu(out)\n        return out\n\n\nclass Encoder(nn.Module):\n    \"\"\"\n    Hourglass Encoder\n    \"\"\"\n\n    def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):\n        super(Encoder, self).__init__()\n\n        down_blocks = []\n        for i in range(num_blocks):\n            down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),\n                                           min(max_features, block_expansion * (2 ** (i + 1))),\n                                           kernel_size=3, padding=1))\n        self.down_blocks = nn.ModuleList(down_blocks)\n\n    def forward(self, x):\n        outs = [x]\n        for down_block in self.down_blocks:\n            outs.append(down_block(outs[-1]))\n        return outs\n\n\nclass Decoder(nn.Module):\n    \"\"\"\n    Hourglass Decoder\n    \"\"\"\n\n    def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):\n        super(Decoder, self).__init__()\n\n        up_blocks = []\n\n        for i in range(num_blocks)[::-1]:\n            in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))\n            out_filters = min(max_features, block_expansion * (2 ** i))\n            up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))\n\n        self.up_blocks = nn.ModuleList(up_blocks)\n        self.out_filters = block_expansion + in_features\n\n    def forward(self, x):\n        out = x.pop()\n        for up_block in self.up_blocks:\n            out = up_block(out)\n            skip = x.pop()\n            out = torch.cat([out, skip], dim=1)\n        return out\n\n\nclass Hourglass(nn.Module):\n    \"\"\"\n    Hourglass architecture.\n    \"\"\"\n\n    def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):\n        super(Hourglass, self).__init__()\n        self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)\n        self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)\n        self.out_filters = self.decoder.out_filters\n\n    def forward(self, x):\n        return self.decoder(self.encoder(x))\n\n\nclass AntiAliasInterpolation2d(nn.Module):\n    \"\"\"\n    Band-limited downsampling, for better preservation of the input signal.\n    \"\"\"\n    def __init__(self, channels, scale):\n        super(AntiAliasInterpolation2d, self).__init__()\n     #   sigma = (1 / scale - 1) / 2\n        sigma = 1.5\n        kernel_size = 2 * round(sigma * 4) + 1\n        self.ka = kernel_size // 2\n        self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka\n\n        kernel_size = [kernel_size, kernel_size]\n        sigma = [sigma, sigma]\n        # The gaussian kernel is the product of the\n        # gaussian function of each dimension.\n        kernel = 1\n        meshgrids = torch.meshgrid(\n            [\n                torch.arange(size, dtype=torch.float32)\n                for size in kernel_size\n                ]\n        )\n        for size, std, mgrid in zip(kernel_size, sigma, meshgrids):\n            mean = (size - 1) / 2\n            kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))\n\n        # Make sure sum of values in gaussian kernel equals 1.\n        kernel = kernel / torch.sum(kernel)\n        # Reshape to depthwise convolutional weight\n        kernel = kernel.view(1, 1, *kernel.size())\n        kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))\n\n        self.register_buffer('weight', kernel)\n        self.groups = channels\n        self.scale = scale\n        inv_scale = 1 / scale\n        self.int_inv_scale = int(inv_scale)\n\n    def forward(self, input):\n        if self.scale == 1.0:\n            return input\n\n        out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))\n        out = F.conv2d(out, weight=self.weight, groups=self.groups)\n        out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]\n\n        return out\n\ndef sigmoid(x):\n    return 1 / (1 + math.exp(-x))\n\n\ndef norm_angle(angle):\n    norm_angle = sigmoid(10 * (abs(angle) / 0.7853975 - 1))\n    return norm_angle\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    \"3x3 convolution with padding\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=1, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(BasicBlock, self).__init__()\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.relu = nn.ReLU()\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(Bottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n                               padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(planes * 4)\n        self.relu = nn.ReLU()\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out = out + residual\n        out = self.relu(out)\n\n        return out\n\nclass EmDetector(nn.Module):\n    \"\"\"\n    Detecting a keypoints. Return keypoint position and jacobian near each keypoint.\n    \"\"\"\n\n    def __init__(self, block_expansion,  num_channels, max_features,\n                 num_blocks, scale_factor=1,  num_classes=8):\n        super(EmDetector, self).__init__()\n        self.inplanes = 64\n        self.predictor = Hourglass(block_expansion, in_features=num_channels,\n                                   max_features=max_features, num_blocks=num_blocks)\n\n\n\n\n        self.scale_factor = scale_factor\n        if self.scale_factor != 1:\n            self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)\n        self.conv1 = nn.Conv2d(self.predictor.out_filters, 64, kernel_size=3, stride=1, padding=1,\n                               bias=False)\n        self.bn1 = nn.BatchNorm2d(64)\n        self.relu = nn.ReLU()\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        layers = [2,2,2,2]\n        self.layer1 = self._make_layer(BasicBlock, 64, layers[0])\n        self.layer2 = self._make_layer(BasicBlock, 128, layers[1], stride=2)\n        self.layer3 = self._make_layer(BasicBlock, 256, layers[2], stride=2)\n        self.layer4 = self._make_layer(BasicBlock, 512, layers[3], stride=2)\n        self.avgpool = nn.AdaptiveAvgPool2d(1)\n        self.fc = nn.Linear(512 * BasicBlock.expansion, num_classes)\n        self.classify = Classify()\n    def _make_layer(self, block, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(self.inplanes, planes * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample))\n        self.inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(self.inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def adain_feature(self, x): #torch.Size([4, 3, H, W])\n        if self.scale_factor != 1:\n            x = self.down(x) # 0.25 [4, 3, H/4, W/4]\n\n        feature_map = self.predictor(x) #[4,3+32,H/4, W/4]\n\n    #    out = self.fc(out)\n\n        return feature_map\n\n    def forward(self, x): #torch.Size([4, 3, H, W])\n        if self.scale_factor != 1:\n            x = self.down(x) # 0.25 [4, 3, H/4, W/4]\n\n        feature_map = self.predictor(x) #[4,3+32,H/4, W/4]\n        f = self.conv1(feature_map) #[16,64,64,64]\n        f = self.bn1(f) #torch.Size([16, 64, 64, 64])\n        f = self.relu(f)\n        f = self.maxpool(f) #[16, 64, 32, 32]\n\n        f = self.layer1(f) #[16, 64, 32, 32]\n        f = self.layer2(f) #[16, 128, 16, 16])\n        f = self.layer3(f) #[16, 256, 8, 8]\n        f = self.layer4(f) #[16, 512, 4, 4]\n        f = self.avgpool(f) #[16, 512, 1, 1]\n        out = f.squeeze(3).squeeze(2)\n        fake = self.classify(out)\n    #    out = self.fc(out)\n\n        return out, fake\n\n\n\n\n\n\nclass Emotion_k(nn.Module):\n    \"\"\"\n    Detecting a keypoints. Return keypoint position and jacobian near each keypoint.\n    \"\"\"\n\n    def __init__(self, block_expansion,  num_channels, max_features,\n                 num_blocks, scale_factor=1,  num_classes=8):\n        super(Emotion_k, self).__init__()\n        self.inplanes = 64\n        self.predictor = Hourglass(block_expansion, in_features=num_channels,\n                                   max_features=max_features, num_blocks=num_blocks)\n\n\n\n\n        self.scale_factor = scale_factor\n        if self.scale_factor != 1:\n            self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)\n        self.conv1 = nn.Conv2d(self.predictor.out_filters, 64, kernel_size=3, stride=1, padding=1,\n                               bias=False)\n        self.bn1 = nn.BatchNorm2d(64)\n        self.relu = nn.ReLU()\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        layers = [2,2,2,2]\n        self.layer1 = self._make_layer(BasicBlock, 64, layers[0])\n        self.layer2 = self._make_layer(BasicBlock, 128, layers[1], stride=2)\n        self.layer3 = self._make_layer(BasicBlock, 256, layers[2], stride=2)\n        self.layer4 = self._make_layer(BasicBlock, 512, layers[3], stride=2)\n        self.avgpool = nn.AdaptiveAvgPool2d(1)\n        self.fc = nn.Linear(512 * BasicBlock.expansion, num_classes)\n\n        self.embed_fn, self.input_ch = get_embedder(10, 0)\n\n        self.fc_p = nn.Sequential(\n            nn.Linear(10 * 126,1024),\n            nn.ReLU(True),\n            nn.Linear(1024,512),\n            nn.ReLU(True),\n\n            )\n        self.fc_n = nn.Sequential(\n            nn.Linear(10 * 6,128),\n            nn.ReLU(True),\n            nn.Linear(128,512),\n            nn.ReLU(True),\n\n            )\n\n        self.fc_all = nn.Sequential(\n            nn.Linear(1024,512),\n            nn.ReLU(True),\n            nn.Linear(512,256),\n            nn.ReLU(True),\n            nn.Linear(256,64),\n            nn.ReLU(True),\n            )\n\n      #  self.fc_single = nn.Sequential(\n      #      nn.Linear(512,256),\n      #      nn.ReLU(True),\n      #      nn.Linear(256,64),\n      #      nn.ReLU(True),\n      #      )\n\n        self.final = nn.Sequential(\n            nn.Conv1d(1,2,4,2,1),\n            nn.MaxPool1d(2,stride=2),\n            nn.ReLU(True),\n            nn.Conv1d(2,4,4,2,1),\n            nn.ReLU(True),\n            nn.Conv1d(4,4,3),\n\n            )\n\n        self.final_4 = nn.Sequential(\n            nn.Conv1d(4,4,3,1,1),\n            nn.MaxPool1d(2,stride=2),\n            nn.ReLU(True),\n            nn.Conv1d(4,4,3,1)\n\n            )\n\n        self.final_10 = nn.Sequential(\n            nn.Conv1d(4,8,3,1,1), #[B,8,16]\n            nn.MaxPool1d(2,stride=2), #[B,8,8]\n            nn.ReLU(True),\n            nn.Conv1d(8,10,3,1), #[B,10,6]\n\n\n            )\n\n        self.classify = Classify()\n\n    def _make_layer(self, block, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(self.inplanes, planes * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample))\n        self.inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(self.inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def linear_10(self, x, value, jacobian): #torch.Size([4, 3, H, W])\n        if self.scale_factor != 1:\n            x = self.down(x) # 0.25 [4, 3, H/4, W/4]\n\n        feature_map = self.predictor(x) #[4,3+32,H/4, W/4]\n        f = self.conv1(feature_map) #[16,64,64,64]\n        f = self.bn1(f) #torch.Size([16, 64, 64, 64])\n        f = self.relu(f)\n        f = self.maxpool(f) #[16, 64, 32, 32]\n\n        f = self.layer1(f) #[16, 64, 32, 32]\n        f = self.layer2(f) #[16, 128, 16, 16])\n        f = self.layer3(f) #[16, 256, 8, 8]\n        f = self.layer4(f) #[16, 512, 4, 4]\n        f = self.avgpool(f) #[16, 512, 1, 1]\n        out = f.squeeze(3).squeeze(2)\n        fake = self.classify(out)\n        jacobian = jacobian.reshape(jacobian.shape[0],jacobian.shape[1],4)\n        neu_input = torch.cat((value,jacobian),2)\n        posi_input = self.embed_fn(neu_input)\n        posi_input =posi_input.reshape(posi_input.shape[0],-1)\n        ner_feature = self.fc_p(posi_input)\n        all_fc = self.fc_all(torch.cat((out,ner_feature),1)).reshape(-1,4,16)\n        result = self.final_10(all_fc)\n        e_value = result[:,:,:2]\n        e_jacobian = result[:,:,2:].reshape(result.shape[0],10,2,2)\n        kp = {'value': e_value,'jacobian': e_jacobian}\n\n        return kp, fake\n\n\n    def linear_4(self, x, value, jacobian): #torch.Size([4, 3, H, W])\n        if self.scale_factor != 1:\n            x = self.down(x) # 0.25 [4, 3, H/4, W/4]\n\n        feature_map = self.predictor(x) #[4,3+32,H/4, W/4]\n        f = self.conv1(feature_map) #[16,64,64,64]\n        f = self.bn1(f) #torch.Size([16, 64, 64, 64])\n        f = self.relu(f)\n        f = self.maxpool(f) #[16, 64, 32, 32]\n\n        f = self.layer1(f) #[16, 64, 32, 32]\n        f = self.layer2(f) #[16, 128, 16, 16])\n        f = self.layer3(f) #[16, 256, 8, 8]\n        f = self.layer4(f) #[16, 512, 4, 4]\n        f = self.avgpool(f) #[16, 512, 1, 1]\n        out = f.squeeze(3).squeeze(2)\n        fake = self.classify(out)\n    #    jacobian = jacobian.reshape(jacobian.shape[0],jacobian.shape[1],4)\n    #    neu_input = torch.cat((value,jacobian),2)\n    #    posi_input = self.embed_fn(neu_input)\n    #    posi_input =posi_input.reshape(posi_input.shape[0],-1)\n    #    ner_feature = self.fc_p(posi_input)\n    #    all_fc = self.fc_all(torch.cat((out,ner_feature),1)).reshape(-1,4,16)\n        all_fc = torch.unsqueeze(self.fc_single(out),1)\n        result = self.final(all_fc)\n        e_value = result[:,:,:2]\n        e_jacobian = result[:,:,2:].reshape(result.shape[0],4,2,2)\n        kp = {'value': e_value,'jacobian': e_jacobian}\n    #    out = self.fc(out)\n\n        return kp, fake\n\n    def linear_np_10(self, x, value, jacobian): #torch.Size([4, 3, H, W])\n        if self.scale_factor != 1:\n            x = self.down(x) # 0.25 [4, 3, H/4, W/4]\n\n        feature_map = self.predictor(x) #[4,3+32,H/4, W/4]\n        f = self.conv1(feature_map) #[16,64,64,64]\n        f = self.bn1(f) #torch.Size([16, 64, 64, 64])\n        f = self.relu(f)\n        f = self.maxpool(f) #[16, 64, 32, 32]\n\n        f = self.layer1(f) #[16, 64, 32, 32]\n        f = self.layer2(f) #[16, 128, 16, 16])\n        f = self.layer3(f) #[16, 256, 8, 8]\n        f = self.layer4(f) #[16, 512, 4, 4]\n        f = self.avgpool(f) #[16, 512, 1, 1]\n        out = f.squeeze(3).squeeze(2)\n        fake = self.classify(out)\n        jacobian = jacobian.reshape(jacobian.shape[0],jacobian.shape[1],4)\n        neu_input = torch.cat((value,jacobian),2)\n\n        posi_input =neu_input.reshape(neu_input.shape[0],-1)\n        ner_feature = self.fc_n(posi_input)\n        all_fc = self.fc_all(torch.cat((out,ner_feature),1)).reshape(-1,4,16)\n        result = self.final_10(all_fc)\n        e_value = result[:,:,:2]\n        e_jacobian = result[:,:,2:].reshape(result.shape[0],10,2,2)\n        kp = {'value': e_value,'jacobian': e_jacobian}\n    #    out = self.fc(out)\n\n        return kp, fake\n\n    def linear_np_4(self, x, value, jacobian): #torch.Size([4, 3, H, W])\n        if self.scale_factor != 1:\n            x = self.down(x) # 0.25 [4, 3, H/4, W/4]\n\n        feature_map = self.predictor(x) #[4,3+32,H/4, W/4]\n        f = self.conv1(feature_map) #[16,64,64,64]\n        f = self.bn1(f) #torch.Size([16, 64, 64, 64])\n        f = self.relu(f)\n        f = self.maxpool(f) #[16, 64, 32, 32]\n\n        f = self.layer1(f) #[16, 64, 32, 32]\n        f = self.layer2(f) #[16, 128, 16, 16])\n        f = self.layer3(f) #[16, 256, 8, 8]\n        f = self.layer4(f) #[16, 512, 4, 4]\n        f = self.avgpool(f) #[16, 512, 1, 1]\n        out = f.squeeze(3).squeeze(2)\n        fake = self.classify(out)\n        jacobian = jacobian.reshape(jacobian.shape[0],jacobian.shape[1],4)\n        neu_input = torch.cat((value,jacobian),2)\n\n        posi_input =neu_input.reshape(neu_input.shape[0],-1)\n        ner_feature = self.fc_n(posi_input)\n        all_fc = torch.unsqueeze(self.fc_all(torch.cat((out,ner_feature),1)),1)\n        result = self.final(all_fc)\n        e_value = result[:,:,:2]\n        e_jacobian = result[:,:,2:].reshape(result.shape[0],4,2,2)\n        kp = {'value': e_value,'jacobian': e_jacobian}\n    #    out = self.fc(out)\n\n        return kp, fake\n\n\n    def emotion_feature(self, feature, value, jacobian): #torch.Size([4, 3, H, W])\n\n        out = feature\n        fake = self.classify(out)\n        jacobian = jacobian.reshape(jacobian.shape[0],jacobian.shape[1],4)\n        neu_input = torch.cat((value,jacobian),2)\n        posi_input = self.embed_fn(neu_input)\n        posi_input =posi_input.reshape(posi_input.shape[0],-1)\n        ner_feature = self.fc_p(posi_input)\n        all_fc = torch.unsqueeze(self.fc_all(torch.cat((out,ner_feature),1)),1)\n        result = self.final(all_fc)\n        e_value = result[:,:,:2]\n        e_jacobian = result[:,:,2:].reshape(result.shape[0],4,2,2)\n        kp = {'value': e_value,'jacobian': e_jacobian}\n    #    out = self.fc(out)\n\n        return kp, fake\n\n    def feature(self, x): #torch.Size([4, 3, H, W])\n        if self.scale_factor != 1:\n            x = self.down(x) # 0.25 [4, 3, H/4, W/4]\n\n        feature_map = self.predictor(x) #[4,3+32,H/4, W/4]\n        f = self.conv1(feature_map) #[16,64,64,64]\n        f = self.bn1(f) #torch.Size([16, 64, 64, 64])\n        f = self.relu(f)\n        f = self.maxpool(f) #[16, 64, 32, 32]\n\n        f = self.layer1(f) #[16, 64, 32, 32]\n        f = self.layer2(f) #[16, 128, 16, 16])\n        f = self.layer3(f) #[16, 256, 8, 8]\n        f = self.layer4(f) #[16, 512, 4, 4]\n        f = self.avgpool(f) #[16, 512, 1, 1]\n        out = f.squeeze(3).squeeze(2)\n\n    #    out = self.fc(out)\n\n        return out\n\n    def forward(self, x, value, jacobian): #torch.Size([4, 3, H, W])\n        if self.scale_factor != 1:\n            x = self.down(x) # 0.25 [4, 3, H/4, W/4]\n\n        feature_map = self.predictor(x) #[4,3+32,H/4, W/4]\n        f = self.conv1(feature_map) #[16,64,64,64]\n        f = self.bn1(f) #torch.Size([16, 64, 64, 64])\n        f = self.relu(f)\n        f = self.maxpool(f) #[16, 64, 32, 32]\n\n        f = self.layer1(f) #[16, 64, 32, 32]\n        f = self.layer2(f) #[16, 128, 16, 16])\n        f = self.layer3(f) #[16, 256, 8, 8]\n        f = self.layer4(f) #[16, 512, 4, 4]\n        f = self.avgpool(f) #[16, 512, 1, 1]\n        out = f.squeeze(3).squeeze(2)\n        fake = self.classify(out)\n        jacobian = jacobian.reshape(jacobian.shape[0],jacobian.shape[1],4)\n        neu_input = torch.cat((value,jacobian),2)\n        posi_input = self.embed_fn(neu_input)\n        posi_input =posi_input.reshape(posi_input.shape[0],-1)\n        ner_feature = self.fc_p(posi_input)\n        all_fc = torch.unsqueeze(self.fc_all(torch.cat((out,ner_feature),1)),1)\n        result = self.final(all_fc)\n        e_value = result[:,:,:2]\n        e_jacobian = result[:,:,2:].reshape(result.shape[0],4,2,2)\n        kp = {'value': e_value,'jacobian': e_jacobian}\n    #    out = self.fc(out)\n\n        return kp, fake\n\nclass Emotion_map(nn.Module):\n    \"\"\"\n    Detecting a keypoints. Return keypoint position and jacobian near each keypoint.\n    \"\"\"\n\n    def __init__(self, block_expansion,  num_channels, max_features,\n                 num_blocks, scale_factor=1,  num_classes=8):\n        super(Emotion_map, self).__init__()\n        self.inplanes = 64\n        self.predictor = Hourglass(block_expansion, in_features=num_channels,\n                                   max_features=max_features, num_blocks=num_blocks)\n\n\n\n\n        self.scale_factor = scale_factor\n        if self.scale_factor != 1:\n            self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)\n        self.conv1 = nn.Conv2d(self.predictor.out_filters, 64, kernel_size=3, stride=1, padding=1,\n                               bias=False)\n        self.bn1 = nn.BatchNorm2d(64)\n        self.relu = nn.ReLU()\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        layers = [2,2,2,2]\n        self.layer1 = self._make_layer(BasicBlock, 64, layers[0])\n        self.layer2 = self._make_layer(BasicBlock, 128, layers[1], stride=2)\n        self.layer3 = self._make_layer(BasicBlock, 256, layers[2], stride=2)\n        self.layer4 = self._make_layer(BasicBlock, 512, layers[3], stride=2)\n        self.avgpool = nn.AdaptiveAvgPool2d(1)\n        self.fc = nn.Linear(512 * BasicBlock.expansion, num_classes)\n\n        self.embed_fn, self.input_ch = get_embedder(10, 0)\n\n        self.fc_p = nn.Sequential(\n            nn.Linear(10 * 126,1024),\n            nn.ReLU(True),\n            nn.Linear(1024,512),\n            nn.ReLU(True),\n\n            )\n\n        self.fc_all = nn.Sequential(\n            nn.Linear(1024,2048),\n            nn.ReLU(True)\n            )\n\n        self.final = nn.Sequential(\n            nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),#8,8\n            nn.BatchNorm2d(128),\n            nn.ReLU(True),\n            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=True), #16,16\n            nn.BatchNorm2d(64),\n            nn.ReLU(True),\n            nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1, bias=True),#32,32\n            nn.BatchNorm2d(64),\n            nn.ReLU(True),\n            nn.ConvTranspose2d(64, 32+3, kernel_size=4, stride=2, padding=1, bias=True),#64,64\n\n            )\n\n\n        self.classify = Classify()\n        self.kp = nn.Conv2d(in_channels=35, out_channels=10, kernel_size=(7, 7),\n                            padding=0)\n        self.jacobian = nn.Conv2d(in_channels=35,\n                                      out_channels=4 * 10, kernel_size=(7, 7), padding=0)\n        self.jacobian.weight.data.zero_()\n        self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * 10, dtype=torch.float))\n        self.temperature = 0.1\n\n        self.kp_4 = nn.Conv2d(in_channels=35, out_channels=4, kernel_size=(7, 7),\n                            padding=0)\n        self.jacobian_4 = nn.Conv2d(in_channels=35,\n                                      out_channels=4 * 4, kernel_size=(7, 7), padding=0)\n        self.jacobian_4.weight.data.zero_()\n        self.jacobian_4.bias.data.copy_(torch.tensor([1, 0, 0, 1] * 4, dtype=torch.float))\n\n\n    def _make_layer(self, block, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(self.inplanes, planes * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample))\n        self.inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(self.inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def gaussian2kp(self, heatmap):\n        \"\"\"\n        Extract the mean and from a heatmap\n        \"\"\"\n        shape = heatmap.shape\n        heatmap = heatmap.unsqueeze(-1) #[4,10,58,58,1]\n        grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) #[1,1,58,58,2]\n        value = (heatmap * grid).sum(dim=(2, 3)) #[4,10,2]\n        kp = {'value': value}\n\n        return kp\n\n    def map_4(self, x, value, jacobian): #torch.Size([4, 3, H, W])\n        if self.scale_factor != 1:\n            x = self.down(x) # 0.25 [4, 3, H/4, W/4]\n\n        feature_map = self.predictor(x) #[4,3+32,H/4, W/4]\n        f = self.conv1(feature_map) #[16,64,64,64]\n        f = self.bn1(f) #torch.Size([16, 64, 64, 64])\n        f = self.relu(f)\n        f = self.maxpool(f) #[16, 64, 32, 32]\n\n        f = self.layer1(f) #[16, 64, 32, 32]\n        f = self.layer2(f) #[16, 128, 16, 16])\n        f = self.layer3(f) #[16, 256, 8, 8]\n        f = self.layer4(f) #[16, 512, 4, 4]\n        f = self.avgpool(f) #[16, 512, 1, 1]\n        out = f.squeeze(3).squeeze(2)\n        fake = self.classify(out)\n        jacobian = jacobian.reshape(jacobian.shape[0],jacobian.shape[1],4)\n        neu_input = torch.cat((value,jacobian),2)\n        posi_input = self.embed_fn(neu_input)\n        posi_input =posi_input.reshape(posi_input.shape[0],-1)\n        ner_feature = self.fc_p(posi_input)\n        all_fc = self.fc_all(torch.cat((out,ner_feature),1)).reshape(-1,128,4,4)\n        feature_map = self.final(all_fc)\n        prediction = self.kp_4(feature_map) #[4,10,H/4-6, W/4-6]\n\n        final_shape = prediction.shape\n\n        heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58]\n        heatmap = F.softmax(heatmap / self.temperature, dim=2)\n        heatmap = heatmap.view(*final_shape) #[4,10,58,58]\n\n        out = self.gaussian2kp(heatmap)\n        out['heatmap'] = heatmap\n\n        if self.jacobian is not None:\n            jacobian_map = self.jacobian_4(feature_map) ##[4,40,H/4-6, W/4-6]\n            jacobian_map = jacobian_map.reshape(final_shape[0], 4, 4, final_shape[2],\n                                                final_shape[3])\n            heatmap = heatmap.unsqueeze(2)\n\n            jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6]\n            jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)\n            jacobian = jacobian.sum(dim=-1) #[4,10,4]\n            jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2]\n            out['jacobian'] = jacobian\n\n\n\n        return out, fake\n\n    def forward(self, x, value, jacobian): #torch.Size([4, 3, H, W])\n        if self.scale_factor != 1:\n            x = self.down(x) # 0.25 [4, 3, H/4, W/4]\n\n        feature_map = self.predictor(x) #[4,3+32,H/4, W/4]\n        f = self.conv1(feature_map) #[16,64,64,64]\n        f = self.bn1(f) #torch.Size([16, 64, 64, 64])\n        f = self.relu(f)\n        f = self.maxpool(f) #[16, 64, 32, 32]\n\n        f = self.layer1(f) #[16, 64, 32, 32]\n        f = self.layer2(f) #[16, 128, 16, 16])\n        f = self.layer3(f) #[16, 256, 8, 8]\n        f = self.layer4(f) #[16, 512, 4, 4]\n        f = self.avgpool(f) #[16, 512, 1, 1]\n        out = f.squeeze(3).squeeze(2)\n        fake = self.classify(out)\n        jacobian = jacobian.reshape(jacobian.shape[0],jacobian.shape[1],4)\n        neu_input = torch.cat((value,jacobian),2)\n        posi_input = self.embed_fn(neu_input)\n        posi_input =posi_input.reshape(posi_input.shape[0],-1)\n        ner_feature = self.fc_p(posi_input)\n        all_fc = self.fc_all(torch.cat((out,ner_feature),1)).reshape(-1,128,4,4)\n        feature_map = self.final(all_fc)\n\n        prediction = self.kp(feature_map) #[4,10,H/4-6, W/4-6]\n\n        final_shape = prediction.shape\n\n        heatmap = prediction.view(final_shape[0], final_shape[1], -1) #[4, 10, 58*58]\n        heatmap = F.softmax(heatmap / self.temperature, dim=2)\n        heatmap = heatmap.view(*final_shape) #[4,10,58,58]\n\n        out = self.gaussian2kp(heatmap)\n        out['heatmap'] = heatmap\n\n        if self.jacobian is not None:\n            jacobian_map = self.jacobian(feature_map) ##[4,40,H/4-6, W/4-6]\n            jacobian_map = jacobian_map.reshape(final_shape[0], 10, 4, final_shape[2],\n                                                final_shape[3])\n            heatmap = heatmap.unsqueeze(2)\n\n            jacobian = heatmap * jacobian_map #[4,10,4,H/4-6, W/4-6]\n            jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)\n            jacobian = jacobian.sum(dim=-1) #[4,10,4]\n            jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) #[4,10,2,2]\n            out['jacobian'] = jacobian\n\n\n\n        return out, fake\n\n\ndef conv2d(channel_in, channel_out,\n           ksize=3, stride=1, padding=1,\n           activation=nn.ReLU,\n           normalizer=nn.BatchNorm2d):\n    layer = list()\n    bias = True if not normalizer else False\n\n    layer.append(nn.Conv2d(channel_in, channel_out,\n                     ksize, stride, padding,\n                     bias=bias))\n    _apply(layer, activation, normalizer, channel_out)\n    # init.kaiming_normal(layer[0].weight)\n\n    return nn.Sequential(*layer)\n\ndef _apply(layer, activation, normalizer, channel_out=None):\n    if normalizer:\n        layer.append(normalizer(channel_out))\n    if activation:\n        layer.append(activation())\n    return layer"
  },
  {
    "path": "ops.py",
    "content": "import torch\nimport torchvision\nimport torch.nn as nn\nimport torch.nn.init as init\nfrom torch.autograd import Variable\n\n\nclass ResidualBlock(nn.Module):\n    def __init__(self, channel_in, channel_out):\n        super(ResidualBlock, self).__init__()\n\n        self.block = nn.Sequential(\n            conv3d(channel_in, channel_out, 3, 1, 1),\n            conv3d(channel_out, channel_out, 3, 1, 1, activation=None)\n        )\n\n        self.lrelu = nn.ReLU(0.2)\n\n    def forward(self, x):\n        residual = x\n        out = self.block(x)\n       \n        out += residual\n        out = self.lrelu(out)\n        return out\n\ndef linear(channel_in, channel_out,\n           activation=nn.ReLU,\n           normalizer=nn.BatchNorm1d):\n    layer = list()\n    bias = True if not normalizer else False\n\n    layer.append(nn.Linear(channel_in, channel_out, bias=bias))\n    _apply(layer, activation, normalizer, channel_out)\n    # init.kaiming_normal(layer[0].weight)\n\n    return nn.Sequential(*layer)\n\n\ndef conv2d(channel_in, channel_out,\n           ksize=3, stride=1, padding=1,\n           activation=nn.ReLU,\n           normalizer=nn.BatchNorm2d):\n    layer = list()\n    bias = True if not normalizer else False\n\n    layer.append(nn.Conv2d(channel_in, channel_out,\n                     ksize, stride, padding,\n                     bias=bias))\n    _apply(layer, activation, normalizer, channel_out)\n    # init.kaiming_normal(layer[0].weight)\n\n    return nn.Sequential(*layer)\n\n\ndef conv_transpose2d(channel_in, channel_out,\n                     ksize=4, stride=2, padding=1,\n                     activation=nn.ReLU,\n                     normalizer=nn.BatchNorm2d):\n    layer = list()\n    bias = True if not normalizer else False\n\n    layer.append(nn.ConvTranspose2d(channel_in, channel_out,\n                              ksize, stride, padding,\n                              bias=bias))\n    _apply(layer, activation, normalizer, channel_out)\n    # init.kaiming_normal(layer[0].weight)\n\n    return nn.Sequential(*layer)\n\n\ndef nn_conv2d(channel_in, channel_out,\n              ksize=3, stride=1, padding=1,\n              scale_factor=2,\n              activation=nn.ReLU,\n              normalizer=nn.BatchNorm2d):\n    layer = list()\n    bias = True if not normalizer else False\n\n    layer.append(nn.UpsamplingNearest2d(scale_factor=scale_factor))\n    layer.append(nn.Conv2d(channel_in, channel_out,\n                           ksize, stride, padding,\n                           bias=bias))\n    _apply(layer, activation, normalizer, channel_out)\n    # init.kaiming_normal(layer[1].weight)\n\n    return nn.Sequential(*layer)\n\n\ndef _apply(layer, activation, normalizer, channel_out=None):\n    if normalizer:\n        layer.append(normalizer(channel_out))\n    if activation:\n        layer.append(activation())\n    return layer\n\n"
  },
  {
    "path": "process_data.py",
    "content": "# -*- coding: utf-8 -*-\n\"\"\"\nCreated on Thu Jun 24 11:36:01 2021\n\n@author: Xinya\n\"\"\"\n\nimport os\nimport glob\nimport time\nimport numpy as np\nimport csv\nimport cv2\nimport dlib\n\nfrom skimage import transform as tf\n\nimport librosa\nimport python_speech_features\n\ndetector = dlib.get_frontal_face_detector()\npredictor = dlib.shape_predictor('./shape_predictor_68_face_landmarks.dat')\n\n\nimport imageio\n\n\n\ndef save(path, frames, format):\n    if format == '.mp4':\n        imageio.mimsave(path, frames)\n    elif format == '.png':\n        if not os.path.exists(path):\n\n\n            os.makedirs(path)\n        for j, frame in enumerate(frames):\n            cv2.imwrite(path+'/'+str(j)+'.png',frame)\n    #        imageio.imsave(os.path.join(path, str(j) + '.png'), frames[j])\n    else:\n        print (\"Unknown format %s\" % format)\n        exit()\n\ndef crop_image(image_path, out_path):\n    template = np.load('./M003_template.npy')\n    image = cv2.imread(image_path)\n    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)\n    rects = detector(gray, 1)  #detect human face\n    if len(rects) != 1:\n        return 0\n    for (j, rect) in enumerate(rects):\n        shape = predictor(gray, rect) #detect 68 points\n        shape = shape_to_np(shape)\n\n    pts2 = np.float32(template[:47,:])\n    # pts2 = np.float32(template[17:35,:])\n    # pts1 = np.vstack((landmark[27:36,:], landmark[39,:],landmark[42,:],landmark[45,:]))\n    pts1 = np.float32(shape[:47,:]) #eye and nose\n    # pts1 = np.float32(landmark[17:35,:])\n    tform = tf.SimilarityTransform()\n    tform.estimate( pts2, pts1) #Set the transformation matrix with the explicit parameters.\n    \n    dst = tf.warp(image, tform, output_shape=(256, 256))\n\n    dst = np.array(dst * 255, dtype=np.uint8)\n    \n    \n    cv2.imwrite(out_path,dst)\n\ndef shape_to_np(shape, dtype=\"int\"):\n    # initialize the list of (x, y)-coordinates\n    coords = np.zeros((shape.num_parts, 2), dtype=dtype)\n\n    # loop over all facial landmarks and convert them\n    # to a 2-tuple of (x, y)-coordinates\n    for i in range(0, shape.num_parts):\n        coords[i] = (shape.part(i).x, shape.part(i).y)\n\n    # return the list of (x, y)-coordinates\n    return coords\n\n\n\n\ndef crop_image_tem(video_path, out_path):\n    image_all = []\n    videoCapture = cv2.VideoCapture(video_path)\n    success, frame = videoCapture.read()\n    n = 0\n    while success :\n        image_all.append(frame)\n        n = n + 1\n        success, frame = videoCapture.read()\n        \n    if len(image_all)!=0 :\n        template = np.load('./M003_template.npy')\n        image=image_all[0]\n        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)\n        rects = detector(gray, 1)  #detect human face\n        if len(rects) != 1:\n            return 0\n        for (j, rect) in enumerate(rects):\n            shape = predictor(gray, rect) #detect 68 points\n            shape = shape_to_np(shape)\n\n        pts2 = np.float32(template[:47,:])\n        # pts2 = np.float32(template[17:35,:])\n        # pts1 = np.vstack((landmark[27:36,:], landmark[39,:],landmark[42,:],landmark[45,:]))\n        pts1 = np.float32(shape[:47,:]) #eye and nose\n        # pts1 = np.float32(landmark[17:35,:])\n        tform = tf.SimilarityTransform()\n        tform.estimate( pts2, pts1) #Set the transformation matrix with the explicit parameters.\n        out = []\n        for i in range(len(image_all)):\n            image = image_all[i]\n            dst = tf.warp(image, tform, output_shape=(256, 256))\n\n            dst = np.array(dst * 255, dtype=np.uint8)\n            out.append(dst)\n        if not os.path.exists(out_path):\n            os.makedirs(out_path)\n        save(out_path,out,'.png')\n\ndef proc_audio(src_mouth_path, dst_audio_path):\n    audio_command = 'ffmpeg -i \\\"{}\\\" -loglevel error -y -f wav -acodec pcm_s16le ' \\\n                    '-ar 16000 \\\"{}\\\"'.format(src_mouth_path, dst_audio_path)\n    os.system(audio_command)\n\n\ndef audio2mfcc(audio_file, save, name):\n    speech, sr = librosa.load(audio_file, sr=16000)\n  #  mfcc = python_speech_features.mfcc(speech ,16000,winstep=0.01)\n    speech = np.insert(speech, 0, np.zeros(1920))\n    speech = np.append(speech, np.zeros(1920))\n    mfcc = python_speech_features.mfcc(speech,16000,winstep=0.01)\n    if not os.path.exists(save):\n        os.makedirs(save)\n    time_len = mfcc.shape[0]\n    mfcc_all = []\n    for input_idx in range(int((time_len-28)/4)+1):\n         #   target_idx = input_idx + sample_delay #14\n\n        input_feat = mfcc[4*input_idx:4*input_idx+28,:]\n        mfcc_all.append(input_feat)\n    np.save(os.path.join(save,name+'.npy'), mfcc_all)\n\n    print(input_idx)\n\nif __name__ == \"__main__\":\n    #video alignment\n    video_path = './test/crop/M030_sad_3_001.mp4'\n    out_path = './test/crop/M030_sad_3_001'\n    crop_image_tem(video_path, out_path)\n    \n    #image alignment\n    image_path = './test/raw_image/brade2.jpg'\n    out_path = './test/image/brade2.jpg'\n    crop_image(image_path, out_path)\n    \n    #change_audio_sample_rate\n    src_mouth_path = './test/audio/00015.mp3'\n    dst_audio_path = './test/audio/00015.mov'\n    proc_audio(src_mouth_path, dst_audio_path)\n\n    #audio2mfcc\n    #mead\n    path = './dataset/MEAD/audio/'\n    pathDir = os.listdir(path)\n    for i in range(len(pathDir)):#len(pathDir)\n        name = pathDir[i]\n        filepath = os.path.join(path,name)\n        if os.path.exists(filepath):\n            Dir = os.listdir(filepath)\n            save_path = './dataset/MEAD/MEAD_MFCC/'+name\n            os.makedirs(save_path,exist_ok=True)\n            for j in range(len(Dir)):\n\n                index = Dir[j].split('.')[0]\n                audio_path = os.path.join(filepath,Dir[j])\n                audio2mfcc(audio_path, save_path,index)\n                print(i,name,j,index)\n        else:\n            print('not exist ',filepath)\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch==1.10.1\ntorchvision==0.11.2\nnumpy\nlibrosa\nopencv-python\npython_speech_features\npickle-mixin\nmatplotlib\nscikit-image\nPillow\ntqdm\ndlib\nscipy\npyyaml\nimageio\npandas\n"
  },
  {
    "path": "run.py",
    "content": "import matplotlib\n\nmatplotlib.use('Agg')\n\nimport os, sys\nimport yaml\nfrom argparse import ArgumentParser\nfrom time import gmtime, strftime\nfrom shutil import copy\n\nfrom frames_dataset import MeadDataset, AudioDataset, VoxDataset\n\nfrom modules.generator import OcclusionAwareGenerator\nfrom modules.discriminator import MultiScaleDiscriminator\nfrom modules.keypoint_detector import KPDetector, Audio_Feature, KPDetector_a\nfrom modules.util import AT_net,Emotion_k,get_logger\nimport torch\n\nfrom train import train_part1, train_part1_fine_tune, train_part2\nfrom reconstruction import reconstruction\nfrom animate import animate\n\nif __name__ == \"__main__\":\n\n    if sys.version_info[0] < 3:\n        raise Exception(\"You must use Python 3 or higher. Recommended version is Python 3.7\")\n\n    parser = ArgumentParser()\n    parser.add_argument(\"--config\", default=\"config/train_part1.yaml\", help=\"path to config\")# required=True\n    parser.add_argument(\"--mode\", default=\"train_part1\", choices=[\"train_part1\", \"train_part1_fine_tune\", \"train_part2\"])\n    parser.add_argument(\"--log_dir\", default='log', help=\"path to log into\")\n    parser.add_argument(\"--checkpoint\", default='124_52000.pth.tar', help=\"path to checkpoint to restore\")\n    parser.add_argument(\"--audio_checkpoint\", default=None, help=\"path to audio_checkpoint to restore\")\n    parser.add_argument(\"--emo_checkpoint\", default=None, help=\"path to audio_checkpoint to restore\")\n    parser.add_argument(\"--device_ids\", default=\"0\", type=lambda x: list(map(int, x.split(','))),\n                        help=\"Names of the devices comma separated.\")\n    parser.add_argument(\"--verbose\", dest=\"verbose\", action=\"store_true\", help=\"Print model architecture\")\n    parser.set_defaults(verbose=False)\n\n    opt = parser.parse_args()\n    with open(opt.config) as f:\n        config = yaml.load(f)  \n    \n    name = os.path.basename(opt.config).split('.')[0]\n    if opt.checkpoint is not None:\n   \n        log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])\n        log_dir += ' ' + strftime(\"%d_%m_%y_%H.%M.%S\", gmtime())\n    else:\n        log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])\n        log_dir += ' ' + strftime(\"%d_%m_%y_%H.%M.%S\", gmtime())\n        \n    if not os.path.exists(log_dir):\n        os.makedirs(log_dir)\n    if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):\n        copy(opt.config, log_dir)\n        \n #   logger = get_logger(os.path.join(log_dir, \"log.txt\"))  \n    \n    generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],\n                                        **config['model_params']['common_params'])\n\n    if torch.cuda.is_available():\n        generator.to(opt.device_ids[0])\n\n    if opt.verbose:\n        print(generator)\n\n    discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'],\n                                            **config['model_params']['common_params'])\n    if torch.cuda.is_available():\n        discriminator.to(opt.device_ids[0])\n        \n    \n    \n    if opt.verbose:\n        print(discriminator)\n   \n    kp_detector = KPDetector(**config['model_params']['kp_detector_params'],\n                             **config['model_params']['common_params'])\n\n    kp_detector_a = KPDetector_a(**config['model_params']['kp_detector_params'],\n                             **config['model_params']['audio_params'])\n\n    if torch.cuda.is_available():\n        kp_detector.to(opt.device_ids[0])\n        kp_detector_a.to(opt.device_ids[0])\n\n    audio_feature = AT_net()\n    emo_feature = Emotion_k(block_expansion=32, num_channels=3, max_features=1024,\n                 num_blocks=5, scale_factor=0.25, num_classes=8)\n    \n    if torch.cuda.is_available():\n        audio_feature.to(opt.device_ids[0])\n        emo_feature.to(opt.device_ids[0])\n\n    if opt.verbose:\n        print(kp_detector)\n        print(kp_detector_a)\n        print(audio_feature)\n        print(emo_feature)\n    \n#    logger.info(\"Successfully load models.\")\n    \n    if config['dataset_params']['name'] == 'Vox':\n        dataset = VoxDataset(is_train=True, **config['dataset_params'])\n        test_dataset = VoxDataset(is_train=False, **config['dataset_params'])\n    elif config['dataset_params']['name'] == 'Lrw':\n        dataset = AudioDataset(is_train=True, **config['dataset_params'])\n        test_dataset = AudioDataset(is_train=False, **config['dataset_params'])\n    elif config['dataset_params']['name'] == 'MEAD':\n        dataset = MeadDataset(is_train=True, **config['dataset_params'])\n        test_dataset = MeadDataset(is_train=False, **config['dataset_params'])\n\n\n    \n\n    if opt.mode == 'train_part1':\n        print(\"Training part1...\")\n        train_part1(config, generator, discriminator, kp_detector, kp_detector_a,audio_feature, opt.checkpoint, opt.audio_checkpoint, log_dir, dataset, test_dataset,opt.device_ids, name)\n    elif opt.mode == 'train_part1_fine_tune':\n        print(\"Finetune part1...\")\n        train_part1_fine_tune(config, generator, discriminator, kp_detector, kp_detector_a,audio_feature, opt.checkpoint, opt.audio_checkpoint, log_dir, dataset, test_dataset,opt.device_ids, name)\n    elif opt.mode == 'train_part2':\n        print(\"Training part2...\")\n         train_part2(config, generator, discriminator, kp_detector, emo_feature,kp_detector_a,audio_feature, opt.checkpoint, opt.audio_checkpoint, opt.emo_checkpoint, log_dir, dataset,test_dataset,opt.device_ids, name)\n"
  },
  {
    "path": "sync_batchnorm/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : __init__.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n# \n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nfrom .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d\nfrom .replicate import DataParallelWithCallback, patch_replication_callback\n"
  },
  {
    "path": "sync_batchnorm/batchnorm.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : batchnorm.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n# \n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nimport collections\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch.nn.modules.batchnorm import _BatchNorm\nfrom torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast\n\nfrom .comm import SyncMaster\n\n__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']\n\n\ndef _sum_ft(tensor):\n    \"\"\"sum over the first and last dimention\"\"\"\n    return tensor.sum(dim=0).sum(dim=-1)\n\n\ndef _unsqueeze_ft(tensor):\n    \"\"\"add new dementions at the front and the tail\"\"\"\n    return tensor.unsqueeze(0).unsqueeze(-1)\n\n\n_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])\n_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])\n\n\nclass _SynchronizedBatchNorm(_BatchNorm):\n    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):\n        super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)\n\n        self._sync_master = SyncMaster(self._data_parallel_master)\n\n        self._is_parallel = False\n        self._parallel_id = None\n        self._slave_pipe = None\n\n    def forward(self, input):\n        # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.\n        if not (self._is_parallel and self.training):\n            return F.batch_norm(\n                input, self.running_mean, self.running_var, self.weight, self.bias,\n                self.training, self.momentum, self.eps)\n\n        # Resize the input to (B, C, -1).\n        input_shape = input.size()\n        input = input.view(input.size(0), self.num_features, -1)\n\n        # Compute the sum and square-sum.\n        sum_size = input.size(0) * input.size(2)\n        input_sum = _sum_ft(input)\n        input_ssum = _sum_ft(input ** 2)\n\n        # Reduce-and-broadcast the statistics.\n        if self._parallel_id == 0:\n            mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))\n        else:\n            mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))\n\n        # Compute the output.\n        if self.affine:\n            # MJY:: Fuse the multiplication for speed.\n            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)\n        else:\n            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)\n\n        # Reshape it.\n        return output.view(input_shape)\n\n    def __data_parallel_replicate__(self, ctx, copy_id):\n        self._is_parallel = True\n        self._parallel_id = copy_id\n\n        # parallel_id == 0 means master device.\n        if self._parallel_id == 0:\n            ctx.sync_master = self._sync_master\n        else:\n            self._slave_pipe = ctx.sync_master.register_slave(copy_id)\n\n    def _data_parallel_master(self, intermediates):\n        \"\"\"Reduce the sum and square-sum, compute the statistics, and broadcast it.\"\"\"\n\n        # Always using same \"device order\" makes the ReduceAdd operation faster.\n        # Thanks to:: Tete Xiao (http://tetexiao.com/)\n        intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())\n\n        to_reduce = [i[1][:2] for i in intermediates]\n        to_reduce = [j for i in to_reduce for j in i]  # flatten\n        target_gpus = [i[1].sum.get_device() for i in intermediates]\n\n        sum_size = sum([i[1].sum_size for i in intermediates])\n        sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)\n        mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)\n\n        broadcasted = Broadcast.apply(target_gpus, mean, inv_std)\n\n        outputs = []\n        for i, rec in enumerate(intermediates):\n            outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))\n\n        return outputs\n\n    def _compute_mean_std(self, sum_, ssum, size):\n        \"\"\"Compute the mean and standard-deviation with sum and square-sum. This method\n        also maintains the moving average on the master device.\"\"\"\n        assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'\n        mean = sum_ / size\n        sumvar = ssum - sum_ * mean\n        unbias_var = sumvar / (size - 1)\n        bias_var = sumvar / size\n\n        self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data\n        self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data\n\n        return mean, bias_var.clamp(self.eps) ** -0.5\n\n\nclass SynchronizedBatchNorm1d(_SynchronizedBatchNorm):\n    r\"\"\"Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a\n    mini-batch.\n\n    .. math::\n\n        y = \\frac{x - mean[x]}{ \\sqrt{Var[x] + \\epsilon}} * gamma + beta\n\n    This module differs from the built-in PyTorch BatchNorm1d as the mean and\n    standard-deviation are reduced across all devices during training.\n\n    For example, when one uses `nn.DataParallel` to wrap the network during\n    training, PyTorch's implementation normalize the tensor on each device using\n    the statistics only on that device, which accelerated the computation and\n    is also easy to implement, but the statistics might be inaccurate.\n    Instead, in this synchronized version, the statistics will be computed\n    over all training samples distributed on multiple devices.\n    \n    Note that, for one-GPU or CPU-only case, this module behaves exactly same\n    as the built-in PyTorch implementation.\n\n    The mean and standard-deviation are calculated per-dimension over\n    the mini-batches and gamma and beta are learnable parameter vectors\n    of size C (where C is the input size).\n\n    During training, this layer keeps a running estimate of its computed mean\n    and variance. The running sum is kept with a default momentum of 0.1.\n\n    During evaluation, this running mean/variance is used for normalization.\n\n    Because the BatchNorm is done over the `C` dimension, computing statistics\n    on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm\n\n    Args:\n        num_features: num_features from an expected input of size\n            `batch_size x num_features [x width]`\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Default: 0.1\n        affine: a boolean value that when set to ``True``, gives the layer learnable\n            affine parameters. Default: ``True``\n\n    Shape:\n        - Input: :math:`(N, C)` or :math:`(N, C, L)`\n        - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)\n\n    Examples:\n        >>> # With Learnable Parameters\n        >>> m = SynchronizedBatchNorm1d(100)\n        >>> # Without Learnable Parameters\n        >>> m = SynchronizedBatchNorm1d(100, affine=False)\n        >>> input = torch.autograd.Variable(torch.randn(20, 100))\n        >>> output = m(input)\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.dim() != 2 and input.dim() != 3:\n            raise ValueError('expected 2D or 3D input (got {}D input)'\n                             .format(input.dim()))\n        super(SynchronizedBatchNorm1d, self)._check_input_dim(input)\n\n\nclass SynchronizedBatchNorm2d(_SynchronizedBatchNorm):\n    r\"\"\"Applies Batch Normalization over a 4d input that is seen as a mini-batch\n    of 3d inputs\n\n    .. math::\n\n        y = \\frac{x - mean[x]}{ \\sqrt{Var[x] + \\epsilon}} * gamma + beta\n\n    This module differs from the built-in PyTorch BatchNorm2d as the mean and\n    standard-deviation are reduced across all devices during training.\n\n    For example, when one uses `nn.DataParallel` to wrap the network during\n    training, PyTorch's implementation normalize the tensor on each device using\n    the statistics only on that device, which accelerated the computation and\n    is also easy to implement, but the statistics might be inaccurate.\n    Instead, in this synchronized version, the statistics will be computed\n    over all training samples distributed on multiple devices.\n    \n    Note that, for one-GPU or CPU-only case, this module behaves exactly same\n    as the built-in PyTorch implementation.\n\n    The mean and standard-deviation are calculated per-dimension over\n    the mini-batches and gamma and beta are learnable parameter vectors\n    of size C (where C is the input size).\n\n    During training, this layer keeps a running estimate of its computed mean\n    and variance. The running sum is kept with a default momentum of 0.1.\n\n    During evaluation, this running mean/variance is used for normalization.\n\n    Because the BatchNorm is done over the `C` dimension, computing statistics\n    on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm\n\n    Args:\n        num_features: num_features from an expected input of\n            size batch_size x num_features x height x width\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Default: 0.1\n        affine: a boolean value that when set to ``True``, gives the layer learnable\n            affine parameters. Default: ``True``\n\n    Shape:\n        - Input: :math:`(N, C, H, W)`\n        - Output: :math:`(N, C, H, W)` (same shape as input)\n\n    Examples:\n        >>> # With Learnable Parameters\n        >>> m = SynchronizedBatchNorm2d(100)\n        >>> # Without Learnable Parameters\n        >>> m = SynchronizedBatchNorm2d(100, affine=False)\n        >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))\n        >>> output = m(input)\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.dim() != 4:\n            raise ValueError('expected 4D input (got {}D input)'\n                             .format(input.dim()))\n        super(SynchronizedBatchNorm2d, self)._check_input_dim(input)\n\n\nclass SynchronizedBatchNorm3d(_SynchronizedBatchNorm):\n    r\"\"\"Applies Batch Normalization over a 5d input that is seen as a mini-batch\n    of 4d inputs\n\n    .. math::\n\n        y = \\frac{x - mean[x]}{ \\sqrt{Var[x] + \\epsilon}} * gamma + beta\n\n    This module differs from the built-in PyTorch BatchNorm3d as the mean and\n    standard-deviation are reduced across all devices during training.\n\n    For example, when one uses `nn.DataParallel` to wrap the network during\n    training, PyTorch's implementation normalize the tensor on each device using\n    the statistics only on that device, which accelerated the computation and\n    is also easy to implement, but the statistics might be inaccurate.\n    Instead, in this synchronized version, the statistics will be computed\n    over all training samples distributed on multiple devices.\n    \n    Note that, for one-GPU or CPU-only case, this module behaves exactly same\n    as the built-in PyTorch implementation.\n\n    The mean and standard-deviation are calculated per-dimension over\n    the mini-batches and gamma and beta are learnable parameter vectors\n    of size C (where C is the input size).\n\n    During training, this layer keeps a running estimate of its computed mean\n    and variance. The running sum is kept with a default momentum of 0.1.\n\n    During evaluation, this running mean/variance is used for normalization.\n\n    Because the BatchNorm is done over the `C` dimension, computing statistics\n    on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm\n    or Spatio-temporal BatchNorm\n\n    Args:\n        num_features: num_features from an expected input of\n            size batch_size x num_features x depth x height x width\n        eps: a value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum: the value used for the running_mean and running_var\n            computation. Default: 0.1\n        affine: a boolean value that when set to ``True``, gives the layer learnable\n            affine parameters. Default: ``True``\n\n    Shape:\n        - Input: :math:`(N, C, D, H, W)`\n        - Output: :math:`(N, C, D, H, W)` (same shape as input)\n\n    Examples:\n        >>> # With Learnable Parameters\n        >>> m = SynchronizedBatchNorm3d(100)\n        >>> # Without Learnable Parameters\n        >>> m = SynchronizedBatchNorm3d(100, affine=False)\n        >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))\n        >>> output = m(input)\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.dim() != 5:\n            raise ValueError('expected 5D input (got {}D input)'\n                             .format(input.dim()))\n        super(SynchronizedBatchNorm3d, self)._check_input_dim(input)\n"
  },
  {
    "path": "sync_batchnorm/comm.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : comm.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n# \n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nimport queue\nimport collections\nimport threading\n\n__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']\n\n\nclass FutureResult(object):\n    \"\"\"A thread-safe future implementation. Used only as one-to-one pipe.\"\"\"\n\n    def __init__(self):\n        self._result = None\n        self._lock = threading.Lock()\n        self._cond = threading.Condition(self._lock)\n\n    def put(self, result):\n        with self._lock:\n            assert self._result is None, 'Previous result has\\'t been fetched.'\n            self._result = result\n            self._cond.notify()\n\n    def get(self):\n        with self._lock:\n            if self._result is None:\n                self._cond.wait()\n\n            res = self._result\n            self._result = None\n            return res\n\n\n_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])\n_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])\n\n\nclass SlavePipe(_SlavePipeBase):\n    \"\"\"Pipe for master-slave communication.\"\"\"\n\n    def run_slave(self, msg):\n        self.queue.put((self.identifier, msg))\n        ret = self.result.get()\n        self.queue.put(True)\n        return ret\n\n\nclass SyncMaster(object):\n    \"\"\"An abstract `SyncMaster` object.\n\n    - During the replication, as the data parallel will trigger an callback of each module, all slave devices should\n    call `register(id)` and obtain an `SlavePipe` to communicate with the master.\n    - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,\n    and passed to a registered callback.\n    - After receiving the messages, the master device should gather the information and determine to message passed\n    back to each slave devices.\n    \"\"\"\n\n    def __init__(self, master_callback):\n        \"\"\"\n\n        Args:\n            master_callback: a callback to be invoked after having collected messages from slave devices.\n        \"\"\"\n        self._master_callback = master_callback\n        self._queue = queue.Queue()\n        self._registry = collections.OrderedDict()\n        self._activated = False\n\n    def __getstate__(self):\n        return {'master_callback': self._master_callback}\n\n    def __setstate__(self, state):\n        self.__init__(state['master_callback'])\n\n    def register_slave(self, identifier):\n        \"\"\"\n        Register an slave device.\n\n        Args:\n            identifier: an identifier, usually is the device id.\n\n        Returns: a `SlavePipe` object which can be used to communicate with the master device.\n\n        \"\"\"\n        if self._activated:\n            assert self._queue.empty(), 'Queue is not clean before next initialization.'\n            self._activated = False\n            self._registry.clear()\n        future = FutureResult()\n        self._registry[identifier] = _MasterRegistry(future)\n        return SlavePipe(identifier, self._queue, future)\n\n    def run_master(self, master_msg):\n        \"\"\"\n        Main entry for the master device in each forward pass.\n        The messages were first collected from each devices (including the master device), and then\n        an callback will be invoked to compute the message to be sent back to each devices\n        (including the master device).\n\n        Args:\n            master_msg: the message that the master want to send to itself. This will be placed as the first\n            message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.\n\n        Returns: the message to be sent back to the master device.\n\n        \"\"\"\n        self._activated = True\n\n        intermediates = [(0, master_msg)]\n        for i in range(self.nr_slaves):\n            intermediates.append(self._queue.get())\n\n        results = self._master_callback(intermediates)\n        assert results[0][0] == 0, 'The first result should belongs to the master.'\n\n        for i, res in results:\n            if i == 0:\n                continue\n            self._registry[i].result.put(res)\n\n        for i in range(self.nr_slaves):\n            assert self._queue.get() is True\n\n        return results[0][1]\n\n    @property\n    def nr_slaves(self):\n        return len(self._registry)\n"
  },
  {
    "path": "sync_batchnorm/replicate.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : replicate.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n# \n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nimport functools\n\nfrom torch.nn.parallel.data_parallel import DataParallel\n\n__all__ = [\n    'CallbackContext',\n    'execute_replication_callbacks',\n    'DataParallelWithCallback',\n    'patch_replication_callback'\n]\n\n\nclass CallbackContext(object):\n    pass\n\n\ndef execute_replication_callbacks(modules):\n    \"\"\"\n    Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.\n\n    The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`\n\n    Note that, as all modules are isomorphism, we assign each sub-module with a context\n    (shared among multiple copies of this module on different devices).\n    Through this context, different copies can share some information.\n\n    We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback\n    of any slave copies.\n    \"\"\"\n    master_copy = modules[0]\n    nr_modules = len(list(master_copy.modules()))\n    ctxs = [CallbackContext() for _ in range(nr_modules)]\n\n    for i, module in enumerate(modules):\n        for j, m in enumerate(module.modules()):\n            if hasattr(m, '__data_parallel_replicate__'):\n                m.__data_parallel_replicate__(ctxs[j], i)\n\n\nclass DataParallelWithCallback(DataParallel):\n    \"\"\"\n    Data Parallel with a replication callback.\n\n    An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by\n    original `replicate` function.\n    The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`\n\n    Examples:\n        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)\n        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])\n        # sync_bn.__data_parallel_replicate__ will be invoked.\n    \"\"\"\n\n    def replicate(self, module, device_ids):\n        modules = super(DataParallelWithCallback, self).replicate(module, device_ids)\n        execute_replication_callbacks(modules)\n        return modules\n\n\ndef patch_replication_callback(data_parallel):\n    \"\"\"\n    Monkey-patch an existing `DataParallel` object. Add the replication callback.\n    Useful when you have customized `DataParallel` implementation.\n\n    Examples:\n        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)\n        > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])\n        > patch_replication_callback(sync_bn)\n        # this is equivalent to\n        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)\n        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])\n    \"\"\"\n\n    assert isinstance(data_parallel, DataParallel)\n\n    old_replicate = data_parallel.replicate\n\n    @functools.wraps(old_replicate)\n    def new_replicate(module, device_ids):\n        modules = old_replicate(module, device_ids)\n        execute_replication_callbacks(modules)\n        return modules\n\n    data_parallel.replicate = new_replicate\n"
  },
  {
    "path": "sync_batchnorm/unittest.py",
    "content": "# -*- coding: utf-8 -*-\n# File   : unittest.py\n# Author : Jiayuan Mao\n# Email  : maojiayuan@gmail.com\n# Date   : 27/01/2018\n# \n# This file is part of Synchronized-BatchNorm-PyTorch.\n# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch\n# Distributed under MIT License.\n\nimport unittest\n\nimport numpy as np\nfrom torch.autograd import Variable\n\n\ndef as_numpy(v):\n    if isinstance(v, Variable):\n        v = v.data\n    return v.cpu().numpy()\n\n\nclass TorchTestCase(unittest.TestCase):\n    def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):\n        npa, npb = as_numpy(a), as_numpy(b)\n        self.assertTrue(\n                np.allclose(npa, npb, atol=atol),\n                'Tensor close check failed\\n{}\\n{}\\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())\n        )\n"
  },
  {
    "path": "train.py",
    "content": "from tqdm import trange\nimport torch\nimport torch.nn as nn\nfrom torch.utils.data import DataLoader\n\nfrom logger import Logger\nfrom modules.model import DiscriminatorFullModel, TrainPart1Model, TrainPart2Model\nimport itertools\n\nfrom torch.optim.lr_scheduler import MultiStepLR\n\nfrom sync_batchnorm import DataParallelWithCallback\n\nfrom frames_dataset import DatasetRepeater,TestsetRepeater\nimport time\nfrom tensorboardX import SummaryWriter\n\ndef train_part1(config, generator, discriminator, kp_detector, kp_detector_a,audio_feature, checkpoint, audio_checkpoint, log_dir, dataset, test_dataset, device_ids, name):\n    train_params = config['train_params']\n\n    optimizer_audio_feature = torch.optim.Adam(itertools.chain(audio_feature.parameters(),kp_detector_a.parameters()), lr=train_params['lr_audio_feature'], betas=(0.5, 0.999))\n\n\n    if checkpoint is not None:\n        start_epoch = Logger.load_cpk(checkpoint, generator, discriminator, kp_detector, audio_feature,\n                                      optimizer_generator, optimizer_discriminator,\n                                      None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector,\n                                      None if train_params['lr_audio_feature'] == 0 else optimizer_audio_feature)\n    if audio_checkpoint is not None:\n        pretrain = torch.load(audio_checkpoint)\n        kp_detector_a.load_state_dict(pretrain['kp_detector_a'])\n        audio_feature.load_state_dict(pretrain['audio_feature'])\n        optimizer_audio_feature.load_state_dict(pretrain['optimizer_audio_feature'])\n        start_epoch = pretrain['epoch']\n\n    else:\n        start_epoch = 0\n\n \n    scheduler_audio_feature = MultiStepLR(optimizer_audio_feature, train_params['epoch_milestones'], gamma=0.1,\n                                        last_epoch=-1 + start_epoch * (train_params['lr_audio_feature'] != 0))\n\n    if 'num_repeats' in train_params or train_params['num_repeats'] != 1:\n        dataset = DatasetRepeater(dataset, train_params['num_repeats'])\n        test_dataset = TestsetRepeater(test_dataset, train_params['num_repeats'])\n    dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=0, drop_last=True)#6\n    test_dataloader = DataLoader(test_dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=0, drop_last=True)#6\n    num_steps_per_epoch = len(dataloader)\n    num_steps_test_epoch = len(test_dataloader)\n    generator_full = TrainPart1Model(kp_detector, kp_detector_a, audio_feature, generator, discriminator, train_params,device_ids)\n    discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params)\n   \n    if len(device_ids)>1:\n        generator_full=torch.nn.DataParallel(generator_full)\n        discriminator_full=torch.nn.DataParallel(discriminator_full)\n\n    if torch.cuda.is_available():\n        if len(device_ids) == 1:\n            generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids)\n            discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids)\n        elif len(device_ids)>1:\n            generator_full = generator_full.to(device_ids[0])\n            discriminator_full = discriminator_full.to(device_ids[0])\n\n    step = 0\n    t0 = time.time()\n\n    writer=SummaryWriter(comment=name)\n    train_itr=0\n    test_itr=0\n    with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger:\n        for epoch in trange(start_epoch, train_params['num_epochs']):\n\n            for x in dataloader:\n\n                losses_generator, generated = generator_full(x)\n        \n                loss_values = [val.mean() for val in losses_generator.values()]\n                loss = sum(loss_values)\n\n                writer.add_scalar('Train',loss,train_itr)\n\n                writer.add_scalar('Train_value',loss_values[0],train_itr)\n                writer.add_scalar('Train_heatmap',loss_values[1],train_itr)\n                writer.add_scalar('Train_jacobian',loss_values[2],train_itr)\n\n                train_itr+=1\n                loss.backward()\n\n              \n                optimizer_audio_feature.step()\n                optimizer_audio_feature.zero_grad()\n                d = time.time()\n         \n                # if train_params['loss_weights']['generator_gan'] != 0:\n                #     optimizer_discriminator.zero_grad()\n                # else:\n                #     losses_discriminator = {}\n\n                # losses_generator.update(losses_discriminator)\n                losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()}\n                logger.log_iter(losses=losses)\n                e = time.time()\n          \n                step += 1\n             \n                if(step % 500 == 0):\n                    \n                    logger.log_epoch(epoch,step, {'audio_feature': audio_feature,\n                                     'kp_detector_a':kp_detector_a,\n                                     'optimizer_audio_feature': optimizer_audio_feature}, inp=x, out=generated)\n\n            scheduler_audio_feature.step()\n\n\n            for x in test_dataloader:\n                with torch.no_grad():\n                    losses_generator, generated = generator_full(x)\n\n                    loss_values = [val.mean() for val in losses_generator.values()]\n                    loss = sum(loss_values)\n\n                    writer.add_scalar('Test',loss,test_itr)\n\n                    writer.add_scalar('Test_value',loss_values[0],test_itr)\n                    writer.add_scalar('Test_heatmap',loss_values[1],test_itr)\n                    writer.add_scalar('Test_jacobian',loss_values[2],test_itr)\n\n                    test_itr+=1\n\n\n              \ndef train_part1_fine_tune(config, generator, discriminator, kp_detector, kp_detector_a,audio_feature, checkpoint, audio_checkpoint, log_dir, dataset, dataset2, test_dataset, device_ids, name):\n    train_params = config['train_params']\n\n    optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999))\n    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params['lr_discriminator'], betas=(0.5, 0.999))\n    optimizer_audio_feature = torch.optim.Adam(itertools.chain(audio_feature.parameters(),kp_detector_a.parameters()), lr=train_params['lr_audio_feature'], betas=(0.5, 0.999))\n  #  optimizer_kp_detector_a = torch.optim.Adam(kp_detector_a.parameters(), lr=train_params['lr_audio_feature'], betas=(0.5, 0.999))\n\n    if checkpoint is not None:\n        start_epoch = Logger.load_cpk(checkpoint, generator, discriminator, kp_detector, audio_feature,\n                                      optimizer_generator, optimizer_discriminator,\n                                      None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector,\n                                      None if train_params['lr_audio_feature'] == 0 else optimizer_audio_feature)\n    if audio_checkpoint is not None:\n        pretrain = torch.load(audio_checkpoint)\n        kp_detector_a.load_state_dict(pretrain['kp_detector_a'])\n        audio_feature.load_state_dict(pretrain['audio_feature'])\n   #     optimizer_kp_detector_a.load_state_dict(pretrain['optimizer_kp_detector_a'])\n        optimizer_audio_feature.load_state_dict(pretrain['optimizer_audio_feature'])\n        start_epoch = pretrain['epoch']\n\n\n    else:\n        start_epoch = 0\n\n    scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1,\n                                      last_epoch=start_epoch - 1)\n    scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1,\n                                          last_epoch=start_epoch - 1)\n    scheduler_audio_feature = MultiStepLR(optimizer_audio_feature, train_params['epoch_milestones'], gamma=0.1,\n                                        last_epoch=-1 + start_epoch * (train_params['lr_audio_feature'] != 0))\n\n    if 'num_repeats' in train_params or train_params['num_repeats'] != 1:\n        dataset = DatasetRepeater(dataset, train_params['num_repeats'])\n        test_dataset = TestsetRepeater(test_dataset, train_params['num_repeats'])\n    dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=0, drop_last=True)#6\n    test_dataloader = DataLoader(test_dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=0, drop_last=True)#6\n    num_steps_per_epoch = len(dataloader)\n    num_steps_test_epoch = len(test_dataloader)\n    generator_full = TrainFullModel(kp_detector, kp_detector_a, audio_feature, generator, discriminator, train_params,device_ids)\n    discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params)\n    print('End dataload ', file=open('log/MEAD_LRW_test_a.txt', 'a'))\n    if len(device_ids)>1:\n        generator_full=torch.nn.DataParallel(generator_full)\n        discriminator_full=torch.nn.DataParallel(discriminator_full)\n\n    if torch.cuda.is_available():\n        if len(device_ids) == 1:\n            generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids)\n            discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids)\n        elif len(device_ids)>1:\n            generator_full = generator_full.to(device_ids[0])\n            discriminator_full = discriminator_full.to(device_ids[0])\n\n    step = 0\n    t0 = time.time()\n\n    writer=SummaryWriter(comment=name)\n    train_itr=0\n    test_itr=0\n    with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger:\n        for epoch in trange(start_epoch, train_params['num_epochs']):\n          \n            for x in dataloader:\n      \n         \n                losses_generator, generated = generator_full(x)\n             \n                loss_values = [val.mean() for val in losses_generator.values()]\n                loss = sum(loss_values)\n\n                writer.add_scalar('Train',loss,train_itr)\n\n                writer.add_scalar('Train_value',loss_values[0],train_itr)\n                writer.add_scalar('Train_heatmap',loss_values[1],train_itr)\n                writer.add_scalar('Train_jacobian',loss_values[2],train_itr)\n                writer.add_scalar('Train_perceptual',loss_values[3],train_itr)\n\n\n                train_itr+=1\n                loss.backward()\n\n          \n\n                optimizer_audio_feature.step()\n                optimizer_audio_feature.zero_grad()\n        \n                optimizer_generator.step()\n                optimizer_generator.zero_grad()\n            #    optimizer_kp_detector_a.step()\n            #    optimizer_kp_detector_a.zero_grad()\n            \n                if train_params['loss_weights']['discriminator_gan'] != 0:\n                    optimizer_discriminator.zero_grad()\n           #         losses_discriminator = discriminator_full(x, generated)\n           #         loss_values = [val.mean() for val in losses_discriminator.values()]\n           #         loss = sum(loss_values)\n\n           #         loss.backward()\n           #         optimizer_discriminator.step()\n           #         optimizer_discriminator.zero_grad()\n                else:\n                    losses_discriminator = {}\n\n                losses_generator.update(losses_discriminator)\n                losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()}\n                logger.log_iter(losses=losses)\n         \n                step += 1\n             \n                if(step % 500 == 0):\n       \n                    logger.log_epoch(epoch,step, {'audio_feature': audio_feature,\n                                     'kp_detector_a':kp_detector_a,\n                                     'generator': generator,\n                                     'optimizer_generator':optimizer_generator,\n                                     'optimizer_audio_feature': optimizer_audio_feature}, inp=x, out=generated)\n               \n            scheduler_generator.step()\n            scheduler_discriminator.step()\n            scheduler_audio_feature.step()\n\n           \n            for x in test_dataloader:\n                with torch.no_grad():\n                    losses_generator, generated = generator_full(x)\n\n                    loss_values = [val.mean() for val in losses_generator.values()]\n                    loss = sum(loss_values)\n\n                    writer.add_scalar('Test',loss,test_itr)\n\n                    writer.add_scalar('Test_value',loss_values[0],test_itr)\n                    writer.add_scalar('Test_heatmap',loss_values[1],test_itr)\n                    writer.add_scalar('Test_jacobian',loss_values[2],test_itr)\n                    writer.add_scalar('Test_perceptual',loss_values[3],test_itr)\n\n                    test_itr+=1\n\n\ndef train_part2(config, generator, discriminator, kp_detector, emo_detector, kp_detector_a,audio_feature, checkpoint, audio_checkpoint, emo_checkpoint, log_dir, dataset, test_dataset, device_ids, exp_name):\n    train_params = config['train_params']\n  \n    optimizer_emo_detector = torch.optim.Adam(emo_detector.parameters(), lr=train_params['lr_audio_feature'], betas=(0.5, 0.999))\n \n    if checkpoint is not None:\n        start_epoch = Logger.load_cpk(checkpoint, generator, discriminator, kp_detector, audio_feature,\n                                      optimizer_generator, optimizer_discriminator,\n                                      None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector,\n                                      None if train_params['lr_audio_feature'] == 0 else optimizer_audio_feature)\n    if emo_checkpoint is not None:\n        pretrain = torch.load(emo_checkpoint)\n        tgt_state = emo_detector.state_dict()\n        strip = 'module.'\n        if 'emo_detector' in pretrain:\n            emo_detector.load_state_dict(pretrain['emo_detector'])\n            optimizer_emo_detector.load_state_dict(pretrain['optimizer_emo_detector'])\n            print('emo_detector in pretrain + load', file=open('log/'+exp_name+'.txt', 'a'))\n        for name, param in pretrain.items():\n            if isinstance(param, nn.Parameter):\n                param = param.data\n            if strip is not None and name.startswith(strip):\n                name = name[len(strip):]\n            if name not in tgt_state:\n                continue\n            tgt_state[name].copy_(param)\n            print(name)\n    if audio_checkpoint is not None:\n        pretrain = torch.load(audio_checkpoint)\n        kp_detector_a.load_state_dict(pretrain['kp_detector_a'])\n        audio_feature.load_state_dict(pretrain['audio_feature'])\n        optimizer_audio_feature.load_state_dict(pretrain['optimizer_audio_feature'])\n        if 'emo_detector' in pretrain:\n            emo_detector.load_state_dict(pretrain['emo_detector'])\n            optimizer_emo_detector.load_state_dict(pretrain['optimizer_emo_detector'])\n        start_epoch = pretrain['epoch']\n   \n    else:\n        start_epoch = 0\n\n\n    scheduler_emo_detector = MultiStepLR(optimizer_emo_detector, train_params['epoch_milestones'], gamma=0.1,\n                                        last_epoch=-1 + start_epoch * (train_params['lr_audio_feature'] != 0))\n\n    if 'num_repeats' in train_params or train_params['num_repeats'] != 1:\n        dataset = DatasetRepeater(dataset, train_params['num_repeats'])\n        test_dataset = TestsetRepeater(test_dataset, train_params['num_repeats'])\n    dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=0, drop_last=True)#6\n    test_dataloader = DataLoader(test_dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=0, drop_last=True)#6\n    num_steps_per_epoch = len(dataloader)\n    num_steps_test_epoch = len(test_dataloader)\n    generator_full = TrainPart2Model(kp_detector, emo_detector,kp_detector_a, audio_feature,generator, discriminator, train_params,device_ids)\n    discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params)\n\n    if len(device_ids)>1:\n        generator_full=torch.nn.DataParallel(generator_full)\n        discriminator_full=torch.nn.DataParallel(discriminator_full)\n        \n    if torch.cuda.is_available():\n        if len(device_ids) == 1:\n            generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids)\n            discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids)\n        elif len(device_ids)>1:\n            generator_full = generator_full.to(device_ids[0])\n            discriminator_full = discriminator_full.to(device_ids[0])\n    \n    step = 0\n    t0 = time.time()\n    \n    writer=SummaryWriter(comment=exp_name)\n    train_itr=0\n    test_itr=0\n    with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger:\n        for epoch in trange(start_epoch, train_params['num_epochs']):\n         \n            for x in dataloader:\n      \n                losses_generator, generated = generator_full(x)\n               \n                loss_values = [val.mean() for val in losses_generator.values()]\n                loss = sum(loss_values)\n                \n                writer.add_scalar('Train',loss,train_itr)\n                \n                writer.add_scalar('Train_value',loss_values[0],train_itr)\n            #    writer.add_scalar('Train_heatmap',loss_values[1],train_itr)\n                writer.add_scalar('Train_jacobian',loss_values[1],train_itr)\n                writer.add_scalar('Train_classify',loss_values[2],train_itr)\n               \n                \n                \n                train_itr+=1\n                loss.backward()\n                \n  \n                optimizer_emo_detector.step()\n                optimizer_emo_detector.zero_grad()\n           \n           \n                losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()}\n                logger.log_iter(losses=losses)\n             \n                step += 1\n               \n                if(step % 1000 == 0):\n                    \n                    logger.log_epoch(epoch,step, {'audio_feature': audio_feature,\n                                     'kp_detector_a':kp_detector_a,\n                                     'emo_detector':emo_detector,\n                                     'optimizer_emo_detector': optimizer_emo_detector,\n             #                        'optimizer_kp_detector_a':optimizer_kp_detector_a,\n                                     'optimizer_audio_feature': optimizer_audio_feature}, inp=x, out=generated)\n                \n            scheduler_emo_detector.step()\n            \n        \n            for x in test_dataloader:\n                with torch.no_grad():\n                    losses_generator, generated = generator_full(x)\n\n                    loss_values = [val.mean() for val in losses_generator.values()]\n                    loss = sum(loss_values)\n                \n                    writer.add_scalar('Test',loss,test_itr)\n                   \n                    writer.add_scalar('Test_value',loss_values[0],test_itr)\n                #    writer.add_scalar('Test_heatmap',loss_values[1],test_itr)\n                    writer.add_scalar('Test_jacobian',loss_values[1],test_itr)\n                    writer.add_scalar('Test_classify',loss_values[2],test_itr)\n                    \n                \n                    test_itr+=1\n                \n                   \n            \n\n\n"
  }
]