Repository: jayg996/BTC-ISMIR19 Branch: master Commit: 2682317be668 Files: 21 Total size: 23.4 MB Directory structure: gitextract_upyc_iog/ ├── LICENSE ├── README.md ├── audio_dataset.py ├── baseline_models.py ├── btc_model.py ├── crf_model.py ├── run_config.yaml ├── test/ │ ├── btc_model.pt │ └── btc_model_large_voca.pt ├── test.py ├── train.py ├── train_crf.py └── utils/ ├── __init__.py ├── chords.py ├── hparams.py ├── logger.py ├── mir_eval_modules.py ├── preprocess.py ├── pytorch_utils.py ├── tf_logger.py └── transformer_modules.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2019 Jonggwon Park Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # A Bi-Directional Transformer for Musical Chord Recognition This repository has the source codes for the paper "A Bi-Directional Transformer for Musical Chord Recognition"(ISMIR19). ## Requirements - pytorch >= 1.0.0 - numpy >= 1.16.2 - pandas >= 0.24.1 - pyrubberband >= 0.3.0 - librosa >= 0.6.3 - pyyaml >= 3.13 - mir_eval >= 0.5 - pretty_midi >= 0.2.8 ## File descriptions * `audio_dataset.py` : loads data and preprocesses label files to chord labels and mp3 files to constant-q transformation. * `btc_model.py` : contains pytorch implementation of BTC. * `train.py` : for training. * `crf_model.py` : contatins pytorch implementation of Conditional Random Fields (CRFs) . * `baseline_models.py` : contains the codes of baseline models. * `train_crf.py` : for training CRFs. * `run_config.yaml` : includes hyper parameters and paths that are needed. * `test.py` : for recognizing chord from audio file. ## Using BTC : Recognizing chords from files in audio directory ### Using BTC from command line ```bash $ python test.py --audio_dir audio_folder --save_dir save_folder --voca False ``` * audio_dir : a folder of audio files for chord recognition (default: './test') * save_dir : a forder for saving recognition results (default: './test') * voca : False means major and minor label type, and True means large vocabulary label type (default: False) The resulting files are lab files of the form shown below and midi files. ## Attention Map The figures represent the probability values of the attention of self-attention layers 1, 3, 5 and 8 respectively. The layers that best represent the different characteristics of each layers were chosen. The input audio is the song "Just A Girl" (0m30s ~ 0m40s) by No Doubt from UsPop2002, which was in evaluation data. ## Data We used Isophonics[1], Robbie Williams[2], UsPop2002[3] dataset which consists of chord label files. Due to copyright issue, these datasets do not include audio files. The audio files used in this work were collected from online music service providers. [1] http://isophonics.net/datasets [2] B. Di Giorgi, M. Zanoni, A. Sarti, and S. Tubaro. Automatic chord recognition based on the probabilistic modeling of diatonic modal harmony. In Proc. of the 8th International Workshop on Multidimensional Systems, Erlangen, Germany, 2013. [3] https://github.com/tmc323/Chord-Annotations ## Reference * pytorch implementation of Transformer and Crf: https://github.com/kolloldas/torchnlp ## Comments * Any comments for the codes are always welcome. ================================================ FILE: audio_dataset.py ================================================ import numpy as np import os import torch from torch.utils.data import Dataset, DataLoader from utils.preprocess import Preprocess, FeatureTypes import math from multiprocessing import Pool from sortedcontainers import SortedList class AudioDataset(Dataset): def __init__(self, config, root_dir='/data/music/chord_recognition', dataset_names=('isophonic',), featuretype=FeatureTypes.cqt, num_workers=20, train=False, preprocessing=False, resize=None, kfold=4): super(AudioDataset, self).__init__() self.config = config self.root_dir = root_dir self.dataset_names = dataset_names self.preprocessor = Preprocess(config, featuretype, dataset_names, self.root_dir) self.resize = resize self.train = train self.ratio = config.experiment['data_ratio'] # preprocessing hyperparameters # song_hz, n_bins, bins_per_octave, hop_length mp3_config = config.mp3 feature_config = config.feature self.mp3_string = "%d_%.1f_%.1f" % \ (mp3_config['song_hz'], mp3_config['inst_len'], mp3_config['skip_interval']) self.feature_string = "%s_%d_%d_%d" % \ (featuretype.value, feature_config['n_bins'], feature_config['bins_per_octave'], feature_config['hop_length']) if feature_config['large_voca'] == True: # store paths if exists is_preprocessed = True if os.path.exists(os.path.join(root_dir, 'result', dataset_names[0]+'_voca', self.mp3_string, self.feature_string)) else False if (not is_preprocessed) | preprocessing: midi_paths = self.preprocessor.get_all_files() if num_workers > 1: num_path_per_process = math.ceil(len(midi_paths) / num_workers) args = [midi_paths[i * num_path_per_process:(i + 1) * num_path_per_process] for i in range(num_workers)] # start process p = Pool(processes=num_workers) p.map(self.preprocessor.generate_labels_features_voca, args) p.close() else: self.preprocessor.generate_labels_features_voca(midi_paths) # kfold is 5 fold index ( 0, 1, 2, 3, 4 ) self.song_names, self.paths = self.get_paths_voca(kfold=kfold) else: # store paths if exists is_preprocessed = True if os.path.exists(os.path.join(root_dir, 'result', dataset_names[0], self.mp3_string, self.feature_string)) else False if (not is_preprocessed) | preprocessing: midi_paths = self.preprocessor.get_all_files() if num_workers > 1: num_path_per_process = math.ceil(len(midi_paths) / num_workers) args = [midi_paths[i * num_path_per_process:(i + 1) * num_path_per_process] for i in range(num_workers)] # start process p = Pool(processes=num_workers) p.map(self.preprocessor.generate_labels_features_new, args) p.close() else: self.preprocessor.generate_labels_features_new(midi_paths) # kfold is 5 fold index ( 0, 1, 2, 3, 4 ) self.song_names, self.paths = self.get_paths(kfold=kfold) def __len__(self): return len(self.paths) def __getitem__(self, idx): instance_path = self.paths[idx] res = dict() data = torch.load(instance_path) res['feature'] = np.log(np.abs(data['feature']) + 1e-6) res['chord'] = data['chord'] return res def get_paths(self, kfold=4): temp = {} used_song_names = list() for name in self.dataset_names: dataset_path = os.path.join(self.root_dir, "result", name, self.mp3_string, self.feature_string) song_names = os.listdir(dataset_path) for song_name in song_names: paths = [] instance_names = os.listdir(os.path.join(dataset_path, song_name)) if len(instance_names) > 0: used_song_names.append(song_name) for instance_name in instance_names: paths.append(os.path.join(dataset_path, song_name, instance_name)) temp[song_name] = paths # throw away unused song names song_names = used_song_names song_names = SortedList(song_names) print('Total used song length : %d' %len(song_names)) tmp = [] for i in range(len(song_names)): tmp += temp[song_names[i]] print('Total instances (train and valid) : %d' %len(tmp)) # divide train/valid dataset using k fold result = [] total_fold = 5 quotient = len(song_names) // total_fold remainder = len(song_names) % total_fold fold_num = [0] for i in range(total_fold): fold_num.append(quotient) for i in range(remainder): fold_num[i+1] += 1 for i in range(total_fold): fold_num[i+1] += fold_num[i] if self.train: tmp = [] # get not augmented data for k in range(total_fold): if k != kfold: for i in range(fold_num[k], fold_num[k+1]): result += temp[song_names[i]] tmp += song_names[fold_num[k]:fold_num[k + 1]] song_names = tmp else: for i in range(fold_num[kfold], fold_num[kfold+1]): instances = temp[song_names[i]] instances = [inst for inst in instances if "1.00_0" in inst] result += instances song_names = song_names[fold_num[kfold]:fold_num[kfold+1]] return song_names, result def get_paths_voca(self, kfold=4): temp = {} used_song_names = list() for name in self.dataset_names: dataset_path = os.path.join(self.root_dir, "result", name+'_voca', self.mp3_string, self.feature_string) song_names = os.listdir(dataset_path) for song_name in song_names: paths = [] instance_names = os.listdir(os.path.join(dataset_path, song_name)) if len(instance_names) > 0: used_song_names.append(song_name) for instance_name in instance_names: paths.append(os.path.join(dataset_path, song_name, instance_name)) temp[song_name] = paths # throw away unused song names song_names = used_song_names song_names = SortedList(song_names) print('Total used song length : %d' %len(song_names)) tmp = [] for i in range(len(song_names)): tmp += temp[song_names[i]] print('Total instances (train and valid) : %d' %len(tmp)) # divide train/valid dataset using k fold result = [] total_fold = 5 quotient = len(song_names) // total_fold remainder = len(song_names) % total_fold fold_num = [0] for i in range(total_fold): fold_num.append(quotient) for i in range(remainder): fold_num[i+1] += 1 for i in range(total_fold): fold_num[i+1] += fold_num[i] if self.train: tmp = [] # get not augmented data for k in range(total_fold): if k != kfold: for i in range(fold_num[k], fold_num[k+1]): result += temp[song_names[i]] tmp += song_names[fold_num[k]:fold_num[k + 1]] song_names = tmp else: for i in range(fold_num[kfold], fold_num[kfold+1]): instances = temp[song_names[i]] instances = [inst for inst in instances if "1.00_0" in inst] result += instances song_names = song_names[fold_num[kfold]:fold_num[kfold+1]] return song_names, result def _collate_fn(batch): batch_size = len(batch) max_len = batch[0]['feature'].shape[1] input_percentages = torch.empty(batch_size) # for variable length chord_lens = torch.empty(batch_size, dtype=torch.int64) chords = [] collapsed_chords = [] features = [] boundaries = [] for i in range(batch_size): sample = batch[i] feature = sample['feature'] chord = sample['chord'] diff = np.diff(chord, axis=0).astype(np.bool) idx = np.insert(diff, 0, True, axis=0) chord_lens[i] = np.sum(idx).item(0) chords.extend(chord) features.append(feature) input_percentages[i] = feature.shape[1] / max_len collapsed_chords.extend(np.array(chord)[idx].tolist()) boundary = np.append([0], diff) boundaries.extend(boundary.tolist()) features = torch.tensor(features, dtype=torch.float32).unsqueeze(1) # batch_size*1*feature_size*max_len chords = torch.tensor(chords, dtype=torch.int64) # (batch_size*time_length) collapsed_chords = torch.tensor(collapsed_chords, dtype=torch.int64) # total_unique_chord_len boundaries = torch.tensor(boundaries, dtype=torch.uint8) # (batch_size*time_length) return features, input_percentages, chords, collapsed_chords, chord_lens, boundaries class AudioDataLoader(DataLoader): def __init__(self, *args, **kwargs): super(AudioDataLoader, self).__init__(*args, **kwargs) self.collate_fn = _collate_fn ================================================ FILE: baseline_models.py ================================================ from utils.hparams import HParams import torch import torch.nn as nn import torch.nn.functional as F import time from crf_model import CRF use_cuda = torch.cuda.is_available() class CNN(nn.Module): def __init__(self,config): super(CNN, self).__init__() self.timestep = config['timestep'] self.context = 7 self.pad = nn.ConstantPad1d(self.context, 0) self.probs_out = config['probs_out'] self.num_chords = config['num_chords'] self.drop_out = nn.Dropout2d(p=0.5) self.conv1 = self.cnn_layers(1, 32, kernel_size=(3,3), padding=1) self.conv2 = self.cnn_layers(32, 32, kernel_size=(3,3), padding=1) self.conv3 = self.cnn_layers(32, 32, kernel_size=(3,3), padding=1) self.conv4 = self.cnn_layers(32, 32, kernel_size=(3,3), padding=1) self.pool_max = nn.MaxPool2d(kernel_size=(2,1)) self.conv5 = self.cnn_layers(32, 64, kernel_size=(3, 3), padding=0) self.conv6 = self.cnn_layers(64, 64, kernel_size=(3, 3), padding=0) self.conv7 = self.cnn_layers(64, 128, kernel_size=(12, 9), padding=0) self.conv_linear = nn.Conv2d(128, config['num_chords'], kernel_size=(1,1), padding=0) def cnn_layers(self, in_channels, out_channels, kernel_size, stride=1, padding=0): layers = [] conv2d = nn.Conv2d(in_channels, out_channels,kernel_size=kernel_size, stride=stride, padding=padding) batch_norm = nn.BatchNorm2d(out_channels) relu = nn.ReLU(inplace=True) layers += [conv2d, batch_norm, relu] return nn.Sequential(*layers) def forward(self, x, labels): x = x.permute(0,2,1) x = self.pad(x) batch_size = x.size(0) for i in range(batch_size): for j in range(self.timestep): if i == 0 and j == 0: inputs = x[i,:,j : j + self.context *2 + 1].unsqueeze(0) else: tmp = x[i, :, j : j + self.context *2 + 1].unsqueeze(0) inputs = torch.cat((inputs,tmp), dim=0) # inputs : [batchsize * timestep, feature_size, context] inputs = inputs.unsqueeze(1) conv = self.conv1(inputs) conv = self.conv2(conv) conv = self.conv3(conv) conv = self.conv4(conv) pooled = self.pool_max(conv) pooled = self.drop_out(pooled) conv = self.conv5(pooled) conv = self.conv6(conv) pooled = self.pool_max(conv) pooled = self.drop_out(pooled) conv = self.conv7(pooled) conv = self.drop_out(conv) conv = self.conv_linear(conv) avg_pool = nn.AvgPool2d(kernel_size=(conv.size(2), conv.size(3))) logits = avg_pool(conv).squeeze(2).squeeze(2) if self.probs_out is True: crf_input = logits.view(-1, self.timestep, self.num_chords) return crf_input log_probs = F.log_softmax(logits, -1) topk, indices = torch.topk(log_probs, 2) predictions = indices[:,0] second = indices[:,1] prediction = predictions.view(-1) second = second.view(-1) loss = F.nll_loss(log_probs.view(-1, self.num_chords), labels.view(-1)) return prediction, loss, 0, second class Crf(nn.Module): def __init__(self, num_chords, timestep): super(Crf, self).__init__() self.output_size = num_chords self.timestep = timestep self.Crf = CRF(self.output_size) def forward(self, probs, labels): prediction = self.Crf(probs) prediction = prediction.view(-1) labels = labels.view(-1, self.timestep) loss = self.Crf.loss(probs, labels) return prediction, loss class CRNN(nn.Module): def __init__(self,config): super(CRNN, self).__init__() self.feature_size = config['feature_size'] self.timestep = config['timestep'] self.probs_out = config['probs_out'] self.num_chords = config['num_chords'] self.hidden_size = 128 self.relu = nn.ReLU(inplace=True) self.batch_norm = nn.BatchNorm2d(1) self.conv1 = nn.Conv2d(1, 1, kernel_size=(5,5), padding=2) self.conv2 = nn.Conv2d(1, 36, kernel_size=(1,self.feature_size)) self.gru = nn.GRU(input_size=36, hidden_size=self.hidden_size, num_layers=2, batch_first=True, bidirectional=True) self.fc = nn.Linear(self.hidden_size*2, self.num_chords) def forward(self, x, labels): # x : [batchsize * timestep * feature_size] x = x.unsqueeze(1) x = self.batch_norm(x) conv = self.relu(self.conv1(x)) conv = self.relu(self.conv2(conv)) conv = conv.squeeze(3).permute(0,2,1) h0 = torch.zeros(4, conv.size(0), self.hidden_size).to(torch.device("cuda" if use_cuda else "cpu")) gru, h = self.gru(conv, h0) logits = self.fc(gru) if self.probs_out is True: # probs = F.softmax(logits, -1) return logits log_probs = F.log_softmax(logits, -1) topk, indices = torch.topk(log_probs, 2) predictions = indices[:,:,0] second = indices[:,:,1] prediction = predictions.view(-1) second = second.view(-1) loss = F.nll_loss(log_probs.view(-1, self.num_chords), labels.view(-1)) return prediction, loss, 0, second if __name__ == "__main__": config = HParams.load("run_config.yaml") device = torch.device("cuda" if use_cuda else "cpu") config.model['probs_out'] = True batch_size = 2 timestep = config.model['timestep'] feature_size = config.model['feature_size'] num_chords = config.model['num_chords'] features = torch.randn(batch_size,timestep,feature_size,requires_grad=True).to(device) chords = torch.randint(num_chords,(batch_size*timestep,)).to(device) model = CNN(config=config.model).to(device) crf = Crf(num_chords=config.model['num_chords'], timestep=config.model['timestep']).to(device) probs = model(features, chords) prediction, total_loss = crf(probs, chords) print(total_loss) ================================================ FILE: btc_model.py ================================================ from utils.transformer_modules import * from utils.transformer_modules import _gen_timing_signal, _gen_bias_mask from utils.hparams import HParams use_cuda = torch.cuda.is_available() class self_attention_block(nn.Module): def __init__(self, hidden_size, total_key_depth, total_value_depth, filter_size, num_heads, bias_mask=None, layer_dropout=0.0, attention_dropout=0.0, relu_dropout=0.0, attention_map=False): super(self_attention_block, self).__init__() self.attention_map = attention_map self.multi_head_attention = MultiHeadAttention(hidden_size, total_key_depth, total_value_depth,hidden_size, num_heads, bias_mask, attention_dropout, attention_map) self.positionwise_convolution = PositionwiseFeedForward(hidden_size, filter_size, hidden_size, layer_config='cc', padding='both', dropout=relu_dropout) self.dropout = nn.Dropout(layer_dropout) self.layer_norm_mha = LayerNorm(hidden_size) self.layer_norm_ffn = LayerNorm(hidden_size) def forward(self, inputs): x = inputs # Layer Normalization x_norm = self.layer_norm_mha(x) # Multi-head attention if self.attention_map is True: y, weights = self.multi_head_attention(x_norm, x_norm, x_norm) else: y = self.multi_head_attention(x_norm, x_norm, x_norm) # Dropout and residual x = self.dropout(x + y) # Layer Normalization x_norm = self.layer_norm_ffn(x) # Positionwise Feedforward y = self.positionwise_convolution(x_norm) # Dropout and residual y = self.dropout(x + y) if self.attention_map is True: return y, weights return y class bi_directional_self_attention(nn.Module): def __init__(self, hidden_size, total_key_depth, total_value_depth, filter_size, num_heads, max_length, layer_dropout=0.0, attention_dropout=0.0, relu_dropout=0.0): super(bi_directional_self_attention, self).__init__() self.weights_list = list() params = (hidden_size, total_key_depth or hidden_size, total_value_depth or hidden_size, filter_size, num_heads, _gen_bias_mask(max_length), layer_dropout, attention_dropout, relu_dropout, True) self.attn_block = self_attention_block(*params) params = (hidden_size, total_key_depth or hidden_size, total_value_depth or hidden_size, filter_size, num_heads, torch.transpose(_gen_bias_mask(max_length), dim0=2, dim1=3), layer_dropout, attention_dropout, relu_dropout, True) self.backward_attn_block = self_attention_block(*params) self.linear = nn.Linear(hidden_size*2, hidden_size) def forward(self, inputs): x, list = inputs # Forward Self-attention Block encoder_outputs, weights = self.attn_block(x) # Backward Self-attention Block reverse_outputs, reverse_weights = self.backward_attn_block(x) # Concatenation and Fully-connected Layer outputs = torch.cat((encoder_outputs, reverse_outputs), dim=2) y = self.linear(outputs) # Attention weights for Visualization self.weights_list = list self.weights_list.append(weights) self.weights_list.append(reverse_weights) return y, self.weights_list class bi_directional_self_attention_layers(nn.Module): def __init__(self, embedding_size, hidden_size, num_layers, num_heads, total_key_depth, total_value_depth, filter_size, max_length=100, input_dropout=0.0, layer_dropout=0.0, attention_dropout=0.0, relu_dropout=0.0): super(bi_directional_self_attention_layers, self).__init__() self.timing_signal = _gen_timing_signal(max_length, hidden_size) params = (hidden_size, total_key_depth or hidden_size, total_value_depth or hidden_size, filter_size, num_heads, max_length, layer_dropout, attention_dropout, relu_dropout) self.embedding_proj = nn.Linear(embedding_size, hidden_size, bias=False) self.self_attn_layers = nn.Sequential(*[bi_directional_self_attention(*params) for l in range(num_layers)]) self.layer_norm = LayerNorm(hidden_size) self.input_dropout = nn.Dropout(input_dropout) def forward(self, inputs): # Add input dropout x = self.input_dropout(inputs) # Project to hidden size x = self.embedding_proj(x) # Add timing signal x += self.timing_signal[:, :inputs.shape[1], :].type_as(inputs.data) # A Stack of Bi-directional Self-attention Layers y, weights_list = self.self_attn_layers((x, [])) # Layer Normalization y = self.layer_norm(y) return y, weights_list class BTC_model(nn.Module): def __init__(self, config): super(BTC_model, self).__init__() self.timestep = config['timestep'] self.probs_out = config['probs_out'] params = (config['feature_size'], config['hidden_size'], config['num_layers'], config['num_heads'], config['total_key_depth'], config['total_value_depth'], config['filter_size'], config['timestep'], config['input_dropout'], config['layer_dropout'], config['attention_dropout'], config['relu_dropout']) self.self_attn_layers = bi_directional_self_attention_layers(*params) self.output_layer = SoftmaxOutputLayer(hidden_size=config['hidden_size'], output_size=config['num_chords'], probs_out=config['probs_out']) def forward(self, x, labels): labels = labels.view(-1, self.timestep) # Output of Bi-directional Self-attention Layers self_attn_output, weights_list = self.self_attn_layers(x) # return logit values for CRF if self.probs_out is True: logits = self.output_layer(self_attn_output) return logits # Output layer and Soft-max prediction,second = self.output_layer(self_attn_output) prediction = prediction.view(-1) second = second.view(-1) # Loss Calculation loss = self.output_layer.loss(self_attn_output, labels) return prediction, loss, weights_list, second if __name__ == "__main__": config = HParams.load("run_config.yaml") device = torch.device("cuda" if use_cuda else "cpu") batch_size = 2 timestep = 108 feature_size = 144 num_chords = 25 features = torch.randn(batch_size,timestep,feature_size,requires_grad=True).to(device) chords = torch.randint(25,(batch_size*timestep,)).to(device) model = BTC_model(config=config.model).to(device) prediction, loss, weights_list, second = model(features, chords) print(prediction.size()) print(loss) ================================================ FILE: crf_model.py ================================================ from __future__ import absolute_import from __future__ import division from __future__ import print_function import torch import torch.nn as nn class CRF(nn.Module): """ Implements Conditional Random Fields that can be trained via backpropagation. """ def __init__(self, num_tags): super(CRF, self).__init__() self.num_tags = num_tags self.transitions = nn.Parameter(torch.Tensor(num_tags, num_tags)) self.start_transitions = nn.Parameter(torch.randn(num_tags)) self.stop_transitions = nn.Parameter(torch.randn(num_tags)) nn.init.xavier_normal_(self.transitions) def forward(self, feats): # Shape checks if len(feats.shape) != 3: raise ValueError("feats must be 3-d got {}-d".format(feats.shape)) return self._viterbi(feats) def loss(self, feats, tags): """ Computes negative log likelihood between features and tags. Essentially difference between individual sequence scores and sum of all possible sequence scores (partition function) Parameters: feats: Input features [batch size, sequence length, number of tags] tags: Target tag indices [batch size, sequence length]. Should be between 0 and num_tags Returns: Negative log likelihood [a scalar] """ # Shape checks if len(feats.shape) != 3: raise ValueError("feats must be 3-d got {}-d".format(feats.shape)) if len(tags.shape) != 2: raise ValueError('tags must be 2-d but got {}-d'.format(tags.shape)) if feats.shape[:2] != tags.shape: raise ValueError('First two dimensions of feats and tags must match') sequence_score = self._sequence_score(feats, tags) partition_function = self._partition_function(feats) log_probability = sequence_score - partition_function # -ve of l() # Average across batch return -log_probability.mean() def _sequence_score(self, feats, tags): """ Parameters: feats: Input features [batch size, sequence length, number of tags] tags: Target tag indices [batch size, sequence length]. Should be between 0 and num_tags Returns: Sequence score of shape [batch size] """ batch_size = feats.shape[0] # Compute feature scores feat_score = feats.gather(2, tags.unsqueeze(-1)).squeeze(-1).sum(dim=-1) # Compute transition scores # Unfold to get [from, to] tag index pairs tags_pairs = tags.unfold(1, 2, 1) # Use advanced indexing to pull out required transition scores indices = tags_pairs.permute(2, 0, 1).chunk(2) trans_score = self.transitions[indices].squeeze(0).sum(dim=-1) # Compute start and stop scores start_score = self.start_transitions[tags[:, 0]] stop_score = self.stop_transitions[tags[:, -1]] return feat_score + start_score + trans_score + stop_score def _partition_function(self, feats): """ Computes the partitition function for CRF using the forward algorithm. Basically calculate scores for all possible tag sequences for the given feature vector sequence Parameters: feats: Input features [batch size, sequence length, number of tags] Returns: Total scores of shape [batch size] """ _, seq_size, num_tags = feats.shape if self.num_tags != num_tags: raise ValueError('num_tags should be {} but got {}'.format(self.num_tags, num_tags)) a = feats[:, 0] + self.start_transitions.unsqueeze(0) # [batch_size, num_tags] transitions = self.transitions.unsqueeze(0) # [1, num_tags, num_tags] from -> to for i in range(1, seq_size): feat = feats[:, i].unsqueeze(1) # [batch_size, 1, num_tags] a = self._log_sum_exp(a.unsqueeze(-1) + transitions + feat, 1) # [batch_size, num_tags] return self._log_sum_exp(a + self.stop_transitions.unsqueeze(0), 1) # [batch_size] def _viterbi(self, feats): """ Uses Viterbi algorithm to predict the best sequence Parameters: feats: Input features [batch size, sequence length, number of tags] Returns: Best tag sequence [batch size, sequence length] """ _, seq_size, num_tags = feats.shape if self.num_tags != num_tags: raise ValueError('num_tags should be {} but got {}'.format(self.num_tags, num_tags)) v = feats[:, 0] + self.start_transitions.unsqueeze(0) # [batch_size, num_tags] transitions = self.transitions.unsqueeze(0) # [1, num_tags, num_tags] from -> to paths = [] for i in range(1, seq_size): feat = feats[:, i] # [batch_size, num_tags] v, idx = (v.unsqueeze(-1) + transitions).max(1) # [batch_size, num_tags], [batch_size, num_tags] paths.append(idx) v = (v + feat) # [batch_size, num_tags] v, tag = (v + self.stop_transitions.unsqueeze(0)).max(1, True) # Backtrack tags = [tag] for idx in reversed(paths): tag = idx.gather(1, tag) tags.append(tag) tags.reverse() return torch.cat(tags, 1) def _log_sum_exp(self, logits, dim): """ Computes log-sum-exp in a stable way """ max_val, _ = logits.max(dim) return max_val + (logits - max_val.unsqueeze(dim)).exp().sum(dim).log() ================================================ FILE: run_config.yaml ================================================ mp3: song_hz: 22050 inst_len: 10.0 skip_interval: 5.0 feature: n_bins: 144 bins_per_octave: 24 hop_length: 2048 large_voca: False # large_voca: True experiment: learning_rate : 0.0001 weight_decay : 0.0 max_epoch : 100 batch_size : 128 save_step : 40 data_ratio : 0.8 model: feature_size : 144 timestep : 108 num_chords : 25 # num_chords : 170 input_dropout : 0.2 layer_dropout : 0.2 attention_dropout : 0.2 relu_dropout : 0.2 num_layers : 8 num_heads : 4 hidden_size : 128 total_key_depth : 128 total_value_depth : 128 filter_size : 128 loss : 'ce' probs_out : False path: ckpt_path : 'model' result_path : 'result' asset_path : '/data/music/chord_recognition/jayg996/assets' root_path : '/data/music/chord_recognition' ================================================ FILE: test/btc_model.pt ================================================ [File too large to display: 11.6 MB] ================================================ FILE: test/btc_model_large_voca.pt ================================================ [File too large to display: 11.7 MB] ================================================ FILE: test.py ================================================ import os import mir_eval import pretty_midi as pm from utils import logger from btc_model import * from utils.mir_eval_modules import audio_file_to_features, idx2chord, idx2voca_chord, get_audio_paths import argparse import warnings warnings.filterwarnings('ignore') logger.logging_verbosity(1) use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") # hyperparameters parser = argparse.ArgumentParser() parser.add_argument('--voca', default=True, type=lambda x: (str(x).lower() == 'true')) parser.add_argument('--audio_dir', type=str, default='./test') parser.add_argument('--save_dir', type=str, default='./test') args = parser.parse_args() config = HParams.load("run_config.yaml") if args.voca is True: config.feature['large_voca'] = True config.model['num_chords'] = 170 model_file = './test/btc_model_large_voca.pt' idx_to_chord = idx2voca_chord() logger.info("label type: large voca") else: model_file = './test/btc_model.pt' idx_to_chord = idx2chord logger.info("label type: Major and minor") model = BTC_model(config=config.model).to(device) # Load model if os.path.isfile(model_file): checkpoint = torch.load(model_file) mean = checkpoint['mean'] std = checkpoint['std'] model.load_state_dict(checkpoint['model']) logger.info("restore model") # Audio files with format of wav and mp3 audio_paths = get_audio_paths(args.audio_dir) # Chord recognition and save lab file for i, audio_path in enumerate(audio_paths): logger.info("======== %d of %d in progress ========" % (i + 1, len(audio_paths))) # Load mp3 feature, feature_per_second, song_length_second = audio_file_to_features(audio_path, config) logger.info("audio file loaded and feature computation success : %s" % audio_path) # Majmin type chord recognition feature = feature.T feature = (feature - mean) / std time_unit = feature_per_second n_timestep = config.model['timestep'] num_pad = n_timestep - (feature.shape[0] % n_timestep) feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0) num_instance = feature.shape[0] // n_timestep start_time = 0.0 lines = [] with torch.no_grad(): model.eval() feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device) for t in range(num_instance): self_attn_output, _ = model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :]) prediction, _ = model.output_layer(self_attn_output) prediction = prediction.squeeze() for i in range(n_timestep): if t == 0 and i == 0: prev_chord = prediction[i].item() continue if prediction[i].item() != prev_chord: lines.append( '%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord])) start_time = time_unit * (n_timestep * t + i) prev_chord = prediction[i].item() if t == num_instance - 1 and i + num_pad == n_timestep: if start_time != time_unit * (n_timestep * t + i): lines.append('%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord])) break # lab file write if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) save_path = os.path.join(args.save_dir, os.path.split(audio_path)[-1].replace('.mp3', '').replace('.wav', '') + '.lab') with open(save_path, 'w') as f: for line in lines: f.write(line) logger.info("label file saved : %s" % save_path) # lab file to midi file starts, ends, pitchs = list(), list(), list() intervals, chords = mir_eval.io.load_labeled_intervals(save_path) for p in range(12): for i, (interval, chord) in enumerate(zip(intervals, chords)): root_num, relative_bitmap, _ = mir_eval.chord.encode(chord) tmp_label = mir_eval.chord.rotate_bitmap_to_root(relative_bitmap, root_num)[p] if i == 0: start_time = interval[0] label = tmp_label continue if tmp_label != label: if label == 1.0: starts.append(start_time), ends.append(interval[0]), pitchs.append(p + 48) start_time = interval[0] label = tmp_label if i == (len(intervals) - 1): if label == 1.0: starts.append(start_time), ends.append(interval[1]), pitchs.append(p + 48) midi = pm.PrettyMIDI() instrument = pm.Instrument(program=0) for start, end, pitch in zip(starts, ends, pitchs): pm_note = pm.Note(velocity=120, pitch=pitch, start=start, end=end) instrument.notes.append(pm_note) midi.instruments.append(instrument) midi.write(save_path.replace('.lab', '.midi')) ================================================ FILE: train.py ================================================ import os from torch import optim from utils import logger from audio_dataset import AudioDataset, AudioDataLoader from utils.tf_logger import TF_Logger from btc_model import * from baseline_models import CNN, CRNN from utils.hparams import HParams import argparse from utils.pytorch_utils import adjusting_learning_rate from utils.mir_eval_modules import root_majmin_score_calculation, large_voca_score_calculation import warnings warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=FutureWarning) logger.logging_verbosity(1) use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") parser = argparse.ArgumentParser() parser.add_argument('--index', type=int, help='Experiment Number', default='e') parser.add_argument('--kfold', type=int, help='5 fold (0,1,2,3,4)',default='e') parser.add_argument('--voca', type=bool, help='large voca is True', default=False) parser.add_argument('--model', type=str, help='btc, cnn, crnn', default='btc') parser.add_argument('--dataset1', type=str, help='Dataset', default='isophonic') parser.add_argument('--dataset2', type=str, help='Dataset', default='uspop') parser.add_argument('--dataset3', type=str, help='Dataset', default='robbiewilliams') parser.add_argument('--restore_epoch', type=int, default=1000) parser.add_argument('--early_stop', type=bool, help='no improvement during 10 epoch -> stop', default=True) args = parser.parse_args() config = HParams.load("run_config.yaml") if args.voca == True: config.feature['large_voca'] = True config.model['num_chords'] = 170 # Result save path asset_path = config.path['asset_path'] ckpt_path = config.path['ckpt_path'] result_path = config.path['result_path'] restore_epoch = args.restore_epoch experiment_num = str(args.index) ckpt_file_name = 'idx_'+experiment_num+'_%03d.pth.tar' tf_logger = TF_Logger(os.path.join(asset_path, 'tensorboard', 'idx_'+experiment_num)) logger.info("==== Experiment Number : %d " % args.index) if args.model == 'cnn': config.experiment['batch_size'] = 10 # Data loader train_dataset1 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset1,), num_workers=20, preprocessing=False, train=True, kfold=args.kfold) train_dataset2 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset2,), num_workers=20, preprocessing=False, train=True, kfold=args.kfold) train_dataset3 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset3,), num_workers=20, preprocessing=False, train=True, kfold=args.kfold) train_dataset = train_dataset1.__add__(train_dataset2).__add__(train_dataset3) valid_dataset1 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset1,), preprocessing=False, train=False, kfold=args.kfold) valid_dataset2 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset2,), preprocessing=False, train=False, kfold=args.kfold) valid_dataset3 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset3,), preprocessing=False, train=False, kfold=args.kfold) valid_dataset = valid_dataset1.__add__(valid_dataset2).__add__(valid_dataset3) train_dataloader = AudioDataLoader(dataset=train_dataset, batch_size=config.experiment['batch_size'], drop_last=False, shuffle=True) valid_dataloader = AudioDataLoader(dataset=valid_dataset, batch_size=config.experiment['batch_size'], drop_last=False) # Model and Optimizer if args.model == 'cnn': model = CNN(config=config.model).to(device) elif args.model == 'crnn': model = CRNN(config=config.model).to(device) elif args.model == 'btc': model = BTC_model(config=config.model).to(device) else: raise NotImplementedError optimizer = optim.Adam(model.parameters(), lr=config.experiment['learning_rate'], weight_decay=config.experiment['weight_decay'], betas=(0.9, 0.98), eps=1e-9) # Make asset directory if not os.path.exists(os.path.join(asset_path, ckpt_path)): os.makedirs(os.path.join(asset_path, ckpt_path)) os.makedirs(os.path.join(asset_path, result_path)) # Load model if os.path.isfile(os.path.join(asset_path, ckpt_path, ckpt_file_name % restore_epoch)): checkpoint = torch.load(os.path.join(asset_path, ckpt_path, ckpt_file_name % restore_epoch)) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) epoch = checkpoint['epoch'] logger.info("restore model with %d epochs" % restore_epoch) else: logger.info("no checkpoint with %d epochs" % restore_epoch) restore_epoch = 0 # Global mean and variance calculate mp3_config = config.mp3 feature_config = config.feature mp3_string = "%d_%.1f_%.1f" % (mp3_config['song_hz'], mp3_config['inst_len'], mp3_config['skip_interval']) feature_string = "_%s_%d_%d_%d_" % ('cqt', feature_config['n_bins'], feature_config['bins_per_octave'], feature_config['hop_length']) z_path = os.path.join(config.path['root_path'], 'result', mp3_string + feature_string + 'mix_kfold_'+ str(args.kfold) +'_normalization.pt') if os.path.exists(z_path): normalization = torch.load(z_path) mean = normalization['mean'] std = normalization['std'] logger.info("Global mean and std (k fold index %d) load complete" % args.kfold) else: mean = 0 square_mean = 0 k = 0 for i, data in enumerate(train_dataloader): features, input_percentages, chords, collapsed_chords, chord_lens, boundaries = data features = features.to(device) mean += torch.mean(features).item() square_mean += torch.mean(features.pow(2)).item() k += 1 square_mean = square_mean / k mean = mean / k std = np.sqrt(square_mean - mean * mean) normalization = dict() normalization['mean'] = mean normalization['std'] = std torch.save(normalization, z_path) logger.info("Global mean and std (training set, k fold index %d) calculation complete" % args.kfold) current_step = 0 best_acc = 0 before_acc = 0 early_stop_idx = 0 for epoch in range(restore_epoch, config.experiment['max_epoch']): # Training model.train() train_loss_list = [] total = 0. correct = 0. second_correct = 0. for i, data in enumerate(train_dataloader): features, input_percentages, chords, collapsed_chords, chord_lens, boundaries = data features, chords = features.to(device), chords.to(device) features.requires_grad = True features = (features - mean) / std # forward features = features.squeeze(1).permute(0,2,1) optimizer.zero_grad() prediction, total_loss, weights, second = model(features, chords) # save accuracy and loss total += chords.size(0) correct += (prediction == chords).type_as(chords).sum() second_correct += (second == chords).type_as(chords).sum() train_loss_list.append(total_loss.item()) # optimize step total_loss.backward() optimizer.step() current_step += 1 # logging loss and accuracy using tensorboard result = {'loss/tr': np.mean(train_loss_list), 'acc/tr': correct.item() / total, 'top2/tr': (correct.item()+second_correct.item()) / total} for tag, value in result.items(): tf_logger.scalar_summary(tag, value, epoch+1) logger.info("training loss for %d epoch: %.4f" % (epoch + 1, np.mean(train_loss_list))) logger.info("training accuracy for %d epoch: %.4f" % (epoch + 1, (correct.item() / total))) logger.info("training top2 accuracy for %d epoch: %.4f" % (epoch + 1, ((correct.item() + second_correct.item()) / total))) # Validation with torch.no_grad(): model.eval() val_total = 0. val_correct = 0. val_second_correct = 0. validation_loss = 0 n = 0 for i, data in enumerate(valid_dataloader): val_features, val_input_percentages, val_chords, val_collapsed_chords, val_chord_lens, val_boundaries = data val_features, val_chords = val_features.to(device), val_chords.to(device) val_features = (val_features - mean) / std val_features = val_features.squeeze(1).permute(0, 2, 1) val_prediction, val_loss, weights, val_second = model(val_features, val_chords) val_total += val_chords.size(0) val_correct += (val_prediction == val_chords).type_as(val_chords).sum() val_second_correct += (val_second == val_chords).type_as(val_chords).sum() validation_loss += val_loss.item() n += 1 # logging loss and accuracy using tensorboard validation_loss /= n result = {'loss/val': validation_loss, 'acc/val': val_correct.item() / val_total, 'top2/val': (val_correct.item()+val_second_correct.item()) / val_total} for tag, value in result.items(): tf_logger.scalar_summary(tag, value, epoch + 1) logger.info("validation loss(%d): %.4f" % (epoch + 1, validation_loss)) logger.info("validation accuracy(%d): %.4f" % (epoch + 1, (val_correct.item() / val_total))) logger.info("validation top2 accuracy(%d): %.4f" % (epoch + 1, ((val_correct.item() + val_second_correct.item()) / val_total))) current_acc = val_correct.item() / val_total if best_acc < val_correct.item() / val_total: early_stop_idx = 0 best_acc = val_correct.item() / val_total logger.info('==== best accuracy is %.4f and epoch is %d' % (best_acc, epoch + 1)) logger.info('saving model, Epoch %d, step %d' % (epoch + 1, current_step + 1)) model_save_path = os.path.join(asset_path, 'model', ckpt_file_name % (epoch + 1)) state_dict = {'model': model.state_dict(),'optimizer': optimizer.state_dict(),'epoch': epoch} torch.save(state_dict, model_save_path) last_best_epoch = epoch + 1 # save model elif (epoch + 1) % config.experiment['save_step'] == 0: logger.info('saving model, Epoch %d, step %d' % (epoch + 1, current_step + 1)) model_save_path = os.path.join(asset_path, 'model', ckpt_file_name % (epoch + 1)) state_dict = {'model': model.state_dict(),'optimizer': optimizer.state_dict(),'epoch': epoch} torch.save(state_dict, model_save_path) early_stop_idx += 1 else: early_stop_idx += 1 if (args.early_stop == True) and (early_stop_idx > 9): logger.info('==== early stopped and epoch is %d' % (epoch + 1)) break # learning rate decay if before_acc > current_acc: adjusting_learning_rate(optimizer=optimizer, factor=0.95, min_lr=5e-6) before_acc = current_acc # Load model if os.path.isfile(os.path.join(asset_path, ckpt_path, ckpt_file_name % last_best_epoch)): checkpoint = torch.load(os.path.join(asset_path, ckpt_path, ckpt_file_name % last_best_epoch)) model.load_state_dict(checkpoint['model']) logger.info("restore model with %d epochs" % last_best_epoch) else: raise NotImplementedError # score Validation if args.voca == True: score_metrics = ['root', 'thirds', 'triads', 'sevenths', 'tetrads', 'majmin', 'mirex'] score_list_dict1, song_length_list1, average_score_dict1 = large_voca_score_calculation(valid_dataset=valid_dataset1, config=config, model=model, model_type=args.model, mean=mean, std=std, device=device) score_list_dict2, song_length_list2, average_score_dict2 = large_voca_score_calculation(valid_dataset=valid_dataset2, config=config, model=model, model_type=args.model, mean=mean, std=std, device=device) score_list_dict3, song_length_list3, average_score_dict3 = large_voca_score_calculation(valid_dataset=valid_dataset3, config=config, model=model, model_type=args.model, mean=mean, std=std, device=device) for m in score_metrics: average_score = (np.sum(song_length_list1) * average_score_dict1[m] + np.sum(song_length_list2) *average_score_dict2[m] + np.sum(song_length_list3) * average_score_dict3[m]) / (np.sum(song_length_list1) + np.sum(song_length_list2) + np.sum(song_length_list3)) logger.info('==== %s score 1 is %.4f' % (m, average_score_dict1[m])) logger.info('==== %s score 2 is %.4f' % (m, average_score_dict2[m])) logger.info('==== %s score 3 is %.4f' % (m, average_score_dict3[m])) logger.info('==== %s mix average score is %.4f' % (m, average_score)) else: score_metrics = ['root', 'majmin'] score_list_dict1, song_length_list1, average_score_dict1 = root_majmin_score_calculation(valid_dataset=valid_dataset1, config=config, model=model, model_type=args.model, mean=mean, std=std, device=device) score_list_dict2, song_length_list2, average_score_dict2 = root_majmin_score_calculation(valid_dataset=valid_dataset2, config=config, model=model, model_type=args.model, mean=mean, std=std, device=device) score_list_dict3, song_length_list3, average_score_dict3 = root_majmin_score_calculation(valid_dataset=valid_dataset3, config=config, model=model, model_type=args.model, mean=mean, std=std, device=device) for m in score_metrics: average_score = (np.sum(song_length_list1) * average_score_dict1[m] + np.sum(song_length_list2) *average_score_dict2[m] + np.sum(song_length_list3) * average_score_dict3[m]) / (np.sum(song_length_list1) + np.sum(song_length_list2) + np.sum(song_length_list3)) logger.info('==== %s score 1 is %.4f' % (m, average_score_dict1[m])) logger.info('==== %s score 2 is %.4f' % (m, average_score_dict2[m])) logger.info('==== %s score 3 is %.4f' % (m, average_score_dict3[m])) logger.info('==== %s mix average score is %.4f' % (m, average_score)) ================================================ FILE: train_crf.py ================================================ import os from torch import optim from utils import logger from audio_dataset import AudioDataset, AudioDataLoader from utils.tf_logger import TF_Logger from btc_model import * from baseline_models import CNN, CRNN, Crf from utils.hparams import HParams import argparse from utils.pytorch_utils import adjusting_learning_rate from utils.mir_eval_modules import large_voca_score_calculation_crf, root_majmin_score_calculation_crf import warnings warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=FutureWarning) logger.logging_verbosity(1) use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") parser = argparse.ArgumentParser() parser.add_argument('--index', type=int, help='Experiment Number', default='e') parser.add_argument('--kfold', type=int, help='5 fold (0,1,2,3,4)',default='e') parser.add_argument('--voca', type=bool, help='large voca is True', default=False) parser.add_argument('--model', type=str, default='crf') parser.add_argument('--pre_model', type=str, help='btc, cnn, crnn', default='e') parser.add_argument('--dataset1', type=str, help='Dataset', default='isophonic_221') parser.add_argument('--dataset2', type=str, help='Dataset', default='uspop_185') parser.add_argument('--dataset3', type=str, help='Dataset', default='robbiewilliams') parser.add_argument('--restore_epoch', type=int, default=1000) parser.add_argument('--early_stop', type=bool, help='no improvement during 10 epoch -> stop', default=True) args = parser.parse_args() config = HParams.load("run_config.yaml") if args.voca == True: config.feature['large_voca'] = True config.model['num_chords'] = 170 config.model['probs_out'] = True # Result save path asset_path = config.path['asset_path'] ckpt_path = config.path['ckpt_path'] result_path = config.path['result_path'] restore_epoch = args.restore_epoch experiment_num = str(args.index) ckpt_file_name = 'idx_'+experiment_num+'_%03d.pth.tar' tf_logger = TF_Logger(os.path.join(asset_path, 'tensorboard', 'idx_'+experiment_num)) logger.info("==== Experiment Number : %d " % args.index) if args.pre_model == 'cnn': config.experiment['batch_size'] = 20 # Data loader train_dataset1 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset1,), num_workers=20, preprocessing=False, train=True, kfold=args.kfold) train_dataset2 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset2,), num_workers=20, preprocessing=False, train=True, kfold=args.kfold) train_dataset3 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset3,), num_workers=20, preprocessing=False, train=True, kfold=args.kfold) train_dataset = train_dataset1.__add__(train_dataset2).__add__(train_dataset3) valid_dataset1 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset1,), preprocessing=False, train=False, kfold=args.kfold) valid_dataset2 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset2,), preprocessing=False, train=False, kfold=args.kfold) valid_dataset3 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset3,), preprocessing=False, train=False, kfold=args.kfold) valid_dataset = valid_dataset1.__add__(valid_dataset2).__add__(valid_dataset3) train_dataloader = AudioDataLoader(dataset=train_dataset, batch_size=config.experiment['batch_size'], drop_last=False, shuffle=True) valid_dataloader = AudioDataLoader(dataset=valid_dataset, batch_size=config.experiment['batch_size'], drop_last=False) # Model and Optimizer if args.pre_model == 'cnn': pre_model = CNN(config=config.model).to(device) elif args.pre_model == 'crnn': pre_model = CRNN(config=config.model).to(device) elif args.pre_model == 'btc': pre_model = BTC_model(config=config.model).to(device) else: raise NotImplementedError if args.pre_model == 'cnn': if args.voca == False: if args.kfold == 0: load_ckpt_file_name = 'idx_0_%03d.pth.tar' load_restore_epoch = 10 else: if args.kfold == 0: load_ckpt_file_name = 'idx_1_%03d.pth.tar' load_restore_epoch = 10 else: raise NotImplementedError if os.path.isfile(os.path.join(asset_path, ckpt_path, load_ckpt_file_name % load_restore_epoch)): checkpoint = torch.load(os.path.join(asset_path, ckpt_path, load_ckpt_file_name % load_restore_epoch)) pre_model.load_state_dict(checkpoint['model']) logger.info("restore pre model with %d epochs" % load_restore_epoch) else: raise NotImplementedError # Fix Pre Model Parameters for param in pre_model.parameters(): param.requires_grad = False # Crf Model and Optimizer crf = Crf(num_chords=config.model['num_chords'], timestep=config.model['timestep']).to(device) optimizer = optim.Adam(filter(lambda p: p.requires_grad, crf.parameters()), lr=0.01, weight_decay=config.experiment['weight_decay'], betas=(0.9, 0.98), eps=1e-9) # Make asset directory if not os.path.exists(os.path.join(asset_path, ckpt_path)): os.makedirs(os.path.join(asset_path, ckpt_path)) os.makedirs(os.path.join(asset_path, result_path)) # Load model if os.path.isfile(os.path.join(asset_path, ckpt_path, ckpt_file_name % restore_epoch)): checkpoint = torch.load(os.path.join(asset_path, ckpt_path, ckpt_file_name % restore_epoch)) crf.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) epoch = checkpoint['epoch'] logger.info("restore model with %d epochs" % restore_epoch) else: logger.info("no checkpoint with %d epochs" % restore_epoch) restore_epoch = 0 # Global mean and variance calculate mp3_config = config.mp3 feature_config = config.feature mp3_string = "%d_%.1f_%.1f" % (mp3_config['song_hz'], mp3_config['inst_len'], mp3_config['skip_interval']) feature_string = "_%s_%d_%d_%d_" % ('cqt', feature_config['n_bins'], feature_config['bins_per_octave'], feature_config['hop_length']) z_path = os.path.join(config.path['root_path'], 'result', mp3_string + feature_string + 'mix_kfold_'+ str(args.kfold) +'_normalization.pt') if os.path.exists(z_path): normalization = torch.load(z_path) mean = normalization['mean'] std = normalization['std'] logger.info("Global mean and std (k fold index %d) load complete" % args.kfold) else: mean = 0 square_mean = 0 k = 0 for i, data in enumerate(train_dataloader): features, input_percentages, chords, collapsed_chords, chord_lens, boundaries = data features = features.to(device) mean += torch.mean(features).item() square_mean += torch.mean(features.pow(2)).item() k += 1 square_mean = square_mean / k mean = mean / k std = np.sqrt(square_mean - mean * mean) normalization = dict() normalization['mean'] = mean normalization['std'] = std torch.save(normalization, z_path) logger.info("Global mean and std (training set, k fold index %d) calculation complete" % args.kfold) current_step = 0 best_acc = 0 before_acc = 0 early_stop_idx = 0 pre_model.eval() for epoch in range(restore_epoch, config.experiment['max_epoch']): # Training crf.train() train_loss_list = [] total = 0. correct = 0. second_correct = 0. for i, data in enumerate(train_dataloader): features, input_percentages, chords, collapsed_chords, chord_lens, boundaries = data features, chords = features.to(device), chords.to(device) features.requires_grad = True features = (features - mean) / std # forward features = features.squeeze(1).permute(0,2,1) optimizer.zero_grad() logits = pre_model(features, chords) if args.pre_model == 'crnn': logits = logits.detach() logits.requires_grad = True prediction, total_loss = crf(logits, chords) # save accuracy and loss total += chords.size(0) correct += (prediction == chords).type_as(chords).sum() train_loss_list.append(total_loss.item()) # optimize step total_loss.backward() optimizer.step() current_step += 1 # logging loss and accuracy using tensorboard result = {'loss/tr': np.mean(train_loss_list), 'acc/tr': correct.item() / total} for tag, value in result.items(): tf_logger.scalar_summary(tag, value, epoch+1) logger.info("training loss for %d epoch: %.4f" % (epoch + 1, np.mean(train_loss_list))) logger.info("training accuracy for %d epoch: %.4f" % (epoch + 1, (correct.item() / total))) # Validation with torch.no_grad(): crf.eval() val_total = 0. val_correct = 0. val_second_correct = 0. validation_loss = 0 n = 0 for i, data in enumerate(valid_dataloader): val_features, val_input_percentages, val_chords, val_collapsed_chords, val_chord_lens, val_boundaries = data val_features, val_chords = val_features.to(device), val_chords.to(device) val_features = (val_features - mean) / std val_features = val_features.squeeze(1).permute(0, 2, 1) val_logits = pre_model(val_features, val_chords) val_prediction, val_loss = crf(val_logits, val_chords) val_total += val_chords.size(0) val_correct += (val_prediction == val_chords).type_as(val_chords).sum() validation_loss += val_loss.item() n += 1 # logging loss and accuracy using tensorboard validation_loss /= n result = {'loss/val': validation_loss, 'acc/val': val_correct.item() / val_total} for tag, value in result.items(): tf_logger.scalar_summary(tag, value, epoch + 1) logger.info("validation loss(%d): %.4f" % (epoch + 1, validation_loss)) logger.info("validation accuracy(%d): %.4f" % (epoch + 1, (val_correct.item() / val_total))) current_acc = val_correct.item() / val_total if best_acc < val_correct.item() / val_total: early_stop_idx = 0 best_acc = val_correct.item() / val_total logger.info('==== best accuracy is %.4f and epoch is %d' % (best_acc, epoch + 1)) logger.info('saving model, Epoch %d, step %d' % (epoch + 1, current_step + 1)) model_save_path = os.path.join(asset_path, 'model', ckpt_file_name % (epoch + 1)) state_dict = {'model': crf.state_dict(),'optimizer': optimizer.state_dict(),'epoch': epoch} torch.save(state_dict, model_save_path) last_best_epoch = epoch + 1 # save model elif (epoch + 1) % config.experiment['save_step'] == 0: logger.info('saving model, Epoch %d, step %d' % (epoch + 1, current_step + 1)) model_save_path = os.path.join(asset_path, 'model', ckpt_file_name % (epoch + 1)) state_dict = {'model': crf.state_dict(),'optimizer': optimizer.state_dict(),'epoch': epoch} torch.save(state_dict, model_save_path) early_stop_idx += 1 else: early_stop_idx += 1 if (args.early_stop == True) and (early_stop_idx > 5): logger.info('==== early stopped and epoch is %d' % (epoch + 1)) break # learning rate decay if before_acc > current_acc: adjusting_learning_rate(optimizer=optimizer, factor=0.95, min_lr=5e-6) before_acc = current_acc # Load model if os.path.isfile(os.path.join(asset_path, ckpt_path, ckpt_file_name % last_best_epoch)): checkpoint = torch.load(os.path.join(asset_path, ckpt_path, ckpt_file_name % last_best_epoch)) crf.load_state_dict(checkpoint['model']) logger.info("last best restore model with %d epochs" % last_best_epoch) else: raise NotImplementedError # score Validation if args.voca == True: score_metrics = ['root', 'thirds', 'triads', 'sevenths', 'tetrads', 'majmin', 'mirex'] score_list_dict1, song_length_list1, average_score_dict1 = large_voca_score_calculation_crf(valid_dataset=valid_dataset1, config=config, pre_model=pre_model, model=crf, model_type=args.pre_model, mean=mean, std=std, device=device) score_list_dict2, song_length_list2, average_score_dict2 = large_voca_score_calculation_crf(valid_dataset=valid_dataset2, config=config, pre_model=pre_model, model=crf, model_type=args.pre_model, mean=mean, std=std, device=device) score_list_dict3, song_length_list3, average_score_dict3 = large_voca_score_calculation_crf(valid_dataset=valid_dataset3, config=config, pre_model=pre_model, model=crf, model_type=args.pre_model, mean=mean, std=std, device=device) for m in score_metrics: average_score = (np.sum(song_length_list1) * average_score_dict1[m] + np.sum(song_length_list2) *average_score_dict2[m] + np.sum(song_length_list3) * average_score_dict3[m]) / (np.sum(song_length_list1) + np.sum(song_length_list2) + np.sum(song_length_list3)) logger.info('==== %s score 1 is %.4f' % (m, average_score_dict1[m])) logger.info('==== %s score 2 is %.4f' % (m, average_score_dict2[m])) logger.info('==== %s score 3 is %.4f' % (m, average_score_dict3[m])) logger.info('==== %s mix average score is %.4f' % (m, average_score)) else: score_metrics = ['root', 'majmin'] score_list_dict1, song_length_list1, average_score_dict1 = root_majmin_score_calculation_crf(valid_dataset=valid_dataset1, config=config, pre_model=pre_model, model=crf, model_type=args.pre_model, mean=mean, std=std, device=device) score_list_dict2, song_length_list2, average_score_dict2 = root_majmin_score_calculation_crf(valid_dataset=valid_dataset2, config=config, pre_model=pre_model, model=crf, model_type=args.pre_model, mean=mean, std=std, device=device) score_list_dict3, song_length_list3, average_score_dict3 = root_majmin_score_calculation_crf(valid_dataset=valid_dataset3, config=config, pre_model=pre_model, model=crf, model_type=args.pre_model, mean=mean, std=std, device=device) for m in score_metrics: average_score = (np.sum(song_length_list1) * average_score_dict1[m] + np.sum(song_length_list2) *average_score_dict2[m] + np.sum(song_length_list3) * average_score_dict3[m]) / (np.sum(song_length_list1) + np.sum(song_length_list2) + np.sum(song_length_list3)) logger.info('==== %s score 1 is %.4f' % (m, average_score_dict1[m])) logger.info('==== %s score 2 is %.4f' % (m, average_score_dict2[m])) logger.info('==== %s score 3 is %.4f' % (m, average_score_dict3[m])) logger.info('==== %s mix average score is %.4f' % (m, average_score)) ================================================ FILE: utils/__init__.py ================================================ ================================================ FILE: utils/chords.py ================================================ # encoding: utf-8 """ This module contains chord evaluation functionality. It provides the evaluation measures used for the MIREX ACE task, and tries to follow [1]_ and [2]_ as closely as possible. Notes ----- This implementation tries to follow the references and their implementation (e.g., https://github.com/jpauwels/MusOOEvaluator for [2]_). However, there are some known (and possibly some unknown) differences. If you find one not listed in the following, please file an issue: - Detected chord segments are adjusted to fit the length of the annotations. In particular, this means that, if necessary, filler segments of 'no chord' are added at beginnings and ends. This can result in different segmentation scores compared to the original implementation. References ---------- .. [1] Christopher Harte, "Towards Automatic Extraction of Harmony Information from Music Signals." Dissertation, Department for Electronic Engineering, Queen Mary University of London, 2010. .. [2] Johan Pauwels and Geoffroy Peeters. "Evaluating Automatically Estimated Chord Sequences." In Proceedings of ICASSP 2013, Vancouver, Canada, 2013. """ import numpy as np import pandas as pd import mir_eval CHORD_DTYPE = [('root', np.int), ('bass', np.int), ('intervals', np.int, (12,)), ('is_major',np.bool)] CHORD_ANN_DTYPE = [('start', np.float), ('end', np.float), ('chord', CHORD_DTYPE)] NO_CHORD = (-1, -1, np.zeros(12, dtype=np.int), False) UNKNOWN_CHORD = (-1, -1, np.ones(12, dtype=np.int) * -1, False) PITCH_CLASS = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'] def idx_to_chord(idx): if idx == 24: return "-" elif idx == 25: return u"\u03B5" minmaj = idx % 2 root = idx // 2 return PITCH_CLASS[root] + ("M" if minmaj == 0 else "m") class Chords: def __init__(self): self._shorthands = { 'maj': self.interval_list('(1,3,5)'), 'min': self.interval_list('(1,b3,5)'), 'dim': self.interval_list('(1,b3,b5)'), 'aug': self.interval_list('(1,3,#5)'), 'maj7': self.interval_list('(1,3,5,7)'), 'min7': self.interval_list('(1,b3,5,b7)'), '7': self.interval_list('(1,3,5,b7)'), '6': self.interval_list('(1,6)'), # custom '5': self.interval_list('(1,5)'), '4': self.interval_list('(1,4)'), # custom '1': self.interval_list('(1)'), 'dim7': self.interval_list('(1,b3,b5,bb7)'), 'hdim7': self.interval_list('(1,b3,b5,b7)'), 'minmaj7': self.interval_list('(1,b3,5,7)'), 'maj6': self.interval_list('(1,3,5,6)'), 'min6': self.interval_list('(1,b3,5,6)'), '9': self.interval_list('(1,3,5,b7,9)'), 'maj9': self.interval_list('(1,3,5,7,9)'), 'min9': self.interval_list('(1,b3,5,b7,9)'), 'sus2': self.interval_list('(1,2,5)'), 'sus4': self.interval_list('(1,4,5)'), '11': self.interval_list('(1,3,5,b7,9,11)'), 'min11': self.interval_list('(1,b3,5,b7,9,11)'), '13': self.interval_list('(1,3,5,b7,13)'), 'maj13': self.interval_list('(1,3,5,7,13)'), 'min13': self.interval_list('(1,b3,5,b7,13)') } def chords(self, labels): """ Transform a list of chord labels into an array of internal numeric representations. Parameters ---------- labels : list List of chord labels (str). Returns ------- chords : numpy.array Structured array with columns 'root', 'bass', and 'intervals', containing a numeric representation of chords. """ crds = np.zeros(len(labels), dtype=CHORD_DTYPE) cache = {} for i, lbl in enumerate(labels): cv = cache.get(lbl, None) if cv is None: cv = self.chord(lbl) cache[lbl] = cv crds[i] = cv return crds def label_error_modify(self, label): if label == 'Emin/4': label = 'E:min/4' elif label == 'A7/3': label = 'A:7/3' elif label == 'Bb7/3': label = 'Bb:7/3' elif label == 'Bb7/5': label = 'Bb:7/5' elif label.find(':') == -1: if label.find('min') != -1: label = label[:label.find('min')] + ':' + label[label.find('min'):] return label def chord(self, label): """ Transform a chord label into the internal numeric represenation of (root, bass, intervals array). Parameters ---------- label : str Chord label. Returns ------- chord : tuple Numeric representation of the chord: (root, bass, intervals array). """ try: is_major = False if label == 'N': return NO_CHORD if label == 'X': return UNKNOWN_CHORD label = self.label_error_modify(label) c_idx = label.find(':') s_idx = label.find('/') if c_idx == -1: quality_str = 'maj' if s_idx == -1: root_str = label bass_str = '' else: root_str = label[:s_idx] bass_str = label[s_idx + 1:] else: root_str = label[:c_idx] if s_idx == -1: quality_str = label[c_idx + 1:] bass_str = '' else: quality_str = label[c_idx + 1:s_idx] bass_str = label[s_idx + 1:] root = self.pitch(root_str) bass = self.interval(bass_str) if bass_str else 0 ivs = self.chord_intervals(quality_str) ivs[bass] = 1 if 'min' in quality_str: is_major = False else: is_major = True except Exception as e: print(e, label) return root, bass, ivs, is_major _l = [0, 1, 1, 0, 1, 1, 1] _chroma_id = (np.arange(len(_l) * 2) + 1) + np.array(_l + _l).cumsum() - 1 def modify(self, base_pitch, modifier): """ Modify a pitch class in integer representation by a given modifier string. A modifier string can be any sequence of 'b' (one semitone down) and '#' (one semitone up). Parameters ---------- base_pitch : int Pitch class as integer. modifier : str String of modifiers ('b' or '#'). Returns ------- modified_pitch : int Modified root note. """ for m in modifier: if m == 'b': base_pitch -= 1 elif m == '#': base_pitch += 1 else: raise ValueError('Unknown modifier: {}'.format(m)) return base_pitch def pitch(self, pitch_str): """ Convert a string representation of a pitch class (consisting of root note and modifiers) to an integer representation. Parameters ---------- pitch_str : str String representation of a pitch class. Returns ------- pitch : int Integer representation of a pitch class. """ return self.modify(self._chroma_id[(ord(pitch_str[0]) - ord('C')) % 7], pitch_str[1:]) % 12 def interval(self, interval_str): """ Convert a string representation of a musical interval into a pitch class (e.g. a minor seventh 'b7' into 10, because it is 10 semitones above its base note). Parameters ---------- interval_str : str Musical interval. Returns ------- pitch_class : int Number of semitones to base note of interval. """ for i, c in enumerate(interval_str): if c.isdigit(): return self.modify(self._chroma_id[int(interval_str[i:]) - 1], interval_str[:i]) % 12 def interval_list(self, intervals_str, given_pitch_classes=None): """ Convert a list of intervals given as string to a binary pitch class representation. For example, 'b3, 5' would become [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0]. Parameters ---------- intervals_str : str List of intervals as comma-separated string (e.g. 'b3, 5'). given_pitch_classes : None or numpy array If None, start with empty pitch class array, if numpy array of length 12, this array will be modified. Returns ------- pitch_classes : numpy array Binary pitch class representation of intervals. """ if given_pitch_classes is None: given_pitch_classes = np.zeros(12, dtype=np.int) for int_def in intervals_str[1:-1].split(','): int_def = int_def.strip() if int_def[0] == '*': given_pitch_classes[self.interval(int_def[1:])] = 0 else: given_pitch_classes[self.interval(int_def)] = 1 return given_pitch_classes # mapping of shorthand interval notations to the actual interval representation def chord_intervals(self, quality_str): """ Convert a chord quality string to a pitch class representation. For example, 'maj' becomes [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0]. Parameters ---------- quality_str : str String defining the chord quality. Returns ------- pitch_classes : numpy array Binary pitch class representation of chord quality. """ list_idx = quality_str.find('(') if list_idx == -1: return self._shorthands[quality_str].copy() if list_idx != 0: ivs = self._shorthands[quality_str[:list_idx]].copy() else: ivs = np.zeros(12, dtype=np.int) return self.interval_list(quality_str[list_idx:], ivs) def load_chords(self, filename): """ Load chords from a text file. The chord must follow the syntax defined in [1]_. Parameters ---------- filename : str File containing chord segments. Returns ------- crds : numpy structured array Structured array with columns "start", "end", and "chord", containing the beginning, end, and chord definition of chord segments. References ---------- .. [1] Christopher Harte, "Towards Automatic Extraction of Harmony Information from Music Signals." Dissertation, Department for Electronic Engineering, Queen Mary University of London, 2010. """ start, end, chord_labels = [], [], [] with open(filename, 'r') as f: for line in f: if line: splits = line.split() if len(splits) == 3: s = splits[0] e = splits[1] l = splits[2] start.append(float(s)) end.append(float(e)) chord_labels.append(l) crds = np.zeros(len(start), dtype=CHORD_ANN_DTYPE) crds['start'] = start crds['end'] = end crds['chord'] = self.chords(chord_labels) return crds def reduce_to_triads(self, chords, keep_bass=False): """ Reduce chords to triads. The function follows the reduction rules implemented in [1]_. If a chord chord does not contain a third, major second or fourth, it is reduced to a power chord. If it does not contain neither a third nor a fifth, it is reduced to a single note "chord". Parameters ---------- chords : numpy structured array Chords to be reduced. keep_bass : bool Indicates whether to keep the bass note or set it to 0. Returns ------- reduced_chords : numpy structured array Chords reduced to triads. References ---------- .. [1] Johan Pauwels and Geoffroy Peeters. "Evaluating Automatically Estimated Chord Sequences." In Proceedings of ICASSP 2013, Vancouver, Canada, 2013. """ unison = chords['intervals'][:, 0].astype(bool) maj_sec = chords['intervals'][:, 2].astype(bool) min_third = chords['intervals'][:, 3].astype(bool) maj_third = chords['intervals'][:, 4].astype(bool) perf_fourth = chords['intervals'][:, 5].astype(bool) dim_fifth = chords['intervals'][:, 6].astype(bool) perf_fifth = chords['intervals'][:, 7].astype(bool) aug_fifth = chords['intervals'][:, 8].astype(bool) no_chord = (chords['intervals'] == NO_CHORD[-1]).all(axis=1) reduced_chords = chords.copy() ivs = reduced_chords['intervals'] ivs[~no_chord] = self.interval_list('(1)') ivs[unison & perf_fifth] = self.interval_list('(1,5)') ivs[~perf_fourth & maj_sec] = self._shorthands['sus2'] ivs[perf_fourth & ~maj_sec] = self._shorthands['sus4'] ivs[min_third] = self._shorthands['min'] ivs[min_third & aug_fifth & ~perf_fifth] = self.interval_list('(1,b3,#5)') ivs[min_third & dim_fifth & ~perf_fifth] = self._shorthands['dim'] ivs[maj_third] = self._shorthands['maj'] ivs[maj_third & dim_fifth & ~perf_fifth] = self.interval_list('(1,3,b5)') ivs[maj_third & aug_fifth & ~perf_fifth] = self._shorthands['aug'] if not keep_bass: reduced_chords['bass'] = 0 else: # remove bass notes if they are not part of the intervals anymore reduced_chords['bass'] *= ivs[range(len(reduced_chords)), reduced_chords['bass']] # keep -1 in bass for no chords reduced_chords['bass'][no_chord] = -1 return reduced_chords def convert_to_id(self, root, is_major): if root == -1: return 24 else: if is_major: return root * 2 else: return root * 2 + 1 def get_converted_chord(self, filename): loaded_chord = self.load_chords(filename) triads = self.reduce_to_triads(loaded_chord['chord']) df = self.assign_chord_id(triads) df['start'] = loaded_chord['start'] df['end'] = loaded_chord['end'] return df def assign_chord_id(self, entry): # maj, min chord only # if you want to add other chord, change this part and get_converted_chord(reduce_to_triads) df = pd.DataFrame(data=entry[['root', 'is_major']]) df['chord_id'] = df.apply(lambda row: self.convert_to_id(row['root'], row['is_major']), axis=1) return df def convert_to_id_voca(self, root, quality): if root == -1: return 169 else: if quality == 'min': return root * 14 elif quality == 'maj': return root * 14 + 1 elif quality == 'dim': return root * 14 + 2 elif quality == 'aug': return root * 14 + 3 elif quality == 'min6': return root * 14 + 4 elif quality == 'maj6': return root * 14 + 5 elif quality == 'min7': return root * 14 + 6 elif quality == 'minmaj7': return root * 14 + 7 elif quality == 'maj7': return root * 14 + 8 elif quality == '7': return root * 14 + 9 elif quality == 'dim7': return root * 14 + 10 elif quality == 'hdim7': return root * 14 + 11 elif quality == 'sus2': return root * 14 + 12 elif quality == 'sus4': return root * 14 + 13 else: return 168 def get_converted_chord_voca(self, filename): loaded_chord = self.load_chords(filename) triads = self.reduce_to_triads(loaded_chord['chord']) df = pd.DataFrame(data=triads[['root', 'is_major']]) (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(filename) ref_labels = self.lab_file_error_modify(ref_labels) idxs = list() for i in ref_labels: chord_root, quality, scale_degrees, bass = mir_eval.chord.split(i, reduce_extended_chords=True) root, bass, ivs, is_major = self.chord(i) idxs.append(self.convert_to_id_voca(root=root, quality=quality)) df['chord_id'] = idxs df['start'] = loaded_chord['start'] df['end'] = loaded_chord['end'] return df def lab_file_error_modify(self, ref_labels): for i in range(len(ref_labels)): if ref_labels[i][-2:] == ':4': ref_labels[i] = ref_labels[i].replace(':4', ':sus4') elif ref_labels[i][-2:] == ':6': ref_labels[i] = ref_labels[i].replace(':6', ':maj6') elif ref_labels[i][-4:] == ':6/2': ref_labels[i] = ref_labels[i].replace(':6/2', ':maj6/2') elif ref_labels[i] == 'Emin/4': ref_labels[i] = 'E:min/4' elif ref_labels[i] == 'A7/3': ref_labels[i] = 'A:7/3' elif ref_labels[i] == 'Bb7/3': ref_labels[i] = 'Bb:7/3' elif ref_labels[i] == 'Bb7/5': ref_labels[i] = 'Bb:7/5' elif ref_labels[i].find(':') == -1: if ref_labels[i].find('min') != -1: ref_labels[i] = ref_labels[i][:ref_labels[i].find('min')] + ':' + ref_labels[i][ref_labels[i].find('min'):] return ref_labels ================================================ FILE: utils/hparams.py ================================================ import yaml # TODO: add function should be changed class HParams(object): # Hyperparameter class using yaml def __init__(self, **kwargs): self.__dict__ = kwargs def add(self, **kwargs): # change is needed - if key is existed, do not update. self.__dict__.update(kwargs) def update(self, **kwargs): self.__dict__.update(kwargs) return self def save(self, path): with open(path, 'w') as f: yaml.dump(self.__dict__, f) return self def __repr__(self): return '\nHyperparameters:\n' + '\n'.join([' {}={}'.format(k, v) for k, v in self.__dict__.items()]) @classmethod def load(cls, path): with open(path, 'r') as f: return cls(**yaml.load(f)) if __name__ == '__main__': hparams = HParams.load('hparams.yaml') print(hparams) d = {"MemoryNetwork": 0, "c": 1} hparams.add(**d) print(hparams) ================================================ FILE: utils/logger.py ================================================ import logging import os import sys import time project_name = os.getcwd().split('/')[-1] _logger = logging.getLogger(project_name) _logger.addHandler(logging.StreamHandler()) def _log_prefix(): # Returns (filename, line number) for the stack frame. def _get_file_line(): # pylint: disable=protected-access # noinspection PyProtectedMember f = sys._getframe() # pylint: enable=protected-access our_file = f.f_code.co_filename f = f.f_back while f: code = f.f_code if code.co_filename != our_file: return code.co_filename, f.f_lineno f = f.f_back return '', 0 # current time now = time.time() now_tuple = time.localtime(now) now_millisecond = int(1e3 * (now % 1.0)) # current filename and line filename, line = _get_file_line() basename = os.path.basename(filename) s = '%02d-%02d %02d:%02d:%02d.%03d %s:%d] ' % ( now_tuple[1], # month now_tuple[2], # day now_tuple[3], # hour now_tuple[4], # min now_tuple[5], # sec now_millisecond, basename, line) return s def logging_verbosity(verbosity=0): _logger.setLevel(verbosity) def debug(msg, *args, **kwargs): _logger.debug('D ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs) def info(msg, *args, **kwargs): _logger.info('I ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs) def warn(msg, *args, **kwargs): _logger.warning('W ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs) def error(msg, *args, **kwargs): _logger.error('E ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs) def fatal(msg, *args, **kwargs): _logger.fatal('F ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs) ================================================ FILE: utils/mir_eval_modules.py ================================================ import numpy as np import librosa import mir_eval import torch import os idx2chord = ['C', 'C:min', 'C#', 'C#:min', 'D', 'D:min', 'D#', 'D#:min', 'E', 'E:min', 'F', 'F:min', 'F#', 'F#:min', 'G', 'G:min', 'G#', 'G#:min', 'A', 'A:min', 'A#', 'A#:min', 'B', 'B:min', 'N'] root_list = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'] quality_list = ['min', 'maj', 'dim', 'aug', 'min6', 'maj6', 'min7', 'minmaj7', 'maj7', '7', 'dim7', 'hdim7', 'sus2', 'sus4'] def idx2voca_chord(): idx2voca_chord = {} idx2voca_chord[169] = 'N' idx2voca_chord[168] = 'X' for i in range(168): root = i // 14 root = root_list[root] quality = i % 14 quality = quality_list[quality] if i % 14 != 1: chord = root + ':' + quality else: chord = root idx2voca_chord[i] = chord return idx2voca_chord def audio_file_to_features(audio_file, config): original_wav, sr = librosa.load(audio_file, sr=config.mp3['song_hz'], mono=True) currunt_sec_hz = 0 while len(original_wav) > currunt_sec_hz + config.mp3['song_hz'] * config.mp3['inst_len']: start_idx = int(currunt_sec_hz) end_idx = int(currunt_sec_hz + config.mp3['song_hz'] * config.mp3['inst_len']) tmp = librosa.cqt(original_wav[start_idx:end_idx], sr=sr, n_bins=config.feature['n_bins'], bins_per_octave=config.feature['bins_per_octave'], hop_length=config.feature['hop_length']) if start_idx == 0: feature = tmp else: feature = np.concatenate((feature, tmp), axis=1) currunt_sec_hz = end_idx tmp = librosa.cqt(original_wav[currunt_sec_hz:], sr=sr, n_bins=config.feature['n_bins'], bins_per_octave=config.feature['bins_per_octave'], hop_length=config.feature['hop_length']) feature = np.concatenate((feature, tmp), axis=1) feature = np.log(np.abs(feature) + 1e-6) feature_per_second = config.mp3['inst_len'] / config.model['timestep'] song_length_second = len(original_wav)/config.mp3['song_hz'] return feature, feature_per_second, song_length_second # Audio files with format of wav and mp3 def get_audio_paths(audio_dir): return [os.path.join(root, fname) for (root, dir_names, file_names) in os.walk(audio_dir, followlinks=True) for fname in file_names if (fname.lower().endswith('.wav') or fname.lower().endswith('.mp3'))] class metrics(): def __init__(self): super(metrics, self).__init__() self.score_metrics = ['root', 'thirds', 'triads', 'sevenths', 'tetrads', 'majmin', 'mirex'] self.score_list_dict = dict() for i in self.score_metrics: self.score_list_dict[i] = list() self.average_score = dict() def score(self, metric, gt_path, est_path): if metric == 'root': score = self.root_score(gt_path,est_path) elif metric == 'thirds': score = self.thirds_score(gt_path,est_path) elif metric == 'triads': score = self.triads_score(gt_path,est_path) elif metric == 'sevenths': score = self.sevenths_score(gt_path,est_path) elif metric == 'tetrads': score = self.tetrads_score(gt_path,est_path) elif metric == 'majmin': score = self.majmin_score(gt_path,est_path) elif metric == 'mirex': score = self.mirex_score(gt_path,est_path) else: raise NotImplementedError return score def root_score(self, gt_path, est_path): (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path) ref_labels = lab_file_error_modify(ref_labels) (est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path) est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(), ref_intervals.max(), mir_eval.chord.NO_CHORD, mir_eval.chord.NO_CHORD) (intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels, est_intervals, est_labels) durations = mir_eval.util.intervals_to_durations(intervals) comparisons = mir_eval.chord.root(ref_labels, est_labels) score = mir_eval.chord.weighted_accuracy(comparisons, durations) return score def thirds_score(self, gt_path, est_path): (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path) ref_labels = lab_file_error_modify(ref_labels) (est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path) est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(), ref_intervals.max(), mir_eval.chord.NO_CHORD, mir_eval.chord.NO_CHORD) (intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels, est_intervals, est_labels) durations = mir_eval.util.intervals_to_durations(intervals) comparisons = mir_eval.chord.thirds(ref_labels, est_labels) score = mir_eval.chord.weighted_accuracy(comparisons, durations) return score def triads_score(self, gt_path, est_path): (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path) ref_labels = lab_file_error_modify(ref_labels) (est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path) est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(), ref_intervals.max(), mir_eval.chord.NO_CHORD, mir_eval.chord.NO_CHORD) (intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels, est_intervals, est_labels) durations = mir_eval.util.intervals_to_durations(intervals) comparisons = mir_eval.chord.triads(ref_labels, est_labels) score = mir_eval.chord.weighted_accuracy(comparisons, durations) return score def sevenths_score(self, gt_path, est_path): (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path) ref_labels = lab_file_error_modify(ref_labels) (est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path) est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(), ref_intervals.max(), mir_eval.chord.NO_CHORD, mir_eval.chord.NO_CHORD) (intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels, est_intervals, est_labels) durations = mir_eval.util.intervals_to_durations(intervals) comparisons = mir_eval.chord.sevenths(ref_labels, est_labels) score = mir_eval.chord.weighted_accuracy(comparisons, durations) return score def tetrads_score(self, gt_path, est_path): (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path) ref_labels = lab_file_error_modify(ref_labels) (est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path) est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(), ref_intervals.max(), mir_eval.chord.NO_CHORD, mir_eval.chord.NO_CHORD) (intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels, est_intervals, est_labels) durations = mir_eval.util.intervals_to_durations(intervals) comparisons = mir_eval.chord.tetrads(ref_labels, est_labels) score = mir_eval.chord.weighted_accuracy(comparisons, durations) return score def majmin_score(self, gt_path, est_path): (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path) ref_labels = lab_file_error_modify(ref_labels) (est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path) est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(), ref_intervals.max(), mir_eval.chord.NO_CHORD, mir_eval.chord.NO_CHORD) (intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels, est_intervals, est_labels) durations = mir_eval.util.intervals_to_durations(intervals) comparisons = mir_eval.chord.majmin(ref_labels, est_labels) score = mir_eval.chord.weighted_accuracy(comparisons, durations) return score def mirex_score(self, gt_path, est_path): (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path) ref_labels = lab_file_error_modify(ref_labels) (est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path) est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(), ref_intervals.max(), mir_eval.chord.NO_CHORD, mir_eval.chord.NO_CHORD) (intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels, est_intervals, est_labels) durations = mir_eval.util.intervals_to_durations(intervals) comparisons = mir_eval.chord.mirex(ref_labels, est_labels) score = mir_eval.chord.weighted_accuracy(comparisons, durations) return score def lab_file_error_modify(ref_labels): for i in range(len(ref_labels)): if ref_labels[i][-2:] == ':4': ref_labels[i] = ref_labels[i].replace(':4', ':sus4') elif ref_labels[i][-2:] == ':6': ref_labels[i] = ref_labels[i].replace(':6', ':maj6') elif ref_labels[i][-4:] == ':6/2': ref_labels[i] = ref_labels[i].replace(':6/2', ':maj6/2') elif ref_labels[i] == 'Emin/4': ref_labels[i] = 'E:min/4' elif ref_labels[i] == 'A7/3': ref_labels[i] = 'A:7/3' elif ref_labels[i] == 'Bb7/3': ref_labels[i] = 'Bb:7/3' elif ref_labels[i] == 'Bb7/5': ref_labels[i] = 'Bb:7/5' elif ref_labels[i].find(':') == -1: if ref_labels[i].find('min') != -1: ref_labels[i] = ref_labels[i][:ref_labels[i].find('min')] + ':' + ref_labels[i][ref_labels[i].find('min'):] return ref_labels def root_majmin_score_calculation(valid_dataset, config, mean, std, device, model, model_type, verbose=False): valid_song_names = valid_dataset.song_names paths = valid_dataset.preprocessor.get_all_files() metrics_ = metrics() song_length_list = list() for path in paths: song_name, lab_file_path, mp3_file_path, _ = path if not song_name in valid_song_names: continue try: n_timestep = config.model['timestep'] feature, feature_per_second, song_length_second = audio_file_to_features(mp3_file_path, config) feature = feature.T feature = (feature - mean) / std time_unit = feature_per_second num_pad = n_timestep - (feature.shape[0] % n_timestep) feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0) num_instance = feature.shape[0] // n_timestep start_time = 0.0 lines = [] with torch.no_grad(): model.eval() feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device) for t in range(num_instance): if model_type == 'btc': encoder_output, _ = model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :]) prediction, _ = model.output_layer(encoder_output) prediction = prediction.squeeze() elif model_type == 'cnn' or model_type =='crnn': prediction, _, _, _ = model(feature[:, n_timestep * t:n_timestep * (t + 1), :], torch.randint(config.model['num_chords'], (n_timestep,)).to(device)) for i in range(n_timestep): if t == 0 and i == 0: prev_chord = prediction[i].item() continue if prediction[i].item() != prev_chord: lines.append( '%.6f %.6f %s\n' % ( start_time, time_unit * (n_timestep * t + i), idx2chord[prev_chord])) start_time = time_unit * (n_timestep * t + i) prev_chord = prediction[i].item() if t == num_instance - 1 and i + num_pad == n_timestep: if start_time != time_unit * (n_timestep * t + i): lines.append( '%.6f %.6f %s\n' % ( start_time, time_unit * (n_timestep * t + i), idx2chord[prev_chord])) break pid = os.getpid() tmp_path = 'tmp_' + str(pid) + '.lab' with open(tmp_path, 'w') as f: for line in lines: f.write(line) root_majmin = ['root', 'majmin'] for m in root_majmin: metrics_.score_list_dict[m].append(metrics_.score(metric=m, gt_path=lab_file_path, est_path=tmp_path)) song_length_list.append(song_length_second) if verbose: for m in root_majmin: print('song name %s, %s score : %.4f' % (song_name, m, metrics_.score_list_dict[m][-1])) except: print('song name %s\' lab file error' % song_name) tmp = song_length_list / np.sum(song_length_list) for m in root_majmin: metrics_.average_score[m] = np.sum(np.multiply(metrics_.score_list_dict[m], tmp)) return metrics_.score_list_dict, song_length_list, metrics_.average_score def root_majmin_score_calculation_crf(valid_dataset, config, mean, std, device, pre_model, model, model_type, verbose=False): valid_song_names = valid_dataset.song_names paths = valid_dataset.preprocessor.get_all_files() metrics_ = metrics() song_length_list = list() for path in paths: song_name, lab_file_path, mp3_file_path, _ = path if not song_name in valid_song_names: continue try: n_timestep = config.model['timestep'] feature, feature_per_second, song_length_second = audio_file_to_features(mp3_file_path, config) feature = feature.T feature = (feature - mean) / std time_unit = feature_per_second num_pad = n_timestep - (feature.shape[0] % n_timestep) feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0) num_instance = feature.shape[0] // n_timestep start_time = 0.0 lines = [] with torch.no_grad(): model.eval() feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device) for t in range(num_instance): if (model_type == 'cnn') or (model_type == 'crnn') or (model_type == 'btc'): logits = pre_model(feature[:, n_timestep * t:n_timestep * (t + 1), :], torch.randint(config.model['num_chords'], (n_timestep,)).to(device)) prediction, _ = model(logits, torch.randint(config.model['num_chords'], (n_timestep,)).to(device)) else: raise NotImplementedError for i in range(n_timestep): if t == 0 and i == 0: prev_chord = prediction[i].item() continue if prediction[i].item() != prev_chord: lines.append( '%.6f %.6f %s\n' % ( start_time, time_unit * (n_timestep * t + i), idx2chord[prev_chord])) start_time = time_unit * (n_timestep * t + i) prev_chord = prediction[i].item() if t == num_instance - 1 and i + num_pad == n_timestep: if start_time != time_unit * (n_timestep * t + i): lines.append( '%.6f %.6f %s\n' % ( start_time, time_unit * (n_timestep * t + i), idx2chord[prev_chord])) break pid = os.getpid() tmp_path = 'tmp_' + str(pid) + '.lab' with open(tmp_path, 'w') as f: for line in lines: f.write(line) root_majmin = ['root', 'majmin'] for m in root_majmin: metrics_.score_list_dict[m].append(metrics_.score(metric=m, gt_path=lab_file_path, est_path=tmp_path)) song_length_list.append(song_length_second) if verbose: for m in root_majmin: print('song name %s, %s score : %.4f' % (song_name, m, metrics_.score_list_dict[m][-1])) except: print('song name %s\' lab file error' % song_name) tmp = song_length_list / np.sum(song_length_list) for m in root_majmin: metrics_.average_score[m] = np.sum(np.multiply(metrics_.score_list_dict[m], tmp)) return metrics_.score_list_dict, song_length_list, metrics_.average_score def large_voca_score_calculation(valid_dataset, config, mean, std, device, model, model_type, verbose=False): idx2voca = idx2voca_chord() valid_song_names = valid_dataset.song_names paths = valid_dataset.preprocessor.get_all_files() metrics_ = metrics() song_length_list = list() for path in paths: song_name, lab_file_path, mp3_file_path, _ = path if not song_name in valid_song_names: continue try: n_timestep = config.model['timestep'] feature, feature_per_second, song_length_second = audio_file_to_features(mp3_file_path, config) feature = feature.T feature = (feature - mean) / std time_unit = feature_per_second num_pad = n_timestep - (feature.shape[0] % n_timestep) feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0) num_instance = feature.shape[0] // n_timestep start_time = 0.0 lines = [] with torch.no_grad(): model.eval() feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device) for t in range(num_instance): if model_type == 'btc': encoder_output, _ = model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :]) prediction, _ = model.output_layer(encoder_output) prediction = prediction.squeeze() elif model_type == 'cnn' or model_type =='crnn': prediction, _, _, _ = model(feature[:, n_timestep * t:n_timestep * (t + 1), :], torch.randint(config.model['num_chords'], (n_timestep,)).to(device)) for i in range(n_timestep): if t == 0 and i == 0: prev_chord = prediction[i].item() continue if prediction[i].item() != prev_chord: lines.append( '%.6f %.6f %s\n' % ( start_time, time_unit * (n_timestep * t + i), idx2voca[prev_chord])) start_time = time_unit * (n_timestep * t + i) prev_chord = prediction[i].item() if t == num_instance - 1 and i + num_pad == n_timestep: if start_time != time_unit * (n_timestep * t + i): lines.append( '%.6f %.6f %s\n' % ( start_time, time_unit * (n_timestep * t + i), idx2voca[prev_chord])) break pid = os.getpid() tmp_path = 'tmp_' + str(pid) + '.lab' with open(tmp_path, 'w') as f: for line in lines: f.write(line) for m in metrics_.score_metrics: metrics_.score_list_dict[m].append(metrics_.score(metric=m, gt_path=lab_file_path, est_path=tmp_path)) song_length_list.append(song_length_second) if verbose: for m in metrics_.score_metrics: print('song name %s, %s score : %.4f' % (song_name, m, metrics_.score_list_dict[m][-1])) except: print('song name %s\' lab file error' % song_name) tmp = song_length_list / np.sum(song_length_list) for m in metrics_.score_metrics: metrics_.average_score[m] = np.sum(np.multiply(metrics_.score_list_dict[m], tmp)) return metrics_.score_list_dict, song_length_list, metrics_.average_score def large_voca_score_calculation_crf(valid_dataset, config, mean, std, device, pre_model, model, model_type, verbose=False): idx2voca = idx2voca_chord() valid_song_names = valid_dataset.song_names paths = valid_dataset.preprocessor.get_all_files() metrics_ = metrics() song_length_list = list() for path in paths: song_name, lab_file_path, mp3_file_path, _ = path if not song_name in valid_song_names: continue try: n_timestep = config.model['timestep'] feature, feature_per_second, song_length_second = audio_file_to_features(mp3_file_path, config) feature = feature.T feature = (feature - mean) / std time_unit = feature_per_second num_pad = n_timestep - (feature.shape[0] % n_timestep) feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0) num_instance = feature.shape[0] // n_timestep start_time = 0.0 lines = [] with torch.no_grad(): model.eval() feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device) for t in range(num_instance): if (model_type == 'cnn') or (model_type == 'crnn') or (model_type == 'btc'): logits = pre_model(feature[:, n_timestep * t:n_timestep * (t + 1), :], torch.randint(config.model['num_chords'], (n_timestep,)).to(device)) prediction, _ = model(logits, torch.randint(config.model['num_chords'], (n_timestep,)).to(device)) else: raise NotImplementedError for i in range(n_timestep): if t == 0 and i == 0: prev_chord = prediction[i].item() continue if prediction[i].item() != prev_chord: lines.append( '%.6f %.6f %s\n' % ( start_time, time_unit * (n_timestep * t + i), idx2voca[prev_chord])) start_time = time_unit * (n_timestep * t + i) prev_chord = prediction[i].item() if t == num_instance - 1 and i + num_pad == n_timestep: if start_time != time_unit * (n_timestep * t + i): lines.append( '%.6f %.6f %s\n' % ( start_time, time_unit * (n_timestep * t + i), idx2voca[prev_chord])) break pid = os.getpid() tmp_path = 'tmp_' + str(pid) + '.lab' with open(tmp_path, 'w') as f: for line in lines: f.write(line) for m in metrics_.score_metrics: metrics_.score_list_dict[m].append(metrics_.score(metric=m, gt_path=lab_file_path, est_path=tmp_path)) song_length_list.append(song_length_second) if verbose: for m in metrics_.score_metrics: print('song name %s, %s score : %.4f' % (song_name, m, metrics_.score_list_dict[m][-1])) except: print('song name %s\' lab file error' % song_name) tmp = song_length_list / np.sum(song_length_list) for m in metrics_.score_metrics: metrics_.average_score[m] = np.sum(np.multiply(metrics_.score_list_dict[m], tmp)) return metrics_.score_list_dict, song_length_list, metrics_.average_score ================================================ FILE: utils/preprocess.py ================================================ import os import librosa from utils.chords import Chords import re from enum import Enum import pyrubberband as pyrb import torch import math class FeatureTypes(Enum): cqt = 'cqt' class Preprocess(): def __init__(self, config, feature_to_use, dataset_names, root_dir): self.config = config self.dataset_names = dataset_names self.root_path = root_dir + '/' self.time_interval = config.feature["hop_length"]/config.mp3["song_hz"] self.no_of_chord_datapoints_per_sequence = math.ceil(config.mp3['inst_len'] / self.time_interval) self.Chord_class = Chords() # isophonic self.isophonic_directory = self.root_path + 'isophonic/' # uspop self.uspop_directory = self.root_path + 'uspop/' self.uspop_audio_path = 'audio/' self.uspop_lab_path = 'annotations/uspopLabels/' self.uspop_index_path = 'annotations/uspopLabels.txt' # robbie williams self.robbie_williams_directory = self.root_path + 'robbiewilliams/' self.robbie_williams_audio_path = 'audio/' self.robbie_williams_lab_path = 'chords/' self.feature_name = feature_to_use self.is_cut_last_chord = False def find_mp3_path(self, dirpath, word): for filename in os.listdir(dirpath): last_dir = dirpath.split("/")[-2] if ".mp3" in filename: tmp = filename.replace(".mp3", "") tmp = tmp.replace(last_dir, "") filename_lower = tmp.lower() filename_lower = " ".join(re.findall("[a-zA-Z]+", filename_lower)) if word.lower().replace(" ", "") in filename_lower.replace(" ", ""): return filename def find_mp3_path_robbiewilliams(self, dirpath, word): for filename in os.listdir(dirpath): if ".mp3" in filename: tmp = filename.replace(".mp3", "") filename_lower = tmp.lower() filename_lower = filename_lower.replace("robbie williams", "") filename_lower = " ".join(re.findall("[a-zA-Z]+", filename_lower)) filename_lower = self.song_pre(filename_lower) if self.song_pre(word.lower()).replace(" ", "") in filename_lower.replace(" ", ""): return filename def get_all_files(self): res_list = [] # isophonic if "isophonic" in self.dataset_names: for dirpath, dirnames, filenames in os.walk(self.isophonic_directory): if not dirnames: for filename in filenames: if ".lab" in filename: tmp = filename.replace(".lab", "") song_name = " ".join(re.findall("[a-zA-Z]+", tmp)).replace("CD", "") mp3_path = self.find_mp3_path(dirpath, song_name) res_list.append([song_name, os.path.join(dirpath, filename), os.path.join(dirpath, mp3_path), os.path.join(self.root_path, "result", "isophonic")]) # uspop if "uspop" in self.dataset_names: with open(os.path.join(self.uspop_directory, self.uspop_index_path)) as f: uspop_lab_list = f.readlines() uspop_lab_list = [x.strip() for x in uspop_lab_list] for lab_path in uspop_lab_list: spl = lab_path.split('/') lab_artist = self.uspop_pre(spl[2]) lab_title = self.uspop_pre(spl[4][3:-4]) lab_path = lab_path.replace('./uspopLabels/', '') lab_path = os.path.join(self.uspop_directory, self.uspop_lab_path, lab_path) for filename in os.listdir(os.path.join(self.uspop_directory, self.uspop_audio_path)): if not '.csv' in filename: spl = filename.split('-') mp3_artist = self.uspop_pre(spl[0]) mp3_title = self.uspop_pre(spl[1][:-4]) if lab_artist == mp3_artist and lab_title == mp3_title: res_list.append([mp3_artist + mp3_title, lab_path, os.path.join(self.uspop_directory, self.uspop_audio_path, filename), os.path.join(self.root_path, "result", "uspop")]) break # robbie williams if "robbiewilliams" in self.dataset_names: for dirpath, dirnames, filenames in os.walk(self.robbie_williams_directory): if not dirnames: for filename in filenames: if ".txt" in filename and (not 'README' in filename): tmp = filename.replace(".txt", "") song_name = " ".join(re.findall("[a-zA-Z]+", tmp)).replace("GTChords", "") mp3_dir = dirpath.replace("chords", "audio") mp3_path = self.find_mp3_path_robbiewilliams(mp3_dir, song_name) res_list.append([song_name, os.path.join(dirpath, filename), os.path.join(mp3_dir, mp3_path), os.path.join(self.root_path, "result", "robbiewilliams")]) return res_list def uspop_pre(self, text): text = text.lower() text = text.replace('_', '') text = text.replace(' ', '') text = " ".join(re.findall("[a-zA-Z]+", text)) return text def song_pre(self, text): to_remove = ["'", '`', '(', ')', ' ', '&', 'and', 'And'] for remove in to_remove: text = text.replace(remove, '') return text def config_to_folder(self): mp3_config = self.config.mp3 feature_config = self.config.feature mp3_string = "%d_%.1f_%.1f" % \ (mp3_config['song_hz'], mp3_config['inst_len'], mp3_config['skip_interval']) feature_string = "%s_%d_%d_%d" % \ (self.feature_name.value, feature_config['n_bins'], feature_config['bins_per_octave'], feature_config['hop_length']) return mp3_config, feature_config, mp3_string, feature_string def generate_labels_features_new(self, all_list): pid = os.getpid() mp3_config, feature_config, mp3_str, feature_str = self.config_to_folder() i = 0 # number of songs j = 0 # number of impossible songs k = 0 # number of tried songs total = 0 # number of generated instances stretch_factors = [1.0] shift_factors = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6] loop_broken = False for song_name, lab_path, mp3_path, save_path in all_list: # different song initialization if loop_broken: loop_broken = False i += 1 print(pid, "generating features from ...", os.path.join(mp3_path)) if i % 10 == 0: print(i, ' th song') original_wav, sr = librosa.load(os.path.join(mp3_path), sr=mp3_config['song_hz']) # make result path if not exists # save_path, mp3_string, feature_string, song_name, aug.pt result_path = os.path.join(save_path, mp3_str, feature_str, song_name.strip()) if not os.path.exists(result_path): os.makedirs(result_path) # calculate result for stretch_factor in stretch_factors: if loop_broken: loop_broken = False break for shift_factor in shift_factors: # for filename idx = 0 chord_info = self.Chord_class.get_converted_chord(os.path.join(lab_path)) k += 1 # stretch original sound and chord info x = pyrb.time_stretch(original_wav, sr, stretch_factor) x = pyrb.pitch_shift(x, sr, shift_factor) audio_length = x.shape[0] chord_info['start'] = chord_info['start'] * 1/stretch_factor chord_info['end'] = chord_info['end'] * 1/stretch_factor last_sec = chord_info.iloc[-1]['end'] last_sec_hz = int(last_sec * mp3_config['song_hz']) if audio_length + mp3_config['skip_interval'] < last_sec_hz: print('loaded song is too short :', song_name) loop_broken = True j += 1 break elif audio_length > last_sec_hz: x = x[:last_sec_hz] origin_length = last_sec_hz origin_length_in_sec = origin_length / mp3_config['song_hz'] current_start_second = 0 # get chord list between current_start_second and current+song_length while current_start_second + mp3_config['inst_len'] < origin_length_in_sec: inst_start_sec = current_start_second curSec = current_start_second chord_list = [] # extract chord per 1/self.time_interval while curSec < inst_start_sec + mp3_config['inst_len']: try: available_chords = chord_info.loc[(chord_info['start'] <= curSec) & ( chord_info['end'] > curSec + self.time_interval)].copy() if len(available_chords) == 0: available_chords = chord_info.loc[((chord_info['start'] >= curSec) & ( chord_info['start'] <= curSec + self.time_interval)) | ( (chord_info['end'] >= curSec) & ( chord_info['end'] <= curSec + self.time_interval))].copy() if len(available_chords) == 1: chord = available_chords['chord_id'].iloc[0] elif len(available_chords) > 1: max_starts = available_chords.apply(lambda row: max(row['start'], curSec), axis=1) available_chords['max_start'] = max_starts min_ends = available_chords.apply( lambda row: min(row.end, curSec + self.time_interval), axis=1) available_chords['min_end'] = min_ends chords_lengths = available_chords['min_end'] - available_chords['max_start'] available_chords['chord_length'] = chords_lengths chord = available_chords.ix[available_chords['chord_length'].idxmax()]['chord_id'] else: chord = 24 except Exception as e: chord = 24 print(e) print(pid, "no chord") raise RuntimeError() finally: # convert chord by shift factor if chord != 24: chord += shift_factor * 2 chord = chord % 24 chord_list.append(chord) curSec += self.time_interval if len(chord_list) == self.no_of_chord_datapoints_per_sequence: try: sequence_start_time = current_start_second sequence_end_time = current_start_second + mp3_config['inst_len'] start_index = int(sequence_start_time * mp3_config['song_hz']) end_index = int(sequence_end_time * mp3_config['song_hz']) song_seq = x[start_index:end_index] etc = '%.1f_%.1f' % ( current_start_second, current_start_second + mp3_config['inst_len']) aug = '%.2f_%i' % (stretch_factor, shift_factor) if self.feature_name == FeatureTypes.cqt: # print(pid, "make feature") feature = librosa.cqt(song_seq, sr=sr, n_bins=feature_config['n_bins'], bins_per_octave=feature_config['bins_per_octave'], hop_length=feature_config['hop_length']) else: raise NotImplementedError if feature.shape[1] > self.no_of_chord_datapoints_per_sequence: feature = feature[:, :self.no_of_chord_datapoints_per_sequence] if feature.shape[1] != self.no_of_chord_datapoints_per_sequence: print('loaded features length is too short :', song_name) loop_broken = True j += 1 break result = { 'feature': feature, 'chord': chord_list, 'etc': etc } # save_path, mp3_string, feature_string, song_name, aug.pt filename = aug + "_" + str(idx) + ".pt" torch.save(result, os.path.join(result_path, filename)) idx += 1 total += 1 except Exception as e: print(e) print(pid, "feature error") raise RuntimeError() else: print("invalid number of chord datapoints in sequence :", len(chord_list)) current_start_second += mp3_config['skip_interval'] print(pid, "total instances: %d" % total) def generate_labels_features_voca(self, all_list): pid = os.getpid() mp3_config, feature_config, mp3_str, feature_str = self.config_to_folder() i = 0 # number of songs j = 0 # number of impossible songs k = 0 # number of tried songs total = 0 # number of generated instances stretch_factors = [1.0] shift_factors = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6] loop_broken = False for song_name, lab_path, mp3_path, save_path in all_list: save_path = save_path + '_voca' # different song initialization if loop_broken: loop_broken = False i += 1 print(pid, "generating features from ...", os.path.join(mp3_path)) if i % 10 == 0: print(i, ' th song') original_wav, sr = librosa.load(os.path.join(mp3_path), sr=mp3_config['song_hz']) # save_path, mp3_string, feature_string, song_name, aug.pt result_path = os.path.join(save_path, mp3_str, feature_str, song_name.strip()) if not os.path.exists(result_path): os.makedirs(result_path) # calculate result for stretch_factor in stretch_factors: if loop_broken: loop_broken = False break for shift_factor in shift_factors: # for filename idx = 0 try: chord_info = self.Chord_class.get_converted_chord_voca(os.path.join(lab_path)) except Exception as e: print(e) print(pid, " chord lab file error : %s" % song_name) loop_broken = True j += 1 break k += 1 # stretch original sound and chord info x = pyrb.time_stretch(original_wav, sr, stretch_factor) x = pyrb.pitch_shift(x, sr, shift_factor) audio_length = x.shape[0] chord_info['start'] = chord_info['start'] * 1/stretch_factor chord_info['end'] = chord_info['end'] * 1/stretch_factor last_sec = chord_info.iloc[-1]['end'] last_sec_hz = int(last_sec * mp3_config['song_hz']) if audio_length + mp3_config['skip_interval'] < last_sec_hz: print('loaded song is too short :', song_name) loop_broken = True j += 1 break elif audio_length > last_sec_hz: x = x[:last_sec_hz] origin_length = last_sec_hz origin_length_in_sec = origin_length / mp3_config['song_hz'] current_start_second = 0 # get chord list between current_start_second and current+song_length while current_start_second + mp3_config['inst_len'] < origin_length_in_sec: inst_start_sec = current_start_second curSec = current_start_second chord_list = [] # extract chord per 1/self.time_interval while curSec < inst_start_sec + mp3_config['inst_len']: try: available_chords = chord_info.loc[(chord_info['start'] <= curSec) & (chord_info['end'] > curSec + self.time_interval)].copy() if len(available_chords) == 0: available_chords = chord_info.loc[((chord_info['start'] >= curSec) & (chord_info['start'] <= curSec + self.time_interval)) | ((chord_info['end'] >= curSec) & (chord_info['end'] <= curSec + self.time_interval))].copy() if len(available_chords) == 1: chord = available_chords['chord_id'].iloc[0] elif len(available_chords) > 1: max_starts = available_chords.apply(lambda row: max(row['start'], curSec),axis=1) available_chords['max_start'] = max_starts min_ends = available_chords.apply(lambda row: min(row.end, curSec + self.time_interval), axis=1) available_chords['min_end'] = min_ends chords_lengths = available_chords['min_end'] - available_chords['max_start'] available_chords['chord_length'] = chords_lengths chord = available_chords.ix[available_chords['chord_length'].idxmax()]['chord_id'] else: chord = 169 except Exception as e: chord = 169 print(e) print(pid, "no chord") raise RuntimeError() finally: # convert chord by shift factor if chord != 169 and chord != 168: chord += shift_factor * 14 chord = chord % 168 chord_list.append(chord) curSec += self.time_interval if len(chord_list) == self.no_of_chord_datapoints_per_sequence: try: sequence_start_time = current_start_second sequence_end_time = current_start_second + mp3_config['inst_len'] start_index = int(sequence_start_time * mp3_config['song_hz']) end_index = int(sequence_end_time * mp3_config['song_hz']) song_seq = x[start_index:end_index] etc = '%.1f_%.1f' % ( current_start_second, current_start_second + mp3_config['inst_len']) aug = '%.2f_%i' % (stretch_factor, shift_factor) if self.feature_name == FeatureTypes.cqt: feature = librosa.cqt(song_seq, sr=sr, n_bins=feature_config['n_bins'], bins_per_octave=feature_config['bins_per_octave'], hop_length=feature_config['hop_length']) else: raise NotImplementedError if feature.shape[1] > self.no_of_chord_datapoints_per_sequence: feature = feature[:, :self.no_of_chord_datapoints_per_sequence] if feature.shape[1] != self.no_of_chord_datapoints_per_sequence: print('loaded features length is too short :', song_name) loop_broken = True j += 1 break result = { 'feature': feature, 'chord': chord_list, 'etc': etc } # save_path, mp3_string, feature_string, song_name, aug.pt filename = aug + "_" + str(idx) + ".pt" torch.save(result, os.path.join(result_path, filename)) idx += 1 total += 1 except Exception as e: print(e) print(pid, "feature error") raise RuntimeError() else: print("invalid number of chord datapoints in sequence :", len(chord_list)) current_start_second += mp3_config['skip_interval'] print(pid, "total instances: %d" % total) ================================================ FILE: utils/pytorch_utils.py ================================================ import torch import numpy as np import os import math from utils import logger use_cuda = torch.cuda.is_available() # optimization # reference: http://pytorch.org/docs/master/_modules/torch/optim/lr_scheduler.html#ReduceLROnPlateau def adjusting_learning_rate(optimizer, factor=.5, min_lr=0.00001): for i, param_group in enumerate(optimizer.param_groups): old_lr = float(param_group['lr']) new_lr = max(old_lr * factor, min_lr) param_group['lr'] = new_lr logger.info('adjusting learning rate from %.6f to %.6f' % (old_lr, new_lr)) # model save and loading def load_model(asset_path, model, optimizer, restore_epoch=0): if os.path.isfile(os.path.join(asset_path, 'model', 'checkpoint_%d.pth.tar' % restore_epoch), map_location=lambda storage, loc: storage): checkpoint = torch.load(os.path.join(asset_path, 'model', 'checkpoint_%d.pth.tar' % restore_epoch)) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) current_step = checkpoint['current_step'] logger.info("restore model with %d epoch" % restore_epoch) else: logger.info("no checkpoint with %d epoch" % restore_epoch) current_step = 0 return model, optimizer, current_step ================================================ FILE: utils/tf_logger.py ================================================ import tensorflow as tf import numpy as np import scipy.misc try: from StringIO import StringIO # Python 2.7 except ImportError: from io import BytesIO # Python 3.x class TF_Logger(object): def __init__(self, log_dir): """Create a summary writer logging to log_dir.""" self.writer = tf.summary.FileWriter(log_dir) def scalar_summary(self, tag, value, step): """Log a scalar variable.""" summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) self.writer.add_summary(summary, step) def image_summary(self, tag, images, step): """Log a list of images.""" img_summaries = [] for i, img in enumerate(images): # Write the image to a string try: s = StringIO() except: s = BytesIO() scipy.misc.toimage(img).save(s, format="png") # Create an Image object img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), height=img.shape[0], width=img.shape[1]) # Create a Summary value img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) # Create and write Summary summary = tf.Summary(value=img_summaries) self.writer.add_summary(summary, step) def histo_summary(self, tag, values, step, bins=1000): """Log a histogram of the tensor of values.""" # Create a histogram using numpy counts, bin_edges = np.histogram(values, bins=bins) # Fill the fields of the histogram proto hist = tf.HistogramProto() hist.min = float(np.min(values)) hist.max = float(np.max(values)) hist.num = int(np.prod(values.shape)) hist.sum = float(np.sum(values)) hist.sum_squares = float(np.sum(values ** 2)) # Drop the start of the first bin bin_edges = bin_edges[1:] # Add bin edges and counts for edge in bin_edges: hist.bucket_limit.append(edge) for c in counts: hist.bucket.append(c) # Create and write Summary summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) self.writer.add_summary(summary, step) self.writer.flush() ================================================ FILE: utils/transformer_modules.py ================================================ from __future__ import absolute_import from __future__ import division from __future__ import print_function import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math def _gen_bias_mask(max_length): """ Generates bias values (-Inf) to mask future timesteps during attention """ np_mask = np.triu(np.full([max_length, max_length], -np.inf), 1) torch_mask = torch.from_numpy(np_mask).type(torch.FloatTensor) return torch_mask.unsqueeze(0).unsqueeze(1) def _gen_timing_signal(length, channels, min_timescale=1.0, max_timescale=1.0e4): """ Generates a [1, length, channels] timing signal consisting of sinusoids Adapted from: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py """ position = np.arange(length) num_timescales = channels // 2 log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / (float(num_timescales) - 1)) inv_timescales = min_timescale * np.exp( np.arange(num_timescales).astype(np.float) * -log_timescale_increment) scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales, 0) signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1) signal = np.pad(signal, [[0, 0], [0, channels % 2]], 'constant', constant_values=[0.0, 0.0]) signal = signal.reshape([1, length, channels]) return torch.from_numpy(signal).type(torch.FloatTensor) class LayerNorm(nn.Module): # Borrowed from jekbradbury # https://github.com/pytorch/pytorch/issues/1959 def __init__(self, features, eps=1e-6): super(LayerNorm, self).__init__() self.gamma = nn.Parameter(torch.ones(features)) self.beta = nn.Parameter(torch.zeros(features)) self.eps = eps def forward(self, x): mean = x.mean(-1, keepdim=True) std = x.std(-1, keepdim=True) return self.gamma * (x - mean) / (std + self.eps) + self.beta class OutputLayer(nn.Module): """ Abstract base class for output layer. Handles projection to output labels """ def __init__(self, hidden_size, output_size, probs_out=False): super(OutputLayer, self).__init__() self.output_size = output_size self.output_projection = nn.Linear(hidden_size, output_size) self.probs_out = probs_out self.lstm = nn.LSTM(input_size=hidden_size, hidden_size=int(hidden_size/2), batch_first=True, bidirectional=True) self.hidden_size = hidden_size def loss(self, hidden, labels): raise NotImplementedError('Must implement {}.loss'.format(self.__class__.__name__)) class SoftmaxOutputLayer(OutputLayer): """ Implements a softmax based output layer """ def forward(self, hidden): logits = self.output_projection(hidden) probs = F.softmax(logits, -1) # _, predictions = torch.max(probs, dim=-1) topk, indices = torch.topk(probs, 2) predictions = indices[:,:,0] second = indices[:,:,1] if self.probs_out is True: return logits # return probs return predictions, second def loss(self, hidden, labels): logits = self.output_projection(hidden) log_probs = F.log_softmax(logits, -1) return F.nll_loss(log_probs.view(-1, self.output_size), labels.view(-1)) class MultiHeadAttention(nn.Module): """ Multi-head attention as per https://arxiv.org/pdf/1706.03762.pdf Refer Figure 2 """ def __init__(self, input_depth, total_key_depth, total_value_depth, output_depth, num_heads, bias_mask=None, dropout=0.0, attention_map=False): """ Parameters: input_depth: Size of last dimension of input total_key_depth: Size of last dimension of keys. Must be divisible by num_head total_value_depth: Size of last dimension of values. Must be divisible by num_head output_depth: Size last dimension of the final output num_heads: Number of attention heads bias_mask: Masking tensor to prevent connections to future elements dropout: Dropout probability (Should be non-zero only during training) """ super(MultiHeadAttention, self).__init__() # Checks borrowed from # https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py if total_key_depth % num_heads != 0: raise ValueError("Key depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_key_depth, num_heads)) if total_value_depth % num_heads != 0: raise ValueError("Value depth (%d) must be divisible by the number of " "attention heads (%d)." % (total_value_depth, num_heads)) self.attention_map = attention_map self.num_heads = num_heads self.query_scale = (total_key_depth // num_heads) ** -0.5 self.bias_mask = bias_mask # Key and query depth will be same self.query_linear = nn.Linear(input_depth, total_key_depth, bias=False) self.key_linear = nn.Linear(input_depth, total_key_depth, bias=False) self.value_linear = nn.Linear(input_depth, total_value_depth, bias=False) self.output_linear = nn.Linear(total_value_depth, output_depth, bias=False) self.dropout = nn.Dropout(dropout) def _split_heads(self, x): """ Split x such to add an extra num_heads dimension Input: x: a Tensor with shape [batch_size, seq_length, depth] Returns: A Tensor with shape [batch_size, num_heads, seq_length, depth/num_heads] """ if len(x.shape) != 3: raise ValueError("x must have rank 3") shape = x.shape return x.view(shape[0], shape[1], self.num_heads, shape[2] // self.num_heads).permute(0, 2, 1, 3) def _merge_heads(self, x): """ Merge the extra num_heads into the last dimension Input: x: a Tensor with shape [batch_size, num_heads, seq_length, depth/num_heads] Returns: A Tensor with shape [batch_size, seq_length, depth] """ if len(x.shape) != 4: raise ValueError("x must have rank 4") shape = x.shape return x.permute(0, 2, 1, 3).contiguous().view(shape[0], shape[2], shape[3] * self.num_heads) def forward(self, queries, keys, values): # Do a linear for each component queries = self.query_linear(queries) keys = self.key_linear(keys) values = self.value_linear(values) # Split into multiple heads queries = self._split_heads(queries) keys = self._split_heads(keys) values = self._split_heads(values) # Scale queries queries *= self.query_scale # Combine queries and keys logits = torch.matmul(queries, keys.permute(0, 1, 3, 2)) # Add bias to mask future values if self.bias_mask is not None: logits += self.bias_mask[:, :, :logits.shape[-2], :logits.shape[-1]].type_as(logits.data) # Convert to probabilites weights = nn.functional.softmax(logits, dim=-1) # Dropout weights = self.dropout(weights) # Combine with values to get context contexts = torch.matmul(weights, values) # Merge heads contexts = self._merge_heads(contexts) # contexts = torch.tanh(contexts) # Linear to get output outputs = self.output_linear(contexts) if self.attention_map is True: return outputs, weights return outputs class Conv(nn.Module): """ Convenience class that does padding and convolution for inputs in the format [batch_size, sequence length, hidden size] """ def __init__(self, input_size, output_size, kernel_size, pad_type): """ Parameters: input_size: Input feature size output_size: Output feature size kernel_size: Kernel width pad_type: left -> pad on the left side (to mask future data_loader), both -> pad on both sides """ super(Conv, self).__init__() padding = (kernel_size - 1, 0) if pad_type == 'left' else (kernel_size // 2, (kernel_size - 1) // 2) self.pad = nn.ConstantPad1d(padding, 0) self.conv = nn.Conv1d(input_size, output_size, kernel_size=kernel_size, padding=0) def forward(self, inputs): inputs = self.pad(inputs.permute(0, 2, 1)) outputs = self.conv(inputs).permute(0, 2, 1) return outputs class PositionwiseFeedForward(nn.Module): """ Does a Linear + RELU + Linear on each of the timesteps """ def __init__(self, input_depth, filter_size, output_depth, layer_config='ll', padding='left', dropout=0.0): """ Parameters: input_depth: Size of last dimension of input filter_size: Hidden size of the middle layer output_depth: Size last dimension of the final output layer_config: ll -> linear + ReLU + linear cc -> conv + ReLU + conv etc. padding: left -> pad on the left side (to mask future data_loader), both -> pad on both sides dropout: Dropout probability (Should be non-zero only during training) """ super(PositionwiseFeedForward, self).__init__() layers = [] sizes = ([(input_depth, filter_size)] + [(filter_size, filter_size)] * (len(layer_config) - 2) + [(filter_size, output_depth)]) for lc, s in zip(list(layer_config), sizes): if lc == 'l': layers.append(nn.Linear(*s)) elif lc == 'c': layers.append(Conv(*s, kernel_size=3, pad_type=padding)) else: raise ValueError("Unknown layer type {}".format(lc)) self.layers = nn.ModuleList(layers) self.relu = nn.ReLU() self.dropout = nn.Dropout(dropout) def forward(self, inputs): x = inputs for i, layer in enumerate(self.layers): x = layer(x) if i < len(self.layers): x = self.relu(x) x = self.dropout(x) return x