Repository: ctr4si/A-Hierarchical-Latent-Structure-for-Variational-Conversation-Modeling Branch: master Commit: 83ca9dd96272 Files: 32 Total size: 152.5 KB Directory structure: gitextract_erg9sdck/ ├── .gitignore ├── LICENSE ├── Readme.md ├── cornell_preprocess.py ├── model/ │ ├── __init__.py │ ├── configs.py │ ├── data_loader.py │ ├── eval.py │ ├── eval_embed.py │ ├── layers/ │ │ ├── __init__.py │ │ ├── beam_search.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ ├── feedforward.py │ │ ├── loss.py │ │ └── rnncells.py │ ├── models.py │ ├── solver.py │ ├── train.py │ └── utils/ │ ├── __init__.py │ ├── bow.py │ ├── convert.py │ ├── embedding_metric.py │ ├── mask.py │ ├── pad.py │ ├── probability.py │ ├── tensorboard.py │ ├── time_track.py │ ├── tokenizer.py │ └── vocab.py ├── requirements.txt └── ubuntu_preprocess.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ /etc/ datasets/ /cornell_movie_dialogue/ *.orig *.lprof # Remote edit *.ftpconfig # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python env/ build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover .hypothesis/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # pyenv .python-version # celery beat schedule file celerybeat-schedule # SageMath parsed files *.sage.py # dotenv .env # virtualenv .venv venv/ ENV/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2018 Center for SuperIntelligence Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: Readme.md ================================================ # Variational Hierarchical Conversation RNN (VHCR) [PyTorch 0.4](https://github.com/pytorch/pytorch) Implementation of ["A Hierarchical Latent Structure for Variational Conversation Modeling"](https://arxiv.org/abs/1804.03424) (NAACL 2018 Oral) * [NAACL 2018 Oral Presentation Video](https://vimeo.com/277671819) ## Prerequisite Install Python packages ``` pip install -r requirements.txt ``` ## Download & Preprocess data Following scripts will 1. Create directories `./datasets/cornell/` and `./datasets/ubuntu/` respectively. 2. Download and preprocess conversation data inside each directory. ### for [Cornell Movie Dialogue dataset](https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html) ``` python cornell_preprocess.py --max_sentence_length (maximum number of words in sentence; default: 30) --max_conversation_length (maximum turns of utterances in single conversation; default: 10) --max_vocab_size (maximum size of word vocabulary; default: 20000) --max_vocab_frequency (minimum frequency of word to be included in vocabulary; default: 5) --n_workers (number of workers for multiprocessing; default: os.cpu_count()) ``` ### for [Ubuntu Dialog Dataset](http://dataset.cs.mcgill.ca/ubuntu-corpus-1.0/) ``` python ubuntu_preprocess.py --max_sentence_length (maximum number of words in sentence; default: 30) --max_conversation_length (maximum turns of utterances in single conversation; default: 10) --max_vocab_size (maximum size of word vocabulary; default: 20000) --max_vocab_frequency (minimum frequency of word to be included in vocabulary; default: 5) --n_workers (number of workers for multiprocessing; default: os.cpu_count()) ``` ## Training Go to the model directory and set the save_dir in configs.py (this is where the model checkpoints will be saved) We provide our implementation of VHCR, as well as our reference implementations for [HRED](https://arxiv.org/abs/1507.02221) and [VHRED](https://arxiv.org/abs/1605.06069). To run training: ``` python train.py --data= --model= --batch_size= ``` For example: 1. Train HRED on Cornell Movie: ``` python train.py --data=cornell --model=HRED ``` 2. Train VHRED with word drop of ratio 0.25 and kl annealing iterations 250000: ``` python train.py --data=ubuntu --model=VHRED --batch_size=40 --word_drop=0.25 --kl_annealing_iter=250000 ``` 3. Train VHCR with utterance drop of ratio 0.25: ``` python train.py --data=ubuntu --model=VHCR --batch_size=40 --sentence_drop=0.25 --kl_annealing_iter=250000 ``` By default, it will save a model checkpoint every epoch to and a tensorboard summary. For more arguments and options, see config.py. ## Evaluation To evaluate the word perplexity: ``` python eval.py --model= --checkpoint= ``` For embedding based metrics, you need to download [Google News word vectors](https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit?usp=sharing), unzip it and put it under the datasets folder. Then run: ``` python eval_embed.py --model= --checkpoint= ``` ## Reference If you use this code or dataset as part of any published research, please refer the following paper. ``` @inproceedings{VHCR:2018:NAACL, author = {Yookoon Park and Jaemin Cho and Gunhee Kim}, title = "{A Hierarchical Latent Structure for Variational Conversation Modeling}", booktitle = {NAACL}, year = 2018 } ``` ================================================ FILE: cornell_preprocess.py ================================================ # Preprocess cornell movie dialogs dataset from multiprocessing import Pool import argparse import pickle import random import os from urllib.request import urlretrieve from zipfile import ZipFile from pathlib import Path from tqdm import tqdm from model.utils import Tokenizer, Vocab, PAD_TOKEN, SOS_TOKEN, EOS_TOKEN project_dir = Path(__file__).resolve().parent datasets_dir = project_dir.joinpath('datasets/') cornell_dir = datasets_dir.joinpath('cornell/') # Tokenizer tokenizer = Tokenizer('spacy') def prepare_cornell_data(): """Download and unpack dialogs""" zip_url = 'http://www.mpi-sws.org/~cristian/data/cornell_movie_dialogs_corpus.zip' zipfile_path = datasets_dir.joinpath('cornell.zip') if not datasets_dir.exists(): datasets_dir.mkdir() # Prepare Dialog data if not cornell_dir.exists(): print(f'Downloading {zip_url} to {zipfile_path}') urlretrieve(zip_url, zipfile_path) print(f'Successfully downloaded {zipfile_path}') zip_ref = ZipFile(zipfile_path, 'r') zip_ref.extractall(datasets_dir) zip_ref.close() datasets_dir.joinpath('cornell movie-dialogs corpus').rename(cornell_dir) else: print('Cornell Data prepared!') def loadLines(fileName, fields=["lineID", "characterID", "movieID", "character", "text"], delimiter=" +++$+++ "): """ Args: fileName (str): file to load field (set): fields to extract Return: dict>: the extracted fields for each line """ lines = {} with open(fileName, 'r', encoding='iso-8859-1') as f: for line in f: values = line.split(delimiter) # Extract fields lineObj = {} for i, field in enumerate(fields): lineObj[field] = values[i] lines[lineObj['lineID']] = lineObj return lines def loadConversations(fileName, lines, fields=["character1ID", "character2ID", "movieID", "utteranceIDs"], delimiter=" +++$+++ "): """ Args: fileName (str): file to load field (set): fields to extract Return: dict>: the extracted fields for each line """ conversations = [] with open(fileName, 'r', encoding='iso-8859-1') as f: for line in f: values = line.split(delimiter) # Extract fields convObj = {} for i, field in enumerate(fields): convObj[field] = values[i] # Convert string to list (convObj["utteranceIDs"] == "['L598485', 'L598486', ...]") lineIds = eval(convObj["utteranceIDs"]) # Reassemble lines convObj["lines"] = [] for lineId in lineIds: convObj["lines"].append(lines[lineId]) conversations.append(convObj) return conversations def train_valid_test_split_by_conversation(conversations, split_ratio=[0.8, 0.1, 0.1]): """Train/Validation/Test split by randomly selected movies""" train_ratio, valid_ratio, test_ratio = split_ratio assert train_ratio + valid_ratio + test_ratio == 1.0 n_conversations = len(conversations) # Random shuffle movie list random.seed(0) random.shuffle(conversations) # Train / Validation / Test Split train_split = int(n_conversations * train_ratio) valid_split = int(n_conversations * (train_ratio + valid_ratio)) train = conversations[:train_split] valid = conversations[train_split:valid_split] test = conversations[valid_split:] print(f'Train set: {len(train)} conversations') print(f'Validation set: {len(valid)} conversations') print(f'Test set: {len(test)} conversations') return train, valid, test def tokenize_conversation(lines): sentence_list = [tokenizer(line['text']) for line in lines] return sentence_list def pad_sentences(conversations, max_sentence_length=30, max_conversation_length=10): def pad_tokens(tokens, max_sentence_length=max_sentence_length): n_valid_tokens = len(tokens) if n_valid_tokens > max_sentence_length - 1: tokens = tokens[:max_sentence_length - 1] n_pad = max_sentence_length - n_valid_tokens - 1 tokens = tokens + [EOS_TOKEN] + [PAD_TOKEN] * n_pad return tokens def pad_conversation(conversation): conversation = [pad_tokens(sentence) for sentence in conversation] return conversation all_padded_sentences = [] all_sentence_length = [] for conversation in conversations: if len(conversation) > max_conversation_length: conversation = conversation[:max_conversation_length] sentence_length = [min(len(sentence) + 1, max_sentence_length) # +1 for EOS token for sentence in conversation] all_sentence_length.append(sentence_length) sentences = pad_conversation(conversation) all_padded_sentences.append(sentences) sentences = all_padded_sentences sentence_length = all_sentence_length return sentences, sentence_length if __name__ == '__main__': parser = argparse.ArgumentParser() # Maximum valid length of sentence # => SOS/EOS will surround sentence (EOS for source / SOS for target) # => maximum length of tensor = max_sentence_length + 1 parser.add_argument('-s', '--max_sentence_length', type=int, default=30) parser.add_argument('-c', '--max_conversation_length', type=int, default=10) # Split Ratio split_ratio = [0.8, 0.1, 0.1] # Vocabulary parser.add_argument('--max_vocab_size', type=int, default=20000) parser.add_argument('--min_vocab_frequency', type=int, default=5) # Multiprocess parser.add_argument('--n_workers', type=int, default=os.cpu_count()) args = parser.parse_args() max_sent_len = args.max_sentence_length max_conv_len = args.max_conversation_length max_vocab_size = args.max_vocab_size min_freq = args.min_vocab_frequency n_workers = args.n_workers # Download and extract dialogs if necessary. prepare_cornell_data() print("Loading lines") lines = loadLines(cornell_dir.joinpath("movie_lines.txt")) print('Number of lines:', len(lines)) print("Loading conversations...") conversations = loadConversations(cornell_dir.joinpath("movie_conversations.txt"), lines) print('Number of conversations:', len(conversations)) print('Train/Valid/Test Split') # train, valid, test = train_valid_test_split_by_movie(conversations, split_ratio) train, valid, test = train_valid_test_split_by_conversation(conversations, split_ratio) def to_pickle(obj, path): with open(path, 'wb') as f: pickle.dump(obj, f) for split_type, conv_objects in [('train', train), ('valid', valid), ('test', test)]: print(f'Processing {split_type} dataset...') split_data_dir = cornell_dir.joinpath(split_type) split_data_dir.mkdir(exist_ok=True) print(f'Tokenize.. (n_workers={n_workers})') def _tokenize_conversation(conv): return tokenize_conversation(conv['lines']) with Pool(n_workers) as pool: conversations = list(tqdm(pool.imap(_tokenize_conversation, conv_objects), total=len(conv_objects))) conversation_length = [min(len(conv['lines']), max_conv_len) for conv in conv_objects] sentences, sentence_length = pad_sentences( conversations, max_sentence_length=max_sent_len, max_conversation_length=max_conv_len) print('Saving preprocessed data at', split_data_dir) to_pickle(conversation_length, split_data_dir.joinpath('conversation_length.pkl')) to_pickle(sentences, split_data_dir.joinpath('sentences.pkl')) to_pickle(sentence_length, split_data_dir.joinpath('sentence_length.pkl')) if split_type == 'train': print('Save Vocabulary...') vocab = Vocab(tokenizer) vocab.add_dataframe(conversations) vocab.update(max_size=max_vocab_size, min_freq=min_freq) print('Vocabulary size: ', len(vocab)) vocab.pickle(cornell_dir.joinpath('word2id.pkl'), cornell_dir.joinpath('id2word.pkl')) print('Done!') ================================================ FILE: model/__init__.py ================================================ ================================================ FILE: model/configs.py ================================================ import os import argparse from datetime import datetime from collections import defaultdict from pathlib import Path import pprint from torch import optim import torch.nn as nn from layers.rnncells import StackedLSTMCell, StackedGRUCell project_dir = Path(__file__).resolve().parent.parent data_dir = project_dir.joinpath('datasets') data_dict = {'cornell': data_dir.joinpath('cornell'), 'ubuntu': data_dir.joinpath('ubuntu')} optimizer_dict = {'RMSprop': optim.RMSprop, 'Adam': optim.Adam} rnn_dict = {'lstm': nn.LSTM, 'gru': nn.GRU} rnncell_dict = {'lstm': StackedLSTMCell, 'gru': StackedGRUCell} username = Path.home().name save_dir = Path(f'/data1/{username}/conversation/') def str2bool(v): """string to boolean""" if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Boolean value expected.') class Config(object): def __init__(self, **kwargs): """Configuration Class: set kwargs as class attributes with setattr""" if kwargs is not None: for key, value in kwargs.items(): if key == 'optimizer': value = optimizer_dict[value] if key == 'rnn': value = rnn_dict[value] if key == 'rnncell': value = rnncell_dict[value] setattr(self, key, value) # Dataset directory: ex) ./datasets/cornell/ self.dataset_dir = data_dict[self.data.lower()] # Data Split ex) 'train', 'valid', 'test' self.data_dir = self.dataset_dir.joinpath(self.mode) # Pickled Vocabulary self.word2id_path = self.dataset_dir.joinpath('word2id.pkl') self.id2word_path = self.dataset_dir.joinpath('id2word.pkl') # Pickled Dataframes self.sentences_path = self.data_dir.joinpath('sentences.pkl') self.sentence_length_path = self.data_dir.joinpath('sentence_length.pkl') self.conversation_length_path = self.data_dir.joinpath('conversation_length.pkl') # Save path if self.mode == 'train' and self.checkpoint is None: time_now = datetime.now().strftime('%Y-%m-%d_%H:%M:%S') self.save_path = save_dir.joinpath(self.data, self.model, time_now) self.logdir = self.save_path os.makedirs(self.save_path, exist_ok=True) elif self.checkpoint is not None: assert os.path.exists(self.checkpoint) self.save_path = os.path.dirname(self.checkpoint) self.logdir = self.save_path def __str__(self): """Pretty-print configurations in alphabetical order""" config_str = 'Configurations\n' config_str += pprint.pformat(self.__dict__) return config_str def get_config(parse=True, **optional_kwargs): """ Get configurations as attributes of class 1. Parse configurations with argparse. 2. Create Config class initilized with parsed kwargs. 3. Return Config class. """ parser = argparse.ArgumentParser() # Mode parser.add_argument('--mode', type=str, default='train') # Train parser.add_argument('--batch_size', type=int, default=80) parser.add_argument('--eval_batch_size', type=int, default=80) parser.add_argument('--n_epoch', type=int, default=30) parser.add_argument('--learning_rate', type=float, default=1e-4) parser.add_argument('--optimizer', type=str, default='Adam') parser.add_argument('--clip', type=float, default=1.0) parser.add_argument('--checkpoint', type=str, default=None) # Generation parser.add_argument('--max_unroll', type=int, default=30) parser.add_argument('--sample', type=str2bool, default=False, help='if false, use beam search for decoding') parser.add_argument('--temperature', type=float, default=1.0) parser.add_argument('--beam_size', type=int, default=1) # Model parser.add_argument('--model', type=str, default='VHCR', help='one of {HRED, VHRED, VHCR}') # Currently does not support lstm parser.add_argument('--rnn', type=str, default='gru') parser.add_argument('--rnncell', type=str, default='gru') parser.add_argument('--num_layers', type=int, default=1) parser.add_argument('--embedding_size', type=int, default=500) parser.add_argument('--tie_embedding', type=str2bool, default=True) parser.add_argument('--encoder_hidden_size', type=int, default=1000) parser.add_argument('--bidirectional', type=str2bool, default=True) parser.add_argument('--decoder_hidden_size', type=int, default=1000) parser.add_argument('--dropout', type=float, default=0.2) parser.add_argument('--context_size', type=int, default=1000) parser.add_argument('--feedforward', type=str, default='FeedForward') parser.add_argument('--activation', type=str, default='Tanh') # VAE model parser.add_argument('--z_sent_size', type=int, default=100) parser.add_argument('--z_conv_size', type=int, default=100) parser.add_argument('--word_drop', type=float, default=0.0, help='only applied to variational models') parser.add_argument('--kl_threshold', type=float, default=0.0) parser.add_argument('--kl_annealing_iter', type=int, default=25000) parser.add_argument('--importance_sample', type=int, default=100) parser.add_argument('--sentence_drop', type=float, default=0.0) # Generation parser.add_argument('--n_context', type=int, default=1) parser.add_argument('--n_sample_step', type=int, default=1) # BOW parser.add_argument('--bow', type=str2bool, default=False) # Utility parser.add_argument('--print_every', type=int, default=100) parser.add_argument('--plot_every_epoch', type=int, default=1) parser.add_argument('--save_every_epoch', type=int, default=1) # Data parser.add_argument('--data', type=str, default='ubuntu') # Parse arguments if parse: kwargs = parser.parse_args() else: kwargs = parser.parse_known_args()[0] # Namespace => Dictionary kwargs = vars(kwargs) kwargs.update(optional_kwargs) return Config(**kwargs) ================================================ FILE: model/data_loader.py ================================================ import random from collections import defaultdict from torch.utils.data import Dataset, DataLoader from utils import PAD_ID, UNK_ID, SOS_ID, EOS_ID import numpy as np class DialogDataset(Dataset): def __init__(self, sentences, conversation_length, sentence_length, vocab, data=None): # [total_data_size, max_conversation_length, max_sentence_length] # tokenized raw text of sentences self.sentences = sentences self.vocab = vocab # conversation length of each batch # [total_data_size] self.conversation_length = conversation_length # list of length of sentences # [total_data_size, max_conversation_length] self.sentence_length = sentence_length self.data = data self.len = len(sentences) def __getitem__(self, index): """Return Single data sentence""" # [max_conversation_length, max_sentence_length] sentence = self.sentences[index] conversation_length = self.conversation_length[index] sentence_length = self.sentence_length[index] # word => word_ids sentence = self.sent2id(sentence) return sentence, conversation_length, sentence_length def __len__(self): return self.len def sent2id(self, sentences): """word => word id""" # [max_conversation_length, max_sentence_length] return [self.vocab.sent2id(sentence) for sentence in sentences] def get_loader(sentences, conversation_length, sentence_length, vocab, batch_size=100, data=None, shuffle=True): """Load DataLoader of given DialogDataset""" def collate_fn(data): """ Collate list of data in to batch Args: data: list of tuple(source, target, conversation_length, source_length, target_length) Return: Batch of each feature - source (LongTensor): [batch_size, max_conversation_length, max_source_length] - target (LongTensor): [batch_size, max_conversation_length, max_source_length] - conversation_length (np.array): [batch_size] - source_length (LongTensor): [batch_size, max_conversation_length] """ # Sort by conversation length (descending order) to use 'pack_padded_sequence' data.sort(key=lambda x: x[1], reverse=True) # Separate sentences, conversation_length, sentence_length = zip(*data) # return sentences, conversation_length, sentence_length.tolist() return sentences, conversation_length, sentence_length dataset = DialogDataset(sentences, conversation_length, sentence_length, vocab, data=data) data_loader = DataLoader( dataset=dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn) return data_loader ================================================ FILE: model/eval.py ================================================ from solver import Solver, VariationalSolver from data_loader import get_loader from configs import get_config from utils import Vocab, Tokenizer import os import pickle from models import VariationalModels def load_pickle(path): with open(path, 'rb') as f: return pickle.load(f) if __name__ == '__main__': config = get_config(mode='test') print('Loading Vocabulary...') vocab = Vocab() vocab.load(config.word2id_path, config.id2word_path) print(f'Vocabulary size: {vocab.vocab_size}') config.vocab_size = vocab.vocab_size data_loader = get_loader( sentences=load_pickle(config.sentences_path), conversation_length=load_pickle(config.conversation_length_path), sentence_length=load_pickle(config.sentence_length_path), vocab=vocab, batch_size=config.batch_size) if config.model in VariationalModels: solver = VariationalSolver(config, None, data_loader, vocab=vocab, is_train=False) solver.build() solver.importance_sample() else: solver = Solver(config, None, data_loader, vocab=vocab, is_train=False) solver.build() solver.test() ================================================ FILE: model/eval_embed.py ================================================ from solver import Solver, VariationalSolver from data_loader import get_loader from configs import get_config from utils import Vocab, Tokenizer import os import pickle from models import VariationalModels import re def load_pickle(path): with open(path, 'rb') as f: return pickle.load(f) if __name__ == '__main__': config = get_config(mode='test') print('Loading Vocabulary...') vocab = Vocab() vocab.load(config.word2id_path, config.id2word_path) print(f'Vocabulary size: {vocab.vocab_size}') config.vocab_size = vocab.vocab_size data_loader = get_loader( sentences=load_pickle(config.sentences_path), conversation_length=load_pickle(config.conversation_length_path), sentence_length=load_pickle(config.sentence_length_path), vocab=vocab, batch_size=config.batch_size, shuffle=False) if config.model in VariationalModels: solver = VariationalSolver(config, None, data_loader, vocab=vocab, is_train=False) else: solver = Solver(config, None, data_loader, vocab=vocab, is_train=False) solver.build() solver.embedding_metric() ================================================ FILE: model/layers/__init__.py ================================================ from .encoder import * from .decoder import * from .rnncells import StackedLSTMCell, StackedGRUCell from .loss import * from .feedforward import * ================================================ FILE: model/layers/beam_search.py ================================================ import torch from utils import EOS_ID class Beam(object): def __init__(self, batch_size, hidden_size, vocab_size, beam_size, max_unroll, batch_position): """Beam class for beam search""" self.batch_size = batch_size self.hidden_size = hidden_size self.vocab_size = vocab_size self.beam_size = beam_size self.max_unroll = max_unroll # batch_position [batch_size] # [0, beam_size, beam_size * 2, .., beam_size * (batch_size-1)] # Points where batch starts in [batch_size x beam_size] tensors # Ex. position_idx[5]: when 5-th batch starts self.batch_position = batch_position self.log_probs = list() # [(batch*k, vocab_size)] * sequence_length self.scores = list() # [(batch*k)] * sequence_length self.back_pointers = list() # [(batch*k)] * sequence_length self.token_ids = list() # [(batch*k)] * sequence_length # self.hidden = list() # [(num_layers, batch*k, hidden_size)] * sequence_length self.metadata = { 'inputs': None, 'output': None, 'scores': None, 'length': None, 'sequence': None, } def update(self, score, back_pointer, token_id): # , h): """Append intermediate top-k candidates to beam at each step""" # self.log_probs.append(log_prob) self.scores.append(score) self.back_pointers.append(back_pointer) self.token_ids.append(token_id) # self.hidden.append(h) def backtrack(self): """Backtracks over batch to generate optimal k-sequences Returns: prediction ([batch, k, max_unroll]) A list of Tensors containing predicted sequence final_score [batch, k] A list containing the final scores for all top-k sequences length [batch, k] A list specifying the length of each sequence in the top-k candidates """ prediction = list() # import ipdb # ipdb.set_trace() # Initialize for length of top-k sequences length = [[self.max_unroll] * self.beam_size for _ in range(self.batch_size)] # Last step output of the beam are not sorted => sort here! # Size not changed [batch size, beam_size] top_k_score, top_k_idx = self.scores[-1].topk(self.beam_size, dim=1) # Initialize sequence scores top_k_score = top_k_score.clone() n_eos_in_batch = [0] * self.batch_size # Initialize Back-pointer from the last step # Add self.position_idx for indexing variable with batch x beam as the first dimension # [batch x beam] back_pointer = (top_k_idx + self.batch_position.unsqueeze(1)).view(-1) for t in reversed(range(self.max_unroll)): # Reorder variables with the Back-pointer # [batch x beam] token_id = self.token_ids[t].index_select(0, back_pointer) # Reorder the Back-pointer # [batch x beam] back_pointer = self.back_pointers[t].index_select(0, back_pointer) # Indices of ended sequences # [< batch x beam] eos_indices = self.token_ids[t].data.eq(EOS_ID).nonzero() # For each batch, every time we see an EOS in the backtracking process, # If not all sequences are ended # lowest scored survived sequence <- detected ended sequence # if all sequences are ended # lowest scored ended sequence <- detected ended sequence if eos_indices.dim() > 0: # Loop over all EOS at current step for i in range(eos_indices.size(0) - 1, -1, -1): # absolute index of detected ended sequence eos_idx = eos_indices[i, 0].item() # At which batch EOS is located batch_idx = eos_idx // self.beam_size batch_start_idx = batch_idx * self.beam_size # if n_eos_in_batch[batch_idx] > self.beam_size: # Index of sequence with lowest score _n_eos_in_batch = n_eos_in_batch[batch_idx] % self.beam_size beam_idx_to_be_replaced = self.beam_size - _n_eos_in_batch - 1 idx_to_be_replaced = batch_start_idx + beam_idx_to_be_replaced # Replace old information with new sequence information back_pointer[idx_to_be_replaced] = self.back_pointers[t][eos_idx].item() token_id[idx_to_be_replaced] = self.token_ids[t][eos_idx].item() top_k_score[batch_idx, beam_idx_to_be_replaced] = self.scores[t].view(-1)[eos_idx].item() length[batch_idx][beam_idx_to_be_replaced] = t + 1 n_eos_in_batch[batch_idx] += 1 # max_unroll * [batch x beam] prediction.append(token_id) # Sort and re-order again as the added ended sequences may change the order # [batch, beam] top_k_score, top_k_idx = top_k_score.topk(self.beam_size, dim=1) final_score = top_k_score.data for batch_idx in range(self.batch_size): length[batch_idx] = [length[batch_idx][beam_idx.item()] for beam_idx in top_k_idx[batch_idx]] # [batch x beam] top_k_idx = (top_k_idx + self.batch_position.unsqueeze(1)).view(-1) # Reverse the sequences and re-order at the same time # It is reversed because the backtracking happens in the reverse order # [batch, beam] prediction = [step.index_select(0, top_k_idx).view( self.batch_size, self.beam_size) for step in reversed(prediction)] # [batch, beam, max_unroll] prediction = torch.stack(prediction, 2) return prediction, final_score, length ================================================ FILE: model/layers/decoder.py ================================================ import random import torch from torch import nn from torch.nn import functional as F from .rnncells import StackedLSTMCell, StackedGRUCell from .beam_search import Beam from .feedforward import FeedForward from utils import to_var, SOS_ID, UNK_ID, EOS_ID import math class BaseRNNDecoder(nn.Module): def __init__(self): """Base Decoder Class""" super(BaseRNNDecoder, self).__init__() @property def use_lstm(self): return isinstance(self.rnncell, StackedLSTMCell) def init_token(self, batch_size, SOS_ID=SOS_ID): """Get Variable of Index (batch_size)""" x = to_var(torch.LongTensor([SOS_ID] * batch_size)) return x def init_h(self, batch_size=None, zero=True, hidden=None): """Return RNN initial state""" if hidden is not None: return hidden if self.use_lstm: # (h, c) return (to_var(torch.zeros(self.num_layers, batch_size, self.hidden_size)), to_var(torch.zeros(self.num_layers, batch_size, self.hidden_size))) else: # h return to_var(torch.zeros(self.num_layers, batch_size, self.hidden_size)) def batch_size(self, inputs=None, h=None): """ inputs: [batch_size, seq_len] h: [num_layers, batch_size, hidden_size] (RNN/GRU) h_c: [2, num_layers, batch_size, hidden_size] (LSTMCell) """ if inputs is not None: batch_size = inputs.size(0) return batch_size else: if self.use_lstm: batch_size = h[0].size(1) else: batch_size = h.size(1) return batch_size def decode(self, out): """ Args: out: unnormalized word distribution [batch_size, vocab_size] Return: x: word_index [batch_size] """ # Sample next word from multinomial word distribution if self.sample: # x: [batch_size] - word index (next input) x = torch.multinomial(self.softmax(out / self.temperature), 1).view(-1) # Greedy sampling else: # x: [batch_size] - word index (next input) _, x = out.max(dim=1) return x def forward(self): """Base forward function to inherit""" raise NotImplementedError def forward_step(self): """Run RNN single step""" raise NotImplementedError def embed(self, x): """word index: [batch_size] => word vectors: [batch_size, hidden_size]""" if self.training and self.word_drop > 0.0: if random.random() < self.word_drop: embed = self.embedding(to_var(x.data.new([UNK_ID] * x.size(0)))) else: embed = self.embedding(x) else: embed = self.embedding(x) return embed def beam_decode(self, init_h=None, encoder_outputs=None, input_valid_length=None, decode=False): """ Args: encoder_outputs (Variable, FloatTensor): [batch_size, source_length, hidden_size] input_valid_length (Variable, LongTensor): [batch_size] (optional) init_h (variable, FloatTensor): [batch_size, hidden_size] (optional) Return: out : [batch_size, seq_len] """ batch_size = self.batch_size(h=init_h) # [batch_size x beam_size] x = self.init_token(batch_size * self.beam_size, SOS_ID) # [num_layers, batch_size x beam_size, hidden_size] h = self.init_h(batch_size, hidden=init_h).repeat(1, self.beam_size, 1) # batch_position [batch_size] # [0, beam_size, beam_size * 2, .., beam_size * (batch_size-1)] # Points where batch starts in [batch_size x beam_size] tensors # Ex. position_idx[5]: when 5-th batch starts batch_position = to_var(torch.arange(0, batch_size).long() * self.beam_size) # Initialize scores of sequence # [batch_size x beam_size] # Ex. batch_size: 5, beam_size: 3 # [0, -inf, -inf, 0, -inf, -inf, 0, -inf, -inf, 0, -inf, -inf, 0, -inf, -inf] score = torch.ones(batch_size * self.beam_size) * -float('inf') score.index_fill_(0, torch.arange(0, batch_size).long() * self.beam_size, 0.0) score = to_var(score) # Initialize Beam that stores decisions for backtracking beam = Beam( batch_size, self.hidden_size, self.vocab_size, self.beam_size, self.max_unroll, batch_position) for i in range(self.max_unroll): # x: [batch_size x beam_size]; (token index) # => # out: [batch_size x beam_size, vocab_size] # h: [num_layers, batch_size x beam_size, hidden_size] out, h = self.forward_step(x, h, encoder_outputs=encoder_outputs, input_valid_length=input_valid_length) # log_prob: [batch_size x beam_size, vocab_size] log_prob = F.log_softmax(out, dim=1) # [batch_size x beam_size] # => [batch_size x beam_size, vocab_size] score = score.view(-1, 1) + log_prob # Select `beam size` transitions out of `vocab size` combinations # [batch_size x beam_size, vocab_size] # => [batch_size, beam_size x vocab_size] # Cutoff and retain candidates with top-k scores # score: [batch_size, beam_size] # top_k_idx: [batch_size, beam_size] # each element of top_k_idx [0 ~ beam x vocab) score, top_k_idx = score.view(batch_size, -1).topk(self.beam_size, dim=1) # Get token ids with remainder after dividing by top_k_idx # Each element is among [0, vocab_size) # Ex. Index of token 3 in beam 4 # (4 * vocab size) + 3 => 3 # x: [batch_size x beam_size] x = (top_k_idx % self.vocab_size).view(-1) # top-k-pointer [batch_size x beam_size] # Points top-k beam that scored best at current step # Later used as back-pointer at backtracking # Each element is beam index: 0 ~ beam_size # + position index: 0 ~ beam_size x (batch_size-1) beam_idx = top_k_idx / self.vocab_size # [batch_size, beam_size] top_k_pointer = (beam_idx + batch_position.unsqueeze(1)).view(-1) # Select next h (size doesn't change) # [num_layers, batch_size * beam_size, hidden_size] h = h.index_select(1, top_k_pointer) # Update sequence scores at beam beam.update(score.clone(), top_k_pointer, x) # , h) # Erase scores for EOS so that they are not expanded # [batch_size, beam_size] eos_idx = x.data.eq(EOS_ID).view(batch_size, self.beam_size) if eos_idx.nonzero().dim() > 0: score.data.masked_fill_(eos_idx, -float('inf')) # prediction ([batch, k, max_unroll]) # A list of Tensors containing predicted sequence # final_score [batch, k] # A list containing the final scores for all top-k sequences # length [batch, k] # A list specifying the length of each sequence in the top-k candidates # prediction, final_score, length = beam.backtrack() prediction, final_score, length = beam.backtrack() return prediction, final_score, length class DecoderRNN(BaseRNNDecoder): def __init__(self, vocab_size, embedding_size, hidden_size, rnncell=StackedGRUCell, num_layers=1, dropout=0.0, word_drop=0.0, max_unroll=30, sample=True, temperature=1.0, beam_size=1): super(DecoderRNN, self).__init__() self.vocab_size = vocab_size self.embedding_size = embedding_size self.hidden_size = hidden_size self.num_layers = num_layers self.dropout = dropout self.temperature = temperature self.word_drop = word_drop self.max_unroll = max_unroll self.sample = sample self.beam_size = beam_size self.embedding = nn.Embedding(vocab_size, embedding_size) self.rnncell = rnncell(num_layers, embedding_size, hidden_size, dropout) self.out = nn.Linear(hidden_size, vocab_size) self.softmax = nn.Softmax(dim=1) def forward_step(self, x, h, encoder_outputs=None, input_valid_length=None): """ Single RNN Step 1. Input Embedding (vocab_size => hidden_size) 2. RNN Step (hidden_size => hidden_size) 3. Output Projection (hidden_size => vocab size) Args: x: [batch_size] h: [num_layers, batch_size, hidden_size] (h and c from all layers) Return: out: [batch_size,vocab_size] (Unnormalized word distribution) h: [num_layers, batch_size, hidden_size] (h and c from all layers) """ # x: [batch_size] => [batch_size, hidden_size] x = self.embed(x) # last_h: [batch_size, hidden_size] (h from Top RNN layer) # h: [num_layers, batch_size, hidden_size] (h and c from all layers) last_h, h = self.rnncell(x, h) if self.use_lstm: # last_h_c: [2, batch_size, hidden_size] (h from Top RNN layer) # h_c: [2, num_layers, batch_size, hidden_size] (h and c from all layers) last_h = last_h[0] # Unormalized word distribution # out: [batch_size, vocab_size] out = self.out(last_h) return out, h def forward(self, inputs, init_h=None, encoder_outputs=None, input_valid_length=None, decode=False): """ Train (decode=False) Args: inputs (Variable, LongTensor): [batch_size, seq_len] init_h: (Variable, FloatTensor): [num_layers, batch_size, hidden_size] Return: out : [batch_size, seq_len, vocab_size] Test (decode=True) Args: inputs: None init_h: (Variable, FloatTensor): [num_layers, batch_size, hidden_size] Return: out : [batch_size, seq_len] """ batch_size = self.batch_size(inputs, init_h) # x: [batch_size] x = self.init_token(batch_size, SOS_ID) # h: [num_layers, batch_size, hidden_size] h = self.init_h(batch_size, hidden=init_h) if not decode: out_list = [] seq_len = inputs.size(1) for i in range(seq_len): # x: [batch_size] # => # out: [batch_size, vocab_size] # h: [num_layers, batch_size, hidden_size] (h and c from all layers) out, h = self.forward_step(x, h) out_list.append(out) x = inputs[:, i] # [batch_size, max_target_len, vocab_size] return torch.stack(out_list, dim=1) else: x_list = [] for i in range(self.max_unroll): # x: [batch_size] # => # out: [batch_size, vocab_size] # h: [num_layers, batch_size, hidden_size] (h and c from all layers) out, h = self.forward_step(x, h) # out: [batch_size, vocab_size] # => x: [batch_size] x = self.decode(out) x_list.append(x) # [batch_size, max_target_len] return torch.stack(x_list, dim=1) ================================================ FILE: model/layers/encoder.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence from utils import to_var, reverse_order_valid, PAD_ID from .rnncells import StackedGRUCell, StackedLSTMCell import copy class BaseRNNEncoder(nn.Module): def __init__(self): """Base RNN Encoder Class""" super(BaseRNNEncoder, self).__init__() @property def use_lstm(self): if hasattr(self, 'rnn'): return isinstance(self.rnn, nn.LSTM) else: raise AttributeError('no rnn selected') def init_h(self, batch_size=None, hidden=None): """Return RNN initial state""" if hidden is not None: return hidden if self.use_lstm: return (to_var(torch.zeros(self.num_layers*self.num_directions, batch_size, self.hidden_size)), to_var(torch.zeros(self.num_layers*self.num_directions, batch_size, self.hidden_size))) else: return to_var(torch.zeros(self.num_layers*self.num_directions, batch_size, self.hidden_size)) def batch_size(self, inputs=None, h=None): """ inputs: [batch_size, seq_len] h: [num_layers, batch_size, hidden_size] (RNN/GRU) h_c: [2, num_layers, batch_size, hidden_size] (LSTM) """ if inputs is not None: batch_size = inputs.size(0) return batch_size else: if self.use_lstm: batch_size = h[0].size(1) else: batch_size = h.size(1) return batch_size def forward(self): raise NotImplementedError class EncoderRNN(BaseRNNEncoder): def __init__(self, vocab_size, embedding_size, hidden_size, rnn=nn.GRU, num_layers=1, bidirectional=False, dropout=0.0, bias=True, batch_first=True): """Sentence-level Encoder""" super(EncoderRNN, self).__init__() self.vocab_size = vocab_size self.embedding_size = embedding_size self.hidden_size = hidden_size self.num_layers = num_layers self.dropout = dropout self.batch_first = batch_first self.bidirectional = bidirectional if bidirectional: self.num_directions = 2 else: self.num_directions = 1 # word embedding self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=PAD_ID) self.rnn = rnn(input_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional) def forward(self, inputs, input_length, hidden=None): """ Args: inputs (Variable, LongTensor): [num_setences, max_seq_len] input_length (Variable, LongTensor): [num_sentences] Return: outputs (Variable): [max_source_length, batch_size, hidden_size] - list of all hidden states hidden ((tuple of) Variable): [num_layers*num_directions, batch_size, hidden_size] - last hidden state - (h, c) or h """ batch_size, seq_len = inputs.size() # Sort in decreasing order of length for pack_padded_sequence() input_length_sorted, indices = input_length.sort(descending=True) input_length_sorted = input_length_sorted.data.tolist() # [num_sentences, max_source_length] inputs_sorted = inputs.index_select(0, indices) # [num_sentences, max_source_length, embedding_dim] embedded = self.embedding(inputs_sorted) # batch_first=True rnn_input = pack_padded_sequence(embedded, input_length_sorted, batch_first=self.batch_first) hidden = self.init_h(batch_size, hidden=hidden) # outputs: [batch, seq_len, hidden_size * num_directions] # hidden: [num_layers * num_directions, batch, hidden_size] self.rnn.flatten_parameters() outputs, hidden = self.rnn(rnn_input, hidden) outputs, outputs_lengths = pad_packed_sequence(outputs, batch_first=self.batch_first) # Reorder outputs and hidden _, inverse_indices = indices.sort() outputs = outputs.index_select(0, inverse_indices) if self.use_lstm: hidden = (hidden[0].index_select(1, inverse_indices), hidden[1].index_select(1, inverse_indices)) else: hidden = hidden.index_select(1, inverse_indices) return outputs, hidden class ContextRNN(BaseRNNEncoder): def __init__(self, input_size, context_size, rnn=nn.GRU, num_layers=1, dropout=0.0, bidirectional=False, bias=True, batch_first=True): """Context-level Encoder""" super(ContextRNN, self).__init__() self.input_size = input_size self.context_size = context_size self.hidden_size = self.context_size self.num_layers = num_layers self.dropout = dropout self.bidirectional = bidirectional self.batch_first = batch_first if bidirectional: self.num_directions = 2 else: self.num_directions = 1 self.rnn = rnn(input_size=input_size, hidden_size=context_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional) def forward(self, encoder_hidden, conversation_length, hidden=None): """ Args: encoder_hidden (Variable, FloatTensor): [batch_size, max_len, num_layers * direction * hidden_size] conversation_length (Variable, LongTensor): [batch_size] Return: outputs (Variable): [batch_size, max_seq_len, hidden_size] - list of all hidden states hidden ((tuple of) Variable): [num_layers*num_directions, batch_size, hidden_size] - last hidden state - (h, c) or h """ batch_size, seq_len, _ = encoder_hidden.size() # Sort for PackedSequence conv_length_sorted, indices = conversation_length.sort(descending=True) conv_length_sorted = conv_length_sorted.data.tolist() encoder_hidden_sorted = encoder_hidden.index_select(0, indices) rnn_input = pack_padded_sequence(encoder_hidden_sorted, conv_length_sorted, batch_first=True) hidden = self.init_h(batch_size, hidden=hidden) self.rnn.flatten_parameters() outputs, hidden = self.rnn(rnn_input, hidden) # outputs: [batch_size, max_conversation_length, context_size] outputs, outputs_length = pad_packed_sequence(outputs, batch_first=True) # reorder outputs and hidden _, inverse_indices = indices.sort() outputs = outputs.index_select(0, inverse_indices) if self.use_lstm: hidden = (hidden[0].index_select(1, inverse_indices), hidden[1].index_select(1, inverse_indices)) else: hidden = hidden.index_select(1, inverse_indices) # outputs: [batch, seq_len, hidden_size * num_directions] # hidden: [num_layers * num_directions, batch, hidden_size] return outputs, hidden def step(self, encoder_hidden, hidden): batch_size = encoder_hidden.size(0) # encoder_hidden: [1, batch_size, hidden_size] encoder_hidden = torch.unsqueeze(encoder_hidden, 1) if hidden is None: hidden = self.init_h(batch_size, hidden=None) outputs, hidden = self.rnn(encoder_hidden, hidden) return outputs, hidden ================================================ FILE: model/layers/feedforward.py ================================================ import torch import torch.nn as nn class FeedForward(nn.Module): def __init__(self, input_size, output_size, num_layers=1, hidden_size=None, activation="Tanh", bias=True): super(FeedForward, self).__init__() self.input_size = input_size self.output_size = output_size self.hidden_size = hidden_size self.num_layers = num_layers self.activation = getattr(nn, activation)() n_inputs = [input_size] + [hidden_size] * (num_layers - 1) n_outputs = [hidden_size] * (num_layers - 1) + [output_size] self.linears = nn.ModuleList([nn.Linear(n_in, n_out, bias=bias) for n_in, n_out in zip(n_inputs, n_outputs)]) def forward(self, input): x = input for linear in self.linears: x = linear(x) x = self.activation(x) return x ================================================ FILE: model/layers/loss.py ================================================ import torch from torch.nn import functional as F import torch.nn as nn from utils import to_var, sequence_mask # https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 def masked_cross_entropy(logits, target, length, per_example=False): """ Args: logits (Variable, FloatTensor): [batch, max_len, num_classes] - unnormalized probability for each class target (Variable, LongTensor): [batch, max_len] - index of true class for each corresponding step length (Variable, LongTensor): [batch] - length of each data in a batch Returns: loss (Variable): [] - An average loss value masked by the length """ batch_size, max_len, num_classes = logits.size() # [batch_size * max_len, num_classes] logits_flat = logits.view(-1, num_classes) # [batch_size * max_len, num_classes] log_probs_flat = F.log_softmax(logits_flat, dim=1) # [batch_size * max_len, 1] target_flat = target.view(-1, 1) # Negative Log-likelihood: -sum { 1* log P(target) + 0 log P(non-target)} = -sum( log P(target) ) # [batch_size * max_len, 1] losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat) # [batch_size, max_len] losses = losses_flat.view(batch_size, max_len) # [batch_size, max_len] mask = sequence_mask(sequence_length=length, max_len=max_len) # Apply masking on loss losses = losses * mask.float() # word-wise cross entropy # loss = losses.sum() / length.float().sum() if per_example: # loss: [batch_size] return losses.sum(1) else: loss = losses.sum() return loss, length.float().sum() ================================================ FILE: model/layers/rnncells.py ================================================ # Modified from OpenNMT.py, Z-forcing import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.nn._functions.thnn.rnnFusedPointwise import LSTMFused, GRUFused class StackedLSTMCell(nn.Module): def __init__(self, num_layers, input_size, rnn_size, dropout): super(StackedLSTMCell, self).__init__() self.dropout = nn.Dropout(dropout) self.num_layers = num_layers self.layers = nn.ModuleList() for i in range(num_layers): self.layers.append(nn.LSTMCell(input_size, rnn_size)) input_size = rnn_size def forward(self, x, h_c): """ Args: x: [batch_size, input_size] h_c: [2, num_layers, batch_size, hidden_size] Return: last_h_c: [2, batch_size, hidden_size] (h from last layer) h_c_list: [2, num_layers, batch_size, hidden_size] (h and c from all layers) """ h_0, c_0 = h_c h_list, c_list = [], [] for i, layer in enumerate(self.layers): # h of i-th layer h_i, c_i = layer(x, (h_0[i], c_0[i])) # x for next layer x = h_i if i + 1 != self.num_layers: x = self.dropout(x) h_list += [h_i] c_list += [c_i] last_h_c = (h_list[-1], c_list[-1]) h_list = torch.stack(h_list) c_list = torch.stack(c_list) h_c_list = (h_list, c_list) return last_h_c, h_c_list class StackedGRUCell(nn.Module): def __init__(self, num_layers, input_size, rnn_size, dropout): super(StackedGRUCell, self).__init__() self.dropout = nn.Dropout(dropout) self.num_layers = num_layers self.layers = nn.ModuleList() for i in range(num_layers): self.layers.append(nn.GRUCell(input_size, rnn_size)) input_size = rnn_size def forward(self, x, h): """ Args: x: [batch_size, input_size] h: [num_layers, batch_size, hidden_size] Return: last_h: [batch_size, hidden_size] (h from last layer) h_list: [num_layers, batch_size, hidden_size] (h from all layers) """ # h of all layers h_list = [] for i, layer in enumerate(self.layers): # h of i-th layer h_i = layer(x, h[i]) # x for next layer x = h_i if i + 1 is not self.num_layers: x = self.dropout(x) h_list.append(h_i) last_h = h_list[-1] h_list = torch.stack(h_list) return last_h, h_list ================================================ FILE: model/models.py ================================================ import torch import torch.nn as nn from utils import to_var, pad, normal_kl_div, normal_logpdf, bag_of_words_loss, to_bow, EOS_ID import layers import numpy as np import random VariationalModels = ['VHRED', 'VHCR'] class HRED(nn.Module): def __init__(self, config): super(HRED, self).__init__() self.config = config self.encoder = layers.EncoderRNN(config.vocab_size, config.embedding_size, config.encoder_hidden_size, config.rnn, config.num_layers, config.bidirectional, config.dropout) context_input_size = (config.num_layers * config.encoder_hidden_size * self.encoder.num_directions) self.context_encoder = layers.ContextRNN(context_input_size, config.context_size, config.rnn, config.num_layers, config.dropout) self.decoder = layers.DecoderRNN(config.vocab_size, config.embedding_size, config.decoder_hidden_size, config.rnncell, config.num_layers, config.dropout, config.word_drop, config.max_unroll, config.sample, config.temperature, config.beam_size) self.context2decoder = layers.FeedForward(config.context_size, config.num_layers * config.decoder_hidden_size, num_layers=1, activation=config.activation) if config.tie_embedding: self.decoder.embedding = self.encoder.embedding def forward(self, input_sentences, input_sentence_length, input_conversation_length, target_sentences, decode=False): """ Args: input_sentences: (Variable, LongTensor) [num_sentences, seq_len] target_sentences: (Variable, LongTensor) [num_sentences, seq_len] Return: decoder_outputs: (Variable, FloatTensor) - train: [batch_size, seq_len, vocab_size] - eval: [batch_size, seq_len] """ num_sentences = input_sentences.size(0) max_len = input_conversation_length.data.max().item() # encoder_outputs: [num_sentences, max_source_length, hidden_size * direction] # encoder_hidden: [num_layers * direction, num_sentences, hidden_size] encoder_outputs, encoder_hidden = self.encoder(input_sentences, input_sentence_length) # encoder_hidden: [num_sentences, num_layers * direction * hidden_size] encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(num_sentences, -1) # pad and pack encoder_hidden start = torch.cumsum(torch.cat((to_var(input_conversation_length.data.new(1).zero_()), input_conversation_length[:-1])), 0) # encoder_hidden: [batch_size, max_len, num_layers * direction * hidden_size] encoder_hidden = torch.stack([pad(encoder_hidden.narrow(0, s, l), max_len) for s, l in zip(start.data.tolist(), input_conversation_length.data.tolist())], 0) # context_outputs: [batch_size, max_len, context_size] context_outputs, context_last_hidden = self.context_encoder(encoder_hidden, input_conversation_length) # flatten outputs # context_outputs: [num_sentences, context_size] context_outputs = torch.cat([context_outputs[i, :l, :] for i, l in enumerate(input_conversation_length.data)]) # project context_outputs to decoder init state decoder_init = self.context2decoder(context_outputs) # [num_layers, batch_size, hidden_size] decoder_init = decoder_init.view(self.decoder.num_layers, -1, self.decoder.hidden_size) # train: [batch_size, seq_len, vocab_size] # eval: [batch_size, seq_len] if not decode: decoder_outputs = self.decoder(target_sentences, init_h=decoder_init, decode=decode) return decoder_outputs else: # decoder_outputs = self.decoder(target_sentences, # init_h=decoder_init, # decode=decode) # return decoder_outputs.unsqueeze(1) # prediction: [batch_size, beam_size, max_unroll] prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init) # Get top prediction only # [batch_size, max_unroll] # prediction = prediction[:, 0] # [batch_size, beam_size, max_unroll] return prediction def generate(self, context, sentence_length, n_context): # context: [batch_size, n_context, seq_len] batch_size = context.size(0) # n_context = context.size(1) samples = [] # Run for context context_hidden=None for i in range(n_context): # encoder_outputs: [batch_size, seq_len, hidden_size * direction] # encoder_hidden: [num_layers * direction, batch_size, hidden_size] encoder_outputs, encoder_hidden = self.encoder(context[:, i, :], sentence_length[:, i]) encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1) # context_outputs: [batch_size, 1, context_hidden_size * direction] # context_hidden: [num_layers * direction, batch_size, context_hidden_size] context_outputs, context_hidden = self.context_encoder.step(encoder_hidden, context_hidden) # Run for generation for j in range(self.config.n_sample_step): # context_outputs: [batch_size, context_hidden_size * direction] context_outputs = context_outputs.squeeze(1) decoder_init = self.context2decoder(context_outputs) decoder_init = decoder_init.view(self.decoder.num_layers, -1, self.decoder.hidden_size) prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init) # prediction: [batch_size, seq_len] prediction = prediction[:, 0, :] # length: [batch_size] length = [l[0] for l in length] length = to_var(torch.LongTensor(length)) samples.append(prediction) encoder_outputs, encoder_hidden = self.encoder(prediction, length) encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1) context_outputs, context_hidden = self.context_encoder.step(encoder_hidden, context_hidden) samples = torch.stack(samples, 1) return samples class VHRED(nn.Module): def __init__(self, config): super(VHRED, self).__init__() self.config = config self.encoder = layers.EncoderRNN(config.vocab_size, config.embedding_size, config.encoder_hidden_size, config.rnn, config.num_layers, config.bidirectional, config.dropout) context_input_size = (config.num_layers * config.encoder_hidden_size * self.encoder.num_directions) self.context_encoder = layers.ContextRNN(context_input_size, config.context_size, config.rnn, config.num_layers, config.dropout) self.decoder = layers.DecoderRNN(config.vocab_size, config.embedding_size, config.decoder_hidden_size, config.rnncell, config.num_layers, config.dropout, config.word_drop, config.max_unroll, config.sample, config.temperature, config.beam_size) self.context2decoder = layers.FeedForward(config.context_size + config.z_sent_size, config.num_layers * config.decoder_hidden_size, num_layers=1, activation=config.activation) self.softplus = nn.Softplus() self.prior_h = layers.FeedForward(config.context_size, config.context_size, num_layers=2, hidden_size=config.context_size, activation=config.activation) self.prior_mu = nn.Linear(config.context_size, config.z_sent_size) self.prior_var = nn.Linear(config.context_size, config.z_sent_size) self.posterior_h = layers.FeedForward(config.encoder_hidden_size * self.encoder.num_directions * config.num_layers + config.context_size, config.context_size, num_layers=2, hidden_size=config.context_size, activation=config.activation) self.posterior_mu = nn.Linear(config.context_size, config.z_sent_size) self.posterior_var = nn.Linear(config.context_size, config.z_sent_size) if config.tie_embedding: self.decoder.embedding = self.encoder.embedding if config.bow: self.bow_h = layers.FeedForward(config.z_sent_size, config.decoder_hidden_size, num_layers=1, hidden_size=config.decoder_hidden_size, activation=config.activation) self.bow_predict = nn.Linear(config.decoder_hidden_size, config.vocab_size) def prior(self, context_outputs): # Context dependent prior h_prior = self.prior_h(context_outputs) mu_prior = self.prior_mu(h_prior) var_prior = self.softplus(self.prior_var(h_prior)) return mu_prior, var_prior def posterior(self, context_outputs, encoder_hidden): h_posterior = self.posterior_h(torch.cat([context_outputs, encoder_hidden], 1)) mu_posterior = self.posterior_mu(h_posterior) var_posterior = self.softplus(self.posterior_var(h_posterior)) return mu_posterior, var_posterior def compute_bow_loss(self, target_conversations): target_bow = np.stack([to_bow(sent, self.config.vocab_size) for conv in target_conversations for sent in conv], axis=0) target_bow = to_var(torch.FloatTensor(target_bow)) bow_logits = self.bow_predict(self.bow_h(self.z_sent)) bow_loss = bag_of_words_loss(bow_logits, target_bow) return bow_loss def forward(self, sentences, sentence_length, input_conversation_length, target_sentences, decode=False): """ Args: sentences: (Variable, LongTensor) [num_sentences + batch_size, seq_len] target_sentences: (Variable, LongTensor) [num_sentences, seq_len] Return: decoder_outputs: (Variable, FloatTensor) - train: [batch_size, seq_len, vocab_size] - eval: [batch_size, seq_len] """ batch_size = input_conversation_length.size(0) num_sentences = sentences.size(0) - batch_size max_len = input_conversation_length.data.max().item() # encoder_outputs: [num_sentences + batch_size, max_source_length, hidden_size] # encoder_hidden: [num_layers * direction, num_sentences + batch_size, hidden_size] encoder_outputs, encoder_hidden = self.encoder(sentences, sentence_length) # encoder_hidden: [num_sentences + batch_size, num_layers * direction * hidden_size] encoder_hidden = encoder_hidden.transpose( 1, 0).contiguous().view(num_sentences + batch_size, -1) # pad and pack encoder_hidden start = torch.cumsum(torch.cat((to_var(input_conversation_length.data.new(1).zero_()), input_conversation_length[:-1] + 1)), 0) # encoder_hidden: [batch_size, max_len + 1, num_layers * direction * hidden_size] encoder_hidden = torch.stack([pad(encoder_hidden.narrow(0, s, l + 1), max_len + 1) for s, l in zip(start.data.tolist(), input_conversation_length.data.tolist())], 0) # encoder_hidden_inference: [batch_size, max_len, num_layers * direction * hidden_size] encoder_hidden_inference = encoder_hidden[:, 1:, :] encoder_hidden_inference_flat = torch.cat( [encoder_hidden_inference[i, :l, :] for i, l in enumerate(input_conversation_length.data)]) # encoder_hidden_input: [batch_size, max_len, num_layers * direction * hidden_size] encoder_hidden_input = encoder_hidden[:, :-1, :] # context_outputs: [batch_size, max_len, context_size] context_outputs, context_last_hidden = self.context_encoder(encoder_hidden_input, input_conversation_length) # flatten outputs # context_outputs: [num_sentences, context_size] context_outputs = torch.cat([context_outputs[i, :l, :] for i, l in enumerate(input_conversation_length.data)]) mu_prior, var_prior = self.prior(context_outputs) eps = to_var(torch.randn((num_sentences, self.config.z_sent_size))) if not decode: mu_posterior, var_posterior = self.posterior( context_outputs, encoder_hidden_inference_flat) z_sent = mu_posterior + torch.sqrt(var_posterior) * eps log_q_zx = normal_logpdf(z_sent, mu_posterior, var_posterior).sum() log_p_z = normal_logpdf(z_sent, mu_prior, var_prior).sum() # kl_div: [num_sentneces] kl_div = normal_kl_div(mu_posterior, var_posterior, mu_prior, var_prior) kl_div = torch.sum(kl_div) else: z_sent = mu_prior + torch.sqrt(var_prior) * eps kl_div = None log_p_z = normal_logpdf(z_sent, mu_prior, var_prior).sum() log_q_zx = None self.z_sent = z_sent latent_context = torch.cat([context_outputs, z_sent], 1) decoder_init = self.context2decoder(latent_context) decoder_init = decoder_init.view(-1, self.decoder.num_layers, self.decoder.hidden_size) decoder_init = decoder_init.transpose(1, 0).contiguous() # train: [batch_size, seq_len, vocab_size] # eval: [batch_size, seq_len] if not decode: decoder_outputs = self.decoder(target_sentences, init_h=decoder_init, decode=decode) return decoder_outputs, kl_div, log_p_z, log_q_zx else: # prediction: [batch_size, beam_size, max_unroll] prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init) return prediction, kl_div, log_p_z, log_q_zx def generate(self, context, sentence_length, n_context): # context: [batch_size, n_context, seq_len] batch_size = context.size(0) # n_context = context.size(1) samples = [] # Run for context context_hidden=None for i in range(n_context): # encoder_outputs: [batch_size, seq_len, hidden_size * direction] # encoder_hidden: [num_layers * direction, batch_size, hidden_size] encoder_outputs, encoder_hidden = self.encoder(context[:, i, :], sentence_length[:, i]) encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1) # context_outputs: [batch_size, 1, context_hidden_size * direction] # context_hidden: [num_layers * direction, batch_size, context_hidden_size] context_outputs, context_hidden = self.context_encoder.step(encoder_hidden, context_hidden) # Run for generation for j in range(self.config.n_sample_step): # context_outputs: [batch_size, context_hidden_size * direction] context_outputs = context_outputs.squeeze(1) mu_prior, var_prior = self.prior(context_outputs) eps = to_var(torch.randn((batch_size, self.config.z_sent_size))) z_sent = mu_prior + torch.sqrt(var_prior) * eps latent_context = torch.cat([context_outputs, z_sent], 1) decoder_init = self.context2decoder(latent_context) decoder_init = decoder_init.view(self.decoder.num_layers, -1, self.decoder.hidden_size) if self.config.sample: prediction = self.decoder(None, decoder_init) p = prediction.data.cpu().numpy() length = torch.from_numpy(np.where(p == EOS_ID)[1]) else: prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init) # prediction: [batch_size, seq_len] prediction = prediction[:, 0, :] # length: [batch_size] length = [l[0] for l in length] length = to_var(torch.LongTensor(length)) samples.append(prediction) encoder_outputs, encoder_hidden = self.encoder(prediction, length) encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1) context_outputs, context_hidden = self.context_encoder.step(encoder_hidden, context_hidden) samples = torch.stack(samples, 1) return samples class VHCR(nn.Module): def __init__(self, config): super(VHCR, self).__init__() self.config = config self.encoder = layers.EncoderRNN(config.vocab_size, config.embedding_size, config.encoder_hidden_size, config.rnn, config.num_layers, config.bidirectional, config.dropout) context_input_size = (config.num_layers * config.encoder_hidden_size * self.encoder.num_directions + config.z_conv_size) self.context_encoder = layers.ContextRNN(context_input_size, config.context_size, config.rnn, config.num_layers, config.dropout) self.unk_sent = nn.Parameter(torch.randn(context_input_size - config.z_conv_size)) self.z_conv2context = layers.FeedForward(config.z_conv_size, config.num_layers * config.context_size, num_layers=1, activation=config.activation) context_input_size = (config.num_layers * config.encoder_hidden_size * self.encoder.num_directions) self.context_inference = layers.ContextRNN(context_input_size, config.context_size, config.rnn, config.num_layers, config.dropout, bidirectional=True) self.decoder = layers.DecoderRNN(config.vocab_size, config.embedding_size, config.decoder_hidden_size, config.rnncell, config.num_layers, config.dropout, config.word_drop, config.max_unroll, config.sample, config.temperature, config.beam_size) self.context2decoder = layers.FeedForward(config.context_size + config.z_sent_size + config.z_conv_size, config.num_layers * config.decoder_hidden_size, num_layers=1, activation=config.activation) self.softplus = nn.Softplus() self.conv_posterior_h = layers.FeedForward(config.num_layers * self.context_inference.num_directions * config.context_size, config.context_size, num_layers=2, hidden_size=config.context_size, activation=config.activation) self.conv_posterior_mu = nn.Linear(config.context_size, config.z_conv_size) self.conv_posterior_var = nn.Linear(config.context_size, config.z_conv_size) self.sent_prior_h = layers.FeedForward(config.context_size + config.z_conv_size, config.context_size, num_layers=1, hidden_size=config.z_sent_size, activation=config.activation) self.sent_prior_mu = nn.Linear(config.context_size, config.z_sent_size) self.sent_prior_var = nn.Linear(config.context_size, config.z_sent_size) self.sent_posterior_h = layers.FeedForward(config.z_conv_size + config.encoder_hidden_size * self.encoder.num_directions * config.num_layers + config.context_size, config.context_size, num_layers=2, hidden_size=config.context_size, activation=config.activation) self.sent_posterior_mu = nn.Linear(config.context_size, config.z_sent_size) self.sent_posterior_var = nn.Linear(config.context_size, config.z_sent_size) if config.tie_embedding: self.decoder.embedding = self.encoder.embedding def conv_prior(self): # Standard gaussian prior return to_var(torch.FloatTensor([0.0])), to_var(torch.FloatTensor([1.0])) def conv_posterior(self, context_inference_hidden): h_posterior = self.conv_posterior_h(context_inference_hidden) mu_posterior = self.conv_posterior_mu(h_posterior) var_posterior = self.softplus(self.conv_posterior_var(h_posterior)) return mu_posterior, var_posterior def sent_prior(self, context_outputs, z_conv): # Context dependent prior h_prior = self.sent_prior_h(torch.cat([context_outputs, z_conv], dim=1)) mu_prior = self.sent_prior_mu(h_prior) var_prior = self.softplus(self.sent_prior_var(h_prior)) return mu_prior, var_prior def sent_posterior(self, context_outputs, encoder_hidden, z_conv): h_posterior = self.sent_posterior_h(torch.cat([context_outputs, encoder_hidden, z_conv], 1)) mu_posterior = self.sent_posterior_mu(h_posterior) var_posterior = self.softplus(self.sent_posterior_var(h_posterior)) return mu_posterior, var_posterior def forward(self, sentences, sentence_length, input_conversation_length, target_sentences, decode=False): """ Args: sentences: (Variable, LongTensor) [num_sentences + batch_size, seq_len] target_sentences: (Variable, LongTensor) [num_sentences, seq_len] Return: decoder_outputs: (Variable, FloatTensor) - train: [batch_size, seq_len, vocab_size] - eval: [batch_size, seq_len] """ batch_size = input_conversation_length.size(0) num_sentences = sentences.size(0) - batch_size max_len = input_conversation_length.data.max().item() # encoder_outputs: [num_sentences + batch_size, max_source_length, hidden_size] # encoder_hidden: [num_layers * direction, num_sentences + batch_size, hidden_size] encoder_outputs, encoder_hidden = self.encoder(sentences, sentence_length) # encoder_hidden: [num_sentences + batch_size, num_layers * direction * hidden_size] encoder_hidden = encoder_hidden.transpose( 1, 0).contiguous().view(num_sentences + batch_size, -1) # pad and pack encoder_hidden start = torch.cumsum(torch.cat((to_var(input_conversation_length.data.new(1).zero_()), input_conversation_length[:-1] + 1)), 0) # encoder_hidden: [batch_size, max_len + 1, num_layers * direction * hidden_size] encoder_hidden = torch.stack([pad(encoder_hidden.narrow(0, s, l + 1), max_len + 1) for s, l in zip(start.data.tolist(), input_conversation_length.data.tolist())], 0) # encoder_hidden_inference: [batch_size, max_len, num_layers * direction * hidden_size] encoder_hidden_inference = encoder_hidden[:, 1:, :] encoder_hidden_inference_flat = torch.cat( [encoder_hidden_inference[i, :l, :] for i, l in enumerate(input_conversation_length.data)]) # encoder_hidden_input: [batch_size, max_len, num_layers * direction * hidden_size] encoder_hidden_input = encoder_hidden[:, :-1, :] # Standard Gaussian prior conv_eps = to_var(torch.randn([batch_size, self.config.z_conv_size])) conv_mu_prior, conv_var_prior = self.conv_prior() if not decode: if self.config.sentence_drop > 0.0: indices = np.where(np.random.rand(max_len) < self.config.sentence_drop)[0] if len(indices) > 0: encoder_hidden_input[:, indices, :] = self.unk_sent # context_inference_outputs: [batch_size, max_len, num_directions * context_size] # context_inference_hidden: [num_layers * num_directions, batch_size, hidden_size] context_inference_outputs, context_inference_hidden = self.context_inference(encoder_hidden, input_conversation_length + 1) # context_inference_hidden: [batch_size, num_layers * num_directions * hidden_size] context_inference_hidden = context_inference_hidden.transpose( 1, 0).contiguous().view(batch_size, -1) conv_mu_posterior, conv_var_posterior = self.conv_posterior(context_inference_hidden) z_conv = conv_mu_posterior + torch.sqrt(conv_var_posterior) * conv_eps log_q_zx_conv = normal_logpdf(z_conv, conv_mu_posterior, conv_var_posterior).sum() log_p_z_conv = normal_logpdf(z_conv, conv_mu_prior, conv_var_prior).sum() kl_div_conv = normal_kl_div(conv_mu_posterior, conv_var_posterior, conv_mu_prior, conv_var_prior).sum() context_init = self.z_conv2context(z_conv).view( self.config.num_layers, batch_size, self.config.context_size) z_conv_expand = z_conv.view(z_conv.size(0), 1, z_conv.size( 1)).expand(z_conv.size(0), max_len, z_conv.size(1)) context_outputs, context_last_hidden = self.context_encoder( torch.cat([encoder_hidden_input, z_conv_expand], 2), input_conversation_length, hidden=context_init) # flatten outputs # context_outputs: [num_sentences, context_size] context_outputs = torch.cat([context_outputs[i, :l, :] for i, l in enumerate(input_conversation_length.data)]) z_conv_flat = torch.cat( [z_conv_expand[i, :l, :] for i, l in enumerate(input_conversation_length.data)]) sent_mu_prior, sent_var_prior = self.sent_prior(context_outputs, z_conv_flat) eps = to_var(torch.randn((num_sentences, self.config.z_sent_size))) sent_mu_posterior, sent_var_posterior = self.sent_posterior( context_outputs, encoder_hidden_inference_flat, z_conv_flat) z_sent = sent_mu_posterior + torch.sqrt(sent_var_posterior) * eps log_q_zx_sent = normal_logpdf(z_sent, sent_mu_posterior, sent_var_posterior).sum() log_p_z_sent = normal_logpdf(z_sent, sent_mu_prior, sent_var_prior).sum() # kl_div: [num_sentences] kl_div_sent = normal_kl_div(sent_mu_posterior, sent_var_posterior, sent_mu_prior, sent_var_prior).sum() kl_div = kl_div_conv + kl_div_sent log_q_zx = log_q_zx_conv + log_q_zx_sent log_p_z = log_p_z_conv + log_p_z_sent else: z_conv = conv_mu_prior + torch.sqrt(conv_var_prior) * conv_eps context_init = self.z_conv2context(z_conv).view( self.config.num_layers, batch_size, self.config.context_size) z_conv_expand = z_conv.view(z_conv.size(0), 1, z_conv.size( 1)).expand(z_conv.size(0), max_len, z_conv.size(1)) # context_outputs: [batch_size, max_len, context_size] context_outputs, context_last_hidden = self.context_encoder( torch.cat([encoder_hidden_input, z_conv_expand], 2), input_conversation_length, hidden=context_init) # flatten outputs # context_outputs: [num_sentences, context_size] context_outputs = torch.cat([context_outputs[i, :l, :] for i, l in enumerate(input_conversation_length.data)]) z_conv_flat = torch.cat( [z_conv_expand[i, :l, :] for i, l in enumerate(input_conversation_length.data)]) sent_mu_prior, sent_var_prior = self.sent_prior(context_outputs, z_conv_flat) eps = to_var(torch.randn((num_sentences, self.config.z_sent_size))) z_sent = sent_mu_prior + torch.sqrt(sent_var_prior) * eps kl_div = None log_p_z = normal_logpdf(z_sent, sent_mu_prior, sent_var_prior).sum() log_p_z += normal_logpdf(z_conv, conv_mu_prior, conv_var_prior).sum() log_q_zx = None # expand z_conv to all associated sentences z_conv = torch.cat([z.view(1, -1).expand(m.item(), self.config.z_conv_size) for z, m in zip(z_conv, input_conversation_length)]) # latent_context: [num_sentences, context_size + z_sent_size + # z_conv_size] latent_context = torch.cat([context_outputs, z_sent, z_conv], 1) decoder_init = self.context2decoder(latent_context) decoder_init = decoder_init.view(-1, self.decoder.num_layers, self.decoder.hidden_size) decoder_init = decoder_init.transpose(1, 0).contiguous() # train: [batch_size, seq_len, vocab_size] # eval: [batch_size, seq_len] if not decode: decoder_outputs = self.decoder(target_sentences, init_h=decoder_init, decode=decode) return decoder_outputs, kl_div, log_p_z, log_q_zx else: # prediction: [batch_size, beam_size, max_unroll] prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init) return prediction, kl_div, log_p_z, log_q_zx def generate(self, context, sentence_length, n_context): # context: [batch_size, n_context, seq_len] batch_size = context.size(0) # n_context = context.size(1) samples = [] # Run for context conv_eps = to_var(torch.randn([batch_size, self.config.z_conv_size])) # conv_mu_prior, conv_var_prior = self.conv_prior() # z_conv = conv_mu_prior + torch.sqrt(conv_var_prior) * conv_eps encoder_hidden_list = [] for i in range(n_context): # encoder_outputs: [batch_size, seq_len, hidden_size * direction] # encoder_hidden: [num_layers * direction, batch_size, hidden_size] encoder_outputs, encoder_hidden = self.encoder(context[:, i, :], sentence_length[:, i]) # encoder_hidden: [batch_size, num_layers * direction * hidden_size] encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1) encoder_hidden_list.append(encoder_hidden) encoder_hidden = torch.stack(encoder_hidden_list, 1) context_inference_outputs, context_inference_hidden = self.context_inference(encoder_hidden, to_var(torch.LongTensor([n_context] * batch_size))) context_inference_hidden = context_inference_hidden.transpose( 1, 0).contiguous().view(batch_size, -1) conv_mu_posterior, conv_var_posterior = self.conv_posterior(context_inference_hidden) z_conv = conv_mu_posterior + torch.sqrt(conv_var_posterior) * conv_eps context_init = self.z_conv2context(z_conv).view( self.config.num_layers, batch_size, self.config.context_size) context_hidden = context_init for i in range(n_context): # encoder_outputs: [batch_size, seq_len, hidden_size * direction] # encoder_hidden: [num_layers * direction, batch_size, hidden_size] encoder_outputs, encoder_hidden = self.encoder(context[:, i, :], sentence_length[:, i]) # encoder_hidden: [batch_size, num_layers * direction * encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1) encoder_hidden_list.append(encoder_hidden) # context_outputs: [batch_size, 1, context_hidden_size * direction] # context_hidden: [num_layers * direction, batch_size, context_hidden_size] context_outputs, context_hidden = self.context_encoder.step(torch.cat([encoder_hidden, z_conv], 1), context_hidden) # Run for generation for j in range(self.config.n_sample_step): # context_outputs: [batch_size, context_hidden_size * direction] context_outputs = context_outputs.squeeze(1) mu_prior, var_prior = self.sent_prior(context_outputs, z_conv) eps = to_var(torch.randn((batch_size, self.config.z_sent_size))) z_sent = mu_prior + torch.sqrt(var_prior) * eps latent_context = torch.cat([context_outputs, z_sent, z_conv], 1) decoder_init = self.context2decoder(latent_context) decoder_init = decoder_init.view(self.decoder.num_layers, -1, self.decoder.hidden_size) if self.config.sample: prediction = self.decoder(None, decoder_init, decode=True) p = prediction.data.cpu().numpy() length = torch.from_numpy(np.where(p == EOS_ID)[1]) else: prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init) # prediction: [batch_size, seq_len] prediction = prediction[:, 0, :] # length: [batch_size] length = [l[0] for l in length] length = to_var(torch.LongTensor(length)) samples.append(prediction) encoder_outputs, encoder_hidden = self.encoder(prediction, length) encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1) context_outputs, context_hidden = self.context_encoder.step(torch.cat([encoder_hidden, z_conv], 1), context_hidden) samples = torch.stack(samples, 1) return samples ================================================ FILE: model/solver.py ================================================ from itertools import cycle import numpy as np import torch import torch.nn as nn import models from layers import masked_cross_entropy from utils import to_var, time_desc_decorator, TensorboardWriter, pad_and_pack, normal_kl_div, to_bow, bag_of_words_loss, normal_kl_div, embedding_metric import os from tqdm import tqdm from math import isnan import re import math import pickle import gensim word2vec_path = "../datasets/GoogleNews-vectors-negative300.bin" class Solver(object): def __init__(self, config, train_data_loader, eval_data_loader, vocab, is_train=True, model=None): self.config = config self.epoch_i = 0 self.train_data_loader = train_data_loader self.eval_data_loader = eval_data_loader self.vocab = vocab self.is_train = is_train self.model = model @time_desc_decorator('Build Graph') def build(self, cuda=True): if self.model is None: self.model = getattr(models, self.config.model)(self.config) # orthogonal initialiation for hidden weights # input gate bias for GRUs if self.config.mode == 'train' and self.config.checkpoint is None: print('Parameter initiailization') for name, param in self.model.named_parameters(): if 'weight_hh' in name: print('\t' + name) nn.init.orthogonal_(param) # bias_hh is concatenation of reset, input, new gates # only set the input gate bias to 2.0 if 'bias_hh' in name: print('\t' + name) dim = int(param.size(0) / 3) param.data[dim:2 * dim].fill_(2.0) if torch.cuda.is_available() and cuda: self.model.cuda() # Overview Parameters print('Model Parameters') for name, param in self.model.named_parameters(): print('\t' + name + '\t', list(param.size())) if self.config.checkpoint: self.load_model(self.config.checkpoint) if self.is_train: self.writer = TensorboardWriter(self.config.logdir) self.optimizer = self.config.optimizer( filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.config.learning_rate) def save_model(self, epoch): """Save parameters to checkpoint""" ckpt_path = os.path.join(self.config.save_path, f'{epoch}.pkl') print(f'Save parameters to {ckpt_path}') torch.save(self.model.state_dict(), ckpt_path) def load_model(self, checkpoint): """Load parameters from checkpoint""" print(f'Load parameters from {checkpoint}') epoch = re.match(r"[0-9]*", os.path.basename(checkpoint)).group(0) self.epoch_i = int(epoch) self.model.load_state_dict(torch.load(checkpoint)) def write_summary(self, epoch_i): epoch_loss = getattr(self, 'epoch_loss', None) if epoch_loss is not None: self.writer.update_loss( loss=epoch_loss, step_i=epoch_i + 1, name='train_loss') epoch_recon_loss = getattr(self, 'epoch_recon_loss', None) if epoch_recon_loss is not None: self.writer.update_loss( loss=epoch_recon_loss, step_i=epoch_i + 1, name='train_recon_loss') epoch_kl_div = getattr(self, 'epoch_kl_div', None) if epoch_kl_div is not None: self.writer.update_loss( loss=epoch_kl_div, step_i=epoch_i + 1, name='train_kl_div') kl_mult = getattr(self, 'kl_mult', None) if kl_mult is not None: self.writer.update_loss( loss=kl_mult, step_i=epoch_i + 1, name='kl_mult') epoch_bow_loss = getattr(self, 'epoch_bow_loss', None) if epoch_bow_loss is not None: self.writer.update_loss( loss=epoch_bow_loss, step_i=epoch_i + 1, name='bow_loss') validation_loss = getattr(self, 'validation_loss', None) if validation_loss is not None: self.writer.update_loss( loss=validation_loss, step_i=epoch_i + 1, name='validation_loss') @time_desc_decorator('Training Start!') def train(self): epoch_loss_history = [] for epoch_i in range(self.epoch_i, self.config.n_epoch): self.epoch_i = epoch_i batch_loss_history = [] self.model.train() n_total_words = 0 for batch_i, (conversations, conversation_length, sentence_length) in enumerate(tqdm(self.train_data_loader, ncols=80)): # conversations: (batch_size) list of conversations # conversation: list of sentences # sentence: list of tokens # conversation_length: list of int # sentence_length: (batch_size) list of conversation list of sentence_lengths input_conversations = [conv[:-1] for conv in conversations] target_conversations = [conv[1:] for conv in conversations] # flatten input and target conversations input_sentences = [sent for conv in input_conversations for sent in conv] target_sentences = [sent for conv in target_conversations for sent in conv] input_sentence_length = [l for len_list in sentence_length for l in len_list[:-1]] target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]] input_conversation_length = [l - 1 for l in conversation_length] input_sentences = to_var(torch.LongTensor(input_sentences)) target_sentences = to_var(torch.LongTensor(target_sentences)) input_sentence_length = to_var(torch.LongTensor(input_sentence_length)) target_sentence_length = to_var(torch.LongTensor(target_sentence_length)) input_conversation_length = to_var(torch.LongTensor(input_conversation_length)) # reset gradient self.optimizer.zero_grad() sentence_logits = self.model( input_sentences, input_sentence_length, input_conversation_length, target_sentences, decode=False) batch_loss, n_words = masked_cross_entropy( sentence_logits, target_sentences, target_sentence_length) assert not isnan(batch_loss.item()) batch_loss_history.append(batch_loss.item()) n_total_words += n_words.item() if batch_i % self.config.print_every == 0: tqdm.write( f'Epoch: {epoch_i+1}, iter {batch_i}: loss = {batch_loss.item()/ n_words.item():.3f}') # Back-propagation batch_loss.backward() # Gradient cliping torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip) # Run optimizer self.optimizer.step() epoch_loss = np.sum(batch_loss_history) / n_total_words epoch_loss_history.append(epoch_loss) self.epoch_loss = epoch_loss print_str = f'Epoch {epoch_i+1} loss average: {epoch_loss:.3f}' print(print_str) if epoch_i % self.config.save_every_epoch == 0: self.save_model(epoch_i + 1) print('\n...') self.validation_loss = self.evaluate() if epoch_i % self.config.plot_every_epoch == 0: self.write_summary(epoch_i) self.save_model(self.config.n_epoch) return epoch_loss_history def generate_sentence(self, input_sentences, input_sentence_length, input_conversation_length, target_sentences): self.model.eval() # [batch_size, max_seq_len, vocab_size] generated_sentences = self.model( input_sentences, input_sentence_length, input_conversation_length, target_sentences, decode=True) # write output to file with open(os.path.join(self.config.save_path, 'samples.txt'), 'a') as f: f.write(f'\n\n') tqdm.write('\n') for input_sent, target_sent, output_sent in zip(input_sentences, target_sentences, generated_sentences): input_sent = self.vocab.decode(input_sent) target_sent = self.vocab.decode(target_sent) output_sent = '\n'.join([self.vocab.decode(sent) for sent in output_sent]) s = '\n'.join(['Input sentence: ' + input_sent, 'Ground truth: ' + target_sent, 'Generated response: ' + output_sent + '\n']) f.write(s + '\n') print(s) print('') def evaluate(self): self.model.eval() batch_loss_history = [] n_total_words = 0 for batch_i, (conversations, conversation_length, sentence_length) in enumerate(tqdm(self.eval_data_loader, ncols=80)): # conversations: (batch_size) list of conversations # conversation: list of sentences # sentence: list of tokens # conversation_length: list of int # sentence_length: (batch_size) list of conversation list of sentence_lengths input_conversations = [conv[:-1] for conv in conversations] target_conversations = [conv[1:] for conv in conversations] # flatten input and target conversations input_sentences = [sent for conv in input_conversations for sent in conv] target_sentences = [sent for conv in target_conversations for sent in conv] input_sentence_length = [l for len_list in sentence_length for l in len_list[:-1]] target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]] input_conversation_length = [l - 1 for l in conversation_length] with torch.no_grad(): input_sentences = to_var(torch.LongTensor(input_sentences)) target_sentences = to_var(torch.LongTensor(target_sentences)) input_sentence_length = to_var(torch.LongTensor(input_sentence_length)) target_sentence_length = to_var(torch.LongTensor(target_sentence_length)) input_conversation_length = to_var( torch.LongTensor(input_conversation_length)) if batch_i == 0: self.generate_sentence(input_sentences, input_sentence_length, input_conversation_length, target_sentences) sentence_logits = self.model( input_sentences, input_sentence_length, input_conversation_length, target_sentences) batch_loss, n_words = masked_cross_entropy( sentence_logits, target_sentences, target_sentence_length) assert not isnan(batch_loss.item()) batch_loss_history.append(batch_loss.item()) n_total_words += n_words.item() epoch_loss = np.sum(batch_loss_history) / n_total_words print_str = f'Validation loss: {epoch_loss:.3f}\n' print(print_str) return epoch_loss def test(self): self.model.eval() batch_loss_history = [] n_total_words = 0 for batch_i, (conversations, conversation_length, sentence_length) in enumerate(tqdm(self.eval_data_loader, ncols=80)): # conversations: (batch_size) list of conversations # conversation: list of sentences # sentence: list of tokens # conversation_length: list of int # sentence_length: (batch_size) list of conversation list of sentence_lengths input_conversations = [conv[:-1] for conv in conversations] target_conversations = [conv[1:] for conv in conversations] # flatten input and target conversations input_sentences = [sent for conv in input_conversations for sent in conv] target_sentences = [sent for conv in target_conversations for sent in conv] input_sentence_length = [l for len_list in sentence_length for l in len_list[:-1]] target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]] input_conversation_length = [l - 1 for l in conversation_length] with torch.no_grad(): input_sentences = to_var(torch.LongTensor(input_sentences)) target_sentences = to_var(torch.LongTensor(target_sentences)) input_sentence_length = to_var(torch.LongTensor(input_sentence_length)) target_sentence_length = to_var(torch.LongTensor(target_sentence_length)) input_conversation_length = to_var(torch.LongTensor(input_conversation_length)) sentence_logits = self.model( input_sentences, input_sentence_length, input_conversation_length, target_sentences) batch_loss, n_words = masked_cross_entropy( sentence_logits, target_sentences, target_sentence_length) assert not isnan(batch_loss.item()) batch_loss_history.append(batch_loss.item()) n_total_words += n_words.item() epoch_loss = np.sum(batch_loss_history) / n_total_words print(f'Number of words: {n_total_words}') print(f'Bits per word: {epoch_loss:.3f}') word_perplexity = np.exp(epoch_loss) print_str = f'Word perplexity : {word_perplexity:.3f}\n' print(print_str) return word_perplexity def embedding_metric(self): word2vec = getattr(self, 'word2vec', None) if word2vec is None: print('Loading word2vec model') word2vec = gensim.models.KeyedVectors.load_word2vec_format(word2vec_path, binary=True) self.word2vec = word2vec keys = word2vec.vocab self.model.eval() n_context = self.config.n_context n_sample_step = self.config.n_sample_step metric_average_history = [] metric_extrema_history = [] metric_greedy_history = [] context_history = [] sample_history = [] n_sent = 0 n_conv = 0 for batch_i, (conversations, conversation_length, sentence_length) \ in enumerate(tqdm(self.eval_data_loader, ncols=80)): # conversations: (batch_size) list of conversations # conversation: list of sentences # sentence: list of tokens # conversation_length: list of int # sentence_length: (batch_size) list of conversation list of sentence_lengths conv_indices = [i for i in range(len(conversations)) if len(conversations[i]) >= n_context + n_sample_step] context = [c for i in conv_indices for c in [conversations[i][:n_context]]] ground_truth = [c for i in conv_indices for c in [conversations[i][n_context:n_context + n_sample_step]]] sentence_length = [c for i in conv_indices for c in [sentence_length[i][:n_context]]] with torch.no_grad(): context = to_var(torch.LongTensor(context)) sentence_length = to_var(torch.LongTensor(sentence_length)) samples = self.model.generate(context, sentence_length, n_context) context = context.data.cpu().numpy().tolist() samples = samples.data.cpu().numpy().tolist() context_history.append(context) sample_history.append(samples) samples = [[self.vocab.decode(sent) for sent in c] for c in samples] ground_truth = [[self.vocab.decode(sent) for sent in c] for c in ground_truth] samples = [sent for c in samples for sent in c] ground_truth = [sent for c in ground_truth for sent in c] samples = [[word2vec[s] for s in sent.split() if s in keys] for sent in samples] ground_truth = [[word2vec[s] for s in sent.split() if s in keys] for sent in ground_truth] indices = [i for i, s, g in zip(range(len(samples)), samples, ground_truth) if s != [] and g != []] samples = [samples[i] for i in indices] ground_truth = [ground_truth[i] for i in indices] n = len(samples) n_sent += n metric_average = embedding_metric(samples, ground_truth, word2vec, 'average') metric_extrema = embedding_metric(samples, ground_truth, word2vec, 'extrema') metric_greedy = embedding_metric(samples, ground_truth, word2vec, 'greedy') metric_average_history.append(metric_average) metric_extrema_history.append(metric_extrema) metric_greedy_history.append(metric_greedy) epoch_average = np.mean(np.concatenate(metric_average_history), axis=0) epoch_extrema = np.mean(np.concatenate(metric_extrema_history), axis=0) epoch_greedy = np.mean(np.concatenate(metric_greedy_history), axis=0) print('n_sentences:', n_sent) print_str = f'Metrics - Average: {epoch_average:.3f}, Extrema: {epoch_extrema:.3f}, Greedy: {epoch_greedy:.3f}' print(print_str) print('\n') return epoch_average, epoch_extrema, epoch_greedy class VariationalSolver(Solver): def __init__(self, config, train_data_loader, eval_data_loader, vocab, is_train=True, model=None): self.config = config self.epoch_i = 0 self.train_data_loader = train_data_loader self.eval_data_loader = eval_data_loader self.vocab = vocab self.is_train = is_train self.model = model @time_desc_decorator('Training Start!') def train(self): epoch_loss_history = [] kl_mult = 0.0 conv_kl_mult = 0.0 for epoch_i in range(self.epoch_i, self.config.n_epoch): self.epoch_i = epoch_i batch_loss_history = [] recon_loss_history = [] kl_div_history = [] kl_div_sent_history = [] kl_div_conv_history = [] bow_loss_history = [] self.model.train() n_total_words = 0 # self.evaluate() for batch_i, (conversations, conversation_length, sentence_length) \ in enumerate(tqdm(self.train_data_loader, ncols=80)): # conversations: (batch_size) list of conversations # conversation: list of sentences # sentence: list of tokens # conversation_length: list of int # sentence_length: (batch_size) list of conversation list of sentence_lengths target_conversations = [conv[1:] for conv in conversations] # flatten input and target conversations sentences = [sent for conv in conversations for sent in conv] input_conversation_length = [l - 1 for l in conversation_length] target_sentences = [sent for conv in target_conversations for sent in conv] target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]] sentence_length = [l for len_list in sentence_length for l in len_list] sentences = to_var(torch.LongTensor(sentences)) sentence_length = to_var(torch.LongTensor(sentence_length)) input_conversation_length = to_var(torch.LongTensor(input_conversation_length)) target_sentences = to_var(torch.LongTensor(target_sentences)) target_sentence_length = to_var(torch.LongTensor(target_sentence_length)) # reset gradient self.optimizer.zero_grad() sentence_logits, kl_div, _, _ = self.model( sentences, sentence_length, input_conversation_length, target_sentences) recon_loss, n_words = masked_cross_entropy( sentence_logits, target_sentences, target_sentence_length) batch_loss = recon_loss + kl_mult * kl_div batch_loss_history.append(batch_loss.item()) recon_loss_history.append(recon_loss.item()) kl_div_history.append(kl_div.item()) n_total_words += n_words.item() if self.config.bow: bow_loss = self.model.compute_bow_loss(target_conversations) batch_loss += bow_loss bow_loss_history.append(bow_loss.item()) assert not isnan(batch_loss.item()) if batch_i % self.config.print_every == 0: print_str = f'Epoch: {epoch_i+1}, iter {batch_i}: loss = {batch_loss.item() / n_words.item():.3f}, recon = {recon_loss.item() / n_words.item():.3f}, kl_div = {kl_div.item() / n_words.item():.3f}' if self.config.bow: print_str += f', bow_loss = {bow_loss.item() / n_words.item():.3f}' tqdm.write(print_str) # Back-propagation batch_loss.backward() # Gradient cliping torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip) # Run optimizer self.optimizer.step() kl_mult = min(kl_mult + 1.0 / self.config.kl_annealing_iter, 1.0) epoch_loss = np.sum(batch_loss_history) / n_total_words epoch_loss_history.append(epoch_loss) epoch_recon_loss = np.sum(recon_loss_history) / n_total_words epoch_kl_div = np.sum(kl_div_history) / n_total_words self.kl_mult = kl_mult self.epoch_loss = epoch_loss self.epoch_recon_loss = epoch_recon_loss self.epoch_kl_div = epoch_kl_div print_str = f'Epoch {epoch_i+1} loss average: {epoch_loss:.3f}, recon_loss: {epoch_recon_loss:.3f}, kl_div: {epoch_kl_div:.3f}' if bow_loss_history: self.epoch_bow_loss = np.sum(bow_loss_history) / n_total_words print_str += f', bow_loss = {self.epoch_bow_loss:.3f}' print(print_str) if epoch_i % self.config.save_every_epoch == 0: self.save_model(epoch_i + 1) print('\n...') self.validation_loss = self.evaluate() if epoch_i % self.config.plot_every_epoch == 0: self.write_summary(epoch_i) return epoch_loss_history def generate_sentence(self, sentences, sentence_length, input_conversation_length, input_sentences, target_sentences): """Generate output of decoder (single batch)""" self.model.eval() # [batch_size, max_seq_len, vocab_size] generated_sentences, _, _, _ = self.model( sentences, sentence_length, input_conversation_length, target_sentences, decode=True) # write output to file with open(os.path.join(self.config.save_path, 'samples.txt'), 'a') as f: f.write(f'\n\n') tqdm.write('\n') for input_sent, target_sent, output_sent in zip(input_sentences, target_sentences, generated_sentences): input_sent = self.vocab.decode(input_sent) target_sent = self.vocab.decode(target_sent) output_sent = '\n'.join([self.vocab.decode(sent) for sent in output_sent]) s = '\n'.join(['Input sentence: ' + input_sent, 'Ground truth: ' + target_sent, 'Generated response: ' + output_sent + '\n']) f.write(s + '\n') print(s) print('') def evaluate(self): self.model.eval() batch_loss_history = [] recon_loss_history = [] kl_div_history = [] bow_loss_history = [] n_total_words = 0 for batch_i, (conversations, conversation_length, sentence_length) \ in enumerate(tqdm(self.eval_data_loader, ncols=80)): # conversations: (batch_size) list of conversations # conversation: list of sentences # sentence: list of tokens # conversation_length: list of int # sentence_length: (batch_size) list of conversation list of sentence_lengths target_conversations = [conv[1:] for conv in conversations] # flatten input and target conversations sentences = [sent for conv in conversations for sent in conv] input_conversation_length = [l - 1 for l in conversation_length] target_sentences = [sent for conv in target_conversations for sent in conv] target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]] sentence_length = [l for len_list in sentence_length for l in len_list] with torch.no_grad(): sentences = to_var(torch.LongTensor(sentences)) sentence_length = to_var(torch.LongTensor(sentence_length)) input_conversation_length = to_var( torch.LongTensor(input_conversation_length)) target_sentences = to_var(torch.LongTensor(target_sentences)) target_sentence_length = to_var(torch.LongTensor(target_sentence_length)) if batch_i == 0: input_conversations = [conv[:-1] for conv in conversations] input_sentences = [sent for conv in input_conversations for sent in conv] with torch.no_grad(): input_sentences = to_var(torch.LongTensor(input_sentences)) self.generate_sentence(sentences, sentence_length, input_conversation_length, input_sentences, target_sentences) sentence_logits, kl_div, _, _ = self.model( sentences, sentence_length, input_conversation_length, target_sentences) recon_loss, n_words = masked_cross_entropy( sentence_logits, target_sentences, target_sentence_length) batch_loss = recon_loss + kl_div if self.config.bow: bow_loss = self.model.compute_bow_loss(target_conversations) bow_loss_history.append(bow_loss.item()) assert not isnan(batch_loss.item()) batch_loss_history.append(batch_loss.item()) recon_loss_history.append(recon_loss.item()) kl_div_history.append(kl_div.item()) n_total_words += n_words.item() epoch_loss = np.sum(batch_loss_history) / n_total_words epoch_recon_loss = np.sum(recon_loss_history) / n_total_words epoch_kl_div = np.sum(kl_div_history) / n_total_words print_str = f'Validation loss: {epoch_loss:.3f}, recon_loss: {epoch_recon_loss:.3f}, kl_div: {epoch_kl_div:.3f}' if bow_loss_history: epoch_bow_loss = np.sum(bow_loss_history) / n_total_words print_str += f', bow_loss = {epoch_bow_loss:.3f}' print(print_str) print('\n') return epoch_loss def importance_sample(self): ''' Perform importance sampling to get tighter bound ''' self.model.eval() weight_history = [] n_total_words = 0 kl_div_history = [] for batch_i, (conversations, conversation_length, sentence_length) \ in enumerate(tqdm(self.eval_data_loader, ncols=80)): # conversations: (batch_size) list of conversations # conversation: list of sentences # sentence: list of tokens # conversation_length: list of int # sentence_length: (batch_size) list of conversation list of sentence_lengths target_conversations = [conv[1:] for conv in conversations] # flatten input and target conversations sentences = [sent for conv in conversations for sent in conv] input_conversation_length = [l - 1 for l in conversation_length] target_sentences = [sent for conv in target_conversations for sent in conv] target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]] sentence_length = [l for len_list in sentence_length for l in len_list] # n_words += sum([len([word for word in sent if word != PAD_ID]) for sent in target_sentences]) with torch.no_grad(): sentences = to_var(torch.LongTensor(sentences)) sentence_length = to_var(torch.LongTensor(sentence_length)) input_conversation_length = to_var( torch.LongTensor(input_conversation_length)) target_sentences = to_var(torch.LongTensor(target_sentences)) target_sentence_length = to_var(torch.LongTensor(target_sentence_length)) # treat whole batch as one data sample weights = [] for j in range(self.config.importance_sample): sentence_logits, kl_div, log_p_z, log_q_zx = self.model( sentences, sentence_length, input_conversation_length, target_sentences) recon_loss, n_words = masked_cross_entropy( sentence_logits, target_sentences, target_sentence_length) log_w = (-recon_loss.sum() + log_p_z - log_q_zx).data weights.append(log_w) if j == 0: n_total_words += n_words.item() kl_div_history.append(kl_div.item()) # weights: [n_samples] weights = torch.stack(weights, 0) m = np.floor(weights.max()) weights = np.log(torch.exp(weights - m).sum()) weights = m + weights - np.log(self.config.importance_sample) weight_history.append(weights) print(f'Number of words: {n_total_words}') bits_per_word = -np.sum(weight_history) / n_total_words print(f'Bits per word: {bits_per_word:.3f}') word_perplexity = np.exp(bits_per_word) epoch_kl_div = np.sum(kl_div_history) / n_total_words print_str = f'Word perplexity upperbound using {self.config.importance_sample} importance samples: {word_perplexity:.3f}, kl_div: {epoch_kl_div:.3f}\n' print(print_str) return word_perplexity ================================================ FILE: model/train.py ================================================ from solver import * from data_loader import get_loader from configs import get_config from utils import Vocab import os import pickle from models import VariationalModels def load_pickle(path): with open(path, 'rb') as f: return pickle.load(f) if __name__ == '__main__': config = get_config(mode='train') val_config = get_config(mode='valid') print(config) with open(os.path.join(config.save_path, 'config.txt'), 'w') as f: print(config, file=f) print('Loading Vocabulary...') vocab = Vocab() vocab.load(config.word2id_path, config.id2word_path) print(f'Vocabulary size: {vocab.vocab_size}') config.vocab_size = vocab.vocab_size train_data_loader = get_loader( sentences=load_pickle(config.sentences_path), conversation_length=load_pickle(config.conversation_length_path), sentence_length=load_pickle(config.sentence_length_path), vocab=vocab, batch_size=config.batch_size) eval_data_loader = get_loader( sentences=load_pickle(val_config.sentences_path), conversation_length=load_pickle(val_config.conversation_length_path), sentence_length=load_pickle(val_config.sentence_length_path), vocab=vocab, batch_size=val_config.eval_batch_size, shuffle=False) # for testing # train_data_loader = eval_data_loader if config.model in VariationalModels: solver = VariationalSolver else: solver = Solver solver = solver(config, train_data_loader, eval_data_loader, vocab=vocab, is_train=True) solver.build() solver.train() ================================================ FILE: model/utils/__init__.py ================================================ from .convert import * from .time_track import time_desc_decorator from .tensorboard import TensorboardWriter from .vocab import * from .mask import * from .tokenizer import * from .probability import * from .pad import * from .bow import * from .embedding_metric import * ================================================ FILE: model/utils/bow.py ================================================ import numpy as np from collections import Counter import torch.nn as nn from torch.nn import functional as F import torch from math import isnan from .vocab import PAD_ID, EOS_ID def to_bow(sentence, vocab_size): ''' Convert a sentence into a bag of words representation Args - sentence: a list of token ids - vocab_size: V Returns - bow: a integer vector of size V ''' bow = Counter(sentence) # Remove EOS tokens bow[PAD_ID] = 0 bow[EOS_ID] = 0 x = np.zeros(vocab_size, dtype=np.int64) x[list(bow.keys())] = list(bow.values()) return x def bag_of_words_loss(bow_logits, target_bow, weight=None): ''' Calculate bag of words representation loss Args - bow_logits: [num_sentences, vocab_size] - target_bow: [num_sentences] ''' log_probs = F.log_softmax(bow_logits, dim=1) target_distribution = target_bow / (target_bow.sum(1).view(-1, 1) + 1e-23) + 1e-23 entropy = -(torch.log(target_distribution) * target_bow).sum() loss = -(log_probs * target_bow).sum() - entropy return loss ================================================ FILE: model/utils/convert.py ================================================ import torch from torch.autograd import Variable def to_var(x, on_cpu=False, gpu_id=None, async=False): """Tensor => Variable""" if torch.cuda.is_available() and not on_cpu: x = x.cuda(gpu_id, async) #x = Variable(x) return x def to_tensor(x): """Variable => Tensor""" if torch.cuda.is_available(): x = x.cpu() return x.data def reverse_order(tensor, dim=0): """Reverse Tensor or Variable""" if isinstance(tensor, torch.Tensor) or isinstance(tensor, torch.LongTensor): idx = [i for i in range(tensor.size(dim)-1, -1, -1)] idx = torch.LongTensor(idx) inverted_tensor = tensor.index_select(dim, idx) if isinstance(tensor, torch.cuda.FloatTensor) or isinstance(tensor, torch.cuda.LongTensor): idx = [i for i in range(tensor.size(dim)-1, -1, -1)] idx = torch.cuda.LongTensor(idx) inverted_tensor = tensor.index_select(dim, idx) return inverted_tensor elif isinstance(tensor, Variable): variable = tensor variable.data = reverse_order(variable.data, dim) return variable def reverse_order_valid(tensor, length_list, dim=0): """ Reverse Tensor of Variable only in given length Ex) Args: - tensor (Tensor or Variable) 1 2 3 4 5 6 6 7 8 9 0 0 11 12 13 0 0 0 16 17 0 0 0 0 21 22 23 24 25 26 - length_list (list) [6, 4, 3, 2, 6] Return: tensor (Tensor or Variable; in-place) 6 5 4 3 2 1 0 0 9 8 7 6 0 0 0 13 12 11 0 0 0 0 17 16 26 25 24 23 22 21 """ for row, length in zip(tensor, length_list): valid_row = row[:length] reversed_valid_row = reverse_order(valid_row, dim=dim) row[:length] = reversed_valid_row return tensor ================================================ FILE: model/utils/embedding_metric.py ================================================ import numpy as np def cosine_similarity(s, g): similarity = np.sum(s * g, axis=1) / np.sqrt((np.sum(s * s, axis=1) * np.sum(g * g, axis=1))) # return np.sum(similarity) return similarity def embedding_metric(samples, ground_truth, word2vec, method='average'): if method == 'average': # s, g: [n_samples, word_dim] s = [np.mean(sample, axis=0) for sample in samples] g = [np.mean(gt, axis=0) for gt in ground_truth] return cosine_similarity(np.array(s), np.array(g)) elif method == 'extrema': s_list = [] g_list = [] for sample, gt in zip(samples, ground_truth): s_max = np.max(sample, axis=0) s_min = np.min(sample, axis=0) s_plus = np.absolute(s_min) <= s_max s_abs = np.max(np.absolute(sample), axis=0) s = s_max * s_plus + s_min * np.logical_not(s_plus) s_list.append(s) g_max = np.max(gt, axis=0) g_min = np.min(gt, axis=0) g_plus = np.absolute(g_min) <= g_max g_abs = np.max(np.absolute(gt), axis=0) g = g_max * g_plus + g_min * np.logical_not(g_plus) g_list.append(g) return cosine_similarity(np.array(s_list), np.array(g_list)) elif method == 'greedy': sim_list = [] for s, g in zip(samples, ground_truth): s = np.array(s) g = np.array(g).T sim = (np.matmul(s, g) / np.sqrt(np.matmul(np.sum(s * s, axis=1, keepdims=True), np.sum(g * g, axis=0, keepdims=True)))) sim = np.max(sim, axis=0) sim_list.append(np.mean(sim)) # return np.sum(sim_list) return np.array(sim_list) else: raise NotImplementedError ================================================ FILE: model/utils/mask.py ================================================ import torch from .convert import to_var def sequence_mask(sequence_length, max_len=None): """ Args: sequence_length (Variable, LongTensor) [batch_size] - list of sequence length of each batch max_len (int) Return: masks (bool): [batch_size, max_len] - True if current sequence is valid (not padded), False otherwise Ex. sequence length: [3, 2, 1] seq_length_expand [[3, 3, 3], [2, 2, 2] [1, 1, 1]] seq_range_expand [[0, 1, 2] [0, 1, 2], [0, 1, 2]] masks [[True, True, True], [True, True, False], [True, False, False]] """ if max_len is None: max_len = sequence_length.max() batch_size = sequence_length.size(0) # [max_len] seq_range = torch.arange(0, max_len).long() # [0, 1, ... max_len-1] # [batch_size, max_len] seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) seq_range_expand = to_var(seq_range_expand) # [batch_size, max_len] seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand) # [batch_size, max_len] masks = seq_range_expand < seq_length_expand return masks ================================================ FILE: model/utils/pad.py ================================================ import torch from torch.autograd import Variable from .convert import to_var def pad(tensor, length): if isinstance(tensor, Variable): var = tensor if length > var.size(0): return torch.cat([var, torch.zeros(length - var.size(0), *var.size()[1:]).cuda()]) else: return var else: if length > tensor.size(0): return torch.cat([tensor, torch.zeros(length - tensor.size(0), *tensor.size()[1:]).cuda()]) else: return tensor def pad_and_pack(tensor_list): length_list = ([t.size(0) for t in tensor_list]) max_len = max(length_list) padded = [pad(t, max_len) for t in tensor_list] packed = torch.stack(padded, 0) return packed, length_list ================================================ FILE: model/utils/probability.py ================================================ import torch import numpy as np from .convert import to_var def normal_logpdf(x, mean, var): """ Args: x: (Variable, FloatTensor) [batch_size, dim] mean: (Variable, FloatTensor) [batch_size, dim] or [batch_size] or [1] var: (Variable, FloatTensor) [batch_size, dim]: positive value Return: log_p: (Variable, FloatTensor) [batch_size] """ pi = to_var(torch.FloatTensor([np.pi])) return 0.5 * torch.sum(-torch.log(2.0 * pi) - torch.log(var) - ((x - mean).pow(2) / var), dim=1) def normal_kl_div(mu1, var1, mu2=to_var(torch.FloatTensor([0.0])), var2=to_var(torch.FloatTensor([1.0]))): one = to_var(torch.FloatTensor([1.0])) return torch.sum(0.5 * (torch.log(var2) - torch.log(var1) + (var1 + (mu1 - mu2).pow(2)) / var2 - one), 1) ================================================ FILE: model/utils/tensorboard.py ================================================ from tensorboardX import SummaryWriter class TensorboardWriter(SummaryWriter): def __init__(self, logdir): """ Extended SummaryWriter Class from tensorboard-pytorch (tensorbaordX) https://github.com/lanpa/tensorboard-pytorch/blob/master/tensorboardX/writer.py Internally calls self.file_writer """ super(TensorboardWriter, self).__init__(logdir) self.logdir = self.file_writer.get_logdir() def update_parameters(self, module, step_i): """ module: nn.Module """ for name, param in module.named_parameters(): self.add_histogram(name, param.clone().cpu().data.numpy(), step_i) def update_loss(self, loss, step_i, name='loss'): self.add_scalar(name, loss, step_i) def update_histogram(self, values, step_i, name='hist'): self.add_histogram(name, values, step_i) ================================================ FILE: model/utils/time_track.py ================================================ import time from functools import partial def base_time_desc_decorator(method, desc='test_description'): def timed(*args, **kwargs): # Print Description # print('#' * 50) print(desc) # print('#' * 50 + '\n') # Calculation Runtime start = time.time() # Run Method try: result = method(*args, **kwargs) except TypeError: result = method(**kwargs) # Print Runtime print('Done! It took {:.2} secs\n'.format(time.time() - start)) if result is not None: return result return timed def time_desc_decorator(desc): return partial(base_time_desc_decorator, desc=desc) @time_desc_decorator('this is description') def time_test(arg, kwarg='this is kwarg'): time.sleep(3) print('Inside of time_test') print('printing arg: ', arg) print('printing kwarg: ', kwarg) @time_desc_decorator('this is second description') def no_arg_method(): print('this method has no argument') if __name__ == '__main__': time_test('hello', kwarg=3) time_test(3) no_arg_method() ================================================ FILE: model/utils/tokenizer.py ================================================ import re def clean_str(string): """ Tokenization/string cleaning for all datasets except for SST. Every dataset is lower cased except for TREC """ string = re.sub(r"[^A-Za-z0-9,!?\'\`\.]", " ", string) string = re.sub(r"\.{3}", " ...", string) string = re.sub(r"\'s", " \'s", string) string = re.sub(r"\'ve", " \'ve", string) string = re.sub(r"n\'t", " n\'t", string) string = re.sub(r"\'re", " \'re", string) string = re.sub(r"\'d", " \'d", string) string = re.sub(r"\'ll", " \'ll", string) string = re.sub(r",", " , ", string) string = re.sub(r"!", " ! ", string) string = re.sub(r"\?", " ? ", string) string = re.sub(r"\s{2,}", " ", string) return string.strip().lower() class Tokenizer(): def __init__(self, tokenizer='whitespace', clean_string=True): self.clean_string = clean_string tokenizer = tokenizer.lower() # Tokenize with whitespace if tokenizer == 'whitespace': print('Loading whitespace tokenizer') self.tokenize = lambda string: string.strip().split() if tokenizer == 'regex': print('Loading regex tokenizer') import re pattern = r"[A-Z]{2,}(?![a-z])|[A-Z][a-z]+(?=[A-Z])|[\'\w\-]+" self.tokenize = lambda string: re.findall(pattern, string) if tokenizer == 'spacy': print('Loading SpaCy') import spacy nlp = spacy.load('en') self.tokenize = lambda string: [token.text for token in nlp(string)] # Tokenize with punctuations other than periods if tokenizer == 'nltk': print('Loading NLTK word tokenizer') from nltk import word_tokenize self.tokenize = word_tokenize def __call__(self, string): if self.clean_string: string = clean_str(string) return self.tokenize(string) if __name__ == '__main__': tokenizer = Tokenizer() print(tokenizer("Hello, how are you doin'?")) tokenizer = Tokenizer('spacy') print(tokenizer("Hello, how are you doin'?")) ================================================ FILE: model/utils/vocab.py ================================================ from collections import defaultdict import pickle import torch from torch import Tensor from torch.autograd import Variable from nltk import FreqDist from .convert import to_tensor, to_var PAD_TOKEN = '' UNK_TOKEN = '' SOS_TOKEN = '' EOS_TOKEN = '' PAD_ID, UNK_ID, SOS_ID, EOS_ID = [0, 1, 2, 3] class Vocab(object): def __init__(self, tokenizer=None, max_size=None, min_freq=1): """Basic Vocabulary object""" self.vocab_size = 0 self.freqdist = FreqDist() self.tokenizer = tokenizer def update(self, max_size=None, min_freq=1): """ Initialize id2word & word2id based on self.freqdist max_size include 4 special tokens """ # {0: '', 1: '', 2: '', 3: ''} self.id2word = { PAD_ID: PAD_TOKEN, UNK_ID: UNK_TOKEN, SOS_ID: SOS_TOKEN, EOS_ID: EOS_TOKEN } # {'': 0, '': 1, '': 2, '': 3} self.word2id = defaultdict(lambda: UNK_ID) # Not in vocab => return UNK self.word2id.update({ PAD_TOKEN: PAD_ID, UNK_TOKEN: UNK_ID, SOS_TOKEN: SOS_ID, EOS_TOKEN: EOS_ID }) # self.word2id = { # PAD_TOKEN: PAD_ID, UNK_TOKEN: UNK_ID, # SOS_TOKEN: SOS_ID, EOS_TOKEN: EOS_ID # } vocab_size = 4 min_freq = max(min_freq, 1) # Reset frequencies of special tokens # [...('', 0), ('', 0), ('', 0), ('', 0)] freqdist = self.freqdist.copy() special_freqdist = {token: freqdist[token] for token in [PAD_TOKEN, UNK_TOKEN, SOS_TOKEN, EOS_TOKEN]} freqdist.subtract(special_freqdist) # Sort: by frequency, then alphabetically # Ex) freqdist = { 'a': 4, 'b': 5, 'c': 3 } # => sorted = [('b', 5), ('a', 4), ('c', 3)] sorted_frequency_counter = sorted(freqdist.items(), key=lambda k_v: k_v[0]) sorted_frequency_counter.sort(key=lambda k_v: k_v[1], reverse=True) for word, freq in sorted_frequency_counter: if freq < min_freq or vocab_size == max_size: break self.id2word[vocab_size] = word self.word2id[word] = vocab_size vocab_size += 1 self.vocab_size = vocab_size def __len__(self): return len(self.id2word) def load(self, word2id_path=None, id2word_path=None): if word2id_path: with open(word2id_path, 'rb') as f: word2id = pickle.load(f) # Can't pickle lambda function self.word2id = defaultdict(lambda: UNK_ID) self.word2id.update(word2id) self.vocab_size = len(self.word2id) if id2word_path: with open(id2word_path, 'rb') as f: id2word = pickle.load(f) self.id2word = id2word def add_word(self, word): assert isinstance(word, str), 'Input should be str' self.freqdist.update([word]) def add_sentence(self, sentence, tokenized=False): if not tokenized: sentence = self.tokenizer(sentence) for word in sentence: self.add_word(word) def add_dataframe(self, conversation_df, tokenized=True): for conversation in conversation_df: for sentence in conversation: self.add_sentence(sentence, tokenized=tokenized) def pickle(self, word2id_path, id2word_path): with open(word2id_path, 'wb') as f: pickle.dump(dict(self.word2id), f) with open(id2word_path, 'wb') as f: pickle.dump(self.id2word, f) def to_list(self, list_like): """Convert list-like containers to list""" if isinstance(list_like, list): return list_like if isinstance(list_like, Variable): return list(to_tensor(list_like).numpy()) elif isinstance(list_like, Tensor): return list(list_like.numpy()) def id2sent(self, id_list): """list of id => list of tokens (Single sentence)""" id_list = self.to_list(id_list) sentence = [] for id in id_list: word = self.id2word[id] if word not in [EOS_TOKEN, SOS_TOKEN, PAD_TOKEN]: sentence.append(word) if word == EOS_TOKEN: break return sentence def sent2id(self, sentence, var=False): """list of tokens => list of id (Single sentence)""" id_list = [self.word2id[word] for word in sentence] if var: id_list = to_var(torch.LongTensor(id_list), eval=True) return id_list def decode(self, id_list): sentence = self.id2sent(id_list) return ' '.join(sentence) ================================================ FILE: requirements.txt ================================================ pandas==0.20.3 numpy==1.14.0 gensim==3.1.0 spacy==1.9.0 tqdm==4.15.0 nltk==3.4.5 tensorboardX==1.1 torch==0.4 ================================================ FILE: ubuntu_preprocess.py ================================================ # Load the Ubuntu dialog corpus # Available from here: # http://dataset.cs.mcgill.ca/ubuntu-corpus-1.0/ubuntu_dialogs.tgz from multiprocessing import Pool from pathlib import Path from collections import OrderedDict from urllib.request import urlretrieve import os import argparse import tarfile import pickle from tqdm import tqdm import pandas as pd from model.utils import Tokenizer, Vocab, PAD_TOKEN, SOS_TOKEN, EOS_TOKEN project_dir = Path(__file__).resolve().parent datasets_dir = project_dir.joinpath('datasets/') ubuntu_dir = datasets_dir.joinpath('ubuntu/') ubuntu_meta_dir = ubuntu_dir.joinpath('meta/') dialogs_dir = ubuntu_dir.joinpath('dialogs/') # Tokenizer tokenizer = Tokenizer('spacy') def prepare_ubuntu_data(): """Download and unpack dialogs""" tar_filename = 'ubuntu_dialogs.tgz' url = 'http://dataset.cs.mcgill.ca/ubuntu-corpus-1.0/ubuntu_dialogs.tgz' tarfile_path = ubuntu_dir.joinpath(tar_filename) metadata_url = 'https://raw.githubusercontent.com/rkadlec/ubuntu-ranking-dataset-creator/master/src/meta/' if not datasets_dir.exists(): datasets_dir.mkdir() if not ubuntu_dir.exists(): ubuntu_dir.mkdir() if not ubuntu_meta_dir.exists(): ubuntu_meta_dir.mkdir() # Prepare Dialog data if not dialogs_dir.joinpath("10/1.tst").exists(): # Download Dialog tarfile if not tarfile_path.exists(): print(f"Downloading {url} to {tarfile_path}") urlretrieve(url, tarfile_path) print(f"Successfully downloaded {tarfile_path}") # Unpack tarfile if not dialogs_dir.exists(): print("Unpacking dialogs ... (This can take 5~10 mins.)") with tarfile.open(tarfile_path) as tar: tar.extractall(path=ubuntu_dir) print("Archive unpacked.") # Download metadata if not ubuntu_meta_dir.joinpath('trainfiles.csv').exists(): print('Downloading metadata ... (This can take 5~10 mins.)') for filename in ['trainfiles.csv', 'valfiles.csv', 'testfiles.csv']: csv_path = ubuntu_meta_dir.joinpath(filename) print(f"Downloading {metadata_url+filename} to {csv_path}") urlretrieve(metadata_url + filename, csv_path) print(f"Successfully downloaded {csv_path}") print('Ubuntu Data prepared!') def get_dialog_path_list(dataset='train'): if dataset == 'train': filename = 'trainfiles.csv' elif dataset == 'test': filename = 'testfiles.csv' elif dataset == 'valid': filename = 'valfiles.csv' with open(ubuntu_meta_dir.joinpath(filename)) as f: dialog_path_list = [] for line in f: file, dir = line.strip().split(",") path = dialogs_dir.joinpath(dir, file) dialog_path_list.append(path) return dialog_path_list def read_and_tokenize(dialog_path, min_turn=3): """ Read conversation Args: dialog_path (str): path of dialog (tsv format) Return: dialogs: (list of list of str) [dialog_length, sentence_length] users: (list of str); [2] """ with open(dialog_path, 'r', encoding='utf-8') as f: # Go through the dialog first_turn = True dialog = [] users = [] same_user_utterances = [] # list of sentences of current user dialog.append(same_user_utterances) for line in f: _time, speaker, _listener, sentence = line.split('\t') users.append(speaker) if first_turn: last_speaker = speaker first_turn = False # Speaker has changed if last_speaker != speaker: same_user_utterances = [] dialog.append(same_user_utterances) same_user_utterances.append(sentence) last_speaker = speaker # All users in conversation (len: 2) users = list(OrderedDict.fromkeys(users)) # 1. Concatenate consecutive sentences of single user # 2. Tokenize dialog = [tokenizer(" ".join(sentence)) for sentence in dialog] if len(dialog) < min_turn: print(f"Dialog {dialog_path} length ({len(dialog)}) < minimum required length {min_turn}") return [] return dialog #, users def pad_sentences(conversations, max_sentence_length=30, max_conversation_length=10): def pad_tokens(tokens, max_sentence_length=max_sentence_length): n_valid_tokens = len(tokens) if n_valid_tokens > max_sentence_length - 1: tokens = tokens[:max_sentence_length - 1] n_pad = max_sentence_length - n_valid_tokens - 1 tokens = tokens + [EOS_TOKEN] + [PAD_TOKEN] * n_pad return tokens def pad_conversation(conversation): conversation = [pad_tokens(sentence) for sentence in conversation] return conversation all_padded_sentences = [] all_sentence_length = [] for conversation in conversations: if len(conversation) > max_conversation_length: conversation = conversation[:max_conversation_length] sentence_length = [min(len(sentence) + 1, max_sentence_length) # +1 for EOS token for sentence in conversation] all_sentence_length.append(sentence_length) sentences = pad_conversation(conversation) all_padded_sentences.append(sentences) # [n_conversations, n_sentence (various), max_sentence_length] sentences = all_padded_sentences # [n_conversations, n_sentence (various)] sentence_length = all_sentence_length return sentences, sentence_length if __name__ == '__main__': parser = argparse.ArgumentParser() # Maximum valid length of sentence # => SOS/EOS will surround sentence (EOS for source / SOS for target) # => maximum length of tensor = max_sentence_length + 1 parser.add_argument('-s', '--max_sentence_length', type=int, default=30) parser.add_argument('-c', '--max_conversation_length', type=int, default=10) # Vocabulary parser.add_argument('--max_vocab_size', type=int, default=20000) parser.add_argument('--min_vocab_frequency', type=int, default=5) # Multiprocess parser.add_argument('--n_workers', type=int, default=os.cpu_count()) args = parser.parse_args() max_sent_len = args.max_sentence_length max_conv_len = args.max_conversation_length max_vocab_size = args.max_vocab_size min_freq = args.min_vocab_frequency n_workers = args.n_workers min_turn = 3 # Download and unpack dialogs if necessary. prepare_ubuntu_data() def to_pickle(obj, path): with open(path, 'wb') as f: pickle.dump(obj, f) for split_type in ['train', 'test', 'valid']: print(f'Processing {split_type} dataset...') split_data_dir = ubuntu_dir.joinpath(split_type) split_data_dir.mkdir(exist_ok=True) # List of dialogs (tsv) dialog_path_list = get_dialog_path_list(split_type) print(f'Tokenize.. (n_workers={n_workers})') def _tokenize_conversation(dialog_path): return read_and_tokenize(dialog_path) with Pool(n_workers) as pool: conversations = list(tqdm(pool.imap(_tokenize_conversation, dialog_path_list), total=len(dialog_path_list))) # Filter too short conversations conversations = list(filter(lambda x: len(x) >= min_turn, conversations)) # conversations: padded_sentences # [n_conversations, conversation_length (various), max_sentence_length] # sentence_length: list of length of sentences # [n_conversations, conversation_length (various)] conversation_length = [min(len(conversation), max_conv_len) for conversation in conversations] sentences, sentence_length = pad_sentences( conversations, max_sentence_length=max_sent_len, max_conversation_length=max_conv_len) print('Saving preprocessed data at', split_data_dir) to_pickle(conversation_length, split_data_dir.joinpath('conversation_length.pkl')) to_pickle(sentences, split_data_dir.joinpath('sentences.pkl')) to_pickle(sentence_length, split_data_dir.joinpath('sentence_length.pkl')) if split_type == 'train': print('Save Vocabulary...') vocab = Vocab(tokenizer) vocab.add_dataframe(conversations) vocab.update(max_size=max_vocab_size, min_freq=min_freq) print('Vocabulary size: ', len(vocab)) vocab.pickle(ubuntu_dir.joinpath('word2id.pkl'), ubuntu_dir.joinpath('id2word.pkl')) print('Done!')