Repository: iamyuanchung/Autoregressive-Predictive-Coding Branch: master Commit: 890c40dcaad3 Files: 7 Total size: 23.1 KB Directory structure: gitextract_f00nvz2h/ ├── README.md ├── apc_model.py ├── datasets.py ├── load_pretrained_model.py ├── prepare_data.py ├── train_apc.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: README.md ================================================ ## Autoregressive Predictive Coding This repository contains the official implementation (in PyTorch) of Autoregressive Predictive Coding (APC) proposed in [An Unsupervised Autoregressive Model for Speech Representation Learning](https://arxiv.org/abs/1904.03240). APC is a speech feature extractor trained on a large amount of unlabeled data. With an unsupervised, autoregressive training objective, representations learned by APC not only capture general acoustic characteristics such as speaker and phone information from the speech signals, but are also highly accessible to downstream models--our experimental results on phone classification show that a linear classifier taking the APC representations as the input features significantly outperforms a multi-layer percepron using the surface features. ## Dependencies * Python 3.5 * PyTorch 1.0 ## Dataset In the paper, we used the train-clean-360 split from the [LibriSpeech](http://www.openslr.org/12/) corpus for training the APC models, and the dev-clean split for keeping track of the training loss. We used the log Mel spectrograms, which were generated by running the Kaldi scripts, as the input acoustic features to the APC models. Of course you can generate the log Mel spectrograms yourself, but to help you better reproduce our results, here we provide the links to the data proprocessed by us that can be directly fed to the APC models. We also include other data splits that we did not use in the paper for you to explore, e.g., you can try training an APC model on a larger and nosier set (e.g., train-other-500) and see if it learns more robust speech representations. * [train-clean-100](https://www.dropbox.com/s/kl6ivulhucukdz1/train-clean-100.xz?dl=0) * [train-clean-360](https://www.dropbox.com/s/0hzg2momellrpoj/train-clean-360.xz?dl=0) (used for training APC models in our paper) * [train-other-500](https://www.dropbox.com/s/uy0aex30ufq2po8/train-other.xz?dl=0) * [dev-clean](https://www.dropbox.com/s/4f1ypyowwmkfapx/dev-clean.xz?dl=0) (used for tracing the training loss) ## Training APC Below we will follow the paper and use train-clean-360 and dev-clean as demonstration. Once you have downloaded the data, unzip them by running: ```bash xz -d train-clean-360.xz xz -d dev-clean.xz ``` Then, create a directory `librispeech_data/kaldi` and move the data into it: ```bash mkdir -p librispeech_data/kaldi mv train-clean-360-hires-norm.blogmel librispeech_data/kaldi mv dev-clean-hires-norm.blogmel librispeech_data/kaldi ``` Now we will have to transform the data into the format loadable by the PyTorch DataLoader. To do so, simply run: ```bash # Prepare the training set python prepare_data.py --librispeech_from_kaldi librispeech_data/kaldi/train-clean-360-hires-norm.blogmel --save_dir librispeech_data/preprocessed/train-clean-360-hires-norm.blogmel # Prepare the valication set python prepare_data.py --librispeech_from_kaldi librispeech_data/kaldi/dev-clean-hires-norm.blogmel --save_dir librispeech_data/preprocessed/dev-clean-hires-norm-blogmel ``` Once the program is done, you will see a directory `preprocessed/` inside `librispeech_data/` that contains all the preprocessed PyTorch tensors. To train an APC model, simply run: ```bash python train_apc.py ``` By default, the trained models will be put in `logs/`. You can also use Tensorboard to trace the training progress. There are many other configurations you can try, check `train_apc.py` for more details--it is highly documented and should be self-explanatory. ## Feature extraction Once you have trained your APC model, you can use it to extract speech features from your target dataset. To do so, feed-forward the trained model on the target dataset and retrieve the extracted features by running: ```bash _, feats = model.forward(inputs, lengths) ``` `feats` is a PyTorch tensor of shape (`num_layers`, `batch_size`, `seq_len`, `rnn_hidden_size`) where: - `num_layers` is the RNN depth of your APC model - `batch_size` is your inference batch size - `seq_len` is the maximum sequence length and is determined when you run `prepare_data.py`. By default this value is 1600. - `rnn_hidden_size` is the dimensionality of the RNN hidden unit. As you can see, `feats` is essentially the RNN hidden states in an APC model. You can think of APC as a speech version of [ELMo](https://www.aclweb.org/anthology/N18-1202) if you are familiar with it. There are many ways to incorporate `feats` into your downstream task. One of the easiest way is to take only the outputs of the last RNN layer (i.e., `feats[-1, :, :, :]`) as the input features to your downstream model, which is what we did in our paper. Feel free to explore other mechanisms. ## Pre-trained models We release the pre-trained models that were used to produce the numbers reported in the paper. `load_pretrained_model.py` provides a simple example of loading a pre-trained model. * [n = 1](https://www.dropbox.com/s/qyb1gicjkhv0wz9/bs32-rhl3-rhs512-rd0-adam-res-ts1.pt?dl=0) * [n = 2](https://www.dropbox.com/s/76amvx3fccfmp2n/bs32-rhl3-rhs512-rd0-adam-res-ts2.pt?dl=0) * [n = 3](https://www.dropbox.com/s/9nwj8y0djiw9pek/bs32-rhl3-rhs512-rd0-adam-res-ts3.pt?dl=0) * [n = 5](https://www.dropbox.com/s/8pqlr5wg89eicwk/bs32-rhl3-rhs512-rd0-adam-res-ts5.pt?dl=0) * [n = 10](https://www.dropbox.com/s/ucpf66k89xkm1jw/bs32-rhl3-rhs512-rd0-adam-res-ts10.pt?dl=0) * [n = 20](https://www.dropbox.com/s/wa01myucfifloqo/bs32-rhl3-rhs512-rd0-adam-res-ts20.pt?dl=0) ## Reference Please cite our paper(s) if you find this repository useful. This first paper proposes the APC objective, while the second paper applies it to speech recognition, speech translation, and speaker identification, and provides more systematic analysis on the learned representations. Cite both if you are kind enough! ``` @inproceedings{chung2019unsupervised, title = {An unsupervised autoregressive model for speech representation learning}, author = {Chung, Yu-An and Hsu, Wei-Ning and Tang, Hao and Glass, James}, booktitle = {Interspeech}, year = {2019} } ``` ``` @inproceedings{chung2020generative, title = {Generative pre-training for speech with autoregressive predictive coding}, author = {Chung, Yu-An and Glass, James}, booktitle = {ICASSP}, year = {2020} } ``` ## Contact Feel free to shoot me an email for any inquiries about the paper and this repository. ================================================ FILE: apc_model.py ================================================ import torch from torch import nn from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence class Prenet(nn.Module): """Prenet is a multi-layer fully-connected network with ReLU activations. During training and testing (i.e., feature extraction), each input frame is passed into the Prenet, and the Prenet output is then fed to the RNN. If Prenet configuration is None, the input frames will be directly fed to the RNN without any transformation. """ def __init__(self, input_size, num_layers, hidden_size, dropout): super(Prenet, self).__init__() input_sizes = [input_size] + [hidden_size] * (num_layers - 1) output_sizes = [hidden_size] * num_layers self.layers = nn.ModuleList( [nn.Linear(in_features=in_size, out_features=out_size) for (in_size, out_size) in zip(input_sizes, output_sizes)]) self.relu = nn.ReLU() self.dropout = nn.Dropout(dropout) def forward(self, inputs): # inputs: (batch_size, seq_len, mel_dim) for layer in self.layers: inputs = self.dropout(self.relu(layer(inputs))) return inputs # inputs: (batch_size, seq_len, out_dim) class Postnet(nn.Module): """Postnet is a simple linear layer for predicting the target frames given the RNN context during training. We don't need the Postnet for feature extraction. """ def __init__(self, input_size, output_size=80): super(Postnet, self).__init__() self.layer = nn.Conv1d(in_channels=input_size, out_channels=output_size, kernel_size=1, stride=1) def forward(self, inputs): # inputs: (batch_size, seq_len, hidden_size) inputs = torch.transpose(inputs, 1, 2) # inputs: (batch_size, hidden_size, seq_len) -- for conv1d operation return torch.transpose(self.layer(inputs), 1, 2) # (batch_size, seq_len, output_size) -- back to the original shape class APCModel(nn.Module): """This class defines Autoregressive Predictive Coding (APC), a model that learns to extract general speech features from unlabeled speech data. These features are shown to contain rich speaker and phone information, and are useful for a wide range of downstream tasks such as speaker verification and phone classification. An APC model consists of a Prenet (optional), a multi-layer GRU network, and a Postnet. For each time step during training, the Prenet transforms the input frame into a latent representation, which is then consumed by the GRU network for generating internal representations across the layers. Finally, the Postnet takes the output of the last GRU layer and attempts to predict the target frame. After training, to extract features from the data of your interest, which do not have to be i.i.d. with the training data, simply feed-forward the the data through the APC model, and take the the internal representations (i.e., the GRU hidden states) as the extracted features and use them in your tasks. """ def __init__(self, mel_dim, prenet_config, rnn_config): super(APCModel, self).__init__() self.mel_dim = mel_dim if prenet_config is not None: # Make sure the dimensionalities are correct assert prenet_config.input_size == mel_dim assert prenet_config.hidden_size == rnn_config.input_size assert rnn_config.input_size == rnn_config.hidden_size self.prenet = Prenet( input_size=prenet_config.input_size, num_layers=prenet_config.num_layers, hidden_size=prenet_config.hidden_size, dropout=prenet_config.dropout) else: assert rnn_config.input_size == mel_dim self.prenet = None in_sizes = ([rnn_config.input_size] + [rnn_config.hidden_size] * (rnn_config.num_layers - 1)) out_sizes = [rnn_config.hidden_size] * rnn_config.num_layers self.rnns = nn.ModuleList( [nn.GRU(input_size=in_size, hidden_size=out_size, batch_first=True) for (in_size, out_size) in zip(in_sizes, out_sizes)]) self.rnn_dropout = nn.Dropout(rnn_config.dropout) self.rnn_residual = rnn_config.residual self.postnet = Postnet( input_size=rnn_config.hidden_size, output_size=self.mel_dim) def forward(self, inputs, lengths): """Forward function for both training and testing (feature extraction). input: inputs: (batch_size, seq_len, mel_dim) lengths: (batch_size,) return: predicted_mel: (batch_size, seq_len, mel_dim) internal_reps: (num_layers + x, batch_size, seq_len, rnn_hidden_size), where x is 1 if there's a prenet, otherwise 0 """ seq_len = inputs.size(1) if self.prenet is not None: rnn_inputs = self.prenet(inputs) # rnn_inputs: (batch_size, seq_len, rnn_input_size) internal_reps = [rnn_inputs] # also include prenet_outputs in internal_reps else: rnn_inputs = inputs internal_reps = [] packed_rnn_inputs = pack_padded_sequence(rnn_inputs, lengths, True) for i, layer in enumerate(self.rnns): packed_rnn_outputs, _ = layer(packed_rnn_inputs) rnn_outputs, _ = pad_packed_sequence( packed_rnn_outputs, True, total_length=seq_len) # outputs: (batch_size, seq_len, rnn_hidden_size) if i + 1 < len(self.rnns): # apply dropout except the last rnn layer rnn_outputs = self.rnn_dropout(rnn_outputs) rnn_inputs, _ = pad_packed_sequence( packed_rnn_inputs, True, total_length=seq_len) # rnn_inputs: (batch_size, seq_len, rnn_hidden_size) if self.rnn_residual and rnn_inputs.size(-1) == rnn_outputs.size(-1): # Residual connections rnn_outputs = rnn_outputs + rnn_inputs internal_reps.append(rnn_outputs) packed_rnn_inputs = pack_padded_sequence(rnn_outputs, lengths, True) predicted_mel = self.postnet(rnn_outputs) # predicted_mel: (batch_size, seq_len, mel_dim) internal_reps = torch.stack(internal_reps) return predicted_mel, internal_reps # predicted_mel is only for training; internal_reps is the extracted # features ================================================ FILE: datasets.py ================================================ from os import listdir from os.path import join import pickle import torch from torch.utils import data class LibriSpeech(data.Dataset): def __init__(self, path): self.path = path self.ids = [f for f in listdir(self.path) if f.endswith('.pt')] with open(join(path, 'lengths.pkl'), 'rb') as f: self.lengths = pickle.load(f) def __len__(self): return len(self.ids) def __getitem__(self, index): x = torch.load(join(self.path, self.ids[index])) l = self.lengths[self.ids[index]] return x, l ================================================ FILE: load_pretrained_model.py ================================================ """Example of loading a pre-trained APC model.""" import torch from apc_model import APCModel from utils import PrenetConfig, RNNConfig def main(): prenet_config = None rnn_config = RNNConfig(input_size=80, hidden_size=512, num_layers=3, dropout=0.) pretrained_apc = APCModel(mel_dim=80, prenet_config=prenet_config, rnn_config=rnn_config).cuda() pretrained_weights_path = 'bs32-rhl3-rhs512-rd0-adam-res-ts3.pt' pretrained_apc.load_state_dict(torch.load(pretrained_weights_path)) # Load data and perform your task ... ================================================ FILE: prepare_data.py ================================================ import os import argparse import pickle import torch import torch.nn.functional as F def main(): parser = argparse.ArgumentParser("Configuration for data preparation") parser.add_argument("--librispeech_from_kaldi", default="./librispeech_data/kaldi/dev-clean-hires-norm.blogmel", type=str, help="Path to the librispeech log Mel features generated by the Kaldi scripts") parser.add_argument("--max_seq_len", default=1600, type=int, help="The maximum length (number of frames) of each sequence; sequences will be truncated or padded (with zero vectors) to this length") parser.add_argument("--save_dir", default="./librispeech_data/preprocessed/dev-clean", type=str, help="Directory to save the preprocessed pytorch tensors") config = parser.parse_args() os.makedirs(config.save_dir, exist_ok=True) id2len = {} with open(config.librispeech_from_kaldi, 'r') as f: # process the file line by line for line in f: data = line.strip().split() if len(data) == 1: if data[0] == '.': # end of the current utterance id2len[utt_id + '.pt'] = min(len(log_mel), config.max_seq_len) log_mel = torch.FloatTensor(log_mel) # convert the 2D list to a pytorch tensor log_mel = F.pad(log_mel, (0, 0, 0, config.max_seq_len - log_mel.size(0))) # pad or truncate torch.save(log_mel, os.path.join(config.save_dir, utt_id + '.pt')) else: # here starts a new utterance utt_id = data[0] log_mel = [] else: log_mel.append([float(i) for i in data]) with open(os.path.join(config.save_dir, 'lengths.pkl'), 'wb') as f: pickle.dump(id2len, f, protocol=4) if __name__ == '__main__': main() ================================================ FILE: train_apc.py ================================================ import os import logging import argparse import numpy as np import torch from torch.autograd import Variable from torch import nn, optim from torch.utils import data import tensorboard_logger from tensorboard_logger import log_value from apc_model import APCModel from datasets import LibriSpeech from utils import PrenetConfig, RNNConfig def main(): parser = argparse.ArgumentParser( description="Configuration for training an APC model") # Prenet architecture (note that all APC models in the paper DO NOT # incoporate a prenet) parser.add_argument("--prenet_num_layers", default=0, type=int, help="Number of ReLU layers as prenet") parser.add_argument("--prenet_dropout", default=0., type=float, help="Dropout for prenet") # RNN architecture parser.add_argument("--rnn_num_layers", default=3, type=int, help="Number of RNN layers in the APC model") parser.add_argument("--rnn_hidden_size", default=512, type=int, help="Number of hidden units in each RNN layer") parser.add_argument("--rnn_dropout", default=0., type=float, help="Dropout for each RNN output layer except the last one") parser.add_argument("--rnn_residual", action="store_true", help="Apply residual connections between RNN layers if specified") # Training configuration parser.add_argument("--optimizer", default="adam", type=str, help="The gradient descent optimizer (e.g., sgd, adam, etc.)") parser.add_argument("--batch_size", default=32, type=int, help="Training minibatch size") parser.add_argument("--learning_rate", default=0.0001, type=float, help="Initial learning rate") parser.add_argument("--epochs", default=100, type=int, help="Number of training epochs") parser.add_argument("--time_shift", default=1, type=int, help="Given f_{t}, predict f_{t + n}, where n is the time_shift") parser.add_argument("--clip_thresh", default=1.0, type=float, help="Threshold for clipping the gradients") # Misc configurations parser.add_argument("--feature_dim", default=80, type=int, help="The dimension of the input frame") parser.add_argument("--load_data_workers", default=2, type=int, help="Number of parallel data loaders") parser.add_argument("--experiment_name", default="foo", type=str, help="Name of this experiment") parser.add_argument("--store_path", default="./logs", type=str, help="Where to save the trained models and logs") parser.add_argument("--librispeech_path", default="./librispeech_data/preprocessed", type=str, help="Path to the librispeech directory") config = parser.parse_args() model_dir = os.path.join(config.store_path, config.experiment_name + '.dir') os.makedirs(config.store_path, exist_ok=True) os.makedirs(model_dir, exist_ok=True) logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filename=os.path.join(model_dir, config.experiment_name), filemode='w') # define a new Handler to log to console as well console = logging.StreamHandler() console.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') console.setFormatter(formatter) logging.getLogger('').addHandler(console) logging.info('Model Parameters: ') logging.info('Prenet Depth: %d' % (config.prenet_num_layers)) logging.info('Prenet Dropout: %f' % (config.prenet_dropout)) logging.info('RNN Depth: %d ' % (config.rnn_num_layers)) logging.info('RNN Hidden Dim: %d' % (config.rnn_hidden_size)) logging.info('RNN Residual Connections: %s' % (config.rnn_residual)) logging.info('RNN Dropout: %f' % (config.rnn_dropout)) logging.info('Optimizer: %s ' % (config.optimizer)) logging.info('Batch Size: %d ' % (config.batch_size)) logging.info('Initial Learning Rate: %f ' % (config.learning_rate)) logging.info('Time Shift: %d' % (config.time_shift)) logging.info('Gradient Clip Threshold: %f' % (config.clip_thresh)) if config.prenet_num_layers == 0: prenet_config = None rnn_config = RNNConfig( config.feature_dim, config.rnn_hidden_size, config.rnn_num_layers, config.rnn_dropout, config.rnn_residual) else: prenet_config = PrenetConfig( config.feature_dim, config.rnn_hidden_size, config.prenet_num_layers, config.prenet_dropout) rnn_config = RNNConfig( config.rnn_hidden_size, config.rnn_hidden_size, config.rnn_num_layers, config.rnn_dropout, config.rnn_residual) model = APCModel( mel_dim=config.feature_dim, prenet_config=prenet_config, rnn_config=rnn_config).cuda() criterion = nn.L1Loss() if config.optimizer == 'adam': optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) elif config.optimizer == 'adadelta': optimizer = optim.Adadelta(model.parameters()) elif config.optimizer == 'sgd': optimizer = optim.SGD(model.parameters(), lr=config.learning_rate) elif config.optimizer == 'adagrad': optimizer = optim.Adagrad(model.parameters(), lr=config.learning_rate) elif config.optimizer == 'rmsprop': optimizer = optim.RMSprop(model.parameters(), lr=config.learning_rate) else: raise NotImplementedError("Learning method not supported for the task") # setup tensorboard logger tensorboard_logger.configure( os.path.join(model_dir, config.experiment_name + '.tb_log')) train_set = LibriSpeech(os.path.join( config.librispeech_path, 'train-clean-360')) train_data_loader = data.DataLoader( train_set, batch_size=config.batch_size, num_workers=config.load_data_workers, shuffle=True) val_set = LibriSpeech(os.path.join(config.librispeech_path, 'dev-clean')) val_data_loader = data.DataLoader( val_set, batch_size=config.batch_size, num_workers=config.load_data_workers, shuffle=False) torch.save(model.state_dict(), open(os.path.join(model_dir, config.experiment_name + '__epoch_0.model'), 'wb')) global_step = 0 for epoch_i in range(config.epochs): #################### ##### Training ##### #################### model.train() train_losses = [] for batch_x, batch_l in train_data_loader: _, indices = torch.sort(batch_l, descending=True) batch_x = Variable(batch_x[indices]).cuda() batch_l = Variable(batch_l[indices]).cuda() outputs, _ = model( batch_x[:, :-config.time_shift, :], batch_l - config.time_shift) optimizer.zero_grad() loss = criterion(outputs, batch_x[:, config.time_shift:, :]) train_losses.append(loss.item()) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip_thresh) optimizer.step() log_value("training loss (step-wise)", float(loss.item()), global_step) log_value("gradient norm", grad_norm, global_step) global_step += 1 ###################### ##### Validation ##### ###################### model.eval() val_losses = [] with torch.set_grad_enabled(False): for val_batch_x, val_batch_l in val_data_loader: _, val_indices = torch.sort(val_batch_l, descending=True) val_batch_x = Variable(val_batch_x[val_indices]).cuda() val_batch_l = Variable(val_batch_l[val_indices]).cuda() val_outputs, _ = model( val_batch_x[:, :-config.time_shift, :], val_batch_l - config.time_shift) val_loss = criterion(val_outputs, val_batch_x[:, config.time_shift:, :]) val_losses.append(val_loss.item()) logging.info('Epoch: %d Training Loss: %.5f Validation Loss: %.5f' % (epoch_i + 1, np.mean(train_losses), np.mean(val_losses))) log_value("training loss (epoch-wise)", np.mean(train_losses), epoch_i) log_value("validation loss (epoch-wise)", np.mean(val_losses), epoch_i) torch.save(model.state_dict(), open(os.path.join(model_dir, config.experiment_name + '__epoch_%d' % (epoch_i + 1) + '.model'), 'wb')) if __name__ == '__main__': main() ================================================ FILE: utils.py ================================================ from collections import namedtuple PrenetConfig = namedtuple( 'PrenetConfig', ['input_size', 'hidden_size', 'num_layers', 'dropout']) RNNConfig = namedtuple( 'RNNConfig', ['input_size', 'hidden_size', 'num_layers', 'dropout', 'residual'])