Repository: NVlabs/Dancing2Music
Branch: master
Commit: 7ff1d95f9f3d
Files: 13
Total size: 93.6 KB
Directory structure:
gitextract_2r5pf7ve/
├── License.txt
├── README.md
├── data.py
├── demo.py
├── model_comp.py
├── model_decomp.py
├── modulate.py
├── networks.py
├── options.py
├── test.py
├── train_comp.py
├── train_decomp.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: License.txt
================================================
Nvidia Source Code License-NC
1. Definitions
“Licensor” means any person or entity that distributes its Work.
“Software” means the original work of authorship made available under this License.
“Work” means the Software and any additions to or derivative works of the Software that are made available under this License.
“Nvidia Processors” means any central processing unit (CPU), graphics processing unit (GPU), field-programmable gate array (FPGA), application-specific integrated circuit (ASIC) or any combination thereof designed, made, sold, or provided by Nvidia or its affiliates.
The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
Works, including the Software, are “made available” under this License by including in or with the Work either (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License.
2. License Grants
2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
3. Limitations
3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you include a complete copy of this License with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. The Work or derivative works thereof may be used or intended for use by Nvidia or its affiliates commercially or non-commercially. As used herein, “non-commercially” means for research or evaluation purposes only.
3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this License from such Licensor (including the grants in Sections 2.1 and 2.2) will terminate immediately.
3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this License.
3.6 Termination. If you violate any term of this License, then your rights under this License (including the grants in Sections 2.1 and 2.2) will terminate immediately.
4. Disclaimer of Warranty.
THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
5. Limitation of Liability.
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
================================================
FILE: README.md
================================================


## Dancing to Music
PyTorch implementation of the cross-modality generative model that synthesizes dance from music.
### Paper
[Hsin-Ying Lee](http://vllab.ucmerced.edu/hylee/), [Xiaodong Yang](https://xiaodongyang.org/), [Ming-Yu Liu](http://mingyuliu.net/), [Ting-Chun Wang](https://tcwang0509.github.io/), [Yu-Ding Lu](https://jonlu0602.github.io/), [Ming-Hsuan Yang](https://faculty.ucmerced.edu/mhyang/), [Jan Kautz](http://jankautz.com/)
Dancing to Music
Neural Information Processing Systems (**NeurIPS**) 2019
[[Paper]](https://arxiv.org/abs/1911.02001) [[YouTube]](https://youtu.be/-e9USqfwZ4A) [[Project]](http://vllab.ucmerced.edu/hylee/Dancing2Music/script.txt) [[Blog]](https://news.developer.nvidia.com/nvidia-dance-to-music-neurips/) [[Supp]](http://xiaodongyang.org/publications/papers/dance2music-supp-neurips19.pdf)
### Example Videos
- Beat-Matching
1st row: generated dance sequences, 2nd row: music beats, 3rd row: kinematics beats
<p align='left'>
<img src='imgs/example.gif' width='400'/>
</p>
- Multimodality
Generate various dance sequences with the same music and the same initial pose.
<p align='left'>
<img src='imgs/multimodal.gif' width='400'/>
</p>
- Long-Term Generation
Seamlessly generate a dance sequence with arbitrary length.
<p align='left'>
<kbd>
<img src='imgs/long.gif' width='300'/>
</kbd>
</p>
- Photo-Realisitc Videos
Map generated dance sequences to photo-realistic videos.
<p align='left'>
<img src='imgs/v2v.gif' width='800'/>
</p>
## Train Decomposition
```
python train_decomp.py --name Decomp
```
## Train Composition
```
python train_comp.py --name Decomp --decomp_snapshot DECOMP_SNAPSHOT
```
## Demo
```
python demo.py --decomp_snapshot DECOMP_SNAPSHOT --comp_snapshot COMP_SNAPSHOT --aud_path AUD_PATH --out_file OUT_FILE --out_dir OUT_DIR --thr THR
```
- Flags
- `aud_path`: input .wav file
- `out_file`: location of output .mp4 file
- `out_dir`: directory of output frames
- `thr`: threshold based on motion magnitude
- `modulate`: whether to do beat warping
- Example
```
python demo.py -decomp_snapshot snapshot/Stage1.ckpt --comp_snapshot snapshot/Stage2.ckpt --aud_path demo/demo.wav --out_file demo/out.mp4 --out_dir demo/out_frame
```
### Citation
If you find this code useful for your research, please cite our paper:
```bibtex
@inproceedings{lee2019dancing2music,
title={Dancing to Music},
author={Lee, Hsin-Ying and Yang, Xiaodong and Liu, Ming-Yu and Wang, Ting-Chun and Lu, Yu-Ding and Yang, Ming-Hsuan and Kautz, Jan},
booktitle={NeurIPS},
year={2019}
}
```
### License
Copyright (C) 2020 NVIDIA Corporation. All rights reserved. This work is made available under NVIDIA Source Code License (1-Way Commercial). To view a copy of this license, visit https://nvlabs.github.io/Dancing2Music/LICENSE.txt.
================================================
FILE: data.py
================================================
# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/Dancing2Music/License.txt
import os
import pickle
import numpy as np
import random
import torch.utils.data
from torchvision.datasets import ImageFolder
import utils
class PoseDataset(torch.utils.data.Dataset):
def __init__(self, data_dir, tolerance=False):
self.data_dir = data_dir
z_fname = '{}/unitList/zumba_unit.txt'.format(data_dir)
b_fname = '{}/unitList/ballet_unit.txt'.format(data_dir)
h_fname = '{}/unitList/hiphop_unit.txt'.format(data_dir)
self.z_data = []
self.b_data = []
self.h_data = []
with open(z_fname, 'r') as f:
for line in f:
self.z_data.append([s for s in line.strip().split(' ')])
with open(b_fname, 'r') as f:
for line in f:
self.b_data.append([s for s in line.strip().split(' ')])
with open(h_fname, 'r') as f:
for line in f:
self.h_data.append([s for s in line.strip().split(' ')])
self.data = [self.z_data, self.b_data, self.h_data]
self.tolerance = tolerance
if self.tolerance:
z3_fname = '{}/unitList/zumba_unitseq3.txt'.format(data_dir)
b3_fname = '{}/unitList/ballet_unitseq3.txt'.format(data_dir)
h3_fname = '{}/unitList/hiphop_unitseq3.txt'.format(data_dir)
z4_fname = '{}/unitList/zumba_unitseq4.txt'.format(data_dir)
b4_fname = '{}/unitList/ballet_unitseq4.txt'.format(data_dir)
h4_fname = '{}/unitList/hiphop_unitseq4.txt'.format(data_dir)
z3_data = []; b3_data = []; h3_data = []; z4_data = []; b4_data = []; h4_data = []
with open(z3_fname, 'r') as f:
for line in f:
z3_data.append([s for s in line.strip().split(' ')])
with open(b3_fname, 'r') as f:
for line in f:
b3_data.append([s for s in line.strip().split(' ')])
with open(h3_fname, 'r') as f:
for line in f:
h3_data.append([s for s in line.strip().split(' ')])
with open(z4_fname, 'r') as f:
for line in f:
z4_data.append([s for s in line.strip().split(' ')])
with open(b4_fname, 'r') as f:
for line in f:
b4_data.append([s for s in line.strip().split(' ')])
with open(h4_fname, 'r') as f:
for line in f:
h4_data.append([s for s in line.strip().split(' ')])
self.zt_data = z3_data + z4_data
self.bt_data = b3_data + b4_data
self.ht_data = h3_data + h4_data
self.t_data = [self.zt_data, self.bt_data, self.ht_data]
self.mean_pose=np.load(data_dir+'/stats/all_onbeat_mean.npy')
self.std_pose=np.load(data_dir+'/stats/all_onbeat_std.npy')
def __getitem__(self, index):
cls = random.randint(0,2)
cls = random.randint(0,1)
if self.tolerance and random.randint(0,9)==0:
index = random.randint(0, len(self.t_data[cls])-1)
path = self.t_data[cls][index][0]
path = os.path.join(self.data_dir, path[5:])
orig_poses = np.load(path)
sel = random.randint(0, orig_poses.shape[0]-1)
orig_poses = orig_poses[sel]
else:
index = random.randint(0, len(self.data[cls])-1)
path = self.data[cls][index][0]
path = os.path.join(self.data_dir, path[5:])
orig_poses = np.load(path)
xjit = np.random.uniform(low=-50, high=50)
yjit = np.random.uniform(low=-20, high=20)
poses = orig_poses.copy()
poses[:,:,0] += xjit
poses[:,:,1] += yjit
xjit = np.random.uniform(low=-50, high=50)
yjit = np.random.uniform(low=-20, high=20)
poses2 = orig_poses.copy()
poses2[:,:,0] += xjit
poses2[:,:,1] += yjit
poses = poses.reshape(poses.shape[0], poses.shape[1]*poses.shape[2])
poses2 = poses2.reshape(poses2.shape[0], poses2.shape[1]*poses2.shape[2])
for i in range(poses.shape[0]):
poses[i] = (poses[i]-self.mean_pose)/self.std_pose
poses2[i] = (poses2[i]-self.mean_pose)/self.std_pose
return torch.Tensor(poses), torch.Tensor(poses2)
def __len__(self):
return len(self.z_data)+len(self.b_data)
class MovementAudDataset(torch.utils.data.Dataset):
def __init__(self, data_dir):
self.data_dir = data_dir
z3_fname = '{}/unitList/zumba_unitseq3.txt'.format(data_dir)
b3_fname = '{}/unitList/ballet_unitseq3.txt'.format(data_dir)
h3_fname = '{}/unitList/hiphop_unitseq3.txt'.format(data_dir)
z4_fname = '{}/unitList/zumba_unitseq4.txt'.format(data_dir)
b4_fname = '{}/unitList/ballet_unitseq4.txt'.format(data_dir)
h4_fname = '{}/unitList/hiphop_unitseq4.txt'.format(data_dir)
self.z3_data = []
self.b3_data = []
self.h3_data = []
self.z4_data = []
self.b4_data = []
self.h4_data = []
with open(z3_fname, 'r') as f:
for line in f:
self.z3_data.append([s for s in line.strip().split(' ')])
with open(b3_fname, 'r') as f:
for line in f:
self.b3_data.append([s for s in line.strip().split(' ')])
with open(h3_fname, 'r') as f:
for line in f:
self.h3_data.append([s for s in line.strip().split(' ')])
with open(z4_fname, 'r') as f:
for line in f:
self.z4_data.append([s for s in line.strip().split(' ')])
with open(b4_fname, 'r') as f:
for line in f:
self.b4_data.append([s for s in line.strip().split(' ')])
with open(h4_fname, 'r') as f:
for line in f:
self.h4_data.append([s for s in line.strip().split(' ')])
self.data_3 = [self.z3_data, self.b3_data, self.h3_data]
self.data_4 = [self.z4_data, self.b4_data, self.h4_data]
z_data_root = 'zumba/'
b_data_root = 'ballet/'
h_data_root = 'hiphop/'
self.data_root = [z_data_root, b_data_root, h_data_root ]
self.mean_pose=np.load(data_dir+'/stats/all_onbeat_mean.npy')
self.std_pose=np.load(data_dir+'/stats/all_onbeat_std.npy')
self.mean_aud=np.load(data_dir+'/stats/all_aud_mean.npy')
self.std_aud=np.load(data_dir+'/stats/all_aud_std.npy')
def __getitem__(self, index):
cls = random.randint(0,2)
cls = random.randint(0,1)
isthree = random.randint(0,1)
if isthree == 0:
index = random.randint(0, len(self.data_4[cls])-1)
path = self.data_4[cls][index][0]
else:
index = random.randint(0, len(self.data_3[cls])-1)
path = self.data_3[cls][index][0]
path = os.path.join(self.data_dir, path[5:])
stdpSeq = np.load(path)
vid, cid = path.split('/')[-4], path.split('/')[-3]
#vid, cid = vid_cid[:11], vid_cid[12:]
aud = np.load('{}/{}/{}/{}/aud/c{}_fps15.npy'.format(self.data_dir, self.data_root[cls], vid, cid, cid))
stdpSeq = stdpSeq.reshape(stdpSeq.shape[0], stdpSeq.shape[1], stdpSeq.shape[2]*stdpSeq.shape[3])
for i in range(stdpSeq.shape[0]):
for j in range(stdpSeq.shape[1]):
stdpSeq[i,j] = (stdpSeq[i,j]-self.mean_pose)/self.std_pose
if isthree == 0:
start = random.randint(0,1)
stdpSeq = stdpSeq[start:start+3]
for i in range(aud.shape[0]):
aud[i] = (aud[i]-self.mean_aud)/self.std_aud
aud = aud[:30]
return torch.Tensor(stdpSeq), torch.Tensor(aud)
def __len__(self):
return len(self.z3_data)+len(self.b3_data)+len(self.z4_data)+len(self.b4_data)+len(self.h3_data)+len(self.h4_data)
def get_loader(batch_size, shuffle, num_workers, dataset, data_dir, tolerance=False):
if dataset == 0:
a2d = PoseDataset(data_dir, tolerance)
elif dataset == 2:
a2d = MovementAudDataset(data_dir)
data_loader = torch.utils.data.DataLoader(dataset=a2d,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
)
return data_loader
================================================
FILE: demo.py
================================================
# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/Dancing2Music/License.txt
import os
import argparse
import functools
import librosa
import shutil
import sys
sys.path.insert(0, 'preprocess')
import preprocess as p
import subprocess as sp
from shutil import copyfile
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from model_comp import *
from networks import *
from options import TestOptions
import modulate
import utils
def loadDecompModel(args):
initp_enc = InitPose_Enc(pose_size=args.pose_size, dim_z_init=args.dim_z_init)
stdp_dec = StandardPose_Dec(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, dim_z_init=args.dim_z_init, length=args.stdp_length,
hidden_size=args.stdp_dec_hidden_size, num_layers=args.stdp_dec_num_layers)
movement_enc = Movement_Enc(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, length=args.stdp_length,
hidden_size=args.movement_enc_hidden_size, num_layers=args.movement_enc_num_layers, bidirection=(args.movement_enc_bidirection==1))
checkpoint = torch.load(args.decomp_snapshot)
initp_enc.load_state_dict(checkpoint['initp_enc'])
stdp_dec.load_state_dict(checkpoint['stdp_dec'])
movement_enc.load_state_dict(checkpoint['movement_enc'])
return initp_enc, stdp_dec, movement_enc
def loadCompModel(args):
dance_enc = Dance_Enc(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,
hidden_size=args.dance_enc_hidden_size, num_layers=args.dance_enc_num_layers, bidirection=(args.dance_enc_bidirection==1))
dance_dec = Dance_Dec(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,
hidden_size=args.dance_dec_hidden_size, num_layers=args.dance_dec_num_layers)
audstyle_enc = Audstyle_Enc(aud_size=args.aud_style_size, dim_z=args.dim_z_dance)
dance_reg = Dance2Style(aud_size=args.aud_style_size, dim_z_dance=args.dim_z_dance)
danceAud_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_movement, length=3)
zdance_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_dance, length=1)
checkpoint = torch.load(args.comp_snapshot)
dance_enc.load_state_dict(checkpoint['dance_enc'])
dance_dec.load_state_dict(checkpoint['dance_dec'])
audstyle_enc.load_state_dict(checkpoint['audstyle_enc'])
checkpoint2 = torch.load(args.neta_snapshot)
neta_cls = AudioClassifier_rnn(10,30,28,cls=3)
neta_cls.load_state_dict(checkpoint2)
return dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls
if __name__ == "__main__":
parser = TestOptions()
args = parser.parse()
args.train = False
thr = args.thr
# Process music and get feature
infile = args.aud_path
outfile = 'style.npy'
p.preprocess(infile, outfile)
y, sr = librosa.load(infile)
onset_env = librosa.onset.onset_strength(y, sr=sr,aggregate=np.median)
times = librosa.frames_to_time(np.arange(len(onset_env)),sr=sr, hop_length=512)
tempo, beats = librosa.beat.beat_track(onset_envelope=onset_env,sr=sr)
np.save('beats.npy', times[beats])
beats = np.round(librosa.frames_to_time(beats, sr=sr)*15)
beats = np.load('beats.npy')
aud = np.load('style.npy')
os.remove('beats.npy')
os.remove('style.npy')
shutil.rmtree('normalized')
#### Pretrain network from Decomp
initp_enc, stdp_dec, movement_enc = loadDecompModel(args)
#### Comp network
dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls = loadCompModel(args)
trainer = Trainer_Comp(data_loader=None,
movement_enc = movement_enc,
initp_enc = initp_enc,
stdp_dec = stdp_dec,
dance_enc = dance_enc,
dance_dec = dance_dec,
danceAud_dis = danceAud_dis,
zdance_dis = zdance_dis,
aud_enc=neta_cls,
audstyle_enc=audstyle_enc,
dance_reg=dance_reg,
args = args
)
print('Loading Done')
mean_pose=np.load('{}/stats/all_onbeat_mean.npy'.format(args.data_dir))
std_pose=np.load('{}/stats/all_onbeat_std.npy'.format(args.data_dir))
mean_aud=np.load('{}/stats/all_aud_mean.npy'.format(args.data_dir))
std_aud=np.load('{}/stats/all_aud_std.npy'.format(args.data_dir))
length = aud.shape[0]
initpose = np.zeros((14, 2))
initpose = initpose.reshape(-1)
#initpose = (initpose-mean_pose)/std_pose
for j in range(aud.shape[0]):
aud[j] = (aud[j]-mean_aud)/std_aud
total_t = int(length/32+1)
final_stdpSeq = np.zeros((total_t*3*32, 14, 2))
initpose, aud = torch.Tensor(initpose).cuda(), torch.Tensor(aud).cuda()
initpose, aud = initpose.view(1, initpose.shape[0]), aud.view(1, aud.shape[0], aud.shape[1])
for t in range(total_t):
print('process {}/{}'.format(t, total_t))
fake_stdpSeq = trainer.test_final(initpose, aud, 3, thr)
while True:
fake_stdpSeq = trainer.test_final(initpose, aud, 3, thr)
if not fake_stdpSeq is None:
break
initpose = fake_stdpSeq[2,-1]
initpose = torch.Tensor(initpose).cuda()
initpose = initpose.view(1,-1)
fake_stdpSeq = fake_stdpSeq.squeeze()
for j in range(fake_stdpSeq.shape[0]):
for k in range(fake_stdpSeq.shape[1]):
fake_stdpSeq[j,k] = fake_stdpSeq[j,k]*std_pose + mean_pose
fake_stdpSeq = np.resize(fake_stdpSeq, (fake_stdpSeq.shape[0],32, 14, 2))
for j in range(3):
final_stdpSeq[96*t+32*j:96*t+32*(j+1)] = fake_stdpSeq[j]
if args.modulate:
final_stdpSeq = modulate.modulate(final_stdpSeq, beats, length)
out_dir = args.out_dir
if not os.path.exists(out_dir):
os.mkdir(out_dir)
utils.vis(final_stdpSeq, out_dir)
sp.call('ffmpeg -r 15 -i {}/frame%03d.png -i {} -c:v libx264 -pix_fmt yuv420p -crf 23 -r 30 -y -strict -2 {}'.format(out_dir, args.aud_path, args.out_file), shell=True)
================================================
FILE: model_comp.py
================================================
# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/Dancing2Music/License.txt
import os
import time
import numpy as np
import random
import math
import torch
from torch import nn
from torch.autograd import Variable
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
from utils import Logger
if torch.cuda.is_available():
T = torch.cuda
else:
T = torch
class Trainer_Comp(object):
def __init__(self, data_loader, dance_enc, dance_dec, danceAud_dis, movement_enc, initp_enc, stdp_dec, aud_enc, audstyle_enc, dance_reg=None, args=None, zdance_dis=None):
self.data_loader = data_loader
self.movement_enc = movement_enc
self.initp_enc = initp_enc
self.stdp_dec = stdp_dec
self.dance_enc = dance_enc
self.dance_dec = dance_dec
self.danceAud_dis = danceAud_dis
self.aud_enc = aud_enc
self.audstyle_enc = audstyle_enc
self.train = args.train
self.args = args
if args.train:
self.zdance_dis = zdance_dis
self.dance_reg = dance_reg
self.logger = Logger(args.log_dir)
self.logs = self.init_logs()
self.log_interval = args.log_interval
self.snapshot_ep = args.snapshot_ep
self.snapshot_dir = args.snapshot_dir
self.opt_dance_enc = torch.optim.Adam(self.dance_enc.parameters(), lr=args.lr)
self.opt_dance_dec = torch.optim.Adam(self.dance_dec.parameters(), lr=args.lr)
self.opt_danceAud_dis = torch.optim.Adam(self.danceAud_dis.parameters(), lr=args.lr)
self.opt_audstyle_enc = torch.optim.Adam(self.audstyle_enc.parameters(), lr=args.lr)
self.opt_zdance_dis = torch.optim.Adam(self.zdance_dis.parameters(), lr=args.lr)
self.opt_dance_reg = torch.optim.Adam(self.dance_reg.parameters(), lr=args.lr)
self.opt_stdp_dec = torch.optim.Adam(self.stdp_dec.parameters(), lr=args.lr*0.1)
self.opt_movement_enc = torch.optim.Adam(self.movement_enc.parameters(), lr=args.lr*0.1)
self.latent_dropout = nn.Dropout(p=args.latent_dropout)
self.l1_criterion = torch.nn.L1Loss()
self.gan_criterion = nn.BCEWithLogitsLoss()
self.mse_criterion = nn.MSELoss().cuda()
def init_logs(self):
return {'l_kl_zdance':0, 'l_kl_zmovement':0, 'l_kl_fake_zdance':0, 'l_kl_fake_zmovement':0,
'l_l1_zmovement_mu':0, 'l_l1_zmovement_logvar':0, 'l_l1_stdpSeq':0, 'l_l1_zdance':0,
'l_dis':0, 'l_dis_true':0, 'l_dis_fake':0,
'l_info':0, 'l_info_real':0, 'l_info_fake':0,
'l_gen':0
}
def get_z_random(self, batchSize, nz, random_type='gauss'):
z = torch.randn(batchSize, nz).cuda()
return z
@staticmethod
def ones_like(tensor, val=1.):
return T.FloatTensor(tensor.size()).fill_(val)
@staticmethod
def zeros_like(tensor, val=0.):
return T.FloatTensor(tensor.size()).fill_(val)
def kld_coef(self, i):
return float(1/(1+np.exp(-0.0005*(i-15000))))
def forward(self, stdpSeq, batchsize, aud_style, aud):
self.aud = torch.mean(aud, dim=1)
self.batchsize = batchsize
self.stdpSeq = stdpSeq
self.aud_style = aud_style
### stdpSeq -> z_inits, z_movements
self.pose_0 = stdpSeq[:,0,:]
self.z_init_mu, self.z_init_logvar = self.initp_enc(self.pose_0)
z_init_std = self.z_init_logvar.mul(0.5).exp_()
z_init_eps = self.get_z_random(z_init_std.size(0), z_init_std.size(1), 'gauss')
self.z_init = z_init_eps.mul(z_init_std).add_(self.z_init_mu)
self.z_movement_mus, self.z_movement_logvars = self.movement_enc(stdpSeq)
z_movement_stds = self.z_movement_logvars.mul(0.5).exp_()
z_movement_epss = self.get_z_random(z_movement_stds.size(0), z_movement_stds.size(1), 'gauss')
self.z_movements = z_movement_epss.mul(z_movement_stds).add_(self.z_movement_mus)
self.z_movementSeq_mu = self.z_movement_mus.view(batchsize, -1, self.z_movements.shape[1])
self.z_movementSeq_logvar = self.z_movement_logvars.view(batchsize, -1, self.z_movements.shape[1])
self.z_init, self.z_movements = self.z_init.detach(), self.z_movements.detach()
self.z_movement_mus, self.z_movement_logvars = self.z_movement_mus.detach(), self.z_movement_logvars.detach()
### z_movements -> z_dance
self.z_dance_mu, self.z_dance_logvar = self.dance_enc(self.z_movementSeq_mu, self.z_movementSeq_logvar)
z_dance_std = self.z_dance_logvar.mul(0.5).exp_()
z_dance_eps = self.get_z_random(z_dance_std.size(0), z_dance_std.size(1), 'gauss')
self.z_dance = z_dance_eps.mul(z_dance_std).add_(self.z_dance_mu)
### z_dance -> z_movements
self.recon_z_movements_mu, self.recon_z_movements_logvar = self.dance_dec(self.z_dance)
recon_z_movement_std = self.recon_z_movements_logvar.mul(0.5).exp_()
recon_z_movement_eps = self.get_z_random(recon_z_movement_std.size(0), recon_z_movement_std.size(1), 'gauss')
self.recon_z_movements = recon_z_movement_eps.mul(recon_z_movement_std).add_(self.recon_z_movements_mu)
### z_movements -> stdpSeq
self.recon_stdpSeq = self.stdp_dec(self.z_init, self.recon_z_movements)
### Music to z_dance to z_movements
self.fake_z_dance_mu, self.fake_z_dance_logvar = self.audstyle_enc(aud_style)
fake_z_dance_std = self.fake_z_dance_logvar.mul(0.5).exp_()
fake_z_dance_eps = self.get_z_random(fake_z_dance_std.size(0), fake_z_dance_std.size(1), 'gauss')
self.fake_z_dance = fake_z_dance_eps.mul(fake_z_dance_std).add_(self.fake_z_dance_mu)
self.fake_z_movements_mu, self.fake_z_movements_logvar = self.dance_dec(self.fake_z_dance)
fake_z_movements_std = self.fake_z_movements_logvar.mul(0.5).exp_()
fake_z_movements_eps = self.get_z_random(fake_z_movements_std.size(0), fake_z_movements_std.size(1), 'gauss')
self.fake_z_movements = fake_z_movements_eps.mul(fake_z_movements_std).add_(self.fake_z_movements_mu)
fake_z_movementSeq_mu = self.fake_z_movements_mu.view(batchsize, -1, self.fake_z_movements_mu.shape[1])
fake_z_movementSeq_logvar = self.fake_z_movements_logvar.view(batchsize, -1, self.fake_z_movements_logvar.shape[1])
self.fake_z_movementSeq = torch.cat((fake_z_movementSeq_mu, fake_z_movementSeq_logvar),2)
def backward_D(self):
#real_movements = torch.cat((self.z_movementSeq_mu, self.z_movementSeq_logvar),2)
tmp_recon_mu = self.recon_z_movements_mu.view(self.batchsize, -1, self.z_movements.shape[1])
tmp_recon_logvar = self.recon_z_movements_logvar.view(self.batchsize, -1, self.z_movements.shape[1])
real_movements = torch.cat((tmp_recon_mu, tmp_recon_logvar),2)
fake_movements = self.fake_z_movementSeq
real_labels,_ = self.danceAud_dis(real_movements.detach(), self.aud)
fake_labels,_ = self.danceAud_dis(fake_movements.detach(), self.aud)
ones = self.ones_like(real_labels)
zeros = self.zeros_like(fake_labels)
self.loss_dis_true = self.gan_criterion(real_labels, ones)
self.loss_dis_fake = self.gan_criterion(fake_labels, zeros)
self.loss_dis = (self.loss_dis_true + self.loss_dis_fake)*self.args.lambda_gan
real_dance = torch.cat((self.z_dance_mu, self.z_dance_logvar), 1)
fake_dance = torch.cat((self.fake_z_dance_mu, self.fake_z_dance_logvar), 1)
real_labels, _ = self.zdance_dis(real_dance.detach(), self.aud)
fake_labels, _ = self.zdance_dis(fake_dance.detach(), self.aud)
ones = self.ones_like(real_labels)
zeros = self.zeros_like(fake_labels)
self.loss_zdis_true = self.gan_criterion(real_labels, ones)
self.loss_zdis_fake = self.gan_criterion(fake_labels, zeros)
self.loss_dis += (self.loss_zdis_true + self.loss_zdis_fake)*self.args.lambda_gan
def backward_danceED(self):
# z_dance KL
kl_element = self.z_dance_mu.pow(2).add_(self.z_dance_logvar.exp()).mul_(-1).add_(1).add_(self.z_dance_logvar)
self.loss_kl_z_dance = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl_dance))
kl_element = self.fake_z_dance_mu.pow(2).add_(self.fake_z_dance_logvar.exp()).mul_(-1).add_(1).add_(self.fake_z_dance_logvar)
self.loss_kl_fake_z_dance = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl_dance))
# z_movement KL
kl_element = self.recon_z_movements_mu.pow(2).add_(self.recon_z_movements_logvar.exp()).mul_(-1).add_(1).add_(self.recon_z_movements_logvar)
self.loss_kl_z_movement = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl))
kl_element = self.fake_z_movements_mu.pow(2).add_(self.fake_z_movements_logvar.exp()).mul_(-1).add_(1).add_(self.fake_z_movements_logvar)
self.loss_kl_fake_z_movements = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl))
# z_movement reconstruction
self.loss_l1_z_movement_mu = self.l1_criterion(self.recon_z_movements_mu, self.z_movement_mus) * self.args.lambda_zmovements_recon
self.loss_l1_z_movement_logvar = self.l1_criterion(self.recon_z_movements_logvar, self.z_movement_logvars) * self.args.lambda_zmovements_recon
# stdp reconstruction
self.loss_l1_stdpSeq = self.l1_criterion(self.recon_stdpSeq, self.stdpSeq) * self.args.lambda_stdpSeq_recon
# Music2Dance GAN
fake_movements = self.fake_z_movementSeq
fake_labels, _ = self.danceAud_dis(fake_movements, self.aud)
ones = self.ones_like(fake_labels)
self.loss_gen = self.gan_criterion(fake_labels, ones) * self.args.lambda_gan
fake_dance = torch.cat((self.fake_z_dance_mu, self.fake_z_dance_logvar), 1)
fake_labels, _ = self.zdance_dis(fake_dance, self.aud)
ones = self.ones_like(fake_labels)
self.loss_gen += self.gan_criterion(fake_labels, ones) * self.args.lambda_gan
self.loss = self.loss_kl_z_movement + self.loss_kl_z_dance + self.loss_l1_z_movement_mu + self.loss_l1_z_movement_logvar + self.loss_l1_stdpSeq + self.loss_gen
def backward_info_ondance(self):
real_pred = self.dance_reg(self.z_dance)
fake_pred = self.dance_reg(self.fake_z_dance)
self.loss_info_real = self.mse_criterion(real_pred, self.aud_style)
self.loss_info_fake = self.mse_criterion(fake_pred, self.aud_style)
self.loss_info = self.loss_info_real + self.loss_info_fake
def zero_grad(self, opt_list):
for opt in opt_list:
opt.zero_grad()
def clip_norm(self, network_list):
for network in network_list:
clip_grad_norm_(network.parameters(), 0.5)
def step(self, opt_list):
for opt in opt_list:
opt.step()
def update(self):
self.zero_grad([self.opt_danceAud_dis, self.opt_zdance_dis])
self.backward_D()
self.loss_dis.backward(retain_graph=True)
self.clip_norm([self.danceAud_dis, self.zdance_dis])
self.step([self.opt_danceAud_dis, self.opt_zdance_dis])
self.zero_grad([self.opt_dance_enc, self.opt_dance_dec, self.opt_audstyle_enc, self.opt_stdp_dec])
self.backward_danceED()
self.loss.backward(retain_graph=True)
self.clip_norm([self.dance_enc, self.dance_dec, self.audstyle_enc, self.stdp_dec])
self.step([self.opt_dance_enc, self.opt_dance_dec, self.opt_audstyle_enc, self.opt_stdp_dec])
self.zero_grad([self.opt_dance_enc, self.opt_audstyle_enc, self.opt_dance_reg, self.opt_stdp_dec])
self.backward_info_ondance()
self.loss_info.backward()
self.clip_norm([self.dance_enc, self.audstyle_enc, self.dance_reg, self.stdp_dec])
self.step([self.opt_dance_enc, self.opt_audstyle_enc, self.opt_dance_reg, self.opt_stdp_dec])
def test_final(self, initpose, aud, n, thr=0):
self.cuda()
self.movement_enc.eval()
self.stdp_dec.eval()
self.initp_enc.eval()
self.dance_enc.eval()
self.dance_dec.eval()
self.aud_enc.eval()
self.audstyle_enc.eval()
aud_style = self.aud_enc.get_style(aud).detach()
self.fake_z_dance_mu, self.fake_z_dance_logvar = self.audstyle_enc(aud_style)
fake_z_dance_std = self.fake_z_dance_logvar.mul(0.5).exp_()
fake_z_dance_eps = self.get_z_random(fake_z_dance_std.size(0), fake_z_dance_std.size(1), 'gauss')
self.fake_z_dance = fake_z_dance_eps.mul(fake_z_dance_std).add_(self.fake_z_dance_mu)
self.fake_z_movements_mu, self.fake_z_movements_logvar = self.dance_dec(self.fake_z_dance, length=3)
fake_z_movements_std = self.fake_z_movements_logvar.mul(0.5).exp_()
fake_z_movements_eps = self.get_z_random(fake_z_movements_std.size(0), fake_z_movements_std.size(1), 'gauss')
self.fake_z_movements = fake_z_movements_eps.mul(fake_z_movements_std).add_(self.fake_z_movements_mu)
fake_stdpSeq=[]
for i in range(n):
z_init_mus, z_init_logvars = self.initp_enc(initpose)
z_init_stds = z_init_logvars.mul(0.5).exp_()
z_init_epss = self.get_z_random(z_init_stds.size(0), z_init_stds.size(1), 'gauss')
z_init = z_init_epss.mul(z_init_stds).add_(z_init_mus)
fake_stdp = self.stdp_dec(z_init, self.fake_z_movements[i:i+1])
fake_stdpSeq.append(fake_stdp)
initpose = fake_stdp[:,-1,:]
fake_stdpSeq = torch.cat(fake_stdpSeq, dim=0)
flag = False
for i in range(n):
s = fake_stdpSeq[i]
diff = torch.abs(s[1:]-s[:-1])
diffsum = torch.sum(diff)
if diffsum.cpu().detach().numpy() < thr:
flag = True
if flag:
return None
else:
return fake_stdpSeq.cpu().detach().numpy()
def resume(self, model_dir, train=True):
checkpoint = torch.load(model_dir)
self.dance_enc.load_state_dict(checkpoint['dance_enc'])
self.dance_dec.load_state_dict(checkpoint['dance_dec'])
self.audstyle_enc.load_state_dict(checkpoint['audstyle_enc'])
self.stdp_dec.load_state_dict(checkpoint['stdp_dec'])
self.movement_enc.load_state_dict(checkpoint['movement_enc'])
if train:
self.danceAud_dis.load_state_dict(checkpoint['danceAud_dis'])
self.dance_reg.load_state_dict(checkpoint['dance_reg'])
self.opt_dance_enc.load_state_dict(checkpoint['opt_dance_enc'])
self.opt_dance_dec.load_state_dict(checkpoint['opt_dance_dec'])
self.opt_stdp_dec.load_state_dict(checkpoint['opt_stdp_dec'])
self.opt_audstyle_enc.load_state_dict(checkpoint['opt_audstyle_enc'])
self.opt_danceAud_dis.load_state_dict(checkpoint['opt_danceAud_dis'])
self.opt_dance_reg.load_state_dict(checkpoint['opt_dance_reg'])
return checkpoint['ep'], checkpoint['total_it']
def save(self, filename, ep, total_it):
state = {
'stdp_dec': self.stdp_dec.state_dict(),
'movement_enc': self.movement_enc.state_dict(),
'dance_enc': self.dance_enc.state_dict(),
'dance_dec': self.dance_dec.state_dict(),
'audstyle_enc': self.audstyle_enc.state_dict(),
'danceAud_dis': self.danceAud_dis.state_dict(),
'zdance_dis': self.zdance_dis.state_dict(),
'dance_reg': self.dance_reg.state_dict(),
'opt_stdp_dec': self.opt_stdp_dec.state_dict(),
'opt_movement_enc': self.opt_movement_enc.state_dict(),
'opt_dance_enc': self.opt_dance_enc.state_dict(),
'opt_dance_dec': self.opt_dance_dec.state_dict(),
'opt_audstyle_enc': self.opt_audstyle_enc.state_dict(),
'opt_danceAud_dis': self.opt_danceAud_dis.state_dict(),
'opt_zdance_dis': self.opt_zdance_dis.state_dict(),
'opt_dance_reg': self.opt_dance_reg.state_dict(),
'ep': ep,
'total_it': total_it
}
torch.save(state, filename)
return
def cuda(self):
if self.train:
self.dance_reg.cuda()
self.danceAud_dis.cuda()
self.zdance_dis.cuda()
self.stdp_dec.cuda()
self.initp_enc.cuda()
self.movement_enc.cuda()
self.dance_enc.cuda()
self.dance_dec.cuda()
self.aud_enc.cuda()
self.audstyle_enc.cuda()
self.gan_criterion.cuda()
def train(self, ep=0, it=0):
self.cuda()
for epoch in range(ep, self.args.num_epochs):
self.movement_enc.train()
self.stdp_dec.train()
self.initp_enc.train()
self.dance_enc.train()
self.dance_dec.train()
self.danceAud_dis.train()
self.zdance_dis.train()
self.audstyle_enc.train()
self.dance_reg.train()
self.aud_enc.eval()
stdp_recon = 0
for i, (stdpSeq, aud) in enumerate(self.data_loader):
stdpSeq, aud = stdpSeq.cuda().detach(), aud.cuda().detach()
stdpSeq = stdpSeq.view(stdpSeq.shape[0]*stdpSeq.shape[1], stdpSeq.shape[2], stdpSeq.shape[3])
aud_style = self.aud_enc.get_style(aud).detach()
self.forward(stdpSeq, aud.shape[0], aud_style, aud)
self.update()
self.logs['l_kl_zmovement'] += self.loss_kl_z_movement.data
self.logs['l_kl_zdance'] += self.loss_kl_z_dance.data
self.logs['l_l1_zmovement_mu'] += self.loss_l1_z_movement_mu.data
self.logs['l_l1_zmovement_logvar'] += self.loss_l1_z_movement_logvar.data
self.logs['l_l1_stdpSeq'] += self.loss_l1_stdpSeq.data
self.logs['l_kl_fake_zdance'] += self.loss_kl_fake_z_dance.data
self.logs['l_kl_fake_zmovement'] += self.loss_kl_fake_z_movements
self.logs['l_dis'] += self.loss_dis.data
self.logs['l_dis_true'] += self.loss_dis_true.data
self.logs['l_dis_fake'] += self.loss_dis_fake.data
self.logs['l_gen'] += self.loss_gen.data
self.logs['l_info'] += self.loss_info
self.logs['l_info_real'] += self.loss_info_real
self.logs['l_info_fake'] += self.loss_info_fake
print('Epoch:{:3} Iter{}/{}\tl_l1_zmovement mu{:.3f} logvar{:.3f}\tl_l1_stdpSeq {:.3f}\tl_kl_dance {:.3f}\tl_kl_movement {:.3f}\n'.format(epoch, i, len(self.data_loader),
self.loss_l1_z_movement_mu, self.loss_l1_z_movement_logvar, self.loss_l1_stdpSeq, self.loss_kl_z_dance, self.loss_kl_z_movement) +
'\t\t\tl_kl_f_dance {:.3f}\tl_dis {:.3f} {:.3f}\tl_gen {:.3f}'.format(self.loss_kl_fake_z_dance, self.loss_dis_true, self.loss_dis_fake, self.loss_gen))
it += 1
if it % self.log_interval == 0:
for tag, value in self.logs.items():
self.logger.scalar_summary(tag, value/self.log_interval, it)
self.logs = self.init_logs()
if epoch % self.snapshot_ep == 0:
self.save(os.path.join(self.snapshot_dir, '{:04}.ckpt'.format(epoch)), epoch, it)
================================================
FILE: model_decomp.py
================================================
# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/Dancing2Music/License.txt
import os
import time
import numpy as np
import random
import math
import torch
from torch import nn
from torch.autograd import Variable
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
from utils import Logger
if torch.cuda.is_available():
T = torch.cuda
else:
T = torch
class Trainer_Decomp(object):
def __init__(self, data_loader, initp_enc, initp_dec, movement_enc, stdp_dec, args=None):
self.data_loader = data_loader
self.initp_enc = initp_enc
self.initp_dec = initp_dec
self.movement_enc = movement_enc
self.stdp_dec = stdp_dec
self.args = args
if args.train:
self.logger = Logger(args.log_dir)
self.logs = self.init_logs()
self.log_interval = args.log_interval
self.snapshot_ep = args.snapshot_ep
self.snapshot_dir = args.snapshot_dir
self.opt_initp_enc = torch.optim.Adam(self.initp_enc.parameters(), lr=args.lr)
self.opt_initp_dec = torch.optim.Adam(self.initp_dec.parameters(), lr=args.lr)
self.opt_movement_enc = torch.optim.Adam(self.movement_enc.parameters(), lr=args.lr)
self.opt_stdp_dec = torch.optim.Adam(self.stdp_dec.parameters(), lr=args.lr)
self.latent_dropout = nn.Dropout(p=args.latent_dropout)
self.l1_criterion = torch.nn.L1Loss()
self.gan_criterion = nn.BCEWithLogitsLoss()
def init_logs(self):
return {'l_kl_zinit':0, 'l_kl_zmovement':0, 'l_l1_stdp':0, 'l_l1_cross_stdp':0, 'l_dist_zmovement':0,
'l_l1_initp':0, 'l_l1_initp_con':0,
'kld_coef':0
}
def get_z_random(self, batchSize, nz, random_type='gauss'):
z = torch.randn(batchSize, nz).cuda()
return z
@staticmethod
def ones_like(tensor, val=1.):
return T.FloatTensor(tensor.size()).fill_(val)
@staticmethod
def zeros_like(tensor, val=0.):
return T.FloatTensor(tensor.size()).fill_(val)
def random_generate_stdp(self, init_p):
self.pose_0 = init_p
self.z_init_mu, self.z_init_logvar = self.initp_enc(self.pose_0)
z_init_std = self.z_init_logvar.mul(0.5).exp_()
z_init_eps = self.get_z_random(z_init_std.size(0), z_init_std.size(1), 'gauss')
self.z_init = z_init_eps.mul(z_init_std).add_(self.z_init_mu)
self.z_random_movement = self.get_z_random(self.z_init.size(0), 512, 'gauss')
self.fake_stdpose = self.stdp_dec(self.z_init, self.z_random_movement)
return self.fake_stdpose
def forward(self, stdpose1, stdpose2):
self.stdpose1 = stdpose1
self.stdpose2 = stdpose2
# stdpose -> stdpose[0] -> z_init
self.pose1_0 = stdpose1[:,0,:]
self.pose2_0 = stdpose2[:,0,:]
self.poses_0 = torch.cat((self.pose1_0, self.pose2_0), 0)
self.z_init_mus, self.z_init_logvars = self.initp_enc(self.poses_0)
z_init_stds = self.z_init_logvars.mul(0.5).exp_()
z_init_epss = self.get_z_random(z_init_stds.size(0), z_init_stds.size(1), 'gauss')
self.z_inits = z_init_epss.mul(z_init_stds).add_(self.z_init_mus)
self.z_init1, self.z_init2 = torch.split(self.z_inits, self.stdpose1.size(0), dim=0)
# stdpose -> z_movement
stdposes = torch.cat((stdpose1, stdpose2), 0)
self.z_movement_mus, self.z_movement_logvars = self.movement_enc(stdposes)
z_movement_stds = self.z_movement_logvars.mul(0.5).exp_()
z_movement_epss = self.get_z_random(z_movement_stds.size(0), z_movement_stds.size(1), 'gauss')
self.z_movements = z_movement_epss.mul(z_movement_stds).add_(self.z_movement_mus)
self.z_movement1, self.z_movement2 = torch.split(self.z_movements, self.stdpose1.size(0), dim=0)
# zinit1+zmovement1->stdpose1 zinit2+zmovement2->stdpose2
self.recon_stdpose1 = self.stdp_dec(self.z_init1, self.z_movement1)
self.recon_stdpose2 = self.stdp_dec(self.z_init2, self.z_movement2)
# zinit1+zmovement2->stdpose1 zinit2+zmovement1->stdpose2
self.recon_stdpose1_cross = self.stdp_dec(self.z_init1, self.z_movement2)
self.recon_stdpose2_cross = self.stdp_dec(self.z_init2, self.z_movement1)
# z_init -> \hat{stdpose[0]}
self.recon_pose1_0 = self.initp_dec(self.z_init1)
self.recon_pose2_0 = self.initp_dec(self.z_init2)
# single pose reconstruction
randomlist = np.random.permutation(31)[:4]
singlepose = []
for r in randomlist:
singlepose.append(self.stdpose1[:,r,:])
self.singleposes = torch.cat(singlepose, dim=0).detach()
self.z_single_mus, self.z_single_logvars = self.initp_enc(self.singleposes)
z_single_stds = self.z_single_logvars.mul(0.5).exp_()
z_single_epss = self.get_z_random(z_single_stds.size(0), z_single_stds.size(1), 'gauss')
z_single = z_single_epss.mul(z_single_stds).add_(self.z_single_mus)
self.recon_singleposes = self.initp_dec(z_single)
def backward_initp_ED(self):
# z_init KL
kl_element = self.z_init_mus.pow(2).add_(self.z_init_logvars.exp()).mul_(-1).add_(1).add_(self.z_init_logvars)
self.loss_kl_z_init = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl))
# initpose reconstruction
self.loss_l1_initp = self.l1_criterion(self.recon_singleposes, self.singleposes) * self.args.lambda_initp_recon
self.loss_initp = self.loss_kl_z_init + self.loss_l1_initp
def backward_movement_ED(self):
# z_movement KL
kl_element = self.z_movement_mus.pow(2).add_(self.z_movement_logvars.exp()).mul_(-1).add_(1).add_(self.z_movement_logvars)
#self.loss_kl_z_movement = torch.mean(kl_element).mul_(-0.5) * self.args.lambda_kl
self.loss_kl_z_movement = torch.mean( (torch.sum(kl_element, dim=1).mul_(-0.5) * self.args.lambda_kl))
# stdpose self reconstruction
loss_l1_stdp1 = self.l1_criterion(self.recon_stdpose1, self.stdpose1) * self.args.lambda_stdp_recon
loss_l1_stdp2 = self.l1_criterion(self.recon_stdpose2, self.stdpose2) * self.args.lambda_stdp_recon
self.loss_l1_stdp = loss_l1_stdp1 + loss_l1_stdp2
# stdpose cross reconstruction
loss_l1_cross_stdp1 = self.l1_criterion(self.recon_stdpose1_cross, self.stdpose1) * self.args.lambda_stdp_recon
loss_l1_cross_stdp2 = self.l1_criterion(self.recon_stdpose2_cross, self.stdpose2) * self.args.lambda_stdp_recon
self.loss_l1_cross_stdp = loss_l1_cross_stdp1 + loss_l1_cross_stdp2
# Movement dist
self.loss_dist_z_movement = torch.mean(torch.abs(self.z_movement1-self.z_movement2)) * self.args.lambda_dist_z_movement
self.loss_movement = self.loss_kl_z_movement + self.loss_l1_stdp + self.loss_l1_cross_stdp + self.loss_dist_z_movement
def update(self):
self.opt_initp_enc.zero_grad()
self.opt_initp_dec.zero_grad()
self.opt_movement_enc.zero_grad()
self.opt_stdp_dec.zero_grad()
self.backward_initp_ED()
self.backward_movement_ED()
self.g_loss = self.loss_initp + self.loss_movement
self.g_loss.backward(retain_graph=True)
clip_grad_norm_(self.movement_enc.parameters(), 0.5)
clip_grad_norm_(self.stdp_dec.parameters(), 0.5)
self.opt_initp_enc.step()
self.opt_initp_dec.step()
self.opt_movement_enc.step()
self.opt_stdp_dec.step()
def save(self, filename, ep, total_it):
state = {
'stdp_dec': self.stdp_dec.state_dict(),
'movement_enc': self.movement_enc.state_dict(),
'initp_enc': self.initp_enc.state_dict(),
'initp_dec': self.initp_dec.state_dict(),
'opt_stdp_dec': self.opt_stdp_dec.state_dict(),
'opt_movement_enc': self.opt_movement_enc.state_dict(),
'opt_initp_enc': self.opt_initp_enc.state_dict(),
'opt_initp_dec': self.opt_initp_dec.state_dict(),
'ep': ep,
'total_it': total_it
}
torch.save(state, filename)
return
def resume(self, model_dir, train=True):
checkpoint = torch.load(model_dir)
# weight
self.stdp_dec.load_state_dict(checkpoint['stdp_dec'])
self.movement_enc.load_state_dict(checkpoint['movement_enc'])
self.initp_enc.load_state_dict(checkpoint['initp_enc'])
self.initp_dec.load_state_dict(checkpoint['initp_dec'])
# optimizer
if train:
self.opt_stdp_dec.load_state_dict(checkpoint['opt_stdp_dec'])
self.opt_movement_enc.load_state_dict(checkpoint['opt_movement_enc'])
self.opt_initp_enc.load_state_dict(checkpoint['opt_initp_enc'])
self.opt_initp_dec.load_state_dict(checkpoint['opt_initp_dec'])
return checkpoint['ep'], checkpoint['total_it']
def kld_coef(self, i):
return float(1/(1+np.exp(-0.0005*(i-15000)))) #v3
def generate_stdp_sequence(self, initpose, aud, num_stdp):
self.initp_enc.cuda()
self.initp_dec.cuda()
self.movement_enc.cuda()
self.stdp_dec.cuda()
self.initp_enc.eval()
self.initp_dec.eval()
self.movement_enc.eval()
self.stdp_dec.eval()
initpose = initpose.cuda()
aud_style = self.aud_enc.get_style(aud)
stdp_seq = []
cnt = 0
#for i in range(num_stdp):
while not cnt == num_stdp:
if cnt==0:
z_inits = self.get_z_random(1, 10, 'gauss')
else:
z_init_mus, z_init_logvars = self.initp_enc(initpose)
z_init_stds = z_init_logvars.mul(0.5).exp_()
z_init_epss = self.get_z_random(z_init_stds.size(0), z_init_stds.size(1), 'gauss')
z_inits = z_init_epss.mul(z_init_stds).add_(z_init_mus)
z_audstyle_mu, z_audstyle_logvar = self.audstyle_enc(aud_style)
z_as_std = z_audstyle_logvar.mul(0.5).exp_()
z_as_eps = self.get_z_random(z_as_std.size(0), z_as_std.size(1), 'gauss')
z_audstyle = z_as_eps.mul(z_as_std).add_(z_audstyle_mu)
if random.randint(0,5)==100:
z_audstyle = self.get_z_random(z_inits.size(0), 512, 'gauss')
fake_stdpose = self.stdp_dec(z_inits, z_audstyle)
s = fake_stdpose[0]
diff = torch.abs(s[1:]-s[:-1])
diffsum = torch.sum(diff)
if diffsum.cpu().detach().numpy() < 70:
continue
cnt += 1
stdp_seq.append(fake_stdpose.cpu().detach().numpy())
initpose = fake_stdpose[:,-1,:]
return stdp_seq
def cuda(self):
self.initp_enc.cuda()
self.initp_dec.cuda()
self.movement_enc.cuda()
self.stdp_dec.cuda()
self.l1_criterion.cuda()
def train(self, ep=0, it=0):
self.cuda()
full_kl = self.args.lambda_kl
kl_w = 0
kl_step = 0.05
best_stdp_recon = 100
for epoch in range(ep, self.args.num_epochs):
self.initp_enc.train()
self.initp_dec.train()
self.movement_enc.train()
self.stdp_dec.train()
stdp_recon = 0
for i, (stdpose, stdpose2) in enumerate(self.data_loader):
self.args.lambda_kl = full_kl*self.kld_coef(it)
stdpose, stdpose2 = stdpose.cuda().detach(), stdpose2.cuda().detach()
self.forward(stdpose, stdpose2)
self.update()
self.logs['l_kl_zinit'] += self.loss_kl_z_init.data
self.logs['l_kl_zmovement'] += self.loss_kl_z_movement.data
self.logs['l_l1_initp'] += self.loss_l1_initp.data
self.logs['l_l1_stdp'] += self.loss_l1_stdp.data
self.logs['l_l1_cross_stdp'] += self.loss_l1_cross_stdp.data
self.logs['l_dist_zmovement'] += self.loss_dist_z_movement.data
self.logs['kld_coef'] += self.args.lambda_kl
print('Epoch:{:3} Iter{}/{}\tl_l1_initp {:.3f}\tl_l1_stdp {:.3f}\tl_l1_cross_stdp {:.3f}\tl_dist_zmove {:.3f}\tl_kl_zinit {:.3f}\t l_kl_zmove {:.3f}'.format(
epoch, i, len(self.data_loader), self.loss_l1_initp, self.loss_l1_stdp, self.loss_l1_cross_stdp, self.loss_dist_z_movement, self.loss_kl_z_init, self.loss_kl_z_movement))
it += 1
if it % self.log_interval == 0:
for tag, value in self.logs.items():
self.logger.scalar_summary(tag, value/self.log_interval, it)
self.logs = self.init_logs()
if epoch % self.snapshot_ep == 0:
self.save(os.path.join(self.snapshot_dir, '{:04}.ckpt'.format(epoch)), epoch, it)
================================================
FILE: modulate.py
================================================
# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/Dancing2Music/License.txt
import os
import numpy as np
import librosa
import utils
def modulate(dance, beats, length):
sec_interframe = 1/15
beats_frame = np.around(beats)
t_beat = beats_frame.astype(int)
s_beat = np.arange(3,dance.shape[0],8)
final_pose = np.zeros((length, 14, 2))
if t_beat[0] >3:
final_pose[t_beat[0]-3:t_beat[0]] = dance[:3]
else:
final_pose[:t_beat[0]] = dance[:t_beat[0]]
if t_beat[0]-3 > 0:
final_pose[:t_beat[0]-3] = dance[0]
for t in range(t_beat.shape[0]-1):
begin = int(t_beat[t])
end = int(t_beat[t+1])
interval = end-begin
if t==s_beat.shape[0]-1:
rest = min(final_pose.shape[0]-begin-1, dance.shape[0]-s_beat[t]-1)
break
if t+1 < s_beat.shape[0] and s_beat[t+1]<dance.shape[0]:
pose = get_pose(dance[s_beat[t]:s_beat[t+1]+1], interval+1)
if t==0 and begin>=3:
final_pose[begin-s_beat[t]:begin] = dance[:s_beat[t]]
final_pose[begin:end+1]=pose
rest = min(final_pose.shape[0]-end-1, dance.shape[0]-s_beat[t+1]-1)
else:
end = begin
if t+1 < s_beat.shape[0]:
rest = min(final_pose.shape[0]-end-1, dance.shape[0]-s_beat[t+1]-1)
else:
print(t_beat.shape, s_beat.shape, t)
rest = min(final_pose.shape[0]-end-1, dance.shape[0]-s_beat[t]-1)
if rest > 0:
if t+1 < s_beat.shape[0]:
final_pose[end+1:end+1+rest] = dance[s_beat[t+1]+1:s_beat[t+1]+1+rest]
else:
final_pose[end+1:end+1+rest] = dance[s_beat[t]+1:s_beat[t]+1+rest]
return final_pose
def get_pose(pose, n):
t_pose = np.zeros((n, 14, 2))
if n==11:
t_pose[0] = pose[0]
t_pose[1] = (pose[0]*1+pose[1]*4)/5
t_pose[2] = (pose[1]*2+pose[2]*3)/5
t_pose[3] = (pose[2]*3+pose[3]*2)/5
t_pose[4] = (pose[3]*4+pose[4]*1)/5
t_pose[5] = pose[4]
t_pose[6] = (pose[4]*1+pose[5]*4)/5
t_pose[7] = (pose[5]*2+pose[6]*3)/5
t_pose[8] = (pose[6]*3+pose[7]*2)/5
t_pose[9] = (pose[7]*4+pose[8]*1)/5
t_pose[10] = pose[8]
elif n==10:
t_pose[0] = pose[0]
t_pose[1] = (pose[0]*1+pose[1]*8)/9
t_pose[2] = (pose[1]*2+pose[2]*7)/9
t_pose[3] = (pose[2]*3+pose[3]*6)/9
t_pose[4] = (pose[3]*4+pose[4]*5)/9
t_pose[5] = (pose[4]*5+pose[5]*4)/9
t_pose[6] = (pose[5]*6+pose[6]*3)/9
t_pose[7] = (pose[6]*7+pose[7]*2)/9
t_pose[8] = (pose[7]*8+pose[8]*1)/9
t_pose[9] = pose[8]
elif n==12:
t_pose[0] = pose[0]
t_pose[1] = (pose[0]*3+pose[1]*8)/11
t_pose[2] = (pose[1]*6+pose[2]*5)/11
t_pose[3] = (pose[2]*9+pose[3]*2)/11
t_pose[4] = (pose[2]*1+pose[3]*10)/11
t_pose[5] = (pose[3]*4+pose[4]*7)/11
t_pose[6] = (pose[4]*7+pose[5]*4)/11
t_pose[7] = (pose[5]*10+pose[6]*1)/11
t_pose[8] = (pose[5]*2+pose[6]*9)/11
t_pose[9] = (pose[6]*5+pose[7]*6)/11
t_pose[10] = (pose[7]*8+pose[8]*3)/11
t_pose[11] = pose[8]
elif n==13:
t_pose[0] = pose[0]
t_pose[1] = (pose[0]*1+pose[1]*2)/3
t_pose[2] = (pose[1]*2+pose[2]*1)/3
t_pose[3] = pose[2]
t_pose[4] = (pose[2]*1+pose[3]*2)/3
t_pose[5] = (pose[3]*2+pose[4]*1)/3
t_pose[6] = pose[4]
t_pose[7] = (pose[4]*1+pose[5]*2)/3
t_pose[8] = (pose[5]*2+pose[6]*1)/3
t_pose[9] = pose[6]
t_pose[10] = (pose[6]*1+pose[7]*2)/3
t_pose[11] = (pose[7]*2+pose[8]*1)/3
t_pose[12] = pose[8]
elif n==14:
t_pose[0] = pose[0]
t_pose[1] = (pose[0]*5+pose[1]*8)/13
t_pose[2] = (pose[1]*10+pose[2]*3)/13
t_pose[3] = (pose[1]*2+pose[2]*11)/13
t_pose[4] = (pose[2]*7+pose[3]*6)/13
t_pose[5] = (pose[3]*12+pose[4]*1)/13
t_pose[6] = (pose[3]*4+pose[4]*9)/13
t_pose[7] = (pose[4]*9+pose[5]*4)/13
t_pose[8] = (pose[4]*12+pose[5]*1)/13
t_pose[9] = (pose[5]*6+pose[6]*7)/13
t_pose[10] = (pose[6]*11+pose[7]*2)/13
t_pose[11] = (pose[6]*3+pose[7]*10)/13
t_pose[12] = (pose[7]*8+pose[8]*5)/13
t_pose[13] = pose[8]
elif n==9:
t_pose = pose
elif n==8:
t_pose[0] = pose[0]
t_pose[1] = (pose[1]*6+pose[2]*1)/7
t_pose[2] = (pose[2]*5+pose[3]*2)/7
t_pose[3] = (pose[3]*4+pose[4]*3)/7
t_pose[4] = (pose[4]*3+pose[5]*4)/7
t_pose[5] = (pose[5]*2+pose[6]*5)/7
t_pose[6] = (pose[6]*1+pose[7]*6)/7
t_pose[7] = pose[8]
elif n==7:
t_pose[0] = pose[0]
t_pose[1] = (pose[1]*2+pose[2]*1)/3
t_pose[2] = (pose[2]*1+pose[3]*2)/3
t_pose[3] = pose[4]
t_pose[4] = (pose[5]*2+pose[6]*1)/3
t_pose[5] = (pose[6]*1+pose[7]*2)/3
t_pose[6] = pose[8]
elif n==6:
t_pose[0] = pose[0]
t_pose[1] = (pose[1]*2+pose[2]*3)/5
t_pose[2] = (pose[3]*4+pose[4]*1)/5
t_pose[3] = (pose[4]*1+pose[5]*4)/5
t_pose[4] = (pose[6]*3+pose[7]*2)/5
t_pose[5] = pose[8]
elif n<6:
t_pose[0] = pose[0]
t_pose[n-1] = pose[8]
for i in range(1,n-1):
t_pose[i] = pose[4]
elif n>14:
t_pose[0] = pose[0]
t_pose[n-1] = pose[8]
for i in range(1, n-1):
k = int(8/(n-1)*i)
t_pose[i] = t_pose[k]
else:
print('NOT IMPLEMENT {}'.format(n))
return t_pose
================================================
FILE: networks.py
================================================
# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/Dancing2Music/License.txt
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variable
import numpy as np
if torch.cuda.is_available():
T = torch.cuda
else:
T = torch
###########################################################
##########
########## Stage 1: Movement
##########
###########################################################
class InitPose_Enc(nn.Module):
def __init__(self, pose_size, dim_z_init):
super(InitPose_Enc, self).__init__()
nf = 64
#nf = 32
self.enc = nn.Sequential(
nn.Linear(pose_size, nf),
nn.LayerNorm(nf),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nf, nf),
nn.LayerNorm(nf),
nn.LeakyReLU(0.2, inplace=True),
)
self.mean = nn.Sequential(
nn.Linear(nf,dim_z_init),
)
self.std = nn.Sequential(
nn.Linear(nf,dim_z_init),
)
def forward(self, pose):
enc = self.enc(pose)
return self.mean(enc), self.std(enc)
class InitPose_Dec(nn.Module):
def __init__(self, pose_size, dim_z_init):
super(InitPose_Dec, self).__init__()
nf = 64
#nf = dim_z_init
self.dec = nn.Sequential(
nn.Linear(dim_z_init, nf),
nn.LayerNorm(nf),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nf, nf),
nn.LayerNorm(nf),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nf,pose_size),
)
def forward(self, z_init):
return self.dec(z_init)
class Movement_Enc(nn.Module):
def __init__(self, pose_size, dim_z_movement, length, hidden_size, num_layers, bidirection=False):
super(Movement_Enc, self).__init__()
self.hidden_size = hidden_size
self.bidirection = bidirection
if bidirection:
self.num_dir = 2
else:
self.num_dir = 1
self.recurrent = nn.GRU(pose_size, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=bidirection)
self.init_h = nn.Parameter(torch.randn(num_layers*self.num_dir, 1, hidden_size).type(T.FloatTensor), requires_grad=True)
if bidirection:
self.mean = nn.Sequential(
nn.Linear(hidden_size*2,dim_z_movement),
)
self.std = nn.Sequential(
nn.Linear(hidden_size*2,dim_z_movement),
)
else:
'''
self.enc = nn.Sequential(
nn.Linear(hidden_size, hidden_size//2),
nn.LayerNorm(hidden_size//2),
nn.ReLU(inplace=True),
)
'''
self.mean = nn.Sequential(
nn.Linear(hidden_size,dim_z_movement),
)
self.std = nn.Sequential(
nn.Linear(hidden_size,dim_z_movement),
)
def forward(self, poses):
num_samples = poses.shape[0]
h_t = [self.init_h.repeat(1, num_samples, 1)]
output, hidden = self.recurrent(poses, h_t[0])
if self.bidirection:
output = torch.cat((output[:,-1,:self.hidden_size], output[:,0,self.hidden_size:]), 1)
else:
output = output[:,-1,:]
#enc = self.enc(output)
#return self.mean(enc), self.std(enc)
return self.mean(output), self.std(output)
def getFeature(self, poses):
num_samples = poses.shape[0]
h_t = [self.init_h.repeat(1, num_samples, 1)]
output, hidden = self.recurrent(poses, h_t[0])
if self.bidirection:
output = torch.cat((output[:,-1,:self.hidden_size], output[:,0,self.hidden_size:]), 1)
else:
output = output[:,-1,:]
return output
class StandardPose_Dec(nn.Module):
def __init__(self, pose_size, dim_z_init, dim_z_movement, length, hidden_size, num_layers):
super(StandardPose_Dec, self).__init__()
self.length = length
self.pose_size = pose_size
self.hidden_size = hidden_size
self.num_layers = num_layers
#dim_z_init=0
'''
self.z2init = nn.Sequential(
nn.Linear(dim_z_init+dim_z_movement, hidden_size),
nn.LayerNorm(hidden_size),
nn.ReLU(True),
nn.Linear(hidden_size, num_layers*hidden_size)
)
'''
self.z2init = nn.Sequential(
nn.Linear(dim_z_init+dim_z_movement, num_layers*hidden_size)
)
self.recurrent = nn.GRU(dim_z_movement, hidden_size, num_layers=num_layers, batch_first=True)
self.pose_g = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.LayerNorm(hidden_size),
nn.ReLU(True),
nn.Linear(hidden_size, pose_size)
)
def forward(self, z_init, z_movement):
h_init = self.z2init(torch.cat((z_init, z_movement), 1))
#h_init = self.z2init(z_movement)
h_init = h_init.view(self.num_layers, h_init.size(0), self.hidden_size)
z_movements = z_movement.view(z_movement.size(0),1,z_movement.size(1)).repeat(1, self.length, 1)
z_m_t, _ = self.recurrent(z_movements, h_init)
z_m = z_m_t.contiguous().view(-1, self.hidden_size)
poses = self.pose_g(z_m)
poses = poses.view(z_movement.shape[0], self.length, self.pose_size)
return poses
class StandardPose_Dis(nn.Module):
def __init__(self, pose_size, length):
super(StandardPose_Dis, self).__init__()
self.pose_size = pose_size
self.length = length
nd = 1024
self.main = nn.Sequential(
nn.Linear(length*pose_size, nd),
nn.LayerNorm(nd),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd,nd//2),
nn.LayerNorm(nd//2),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd//2,nd//4),
nn.LayerNorm(nd//4),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd//4, 1)
)
def forward(self, pose_seq):
pose_seq = pose_seq.view(-1, self.pose_size*self.length)
return self.main(pose_seq).squeeze()
###########################################################
##########
########## Stage 2: Dance
##########
###########################################################
class Dance_Enc(nn.Module):
def __init__(self, dim_z_movement, dim_z_dance, hidden_size, num_layers, bidirection=False):
super(Dance_Enc, self).__init__()
self.hidden_size = hidden_size
self.bidirection = bidirection
if bidirection:
self.num_dir = 2
else:
self.num_dir = 1
self.recurrent = nn.GRU(2*dim_z_movement, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=bidirection)
self.init_h = nn.Parameter(torch.randn(num_layers*self.num_dir, 1, hidden_size).type(T.FloatTensor), requires_grad=True)
if bidirection:
self.mean = nn.Sequential(
nn.Linear(hidden_size*2,dim_z_dance),
)
self.std = nn.Sequential(
nn.Linear(hidden_size*2,dim_z_dance),
)
else:
self.mean = nn.Sequential(
nn.Linear(hidden_size,dim_z_dance),
)
self.std = nn.Sequential(
nn.Linear(hidden_size,dim_z_dance),
)
def forward(self, movements_mean, movements_std):
movements = torch.cat((movements_mean, movements_std),2)
num_samples = movements.shape[0]
h_t = [self.init_h.repeat(1, num_samples, 1)]
output, hidden = self.recurrent(movements, h_t[0])
if self.bidirection:
output = torch.cat((output[:,-1,:self.hidden_size], output[:,0,self.hidden_size:]), 1)
else:
output = output[:,-1,:]
return self.mean(output), self.std(output)
class Dance_Dec(nn.Module):
def __init__(self, dim_z_dance, dim_z_movement, hidden_size, num_layers):
super(Dance_Dec, self).__init__()
#self.length = length
self.num_layers = num_layers
self.hidden_size = hidden_size
self.dim_z_movement = dim_z_movement
#dim_z_init=0
'''
self.z2init = nn.Sequential(
nn.Linear(dim_z_init+dim_z_movement, hidden_size),
nn.LayerNorm(hidden_size),
nn.ReLU(True),
nn.Linear(hidden_size, num_layers*hidden_size)
)
'''
self.z2init = nn.Sequential(
nn.Linear(dim_z_dance, num_layers*hidden_size)
)
self.recurrent = nn.GRU(dim_z_dance, hidden_size, num_layers=num_layers, batch_first=True)
self.movement_g = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.LayerNorm(hidden_size),
nn.ReLU(True),
#nn.Linear(hidden_size, dim_z_movement)
)
self.mean = nn.Sequential(
nn.Linear(hidden_size,dim_z_movement),
)
self.std = nn.Sequential(
nn.Linear(hidden_size,dim_z_movement),
)
def forward(self, z_dance, length=3):
h_init = self.z2init(z_dance)
h_init = h_init.view(self.num_layers, h_init.size(0), self.hidden_size)
z_dance = z_dance.view(z_dance.size(0),1,z_dance.size(1)).repeat(1, length, 1)
z_d_t, _ = self.recurrent(z_dance, h_init)
z_d = z_d_t.contiguous().view(-1, self.hidden_size)
z_movement = self.movement_g(z_d)
z_movement_mean, z_movement_std = self.mean(z_movement), self.std(z_movement)
#z_movement = z_movement.view(z_dance.shape[0], length, self.dim_z_movement)
return z_movement_mean, z_movement_std
class DanceAud_Dis2(nn.Module):
def __init__(self, aud_size, dim_z_movement, length=3):
super(DanceAud_Dis2, self).__init__()
self.aud_size = aud_size
self.dim_z_movement = dim_z_movement
self.length = length
nd = 1024
self.movementd = nn.Sequential(
nn.Linear(dim_z_movement*2*length, nd),
nn.LayerNorm(nd),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd,nd//2),
nn.LayerNorm(nd//2),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd//2,nd//4),
nn.LayerNorm(nd//4),
nn.LeakyReLU(0.2, inplace=True),
#nn.Linear(nd//4, 30),
nn.Linear(nd//4, 30),
)
self.audd = nn.Sequential(
nn.Linear(aud_size, 30),
nn.LayerNorm(30),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(30, 30),
nn.LayerNorm(30),
nn.LeakyReLU(0.2, inplace=True),
)
self.jointd = nn.Sequential(
nn.Linear(60, 1)
)
def forward(self, movements, aud):
if len(movements.shape) == 3:
movements = movements.view(movements.shape[0], movements.shape[1]*movements.shape[2])
m = self.movementd(movements)
a = self.audd(aud)
ma = torch.cat((m,a),1)
return self.jointd(ma).squeeze(), None
class DanceAud_Dis(nn.Module):
def __init__(self, aud_size, dim_z_movement, length=3):
super(DanceAud_Dis, self).__init__()
self.aud_size = aud_size
self.dim_z_movement = dim_z_movement
self.length = length
nd = 1024
self.movementd = nn.Sequential(
#nn.Linear(dim_z_movement*3, nd),
nn.Linear(dim_z_movement*2, nd),
nn.LayerNorm(nd),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd,nd//2),
nn.LayerNorm(nd//2),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd//2,nd//4),
nn.LayerNorm(nd//4),
nn.LeakyReLU(0.2, inplace=True),
#nn.Linear(nd//4, 30),
nn.Linear(nd//4, 30),
)
def forward(self, movements, aud):
#movements = movements.view(movements.shape[0], movements.shape[1]*movements.shape[2])
m = self.movementd(movements)
return m.squeeze()
#a = self.audd(aud)
#ma = torch.cat((m,a),1)
#return self.jointd(ma).squeeze()
class DanceAud_InfoDis(nn.Module):
def __init__(self, aud_size, dim_z_movement, length):
super(DanceAud_InfoDis, self).__init__()
self.aud_size = aud_size
self.dim_z_movement = dim_z_movement
self.length = length
nd = 1024
self.movementd = nn.Sequential(
nn.Linear(dim_z_movement*6, nd*2),
nn.LayerNorm(nd*2),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd*2, nd),
nn.LayerNorm(nd),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd,nd//2),
nn.LayerNorm(nd//2),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd//2,nd//4),
nn.LayerNorm(nd//4),
nn.LeakyReLU(0.2, inplace=True),
)
self.dis = nn.Sequential(
nn.Linear(nd//4, 1)
)
self.reg = nn.Sequential(
nn.Linear(nd//4, aud_size)
)
def forward(self, movements, aud):
movements = movements.view(movements.shape[0], movements.shape[1]*movements.shape[2])
m = self.movementd(movements)
return self.dis(m).squeeze(), self.reg(m)
class Dance2Style(nn.Module):
def __init__(self, dim_z_dance, aud_size):
super(Dance2Style, self).__init__()
self.aud_size = aud_size
self.dim_z_dance = dim_z_dance
nd = 512
self.main = nn.Sequential(
nn.Linear(dim_z_dance, nd),
nn.LayerNorm(nd),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd, nd//2),
nn.LayerNorm(nd//2),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd//2, nd//4),
nn.LayerNorm(nd//4),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd//4, aud_size),
)
def forward(self, zdance):
return self.main(zdance)
###########################################################
##########
########## Audio
##########
###########################################################
class AudioClassifier_rnn(nn.Module):
def __init__(self, dim_z_motion, hidden_size, pose_size, cls, num_layers=1, h_init=2):
super(AudioClassifier_rnn, self).__init__()
self.dim_z_motion = dim_z_motion
self.hidden_size = hidden_size
self.pose_size = pose_size
self.h_init = h_init
self.num_layers = num_layers
self.init_h = nn.Parameter(torch.randn(1, 1, self.hidden_size).type(T.FloatTensor), requires_grad=True)
self.recurrent = nn.GRU(pose_size, hidden_size, num_layers=num_layers, batch_first=True)
self.classifier = nn.Sequential(
#nn.Dropout(p=0.2),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(True),
#nn.Dropout(p=0.2),
nn.Linear(hidden_size, cls)
)
def forward(self, poses):
hidden, _ = self.recurrent(poses, self.init_h.repeat(1, poses.shape[0], 1))
last_hidden = hidden[:,-1,:]
cls = self.classifier(last_hidden)
return cls
def get_style(self, auds):
hidden, _ = self.recurrent(auds, self.init_h.repeat(1, auds.shape[0], 1))
last_hidden = hidden[:,-1,:]
return last_hidden
class Audstyle_Enc(nn.Module):
def __init__(self, aud_size, dim_z, dim_noise=30):
super(Audstyle_Enc, self).__init__()
self.dim_noise = dim_noise
nf = 64
#nf = 32
self.enc = nn.Sequential(
nn.Linear(aud_size+dim_noise, nf),
nn.LayerNorm(nf),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nf, nf*2),
nn.LayerNorm(nf*2),
nn.LeakyReLU(0.2, inplace=True),
)
self.mean = nn.Sequential(
nn.Linear(nf*2,dim_z),
)
self.std = nn.Sequential(
nn.Linear(nf*2,dim_z),
)
def forward(self, aud):
noise = torch.randn(aud.shape[0], self.dim_noise).cuda()
y = torch.cat((aud, noise), 1)
enc = self.enc(y)
return self.mean(enc), self.std(enc)
================================================
FILE: options.py
================================================
# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/Dancing2Music/License.txt
import argparse
class DecompOptions():
def __init__(self):
parser = argparse.ArgumentParser()
parser.add_argument('--name', default=None)
parser.add_argument('--log_interval', type=int, default=50)
parser.add_argument('--log_dir', default='./logs')
parser.add_argument('--snapshot_ep', type=int, default=1)
parser.add_argument('--snapshot_dir', default='./snapshot')
parser.add_argument('--data_dir', default='./data')
# Model architecture
parser.add_argument('--pose_size', type=int, default=28)
parser.add_argument('--dim_z_init', type=int, default=10)
parser.add_argument('--dim_z_movement', type=int, default=512)
parser.add_argument('--stdp_length', type=int, default=32)
parser.add_argument('--movement_enc_bidirection', type=int, default=1)
parser.add_argument('--movement_enc_hidden_size', type=int, default=1024)
parser.add_argument('--stdp_dec_hidden_size', type=int, default=1024)
parser.add_argument('--movement_enc_num_layers', type=int, default=1)
parser.add_argument('--stdp_dec_num_layers', type=int, default=1)
# Training
parser.add_argument('--lr', type=float, default=2e-4)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--num_epochs', type=int, default=1000)
parser.add_argument('--latent_dropout', type=float, default=0.3)
parser.add_argument('--lambda_kl', type=float, default=0.01)
parser.add_argument('--lambda_initp_recon', type=float, default=1)
parser.add_argument('--lambda_initp_consistency', type=float, default=1)
parser.add_argument('--lambda_stdp_recon', type=float, default=1)
parser.add_argument('--lambda_dist_z_movement', type=float, default=1)
# Others
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--resume', default=None)
parser.add_argument('--dataset', type=int, default=0)
parser.add_argument('--tolerance', action='store_true')
self.parser = parser
def parse(self):
self.opt = self.parser.parse_args()
args = vars(self.opt)
return self.opt
class CompOptions():
def __init__(self):
parser = argparse.ArgumentParser()
parser.add_argument('--name', default=None)
parser.add_argument('--log_interval', type=int, default=50)
parser.add_argument('--log_dir', default='./logs')
parser.add_argument('--snapshot_ep', type=int, default=1)
parser.add_argument('--snapshot_dir', default='./snapshot')
parser.add_argument('--data_dir', default='./data')
# Network architecture
parser.add_argument('--pose_size', type=int, default=28)
parser.add_argument('--aud_style_size', type=int, default=30)
parser.add_argument('--dim_z_init', type=int, default=10)
parser.add_argument('--dim_z_movement', type=int, default=512)
parser.add_argument('--dim_z_dance', type=int, default=512)
parser.add_argument('--stdp_length', type=int, default=32)
parser.add_argument('--movement_enc_bidirection', type=int, default=1)
parser.add_argument('--movement_enc_hidden_size', type=int, default=1024)
parser.add_argument('--stdp_dec_hidden_size', type=int, default=1024)
parser.add_argument('--movement_enc_num_layers', type=int, default=1)
parser.add_argument('--stdp_dec_num_layers', type=int, default=1)
parser.add_argument('--dance_enc_bidirection', type=int, default=0)
parser.add_argument('--dance_enc_hidden_size', type=int, default=1024)
parser.add_argument('--dance_enc_num_layers', type=int, default=1)
parser.add_argument('--dance_dec_hidden_size', type=int, default=1024)
parser.add_argument('--dance_dec_num_layers', type=int, default=1)
# Training
parser.add_argument('--lr', type=float, default=2e-4)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--num_epochs', type=int, default=1500)
parser.add_argument('--latent_dropout', type=float, default=0.3)
parser.add_argument('--lambda_kl', type=float, default=0.01)
parser.add_argument('--lambda_kl_dance', type=float, default=0.01)
parser.add_argument('--lambda_gan', type=float, default=1)
parser.add_argument('--lambda_zmovements_recon', type=float, default=1)
parser.add_argument('--lambda_stdpSeq_recon', type=float, default=10)
parser.add_argument('--lambda_dist_z_movement', type=float, default=1)
# Other
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--decomp_snapshot', required=True)
parser.add_argument('--neta_snapshot', default='./data/stats/aud_3cls.ckpt')
parser.add_argument('--resume', default=None)
parser.add_argument('--dataset', type=int, default=2)
self.parser = parser
def parse(self):
self.opt = self.parser.parse_args()
args = vars(self.opt)
return self.opt
class TestOptions():
def __init__(self):
parser = argparse.ArgumentParser()
parser.add_argument('--name', default=None)
parser.add_argument('--log_interval', type=int, default=50)
parser.add_argument('--log_dir', default='./logs')
parser.add_argument('--snapshot_ep', type=int, default=1)
parser.add_argument('--snapshot_dir', default='./snapshot')
parser.add_argument('--data_dir', default='./data')
# Network architecture
parser.add_argument('--pose_size', type=int, default=28)
parser.add_argument('--aud_style_size', type=int, default=30)
parser.add_argument('--dim_z_init', type=int, default=10)
parser.add_argument('--dim_z_movement', type=int, default=512)
parser.add_argument('--dim_z_dance', type=int, default=512)
parser.add_argument('--stdp_length', type=int, default=32)
parser.add_argument('--movement_enc_bidirection', type=int, default=1)
parser.add_argument('--movement_enc_hidden_size', type=int, default=1024)
parser.add_argument('--stdp_dec_hidden_size', type=int, default=1024)
parser.add_argument('--movement_enc_num_layers', type=int, default=1)
parser.add_argument('--stdp_dec_num_layers', type=int, default=1)
parser.add_argument('--dance_enc_bidirection', type=int, default=0)
parser.add_argument('--dance_enc_hidden_size', type=int, default=1024)
parser.add_argument('--dance_enc_num_layers', type=int, default=1)
parser.add_argument('--dance_dec_hidden_size', type=int, default=1024)
parser.add_argument('--dance_dec_num_layers', type=int, default=1)
# Training
parser.add_argument('--lr', type=float, default=2e-4)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--num_epochs', type=int, default=1500)
parser.add_argument('--latent_dropout', type=float, default=0.3)
parser.add_argument('--lambda_kl', type=float, default=0.01)
parser.add_argument('--lambda_kl_dance', type=float, default=0.01)
parser.add_argument('--lambda_gan', type=float, default=1)
parser.add_argument('--lambda_zmovements_recon', type=float, default=1)
parser.add_argument('--lambda_stdpSeq_recon', type=float, default=10)
parser.add_argument('--lambda_dist_z_movement', type=float, default=1)
# Other
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--decomp_snapshot', required=True)
parser.add_argument('--comp_snapshot', required=True)
parser.add_argument('--neta_snapshot', default='./data/stats/aud_3cls.ckpt')
parser.add_argument('--dataset', type=int, default=2)
parser.add_argument('--thr', type=int, default=50)
parser.add_argument('--aud_path', type=str, required=True)
parser.add_argument('--modulate', action='store_true')
parser.add_argument('--out_file', type=str, default='demo/out.mp4')
parser.add_argument('--out_dir', type=str, default='demo/out_frame')
self.parser = parser
def parse(self):
self.opt = self.parser.parse_args()
args = vars(self.opt)
return self.opt
================================================
FILE: test.py
================================================
# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/Dancing2Music/License.txt
import os
import argparse
import functools
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from model_comp import *
from networks import *
from options import CompOptions
from data import get_loader
if __name__ == "__main__":
parser = CompOptions()
args = parser.parse()
#### Pretrain network from Decomp
initp_enc, stdp_dec, movement_enc = loadDecompModel(args)
#### Comp network
dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls = loadCompModel(args)
mean_pose=np.load('../onbeat/all_onbeat_mean.npy')
std_pose=np.load('../onbeat/all_onbeat_std.npy')
mean_aud=np.load('../onbeat/all_aud_mean.npy')
std_aud=np.load('../onbeat/all_aud_std.npy')
def loadDecompModel(args):
initp_enc = InitPose_Enc(pose_size=args.pose_size, dim_z_init=args.dim_z_init)
stdp_dec = StandardPose_Dec(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, dim_z_init=args.dim_z_init, length=args.stdp_length,
hidden_size=args.stdp_dec_hidden_size, num_layers=args.stdp_dec_num_layers)
movement_enc = Movement_Enc(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, length=args.stdp_length,
hidden_size=args.movement_enc_hidden_size, num_layers=args.movement_enc_num_layers, bidirection=(args.movement_enc_bidirection==1))
checkpoint = torch.load(args.decomp_snapshot)
initp_enc.load_state_dict(checkpoint['initp_enc'])
stdp_dec.load_state_dict(checkpoint['stdp_dec'])
movement_enc.load_state_dict(checkpoint['movement_enc'])
return initp_enc, stdp_dec, movement_enc
def loadCompModel(args):
dance_enc = Dance_Enc(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,
hidden_size=args.dance_enc_hidden_size, num_layers=args.dance_enc_num_layers, bidirection=(args.dance_enc_bidirection==1))
dance_dec = Dance_Dec(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,
hidden_size=args.dance_dec_hidden_size, num_layers=args.dance_dec_num_layers)
audstyle_enc = Audstyle_Enc(aud_size=args.aud_style_size, dim_z=args.dim_z_dance)
dance_reg = Dance2Style(aud_size=args.aud_style_size, dim_z_dance=args.dim_z_dance)
danceAud_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_movement, length=3)
zdance_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_dance, length=1)
checkpoint = torch.load(args.resume)
dance_enc.load_state_dict(checkpoint['dance_enc'])
dance_dec.load_state_dict(checkpoint['dance_dec'])
audstyle_enc.load_state_dict(checkpoint['audstyle_enc'])
dance_reg.load_state_dict(checkpoint['dance_reg'])
danceAud_dis.load_state_dict(checkpoint['danceAud_dis'])
zdance_dis.load_state_dict(checkpoint['zdance_dis'])
checkpoint2 = torch.load(args.neta_snapshot)
neta_cls = AudioClassifier_rnn(10,30,28,cls=3)
neta_cls.load_state_dict(checkpoint2)
return dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls
================================================
FILE: train_comp.py
================================================
# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/Dancing2Music/License.txt
import os
import argparse
import functools
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from model_comp import *
from networks import *
from options import CompOptions
from data import get_loader
def loadDecompModel(args):
initp_enc = InitPose_Enc(pose_size=args.pose_size, dim_z_init=args.dim_z_init)
stdp_dec = StandardPose_Dec(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, dim_z_init=args.dim_z_init, length=args.stdp_length,
hidden_size=args.stdp_dec_hidden_size, num_layers=args.stdp_dec_num_layers)
movement_enc = Movement_Enc(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, length=args.stdp_length,
hidden_size=args.movement_enc_hidden_size, num_layers=args.movement_enc_num_layers, bidirection=(args.movement_enc_bidirection==1))
checkpoint = torch.load(args.decomp_snapshot)
initp_enc.load_state_dict(checkpoint['initp_enc'])
stdp_dec.load_state_dict(checkpoint['stdp_dec'])
movement_enc.load_state_dict(checkpoint['movement_enc'])
return initp_enc, stdp_dec, movement_enc
def getCompNetworks(args):
dance_enc = Dance_Enc(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,
hidden_size=args.dance_enc_hidden_size, num_layers=args.dance_enc_num_layers, bidirection=(args.dance_enc_bidirection==1))
dance_dec = Dance_Dec(dim_z_dance=args.dim_z_dance, dim_z_movement=args.dim_z_movement,
hidden_size=args.dance_dec_hidden_size, num_layers=args.dance_dec_num_layers)
audstyle_enc = Audstyle_Enc(aud_size=args.aud_style_size, dim_z=args.dim_z_dance)
dance_reg = Dance2Style(aud_size=args.aud_style_size, dim_z_dance=args.dim_z_dance)
danceAud_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_movement, length=3)
zdance_dis = DanceAud_Dis2(aud_size=28, dim_z_movement=args.dim_z_dance, length=1)
checkpoint2 = torch.load(args.neta_snapshot)
neta_cls = AudioClassifier_rnn(10,30,28,cls=3)
neta_cls.load_state_dict(checkpoint2)
return dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls
if __name__ == "__main__":
parser = CompOptions()
args = parser.parse()
args.train = True
if args.name is None:
args.name = 'Comp'
args.log_dir = os.path.join(args.log_dir, args.name)
if not os.path.exists(args.log_dir):
os.mkdir(args.log_dir)
args.snapshot_dir = os.path.join(args.snapshot_dir, args.name)
if not os.path.exists(args.snapshot_dir):
os.mkdir(args.snapshot_dir)
data_loader = get_loader(batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, dataset=args.dataset, data_dir=args.data_dir)
#### Pretrain network from Decomp
initp_enc, stdp_dec, movement_enc = loadDecompModel(args)
#### Comp network
dance_enc, dance_dec, audstyle_enc, dance_reg, danceAud_dis, zdance_dis, neta_cls = getCompNetworks(args)
trainer = Trainer_Comp(data_loader,
movement_enc = movement_enc,
initp_enc = initp_enc,
stdp_dec = stdp_dec,
dance_enc = dance_enc,
dance_dec = dance_dec,
danceAud_dis = danceAud_dis,
zdance_dis = zdance_dis,
aud_enc=neta_cls,
audstyle_enc=audstyle_enc,
dance_reg=dance_reg,
args = args
)
if not args.resume is None:
ep, it = trainer.resume(args.resume, True)
else:
ep, it = 0, 0
trainer.train(ep, it)
================================================
FILE: train_decomp.py
================================================
# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/Dancing2Music/License.txt
import os
import argparse
import functools
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from model_decomp import *
from networks import *
from options import DecompOptions
from data import get_loader
def getDecompNetworks(args):
initp_enc = InitPose_Enc(pose_size=args.pose_size, dim_z_init=args.dim_z_init)
initp_dec = InitPose_Dec(pose_size=args.pose_size, dim_z_init=args.dim_z_init)
movement_enc = Movement_Enc(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, length=args.stdp_length,
hidden_size=args.movement_enc_hidden_size, num_layers=args.movement_enc_num_layers, bidirection=(args.movement_enc_bidirection==1))
stdp_dec = StandardPose_Dec(pose_size=args.pose_size, dim_z_movement=args.dim_z_movement, dim_z_init=args.dim_z_init, length=args.stdp_length,
hidden_size=args.stdp_dec_hidden_size, num_layers=args.stdp_dec_num_layers)
return initp_enc, initp_dec, movement_enc, stdp_dec
if __name__ == "__main__":
parser = DecompOptions()
args = parser.parse()
args.train = True
if args.name is None:
args.name = 'Decomp'
args.log_dir = os.path.join(args.log_dir, args.name)
if not os.path.exists(args.log_dir):
os.mkdir(args.log_dir)
args.snapshot_dir = os.path.join(args.snapshot_dir, args.name)
if not os.path.exists(args.snapshot_dir):
os.mkdir(args.snapshot_dir)
data_loader = get_loader(batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, dataset=args.dataset, data_dir=args.data_dir, tolerance=args.tolerance)
initp_enc, initp_dec, movement_enc, stdp_dec = getDecompNetworks(args)
trainer = Trainer_Decomp(data_loader,
initp_enc = initp_enc,
initp_dec = initp_dec,
movement_enc = movement_enc,
stdp_dec = stdp_dec,
args = args
)
if not args.resume is None:
ep, it = trainer.resume(args.resume, False)
else:
ep, it = 0, 0
trainer.train(ep=ep, it=it)
================================================
FILE: utils.py
================================================
# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/Dancing2Music/License.txt
import numpy as np
import pickle
import cv2
import math
import os
import random
import tensorflow as tf
class Logger(object):
def __init__(self, log_dir):
self.writer = tf.summary.FileWriter(log_dir)
def scalar_summary(self, tag, value, step):
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)
def vis(poses, outdir, aud=None):
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
# find connection in the specified sequence, center 29 is in the position 15
limbSeq = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10], \
[10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17], \
[1,16], [16,18], [3,17], [6,18]]
neglect = [14,15,16,17]
for t in range(poses.shape[0]):
#break
canvas = np.ones((256,500,3), np.uint8)*255
thisPeak = poses[t]
for i in range(18):
if i in neglect:
continue
if thisPeak[i,0] == -1:
continue
cv2.circle(canvas, tuple(thisPeak[i,0:2].astype(int)), 4, colors[i], thickness=-1)
for i in range(17):
limbid = np.array(limbSeq[i])-1
if limbid[0] in neglect or limbid[1] in neglect:
continue
X = thisPeak[[limbid[0],limbid[1]], 1]
Y = thisPeak[[limbid[0],limbid[1]], 0]
if X[0] == -1 or Y[0]==-1 or X[1]==-1 or Y[1]==-1:
continue
stickwidth = 4
cur_canvas = canvas.copy()
mX = np.mean(X)
mY = np.mean(Y)
#print(X, Y, limbid)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly((int(mY),int(mX)), (int(length/2), stickwidth), int(angle), 0, 360, 1)
#print(i, n, int(mY), int(mX), limbid, X, Y)
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
if aud is not None:
if aud[:,t] == 1:
cv2.circle(canvas, (30, 30), 20, (0,0,255), -1)
#canvas = cv2.copyMakeBorder(canvas,10,10,10,10,cv2.BORDER_CONSTANT,value=[255,0,0])
cv2.imwrite(os.path.join(outdir, 'frame{0:03d}.png'.format(t)),canvas)
def vis2(poses, outdir, fibeat):
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
# find connection in the specified sequence, center 29 is in the position 15
limbSeq = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10], \
[10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17], \
[1,16], [16,18], [3,17], [6,18]]
neglect = [14,15,16,17]
ibeat = cv2.imread(fibeat);
ibeat = cv2.resize(ibeat, (500,200))
for t in range(poses.shape[0]):
subibeat = ibeat.copy()
canvas = np.ones((256+200,500,3), np.uint8)*255
canvas[256:,:,:] = subibeat
overlay = canvas.copy()
cv2.rectangle(overlay, (int(500/poses.shape[0]*(t+1)),256),(500,256+200), (100,100,100), -1)
cv2.addWeighted(overlay, 0.4, canvas, 1-0.4, 0, canvas)
thisPeak = poses[t]
for i in range(18):
if i in neglect:
continue
if thisPeak[i,0] == -1:
continue
cv2.circle(canvas, tuple(thisPeak[i,0:2].astype(int)), 4, colors[i], thickness=-1)
for i in range(17):
limbid = np.array(limbSeq[i])-1
if limbid[0] in neglect or limbid[1] in neglect:
continue
X = thisPeak[[limbid[0],limbid[1]], 1]
Y = thisPeak[[limbid[0],limbid[1]], 0]
if X[0] == -1 or Y[0]==-1 or X[1]==-1 or Y[1]==-1:
continue
stickwidth = 4
cur_canvas = canvas.copy()
mX = np.mean(X)
mY = np.mean(Y)
#print(X, Y, limbid)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly((int(mY),int(mX)), (int(length/2), stickwidth), int(angle), 0, 360, 1)
#print(i, n, int(mY), int(mX), limbid, X, Y)
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
cv2.imwrite(os.path.join(outdir, 'frame{0:03d}.png'.format(t)),canvas)
def vis_single(pose, outfile):
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
# find connection in the specified sequence, center 29 is in the position 15
limbSeq = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10], \
[10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17], \
[1,16], [16,18], [3,17], [6,18]]
neglect = [14,15,16,17]
for t in range(1):
#break
canvas = np.ones((256,500,3), np.uint8)*255
thisPeak = pose
for i in range(18):
if i in neglect:
continue
if thisPeak[i,0] == -1:
continue
cv2.circle(canvas, tuple(thisPeak[i,0:2].astype(int)), 4, colors[i], thickness=-1)
for i in range(17):
limbid = np.array(limbSeq[i])-1
if limbid[0] in neglect or limbid[1] in neglect:
continue
X = thisPeak[[limbid[0],limbid[1]], 1]
Y = thisPeak[[limbid[0],limbid[1]], 0]
if X[0] == -1 or Y[0]==-1 or X[1]==-1 or Y[1]==-1:
continue
stickwidth = 4
cur_canvas = canvas.copy()
mX = np.mean(X)
mY = np.mean(Y)
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly((int(mY),int(mX)), (int(length/2), stickwidth), int(angle), 0, 360, 1)
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
cv2.imwrite(outfile,canvas)
gitextract_2r5pf7ve/ ├── License.txt ├── README.md ├── data.py ├── demo.py ├── model_comp.py ├── model_decomp.py ├── modulate.py ├── networks.py ├── options.py ├── test.py ├── train_comp.py ├── train_decomp.py └── utils.py
SYMBOL INDEX (111 symbols across 11 files)
FILE: data.py
class PoseDataset (line 16) | class PoseDataset(torch.utils.data.Dataset):
method __init__ (line 17) | def __init__(self, data_dir, tolerance=False):
method __getitem__ (line 71) | def __getitem__(self, index):
method __len__ (line 106) | def __len__(self):
class MovementAudDataset (line 110) | class MovementAudDataset(torch.utils.data.Dataset):
method __init__ (line 111) | def __init__(self, data_dir):
method __getitem__ (line 155) | def __getitem__(self, index):
method __len__ (line 185) | def __len__(self):
function get_loader (line 188) | def get_loader(batch_size, shuffle, num_workers, dataset, data_dir, tole...
FILE: demo.py
function loadDecompModel (line 28) | def loadDecompModel(args):
function loadCompModel (line 40) | def loadCompModel(args):
FILE: model_comp.py
class Trainer_Comp (line 26) | class Trainer_Comp(object):
method __init__ (line 27) | def __init__(self, data_loader, dance_enc, dance_dec, danceAud_dis, mo...
method init_logs (line 65) | def init_logs(self):
method get_z_random (line 73) | def get_z_random(self, batchSize, nz, random_type='gauss'):
method ones_like (line 78) | def ones_like(tensor, val=1.):
method zeros_like (line 82) | def zeros_like(tensor, val=0.):
method kld_coef (line 84) | def kld_coef(self, i):
method forward (line 88) | def forward(self, stdpSeq, batchsize, aud_style, aud):
method backward_D (line 138) | def backward_D(self):
method backward_danceED (line 167) | def backward_danceED(self):
method backward_info_ondance (line 199) | def backward_info_ondance(self):
method zero_grad (line 206) | def zero_grad(self, opt_list):
method clip_norm (line 210) | def clip_norm(self, network_list):
method step (line 214) | def step(self, opt_list):
method update (line 218) | def update(self):
method test_final (line 237) | def test_final(self, initpose, aud, n, thr=0):
method resume (line 282) | def resume(self, model_dir, train=True):
method save (line 300) | def save(self, filename, ep, total_it):
method cuda (line 324) | def cuda(self):
method train (line 338) | def train(self, ep=0, it=0):
FILE: model_decomp.py
class Trainer_Decomp (line 26) | class Trainer_Decomp(object):
method __init__ (line 27) | def __init__(self, data_loader, initp_enc, initp_dec, movement_enc, st...
method init_logs (line 53) | def init_logs(self):
method get_z_random (line 59) | def get_z_random(self, batchSize, nz, random_type='gauss'):
method ones_like (line 64) | def ones_like(tensor, val=1.):
method zeros_like (line 68) | def zeros_like(tensor, val=0.):
method random_generate_stdp (line 72) | def random_generate_stdp(self, init_p):
method forward (line 82) | def forward(self, stdpose1, stdpose2):
method backward_initp_ED (line 128) | def backward_initp_ED(self):
method backward_movement_ED (line 138) | def backward_movement_ED(self):
method update (line 160) | def update(self):
method save (line 177) | def save(self, filename, ep, total_it):
method resume (line 193) | def resume(self, model_dir, train=True):
method kld_coef (line 208) | def kld_coef(self, i):
method generate_stdp_sequence (line 212) | def generate_stdp_sequence(self, initpose, aud, num_stdp):
method cuda (line 258) | def cuda(self):
method train (line 265) | def train(self, ep=0, it=0):
FILE: modulate.py
function modulate (line 14) | def modulate(dance, beats, length):
function get_pose (line 56) | def get_pose(pose, n):
FILE: networks.py
class InitPose_Enc (line 26) | class InitPose_Enc(nn.Module):
method __init__ (line 27) | def __init__(self, pose_size, dim_z_init):
method forward (line 45) | def forward(self, pose):
class InitPose_Dec (line 49) | class InitPose_Dec(nn.Module):
method __init__ (line 50) | def __init__(self, pose_size, dim_z_init):
method forward (line 63) | def forward(self, z_init):
class Movement_Enc (line 66) | class Movement_Enc(nn.Module):
method __init__ (line 67) | def __init__(self, pose_size, dim_z_movement, length, hidden_size, num...
method forward (line 98) | def forward(self, poses):
method getFeature (line 110) | def getFeature(self, poses):
class StandardPose_Dec (line 120) | class StandardPose_Dec(nn.Module):
method __init__ (line 121) | def __init__(self, pose_size, dim_z_init, dim_z_movement, length, hidd...
method forward (line 147) | def forward(self, z_init, z_movement):
class StandardPose_Dis (line 158) | class StandardPose_Dis(nn.Module):
method __init__ (line 159) | def __init__(self, pose_size, length):
method forward (line 176) | def forward(self, pose_seq):
class Dance_Enc (line 185) | class Dance_Enc(nn.Module):
method __init__ (line 186) | def __init__(self, dim_z_movement, dim_z_dance, hidden_size, num_layer...
method forward (line 210) | def forward(self, movements_mean, movements_std):
class Dance_Dec (line 221) | class Dance_Dec(nn.Module):
method __init__ (line 222) | def __init__(self, dim_z_dance, dim_z_movement, hidden_size, num_layers):
method forward (line 254) | def forward(self, z_dance, length=3):
class DanceAud_Dis2 (line 266) | class DanceAud_Dis2(nn.Module):
method __init__ (line 267) | def __init__(self, aud_size, dim_z_movement, length=3):
method forward (line 299) | def forward(self, movements, aud):
class DanceAud_Dis (line 308) | class DanceAud_Dis(nn.Module):
method __init__ (line 309) | def __init__(self, aud_size, dim_z_movement, length=3):
method forward (line 331) | def forward(self, movements, aud):
class DanceAud_InfoDis (line 340) | class DanceAud_InfoDis(nn.Module):
method __init__ (line 341) | def __init__(self, aud_size, dim_z_movement, length):
method forward (line 370) | def forward(self, movements, aud):
class Dance2Style (line 375) | class Dance2Style(nn.Module):
method __init__ (line 376) | def __init__(self, dim_z_dance, aud_size):
method forward (line 393) | def forward(self, zdance):
class AudioClassifier_rnn (line 401) | class AudioClassifier_rnn(nn.Module):
method __init__ (line 402) | def __init__(self, dim_z_motion, hidden_size, pose_size, cls, num_laye...
method forward (line 419) | def forward(self, poses):
method get_style (line 424) | def get_style(self, auds):
class Audstyle_Enc (line 430) | class Audstyle_Enc(nn.Module):
method __init__ (line 431) | def __init__(self, aud_size, dim_z, dim_noise=30):
method forward (line 450) | def forward(self, aud):
FILE: options.py
class DecompOptions (line 11) | class DecompOptions():
method __init__ (line 12) | def __init__(self):
method parse (line 51) | def parse(self):
class CompOptions (line 56) | class CompOptions():
method __init__ (line 57) | def __init__(self):
method parse (line 103) | def parse(self):
class TestOptions (line 108) | class TestOptions():
method __init__ (line 109) | def __init__(self):
method parse (line 160) | def parse(self):
FILE: test.py
function loadDecompModel (line 37) | def loadDecompModel(args):
function loadCompModel (line 49) | def loadCompModel(args):
FILE: train_comp.py
function loadDecompModel (line 21) | def loadDecompModel(args):
function getCompNetworks (line 33) | def getCompNetworks(args):
FILE: train_decomp.py
function getDecompNetworks (line 21) | def getDecompNetworks(args):
FILE: utils.py
class Logger (line 16) | class Logger(object):
method __init__ (line 17) | def __init__(self, log_dir):
method scalar_summary (line 20) | def scalar_summary(self, tag, value, step):
function vis (line 24) | def vis(poses, outdir, aud=None):
function vis2 (line 73) | def vis2(poses, outdir, fibeat):
function vis_single (line 126) | def vis_single(pose, outfile):
Condensed preview — 13 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (99K chars).
[
{
"path": "License.txt",
"chars": 4316,
"preview": "Nvidia Source Code License-NC\n\n1. Definitions\n\n“Licensor” means any person or entity that distributes its Work.\n\n“Softwa"
},
{
"path": "README.md",
"chars": 2983,
"preview": "\n 2020, NVIDIA Corporation. All rights reserved.\n#\n# This work is made available\n# under the Nvidia Source"
},
{
"path": "demo.py",
"chars": 6158,
"preview": "# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.\n#\n# This work is made available\n# under the Nvidia Source"
},
{
"path": "model_comp.py",
"chars": 18329,
"preview": "# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.\n#\n# This work is made available\n# under the Nvidia Source"
},
{
"path": "model_decomp.py",
"chars": 12106,
"preview": "# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.\n#\n# This work is made available\n# under the Nvidia Source"
},
{
"path": "modulate.py",
"chars": 5217,
"preview": "# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.\n#\n# This work is made available\n# under the Nvidia Source"
},
{
"path": "networks.py",
"chars": 14704,
"preview": "# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.\n#\n# This work is made available\n# under the Nvidia Source"
},
{
"path": "options.py",
"chars": 8138,
"preview": "# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.\n#\n# This work is made available\n# under the Nvidia Source"
},
{
"path": "test.py",
"chars": 3294,
"preview": "# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.\n#\n# This work is made available\n# under the Nvidia Source"
},
{
"path": "train_comp.py",
"chars": 3865,
"preview": "# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.\n#\n# This work is made available\n# under the Nvidia Source"
},
{
"path": "train_decomp.py",
"chars": 2335,
"preview": "# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.\n#\n# This work is made available\n# under the Nvidia Source"
},
{
"path": "utils.py",
"chars": 6530,
"preview": "# Copyright (c) 2020, NVIDIA Corporation. All rights reserved.\n#\n# This work is made available\n# under the Nvidia Source"
}
]
About this extraction
This page contains the full source code of the NVlabs/Dancing2Music GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 13 files (93.6 KB), approximately 28.2k tokens, and a symbol index with 111 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.