[
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2019 Jonggwon Park\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# A Bi-Directional Transformer for Musical Chord Recognition\n\nThis repository has the source codes for the paper \"A Bi-Directional Transformer for Musical Chord Recognition\"(ISMIR19).\n\n<img src=\"png/model.png\">\n\n## Requirements\n- pytorch >= 1.0.0\n- numpy >= 1.16.2\n- pandas >= 0.24.1\n- pyrubberband >= 0.3.0\n- librosa >= 0.6.3\n- pyyaml >= 3.13\n- mir_eval >= 0.5\n- pretty_midi >= 0.2.8\n\n## File descriptions\n  * `audio_dataset.py` : loads data and preprocesses label files to chord labels and mp3 files to constant-q transformation. \n  * `btc_model.py` : contains pytorch implementation of BTC.\n  * `train.py` : for training. \n  * `crf_model.py` : contatins pytorch implementation of Conditional Random Fields (CRFs) .\n  * `baseline_models.py` : contains the codes of baseline models.\n  * `train_crf.py` : for training CRFs.  \n  * `run_config.yaml` : includes hyper parameters and paths that are needed.\n  * `test.py` : for recognizing chord from audio file. \n\n## Using BTC : Recognizing chords from files in audio directory\n\n### Using BTC from command line\n```bash \n$ python test.py --audio_dir audio_folder --save_dir save_folder --voca False\n```\n  * audio_dir : a folder of audio files for chord recognition (default: './test')\n  * save_dir : a forder for saving recognition results (default: './test')\n  * voca : False means major and minor label type, and True means large vocabulary label type (default: False)\n  \nThe resulting files are lab files of the form shown below and midi files.\n\n  <img src=\"png/example.png\">\n\n## Attention Map\nThe figures represent the probability values of the attention of self-attention layers 1, 3, 5 and 8 respectively. The\nlayers that best represent the different characteristics of each layers were chosen. The input audio is the song \"Just A Girl\"\n(0m30s ~ 0m40s) by No Doubt from UsPop2002, which was in evaluation data.\n  <img src=\"png/attention.png\">\n\n## Data\nWe used Isophonics[1], Robbie Williams[2], UsPop2002[3] dataset which consists of chord label files. Due to copyright issue, these datasets do not include audio files. The audio files used in this work were collected from online music service providers.\n\n[1] http://isophonics.net/datasets \n\n[2] B. Di Giorgi, M. Zanoni, A. Sarti, and S. Tubaro. Automatic\nchord recognition based on the probabilistic\nmodeling of diatonic modal harmony. In Proc. of the\n8th International Workshop on Multidimensional Systems,\nErlangen, Germany, 2013.\n\n[3] https://github.com/tmc323/Chord-Annotations\n\n## Reference\n  * pytorch implementation of Transformer and Crf: https://github.com/kolloldas/torchnlp \n\n## Comments\n  * Any comments for the codes are always welcome.\n\n"
  },
  {
    "path": "audio_dataset.py",
    "content": "import numpy as np\nimport os\nimport torch\nfrom torch.utils.data import Dataset, DataLoader\nfrom utils.preprocess import Preprocess, FeatureTypes\nimport math\nfrom multiprocessing import Pool\nfrom sortedcontainers import SortedList\n\nclass AudioDataset(Dataset):\n    def __init__(self, config, root_dir='/data/music/chord_recognition', dataset_names=('isophonic',),\n                 featuretype=FeatureTypes.cqt, num_workers=20, train=False, preprocessing=False, resize=None, kfold=4):\n        super(AudioDataset, self).__init__()\n\n        self.config = config\n        self.root_dir = root_dir\n        self.dataset_names = dataset_names\n        self.preprocessor = Preprocess(config, featuretype, dataset_names, self.root_dir)\n        self.resize = resize\n        self.train = train\n        self.ratio = config.experiment['data_ratio']\n\n        # preprocessing hyperparameters\n        # song_hz, n_bins, bins_per_octave, hop_length\n        mp3_config = config.mp3\n        feature_config = config.feature\n        self.mp3_string = \"%d_%.1f_%.1f\" % \\\n                          (mp3_config['song_hz'], mp3_config['inst_len'],\n                           mp3_config['skip_interval'])\n        self.feature_string = \"%s_%d_%d_%d\" % \\\n                              (featuretype.value, feature_config['n_bins'], feature_config['bins_per_octave'], feature_config['hop_length'])\n\n        if feature_config['large_voca'] == True:\n            # store paths if exists\n            is_preprocessed = True if os.path.exists(os.path.join(root_dir, 'result', dataset_names[0]+'_voca', self.mp3_string, self.feature_string)) else False\n            if (not is_preprocessed) | preprocessing:\n                midi_paths = self.preprocessor.get_all_files()\n\n                if num_workers > 1:\n                    num_path_per_process = math.ceil(len(midi_paths) / num_workers)\n                    args = [midi_paths[i * num_path_per_process:(i + 1) * num_path_per_process] for i in range(num_workers)]\n\n                    # start process\n                    p = Pool(processes=num_workers)\n                    p.map(self.preprocessor.generate_labels_features_voca, args)\n\n                    p.close()\n                else:\n                    self.preprocessor.generate_labels_features_voca(midi_paths)\n\n            # kfold is 5 fold index ( 0, 1, 2, 3, 4 )\n            self.song_names, self.paths = self.get_paths_voca(kfold=kfold)\n        else:\n            # store paths if exists\n            is_preprocessed = True if os.path.exists(os.path.join(root_dir, 'result', dataset_names[0], self.mp3_string, self.feature_string)) else False\n            if (not is_preprocessed) | preprocessing:\n                midi_paths = self.preprocessor.get_all_files()\n\n                if num_workers > 1:\n                    num_path_per_process = math.ceil(len(midi_paths) / num_workers)\n                    args = [midi_paths[i * num_path_per_process:(i + 1) * num_path_per_process]\n                            for i in range(num_workers)]\n\n                    # start process\n                    p = Pool(processes=num_workers)\n                    p.map(self.preprocessor.generate_labels_features_new, args)\n\n                    p.close()\n                else:\n                    self.preprocessor.generate_labels_features_new(midi_paths)\n\n            # kfold is 5 fold index ( 0, 1, 2, 3, 4 )\n            self.song_names, self.paths = self.get_paths(kfold=kfold)\n\n    def __len__(self):\n        return len(self.paths)\n\n    def __getitem__(self, idx):\n        instance_path = self.paths[idx]\n\n        res = dict()\n        data = torch.load(instance_path)\n        res['feature'] = np.log(np.abs(data['feature']) + 1e-6)\n        res['chord'] = data['chord']\n        return res\n\n    def get_paths(self, kfold=4):\n        temp = {}\n        used_song_names = list()\n        for name in self.dataset_names:\n            dataset_path = os.path.join(self.root_dir, \"result\", name, self.mp3_string, self.feature_string)\n            song_names = os.listdir(dataset_path)\n            for song_name in song_names:\n                paths = []\n                instance_names = os.listdir(os.path.join(dataset_path, song_name))\n                if len(instance_names) > 0:\n                    used_song_names.append(song_name)\n                for instance_name in instance_names:\n                    paths.append(os.path.join(dataset_path, song_name, instance_name))\n                temp[song_name] = paths\n        # throw away unused song names\n        song_names = used_song_names\n        song_names = SortedList(song_names)\n\n        print('Total used song length : %d' %len(song_names))\n        tmp = []\n        for i in range(len(song_names)):\n            tmp += temp[song_names[i]]\n        print('Total instances (train and valid) : %d' %len(tmp))\n\n        # divide train/valid dataset using k fold\n        result = []\n        total_fold = 5\n        quotient = len(song_names) // total_fold\n        remainder = len(song_names) % total_fold\n        fold_num = [0]\n        for i in range(total_fold):\n            fold_num.append(quotient)\n        for i in range(remainder):\n            fold_num[i+1] += 1\n        for i in range(total_fold):\n                fold_num[i+1] += fold_num[i]\n\n        if self.train:\n            tmp = []\n            # get not augmented data\n            for k in range(total_fold):\n                if k != kfold:\n                    for i in range(fold_num[k], fold_num[k+1]):\n                        result += temp[song_names[i]]\n                    tmp += song_names[fold_num[k]:fold_num[k + 1]]\n            song_names = tmp\n        else:\n            for i in range(fold_num[kfold], fold_num[kfold+1]):\n                instances = temp[song_names[i]]\n                instances = [inst for inst in instances if \"1.00_0\" in inst]\n                result += instances\n            song_names = song_names[fold_num[kfold]:fold_num[kfold+1]]\n        return song_names, result\n\n    def get_paths_voca(self, kfold=4):\n        temp = {}\n        used_song_names = list()\n        for name in self.dataset_names:\n            dataset_path = os.path.join(self.root_dir, \"result\", name+'_voca', self.mp3_string, self.feature_string)\n            song_names = os.listdir(dataset_path)\n            for song_name in song_names:\n                paths = []\n                instance_names = os.listdir(os.path.join(dataset_path, song_name))\n                if len(instance_names) > 0:\n                    used_song_names.append(song_name)\n                for instance_name in instance_names:\n                    paths.append(os.path.join(dataset_path, song_name, instance_name))\n                temp[song_name] = paths\n        # throw away unused song names\n        song_names = used_song_names\n        song_names = SortedList(song_names)\n\n        print('Total used song length : %d' %len(song_names))\n        tmp = []\n        for i in range(len(song_names)):\n            tmp += temp[song_names[i]]\n        print('Total instances (train and valid) : %d' %len(tmp))\n\n        # divide train/valid dataset using k fold\n        result = []\n        total_fold = 5\n        quotient = len(song_names) // total_fold\n        remainder = len(song_names) % total_fold\n        fold_num = [0]\n        for i in range(total_fold):\n            fold_num.append(quotient)\n        for i in range(remainder):\n            fold_num[i+1] += 1\n        for i in range(total_fold):\n                fold_num[i+1] += fold_num[i]\n\n        if self.train:\n            tmp = []\n            # get not augmented data\n            for k in range(total_fold):\n                if k != kfold:\n                    for i in range(fold_num[k], fold_num[k+1]):\n                        result += temp[song_names[i]]\n                    tmp += song_names[fold_num[k]:fold_num[k + 1]]\n            song_names = tmp\n        else:\n            for i in range(fold_num[kfold], fold_num[kfold+1]):\n                instances = temp[song_names[i]]\n                instances = [inst for inst in instances if \"1.00_0\" in inst]\n                result += instances\n            song_names = song_names[fold_num[kfold]:fold_num[kfold+1]]\n        return song_names, result\n\ndef _collate_fn(batch):\n    batch_size = len(batch)\n    max_len = batch[0]['feature'].shape[1]\n\n    input_percentages = torch.empty(batch_size)  # for variable length\n    chord_lens = torch.empty(batch_size, dtype=torch.int64)\n    chords = []\n    collapsed_chords = []\n    features = []\n    boundaries = []\n    for i in range(batch_size):\n        sample = batch[i]\n        feature = sample['feature']\n        chord = sample['chord']\n        diff = np.diff(chord, axis=0).astype(np.bool)\n        idx = np.insert(diff, 0, True, axis=0)\n        chord_lens[i] = np.sum(idx).item(0)\n        chords.extend(chord)\n        features.append(feature)\n        input_percentages[i] = feature.shape[1] / max_len\n        collapsed_chords.extend(np.array(chord)[idx].tolist())\n        boundary = np.append([0], diff)\n        boundaries.extend(boundary.tolist())\n\n    features = torch.tensor(features, dtype=torch.float32).unsqueeze(1)  # batch_size*1*feature_size*max_len\n    chords = torch.tensor(chords, dtype=torch.int64)  # (batch_size*time_length)\n    collapsed_chords = torch.tensor(collapsed_chords, dtype=torch.int64)  # total_unique_chord_len\n    boundaries = torch.tensor(boundaries, dtype=torch.uint8)  # (batch_size*time_length)\n\n    return features, input_percentages, chords, collapsed_chords, chord_lens, boundaries\n\nclass AudioDataLoader(DataLoader):\n    def __init__(self, *args, **kwargs):\n        super(AudioDataLoader, self).__init__(*args, **kwargs)\n        self.collate_fn = _collate_fn\n"
  },
  {
    "path": "baseline_models.py",
    "content": "from utils.hparams import HParams\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport time\nfrom crf_model import CRF\n\nuse_cuda = torch.cuda.is_available()\n\nclass CNN(nn.Module):\n    def __init__(self,config):\n        super(CNN, self).__init__()\n\n        self.timestep = config['timestep']\n        self.context = 7\n        self.pad = nn.ConstantPad1d(self.context, 0)\n        self.probs_out = config['probs_out']\n        self.num_chords = config['num_chords']\n\n        self.drop_out = nn.Dropout2d(p=0.5)\n        self.conv1 = self.cnn_layers(1, 32, kernel_size=(3,3), padding=1)\n        self.conv2 = self.cnn_layers(32, 32, kernel_size=(3,3), padding=1)\n        self.conv3 = self.cnn_layers(32, 32, kernel_size=(3,3), padding=1)\n        self.conv4 = self.cnn_layers(32, 32, kernel_size=(3,3), padding=1)\n        self.pool_max = nn.MaxPool2d(kernel_size=(2,1))\n        self.conv5 = self.cnn_layers(32, 64, kernel_size=(3, 3), padding=0)\n        self.conv6 = self.cnn_layers(64, 64, kernel_size=(3, 3), padding=0)\n        self.conv7 = self.cnn_layers(64, 128, kernel_size=(12, 9), padding=0)\n        self.conv_linear = nn.Conv2d(128, config['num_chords'], kernel_size=(1,1), padding=0)\n\n    def cnn_layers(self, in_channels, out_channels, kernel_size, stride=1, padding=0):\n        layers = []\n        conv2d = nn.Conv2d(in_channels, out_channels,kernel_size=kernel_size, stride=stride, padding=padding)\n        batch_norm = nn.BatchNorm2d(out_channels)\n        relu = nn.ReLU(inplace=True)\n        layers += [conv2d, batch_norm, relu]\n        return nn.Sequential(*layers)\n\n    def forward(self, x, labels):\n        x = x.permute(0,2,1)\n        x = self.pad(x)\n        batch_size = x.size(0)\n        for i in range(batch_size):\n            for j in range(self.timestep):\n                if i == 0 and j == 0:\n                    inputs = x[i,:,j : j + self.context *2 + 1].unsqueeze(0)\n                else:\n                    tmp = x[i, :, j : j + self.context *2 + 1].unsqueeze(0)\n                    inputs = torch.cat((inputs,tmp), dim=0)\n        # inputs : [batchsize * timestep, feature_size, context]\n        inputs = inputs.unsqueeze(1)\n        conv = self.conv1(inputs)\n        conv = self.conv2(conv)\n        conv = self.conv3(conv)\n        conv = self.conv4(conv)\n        pooled = self.pool_max(conv)\n        pooled = self.drop_out(pooled)\n        conv = self.conv5(pooled)\n        conv = self.conv6(conv)\n        pooled = self.pool_max(conv)\n        pooled = self.drop_out(pooled)\n        conv = self.conv7(pooled)\n        conv = self.drop_out(conv)\n        conv = self.conv_linear(conv)\n        avg_pool = nn.AvgPool2d(kernel_size=(conv.size(2), conv.size(3)))\n        logits = avg_pool(conv).squeeze(2).squeeze(2)\n        if self.probs_out is True:\n            crf_input = logits.view(-1, self.timestep, self.num_chords)\n            return crf_input\n        log_probs = F.log_softmax(logits, -1)\n        topk, indices = torch.topk(log_probs, 2)\n        predictions = indices[:,0]\n        second = indices[:,1]\n        prediction = predictions.view(-1)\n        second = second.view(-1)\n        loss = F.nll_loss(log_probs.view(-1, self.num_chords), labels.view(-1))\n        return prediction, loss, 0, second\n\nclass Crf(nn.Module):\n    def __init__(self, num_chords, timestep):\n        super(Crf, self).__init__()\n        self.output_size = num_chords\n        self.timestep = timestep\n        self.Crf = CRF(self.output_size)\n\n    def forward(self, probs, labels):\n        prediction = self.Crf(probs)\n        prediction = prediction.view(-1)\n        labels = labels.view(-1, self.timestep)\n        loss = self.Crf.loss(probs, labels)\n        return prediction, loss\n\nclass CRNN(nn.Module):\n    def __init__(self,config):\n        super(CRNN, self).__init__()\n\n        self.feature_size = config['feature_size']\n        self.timestep = config['timestep']\n        self.probs_out = config['probs_out']\n        self.num_chords = config['num_chords']\n        self.hidden_size = 128\n\n        self.relu = nn.ReLU(inplace=True)\n        self.batch_norm = nn.BatchNorm2d(1)\n        self.conv1 = nn.Conv2d(1, 1, kernel_size=(5,5), padding=2)\n        self.conv2 = nn.Conv2d(1, 36, kernel_size=(1,self.feature_size))\n        self.gru = nn.GRU(input_size=36, hidden_size=self.hidden_size, num_layers=2, batch_first=True, bidirectional=True)\n        self.fc = nn.Linear(self.hidden_size*2, self.num_chords)\n\n    def forward(self, x, labels):\n        # x : [batchsize * timestep * feature_size]\n        x = x.unsqueeze(1)\n        x = self.batch_norm(x)\n        conv = self.relu(self.conv1(x))\n        conv = self.relu(self.conv2(conv))\n        conv = conv.squeeze(3).permute(0,2,1)\n\n        h0 = torch.zeros(4, conv.size(0), self.hidden_size).to(torch.device(\"cuda\" if use_cuda else \"cpu\"))\n        gru, h = self.gru(conv, h0)\n        logits = self.fc(gru)\n        if self.probs_out is True:\n            # probs = F.softmax(logits, -1)\n            return logits\n        log_probs = F.log_softmax(logits, -1)\n        topk, indices = torch.topk(log_probs, 2)\n        predictions = indices[:,:,0]\n        second = indices[:,:,1]\n        prediction = predictions.view(-1)\n        second = second.view(-1)\n        loss = F.nll_loss(log_probs.view(-1, self.num_chords), labels.view(-1))\n        return prediction, loss, 0, second\n\n\nif __name__ == \"__main__\":\n    config = HParams.load(\"run_config.yaml\")\n    device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n    config.model['probs_out'] = True\n    batch_size = 2\n    timestep = config.model['timestep']\n    feature_size = config.model['feature_size']\n    num_chords = config.model['num_chords']\n\n    features = torch.randn(batch_size,timestep,feature_size,requires_grad=True).to(device)\n    chords = torch.randint(num_chords,(batch_size*timestep,)).to(device)\n\n    model = CNN(config=config.model).to(device)\n    crf = Crf(num_chords=config.model['num_chords'], timestep=config.model['timestep']).to(device)\n\n    probs = model(features, chords)\n    prediction, total_loss = crf(probs, chords)\n\n    print(total_loss)\n"
  },
  {
    "path": "btc_model.py",
    "content": "from utils.transformer_modules import *\nfrom utils.transformer_modules import _gen_timing_signal, _gen_bias_mask\nfrom utils.hparams import HParams\n\nuse_cuda = torch.cuda.is_available()\n\nclass self_attention_block(nn.Module):\n    def __init__(self, hidden_size, total_key_depth, total_value_depth, filter_size, num_heads,\n                 bias_mask=None, layer_dropout=0.0, attention_dropout=0.0, relu_dropout=0.0, attention_map=False):\n        super(self_attention_block, self).__init__()\n\n        self.attention_map = attention_map\n        self.multi_head_attention = MultiHeadAttention(hidden_size, total_key_depth, total_value_depth,hidden_size, num_heads, bias_mask, attention_dropout, attention_map)\n        self.positionwise_convolution = PositionwiseFeedForward(hidden_size, filter_size, hidden_size, layer_config='cc', padding='both', dropout=relu_dropout)\n        self.dropout = nn.Dropout(layer_dropout)\n        self.layer_norm_mha = LayerNorm(hidden_size)\n        self.layer_norm_ffn = LayerNorm(hidden_size)\n\n    def forward(self, inputs):\n        x = inputs\n\n        # Layer Normalization\n        x_norm = self.layer_norm_mha(x)\n\n        # Multi-head attention\n        if self.attention_map is True:\n            y, weights = self.multi_head_attention(x_norm, x_norm, x_norm)\n        else:\n            y = self.multi_head_attention(x_norm, x_norm, x_norm)\n\n        # Dropout and residual\n        x = self.dropout(x + y)\n\n        # Layer Normalization\n        x_norm = self.layer_norm_ffn(x)\n\n        # Positionwise Feedforward\n        y = self.positionwise_convolution(x_norm)\n\n        # Dropout and residual\n        y = self.dropout(x + y)\n\n        if self.attention_map is True:\n            return y, weights\n        return y\n\nclass bi_directional_self_attention(nn.Module):\n    def __init__(self, hidden_size, total_key_depth, total_value_depth, filter_size, num_heads, max_length,\n                 layer_dropout=0.0, attention_dropout=0.0, relu_dropout=0.0):\n\n        super(bi_directional_self_attention, self).__init__()\n\n        self.weights_list = list()\n\n        params = (hidden_size,\n                  total_key_depth or hidden_size,\n                  total_value_depth or hidden_size,\n                  filter_size,\n                  num_heads,\n                  _gen_bias_mask(max_length),\n                  layer_dropout,\n                  attention_dropout,\n                  relu_dropout,\n                  True)\n\n        self.attn_block = self_attention_block(*params)\n\n        params = (hidden_size,\n                  total_key_depth or hidden_size,\n                  total_value_depth or hidden_size,\n                  filter_size,\n                  num_heads,\n                  torch.transpose(_gen_bias_mask(max_length), dim0=2, dim1=3),\n                  layer_dropout,\n                  attention_dropout,\n                  relu_dropout,\n                  True)\n\n        self.backward_attn_block = self_attention_block(*params)\n\n        self.linear = nn.Linear(hidden_size*2, hidden_size)\n\n    def forward(self, inputs):\n        x, list = inputs\n\n        # Forward Self-attention Block\n        encoder_outputs, weights = self.attn_block(x)\n        # Backward Self-attention Block\n        reverse_outputs, reverse_weights = self.backward_attn_block(x)\n        # Concatenation and Fully-connected Layer\n        outputs = torch.cat((encoder_outputs, reverse_outputs), dim=2)\n        y = self.linear(outputs)\n\n        # Attention weights for Visualization\n        self.weights_list = list\n        self.weights_list.append(weights)\n        self.weights_list.append(reverse_weights)\n        return y, self.weights_list\n\nclass bi_directional_self_attention_layers(nn.Module):\n    def __init__(self, embedding_size, hidden_size, num_layers, num_heads, total_key_depth, total_value_depth,\n                 filter_size, max_length=100, input_dropout=0.0, layer_dropout=0.0,\n                 attention_dropout=0.0, relu_dropout=0.0):\n        super(bi_directional_self_attention_layers, self).__init__()\n\n        self.timing_signal = _gen_timing_signal(max_length, hidden_size)\n        params = (hidden_size,\n                  total_key_depth or hidden_size,\n                  total_value_depth or hidden_size,\n                  filter_size,\n                  num_heads,\n                  max_length,\n                  layer_dropout,\n                  attention_dropout,\n                  relu_dropout)\n        self.embedding_proj = nn.Linear(embedding_size, hidden_size, bias=False)\n        self.self_attn_layers = nn.Sequential(*[bi_directional_self_attention(*params) for l in range(num_layers)])\n        self.layer_norm = LayerNorm(hidden_size)\n        self.input_dropout = nn.Dropout(input_dropout)\n\n    def forward(self, inputs):\n        # Add input dropout\n        x = self.input_dropout(inputs)\n\n        # Project to hidden size\n        x = self.embedding_proj(x)\n\n        # Add timing signal\n        x += self.timing_signal[:, :inputs.shape[1], :].type_as(inputs.data)\n\n        # A Stack of Bi-directional Self-attention Layers\n        y, weights_list = self.self_attn_layers((x, []))\n\n        # Layer Normalization\n        y = self.layer_norm(y)\n        return y, weights_list\n\nclass BTC_model(nn.Module):\n    def __init__(self, config):\n        super(BTC_model, self).__init__()\n\n        self.timestep = config['timestep']\n        self.probs_out = config['probs_out']\n\n        params = (config['feature_size'],\n                  config['hidden_size'],\n                  config['num_layers'],\n                  config['num_heads'],\n                  config['total_key_depth'],\n                  config['total_value_depth'],\n                  config['filter_size'],\n                  config['timestep'],\n                  config['input_dropout'],\n                  config['layer_dropout'],\n                  config['attention_dropout'],\n                  config['relu_dropout'])\n\n        self.self_attn_layers = bi_directional_self_attention_layers(*params)\n        self.output_layer = SoftmaxOutputLayer(hidden_size=config['hidden_size'], output_size=config['num_chords'], probs_out=config['probs_out'])\n\n    def forward(self, x, labels):\n        labels = labels.view(-1, self.timestep)\n        # Output of Bi-directional Self-attention Layers\n        self_attn_output, weights_list = self.self_attn_layers(x)\n\n        # return logit values for CRF\n        if self.probs_out is True:\n            logits = self.output_layer(self_attn_output)\n            return logits\n\n        # Output layer and Soft-max\n        prediction,second = self.output_layer(self_attn_output)\n        prediction = prediction.view(-1)\n        second = second.view(-1)\n\n        # Loss Calculation\n        loss = self.output_layer.loss(self_attn_output, labels)\n        return prediction, loss, weights_list, second\n\nif __name__ == \"__main__\":\n    config = HParams.load(\"run_config.yaml\")\n    device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n\n    batch_size = 2\n    timestep = 108\n    feature_size = 144\n    num_chords = 25\n\n    features = torch.randn(batch_size,timestep,feature_size,requires_grad=True).to(device)\n    chords = torch.randint(25,(batch_size*timestep,)).to(device)\n\n    model = BTC_model(config=config.model).to(device)\n\n    prediction, loss, weights_list, second = model(features, chords)\n    print(prediction.size())\n    print(loss)\n\n\n"
  },
  {
    "path": "crf_model.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport torch\nimport torch.nn as nn\n\nclass CRF(nn.Module):\n    \"\"\"\n    Implements Conditional Random Fields that can be trained via\n    backpropagation.\n    \"\"\"\n\n    def __init__(self, num_tags):\n        super(CRF, self).__init__()\n\n        self.num_tags = num_tags\n        self.transitions = nn.Parameter(torch.Tensor(num_tags, num_tags))\n        self.start_transitions = nn.Parameter(torch.randn(num_tags))\n        self.stop_transitions = nn.Parameter(torch.randn(num_tags))\n\n        nn.init.xavier_normal_(self.transitions)\n\n    def forward(self, feats):\n        # Shape checks\n        if len(feats.shape) != 3:\n            raise ValueError(\"feats must be 3-d got {}-d\".format(feats.shape))\n\n        return self._viterbi(feats)\n\n    def loss(self, feats, tags):\n        \"\"\"\n        Computes negative log likelihood between features and tags.\n        Essentially difference between individual sequence scores and\n        sum of all possible sequence scores (partition function)\n        Parameters:\n            feats: Input features [batch size, sequence length, number of tags]\n            tags: Target tag indices [batch size, sequence length]. Should be between\n                    0 and num_tags\n        Returns:\n            Negative log likelihood [a scalar]\n        \"\"\"\n        # Shape checks\n        if len(feats.shape) != 3:\n            raise ValueError(\"feats must be 3-d got {}-d\".format(feats.shape))\n\n        if len(tags.shape) != 2:\n            raise ValueError('tags must be 2-d but got {}-d'.format(tags.shape))\n\n        if feats.shape[:2] != tags.shape:\n            raise ValueError('First two dimensions of feats and tags must match')\n\n        sequence_score = self._sequence_score(feats, tags)\n        partition_function = self._partition_function(feats)\n        log_probability = sequence_score - partition_function\n\n        # -ve of l()\n        # Average across batch\n        return -log_probability.mean()\n\n    def _sequence_score(self, feats, tags):\n        \"\"\"\n        Parameters:\n            feats: Input features [batch size, sequence length, number of tags]\n            tags: Target tag indices [batch size, sequence length]. Should be between\n                    0 and num_tags\n        Returns: Sequence score of shape [batch size]\n        \"\"\"\n\n        batch_size = feats.shape[0]\n\n        # Compute feature scores\n        feat_score = feats.gather(2, tags.unsqueeze(-1)).squeeze(-1).sum(dim=-1)\n\n        # Compute transition scores\n        # Unfold to get [from, to] tag index pairs\n        tags_pairs = tags.unfold(1, 2, 1)\n\n        # Use advanced indexing to pull out required transition scores\n        indices = tags_pairs.permute(2, 0, 1).chunk(2)\n        trans_score = self.transitions[indices].squeeze(0).sum(dim=-1)\n\n        # Compute start and stop scores\n        start_score = self.start_transitions[tags[:, 0]]\n        stop_score = self.stop_transitions[tags[:, -1]]\n\n        return feat_score + start_score + trans_score + stop_score\n\n    def _partition_function(self, feats):\n        \"\"\"\n        Computes the partitition function for CRF using the forward algorithm.\n        Basically calculate scores for all possible tag sequences for\n        the given feature vector sequence\n        Parameters:\n            feats: Input features [batch size, sequence length, number of tags]\n        Returns:\n            Total scores of shape [batch size]\n        \"\"\"\n        _, seq_size, num_tags = feats.shape\n\n        if self.num_tags != num_tags:\n            raise ValueError('num_tags should be {} but got {}'.format(self.num_tags, num_tags))\n\n        a = feats[:, 0] + self.start_transitions.unsqueeze(0)  # [batch_size, num_tags]\n        transitions = self.transitions.unsqueeze(0)  # [1, num_tags, num_tags] from -> to\n\n        for i in range(1, seq_size):\n            feat = feats[:, i].unsqueeze(1)  # [batch_size, 1, num_tags]\n            a = self._log_sum_exp(a.unsqueeze(-1) + transitions + feat, 1)  # [batch_size, num_tags]\n\n        return self._log_sum_exp(a + self.stop_transitions.unsqueeze(0), 1)  # [batch_size]\n\n    def _viterbi(self, feats):\n        \"\"\"\n        Uses Viterbi algorithm to predict the best sequence\n        Parameters:\n            feats: Input features [batch size, sequence length, number of tags]\n        Returns: Best tag sequence [batch size, sequence length]\n        \"\"\"\n        _, seq_size, num_tags = feats.shape\n\n        if self.num_tags != num_tags:\n            raise ValueError('num_tags should be {} but got {}'.format(self.num_tags, num_tags))\n\n        v = feats[:, 0] + self.start_transitions.unsqueeze(0)  # [batch_size, num_tags]\n        transitions = self.transitions.unsqueeze(0)  # [1, num_tags, num_tags] from -> to\n        paths = []\n\n        for i in range(1, seq_size):\n            feat = feats[:, i]  # [batch_size, num_tags]\n            v, idx = (v.unsqueeze(-1) + transitions).max(1)  # [batch_size, num_tags], [batch_size, num_tags]\n\n            paths.append(idx)\n            v = (v + feat)  # [batch_size, num_tags]\n\n        v, tag = (v + self.stop_transitions.unsqueeze(0)).max(1, True)\n\n        # Backtrack\n        tags = [tag]\n        for idx in reversed(paths):\n            tag = idx.gather(1, tag)\n            tags.append(tag)\n\n        tags.reverse()\n        return torch.cat(tags, 1)\n\n    def _log_sum_exp(self, logits, dim):\n        \"\"\"\n        Computes log-sum-exp in a stable way\n        \"\"\"\n        max_val, _ = logits.max(dim)\n        return max_val + (logits - max_val.unsqueeze(dim)).exp().sum(dim).log()"
  },
  {
    "path": "run_config.yaml",
    "content": "mp3:\n  song_hz: 22050\n  inst_len: 10.0\n  skip_interval: 5.0\n\nfeature:\n  n_bins: 144\n  bins_per_octave: 24\n  hop_length: 2048\n  large_voca: False\n#  large_voca: True\n\nexperiment:\n  learning_rate : 0.0001\n  weight_decay : 0.0\n  max_epoch : 100\n  batch_size : 128\n  save_step : 40\n  data_ratio : 0.8\n\nmodel:\n  feature_size : 144\n  timestep : 108\n  num_chords : 25\n#  num_chords : 170\n  input_dropout : 0.2\n  layer_dropout : 0.2\n  attention_dropout : 0.2\n  relu_dropout : 0.2\n  num_layers : 8\n  num_heads : 4\n  hidden_size : 128\n  total_key_depth : 128\n  total_value_depth : 128\n  filter_size : 128\n  loss : 'ce'\n  probs_out : False\n\npath:\n  ckpt_path : 'model'\n  result_path : 'result'\n  asset_path : '/data/music/chord_recognition/jayg996/assets'\n  root_path : '/data/music/chord_recognition'\n"
  },
  {
    "path": "test.py",
    "content": "import os\nimport mir_eval\nimport pretty_midi as pm\nfrom utils import logger\nfrom btc_model import *\nfrom utils.mir_eval_modules import audio_file_to_features, idx2chord, idx2voca_chord, get_audio_paths\nimport argparse\nimport warnings\n\nwarnings.filterwarnings('ignore')\nlogger.logging_verbosity(1)\nuse_cuda = torch.cuda.is_available()\ndevice = torch.device(\"cuda\" if use_cuda else \"cpu\")\n\n# hyperparameters\nparser = argparse.ArgumentParser()\nparser.add_argument('--voca', default=True, type=lambda x: (str(x).lower() == 'true'))\nparser.add_argument('--audio_dir', type=str, default='./test')\nparser.add_argument('--save_dir', type=str, default='./test')\nargs = parser.parse_args()\n\nconfig = HParams.load(\"run_config.yaml\")\n\nif args.voca is True:\n    config.feature['large_voca'] = True\n    config.model['num_chords'] = 170\n    model_file = './test/btc_model_large_voca.pt'\n    idx_to_chord = idx2voca_chord()\n    logger.info(\"label type: large voca\")\nelse:\n    model_file = './test/btc_model.pt'\n    idx_to_chord = idx2chord\n    logger.info(\"label type: Major and minor\")\n\nmodel = BTC_model(config=config.model).to(device)\n\n# Load model\nif os.path.isfile(model_file):\n    checkpoint = torch.load(model_file)\n    mean = checkpoint['mean']\n    std = checkpoint['std']\n    model.load_state_dict(checkpoint['model'])\n    logger.info(\"restore model\")\n\n# Audio files with format of wav and mp3\naudio_paths = get_audio_paths(args.audio_dir)\n\n# Chord recognition and save lab file\nfor i, audio_path in enumerate(audio_paths):\n    logger.info(\"======== %d of %d in progress ========\" % (i + 1, len(audio_paths)))\n    # Load mp3\n    feature, feature_per_second, song_length_second = audio_file_to_features(audio_path, config)\n    logger.info(\"audio file loaded and feature computation success : %s\" % audio_path)\n\n    # Majmin type chord recognition\n    feature = feature.T\n    feature = (feature - mean) / std\n    time_unit = feature_per_second\n    n_timestep = config.model['timestep']\n\n    num_pad = n_timestep - (feature.shape[0] % n_timestep)\n    feature = np.pad(feature, ((0, num_pad), (0, 0)), mode=\"constant\", constant_values=0)\n    num_instance = feature.shape[0] // n_timestep\n\n    start_time = 0.0\n    lines = []\n    with torch.no_grad():\n        model.eval()\n        feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device)\n        for t in range(num_instance):\n            self_attn_output, _ = model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :])\n            prediction, _ = model.output_layer(self_attn_output)\n            prediction = prediction.squeeze()\n            for i in range(n_timestep):\n                if t == 0 and i == 0:\n                    prev_chord = prediction[i].item()\n                    continue\n                if prediction[i].item() != prev_chord:\n                    lines.append(\n                        '%.3f %.3f %s\\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord]))\n                    start_time = time_unit * (n_timestep * t + i)\n                    prev_chord = prediction[i].item()\n                if t == num_instance - 1 and i + num_pad == n_timestep:\n                    if start_time != time_unit * (n_timestep * t + i):\n                        lines.append('%.3f %.3f %s\\n' % (start_time, time_unit * (n_timestep * t + i), idx_to_chord[prev_chord]))\n                    break\n\n    # lab file write\n    if not os.path.exists(args.save_dir):\n        os.makedirs(args.save_dir)\n    save_path = os.path.join(args.save_dir, os.path.split(audio_path)[-1].replace('.mp3', '').replace('.wav', '') + '.lab')\n    with open(save_path, 'w') as f:\n        for line in lines:\n            f.write(line)\n\n    logger.info(\"label file saved : %s\" % save_path)\n\n    # lab file to midi file\n    \n\n    starts, ends, pitchs = list(), list(), list()\n\n    intervals, chords = mir_eval.io.load_labeled_intervals(save_path)\n    for p in range(12):\n        for i, (interval, chord) in enumerate(zip(intervals, chords)):\n            root_num, relative_bitmap, _ = mir_eval.chord.encode(chord)\n            tmp_label = mir_eval.chord.rotate_bitmap_to_root(relative_bitmap, root_num)[p]\n            if i == 0:\n                start_time = interval[0]\n                label = tmp_label\n                continue\n            if tmp_label != label:\n                if label == 1.0:\n                    starts.append(start_time), ends.append(interval[0]), pitchs.append(p + 48)\n                start_time = interval[0]\n                label = tmp_label\n            if i == (len(intervals) - 1): \n                if label == 1.0:\n                    starts.append(start_time), ends.append(interval[1]), pitchs.append(p + 48)\n\n    midi = pm.PrettyMIDI()\n    instrument = pm.Instrument(program=0)\n\n    for start, end, pitch in zip(starts, ends, pitchs):\n        pm_note = pm.Note(velocity=120, pitch=pitch, start=start, end=end)\n        instrument.notes.append(pm_note)\n\n    midi.instruments.append(instrument)\n    midi.write(save_path.replace('.lab', '.midi'))    \n\n"
  },
  {
    "path": "train.py",
    "content": "import os\nfrom torch import optim\nfrom utils import logger\nfrom audio_dataset import AudioDataset, AudioDataLoader\nfrom utils.tf_logger import TF_Logger\nfrom btc_model import *\nfrom baseline_models import CNN, CRNN\nfrom utils.hparams import HParams\nimport argparse\nfrom utils.pytorch_utils import adjusting_learning_rate\nfrom utils.mir_eval_modules import root_majmin_score_calculation, large_voca_score_calculation\nimport warnings\n\nwarnings.filterwarnings(\"ignore\", category=UserWarning)\nwarnings.filterwarnings(\"ignore\", category=FutureWarning)\nlogger.logging_verbosity(1)\nuse_cuda = torch.cuda.is_available()\ndevice = torch.device(\"cuda\" if use_cuda else \"cpu\")\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--index', type=int, help='Experiment Number', default='e')\nparser.add_argument('--kfold', type=int, help='5 fold (0,1,2,3,4)',default='e')\nparser.add_argument('--voca', type=bool, help='large voca is True', default=False)\nparser.add_argument('--model', type=str, help='btc, cnn, crnn', default='btc')\nparser.add_argument('--dataset1', type=str, help='Dataset', default='isophonic')\nparser.add_argument('--dataset2', type=str, help='Dataset', default='uspop')\nparser.add_argument('--dataset3', type=str, help='Dataset', default='robbiewilliams')\nparser.add_argument('--restore_epoch', type=int, default=1000)\nparser.add_argument('--early_stop', type=bool, help='no improvement during 10 epoch -> stop', default=True)\nargs = parser.parse_args()\n\nconfig = HParams.load(\"run_config.yaml\")\nif args.voca == True:\n    config.feature['large_voca'] = True\n    config.model['num_chords'] = 170\n\n# Result save path\nasset_path = config.path['asset_path']\nckpt_path = config.path['ckpt_path']\nresult_path = config.path['result_path']\nrestore_epoch = args.restore_epoch\nexperiment_num = str(args.index)\nckpt_file_name = 'idx_'+experiment_num+'_%03d.pth.tar'\ntf_logger = TF_Logger(os.path.join(asset_path, 'tensorboard', 'idx_'+experiment_num))\nlogger.info(\"==== Experiment Number : %d \" % args.index)\n\nif args.model == 'cnn':\n    config.experiment['batch_size'] = 10\n\n# Data loader\ntrain_dataset1 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset1,), num_workers=20, preprocessing=False, train=True, kfold=args.kfold)\ntrain_dataset2 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset2,), num_workers=20, preprocessing=False, train=True, kfold=args.kfold)\ntrain_dataset3 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset3,), num_workers=20, preprocessing=False, train=True, kfold=args.kfold)\ntrain_dataset = train_dataset1.__add__(train_dataset2).__add__(train_dataset3)\nvalid_dataset1 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset1,), preprocessing=False, train=False, kfold=args.kfold)\nvalid_dataset2 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset2,), preprocessing=False, train=False, kfold=args.kfold)\nvalid_dataset3 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset3,), preprocessing=False, train=False, kfold=args.kfold)\nvalid_dataset = valid_dataset1.__add__(valid_dataset2).__add__(valid_dataset3)\ntrain_dataloader = AudioDataLoader(dataset=train_dataset, batch_size=config.experiment['batch_size'], drop_last=False, shuffle=True)\nvalid_dataloader = AudioDataLoader(dataset=valid_dataset, batch_size=config.experiment['batch_size'], drop_last=False)\n\n# Model and Optimizer\nif args.model == 'cnn':\n    model = CNN(config=config.model).to(device)\nelif args.model == 'crnn':\n    model = CRNN(config=config.model).to(device)\nelif args.model == 'btc':\n    model = BTC_model(config=config.model).to(device)\nelse: raise NotImplementedError\noptimizer = optim.Adam(model.parameters(), lr=config.experiment['learning_rate'], weight_decay=config.experiment['weight_decay'], betas=(0.9, 0.98), eps=1e-9)\n\n# Make asset directory\nif not os.path.exists(os.path.join(asset_path, ckpt_path)):\n    os.makedirs(os.path.join(asset_path, ckpt_path))\n    os.makedirs(os.path.join(asset_path, result_path))\n\n# Load model\nif os.path.isfile(os.path.join(asset_path, ckpt_path, ckpt_file_name % restore_epoch)):\n    checkpoint = torch.load(os.path.join(asset_path, ckpt_path, ckpt_file_name % restore_epoch))\n    model.load_state_dict(checkpoint['model'])\n    optimizer.load_state_dict(checkpoint['optimizer'])\n    epoch = checkpoint['epoch']\n    logger.info(\"restore model with %d epochs\" % restore_epoch)\nelse:\n    logger.info(\"no checkpoint with %d epochs\" % restore_epoch)\n    restore_epoch = 0\n\n# Global mean and variance calculate\nmp3_config = config.mp3\nfeature_config = config.feature\nmp3_string = \"%d_%.1f_%.1f\" % (mp3_config['song_hz'], mp3_config['inst_len'], mp3_config['skip_interval'])\nfeature_string = \"_%s_%d_%d_%d_\" % ('cqt', feature_config['n_bins'], feature_config['bins_per_octave'], feature_config['hop_length'])\nz_path = os.path.join(config.path['root_path'], 'result', mp3_string + feature_string + 'mix_kfold_'+ str(args.kfold) +'_normalization.pt')\nif os.path.exists(z_path):\n    normalization = torch.load(z_path)\n    mean = normalization['mean']\n    std = normalization['std']\n    logger.info(\"Global mean and std (k fold index %d) load complete\" % args.kfold)\nelse:\n    mean = 0\n    square_mean = 0\n    k = 0\n    for i, data in enumerate(train_dataloader):\n        features, input_percentages, chords, collapsed_chords, chord_lens, boundaries = data\n        features = features.to(device)\n        mean += torch.mean(features).item()\n        square_mean += torch.mean(features.pow(2)).item()\n        k += 1\n    square_mean = square_mean / k\n    mean = mean / k\n    std = np.sqrt(square_mean - mean * mean)\n    normalization = dict()\n    normalization['mean'] = mean\n    normalization['std'] = std\n    torch.save(normalization, z_path)\n    logger.info(\"Global mean and std (training set, k fold index %d) calculation complete\" % args.kfold)\n\ncurrent_step = 0\nbest_acc = 0\nbefore_acc = 0\nearly_stop_idx = 0\nfor epoch in range(restore_epoch, config.experiment['max_epoch']):\n    # Training\n    model.train()\n    train_loss_list = []\n    total = 0.\n    correct = 0.\n    second_correct = 0.\n    for i, data in enumerate(train_dataloader):\n        features, input_percentages, chords, collapsed_chords, chord_lens, boundaries = data\n        features, chords = features.to(device), chords.to(device)\n\n        features.requires_grad = True\n        features = (features - mean) / std\n\n        # forward\n        features = features.squeeze(1).permute(0,2,1)\n        optimizer.zero_grad()\n        prediction, total_loss, weights, second = model(features, chords)\n\n        # save accuracy and loss\n        total += chords.size(0)\n        correct += (prediction == chords).type_as(chords).sum()\n        second_correct += (second == chords).type_as(chords).sum()\n        train_loss_list.append(total_loss.item())\n\n        # optimize step\n        total_loss.backward()\n        optimizer.step()\n\n        current_step += 1\n\n    # logging loss and accuracy using tensorboard\n    result = {'loss/tr': np.mean(train_loss_list), 'acc/tr': correct.item() / total, 'top2/tr': (correct.item()+second_correct.item()) / total}\n    for tag, value in result.items(): tf_logger.scalar_summary(tag, value, epoch+1)\n    logger.info(\"training loss for %d epoch: %.4f\" % (epoch + 1, np.mean(train_loss_list)))\n    logger.info(\"training accuracy for %d epoch: %.4f\" % (epoch + 1, (correct.item() / total)))\n    logger.info(\"training top2 accuracy for %d epoch: %.4f\" % (epoch + 1, ((correct.item() + second_correct.item()) / total)))\n\n    # Validation\n    with torch.no_grad():\n        model.eval()\n        val_total = 0.\n        val_correct = 0.\n        val_second_correct = 0.\n        validation_loss = 0\n        n = 0\n        for i, data in enumerate(valid_dataloader):\n            val_features, val_input_percentages, val_chords, val_collapsed_chords, val_chord_lens, val_boundaries = data\n            val_features, val_chords = val_features.to(device), val_chords.to(device)\n\n            val_features = (val_features - mean) / std\n\n            val_features = val_features.squeeze(1).permute(0, 2, 1)\n            val_prediction, val_loss, weights, val_second = model(val_features, val_chords)\n\n            val_total += val_chords.size(0)\n            val_correct += (val_prediction == val_chords).type_as(val_chords).sum()\n            val_second_correct += (val_second == val_chords).type_as(val_chords).sum()\n            validation_loss += val_loss.item()\n\n            n += 1\n\n        # logging loss and accuracy using tensorboard\n        validation_loss /= n\n        result = {'loss/val': validation_loss, 'acc/val': val_correct.item() / val_total, 'top2/val': (val_correct.item()+val_second_correct.item()) / val_total}\n        for tag, value in result.items(): tf_logger.scalar_summary(tag, value, epoch + 1)\n        logger.info(\"validation loss(%d): %.4f\" % (epoch + 1, validation_loss))\n        logger.info(\"validation accuracy(%d): %.4f\" % (epoch + 1, (val_correct.item() / val_total)))\n        logger.info(\"validation top2 accuracy(%d): %.4f\" % (epoch + 1, ((val_correct.item() + val_second_correct.item()) / val_total)))\n\n        current_acc = val_correct.item() / val_total\n\n        if best_acc < val_correct.item() / val_total:\n            early_stop_idx = 0\n            best_acc = val_correct.item() / val_total\n            logger.info('==== best accuracy is %.4f and epoch is %d' % (best_acc, epoch + 1))\n            logger.info('saving model, Epoch %d, step %d' % (epoch + 1, current_step + 1))\n            model_save_path = os.path.join(asset_path, 'model', ckpt_file_name % (epoch + 1))\n            state_dict = {'model': model.state_dict(),'optimizer': optimizer.state_dict(),'epoch': epoch}\n            torch.save(state_dict, model_save_path)\n            last_best_epoch = epoch + 1\n\n        # save model\n        elif (epoch + 1) % config.experiment['save_step'] == 0:\n            logger.info('saving model, Epoch %d, step %d' % (epoch + 1, current_step + 1))\n            model_save_path = os.path.join(asset_path, 'model', ckpt_file_name % (epoch + 1))\n            state_dict = {'model': model.state_dict(),'optimizer': optimizer.state_dict(),'epoch': epoch}\n            torch.save(state_dict, model_save_path)\n            early_stop_idx += 1\n        else:\n            early_stop_idx += 1\n\n    if (args.early_stop == True) and (early_stop_idx > 9):\n        logger.info('==== early stopped and epoch is %d' % (epoch + 1))\n        break\n    # learning rate decay\n    if before_acc > current_acc:\n        adjusting_learning_rate(optimizer=optimizer, factor=0.95, min_lr=5e-6)\n    before_acc = current_acc\n\n# Load model\nif os.path.isfile(os.path.join(asset_path, ckpt_path, ckpt_file_name % last_best_epoch)):\n    checkpoint = torch.load(os.path.join(asset_path, ckpt_path, ckpt_file_name % last_best_epoch))\n    model.load_state_dict(checkpoint['model'])\n    logger.info(\"restore model with %d epochs\" % last_best_epoch)\nelse:\n    raise NotImplementedError\n\n# score Validation\nif args.voca == True:\n    score_metrics = ['root', 'thirds', 'triads', 'sevenths', 'tetrads', 'majmin', 'mirex']\n    score_list_dict1, song_length_list1, average_score_dict1 = large_voca_score_calculation(valid_dataset=valid_dataset1, config=config, model=model, model_type=args.model, mean=mean, std=std, device=device)\n    score_list_dict2, song_length_list2, average_score_dict2 = large_voca_score_calculation(valid_dataset=valid_dataset2, config=config, model=model, model_type=args.model, mean=mean, std=std, device=device)\n    score_list_dict3, song_length_list3, average_score_dict3 = large_voca_score_calculation(valid_dataset=valid_dataset3, config=config, model=model, model_type=args.model, mean=mean, std=std, device=device)\n    for m in score_metrics:\n        average_score = (np.sum(song_length_list1) * average_score_dict1[m] + np.sum(song_length_list2) *average_score_dict2[m] + np.sum(song_length_list3) * average_score_dict3[m]) / (np.sum(song_length_list1) + np.sum(song_length_list2) + np.sum(song_length_list3))\n        logger.info('==== %s score 1 is %.4f' % (m, average_score_dict1[m]))\n        logger.info('==== %s score 2 is %.4f' % (m, average_score_dict2[m]))\n        logger.info('==== %s score 3 is %.4f' % (m, average_score_dict3[m]))\n        logger.info('==== %s mix average score is %.4f' % (m, average_score))\nelse:\n    score_metrics = ['root', 'majmin']\n    score_list_dict1, song_length_list1, average_score_dict1 = root_majmin_score_calculation(valid_dataset=valid_dataset1, config=config, model=model, model_type=args.model, mean=mean, std=std, device=device)\n    score_list_dict2, song_length_list2, average_score_dict2 = root_majmin_score_calculation(valid_dataset=valid_dataset2, config=config, model=model, model_type=args.model, mean=mean, std=std, device=device)\n    score_list_dict3, song_length_list3, average_score_dict3 = root_majmin_score_calculation(valid_dataset=valid_dataset3, config=config, model=model, model_type=args.model, mean=mean, std=std, device=device)\n    for m in score_metrics:\n        average_score = (np.sum(song_length_list1) * average_score_dict1[m] + np.sum(song_length_list2) *average_score_dict2[m] + np.sum(song_length_list3) * average_score_dict3[m]) / (np.sum(song_length_list1) + np.sum(song_length_list2) + np.sum(song_length_list3))\n        logger.info('==== %s score 1 is %.4f' % (m, average_score_dict1[m]))\n        logger.info('==== %s score 2 is %.4f' % (m, average_score_dict2[m]))\n        logger.info('==== %s score 3 is %.4f' % (m, average_score_dict3[m]))\n        logger.info('==== %s mix average score is %.4f' % (m, average_score))\n"
  },
  {
    "path": "train_crf.py",
    "content": "import os\nfrom torch import optim\nfrom utils import logger\nfrom audio_dataset import AudioDataset, AudioDataLoader\nfrom utils.tf_logger import TF_Logger\nfrom btc_model import *\nfrom baseline_models import CNN, CRNN, Crf\nfrom utils.hparams import HParams\nimport argparse\nfrom utils.pytorch_utils import adjusting_learning_rate\nfrom utils.mir_eval_modules import large_voca_score_calculation_crf, root_majmin_score_calculation_crf\nimport warnings\n\nwarnings.filterwarnings(\"ignore\", category=UserWarning)\nwarnings.filterwarnings(\"ignore\", category=FutureWarning)\nlogger.logging_verbosity(1)\nuse_cuda = torch.cuda.is_available()\ndevice = torch.device(\"cuda\" if use_cuda else \"cpu\")\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--index', type=int, help='Experiment Number', default='e')\nparser.add_argument('--kfold', type=int, help='5 fold (0,1,2,3,4)',default='e')\nparser.add_argument('--voca', type=bool, help='large voca is True', default=False)\nparser.add_argument('--model', type=str, default='crf')\nparser.add_argument('--pre_model', type=str, help='btc, cnn, crnn', default='e')\nparser.add_argument('--dataset1', type=str, help='Dataset', default='isophonic_221')\nparser.add_argument('--dataset2', type=str, help='Dataset', default='uspop_185')\nparser.add_argument('--dataset3', type=str, help='Dataset', default='robbiewilliams')\nparser.add_argument('--restore_epoch', type=int, default=1000)\nparser.add_argument('--early_stop', type=bool, help='no improvement during 10 epoch -> stop', default=True)\nargs = parser.parse_args()\n\nconfig = HParams.load(\"run_config.yaml\")\nif args.voca == True:\n    config.feature['large_voca'] = True\n    config.model['num_chords'] = 170\n\nconfig.model['probs_out'] = True\n\n# Result save path\nasset_path = config.path['asset_path']\nckpt_path = config.path['ckpt_path']\nresult_path = config.path['result_path']\nrestore_epoch = args.restore_epoch\nexperiment_num = str(args.index)\nckpt_file_name = 'idx_'+experiment_num+'_%03d.pth.tar'\ntf_logger = TF_Logger(os.path.join(asset_path, 'tensorboard', 'idx_'+experiment_num))\nlogger.info(\"==== Experiment Number : %d \" % args.index)\n\nif args.pre_model == 'cnn':\n    config.experiment['batch_size'] = 20\n\n# Data loader\ntrain_dataset1 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset1,), num_workers=20, preprocessing=False, train=True, kfold=args.kfold)\ntrain_dataset2 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset2,), num_workers=20, preprocessing=False, train=True, kfold=args.kfold)\ntrain_dataset3 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset3,), num_workers=20, preprocessing=False, train=True, kfold=args.kfold)\ntrain_dataset = train_dataset1.__add__(train_dataset2).__add__(train_dataset3)\nvalid_dataset1 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset1,), preprocessing=False, train=False, kfold=args.kfold)\nvalid_dataset2 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset2,), preprocessing=False, train=False, kfold=args.kfold)\nvalid_dataset3 = AudioDataset(config, root_dir=config.path['root_path'], dataset_names=(args.dataset3,), preprocessing=False, train=False, kfold=args.kfold)\nvalid_dataset = valid_dataset1.__add__(valid_dataset2).__add__(valid_dataset3)\ntrain_dataloader = AudioDataLoader(dataset=train_dataset, batch_size=config.experiment['batch_size'], drop_last=False, shuffle=True)\nvalid_dataloader = AudioDataLoader(dataset=valid_dataset, batch_size=config.experiment['batch_size'], drop_last=False)\n\n# Model and Optimizer\nif args.pre_model == 'cnn':\n    pre_model = CNN(config=config.model).to(device)\nelif args.pre_model == 'crnn':\n    pre_model = CRNN(config=config.model).to(device)\nelif args.pre_model == 'btc':\n    pre_model = BTC_model(config=config.model).to(device)\nelse: raise NotImplementedError\n\nif args.pre_model == 'cnn':\n    if args.voca == False:\n        if args.kfold == 0:\n            load_ckpt_file_name = 'idx_0_%03d.pth.tar'\n            load_restore_epoch = 10\n    else:\n        if args.kfold == 0:\n            load_ckpt_file_name = 'idx_1_%03d.pth.tar'\n            load_restore_epoch = 10\nelse:\n    raise NotImplementedError\n\nif os.path.isfile(os.path.join(asset_path, ckpt_path, load_ckpt_file_name % load_restore_epoch)):\n    checkpoint = torch.load(os.path.join(asset_path, ckpt_path, load_ckpt_file_name % load_restore_epoch))\n    pre_model.load_state_dict(checkpoint['model'])\n    logger.info(\"restore pre model with %d epochs\" % load_restore_epoch)\nelse:\n    raise NotImplementedError\n\n# Fix Pre Model Parameters\nfor param in pre_model.parameters():\n    param.requires_grad = False\n\n# Crf Model and Optimizer\ncrf = Crf(num_chords=config.model['num_chords'], timestep=config.model['timestep']).to(device)\noptimizer = optim.Adam(filter(lambda p: p.requires_grad, crf.parameters()), lr=0.01, weight_decay=config.experiment['weight_decay'], betas=(0.9, 0.98), eps=1e-9)\n\n# Make asset directory\nif not os.path.exists(os.path.join(asset_path, ckpt_path)):\n    os.makedirs(os.path.join(asset_path, ckpt_path))\n    os.makedirs(os.path.join(asset_path, result_path))\n\n# Load model\nif os.path.isfile(os.path.join(asset_path, ckpt_path, ckpt_file_name % restore_epoch)):\n    checkpoint = torch.load(os.path.join(asset_path, ckpt_path, ckpt_file_name % restore_epoch))\n    crf.load_state_dict(checkpoint['model'])\n    optimizer.load_state_dict(checkpoint['optimizer'])\n    epoch = checkpoint['epoch']\n    logger.info(\"restore model with %d epochs\" % restore_epoch)\nelse:\n    logger.info(\"no checkpoint with %d epochs\" % restore_epoch)\n    restore_epoch = 0\n\n# Global mean and variance calculate\nmp3_config = config.mp3\nfeature_config = config.feature\nmp3_string = \"%d_%.1f_%.1f\" % (mp3_config['song_hz'], mp3_config['inst_len'], mp3_config['skip_interval'])\nfeature_string = \"_%s_%d_%d_%d_\" % ('cqt', feature_config['n_bins'], feature_config['bins_per_octave'], feature_config['hop_length'])\nz_path = os.path.join(config.path['root_path'], 'result', mp3_string + feature_string + 'mix_kfold_'+ str(args.kfold) +'_normalization.pt')\nif os.path.exists(z_path):\n    normalization = torch.load(z_path)\n    mean = normalization['mean']\n    std = normalization['std']\n    logger.info(\"Global mean and std (k fold index %d) load complete\" % args.kfold)\nelse:\n    mean = 0\n    square_mean = 0\n    k = 0\n    for i, data in enumerate(train_dataloader):\n        features, input_percentages, chords, collapsed_chords, chord_lens, boundaries = data\n        features = features.to(device)\n        mean += torch.mean(features).item()\n        square_mean += torch.mean(features.pow(2)).item()\n        k += 1\n    square_mean = square_mean / k\n    mean = mean / k\n    std = np.sqrt(square_mean - mean * mean)\n    normalization = dict()\n    normalization['mean'] = mean\n    normalization['std'] = std\n    torch.save(normalization, z_path)\n    logger.info(\"Global mean and std (training set, k fold index %d) calculation complete\" % args.kfold)\n\ncurrent_step = 0\nbest_acc = 0\nbefore_acc = 0\nearly_stop_idx = 0\npre_model.eval()\nfor epoch in range(restore_epoch, config.experiment['max_epoch']):\n    # Training\n    crf.train()\n    train_loss_list = []\n    total = 0.\n    correct = 0.\n    second_correct = 0.\n    for i, data in enumerate(train_dataloader):\n        features, input_percentages, chords, collapsed_chords, chord_lens, boundaries = data\n        features, chords = features.to(device), chords.to(device)\n\n        features.requires_grad = True\n        features = (features - mean) / std\n\n        # forward\n        features = features.squeeze(1).permute(0,2,1)\n        optimizer.zero_grad()\n        logits = pre_model(features, chords)\n        if args.pre_model == 'crnn':\n            logits = logits.detach()\n            logits.requires_grad = True\n        prediction, total_loss = crf(logits, chords)\n\n        # save accuracy and loss\n        total += chords.size(0)\n        correct += (prediction == chords).type_as(chords).sum()\n        train_loss_list.append(total_loss.item())\n\n        # optimize step\n        total_loss.backward()\n        optimizer.step()\n\n        current_step += 1\n\n    # logging loss and accuracy using tensorboard\n    result = {'loss/tr': np.mean(train_loss_list), 'acc/tr': correct.item() / total}\n    for tag, value in result.items(): tf_logger.scalar_summary(tag, value, epoch+1)\n    logger.info(\"training loss for %d epoch: %.4f\" % (epoch + 1, np.mean(train_loss_list)))\n    logger.info(\"training accuracy for %d epoch: %.4f\" % (epoch + 1, (correct.item() / total)))\n\n    # Validation\n    with torch.no_grad():\n        crf.eval()\n        val_total = 0.\n        val_correct = 0.\n        val_second_correct = 0.\n        validation_loss = 0\n        n = 0\n        for i, data in enumerate(valid_dataloader):\n            val_features, val_input_percentages, val_chords, val_collapsed_chords, val_chord_lens, val_boundaries = data\n            val_features, val_chords = val_features.to(device), val_chords.to(device)\n\n            val_features = (val_features - mean) / std\n\n            val_features = val_features.squeeze(1).permute(0, 2, 1)\n            val_logits = pre_model(val_features, val_chords)\n            val_prediction, val_loss = crf(val_logits, val_chords)\n\n            val_total += val_chords.size(0)\n            val_correct += (val_prediction == val_chords).type_as(val_chords).sum()\n            validation_loss += val_loss.item()\n\n            n += 1\n\n        # logging loss and accuracy using tensorboard\n        validation_loss /= n\n        result = {'loss/val': validation_loss, 'acc/val': val_correct.item() / val_total}\n        for tag, value in result.items(): tf_logger.scalar_summary(tag, value, epoch + 1)\n        logger.info(\"validation loss(%d): %.4f\" % (epoch + 1, validation_loss))\n        logger.info(\"validation accuracy(%d): %.4f\" % (epoch + 1, (val_correct.item() / val_total)))\n\n        current_acc = val_correct.item() / val_total\n\n        if best_acc < val_correct.item() / val_total:\n            early_stop_idx = 0\n            best_acc = val_correct.item() / val_total\n            logger.info('==== best accuracy is %.4f and epoch is %d' % (best_acc, epoch + 1))\n            logger.info('saving model, Epoch %d, step %d' % (epoch + 1, current_step + 1))\n            model_save_path = os.path.join(asset_path, 'model', ckpt_file_name % (epoch + 1))\n            state_dict = {'model': crf.state_dict(),'optimizer': optimizer.state_dict(),'epoch': epoch}\n            torch.save(state_dict, model_save_path)\n            last_best_epoch = epoch + 1\n\n        # save model\n        elif (epoch + 1) % config.experiment['save_step'] == 0:\n            logger.info('saving model, Epoch %d, step %d' % (epoch + 1, current_step + 1))\n            model_save_path = os.path.join(asset_path, 'model', ckpt_file_name % (epoch + 1))\n            state_dict = {'model': crf.state_dict(),'optimizer': optimizer.state_dict(),'epoch': epoch}\n            torch.save(state_dict, model_save_path)\n            early_stop_idx += 1\n        else:\n            early_stop_idx += 1\n\n    if (args.early_stop == True) and (early_stop_idx > 5):\n        logger.info('==== early stopped and epoch is %d' % (epoch + 1))\n        break\n    # learning rate decay\n    if before_acc > current_acc:\n        adjusting_learning_rate(optimizer=optimizer, factor=0.95, min_lr=5e-6)\n    before_acc = current_acc\n\n# Load model\nif os.path.isfile(os.path.join(asset_path, ckpt_path, ckpt_file_name % last_best_epoch)):\n    checkpoint = torch.load(os.path.join(asset_path, ckpt_path, ckpt_file_name % last_best_epoch))\n    crf.load_state_dict(checkpoint['model'])\n    logger.info(\"last best restore model with %d epochs\" % last_best_epoch)\nelse:\n    raise NotImplementedError\n\n# score Validation\nif args.voca == True:\n    score_metrics = ['root', 'thirds', 'triads', 'sevenths', 'tetrads', 'majmin', 'mirex']\n    score_list_dict1, song_length_list1, average_score_dict1 = large_voca_score_calculation_crf(valid_dataset=valid_dataset1, config=config, pre_model=pre_model, model=crf, model_type=args.pre_model, mean=mean, std=std, device=device)\n    score_list_dict2, song_length_list2, average_score_dict2 = large_voca_score_calculation_crf(valid_dataset=valid_dataset2, config=config, pre_model=pre_model, model=crf, model_type=args.pre_model, mean=mean, std=std, device=device)\n    score_list_dict3, song_length_list3, average_score_dict3 = large_voca_score_calculation_crf(valid_dataset=valid_dataset3, config=config, pre_model=pre_model, model=crf, model_type=args.pre_model, mean=mean, std=std, device=device)\n    for m in score_metrics:\n        average_score = (np.sum(song_length_list1) * average_score_dict1[m] + np.sum(song_length_list2) *average_score_dict2[m] + np.sum(song_length_list3) * average_score_dict3[m]) / (np.sum(song_length_list1) + np.sum(song_length_list2) + np.sum(song_length_list3))\n        logger.info('==== %s score 1 is %.4f' % (m, average_score_dict1[m]))\n        logger.info('==== %s score 2 is %.4f' % (m, average_score_dict2[m]))\n        logger.info('==== %s score 3 is %.4f' % (m, average_score_dict3[m]))\n        logger.info('==== %s mix average score is %.4f' % (m, average_score))\nelse:\n    score_metrics = ['root', 'majmin']\n    score_list_dict1, song_length_list1, average_score_dict1 = root_majmin_score_calculation_crf(valid_dataset=valid_dataset1, config=config, pre_model=pre_model, model=crf, model_type=args.pre_model, mean=mean, std=std, device=device)\n    score_list_dict2, song_length_list2, average_score_dict2 = root_majmin_score_calculation_crf(valid_dataset=valid_dataset2, config=config, pre_model=pre_model, model=crf, model_type=args.pre_model, mean=mean, std=std, device=device)\n    score_list_dict3, song_length_list3, average_score_dict3 = root_majmin_score_calculation_crf(valid_dataset=valid_dataset3, config=config, pre_model=pre_model, model=crf, model_type=args.pre_model, mean=mean, std=std, device=device)\n    for m in score_metrics:\n        average_score = (np.sum(song_length_list1) * average_score_dict1[m] + np.sum(song_length_list2) *average_score_dict2[m] + np.sum(song_length_list3) * average_score_dict3[m]) / (np.sum(song_length_list1) + np.sum(song_length_list2) + np.sum(song_length_list3))\n        logger.info('==== %s score 1 is %.4f' % (m, average_score_dict1[m]))\n        logger.info('==== %s score 2 is %.4f' % (m, average_score_dict2[m]))\n        logger.info('==== %s score 3 is %.4f' % (m, average_score_dict3[m]))\n        logger.info('==== %s mix average score is %.4f' % (m, average_score))\n"
  },
  {
    "path": "utils/__init__.py",
    "content": ""
  },
  {
    "path": "utils/chords.py",
    "content": "# encoding: utf-8\n\"\"\"\nThis module contains chord evaluation functionality.\n\nIt provides the evaluation measures used for the MIREX ACE task, and\ntries to follow [1]_ and [2]_ as closely as possible.\n\nNotes\n-----\nThis implementation tries to follow the references and their implementation\n(e.g., https://github.com/jpauwels/MusOOEvaluator for [2]_). However, there\nare some known (and possibly some unknown) differences. If you find one not\nlisted in the following, please file an issue:\n\n - Detected chord segments are adjusted to fit the length of the annotations.\n   In particular, this means that, if necessary, filler segments of 'no chord'\n   are added at beginnings and ends. This can result in different segmentation\n   scores compared to the original implementation.\n\nReferences\n----------\n.. [1] Christopher Harte, \"Towards Automatic Extraction of Harmony Information\n       from Music Signals.\" Dissertation,\n       Department for Electronic Engineering, Queen Mary University of London,\n       2010.\n.. [2] Johan Pauwels and Geoffroy Peeters.\n       \"Evaluating Automatically Estimated Chord Sequences.\"\n       In Proceedings of ICASSP 2013, Vancouver, Canada, 2013.\n\n\"\"\"\n\nimport numpy as np\nimport pandas as pd\nimport mir_eval\n\n\nCHORD_DTYPE = [('root', np.int),\n               ('bass', np.int),\n               ('intervals', np.int, (12,)),\n               ('is_major',np.bool)]\n\nCHORD_ANN_DTYPE = [('start', np.float),\n                   ('end', np.float),\n                   ('chord', CHORD_DTYPE)]\n\nNO_CHORD = (-1, -1, np.zeros(12, dtype=np.int), False)\nUNKNOWN_CHORD = (-1, -1, np.ones(12, dtype=np.int) * -1, False)\n\nPITCH_CLASS = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']\n\n\ndef idx_to_chord(idx):\n    if idx == 24:\n        return \"-\"\n    elif idx == 25:\n        return u\"\\u03B5\"\n\n    minmaj = idx % 2\n    root = idx // 2\n\n    return PITCH_CLASS[root] + (\"M\" if minmaj == 0 else \"m\")\n\nclass Chords:\n\n    def __init__(self):\n        self._shorthands = {\n            'maj': self.interval_list('(1,3,5)'),\n            'min': self.interval_list('(1,b3,5)'),\n            'dim': self.interval_list('(1,b3,b5)'),\n            'aug': self.interval_list('(1,3,#5)'),\n            'maj7': self.interval_list('(1,3,5,7)'),\n            'min7': self.interval_list('(1,b3,5,b7)'),\n            '7': self.interval_list('(1,3,5,b7)'),\n            '6': self.interval_list('(1,6)'),  # custom\n            '5': self.interval_list('(1,5)'),\n            '4': self.interval_list('(1,4)'),  # custom\n            '1': self.interval_list('(1)'),\n            'dim7': self.interval_list('(1,b3,b5,bb7)'),\n            'hdim7': self.interval_list('(1,b3,b5,b7)'),\n            'minmaj7': self.interval_list('(1,b3,5,7)'),\n            'maj6': self.interval_list('(1,3,5,6)'),\n            'min6': self.interval_list('(1,b3,5,6)'),\n            '9': self.interval_list('(1,3,5,b7,9)'),\n            'maj9': self.interval_list('(1,3,5,7,9)'),\n            'min9': self.interval_list('(1,b3,5,b7,9)'),\n            'sus2': self.interval_list('(1,2,5)'),\n            'sus4': self.interval_list('(1,4,5)'),\n            '11': self.interval_list('(1,3,5,b7,9,11)'),\n            'min11': self.interval_list('(1,b3,5,b7,9,11)'),\n            '13': self.interval_list('(1,3,5,b7,13)'),\n            'maj13': self.interval_list('(1,3,5,7,13)'),\n            'min13': self.interval_list('(1,b3,5,b7,13)')\n        }\n\n    def chords(self, labels):\n\n        \"\"\"\n        Transform a list of chord labels into an array of internal numeric\n        representations.\n\n        Parameters\n        ----------\n        labels : list\n            List of chord labels (str).\n\n        Returns\n        -------\n        chords : numpy.array\n            Structured array with columns 'root', 'bass', and 'intervals',\n            containing a numeric representation of chords.\n\n        \"\"\"\n        crds = np.zeros(len(labels), dtype=CHORD_DTYPE)\n        cache = {}\n        for i, lbl in enumerate(labels):\n            cv = cache.get(lbl, None)\n            if cv is None:\n                cv = self.chord(lbl)\n                cache[lbl] = cv\n            crds[i] = cv\n\n        return crds\n\n    def label_error_modify(self, label):\n        if label == 'Emin/4': label = 'E:min/4'\n        elif label == 'A7/3': label = 'A:7/3'\n        elif label == 'Bb7/3': label = 'Bb:7/3'\n        elif label == 'Bb7/5': label = 'Bb:7/5'\n        elif label.find(':') == -1:\n            if label.find('min') != -1:\n                label = label[:label.find('min')] + ':' + label[label.find('min'):]\n        return label\n\n    def chord(self, label):\n        \"\"\"\n        Transform a chord label into the internal numeric represenation of\n        (root, bass, intervals array).\n\n        Parameters\n        ----------\n        label : str\n            Chord label.\n\n        Returns\n        -------\n        chord : tuple\n            Numeric representation of the chord: (root, bass, intervals array).\n\n        \"\"\"\n\n        try:\n            is_major = False\n\n            if label == 'N':\n                return NO_CHORD\n            if label == 'X':\n                return UNKNOWN_CHORD\n\n            label = self.label_error_modify(label)\n\n            c_idx = label.find(':')\n            s_idx = label.find('/')\n\n            if c_idx == -1:\n                quality_str = 'maj'\n                if s_idx == -1:\n                    root_str = label\n                    bass_str = ''\n                else:\n                    root_str = label[:s_idx]\n                    bass_str = label[s_idx + 1:]\n            else:\n                root_str = label[:c_idx]\n                if s_idx == -1:\n                    quality_str = label[c_idx + 1:]\n                    bass_str = ''\n                else:\n                    quality_str = label[c_idx + 1:s_idx]\n                    bass_str = label[s_idx + 1:]\n\n            root = self.pitch(root_str)\n            bass = self.interval(bass_str) if bass_str else 0\n            ivs = self.chord_intervals(quality_str)\n            ivs[bass] = 1\n\n            if 'min' in quality_str:\n                is_major = False\n            else:\n                is_major = True\n\n        except Exception as e:\n            print(e, label)\n\n        return root, bass, ivs, is_major\n\n    _l = [0, 1, 1, 0, 1, 1, 1]\n    _chroma_id = (np.arange(len(_l) * 2) + 1) + np.array(_l + _l).cumsum() - 1\n\n    def modify(self, base_pitch, modifier):\n        \"\"\"\n        Modify a pitch class in integer representation by a given modifier string.\n\n        A modifier string can be any sequence of 'b' (one semitone down)\n        and '#' (one semitone up).\n\n        Parameters\n        ----------\n        base_pitch : int\n            Pitch class as integer.\n        modifier : str\n            String of modifiers ('b' or '#').\n\n        Returns\n        -------\n        modified_pitch : int\n            Modified root note.\n\n        \"\"\"\n        for m in modifier:\n            if m == 'b':\n                base_pitch -= 1\n            elif m == '#':\n                base_pitch += 1\n            else:\n                raise ValueError('Unknown modifier: {}'.format(m))\n        return base_pitch\n\n    def pitch(self, pitch_str):\n        \"\"\"\n        Convert a string representation of a pitch class (consisting of root\n        note and modifiers) to an integer representation.\n\n        Parameters\n        ----------\n        pitch_str : str\n            String representation of a pitch class.\n\n        Returns\n        -------\n        pitch : int\n            Integer representation of a pitch class.\n\n        \"\"\"\n        return self.modify(self._chroma_id[(ord(pitch_str[0]) - ord('C')) % 7],\n                      pitch_str[1:]) % 12\n\n    def interval(self, interval_str):\n        \"\"\"\n        Convert a string representation of a musical interval into a pitch class\n        (e.g. a minor seventh 'b7' into 10, because it is 10 semitones above its\n        base note).\n\n        Parameters\n        ----------\n        interval_str : str\n            Musical interval.\n\n        Returns\n        -------\n        pitch_class : int\n            Number of semitones to base note of interval.\n\n        \"\"\"\n        for i, c in enumerate(interval_str):\n            if c.isdigit():\n                return self.modify(self._chroma_id[int(interval_str[i:]) - 1],\n                              interval_str[:i]) % 12\n\n    def interval_list(self, intervals_str, given_pitch_classes=None):\n        \"\"\"\n        Convert a list of intervals given as string to a binary pitch class\n        representation. For example, 'b3, 5' would become\n        [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0].\n\n        Parameters\n        ----------\n        intervals_str : str\n            List of intervals as comma-separated string (e.g. 'b3, 5').\n        given_pitch_classes : None or numpy array\n            If None, start with empty pitch class array, if numpy array of length\n            12, this array will be modified.\n\n        Returns\n        -------\n        pitch_classes : numpy array\n            Binary pitch class representation of intervals.\n\n        \"\"\"\n        if given_pitch_classes is None:\n            given_pitch_classes = np.zeros(12, dtype=np.int)\n        for int_def in intervals_str[1:-1].split(','):\n            int_def = int_def.strip()\n            if int_def[0] == '*':\n                given_pitch_classes[self.interval(int_def[1:])] = 0\n            else:\n                given_pitch_classes[self.interval(int_def)] = 1\n        return given_pitch_classes\n\n    # mapping of shorthand interval notations to the actual interval representation\n\n    def chord_intervals(self, quality_str):\n        \"\"\"\n        Convert a chord quality string to a pitch class representation. For\n        example, 'maj' becomes [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0].\n\n        Parameters\n        ----------\n        quality_str : str\n            String defining the chord quality.\n\n        Returns\n        -------\n        pitch_classes : numpy array\n            Binary pitch class representation of chord quality.\n\n        \"\"\"\n        list_idx = quality_str.find('(')\n        if list_idx == -1:\n            return self._shorthands[quality_str].copy()\n        if list_idx != 0:\n            ivs = self._shorthands[quality_str[:list_idx]].copy()\n        else:\n            ivs = np.zeros(12, dtype=np.int)\n\n\n        return self.interval_list(quality_str[list_idx:], ivs)\n\n    def load_chords(self, filename):\n        \"\"\"\n        Load chords from a text file.\n\n        The chord must follow the syntax defined in [1]_.\n\n        Parameters\n        ----------\n        filename : str\n            File containing chord segments.\n\n        Returns\n        -------\n        crds : numpy structured array\n            Structured array with columns \"start\", \"end\", and \"chord\",\n            containing the beginning, end, and chord definition of chord\n            segments.\n\n        References\n        ----------\n        .. [1] Christopher Harte, \"Towards Automatic Extraction of Harmony\n               Information from Music Signals.\" Dissertation,\n               Department for Electronic Engineering, Queen Mary University of\n               London, 2010.\n\n        \"\"\"\n        start, end, chord_labels = [], [], []\n        with open(filename, 'r') as f:\n            for line in f:\n                if line:\n\n                    splits = line.split()\n                    if len(splits) == 3:\n\n                        s = splits[0]\n                        e = splits[1]\n                        l = splits[2]\n\n                        start.append(float(s))\n                        end.append(float(e))\n                        chord_labels.append(l)\n\n        crds = np.zeros(len(start), dtype=CHORD_ANN_DTYPE)\n        crds['start'] = start\n        crds['end'] = end\n        crds['chord'] = self.chords(chord_labels)\n\n        return crds\n\n    def reduce_to_triads(self, chords, keep_bass=False):\n        \"\"\"\n        Reduce chords to triads.\n\n        The function follows the reduction rules implemented in [1]_. If a chord\n        chord does not contain a third, major second or fourth, it is reduced to\n        a power chord. If it does not contain neither a third nor a fifth, it is\n        reduced to a single note \"chord\".\n\n        Parameters\n        ----------\n        chords : numpy structured array\n            Chords to be reduced.\n        keep_bass : bool\n            Indicates whether to keep the bass note or set it to 0.\n\n        Returns\n        -------\n        reduced_chords : numpy structured array\n            Chords reduced to triads.\n\n        References\n        ----------\n        .. [1] Johan Pauwels and Geoffroy Peeters.\n               \"Evaluating Automatically Estimated Chord Sequences.\"\n               In Proceedings of ICASSP 2013, Vancouver, Canada, 2013.\n\n        \"\"\"\n        unison = chords['intervals'][:, 0].astype(bool)\n        maj_sec = chords['intervals'][:, 2].astype(bool)\n        min_third = chords['intervals'][:, 3].astype(bool)\n        maj_third = chords['intervals'][:, 4].astype(bool)\n        perf_fourth = chords['intervals'][:, 5].astype(bool)\n        dim_fifth = chords['intervals'][:, 6].astype(bool)\n        perf_fifth = chords['intervals'][:, 7].astype(bool)\n        aug_fifth = chords['intervals'][:, 8].astype(bool)\n        no_chord = (chords['intervals'] == NO_CHORD[-1]).all(axis=1)\n\n        reduced_chords = chords.copy()\n        ivs = reduced_chords['intervals']\n\n        ivs[~no_chord] = self.interval_list('(1)')\n        ivs[unison & perf_fifth] = self.interval_list('(1,5)')\n        ivs[~perf_fourth & maj_sec] = self._shorthands['sus2']\n        ivs[perf_fourth & ~maj_sec] = self._shorthands['sus4']\n\n        ivs[min_third] = self._shorthands['min']\n        ivs[min_third & aug_fifth & ~perf_fifth] = self.interval_list('(1,b3,#5)')\n        ivs[min_third & dim_fifth & ~perf_fifth] = self._shorthands['dim']\n\n        ivs[maj_third] = self._shorthands['maj']\n        ivs[maj_third & dim_fifth & ~perf_fifth] = self.interval_list('(1,3,b5)')\n        ivs[maj_third & aug_fifth & ~perf_fifth] = self._shorthands['aug']\n\n        if not keep_bass:\n            reduced_chords['bass'] = 0\n        else:\n            # remove bass notes if they are not part of the intervals anymore\n            reduced_chords['bass'] *= ivs[range(len(reduced_chords)),\n                                          reduced_chords['bass']]\n        # keep -1 in bass for no chords\n        reduced_chords['bass'][no_chord] = -1\n\n        return reduced_chords\n\n    def convert_to_id(self, root, is_major):\n        if root == -1:\n            return 24\n        else:\n            if is_major:\n                return root * 2\n            else:\n                return root * 2 + 1\n\n    def get_converted_chord(self, filename):\n        loaded_chord = self.load_chords(filename)\n        triads = self.reduce_to_triads(loaded_chord['chord'])\n\n        df = self.assign_chord_id(triads)\n        df['start'] = loaded_chord['start']\n        df['end'] = loaded_chord['end']\n\n        return df\n\n    def assign_chord_id(self, entry):\n        # maj, min chord only\n        # if you want to add other chord, change this part and get_converted_chord(reduce_to_triads)\n        df = pd.DataFrame(data=entry[['root', 'is_major']])\n        df['chord_id'] = df.apply(lambda row: self.convert_to_id(row['root'], row['is_major']), axis=1)\n        return df\n\n    def convert_to_id_voca(self, root, quality):\n        if root == -1:\n            return 169\n        else:\n            if quality == 'min':\n                return root * 14\n            elif quality == 'maj':\n                return root * 14 + 1\n            elif quality == 'dim':\n                return root * 14 + 2\n            elif quality == 'aug':\n                return root * 14 + 3\n            elif quality == 'min6':\n                return root * 14 + 4\n            elif quality == 'maj6':\n                return root * 14 + 5\n            elif quality == 'min7':\n                return root * 14 + 6\n            elif quality == 'minmaj7':\n                return root * 14 + 7\n            elif quality == 'maj7':\n                return root * 14 + 8\n            elif quality == '7':\n                return root * 14 + 9\n            elif quality == 'dim7':\n                return root * 14 + 10\n            elif quality == 'hdim7':\n                return root * 14 + 11\n            elif quality == 'sus2':\n                return root * 14 + 12\n            elif quality == 'sus4':\n                return root * 14 + 13\n            else:\n                return 168\n\n    def get_converted_chord_voca(self, filename):\n        loaded_chord = self.load_chords(filename)\n        triads = self.reduce_to_triads(loaded_chord['chord'])\n        df = pd.DataFrame(data=triads[['root', 'is_major']])\n\n        (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(filename)\n        ref_labels = self.lab_file_error_modify(ref_labels)\n        idxs = list()\n        for i in ref_labels:\n            chord_root, quality, scale_degrees, bass = mir_eval.chord.split(i, reduce_extended_chords=True)\n            root, bass, ivs, is_major = self.chord(i)\n            idxs.append(self.convert_to_id_voca(root=root, quality=quality))\n        df['chord_id'] = idxs\n\n        df['start'] = loaded_chord['start']\n        df['end'] = loaded_chord['end']\n\n        return df\n\n    def lab_file_error_modify(self, ref_labels):\n        for i in range(len(ref_labels)):\n            if ref_labels[i][-2:] == ':4':\n                ref_labels[i] = ref_labels[i].replace(':4', ':sus4')\n            elif ref_labels[i][-2:] == ':6':\n                ref_labels[i] = ref_labels[i].replace(':6', ':maj6')\n            elif ref_labels[i][-4:] == ':6/2':\n                ref_labels[i] = ref_labels[i].replace(':6/2', ':maj6/2')\n            elif ref_labels[i] == 'Emin/4':\n                ref_labels[i] = 'E:min/4'\n            elif ref_labels[i] == 'A7/3':\n                ref_labels[i] = 'A:7/3'\n            elif ref_labels[i] == 'Bb7/3':\n                ref_labels[i] = 'Bb:7/3'\n            elif ref_labels[i] == 'Bb7/5':\n                ref_labels[i] = 'Bb:7/5'\n            elif ref_labels[i].find(':') == -1:\n                if ref_labels[i].find('min') != -1:\n                    ref_labels[i] = ref_labels[i][:ref_labels[i].find('min')] + ':' + ref_labels[i][ref_labels[i].find('min'):]\n        return ref_labels\n\n"
  },
  {
    "path": "utils/hparams.py",
    "content": "import yaml\n\n\n# TODO: add function should be changed\nclass HParams(object):\n    # Hyperparameter class using yaml\n    def __init__(self, **kwargs):\n        self.__dict__ = kwargs\n\n    def add(self, **kwargs):\n        # change is needed - if key is existed, do not update.\n        self.__dict__.update(kwargs)\n\n    def update(self, **kwargs):\n        self.__dict__.update(kwargs)\n        return self\n\n    def save(self, path):\n        with open(path, 'w') as f:\n            yaml.dump(self.__dict__, f)\n        return self\n\n    def __repr__(self):\n        return '\\nHyperparameters:\\n' + '\\n'.join([' {}={}'.format(k, v) for k, v in self.__dict__.items()])\n\n    @classmethod\n    def load(cls, path):\n        with open(path, 'r') as f:\n            return cls(**yaml.load(f))\n\n\nif __name__ == '__main__':\n    hparams = HParams.load('hparams.yaml')\n    print(hparams)\n    d = {\"MemoryNetwork\": 0, \"c\": 1}\n    hparams.add(**d)\n    print(hparams)\n"
  },
  {
    "path": "utils/logger.py",
    "content": "import logging\nimport os\nimport sys\nimport time\n\n\nproject_name = os.getcwd().split('/')[-1]\n_logger = logging.getLogger(project_name)\n_logger.addHandler(logging.StreamHandler())\n\ndef _log_prefix():\n\n    # Returns (filename, line number) for the stack frame.\n    def _get_file_line():\n\n        # pylint: disable=protected-access\n        # noinspection PyProtectedMember\n        f = sys._getframe()\n        # pylint: enable=protected-access\n        our_file = f.f_code.co_filename\n        f = f.f_back\n        while f:\n            code = f.f_code\n            if code.co_filename != our_file:\n                return code.co_filename, f.f_lineno\n            f = f.f_back\n        return '<unknown>', 0\n\n    # current time\n    now = time.time()\n    now_tuple = time.localtime(now)\n    now_millisecond = int(1e3 * (now % 1.0))\n\n    # current filename and line\n    filename, line = _get_file_line()\n    basename = os.path.basename(filename)\n\n    s = '%02d-%02d %02d:%02d:%02d.%03d %s:%d] ' % (\n        now_tuple[1],  # month\n        now_tuple[2],  # day\n        now_tuple[3],  # hour\n        now_tuple[4],  # min\n        now_tuple[5],  # sec\n        now_millisecond,\n        basename,\n        line)\n\n    return s\n\n\ndef logging_verbosity(verbosity=0):\n    _logger.setLevel(verbosity)\n\n\ndef debug(msg, *args, **kwargs):\n    _logger.debug('D ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs)\n\n\ndef info(msg, *args, **kwargs):\n    _logger.info('I ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs)\n\n\ndef warn(msg, *args, **kwargs):\n    _logger.warning('W ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs)\n\n\ndef error(msg, *args, **kwargs):\n    _logger.error('E ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs)\n\n\ndef fatal(msg, *args, **kwargs):\n    _logger.fatal('F ' + project_name + ' ' + _log_prefix() + msg, *args, **kwargs)\n"
  },
  {
    "path": "utils/mir_eval_modules.py",
    "content": "import numpy as np\nimport librosa\nimport mir_eval\nimport torch\nimport os\n\nidx2chord = ['C', 'C:min', 'C#', 'C#:min', 'D', 'D:min', 'D#', 'D#:min', 'E', 'E:min', 'F', 'F:min', 'F#',\n             'F#:min', 'G', 'G:min', 'G#', 'G#:min', 'A', 'A:min', 'A#', 'A#:min', 'B', 'B:min', 'N']\n\nroot_list = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']\nquality_list = ['min', 'maj', 'dim', 'aug', 'min6', 'maj6', 'min7', 'minmaj7', 'maj7', '7', 'dim7', 'hdim7', 'sus2', 'sus4']\n\ndef idx2voca_chord():\n    idx2voca_chord = {}\n    idx2voca_chord[169] = 'N'\n    idx2voca_chord[168] = 'X'\n    for i in range(168):\n        root = i // 14\n        root = root_list[root]\n        quality = i % 14\n        quality = quality_list[quality]\n        if i % 14 != 1:\n            chord = root + ':' + quality\n        else:\n            chord = root\n        idx2voca_chord[i] = chord\n    return idx2voca_chord\n\ndef audio_file_to_features(audio_file, config):\n    original_wav, sr = librosa.load(audio_file, sr=config.mp3['song_hz'], mono=True)\n    currunt_sec_hz = 0\n    while len(original_wav) > currunt_sec_hz + config.mp3['song_hz'] * config.mp3['inst_len']:\n        start_idx = int(currunt_sec_hz)\n        end_idx = int(currunt_sec_hz + config.mp3['song_hz'] * config.mp3['inst_len'])\n        tmp = librosa.cqt(original_wav[start_idx:end_idx], sr=sr, n_bins=config.feature['n_bins'], bins_per_octave=config.feature['bins_per_octave'], hop_length=config.feature['hop_length'])\n        if start_idx == 0:\n            feature = tmp\n        else:\n            feature = np.concatenate((feature, tmp), axis=1)\n        currunt_sec_hz = end_idx\n    tmp = librosa.cqt(original_wav[currunt_sec_hz:], sr=sr, n_bins=config.feature['n_bins'], bins_per_octave=config.feature['bins_per_octave'], hop_length=config.feature['hop_length'])\n    feature = np.concatenate((feature, tmp), axis=1)\n    feature = np.log(np.abs(feature) + 1e-6)\n    feature_per_second = config.mp3['inst_len'] / config.model['timestep']\n    song_length_second = len(original_wav)/config.mp3['song_hz']\n    return feature, feature_per_second, song_length_second\n\n# Audio files with format of wav and mp3\ndef get_audio_paths(audio_dir):\n    return [os.path.join(root, fname) for (root, dir_names, file_names) in os.walk(audio_dir, followlinks=True)\n            for fname in file_names if (fname.lower().endswith('.wav') or fname.lower().endswith('.mp3'))]\n\nclass metrics():\n    def __init__(self):\n        super(metrics, self).__init__()\n        self.score_metrics = ['root', 'thirds', 'triads', 'sevenths', 'tetrads', 'majmin', 'mirex']\n        self.score_list_dict = dict()\n        for i in self.score_metrics:\n            self.score_list_dict[i] = list()\n        self.average_score = dict()\n\n    def score(self, metric, gt_path, est_path):\n        if metric == 'root':\n            score = self.root_score(gt_path,est_path)\n        elif metric == 'thirds':\n            score = self.thirds_score(gt_path,est_path)\n        elif metric == 'triads':\n            score = self.triads_score(gt_path,est_path)\n        elif metric == 'sevenths':\n            score = self.sevenths_score(gt_path,est_path)\n        elif metric == 'tetrads':\n            score = self.tetrads_score(gt_path,est_path)\n        elif metric == 'majmin':\n            score = self.majmin_score(gt_path,est_path)\n        elif metric == 'mirex':\n            score = self.mirex_score(gt_path,est_path)\n        else:\n            raise NotImplementedError\n        return score\n\n    def root_score(self, gt_path, est_path):\n        (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)\n        ref_labels = lab_file_error_modify(ref_labels)\n        (est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)\n        est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),\n                                                                   ref_intervals.max(), mir_eval.chord.NO_CHORD,\n                                                                   mir_eval.chord.NO_CHORD)\n        (intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,\n                                                                                    est_intervals, est_labels)\n        durations = mir_eval.util.intervals_to_durations(intervals)\n        comparisons = mir_eval.chord.root(ref_labels, est_labels)\n        score = mir_eval.chord.weighted_accuracy(comparisons, durations)\n        return score\n\n    def thirds_score(self, gt_path, est_path):\n        (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)\n        ref_labels = lab_file_error_modify(ref_labels)\n        (est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)\n        est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),\n                                                                   ref_intervals.max(), mir_eval.chord.NO_CHORD,\n                                                                   mir_eval.chord.NO_CHORD)\n        (intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,\n                                                                                    est_intervals, est_labels)\n        durations = mir_eval.util.intervals_to_durations(intervals)\n        comparisons = mir_eval.chord.thirds(ref_labels, est_labels)\n        score = mir_eval.chord.weighted_accuracy(comparisons, durations)\n        return score\n\n    def triads_score(self, gt_path, est_path):\n        (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)\n        ref_labels = lab_file_error_modify(ref_labels)\n        (est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)\n        est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),\n                                                                   ref_intervals.max(), mir_eval.chord.NO_CHORD,\n                                                                   mir_eval.chord.NO_CHORD)\n        (intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,\n                                                                                    est_intervals, est_labels)\n        durations = mir_eval.util.intervals_to_durations(intervals)\n        comparisons = mir_eval.chord.triads(ref_labels, est_labels)\n        score = mir_eval.chord.weighted_accuracy(comparisons, durations)\n        return score\n\n    def sevenths_score(self, gt_path, est_path):\n        (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)\n        ref_labels = lab_file_error_modify(ref_labels)\n        (est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)\n        est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),\n                                                                   ref_intervals.max(), mir_eval.chord.NO_CHORD,\n                                                                   mir_eval.chord.NO_CHORD)\n        (intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,\n                                                                                    est_intervals, est_labels)\n        durations = mir_eval.util.intervals_to_durations(intervals)\n        comparisons = mir_eval.chord.sevenths(ref_labels, est_labels)\n        score = mir_eval.chord.weighted_accuracy(comparisons, durations)\n        return score\n\n    def tetrads_score(self, gt_path, est_path):\n        (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)\n        ref_labels = lab_file_error_modify(ref_labels)\n        (est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)\n        est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),\n                                                                   ref_intervals.max(), mir_eval.chord.NO_CHORD,\n                                                                   mir_eval.chord.NO_CHORD)\n        (intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,\n                                                                                    est_intervals, est_labels)\n        durations = mir_eval.util.intervals_to_durations(intervals)\n        comparisons = mir_eval.chord.tetrads(ref_labels, est_labels)\n        score = mir_eval.chord.weighted_accuracy(comparisons, durations)\n        return score\n\n    def majmin_score(self, gt_path, est_path):\n        (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)\n        ref_labels = lab_file_error_modify(ref_labels)\n        (est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)\n        est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),\n                                                                   ref_intervals.max(), mir_eval.chord.NO_CHORD,\n                                                                   mir_eval.chord.NO_CHORD)\n        (intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,\n                                                                                    est_intervals, est_labels)\n        durations = mir_eval.util.intervals_to_durations(intervals)\n        comparisons = mir_eval.chord.majmin(ref_labels, est_labels)\n        score = mir_eval.chord.weighted_accuracy(comparisons, durations)\n        return score\n\n    def mirex_score(self, gt_path, est_path):\n        (ref_intervals, ref_labels) = mir_eval.io.load_labeled_intervals(gt_path)\n        ref_labels = lab_file_error_modify(ref_labels)\n        (est_intervals, est_labels) = mir_eval.io.load_labeled_intervals(est_path)\n        est_intervals, est_labels = mir_eval.util.adjust_intervals(est_intervals, est_labels, ref_intervals.min(),\n                                                                   ref_intervals.max(), mir_eval.chord.NO_CHORD,\n                                                                   mir_eval.chord.NO_CHORD)\n        (intervals, ref_labels, est_labels) = mir_eval.util.merge_labeled_intervals(ref_intervals, ref_labels,\n                                                                                    est_intervals, est_labels)\n        durations = mir_eval.util.intervals_to_durations(intervals)\n        comparisons = mir_eval.chord.mirex(ref_labels, est_labels)\n        score = mir_eval.chord.weighted_accuracy(comparisons, durations)\n        return score\n\ndef lab_file_error_modify(ref_labels):\n    for i in range(len(ref_labels)):\n        if ref_labels[i][-2:] == ':4':\n            ref_labels[i] = ref_labels[i].replace(':4', ':sus4')\n        elif ref_labels[i][-2:] == ':6':\n            ref_labels[i] = ref_labels[i].replace(':6', ':maj6')\n        elif ref_labels[i][-4:] == ':6/2':\n            ref_labels[i] = ref_labels[i].replace(':6/2', ':maj6/2')\n        elif ref_labels[i] == 'Emin/4':\n            ref_labels[i] = 'E:min/4'\n        elif ref_labels[i] == 'A7/3':\n            ref_labels[i] = 'A:7/3'\n        elif ref_labels[i] == 'Bb7/3':\n            ref_labels[i] = 'Bb:7/3'\n        elif ref_labels[i] == 'Bb7/5':\n            ref_labels[i] = 'Bb:7/5'\n        elif ref_labels[i].find(':') == -1:\n            if ref_labels[i].find('min') != -1:\n                ref_labels[i] = ref_labels[i][:ref_labels[i].find('min')] + ':' + ref_labels[i][ref_labels[i].find('min'):]\n    return ref_labels\n\ndef root_majmin_score_calculation(valid_dataset, config, mean, std, device, model, model_type, verbose=False):\n    valid_song_names = valid_dataset.song_names\n    paths = valid_dataset.preprocessor.get_all_files()\n\n    metrics_ = metrics()\n    song_length_list = list()\n    for path in paths:\n        song_name, lab_file_path, mp3_file_path, _ = path\n        if not song_name in valid_song_names:\n            continue\n        try:\n            n_timestep = config.model['timestep']\n            feature, feature_per_second, song_length_second = audio_file_to_features(mp3_file_path, config)\n            feature = feature.T\n            feature = (feature - mean) / std\n            time_unit = feature_per_second\n\n            num_pad = n_timestep - (feature.shape[0] % n_timestep)\n            feature = np.pad(feature, ((0, num_pad), (0, 0)), mode=\"constant\", constant_values=0)\n            num_instance = feature.shape[0] // n_timestep\n\n            start_time = 0.0\n            lines = []\n            with torch.no_grad():\n                model.eval()\n                feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device)\n                for t in range(num_instance):\n                    if model_type == 'btc':\n                        encoder_output, _ = model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :])\n                        prediction, _ = model.output_layer(encoder_output)\n                        prediction = prediction.squeeze()\n                    elif model_type == 'cnn' or model_type =='crnn':\n                        prediction, _, _, _ = model(feature[:, n_timestep * t:n_timestep * (t + 1), :], torch.randint(config.model['num_chords'], (n_timestep,)).to(device))\n                    for i in range(n_timestep):\n                        if t == 0 and i == 0:\n                            prev_chord = prediction[i].item()\n                            continue\n                        if prediction[i].item() != prev_chord:\n                            lines.append(\n                                '%.6f %.6f %s\\n' % (\n                                    start_time, time_unit * (n_timestep * t + i), idx2chord[prev_chord]))\n                            start_time = time_unit * (n_timestep * t + i)\n                            prev_chord = prediction[i].item()\n                        if t == num_instance - 1 and i + num_pad == n_timestep:\n                            if start_time != time_unit * (n_timestep * t + i):\n                                lines.append(\n                                    '%.6f %.6f %s\\n' % (\n                                        start_time, time_unit * (n_timestep * t + i), idx2chord[prev_chord]))\n                            break\n            pid = os.getpid()\n            tmp_path = 'tmp_' + str(pid) + '.lab'\n            with open(tmp_path, 'w') as f:\n                for line in lines:\n                    f.write(line)\n\n            root_majmin = ['root', 'majmin']\n            for m in root_majmin:\n                metrics_.score_list_dict[m].append(metrics_.score(metric=m, gt_path=lab_file_path, est_path=tmp_path))\n            song_length_list.append(song_length_second)\n            if verbose:\n                for m in root_majmin:\n                    print('song name %s, %s score : %.4f' % (song_name, m, metrics_.score_list_dict[m][-1]))\n        except:\n            print('song name %s\\' lab file error' % song_name)\n\n    tmp = song_length_list / np.sum(song_length_list)\n    for m in root_majmin:\n        metrics_.average_score[m] = np.sum(np.multiply(metrics_.score_list_dict[m], tmp))\n\n    return metrics_.score_list_dict, song_length_list, metrics_.average_score\n\ndef root_majmin_score_calculation_crf(valid_dataset, config, mean, std, device, pre_model, model, model_type, verbose=False):\n    valid_song_names = valid_dataset.song_names\n    paths = valid_dataset.preprocessor.get_all_files()\n\n    metrics_ = metrics()\n    song_length_list = list()\n    for path in paths:\n        song_name, lab_file_path, mp3_file_path, _ = path\n        if not song_name in valid_song_names:\n            continue\n        try:\n            n_timestep = config.model['timestep']\n            feature, feature_per_second, song_length_second = audio_file_to_features(mp3_file_path, config)\n            feature = feature.T\n            feature = (feature - mean) / std\n            time_unit = feature_per_second\n\n            num_pad = n_timestep - (feature.shape[0] % n_timestep)\n            feature = np.pad(feature, ((0, num_pad), (0, 0)), mode=\"constant\", constant_values=0)\n            num_instance = feature.shape[0] // n_timestep\n\n            start_time = 0.0\n            lines = []\n            with torch.no_grad():\n                model.eval()\n                feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device)\n                for t in range(num_instance):\n                    if (model_type == 'cnn') or (model_type == 'crnn') or (model_type == 'btc'):\n                        logits = pre_model(feature[:, n_timestep * t:n_timestep * (t + 1), :], torch.randint(config.model['num_chords'], (n_timestep,)).to(device))\n                        prediction, _ = model(logits, torch.randint(config.model['num_chords'], (n_timestep,)).to(device))\n                    else:\n                        raise NotImplementedError\n                    for i in range(n_timestep):\n                        if t == 0 and i == 0:\n                            prev_chord = prediction[i].item()\n                            continue\n                        if prediction[i].item() != prev_chord:\n                            lines.append(\n                                '%.6f %.6f %s\\n' % (\n                                    start_time, time_unit * (n_timestep * t + i), idx2chord[prev_chord]))\n                            start_time = time_unit * (n_timestep * t + i)\n                            prev_chord = prediction[i].item()\n                        if t == num_instance - 1 and i + num_pad == n_timestep:\n                            if start_time != time_unit * (n_timestep * t + i):\n                                lines.append(\n                                    '%.6f %.6f %s\\n' % (\n                                        start_time, time_unit * (n_timestep * t + i), idx2chord[prev_chord]))\n                            break\n            pid = os.getpid()\n            tmp_path = 'tmp_' + str(pid) + '.lab'\n            with open(tmp_path, 'w') as f:\n                for line in lines:\n                    f.write(line)\n\n            root_majmin = ['root', 'majmin']\n            for m in root_majmin:\n                metrics_.score_list_dict[m].append(metrics_.score(metric=m, gt_path=lab_file_path, est_path=tmp_path))\n            song_length_list.append(song_length_second)\n            if verbose:\n                for m in root_majmin:\n                    print('song name %s, %s score : %.4f' % (song_name, m, metrics_.score_list_dict[m][-1]))\n        except:\n            print('song name %s\\' lab file error' % song_name)\n\n    tmp = song_length_list / np.sum(song_length_list)\n    for m in root_majmin:\n        metrics_.average_score[m] = np.sum(np.multiply(metrics_.score_list_dict[m], tmp))\n\n    return metrics_.score_list_dict, song_length_list, metrics_.average_score\n\n\ndef large_voca_score_calculation(valid_dataset, config, mean, std, device, model, model_type, verbose=False):\n    idx2voca = idx2voca_chord()\n    valid_song_names = valid_dataset.song_names\n    paths = valid_dataset.preprocessor.get_all_files()\n\n    metrics_ = metrics()\n    song_length_list = list()\n    for path in paths:\n        song_name, lab_file_path, mp3_file_path, _ = path\n        if not song_name in valid_song_names:\n            continue\n        try:\n            n_timestep = config.model['timestep']\n            feature, feature_per_second, song_length_second = audio_file_to_features(mp3_file_path, config)\n            feature = feature.T\n            feature = (feature - mean) / std\n            time_unit = feature_per_second\n\n            num_pad = n_timestep - (feature.shape[0] % n_timestep)\n            feature = np.pad(feature, ((0, num_pad), (0, 0)), mode=\"constant\", constant_values=0)\n            num_instance = feature.shape[0] // n_timestep\n\n            start_time = 0.0\n            lines = []\n            with torch.no_grad():\n                model.eval()\n                feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device)\n                for t in range(num_instance):\n                    if model_type == 'btc':\n                        encoder_output, _ = model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :])\n                        prediction, _ = model.output_layer(encoder_output)\n                        prediction = prediction.squeeze()\n                    elif model_type == 'cnn' or model_type =='crnn':\n                        prediction, _, _, _ = model(feature[:, n_timestep * t:n_timestep * (t + 1), :], torch.randint(config.model['num_chords'], (n_timestep,)).to(device))\n                    for i in range(n_timestep):\n                        if t == 0 and i == 0:\n                            prev_chord = prediction[i].item()\n                            continue\n                        if prediction[i].item() != prev_chord:\n                            lines.append(\n                                '%.6f %.6f %s\\n' % (\n                                    start_time, time_unit * (n_timestep * t + i), idx2voca[prev_chord]))\n                            start_time = time_unit * (n_timestep * t + i)\n                            prev_chord = prediction[i].item()\n                        if t == num_instance - 1 and i + num_pad == n_timestep:\n                            if start_time != time_unit * (n_timestep * t + i):\n                                lines.append(\n                                    '%.6f %.6f %s\\n' % (\n                                        start_time, time_unit * (n_timestep * t + i), idx2voca[prev_chord]))\n                            break\n            pid = os.getpid()\n            tmp_path = 'tmp_' + str(pid) + '.lab'\n            with open(tmp_path, 'w') as f:\n                for line in lines:\n                    f.write(line)\n\n            for m in metrics_.score_metrics:\n                metrics_.score_list_dict[m].append(metrics_.score(metric=m, gt_path=lab_file_path, est_path=tmp_path))\n            song_length_list.append(song_length_second)\n            if verbose:\n                for m in metrics_.score_metrics:\n                    print('song name %s, %s score : %.4f' % (song_name, m, metrics_.score_list_dict[m][-1]))\n        except:\n            print('song name %s\\' lab file error' % song_name)\n\n    tmp = song_length_list / np.sum(song_length_list)\n    for m in metrics_.score_metrics:\n        metrics_.average_score[m] = np.sum(np.multiply(metrics_.score_list_dict[m], tmp))\n\n    return metrics_.score_list_dict, song_length_list, metrics_.average_score\n\ndef large_voca_score_calculation_crf(valid_dataset, config, mean, std, device, pre_model, model, model_type, verbose=False):\n    idx2voca = idx2voca_chord()\n    valid_song_names = valid_dataset.song_names\n    paths = valid_dataset.preprocessor.get_all_files()\n\n    metrics_ = metrics()\n    song_length_list = list()\n    for path in paths:\n        song_name, lab_file_path, mp3_file_path, _ = path\n        if not song_name in valid_song_names:\n            continue\n        try:\n            n_timestep = config.model['timestep']\n            feature, feature_per_second, song_length_second = audio_file_to_features(mp3_file_path, config)\n            feature = feature.T\n            feature = (feature - mean) / std\n            time_unit = feature_per_second\n\n            num_pad = n_timestep - (feature.shape[0] % n_timestep)\n            feature = np.pad(feature, ((0, num_pad), (0, 0)), mode=\"constant\", constant_values=0)\n            num_instance = feature.shape[0] // n_timestep\n\n            start_time = 0.0\n            lines = []\n            with torch.no_grad():\n                model.eval()\n                feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(device)\n                for t in range(num_instance):\n                    if (model_type == 'cnn') or (model_type == 'crnn') or (model_type == 'btc'):\n                        logits = pre_model(feature[:, n_timestep * t:n_timestep * (t + 1), :], torch.randint(config.model['num_chords'], (n_timestep,)).to(device))\n                        prediction, _ = model(logits, torch.randint(config.model['num_chords'], (n_timestep,)).to(device))\n                    else:\n                        raise NotImplementedError\n                    for i in range(n_timestep):\n                        if t == 0 and i == 0:\n                            prev_chord = prediction[i].item()\n                            continue\n                        if prediction[i].item() != prev_chord:\n                            lines.append(\n                                '%.6f %.6f %s\\n' % (\n                                    start_time, time_unit * (n_timestep * t + i), idx2voca[prev_chord]))\n                            start_time = time_unit * (n_timestep * t + i)\n                            prev_chord = prediction[i].item()\n                        if t == num_instance - 1 and i + num_pad == n_timestep:\n                            if start_time != time_unit * (n_timestep * t + i):\n                                lines.append(\n                                    '%.6f %.6f %s\\n' % (\n                                        start_time, time_unit * (n_timestep * t + i), idx2voca[prev_chord]))\n                            break\n            pid = os.getpid()\n            tmp_path = 'tmp_' + str(pid) + '.lab'\n            with open(tmp_path, 'w') as f:\n                for line in lines:\n                    f.write(line)\n\n            for m in metrics_.score_metrics:\n                metrics_.score_list_dict[m].append(metrics_.score(metric=m, gt_path=lab_file_path, est_path=tmp_path))\n            song_length_list.append(song_length_second)\n            if verbose:\n                for m in metrics_.score_metrics:\n                    print('song name %s, %s score : %.4f' % (song_name, m, metrics_.score_list_dict[m][-1]))\n        except:\n            print('song name %s\\' lab file error' % song_name)\n\n    tmp = song_length_list / np.sum(song_length_list)\n    for m in metrics_.score_metrics:\n        metrics_.average_score[m] = np.sum(np.multiply(metrics_.score_list_dict[m], tmp))\n\n    return metrics_.score_list_dict, song_length_list, metrics_.average_score\n"
  },
  {
    "path": "utils/preprocess.py",
    "content": "import os\nimport librosa\nfrom utils.chords import Chords\nimport re\nfrom enum import Enum\nimport pyrubberband as pyrb\nimport torch\nimport math\n\nclass FeatureTypes(Enum):\n    cqt = 'cqt'\n\nclass Preprocess():\n    def __init__(self, config, feature_to_use, dataset_names, root_dir):\n        self.config = config\n        self.dataset_names = dataset_names\n        self.root_path = root_dir + '/'\n\n        self.time_interval = config.feature[\"hop_length\"]/config.mp3[\"song_hz\"]\n        self.no_of_chord_datapoints_per_sequence = math.ceil(config.mp3['inst_len'] / self.time_interval)\n        self.Chord_class = Chords()\n\n        # isophonic\n        self.isophonic_directory = self.root_path + 'isophonic/'\n\n        # uspop\n        self.uspop_directory = self.root_path + 'uspop/'\n        self.uspop_audio_path = 'audio/'\n        self.uspop_lab_path = 'annotations/uspopLabels/'\n        self.uspop_index_path = 'annotations/uspopLabels.txt'\n\n        # robbie williams\n        self.robbie_williams_directory = self.root_path + 'robbiewilliams/'\n        self.robbie_williams_audio_path = 'audio/'\n        self.robbie_williams_lab_path = 'chords/'\n\n        self.feature_name = feature_to_use\n        self.is_cut_last_chord = False\n\n    def find_mp3_path(self, dirpath, word):\n        for filename in os.listdir(dirpath):\n            last_dir = dirpath.split(\"/\")[-2]\n            if \".mp3\" in filename:\n                tmp = filename.replace(\".mp3\", \"\")\n                tmp = tmp.replace(last_dir, \"\")\n                filename_lower = tmp.lower()\n                filename_lower = \" \".join(re.findall(\"[a-zA-Z]+\", filename_lower))\n                if word.lower().replace(\" \", \"\") in filename_lower.replace(\" \", \"\"):\n                    return filename\n\n    def find_mp3_path_robbiewilliams(self, dirpath, word):\n        for filename in os.listdir(dirpath):\n            if \".mp3\" in filename:\n                tmp = filename.replace(\".mp3\", \"\")\n                filename_lower = tmp.lower()\n                filename_lower = filename_lower.replace(\"robbie williams\", \"\")\n                filename_lower = \" \".join(re.findall(\"[a-zA-Z]+\", filename_lower))\n                filename_lower = self.song_pre(filename_lower)\n                if self.song_pre(word.lower()).replace(\" \", \"\") in filename_lower.replace(\" \", \"\"):\n                    return filename\n\n    def get_all_files(self):\n        res_list = []\n\n        # isophonic\n        if \"isophonic\" in self.dataset_names:\n            for dirpath, dirnames, filenames in os.walk(self.isophonic_directory):\n                if not dirnames:\n                    for filename in filenames:\n                        if \".lab\" in filename:\n                            tmp = filename.replace(\".lab\", \"\")\n                            song_name = \" \".join(re.findall(\"[a-zA-Z]+\", tmp)).replace(\"CD\", \"\")\n                            mp3_path = self.find_mp3_path(dirpath, song_name)\n                            res_list.append([song_name, os.path.join(dirpath, filename), os.path.join(dirpath, mp3_path),\n                                             os.path.join(self.root_path, \"result\", \"isophonic\")])\n\n        # uspop\n        if \"uspop\" in self.dataset_names:\n            with open(os.path.join(self.uspop_directory, self.uspop_index_path)) as f:\n                uspop_lab_list = f.readlines()\n            uspop_lab_list = [x.strip() for x in uspop_lab_list]\n\n            for lab_path in uspop_lab_list:\n                spl = lab_path.split('/')\n                lab_artist = self.uspop_pre(spl[2])\n                lab_title = self.uspop_pre(spl[4][3:-4])\n                lab_path = lab_path.replace('./uspopLabels/', '')\n                lab_path = os.path.join(self.uspop_directory, self.uspop_lab_path, lab_path)\n\n                for filename in os.listdir(os.path.join(self.uspop_directory, self.uspop_audio_path)):\n                    if not '.csv' in filename:\n                        spl = filename.split('-')\n                        mp3_artist = self.uspop_pre(spl[0])\n                        mp3_title = self.uspop_pre(spl[1][:-4])\n\n                        if lab_artist == mp3_artist and lab_title == mp3_title:\n                            res_list.append([mp3_artist + mp3_title, lab_path,\n                                             os.path.join(self.uspop_directory, self.uspop_audio_path, filename),\n                                             os.path.join(self.root_path, \"result\", \"uspop\")])\n                            break\n\n        # robbie williams\n        if \"robbiewilliams\" in self.dataset_names:\n            for dirpath, dirnames, filenames in os.walk(self.robbie_williams_directory):\n                if not dirnames:\n                    for filename in filenames:\n                        if \".txt\" in filename and (not 'README' in filename):\n                            tmp = filename.replace(\".txt\", \"\")\n                            song_name = \" \".join(re.findall(\"[a-zA-Z]+\", tmp)).replace(\"GTChords\", \"\")\n                            mp3_dir = dirpath.replace(\"chords\", \"audio\")\n                            mp3_path = self.find_mp3_path_robbiewilliams(mp3_dir, song_name)\n                            res_list.append([song_name, os.path.join(dirpath, filename), os.path.join(mp3_dir, mp3_path),\n                                             os.path.join(self.root_path, \"result\", \"robbiewilliams\")])\n        return res_list\n\n    def uspop_pre(self, text):\n        text = text.lower()\n        text = text.replace('_', '')\n        text = text.replace(' ', '')\n        text = \" \".join(re.findall(\"[a-zA-Z]+\", text))\n        return text\n\n    def song_pre(self, text):\n        to_remove = [\"'\", '`', '(', ')', ' ', '&', 'and', 'And']\n\n        for remove in to_remove:\n            text = text.replace(remove, '')\n\n        return text\n\n    def config_to_folder(self):\n        mp3_config = self.config.mp3\n        feature_config = self.config.feature\n        mp3_string = \"%d_%.1f_%.1f\" % \\\n                     (mp3_config['song_hz'], mp3_config['inst_len'],\n                      mp3_config['skip_interval'])\n        feature_string = \"%s_%d_%d_%d\" % \\\n                         (self.feature_name.value, feature_config['n_bins'], feature_config['bins_per_octave'], feature_config['hop_length'])\n\n        return mp3_config, feature_config, mp3_string, feature_string\n\n    def generate_labels_features_new(self, all_list):\n        pid = os.getpid()\n        mp3_config, feature_config, mp3_str, feature_str = self.config_to_folder()\n\n        i = 0  # number of songs\n        j = 0  # number of impossible songs\n        k = 0  # number of tried songs\n        total = 0  # number of generated instances\n\n        stretch_factors = [1.0]\n        shift_factors = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6]\n\n        loop_broken = False\n        for song_name, lab_path, mp3_path, save_path in all_list:\n\n            # different song initialization\n            if loop_broken:\n                loop_broken = False\n\n            i += 1\n            print(pid, \"generating features from ...\", os.path.join(mp3_path))\n            if i % 10 == 0:\n                print(i, ' th song')\n\n            original_wav, sr = librosa.load(os.path.join(mp3_path), sr=mp3_config['song_hz'])\n\n            # make result path if not exists\n            # save_path, mp3_string, feature_string, song_name, aug.pt\n            result_path = os.path.join(save_path, mp3_str, feature_str, song_name.strip())\n            if not os.path.exists(result_path):\n                os.makedirs(result_path)\n\n            # calculate result\n            for stretch_factor in stretch_factors:\n                if loop_broken:\n                    loop_broken = False\n                    break\n\n                for shift_factor in shift_factors:\n                    # for filename\n                    idx = 0\n\n                    chord_info = self.Chord_class.get_converted_chord(os.path.join(lab_path))\n\n                    k += 1\n                    # stretch original sound and chord info\n                    x = pyrb.time_stretch(original_wav, sr, stretch_factor)\n                    x = pyrb.pitch_shift(x, sr, shift_factor)\n                    audio_length = x.shape[0]\n                    chord_info['start'] = chord_info['start'] * 1/stretch_factor\n                    chord_info['end'] = chord_info['end'] * 1/stretch_factor\n\n                    last_sec = chord_info.iloc[-1]['end']\n                    last_sec_hz = int(last_sec * mp3_config['song_hz'])\n\n                    if audio_length + mp3_config['skip_interval'] < last_sec_hz:\n                        print('loaded song is too short :', song_name)\n                        loop_broken = True\n                        j += 1\n                        break\n                    elif audio_length > last_sec_hz:\n                        x = x[:last_sec_hz]\n\n                    origin_length = last_sec_hz\n                    origin_length_in_sec = origin_length / mp3_config['song_hz']\n\n                    current_start_second = 0\n\n                    # get chord list between current_start_second and current+song_length\n                    while current_start_second + mp3_config['inst_len'] < origin_length_in_sec:\n                        inst_start_sec = current_start_second\n                        curSec = current_start_second\n\n                        chord_list = []\n                        # extract chord per 1/self.time_interval\n                        while curSec < inst_start_sec + mp3_config['inst_len']:\n                            try:\n                                available_chords = chord_info.loc[(chord_info['start'] <= curSec) & (\n                                        chord_info['end'] > curSec + self.time_interval)].copy()\n                                if len(available_chords) == 0:\n                                    available_chords = chord_info.loc[((chord_info['start'] >= curSec) & (\n                                            chord_info['start'] <= curSec + self.time_interval)) | (\n                                                                              (chord_info['end'] >= curSec) & (\n                                                                              chord_info['end'] <= curSec + self.time_interval))].copy()\n                                if len(available_chords) == 1:\n                                    chord = available_chords['chord_id'].iloc[0]\n                                elif len(available_chords) > 1:\n                                    max_starts = available_chords.apply(lambda row: max(row['start'], curSec),\n                                                                        axis=1)\n                                    available_chords['max_start'] = max_starts\n                                    min_ends = available_chords.apply(\n                                        lambda row: min(row.end, curSec + self.time_interval), axis=1)\n                                    available_chords['min_end'] = min_ends\n                                    chords_lengths = available_chords['min_end'] - available_chords['max_start']\n                                    available_chords['chord_length'] = chords_lengths\n                                    chord = available_chords.ix[available_chords['chord_length'].idxmax()]['chord_id']\n                                else:\n                                    chord = 24\n                            except Exception as e:\n                                chord = 24\n                                print(e)\n                                print(pid, \"no chord\")\n                                raise RuntimeError()\n                            finally:\n                                # convert chord by shift factor\n                                if chord != 24:\n                                    chord += shift_factor * 2\n                                    chord = chord % 24\n\n                                chord_list.append(chord)\n                                curSec += self.time_interval\n\n                        if len(chord_list) == self.no_of_chord_datapoints_per_sequence:\n                            try:\n                                sequence_start_time = current_start_second\n                                sequence_end_time = current_start_second + mp3_config['inst_len']\n\n                                start_index = int(sequence_start_time * mp3_config['song_hz'])\n                                end_index = int(sequence_end_time * mp3_config['song_hz'])\n\n                                song_seq = x[start_index:end_index]\n\n                                etc = '%.1f_%.1f' % (\n                                    current_start_second, current_start_second + mp3_config['inst_len'])\n                                aug = '%.2f_%i' % (stretch_factor, shift_factor)\n\n                                if self.feature_name == FeatureTypes.cqt:\n                                    # print(pid, \"make feature\")\n                                    feature = librosa.cqt(song_seq, sr=sr, n_bins=feature_config['n_bins'],\n                                                          bins_per_octave=feature_config['bins_per_octave'],\n                                                          hop_length=feature_config['hop_length'])\n                                else:\n                                    raise NotImplementedError\n\n                                if feature.shape[1] > self.no_of_chord_datapoints_per_sequence:\n                                    feature = feature[:, :self.no_of_chord_datapoints_per_sequence]\n\n                                if feature.shape[1] != self.no_of_chord_datapoints_per_sequence:\n                                    print('loaded features length is too short :', song_name)\n                                    loop_broken = True\n                                    j += 1\n                                    break\n\n                                result = {\n                                    'feature': feature,\n                                    'chord': chord_list,\n                                    'etc': etc\n                                }\n\n                                # save_path, mp3_string, feature_string, song_name, aug.pt\n                                filename = aug + \"_\" + str(idx) + \".pt\"\n                                torch.save(result, os.path.join(result_path, filename))\n                                idx += 1\n                                total += 1\n                            except Exception as e:\n                                print(e)\n                                print(pid, \"feature error\")\n                                raise RuntimeError()\n                        else:\n                            print(\"invalid number of chord datapoints in sequence :\", len(chord_list))\n                        current_start_second += mp3_config['skip_interval']\n        print(pid, \"total instances: %d\" % total)\n\n    def generate_labels_features_voca(self, all_list):\n        pid = os.getpid()\n        mp3_config, feature_config, mp3_str, feature_str = self.config_to_folder()\n\n        i = 0  # number of songs\n        j = 0  # number of impossible songs\n        k = 0  # number of tried songs\n        total = 0  # number of generated instances\n        stretch_factors = [1.0]\n        shift_factors = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6]\n\n        loop_broken = False\n        for song_name, lab_path, mp3_path, save_path in all_list:\n            save_path = save_path + '_voca'\n\n            # different song initialization\n            if loop_broken:\n                loop_broken = False\n\n            i += 1\n            print(pid, \"generating features from ...\", os.path.join(mp3_path))\n            if i % 10 == 0:\n                print(i, ' th song')\n\n            original_wav, sr = librosa.load(os.path.join(mp3_path), sr=mp3_config['song_hz'])\n\n            # save_path, mp3_string, feature_string, song_name, aug.pt\n            result_path = os.path.join(save_path, mp3_str, feature_str, song_name.strip())\n            if not os.path.exists(result_path):\n                os.makedirs(result_path)\n\n            # calculate result\n            for stretch_factor in stretch_factors:\n                if loop_broken:\n                    loop_broken = False\n                    break\n\n                for shift_factor in shift_factors:\n                    # for filename\n                    idx = 0\n\n                    try:\n                        chord_info = self.Chord_class.get_converted_chord_voca(os.path.join(lab_path))\n                    except Exception as e:\n                        print(e)\n                        print(pid, \" chord lab file error : %s\" % song_name)\n                        loop_broken = True\n                        j += 1\n                        break\n\n                    k += 1\n                    # stretch original sound and chord info\n                    x = pyrb.time_stretch(original_wav, sr, stretch_factor)\n                    x = pyrb.pitch_shift(x, sr, shift_factor)\n                    audio_length = x.shape[0]\n                    chord_info['start'] = chord_info['start'] * 1/stretch_factor\n                    chord_info['end'] = chord_info['end'] * 1/stretch_factor\n\n                    last_sec = chord_info.iloc[-1]['end']\n                    last_sec_hz = int(last_sec * mp3_config['song_hz'])\n\n                    if audio_length + mp3_config['skip_interval'] < last_sec_hz:\n                        print('loaded song is too short :', song_name)\n                        loop_broken = True\n                        j += 1\n                        break\n                    elif audio_length > last_sec_hz:\n                        x = x[:last_sec_hz]\n\n                    origin_length = last_sec_hz\n                    origin_length_in_sec = origin_length / mp3_config['song_hz']\n\n                    current_start_second = 0\n\n                    # get chord list between current_start_second and current+song_length\n                    while current_start_second + mp3_config['inst_len'] < origin_length_in_sec:\n                        inst_start_sec = current_start_second\n                        curSec = current_start_second\n\n                        chord_list = []\n                        # extract chord per 1/self.time_interval\n                        while curSec < inst_start_sec + mp3_config['inst_len']:\n                            try:\n                                available_chords = chord_info.loc[(chord_info['start'] <= curSec) & (chord_info['end'] > curSec + self.time_interval)].copy()\n                                if len(available_chords) == 0:\n                                    available_chords = chord_info.loc[((chord_info['start'] >= curSec) & (chord_info['start'] <= curSec + self.time_interval)) | ((chord_info['end'] >= curSec) & (chord_info['end'] <= curSec + self.time_interval))].copy()\n\n                                if len(available_chords) == 1:\n                                    chord = available_chords['chord_id'].iloc[0]\n                                elif len(available_chords) > 1:\n                                    max_starts = available_chords.apply(lambda row: max(row['start'], curSec),axis=1)\n                                    available_chords['max_start'] = max_starts\n                                    min_ends = available_chords.apply(lambda row: min(row.end, curSec + self.time_interval), axis=1)\n                                    available_chords['min_end'] = min_ends\n                                    chords_lengths = available_chords['min_end'] - available_chords['max_start']\n                                    available_chords['chord_length'] = chords_lengths\n                                    chord = available_chords.ix[available_chords['chord_length'].idxmax()]['chord_id']\n                                else:\n                                    chord = 169\n                            except Exception as e:\n                                chord = 169\n                                print(e)\n                                print(pid, \"no chord\")\n                                raise RuntimeError()\n                            finally:\n                                # convert chord by shift factor\n                                if chord != 169 and chord != 168:\n                                    chord += shift_factor * 14\n                                    chord = chord % 168\n\n                                chord_list.append(chord)\n                                curSec += self.time_interval\n\n                        if len(chord_list) == self.no_of_chord_datapoints_per_sequence:\n                            try:\n                                sequence_start_time = current_start_second\n                                sequence_end_time = current_start_second + mp3_config['inst_len']\n\n                                start_index = int(sequence_start_time * mp3_config['song_hz'])\n                                end_index = int(sequence_end_time * mp3_config['song_hz'])\n\n                                song_seq = x[start_index:end_index]\n\n                                etc = '%.1f_%.1f' % (\n                                    current_start_second, current_start_second + mp3_config['inst_len'])\n                                aug = '%.2f_%i' % (stretch_factor, shift_factor)\n\n                                if self.feature_name == FeatureTypes.cqt:\n                                    feature = librosa.cqt(song_seq, sr=sr, n_bins=feature_config['n_bins'],\n                                                          bins_per_octave=feature_config['bins_per_octave'],\n                                                          hop_length=feature_config['hop_length'])\n                                else:\n                                    raise NotImplementedError\n\n                                if feature.shape[1] > self.no_of_chord_datapoints_per_sequence:\n                                    feature = feature[:, :self.no_of_chord_datapoints_per_sequence]\n\n                                if feature.shape[1] != self.no_of_chord_datapoints_per_sequence:\n                                    print('loaded features length is too short :', song_name)\n                                    loop_broken = True\n                                    j += 1\n                                    break\n\n                                result = {\n                                    'feature': feature,\n                                    'chord': chord_list,\n                                    'etc': etc\n                                }\n\n                                # save_path, mp3_string, feature_string, song_name, aug.pt\n                                filename = aug + \"_\" + str(idx) + \".pt\"\n                                torch.save(result, os.path.join(result_path, filename))\n                                idx += 1\n                                total += 1\n                            except Exception as e:\n                                print(e)\n                                print(pid, \"feature error\")\n                                raise RuntimeError()\n                        else:\n                            print(\"invalid number of chord datapoints in sequence :\", len(chord_list))\n                        current_start_second += mp3_config['skip_interval']\n        print(pid, \"total instances: %d\" % total)"
  },
  {
    "path": "utils/pytorch_utils.py",
    "content": "\nimport torch\nimport numpy as np\nimport os\nimport math\nfrom utils import logger\n\nuse_cuda = torch.cuda.is_available()\n\n\n# optimization\n# reference: http://pytorch.org/docs/master/_modules/torch/optim/lr_scheduler.html#ReduceLROnPlateau\ndef adjusting_learning_rate(optimizer, factor=.5, min_lr=0.00001):\n    for i, param_group in enumerate(optimizer.param_groups):\n        old_lr = float(param_group['lr'])\n        new_lr = max(old_lr * factor, min_lr)\n        param_group['lr'] = new_lr\n        logger.info('adjusting learning rate from %.6f to %.6f' % (old_lr, new_lr))\n\n\n# model save and loading\ndef load_model(asset_path, model, optimizer, restore_epoch=0):\n    if os.path.isfile(os.path.join(asset_path, 'model', 'checkpoint_%d.pth.tar' % restore_epoch), map_location=lambda storage, loc: storage):\n        checkpoint = torch.load(os.path.join(asset_path, 'model', 'checkpoint_%d.pth.tar' % restore_epoch))\n        model.load_state_dict(checkpoint['model'])\n        optimizer.load_state_dict(checkpoint['optimizer'])\n        current_step = checkpoint['current_step']\n        logger.info(\"restore model with %d epoch\" % restore_epoch)\n    else:\n        logger.info(\"no checkpoint with %d epoch\" % restore_epoch)\n        current_step = 0\n\n    return model, optimizer, current_step\n"
  },
  {
    "path": "utils/tf_logger.py",
    "content": "import tensorflow as tf\nimport numpy as np\nimport scipy.misc\n\ntry:\n    from StringIO import StringIO  # Python 2.7\nexcept ImportError:\n    from io import BytesIO  # Python 3.x\n\n\nclass TF_Logger(object):\n    def __init__(self, log_dir):\n        \"\"\"Create a summary writer logging to log_dir.\"\"\"\n        self.writer = tf.summary.FileWriter(log_dir)\n\n    def scalar_summary(self, tag, value, step):\n        \"\"\"Log a scalar variable.\"\"\"\n        summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])\n        self.writer.add_summary(summary, step)\n\n    def image_summary(self, tag, images, step):\n        \"\"\"Log a list of images.\"\"\"\n\n        img_summaries = []\n        for i, img in enumerate(images):\n            # Write the image to a string\n            try:\n                s = StringIO()\n            except:\n                s = BytesIO()\n            scipy.misc.toimage(img).save(s, format=\"png\")\n\n            # Create an Image object\n            img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),\n                                       height=img.shape[0],\n                                       width=img.shape[1])\n            # Create a Summary value\n            img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))\n\n        # Create and write Summary\n        summary = tf.Summary(value=img_summaries)\n        self.writer.add_summary(summary, step)\n\n    def histo_summary(self, tag, values, step, bins=1000):\n        \"\"\"Log a histogram of the tensor of values.\"\"\"\n\n        # Create a histogram using numpy\n        counts, bin_edges = np.histogram(values, bins=bins)\n\n        # Fill the fields of the histogram proto\n        hist = tf.HistogramProto()\n        hist.min = float(np.min(values))\n        hist.max = float(np.max(values))\n        hist.num = int(np.prod(values.shape))\n        hist.sum = float(np.sum(values))\n        hist.sum_squares = float(np.sum(values ** 2))\n\n        # Drop the start of the first bin\n        bin_edges = bin_edges[1:]\n\n        # Add bin edges and counts\n        for edge in bin_edges:\n            hist.bucket_limit.append(edge)\n        for c in counts:\n            hist.bucket.append(c)\n\n        # Create and write Summary\n        summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])\n        self.writer.add_summary(summary, step)\n        self.writer.flush()"
  },
  {
    "path": "utils/transformer_modules.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nimport math\n\ndef _gen_bias_mask(max_length):\n    \"\"\"\n    Generates bias values (-Inf) to mask future timesteps during attention\n    \"\"\"\n    np_mask = np.triu(np.full([max_length, max_length], -np.inf), 1)\n    torch_mask = torch.from_numpy(np_mask).type(torch.FloatTensor)\n    return torch_mask.unsqueeze(0).unsqueeze(1)\n\ndef _gen_timing_signal(length, channels, min_timescale=1.0, max_timescale=1.0e4):\n    \"\"\"\n    Generates a [1, length, channels] timing signal consisting of sinusoids\n    Adapted from:\n    https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py\n    \"\"\"\n    position = np.arange(length)\n    num_timescales = channels // 2\n    log_timescale_increment = (\n            math.log(float(max_timescale) / float(min_timescale)) /\n            (float(num_timescales) - 1))\n    inv_timescales = min_timescale * np.exp(\n        np.arange(num_timescales).astype(np.float) * -log_timescale_increment)\n    scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales, 0)\n\n    signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)\n    signal = np.pad(signal, [[0, 0], [0, channels % 2]],\n                    'constant', constant_values=[0.0, 0.0])\n    signal = signal.reshape([1, length, channels])\n\n    return torch.from_numpy(signal).type(torch.FloatTensor)\n\nclass LayerNorm(nn.Module):\n    # Borrowed from jekbradbury\n    # https://github.com/pytorch/pytorch/issues/1959\n    def __init__(self, features, eps=1e-6):\n        super(LayerNorm, self).__init__()\n        self.gamma = nn.Parameter(torch.ones(features))\n        self.beta = nn.Parameter(torch.zeros(features))\n        self.eps = eps\n\n    def forward(self, x):\n        mean = x.mean(-1, keepdim=True)\n        std = x.std(-1, keepdim=True)\n        return self.gamma * (x - mean) / (std + self.eps) + self.beta\n\nclass OutputLayer(nn.Module):\n    \"\"\"\n    Abstract base class for output layer.\n    Handles projection to output labels\n    \"\"\"\n    def __init__(self, hidden_size, output_size, probs_out=False):\n        super(OutputLayer, self).__init__()\n        self.output_size = output_size\n        self.output_projection = nn.Linear(hidden_size, output_size)\n        self.probs_out = probs_out\n        self.lstm = nn.LSTM(input_size=hidden_size, hidden_size=int(hidden_size/2), batch_first=True, bidirectional=True)\n        self.hidden_size = hidden_size\n\n    def loss(self, hidden, labels):\n        raise NotImplementedError('Must implement {}.loss'.format(self.__class__.__name__))\n\nclass SoftmaxOutputLayer(OutputLayer):\n    \"\"\"\n    Implements a softmax based output layer\n    \"\"\"\n    def forward(self, hidden):\n        logits = self.output_projection(hidden)\n        probs = F.softmax(logits, -1)\n        # _, predictions = torch.max(probs, dim=-1)\n        topk, indices = torch.topk(probs, 2)\n        predictions = indices[:,:,0]\n        second = indices[:,:,1]\n        if self.probs_out is True:\n            return logits\n            # return probs\n        return predictions, second\n\n    def loss(self, hidden, labels):\n        logits = self.output_projection(hidden)\n        log_probs = F.log_softmax(logits, -1)\n        return F.nll_loss(log_probs.view(-1, self.output_size), labels.view(-1))\n\nclass MultiHeadAttention(nn.Module):\n    \"\"\"\n    Multi-head attention as per https://arxiv.org/pdf/1706.03762.pdf\n    Refer Figure 2\n    \"\"\"\n\n    def __init__(self, input_depth, total_key_depth, total_value_depth, output_depth,\n                 num_heads, bias_mask=None, dropout=0.0, attention_map=False):\n        \"\"\"\n        Parameters:\n            input_depth: Size of last dimension of input\n            total_key_depth: Size of last dimension of keys. Must be divisible by num_head\n            total_value_depth: Size of last dimension of values. Must be divisible by num_head\n            output_depth: Size last dimension of the final output\n            num_heads: Number of attention heads\n            bias_mask: Masking tensor to prevent connections to future elements\n            dropout: Dropout probability (Should be non-zero only during training)\n        \"\"\"\n        super(MultiHeadAttention, self).__init__()\n        # Checks borrowed from\n        # https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py\n        if total_key_depth % num_heads != 0:\n            raise ValueError(\"Key depth (%d) must be divisible by the number of \"\n                             \"attention heads (%d).\" % (total_key_depth, num_heads))\n        if total_value_depth % num_heads != 0:\n            raise ValueError(\"Value depth (%d) must be divisible by the number of \"\n                             \"attention heads (%d).\" % (total_value_depth, num_heads))\n\n        self.attention_map = attention_map\n\n        self.num_heads = num_heads\n        self.query_scale = (total_key_depth // num_heads) ** -0.5\n        self.bias_mask = bias_mask\n\n        # Key and query depth will be same\n        self.query_linear = nn.Linear(input_depth, total_key_depth, bias=False)\n        self.key_linear = nn.Linear(input_depth, total_key_depth, bias=False)\n        self.value_linear = nn.Linear(input_depth, total_value_depth, bias=False)\n        self.output_linear = nn.Linear(total_value_depth, output_depth, bias=False)\n\n        self.dropout = nn.Dropout(dropout)\n\n    def _split_heads(self, x):\n        \"\"\"\n        Split x such to add an extra num_heads dimension\n        Input:\n            x: a Tensor with shape [batch_size, seq_length, depth]\n        Returns:\n            A Tensor with shape [batch_size, num_heads, seq_length, depth/num_heads]\n        \"\"\"\n        if len(x.shape) != 3:\n            raise ValueError(\"x must have rank 3\")\n        shape = x.shape\n        return x.view(shape[0], shape[1], self.num_heads, shape[2] // self.num_heads).permute(0, 2, 1, 3)\n\n    def _merge_heads(self, x):\n        \"\"\"\n        Merge the extra num_heads into the last dimension\n        Input:\n            x: a Tensor with shape [batch_size, num_heads, seq_length, depth/num_heads]\n        Returns:\n            A Tensor with shape [batch_size, seq_length, depth]\n        \"\"\"\n        if len(x.shape) != 4:\n            raise ValueError(\"x must have rank 4\")\n        shape = x.shape\n        return x.permute(0, 2, 1, 3).contiguous().view(shape[0], shape[2], shape[3] * self.num_heads)\n\n    def forward(self, queries, keys, values):\n\n        # Do a linear for each component\n        queries = self.query_linear(queries)\n        keys = self.key_linear(keys)\n        values = self.value_linear(values)\n\n        # Split into multiple heads\n        queries = self._split_heads(queries)\n        keys = self._split_heads(keys)\n        values = self._split_heads(values)\n\n        # Scale queries\n        queries *= self.query_scale\n\n        # Combine queries and keys\n        logits = torch.matmul(queries, keys.permute(0, 1, 3, 2))\n\n        # Add bias to mask future values\n        if self.bias_mask is not None:\n            logits += self.bias_mask[:, :, :logits.shape[-2], :logits.shape[-1]].type_as(logits.data)\n\n        # Convert to probabilites\n        weights = nn.functional.softmax(logits, dim=-1)\n\n        # Dropout\n        weights = self.dropout(weights)\n\n        # Combine with values to get context\n        contexts = torch.matmul(weights, values)\n\n        # Merge heads\n        contexts = self._merge_heads(contexts)\n        # contexts = torch.tanh(contexts)\n\n        # Linear to get output\n        outputs = self.output_linear(contexts)\n\n        if self.attention_map is True:\n            return outputs, weights\n\n        return outputs\n\n\nclass Conv(nn.Module):\n    \"\"\"\n    Convenience class that does padding and convolution for inputs in the format\n    [batch_size, sequence length, hidden size]\n    \"\"\"\n\n    def __init__(self, input_size, output_size, kernel_size, pad_type):\n        \"\"\"\n        Parameters:\n            input_size: Input feature size\n            output_size: Output feature size\n            kernel_size: Kernel width\n            pad_type: left -> pad on the left side (to mask future data_loader),\n                      both -> pad on both sides\n        \"\"\"\n        super(Conv, self).__init__()\n        padding = (kernel_size - 1, 0) if pad_type == 'left' else (kernel_size // 2, (kernel_size - 1) // 2)\n        self.pad = nn.ConstantPad1d(padding, 0)\n        self.conv = nn.Conv1d(input_size, output_size, kernel_size=kernel_size, padding=0)\n\n    def forward(self, inputs):\n        inputs = self.pad(inputs.permute(0, 2, 1))\n        outputs = self.conv(inputs).permute(0, 2, 1)\n\n        return outputs\n\n\nclass PositionwiseFeedForward(nn.Module):\n    \"\"\"\n    Does a Linear + RELU + Linear on each of the timesteps\n    \"\"\"\n\n    def __init__(self, input_depth, filter_size, output_depth, layer_config='ll', padding='left', dropout=0.0):\n        \"\"\"\n        Parameters:\n            input_depth: Size of last dimension of input\n            filter_size: Hidden size of the middle layer\n            output_depth: Size last dimension of the final output\n            layer_config: ll -> linear + ReLU + linear\n                          cc -> conv + ReLU + conv etc.\n            padding: left -> pad on the left side (to mask future data_loader),\n                     both -> pad on both sides\n            dropout: Dropout probability (Should be non-zero only during training)\n        \"\"\"\n        super(PositionwiseFeedForward, self).__init__()\n\n        layers = []\n        sizes = ([(input_depth, filter_size)] +\n                 [(filter_size, filter_size)] * (len(layer_config) - 2) +\n                 [(filter_size, output_depth)])\n\n        for lc, s in zip(list(layer_config), sizes):\n            if lc == 'l':\n                layers.append(nn.Linear(*s))\n            elif lc == 'c':\n                layers.append(Conv(*s, kernel_size=3, pad_type=padding))\n            else:\n                raise ValueError(\"Unknown layer type {}\".format(lc))\n\n        self.layers = nn.ModuleList(layers)\n        self.relu = nn.ReLU()\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, inputs):\n        x = inputs\n        for i, layer in enumerate(self.layers):\n            x = layer(x)\n            if i < len(self.layers):\n                x = self.relu(x)\n                x = self.dropout(x)\n\n        return x"
  }
]