Repository: NVlabs/Dancing2Music
Branch: master
Commit: 7ff1d95f9f3d
Files: 13
Total size: 93.6 KB
Directory structure:
gitextract_2r5pf7ve/
├── License.txt
├── README.md
├── data.py
├── demo.py
├── model_comp.py
├── model_decomp.py
├── modulate.py
├── networks.py
├── options.py
├── test.py
├── train_comp.py
├── train_decomp.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: License.txt
================================================
Nvidia Source Code License-NC
1. Definitions
“Licensor” means any person or entity that distributes its Work.
“Software” means the original work of authorship made available under this License.
“Work” means the Software and any additions to or derivative works of the Software that are made available under this License.
“Nvidia Processors” means any central processing unit (CPU), graphics processing unit (GPU), field-programmable gate array (FPGA), application-specific integrated circuit (ASIC) or any combination thereof designed, made, sold, or provided by Nvidia or its affiliates.
The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
Works, including the Software, are “made available” under this License by including in or with the Work either (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License.
2. License Grants
2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
3. Limitations
3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you include a complete copy of this License with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. The Work or derivative works thereof may be used or intended for use by Nvidia or its affiliates commercially or non-commercially. As used herein, “non-commercially” means for research or evaluation purposes only.
3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this License from such Licensor (including the grants in Sections 2.1 and 2.2) will terminate immediately.
3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this License.
3.6 Termination. If you violate any term of this License, then your rights under this License (including the grants in Sections 2.1 and 2.2) will terminate immediately.
4. Disclaimer of Warranty.
THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
5. Limitation of Liability.
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
================================================
FILE: README.md
================================================


## Dancing to Music
PyTorch implementation of the cross-modality generative model that synthesizes dance from music.
### Paper
[Hsin-Ying Lee](http://vllab.ucmerced.edu/hylee/), [Xiaodong Yang](https://xiaodongyang.org/), [Ming-Yu Liu](http://mingyuliu.net/), [Ting-Chun Wang](https://tcwang0509.github.io/), [Yu-Ding Lu](https://jonlu0602.github.io/), [Ming-Hsuan Yang](https://faculty.ucmerced.edu/mhyang/), [Jan Kautz](http://jankautz.com/)
Dancing to Music
Neural Information Processing Systems (**NeurIPS**) 2019
[[Paper]](https://arxiv.org/abs/1911.02001) [[YouTube]](https://youtu.be/-e9USqfwZ4A) [[Project]](http://vllab.ucmerced.edu/hylee/Dancing2Music/script.txt) [[Blog]](https://news.developer.nvidia.com/nvidia-dance-to-music-neurips/) [[Supp]](http://xiaodongyang.org/publications/papers/dance2music-supp-neurips19.pdf)
### Example Videos
- Beat-Matching
1st row: generated dance sequences, 2nd row: music beats, 3rd row: kinematics beats
- Multimodality
Generate various dance sequences with the same music and the same initial pose.
- Long-Term Generation
Seamlessly generate a dance sequence with arbitrary length.
- Photo-Realisitc Videos
Map generated dance sequences to photo-realistic videos.
## Train Decomposition
```
python train_decomp.py --name Decomp
```
## Train Composition
```
python train_comp.py --name Decomp --decomp_snapshot DECOMP_SNAPSHOT
```
## Demo
```
python demo.py --decomp_snapshot DECOMP_SNAPSHOT --comp_snapshot COMP_SNAPSHOT --aud_path AUD_PATH --out_file OUT_FILE --out_dir OUT_DIR --thr THR
```
- Flags
- `aud_path`: input .wav file
- `out_file`: location of output .mp4 file
- `out_dir`: directory of output frames
- `thr`: threshold based on motion magnitude
- `modulate`: whether to do beat warping
- Example
```
python demo.py -decomp_snapshot snapshot/Stage1.ckpt --comp_snapshot snapshot/Stage2.ckpt --aud_path demo/demo.wav --out_file demo/out.mp4 --out_dir demo/out_frame
```
### Citation
If you find this code useful for your research, please cite our paper:
```bibtex
@inproceedings{lee2019dancing2music,
title={Dancing to Music},
author={Lee, Hsin-Ying and Yang, Xiaodong and Liu, Ming-Yu and Wang, Ting-Chun and Lu, Yu-Ding and Yang, Ming-Hsuan and Kautz, Jan},
booktitle={NeurIPS},
year={2019}
}
```
### License
Copyright (C) 2020 NVIDIA Corporation. All rights reserved. This work is made available under NVIDIA Source Code License (1-Way Commercial). To view a copy of this license, visit https://nvlabs.github.io/Dancing2Music/LICENSE.txt.
================================================
FILE: data.py
================================================
# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/Dancing2Music/License.txt
import os
import pickle
import numpy as np
import random
import torch.utils.data
from torchvision.datasets import ImageFolder
import utils
class PoseDataset(torch.utils.data.Dataset):
def __init__(self, data_dir, tolerance=False):
self.data_dir = data_dir
z_fname = '{}/unitList/zumba_unit.txt'.format(data_dir)
b_fname = '{}/unitList/ballet_unit.txt'.format(data_dir)
h_fname = '{}/unitList/hiphop_unit.txt'.format(data_dir)
self.z_data = []
self.b_data = []
self.h_data = []
with open(z_fname, 'r') as f:
for line in f:
self.z_data.append([s for s in line.strip().split(' ')])
with open(b_fname, 'r') as f:
for line in f:
self.b_data.append([s for s in line.strip().split(' ')])
with open(h_fname, 'r') as f:
for line in f:
self.h_data.append([s for s in line.strip().split(' ')])
self.data = [self.z_data, self.b_data, self.h_data]
self.tolerance = tolerance
if self.tolerance:
z3_fname = '{}/unitList/zumba_unitseq3.txt'.format(data_dir)
b3_fname = '{}/unitList/ballet_unitseq3.txt'.format(data_dir)
h3_fname = '{}/unitList/hiphop_unitseq3.txt'.format(data_dir)
z4_fname = '{}/unitList/zumba_unitseq4.txt'.format(data_dir)
b4_fname = '{}/unitList/ballet_unitseq4.txt'.format(data_dir)
h4_fname = '{}/unitList/hiphop_unitseq4.txt'.format(data_dir)
z3_data = []; b3_data = []; h3_data = []; z4_data = []; b4_data = []; h4_data = []
with open(z3_fname, 'r') as f:
for line in f:
z3_data.append([s for s in line.strip().split(' ')])
with open(b3_fname, 'r') as f:
for line in f:
b3_data.append([s for s in line.strip().split(' ')])
with open(h3_fname, 'r') as f:
for line in f:
h3_data.append([s for s in line.strip().split(' ')])
with open(z4_fname, 'r') as f:
for line in f:
z4_data.append([s for s in line.strip().split(' ')])
with open(b4_fname, 'r') as f:
for line in f:
b4_data.append([s for s in line.strip().split(' ')])
with open(h4_fname, 'r') as f:
for line in f:
h4_data.append([s for s in line.strip().split(' ')])
self.zt_data = z3_data + z4_data
self.bt_data = b3_data + b4_data
self.ht_data = h3_data + h4_data
self.t_data = [self.zt_data, self.bt_data, self.ht_data]
self.mean_pose=np.load(data_dir+'/stats/all_onbeat_mean.npy')
self.std_pose=np.load(data_dir+'/stats/all_onbeat_std.npy')
def __getitem__(self, index):
cls = random.randint(0,2)
cls = random.randint(0,1)
if self.tolerance and random.randint(0,9)==0:
index = random.randint(0, len(self.t_data[cls])-1)
path = self.t_data[cls][index][0]
path = os.path.join(self.data_dir, path[5:])
orig_poses = np.load(path)
sel = random.randint(0, orig_poses.shape[0]-1)
orig_poses = orig_poses[sel]
else:
index = random.randint(0, len(self.data[cls])-1)
path = self.data[cls][index][0]
path = os.path.join(self.data_dir, path[5:])
orig_poses = np.load(path)
xjit = np.random.uniform(low=-50, high=50)
yjit = np.random.uniform(low=-20, high=20)
poses = orig_poses.copy()
poses[:,:,0] += xjit
poses[:,:,1] += yjit
xjit = np.random.uniform(low=-50, high=50)
yjit = np.random.uniform(low=-20, high=20)
poses2 = orig_poses.copy()
poses2[:,:,0] += xjit
poses2[:,:,1] += yjit
poses = poses.reshape(poses.shape[0], poses.shape[1]*poses.shape[2])
poses2 = poses2.reshape(poses2.shape[0], poses2.shape[1]*poses2.shape[2])
for i in range(poses.shape[0]):
poses[i] = (poses[i]-self.mean_pose)/self.std_pose
poses2[i] = (poses2[i]-self.mean_pose)/self.std_pose
return torch.Tensor(poses), torch.Tensor(poses2)
def __len__(self):
return len(self.z_data)+len(self.b_data)
class MovementAudDataset(torch.utils.data.Dataset):
def __init__(self, data_dir):
self.data_dir = data_dir
z3_fname = '{}/unitList/zumba_unitseq3.txt'.format(data_dir)
b3_fname = '{}/unitList/ballet_unitseq3.txt'.format(data_dir)
h3_fname = '{}/unitList/hiphop_unitseq3.txt'.format(data_dir)
z4_fname = '{}/unitList/zumba_unitseq4.txt'.format(data_dir)
b4_fname = '{}/unitList/ballet_unitseq4.txt'.format(data_dir)
h4_fname = '{}/unitList/hiphop_unitseq4.txt'.format(data_dir)
self.z3_data = []
self.b3_data = []
self.h3_data = []
self.z4_data = []
self.b4_data = []
self.h4_data = []
with open(z3_fname, 'r') as f:
for line in f:
self.z3_data.append([s for s in line.strip().split(' ')])
with open(b3_fname, 'r') as f:
for line in f:
self.b3_data.append([s for s in line.strip().split(' ')])
with open(h3_fname, 'r') as f:
for line in f:
self.h3_data.append([s for s in line.strip().split(' ')])
with open(z4_fname, 'r') as f:
for line in f:
self.z4_data.append([s for s in line.strip().split(' ')])
with open(b4_fname, 'r') as f:
for line in f:
self.b4_data.append([s for s in line.strip().split(' ')])
with open(h4_fname, 'r') as f:
for line in f:
self.h4_data.append([s for s in line.strip().split(' ')])
self.data_3 = [self.z3_data, self.b3_data, self.h3_data]
self.data_4 = [self.z4_data, self.b4_data, self.h4_data]
z_data_root = 'zumba/'
b_data_root = 'ballet/'
h_data_root = 'hiphop/'
self.data_root = [z_data_root, b_data_root, h_data_root ]
self.mean_pose=np.load(data_dir+'/stats/all_onbeat_mean.npy')
self.std_pose=np.load(data_dir+'/stats/all_onbeat_std.npy')
self.mean_aud=np.load(data_dir+'/stats/all_aud_mean.npy')
self.std_aud=np.load(data_dir+'/stats/all_aud_std.npy')
def __getitem__(self, index):
cls = random.randint(0,2)
cls = random.randint(0,1)
isthree = random.randint(0,1)
if isthree == 0:
index = random.randint(0, len(self.data_4[cls])-1)
path = self.data_4[cls][index][0]
else:
index = random.randint(0, len(self.data_3[cls])-1)
path = self.data_3[cls][index][0]
path = os.path.join(self.data_dir, path[5:])
stdpSeq = np.load(path)
vid, cid = path.split('/')[-4], path.split('/')[-3]
#vid, cid = vid_cid[:11], vid_cid[12:]
aud = np.load('{}/{}/{}/{}/aud/c{}_fps15.npy'.format(self.data_dir, self.data_root[cls], vid, cid, cid))
stdpSeq = stdpSeq.reshape(stdpSeq.shape[0], stdpSeq.shape[1], stdpSeq.shape[2]*stdpSeq.shape[3])
for i in range(stdpSeq.shape[0]):
for j in range(stdpSeq.shape[1]):
stdpSeq[i,j] = (stdpSeq[i,j]-self.mean_pose)/self.std_pose
if isthree == 0:
start = random.randint(0,1)
stdpSeq = stdpSeq[start:start+3]
for i in range(aud.shape[0]):
aud[i] = (aud[i]-self.mean_aud)/self.std_aud
aud = aud[:30]
return torch.Tensor(stdpSeq), torch.Tensor(aud)
def __len__(self):
return len(self.z3_data)+len(self.b3_data)+len(self.z4_data)+len(self.b4_data)+len(self.h3_data)+len(self.h4_data)
def get_loader(batch_size, shuffle, num_workers, dataset, data_dir, tolerance=False):
if dataset == 0:
a2d = PoseDataset(data_dir, tolerance)
elif dataset == 2:
a2d = MovementAudDataset(data_dir)
data_loader = torch.utils.data.DataLoader(dataset=a2d,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
)
return data_loader
================================================
FILE: demo.py
================================================
# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/Dancing2Music/License.txt
import os
import argparse
import functools
import librosa
import shutil
import sys
sys.path.insert(0, 'preprocess')
import preprocess as p
import subprocess as sp
from shutil import copyfile
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from model_comp import *
from networks import *
from options import TestOptions
import modulate
import utils
def loadDecompModel(args):
initp_enc = InitPose_Enc(pose_size=args.pose_size, dim_z_init=args.dim_z_init)
stdp_dec = StandardPose_Dec(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, dim_z_init=args.dim_z_init, length=args.stdp_length,
hidden_size=args.stdp_dec_hidden_size, num_layers=args.stdp_dec_num_layers)
movement_enc = Movement_Enc(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, length=args.stdp_length,
hidden_size=args.movement_enc_hidden_size, num_layers=args.movement_enc_num_layers, bidirection=(args.movement_enc_bidirection==1))
checkpoint = torch.load(args.decomp_snapshot)
initp_enc.load_state_dict(checkpoint['initp_enc'])
stdp_dec.load_state_dict(checkpoint['stdp_dec'])
movement_enc.load_state_dict(checkpoint['movement_enc'])
return initp_enc, stdp_dec, movement_enc
def loadCompModel(args):
dance_enc = Dance_Enc(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,
hidden_size=args.dance_enc_hidden_size, num_layers=args.dance_enc_num_layers, bidirection=(args.dance_enc_bidirection==1))
dance_dec = Dance_Dec(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,
hidden_size=args.dance_dec_hidden_size, num_layers=args.dance_dec_num_layers)
audstyle_enc = Audstyle_Enc(aud_size=args.aud_style_size, dim_z=args.dim_z_dance)
dance_reg = Dance2Style(aud_size=args.aud_style_size, dim_z_dance=args.dim_z_dance)
danceAud_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_movement, length=3)
zdance_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_dance, length=1)
checkpoint = torch.load(args.comp_snapshot)
dance_enc.load_state_dict(checkpoint['dance_enc'])
dance_dec.load_state_dict(checkpoint['dance_dec'])
audstyle_enc.load_state_dict(checkpoint['audstyle_enc'])
checkpoint2 = torch.load(args.neta_snapshot)
neta_cls = AudioClassifier_rnn(10,30,28,cls=3)
neta_cls.load_state_dict(checkpoint2)
return dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls
if __name__ == "__main__":
parser = TestOptions()
args = parser.parse()
args.train = False
thr = args.thr
# Process music and get feature
infile = args.aud_path
outfile = 'style.npy'
p.preprocess(infile, outfile)
y, sr = librosa.load(infile)
onset_env = librosa.onset.onset_strength(y, sr=sr,aggregate=np.median)
times = librosa.frames_to_time(np.arange(len(onset_env)),sr=sr, hop_length=512)
tempo, beats = librosa.beat.beat_track(onset_envelope=onset_env,sr=sr)
np.save('beats.npy', times[beats])
beats = np.round(librosa.frames_to_time(beats, sr=sr)*15)
beats = np.load('beats.npy')
aud = np.load('style.npy')
os.remove('beats.npy')
os.remove('style.npy')
shutil.rmtree('normalized')
#### Pretrain network from Decomp
initp_enc, stdp_dec, movement_enc = loadDecompModel(args)
#### Comp network
dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls = loadCompModel(args)
trainer = Trainer_Comp(data_loader=None,
movement_enc = movement_enc,
initp_enc = initp_enc,
stdp_dec = stdp_dec,
dance_enc = dance_enc,
dance_dec = dance_dec,
danceAud_dis = danceAud_dis,
zdance_dis = zdance_dis,
aud_enc=neta_cls,
audstyle_enc=audstyle_enc,
dance_reg=dance_reg,
args = args
)
print('Loading Done')
mean_pose=np.load('{}/stats/all_onbeat_mean.npy'.format(args.data_dir))
std_pose=np.load('{}/stats/all_onbeat_std.npy'.format(args.data_dir))
mean_aud=np.load('{}/stats/all_aud_mean.npy'.format(args.data_dir))
std_aud=np.load('{}/stats/all_aud_std.npy'.format(args.data_dir))
length = aud.shape[0]
initpose = np.zeros((14, 2))
initpose = initpose.reshape(-1)
#initpose = (initpose-mean_pose)/std_pose
for j in range(aud.shape[0]):
aud[j] = (aud[j]-mean_aud)/std_aud
total_t = int(length/32+1)
final_stdpSeq = np.zeros((total_t*3*32, 14, 2))
initpose, aud = torch.Tensor(initpose).cuda(), torch.Tensor(aud).cuda()
initpose, aud = initpose.view(1, initpose.shape[0]), aud.view(1, aud.shape[0], aud.shape[1])
for t in range(total_t):
print('process {}/{}'.format(t, total_t))
fake_stdpSeq = trainer.test_final(initpose, aud, 3, thr)
while True:
fake_stdpSeq = trainer.test_final(initpose, aud, 3, thr)
if not fake_stdpSeq is None:
break
initpose = fake_stdpSeq[2,-1]
initpose = torch.Tensor(initpose).cuda()
initpose = initpose.view(1,-1)
fake_stdpSeq = fake_stdpSeq.squeeze()
for j in range(fake_stdpSeq.shape[0]):
for k in range(fake_stdpSeq.shape[1]):
fake_stdpSeq[j,k] = fake_stdpSeq[j,k]*std_pose + mean_pose
fake_stdpSeq = np.resize(fake_stdpSeq, (fake_stdpSeq.shape[0],32, 14, 2))
for j in range(3):
final_stdpSeq[96*t+32*j:96*t+32*(j+1)] = fake_stdpSeq[j]
if args.modulate:
final_stdpSeq = modulate.modulate(final_stdpSeq, beats, length)
out_dir = args.out_dir
if not os.path.exists(out_dir):
os.mkdir(out_dir)
utils.vis(final_stdpSeq, out_dir)
sp.call('ffmpeg -r 15 -i {}/frame%03d.png -i {} -c:v libx264 -pix_fmt yuv420p -crf 23 -r 30 -y -strict -2 {}'.format(out_dir, args.aud_path, args.out_file), shell=True)
================================================
FILE: model_comp.py
================================================
# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/Dancing2Music/License.txt
import os
import time
import numpy as np
import random
import math
import torch
from torch import nn
from torch.autograd import Variable
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
from utils import Logger
if torch.cuda.is_available():
T = torch.cuda
else:
T = torch
class Trainer_Comp(object):
def __init__(self, data_loader, dance_enc, dance_dec, danceAud_dis, movement_enc, initp_enc, stdp_dec, aud_enc, audstyle_enc, dance_reg=None, args=None, zdance_dis=None):
self.data_loader = data_loader
self.movement_enc = movement_enc
self.initp_enc = initp_enc
self.stdp_dec = stdp_dec
self.dance_enc = dance_enc
self.dance_dec = dance_dec
self.danceAud_dis = danceAud_dis
self.aud_enc = aud_enc
self.audstyle_enc = audstyle_enc
self.train = args.train
self.args = args
if args.train:
self.zdance_dis = zdance_dis
self.dance_reg = dance_reg
self.logger = Logger(args.log_dir)
self.logs = self.init_logs()
self.log_interval = args.log_interval
self.snapshot_ep = args.snapshot_ep
self.snapshot_dir = args.snapshot_dir
self.opt_dance_enc = torch.optim.Adam(self.dance_enc.parameters(), lr=args.lr)
self.opt_dance_dec = torch.optim.Adam(self.dance_dec.parameters(), lr=args.lr)
self.opt_danceAud_dis = torch.optim.Adam(self.danceAud_dis.parameters(), lr=args.lr)
self.opt_audstyle_enc = torch.optim.Adam(self.audstyle_enc.parameters(), lr=args.lr)
self.opt_zdance_dis = torch.optim.Adam(self.zdance_dis.parameters(), lr=args.lr)
self.opt_dance_reg = torch.optim.Adam(self.dance_reg.parameters(), lr=args.lr)
self.opt_stdp_dec = torch.optim.Adam(self.stdp_dec.parameters(), lr=args.lr*0.1)
self.opt_movement_enc = torch.optim.Adam(self.movement_enc.parameters(), lr=args.lr*0.1)
self.latent_dropout = nn.Dropout(p=args.latent_dropout)
self.l1_criterion = torch.nn.L1Loss()
self.gan_criterion = nn.BCEWithLogitsLoss()
self.mse_criterion = nn.MSELoss().cuda()
def init_logs(self):
return {'l_kl_zdance':0, 'l_kl_zmovement':0, 'l_kl_fake_zdance':0, 'l_kl_fake_zmovement':0,
'l_l1_zmovement_mu':0, 'l_l1_zmovement_logvar':0, 'l_l1_stdpSeq':0, 'l_l1_zdance':0,
'l_dis':0, 'l_dis_true':0, 'l_dis_fake':0,
'l_info':0, 'l_info_real':0, 'l_info_fake':0,
'l_gen':0
}
def get_z_random(self, batchSize, nz, random_type='gauss'):
z = torch.randn(batchSize, nz).cuda()
return z
@staticmethod
def ones_like(tensor, val=1.):
return T.FloatTensor(tensor.size()).fill_(val)
@staticmethod
def zeros_like(tensor, val=0.):
return T.FloatTensor(tensor.size()).fill_(val)
def kld_coef(self, i):
return float(1/(1+np.exp(-0.0005*(i-15000))))
def forward(self, stdpSeq, batchsize, aud_style, aud):
self.aud = torch.mean(aud, dim=1)
self.batchsize = batchsize
self.stdpSeq = stdpSeq
self.aud_style = aud_style
### stdpSeq -> z_inits, z_movements
self.pose_0 = stdpSeq[:,0,:]
self.z_init_mu, self.z_init_logvar = self.initp_enc(self.pose_0)
z_init_std = self.z_init_logvar.mul(0.5).exp_()
z_init_eps = self.get_z_random(z_init_std.size(0), z_init_std.size(1), 'gauss')
self.z_init = z_init_eps.mul(z_init_std).add_(self.z_init_mu)
self.z_movement_mus, self.z_movement_logvars = self.movement_enc(stdpSeq)
z_movement_stds = self.z_movement_logvars.mul(0.5).exp_()
z_movement_epss = self.get_z_random(z_movement_stds.size(0), z_movement_stds.size(1), 'gauss')
self.z_movements = z_movement_epss.mul(z_movement_stds).add_(self.z_movement_mus)
self.z_movementSeq_mu = self.z_movement_mus.view(batchsize, -1, self.z_movements.shape[1])
self.z_movementSeq_logvar = self.z_movement_logvars.view(batchsize, -1, self.z_movements.shape[1])
self.z_init, self.z_movements = self.z_init.detach(), self.z_movements.detach()
self.z_movement_mus, self.z_movement_logvars = self.z_movement_mus.detach(), self.z_movement_logvars.detach()
### z_movements -> z_dance
self.z_dance_mu, self.z_dance_logvar = self.dance_enc(self.z_movementSeq_mu, self.z_movementSeq_logvar)
z_dance_std = self.z_dance_logvar.mul(0.5).exp_()
z_dance_eps = self.get_z_random(z_dance_std.size(0), z_dance_std.size(1), 'gauss')
self.z_dance = z_dance_eps.mul(z_dance_std).add_(self.z_dance_mu)
### z_dance -> z_movements
self.recon_z_movements_mu, self.recon_z_movements_logvar = self.dance_dec(self.z_dance)
recon_z_movement_std = self.recon_z_movements_logvar.mul(0.5).exp_()
recon_z_movement_eps = self.get_z_random(recon_z_movement_std.size(0), recon_z_movement_std.size(1), 'gauss')
self.recon_z_movements = recon_z_movement_eps.mul(recon_z_movement_std).add_(self.recon_z_movements_mu)
### z_movements -> stdpSeq
self.recon_stdpSeq = self.stdp_dec(self.z_init, self.recon_z_movements)
### Music to z_dance to z_movements
self.fake_z_dance_mu, self.fake_z_dance_logvar = self.audstyle_enc(aud_style)
fake_z_dance_std = self.fake_z_dance_logvar.mul(0.5).exp_()
fake_z_dance_eps = self.get_z_random(fake_z_dance_std.size(0), fake_z_dance_std.size(1), 'gauss')
self.fake_z_dance = fake_z_dance_eps.mul(fake_z_dance_std).add_(self.fake_z_dance_mu)
self.fake_z_movements_mu, self.fake_z_movements_logvar = self.dance_dec(self.fake_z_dance)
fake_z_movements_std = self.fake_z_movements_logvar.mul(0.5).exp_()
fake_z_movements_eps = self.get_z_random(fake_z_movements_std.size(0), fake_z_movements_std.size(1), 'gauss')
self.fake_z_movements = fake_z_movements_eps.mul(fake_z_movements_std).add_(self.fake_z_movements_mu)
fake_z_movementSeq_mu = self.fake_z_movements_mu.view(batchsize, -1, self.fake_z_movements_mu.shape[1])
fake_z_movementSeq_logvar = self.fake_z_movements_logvar.view(batchsize, -1, self.fake_z_movements_logvar.shape[1])
self.fake_z_movementSeq = torch.cat((fake_z_movementSeq_mu, fake_z_movementSeq_logvar),2)
def backward_D(self):
#real_movements = torch.cat((self.z_movementSeq_mu, self.z_movementSeq_logvar),2)
tmp_recon_mu = self.recon_z_movements_mu.view(self.batchsize, -1, self.z_movements.shape[1])
tmp_recon_logvar = self.recon_z_movements_logvar.view(self.batchsize, -1, self.z_movements.shape[1])
real_movements = torch.cat((tmp_recon_mu, tmp_recon_logvar),2)
fake_movements = self.fake_z_movementSeq
real_labels,_ = self.danceAud_dis(real_movements.detach(), self.aud)
fake_labels,_ = self.danceAud_dis(fake_movements.detach(), self.aud)
ones = self.ones_like(real_labels)
zeros = self.zeros_like(fake_labels)
self.loss_dis_true = self.gan_criterion(real_labels, ones)
self.loss_dis_fake = self.gan_criterion(fake_labels, zeros)
self.loss_dis = (self.loss_dis_true + self.loss_dis_fake)*self.args.lambda_gan
real_dance = torch.cat((self.z_dance_mu, self.z_dance_logvar), 1)
fake_dance = torch.cat((self.fake_z_dance_mu, self.fake_z_dance_logvar), 1)
real_labels, _ = self.zdance_dis(real_dance.detach(), self.aud)
fake_labels, _ = self.zdance_dis(fake_dance.detach(), self.aud)
ones = self.ones_like(real_labels)
zeros = self.zeros_like(fake_labels)
self.loss_zdis_true = self.gan_criterion(real_labels, ones)
self.loss_zdis_fake = self.gan_criterion(fake_labels, zeros)
self.loss_dis += (self.loss_zdis_true + self.loss_zdis_fake)*self.args.lambda_gan
def backward_danceED(self):
# z_dance KL
kl_element = self.z_dance_mu.pow(2).add_(self.z_dance_logvar.exp()).mul_(-1).add_(1).add_(self.z_dance_logvar)
self.loss_kl_z_dance = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl_dance))
kl_element = self.fake_z_dance_mu.pow(2).add_(self.fake_z_dance_logvar.exp()).mul_(-1).add_(1).add_(self.fake_z_dance_logvar)
self.loss_kl_fake_z_dance = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl_dance))
# z_movement KL
kl_element = self.recon_z_movements_mu.pow(2).add_(self.recon_z_movements_logvar.exp()).mul_(-1).add_(1).add_(self.recon_z_movements_logvar)
self.loss_kl_z_movement = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl))
kl_element = self.fake_z_movements_mu.pow(2).add_(self.fake_z_movements_logvar.exp()).mul_(-1).add_(1).add_(self.fake_z_movements_logvar)
self.loss_kl_fake_z_movements = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl))
# z_movement reconstruction
self.loss_l1_z_movement_mu = self.l1_criterion(self.recon_z_movements_mu, self.z_movement_mus) * self.args.lambda_zmovements_recon
self.loss_l1_z_movement_logvar = self.l1_criterion(self.recon_z_movements_logvar, self.z_movement_logvars) * self.args.lambda_zmovements_recon
# stdp reconstruction
self.loss_l1_stdpSeq = self.l1_criterion(self.recon_stdpSeq, self.stdpSeq) * self.args.lambda_stdpSeq_recon
# Music2Dance GAN
fake_movements = self.fake_z_movementSeq
fake_labels, _ = self.danceAud_dis(fake_movements, self.aud)
ones = self.ones_like(fake_labels)
self.loss_gen = self.gan_criterion(fake_labels, ones) * self.args.lambda_gan
fake_dance = torch.cat((self.fake_z_dance_mu, self.fake_z_dance_logvar), 1)
fake_labels, _ = self.zdance_dis(fake_dance, self.aud)
ones = self.ones_like(fake_labels)
self.loss_gen += self.gan_criterion(fake_labels, ones) * self.args.lambda_gan
self.loss = self.loss_kl_z_movement + self.loss_kl_z_dance + self.loss_l1_z_movement_mu + self.loss_l1_z_movement_logvar + self.loss_l1_stdpSeq + self.loss_gen
def backward_info_ondance(self):
real_pred = self.dance_reg(self.z_dance)
fake_pred = self.dance_reg(self.fake_z_dance)
self.loss_info_real = self.mse_criterion(real_pred, self.aud_style)
self.loss_info_fake = self.mse_criterion(fake_pred, self.aud_style)
self.loss_info = self.loss_info_real + self.loss_info_fake
def zero_grad(self, opt_list):
for opt in opt_list:
opt.zero_grad()
def clip_norm(self, network_list):
for network in network_list:
clip_grad_norm_(network.parameters(), 0.5)
def step(self, opt_list):
for opt in opt_list:
opt.step()
def update(self):
self.zero_grad([self.opt_danceAud_dis, self.opt_zdance_dis])
self.backward_D()
self.loss_dis.backward(retain_graph=True)
self.clip_norm([self.danceAud_dis, self.zdance_dis])
self.step([self.opt_danceAud_dis, self.opt_zdance_dis])
self.zero_grad([self.opt_dance_enc, self.opt_dance_dec, self.opt_audstyle_enc, self.opt_stdp_dec])
self.backward_danceED()
self.loss.backward(retain_graph=True)
self.clip_norm([self.dance_enc, self.dance_dec, self.audstyle_enc, self.stdp_dec])
self.step([self.opt_dance_enc, self.opt_dance_dec, self.opt_audstyle_enc, self.opt_stdp_dec])
self.zero_grad([self.opt_dance_enc, self.opt_audstyle_enc, self.opt_dance_reg, self.opt_stdp_dec])
self.backward_info_ondance()
self.loss_info.backward()
self.clip_norm([self.dance_enc, self.audstyle_enc, self.dance_reg, self.stdp_dec])
self.step([self.opt_dance_enc, self.opt_audstyle_enc, self.opt_dance_reg, self.opt_stdp_dec])
def test_final(self, initpose, aud, n, thr=0):
self.cuda()
self.movement_enc.eval()
self.stdp_dec.eval()
self.initp_enc.eval()
self.dance_enc.eval()
self.dance_dec.eval()
self.aud_enc.eval()
self.audstyle_enc.eval()
aud_style = self.aud_enc.get_style(aud).detach()
self.fake_z_dance_mu, self.fake_z_dance_logvar = self.audstyle_enc(aud_style)
fake_z_dance_std = self.fake_z_dance_logvar.mul(0.5).exp_()
fake_z_dance_eps = self.get_z_random(fake_z_dance_std.size(0), fake_z_dance_std.size(1), 'gauss')
self.fake_z_dance = fake_z_dance_eps.mul(fake_z_dance_std).add_(self.fake_z_dance_mu)
self.fake_z_movements_mu, self.fake_z_movements_logvar = self.dance_dec(self.fake_z_dance, length=3)
fake_z_movements_std = self.fake_z_movements_logvar.mul(0.5).exp_()
fake_z_movements_eps = self.get_z_random(fake_z_movements_std.size(0), fake_z_movements_std.size(1), 'gauss')
self.fake_z_movements = fake_z_movements_eps.mul(fake_z_movements_std).add_(self.fake_z_movements_mu)
fake_stdpSeq=[]
for i in range(n):
z_init_mus, z_init_logvars = self.initp_enc(initpose)
z_init_stds = z_init_logvars.mul(0.5).exp_()
z_init_epss = self.get_z_random(z_init_stds.size(0), z_init_stds.size(1), 'gauss')
z_init = z_init_epss.mul(z_init_stds).add_(z_init_mus)
fake_stdp = self.stdp_dec(z_init, self.fake_z_movements[i:i+1])
fake_stdpSeq.append(fake_stdp)
initpose = fake_stdp[:,-1,:]
fake_stdpSeq = torch.cat(fake_stdpSeq, dim=0)
flag = False
for i in range(n):
s = fake_stdpSeq[i]
diff = torch.abs(s[1:]-s[:-1])
diffsum = torch.sum(diff)
if diffsum.cpu().detach().numpy() < thr:
flag = True
if flag:
return None
else:
return fake_stdpSeq.cpu().detach().numpy()
def resume(self, model_dir, train=True):
checkpoint = torch.load(model_dir)
self.dance_enc.load_state_dict(checkpoint['dance_enc'])
self.dance_dec.load_state_dict(checkpoint['dance_dec'])
self.audstyle_enc.load_state_dict(checkpoint['audstyle_enc'])
self.stdp_dec.load_state_dict(checkpoint['stdp_dec'])
self.movement_enc.load_state_dict(checkpoint['movement_enc'])
if train:
self.danceAud_dis.load_state_dict(checkpoint['danceAud_dis'])
self.dance_reg.load_state_dict(checkpoint['dance_reg'])
self.opt_dance_enc.load_state_dict(checkpoint['opt_dance_enc'])
self.opt_dance_dec.load_state_dict(checkpoint['opt_dance_dec'])
self.opt_stdp_dec.load_state_dict(checkpoint['opt_stdp_dec'])
self.opt_audstyle_enc.load_state_dict(checkpoint['opt_audstyle_enc'])
self.opt_danceAud_dis.load_state_dict(checkpoint['opt_danceAud_dis'])
self.opt_dance_reg.load_state_dict(checkpoint['opt_dance_reg'])
return checkpoint['ep'], checkpoint['total_it']
def save(self, filename, ep, total_it):
state = {
'stdp_dec': self.stdp_dec.state_dict(),
'movement_enc': self.movement_enc.state_dict(),
'dance_enc': self.dance_enc.state_dict(),
'dance_dec': self.dance_dec.state_dict(),
'audstyle_enc': self.audstyle_enc.state_dict(),
'danceAud_dis': self.danceAud_dis.state_dict(),
'zdance_dis': self.zdance_dis.state_dict(),
'dance_reg': self.dance_reg.state_dict(),
'opt_stdp_dec': self.opt_stdp_dec.state_dict(),
'opt_movement_enc': self.opt_movement_enc.state_dict(),
'opt_dance_enc': self.opt_dance_enc.state_dict(),
'opt_dance_dec': self.opt_dance_dec.state_dict(),
'opt_audstyle_enc': self.opt_audstyle_enc.state_dict(),
'opt_danceAud_dis': self.opt_danceAud_dis.state_dict(),
'opt_zdance_dis': self.opt_zdance_dis.state_dict(),
'opt_dance_reg': self.opt_dance_reg.state_dict(),
'ep': ep,
'total_it': total_it
}
torch.save(state, filename)
return
def cuda(self):
if self.train:
self.dance_reg.cuda()
self.danceAud_dis.cuda()
self.zdance_dis.cuda()
self.stdp_dec.cuda()
self.initp_enc.cuda()
self.movement_enc.cuda()
self.dance_enc.cuda()
self.dance_dec.cuda()
self.aud_enc.cuda()
self.audstyle_enc.cuda()
self.gan_criterion.cuda()
def train(self, ep=0, it=0):
self.cuda()
for epoch in range(ep, self.args.num_epochs):
self.movement_enc.train()
self.stdp_dec.train()
self.initp_enc.train()
self.dance_enc.train()
self.dance_dec.train()
self.danceAud_dis.train()
self.zdance_dis.train()
self.audstyle_enc.train()
self.dance_reg.train()
self.aud_enc.eval()
stdp_recon = 0
for i, (stdpSeq, aud) in enumerate(self.data_loader):
stdpSeq, aud = stdpSeq.cuda().detach(), aud.cuda().detach()
stdpSeq = stdpSeq.view(stdpSeq.shape[0]*stdpSeq.shape[1], stdpSeq.shape[2], stdpSeq.shape[3])
aud_style = self.aud_enc.get_style(aud).detach()
self.forward(stdpSeq, aud.shape[0], aud_style, aud)
self.update()
self.logs['l_kl_zmovement'] += self.loss_kl_z_movement.data
self.logs['l_kl_zdance'] += self.loss_kl_z_dance.data
self.logs['l_l1_zmovement_mu'] += self.loss_l1_z_movement_mu.data
self.logs['l_l1_zmovement_logvar'] += self.loss_l1_z_movement_logvar.data
self.logs['l_l1_stdpSeq'] += self.loss_l1_stdpSeq.data
self.logs['l_kl_fake_zdance'] += self.loss_kl_fake_z_dance.data
self.logs['l_kl_fake_zmovement'] += self.loss_kl_fake_z_movements
self.logs['l_dis'] += self.loss_dis.data
self.logs['l_dis_true'] += self.loss_dis_true.data
self.logs['l_dis_fake'] += self.loss_dis_fake.data
self.logs['l_gen'] += self.loss_gen.data
self.logs['l_info'] += self.loss_info
self.logs['l_info_real'] += self.loss_info_real
self.logs['l_info_fake'] += self.loss_info_fake
print('Epoch:{:3} Iter{}/{}\tl_l1_zmovement mu{:.3f} logvar{:.3f}\tl_l1_stdpSeq {:.3f}\tl_kl_dance {:.3f}\tl_kl_movement {:.3f}\n'.format(epoch, i, len(self.data_loader),
self.loss_l1_z_movement_mu, self.loss_l1_z_movement_logvar, self.loss_l1_stdpSeq, self.loss_kl_z_dance, self.loss_kl_z_movement) +
'\t\t\tl_kl_f_dance {:.3f}\tl_dis {:.3f} {:.3f}\tl_gen {:.3f}'.format(self.loss_kl_fake_z_dance, self.loss_dis_true, self.loss_dis_fake, self.loss_gen))
it += 1
if it % self.log_interval == 0:
for tag, value in self.logs.items():
self.logger.scalar_summary(tag, value/self.log_interval, it)
self.logs = self.init_logs()
if epoch % self.snapshot_ep == 0:
self.save(os.path.join(self.snapshot_dir, '{:04}.ckpt'.format(epoch)), epoch, it)
================================================
FILE: model_decomp.py
================================================
# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/Dancing2Music/License.txt
import os
import time
import numpy as np
import random
import math
import torch
from torch import nn
from torch.autograd import Variable
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
from utils import Logger
if torch.cuda.is_available():
T = torch.cuda
else:
T = torch
class Trainer_Decomp(object):
def __init__(self, data_loader, initp_enc, initp_dec, movement_enc, stdp_dec, args=None):
self.data_loader = data_loader
self.initp_enc = initp_enc
self.initp_dec = initp_dec
self.movement_enc = movement_enc
self.stdp_dec = stdp_dec
self.args = args
if args.train:
self.logger = Logger(args.log_dir)
self.logs = self.init_logs()
self.log_interval = args.log_interval
self.snapshot_ep = args.snapshot_ep
self.snapshot_dir = args.snapshot_dir
self.opt_initp_enc = torch.optim.Adam(self.initp_enc.parameters(), lr=args.lr)
self.opt_initp_dec = torch.optim.Adam(self.initp_dec.parameters(), lr=args.lr)
self.opt_movement_enc = torch.optim.Adam(self.movement_enc.parameters(), lr=args.lr)
self.opt_stdp_dec = torch.optim.Adam(self.stdp_dec.parameters(), lr=args.lr)
self.latent_dropout = nn.Dropout(p=args.latent_dropout)
self.l1_criterion = torch.nn.L1Loss()
self.gan_criterion = nn.BCEWithLogitsLoss()
def init_logs(self):
return {'l_kl_zinit':0, 'l_kl_zmovement':0, 'l_l1_stdp':0, 'l_l1_cross_stdp':0, 'l_dist_zmovement':0,
'l_l1_initp':0, 'l_l1_initp_con':0,
'kld_coef':0
}
def get_z_random(self, batchSize, nz, random_type='gauss'):
z = torch.randn(batchSize, nz).cuda()
return z
@staticmethod
def ones_like(tensor, val=1.):
return T.FloatTensor(tensor.size()).fill_(val)
@staticmethod
def zeros_like(tensor, val=0.):
return T.FloatTensor(tensor.size()).fill_(val)
def random_generate_stdp(self, init_p):
self.pose_0 = init_p
self.z_init_mu, self.z_init_logvar = self.initp_enc(self.pose_0)
z_init_std = self.z_init_logvar.mul(0.5).exp_()
z_init_eps = self.get_z_random(z_init_std.size(0), z_init_std.size(1), 'gauss')
self.z_init = z_init_eps.mul(z_init_std).add_(self.z_init_mu)
self.z_random_movement = self.get_z_random(self.z_init.size(0), 512, 'gauss')
self.fake_stdpose = self.stdp_dec(self.z_init, self.z_random_movement)
return self.fake_stdpose
def forward(self, stdpose1, stdpose2):
self.stdpose1 = stdpose1
self.stdpose2 = stdpose2
# stdpose -> stdpose[0] -> z_init
self.pose1_0 = stdpose1[:,0,:]
self.pose2_0 = stdpose2[:,0,:]
self.poses_0 = torch.cat((self.pose1_0, self.pose2_0), 0)
self.z_init_mus, self.z_init_logvars = self.initp_enc(self.poses_0)
z_init_stds = self.z_init_logvars.mul(0.5).exp_()
z_init_epss = self.get_z_random(z_init_stds.size(0), z_init_stds.size(1), 'gauss')
self.z_inits = z_init_epss.mul(z_init_stds).add_(self.z_init_mus)
self.z_init1, self.z_init2 = torch.split(self.z_inits, self.stdpose1.size(0), dim=0)
# stdpose -> z_movement
stdposes = torch.cat((stdpose1, stdpose2), 0)
self.z_movement_mus, self.z_movement_logvars = self.movement_enc(stdposes)
z_movement_stds = self.z_movement_logvars.mul(0.5).exp_()
z_movement_epss = self.get_z_random(z_movement_stds.size(0), z_movement_stds.size(1), 'gauss')
self.z_movements = z_movement_epss.mul(z_movement_stds).add_(self.z_movement_mus)
self.z_movement1, self.z_movement2 = torch.split(self.z_movements, self.stdpose1.size(0), dim=0)
# zinit1+zmovement1->stdpose1 zinit2+zmovement2->stdpose2
self.recon_stdpose1 = self.stdp_dec(self.z_init1, self.z_movement1)
self.recon_stdpose2 = self.stdp_dec(self.z_init2, self.z_movement2)
# zinit1+zmovement2->stdpose1 zinit2+zmovement1->stdpose2
self.recon_stdpose1_cross = self.stdp_dec(self.z_init1, self.z_movement2)
self.recon_stdpose2_cross = self.stdp_dec(self.z_init2, self.z_movement1)
# z_init -> \hat{stdpose[0]}
self.recon_pose1_0 = self.initp_dec(self.z_init1)
self.recon_pose2_0 = self.initp_dec(self.z_init2)
# single pose reconstruction
randomlist = np.random.permutation(31)[:4]
singlepose = []
for r in randomlist:
singlepose.append(self.stdpose1[:,r,:])
self.singleposes = torch.cat(singlepose, dim=0).detach()
self.z_single_mus, self.z_single_logvars = self.initp_enc(self.singleposes)
z_single_stds = self.z_single_logvars.mul(0.5).exp_()
z_single_epss = self.get_z_random(z_single_stds.size(0), z_single_stds.size(1), 'gauss')
z_single = z_single_epss.mul(z_single_stds).add_(self.z_single_mus)
self.recon_singleposes = self.initp_dec(z_single)
def backward_initp_ED(self):
# z_init KL
kl_element = self.z_init_mus.pow(2).add_(self.z_init_logvars.exp()).mul_(-1).add_(1).add_(self.z_init_logvars)
self.loss_kl_z_init = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl))
# initpose reconstruction
self.loss_l1_initp = self.l1_criterion(self.recon_singleposes, self.singleposes) * self.args.lambda_initp_recon
self.loss_initp = self.loss_kl_z_init + self.loss_l1_initp
def backward_movement_ED(self):
# z_movement KL
kl_element = self.z_movement_mus.pow(2).add_(self.z_movement_logvars.exp()).mul_(-1).add_(1).add_(self.z_movement_logvars)
#self.loss_kl_z_movement = torch.mean(kl_element).mul_(-0.5) * self.args.lambda_kl
self.loss_kl_z_movement = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl))
# stdpose self reconstruction
loss_l1_stdp1 = self.l1_criterion(self.recon_stdpose1, self.stdpose1) * self.args.lambda_stdp_recon
loss_l1_stdp2 = self.l1_criterion(self.recon_stdpose2, self.stdpose2) * self.args.lambda_stdp_recon
self.loss_l1_stdp = loss_l1_stdp1 + loss_l1_stdp2
# stdpose cross reconstruction
loss_l1_cross_stdp1 = self.l1_criterion(self.recon_stdpose1_cross, self.stdpose1) * self.args.lambda_stdp_recon
loss_l1_cross_stdp2 = self.l1_criterion(self.recon_stdpose2_cross, self.stdpose2) * self.args.lambda_stdp_recon
self.loss_l1_cross_stdp = loss_l1_cross_stdp1 + loss_l1_cross_stdp2
# Movement dist
self.loss_dist_z_movement = torch.mean(torch.abs(self.z_movement1-self.z_movement2)) * self.args.lambda_dist_z_movement
self.loss_movement = self.loss_kl_z_movement + self.loss_l1_stdp + self.loss_l1_cross_stdp + self.loss_dist_z_movement
def update(self):
self.opt_initp_enc.zero_grad()
self.opt_initp_dec.zero_grad()
self.opt_movement_enc.zero_grad()
self.opt_stdp_dec.zero_grad()
self.backward_initp_ED()
self.backward_movement_ED()
self.g_loss = self.loss_initp + self.loss_movement
self.g_loss.backward(retain_graph=True)
clip_grad_norm_(self.movement_enc.parameters(), 0.5)
clip_grad_norm_(self.stdp_dec.parameters(), 0.5)
self.opt_initp_enc.step()
self.opt_initp_dec.step()
self.opt_movement_enc.step()
self.opt_stdp_dec.step()
def save(self, filename, ep, total_it):
state = {
'stdp_dec': self.stdp_dec.state_dict(),
'movement_enc': self.movement_enc.state_dict(),
'initp_enc': self.initp_enc.state_dict(),
'initp_dec': self.initp_dec.state_dict(),
'opt_stdp_dec': self.opt_stdp_dec.state_dict(),
'opt_movement_enc': self.opt_movement_enc.state_dict(),
'opt_initp_enc': self.opt_initp_enc.state_dict(),
'opt_initp_dec': self.opt_initp_dec.state_dict(),
'ep': ep,
'total_it': total_it
}
torch.save(state, filename)
return
def resume(self, model_dir, train=True):
checkpoint = torch.load(model_dir)
# weight
self.stdp_dec.load_state_dict(checkpoint['stdp_dec'])
self.movement_enc.load_state_dict(checkpoint['movement_enc'])
self.initp_enc.load_state_dict(checkpoint['initp_enc'])
self.initp_dec.load_state_dict(checkpoint['initp_dec'])
# optimizer
if train:
self.opt_stdp_dec.load_state_dict(checkpoint['opt_stdp_dec'])
self.opt_movement_enc.load_state_dict(checkpoint['opt_movement_enc'])
self.opt_initp_enc.load_state_dict(checkpoint['opt_initp_enc'])
self.opt_initp_dec.load_state_dict(checkpoint['opt_initp_dec'])
return checkpoint['ep'], checkpoint['total_it']
def kld_coef(self, i):
return float(1/(1+np.exp(-0.0005*(i-15000)))) #v3
def generate_stdp_sequence(self, initpose, aud, num_stdp):
self.initp_enc.cuda()
self.initp_dec.cuda()
self.movement_enc.cuda()
self.stdp_dec.cuda()
self.initp_enc.eval()
self.initp_dec.eval()
self.movement_enc.eval()
self.stdp_dec.eval()
initpose = initpose.cuda()
aud_style = self.aud_enc.get_style(aud)
stdp_seq = []
cnt = 0
#for i in range(num_stdp):
while not cnt == num_stdp:
if cnt==0:
z_inits = self.get_z_random(1, 10, 'gauss')
else:
z_init_mus, z_init_logvars = self.initp_enc(initpose)
z_init_stds = z_init_logvars.mul(0.5).exp_()
z_init_epss = self.get_z_random(z_init_stds.size(0), z_init_stds.size(1), 'gauss')
z_inits = z_init_epss.mul(z_init_stds).add_(z_init_mus)
z_audstyle_mu, z_audstyle_logvar = self.audstyle_enc(aud_style)
z_as_std = z_audstyle_logvar.mul(0.5).exp_()
z_as_eps = self.get_z_random(z_as_std.size(0), z_as_std.size(1), 'gauss')
z_audstyle = z_as_eps.mul(z_as_std).add_(z_audstyle_mu)
if random.randint(0,5)==100:
z_audstyle = self.get_z_random(z_inits.size(0), 512, 'gauss')
fake_stdpose = self.stdp_dec(z_inits, z_audstyle)
s = fake_stdpose[0]
diff = torch.abs(s[1:]-s[:-1])
diffsum = torch.sum(diff)
if diffsum.cpu().detach().numpy() < 70:
continue
cnt += 1
stdp_seq.append(fake_stdpose.cpu().detach().numpy())
initpose = fake_stdpose[:,-1,:]
return stdp_seq
def cuda(self):
self.initp_enc.cuda()
self.initp_dec.cuda()
self.movement_enc.cuda()
self.stdp_dec.cuda()
self.l1_criterion.cuda()
def train(self, ep=0, it=0):
self.cuda()
full_kl = self.args.lambda_kl
kl_w = 0
kl_step = 0.05
best_stdp_recon = 100
for epoch in range(ep, self.args.num_epochs):
self.initp_enc.train()
self.initp_dec.train()
self.movement_enc.train()
self.stdp_dec.train()
stdp_recon = 0
for i, (stdpose, stdpose2) in enumerate(self.data_loader):
self.args.lambda_kl = full_kl*self.kld_coef(it)
stdpose, stdpose2 = stdpose.cuda().detach(), stdpose2.cuda().detach()
self.forward(stdpose, stdpose2)
self.update()
self.logs['l_kl_zinit'] += self.loss_kl_z_init.data
self.logs['l_kl_zmovement'] += self.loss_kl_z_movement.data
self.logs['l_l1_initp'] += self.loss_l1_initp.data
self.logs['l_l1_stdp'] += self.loss_l1_stdp.data
self.logs['l_l1_cross_stdp'] += self.loss_l1_cross_stdp.data
self.logs['l_dist_zmovement'] += self.loss_dist_z_movement.data
self.logs['kld_coef'] += self.args.lambda_kl
print('Epoch:{:3} Iter{}/{}\tl_l1_initp {:.3f}\tl_l1_stdp {:.3f}\tl_l1_cross_stdp {:.3f}\tl_dist_zmove {:.3f}\tl_kl_zinit {:.3f}\t l_kl_zmove {:.3f}'.format(
epoch, i, len(self.data_loader), self.loss_l1_initp, self.loss_l1_stdp, self.loss_l1_cross_stdp, self.loss_dist_z_movement, self.loss_kl_z_init, self.loss_kl_z_movement))
it += 1
if it % self.log_interval == 0:
for tag, value in self.logs.items():
self.logger.scalar_summary(tag, value/self.log_interval, it)
self.logs = self.init_logs()
if epoch % self.snapshot_ep == 0:
self.save(os.path.join(self.snapshot_dir, '{:04}.ckpt'.format(epoch)), epoch, it)
================================================
FILE: modulate.py
================================================
# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/Dancing2Music/License.txt
import os
import numpy as np
import librosa
import utils
def modulate(dance, beats, length):
sec_interframe = 1/15
beats_frame = np.around(beats)
t_beat = beats_frame.astype(int)
s_beat = np.arange(3,dance.shape[0],8)
final_pose = np.zeros((length, 14, 2))
if t_beat[0] >3:
final_pose[t_beat[0]-3:t_beat[0]] = dance[:3]
else:
final_pose[:t_beat[0]] = dance[:t_beat[0]]
if t_beat[0]-3 > 0:
final_pose[:t_beat[0]-3] = dance[0]
for t in range(t_beat.shape[0]-1):
begin = int(t_beat[t])
end = int(t_beat[t+1])
interval = end-begin
if t==s_beat.shape[0]-1:
rest = min(final_pose.shape[0]-begin-1, dance.shape[0]-s_beat[t]-1)
break
if t+1 < s_beat.shape[0] and s_beat[t+1]=3:
final_pose[begin-s_beat[t]:begin] = dance[:s_beat[t]]
final_pose[begin:end+1]=pose
rest = min(final_pose.shape[0]-end-1, dance.shape[0]-s_beat[t+1]-1)
else:
end = begin
if t+1 < s_beat.shape[0]:
rest = min(final_pose.shape[0]-end-1, dance.shape[0]-s_beat[t+1]-1)
else:
print(t_beat.shape, s_beat.shape, t)
rest = min(final_pose.shape[0]-end-1, dance.shape[0]-s_beat[t]-1)
if rest > 0:
if t+1 < s_beat.shape[0]:
final_pose[end+1:end+1+rest] = dance[s_beat[t+1]+1:s_beat[t+1]+1+rest]
else:
final_pose[end+1:end+1+rest] = dance[s_beat[t]+1:s_beat[t]+1+rest]
return final_pose
def get_pose(pose, n):
t_pose = np.zeros((n, 14, 2))
if n==11:
t_pose[0] = pose[0]
t_pose[1] = (pose[0]*1+pose[1]*4)/5
t_pose[2] = (pose[1]*2+pose[2]*3)/5
t_pose[3] = (pose[2]*3+pose[3]*2)/5
t_pose[4] = (pose[3]*4+pose[4]*1)/5
t_pose[5] = pose[4]
t_pose[6] = (pose[4]*1+pose[5]*4)/5
t_pose[7] = (pose[5]*2+pose[6]*3)/5
t_pose[8] = (pose[6]*3+pose[7]*2)/5
t_pose[9] = (pose[7]*4+pose[8]*1)/5
t_pose[10] = pose[8]
elif n==10:
t_pose[0] = pose[0]
t_pose[1] = (pose[0]*1+pose[1]*8)/9
t_pose[2] = (pose[1]*2+pose[2]*7)/9
t_pose[3] = (pose[2]*3+pose[3]*6)/9
t_pose[4] = (pose[3]*4+pose[4]*5)/9
t_pose[5] = (pose[4]*5+pose[5]*4)/9
t_pose[6] = (pose[5]*6+pose[6]*3)/9
t_pose[7] = (pose[6]*7+pose[7]*2)/9
t_pose[8] = (pose[7]*8+pose[8]*1)/9
t_pose[9] = pose[8]
elif n==12:
t_pose[0] = pose[0]
t_pose[1] = (pose[0]*3+pose[1]*8)/11
t_pose[2] = (pose[1]*6+pose[2]*5)/11
t_pose[3] = (pose[2]*9+pose[3]*2)/11
t_pose[4] = (pose[2]*1+pose[3]*10)/11
t_pose[5] = (pose[3]*4+pose[4]*7)/11
t_pose[6] = (pose[4]*7+pose[5]*4)/11
t_pose[7] = (pose[5]*10+pose[6]*1)/11
t_pose[8] = (pose[5]*2+pose[6]*9)/11
t_pose[9] = (pose[6]*5+pose[7]*6)/11
t_pose[10] = (pose[7]*8+pose[8]*3)/11
t_pose[11] = pose[8]
elif n==13:
t_pose[0] = pose[0]
t_pose[1] = (pose[0]*1+pose[1]*2)/3
t_pose[2] = (pose[1]*2+pose[2]*1)/3
t_pose[3] = pose[2]
t_pose[4] = (pose[2]*1+pose[3]*2)/3
t_pose[5] = (pose[3]*2+pose[4]*1)/3
t_pose[6] = pose[4]
t_pose[7] = (pose[4]*1+pose[5]*2)/3
t_pose[8] = (pose[5]*2+pose[6]*1)/3
t_pose[9] = pose[6]
t_pose[10] = (pose[6]*1+pose[7]*2)/3
t_pose[11] = (pose[7]*2+pose[8]*1)/3
t_pose[12] = pose[8]
elif n==14:
t_pose[0] = pose[0]
t_pose[1] = (pose[0]*5+pose[1]*8)/13
t_pose[2] = (pose[1]*10+pose[2]*3)/13
t_pose[3] = (pose[1]*2+pose[2]*11)/13
t_pose[4] = (pose[2]*7+pose[3]*6)/13
t_pose[5] = (pose[3]*12+pose[4]*1)/13
t_pose[6] = (pose[3]*4+pose[4]*9)/13
t_pose[7] = (pose[4]*9+pose[5]*4)/13
t_pose[8] = (pose[4]*12+pose[5]*1)/13
t_pose[9] = (pose[5]*6+pose[6]*7)/13
t_pose[10] = (pose[6]*11+pose[7]*2)/13
t_pose[11] = (pose[6]*3+pose[7]*10)/13
t_pose[12] = (pose[7]*8+pose[8]*5)/13
t_pose[13] = pose[8]
elif n==9:
t_pose = pose
elif n==8:
t_pose[0] = pose[0]
t_pose[1] = (pose[1]*6+pose[2]*1)/7
t_pose[2] = (pose[2]*5+pose[3]*2)/7
t_pose[3] = (pose[3]*4+pose[4]*3)/7
t_pose[4] = (pose[4]*3+pose[5]*4)/7
t_pose[5] = (pose[5]*2+pose[6]*5)/7
t_pose[6] = (pose[6]*1+pose[7]*6)/7
t_pose[7] = pose[8]
elif n==7:
t_pose[0] = pose[0]
t_pose[1] = (pose[1]*2+pose[2]*1)/3
t_pose[2] = (pose[2]*1+pose[3]*2)/3
t_pose[3] = pose[4]
t_pose[4] = (pose[5]*2+pose[6]*1)/3
t_pose[5] = (pose[6]*1+pose[7]*2)/3
t_pose[6] = pose[8]
elif n==6:
t_pose[0] = pose[0]
t_pose[1] = (pose[1]*2+pose[2]*3)/5
t_pose[2] = (pose[3]*4+pose[4]*1)/5
t_pose[3] = (pose[4]*1+pose[5]*4)/5
t_pose[4] = (pose[6]*3+pose[7]*2)/5
t_pose[5] = pose[8]
elif n<6:
t_pose[0] = pose[0]
t_pose[n-1] = pose[8]
for i in range(1,n-1):
t_pose[i] = pose[4]
elif n>14:
t_pose[0] = pose[0]
t_pose[n-1] = pose[8]
for i in range(1, n-1):
k = int(8/(n-1)*i)
t_pose[i] = t_pose[k]
else:
print('NOT IMPLEMENT {}'.format(n))
return t_pose
================================================
FILE: networks.py
================================================
# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/Dancing2Music/License.txt
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variable
import numpy as np
if torch.cuda.is_available():
T = torch.cuda
else:
T = torch
###########################################################
##########
########## Stage 1: Movement
##########
###########################################################
class InitPose_Enc(nn.Module):
def __init__(self, pose_size, dim_z_init):
super(InitPose_Enc, self).__init__()
nf = 64
#nf = 32
self.enc = nn.Sequential(
nn.Linear(pose_size, nf),
nn.LayerNorm(nf),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nf, nf),
nn.LayerNorm(nf),
nn.LeakyReLU(0.2, inplace=True),
)
self.mean = nn.Sequential(
nn.Linear(nf,dim_z_init),
)
self.std = nn.Sequential(
nn.Linear(nf,dim_z_init),
)
def forward(self, pose):
enc = self.enc(pose)
return self.mean(enc), self.std(enc)
class InitPose_Dec(nn.Module):
def __init__(self, pose_size, dim_z_init):
super(InitPose_Dec, self).__init__()
nf = 64
#nf = dim_z_init
self.dec = nn.Sequential(
nn.Linear(dim_z_init, nf),
nn.LayerNorm(nf),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nf, nf),
nn.LayerNorm(nf),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nf,pose_size),
)
def forward(self, z_init):
return self.dec(z_init)
class Movement_Enc(nn.Module):
def __init__(self, pose_size, dim_z_movement, length, hidden_size, num_layers, bidirection=False):
super(Movement_Enc, self).__init__()
self.hidden_size = hidden_size
self.bidirection = bidirection
if bidirection:
self.num_dir = 2
else:
self.num_dir = 1
self.recurrent = nn.GRU(pose_size, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=bidirection)
self.init_h = nn.Parameter(torch.randn(num_layers*self.num_dir, 1, hidden_size).type(T.FloatTensor), requires_grad=True)
if bidirection:
self.mean = nn.Sequential(
nn.Linear(hidden_size*2,dim_z_movement),
)
self.std = nn.Sequential(
nn.Linear(hidden_size*2,dim_z_movement),
)
else:
'''
self.enc = nn.Sequential(
nn.Linear(hidden_size, hidden_size//2),
nn.LayerNorm(hidden_size//2),
nn.ReLU(inplace=True),
)
'''
self.mean = nn.Sequential(
nn.Linear(hidden_size,dim_z_movement),
)
self.std = nn.Sequential(
nn.Linear(hidden_size,dim_z_movement),
)
def forward(self, poses):
num_samples = poses.shape[0]
h_t = [self.init_h.repeat(1, num_samples, 1)]
output, hidden = self.recurrent(poses, h_t[0])
if self.bidirection:
output = torch.cat((output[:,-1,:self.hidden_size], output[:,0,self.hidden_size:]), 1)
else:
output = output[:,-1,:]
#enc = self.enc(output)
#return self.mean(enc), self.std(enc)
return self.mean(output), self.std(output)
def getFeature(self, poses):
num_samples = poses.shape[0]
h_t = [self.init_h.repeat(1, num_samples, 1)]
output, hidden = self.recurrent(poses, h_t[0])
if self.bidirection:
output = torch.cat((output[:,-1,:self.hidden_size], output[:,0,self.hidden_size:]), 1)
else:
output = output[:,-1,:]
return output
class StandardPose_Dec(nn.Module):
def __init__(self, pose_size, dim_z_init, dim_z_movement, length, hidden_size, num_layers):
super(StandardPose_Dec, self).__init__()
self.length = length
self.pose_size = pose_size
self.hidden_size = hidden_size
self.num_layers = num_layers
#dim_z_init=0
'''
self.z2init = nn.Sequential(
nn.Linear(dim_z_init+dim_z_movement, hidden_size),
nn.LayerNorm(hidden_size),
nn.ReLU(True),
nn.Linear(hidden_size, num_layers*hidden_size)
)
'''
self.z2init = nn.Sequential(
nn.Linear(dim_z_init+dim_z_movement, num_layers*hidden_size)
)
self.recurrent = nn.GRU(dim_z_movement, hidden_size, num_layers=num_layers, batch_first=True)
self.pose_g = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.LayerNorm(hidden_size),
nn.ReLU(True),
nn.Linear(hidden_size, pose_size)
)
def forward(self, z_init, z_movement):
h_init = self.z2init(torch.cat((z_init, z_movement), 1))
#h_init = self.z2init(z_movement)
h_init = h_init.view(self.num_layers, h_init.size(0), self.hidden_size)
z_movements = z_movement.view(z_movement.size(0),1,z_movement.size(1)).repeat(1, self.length, 1)
z_m_t, _ = self.recurrent(z_movements, h_init)
z_m = z_m_t.contiguous().view(-1, self.hidden_size)
poses = self.pose_g(z_m)
poses = poses.view(z_movement.shape[0], self.length, self.pose_size)
return poses
class StandardPose_Dis(nn.Module):
def __init__(self, pose_size, length):
super(StandardPose_Dis, self).__init__()
self.pose_size = pose_size
self.length = length
nd = 1024
self.main = nn.Sequential(
nn.Linear(length*pose_size, nd),
nn.LayerNorm(nd),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd,nd//2),
nn.LayerNorm(nd//2),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd//2,nd//4),
nn.LayerNorm(nd//4),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd//4, 1)
)
def forward(self, pose_seq):
pose_seq = pose_seq.view(-1, self.pose_size*self.length)
return self.main(pose_seq).squeeze()
###########################################################
##########
########## Stage 2: Dance
##########
###########################################################
class Dance_Enc(nn.Module):
def __init__(self, dim_z_movement, dim_z_dance, hidden_size, num_layers, bidirection=False):
super(Dance_Enc, self).__init__()
self.hidden_size = hidden_size
self.bidirection = bidirection
if bidirection:
self.num_dir = 2
else:
self.num_dir = 1
self.recurrent = nn.GRU(2*dim_z_movement, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=bidirection)
self.init_h = nn.Parameter(torch.randn(num_layers*self.num_dir, 1, hidden_size).type(T.FloatTensor), requires_grad=True)
if bidirection:
self.mean = nn.Sequential(
nn.Linear(hidden_size*2,dim_z_dance),
)
self.std = nn.Sequential(
nn.Linear(hidden_size*2,dim_z_dance),
)
else:
self.mean = nn.Sequential(
nn.Linear(hidden_size,dim_z_dance),
)
self.std = nn.Sequential(
nn.Linear(hidden_size,dim_z_dance),
)
def forward(self, movements_mean, movements_std):
movements = torch.cat((movements_mean, movements_std),2)
num_samples = movements.shape[0]
h_t = [self.init_h.repeat(1, num_samples, 1)]
output, hidden = self.recurrent(movements, h_t[0])
if self.bidirection:
output = torch.cat((output[:,-1,:self.hidden_size], output[:,0,self.hidden_size:]), 1)
else:
output = output[:,-1,:]
return self.mean(output), self.std(output)
class Dance_Dec(nn.Module):
def __init__(self, dim_z_dance, dim_z_movement, hidden_size, num_layers):
super(Dance_Dec, self).__init__()
#self.length = length
self.num_layers = num_layers
self.hidden_size = hidden_size
self.dim_z_movement = dim_z_movement
#dim_z_init=0
'''
self.z2init = nn.Sequential(
nn.Linear(dim_z_init+dim_z_movement, hidden_size),
nn.LayerNorm(hidden_size),
nn.ReLU(True),
nn.Linear(hidden_size, num_layers*hidden_size)
)
'''
self.z2init = nn.Sequential(
nn.Linear(dim_z_dance, num_layers*hidden_size)
)
self.recurrent = nn.GRU(dim_z_dance, hidden_size, num_layers=num_layers, batch_first=True)
self.movement_g = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.LayerNorm(hidden_size),
nn.ReLU(True),
#nn.Linear(hidden_size, dim_z_movement)
)
self.mean = nn.Sequential(
nn.Linear(hidden_size,dim_z_movement),
)
self.std = nn.Sequential(
nn.Linear(hidden_size,dim_z_movement),
)
def forward(self, z_dance, length=3):
h_init = self.z2init(z_dance)
h_init = h_init.view(self.num_layers, h_init.size(0), self.hidden_size)
z_dance = z_dance.view(z_dance.size(0),1,z_dance.size(1)).repeat(1, length, 1)
z_d_t, _ = self.recurrent(z_dance, h_init)
z_d = z_d_t.contiguous().view(-1, self.hidden_size)
z_movement = self.movement_g(z_d)
z_movement_mean, z_movement_std = self.mean(z_movement), self.std(z_movement)
#z_movement = z_movement.view(z_dance.shape[0], length, self.dim_z_movement)
return z_movement_mean, z_movement_std
class DanceAud_Dis2(nn.Module):
def __init__(self, aud_size, dim_z_movement, length=3):
super(DanceAud_Dis2, self).__init__()
self.aud_size = aud_size
self.dim_z_movement = dim_z_movement
self.length = length
nd = 1024
self.movementd = nn.Sequential(
nn.Linear(dim_z_movement*2*length, nd),
nn.LayerNorm(nd),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd,nd//2),
nn.LayerNorm(nd//2),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd//2,nd//4),
nn.LayerNorm(nd//4),
nn.LeakyReLU(0.2, inplace=True),
#nn.Linear(nd//4, 30),
nn.Linear(nd//4, 30),
)
self.audd = nn.Sequential(
nn.Linear(aud_size, 30),
nn.LayerNorm(30),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(30, 30),
nn.LayerNorm(30),
nn.LeakyReLU(0.2, inplace=True),
)
self.jointd = nn.Sequential(
nn.Linear(60, 1)
)
def forward(self, movements, aud):
if len(movements.shape) == 3:
movements = movements.view(movements.shape[0], movements.shape[1]*movements.shape[2])
m = self.movementd(movements)
a = self.audd(aud)
ma = torch.cat((m,a),1)
return self.jointd(ma).squeeze(), None
class DanceAud_Dis(nn.Module):
def __init__(self, aud_size, dim_z_movement, length=3):
super(DanceAud_Dis, self).__init__()
self.aud_size = aud_size
self.dim_z_movement = dim_z_movement
self.length = length
nd = 1024
self.movementd = nn.Sequential(
#nn.Linear(dim_z_movement*3, nd),
nn.Linear(dim_z_movement*2, nd),
nn.LayerNorm(nd),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd,nd//2),
nn.LayerNorm(nd//2),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd//2,nd//4),
nn.LayerNorm(nd//4),
nn.LeakyReLU(0.2, inplace=True),
#nn.Linear(nd//4, 30),
nn.Linear(nd//4, 30),
)
def forward(self, movements, aud):
#movements = movements.view(movements.shape[0], movements.shape[1]*movements.shape[2])
m = self.movementd(movements)
return m.squeeze()
#a = self.audd(aud)
#ma = torch.cat((m,a),1)
#return self.jointd(ma).squeeze()
class DanceAud_InfoDis(nn.Module):
def __init__(self, aud_size, dim_z_movement, length):
super(DanceAud_InfoDis, self).__init__()
self.aud_size = aud_size
self.dim_z_movement = dim_z_movement
self.length = length
nd = 1024
self.movementd = nn.Sequential(
nn.Linear(dim_z_movement*6, nd*2),
nn.LayerNorm(nd*2),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd*2, nd),
nn.LayerNorm(nd),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd,nd//2),
nn.LayerNorm(nd//2),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd//2,nd//4),
nn.LayerNorm(nd//4),
nn.LeakyReLU(0.2, inplace=True),
)
self.dis = nn.Sequential(
nn.Linear(nd//4, 1)
)
self.reg = nn.Sequential(
nn.Linear(nd//4, aud_size)
)
def forward(self, movements, aud):
movements = movements.view(movements.shape[0], movements.shape[1]*movements.shape[2])
m = self.movementd(movements)
return self.dis(m).squeeze(), self.reg(m)
class Dance2Style(nn.Module):
def __init__(self, dim_z_dance, aud_size):
super(Dance2Style, self).__init__()
self.aud_size = aud_size
self.dim_z_dance = dim_z_dance
nd = 512
self.main = nn.Sequential(
nn.Linear(dim_z_dance, nd),
nn.LayerNorm(nd),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd, nd//2),
nn.LayerNorm(nd//2),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd//2, nd//4),
nn.LayerNorm(nd//4),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd//4, aud_size),
)
def forward(self, zdance):
return self.main(zdance)
###########################################################
##########
########## Audio
##########
###########################################################
class AudioClassifier_rnn(nn.Module):
def __init__(self, dim_z_motion, hidden_size, pose_size, cls, num_layers=1, h_init=2):
super(AudioClassifier_rnn, self).__init__()
self.dim_z_motion = dim_z_motion
self.hidden_size = hidden_size
self.pose_size = pose_size
self.h_init = h_init
self.num_layers = num_layers
self.init_h = nn.Parameter(torch.randn(1, 1, self.hidden_size).type(T.FloatTensor), requires_grad=True)
self.recurrent = nn.GRU(pose_size, hidden_size, num_layers=num_layers, batch_first=True)
self.classifier = nn.Sequential(
#nn.Dropout(p=0.2),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(True),
#nn.Dropout(p=0.2),
nn.Linear(hidden_size, cls)
)
def forward(self, poses):
hidden, _ = self.recurrent(poses, self.init_h.repeat(1, poses.shape[0], 1))
last_hidden = hidden[:,-1,:]
cls = self.classifier(last_hidden)
return cls
def get_style(self, auds):
hidden, _ = self.recurrent(auds, self.init_h.repeat(1, auds.shape[0], 1))
last_hidden = hidden[:,-1,:]
return last_hidden
class Audstyle_Enc(nn.Module):
def __init__(self, aud_size, dim_z, dim_noise=30):
super(Audstyle_Enc, self).__init__()
self.dim_noise = dim_noise
nf = 64
#nf = 32
self.enc = nn.Sequential(
nn.Linear(aud_size+dim_noise, nf),
nn.LayerNorm(nf),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nf, nf*2),
nn.LayerNorm(nf*2),
nn.LeakyReLU(0.2, inplace=True),
)
self.mean = nn.Sequential(
nn.Linear(nf*2,dim_z),
)
self.std = nn.Sequential(
nn.Linear(nf*2,dim_z),
)
def forward(self, aud):
noise = torch.randn(aud.shape[0], self.dim_noise).cuda()
y = torch.cat((aud, noise), 1)
enc = self.enc(y)
return self.mean(enc), self.std(enc)
================================================
FILE: options.py
================================================
# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/Dancing2Music/License.txt
import argparse
class DecompOptions():
def __init__(self):
parser = argparse.ArgumentParser()
parser.add_argument('--name', default=None)
parser.add_argument('--log_interval', type=int, default=50)
parser.add_argument('--log_dir', default='./logs')
parser.add_argument('--snapshot_ep', type=int, default=1)
parser.add_argument('--snapshot_dir', default='./snapshot')
parser.add_argument('--data_dir', default='./data')
# Model architecture
parser.add_argument('--pose_size', type=int, default=28)
parser.add_argument('--dim_z_init', type=int, default=10)
parser.add_argument('--dim_z_movement', type=int, default=512)
parser.add_argument('--stdp_length', type=int, default=32)
parser.add_argument('--movement_enc_bidirection', type=int, default=1)
parser.add_argument('--movement_enc_hidden_size', type=int, default=1024)
parser.add_argument('--stdp_dec_hidden_size', type=int, default=1024)
parser.add_argument('--movement_enc_num_layers', type=int, default=1)
parser.add_argument('--stdp_dec_num_layers', type=int, default=1)
# Training
parser.add_argument('--lr', type=float, default=2e-4)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--num_epochs', type=int, default=1000)
parser.add_argument('--latent_dropout', type=float, default=0.3)
parser.add_argument('--lambda_kl', type=float, default=0.01)
parser.add_argument('--lambda_initp_recon', type=float, default=1)
parser.add_argument('--lambda_initp_consistency', type=float, default=1)
parser.add_argument('--lambda_stdp_recon', type=float, default=1)
parser.add_argument('--lambda_dist_z_movement', type=float, default=1)
# Others
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--resume', default=None)
parser.add_argument('--dataset', type=int, default=0)
parser.add_argument('--tolerance', action='store_true')
self.parser = parser
def parse(self):
self.opt = self.parser.parse_args()
args = vars(self.opt)
return self.opt
class CompOptions():
def __init__(self):
parser = argparse.ArgumentParser()
parser.add_argument('--name', default=None)
parser.add_argument('--log_interval', type=int, default=50)
parser.add_argument('--log_dir', default='./logs')
parser.add_argument('--snapshot_ep', type=int, default=1)
parser.add_argument('--snapshot_dir', default='./snapshot')
parser.add_argument('--data_dir', default='./data')
# Network architecture
parser.add_argument('--pose_size', type=int, default=28)
parser.add_argument('--aud_style_size', type=int, default=30)
parser.add_argument('--dim_z_init', type=int, default=10)
parser.add_argument('--dim_z_movement', type=int, default=512)
parser.add_argument('--dim_z_dance', type=int, default=512)
parser.add_argument('--stdp_length', type=int, default=32)
parser.add_argument('--movement_enc_bidirection', type=int, default=1)
parser.add_argument('--movement_enc_hidden_size', type=int, default=1024)
parser.add_argument('--stdp_dec_hidden_size', type=int, default=1024)
parser.add_argument('--movement_enc_num_layers', type=int, default=1)
parser.add_argument('--stdp_dec_num_layers', type=int, default=1)
parser.add_argument('--dance_enc_bidirection', type=int, default=0)
parser.add_argument('--dance_enc_hidden_size', type=int, default=1024)
parser.add_argument('--dance_enc_num_layers', type=int, default=1)
parser.add_argument('--dance_dec_hidden_size', type=int, default=1024)
parser.add_argument('--dance_dec_num_layers', type=int, default=1)
# Training
parser.add_argument('--lr', type=float, default=2e-4)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--num_epochs', type=int, default=1500)
parser.add_argument('--latent_dropout', type=float, default=0.3)
parser.add_argument('--lambda_kl', type=float, default=0.01)
parser.add_argument('--lambda_kl_dance', type=float, default=0.01)
parser.add_argument('--lambda_gan', type=float, default=1)
parser.add_argument('--lambda_zmovements_recon', type=float, default=1)
parser.add_argument('--lambda_stdpSeq_recon', type=float, default=10)
parser.add_argument('--lambda_dist_z_movement', type=float, default=1)
# Other
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--decomp_snapshot', required=True)
parser.add_argument('--neta_snapshot', default='./data/stats/aud_3cls.ckpt')
parser.add_argument('--resume', default=None)
parser.add_argument('--dataset', type=int, default=2)
self.parser = parser
def parse(self):
self.opt = self.parser.parse_args()
args = vars(self.opt)
return self.opt
class TestOptions():
def __init__(self):
parser = argparse.ArgumentParser()
parser.add_argument('--name', default=None)
parser.add_argument('--log_interval', type=int, default=50)
parser.add_argument('--log_dir', default='./logs')
parser.add_argument('--snapshot_ep', type=int, default=1)
parser.add_argument('--snapshot_dir', default='./snapshot')
parser.add_argument('--data_dir', default='./data')
# Network architecture
parser.add_argument('--pose_size', type=int, default=28)
parser.add_argument('--aud_style_size', type=int, default=30)
parser.add_argument('--dim_z_init', type=int, default=10)
parser.add_argument('--dim_z_movement', type=int, default=512)
parser.add_argument('--dim_z_dance', type=int, default=512)
parser.add_argument('--stdp_length', type=int, default=32)
parser.add_argument('--movement_enc_bidirection', type=int, default=1)
parser.add_argument('--movement_enc_hidden_size', type=int, default=1024)
parser.add_argument('--stdp_dec_hidden_size', type=int, default=1024)
parser.add_argument('--movement_enc_num_layers', type=int, default=1)
parser.add_argument('--stdp_dec_num_layers', type=int, default=1)
parser.add_argument('--dance_enc_bidirection', type=int, default=0)
parser.add_argument('--dance_enc_hidden_size', type=int, default=1024)
parser.add_argument('--dance_enc_num_layers', type=int, default=1)
parser.add_argument('--dance_dec_hidden_size', type=int, default=1024)
parser.add_argument('--dance_dec_num_layers', type=int, default=1)
# Training
parser.add_argument('--lr', type=float, default=2e-4)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--num_epochs', type=int, default=1500)
parser.add_argument('--latent_dropout', type=float, default=0.3)
parser.add_argument('--lambda_kl', type=float, default=0.01)
parser.add_argument('--lambda_kl_dance', type=float, default=0.01)
parser.add_argument('--lambda_gan', type=float, default=1)
parser.add_argument('--lambda_zmovements_recon', type=float, default=1)
parser.add_argument('--lambda_stdpSeq_recon', type=float, default=10)
parser.add_argument('--lambda_dist_z_movement', type=float, default=1)
# Other
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--decomp_snapshot', required=True)
parser.add_argument('--comp_snapshot', required=True)
parser.add_argument('--neta_snapshot', default='./data/stats/aud_3cls.ckpt')
parser.add_argument('--dataset', type=int, default=2)
parser.add_argument('--thr', type=int, default=50)
parser.add_argument('--aud_path', type=str, required=True)
parser.add_argument('--modulate', action='store_true')
parser.add_argument('--out_file', type=str, default='demo/out.mp4')
parser.add_argument('--out_dir', type=str, default='demo/out_frame')
self.parser = parser
def parse(self):
self.opt = self.parser.parse_args()
args = vars(self.opt)
return self.opt
================================================
FILE: test.py
================================================
# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/Dancing2Music/License.txt
import os
import argparse
import functools
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from model_comp import *
from networks import *
from options import CompOptions
from data import get_loader
if __name__ == "__main__":
parser = CompOptions()
args = parser.parse()
#### Pretrain network from Decomp
initp_enc, stdp_dec, movement_enc = loadDecompModel(args)
#### Comp network
dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls = loadCompModel(args)
mean_pose=np.load('../onbeat/all_onbeat_mean.npy')
std_pose=np.load('../onbeat/all_onbeat_std.npy')
mean_aud=np.load('../onbeat/all_aud_mean.npy')
std_aud=np.load('../onbeat/all_aud_std.npy')
def loadDecompModel(args):
initp_enc = InitPose_Enc(pose_size=args.pose_size, dim_z_init=args.dim_z_init)
stdp_dec = StandardPose_Dec(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, dim_z_init=args.dim_z_init, length=args.stdp_length,
hidden_size=args.stdp_dec_hidden_size, num_layers=args.stdp_dec_num_layers)
movement_enc = Movement_Enc(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, length=args.stdp_length,
hidden_size=args.movement_enc_hidden_size, num_layers=args.movement_enc_num_layers, bidirection=(args.movement_enc_bidirection==1))
checkpoint = torch.load(args.decomp_snapshot)
initp_enc.load_state_dict(checkpoint['initp_enc'])
stdp_dec.load_state_dict(checkpoint['stdp_dec'])
movement_enc.load_state_dict(checkpoint['movement_enc'])
return initp_enc, stdp_dec, movement_enc
def loadCompModel(args):
dance_enc = Dance_Enc(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,
hidden_size=args.dance_enc_hidden_size, num_layers=args.dance_enc_num_layers, bidirection=(args.dance_enc_bidirection==1))
dance_dec = Dance_Dec(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,
hidden_size=args.dance_dec_hidden_size, num_layers=args.dance_dec_num_layers)
audstyle_enc = Audstyle_Enc(aud_size=args.aud_style_size, dim_z=args.dim_z_dance)
dance_reg = Dance2Style(aud_size=args.aud_style_size, dim_z_dance=args.dim_z_dance)
danceAud_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_movement, length=3)
zdance_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_dance, length=1)
checkpoint = torch.load(args.resume)
dance_enc.load_state_dict(checkpoint['dance_enc'])
dance_dec.load_state_dict(checkpoint['dance_dec'])
audstyle_enc.load_state_dict(checkpoint['audstyle_enc'])
dance_reg.load_state_dict(checkpoint['dance_reg'])
danceAud_dis.load_state_dict(checkpoint['danceAud_dis'])
zdance_dis.load_state_dict(checkpoint['zdance_dis'])
checkpoint2 = torch.load(args.neta_snapshot)
neta_cls = AudioClassifier_rnn(10,30,28,cls=3)
neta_cls.load_state_dict(checkpoint2)
return dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls
================================================
FILE: train_comp.py
================================================
# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/Dancing2Music/License.txt
import os
import argparse
import functools
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from model_comp import *
from networks import *
from options import CompOptions
from data import get_loader
def loadDecompModel(args):
initp_enc = InitPose_Enc(pose_size=args.pose_size, dim_z_init=args.dim_z_init)
stdp_dec = StandardPose_Dec(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, dim_z_init=args.dim_z_init, length=args.stdp_length,
hidden_size=args.stdp_dec_hidden_size, num_layers=args.stdp_dec_num_layers)
movement_enc = Movement_Enc(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, length=args.stdp_length,
hidden_size=args.movement_enc_hidden_size, num_layers=args.movement_enc_num_layers, bidirection=(args.movement_enc_bidirection==1))
checkpoint = torch.load(args.decomp_snapshot)
initp_enc.load_state_dict(checkpoint['initp_enc'])
stdp_dec.load_state_dict(checkpoint['stdp_dec'])
movement_enc.load_state_dict(checkpoint['movement_enc'])
return initp_enc, stdp_dec, movement_enc
def getCompNetworks(args):
dance_enc = Dance_Enc(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,
hidden_size=args.dance_enc_hidden_size, num_layers=args.dance_enc_num_layers, bidirection=(args.dance_enc_bidirection==1))
dance_dec = Dance_Dec(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,
hidden_size=args.dance_dec_hidden_size, num_layers=args.dance_dec_num_layers)
audstyle_enc = Audstyle_Enc(aud_size=args.aud_style_size, dim_z=args.dim_z_dance)
dance_reg = Dance2Style(aud_size=args.aud_style_size, dim_z_dance=args.dim_z_dance)
danceAud_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_movement, length=3)
zdance_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_dance, length=1)
checkpoint2 = torch.load(args.neta_snapshot)
neta_cls = AudioClassifier_rnn(10,30,28,cls=3)
neta_cls.load_state_dict(checkpoint2)
return dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls
if __name__ == "__main__":
parser = CompOptions()
args = parser.parse()
args.train = True
if args.name is None:
args.name = 'Comp'
args.log_dir = os.path.join(args.log_dir, args.name)
if not os.path.exists(args.log_dir):
os.mkdir(args.log_dir)
args.snapshot_dir = os.path.join(args.snapshot_dir, args.name)
if not os.path.exists(args.snapshot_dir):
os.mkdir(args.snapshot_dir)
data_loader = get_loader(batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, dataset=args.dataset, data_dir=args.data_dir)
#### Pretrain network from Decomp
initp_enc, stdp_dec, movement_enc = loadDecompModel(args)
#### Comp network
dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls = getCompNetworks(args)
trainer = Trainer_Comp(data_loader,
movement_enc = movement_enc,
initp_enc = initp_enc,
stdp_dec = stdp_dec,
dance_enc = dance_enc,
dance_dec = dance_dec,
danceAud_dis = danceAud_dis,
zdance_dis = zdance_dis,
aud_enc=neta_cls,
audstyle_enc=audstyle_enc,
dance_reg=dance_reg,
args = args
)
if not args.resume is None:
ep, it = trainer.resume(args.resume, True)
else:
ep, it = 0, 0
trainer.train(ep, it)
================================================
FILE: train_decomp.py
================================================
# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/Dancing2Music/License.txt
import os
import argparse
import functools
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from model_decomp import *
from networks import *
from options import DecompOptions
from data import get_loader
def getDecompNetworks(args):
initp_enc = InitPose_Enc(pose_size=args.pose_size, dim_z_init=args.dim_z_init)
initp_dec = InitPose_Dec(pose_size=args.pose_size, dim_z_init=args.dim_z_init)
movement_enc = Movement_Enc(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, length=args.stdp_length,
hidden_size=args.movement_enc_hidden_size, num_layers=args.movement_enc_num_layers, bidirection=(args.movement_enc_bidirection==1))
stdp_dec = StandardPose_Dec(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, dim_z_init=args.dim_z_init, length=args.stdp_length,
hidden_size=args.stdp_dec_hidden_size, num_layers=args.stdp_dec_num_layers)
return initp_enc, initp_dec, movement_enc, stdp_dec
if __name__ == "__main__":
parser = DecompOptions()
args = parser.parse()
args.train = True
if args.name is None:
args.name = 'Decomp'
args.log_dir = os.path.join(args.log_dir, args.name)
if not os.path.exists(args.log_dir):
os.mkdir(args.log_dir)
args.snapshot_dir = os.path.join(args.snapshot_dir, args.name)
if not os.path.exists(args.snapshot_dir):
os.mkdir(args.snapshot_dir)
data_loader = get_loader(batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, dataset=args.dataset, data_dir=args.data_dir, tolerance=args.tolerance)
initp_enc, initp_dec, movement_enc, stdp_dec = getDecompNetworks(args)
trainer = Trainer_Decomp(data_loader,
initp_enc = initp_enc,
initp_dec = initp_dec,
movement_enc = movement_enc,
stdp_dec = stdp_dec,
args = args
)
if not args.resume is None:
ep, it = trainer.resume(args.resume, False)
else:
ep, it = 0, 0
trainer.train(ep=ep, it=it)
================================================
FILE: utils.py
================================================
# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/Dancing2Music/License.txt
import numpy as np
import pickle
import cv2
import math
import os
import random
import tensorflow as tf
class Logger(object):
def __init__(self, log_dir):
self.writer = tf.summary.FileWriter(log_dir)
def scalar_summary(self, tag, value, step):
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)
def vis(poses, outdir, aud=None):
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
# find connection in the specified sequence, center 29 is in the position 15
limbSeq = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10], \
[10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17], \
[1,16], [16,18], [3,17], [6,18]]
neglect = [14,15,16,17]
for t in range(poses.shape[0]):
#break
canvas = np.ones((256,500,3), np.uint8)*255
thisPeak = poses[t]
for i in range(18):
if i in neglect:
continue
if thisPeak[i,0] == -1:
continue
cv2.circle(canvas, tuple(thisPeak[i,0:2].astype(int)), 4, colors[i], thickness=-1)
for i in range(17):
limbid = np.array(limbSeq[i])-1
if limbid[0] in neglect or limbid[1] in neglect:
continue
X = thisPeak[[limbid[0],limbid[1]], 1]
Y = thisPeak[[limbid[0],limbid[1]], 0]
if X[0] == -1 or Y[0]==-1 or X[1]==-1 or Y[1]==-1:
continue
stickwidth = 4
cur_canvas = canvas.copy()
mX = np.mean(X)
mY = np.mean(Y)
#print(X, Y, limbid)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly((int(mY),int(mX)), (int(length/2), stickwidth), int(angle), 0, 360, 1)
#print(i, n, int(mY), int(mX), limbid, X, Y)
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
if aud is not None:
if aud[:,t] == 1:
cv2.circle(canvas, (30, 30), 20, (0,0,255), -1)
#canvas = cv2.copyMakeBorder(canvas,10,10,10,10,cv2.BORDER_CONSTANT,value=[255,0,0])
cv2.imwrite(os.path.join(outdir, 'frame{0:03d}.png'.format(t)),canvas)
def vis2(poses, outdir, fibeat):
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
# find connection in the specified sequence, center 29 is in the position 15
limbSeq = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10], \
[10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17], \
[1,16], [16,18], [3,17], [6,18]]
neglect = [14,15,16,17]
ibeat = cv2.imread(fibeat);
ibeat = cv2.resize(ibeat, (500,200))
for t in range(poses.shape[0]):
subibeat = ibeat.copy()
canvas = np.ones((256+200,500,3), np.uint8)*255
canvas[256:,:,:] = subibeat
overlay = canvas.copy()
cv2.rectangle(overlay, (int(500/poses.shape[0]*(t+1)),256),(500,256+200), (100,100,100), -1)
cv2.addWeighted(overlay, 0.4, canvas, 1-0.4, 0, canvas)
thisPeak = poses[t]
for i in range(18):
if i in neglect:
continue
if thisPeak[i,0] == -1:
continue
cv2.circle(canvas, tuple(thisPeak[i,0:2].astype(int)), 4, colors[i], thickness=-1)
for i in range(17):
limbid = np.array(limbSeq[i])-1
if limbid[0] in neglect or limbid[1] in neglect:
continue
X = thisPeak[[limbid[0],limbid[1]], 1]
Y = thisPeak[[limbid[0],limbid[1]], 0]
if X[0] == -1 or Y[0]==-1 or X[1]==-1 or Y[1]==-1:
continue
stickwidth = 4
cur_canvas = canvas.copy()
mX = np.mean(X)
mY = np.mean(Y)
#print(X, Y, limbid)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly((int(mY),int(mX)), (int(length/2), stickwidth), int(angle), 0, 360, 1)
#print(i, n, int(mY), int(mX), limbid, X, Y)
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
cv2.imwrite(os.path.join(outdir, 'frame{0:03d}.png'.format(t)),canvas)
def vis_single(pose, outfile):
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
# find connection in the specified sequence, center 29 is in the position 15
limbSeq = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10], \
[10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17], \
[1,16], [16,18], [3,17], [6,18]]
neglect = [14,15,16,17]
for t in range(1):
#break
canvas = np.ones((256,500,3), np.uint8)*255
thisPeak = pose
for i in range(18):
if i in neglect:
continue
if thisPeak[i,0] == -1:
continue
cv2.circle(canvas, tuple(thisPeak[i,0:2].astype(int)), 4, colors[i], thickness=-1)
for i in range(17):
limbid = np.array(limbSeq[i])-1
if limbid[0] in neglect or limbid[1] in neglect:
continue
X = thisPeak[[limbid[0],limbid[1]], 1]
Y = thisPeak[[limbid[0],limbid[1]], 0]
if X[0] == -1 or Y[0]==-1 or X[1]==-1 or Y[1]==-1:
continue
stickwidth = 4
cur_canvas = canvas.copy()
mX = np.mean(X)
mY = np.mean(Y)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly((int(mY),int(mX)), (int(length/2), stickwidth), int(angle), 0, 360, 1)
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
cv2.imwrite(outfile,canvas)