[
  {
    "path": "License.txt",
    "content": "Nvidia Source Code License-NC\n\n1. Definitions\n\n“Licensor” means any person or entity that distributes its Work.\n\n“Software” means the original work of authorship made available under this License.\n“Work” means the Software and any additions to or derivative works of the Software that are made available under this License.\n\n“Nvidia Processors” means any central processing unit (CPU), graphics processing unit (GPU), field-programmable gate array (FPGA), application-specific integrated circuit (ASIC) or any combination thereof designed, made, sold, or provided by Nvidia or its affiliates.\n\nThe terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.\n\nWorks, including the Software, are “made available” under this License by including in or with the Work either (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License.\n\n2. License Grants\n\n2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.\n\n3. Limitations\n\n3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you include a complete copy of this License with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.\n\n3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.\n\n3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially.  The Work or derivative works thereof may be used or intended for use by Nvidia or its affiliates commercially or non-commercially.  As used herein, “non-commercially” means for research or evaluation purposes only.\n\n3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this License from such Licensor (including the grants in Sections 2.1 and 2.2) will terminate immediately.\n\n3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this License.\n\n3.6 Termination. If you violate any term of this License, then your rights under this License (including the grants in Sections 2.1 and 2.2) will terminate immediately.\n\n4. Disclaimer of Warranty.\n\nTHE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.\n\n5. Limitation of Liability.\n\nEXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.\n\n"
  },
  {
    "path": "README.md",
    "content": "![Python 2.7](https://img.shields.io/badge/python-2.7-green.svg)\n![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg)\n## Dancing to Music\nPyTorch implementation of the cross-modality generative model that synthesizes dance from music.\n\n\n### Paper \n[Hsin-Ying Lee](http://vllab.ucmerced.edu/hylee/), [Xiaodong Yang](https://xiaodongyang.org/), [Ming-Yu Liu](http://mingyuliu.net/), [Ting-Chun Wang](https://tcwang0509.github.io/), [Yu-Ding Lu](https://jonlu0602.github.io/), [Ming-Hsuan Yang](https://faculty.ucmerced.edu/mhyang/), [Jan Kautz](http://jankautz.com/)  \nDancing to Music\nNeural Information Processing Systems (**NeurIPS**) 2019     \n[[Paper]](https://arxiv.org/abs/1911.02001) [[YouTube]](https://youtu.be/-e9USqfwZ4A) [[Project]](http://vllab.ucmerced.edu/hylee/Dancing2Music/script.txt) [[Blog]](https://news.developer.nvidia.com/nvidia-dance-to-music-neurips/) [[Supp]](http://xiaodongyang.org/publications/papers/dance2music-supp-neurips19.pdf)\n\n### Example Videos\n- Beat-Matching     \n1st row: generated dance sequences, 2nd row: music beats, 3rd row: kinematics beats         \n<p align='left'>\n  <img src='imgs/example.gif' width='400'/>\n</p>\n\n- Multimodality    \nGenerate various dance sequences with the same music and the same initial pose. \n<p align='left'>\n  <img src='imgs/multimodal.gif' width='400'/>\n</p>\n\n- Long-Term Generation    \nSeamlessly generate a dance sequence with arbitrary length. \n<p align='left'>\n  <kbd>\n   <img src='imgs/long.gif' width='300'/>\n  </kbd>\n</p>\n\n- Photo-Realisitc Videos    \nMap generated dance sequences to photo-realistic videos.\n<p align='left'>\n  <img src='imgs/v2v.gif' width='800'/>\n</p>\n\n\n## Train Decomposition \n```\npython train_decomp.py --name Decomp\n```\n\n## Train Composition \n```\npython train_comp.py --name Decomp --decomp_snapshot DECOMP_SNAPSHOT\n```\n\n## Demo\n```\npython demo.py --decomp_snapshot DECOMP_SNAPSHOT --comp_snapshot COMP_SNAPSHOT --aud_path AUD_PATH --out_file OUT_FILE --out_dir OUT_DIR --thr THR\n```\n- Flags\n  - `aud_path`: input .wav file\n  - `out_file`: location of output .mp4 file\n  - `out_dir`: directory of output frames\n  - `thr`: threshold based on motion magnitude\n  - `modulate`: whether to do beat warping\n\n- Example\n```\npython demo.py -decomp_snapshot snapshot/Stage1.ckpt --comp_snapshot snapshot/Stage2.ckpt --aud_path demo/demo.wav --out_file demo/out.mp4 --out_dir demo/out_frame\n```\n\n\n### Citation\nIf you find this code useful for your research, please cite our paper:\n```bibtex\n@inproceedings{lee2019dancing2music,\n  title={Dancing to Music},\n  author={Lee, Hsin-Ying and Yang, Xiaodong and Liu, Ming-Yu and Wang, Ting-Chun and Lu, Yu-Ding and Yang, Ming-Hsuan and Kautz, Jan},\n  booktitle={NeurIPS},\n  year={2019}\n}\n```\n\n### License\nCopyright (C) 2020 NVIDIA Corporation. All rights reserved. This work is made available under NVIDIA Source Code License (1-Way Commercial). To view a copy of this license, visit https://nvlabs.github.io/Dancing2Music/LICENSE.txt. \n"
  },
  {
    "path": "data.py",
    "content": "# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.\n#\n# This work is made available\n# under the Nvidia Source Code License (1-way Commercial).\n# To view a copy of this license, visit\n# https://nvlabs.github.io/Dancing2Music/License.txt\nimport os  \nimport pickle\nimport numpy as np\nimport random\nimport torch.utils.data\nfrom torchvision.datasets import ImageFolder\nimport utils\n\n\nclass PoseDataset(torch.utils.data.Dataset):\n  def __init__(self, data_dir, tolerance=False):\n    self.data_dir = data_dir\n    z_fname = '{}/unitList/zumba_unit.txt'.format(data_dir)\n    b_fname = '{}/unitList/ballet_unit.txt'.format(data_dir)\n    h_fname = '{}/unitList/hiphop_unit.txt'.format(data_dir)\n    self.z_data = []\n    self.b_data = []\n    self.h_data = []\n    with open(z_fname, 'r') as f:\n      for line in f:\n        self.z_data.append([s for s in line.strip().split(' ')])\n    with open(b_fname, 'r') as f:\n      for line in f:\n        self.b_data.append([s for s in line.strip().split(' ')])\n    with open(h_fname, 'r') as f:\n      for line in f:\n        self.h_data.append([s for s in line.strip().split(' ')])\n    self.data = [self.z_data, self.b_data, self.h_data]\n\n    self.tolerance = tolerance\n    if self.tolerance:\n      z3_fname = '{}/unitList/zumba_unitseq3.txt'.format(data_dir)\n      b3_fname = '{}/unitList/ballet_unitseq3.txt'.format(data_dir)\n      h3_fname = '{}/unitList/hiphop_unitseq3.txt'.format(data_dir)\n      z4_fname = '{}/unitList/zumba_unitseq4.txt'.format(data_dir)\n      b4_fname = '{}/unitList/ballet_unitseq4.txt'.format(data_dir)\n      h4_fname = '{}/unitList/hiphop_unitseq4.txt'.format(data_dir)\n      z3_data = []; b3_data = []; h3_data = []; z4_data = []; b4_data = []; h4_data = []\n      with open(z3_fname, 'r') as f:\n        for line in f:\n          z3_data.append([s for s in line.strip().split(' ')])\n      with open(b3_fname, 'r') as f:\n        for line in f:\n          b3_data.append([s for s in line.strip().split(' ')])\n      with open(h3_fname, 'r') as f:\n        for line in f:\n          h3_data.append([s for s in line.strip().split(' ')])\n      with open(z4_fname, 'r') as f:\n        for line in f:\n          z4_data.append([s for s in line.strip().split(' ')])\n      with open(b4_fname, 'r') as f:\n        for line in f:\n          b4_data.append([s for s in line.strip().split(' ')])\n      with open(h4_fname, 'r') as f:\n        for line in f:\n          h4_data.append([s for s in line.strip().split(' ')])\n      self.zt_data = z3_data + z4_data\n      self.bt_data = b3_data + b4_data\n      self.ht_data = h3_data + h4_data\n      self.t_data = [self.zt_data, self.bt_data, self.ht_data]\n\n    self.mean_pose=np.load(data_dir+'/stats/all_onbeat_mean.npy')\n    self.std_pose=np.load(data_dir+'/stats/all_onbeat_std.npy')\n\n  def __getitem__(self, index):\n    cls = random.randint(0,2)\n    cls = random.randint(0,1)\n    if self.tolerance and random.randint(0,9)==0:\n      index = random.randint(0, len(self.t_data[cls])-1)\n      path = self.t_data[cls][index][0]\n      path = os.path.join(self.data_dir, path[5:])\n      orig_poses = np.load(path)\n      sel = random.randint(0, orig_poses.shape[0]-1)\n      orig_poses = orig_poses[sel]\n    else:\n      index = random.randint(0, len(self.data[cls])-1)\n      path = self.data[cls][index][0]\n      path = os.path.join(self.data_dir, path[5:])\n      orig_poses = np.load(path)\n\n    xjit = np.random.uniform(low=-50, high=50)\n    yjit = np.random.uniform(low=-20, high=20)\n    poses = orig_poses.copy()\n    poses[:,:,0] += xjit\n    poses[:,:,1] += yjit\n    xjit = np.random.uniform(low=-50, high=50)\n    yjit = np.random.uniform(low=-20, high=20)\n    poses2 = orig_poses.copy()\n    poses2[:,:,0] += xjit\n    poses2[:,:,1] += yjit\n\n    poses = poses.reshape(poses.shape[0], poses.shape[1]*poses.shape[2])\n    poses2 = poses2.reshape(poses2.shape[0], poses2.shape[1]*poses2.shape[2])\n    for i in range(poses.shape[0]):\n      poses[i] = (poses[i]-self.mean_pose)/self.std_pose\n      poses2[i] = (poses2[i]-self.mean_pose)/self.std_pose\n\n    return torch.Tensor(poses), torch.Tensor(poses2)\n\n  def __len__(self):\n    return len(self.z_data)+len(self.b_data)\n\n\nclass MovementAudDataset(torch.utils.data.Dataset):\n  def __init__(self, data_dir):\n    self.data_dir = data_dir\n    z3_fname = '{}/unitList/zumba_unitseq3.txt'.format(data_dir)\n    b3_fname = '{}/unitList/ballet_unitseq3.txt'.format(data_dir)\n    h3_fname = '{}/unitList/hiphop_unitseq3.txt'.format(data_dir)\n    z4_fname = '{}/unitList/zumba_unitseq4.txt'.format(data_dir)\n    b4_fname = '{}/unitList/ballet_unitseq4.txt'.format(data_dir)\n    h4_fname = '{}/unitList/hiphop_unitseq4.txt'.format(data_dir)\n    self.z3_data = []\n    self.b3_data = []\n    self.h3_data = []\n    self.z4_data = []\n    self.b4_data = []\n    self.h4_data = []\n    with open(z3_fname, 'r') as f:\n      for line in f:\n        self.z3_data.append([s for s in line.strip().split(' ')])\n    with open(b3_fname, 'r') as f:\n      for line in f:\n        self.b3_data.append([s for s in line.strip().split(' ')])\n    with open(h3_fname, 'r') as f:\n      for line in f:\n        self.h3_data.append([s for s in line.strip().split(' ')])\n    with open(z4_fname, 'r') as f:\n      for line in f:\n        self.z4_data.append([s for s in line.strip().split(' ')])\n    with open(b4_fname, 'r') as f:\n      for line in f:\n        self.b4_data.append([s for s in line.strip().split(' ')])\n    with open(h4_fname, 'r') as f:\n      for line in f:\n        self.h4_data.append([s for s in line.strip().split(' ')])\n    self.data_3 = [self.z3_data, self.b3_data, self.h3_data]\n    self.data_4 = [self.z4_data, self.b4_data, self.h4_data]\n\n    z_data_root = 'zumba/'\n    b_data_root = 'ballet/'\n    h_data_root = 'hiphop/'\n    self.data_root = [z_data_root, b_data_root, h_data_root ]\n    self.mean_pose=np.load(data_dir+'/stats/all_onbeat_mean.npy')\n    self.std_pose=np.load(data_dir+'/stats/all_onbeat_std.npy')\n    self.mean_aud=np.load(data_dir+'/stats/all_aud_mean.npy')\n    self.std_aud=np.load(data_dir+'/stats/all_aud_std.npy')\n\n  def __getitem__(self, index):\n    cls = random.randint(0,2)\n    cls = random.randint(0,1)\n    isthree = random.randint(0,1)\n\n    if isthree == 0:\n      index = random.randint(0, len(self.data_4[cls])-1)\n      path = self.data_4[cls][index][0]\n    else:\n      index = random.randint(0, len(self.data_3[cls])-1)\n      path = self.data_3[cls][index][0]\n    path = os.path.join(self.data_dir, path[5:])\n    stdpSeq = np.load(path)\n    vid, cid = path.split('/')[-4], path.split('/')[-3]\n    #vid, cid = vid_cid[:11], vid_cid[12:]\n    aud = np.load('{}/{}/{}/{}/aud/c{}_fps15.npy'.format(self.data_dir, self.data_root[cls], vid, cid, cid))\n\n    stdpSeq = stdpSeq.reshape(stdpSeq.shape[0], stdpSeq.shape[1], stdpSeq.shape[2]*stdpSeq.shape[3])\n    for i in range(stdpSeq.shape[0]):\n      for j in range(stdpSeq.shape[1]):\n        stdpSeq[i,j] = (stdpSeq[i,j]-self.mean_pose)/self.std_pose\n    if isthree == 0:\n      start = random.randint(0,1)\n      stdpSeq = stdpSeq[start:start+3]\n\n    for i in range(aud.shape[0]):\n      aud[i] = (aud[i]-self.mean_aud)/self.std_aud\n    aud = aud[:30]\n    return torch.Tensor(stdpSeq), torch.Tensor(aud)\n\n  def __len__(self):\n    return len(self.z3_data)+len(self.b3_data)+len(self.z4_data)+len(self.b4_data)+len(self.h3_data)+len(self.h4_data)\n\ndef get_loader(batch_size, shuffle, num_workers, dataset, data_dir, tolerance=False):\n  if dataset == 0:\n    a2d = PoseDataset(data_dir, tolerance)\n  elif dataset == 2:\n    a2d = MovementAudDataset(data_dir)\n  data_loader = torch.utils.data.DataLoader(dataset=a2d,\n                                            batch_size=batch_size,\n                                            shuffle=shuffle,\n                                            num_workers=num_workers,\n                                            )\n  return data_loader\n"
  },
  {
    "path": "demo.py",
    "content": "# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.\n#\n# This work is made available\n# under the Nvidia Source Code License (1-way Commercial).\n# To view a copy of this license, visit\n# https://nvlabs.github.io/Dancing2Music/License.txt\nimport os\nimport argparse\nimport functools\nimport librosa\nimport shutil\nimport sys \nsys.path.insert(0, 'preprocess')\nimport preprocess as p\nimport subprocess as sp\nfrom shutil import copyfile\n\nimport torch\nfrom torch.utils.data import DataLoader\nfrom torchvision import transforms\n\nfrom model_comp import *\nfrom networks import *\nfrom options import TestOptions\nimport modulate\nimport utils\n\ndef loadDecompModel(args):\n  initp_enc = InitPose_Enc(pose_size=args.pose_size, dim_z_init=args.dim_z_init)\n  stdp_dec = StandardPose_Dec(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, dim_z_init=args.dim_z_init, length=args.stdp_length,\n                          hidden_size=args.stdp_dec_hidden_size, num_layers=args.stdp_dec_num_layers)\n  movement_enc = Movement_Enc(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, length=args.stdp_length,\n                             hidden_size=args.movement_enc_hidden_size, num_layers=args.movement_enc_num_layers, bidirection=(args.movement_enc_bidirection==1))\n  checkpoint = torch.load(args.decomp_snapshot)\n  initp_enc.load_state_dict(checkpoint['initp_enc'])\n  stdp_dec.load_state_dict(checkpoint['stdp_dec'])\n  movement_enc.load_state_dict(checkpoint['movement_enc'])\n  return initp_enc, stdp_dec, movement_enc\n\ndef loadCompModel(args):\n  dance_enc = Dance_Enc(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,\n                               hidden_size=args.dance_enc_hidden_size, num_layers=args.dance_enc_num_layers, bidirection=(args.dance_enc_bidirection==1))\n  dance_dec = Dance_Dec(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,\n                               hidden_size=args.dance_dec_hidden_size, num_layers=args.dance_dec_num_layers)\n  audstyle_enc = Audstyle_Enc(aud_size=args.aud_style_size, dim_z=args.dim_z_dance)\n  dance_reg = Dance2Style(aud_size=args.aud_style_size, dim_z_dance=args.dim_z_dance)\n  danceAud_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_movement, length=3)\n  zdance_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_dance, length=1)\n  checkpoint = torch.load(args.comp_snapshot)\n  dance_enc.load_state_dict(checkpoint['dance_enc'])\n  dance_dec.load_state_dict(checkpoint['dance_dec'])\n  audstyle_enc.load_state_dict(checkpoint['audstyle_enc'])\n\n  checkpoint2 = torch.load(args.neta_snapshot)\n  neta_cls = AudioClassifier_rnn(10,30,28,cls=3)\n  neta_cls.load_state_dict(checkpoint2)\n\n  return dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls\n\nif __name__ == \"__main__\":\n  parser = TestOptions()\n  args = parser.parse()\n  args.train = False\n\n  thr = args.thr\n\n  # Process music and get feature\n  infile = args.aud_path\n  outfile = 'style.npy'\n  p.preprocess(infile, outfile)\n\n  y, sr = librosa.load(infile)\n  onset_env = librosa.onset.onset_strength(y, sr=sr,aggregate=np.median)\n  times = librosa.frames_to_time(np.arange(len(onset_env)),sr=sr, hop_length=512)\n  tempo, beats = librosa.beat.beat_track(onset_envelope=onset_env,sr=sr)\n  np.save('beats.npy', times[beats])\n  beats = np.round(librosa.frames_to_time(beats, sr=sr)*15)\n\n  beats = np.load('beats.npy')\n  aud = np.load('style.npy')\n  os.remove('beats.npy')\n  os.remove('style.npy')\n  shutil.rmtree('normalized')\n\n  #### Pretrain network from Decomp\n  initp_enc, stdp_dec, movement_enc = loadDecompModel(args)\n\n  #### Comp network\n  dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls = loadCompModel(args)\n\n  trainer = Trainer_Comp(data_loader=None,\n                    movement_enc = movement_enc,\n                    initp_enc = initp_enc,\n                    stdp_dec = stdp_dec,\n                    dance_enc = dance_enc,\n                    dance_dec = dance_dec,\n                    danceAud_dis = danceAud_dis,\n                    zdance_dis = zdance_dis,\n                    aud_enc=neta_cls,\n                    audstyle_enc=audstyle_enc,\n                    dance_reg=dance_reg,\n                    args = args\n                    )\n\n  print('Loading Done')\n\n  mean_pose=np.load('{}/stats/all_onbeat_mean.npy'.format(args.data_dir))\n  std_pose=np.load('{}/stats/all_onbeat_std.npy'.format(args.data_dir))\n  mean_aud=np.load('{}/stats/all_aud_mean.npy'.format(args.data_dir))\n  std_aud=np.load('{}/stats/all_aud_std.npy'.format(args.data_dir))\n\n\n  length = aud.shape[0]\n\n  initpose = np.zeros((14, 2))\n  initpose = initpose.reshape(-1)\n  #initpose = (initpose-mean_pose)/std_pose\n\n  for j in range(aud.shape[0]):\n    aud[j] = (aud[j]-mean_aud)/std_aud\n\n  total_t = int(length/32+1)\n  final_stdpSeq = np.zeros((total_t*3*32, 14, 2))\n  initpose, aud = torch.Tensor(initpose).cuda(), torch.Tensor(aud).cuda()\n  initpose, aud = initpose.view(1, initpose.shape[0]), aud.view(1, aud.shape[0], aud.shape[1])\n  for t in range(total_t):\n    print('process {}/{}'.format(t, total_t))\n    fake_stdpSeq = trainer.test_final(initpose, aud, 3, thr)\n    while True:\n      fake_stdpSeq = trainer.test_final(initpose, aud, 3, thr)\n      if not fake_stdpSeq is None:\n        break\n    initpose = fake_stdpSeq[2,-1]\n    initpose = torch.Tensor(initpose).cuda()\n    initpose = initpose.view(1,-1)\n    fake_stdpSeq = fake_stdpSeq.squeeze()\n    for j in range(fake_stdpSeq.shape[0]):\n      for k in range(fake_stdpSeq.shape[1]):\n        fake_stdpSeq[j,k] = fake_stdpSeq[j,k]*std_pose + mean_pose\n    fake_stdpSeq = np.resize(fake_stdpSeq, (fake_stdpSeq.shape[0],32, 14, 2))\n    for j in range(3):\n      final_stdpSeq[96*t+32*j:96*t+32*(j+1)] = fake_stdpSeq[j]\n\n  if args.modulate:\n    final_stdpSeq = modulate.modulate(final_stdpSeq, beats, length)\n\n  out_dir = args.out_dir\n  if not os.path.exists(out_dir):\n    os.mkdir(out_dir)\n  utils.vis(final_stdpSeq, out_dir)\n  sp.call('ffmpeg -r 15  -i {}/frame%03d.png -i {} -c:v libx264 -pix_fmt yuv420p  -crf 23 -r 30  -y -strict -2  {}'.format(out_dir, args.aud_path, args.out_file), shell=True)\n\n"
  },
  {
    "path": "model_comp.py",
    "content": "# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.\n#\n# This work is made available\n# under the Nvidia Source Code License (1-way Commercial).\n# To view a copy of this license, visit\n# https://nvlabs.github.io/Dancing2Music/License.txt\nimport os\nimport time\nimport numpy as np\nimport random\nimport math\n\nimport torch\nfrom torch import nn\nfrom torch.autograd import Variable\nimport torch.optim as optim\nfrom torch.nn.utils import clip_grad_norm_\n\nfrom utils import Logger\n\nif torch.cuda.is_available():\n  T = torch.cuda\nelse:\n  T = torch\n\nclass Trainer_Comp(object):\n  def __init__(self, data_loader, dance_enc, dance_dec, danceAud_dis, movement_enc, initp_enc, stdp_dec, aud_enc, audstyle_enc, dance_reg=None, args=None, zdance_dis=None):\n    self.data_loader = data_loader\n    self.movement_enc = movement_enc\n    self.initp_enc = initp_enc\n    self.stdp_dec = stdp_dec\n    self.dance_enc = dance_enc\n    self.dance_dec = dance_dec\n    self.danceAud_dis = danceAud_dis\n    self.aud_enc = aud_enc\n    self.audstyle_enc = audstyle_enc\n    self.train = args.train\n    self.args = args\n\n    if args.train:\n      self.zdance_dis = zdance_dis\n      self.dance_reg = dance_reg\n\n      self.logger = Logger(args.log_dir)\n      self.logs = self.init_logs()\n      self.log_interval = args.log_interval\n      self.snapshot_ep = args.snapshot_ep\n      self.snapshot_dir = args.snapshot_dir\n\n      self.opt_dance_enc = torch.optim.Adam(self.dance_enc.parameters(), lr=args.lr)\n      self.opt_dance_dec = torch.optim.Adam(self.dance_dec.parameters(), lr=args.lr)\n      self.opt_danceAud_dis = torch.optim.Adam(self.danceAud_dis.parameters(), lr=args.lr)\n      self.opt_audstyle_enc = torch.optim.Adam(self.audstyle_enc.parameters(), lr=args.lr)\n      self.opt_zdance_dis = torch.optim.Adam(self.zdance_dis.parameters(), lr=args.lr)\n      self.opt_dance_reg = torch.optim.Adam(self.dance_reg.parameters(), lr=args.lr)\n\n      self.opt_stdp_dec = torch.optim.Adam(self.stdp_dec.parameters(), lr=args.lr*0.1)\n      self.opt_movement_enc = torch.optim.Adam(self.movement_enc.parameters(), lr=args.lr*0.1)\n\n    self.latent_dropout = nn.Dropout(p=args.latent_dropout)\n    self.l1_criterion = torch.nn.L1Loss()\n    self.gan_criterion = nn.BCEWithLogitsLoss()\n    self.mse_criterion = nn.MSELoss().cuda()\n\n  def init_logs(self):\n    return {'l_kl_zdance':0, 'l_kl_zmovement':0, 'l_kl_fake_zdance':0, 'l_kl_fake_zmovement':0,\n            'l_l1_zmovement_mu':0, 'l_l1_zmovement_logvar':0, 'l_l1_stdpSeq':0, 'l_l1_zdance':0,\n            'l_dis':0, 'l_dis_true':0, 'l_dis_fake':0,\n            'l_info':0, 'l_info_real':0, 'l_info_fake':0,\n            'l_gen':0\n            }\n\n  def get_z_random(self, batchSize, nz, random_type='gauss'):\n    z = torch.randn(batchSize, nz).cuda()\n    return z\n\n  @staticmethod\n  def ones_like(tensor, val=1.):\n    return T.FloatTensor(tensor.size()).fill_(val)\n\n  @staticmethod\n  def zeros_like(tensor, val=0.):\n    return T.FloatTensor(tensor.size()).fill_(val)\n  def kld_coef(self, i):\n    return float(1/(1+np.exp(-0.0005*(i-15000))))\n\n\n  def forward(self, stdpSeq, batchsize, aud_style, aud):\n    self.aud = torch.mean(aud, dim=1)\n\n    self.batchsize = batchsize\n    self.stdpSeq = stdpSeq\n    self.aud_style = aud_style\n    ### stdpSeq -> z_inits, z_movements\n    self.pose_0 = stdpSeq[:,0,:]\n    self.z_init_mu, self.z_init_logvar = self.initp_enc(self.pose_0)\n    z_init_std = self.z_init_logvar.mul(0.5).exp_()\n    z_init_eps = self.get_z_random(z_init_std.size(0), z_init_std.size(1), 'gauss')\n    self.z_init = z_init_eps.mul(z_init_std).add_(self.z_init_mu)\n\n    self.z_movement_mus, self.z_movement_logvars = self.movement_enc(stdpSeq)\n    z_movement_stds = self.z_movement_logvars.mul(0.5).exp_()\n    z_movement_epss = self.get_z_random(z_movement_stds.size(0), z_movement_stds.size(1), 'gauss')\n    self.z_movements = z_movement_epss.mul(z_movement_stds).add_(self.z_movement_mus)\n    self.z_movementSeq_mu = self.z_movement_mus.view(batchsize, -1, self.z_movements.shape[1])\n    self.z_movementSeq_logvar = self.z_movement_logvars.view(batchsize, -1, self.z_movements.shape[1])\n\n    self.z_init, self.z_movements = self.z_init.detach(), self.z_movements.detach()\n    self.z_movement_mus, self.z_movement_logvars = self.z_movement_mus.detach(), self.z_movement_logvars.detach()\n\n    ### z_movements -> z_dance\n    self.z_dance_mu, self.z_dance_logvar = self.dance_enc(self.z_movementSeq_mu, self.z_movementSeq_logvar)\n    z_dance_std = self.z_dance_logvar.mul(0.5).exp_()\n    z_dance_eps = self.get_z_random(z_dance_std.size(0), z_dance_std.size(1), 'gauss')\n    self.z_dance = z_dance_eps.mul(z_dance_std).add_(self.z_dance_mu)\n    ### z_dance -> z_movements\n    self.recon_z_movements_mu, self.recon_z_movements_logvar  = self.dance_dec(self.z_dance)\n    recon_z_movement_std = self.recon_z_movements_logvar.mul(0.5).exp_()\n    recon_z_movement_eps = self.get_z_random(recon_z_movement_std.size(0), recon_z_movement_std.size(1), 'gauss')\n    self.recon_z_movements = recon_z_movement_eps.mul(recon_z_movement_std).add_(self.recon_z_movements_mu)\n    ### z_movements -> stdpSeq\n    self.recon_stdpSeq = self.stdp_dec(self.z_init, self.recon_z_movements)\n\n    ### Music to z_dance to z_movements\n    self.fake_z_dance_mu, self.fake_z_dance_logvar = self.audstyle_enc(aud_style)\n    fake_z_dance_std = self.fake_z_dance_logvar.mul(0.5).exp_()\n    fake_z_dance_eps = self.get_z_random(fake_z_dance_std.size(0), fake_z_dance_std.size(1), 'gauss')\n    self.fake_z_dance = fake_z_dance_eps.mul(fake_z_dance_std).add_(self.fake_z_dance_mu)\n    self.fake_z_movements_mu, self.fake_z_movements_logvar = self.dance_dec(self.fake_z_dance)\n    fake_z_movements_std = self.fake_z_movements_logvar.mul(0.5).exp_()\n    fake_z_movements_eps = self.get_z_random(fake_z_movements_std.size(0), fake_z_movements_std.size(1), 'gauss')\n    self.fake_z_movements = fake_z_movements_eps.mul(fake_z_movements_std).add_(self.fake_z_movements_mu)\n\n    fake_z_movementSeq_mu = self.fake_z_movements_mu.view(batchsize, -1, self.fake_z_movements_mu.shape[1])\n    fake_z_movementSeq_logvar = self.fake_z_movements_logvar.view(batchsize, -1, self.fake_z_movements_logvar.shape[1])\n    self.fake_z_movementSeq = torch.cat((fake_z_movementSeq_mu, fake_z_movementSeq_logvar),2)\n\n  def backward_D(self):\n    #real_movements = torch.cat((self.z_movementSeq_mu, self.z_movementSeq_logvar),2)\n    tmp_recon_mu = self.recon_z_movements_mu.view(self.batchsize, -1, self.z_movements.shape[1])\n    tmp_recon_logvar = self.recon_z_movements_logvar.view(self.batchsize, -1, self.z_movements.shape[1])\n    real_movements = torch.cat((tmp_recon_mu, tmp_recon_logvar),2)\n    fake_movements = self.fake_z_movementSeq\n\n    real_labels,_ = self.danceAud_dis(real_movements.detach(), self.aud)\n    fake_labels,_ = self.danceAud_dis(fake_movements.detach(), self.aud)\n\n    ones = self.ones_like(real_labels)\n    zeros = self.zeros_like(fake_labels)\n\n    self.loss_dis_true = self.gan_criterion(real_labels, ones)\n    self.loss_dis_fake = self.gan_criterion(fake_labels, zeros)\n    self.loss_dis = (self.loss_dis_true + self.loss_dis_fake)*self.args.lambda_gan\n\n    real_dance = torch.cat((self.z_dance_mu, self.z_dance_logvar), 1)\n    fake_dance = torch.cat((self.fake_z_dance_mu, self.fake_z_dance_logvar), 1)\n    real_labels, _ = self.zdance_dis(real_dance.detach(), self.aud)\n    fake_labels, _ = self.zdance_dis(fake_dance.detach(), self.aud)\n    ones = self.ones_like(real_labels)\n    zeros = self.zeros_like(fake_labels)\n\n    self.loss_zdis_true = self.gan_criterion(real_labels, ones)\n    self.loss_zdis_fake = self.gan_criterion(fake_labels, zeros)\n    self.loss_dis += (self.loss_zdis_true + self.loss_zdis_fake)*self.args.lambda_gan\n\n\n  def backward_danceED(self):\n    # z_dance KL\n    kl_element = self.z_dance_mu.pow(2).add_(self.z_dance_logvar.exp()).mul_(-1).add_(1).add_(self.z_dance_logvar)\n    self.loss_kl_z_dance = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl_dance))\n    kl_element = self.fake_z_dance_mu.pow(2).add_(self.fake_z_dance_logvar.exp()).mul_(-1).add_(1).add_(self.fake_z_dance_logvar)\n    self.loss_kl_fake_z_dance = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl_dance))\n    # z_movement KL\n    kl_element = self.recon_z_movements_mu.pow(2).add_(self.recon_z_movements_logvar.exp()).mul_(-1).add_(1).add_(self.recon_z_movements_logvar)\n    self.loss_kl_z_movement = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl))\n    kl_element = self.fake_z_movements_mu.pow(2).add_(self.fake_z_movements_logvar.exp()).mul_(-1).add_(1).add_(self.fake_z_movements_logvar)\n    self.loss_kl_fake_z_movements = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl))\n    # z_movement reconstruction\n    self.loss_l1_z_movement_mu = self.l1_criterion(self.recon_z_movements_mu, self.z_movement_mus) * self.args.lambda_zmovements_recon\n    self.loss_l1_z_movement_logvar = self.l1_criterion(self.recon_z_movements_logvar, self.z_movement_logvars) * self.args.lambda_zmovements_recon\n\n    # stdp reconstruction\n    self.loss_l1_stdpSeq = self.l1_criterion(self.recon_stdpSeq, self.stdpSeq) * self.args.lambda_stdpSeq_recon\n\n    # Music2Dance GAN\n    fake_movements = self.fake_z_movementSeq\n    fake_labels, _ = self.danceAud_dis(fake_movements, self.aud)\n\n    ones = self.ones_like(fake_labels)\n    self.loss_gen = self.gan_criterion(fake_labels, ones) * self.args.lambda_gan\n\n    fake_dance = torch.cat((self.fake_z_dance_mu, self.fake_z_dance_logvar), 1)\n    fake_labels, _ = self.zdance_dis(fake_dance, self.aud)\n    ones = self.ones_like(fake_labels)\n    self.loss_gen += self.gan_criterion(fake_labels, ones) * self.args.lambda_gan\n\n    self.loss = self.loss_kl_z_movement + self.loss_kl_z_dance + self.loss_l1_z_movement_mu + self.loss_l1_z_movement_logvar + self.loss_l1_stdpSeq + self.loss_gen\n\n  def backward_info_ondance(self):\n    real_pred = self.dance_reg(self.z_dance)\n    fake_pred = self.dance_reg(self.fake_z_dance)\n    self.loss_info_real = self.mse_criterion(real_pred, self.aud_style)\n    self.loss_info_fake = self.mse_criterion(fake_pred, self.aud_style)\n    self.loss_info = self.loss_info_real + self.loss_info_fake\n\n  def zero_grad(self, opt_list):\n    for opt in opt_list:\n      opt.zero_grad()\n\n  def clip_norm(self, network_list):\n    for network in network_list:\n      clip_grad_norm_(network.parameters(), 0.5)\n\n  def step(self, opt_list):\n    for opt in opt_list:\n      opt.step()\n\n  def update(self):\n    self.zero_grad([self.opt_danceAud_dis, self.opt_zdance_dis])\n    self.backward_D()\n    self.loss_dis.backward(retain_graph=True)\n    self.clip_norm([self.danceAud_dis, self.zdance_dis])\n    self.step([self.opt_danceAud_dis, self.opt_zdance_dis])\n\n    self.zero_grad([self.opt_dance_enc, self.opt_dance_dec, self.opt_audstyle_enc, self.opt_stdp_dec])\n    self.backward_danceED()\n    self.loss.backward(retain_graph=True)\n    self.clip_norm([self.dance_enc, self.dance_dec, self.audstyle_enc, self.stdp_dec])\n    self.step([self.opt_dance_enc, self.opt_dance_dec, self.opt_audstyle_enc, self.opt_stdp_dec])\n\n    self.zero_grad([self.opt_dance_enc, self.opt_audstyle_enc, self.opt_dance_reg, self.opt_stdp_dec])\n    self.backward_info_ondance()\n    self.loss_info.backward()\n    self.clip_norm([self.dance_enc, self.audstyle_enc, self.dance_reg, self.stdp_dec])\n    self.step([self.opt_dance_enc, self.opt_audstyle_enc, self.opt_dance_reg, self.opt_stdp_dec])\n\n  def test_final(self, initpose, aud, n, thr=0):\n    self.cuda()\n    self.movement_enc.eval()\n    self.stdp_dec.eval()\n    self.initp_enc.eval()\n    self.dance_enc.eval()\n    self.dance_dec.eval()\n    self.aud_enc.eval()\n    self.audstyle_enc.eval()\n    aud_style = self.aud_enc.get_style(aud).detach()\n\n    self.fake_z_dance_mu, self.fake_z_dance_logvar = self.audstyle_enc(aud_style)\n    fake_z_dance_std = self.fake_z_dance_logvar.mul(0.5).exp_()\n    fake_z_dance_eps = self.get_z_random(fake_z_dance_std.size(0), fake_z_dance_std.size(1), 'gauss')\n    self.fake_z_dance = fake_z_dance_eps.mul(fake_z_dance_std).add_(self.fake_z_dance_mu)\n\n    self.fake_z_movements_mu, self.fake_z_movements_logvar = self.dance_dec(self.fake_z_dance, length=3)\n    fake_z_movements_std = self.fake_z_movements_logvar.mul(0.5).exp_()\n    fake_z_movements_eps = self.get_z_random(fake_z_movements_std.size(0), fake_z_movements_std.size(1), 'gauss')\n    self.fake_z_movements = fake_z_movements_eps.mul(fake_z_movements_std).add_(self.fake_z_movements_mu)\n\n    fake_stdpSeq=[]\n    for i in range(n):\n      z_init_mus, z_init_logvars = self.initp_enc(initpose)\n      z_init_stds = z_init_logvars.mul(0.5).exp_()\n      z_init_epss = self.get_z_random(z_init_stds.size(0), z_init_stds.size(1), 'gauss')\n      z_init = z_init_epss.mul(z_init_stds).add_(z_init_mus)\n      fake_stdp = self.stdp_dec(z_init, self.fake_z_movements[i:i+1])\n      fake_stdpSeq.append(fake_stdp)\n      initpose  = fake_stdp[:,-1,:]\n    fake_stdpSeq = torch.cat(fake_stdpSeq, dim=0)\n    flag = False\n    for i in range(n):\n      s = fake_stdpSeq[i]\n      diff = torch.abs(s[1:]-s[:-1])\n      diffsum = torch.sum(diff)\n      if diffsum.cpu().detach().numpy() < thr:\n        flag = True\n\n    if flag:\n      return None\n    else:\n      return fake_stdpSeq.cpu().detach().numpy()\n\n\n  def resume(self, model_dir, train=True):\n    checkpoint = torch.load(model_dir)\n    self.dance_enc.load_state_dict(checkpoint['dance_enc'])\n    self.dance_dec.load_state_dict(checkpoint['dance_dec'])\n    self.audstyle_enc.load_state_dict(checkpoint['audstyle_enc'])\n    self.stdp_dec.load_state_dict(checkpoint['stdp_dec'])\n    self.movement_enc.load_state_dict(checkpoint['movement_enc'])\n    if train:\n      self.danceAud_dis.load_state_dict(checkpoint['danceAud_dis'])\n      self.dance_reg.load_state_dict(checkpoint['dance_reg'])\n      self.opt_dance_enc.load_state_dict(checkpoint['opt_dance_enc'])\n      self.opt_dance_dec.load_state_dict(checkpoint['opt_dance_dec'])\n      self.opt_stdp_dec.load_state_dict(checkpoint['opt_stdp_dec'])\n      self.opt_audstyle_enc.load_state_dict(checkpoint['opt_audstyle_enc'])\n      self.opt_danceAud_dis.load_state_dict(checkpoint['opt_danceAud_dis'])\n      self.opt_dance_reg.load_state_dict(checkpoint['opt_dance_reg'])\n    return checkpoint['ep'], checkpoint['total_it']\n\n  def save(self, filename, ep, total_it):\n    state = {\n             'stdp_dec': self.stdp_dec.state_dict(),\n             'movement_enc': self.movement_enc.state_dict(),\n             'dance_enc': self.dance_enc.state_dict(),\n             'dance_dec': self.dance_dec.state_dict(),\n             'audstyle_enc': self.audstyle_enc.state_dict(),\n             'danceAud_dis': self.danceAud_dis.state_dict(),\n             'zdance_dis': self.zdance_dis.state_dict(),\n             'dance_reg': self.dance_reg.state_dict(),\n             'opt_stdp_dec': self.opt_stdp_dec.state_dict(),\n             'opt_movement_enc': self.opt_movement_enc.state_dict(),\n             'opt_dance_enc': self.opt_dance_enc.state_dict(),\n             'opt_dance_dec': self.opt_dance_dec.state_dict(),\n             'opt_audstyle_enc': self.opt_audstyle_enc.state_dict(),\n             'opt_danceAud_dis': self.opt_danceAud_dis.state_dict(),\n             'opt_zdance_dis': self.opt_zdance_dis.state_dict(),\n             'opt_dance_reg': self.opt_dance_reg.state_dict(),\n             'ep': ep,\n             'total_it': total_it\n              }\n    torch.save(state, filename)\n    return\n\n  def cuda(self):\n    if self.train:\n      self.dance_reg.cuda()\n      self.danceAud_dis.cuda()\n      self.zdance_dis.cuda()\n    self.stdp_dec.cuda()\n    self.initp_enc.cuda()\n    self.movement_enc.cuda()\n    self.dance_enc.cuda()\n    self.dance_dec.cuda()\n    self.aud_enc.cuda()\n    self.audstyle_enc.cuda()\n    self.gan_criterion.cuda()\n\n  def train(self, ep=0, it=0):\n    self.cuda()\n    for epoch in range(ep, self.args.num_epochs):\n      self.movement_enc.train()\n      self.stdp_dec.train()\n      self.initp_enc.train()\n      self.dance_enc.train()\n      self.dance_dec.train()\n      self.danceAud_dis.train()\n      self.zdance_dis.train()\n      self.audstyle_enc.train()\n      self.dance_reg.train()\n      self.aud_enc.eval()\n      stdp_recon = 0\n\n      for i, (stdpSeq, aud) in enumerate(self.data_loader):\n        stdpSeq, aud = stdpSeq.cuda().detach(), aud.cuda().detach()\n        stdpSeq = stdpSeq.view(stdpSeq.shape[0]*stdpSeq.shape[1], stdpSeq.shape[2], stdpSeq.shape[3])\n        aud_style = self.aud_enc.get_style(aud).detach()\n\n        self.forward(stdpSeq, aud.shape[0], aud_style, aud)\n        self.update()\n        self.logs['l_kl_zmovement'] += self.loss_kl_z_movement.data\n        self.logs['l_kl_zdance'] += self.loss_kl_z_dance.data\n        self.logs['l_l1_zmovement_mu'] += self.loss_l1_z_movement_mu.data\n        self.logs['l_l1_zmovement_logvar'] += self.loss_l1_z_movement_logvar.data\n        self.logs['l_l1_stdpSeq'] += self.loss_l1_stdpSeq.data\n        self.logs['l_kl_fake_zdance'] += self.loss_kl_fake_z_dance.data\n        self.logs['l_kl_fake_zmovement'] += self.loss_kl_fake_z_movements\n        self.logs['l_dis'] += self.loss_dis.data\n        self.logs['l_dis_true'] += self.loss_dis_true.data\n        self.logs['l_dis_fake'] += self.loss_dis_fake.data\n        self.logs['l_gen'] += self.loss_gen.data\n        self.logs['l_info'] += self.loss_info\n        self.logs['l_info_real'] += self.loss_info_real\n        self.logs['l_info_fake'] += self.loss_info_fake\n\n        print('Epoch:{:3} Iter{}/{}\\tl_l1_zmovement mu{:.3f} logvar{:.3f}\\tl_l1_stdpSeq {:.3f}\\tl_kl_dance {:.3f}\\tl_kl_movement {:.3f}\\n'.format(epoch, i, len(self.data_loader),\n            self.loss_l1_z_movement_mu, self.loss_l1_z_movement_logvar, self.loss_l1_stdpSeq, self.loss_kl_z_dance, self.loss_kl_z_movement) +\n             '\\t\\t\\tl_kl_f_dance {:.3f}\\tl_dis {:.3f} {:.3f}\\tl_gen {:.3f}'.format(self.loss_kl_fake_z_dance, self.loss_dis_true, self.loss_dis_fake, self.loss_gen))\n\n        it += 1\n        if it % self.log_interval == 0:\n          for tag, value in self.logs.items():\n            self.logger.scalar_summary(tag, value/self.log_interval, it)\n          self.logs = self.init_logs()\n      if epoch % self.snapshot_ep == 0:\n        self.save(os.path.join(self.snapshot_dir, '{:04}.ckpt'.format(epoch)), epoch, it)\n"
  },
  {
    "path": "model_decomp.py",
    "content": "# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.\n#\n# This work is made available\n# under the Nvidia Source Code License (1-way Commercial).\n# To view a copy of this license, visit\n# https://nvlabs.github.io/Dancing2Music/License.txt\nimport os\nimport time\nimport numpy as np\nimport random\nimport math\n\nimport torch\nfrom torch import nn\nfrom torch.autograd import Variable\nimport torch.optim as optim\nfrom torch.nn.utils import clip_grad_norm_\n\nfrom utils import Logger\n\nif torch.cuda.is_available():\n  T = torch.cuda\nelse:\n  T = torch\n\nclass Trainer_Decomp(object):\n  def __init__(self, data_loader, initp_enc, initp_dec, movement_enc, stdp_dec, args=None):\n    self.data_loader = data_loader\n    self.initp_enc = initp_enc\n    self.initp_dec = initp_dec\n    self.movement_enc = movement_enc\n    self.stdp_dec = stdp_dec\n\n\n    self.args = args\n    if args.train:\n      self.logger = Logger(args.log_dir)\n      self.logs = self.init_logs()\n      self.log_interval = args.log_interval\n      self.snapshot_ep = args.snapshot_ep\n      self.snapshot_dir = args.snapshot_dir\n\n      self.opt_initp_enc = torch.optim.Adam(self.initp_enc.parameters(), lr=args.lr)\n      self.opt_initp_dec = torch.optim.Adam(self.initp_dec.parameters(), lr=args.lr)\n      self.opt_movement_enc = torch.optim.Adam(self.movement_enc.parameters(), lr=args.lr)\n      self.opt_stdp_dec = torch.optim.Adam(self.stdp_dec.parameters(), lr=args.lr)\n\n    self.latent_dropout = nn.Dropout(p=args.latent_dropout)\n    self.l1_criterion = torch.nn.L1Loss()\n    self.gan_criterion = nn.BCEWithLogitsLoss()\n\n\n  def init_logs(self):\n    return {'l_kl_zinit':0, 'l_kl_zmovement':0, 'l_l1_stdp':0, 'l_l1_cross_stdp':0, 'l_dist_zmovement':0,\n            'l_l1_initp':0, 'l_l1_initp_con':0,\n            'kld_coef':0\n            }\n\n  def get_z_random(self, batchSize, nz, random_type='gauss'):\n    z = torch.randn(batchSize, nz).cuda()\n    return z\n\n  @staticmethod\n  def ones_like(tensor, val=1.):\n    return T.FloatTensor(tensor.size()).fill_(val)\n\n  @staticmethod\n  def zeros_like(tensor, val=0.):\n    return T.FloatTensor(tensor.size()).fill_(val)\n\n\n  def random_generate_stdp(self, init_p):\n    self.pose_0 = init_p\n    self.z_init_mu, self.z_init_logvar = self.initp_enc(self.pose_0)\n    z_init_std = self.z_init_logvar.mul(0.5).exp_()\n    z_init_eps = self.get_z_random(z_init_std.size(0), z_init_std.size(1), 'gauss')\n    self.z_init = z_init_eps.mul(z_init_std).add_(self.z_init_mu)\n    self.z_random_movement = self.get_z_random(self.z_init.size(0), 512, 'gauss')\n    self.fake_stdpose = self.stdp_dec(self.z_init, self.z_random_movement)\n    return self.fake_stdpose\n\n  def forward(self, stdpose1, stdpose2):\n    self.stdpose1 = stdpose1\n    self.stdpose2 = stdpose2\n\n    # stdpose -> stdpose[0] -> z_init\n    self.pose1_0 = stdpose1[:,0,:]\n    self.pose2_0 = stdpose2[:,0,:]\n    self.poses_0 = torch.cat((self.pose1_0, self.pose2_0), 0)\n    self.z_init_mus, self.z_init_logvars = self.initp_enc(self.poses_0)\n    z_init_stds = self.z_init_logvars.mul(0.5).exp_()\n    z_init_epss = self.get_z_random(z_init_stds.size(0), z_init_stds.size(1), 'gauss')\n    self.z_inits = z_init_epss.mul(z_init_stds).add_(self.z_init_mus)\n    self.z_init1, self.z_init2 = torch.split(self.z_inits, self.stdpose1.size(0), dim=0)\n\n    # stdpose -> z_movement\n    stdposes = torch.cat((stdpose1, stdpose2), 0)\n    self.z_movement_mus, self.z_movement_logvars = self.movement_enc(stdposes)\n    z_movement_stds = self.z_movement_logvars.mul(0.5).exp_()\n    z_movement_epss = self.get_z_random(z_movement_stds.size(0), z_movement_stds.size(1), 'gauss')\n    self.z_movements = z_movement_epss.mul(z_movement_stds).add_(self.z_movement_mus)\n    self.z_movement1, self.z_movement2 = torch.split(self.z_movements, self.stdpose1.size(0), dim=0)\n\n    # zinit1+zmovement1->stdpose1   zinit2+zmovement2->stdpose2\n    self.recon_stdpose1 = self.stdp_dec(self.z_init1, self.z_movement1)\n    self.recon_stdpose2 = self.stdp_dec(self.z_init2, self.z_movement2)\n\n    # zinit1+zmovement2->stdpose1   zinit2+zmovement1->stdpose2\n    self.recon_stdpose1_cross = self.stdp_dec(self.z_init1, self.z_movement2)\n    self.recon_stdpose2_cross = self.stdp_dec(self.z_init2, self.z_movement1)\n\n    # z_init -> \\hat{stdpose[0]}\n    self.recon_pose1_0 = self.initp_dec(self.z_init1)\n    self.recon_pose2_0 = self.initp_dec(self.z_init2)\n\n    # single pose reconstruction\n    randomlist = np.random.permutation(31)[:4]\n    singlepose = []\n    for r in randomlist:\n      singlepose.append(self.stdpose1[:,r,:])\n    self.singleposes = torch.cat(singlepose, dim=0).detach()\n    self.z_single_mus, self.z_single_logvars = self.initp_enc(self.singleposes)\n    z_single_stds = self.z_single_logvars.mul(0.5).exp_()\n    z_single_epss = self.get_z_random(z_single_stds.size(0), z_single_stds.size(1), 'gauss')\n    z_single = z_single_epss.mul(z_single_stds).add_(self.z_single_mus)\n    self.recon_singleposes = self.initp_dec(z_single)\n\n  def backward_initp_ED(self):\n    # z_init KL\n    kl_element = self.z_init_mus.pow(2).add_(self.z_init_logvars.exp()).mul_(-1).add_(1).add_(self.z_init_logvars)\n    self.loss_kl_z_init = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl))\n\n    # initpose reconstruction\n    self.loss_l1_initp = self.l1_criterion(self.recon_singleposes, self.singleposes) * self.args.lambda_initp_recon\n\n    self.loss_initp = self.loss_kl_z_init + self.loss_l1_initp\n\n  def backward_movement_ED(self):\n    # z_movement KL\n    kl_element = self.z_movement_mus.pow(2).add_(self.z_movement_logvars.exp()).mul_(-1).add_(1).add_(self.z_movement_logvars)\n    #self.loss_kl_z_movement = torch.mean(kl_element).mul_(-0.5) * self.args.lambda_kl\n    self.loss_kl_z_movement = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl))\n\n    # stdpose self reconstruction\n    loss_l1_stdp1 = self.l1_criterion(self.recon_stdpose1, self.stdpose1) * self.args.lambda_stdp_recon\n    loss_l1_stdp2 = self.l1_criterion(self.recon_stdpose2, self.stdpose2) * self.args.lambda_stdp_recon\n    self.loss_l1_stdp = loss_l1_stdp1 + loss_l1_stdp2\n\n    # stdpose cross reconstruction\n    loss_l1_cross_stdp1 = self.l1_criterion(self.recon_stdpose1_cross, self.stdpose1) * self.args.lambda_stdp_recon\n    loss_l1_cross_stdp2 = self.l1_criterion(self.recon_stdpose2_cross, self.stdpose2) * self.args.lambda_stdp_recon\n    self.loss_l1_cross_stdp = loss_l1_cross_stdp1 + loss_l1_cross_stdp2\n\n    # Movement dist\n    self.loss_dist_z_movement = torch.mean(torch.abs(self.z_movement1-self.z_movement2)) * self.args.lambda_dist_z_movement\n\n    self.loss_movement = self.loss_kl_z_movement + self.loss_l1_stdp + self.loss_l1_cross_stdp + self.loss_dist_z_movement\n\n\n  def update(self):\n    self.opt_initp_enc.zero_grad()\n    self.opt_initp_dec.zero_grad()\n    self.opt_movement_enc.zero_grad()\n    self.opt_stdp_dec.zero_grad()\n    self.backward_initp_ED()\n    self.backward_movement_ED()\n    self.g_loss = self.loss_initp + self.loss_movement\n    self.g_loss.backward(retain_graph=True)\n    clip_grad_norm_(self.movement_enc.parameters(), 0.5)\n    clip_grad_norm_(self.stdp_dec.parameters(), 0.5)\n    self.opt_initp_enc.step()\n    self.opt_initp_dec.step()\n    self.opt_movement_enc.step()\n    self.opt_stdp_dec.step()\n\n\n  def save(self, filename, ep, total_it):\n    state = {\n             'stdp_dec': self.stdp_dec.state_dict(),\n             'movement_enc': self.movement_enc.state_dict(),\n             'initp_enc': self.initp_enc.state_dict(),\n             'initp_dec': self.initp_dec.state_dict(),\n             'opt_stdp_dec': self.opt_stdp_dec.state_dict(),\n             'opt_movement_enc': self.opt_movement_enc.state_dict(),\n             'opt_initp_enc': self.opt_initp_enc.state_dict(),\n             'opt_initp_dec': self.opt_initp_dec.state_dict(),\n             'ep': ep,\n             'total_it': total_it\n              }\n    torch.save(state, filename)\n    return\n\n  def resume(self, model_dir, train=True):\n    checkpoint = torch.load(model_dir)\n    # weight\n    self.stdp_dec.load_state_dict(checkpoint['stdp_dec'])\n    self.movement_enc.load_state_dict(checkpoint['movement_enc'])\n    self.initp_enc.load_state_dict(checkpoint['initp_enc'])\n    self.initp_dec.load_state_dict(checkpoint['initp_dec'])\n    # optimizer\n    if train:\n      self.opt_stdp_dec.load_state_dict(checkpoint['opt_stdp_dec'])\n      self.opt_movement_enc.load_state_dict(checkpoint['opt_movement_enc'])\n      self.opt_initp_enc.load_state_dict(checkpoint['opt_initp_enc'])\n      self.opt_initp_dec.load_state_dict(checkpoint['opt_initp_dec'])\n    return checkpoint['ep'], checkpoint['total_it']\n\n  def kld_coef(self, i):\n    return float(1/(1+np.exp(-0.0005*(i-15000)))) #v3\n\n\n  def generate_stdp_sequence(self, initpose, aud, num_stdp):\n    self.initp_enc.cuda()\n    self.initp_dec.cuda()\n    self.movement_enc.cuda()\n    self.stdp_dec.cuda()\n    self.initp_enc.eval()\n    self.initp_dec.eval()\n    self.movement_enc.eval()\n    self.stdp_dec.eval()\n    initpose = initpose.cuda()\n\n    aud_style = self.aud_enc.get_style(aud)\n\n    stdp_seq = []\n    cnt = 0\n    #for i in range(num_stdp):\n    while not cnt == num_stdp:\n      if cnt==0:\n        z_inits = self.get_z_random(1, 10, 'gauss')\n      else:\n        z_init_mus, z_init_logvars = self.initp_enc(initpose)\n        z_init_stds = z_init_logvars.mul(0.5).exp_()\n        z_init_epss = self.get_z_random(z_init_stds.size(0), z_init_stds.size(1), 'gauss')\n        z_inits = z_init_epss.mul(z_init_stds).add_(z_init_mus)\n\n      z_audstyle_mu, z_audstyle_logvar = self.audstyle_enc(aud_style)\n      z_as_std = z_audstyle_logvar.mul(0.5).exp_()\n      z_as_eps = self.get_z_random(z_as_std.size(0), z_as_std.size(1), 'gauss')\n      z_audstyle = z_as_eps.mul(z_as_std).add_(z_audstyle_mu)\n      if random.randint(0,5)==100:\n        z_audstyle = self.get_z_random(z_inits.size(0), 512, 'gauss')\n\n      fake_stdpose = self.stdp_dec(z_inits, z_audstyle)\n\n      s = fake_stdpose[0]\n      diff = torch.abs(s[1:]-s[:-1])\n      diffsum = torch.sum(diff)\n      if diffsum.cpu().detach().numpy() < 70:\n        continue\n\n      cnt += 1\n      stdp_seq.append(fake_stdpose.cpu().detach().numpy())\n      initpose = fake_stdpose[:,-1,:]\n    return stdp_seq\n\n\n  def cuda(self):\n    self.initp_enc.cuda()\n    self.initp_dec.cuda()\n    self.movement_enc.cuda()\n    self.stdp_dec.cuda()\n    self.l1_criterion.cuda()\n\n  def train(self, ep=0, it=0):\n    self.cuda()\n\n    full_kl = self.args.lambda_kl\n    kl_w = 0\n    kl_step = 0.05\n    best_stdp_recon = 100\n    for epoch in range(ep, self.args.num_epochs):\n      self.initp_enc.train()\n      self.initp_dec.train()\n      self.movement_enc.train()\n      self.stdp_dec.train()\n      stdp_recon = 0\n      for i, (stdpose, stdpose2) in enumerate(self.data_loader):\n        self.args.lambda_kl = full_kl*self.kld_coef(it)\n        stdpose, stdpose2  = stdpose.cuda().detach(), stdpose2.cuda().detach()\n\n        self.forward(stdpose, stdpose2)\n        self.update()\n        self.logs['l_kl_zinit'] += self.loss_kl_z_init.data\n        self.logs['l_kl_zmovement'] += self.loss_kl_z_movement.data\n        self.logs['l_l1_initp'] += self.loss_l1_initp.data\n        self.logs['l_l1_stdp'] += self.loss_l1_stdp.data\n        self.logs['l_l1_cross_stdp'] += self.loss_l1_cross_stdp.data\n        self.logs['l_dist_zmovement'] += self.loss_dist_z_movement.data\n        self.logs['kld_coef'] += self.args.lambda_kl\n\n        print('Epoch:{:3} Iter{}/{}\\tl_l1_initp {:.3f}\\tl_l1_stdp {:.3f}\\tl_l1_cross_stdp {:.3f}\\tl_dist_zmove {:.3f}\\tl_kl_zinit {:.3f}\\t l_kl_zmove {:.3f}'.format(\n              epoch, i, len(self.data_loader), self.loss_l1_initp, self.loss_l1_stdp, self.loss_l1_cross_stdp, self.loss_dist_z_movement, self.loss_kl_z_init, self.loss_kl_z_movement))\n\n        it += 1\n        if it % self.log_interval == 0:\n          for tag, value in self.logs.items():\n            self.logger.scalar_summary(tag, value/self.log_interval, it)\n          self.logs = self.init_logs()\n      if epoch % self.snapshot_ep == 0:\n        self.save(os.path.join(self.snapshot_dir, '{:04}.ckpt'.format(epoch)), epoch, it)\n"
  },
  {
    "path": "modulate.py",
    "content": "# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.\n#\n# This work is made available\n# under the Nvidia Source Code License (1-way Commercial).\n# To view a copy of this license, visit\n# https://nvlabs.github.io/Dancing2Music/License.txt\n\nimport os\nimport numpy as np\nimport librosa\nimport utils\n\n\ndef modulate(dance, beats, length):\n  sec_interframe = 1/15\n\n  beats_frame = np.around(beats)\n  t_beat = beats_frame.astype(int)\n  s_beat = np.arange(3,dance.shape[0],8)\n  final_pose = np.zeros((length, 14, 2))\n\n  if t_beat[0] >3:\n    final_pose[t_beat[0]-3:t_beat[0]] = dance[:3]\n  else:\n    final_pose[:t_beat[0]] = dance[:t_beat[0]]\n  if t_beat[0]-3 > 0:\n    final_pose[:t_beat[0]-3] = dance[0]\n  for t in range(t_beat.shape[0]-1):\n    begin = int(t_beat[t])\n    end = int(t_beat[t+1])\n    interval = end-begin\n    if t==s_beat.shape[0]-1:\n      rest = min(final_pose.shape[0]-begin-1, dance.shape[0]-s_beat[t]-1)\n      break\n    if t+1 < s_beat.shape[0] and s_beat[t+1]<dance.shape[0]:\n      pose = get_pose(dance[s_beat[t]:s_beat[t+1]+1], interval+1)\n      if t==0 and begin>=3:\n        final_pose[begin-s_beat[t]:begin] = dance[:s_beat[t]]\n      final_pose[begin:end+1]=pose\n      rest = min(final_pose.shape[0]-end-1, dance.shape[0]-s_beat[t+1]-1)\n    else:\n      end = begin\n      if t+1 < s_beat.shape[0]:\n        rest = min(final_pose.shape[0]-end-1, dance.shape[0]-s_beat[t+1]-1)\n      else:\n        print(t_beat.shape, s_beat.shape, t)\n        rest = min(final_pose.shape[0]-end-1, dance.shape[0]-s_beat[t]-1)\n  if rest > 0:\n    if t+1 < s_beat.shape[0]:\n      final_pose[end+1:end+1+rest] = dance[s_beat[t+1]+1:s_beat[t+1]+1+rest]\n    else:\n      final_pose[end+1:end+1+rest] = dance[s_beat[t]+1:s_beat[t]+1+rest]\n\n  return final_pose\n\ndef get_pose(pose, n):\n  t_pose = np.zeros((n, 14, 2))\n  if n==11:\n    t_pose[0] = pose[0]\n    t_pose[1] = (pose[0]*1+pose[1]*4)/5\n    t_pose[2] = (pose[1]*2+pose[2]*3)/5\n    t_pose[3] = (pose[2]*3+pose[3]*2)/5\n    t_pose[4] = (pose[3]*4+pose[4]*1)/5\n    t_pose[5] = pose[4]\n    t_pose[6] = (pose[4]*1+pose[5]*4)/5\n    t_pose[7] = (pose[5]*2+pose[6]*3)/5\n    t_pose[8] = (pose[6]*3+pose[7]*2)/5\n    t_pose[9] = (pose[7]*4+pose[8]*1)/5\n    t_pose[10] = pose[8]\n  elif n==10:\n    t_pose[0] = pose[0]\n    t_pose[1] = (pose[0]*1+pose[1]*8)/9\n    t_pose[2] = (pose[1]*2+pose[2]*7)/9\n    t_pose[3] = (pose[2]*3+pose[3]*6)/9\n    t_pose[4] = (pose[3]*4+pose[4]*5)/9\n    t_pose[5] = (pose[4]*5+pose[5]*4)/9\n    t_pose[6] = (pose[5]*6+pose[6]*3)/9\n    t_pose[7] = (pose[6]*7+pose[7]*2)/9\n    t_pose[8] = (pose[7]*8+pose[8]*1)/9\n    t_pose[9] = pose[8]\n  elif n==12:\n    t_pose[0] = pose[0]\n    t_pose[1] = (pose[0]*3+pose[1]*8)/11\n    t_pose[2] = (pose[1]*6+pose[2]*5)/11\n    t_pose[3] = (pose[2]*9+pose[3]*2)/11\n    t_pose[4] = (pose[2]*1+pose[3]*10)/11\n    t_pose[5] = (pose[3]*4+pose[4]*7)/11\n    t_pose[6] = (pose[4]*7+pose[5]*4)/11\n    t_pose[7] = (pose[5]*10+pose[6]*1)/11\n    t_pose[8] = (pose[5]*2+pose[6]*9)/11\n    t_pose[9] = (pose[6]*5+pose[7]*6)/11\n    t_pose[10] = (pose[7]*8+pose[8]*3)/11\n    t_pose[11] = pose[8]\n  elif n==13:\n    t_pose[0] = pose[0]\n    t_pose[1] = (pose[0]*1+pose[1]*2)/3\n    t_pose[2] = (pose[1]*2+pose[2]*1)/3\n    t_pose[3] = pose[2]\n    t_pose[4] = (pose[2]*1+pose[3]*2)/3\n    t_pose[5] = (pose[3]*2+pose[4]*1)/3\n    t_pose[6] = pose[4]\n    t_pose[7] = (pose[4]*1+pose[5]*2)/3\n    t_pose[8] = (pose[5]*2+pose[6]*1)/3\n    t_pose[9] = pose[6]\n    t_pose[10] = (pose[6]*1+pose[7]*2)/3\n    t_pose[11] = (pose[7]*2+pose[8]*1)/3\n    t_pose[12] = pose[8]\n  elif n==14:\n    t_pose[0] = pose[0]\n    t_pose[1] = (pose[0]*5+pose[1]*8)/13\n    t_pose[2] = (pose[1]*10+pose[2]*3)/13\n    t_pose[3] = (pose[1]*2+pose[2]*11)/13\n    t_pose[4] = (pose[2]*7+pose[3]*6)/13\n    t_pose[5] = (pose[3]*12+pose[4]*1)/13\n    t_pose[6] = (pose[3]*4+pose[4]*9)/13\n    t_pose[7] = (pose[4]*9+pose[5]*4)/13\n    t_pose[8] = (pose[4]*12+pose[5]*1)/13\n    t_pose[9] = (pose[5]*6+pose[6]*7)/13\n    t_pose[10] = (pose[6]*11+pose[7]*2)/13\n    t_pose[11] = (pose[6]*3+pose[7]*10)/13\n    t_pose[12] = (pose[7]*8+pose[8]*5)/13\n    t_pose[13] = pose[8]\n  elif n==9:\n    t_pose = pose\n  elif n==8:\n    t_pose[0] = pose[0]\n    t_pose[1] = (pose[1]*6+pose[2]*1)/7\n    t_pose[2] = (pose[2]*5+pose[3]*2)/7\n    t_pose[3] = (pose[3]*4+pose[4]*3)/7\n    t_pose[4] = (pose[4]*3+pose[5]*4)/7\n    t_pose[5] = (pose[5]*2+pose[6]*5)/7\n    t_pose[6] = (pose[6]*1+pose[7]*6)/7\n    t_pose[7] = pose[8]\n  elif n==7:\n    t_pose[0] = pose[0]\n    t_pose[1] = (pose[1]*2+pose[2]*1)/3\n    t_pose[2] = (pose[2]*1+pose[3]*2)/3\n    t_pose[3] = pose[4]\n    t_pose[4] = (pose[5]*2+pose[6]*1)/3\n    t_pose[5] = (pose[6]*1+pose[7]*2)/3\n    t_pose[6] = pose[8]\n  elif n==6:\n    t_pose[0] = pose[0]\n    t_pose[1] = (pose[1]*2+pose[2]*3)/5\n    t_pose[2] = (pose[3]*4+pose[4]*1)/5\n    t_pose[3] = (pose[4]*1+pose[5]*4)/5\n    t_pose[4] = (pose[6]*3+pose[7]*2)/5\n    t_pose[5] = pose[8]\n  elif n<6:\n    t_pose[0] = pose[0]\n    t_pose[n-1] = pose[8]\n    for i in range(1,n-1):\n      t_pose[i] = pose[4]\n  elif n>14:\n    t_pose[0] = pose[0]\n    t_pose[n-1] = pose[8]\n    for i in range(1, n-1):\n      k = int(8/(n-1)*i)\n      t_pose[i] = t_pose[k]\n  else:\n    print('NOT IMPLEMENT {}'.format(n))\n\n  return t_pose\n"
  },
  {
    "path": "networks.py",
    "content": "# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.\n#\n# This work is made available\n# under the Nvidia Source Code License (1-way Commercial).\n# To view a copy of this license, visit\n# https://nvlabs.github.io/Dancing2Music/License.txt\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel\nimport torch.utils.data\nfrom torch.autograd import Variable\n\nimport numpy as np\n\nif torch.cuda.is_available():\n  T = torch.cuda\nelse:\n  T = torch\n\n###########################################################\n##########\n##########         Stage 1: Movement\n##########\n###########################################################\nclass InitPose_Enc(nn.Module):\n  def __init__(self, pose_size, dim_z_init):\n    super(InitPose_Enc, self).__init__()\n    nf = 64\n    #nf = 32\n    self.enc = nn.Sequential(\n      nn.Linear(pose_size, nf),\n      nn.LayerNorm(nf),\n      nn.LeakyReLU(0.2, inplace=True),\n      nn.Linear(nf, nf),\n      nn.LayerNorm(nf),\n      nn.LeakyReLU(0.2, inplace=True),\n    )\n    self.mean = nn.Sequential(\n      nn.Linear(nf,dim_z_init),\n    )\n    self.std = nn.Sequential(\n      nn.Linear(nf,dim_z_init),\n    )\n  def forward(self, pose):\n    enc = self.enc(pose)\n    return self.mean(enc), self.std(enc)\n\nclass InitPose_Dec(nn.Module):\n  def __init__(self, pose_size, dim_z_init):\n    super(InitPose_Dec, self).__init__()\n    nf = 64\n    #nf = dim_z_init\n    self.dec = nn.Sequential(\n      nn.Linear(dim_z_init, nf),\n      nn.LayerNorm(nf),\n      nn.LeakyReLU(0.2, inplace=True),\n      nn.Linear(nf, nf),\n      nn.LayerNorm(nf),\n      nn.LeakyReLU(0.2, inplace=True),\n      nn.Linear(nf,pose_size),\n    )\n  def forward(self, z_init):\n    return self.dec(z_init)\n\nclass Movement_Enc(nn.Module):\n  def __init__(self, pose_size, dim_z_movement, length, hidden_size, num_layers, bidirection=False):\n    super(Movement_Enc, self).__init__()\n    self.hidden_size = hidden_size\n    self.bidirection = bidirection\n    if bidirection:\n      self.num_dir = 2\n    else:\n      self.num_dir = 1\n    self.recurrent = nn.GRU(pose_size, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=bidirection)\n    self.init_h = nn.Parameter(torch.randn(num_layers*self.num_dir, 1, hidden_size).type(T.FloatTensor), requires_grad=True)\n    if bidirection:\n      self.mean = nn.Sequential(\n        nn.Linear(hidden_size*2,dim_z_movement),\n      )\n      self.std = nn.Sequential(\n        nn.Linear(hidden_size*2,dim_z_movement),\n      )\n    else:\n      '''\n      self.enc = nn.Sequential(\n        nn.Linear(hidden_size, hidden_size//2),\n        nn.LayerNorm(hidden_size//2),\n        nn.ReLU(inplace=True),\n      )\n      '''\n      self.mean = nn.Sequential(\n        nn.Linear(hidden_size,dim_z_movement),\n      )\n      self.std = nn.Sequential(\n        nn.Linear(hidden_size,dim_z_movement),\n      )\n  def forward(self, poses):\n    num_samples = poses.shape[0]\n    h_t = [self.init_h.repeat(1, num_samples, 1)]\n    output, hidden = self.recurrent(poses, h_t[0])\n    if self.bidirection:\n      output = torch.cat((output[:,-1,:self.hidden_size], output[:,0,self.hidden_size:]), 1)\n    else:\n      output = output[:,-1,:]\n    #enc = self.enc(output)\n    #return self.mean(enc), self.std(enc)\n    return self.mean(output), self.std(output)\n\n  def getFeature(self, poses):\n    num_samples = poses.shape[0]\n    h_t = [self.init_h.repeat(1, num_samples, 1)]\n    output, hidden = self.recurrent(poses, h_t[0])\n    if self.bidirection:\n      output = torch.cat((output[:,-1,:self.hidden_size], output[:,0,self.hidden_size:]), 1)\n    else:\n      output = output[:,-1,:]\n    return output\n\nclass StandardPose_Dec(nn.Module):\n  def __init__(self, pose_size, dim_z_init, dim_z_movement, length, hidden_size, num_layers):\n    super(StandardPose_Dec, self).__init__()\n    self.length = length\n    self.pose_size = pose_size\n    self.hidden_size = hidden_size\n    self.num_layers = num_layers\n    #dim_z_init=0\n    '''\n    self.z2init = nn.Sequential(\n      nn.Linear(dim_z_init+dim_z_movement, hidden_size),\n      nn.LayerNorm(hidden_size),\n      nn.ReLU(True),\n      nn.Linear(hidden_size, num_layers*hidden_size)\n    )\n    '''\n    self.z2init = nn.Sequential(\n      nn.Linear(dim_z_init+dim_z_movement, num_layers*hidden_size)\n    )\n    self.recurrent = nn.GRU(dim_z_movement, hidden_size, num_layers=num_layers, batch_first=True)\n    self.pose_g = nn.Sequential(\n      nn.Linear(hidden_size, hidden_size),\n      nn.LayerNorm(hidden_size),\n      nn.ReLU(True),\n      nn.Linear(hidden_size, pose_size)\n    )\n\n  def forward(self, z_init, z_movement):\n    h_init = self.z2init(torch.cat((z_init, z_movement), 1))\n    #h_init = self.z2init(z_movement)\n    h_init = h_init.view(self.num_layers, h_init.size(0), self.hidden_size)\n    z_movements = z_movement.view(z_movement.size(0),1,z_movement.size(1)).repeat(1, self.length, 1)\n    z_m_t, _ = self.recurrent(z_movements, h_init)\n    z_m = z_m_t.contiguous().view(-1, self.hidden_size)\n    poses = self.pose_g(z_m)\n    poses = poses.view(z_movement.shape[0], self.length, self.pose_size)\n    return poses\n\nclass StandardPose_Dis(nn.Module):\n  def __init__(self, pose_size, length):\n    super(StandardPose_Dis, self).__init__()\n    self.pose_size = pose_size\n    self.length = length\n    nd = 1024\n    self.main = nn.Sequential(\n      nn.Linear(length*pose_size, nd),\n      nn.LayerNorm(nd),\n      nn.LeakyReLU(0.2, inplace=True),\n      nn.Linear(nd,nd//2),\n      nn.LayerNorm(nd//2),\n      nn.LeakyReLU(0.2, inplace=True),\n      nn.Linear(nd//2,nd//4),\n      nn.LayerNorm(nd//4),\n      nn.LeakyReLU(0.2, inplace=True),\n      nn.Linear(nd//4, 1)\n    )\n  def forward(self, pose_seq):\n    pose_seq = pose_seq.view(-1, self.pose_size*self.length)\n    return self.main(pose_seq).squeeze()\n\n###########################################################\n##########\n##########         Stage 2: Dance\n##########\n###########################################################\nclass Dance_Enc(nn.Module):\n  def __init__(self, dim_z_movement, dim_z_dance, hidden_size, num_layers, bidirection=False):\n    super(Dance_Enc, self).__init__()\n    self.hidden_size = hidden_size\n    self.bidirection = bidirection\n    if bidirection:\n      self.num_dir = 2\n    else:\n      self.num_dir = 1\n    self.recurrent = nn.GRU(2*dim_z_movement, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=bidirection)\n    self.init_h = nn.Parameter(torch.randn(num_layers*self.num_dir, 1, hidden_size).type(T.FloatTensor), requires_grad=True)\n    if bidirection:\n      self.mean = nn.Sequential(\n        nn.Linear(hidden_size*2,dim_z_dance),\n      )\n      self.std = nn.Sequential(\n        nn.Linear(hidden_size*2,dim_z_dance),\n      )\n    else:\n      self.mean = nn.Sequential(\n        nn.Linear(hidden_size,dim_z_dance),\n      )\n      self.std = nn.Sequential(\n        nn.Linear(hidden_size,dim_z_dance),\n      )\n  def forward(self, movements_mean, movements_std):\n    movements = torch.cat((movements_mean, movements_std),2)\n    num_samples = movements.shape[0]\n    h_t = [self.init_h.repeat(1, num_samples, 1)]\n    output, hidden = self.recurrent(movements, h_t[0])\n    if self.bidirection:\n      output = torch.cat((output[:,-1,:self.hidden_size], output[:,0,self.hidden_size:]), 1)\n    else:\n      output = output[:,-1,:]\n    return self.mean(output), self.std(output)\n\nclass Dance_Dec(nn.Module):\n  def __init__(self, dim_z_dance, dim_z_movement, hidden_size, num_layers):\n    super(Dance_Dec, self).__init__()\n    #self.length = length\n    self.num_layers = num_layers\n    self.hidden_size = hidden_size\n    self.dim_z_movement = dim_z_movement\n    #dim_z_init=0\n    '''\n    self.z2init = nn.Sequential(\n      nn.Linear(dim_z_init+dim_z_movement, hidden_size),\n      nn.LayerNorm(hidden_size),\n      nn.ReLU(True),\n      nn.Linear(hidden_size, num_layers*hidden_size)\n    )\n    '''\n    self.z2init = nn.Sequential(\n      nn.Linear(dim_z_dance, num_layers*hidden_size)\n    )\n    self.recurrent = nn.GRU(dim_z_dance, hidden_size, num_layers=num_layers, batch_first=True)\n    self.movement_g = nn.Sequential(\n      nn.Linear(hidden_size, hidden_size),\n      nn.LayerNorm(hidden_size),\n      nn.ReLU(True),\n      #nn.Linear(hidden_size, dim_z_movement)\n    )\n    self.mean = nn.Sequential(\n      nn.Linear(hidden_size,dim_z_movement),\n    )\n    self.std = nn.Sequential(\n      nn.Linear(hidden_size,dim_z_movement),\n    )\n\n  def forward(self, z_dance, length=3):\n    h_init = self.z2init(z_dance)\n    h_init = h_init.view(self.num_layers, h_init.size(0), self.hidden_size)\n    z_dance = z_dance.view(z_dance.size(0),1,z_dance.size(1)).repeat(1, length, 1)\n    z_d_t, _ = self.recurrent(z_dance, h_init)\n    z_d = z_d_t.contiguous().view(-1, self.hidden_size)\n    z_movement = self.movement_g(z_d)\n    z_movement_mean, z_movement_std = self.mean(z_movement), self.std(z_movement)\n    #z_movement = z_movement.view(z_dance.shape[0], length, self.dim_z_movement)\n    return z_movement_mean, z_movement_std\n\n\nclass DanceAud_Dis2(nn.Module):\n  def __init__(self, aud_size, dim_z_movement, length=3):\n    super(DanceAud_Dis2, self).__init__()\n    self.aud_size = aud_size\n    self.dim_z_movement = dim_z_movement\n    self.length = length\n    nd = 1024\n    self.movementd = nn.Sequential(\n      nn.Linear(dim_z_movement*2*length, nd),\n      nn.LayerNorm(nd),\n      nn.LeakyReLU(0.2, inplace=True),\n      nn.Linear(nd,nd//2),\n      nn.LayerNorm(nd//2),\n      nn.LeakyReLU(0.2, inplace=True),\n      nn.Linear(nd//2,nd//4),\n      nn.LayerNorm(nd//4),\n      nn.LeakyReLU(0.2, inplace=True),\n      #nn.Linear(nd//4, 30),\n      nn.Linear(nd//4, 30),\n    )\n\n    self.audd = nn.Sequential(\n      nn.Linear(aud_size, 30),\n      nn.LayerNorm(30),\n      nn.LeakyReLU(0.2, inplace=True),\n      nn.Linear(30, 30),\n      nn.LayerNorm(30),\n      nn.LeakyReLU(0.2, inplace=True),\n    )\n    self.jointd = nn.Sequential(\n      nn.Linear(60, 1)\n    )\n\n  def forward(self, movements, aud):\n    if len(movements.shape) == 3:\n      movements = movements.view(movements.shape[0], movements.shape[1]*movements.shape[2])\n    m = self.movementd(movements)\n    a = self.audd(aud)\n    ma = torch.cat((m,a),1)\n\n    return self.jointd(ma).squeeze(), None\n\nclass DanceAud_Dis(nn.Module):\n  def __init__(self, aud_size, dim_z_movement, length=3):\n    super(DanceAud_Dis, self).__init__()\n    self.aud_size = aud_size\n    self.dim_z_movement = dim_z_movement\n    self.length = length\n    nd = 1024\n    self.movementd = nn.Sequential(\n      #nn.Linear(dim_z_movement*3, nd),\n      nn.Linear(dim_z_movement*2, nd),\n      nn.LayerNorm(nd),\n      nn.LeakyReLU(0.2, inplace=True),\n      nn.Linear(nd,nd//2),\n      nn.LayerNorm(nd//2),\n      nn.LeakyReLU(0.2, inplace=True),\n      nn.Linear(nd//2,nd//4),\n      nn.LayerNorm(nd//4),\n      nn.LeakyReLU(0.2, inplace=True),\n      #nn.Linear(nd//4, 30),\n      nn.Linear(nd//4, 30),\n    )\n\n\n  def forward(self, movements, aud):\n    #movements = movements.view(movements.shape[0], movements.shape[1]*movements.shape[2])\n    m = self.movementd(movements)\n    return m.squeeze()\n    #a = self.audd(aud)\n    #ma = torch.cat((m,a),1)\n\n    #return self.jointd(ma).squeeze()\n\nclass DanceAud_InfoDis(nn.Module):\n  def __init__(self, aud_size, dim_z_movement, length):\n    super(DanceAud_InfoDis, self).__init__()\n    self.aud_size = aud_size\n    self.dim_z_movement = dim_z_movement\n    self.length = length\n    nd = 1024\n\n    self.movementd = nn.Sequential(\n      nn.Linear(dim_z_movement*6, nd*2),\n      nn.LayerNorm(nd*2),\n      nn.LeakyReLU(0.2, inplace=True),\n      nn.Linear(nd*2, nd),\n      nn.LayerNorm(nd),\n      nn.LeakyReLU(0.2, inplace=True),\n      nn.Linear(nd,nd//2),\n      nn.LayerNorm(nd//2),\n      nn.LeakyReLU(0.2, inplace=True),\n      nn.Linear(nd//2,nd//4),\n      nn.LayerNorm(nd//4),\n      nn.LeakyReLU(0.2, inplace=True),\n    )\n\n    self.dis = nn.Sequential(\n      nn.Linear(nd//4, 1)\n    )\n    self.reg = nn.Sequential(\n      nn.Linear(nd//4, aud_size)\n    )\n\n  def forward(self, movements, aud):\n    movements = movements.view(movements.shape[0], movements.shape[1]*movements.shape[2])\n    m = self.movementd(movements)\n    return self.dis(m).squeeze(), self.reg(m)\n\nclass Dance2Style(nn.Module):\n  def __init__(self, dim_z_dance, aud_size):\n    super(Dance2Style, self).__init__()\n    self.aud_size = aud_size\n    self.dim_z_dance = dim_z_dance\n    nd = 512\n    self.main = nn.Sequential(\n      nn.Linear(dim_z_dance, nd),\n      nn.LayerNorm(nd),\n      nn.LeakyReLU(0.2, inplace=True),\n      nn.Linear(nd, nd//2),\n      nn.LayerNorm(nd//2),\n      nn.LeakyReLU(0.2, inplace=True),\n      nn.Linear(nd//2, nd//4),\n      nn.LayerNorm(nd//4),\n      nn.LeakyReLU(0.2, inplace=True),\n      nn.Linear(nd//4, aud_size),\n    )\n  def forward(self, zdance):\n    return self.main(zdance)\n\n###########################################################\n##########\n##########         Audio\n##########\n###########################################################\nclass AudioClassifier_rnn(nn.Module):\n  def __init__(self, dim_z_motion, hidden_size, pose_size, cls, num_layers=1, h_init=2):\n    super(AudioClassifier_rnn, self).__init__()\n    self.dim_z_motion = dim_z_motion\n    self.hidden_size = hidden_size\n    self.pose_size = pose_size\n    self.h_init = h_init\n    self.num_layers = num_layers\n\n    self.init_h = nn.Parameter(torch.randn(1, 1, self.hidden_size).type(T.FloatTensor), requires_grad=True)\n    self.recurrent = nn.GRU(pose_size, hidden_size, num_layers=num_layers, batch_first=True)\n    self.classifier = nn.Sequential(\n      #nn.Dropout(p=0.2),\n      nn.Linear(hidden_size, hidden_size),\n      nn.ReLU(True),\n      #nn.Dropout(p=0.2),\n      nn.Linear(hidden_size, cls)\n    )\n  def forward(self, poses):\n    hidden, _ = self.recurrent(poses, self.init_h.repeat(1, poses.shape[0], 1))\n    last_hidden = hidden[:,-1,:]\n    cls = self.classifier(last_hidden)\n    return cls\n  def get_style(self, auds):\n    hidden, _ = self.recurrent(auds, self.init_h.repeat(1, auds.shape[0], 1))\n    last_hidden = hidden[:,-1,:]\n    return last_hidden\n\n\nclass Audstyle_Enc(nn.Module):\n  def __init__(self, aud_size, dim_z, dim_noise=30):\n    super(Audstyle_Enc, self).__init__()\n    self.dim_noise = dim_noise\n    nf = 64\n    #nf = 32\n    self.enc = nn.Sequential(\n      nn.Linear(aud_size+dim_noise, nf),\n      nn.LayerNorm(nf),\n      nn.LeakyReLU(0.2, inplace=True),\n      nn.Linear(nf, nf*2),\n      nn.LayerNorm(nf*2),\n      nn.LeakyReLU(0.2, inplace=True),\n    )\n    self.mean = nn.Sequential(\n      nn.Linear(nf*2,dim_z),\n    )\n    self.std = nn.Sequential(\n      nn.Linear(nf*2,dim_z),\n    )\n  def forward(self, aud):\n    noise = torch.randn(aud.shape[0], self.dim_noise).cuda()\n    y = torch.cat((aud, noise), 1)\n    enc = self.enc(y)\n    return self.mean(enc), self.std(enc)\n"
  },
  {
    "path": "options.py",
    "content": "# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.\n#\n# This work is made available\n# under the Nvidia Source Code License (1-way Commercial).\n# To view a copy of this license, visit\n# https://nvlabs.github.io/Dancing2Music/License.txt\n\nimport argparse\n\n\nclass DecompOptions():\n  def __init__(self):\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument('--name', default=None)\n\n    parser.add_argument('--log_interval', type=int, default=50)\n    parser.add_argument('--log_dir', default='./logs')\n    parser.add_argument('--snapshot_ep', type=int, default=1)\n    parser.add_argument('--snapshot_dir', default='./snapshot')\n    parser.add_argument('--data_dir', default='./data')\n\n    # Model architecture\n    parser.add_argument('--pose_size', type=int, default=28)\n    parser.add_argument('--dim_z_init', type=int, default=10)\n    parser.add_argument('--dim_z_movement', type=int, default=512)\n    parser.add_argument('--stdp_length', type=int, default=32)\n    parser.add_argument('--movement_enc_bidirection', type=int, default=1)\n    parser.add_argument('--movement_enc_hidden_size', type=int, default=1024)\n    parser.add_argument('--stdp_dec_hidden_size', type=int, default=1024)\n    parser.add_argument('--movement_enc_num_layers', type=int, default=1)\n    parser.add_argument('--stdp_dec_num_layers', type=int, default=1)\n    # Training\n    parser.add_argument('--lr', type=float, default=2e-4)\n    parser.add_argument('--batch_size', type=int, default=256)\n    parser.add_argument('--num_epochs', type=int, default=1000)\n    parser.add_argument('--latent_dropout', type=float, default=0.3)\n    parser.add_argument('--lambda_kl', type=float, default=0.01)\n    parser.add_argument('--lambda_initp_recon', type=float, default=1)\n    parser.add_argument('--lambda_initp_consistency', type=float, default=1)\n    parser.add_argument('--lambda_stdp_recon', type=float, default=1)\n    parser.add_argument('--lambda_dist_z_movement', type=float, default=1)\n    # Others\n    parser.add_argument('--num_workers', type=int,  default=4)\n    parser.add_argument('--resume', default=None)\n    parser.add_argument('--dataset', type=int, default=0)\n    parser.add_argument('--tolerance', action='store_true')\n\n    self.parser = parser\n\n  def parse(self):\n    self.opt = self.parser.parse_args()\n    args = vars(self.opt)\n    return self.opt\n\nclass CompOptions():\n  def __init__(self):\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument('--name', default=None)\n\n    parser.add_argument('--log_interval', type=int, default=50)\n    parser.add_argument('--log_dir', default='./logs')\n    parser.add_argument('--snapshot_ep', type=int, default=1)\n    parser.add_argument('--snapshot_dir', default='./snapshot')\n    parser.add_argument('--data_dir', default='./data')\n    # Network architecture\n    parser.add_argument('--pose_size', type=int, default=28)\n    parser.add_argument('--aud_style_size', type=int, default=30)\n    parser.add_argument('--dim_z_init', type=int, default=10)\n    parser.add_argument('--dim_z_movement', type=int, default=512)\n    parser.add_argument('--dim_z_dance', type=int, default=512)\n    parser.add_argument('--stdp_length', type=int, default=32)\n    parser.add_argument('--movement_enc_bidirection', type=int, default=1)\n    parser.add_argument('--movement_enc_hidden_size', type=int, default=1024)\n    parser.add_argument('--stdp_dec_hidden_size', type=int, default=1024)\n    parser.add_argument('--movement_enc_num_layers', type=int, default=1)\n    parser.add_argument('--stdp_dec_num_layers', type=int, default=1)\n    parser.add_argument('--dance_enc_bidirection', type=int, default=0)\n    parser.add_argument('--dance_enc_hidden_size', type=int, default=1024)\n    parser.add_argument('--dance_enc_num_layers', type=int, default=1)\n    parser.add_argument('--dance_dec_hidden_size', type=int, default=1024)\n    parser.add_argument('--dance_dec_num_layers', type=int, default=1)\n    # Training\n    parser.add_argument('--lr', type=float, default=2e-4)\n    parser.add_argument('--batch_size', type=int, default=256)\n    parser.add_argument('--num_epochs', type=int, default=1500)\n    parser.add_argument('--latent_dropout', type=float, default=0.3)\n    parser.add_argument('--lambda_kl', type=float, default=0.01)\n    parser.add_argument('--lambda_kl_dance', type=float, default=0.01)\n    parser.add_argument('--lambda_gan', type=float, default=1)\n    parser.add_argument('--lambda_zmovements_recon', type=float, default=1)\n    parser.add_argument('--lambda_stdpSeq_recon', type=float, default=10)\n    parser.add_argument('--lambda_dist_z_movement', type=float, default=1)\n    # Other\n    parser.add_argument('--num_workers', type=int,  default=4)\n    parser.add_argument('--decomp_snapshot', required=True)\n    parser.add_argument('--neta_snapshot', default='./data/stats/aud_3cls.ckpt')\n    parser.add_argument('--resume', default=None)\n    parser.add_argument('--dataset', type=int, default=2)\n    self.parser = parser\n\n  def parse(self):\n    self.opt = self.parser.parse_args()\n    args = vars(self.opt)\n    return self.opt\n\nclass TestOptions():\n  def __init__(self):\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument('--name', default=None)\n\n    parser.add_argument('--log_interval', type=int, default=50)\n    parser.add_argument('--log_dir', default='./logs')\n    parser.add_argument('--snapshot_ep', type=int, default=1)\n    parser.add_argument('--snapshot_dir', default='./snapshot')\n    parser.add_argument('--data_dir', default='./data')\n    # Network architecture\n    parser.add_argument('--pose_size', type=int, default=28)\n    parser.add_argument('--aud_style_size', type=int, default=30)\n    parser.add_argument('--dim_z_init', type=int, default=10)\n    parser.add_argument('--dim_z_movement', type=int, default=512)\n    parser.add_argument('--dim_z_dance', type=int, default=512)\n    parser.add_argument('--stdp_length', type=int, default=32)\n    parser.add_argument('--movement_enc_bidirection', type=int, default=1)\n    parser.add_argument('--movement_enc_hidden_size', type=int, default=1024)\n    parser.add_argument('--stdp_dec_hidden_size', type=int, default=1024)\n    parser.add_argument('--movement_enc_num_layers', type=int, default=1)\n    parser.add_argument('--stdp_dec_num_layers', type=int, default=1)\n    parser.add_argument('--dance_enc_bidirection', type=int, default=0)\n    parser.add_argument('--dance_enc_hidden_size', type=int, default=1024)\n    parser.add_argument('--dance_enc_num_layers', type=int, default=1)\n    parser.add_argument('--dance_dec_hidden_size', type=int, default=1024)\n    parser.add_argument('--dance_dec_num_layers', type=int, default=1)\n    # Training\n    parser.add_argument('--lr', type=float, default=2e-4)\n    parser.add_argument('--batch_size', type=int, default=256)\n    parser.add_argument('--num_epochs', type=int, default=1500)\n    parser.add_argument('--latent_dropout', type=float, default=0.3)\n    parser.add_argument('--lambda_kl', type=float, default=0.01)\n    parser.add_argument('--lambda_kl_dance', type=float, default=0.01)\n    parser.add_argument('--lambda_gan', type=float, default=1)\n    parser.add_argument('--lambda_zmovements_recon', type=float, default=1)\n    parser.add_argument('--lambda_stdpSeq_recon', type=float, default=10)\n    parser.add_argument('--lambda_dist_z_movement', type=float, default=1)\n    # Other\n    parser.add_argument('--num_workers', type=int,  default=4)\n    parser.add_argument('--decomp_snapshot', required=True)\n    parser.add_argument('--comp_snapshot', required=True)\n    parser.add_argument('--neta_snapshot', default='./data/stats/aud_3cls.ckpt')\n    parser.add_argument('--dataset', type=int, default=2)\n    parser.add_argument('--thr', type=int, default=50)\n    parser.add_argument('--aud_path', type=str, required=True)\n    parser.add_argument('--modulate', action='store_true')\n    parser.add_argument('--out_file', type=str, default='demo/out.mp4')\n    parser.add_argument('--out_dir', type=str, default='demo/out_frame')\n    self.parser = parser\n\n  def parse(self):\n    self.opt = self.parser.parse_args()\n    args = vars(self.opt)\n    return self.opt\n"
  },
  {
    "path": "test.py",
    "content": "# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.\n#\n# This work is made available\n# under the Nvidia Source Code License (1-way Commercial).\n# To view a copy of this license, visit\n# https://nvlabs.github.io/Dancing2Music/License.txt\n\nimport os\nimport argparse\nimport functools\n\nimport torch\nfrom torch.utils.data import DataLoader\nfrom torchvision import transforms\n\nfrom model_comp import *\nfrom networks import *\nfrom options import CompOptions\nfrom data import get_loader\n\n\nif __name__ == \"__main__\":\n  parser = CompOptions()\n  args = parser.parse()\n  #### Pretrain network from Decomp\n  initp_enc, stdp_dec, movement_enc = loadDecompModel(args)\n\n  #### Comp network\n  dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls = loadCompModel(args)\n\n  mean_pose=np.load('../onbeat/all_onbeat_mean.npy')\n  std_pose=np.load('../onbeat/all_onbeat_std.npy')\n  mean_aud=np.load('../onbeat/all_aud_mean.npy')\n  std_aud=np.load('../onbeat/all_aud_std.npy')\n\n\ndef loadDecompModel(args):\n  initp_enc = InitPose_Enc(pose_size=args.pose_size, dim_z_init=args.dim_z_init)\n  stdp_dec = StandardPose_Dec(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, dim_z_init=args.dim_z_init, length=args.stdp_length,\n                          hidden_size=args.stdp_dec_hidden_size, num_layers=args.stdp_dec_num_layers)\n  movement_enc = Movement_Enc(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, length=args.stdp_length,\n                             hidden_size=args.movement_enc_hidden_size, num_layers=args.movement_enc_num_layers, bidirection=(args.movement_enc_bidirection==1))\n  checkpoint = torch.load(args.decomp_snapshot)\n  initp_enc.load_state_dict(checkpoint['initp_enc'])\n  stdp_dec.load_state_dict(checkpoint['stdp_dec'])\n  movement_enc.load_state_dict(checkpoint['movement_enc'])\n  return initp_enc, stdp_dec, movement_enc\n\ndef loadCompModel(args):\n  dance_enc = Dance_Enc(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,\n                               hidden_size=args.dance_enc_hidden_size, num_layers=args.dance_enc_num_layers, bidirection=(args.dance_enc_bidirection==1))\n  dance_dec = Dance_Dec(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,\n                               hidden_size=args.dance_dec_hidden_size, num_layers=args.dance_dec_num_layers)\n  audstyle_enc = Audstyle_Enc(aud_size=args.aud_style_size, dim_z=args.dim_z_dance)\n  dance_reg = Dance2Style(aud_size=args.aud_style_size, dim_z_dance=args.dim_z_dance)\n  danceAud_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_movement, length=3)\n  zdance_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_dance, length=1)\n  checkpoint = torch.load(args.resume)\n  dance_enc.load_state_dict(checkpoint['dance_enc'])\n  dance_dec.load_state_dict(checkpoint['dance_dec'])\n  audstyle_enc.load_state_dict(checkpoint['audstyle_enc'])\n  dance_reg.load_state_dict(checkpoint['dance_reg'])\n  danceAud_dis.load_state_dict(checkpoint['danceAud_dis'])\n  zdance_dis.load_state_dict(checkpoint['zdance_dis'])\n\n  checkpoint2 = torch.load(args.neta_snapshot)\n  neta_cls = AudioClassifier_rnn(10,30,28,cls=3)\n  neta_cls.load_state_dict(checkpoint2)\n\n  return dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls\n"
  },
  {
    "path": "train_comp.py",
    "content": "# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.\n#\n# This work is made available\n# under the Nvidia Source Code License (1-way Commercial).\n# To view a copy of this license, visit\n# https://nvlabs.github.io/Dancing2Music/License.txt\n\nimport os\nimport argparse\nimport functools\n\nimport torch\nfrom torch.utils.data import DataLoader\nfrom torchvision import transforms\n\nfrom model_comp import *\nfrom networks import *\nfrom options import CompOptions\nfrom data import get_loader\n\ndef loadDecompModel(args):\n  initp_enc = InitPose_Enc(pose_size=args.pose_size, dim_z_init=args.dim_z_init)\n  stdp_dec = StandardPose_Dec(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, dim_z_init=args.dim_z_init, length=args.stdp_length,\n                          hidden_size=args.stdp_dec_hidden_size, num_layers=args.stdp_dec_num_layers)\n  movement_enc = Movement_Enc(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, length=args.stdp_length,\n                             hidden_size=args.movement_enc_hidden_size, num_layers=args.movement_enc_num_layers, bidirection=(args.movement_enc_bidirection==1))\n  checkpoint = torch.load(args.decomp_snapshot)\n  initp_enc.load_state_dict(checkpoint['initp_enc'])\n  stdp_dec.load_state_dict(checkpoint['stdp_dec'])\n  movement_enc.load_state_dict(checkpoint['movement_enc'])\n  return initp_enc, stdp_dec, movement_enc\n\ndef getCompNetworks(args):\n  dance_enc = Dance_Enc(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,\n                               hidden_size=args.dance_enc_hidden_size, num_layers=args.dance_enc_num_layers, bidirection=(args.dance_enc_bidirection==1))\n  dance_dec = Dance_Dec(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,\n                               hidden_size=args.dance_dec_hidden_size, num_layers=args.dance_dec_num_layers)\n  audstyle_enc = Audstyle_Enc(aud_size=args.aud_style_size, dim_z=args.dim_z_dance)\n  dance_reg = Dance2Style(aud_size=args.aud_style_size, dim_z_dance=args.dim_z_dance)\n  danceAud_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_movement, length=3)\n  zdance_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_dance, length=1)\n\n  checkpoint2 = torch.load(args.neta_snapshot)\n  neta_cls = AudioClassifier_rnn(10,30,28,cls=3)\n  neta_cls.load_state_dict(checkpoint2)\n\n  return dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls\n\nif __name__ == \"__main__\":\n  parser = CompOptions()\n  args = parser.parse()\n\n  args.train = True\n\n  if args.name is None:\n    args.name = 'Comp'\n\n  args.log_dir = os.path.join(args.log_dir, args.name)\n  if not os.path.exists(args.log_dir):\n    os.mkdir(args.log_dir)\n  args.snapshot_dir = os.path.join(args.snapshot_dir, args.name)\n  if not os.path.exists(args.snapshot_dir):\n    os.mkdir(args.snapshot_dir)\n\n  data_loader = get_loader(batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, dataset=args.dataset, data_dir=args.data_dir)\n\n  #### Pretrain network from Decomp\n  initp_enc, stdp_dec, movement_enc = loadDecompModel(args)\n\n  #### Comp network\n  dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls = getCompNetworks(args)\n\n\n  trainer = Trainer_Comp(data_loader,\n                    movement_enc = movement_enc,\n                    initp_enc = initp_enc,\n                    stdp_dec = stdp_dec,\n                    dance_enc = dance_enc,\n                    dance_dec = dance_dec,\n                    danceAud_dis = danceAud_dis,\n                    zdance_dis = zdance_dis,\n                    aud_enc=neta_cls,\n                    audstyle_enc=audstyle_enc,\n                    dance_reg=dance_reg,\n                    args = args\n                    )\n\n  if not args.resume is None:\n    ep, it = trainer.resume(args.resume, True)\n  else:\n    ep, it = 0, 0\n  trainer.train(ep, it)\n\n"
  },
  {
    "path": "train_decomp.py",
    "content": "# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.\n#\n# This work is made available\n# under the Nvidia Source Code License (1-way Commercial).\n# To view a copy of this license, visit\n# https://nvlabs.github.io/Dancing2Music/License.txt\n\nimport os\nimport argparse\nimport functools\n\nimport torch\nfrom torch.utils.data import DataLoader\nfrom torchvision import transforms\n\nfrom model_decomp import *\nfrom networks import *\nfrom options import DecompOptions\nfrom data import get_loader\n\ndef getDecompNetworks(args):\n  initp_enc = InitPose_Enc(pose_size=args.pose_size, dim_z_init=args.dim_z_init)\n  initp_dec = InitPose_Dec(pose_size=args.pose_size, dim_z_init=args.dim_z_init)\n  movement_enc = Movement_Enc(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, length=args.stdp_length,\n                              hidden_size=args.movement_enc_hidden_size, num_layers=args.movement_enc_num_layers, bidirection=(args.movement_enc_bidirection==1))\n  stdp_dec = StandardPose_Dec(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, dim_z_init=args.dim_z_init, length=args.stdp_length,\n                             hidden_size=args.stdp_dec_hidden_size, num_layers=args.stdp_dec_num_layers)\n  return initp_enc, initp_dec, movement_enc, stdp_dec\n\nif __name__ == \"__main__\":\n  parser = DecompOptions()\n  args = parser.parse()\n\n  args.train = True\n\n  if args.name is None:\n    args.name = 'Decomp'\n\n  args.log_dir = os.path.join(args.log_dir, args.name)\n  if not os.path.exists(args.log_dir):\n    os.mkdir(args.log_dir)\n  args.snapshot_dir = os.path.join(args.snapshot_dir, args.name)\n  if not os.path.exists(args.snapshot_dir):\n    os.mkdir(args.snapshot_dir)\n\n  data_loader = get_loader(batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, dataset=args.dataset, data_dir=args.data_dir, tolerance=args.tolerance)\n\n  initp_enc, initp_dec, movement_enc, stdp_dec = getDecompNetworks(args)\n\n  trainer = Trainer_Decomp(data_loader,\n                    initp_enc = initp_enc,\n                    initp_dec = initp_dec,\n                    movement_enc = movement_enc,\n                    stdp_dec = stdp_dec,\n                    args = args\n                    )\n\n  if not args.resume is None:\n    ep, it = trainer.resume(args.resume, False)\n  else:\n    ep, it = 0, 0\n\n  trainer.train(ep=ep, it=it)\n\n"
  },
  {
    "path": "utils.py",
    "content": "# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.\n#\n# This work is made available\n# under the Nvidia Source Code License (1-way Commercial).\n# To view a copy of this license, visit\n# https://nvlabs.github.io/Dancing2Music/License.txt\n\nimport numpy as np\nimport pickle\nimport cv2\nimport math\nimport os\nimport random\nimport tensorflow as tf\n\nclass Logger(object):\n  def __init__(self, log_dir):\n    self.writer = tf.summary.FileWriter(log_dir)\n\n  def scalar_summary(self, tag, value, step):\n    summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])\n    self.writer.add_summary(summary, step)\n\ndef vis(poses, outdir, aud=None):\n  colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \\\n          [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \\\n          [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]\n\n  # find connection in the specified sequence, center 29 is in the position 15\n  limbSeq = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10], \\\n           [10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17], \\\n           [1,16], [16,18], [3,17], [6,18]]\n\n  neglect = [14,15,16,17]\n\n  for t in range(poses.shape[0]):\n    #break\n    canvas = np.ones((256,500,3), np.uint8)*255\n\n    thisPeak = poses[t]\n    for i in range(18):\n      if i in neglect:\n        continue\n      if thisPeak[i,0] == -1:\n        continue\n      cv2.circle(canvas, tuple(thisPeak[i,0:2].astype(int)), 4, colors[i], thickness=-1)\n\n    for i in range(17):\n      limbid = np.array(limbSeq[i])-1\n      if limbid[0] in neglect or limbid[1] in neglect:\n        continue\n      X = thisPeak[[limbid[0],limbid[1]], 1]\n      Y = thisPeak[[limbid[0],limbid[1]], 0]\n      if X[0] == -1 or Y[0]==-1 or X[1]==-1 or Y[1]==-1:\n        continue\n      stickwidth = 4\n      cur_canvas = canvas.copy()\n      mX = np.mean(X)\n      mY = np.mean(Y)\n      #print(X, Y, limbid)\n      length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5\n      angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))\n      polygon = cv2.ellipse2Poly((int(mY),int(mX)), (int(length/2), stickwidth), int(angle), 0, 360, 1)\n      #print(i, n, int(mY), int(mX), limbid, X, Y)\n      cv2.fillConvexPoly(cur_canvas, polygon, colors[i])\n      canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)\n    if aud is not None:\n      if aud[:,t] == 1:\n        cv2.circle(canvas, (30, 30), 20, (0,0,255), -1)\n        #canvas = cv2.copyMakeBorder(canvas,10,10,10,10,cv2.BORDER_CONSTANT,value=[255,0,0])\n    cv2.imwrite(os.path.join(outdir, 'frame{0:03d}.png'.format(t)),canvas)\n\ndef vis2(poses, outdir, fibeat):\n  colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \\\n          [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \\\n          [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]\n\n  # find connection in the specified sequence, center 29 is in the position 15\n  limbSeq = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10], \\\n           [10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17], \\\n           [1,16], [16,18], [3,17], [6,18]]\n\n\n  neglect = [14,15,16,17]\n\n  ibeat = cv2.imread(fibeat);\n  ibeat = cv2.resize(ibeat, (500,200))\n\n  for t in range(poses.shape[0]):\n    subibeat = ibeat.copy()\n    canvas = np.ones((256+200,500,3), np.uint8)*255\n    canvas[256:,:,:] = subibeat\n\n    overlay = canvas.copy()\n    cv2.rectangle(overlay, (int(500/poses.shape[0]*(t+1)),256),(500,256+200), (100,100,100), -1)\n    cv2.addWeighted(overlay, 0.4, canvas, 1-0.4, 0, canvas)\n    thisPeak = poses[t]\n    for i in range(18):\n      if i in neglect:\n        continue\n      if thisPeak[i,0] == -1:\n        continue\n      cv2.circle(canvas, tuple(thisPeak[i,0:2].astype(int)), 4, colors[i], thickness=-1)\n\n    for i in range(17):\n      limbid = np.array(limbSeq[i])-1\n      if limbid[0] in neglect or limbid[1] in neglect:\n        continue\n      X = thisPeak[[limbid[0],limbid[1]], 1]\n      Y = thisPeak[[limbid[0],limbid[1]], 0]\n      if X[0] == -1 or Y[0]==-1 or X[1]==-1 or Y[1]==-1:\n        continue\n      stickwidth = 4\n      cur_canvas = canvas.copy()\n      mX = np.mean(X)\n      mY = np.mean(Y)\n      #print(X, Y, limbid)\n      length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5\n      angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))\n      polygon = cv2.ellipse2Poly((int(mY),int(mX)), (int(length/2), stickwidth), int(angle), 0, 360, 1)\n      #print(i, n, int(mY), int(mX), limbid, X, Y)\n      cv2.fillConvexPoly(cur_canvas, polygon, colors[i])\n      canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)\n    cv2.imwrite(os.path.join(outdir, 'frame{0:03d}.png'.format(t)),canvas)\n\ndef vis_single(pose, outfile):\n  colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \\\n          [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \\\n          [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]\n\n  # find connection in the specified sequence, center 29 is in the position 15\n  limbSeq = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10], \\\n           [10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17], \\\n           [1,16], [16,18], [3,17], [6,18]]\n\n  neglect = [14,15,16,17]\n\n  for t in range(1):\n    #break\n    canvas = np.ones((256,500,3), np.uint8)*255\n\n    thisPeak = pose\n    for i in range(18):\n      if i in neglect:\n        continue\n      if thisPeak[i,0] == -1:\n        continue\n      cv2.circle(canvas, tuple(thisPeak[i,0:2].astype(int)), 4, colors[i], thickness=-1)\n\n    for i in range(17):\n      limbid = np.array(limbSeq[i])-1\n      if limbid[0] in neglect or limbid[1] in neglect:\n        continue\n      X = thisPeak[[limbid[0],limbid[1]], 1]\n      Y = thisPeak[[limbid[0],limbid[1]], 0]\n      if X[0] == -1 or Y[0]==-1 or X[1]==-1 or Y[1]==-1:\n        continue\n      stickwidth = 4\n      cur_canvas = canvas.copy()\n      mX = np.mean(X)\n      mY = np.mean(Y)\n      length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5\n      angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))\n      polygon = cv2.ellipse2Poly((int(mY),int(mX)), (int(length/2), stickwidth), int(angle), 0, 360, 1)\n      cv2.fillConvexPoly(cur_canvas, polygon, colors[i])\n      canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)\n    cv2.imwrite(outfile,canvas)\n"
  }
]