Repository: jayg996/BTC-ISMIR19
Branch: master
Commit: 2682317be668
Files: 21
Total size: 23.4 MB
Directory structure:
gitextract_upyc_iog/
├── LICENSE
├── README.md
├── audio_dataset.py
├── baseline_models.py
├── btc_model.py
├── crf_model.py
├── run_config.yaml
├── test/
│ ├── btc_model.pt
│ └── btc_model_large_voca.pt
├── test.py
├── train.py
├── train_crf.py
└── utils/
├── __init__.py
├── chords.py
├── hparams.py
├── logger.py
├── mir_eval_modules.py
├── preprocess.py
├── pytorch_utils.py
├── tf_logger.py
└── transformer_modules.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2019 Jonggwon Park
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# A Bi-Directional Transformer for Musical Chord Recognition
This repository has the source codes for the paper "A Bi-Directional Transformer for Musical Chord Recognition"(ISMIR19).
## Requirements
- pytorch >= 1.0.0
- numpy >= 1.16.2
- pandas >= 0.24.1
- pyrubberband >= 0.3.0
- librosa >= 0.6.3
- pyyaml >= 3.13
- mir_eval >= 0.5
- pretty_midi >= 0.2.8
## File descriptions
* `audio_dataset.py` : loads data and preprocesses label files to chord labels and mp3 files to constant-q transformation.
* `btc_model.py` : contains pytorch implementation of BTC.
* `train.py` : for training.
* `crf_model.py` : contatins pytorch implementation of Conditional Random Fields (CRFs) .
* `baseline_models.py` : contains the codes of baseline models.
* `train_crf.py` : for training CRFs.
* `run_config.yaml` : includes hyper parameters and paths that are needed.
* `test.py` : for recognizing chord from audio file.
## Using BTC : Recognizing chords from files in audio directory
### Using BTC from command line
```bash
$ python test.py --audio_dir audio_folder --save_dir save_folder --voca False
```
* audio_dir : a folder of audio files for chord recognition (default: './test')
* save_dir : a forder for saving recognition results (default: './test')
* voca : False means major and minor label type, and True means large vocabulary label type (default: False)
The resulting files are lab files of the form shown below and midi files.
## Attention Map
The figures represent the probability values of the attention of self-attention layers 1, 3, 5 and 8 respectively. The
layers that best represent the different characteristics of each layers were chosen. The input audio is the song "Just A Girl"
(0m30s ~ 0m40s) by No Doubt from UsPop2002, which was in evaluation data.
## Data
We used Isophonics[1], Robbie Williams[2], UsPop2002[3] dataset which consists of chord label files. Due to copyright issue, these datasets do not include audio files. The audio files used in this work were collected from online music service providers.
[1] http://isophonics.net/datasets
[2] B. Di Giorgi, M. Zanoni, A. Sarti, and S. Tubaro. Automatic
chord recognition based on the probabilistic
modeling of diatonic modal harmony. In Proc. of the
8th International Workshop on Multidimensional Systems,
Erlangen, Germany, 2013.
[3] https://github.com/tmc323/Chord-Annotations
## Reference
* pytorch implementation of Transformer and Crf: https://github.com/kolloldas/torchnlp
## Comments
* Any comments for the codes are always welcome.
================================================
FILE: audio_dataset.py
================================================
import numpy as np
import os
import torch
from torch.utils.data import Dataset, DataLoader
from utils.preprocess import Preprocess, FeatureTypes
import math
from multiprocessing import Pool
from sortedcontainers import SortedList
class AudioDataset(Dataset):
def __init__(self, config, root_dir='/data/music/chord_recognition', dataset_names=('isophonic',),
featuretype=FeatureTypes.cqt, num_workers=20, train=False, preprocessing=False, resize=None, kfold=4):
super(AudioDataset, self).__init__()
self.config = config
self.root_dir = root_dir
self.dataset_names = dataset_names
self.preprocessor = Preprocess(config, featuretype, dataset_names, self.root_dir)
self.resize = resize
self.train = train
self.ratio = config.experiment['data_ratio']
# preprocessing hyperparameters
# song_hz, n_bins, bins_per_octave, hop_length
mp3_config = config.mp3
feature_config = config.feature
self.mp3_string = "%d_%.1f_%.1f" % \
(mp3_config['song_hz'], mp3_config['inst_len'],
mp3_config['skip_interval'])
self.feature_string = "%s_%d_%d_%d" % \
(featuretype.value, feature_config['n_bins'], feature_config['bins_per_octave'], feature_config['hop_length'])
if feature_config['large_voca'] == True:
# store paths if exists
is_preprocessed = True if os.path.exists(os.path.join(root_dir, 'result', dataset_names[0]+'_voca', self.mp3_string, self.feature_string)) else False
if (not is_preprocessed) | preprocessing:
midi_paths = self.preprocessor.get_all_files()
if num_workers > 1:
num_path_per_process = math.ceil(len(midi_paths) / num_workers)
args = [midi_paths[i * num_path_per_process:(i + 1) * num_path_per_process] for i in range(num_workers)]
# start process
p = Pool(processes=num_workers)
p.map(self.preprocessor.generate_labels_features_voca, args)
p.close()
else:
self.preprocessor.generate_labels_features_voca(midi_paths)
# kfold is 5 fold index ( 0, 1, 2, 3, 4 )
self.song_names, self.paths = self.get_paths_voca(kfold=kfold)
else:
# store paths if exists
is_preprocessed = True if os.path.exists(os.path.join(root_dir, 'result', dataset_names[0], self.mp3_string, self.feature_string)) else False
if (not is_preprocessed) | preprocessing:
midi_paths = self.preprocessor.get_all_files()
if num_workers > 1:
num_path_per_process = math.ceil(len(midi_paths) / num_workers)
args = [midi_paths[i * num_path_per_process:(i + 1) * num_path_per_process]
for i in range(num_workers)]
# start process
p = Pool(processes=num_workers)
p.map(self.preprocessor.generate_labels_features_new, args)
p.close()
else:
self.preprocessor.generate_labels_features_new(midi_paths)
# kfold is 5 fold index ( 0, 1, 2, 3, 4 )
self.song_names, self.paths = self.get_paths(kfold=kfold)
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
instance_path = self.paths[idx]
res = dict()
data = torch.load(instance_path)
res['feature'] = np.log(np.abs(data['feature']) + 1e-6)
res['chord'] = data['chord']
return res
def get_paths(self, kfold=4):
temp = {}
used_song_names = list()
for name in self.dataset_names:
dataset_path = os.path.join(self.root_dir, "result", name, self.mp3_string, self.feature_string)
song_names = os.listdir(dataset_path)
for song_name in song_names:
paths = []
instance_names = os.listdir(os.path.join(dataset_path, song_name))
if len(instance_names) > 0:
used_song_names.append(song_name)
for instance_name in instance_names:
paths.append(os.path.join(dataset_path, song_name, instance_name))
temp[song_name] = paths
# throw away unused song names
song_names = used_song_names
song_names = SortedList(song_names)
print('Total used song length : %d' %len(song_names))
tmp = []
for i in range(len(song_names)):
tmp += temp[song_names[i]]
print('Total instances (train and valid) : %d' %len(tmp))
# divide train/valid dataset using k fold
result = []
total_fold = 5
quotient = len(song_names) // total_fold
remainder = len(song_names) % total_fold
fold_num = [0]
for i in range(total_fold):
fold_num.append(quotient)
for i in range(remainder):
fold_num[i+1] += 1
for i in range(total_fold):
fold_num[i+1] += fold_num[i]
if self.train:
tmp = []
# get not augmented data
for k in range(total_fold):
if k != kfold:
for i in range(fold_num[k], fold_num[k+1]):
result += temp[song_names[i]]
tmp += song_names[fold_num[k]:fold_num[k + 1]]
song_names = tmp
else:
for i in range(fold_num[kfold], fold_num[kfold+1]):
instances = temp[song_names[i]]
instances = [inst for inst in instances if "1.00_0" in inst]
result += instances
song_names = song_names[fold_num[kfold]:fold_num[kfold+1]]
return song_names, result
def get_paths_voca(self, kfold=4):
temp = {}
used_song_names = list()
for name in self.dataset_names:
dataset_path = os.path.join(self.root_dir, "result", name+'_voca', self.mp3_string, self.feature_string)
song_names = os.listdir(dataset_path)
for song_name in song_names:
paths = []
instance_names = os.listdir(os.path.join(dataset_path, song_name))
if len(instance_names) > 0:
used_song_names.append(song_name)
for instance_name in instance_names:
paths.append(os.path.join(dataset_path, song_name, instance_name))
temp[song_name] = paths
# throw away unused song names
song_names = used_song_names
song_names = SortedList(song_names)
print('Total used song length : %d' %len(song_names))
tmp = []
for i in range(len(song_names)):
tmp += temp[song_names[i]]
print('Total instances (train and valid) : %d' %len(tmp))
# divide train/valid dataset using k fold
result = []
total_fold = 5
quotient = len(song_names) // total_fold
remainder = len(song_names) % total_fold
fold_num = [0]
for i in range(total_fold):
fold_num.append(quotient)
for i in range(remainder):
fold_num[i+1] += 1
for i in range(total_fold):
fold_num[i+1] += fold_num[i]
if self.train:
tmp = []
# get not augmented data
for k in range(total_fold):
if k != kfold:
for i in range(fold_num[k], fold_num[k+1]):
result += temp[song_names[i]]
tmp += song_names[fold_num[k]:fold_num[k + 1]]
song_names = tmp
else:
for i in range(fold_num[kfold], fold_num[kfold+1]):
instances = temp[song_names[i]]
instances = [inst for inst in instances if "1.00_0" in inst]
result += instances
song_names = song_names[fold_num[kfold]:fold_num[kfold+1]]
return song_names, result
def _collate_fn(batch):
batch_size = len(batch)
max_len = batch[0]['feature'].shape[1]
input_percentages = torch.empty(batch_size) # for variable length
chord_lens = torch.empty(batch_size, dtype=torch.int64)
chords = []
collapsed_chords = []
features = []
boundaries = []
for i in range(batch_size):
sample = batch[i]
feature = sample['feature']
chord = sample['chord']
diff = np.diff(chord, axis=0).astype(np.bool)
idx = np.insert(diff, 0, True, axis=0)
chord_lens[i] = np.sum(idx).item(0)
chords.extend(chord)
features.append(feature)
input_percentages[i] = feature.shape[1] / max_len
collapsed_chords.extend(np.array(chord)[idx].tolist())
boundary = np.append([0], diff)
boundaries.extend(boundary.tolist())
features = torch.tensor(features, dtype=torch.float32).unsqueeze(1) # batch_size*1*feature_size*max_len
chords = torch.tensor(chords, dtype=torch.int64) # (batch_size*time_length)
collapsed_chords = torch.tensor(collapsed_chords, dtype=torch.int64) # total_unique_chord_len
boundaries = torch.tensor(boundaries, dtype=torch.uint8) # (batch_size*time_length)
return features, input_percentages, chords, collapsed_chords, chord_lens, boundaries
class AudioDataLoader(DataLoader):
def __init__(self, *args, **kwargs):
super(AudioDataLoader, self).__init__(*args, **kwargs)
self.collate_fn = _collate_fn
================================================
FILE: baseline_models.py
================================================
from utils.hparams import HParams
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from crf_model import CRF
use_cuda = torch.cuda.is_available()
class CNN(nn.Module):
def __init__(self,config):
super(CNN, self).__init__()
self.timestep = config['timestep']
self.context = 7
self.pad = nn.ConstantPad1d(self.context, 0)
self.probs_out = config['probs_out']
self.num_chords = config['num_chords']
self.drop_out = nn.Dropout2d(p=0.5)
self.conv1 = self.cnn_layers(1, 32, kernel_size=(3,3), padding=1)
self.conv2 = self.cnn_layers(32, 32, kernel_size=(3,3), padding=1)
self.conv3 = self.cnn_layers(32, 32, kernel_size=(3,3), padding=1)
self.conv4 = self.cnn_layers(32, 32, kernel_size=(3,3), padding=1)
self.pool_max = nn.MaxPool2d(kernel_size=(2,1))
self.conv5 = self.cnn_layers(32, 64, kernel_size=(3, 3), padding=0)
self.conv6 = self.cnn_layers(64, 64, kernel_size=(3, 3), padding=0)
self.conv7 = self.cnn_layers(64, 128, kernel_size=(12, 9), padding=0)
self.conv_linear = nn.Conv2d(128, config['num_chords'], kernel_size=(1,1), padding=0)
def cnn_layers(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
layers = []
conv2d = nn.Conv2d(in_channels, out_channels,kernel_size=kernel_size, stride=stride, padding=padding)
batch_norm = nn.BatchNorm2d(out_channels)
relu = nn.ReLU(inplace=True)
layers += [conv2d, batch_norm, relu]
return nn.Sequential(*layers)
def forward(self, x, labels):
x = x.permute(0,2,1)
x = self.pad(x)
batch_size = x.size(0)
for i in range(batch_size):
for j in range(self.timestep):
if i == 0 and j == 0:
inputs = x[i,:,j : j + self.context *2 + 1].unsqueeze(0)
else:
tmp = x[i, :, j : j + self.context *2 + 1].unsqueeze(0)
inputs = torch.cat((inputs,tmp), dim=0)
# inputs : [batchsize * timestep, feature_size, context]
inputs = inputs.unsqueeze(1)
conv = self.conv1(inputs)
conv = self.conv2(conv)
conv = self.conv3(conv)
conv = self.conv4(conv)
pooled = self.pool_max(conv)
pooled = self.drop_out(pooled)
conv = self.conv5(pooled)
conv = self.conv6(conv)
pooled = self.pool_max(conv)
pooled = self.drop_out(pooled)
conv = self.conv7(pooled)
conv = self.drop_out(conv)
conv = self.conv_linear(conv)
avg_pool = nn.AvgPool2d(kernel_size=(conv.size(2), conv.size(3)))
logits = avg_pool(conv).squeeze(2).squeeze(2)
if self.probs_out is True:
crf_input = logits.view(-1, self.timestep, self.num_chords)
return crf_input
log_probs = F.log_softmax(logits, -1)
topk, indices = torch.topk(log_probs, 2)
predictions = indices[:,0]
second = indices[:,1]
prediction = predictions.view(-1)
second = second.view(-1)
loss = F.nll_loss(log_probs.view(-1, self.num_chords), labels.view(-1))
return prediction, loss, 0, second
class Crf(nn.Module):
def __init__(self, num_chords, timestep):
super(Crf, self).__init__()
self.output_size = num_chords
self.timestep = timestep
self.Crf = CRF(self.output_size)
def forward(self, probs, labels):
prediction = self.Crf(probs)
prediction = prediction.view(-1)
labels = labels.view(-1, self.timestep)
loss = self.Crf.loss(probs, labels)
return prediction, loss
class CRNN(nn.Module):
def __init__(self,config):
super(CRNN, self).__init__()
self.feature_size = config['feature_size']
self.timestep = config['timestep']
self.probs_out = config['probs_out']
self.num_chords = config['num_chords']
self.hidden_size = 128
self.relu = nn.ReLU(inplace=True)
self.batch_norm = nn.BatchNorm2d(1)
self.conv1 = nn.Conv2d(1, 1, kernel_size=(5,5), padding=2)
self.conv2 = nn.Conv2d(1, 36, kernel_size=(1,self.feature_size))
self.gru = nn.GRU(input_size=36, hidden_size=self.hidden_size, num_layers=2, batch_first=True, bidirectional=True)
self.fc = nn.Linear(self.hidden_size*2, self.num_chords)
def forward(self, x, labels):
# x : [batchsize * timestep * feature_size]
x = x.unsqueeze(1)
x = self.batch_norm(x)
conv = self.relu(self.conv1(x))
conv = self.relu(self.conv2(conv))
conv = conv.squeeze(3).permute(0,2,1)
h0 = torch.zeros(4, conv.size(0), self.hidden_size).to(torch.device("cuda" if use_cuda else "cpu"))
gru, h = self.gru(conv, h0)
logits = self.fc(gru)
if self.probs_out is True:
# probs = F.softmax(logits, -1)
return logits
log_probs = F.log_softmax(logits, -1)
topk, indices = torch.topk(log_probs, 2)
predictions = indices[:,:,0]
second = indices[:,:,1]
prediction = predictions.view(-1)
second = second.view(-1)
loss = F.nll_loss(log_probs.view(-1, self.num_chords), labels.view(-1))
return prediction, loss, 0, second
if __name__ == "__main__":
config = HParams.load("run_config.yaml")
device = torch.device("cuda" if use_cuda else "cpu")
config.model['probs_out'] = True
batch_size = 2
timestep = config.model['timestep']
feature_size = config.model['feature_size']
num_chords = config.model['num_chords']
features = torch.randn(batch_size,timestep,feature_size,requires_grad=True).to(device)
chords = torch.randint(num_chords,(batch_size*timestep,)).to(device)
model = CNN(config=config.model).to(device)
crf = Crf(num_chords=config.model['num_chords'], timestep=config.model['timestep']).to(device)
probs = model(features, chords)
prediction, total_loss = crf(probs, chords)
print(total_loss)
================================================
FILE: btc_model.py
================================================
from utils.transformer_modules import *
from utils.transformer_modules import _gen_timing_signal, _gen_bias_mask
from utils.hparams import HParams
use_cuda = torch.cuda.is_available()
class self_attention_block(nn.Module):
def __init__(self, hidden_size, total_key_depth, total_value_depth, filter_size, num_heads,
bias_mask=None, layer_dropout=0.0, attention_dropout=0.0, relu_dropout=0.0, attention_map=False):
super(self_attention_block, self).__init__()
self.attention_map = attention_map
self.multi_head_attention = MultiHeadAttention(hidden_size, total_key_depth, total_value_depth,hidden_size, num_heads, bias_mask, attention_dropout, attention_map)
self.positionwise_convolution = PositionwiseFeedForward(hidden_size, filter_size, hidden_size, layer_config='cc', padding='both', dropout=relu_dropout)
self.dropout = nn.Dropout(layer_dropout)
self.layer_norm_mha = LayerNorm(hidden_size)
self.layer_norm_ffn = LayerNorm(hidden_size)
def forward(self, inputs):
x = inputs
# Layer Normalization
x_norm = self.layer_norm_mha(x)
# Multi-head attention
if self.attention_map is True:
y, weights = self.multi_head_attention(x_norm, x_norm, x_norm)
else:
y = self.multi_head_attention(x_norm, x_norm, x_norm)
# Dropout and residual
x = self.dropout(x + y)
# Layer Normalization
x_norm = self.layer_norm_ffn(x)
# Positionwise Feedforward
y = self.positionwise_convolution(x_norm)
# Dropout and residual
y = self.dropout(x + y)
if self.attention_map is True:
return y, weights
return y
class bi_directional_self_attention(nn.Module):
def __init__(self, hidden_size, total_key_depth, total_value_depth, filter_size, num_heads, max_length,
layer_dropout=0.0, attention_dropout=0.0, relu_dropout=0.0):
super(bi_directional_self_attention, self).__init__()
self.weights_list = list()
params = (hidden_size,
total_key_depth or hidden_size,
total_value_depth or hidden_size,
filter_size,
num_heads,
_gen_bias_mask(max_length),
layer_dropout,
attention_dropout,
relu_dropout,
True)
self.attn_block = self_attention_block(*params)
params = (hidden_size,
total_key_depth or hidden_size,
total_value_depth or hidden_size,
filter_size,
num_heads,
torch.transpose(_gen_bias_mask(max_length), dim0=2, dim1=3),
layer_dropout,
attention_dropout,
relu_dropout,
True)
self.backward_attn_block = self_attention_block(*params)
self.linear = nn.Linear(hidden_size*2, hidden_size)
def forward(self, inputs):
x, list = inputs
# Forward Self-attention Block
encoder_outputs, weights = self.attn_block(x)
# Backward Self-attention Block
reverse_outputs, reverse_weights = self.backward_attn_block(x)
# Concatenation and Fully-connected Layer
outputs = torch.cat((encoder_outputs, reverse_outputs), dim=2)
y = self.linear(outputs)
# Attention weights for Visualization
self.weights_list = list
self.weights_list.append(weights)
self.weights_list.append(reverse_weights)
return y, self.weights_list
class bi_directional_self_attention_layers(nn.Module):
def __init__(self, embedding_size, hidden_size, num_layers, num_heads, total_key_depth, total_value_depth,
filter_size, max_length=100, input_dropout=0.0, layer_dropout=0.0,
attention_dropout=0.0, relu_dropout=0.0):
super(bi_directional_self_attention_layers, self).__init__()
self.timing_signal = _gen_timing_signal(max_length, hidden_size)
params = (hidden_size,
total_key_depth or hidden_size,
total_value_depth or hidden_size,
filter_size,
num_heads,
max_length,
layer_dropout,
attention_dropout,
relu_dropout)
self.embedding_proj = nn.Linear(embedding_size, hidden_size, bias=False)
self.self_attn_layers = nn.Sequential(*[bi_directional_self_attention(*params) for l in range(num_layers)])
self.layer_norm = LayerNorm(hidden_size)
self.input_dropout = nn.Dropout(input_dropout)
def forward(self, inputs):
# Add input dropout
x = self.input_dropout(inputs)
# Project to hidden size
x = self.embedding_proj(x)
# Add timing signal
x += self.timing_signal[:, :inputs.shape[1], :].type_as(inputs.data)
# A Stack of Bi-directional Self-attention Layers
y, weights_list = self.self_attn_layers((x, []))
# Layer Normalization
y = self.layer_norm(y)
return y, weights_list
class BTC_model(nn.Module):
def __init__(self, config):
super(BTC_model, self).__init__()
self.timestep = config['timestep']
self.probs_out = config['probs_out']
params = (config['feature_size'],
config['hidden_size'],
config['num_layers'],
config['num_heads'],
config['total_key_depth'],
config['total_value_depth'],
config['filter_size'],
config['timestep'],
config['input_dropout'],
config['layer_dropout'],
config['attention_dropout'],
config['relu_dropout'])
self.self_attn_layers = bi_directional_self_attention_layers(*params)
self.output_layer = SoftmaxOutputLayer(hidden_size=config['hidden_size'], output_size=config['num_chords'], probs_out=config['probs_out'])
def forward(self, x, labels):
labels = labels.view(-1, self.timestep)
# Output of Bi-directional Self-attention Layers
self_attn_output, weights_list = self.self_attn_layers(x)
# return logit values for CRF
if self.probs_out is True:
logits = self.output_layer(self_attn_output)
return logits
# Output layer and Soft-max
prediction,second = self.output_layer(self_attn_output)
prediction = prediction.view(-1)
second = second.view(-1)
# Loss Calculation
loss = self.output_layer.loss(self_attn_output, labels)
return prediction, loss, weights_list, second
if __name__ == "__main__":
config = HParams.load("run_config.yaml")
device = torch.device("cuda" if use_cuda else "cpu")
batch_size = 2
timestep = 108
feature_size = 144
num_chords = 25
features = torch.randn(batch_size,timestep,feature_size,requires_grad=True).to(device)
chords = torch.randint(25,(batch_size*timestep,)).to(device)
model = BTC_model(config=config.model).to(device)
prediction, loss, weights_list, second = model(features, chords)
print(prediction.size())
print(loss)
================================================
FILE: crf_model.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
class CRF(nn.Module):
"""
Implements Conditional Random Fields that can be trained via
backpropagation.
"""
def __init__(self, num_tags):
super(CRF, self).__init__()
self.num_tags = num_tags
self.transitions = nn.Parameter(torch.Tensor(num_tags, num_tags))
self.start_transitions = nn.Parameter(torch.randn(num_tags))
self.stop_transitions = nn.Parameter(torch.randn(num_tags))
nn.init.xavier_normal_(self.transitions)
def forward(self, feats):
# Shape checks
if len(feats.shape) != 3:
raise ValueError("feats must be 3-d got {}-d".format(feats.shape))
return self._viterbi(feats)
def loss(self, feats, tags):
"""
Computes negative log likelihood between features and tags.
Essentially difference between individual sequence scores and
sum of all possible sequence scores (partition function)
Parameters:
feats: Input features [batch size, sequence length, number of tags]
tags: Target tag indices [batch size, sequence length]. Should be between
0 and num_tags
Returns:
Negative log likelihood [a scalar]
"""
# Shape checks
if len(feats.shape) != 3:
raise ValueError("feats must be 3-d got {}-d".format(feats.shape))
if len(tags.shape) != 2:
raise ValueError('tags must be 2-d but got {}-d'.format(tags.shape))
if feats.shape[:2] != tags.shape:
raise ValueError('First two dimensions of feats and tags must match')
sequence_score = self._sequence_score(feats, tags)
partition_function = self._partition_function(feats)
log_probability = sequence_score - partition_function
# -ve of l()
# Average across batch
return -log_probability.mean()
def _sequence_score(self, feats, tags):
"""
Parameters:
feats: Input features [batch size, sequence length, number of tags]
tags: Target tag indices [batch size, sequence length]. Should be between
0 and num_tags
Returns: Sequence score of shape [batch size]
"""
batch_size = feats.shape[0]
# Compute feature scores
feat_score = feats.gather(2, tags.unsqueeze(-1)).squeeze(-1).sum(dim=-1)
# Compute transition scores
# Unfold to get [from, to] tag index pairs
tags_pairs = tags.unfold(1, 2, 1)
# Use advanced indexing to pull out required transition scores
indices = tags_pairs.permute(2, 0, 1).chunk(2)
trans_score = self.transitions[indices].squeeze(0).sum(dim=-1)
# Compute start and stop scores
start_score = self.start_transitions[tags[:, 0]]
stop_score = self.stop_transitions[tags[:, -1]]
return feat_score + start_score + trans_score + stop_score
def _partition_function(self, feats):
"""
Computes the partitition function for CRF using the forward algorithm.
Basically calculate scores for all possible tag sequences for
the given feature vector sequence
Parameters:
feats: Input features [batch size, sequence length, number of tags]
Returns:
Total scores of shape [batch size]
"""
_, seq_size, num_tags = feats.shape
if self.num_tags != num_tags:
raise ValueError('num_tags should be {} but got {}'.format(self.num_tags, num_tags))
a = feats[:, 0] + self.start_transitions.unsqueeze(0) # [batch_size, num_tags]
transitions = self.transitions.unsqueeze(0) # [1, num_tags, num_tags] from -> to
for i in range(1, seq_size):
feat = feats[:, i].unsqueeze(1) # [batch_size, 1, num_tags]
a = self._log_sum_exp(a.unsqueeze(-1) + transitions + feat, 1) # [batch_size, num_tags]
return self._log_sum_exp(a + self.stop_transitions.unsqueeze(0), 1) # [batch_size]
def _viterbi(self, feats):
"""
Uses Viterbi algorithm to predict the best sequence
Parameters:
feats: Input features [batch size, sequence length, number of tags]
Returns: Best tag sequence [batch size, sequence length]
"""
_, seq_size, num_tags = feats.shape
if self.num_tags != num_tags:
raise ValueError('num_tags should be {} but got {}'.format(self.num_tags, num_tags))
v = feats[:, 0] + self.start_transitions.unsqueeze(0) # [batch_size, num_tags]
transitions = self.transitions.unsqueeze(0) # [1, num_tags, num_tags] from -> to
paths = []
for i in range(1, seq_size):
feat = feats[:, i] # [batch_size, num_tags]
v, idx = (v.unsqueeze(-1) + transitions).max(1) # [batch_size, num_tags], [batch_size, num_tags]
paths.append(idx)
v = (v + feat) # [batch_size, num_tags]
v, tag = (v + self.stop_transitions.unsqueeze(0)).max(1, True)
# Backtrack
tags = [tag]
for idx in reversed(paths):
tag = idx.gather(1, tag)
tags.append(tag)
tags.reverse()
return torch.cat(tags, 1)
def _log_sum_exp(self, logits, dim):
"""
Computes log-sum-exp in a stable way
"""
max_val, _ = logits.max(dim)
return max_val + (logits - max_val.unsqueeze(dim)).exp().sum(dim).log()
================================================
FILE: run_config.yaml
================================================
mp3:
song_hz: 22050
inst_len: 10.0
skip_interval: 5.0
feature:
n_bins: 144
bins_per_octave: 24
hop_length: 2048
large_voca: False
# large_voca: True
experiment:
learning_rate : 0.0001
weight_decay : 0.0
max_epoch : 100
batch_size : 128
save_step : 40
data_ratio : 0.8
model:
feature_size : 144
timestep : 108
num_chords : 25
# num_chords : 170
input_dropout : 0.2
layer_dropout : 0.2
attention_dropout : 0.2
relu_dropout : 0.2
num_layers : 8
num_heads : 4
hidden_size : 128
total_key_depth : 128
total_value_depth : 128
filter_size : 128
loss : 'ce'
probs_out : False
path:
ckpt_path : 'model'
result_path : 'result'
asset_path : '/data/music/chord_recognition/jayg996/assets'
root_path : '/data/music/chord_recognition'
================================================
FILE: test/btc_model.pt
================================================
[File too large to display: 11.6 MB]
================================================
FILE: test/btc_model_large_voca.pt
================================================
[File too large to display: 11.7 MB]
================================================
FILE: test.py
================================================
import os
import mir_eval
import pretty_midi as pm
from utils import logger
from btc_model import *
from utils.mir_eval_modules import audio_file_to_features, idx2chord, idx2voca_chord, get_audio_paths
import argparse
import warnings
warnings.filterwarnings('ignore')
logger.logging_verbosity(1)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
# hyperparameters
parser = argparse.ArgumentParser()
parser.add_argument('--voca', default=True, type=lambda x: (str(x).lower() == 'true'))
parser.add_argument('--audio_dir', type=str, default='./test')
parser.add_argument('--save_dir', type=str, default='./test')
args = parser.parse_args()
config = HParams.load("run_config.yaml")
if args.voca is True:
config.feature['large_voca'] = True
config.model['num_chords'] = 170
model_file = './test/btc_model_large_voca.pt'
idx_to_chord = idx2voca_chord()
logger.info("label type: large voca")
else:
model_file = './test/btc_model.pt'
idx_to_chord = idx2chord
logger.info("label type: Major and minor")
model = BTC_model(config=config.model).to(device)
# Load model
if os.path.isfile(model_file):
checkpoint = torch.load(model_file)
mean = checkpoint['mean']
std = checkpoint['std']
model.load_state_dict(checkpoint['model'])
logger.info("restore model")
# Audio files with format of wav and mp3
audio_paths = get_audio_paths(args.audio_dir)
# Chord recognition and save lab file
for i, audio_path in enumerate(audio_paths):
logger.info("======== %d of %d in progress ========" % (i + 1, len(audio_paths)))
# Load mp3
feature, feature_per_second, song_length_second = audio_file_to_features(audio_path, config)
logger.info("audio file loaded and feature computation success : %s" % audio_path)
# Majmin type chord recognition
feature = feature.T
feature = (feature - mean) / std
time_unit = feature_per_second
n_timestep = config.model['timestep']
num_pad = n_timestep - (feature.shape[0] % n_timestep)
feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
num_instance = feature.shape[0] // n_timestep
start_time = 0.0
lines = []
with torch.no_grad():
model.eval()
feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device)
for t in range(num_instance):
self_attn_output, _ = model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :])
prediction, _ = model.output_layer(self_attn_output)
prediction = prediction.squeeze()
for i in range(n_timestep):
if t == 0 and i == 0:
prev_chord = prediction[i].item()
continue
if prediction[i].item() != prev_chord:
lines.append(
'%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord]))
start_time = time_unit * (n_timestep * t + i)
prev_chord = prediction[i].item()
if t == num_instance - 1 and i + num_pad == n_timestep:
if start_time != time_unit * (n_timestep * t + i):
lines.append('%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord]))
break
# lab file write
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
save_path = os.path.join(args.save_dir, os.path.split(audio_path)[-1].replace('.mp3', '').replace('.wav', '') + '.lab')
with open(save_path, 'w') as f:
for line in lines:
f.write(line)
logger.info("label file saved : %s" % save_path)
# lab file to midi file
starts, ends, pitchs = list(), list(), list()
intervals, chords = mir_eval.io.load_labeled_intervals(save_path)
for p in range(12):
for i, (interval, chord) in enumerate(zip(intervals, chords)):
root_num, relative_bitmap, _ = mir_eval.chord.encode(chord)
tmp_label = mir_eval.chord.rotate_bitmap_to_root(relative_bitmap, root_num)[p]
if i == 0:
start_time = interval[0]
label = tmp_label
continue
if tmp_label != label:
if label == 1.0:
starts.append(start_time), ends.append(interval[0]), pitchs.append(p + 48)
start_time = interval[0]
label = tmp_label
if i == (len(intervals) - 1):
if label == 1.0:
starts.append(start_time), ends.append(interval[1]), pitchs.append(p + 48)
midi = pm.PrettyMIDI()
instrument = pm.Instrument(program=0)
for start, end, pitch in zip(starts, ends, pitchs):
pm_note = pm.Note(velocity=120, pitch=pitch, start=start, end=end)
instrument.notes.append(pm_note)
midi.instruments.append(instrument)
midi.write(save_path.replace('.lab', '.midi'))
================================================
FILE: train.py
================================================
import os
from torch import optim
from utils import logger
from audio_dataset import AudioDataset, AudioDataLoader
from utils.tf_logger import TF_Logger
from btc_model import *
from baseline_models import CNN, CRNN
from utils.hparams import HParams
import argparse
from utils.pytorch_utils import adjusting_learning_rate
from utils.mir_eval_modules import root_majmin_score_calculation, large_voca_score_calculation
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
logger.logging_verbosity(1)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
parser = argparse.ArgumentParser()
parser.add_argument('--index', type=int, help='Experiment Number', default='e')
parser.add_argument('--kfold', type=int, help='5 fold (0,1,2,3,4)',default='e')
parser.add_argument('--voca', type=bool, help='large voca is True', default=False)
parser.add_argument('--model', type=str, help='btc, cnn, crnn', default='btc')
parser.add_argument('--dataset1', type=str, help='Dataset', default='isophonic')
parser.add_argument('--dataset2', type=str, help='Dataset', default='uspop')
parser.add_argument('--dataset3', type=str, help='Dataset', default='robbiewilliams')
parser.add_argument('--restore_epoch', type=int, default=1000)
parser.add_argument('--early_stop', type=bool, help='no improvement during 10 epoch -> stop', default=True)
args = parser.parse_args()
config = HParams.load("run_config.yaml")
if args.voca == True:
config.feature['large_voca'] = True
config.model['num_chords'] = 170
# Result save path
asset_path = config.path['asset_path']
ckpt_path = config.path['ckpt_path']
result_path = config.path['result_path']
restore_epoch = args.restore_epoch
experiment_num = str(args.index)
ckpt_file_name = 'idx_'+experiment_num+'_%03d.pth.tar'
tf_logger = TF_Logger(os.path.join(asset_path, 'tensorboard', 'idx_'+experiment_num))
logger.info("==== Experiment Number : %d " % args.index)
if args.model == 'cnn':
config.experiment['batch_size'] = 10
# Data loader
train_dataset1 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset1,), num_workers=20, preprocessing=False, train=True, kfold=args.kfold)
train_dataset2 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset2,), num_workers=20, preprocessing=False, train=True, kfold=args.kfold)
train_dataset3 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset3,), num_workers=20, preprocessing=False, train=True, kfold=args.kfold)
train_dataset = train_dataset1.__add__(train_dataset2).__add__(train_dataset3)
valid_dataset1 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset1,), preprocessing=False, train=False, kfold=args.kfold)
valid_dataset2 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset2,), preprocessing=False, train=False, kfold=args.kfold)
valid_dataset3 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset3,), preprocessing=False, train=False, kfold=args.kfold)
valid_dataset = valid_dataset1.__add__(valid_dataset2).__add__(valid_dataset3)
train_dataloader = AudioDataLoader(dataset=train_dataset, batch_size=config.experiment['batch_size'], drop_last=False, shuffle=True)
valid_dataloader = AudioDataLoader(dataset=valid_dataset, batch_size=config.experiment['batch_size'], drop_last=False)
# Model and Optimizer
if args.model == 'cnn':
model = CNN(config=config.model).to(device)
elif args.model == 'crnn':
model = CRNN(config=config.model).to(device)
elif args.model == 'btc':
model = BTC_model(config=config.model).to(device)
else: raise NotImplementedError
optimizer = optim.Adam(model.parameters(), lr=config.experiment['learning_rate'], weight_decay=config.experiment['weight_decay'], betas=(0.9, 0.98), eps=1e-9)
# Make asset directory
if not os.path.exists(os.path.join(asset_path, ckpt_path)):
os.makedirs(os.path.join(asset_path, ckpt_path))
os.makedirs(os.path.join(asset_path, result_path))
# Load model
if os.path.isfile(os.path.join(asset_path, ckpt_path, ckpt_file_name % restore_epoch)):
checkpoint = torch.load(os.path.join(asset_path, ckpt_path, ckpt_file_name % restore_epoch))
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']
logger.info("restore model with %d epochs" % restore_epoch)
else:
logger.info("no checkpoint with %d epochs" % restore_epoch)
restore_epoch = 0
# Global mean and variance calculate
mp3_config = config.mp3
feature_config = config.feature
mp3_string = "%d_%.1f_%.1f" % (mp3_config['song_hz'], mp3_config['inst_len'], mp3_config['skip_interval'])
feature_string = "_%s_%d_%d_%d_" % ('cqt', feature_config['n_bins'], feature_config['bins_per_octave'], feature_config['hop_length'])
z_path = os.path.join(config.path['root_path'], 'result', mp3_string + feature_string + 'mix_kfold_'+ str(args.kfold) +'_normalization.pt')
if os.path.exists(z_path):
normalization = torch.load(z_path)
mean = normalization['mean']
std = normalization['std']
logger.info("Global mean and std (k fold index %d) load complete" % args.kfold)
else:
mean = 0
square_mean = 0
k = 0
for i, data in enumerate(train_dataloader):
features, input_percentages, chords, collapsed_chords, chord_lens, boundaries = data
features = features.to(device)
mean += torch.mean(features).item()
square_mean += torch.mean(features.pow(2)).item()
k += 1
square_mean = square_mean / k
mean = mean / k
std = np.sqrt(square_mean - mean * mean)
normalization = dict()
normalization['mean'] = mean
normalization['std'] = std
torch.save(normalization, z_path)
logger.info("Global mean and std (training set, k fold index %d) calculation complete" % args.kfold)
current_step = 0
best_acc = 0
before_acc = 0
early_stop_idx = 0
for epoch in range(restore_epoch, config.experiment['max_epoch']):
# Training
model.train()
train_loss_list = []
total = 0.
correct = 0.
second_correct = 0.
for i, data in enumerate(train_dataloader):
features, input_percentages, chords, collapsed_chords, chord_lens, boundaries = data
features, chords = features.to(device), chords.to(device)
features.requires_grad = True
features = (features - mean) / std
# forward
features = features.squeeze(1).permute(0,2,1)
optimizer.zero_grad()
prediction, total_loss, weights, second = model(features, chords)
# save accuracy and loss
total += chords.size(0)
correct += (prediction == chords).type_as(chords).sum()
second_correct += (second == chords).type_as(chords).sum()
train_loss_list.append(total_loss.item())
# optimize step
total_loss.backward()
optimizer.step()
current_step += 1
# logging loss and accuracy using tensorboard
result = {'loss/tr': np.mean(train_loss_list), 'acc/tr': correct.item() / total, 'top2/tr': (correct.item()+second_correct.item()) / total}
for tag, value in result.items(): tf_logger.scalar_summary(tag, value, epoch+1)
logger.info("training loss for %d epoch: %.4f" % (epoch + 1, np.mean(train_loss_list)))
logger.info("training accuracy for %d epoch: %.4f" % (epoch + 1, (correct.item() / total)))
logger.info("training top2 accuracy for %d epoch: %.4f" % (epoch + 1, ((correct.item() + second_correct.item()) / total)))
# Validation
with torch.no_grad():
model.eval()
val_total = 0.
val_correct = 0.
val_second_correct = 0.
validation_loss = 0
n = 0
for i, data in enumerate(valid_dataloader):
val_features, val_input_percentages, val_chords, val_collapsed_chords, val_chord_lens, val_boundaries = data
val_features, val_chords = val_features.to(device), val_chords.to(device)
val_features = (val_features - mean) / std
val_features = val_features.squeeze(1).permute(0, 2, 1)
val_prediction, val_loss, weights, val_second = model(val_features, val_chords)
val_total += val_chords.size(0)
val_correct += (val_prediction == val_chords).type_as(val_chords).sum()
val_second_correct += (val_second == val_chords).type_as(val_chords).sum()
validation_loss += val_loss.item()
n += 1
# logging loss and accuracy using tensorboard
validation_loss /= n
result = {'loss/val': validation_loss, 'acc/val': val_correct.item() / val_total, 'top2/val': (val_correct.item()+val_second_correct.item()) / val_total}
for tag, value in result.items(): tf_logger.scalar_summary(tag, value, epoch + 1)
logger.info("validation loss(%d): %.4f" % (epoch + 1, validation_loss))
logger.info("validation accuracy(%d): %.4f" % (epoch + 1, (val_correct.item() / val_total)))
logger.info("validation top2 accuracy(%d): %.4f" % (epoch + 1, ((val_correct.item() + val_second_correct.item()) / val_total)))
current_acc = val_correct.item() / val_total
if best_acc < val_correct.item() / val_total:
early_stop_idx = 0
best_acc = val_correct.item() / val_total
logger.info('==== best accuracy is %.4f and epoch is %d' % (best_acc, epoch + 1))
logger.info('saving model, Epoch %d, step %d' % (epoch + 1, current_step + 1))
model_save_path = os.path.join(asset_path, 'model', ckpt_file_name % (epoch + 1))
state_dict = {'model': model.state_dict(),'optimizer': optimizer.state_dict(),'epoch': epoch}
torch.save(state_dict, model_save_path)
last_best_epoch = epoch + 1
# save model
elif (epoch + 1) % config.experiment['save_step'] == 0:
logger.info('saving model, Epoch %d, step %d' % (epoch + 1, current_step + 1))
model_save_path = os.path.join(asset_path, 'model', ckpt_file_name % (epoch + 1))
state_dict = {'model': model.state_dict(),'optimizer': optimizer.state_dict(),'epoch': epoch}
torch.save(state_dict, model_save_path)
early_stop_idx += 1
else:
early_stop_idx += 1
if (args.early_stop == True) and (early_stop_idx > 9):
logger.info('==== early stopped and epoch is %d' % (epoch + 1))
break
# learning rate decay
if before_acc > current_acc:
adjusting_learning_rate(optimizer=optimizer, factor=0.95, min_lr=5e-6)
before_acc = current_acc
# Load model
if os.path.isfile(os.path.join(asset_path, ckpt_path, ckpt_file_name % last_best_epoch)):
checkpoint = torch.load(os.path.join(asset_path, ckpt_path, ckpt_file_name % last_best_epoch))
model.load_state_dict(checkpoint['model'])
logger.info("restore model with %d epochs" % last_best_epoch)
else:
raise NotImplementedError
# score Validation
if args.voca == True:
score_metrics = ['root', 'thirds', 'triads', 'sevenths', 'tetrads', 'majmin', 'mirex']
score_list_dict1, song_length_list1, average_score_dict1 = large_voca_score_calculation(valid_dataset=valid_dataset1, config=config, model=model, model_type=args.model, mean=mean, std=std, device=device)
score_list_dict2, song_length_list2, average_score_dict2 = large_voca_score_calculation(valid_dataset=valid_dataset2, config=config, model=model, model_type=args.model, mean=mean, std=std, device=device)
score_list_dict3, song_length_list3, average_score_dict3 = large_voca_score_calculation(valid_dataset=valid_dataset3, config=config, model=model, model_type=args.model, mean=mean, std=std, device=device)
for m in score_metrics:
average_score = (np.sum(song_length_list1) * average_score_dict1[m] + np.sum(song_length_list2) *average_score_dict2[m] + np.sum(song_length_list3) * average_score_dict3[m]) / (np.sum(song_length_list1) + np.sum(song_length_list2) + np.sum(song_length_list3))
logger.info('==== %s score 1 is %.4f' % (m, average_score_dict1[m]))
logger.info('==== %s score 2 is %.4f' % (m, average_score_dict2[m]))
logger.info('==== %s score 3 is %.4f' % (m, average_score_dict3[m]))
logger.info('==== %s mix average score is %.4f' % (m, average_score))
else:
score_metrics = ['root', 'majmin']
score_list_dict1, song_length_list1, average_score_dict1 = root_majmin_score_calculation(valid_dataset=valid_dataset1, config=config, model=model, model_type=args.model, mean=mean, std=std, device=device)
score_list_dict2, song_length_list2, average_score_dict2 = root_majmin_score_calculation(valid_dataset=valid_dataset2, config=config, model=model, model_type=args.model, mean=mean, std=std, device=device)
score_list_dict3, song_length_list3, average_score_dict3 = root_majmin_score_calculation(valid_dataset=valid_dataset3, config=config, model=model, model_type=args.model, mean=mean, std=std, device=device)
for m in score_metrics:
average_score = (np.sum(song_length_list1) * average_score_dict1[m] + np.sum(song_length_list2) *average_score_dict2[m] + np.sum(song_length_list3) * average_score_dict3[m]) / (np.sum(song_length_list1) + np.sum(song_length_list2) + np.sum(song_length_list3))
logger.info('==== %s score 1 is %.4f' % (m, average_score_dict1[m]))
logger.info('==== %s score 2 is %.4f' % (m, average_score_dict2[m]))
logger.info('==== %s score 3 is %.4f' % (m, average_score_dict3[m]))
logger.info('==== %s mix average score is %.4f' % (m, average_score))
================================================
FILE: train_crf.py
================================================
import os
from torch import optim
from utils import logger
from audio_dataset import AudioDataset, AudioDataLoader
from utils.tf_logger import TF_Logger
from btc_model import *
from baseline_models import CNN, CRNN, Crf
from utils.hparams import HParams
import argparse
from utils.pytorch_utils import adjusting_learning_rate
from utils.mir_eval_modules import large_voca_score_calculation_crf, root_majmin_score_calculation_crf
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
logger.logging_verbosity(1)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
parser = argparse.ArgumentParser()
parser.add_argument('--index', type=int, help='Experiment Number', default='e')
parser.add_argument('--kfold', type=int, help='5 fold (0,1,2,3,4)',default='e')
parser.add_argument('--voca', type=bool, help='large voca is True', default=False)
parser.add_argument('--model', type=str, default='crf')
parser.add_argument('--pre_model', type=str, help='btc, cnn, crnn', default='e')
parser.add_argument('--dataset1', type=str, help='Dataset', default='isophonic_221')
parser.add_argument('--dataset2', type=str, help='Dataset', default='uspop_185')
parser.add_argument('--dataset3', type=str, help='Dataset', default='robbiewilliams')
parser.add_argument('--restore_epoch', type=int, default=1000)
parser.add_argument('--early_stop', type=bool, help='no improvement during 10 epoch -> stop', default=True)
args = parser.parse_args()
config = HParams.load("run_config.yaml")
if args.voca == True:
config.feature['large_voca'] = True
config.model['num_chords'] = 170
config.model['probs_out'] = True
# Result save path
asset_path = config.path['asset_path']
ckpt_path = config.path['ckpt_path']
result_path = config.path['result_path']
restore_epoch = args.restore_epoch
experiment_num = str(args.index)
ckpt_file_name = 'idx_'+experiment_num+'_%03d.pth.tar'
tf_logger = TF_Logger(os.path.join(asset_path, 'tensorboard', 'idx_'+experiment_num))
logger.info("==== Experiment Number : %d " % args.index)
if args.pre_model == 'cnn':
config.experiment['batch_size'] = 20
# Data loader
train_dataset1 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset1,), num_workers=20, preprocessing=False, train=True, kfold=args.kfold)
train_dataset2 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset2,), num_workers=20, preprocessing=False, train=True, kfold=args.kfold)
train_dataset3 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset3,), num_workers=20, preprocessing=False, train=True, kfold=args.kfold)
train_dataset = train_dataset1.__add__(train_dataset2).__add__(train_dataset3)
valid_dataset1 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset1,), preprocessing=False, train=False, kfold=args.kfold)
valid_dataset2 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset2,), preprocessing=False, train=False, kfold=args.kfold)
valid_dataset3 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset3,), preprocessing=False, train=False, kfold=args.kfold)
valid_dataset = valid_dataset1.__add__(valid_dataset2).__add__(valid_dataset3)
train_dataloader = AudioDataLoader(dataset=train_dataset, batch_size=config.experiment['batch_size'], drop_last=False, shuffle=True)
valid_dataloader = AudioDataLoader(dataset=valid_dataset, batch_size=config.experiment['batch_size'], drop_last=False)
# Model and Optimizer
if args.pre_model == 'cnn':
pre_model = CNN(config=config.model).to(device)
elif args.pre_model == 'crnn':
pre_model = CRNN(config=config.model).to(device)
elif args.pre_model == 'btc':
pre_model = BTC_model(config=config.model).to(device)
else: raise NotImplementedError
if args.pre_model == 'cnn':
if args.voca == False:
if args.kfold == 0:
load_ckpt_file_name = 'idx_0_%03d.pth.tar'
load_restore_epoch = 10
else:
if args.kfold == 0:
load_ckpt_file_name = 'idx_1_%03d.pth.tar'
load_restore_epoch = 10
else:
raise NotImplementedError
if os.path.isfile(os.path.join(asset_path, ckpt_path, load_ckpt_file_name % load_restore_epoch)):
checkpoint = torch.load(os.path.join(asset_path, ckpt_path, load_ckpt_file_name % load_restore_epoch))
pre_model.load_state_dict(checkpoint['model'])
logger.info("restore pre model with %d epochs" % load_restore_epoch)
else:
raise NotImplementedError
# Fix Pre Model Parameters
for param in pre_model.parameters():
param.requires_grad = False
# Crf Model and Optimizer
crf = Crf(num_chords=config.model['num_chords'], timestep=config.model['timestep']).to(device)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, crf.parameters()), lr=0.01, weight_decay=config.experiment['weight_decay'], betas=(0.9, 0.98), eps=1e-9)
# Make asset directory
if not os.path.exists(os.path.join(asset_path, ckpt_path)):
os.makedirs(os.path.join(asset_path, ckpt_path))
os.makedirs(os.path.join(asset_path, result_path))
# Load model
if os.path.isfile(os.path.join(asset_path, ckpt_path, ckpt_file_name % restore_epoch)):
checkpoint = torch.load(os.path.join(asset_path, ckpt_path, ckpt_file_name % restore_epoch))
crf.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']
logger.info("restore model with %d epochs" % restore_epoch)
else:
logger.info("no checkpoint with %d epochs" % restore_epoch)
restore_epoch = 0
# Global mean and variance calculate
mp3_config = config.mp3
feature_config = config.feature
mp3_string = "%d_%.1f_%.1f" % (mp3_config['song_hz'], mp3_config['inst_len'], mp3_config['skip_interval'])
feature_string = "_%s_%d_%d_%d_" % ('cqt', feature_config['n_bins'], feature_config['bins_per_octave'], feature_config['hop_length'])
z_path = os.path.join(config.path['root_path'], 'result', mp3_string + feature_string + 'mix_kfold_'+ str(args.kfold) +'_normalization.pt')
if os.path.exists(z_path):
normalization = torch.load(z_path)
mean = normalization['mean']
std = normalization['std']
logger.info("Global mean and std (k fold index %d) load complete" % args.kfold)
else:
mean = 0
square_mean = 0
k = 0
for i, data in enumerate(train_dataloader):
features, input_percentages, chords, collapsed_chords, chord_lens, boundaries = data
features = features.to(device)
mean += torch.mean(features).item()
square_mean += torch.mean(features.pow(2)).item()
k += 1
square_mean = square_mean / k
mean = mean / k
std = np.sqrt(square_mean - mean * mean)
normalization = dict()
normalization['mean'] = mean
normalization['std'] = std
torch.save(normalization, z_path)
logger.info("Global mean and std (training set, k fold index %d) calculation complete" % args.kfold)
current_step = 0
best_acc = 0
before_acc = 0
early_stop_idx = 0
pre_model.eval()
for epoch in range(restore_epoch, config.experiment['max_epoch']):
# Training
crf.train()
train_loss_list = []
total = 0.
correct = 0.
second_correct = 0.
for i, data in enumerate(train_dataloader):
features, input_percentages, chords, collapsed_chords, chord_lens, boundaries = data
features, chords = features.to(device), chords.to(device)
features.requires_grad = True
features = (features - mean) / std
# forward
features = features.squeeze(1).permute(0,2,1)
optimizer.zero_grad()
logits = pre_model(features, chords)
if args.pre_model == 'crnn':
logits = logits.detach()
logits.requires_grad = True
prediction, total_loss = crf(logits, chords)
# save accuracy and loss
total += chords.size(0)
correct += (prediction == chords).type_as(chords).sum()
train_loss_list.append(total_loss.item())
# optimize step
total_loss.backward()
optimizer.step()
current_step += 1
# logging loss and accuracy using tensorboard
result = {'loss/tr': np.mean(train_loss_list), 'acc/tr': correct.item() / total}
for tag, value in result.items(): tf_logger.scalar_summary(tag, value, epoch+1)
logger.info("training loss for %d epoch: %.4f" % (epoch + 1, np.mean(train_loss_list)))
logger.info("training accuracy for %d epoch: %.4f" % (epoch + 1, (correct.item() / total)))
# Validation
with torch.no_grad():
crf.eval()
val_total = 0.
val_correct = 0.
val_second_correct = 0.
validation_loss = 0
n = 0
for i, data in enumerate(valid_dataloader):
val_features, val_input_percentages, val_chords, val_collapsed_chords, val_chord_lens, val_boundaries = data
val_features, val_chords = val_features.to(device), val_chords.to(device)
val_features = (val_features - mean) / std
val_features = val_features.squeeze(1).permute(0, 2, 1)
val_logits = pre_model(val_features, val_chords)
val_prediction, val_loss = crf(val_logits, val_chords)
val_total += val_chords.size(0)
val_correct += (val_prediction == val_chords).type_as(val_chords).sum()
validation_loss += val_loss.item()
n += 1
# logging loss and accuracy using tensorboard
validation_loss /= n
result = {'loss/val': validation_loss, 'acc/val': val_correct.item() / val_total}
for tag, value in result.items(): tf_logger.scalar_summary(tag, value, epoch + 1)
logger.info("validation loss(%d): %.4f" % (epoch + 1, validation_loss))
logger.info("validation accuracy(%d): %.4f" % (epoch + 1, (val_correct.item() / val_total)))
current_acc = val_correct.item() / val_total
if best_acc < val_correct.item() / val_total:
early_stop_idx = 0
best_acc = val_correct.item() / val_total
logger.info('==== best accuracy is %.4f and epoch is %d' % (best_acc, epoch + 1))
logger.info('saving model, Epoch %d, step %d' % (epoch + 1, current_step + 1))
model_save_path = os.path.join(asset_path, 'model', ckpt_file_name % (epoch + 1))
state_dict = {'model': crf.state_dict(),'optimizer': optimizer.state_dict(),'epoch': epoch}
torch.save(state_dict, model_save_path)
last_best_epoch = epoch + 1
# save model
elif (epoch + 1) % config.experiment['save_step'] == 0:
logger.info('saving model, Epoch %d, step %d' % (epoch + 1, current_step + 1))
model_save_path = os.path.join(asset_path, 'model', ckpt_file_name % (epoch + 1))
state_dict = {'model': crf.state_dict(),'optimizer': optimizer.state_dict(),'epoch': epoch}
torch.save(state_dict, model_save_path)
early_stop_idx += 1
else:
early_stop_idx += 1
if (args.early_stop == True) and (early_stop_idx > 5):
logger.info('==== early stopped and epoch is %d' % (epoch + 1))
break
# learning rate decay
if before_acc > current_acc:
adjusting_learning_rate(optimizer=optimizer, factor=0.95, min_lr=5e-6)
before_acc = current_acc
# Load model
if os.path.isfile(os.path.join(asset_path, ckpt_path, ckpt_file_name % last_best_epoch)):
checkpoint = torch.load(os.path.join(asset_path, ckpt_path, ckpt_file_name % last_best_epoch))
crf.load_state_dict(checkpoint['model'])
logger.info("last best restore model with %d epochs" % last_best_epoch)
else:
raise NotImplementedError
# score Validation
if args.voca == True:
score_metrics = ['root', 'thirds', 'triads', 'sevenths', 'tetrads', 'majmin', 'mirex']
score_list_dict1, song_length_list1, average_score_dict1 = large_voca_score_calculation_crf(valid_dataset=valid_dataset1, config=config, pre_model=pre_model, model=crf, model_type=args.pre_model, mean=mean, std=std, device=device)
score_list_dict2, song_length_list2, average_score_dict2 = large_voca_score_calculation_crf(valid_dataset=valid_dataset2, config=config, pre_model=pre_model, model=crf, model_type=args.pre_model, mean=mean, std=std, device=device)
score_list_dict3, song_length_list3, average_score_dict3 = large_voca_score_calculation_crf(valid_dataset=valid_dataset3, config=config, pre_model=pre_model, model=crf, model_type=args.pre_model, mean=mean, std=std, device=device)
for m in score_metrics:
average_score = (np.sum(song_length_list1) * average_score_dict1[m] + np.sum(song_length_list2) *average_score_dict2[m] + np.sum(song_length_list3) * average_score_dict3[m]) / (np.sum(song_length_list1) + np.sum(song_length_list2) + np.sum(song_length_list3))
logger.info('==== %s score 1 is %.4f' % (m, average_score_dict1[m]))
logger.info('==== %s score 2 is %.4f' % (m, average_score_dict2[m]))
logger.info('==== %s score 3 is %.4f' % (m, average_score_dict3[m]))
logger.info('==== %s mix average score is %.4f' % (m, average_score))
else:
score_metrics = ['root', 'majmin']
score_list_dict1, song_length_list1, average_score_dict1 = root_majmin_score_calculation_crf(valid_dataset=valid_dataset1, config=config, pre_model=pre_model, model=crf, model_type=args.pre_model, mean=mean, std=std, device=device)
score_list_dict2, song_length_list2, average_score_dict2 = root_majmin_score_calculation_crf(valid_dataset=valid_dataset2, config=config, pre_model=pre_model, model=crf, model_type=args.pre_model, mean=mean, std=std, device=device)
score_list_dict3, song_length_list3, average_score_dict3 = root_majmin_score_calculation_crf(valid_dataset=valid_dataset3, config=config, pre_model=pre_model, model=crf, model_type=args.pre_model, mean=mean, std=std, device=device)
for m in score_metrics:
average_score = (np.sum(song_length_list1) * average_score_dict1[m] + np.sum(song_length_list2) *average_score_dict2[m] + np.sum(song_length_list3) * average_score_dict3[m]) / (np.sum(song_length_list1) + np.sum(song_length_list2) + np.sum(song_length_list3))
logger.info('==== %s score 1 is %.4f' % (m, average_score_dict1[m]))
logger.info('==== %s score 2 is %.4f' % (m, average_score_dict2[m]))
logger.info('==== %s score 3 is %.4f' % (m, average_score_dict3[m]))
logger.info('==== %s mix average score is %.4f' % (m, average_score))
================================================
FILE: utils/__init__.py
================================================
================================================
FILE: utils/chords.py
================================================
# encoding: utf-8
"""
This module contains chord evaluation functionality.
It provides the evaluation measures used for the MIREX ACE task, and
tries to follow [1]_ and [2]_ as closely as possible.
Notes
-----
This implementation tries to follow the references and their implementation
(e.g., https://github.com/jpauwels/MusOOEvaluator for [2]_). However, there
are some known (and possibly some unknown) differences. If you find one not
listed in the following, please file an issue:
- Detected chord segments are adjusted to fit the length of the annotations.
In particular, this means that, if necessary, filler segments of 'no chord'
are added at beginnings and ends. This can result in different segmentation
scores compared to the original implementation.
References
----------
.. [1] Christopher Harte, "Towards Automatic Extraction of Harmony Information
from Music Signals." Dissertation,
Department for Electronic Engineering, Queen Mary University of London,
2010.
.. [2] Johan Pauwels and Geoffroy Peeters.
"Evaluating Automatically Estimated Chord Sequences."
In Proceedings of ICASSP 2013, Vancouver, Canada, 2013.
"""
import numpy as np
import pandas as pd
import mir_eval
CHORD_DTYPE = [('root', np.int),
('bass', np.int),
('intervals', np.int, (12,)),
('is_major',np.bool)]
CHORD_ANN_DTYPE = [('start', np.float),
('end', np.float),
('chord', CHORD_DTYPE)]
NO_CHORD = (-1, -1, np.zeros(12, dtype=np.int), False)
UNKNOWN_CHORD = (-1, -1, np.ones(12, dtype=np.int) * -1, False)
PITCH_CLASS = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
def idx_to_chord(idx):
if idx == 24:
return "-"
elif idx == 25:
return u"\u03B5"
minmaj = idx % 2
root = idx // 2
return PITCH_CLASS[root] + ("M" if minmaj == 0 else "m")
class Chords:
def __init__(self):
self._shorthands = {
'maj': self.interval_list('(1,3,5)'),
'min': self.interval_list('(1,b3,5)'),
'dim': self.interval_list('(1,b3,b5)'),
'aug': self.interval_list('(1,3,#5)'),
'maj7': self.interval_list('(1,3,5,7)'),
'min7': self.interval_list('(1,b3,5,b7)'),
'7': self.interval_list('(1,3,5,b7)'),
'6': self.interval_list('(1,6)'), # custom
'5': self.interval_list('(1,5)'),
'4': self.interval_list('(1,4)'), # custom
'1': self.interval_list('(1)'),
'dim7': self.interval_list('(1,b3,b5,bb7)'),
'hdim7': self.interval_list('(1,b3,b5,b7)'),
'minmaj7': self.interval_list('(1,b3,5,7)'),
'maj6': self.interval_list('(1,3,5,6)'),
'min6': self.interval_list('(1,b3,5,6)'),
'9': self.interval_list('(1,3,5,b7,9)'),
'maj9': self.interval_list('(1,3,5,7,9)'),
'min9': self.interval_list('(1,b3,5,b7,9)'),
'sus2': self.interval_list('(1,2,5)'),
'sus4': self.interval_list('(1,4,5)'),
'11': self.interval_list('(1,3,5,b7,9,11)'),
'min11': self.interval_list('(1,b3,5,b7,9,11)'),
'13': self.interval_list('(1,3,5,b7,13)'),
'maj13': self.interval_list('(1,3,5,7,13)'),
'min13': self.interval_list('(1,b3,5,b7,13)')
}
def chords(self, labels):
"""
Transform a list of chord labels into an array of internal numeric
representations.
Parameters
----------
labels : list
List of chord labels (str).
Returns
-------
chords : numpy.array
Structured array with columns 'root', 'bass', and 'intervals',
containing a numeric representation of chords.
"""
crds = np.zeros(len(labels), dtype=CHORD_DTYPE)
cache = {}
for i, lbl in enumerate(labels):
cv = cache.get(lbl, None)
if cv is None:
cv = self.chord(lbl)
cache[lbl] = cv
crds[i] = cv
return crds
def label_error_modify(self, label):
if label == 'Emin/4': label = 'E:min/4'
elif label == 'A7/3': label = 'A:7/3'
elif label == 'Bb7/3': label = 'Bb:7/3'
elif label == 'Bb7/5': label = 'Bb:7/5'
elif label.find(':') == -1:
if label.find('min') != -1:
label = label[:label.find('min')] + ':' + label[label.find('min'):]
return label
def chord(self, label):
"""
Transform a chord label into the internal numeric represenation of
(root, bass, intervals array).
Parameters
----------
label : str
Chord label.
Returns
-------
chord : tuple
Numeric representation of the chord: (root, bass, intervals array).
"""
try:
is_major = False
if label == 'N':
return NO_CHORD
if label == 'X':
return UNKNOWN_CHORD
label = self.label_error_modify(label)
c_idx = label.find(':')
s_idx = label.find('/')
if c_idx == -1:
quality_str = 'maj'
if s_idx == -1:
root_str = label
bass_str = ''
else:
root_str = label[:s_idx]
bass_str = label[s_idx + 1:]
else:
root_str = label[:c_idx]
if s_idx == -1:
quality_str = label[c_idx + 1:]
bass_str = ''
else:
quality_str = label[c_idx + 1:s_idx]
bass_str = label[s_idx + 1:]
root = self.pitch(root_str)
bass = self.interval(bass_str) if bass_str else 0
ivs = self.chord_intervals(quality_str)
ivs[bass] = 1
if 'min' in quality_str:
is_major = False
else:
is_major = True
except Exception as e:
print(e, label)
return root, bass, ivs, is_major
_l = [0, 1, 1, 0, 1, 1, 1]
_chroma_id = (np.arange(len(_l) * 2) + 1) + np.array(_l + _l).cumsum() - 1
def modify(self, base_pitch, modifier):
"""
Modify a pitch class in integer representation by a given modifier string.
A modifier string can be any sequence of 'b' (one semitone down)
and '#' (one semitone up).
Parameters
----------
base_pitch : int
Pitch class as integer.
modifier : str
String of modifiers ('b' or '#').
Returns
-------
modified_pitch : int
Modified root note.
"""
for m in modifier:
if m == 'b':
base_pitch -= 1
elif m == '#':
base_pitch += 1
else:
raise ValueError('Unknown modifier: {}'.format(m))
return base_pitch
def pitch(self, pitch_str):
"""
Convert a string representation of a pitch class (consisting of root
note and modifiers) to an integer representation.
Parameters
----------
pitch_str : str
String representation of a pitch class.
Returns
-------
pitch : int
Integer representation of a pitch class.
"""
return self.modify(self._chroma_id[(ord(pitch_str[0]) - ord('C')) % 7],
pitch_str[1:]) % 12
def interval(self, interval_str):
"""
Convert a string representation of a musical interval into a pitch class
(e.g. a minor seventh 'b7' into 10, because it is 10 semitones above its
base note).
Parameters
----------
interval_str : str
Musical interval.
Returns
-------
pitch_class : int
Number of semitones to base note of interval.
"""
for i, c in enumerate(interval_str):
if c.isdigit():
return self.modify(self._chroma_id[int(interval_str[i:]) - 1],
interval_str[:i]) % 12
def interval_list(self, intervals_str, given_pitch_classes=None):
"""
Convert a list of intervals given as string to a binary pitch class
representation. For example, 'b3, 5' would become
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0].
Parameters
----------
intervals_str : str
List of intervals as comma-separated string (e.g. 'b3, 5').
given_pitch_classes : None or numpy array
If None, start with empty pitch class array, if numpy array of length
12, this array will be modified.
Returns
-------
pitch_classes : numpy array
Binary pitch class representation of intervals.
"""
if given_pitch_classes is None:
given_pitch_classes = np.zeros(12, dtype=np.int)
for int_def in intervals_str[1:-1].split(','):
int_def = int_def.strip()
if int_def[0] == '*':
given_pitch_classes[self.interval(int_def[1:])] = 0
else:
given_pitch_classes[self.interval(int_def)] = 1
return given_pitch_classes
# mapping of shorthand interval notations to the actual interval representation
def chord_intervals(self, quality_str):
"""
Convert a chord quality string to a pitch class representation. For
example, 'maj' becomes [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0].
Parameters
----------
quality_str : str
String defining the chord quality.
Returns
-------
pitch_classes : numpy array
Binary pitch class representation of chord quality.
"""
list_idx = quality_str.find('(')
if list_idx == -1:
return self._shorthands[quality_str].copy()
if list_idx != 0:
ivs = self._shorthands[quality_str[:list_idx]].copy()
else:
ivs = np.zeros(12, dtype=np.int)
return self.interval_list(quality_str[list_idx:], ivs)
def load_chords(self, filename):
"""
Load chords from a text file.
The chord must follow the syntax defined in [1]_.
Parameters
----------
filename : str
File containing chord segments.
Returns
-------
crds : numpy structured array
Structured array with columns "start", "end", and "chord",
containing the beginning, end, and chord definition of chord
segments.
References
----------
.. [1] Christopher Harte, "Towards Automatic Extraction of Harmony
Information from Music Signals." Dissertation,
Department for Electronic Engineering, Queen Mary University of
London, 2010.
"""
start, end, chord_labels = [], [], []
with open(filename, 'r') as f:
for line in f:
if line:
splits = line.split()
if len(splits) == 3:
s = splits[0]
e = splits[1]
l = splits[2]
start.append(float(s))
end.append(float(e))
chord_labels.append(l)
crds = np.zeros(len(start), dtype=CHORD_ANN_DTYPE)
crds['start'] = start
crds['end'] = end
crds['chord'] = self.chords(chord_labels)
return crds
def reduce_to_triads(self, chords, keep_bass=False):
"""
Reduce chords to triads.
The function follows the reduction rules implemented in [1]_. If a chord
chord does not contain a third, major second or fourth, it is reduced to
a power chord. If it does not contain neither a third nor a fifth, it is
reduced to a single note "chord".
Parameters
----------
chords : numpy structured array
Chords to be reduced.
keep_bass : bool
Indicates whether to keep the bass note or set it to 0.
Returns
-------
reduced_chords : numpy structured array
Chords reduced to triads.
References
----------
.. [1] Johan Pauwels and Geoffroy Peeters.
"Evaluating Automatically Estimated Chord Sequences."
In Proceedings of ICASSP 2013, Vancouver, Canada, 2013.
"""
unison = chords['intervals'][:, 0].astype(bool)
maj_sec = chords['intervals'][:, 2].astype(bool)
min_third = chords['intervals'][:, 3].astype(bool)
maj_third = chords['intervals'][:, 4].astype(bool)
perf_fourth = chords['intervals'][:, 5].astype(bool)
dim_fifth = chords['intervals'][:, 6].astype(bool)
perf_fifth = chords['intervals'][:, 7].astype(bool)
aug_fifth = chords['intervals'][:, 8].astype(bool)
no_chord = (chords['intervals'] == NO_CHORD[-1]).all(axis=1)
reduced_chords = chords.copy()
ivs = reduced_chords['intervals']
ivs[~no_chord] = self.interval_list('(1)')
ivs[unison & perf_fifth] = self.interval_list('(1,5)')
ivs[~perf_fourth & maj_sec] = self._shorthands['sus2']
ivs[perf_fourth & ~maj_sec] = self._shorthands['sus4']
ivs[min_third] = self._shorthands['min']
ivs[min_third & aug_fifth & ~perf_fifth] = self.interval_list('(1,b3,#5)')
ivs[min_third & dim_fifth & ~perf_fifth] = self._shorthands['dim']
ivs[maj_third] = self._shorthands['maj']
ivs[maj_third & dim_fifth & ~perf_fifth] = self.interval_list('(1,3,b5)')
ivs[maj_third & aug_fifth & ~perf_fifth] = self._shorthands['aug']
if not keep_bass:
reduced_chords['bass'] = 0
else:
# remove bass notes if they are not part of the intervals anymore
reduced_chords['bass'] *= ivs[range(len(reduced_chords)),
reduced_chords['bass']]
# keep -1 in bass for no chords
reduced_chords['bass'][no_chord] = -1
return reduced_chords
def convert_to_id(self, root, is_major):
if root == -1:
return 24
else:
if is_major:
return root * 2
else:
return root * 2 + 1
def get_converted_chord(self, filename):
loaded_chord = self.load_chords(filename)
triads = self.reduce_to_triads(loaded_chord['chord'])
df = self.assign_chord_id(triads)
df['start'] = loaded_chord['start']
df['end'] = loaded_chord['end']
return df
def assign_chord_id(self, entry):
# maj, min chord only
# if you want to add other chord, change this part and get_converted_chord(reduce_to_triads)
df = pd.DataFrame(data=entry[['root', 'is_major']])
df['chord_id'] = df.apply(lambda row: self.convert_to_id(row['root'], row['is_major']), axis=1)
return df
def convert_to_id_voca(self, root, quality):
if root == -1:
return 169
else:
if quality == 'min':
return root * 14
elif quality == 'maj':
return root * 14 + 1
elif quality == 'dim':
return root * 14 + 2
elif quality == 'aug':
return root * 14 + 3
elif quality == 'min6':
return root * 14 + 4
elif quality == 'maj6':
return root * 14 + 5
elif quality == 'min7':
return root * 14 + 6
elif quality == 'minmaj7':
return root * 14 + 7
elif quality == 'maj7':
return root * 14 + 8
elif quality == '7':
return root * 14 + 9
elif quality == 'dim7':
return root * 14 + 10
elif quality == 'hdim7':
return root * 14 + 11
elif quality == 'sus2':
return root * 14 + 12
elif quality == 'sus4':
return root * 14 + 13
else:
return 168
def get_converted_chord_voca(self, filename):
loaded_chord = self.load_chords(filename)
triads = self.reduce_to_triads(loaded_chord['chord'])
df = pd.DataFrame(data=triads[['root', 'is_major']])
(ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(filename)
ref_labels = self.lab_file_error_modify(ref_labels)
idxs = list()
for i in ref_labels:
chord_root, quality, scale_degrees, bass = mir_eval.chord.split(i, reduce_extended_chords=True)
root, bass, ivs, is_major = self.chord(i)
idxs.append(self.convert_to_id_voca(root=root, quality=quality))
df['chord_id'] = idxs
df['start'] = loaded_chord['start']
df['end'] = loaded_chord['end']
return df
def lab_file_error_modify(self, ref_labels):
for i in range(len(ref_labels)):
if ref_labels[i][-2:] == ':4':
ref_labels[i] = ref_labels[i].replace(':4', ':sus4')
elif ref_labels[i][-2:] == ':6':
ref_labels[i] = ref_labels[i].replace(':6', ':maj6')
elif ref_labels[i][-4:] == ':6/2':
ref_labels[i] = ref_labels[i].replace(':6/2', ':maj6/2')
elif ref_labels[i] == 'Emin/4':
ref_labels[i] = 'E:min/4'
elif ref_labels[i] == 'A7/3':
ref_labels[i] = 'A:7/3'
elif ref_labels[i] == 'Bb7/3':
ref_labels[i] = 'Bb:7/3'
elif ref_labels[i] == 'Bb7/5':
ref_labels[i] = 'Bb:7/5'
elif ref_labels[i].find(':') == -1:
if ref_labels[i].find('min') != -1:
ref_labels[i] = ref_labels[i][:ref_labels[i].find('min')] + ':' + ref_labels[i][ref_labels[i].find('min'):]
return ref_labels
================================================
FILE: utils/hparams.py
================================================
import yaml
# TODO: add function should be changed
class HParams(object):
# Hyperparameter class using yaml
def __init__(self, **kwargs):
self.__dict__ = kwargs
def add(self, **kwargs):
# change is needed - if key is existed, do not update.
self.__dict__.update(kwargs)
def update(self, **kwargs):
self.__dict__.update(kwargs)
return self
def save(self, path):
with open(path, 'w') as f:
yaml.dump(self.__dict__, f)
return self
def __repr__(self):
return '\nHyperparameters:\n' + '\n'.join([' {}={}'.format(k, v) for k, v in self.__dict__.items()])
@classmethod
def load(cls, path):
with open(path, 'r') as f:
return cls(**yaml.load(f))
if __name__ == '__main__':
hparams = HParams.load('hparams.yaml')
print(hparams)
d = {"MemoryNetwork": 0, "c": 1}
hparams.add(**d)
print(hparams)
================================================
FILE: utils/logger.py
================================================
import logging
import os
import sys
import time
project_name = os.getcwd().split('/')[-1]
_logger = logging.getLogger(project_name)
_logger.addHandler(logging.StreamHandler())
def _log_prefix():
# Returns (filename, line number) for the stack frame.
def _get_file_line():
# pylint: disable=protected-access
# noinspection PyProtectedMember
f = sys._getframe()
# pylint: enable=protected-access
our_file = f.f_code.co_filename
f = f.f_back
while f:
code = f.f_code
if code.co_filename != our_file:
return code.co_filename, f.f_lineno
f = f.f_back
return '', 0
# current time
now = time.time()
now_tuple = time.localtime(now)
now_millisecond = int(1e3 * (now % 1.0))
# current filename and line
filename, line = _get_file_line()
basename = os.path.basename(filename)
s = '%02d-%02d %02d:%02d:%02d.%03d %s:%d] ' % (
now_tuple[1], # month
now_tuple[2], # day
now_tuple[3], # hour
now_tuple[4], # min
now_tuple[5], # sec
now_millisecond,
basename,
line)
return s
def logging_verbosity(verbosity=0):
_logger.setLevel(verbosity)
def debug(msg, *args, **kwargs):
_logger.debug('D ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs)
def info(msg, *args, **kwargs):
_logger.info('I ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs)
def warn(msg, *args, **kwargs):
_logger.warning('W ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs)
def error(msg, *args, **kwargs):
_logger.error('E ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs)
def fatal(msg, *args, **kwargs):
_logger.fatal('F ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs)
================================================
FILE: utils/mir_eval_modules.py
================================================
import numpy as np
import librosa
import mir_eval
import torch
import os
idx2chord = ['C', 'C:min', 'C#', 'C#:min', 'D', 'D:min', 'D#', 'D#:min', 'E', 'E:min', 'F', 'F:min', 'F#',
'F#:min', 'G', 'G:min', 'G#', 'G#:min', 'A', 'A:min', 'A#', 'A#:min', 'B', 'B:min', 'N']
root_list = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
quality_list = ['min', 'maj', 'dim', 'aug', 'min6', 'maj6', 'min7', 'minmaj7', 'maj7', '7', 'dim7', 'hdim7', 'sus2', 'sus4']
def idx2voca_chord():
idx2voca_chord = {}
idx2voca_chord[169] = 'N'
idx2voca_chord[168] = 'X'
for i in range(168):
root = i // 14
root = root_list[root]
quality = i % 14
quality = quality_list[quality]
if i % 14 != 1:
chord = root + ':' + quality
else:
chord = root
idx2voca_chord[i] = chord
return idx2voca_chord
def audio_file_to_features(audio_file, config):
original_wav, sr = librosa.load(audio_file, sr=config.mp3['song_hz'], mono=True)
currunt_sec_hz = 0
while len(original_wav) > currunt_sec_hz + config.mp3['song_hz'] * config.mp3['inst_len']:
start_idx = int(currunt_sec_hz)
end_idx = int(currunt_sec_hz + config.mp3['song_hz'] * config.mp3['inst_len'])
tmp = librosa.cqt(original_wav[start_idx:end_idx], sr=sr, n_bins=config.feature['n_bins'], bins_per_octave=config.feature['bins_per_octave'], hop_length=config.feature['hop_length'])
if start_idx == 0:
feature = tmp
else:
feature = np.concatenate((feature, tmp), axis=1)
currunt_sec_hz = end_idx
tmp = librosa.cqt(original_wav[currunt_sec_hz:], sr=sr, n_bins=config.feature['n_bins'], bins_per_octave=config.feature['bins_per_octave'], hop_length=config.feature['hop_length'])
feature = np.concatenate((feature, tmp), axis=1)
feature = np.log(np.abs(feature) + 1e-6)
feature_per_second = config.mp3['inst_len'] / config.model['timestep']
song_length_second = len(original_wav)/config.mp3['song_hz']
return feature, feature_per_second, song_length_second
# Audio files with format of wav and mp3
def get_audio_paths(audio_dir):
return [os.path.join(root, fname) for (root, dir_names, file_names) in os.walk(audio_dir, followlinks=True)
for fname in file_names if (fname.lower().endswith('.wav') or fname.lower().endswith('.mp3'))]
class metrics():
def __init__(self):
super(metrics, self).__init__()
self.score_metrics = ['root', 'thirds', 'triads', 'sevenths', 'tetrads', 'majmin', 'mirex']
self.score_list_dict = dict()
for i in self.score_metrics:
self.score_list_dict[i] = list()
self.average_score = dict()
def score(self, metric, gt_path, est_path):
if metric == 'root':
score = self.root_score(gt_path,est_path)
elif metric == 'thirds':
score = self.thirds_score(gt_path,est_path)
elif metric == 'triads':
score = self.triads_score(gt_path,est_path)
elif metric == 'sevenths':
score = self.sevenths_score(gt_path,est_path)
elif metric == 'tetrads':
score = self.tetrads_score(gt_path,est_path)
elif metric == 'majmin':
score = self.majmin_score(gt_path,est_path)
elif metric == 'mirex':
score = self.mirex_score(gt_path,est_path)
else:
raise NotImplementedError
return score
def root_score(self, gt_path, est_path):
(ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)
ref_labels = lab_file_error_modify(ref_labels)
(est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)
est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),
ref_intervals.max(), mir_eval.chord.NO_CHORD,
mir_eval.chord.NO_CHORD)
(intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,
est_intervals, est_labels)
durations = mir_eval.util.intervals_to_durations(intervals)
comparisons = mir_eval.chord.root(ref_labels, est_labels)
score = mir_eval.chord.weighted_accuracy(comparisons, durations)
return score
def thirds_score(self, gt_path, est_path):
(ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)
ref_labels = lab_file_error_modify(ref_labels)
(est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)
est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),
ref_intervals.max(), mir_eval.chord.NO_CHORD,
mir_eval.chord.NO_CHORD)
(intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,
est_intervals, est_labels)
durations = mir_eval.util.intervals_to_durations(intervals)
comparisons = mir_eval.chord.thirds(ref_labels, est_labels)
score = mir_eval.chord.weighted_accuracy(comparisons, durations)
return score
def triads_score(self, gt_path, est_path):
(ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)
ref_labels = lab_file_error_modify(ref_labels)
(est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)
est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),
ref_intervals.max(), mir_eval.chord.NO_CHORD,
mir_eval.chord.NO_CHORD)
(intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,
est_intervals, est_labels)
durations = mir_eval.util.intervals_to_durations(intervals)
comparisons = mir_eval.chord.triads(ref_labels, est_labels)
score = mir_eval.chord.weighted_accuracy(comparisons, durations)
return score
def sevenths_score(self, gt_path, est_path):
(ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)
ref_labels = lab_file_error_modify(ref_labels)
(est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)
est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),
ref_intervals.max(), mir_eval.chord.NO_CHORD,
mir_eval.chord.NO_CHORD)
(intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,
est_intervals, est_labels)
durations = mir_eval.util.intervals_to_durations(intervals)
comparisons = mir_eval.chord.sevenths(ref_labels, est_labels)
score = mir_eval.chord.weighted_accuracy(comparisons, durations)
return score
def tetrads_score(self, gt_path, est_path):
(ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)
ref_labels = lab_file_error_modify(ref_labels)
(est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)
est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),
ref_intervals.max(), mir_eval.chord.NO_CHORD,
mir_eval.chord.NO_CHORD)
(intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,
est_intervals, est_labels)
durations = mir_eval.util.intervals_to_durations(intervals)
comparisons = mir_eval.chord.tetrads(ref_labels, est_labels)
score = mir_eval.chord.weighted_accuracy(comparisons, durations)
return score
def majmin_score(self, gt_path, est_path):
(ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)
ref_labels = lab_file_error_modify(ref_labels)
(est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)
est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),
ref_intervals.max(), mir_eval.chord.NO_CHORD,
mir_eval.chord.NO_CHORD)
(intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,
est_intervals, est_labels)
durations = mir_eval.util.intervals_to_durations(intervals)
comparisons = mir_eval.chord.majmin(ref_labels, est_labels)
score = mir_eval.chord.weighted_accuracy(comparisons, durations)
return score
def mirex_score(self, gt_path, est_path):
(ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)
ref_labels = lab_file_error_modify(ref_labels)
(est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)
est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),
ref_intervals.max(), mir_eval.chord.NO_CHORD,
mir_eval.chord.NO_CHORD)
(intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,
est_intervals, est_labels)
durations = mir_eval.util.intervals_to_durations(intervals)
comparisons = mir_eval.chord.mirex(ref_labels, est_labels)
score = mir_eval.chord.weighted_accuracy(comparisons, durations)
return score
def lab_file_error_modify(ref_labels):
for i in range(len(ref_labels)):
if ref_labels[i][-2:] == ':4':
ref_labels[i] = ref_labels[i].replace(':4', ':sus4')
elif ref_labels[i][-2:] == ':6':
ref_labels[i] = ref_labels[i].replace(':6', ':maj6')
elif ref_labels[i][-4:] == ':6/2':
ref_labels[i] = ref_labels[i].replace(':6/2', ':maj6/2')
elif ref_labels[i] == 'Emin/4':
ref_labels[i] = 'E:min/4'
elif ref_labels[i] == 'A7/3':
ref_labels[i] = 'A:7/3'
elif ref_labels[i] == 'Bb7/3':
ref_labels[i] = 'Bb:7/3'
elif ref_labels[i] == 'Bb7/5':
ref_labels[i] = 'Bb:7/5'
elif ref_labels[i].find(':') == -1:
if ref_labels[i].find('min') != -1:
ref_labels[i] = ref_labels[i][:ref_labels[i].find('min')] + ':' + ref_labels[i][ref_labels[i].find('min'):]
return ref_labels
def root_majmin_score_calculation(valid_dataset, config, mean, std, device, model, model_type, verbose=False):
valid_song_names = valid_dataset.song_names
paths = valid_dataset.preprocessor.get_all_files()
metrics_ = metrics()
song_length_list = list()
for path in paths:
song_name, lab_file_path, mp3_file_path, _ = path
if not song_name in valid_song_names:
continue
try:
n_timestep = config.model['timestep']
feature, feature_per_second, song_length_second = audio_file_to_features(mp3_file_path, config)
feature = feature.T
feature = (feature - mean) / std
time_unit = feature_per_second
num_pad = n_timestep - (feature.shape[0] % n_timestep)
feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
num_instance = feature.shape[0] // n_timestep
start_time = 0.0
lines = []
with torch.no_grad():
model.eval()
feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device)
for t in range(num_instance):
if model_type == 'btc':
encoder_output, _ = model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :])
prediction, _ = model.output_layer(encoder_output)
prediction = prediction.squeeze()
elif model_type == 'cnn' or model_type =='crnn':
prediction, _, _, _ = model(feature[:, n_timestep * t:n_timestep * (t + 1), :], torch.randint(config.model['num_chords'], (n_timestep,)).to(device))
for i in range(n_timestep):
if t == 0 and i == 0:
prev_chord = prediction[i].item()
continue
if prediction[i].item() != prev_chord:
lines.append(
'%.6f %.6f %s\n' % (
start_time, time_unit * (n_timestep * t + i), idx2chord[prev_chord]))
start_time = time_unit * (n_timestep * t + i)
prev_chord = prediction[i].item()
if t == num_instance - 1 and i + num_pad == n_timestep:
if start_time != time_unit * (n_timestep * t + i):
lines.append(
'%.6f %.6f %s\n' % (
start_time, time_unit * (n_timestep * t + i), idx2chord[prev_chord]))
break
pid = os.getpid()
tmp_path = 'tmp_' + str(pid) + '.lab'
with open(tmp_path, 'w') as f:
for line in lines:
f.write(line)
root_majmin = ['root', 'majmin']
for m in root_majmin:
metrics_.score_list_dict[m].append(metrics_.score(metric=m, gt_path=lab_file_path, est_path=tmp_path))
song_length_list.append(song_length_second)
if verbose:
for m in root_majmin:
print('song name %s, %s score : %.4f' % (song_name, m, metrics_.score_list_dict[m][-1]))
except:
print('song name %s\' lab file error' % song_name)
tmp = song_length_list / np.sum(song_length_list)
for m in root_majmin:
metrics_.average_score[m] = np.sum(np.multiply(metrics_.score_list_dict[m], tmp))
return metrics_.score_list_dict, song_length_list, metrics_.average_score
def root_majmin_score_calculation_crf(valid_dataset, config, mean, std, device, pre_model, model, model_type, verbose=False):
valid_song_names = valid_dataset.song_names
paths = valid_dataset.preprocessor.get_all_files()
metrics_ = metrics()
song_length_list = list()
for path in paths:
song_name, lab_file_path, mp3_file_path, _ = path
if not song_name in valid_song_names:
continue
try:
n_timestep = config.model['timestep']
feature, feature_per_second, song_length_second = audio_file_to_features(mp3_file_path, config)
feature = feature.T
feature = (feature - mean) / std
time_unit = feature_per_second
num_pad = n_timestep - (feature.shape[0] % n_timestep)
feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
num_instance = feature.shape[0] // n_timestep
start_time = 0.0
lines = []
with torch.no_grad():
model.eval()
feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device)
for t in range(num_instance):
if (model_type == 'cnn') or (model_type == 'crnn') or (model_type == 'btc'):
logits = pre_model(feature[:, n_timestep * t:n_timestep * (t + 1), :], torch.randint(config.model['num_chords'], (n_timestep,)).to(device))
prediction, _ = model(logits, torch.randint(config.model['num_chords'], (n_timestep,)).to(device))
else:
raise NotImplementedError
for i in range(n_timestep):
if t == 0 and i == 0:
prev_chord = prediction[i].item()
continue
if prediction[i].item() != prev_chord:
lines.append(
'%.6f %.6f %s\n' % (
start_time, time_unit * (n_timestep * t + i), idx2chord[prev_chord]))
start_time = time_unit * (n_timestep * t + i)
prev_chord = prediction[i].item()
if t == num_instance - 1 and i + num_pad == n_timestep:
if start_time != time_unit * (n_timestep * t + i):
lines.append(
'%.6f %.6f %s\n' % (
start_time, time_unit * (n_timestep * t + i), idx2chord[prev_chord]))
break
pid = os.getpid()
tmp_path = 'tmp_' + str(pid) + '.lab'
with open(tmp_path, 'w') as f:
for line in lines:
f.write(line)
root_majmin = ['root', 'majmin']
for m in root_majmin:
metrics_.score_list_dict[m].append(metrics_.score(metric=m, gt_path=lab_file_path, est_path=tmp_path))
song_length_list.append(song_length_second)
if verbose:
for m in root_majmin:
print('song name %s, %s score : %.4f' % (song_name, m, metrics_.score_list_dict[m][-1]))
except:
print('song name %s\' lab file error' % song_name)
tmp = song_length_list / np.sum(song_length_list)
for m in root_majmin:
metrics_.average_score[m] = np.sum(np.multiply(metrics_.score_list_dict[m], tmp))
return metrics_.score_list_dict, song_length_list, metrics_.average_score
def large_voca_score_calculation(valid_dataset, config, mean, std, device, model, model_type, verbose=False):
idx2voca = idx2voca_chord()
valid_song_names = valid_dataset.song_names
paths = valid_dataset.preprocessor.get_all_files()
metrics_ = metrics()
song_length_list = list()
for path in paths:
song_name, lab_file_path, mp3_file_path, _ = path
if not song_name in valid_song_names:
continue
try:
n_timestep = config.model['timestep']
feature, feature_per_second, song_length_second = audio_file_to_features(mp3_file_path, config)
feature = feature.T
feature = (feature - mean) / std
time_unit = feature_per_second
num_pad = n_timestep - (feature.shape[0] % n_timestep)
feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
num_instance = feature.shape[0] // n_timestep
start_time = 0.0
lines = []
with torch.no_grad():
model.eval()
feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device)
for t in range(num_instance):
if model_type == 'btc':
encoder_output, _ = model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :])
prediction, _ = model.output_layer(encoder_output)
prediction = prediction.squeeze()
elif model_type == 'cnn' or model_type =='crnn':
prediction, _, _, _ = model(feature[:, n_timestep * t:n_timestep * (t + 1), :], torch.randint(config.model['num_chords'], (n_timestep,)).to(device))
for i in range(n_timestep):
if t == 0 and i == 0:
prev_chord = prediction[i].item()
continue
if prediction[i].item() != prev_chord:
lines.append(
'%.6f %.6f %s\n' % (
start_time, time_unit * (n_timestep * t + i), idx2voca[prev_chord]))
start_time = time_unit * (n_timestep * t + i)
prev_chord = prediction[i].item()
if t == num_instance - 1 and i + num_pad == n_timestep:
if start_time != time_unit * (n_timestep * t + i):
lines.append(
'%.6f %.6f %s\n' % (
start_time, time_unit * (n_timestep * t + i), idx2voca[prev_chord]))
break
pid = os.getpid()
tmp_path = 'tmp_' + str(pid) + '.lab'
with open(tmp_path, 'w') as f:
for line in lines:
f.write(line)
for m in metrics_.score_metrics:
metrics_.score_list_dict[m].append(metrics_.score(metric=m, gt_path=lab_file_path, est_path=tmp_path))
song_length_list.append(song_length_second)
if verbose:
for m in metrics_.score_metrics:
print('song name %s, %s score : %.4f' % (song_name, m, metrics_.score_list_dict[m][-1]))
except:
print('song name %s\' lab file error' % song_name)
tmp = song_length_list / np.sum(song_length_list)
for m in metrics_.score_metrics:
metrics_.average_score[m] = np.sum(np.multiply(metrics_.score_list_dict[m], tmp))
return metrics_.score_list_dict, song_length_list, metrics_.average_score
def large_voca_score_calculation_crf(valid_dataset, config, mean, std, device, pre_model, model, model_type, verbose=False):
idx2voca = idx2voca_chord()
valid_song_names = valid_dataset.song_names
paths = valid_dataset.preprocessor.get_all_files()
metrics_ = metrics()
song_length_list = list()
for path in paths:
song_name, lab_file_path, mp3_file_path, _ = path
if not song_name in valid_song_names:
continue
try:
n_timestep = config.model['timestep']
feature, feature_per_second, song_length_second = audio_file_to_features(mp3_file_path, config)
feature = feature.T
feature = (feature - mean) / std
time_unit = feature_per_second
num_pad = n_timestep - (feature.shape[0] % n_timestep)
feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0)
num_instance = feature.shape[0] // n_timestep
start_time = 0.0
lines = []
with torch.no_grad():
model.eval()
feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device)
for t in range(num_instance):
if (model_type == 'cnn') or (model_type == 'crnn') or (model_type == 'btc'):
logits = pre_model(feature[:, n_timestep * t:n_timestep * (t + 1), :], torch.randint(config.model['num_chords'], (n_timestep,)).to(device))
prediction, _ = model(logits, torch.randint(config.model['num_chords'], (n_timestep,)).to(device))
else:
raise NotImplementedError
for i in range(n_timestep):
if t == 0 and i == 0:
prev_chord = prediction[i].item()
continue
if prediction[i].item() != prev_chord:
lines.append(
'%.6f %.6f %s\n' % (
start_time, time_unit * (n_timestep * t + i), idx2voca[prev_chord]))
start_time = time_unit * (n_timestep * t + i)
prev_chord = prediction[i].item()
if t == num_instance - 1 and i + num_pad == n_timestep:
if start_time != time_unit * (n_timestep * t + i):
lines.append(
'%.6f %.6f %s\n' % (
start_time, time_unit * (n_timestep * t + i), idx2voca[prev_chord]))
break
pid = os.getpid()
tmp_path = 'tmp_' + str(pid) + '.lab'
with open(tmp_path, 'w') as f:
for line in lines:
f.write(line)
for m in metrics_.score_metrics:
metrics_.score_list_dict[m].append(metrics_.score(metric=m, gt_path=lab_file_path, est_path=tmp_path))
song_length_list.append(song_length_second)
if verbose:
for m in metrics_.score_metrics:
print('song name %s, %s score : %.4f' % (song_name, m, metrics_.score_list_dict[m][-1]))
except:
print('song name %s\' lab file error' % song_name)
tmp = song_length_list / np.sum(song_length_list)
for m in metrics_.score_metrics:
metrics_.average_score[m] = np.sum(np.multiply(metrics_.score_list_dict[m], tmp))
return metrics_.score_list_dict, song_length_list, metrics_.average_score
================================================
FILE: utils/preprocess.py
================================================
import os
import librosa
from utils.chords import Chords
import re
from enum import Enum
import pyrubberband as pyrb
import torch
import math
class FeatureTypes(Enum):
cqt = 'cqt'
class Preprocess():
def __init__(self, config, feature_to_use, dataset_names, root_dir):
self.config = config
self.dataset_names = dataset_names
self.root_path = root_dir + '/'
self.time_interval = config.feature["hop_length"]/config.mp3["song_hz"]
self.no_of_chord_datapoints_per_sequence = math.ceil(config.mp3['inst_len'] / self.time_interval)
self.Chord_class = Chords()
# isophonic
self.isophonic_directory = self.root_path + 'isophonic/'
# uspop
self.uspop_directory = self.root_path + 'uspop/'
self.uspop_audio_path = 'audio/'
self.uspop_lab_path = 'annotations/uspopLabels/'
self.uspop_index_path = 'annotations/uspopLabels.txt'
# robbie williams
self.robbie_williams_directory = self.root_path + 'robbiewilliams/'
self.robbie_williams_audio_path = 'audio/'
self.robbie_williams_lab_path = 'chords/'
self.feature_name = feature_to_use
self.is_cut_last_chord = False
def find_mp3_path(self, dirpath, word):
for filename in os.listdir(dirpath):
last_dir = dirpath.split("/")[-2]
if ".mp3" in filename:
tmp = filename.replace(".mp3", "")
tmp = tmp.replace(last_dir, "")
filename_lower = tmp.lower()
filename_lower = " ".join(re.findall("[a-zA-Z]+", filename_lower))
if word.lower().replace(" ", "") in filename_lower.replace(" ", ""):
return filename
def find_mp3_path_robbiewilliams(self, dirpath, word):
for filename in os.listdir(dirpath):
if ".mp3" in filename:
tmp = filename.replace(".mp3", "")
filename_lower = tmp.lower()
filename_lower = filename_lower.replace("robbie williams", "")
filename_lower = " ".join(re.findall("[a-zA-Z]+", filename_lower))
filename_lower = self.song_pre(filename_lower)
if self.song_pre(word.lower()).replace(" ", "") in filename_lower.replace(" ", ""):
return filename
def get_all_files(self):
res_list = []
# isophonic
if "isophonic" in self.dataset_names:
for dirpath, dirnames, filenames in os.walk(self.isophonic_directory):
if not dirnames:
for filename in filenames:
if ".lab" in filename:
tmp = filename.replace(".lab", "")
song_name = " ".join(re.findall("[a-zA-Z]+", tmp)).replace("CD", "")
mp3_path = self.find_mp3_path(dirpath, song_name)
res_list.append([song_name, os.path.join(dirpath, filename), os.path.join(dirpath, mp3_path),
os.path.join(self.root_path, "result", "isophonic")])
# uspop
if "uspop" in self.dataset_names:
with open(os.path.join(self.uspop_directory, self.uspop_index_path)) as f:
uspop_lab_list = f.readlines()
uspop_lab_list = [x.strip() for x in uspop_lab_list]
for lab_path in uspop_lab_list:
spl = lab_path.split('/')
lab_artist = self.uspop_pre(spl[2])
lab_title = self.uspop_pre(spl[4][3:-4])
lab_path = lab_path.replace('./uspopLabels/', '')
lab_path = os.path.join(self.uspop_directory, self.uspop_lab_path, lab_path)
for filename in os.listdir(os.path.join(self.uspop_directory, self.uspop_audio_path)):
if not '.csv' in filename:
spl = filename.split('-')
mp3_artist = self.uspop_pre(spl[0])
mp3_title = self.uspop_pre(spl[1][:-4])
if lab_artist == mp3_artist and lab_title == mp3_title:
res_list.append([mp3_artist + mp3_title, lab_path,
os.path.join(self.uspop_directory, self.uspop_audio_path, filename),
os.path.join(self.root_path, "result", "uspop")])
break
# robbie williams
if "robbiewilliams" in self.dataset_names:
for dirpath, dirnames, filenames in os.walk(self.robbie_williams_directory):
if not dirnames:
for filename in filenames:
if ".txt" in filename and (not 'README' in filename):
tmp = filename.replace(".txt", "")
song_name = " ".join(re.findall("[a-zA-Z]+", tmp)).replace("GTChords", "")
mp3_dir = dirpath.replace("chords", "audio")
mp3_path = self.find_mp3_path_robbiewilliams(mp3_dir, song_name)
res_list.append([song_name, os.path.join(dirpath, filename), os.path.join(mp3_dir, mp3_path),
os.path.join(self.root_path, "result", "robbiewilliams")])
return res_list
def uspop_pre(self, text):
text = text.lower()
text = text.replace('_', '')
text = text.replace(' ', '')
text = " ".join(re.findall("[a-zA-Z]+", text))
return text
def song_pre(self, text):
to_remove = ["'", '`', '(', ')', ' ', '&', 'and', 'And']
for remove in to_remove:
text = text.replace(remove, '')
return text
def config_to_folder(self):
mp3_config = self.config.mp3
feature_config = self.config.feature
mp3_string = "%d_%.1f_%.1f" % \
(mp3_config['song_hz'], mp3_config['inst_len'],
mp3_config['skip_interval'])
feature_string = "%s_%d_%d_%d" % \
(self.feature_name.value, feature_config['n_bins'], feature_config['bins_per_octave'], feature_config['hop_length'])
return mp3_config, feature_config, mp3_string, feature_string
def generate_labels_features_new(self, all_list):
pid = os.getpid()
mp3_config, feature_config, mp3_str, feature_str = self.config_to_folder()
i = 0 # number of songs
j = 0 # number of impossible songs
k = 0 # number of tried songs
total = 0 # number of generated instances
stretch_factors = [1.0]
shift_factors = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6]
loop_broken = False
for song_name, lab_path, mp3_path, save_path in all_list:
# different song initialization
if loop_broken:
loop_broken = False
i += 1
print(pid, "generating features from ...", os.path.join(mp3_path))
if i % 10 == 0:
print(i, ' th song')
original_wav, sr = librosa.load(os.path.join(mp3_path), sr=mp3_config['song_hz'])
# make result path if not exists
# save_path, mp3_string, feature_string, song_name, aug.pt
result_path = os.path.join(save_path, mp3_str, feature_str, song_name.strip())
if not os.path.exists(result_path):
os.makedirs(result_path)
# calculate result
for stretch_factor in stretch_factors:
if loop_broken:
loop_broken = False
break
for shift_factor in shift_factors:
# for filename
idx = 0
chord_info = self.Chord_class.get_converted_chord(os.path.join(lab_path))
k += 1
# stretch original sound and chord info
x = pyrb.time_stretch(original_wav, sr, stretch_factor)
x = pyrb.pitch_shift(x, sr, shift_factor)
audio_length = x.shape[0]
chord_info['start'] = chord_info['start'] * 1/stretch_factor
chord_info['end'] = chord_info['end'] * 1/stretch_factor
last_sec = chord_info.iloc[-1]['end']
last_sec_hz = int(last_sec * mp3_config['song_hz'])
if audio_length + mp3_config['skip_interval'] < last_sec_hz:
print('loaded song is too short :', song_name)
loop_broken = True
j += 1
break
elif audio_length > last_sec_hz:
x = x[:last_sec_hz]
origin_length = last_sec_hz
origin_length_in_sec = origin_length / mp3_config['song_hz']
current_start_second = 0
# get chord list between current_start_second and current+song_length
while current_start_second + mp3_config['inst_len'] < origin_length_in_sec:
inst_start_sec = current_start_second
curSec = current_start_second
chord_list = []
# extract chord per 1/self.time_interval
while curSec < inst_start_sec + mp3_config['inst_len']:
try:
available_chords = chord_info.loc[(chord_info['start'] <= curSec) & (
chord_info['end'] > curSec + self.time_interval)].copy()
if len(available_chords) == 0:
available_chords = chord_info.loc[((chord_info['start'] >= curSec) & (
chord_info['start'] <= curSec + self.time_interval)) | (
(chord_info['end'] >= curSec) & (
chord_info['end'] <= curSec + self.time_interval))].copy()
if len(available_chords) == 1:
chord = available_chords['chord_id'].iloc[0]
elif len(available_chords) > 1:
max_starts = available_chords.apply(lambda row: max(row['start'], curSec),
axis=1)
available_chords['max_start'] = max_starts
min_ends = available_chords.apply(
lambda row: min(row.end, curSec + self.time_interval), axis=1)
available_chords['min_end'] = min_ends
chords_lengths = available_chords['min_end'] - available_chords['max_start']
available_chords['chord_length'] = chords_lengths
chord = available_chords.ix[available_chords['chord_length'].idxmax()]['chord_id']
else:
chord = 24
except Exception as e:
chord = 24
print(e)
print(pid, "no chord")
raise RuntimeError()
finally:
# convert chord by shift factor
if chord != 24:
chord += shift_factor * 2
chord = chord % 24
chord_list.append(chord)
curSec += self.time_interval
if len(chord_list) == self.no_of_chord_datapoints_per_sequence:
try:
sequence_start_time = current_start_second
sequence_end_time = current_start_second + mp3_config['inst_len']
start_index = int(sequence_start_time * mp3_config['song_hz'])
end_index = int(sequence_end_time * mp3_config['song_hz'])
song_seq = x[start_index:end_index]
etc = '%.1f_%.1f' % (
current_start_second, current_start_second + mp3_config['inst_len'])
aug = '%.2f_%i' % (stretch_factor, shift_factor)
if self.feature_name == FeatureTypes.cqt:
# print(pid, "make feature")
feature = librosa.cqt(song_seq, sr=sr, n_bins=feature_config['n_bins'],
bins_per_octave=feature_config['bins_per_octave'],
hop_length=feature_config['hop_length'])
else:
raise NotImplementedError
if feature.shape[1] > self.no_of_chord_datapoints_per_sequence:
feature = feature[:, :self.no_of_chord_datapoints_per_sequence]
if feature.shape[1] != self.no_of_chord_datapoints_per_sequence:
print('loaded features length is too short :', song_name)
loop_broken = True
j += 1
break
result = {
'feature': feature,
'chord': chord_list,
'etc': etc
}
# save_path, mp3_string, feature_string, song_name, aug.pt
filename = aug + "_" + str(idx) + ".pt"
torch.save(result, os.path.join(result_path, filename))
idx += 1
total += 1
except Exception as e:
print(e)
print(pid, "feature error")
raise RuntimeError()
else:
print("invalid number of chord datapoints in sequence :", len(chord_list))
current_start_second += mp3_config['skip_interval']
print(pid, "total instances: %d" % total)
def generate_labels_features_voca(self, all_list):
pid = os.getpid()
mp3_config, feature_config, mp3_str, feature_str = self.config_to_folder()
i = 0 # number of songs
j = 0 # number of impossible songs
k = 0 # number of tried songs
total = 0 # number of generated instances
stretch_factors = [1.0]
shift_factors = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6]
loop_broken = False
for song_name, lab_path, mp3_path, save_path in all_list:
save_path = save_path + '_voca'
# different song initialization
if loop_broken:
loop_broken = False
i += 1
print(pid, "generating features from ...", os.path.join(mp3_path))
if i % 10 == 0:
print(i, ' th song')
original_wav, sr = librosa.load(os.path.join(mp3_path), sr=mp3_config['song_hz'])
# save_path, mp3_string, feature_string, song_name, aug.pt
result_path = os.path.join(save_path, mp3_str, feature_str, song_name.strip())
if not os.path.exists(result_path):
os.makedirs(result_path)
# calculate result
for stretch_factor in stretch_factors:
if loop_broken:
loop_broken = False
break
for shift_factor in shift_factors:
# for filename
idx = 0
try:
chord_info = self.Chord_class.get_converted_chord_voca(os.path.join(lab_path))
except Exception as e:
print(e)
print(pid, " chord lab file error : %s" % song_name)
loop_broken = True
j += 1
break
k += 1
# stretch original sound and chord info
x = pyrb.time_stretch(original_wav, sr, stretch_factor)
x = pyrb.pitch_shift(x, sr, shift_factor)
audio_length = x.shape[0]
chord_info['start'] = chord_info['start'] * 1/stretch_factor
chord_info['end'] = chord_info['end'] * 1/stretch_factor
last_sec = chord_info.iloc[-1]['end']
last_sec_hz = int(last_sec * mp3_config['song_hz'])
if audio_length + mp3_config['skip_interval'] < last_sec_hz:
print('loaded song is too short :', song_name)
loop_broken = True
j += 1
break
elif audio_length > last_sec_hz:
x = x[:last_sec_hz]
origin_length = last_sec_hz
origin_length_in_sec = origin_length / mp3_config['song_hz']
current_start_second = 0
# get chord list between current_start_second and current+song_length
while current_start_second + mp3_config['inst_len'] < origin_length_in_sec:
inst_start_sec = current_start_second
curSec = current_start_second
chord_list = []
# extract chord per 1/self.time_interval
while curSec < inst_start_sec + mp3_config['inst_len']:
try:
available_chords = chord_info.loc[(chord_info['start'] <= curSec) & (chord_info['end'] > curSec + self.time_interval)].copy()
if len(available_chords) == 0:
available_chords = chord_info.loc[((chord_info['start'] >= curSec) & (chord_info['start'] <= curSec + self.time_interval)) | ((chord_info['end'] >= curSec) & (chord_info['end'] <= curSec + self.time_interval))].copy()
if len(available_chords) == 1:
chord = available_chords['chord_id'].iloc[0]
elif len(available_chords) > 1:
max_starts = available_chords.apply(lambda row: max(row['start'], curSec),axis=1)
available_chords['max_start'] = max_starts
min_ends = available_chords.apply(lambda row: min(row.end, curSec + self.time_interval), axis=1)
available_chords['min_end'] = min_ends
chords_lengths = available_chords['min_end'] - available_chords['max_start']
available_chords['chord_length'] = chords_lengths
chord = available_chords.ix[available_chords['chord_length'].idxmax()]['chord_id']
else:
chord = 169
except Exception as e:
chord = 169
print(e)
print(pid, "no chord")
raise RuntimeError()
finally:
# convert chord by shift factor
if chord != 169 and chord != 168:
chord += shift_factor * 14
chord = chord % 168
chord_list.append(chord)
curSec += self.time_interval
if len(chord_list) == self.no_of_chord_datapoints_per_sequence:
try:
sequence_start_time = current_start_second
sequence_end_time = current_start_second + mp3_config['inst_len']
start_index = int(sequence_start_time * mp3_config['song_hz'])
end_index = int(sequence_end_time * mp3_config['song_hz'])
song_seq = x[start_index:end_index]
etc = '%.1f_%.1f' % (
current_start_second, current_start_second + mp3_config['inst_len'])
aug = '%.2f_%i' % (stretch_factor, shift_factor)
if self.feature_name == FeatureTypes.cqt:
feature = librosa.cqt(song_seq, sr=sr, n_bins=feature_config['n_bins'],
bins_per_octave=feature_config['bins_per_octave'],
hop_length=feature_config['hop_length'])
else:
raise NotImplementedError
if feature.shape[1] > self.no_of_chord_datapoints_per_sequence:
feature = feature[:, :self.no_of_chord_datapoints_per_sequence]
if feature.shape[1] != self.no_of_chord_datapoints_per_sequence:
print('loaded features length is too short :', song_name)
loop_broken = True
j += 1
break
result = {
'feature': feature,
'chord': chord_list,
'etc': etc
}
# save_path, mp3_string, feature_string, song_name, aug.pt
filename = aug + "_" + str(idx) + ".pt"
torch.save(result, os.path.join(result_path, filename))
idx += 1
total += 1
except Exception as e:
print(e)
print(pid, "feature error")
raise RuntimeError()
else:
print("invalid number of chord datapoints in sequence :", len(chord_list))
current_start_second += mp3_config['skip_interval']
print(pid, "total instances: %d" % total)
================================================
FILE: utils/pytorch_utils.py
================================================
import torch
import numpy as np
import os
import math
from utils import logger
use_cuda = torch.cuda.is_available()
# optimization
# reference: http://pytorch.org/docs/master/_modules/torch/optim/lr_scheduler.html#ReduceLROnPlateau
def adjusting_learning_rate(optimizer, factor=.5, min_lr=0.00001):
for i, param_group in enumerate(optimizer.param_groups):
old_lr = float(param_group['lr'])
new_lr = max(old_lr * factor, min_lr)
param_group['lr'] = new_lr
logger.info('adjusting learning rate from %.6f to %.6f' % (old_lr, new_lr))
# model save and loading
def load_model(asset_path, model, optimizer, restore_epoch=0):
if os.path.isfile(os.path.join(asset_path, 'model', 'checkpoint_%d.pth.tar' % restore_epoch), map_location=lambda storage, loc: storage):
checkpoint = torch.load(os.path.join(asset_path, 'model', 'checkpoint_%d.pth.tar' % restore_epoch))
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
current_step = checkpoint['current_step']
logger.info("restore model with %d epoch" % restore_epoch)
else:
logger.info("no checkpoint with %d epoch" % restore_epoch)
current_step = 0
return model, optimizer, current_step
================================================
FILE: utils/tf_logger.py
================================================
import tensorflow as tf
import numpy as np
import scipy.misc
try:
from StringIO import StringIO # Python 2.7
except ImportError:
from io import BytesIO # Python 3.x
class TF_Logger(object):
def __init__(self, log_dir):
"""Create a summary writer logging to log_dir."""
self.writer = tf.summary.FileWriter(log_dir)
def scalar_summary(self, tag, value, step):
"""Log a scalar variable."""
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)
def image_summary(self, tag, images, step):
"""Log a list of images."""
img_summaries = []
for i, img in enumerate(images):
# Write the image to a string
try:
s = StringIO()
except:
s = BytesIO()
scipy.misc.toimage(img).save(s, format="png")
# Create an Image object
img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
height=img.shape[0],
width=img.shape[1])
# Create a Summary value
img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))
# Create and write Summary
summary = tf.Summary(value=img_summaries)
self.writer.add_summary(summary, step)
def histo_summary(self, tag, values, step, bins=1000):
"""Log a histogram of the tensor of values."""
# Create a histogram using numpy
counts, bin_edges = np.histogram(values, bins=bins)
# Fill the fields of the histogram proto
hist = tf.HistogramProto()
hist.min = float(np.min(values))
hist.max = float(np.max(values))
hist.num = int(np.prod(values.shape))
hist.sum = float(np.sum(values))
hist.sum_squares = float(np.sum(values ** 2))
# Drop the start of the first bin
bin_edges = bin_edges[1:]
# Add bin edges and counts
for edge in bin_edges:
hist.bucket_limit.append(edge)
for c in counts:
hist.bucket.append(c)
# Create and write Summary
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
self.writer.add_summary(summary, step)
self.writer.flush()
================================================
FILE: utils/transformer_modules.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
def _gen_bias_mask(max_length):
"""
Generates bias values (-Inf) to mask future timesteps during attention
"""
np_mask = np.triu(np.full([max_length, max_length], -np.inf), 1)
torch_mask = torch.from_numpy(np_mask).type(torch.FloatTensor)
return torch_mask.unsqueeze(0).unsqueeze(1)
def _gen_timing_signal(length, channels, min_timescale=1.0, max_timescale=1.0e4):
"""
Generates a [1, length, channels] timing signal consisting of sinusoids
Adapted from:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py
"""
position = np.arange(length)
num_timescales = channels // 2
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
(float(num_timescales) - 1))
inv_timescales = min_timescale * np.exp(
np.arange(num_timescales).astype(np.float) * -log_timescale_increment)
scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales, 0)
signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
signal = np.pad(signal, [[0, 0], [0, channels % 2]],
'constant', constant_values=[0.0, 0.0])
signal = signal.reshape([1, length, channels])
return torch.from_numpy(signal).type(torch.FloatTensor)
class LayerNorm(nn.Module):
# Borrowed from jekbradbury
# https://github.com/pytorch/pytorch/issues/1959
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.ones(features))
self.beta = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.gamma * (x - mean) / (std + self.eps) + self.beta
class OutputLayer(nn.Module):
"""
Abstract base class for output layer.
Handles projection to output labels
"""
def __init__(self, hidden_size, output_size, probs_out=False):
super(OutputLayer, self).__init__()
self.output_size = output_size
self.output_projection = nn.Linear(hidden_size, output_size)
self.probs_out = probs_out
self.lstm = nn.LSTM(input_size=hidden_size, hidden_size=int(hidden_size/2), batch_first=True, bidirectional=True)
self.hidden_size = hidden_size
def loss(self, hidden, labels):
raise NotImplementedError('Must implement {}.loss'.format(self.__class__.__name__))
class SoftmaxOutputLayer(OutputLayer):
"""
Implements a softmax based output layer
"""
def forward(self, hidden):
logits = self.output_projection(hidden)
probs = F.softmax(logits, -1)
# _, predictions = torch.max(probs, dim=-1)
topk, indices = torch.topk(probs, 2)
predictions = indices[:,:,0]
second = indices[:,:,1]
if self.probs_out is True:
return logits
# return probs
return predictions, second
def loss(self, hidden, labels):
logits = self.output_projection(hidden)
log_probs = F.log_softmax(logits, -1)
return F.nll_loss(log_probs.view(-1, self.output_size), labels.view(-1))
class MultiHeadAttention(nn.Module):
"""
Multi-head attention as per https://arxiv.org/pdf/1706.03762.pdf
Refer Figure 2
"""
def __init__(self, input_depth, total_key_depth, total_value_depth, output_depth,
num_heads, bias_mask=None, dropout=0.0, attention_map=False):
"""
Parameters:
input_depth: Size of last dimension of input
total_key_depth: Size of last dimension of keys. Must be divisible by num_head
total_value_depth: Size of last dimension of values. Must be divisible by num_head
output_depth: Size last dimension of the final output
num_heads: Number of attention heads
bias_mask: Masking tensor to prevent connections to future elements
dropout: Dropout probability (Should be non-zero only during training)
"""
super(MultiHeadAttention, self).__init__()
# Checks borrowed from
# https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py
if total_key_depth % num_heads != 0:
raise ValueError("Key depth (%d) must be divisible by the number of "
"attention heads (%d)." % (total_key_depth, num_heads))
if total_value_depth % num_heads != 0:
raise ValueError("Value depth (%d) must be divisible by the number of "
"attention heads (%d)." % (total_value_depth, num_heads))
self.attention_map = attention_map
self.num_heads = num_heads
self.query_scale = (total_key_depth // num_heads) ** -0.5
self.bias_mask = bias_mask
# Key and query depth will be same
self.query_linear = nn.Linear(input_depth, total_key_depth, bias=False)
self.key_linear = nn.Linear(input_depth, total_key_depth, bias=False)
self.value_linear = nn.Linear(input_depth, total_value_depth, bias=False)
self.output_linear = nn.Linear(total_value_depth, output_depth, bias=False)
self.dropout = nn.Dropout(dropout)
def _split_heads(self, x):
"""
Split x such to add an extra num_heads dimension
Input:
x: a Tensor with shape [batch_size, seq_length, depth]
Returns:
A Tensor with shape [batch_size, num_heads, seq_length, depth/num_heads]
"""
if len(x.shape) != 3:
raise ValueError("x must have rank 3")
shape = x.shape
return x.view(shape[0], shape[1], self.num_heads, shape[2] // self.num_heads).permute(0, 2, 1, 3)
def _merge_heads(self, x):
"""
Merge the extra num_heads into the last dimension
Input:
x: a Tensor with shape [batch_size, num_heads, seq_length, depth/num_heads]
Returns:
A Tensor with shape [batch_size, seq_length, depth]
"""
if len(x.shape) != 4:
raise ValueError("x must have rank 4")
shape = x.shape
return x.permute(0, 2, 1, 3).contiguous().view(shape[0], shape[2], shape[3] * self.num_heads)
def forward(self, queries, keys, values):
# Do a linear for each component
queries = self.query_linear(queries)
keys = self.key_linear(keys)
values = self.value_linear(values)
# Split into multiple heads
queries = self._split_heads(queries)
keys = self._split_heads(keys)
values = self._split_heads(values)
# Scale queries
queries *= self.query_scale
# Combine queries and keys
logits = torch.matmul(queries, keys.permute(0, 1, 3, 2))
# Add bias to mask future values
if self.bias_mask is not None:
logits += self.bias_mask[:, :, :logits.shape[-2], :logits.shape[-1]].type_as(logits.data)
# Convert to probabilites
weights = nn.functional.softmax(logits, dim=-1)
# Dropout
weights = self.dropout(weights)
# Combine with values to get context
contexts = torch.matmul(weights, values)
# Merge heads
contexts = self._merge_heads(contexts)
# contexts = torch.tanh(contexts)
# Linear to get output
outputs = self.output_linear(contexts)
if self.attention_map is True:
return outputs, weights
return outputs
class Conv(nn.Module):
"""
Convenience class that does padding and convolution for inputs in the format
[batch_size, sequence length, hidden size]
"""
def __init__(self, input_size, output_size, kernel_size, pad_type):
"""
Parameters:
input_size: Input feature size
output_size: Output feature size
kernel_size: Kernel width
pad_type: left -> pad on the left side (to mask future data_loader),
both -> pad on both sides
"""
super(Conv, self).__init__()
padding = (kernel_size - 1, 0) if pad_type == 'left' else (kernel_size // 2, (kernel_size - 1) // 2)
self.pad = nn.ConstantPad1d(padding, 0)
self.conv = nn.Conv1d(input_size, output_size, kernel_size=kernel_size, padding=0)
def forward(self, inputs):
inputs = self.pad(inputs.permute(0, 2, 1))
outputs = self.conv(inputs).permute(0, 2, 1)
return outputs
class PositionwiseFeedForward(nn.Module):
"""
Does a Linear + RELU + Linear on each of the timesteps
"""
def __init__(self, input_depth, filter_size, output_depth, layer_config='ll', padding='left', dropout=0.0):
"""
Parameters:
input_depth: Size of last dimension of input
filter_size: Hidden size of the middle layer
output_depth: Size last dimension of the final output
layer_config: ll -> linear + ReLU + linear
cc -> conv + ReLU + conv etc.
padding: left -> pad on the left side (to mask future data_loader),
both -> pad on both sides
dropout: Dropout probability (Should be non-zero only during training)
"""
super(PositionwiseFeedForward, self).__init__()
layers = []
sizes = ([(input_depth, filter_size)] +
[(filter_size, filter_size)] * (len(layer_config) - 2) +
[(filter_size, output_depth)])
for lc, s in zip(list(layer_config), sizes):
if lc == 'l':
layers.append(nn.Linear(*s))
elif lc == 'c':
layers.append(Conv(*s, kernel_size=3, pad_type=padding))
else:
raise ValueError("Unknown layer type {}".format(lc))
self.layers = nn.ModuleList(layers)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer(x)
if i < len(self.layers):
x = self.relu(x)
x = self.dropout(x)
return x