[
  {
    "path": ".gitignore",
    "content": "/etc/\ndatasets/\n/cornell_movie_dialogue/\n*.orig\n*.lprof\n\n# Remote edit\n*.ftpconfig\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nenv/\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# dotenv\n.env\n\n# virtualenv\n.venv\nvenv/\nENV/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2018 Center for SuperIntelligence\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "Readme.md",
    "content": "# Variational Hierarchical Conversation RNN (VHCR)\n[PyTorch 0.4](https://github.com/pytorch/pytorch) Implementation of [\"A Hierarchical Latent Structure for Variational Conversation Modeling\"](https://arxiv.org/abs/1804.03424) (NAACL 2018 Oral)\n* [NAACL 2018 Oral Presentation Video](https://vimeo.com/277671819)\n\n## Prerequisite\nInstall Python packages\n\n```\npip install -r requirements.txt\n```\n\n## Download & Preprocess data\nFollowing scripts will\n\n1. Create directories `./datasets/cornell/` and `./datasets/ubuntu/` respectively.\n\n2. Download and preprocess conversation data inside each directory.\n\n### for [Cornell Movie Dialogue dataset](https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html)\n```\npython cornell_preprocess.py\n    --max_sentence_length (maximum number of words in sentence; default: 30)\n    --max_conversation_length (maximum turns of utterances in single conversation; default: 10)\n    --max_vocab_size (maximum size of word vocabulary; default: 20000)\n    --max_vocab_frequency (minimum frequency of word to be included in vocabulary; default: 5)\n    --n_workers (number of workers for multiprocessing; default: os.cpu_count())\n```\n\n### for [Ubuntu Dialog Dataset](http://dataset.cs.mcgill.ca/ubuntu-corpus-1.0/)\n```\npython ubuntu_preprocess.py\n    --max_sentence_length (maximum number of words in sentence; default: 30)\n    --max_conversation_length (maximum turns of utterances in single conversation; default: 10)\n    --max_vocab_size (maximum size of word vocabulary; default: 20000)\n    --max_vocab_frequency (minimum frequency of word to be included in vocabulary; default: 5)\n    --n_workers (number of workers for multiprocessing; default: os.cpu_count())\n```\n\n\n## Training\nGo to the model directory and set the save_dir in configs.py (this is where the model checkpoints will be saved)\n\nWe provide our implementation of VHCR, as well as our reference implementations for [HRED](https://arxiv.org/abs/1507.02221) and [VHRED](https://arxiv.org/abs/1605.06069).\n\nTo run training:\n```\npython train.py --data=<data> --model=<model> --batch_size=<batch_size>\n```\n\nFor example:\n1. Train HRED on Cornell Movie:\n```\npython train.py --data=cornell --model=HRED\n```\n\n2. Train VHRED with word drop of ratio 0.25 and kl annealing iterations 250000:\n```\npython train.py --data=ubuntu --model=VHRED --batch_size=40 --word_drop=0.25 --kl_annealing_iter=250000\n```\n\n3. Train VHCR with utterance drop of ratio 0.25:\n```\npython train.py --data=ubuntu --model=VHCR --batch_size=40 --sentence_drop=0.25 --kl_annealing_iter=250000\n```\n\nBy default, it will save a model checkpoint every epoch to <save_dir> and a tensorboard summary.\nFor more arguments and options, see config.py.\n\n\n## Evaluation\nTo evaluate the word perplexity:\n```\npython eval.py --model=<model> --checkpoint=<path_to_your_checkpoint>\n```\n\nFor embedding based metrics, you need to download [Google News word vectors](https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit?usp=sharing), unzip it and put it under the datasets folder.\nThen run:\n```\npython eval_embed.py --model=<model> --checkpoint=<path_to_your_checkpoint>\n```\n\n\n## Reference\n\nIf you use this code or dataset as part of any published research, please refer the following paper.\n\n```\n@inproceedings{VHCR:2018:NAACL,\n    author    = {Yookoon Park and Jaemin Cho and Gunhee Kim},\n    title     = \"{A Hierarchical Latent Structure for Variational Conversation Modeling}\",\n    booktitle = {NAACL},\n    year      = 2018\n    }\n```\n"
  },
  {
    "path": "cornell_preprocess.py",
    "content": "# Preprocess cornell movie dialogs dataset\n\nfrom multiprocessing import Pool\nimport argparse\nimport pickle\nimport random\nimport os\nfrom urllib.request import urlretrieve\nfrom zipfile import ZipFile\nfrom pathlib import Path\nfrom tqdm import tqdm\nfrom model.utils import Tokenizer, Vocab, PAD_TOKEN, SOS_TOKEN, EOS_TOKEN\n\nproject_dir = Path(__file__).resolve().parent\ndatasets_dir = project_dir.joinpath('datasets/')\ncornell_dir = datasets_dir.joinpath('cornell/')\n\n# Tokenizer\ntokenizer = Tokenizer('spacy')\n\ndef prepare_cornell_data():\n    \"\"\"Download and unpack dialogs\"\"\"\n\n    zip_url = 'http://www.mpi-sws.org/~cristian/data/cornell_movie_dialogs_corpus.zip'\n    zipfile_path = datasets_dir.joinpath('cornell.zip')\n\n    if not datasets_dir.exists():\n        datasets_dir.mkdir()\n\n    # Prepare Dialog data\n    if not cornell_dir.exists():\n        print(f'Downloading {zip_url} to {zipfile_path}')\n        urlretrieve(zip_url, zipfile_path)\n        print(f'Successfully downloaded {zipfile_path}')\n\n        zip_ref = ZipFile(zipfile_path, 'r')\n        zip_ref.extractall(datasets_dir)\n        zip_ref.close()\n\n        datasets_dir.joinpath('cornell movie-dialogs corpus').rename(cornell_dir)\n\n    else:\n        print('Cornell Data prepared!')\n\n\ndef loadLines(fileName,\n              fields=[\"lineID\", \"characterID\", \"movieID\", \"character\", \"text\"],\n              delimiter=\" +++$+++ \"):\n    \"\"\"\n    Args:\n        fileName (str): file to load\n        field (set<str>): fields to extract\n    Return:\n        dict<dict<str>>: the extracted fields for each line\n    \"\"\"\n    lines = {}\n\n    with open(fileName, 'r', encoding='iso-8859-1') as f:\n        for line in f:\n            values = line.split(delimiter)\n\n            # Extract fields\n            lineObj = {}\n            for i, field in enumerate(fields):\n                lineObj[field] = values[i]\n\n            lines[lineObj['lineID']] = lineObj\n\n    return lines\n\n\ndef loadConversations(fileName, lines,\n                      fields=[\"character1ID\", \"character2ID\", \"movieID\", \"utteranceIDs\"],\n                      delimiter=\" +++$+++ \"):\n    \"\"\"\n    Args:\n        fileName (str): file to load\n        field (set<str>): fields to extract\n    Return:\n        dict<dict<str>>: the extracted fields for each line\n    \"\"\"\n    conversations = []\n\n    with open(fileName, 'r', encoding='iso-8859-1') as f:\n        for line in f:\n            values = line.split(delimiter)\n\n            # Extract fields\n            convObj = {}\n            for i, field in enumerate(fields):\n                convObj[field] = values[i]\n\n            # Convert string to list (convObj[\"utteranceIDs\"] == \"['L598485', 'L598486', ...]\")\n            lineIds = eval(convObj[\"utteranceIDs\"])\n\n            # Reassemble lines\n            convObj[\"lines\"] = []\n            for lineId in lineIds:\n                convObj[\"lines\"].append(lines[lineId])\n\n            conversations.append(convObj)\n\n    return conversations\n\n\ndef train_valid_test_split_by_conversation(conversations, split_ratio=[0.8, 0.1, 0.1]):\n    \"\"\"Train/Validation/Test split by randomly selected movies\"\"\"\n\n    train_ratio, valid_ratio, test_ratio = split_ratio\n    assert train_ratio + valid_ratio + test_ratio == 1.0\n\n    n_conversations = len(conversations)\n\n    # Random shuffle movie list\n    random.seed(0)\n    random.shuffle(conversations)\n\n    # Train / Validation / Test Split\n    train_split = int(n_conversations * train_ratio)\n    valid_split = int(n_conversations * (train_ratio + valid_ratio))\n\n    train = conversations[:train_split]\n    valid = conversations[train_split:valid_split]\n    test = conversations[valid_split:]\n\n    print(f'Train set: {len(train)} conversations')\n    print(f'Validation set: {len(valid)} conversations')\n    print(f'Test set: {len(test)} conversations')\n\n    return train, valid, test\n\n\ndef tokenize_conversation(lines):\n    sentence_list = [tokenizer(line['text']) for line in lines]\n    return sentence_list\n\n\ndef pad_sentences(conversations, max_sentence_length=30, max_conversation_length=10):\n    def pad_tokens(tokens, max_sentence_length=max_sentence_length):\n        n_valid_tokens = len(tokens)\n        if n_valid_tokens > max_sentence_length - 1:\n            tokens = tokens[:max_sentence_length - 1]\n        n_pad = max_sentence_length - n_valid_tokens - 1\n        tokens = tokens + [EOS_TOKEN] + [PAD_TOKEN] * n_pad\n        return tokens\n\n    def pad_conversation(conversation):\n        conversation = [pad_tokens(sentence) for sentence in conversation]\n        return conversation\n\n    all_padded_sentences = []\n    all_sentence_length = []\n\n    for conversation in conversations:\n        if len(conversation) > max_conversation_length:\n            conversation = conversation[:max_conversation_length]\n        sentence_length = [min(len(sentence) + 1, max_sentence_length) # +1 for EOS token\n                           for sentence in conversation]\n        all_sentence_length.append(sentence_length)\n\n        sentences = pad_conversation(conversation)\n        all_padded_sentences.append(sentences)\n\n    sentences = all_padded_sentences\n    sentence_length = all_sentence_length\n    return sentences, sentence_length\n\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n\n    # Maximum valid length of sentence\n    # => SOS/EOS will surround sentence (EOS for source / SOS for target)\n    # => maximum length of tensor = max_sentence_length + 1\n    parser.add_argument('-s', '--max_sentence_length', type=int, default=30)\n    parser.add_argument('-c', '--max_conversation_length', type=int, default=10)\n\n    # Split Ratio\n    split_ratio = [0.8, 0.1, 0.1]\n\n    # Vocabulary\n    parser.add_argument('--max_vocab_size', type=int, default=20000)\n    parser.add_argument('--min_vocab_frequency', type=int, default=5)\n\n    # Multiprocess\n    parser.add_argument('--n_workers', type=int, default=os.cpu_count())\n\n    args = parser.parse_args()\n\n    max_sent_len = args.max_sentence_length\n    max_conv_len = args.max_conversation_length\n    max_vocab_size = args.max_vocab_size\n    min_freq = args.min_vocab_frequency\n    n_workers = args.n_workers\n\n    # Download and extract dialogs if necessary.\n    prepare_cornell_data()\n\n    print(\"Loading lines\")\n    lines = loadLines(cornell_dir.joinpath(\"movie_lines.txt\"))\n    print('Number of lines:', len(lines))\n\n    print(\"Loading conversations...\")\n    conversations = loadConversations(cornell_dir.joinpath(\"movie_conversations.txt\"), lines)\n    print('Number of conversations:', len(conversations))\n    print('Train/Valid/Test Split')\n    # train, valid, test = train_valid_test_split_by_movie(conversations, split_ratio)\n    train, valid, test = train_valid_test_split_by_conversation(conversations, split_ratio)\n\n    def to_pickle(obj, path):\n        with open(path, 'wb') as f:\n            pickle.dump(obj, f)\n\n    for split_type, conv_objects in [('train', train), ('valid', valid), ('test', test)]:\n        print(f'Processing {split_type} dataset...')\n        split_data_dir = cornell_dir.joinpath(split_type)\n        split_data_dir.mkdir(exist_ok=True)\n\n        print(f'Tokenize.. (n_workers={n_workers})')\n        def _tokenize_conversation(conv):\n            return tokenize_conversation(conv['lines'])\n        with Pool(n_workers) as pool:\n            conversations = list(tqdm(pool.imap(_tokenize_conversation, conv_objects),\n                                     total=len(conv_objects)))\n\n        conversation_length = [min(len(conv['lines']), max_conv_len)\n                               for conv in conv_objects]\n\n        sentences, sentence_length = pad_sentences(\n            conversations,\n            max_sentence_length=max_sent_len,\n            max_conversation_length=max_conv_len)\n\n        print('Saving preprocessed data at', split_data_dir)\n        to_pickle(conversation_length, split_data_dir.joinpath('conversation_length.pkl'))\n        to_pickle(sentences, split_data_dir.joinpath('sentences.pkl'))\n        to_pickle(sentence_length, split_data_dir.joinpath('sentence_length.pkl'))\n\n        if split_type == 'train':\n\n            print('Save Vocabulary...')\n            vocab = Vocab(tokenizer)\n            vocab.add_dataframe(conversations)\n            vocab.update(max_size=max_vocab_size, min_freq=min_freq)\n\n            print('Vocabulary size: ', len(vocab))\n            vocab.pickle(cornell_dir.joinpath('word2id.pkl'), cornell_dir.joinpath('id2word.pkl'))\n\n    print('Done!')\n"
  },
  {
    "path": "model/__init__.py",
    "content": ""
  },
  {
    "path": "model/configs.py",
    "content": "import os\nimport argparse\nfrom datetime import datetime\nfrom collections import defaultdict\nfrom pathlib import Path\nimport pprint\nfrom torch import optim\nimport torch.nn as nn\nfrom layers.rnncells import StackedLSTMCell, StackedGRUCell\n\nproject_dir = Path(__file__).resolve().parent.parent\ndata_dir = project_dir.joinpath('datasets')\ndata_dict = {'cornell': data_dir.joinpath('cornell'), 'ubuntu': data_dir.joinpath('ubuntu')}\noptimizer_dict = {'RMSprop': optim.RMSprop, 'Adam': optim.Adam}\nrnn_dict = {'lstm': nn.LSTM, 'gru': nn.GRU}\nrnncell_dict = {'lstm': StackedLSTMCell, 'gru': StackedGRUCell}\nusername = Path.home().name\nsave_dir = Path(f'/data1/{username}/conversation/')\n\n\ndef str2bool(v):\n    \"\"\"string to boolean\"\"\"\n    if v.lower() in ('yes', 'true', 't', 'y', '1'):\n        return True\n    elif v.lower() in ('no', 'false', 'f', 'n', '0'):\n        return False\n    else:\n        raise argparse.ArgumentTypeError('Boolean value expected.')\n\n\nclass Config(object):\n    def __init__(self, **kwargs):\n        \"\"\"Configuration Class: set kwargs as class attributes with setattr\"\"\"\n        if kwargs is not None:\n            for key, value in kwargs.items():\n                if key == 'optimizer':\n                    value = optimizer_dict[value]\n                if key == 'rnn':\n                    value = rnn_dict[value]\n                if key == 'rnncell':\n                    value = rnncell_dict[value]\n                setattr(self, key, value)\n\n        # Dataset directory: ex) ./datasets/cornell/\n        self.dataset_dir = data_dict[self.data.lower()]\n\n        # Data Split ex) 'train', 'valid', 'test'\n        self.data_dir = self.dataset_dir.joinpath(self.mode)\n        # Pickled Vocabulary\n        self.word2id_path = self.dataset_dir.joinpath('word2id.pkl')\n        self.id2word_path = self.dataset_dir.joinpath('id2word.pkl')\n\n        # Pickled Dataframes\n        self.sentences_path = self.data_dir.joinpath('sentences.pkl')\n        self.sentence_length_path = self.data_dir.joinpath('sentence_length.pkl')\n        self.conversation_length_path = self.data_dir.joinpath('conversation_length.pkl')\n\n        # Save path\n        if self.mode == 'train' and self.checkpoint is None:\n            time_now = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')\n            self.save_path = save_dir.joinpath(self.data, self.model, time_now)\n            self.logdir = self.save_path\n            os.makedirs(self.save_path, exist_ok=True)\n        elif self.checkpoint is not None:\n            assert os.path.exists(self.checkpoint)\n            self.save_path = os.path.dirname(self.checkpoint)\n            self.logdir = self.save_path\n\n    def __str__(self):\n        \"\"\"Pretty-print configurations in alphabetical order\"\"\"\n        config_str = 'Configurations\\n'\n        config_str += pprint.pformat(self.__dict__)\n        return config_str\n\n\ndef get_config(parse=True, **optional_kwargs):\n    \"\"\"\n    Get configurations as attributes of class\n    1. Parse configurations with argparse.\n    2. Create Config class initilized with parsed kwargs.\n    3. Return Config class.\n    \"\"\"\n    parser = argparse.ArgumentParser()\n\n    # Mode\n    parser.add_argument('--mode', type=str, default='train')\n\n    # Train\n    parser.add_argument('--batch_size', type=int, default=80)\n    parser.add_argument('--eval_batch_size', type=int, default=80)\n    parser.add_argument('--n_epoch', type=int, default=30)\n    parser.add_argument('--learning_rate', type=float, default=1e-4)\n    parser.add_argument('--optimizer', type=str, default='Adam')\n    parser.add_argument('--clip', type=float, default=1.0)\n    parser.add_argument('--checkpoint', type=str, default=None)\n\n    # Generation\n    parser.add_argument('--max_unroll', type=int, default=30)\n    parser.add_argument('--sample', type=str2bool, default=False,\n                        help='if false, use beam search for decoding')\n    parser.add_argument('--temperature', type=float, default=1.0)\n    parser.add_argument('--beam_size', type=int, default=1)\n\n    # Model\n    parser.add_argument('--model', type=str, default='VHCR',\n                        help='one of {HRED, VHRED, VHCR}')\n    # Currently does not support lstm\n    parser.add_argument('--rnn', type=str, default='gru')\n    parser.add_argument('--rnncell', type=str, default='gru')\n    parser.add_argument('--num_layers', type=int, default=1)\n    parser.add_argument('--embedding_size', type=int, default=500)\n    parser.add_argument('--tie_embedding', type=str2bool, default=True)\n    parser.add_argument('--encoder_hidden_size', type=int, default=1000)\n    parser.add_argument('--bidirectional', type=str2bool, default=True)\n    parser.add_argument('--decoder_hidden_size', type=int, default=1000)\n    parser.add_argument('--dropout', type=float, default=0.2)\n    parser.add_argument('--context_size', type=int, default=1000)\n    parser.add_argument('--feedforward', type=str, default='FeedForward')\n    parser.add_argument('--activation', type=str, default='Tanh')\n\n    # VAE model\n    parser.add_argument('--z_sent_size', type=int, default=100)\n    parser.add_argument('--z_conv_size', type=int, default=100)\n    parser.add_argument('--word_drop', type=float, default=0.0,\n                        help='only applied to variational models')\n    parser.add_argument('--kl_threshold', type=float, default=0.0)\n    parser.add_argument('--kl_annealing_iter', type=int, default=25000)\n    parser.add_argument('--importance_sample', type=int, default=100)\n    parser.add_argument('--sentence_drop', type=float, default=0.0)\n\n    # Generation\n    parser.add_argument('--n_context', type=int, default=1)\n    parser.add_argument('--n_sample_step', type=int, default=1)\n\n    # BOW\n    parser.add_argument('--bow', type=str2bool, default=False)\n\n    # Utility\n    parser.add_argument('--print_every', type=int, default=100)\n    parser.add_argument('--plot_every_epoch', type=int, default=1)\n    parser.add_argument('--save_every_epoch', type=int, default=1)\n\n    # Data\n    parser.add_argument('--data', type=str, default='ubuntu')\n\n    # Parse arguments\n    if parse:\n        kwargs = parser.parse_args()\n    else:\n        kwargs = parser.parse_known_args()[0]\n\n    # Namespace => Dictionary\n    kwargs = vars(kwargs)\n    kwargs.update(optional_kwargs)\n\n    return Config(**kwargs)\n"
  },
  {
    "path": "model/data_loader.py",
    "content": "import random\nfrom collections import defaultdict\nfrom torch.utils.data import Dataset, DataLoader\nfrom utils import PAD_ID, UNK_ID, SOS_ID, EOS_ID\nimport numpy as np\n\n\nclass DialogDataset(Dataset):\n    def __init__(self, sentences, conversation_length, sentence_length, vocab, data=None):\n\n        # [total_data_size, max_conversation_length, max_sentence_length]\n        # tokenized raw text of sentences\n        self.sentences = sentences\n        self.vocab = vocab\n\n        # conversation length of each batch\n        # [total_data_size]\n        self.conversation_length = conversation_length\n\n        # list of length of sentences\n        # [total_data_size, max_conversation_length]\n        self.sentence_length = sentence_length\n        self.data = data\n        self.len = len(sentences)\n\n    def __getitem__(self, index):\n        \"\"\"Return Single data sentence\"\"\"\n        # [max_conversation_length, max_sentence_length]\n        sentence = self.sentences[index]\n        conversation_length = self.conversation_length[index]\n        sentence_length = self.sentence_length[index]\n\n        # word => word_ids\n        sentence = self.sent2id(sentence)\n\n        return sentence, conversation_length, sentence_length\n\n    def __len__(self):\n        return self.len\n\n    def sent2id(self, sentences):\n        \"\"\"word => word id\"\"\"\n        # [max_conversation_length, max_sentence_length]\n        return [self.vocab.sent2id(sentence) for sentence in sentences]\n\n\ndef get_loader(sentences, conversation_length, sentence_length, vocab, batch_size=100, data=None, shuffle=True):\n    \"\"\"Load DataLoader of given DialogDataset\"\"\"\n\n    def collate_fn(data):\n        \"\"\"\n        Collate list of data in to batch\n\n        Args:\n            data: list of tuple(source, target, conversation_length, source_length, target_length)\n        Return:\n            Batch of each feature\n            - source (LongTensor): [batch_size, max_conversation_length, max_source_length]\n            - target (LongTensor): [batch_size, max_conversation_length, max_source_length]\n            - conversation_length (np.array): [batch_size]\n            - source_length (LongTensor): [batch_size, max_conversation_length]\n        \"\"\"\n        # Sort by conversation length (descending order) to use 'pack_padded_sequence'\n        data.sort(key=lambda x: x[1], reverse=True)\n\n        # Separate\n        sentences, conversation_length, sentence_length = zip(*data)\n\n        # return sentences, conversation_length, sentence_length.tolist()\n        return sentences, conversation_length, sentence_length\n\n    dataset = DialogDataset(sentences, conversation_length,\n                            sentence_length, vocab, data=data)\n\n    data_loader = DataLoader(\n        dataset=dataset,\n        batch_size=batch_size,\n        shuffle=shuffle,\n        collate_fn=collate_fn)\n\n    return data_loader\n"
  },
  {
    "path": "model/eval.py",
    "content": "from solver import Solver, VariationalSolver\nfrom data_loader import get_loader\nfrom configs import get_config\nfrom utils import Vocab, Tokenizer\nimport os\nimport pickle\nfrom models import VariationalModels\n\n\ndef load_pickle(path):\n    with open(path, 'rb') as f:\n        return pickle.load(f)\n\n\nif __name__ == '__main__':\n    config = get_config(mode='test')\n\n    print('Loading Vocabulary...')\n    vocab = Vocab()\n    vocab.load(config.word2id_path, config.id2word_path)\n    print(f'Vocabulary size: {vocab.vocab_size}')\n\n    config.vocab_size = vocab.vocab_size\n\n    data_loader = get_loader(\n        sentences=load_pickle(config.sentences_path),\n        conversation_length=load_pickle(config.conversation_length_path),\n        sentence_length=load_pickle(config.sentence_length_path),\n        vocab=vocab,\n        batch_size=config.batch_size)\n\n    if config.model in VariationalModels:\n        solver = VariationalSolver(config, None, data_loader, vocab=vocab, is_train=False)\n        solver.build()\n        solver.importance_sample()\n    else:\n        solver = Solver(config, None, data_loader, vocab=vocab, is_train=False)\n        solver.build()\n        solver.test()\n"
  },
  {
    "path": "model/eval_embed.py",
    "content": "from solver import Solver, VariationalSolver\nfrom data_loader import get_loader\nfrom configs import get_config\nfrom utils import Vocab, Tokenizer\nimport os\nimport pickle\nfrom models import VariationalModels\nimport re\n\n\ndef load_pickle(path):\n    with open(path, 'rb') as f:\n        return pickle.load(f)\n\n\nif __name__ == '__main__':\n    config = get_config(mode='test')\n\n    print('Loading Vocabulary...')\n    vocab = Vocab()\n    vocab.load(config.word2id_path, config.id2word_path)\n    print(f'Vocabulary size: {vocab.vocab_size}')\n\n    config.vocab_size = vocab.vocab_size\n\n    data_loader = get_loader(\n        sentences=load_pickle(config.sentences_path),\n        conversation_length=load_pickle(config.conversation_length_path),\n        sentence_length=load_pickle(config.sentence_length_path),\n        vocab=vocab,\n        batch_size=config.batch_size,\n        shuffle=False)\n\n    if config.model in VariationalModels:\n        solver = VariationalSolver(config, None, data_loader, vocab=vocab, is_train=False)\n    else:\n        solver = Solver(config, None, data_loader, vocab=vocab, is_train=False)\n\n    solver.build()\n    solver.embedding_metric()\n"
  },
  {
    "path": "model/layers/__init__.py",
    "content": "from .encoder import *\nfrom .decoder import *\nfrom .rnncells import StackedLSTMCell, StackedGRUCell\nfrom .loss import *\nfrom .feedforward import *\n"
  },
  {
    "path": "model/layers/beam_search.py",
    "content": "import torch\nfrom utils import EOS_ID\n\n\nclass Beam(object):\n    def __init__(self, batch_size, hidden_size, vocab_size, beam_size, max_unroll, batch_position):\n        \"\"\"Beam class for beam search\"\"\"\n        self.batch_size = batch_size\n        self.hidden_size = hidden_size\n        self.vocab_size = vocab_size\n        self.beam_size = beam_size\n        self.max_unroll = max_unroll\n\n        # batch_position [batch_size]\n        #   [0, beam_size, beam_size * 2, .., beam_size * (batch_size-1)]\n        #   Points where batch starts in [batch_size x beam_size] tensors\n        #   Ex. position_idx[5]: when 5-th batch starts\n        self.batch_position = batch_position\n\n        self.log_probs = list()  # [(batch*k, vocab_size)] * sequence_length\n        self.scores = list()  # [(batch*k)] * sequence_length\n        self.back_pointers = list()  # [(batch*k)] * sequence_length\n        self.token_ids = list()  # [(batch*k)] * sequence_length\n        # self.hidden = list()  # [(num_layers, batch*k, hidden_size)] * sequence_length\n\n        self.metadata = {\n            'inputs': None,\n            'output': None,\n            'scores': None,\n            'length': None,\n            'sequence': None,\n        }\n\n    def update(self, score, back_pointer, token_id):  # , h):\n        \"\"\"Append intermediate top-k candidates to beam at each step\"\"\"\n\n        # self.log_probs.append(log_prob)\n        self.scores.append(score)\n        self.back_pointers.append(back_pointer)\n        self.token_ids.append(token_id)\n        # self.hidden.append(h)\n\n    def backtrack(self):\n        \"\"\"Backtracks over batch to generate optimal k-sequences\n\n        Returns:\n            prediction ([batch, k, max_unroll])\n                A list of Tensors containing predicted sequence\n            final_score [batch, k]\n                A list containing the final scores for all top-k sequences\n            length [batch, k]\n                A list specifying the length of each sequence in the top-k candidates\n        \"\"\"\n        prediction = list()\n\n        # import ipdb\n        # ipdb.set_trace()\n        # Initialize for length of top-k sequences\n        length = [[self.max_unroll] * self.beam_size for _ in range(self.batch_size)]\n\n        # Last step output of the beam are not sorted => sort here!\n        # Size not changed [batch size, beam_size]\n        top_k_score, top_k_idx = self.scores[-1].topk(self.beam_size, dim=1)\n\n        # Initialize sequence scores\n        top_k_score = top_k_score.clone()\n\n        n_eos_in_batch = [0] * self.batch_size\n\n        # Initialize Back-pointer from the last step\n        # Add self.position_idx for indexing variable with batch x beam as the first dimension\n        # [batch x beam]\n        back_pointer = (top_k_idx + self.batch_position.unsqueeze(1)).view(-1)\n\n        for t in reversed(range(self.max_unroll)):\n            # Reorder variables with the Back-pointer\n            # [batch x beam]\n            token_id = self.token_ids[t].index_select(0, back_pointer)\n\n            # Reorder the Back-pointer\n            # [batch x beam]\n            back_pointer = self.back_pointers[t].index_select(0, back_pointer)\n\n            # Indices of ended sequences\n            # [< batch x beam]\n            eos_indices = self.token_ids[t].data.eq(EOS_ID).nonzero()\n\n            # For each batch, every time we see an EOS in the backtracking process,\n            # If not all sequences are ended\n            #    lowest scored survived sequence <- detected ended sequence\n            # if all sequences are ended\n            #    lowest scored ended sequence <- detected ended sequence\n            if eos_indices.dim() > 0:\n                # Loop over all EOS at current step\n                for i in range(eos_indices.size(0) - 1, -1, -1):\n                    # absolute index of detected ended sequence\n                    eos_idx = eos_indices[i, 0].item()\n\n                    # At which batch EOS is located\n                    batch_idx = eos_idx // self.beam_size\n                    batch_start_idx = batch_idx * self.beam_size\n\n                    # if n_eos_in_batch[batch_idx] > self.beam_size:\n\n                    # Index of sequence with lowest score\n                    _n_eos_in_batch = n_eos_in_batch[batch_idx] % self.beam_size\n                    beam_idx_to_be_replaced = self.beam_size - _n_eos_in_batch - 1\n                    idx_to_be_replaced = batch_start_idx + beam_idx_to_be_replaced\n\n                    # Replace old information with new sequence information\n                    back_pointer[idx_to_be_replaced] = self.back_pointers[t][eos_idx].item()\n                    token_id[idx_to_be_replaced] = self.token_ids[t][eos_idx].item()\n                    top_k_score[batch_idx,\n                                beam_idx_to_be_replaced] = self.scores[t].view(-1)[eos_idx].item()\n                    length[batch_idx][beam_idx_to_be_replaced] = t + 1\n\n                    n_eos_in_batch[batch_idx] += 1\n\n            # max_unroll * [batch x beam]\n            prediction.append(token_id)\n\n        # Sort and re-order again as the added ended sequences may change the order\n        # [batch, beam]\n        top_k_score, top_k_idx = top_k_score.topk(self.beam_size, dim=1)\n        final_score = top_k_score.data\n\n        for batch_idx in range(self.batch_size):\n            length[batch_idx] = [length[batch_idx][beam_idx.item()]\n                                 for beam_idx in top_k_idx[batch_idx]]\n\n        # [batch x beam]\n        top_k_idx = (top_k_idx + self.batch_position.unsqueeze(1)).view(-1)\n\n        # Reverse the sequences and re-order at the same time\n        # It is reversed because the backtracking happens in the reverse order\n        # [batch, beam]\n\n        prediction = [step.index_select(0, top_k_idx).view(\n            self.batch_size, self.beam_size) for step in reversed(prediction)]\n\n        # [batch, beam, max_unroll]\n        prediction = torch.stack(prediction, 2)\n\n        return prediction, final_score, length\n"
  },
  {
    "path": "model/layers/decoder.py",
    "content": "import random\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .rnncells import StackedLSTMCell, StackedGRUCell\nfrom .beam_search import Beam\nfrom .feedforward import FeedForward\nfrom utils import to_var, SOS_ID, UNK_ID, EOS_ID\nimport math\n\n\nclass BaseRNNDecoder(nn.Module):\n    def __init__(self):\n        \"\"\"Base Decoder Class\"\"\"\n        super(BaseRNNDecoder, self).__init__()\n\n    @property\n    def use_lstm(self):\n        return isinstance(self.rnncell, StackedLSTMCell)\n\n    def init_token(self, batch_size, SOS_ID=SOS_ID):\n        \"\"\"Get Variable of <SOS> Index (batch_size)\"\"\"\n        x = to_var(torch.LongTensor([SOS_ID] * batch_size))\n        return x\n\n    def init_h(self, batch_size=None, zero=True, hidden=None):\n        \"\"\"Return RNN initial state\"\"\"\n        if hidden is not None:\n            return hidden\n\n        if self.use_lstm:\n            # (h, c)\n            return (to_var(torch.zeros(self.num_layers,\n                                       batch_size,\n                                       self.hidden_size)),\n                    to_var(torch.zeros(self.num_layers,\n                                       batch_size,\n                                       self.hidden_size)))\n        else:\n            # h\n            return to_var(torch.zeros(self.num_layers,\n                                      batch_size,\n                                      self.hidden_size))\n\n    def batch_size(self, inputs=None, h=None):\n        \"\"\"\n        inputs: [batch_size, seq_len]\n        h: [num_layers, batch_size, hidden_size] (RNN/GRU)\n        h_c: [2, num_layers, batch_size, hidden_size] (LSTMCell)\n        \"\"\"\n        if inputs is not None:\n            batch_size = inputs.size(0)\n            return batch_size\n\n        else:\n            if self.use_lstm:\n                batch_size = h[0].size(1)\n            else:\n                batch_size = h.size(1)\n            return batch_size\n\n    def decode(self, out):\n        \"\"\"\n        Args:\n            out: unnormalized word distribution [batch_size, vocab_size]\n        Return:\n            x: word_index [batch_size]\n        \"\"\"\n\n        # Sample next word from multinomial word distribution\n        if self.sample:\n            # x: [batch_size] - word index (next input)\n            x = torch.multinomial(self.softmax(out / self.temperature), 1).view(-1)\n\n        # Greedy sampling\n        else:\n            # x: [batch_size] - word index (next input)\n            _, x = out.max(dim=1)\n        return x\n\n    def forward(self):\n        \"\"\"Base forward function to inherit\"\"\"\n        raise NotImplementedError\n\n    def forward_step(self):\n        \"\"\"Run RNN single step\"\"\"\n        raise NotImplementedError\n\n    def embed(self, x):\n        \"\"\"word index: [batch_size] => word vectors: [batch_size, hidden_size]\"\"\"\n\n        if self.training and self.word_drop > 0.0:\n            if random.random() < self.word_drop:\n                embed = self.embedding(to_var(x.data.new([UNK_ID] * x.size(0))))\n            else:\n                embed = self.embedding(x)\n        else:\n            embed = self.embedding(x)\n\n        return embed\n\n    def beam_decode(self,\n                    init_h=None,\n                    encoder_outputs=None, input_valid_length=None,\n                    decode=False):\n        \"\"\"\n        Args:\n            encoder_outputs (Variable, FloatTensor): [batch_size, source_length, hidden_size]\n            input_valid_length (Variable, LongTensor): [batch_size] (optional)\n            init_h (variable, FloatTensor): [batch_size, hidden_size] (optional)\n        Return:\n            out   : [batch_size, seq_len]\n        \"\"\"\n        batch_size = self.batch_size(h=init_h)\n\n        # [batch_size x beam_size]\n        x = self.init_token(batch_size * self.beam_size, SOS_ID)\n\n        # [num_layers, batch_size x beam_size, hidden_size]\n        h = self.init_h(batch_size, hidden=init_h).repeat(1, self.beam_size, 1)\n\n        # batch_position [batch_size]\n        #   [0, beam_size, beam_size * 2, .., beam_size * (batch_size-1)]\n        #   Points where batch starts in [batch_size x beam_size] tensors\n        #   Ex. position_idx[5]: when 5-th batch starts\n        batch_position = to_var(torch.arange(0, batch_size).long() * self.beam_size)\n\n        # Initialize scores of sequence\n        # [batch_size x beam_size]\n        # Ex. batch_size: 5, beam_size: 3\n        # [0, -inf, -inf, 0, -inf, -inf, 0, -inf, -inf, 0, -inf, -inf, 0, -inf, -inf]\n        score = torch.ones(batch_size * self.beam_size) * -float('inf')\n        score.index_fill_(0, torch.arange(0, batch_size).long() * self.beam_size, 0.0)\n        score = to_var(score)\n\n        # Initialize Beam that stores decisions for backtracking\n        beam = Beam(\n            batch_size,\n            self.hidden_size,\n            self.vocab_size,\n            self.beam_size,\n            self.max_unroll,\n            batch_position)\n\n        for i in range(self.max_unroll):\n\n            # x: [batch_size x beam_size]; (token index)\n            # =>\n            # out: [batch_size x beam_size, vocab_size]\n            # h: [num_layers, batch_size x beam_size, hidden_size]\n            out, h = self.forward_step(x, h,\n                                       encoder_outputs=encoder_outputs,\n                                       input_valid_length=input_valid_length)\n            # log_prob: [batch_size x beam_size, vocab_size]\n            log_prob = F.log_softmax(out, dim=1)\n\n            # [batch_size x beam_size]\n            # => [batch_size x beam_size, vocab_size]\n            score = score.view(-1, 1) + log_prob\n\n            # Select `beam size` transitions out of `vocab size` combinations\n\n            # [batch_size x beam_size, vocab_size]\n            # => [batch_size, beam_size x vocab_size]\n            # Cutoff and retain candidates with top-k scores\n            # score: [batch_size, beam_size]\n            # top_k_idx: [batch_size, beam_size]\n            #       each element of top_k_idx [0 ~ beam x vocab)\n\n            score, top_k_idx = score.view(batch_size, -1).topk(self.beam_size, dim=1)\n\n            # Get token ids with remainder after dividing by top_k_idx\n            # Each element is among [0, vocab_size)\n            # Ex. Index of token 3 in beam 4\n            # (4 * vocab size) + 3 => 3\n            # x: [batch_size x beam_size]\n            x = (top_k_idx % self.vocab_size).view(-1)\n\n            # top-k-pointer [batch_size x beam_size]\n            #       Points top-k beam that scored best at current step\n            #       Later used as back-pointer at backtracking\n            #       Each element is beam index: 0 ~ beam_size\n            #                     + position index: 0 ~ beam_size x (batch_size-1)\n            beam_idx = top_k_idx / self.vocab_size  # [batch_size, beam_size]\n            top_k_pointer = (beam_idx + batch_position.unsqueeze(1)).view(-1)\n\n            # Select next h (size doesn't change)\n            # [num_layers, batch_size * beam_size, hidden_size]\n            h = h.index_select(1, top_k_pointer)\n\n            # Update sequence scores at beam\n            beam.update(score.clone(), top_k_pointer, x)  # , h)\n\n            # Erase scores for EOS so that they are not expanded\n            # [batch_size, beam_size]\n            eos_idx = x.data.eq(EOS_ID).view(batch_size, self.beam_size)\n            if eos_idx.nonzero().dim() > 0:\n                score.data.masked_fill_(eos_idx, -float('inf'))\n\n        # prediction ([batch, k, max_unroll])\n        #     A list of Tensors containing predicted sequence\n        # final_score [batch, k]\n        #     A list containing the final scores for all top-k sequences\n        # length [batch, k]\n        #     A list specifying the length of each sequence in the top-k candidates\n        # prediction, final_score, length = beam.backtrack()\n        prediction, final_score, length = beam.backtrack()\n\n        return prediction, final_score, length\n\n\nclass DecoderRNN(BaseRNNDecoder):\n    def __init__(self, vocab_size, embedding_size,\n                 hidden_size, rnncell=StackedGRUCell, num_layers=1,\n                 dropout=0.0, word_drop=0.0,\n                 max_unroll=30, sample=True, temperature=1.0, beam_size=1):\n        super(DecoderRNN, self).__init__()\n\n        self.vocab_size = vocab_size\n        self.embedding_size = embedding_size\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.dropout = dropout\n        self.temperature = temperature\n        self.word_drop = word_drop\n        self.max_unroll = max_unroll\n        self.sample = sample\n        self.beam_size = beam_size\n\n        self.embedding = nn.Embedding(vocab_size, embedding_size)\n\n        self.rnncell = rnncell(num_layers,\n                               embedding_size,\n                               hidden_size,\n                               dropout)\n        self.out = nn.Linear(hidden_size, vocab_size)\n        self.softmax = nn.Softmax(dim=1)\n\n    def forward_step(self, x, h,\n                     encoder_outputs=None,\n                     input_valid_length=None):\n        \"\"\"\n        Single RNN Step\n        1. Input Embedding (vocab_size => hidden_size)\n        2. RNN Step (hidden_size => hidden_size)\n        3. Output Projection (hidden_size => vocab size)\n\n        Args:\n            x: [batch_size]\n            h: [num_layers, batch_size, hidden_size] (h and c from all layers)\n\n        Return:\n            out: [batch_size,vocab_size] (Unnormalized word distribution)\n            h: [num_layers, batch_size, hidden_size] (h and c from all layers)\n        \"\"\"\n        # x: [batch_size] => [batch_size, hidden_size]\n        x = self.embed(x)\n\n        # last_h: [batch_size, hidden_size] (h from Top RNN layer)\n        # h: [num_layers, batch_size, hidden_size] (h and c from all layers)\n        last_h, h = self.rnncell(x, h)\n\n        if self.use_lstm:\n            # last_h_c: [2, batch_size, hidden_size] (h from Top RNN layer)\n            # h_c: [2, num_layers, batch_size, hidden_size] (h and c from all layers)\n            last_h = last_h[0]\n\n        # Unormalized word distribution\n        # out: [batch_size, vocab_size]\n        out = self.out(last_h)\n        return out, h\n\n    def forward(self, inputs, init_h=None, encoder_outputs=None, input_valid_length=None,\n                decode=False):\n        \"\"\"\n        Train (decode=False)\n            Args:\n                inputs (Variable, LongTensor): [batch_size, seq_len]\n                init_h: (Variable, FloatTensor): [num_layers, batch_size, hidden_size]\n            Return:\n                out   : [batch_size, seq_len, vocab_size]\n        Test (decode=True)\n            Args:\n                inputs: None\n                init_h: (Variable, FloatTensor): [num_layers, batch_size, hidden_size]\n            Return:\n                out   : [batch_size, seq_len]\n        \"\"\"\n        batch_size = self.batch_size(inputs, init_h)\n\n        # x: [batch_size]\n        x = self.init_token(batch_size, SOS_ID)\n\n        # h: [num_layers, batch_size, hidden_size]\n        h = self.init_h(batch_size, hidden=init_h)\n\n\n        if not decode:\n            out_list = []\n            seq_len = inputs.size(1)\n            for i in range(seq_len):\n\n                # x: [batch_size]\n                # =>\n                # out: [batch_size, vocab_size]\n                # h: [num_layers, batch_size, hidden_size] (h and c from all layers)\n                out, h = self.forward_step(x, h)\n\n                out_list.append(out)\n                x = inputs[:, i]\n\n            # [batch_size, max_target_len, vocab_size]\n            return torch.stack(out_list, dim=1)\n        else:\n            x_list = []\n            for i in range(self.max_unroll):\n\n                # x: [batch_size]\n                # =>\n                # out: [batch_size, vocab_size]\n                # h: [num_layers, batch_size, hidden_size] (h and c from all layers)\n                out, h = self.forward_step(x, h)\n\n                # out: [batch_size, vocab_size]\n                # => x: [batch_size]\n                x = self.decode(out)\n                x_list.append(x)\n\n            # [batch_size, max_target_len]\n            return torch.stack(x_list, dim=1)\n"
  },
  {
    "path": "model/layers/encoder.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence\nfrom utils import to_var, reverse_order_valid, PAD_ID\nfrom .rnncells import StackedGRUCell, StackedLSTMCell\n\nimport copy\n\nclass BaseRNNEncoder(nn.Module):\n    def __init__(self):\n        \"\"\"Base RNN Encoder Class\"\"\"\n        super(BaseRNNEncoder, self).__init__()\n\n    @property\n    def use_lstm(self):\n        if hasattr(self, 'rnn'):\n            return isinstance(self.rnn, nn.LSTM)\n        else:\n            raise AttributeError('no rnn selected')\n\n    def init_h(self, batch_size=None, hidden=None):\n        \"\"\"Return RNN initial state\"\"\"\n        if hidden is not None:\n            return hidden\n\n        if self.use_lstm:\n            return (to_var(torch.zeros(self.num_layers*self.num_directions,\n                                      batch_size,\n                                      self.hidden_size)),\n                    to_var(torch.zeros(self.num_layers*self.num_directions,\n                                      batch_size,\n                                      self.hidden_size)))\n        else:\n            return to_var(torch.zeros(self.num_layers*self.num_directions,\n                                        batch_size,\n                                        self.hidden_size))\n\n    def batch_size(self, inputs=None, h=None):\n        \"\"\"\n        inputs: [batch_size, seq_len]\n        h: [num_layers, batch_size, hidden_size] (RNN/GRU)\n        h_c: [2, num_layers, batch_size, hidden_size] (LSTM)\n        \"\"\"\n        if inputs is not None:\n            batch_size = inputs.size(0)\n            return batch_size\n\n        else:\n            if self.use_lstm:\n                batch_size = h[0].size(1)\n            else:\n                batch_size = h.size(1)\n            return batch_size\n\n    def forward(self):\n        raise NotImplementedError\n\n\nclass EncoderRNN(BaseRNNEncoder):\n    def __init__(self, vocab_size, embedding_size,\n                 hidden_size, rnn=nn.GRU, num_layers=1, bidirectional=False,\n                 dropout=0.0, bias=True, batch_first=True):\n        \"\"\"Sentence-level Encoder\"\"\"\n        super(EncoderRNN, self).__init__()\n\n        self.vocab_size = vocab_size\n        self.embedding_size = embedding_size\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.dropout = dropout\n        self.batch_first = batch_first\n        self.bidirectional = bidirectional\n\n        if bidirectional:\n            self.num_directions = 2\n        else:\n            self.num_directions = 1\n\n        # word embedding\n        self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=PAD_ID)\n\n        self.rnn = rnn(input_size=embedding_size,\n                        hidden_size=hidden_size,\n                        num_layers=num_layers,\n                        bias=bias,\n                        batch_first=batch_first,\n                        dropout=dropout,\n                        bidirectional=bidirectional)\n\n    def forward(self, inputs, input_length, hidden=None):\n        \"\"\"\n        Args:\n            inputs (Variable, LongTensor): [num_setences, max_seq_len]\n            input_length (Variable, LongTensor): [num_sentences]\n        Return:\n            outputs (Variable): [max_source_length, batch_size, hidden_size]\n                - list of all hidden states\n            hidden ((tuple of) Variable): [num_layers*num_directions, batch_size, hidden_size]\n                - last hidden state\n                - (h, c) or h\n        \"\"\"\n        batch_size, seq_len = inputs.size()\n\n        # Sort in decreasing order of length for pack_padded_sequence()\n        input_length_sorted, indices = input_length.sort(descending=True)\n\n        input_length_sorted = input_length_sorted.data.tolist()\n\n        # [num_sentences, max_source_length]\n        inputs_sorted = inputs.index_select(0, indices)\n\n        # [num_sentences, max_source_length, embedding_dim]\n        embedded = self.embedding(inputs_sorted)\n\n        # batch_first=True\n        rnn_input = pack_padded_sequence(embedded, input_length_sorted,\n                                            batch_first=self.batch_first)\n\n        hidden = self.init_h(batch_size, hidden=hidden)\n\n        # outputs: [batch, seq_len, hidden_size * num_directions]\n        # hidden: [num_layers * num_directions, batch, hidden_size]\n        self.rnn.flatten_parameters()\n        outputs, hidden = self.rnn(rnn_input, hidden)\n        outputs, outputs_lengths = pad_packed_sequence(outputs, batch_first=self.batch_first)\n\n        # Reorder outputs and hidden\n        _, inverse_indices = indices.sort()\n        outputs = outputs.index_select(0, inverse_indices)\n\n        if self.use_lstm:\n            hidden = (hidden[0].index_select(1, inverse_indices),\n                        hidden[1].index_select(1, inverse_indices))\n        else:\n            hidden = hidden.index_select(1, inverse_indices)\n\n        return outputs, hidden\n\nclass ContextRNN(BaseRNNEncoder):\n    def __init__(self, input_size, context_size, rnn=nn.GRU, num_layers=1, dropout=0.0,\n                 bidirectional=False, bias=True, batch_first=True):\n        \"\"\"Context-level Encoder\"\"\"\n        super(ContextRNN, self).__init__()\n\n        self.input_size = input_size\n        self.context_size = context_size\n        self.hidden_size = self.context_size\n        self.num_layers = num_layers\n        self.dropout = dropout\n        self.bidirectional = bidirectional\n        self.batch_first = batch_first\n\n        if bidirectional:\n            self.num_directions = 2\n        else:\n            self.num_directions = 1\n\n        self.rnn = rnn(input_size=input_size,\n                        hidden_size=context_size,\n                        num_layers=num_layers,\n                        bias=bias,\n                        batch_first=batch_first,\n                        dropout=dropout,\n                        bidirectional=bidirectional)\n\n    def forward(self, encoder_hidden, conversation_length, hidden=None):\n        \"\"\"\n        Args:\n            encoder_hidden (Variable, FloatTensor): [batch_size, max_len, num_layers * direction * hidden_size]\n            conversation_length (Variable, LongTensor): [batch_size]\n        Return:\n            outputs (Variable): [batch_size, max_seq_len, hidden_size]\n                - list of all hidden states\n            hidden ((tuple of) Variable): [num_layers*num_directions, batch_size, hidden_size]\n                - last hidden state\n                - (h, c) or h\n        \"\"\"\n        batch_size, seq_len, _  = encoder_hidden.size()\n\n        # Sort for PackedSequence\n        conv_length_sorted, indices = conversation_length.sort(descending=True)\n        conv_length_sorted = conv_length_sorted.data.tolist()\n        encoder_hidden_sorted = encoder_hidden.index_select(0, indices)\n\n        rnn_input = pack_padded_sequence(encoder_hidden_sorted, conv_length_sorted, batch_first=True)\n\n        hidden = self.init_h(batch_size, hidden=hidden)\n\n        self.rnn.flatten_parameters()\n        outputs, hidden = self.rnn(rnn_input, hidden)\n\n        # outputs: [batch_size, max_conversation_length, context_size]\n        outputs, outputs_length = pad_packed_sequence(outputs, batch_first=True)\n\n        # reorder outputs and hidden\n        _, inverse_indices = indices.sort()\n        outputs = outputs.index_select(0, inverse_indices)\n\n        if self.use_lstm:\n            hidden = (hidden[0].index_select(1, inverse_indices),\n                    hidden[1].index_select(1, inverse_indices))\n        else:\n            hidden = hidden.index_select(1, inverse_indices)\n\n        # outputs: [batch, seq_len, hidden_size * num_directions]\n        # hidden: [num_layers * num_directions, batch, hidden_size]\n        return outputs, hidden\n\n    def step(self, encoder_hidden, hidden):\n\n        batch_size = encoder_hidden.size(0)\n        # encoder_hidden: [1, batch_size, hidden_size]\n        encoder_hidden = torch.unsqueeze(encoder_hidden, 1)\n\n        if hidden is None:\n            hidden = self.init_h(batch_size, hidden=None)\n\n        outputs, hidden = self.rnn(encoder_hidden, hidden)\n        return outputs, hidden\n"
  },
  {
    "path": "model/layers/feedforward.py",
    "content": "import torch\nimport torch.nn as nn\n\n\nclass FeedForward(nn.Module):\n    def __init__(self, input_size, output_size, num_layers=1, hidden_size=None,\n                 activation=\"Tanh\", bias=True):\n        super(FeedForward, self).__init__()\n        self.input_size = input_size\n        self.output_size = output_size\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n        self.activation = getattr(nn, activation)()\n        n_inputs = [input_size] + [hidden_size] * (num_layers - 1)\n        n_outputs = [hidden_size] * (num_layers - 1) + [output_size]\n        self.linears = nn.ModuleList([nn.Linear(n_in, n_out, bias=bias)\n                                      for n_in, n_out in zip(n_inputs, n_outputs)])\n\n    def forward(self, input):\n        x = input\n        for linear in self.linears:\n            x = linear(x)\n            x = self.activation(x)\n\n        return x\n"
  },
  {
    "path": "model/layers/loss.py",
    "content": "import torch\nfrom torch.nn import functional as F\nimport torch.nn as nn\nfrom utils import to_var, sequence_mask\n\n\n# https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1\ndef masked_cross_entropy(logits, target, length, per_example=False):\n    \"\"\"\n    Args:\n        logits (Variable, FloatTensor): [batch, max_len, num_classes]\n            - unnormalized probability for each class\n        target (Variable, LongTensor): [batch, max_len]\n            - index of true class for each corresponding step\n        length (Variable, LongTensor): [batch]\n            - length of each data in a batch\n    Returns:\n        loss (Variable): []\n            - An average loss value masked by the length\n    \"\"\"\n    batch_size, max_len, num_classes = logits.size()\n\n    # [batch_size * max_len, num_classes]\n    logits_flat = logits.view(-1, num_classes)\n\n    # [batch_size * max_len, num_classes]\n    log_probs_flat = F.log_softmax(logits_flat, dim=1)\n\n    # [batch_size * max_len, 1]\n    target_flat = target.view(-1, 1)\n\n    # Negative Log-likelihood: -sum {  1* log P(target)  + 0 log P(non-target)} = -sum( log P(target) )\n    # [batch_size * max_len, 1]\n    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)\n\n    # [batch_size, max_len]\n    losses = losses_flat.view(batch_size, max_len)\n\n    # [batch_size, max_len]\n    mask = sequence_mask(sequence_length=length, max_len=max_len)\n\n    # Apply masking on loss\n    losses = losses * mask.float()\n\n    # word-wise cross entropy\n    # loss = losses.sum() / length.float().sum()\n\n    if per_example:\n        # loss: [batch_size]\n        return losses.sum(1)\n    else:\n        loss = losses.sum()\n        return loss, length.float().sum()\n"
  },
  {
    "path": "model/layers/rnncells.py",
    "content": "# Modified from OpenNMT.py, Z-forcing\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn._functions.thnn.rnnFusedPointwise import LSTMFused, GRUFused\n\n\nclass StackedLSTMCell(nn.Module):\n\n    def __init__(self, num_layers, input_size, rnn_size, dropout):\n        super(StackedLSTMCell, self).__init__()\n        self.dropout = nn.Dropout(dropout)\n        self.num_layers = num_layers\n\n        self.layers = nn.ModuleList()\n        for i in range(num_layers):\n            self.layers.append(nn.LSTMCell(input_size, rnn_size))\n            input_size = rnn_size\n\n    def forward(self, x, h_c):\n        \"\"\"\n        Args:\n            x: [batch_size, input_size]\n            h_c: [2, num_layers, batch_size, hidden_size]\n        Return:\n            last_h_c: [2, batch_size, hidden_size] (h from last layer)\n            h_c_list: [2, num_layers, batch_size, hidden_size] (h and c from all layers)\n        \"\"\"\n        h_0, c_0 = h_c\n        h_list, c_list = [], []\n        for i, layer in enumerate(self.layers):\n            # h of i-th layer\n            h_i, c_i = layer(x, (h_0[i], c_0[i]))\n\n            # x for next layer\n            x = h_i\n            if i + 1 != self.num_layers:\n                x = self.dropout(x)\n            h_list += [h_i]\n            c_list += [c_i]\n\n        last_h_c = (h_list[-1], c_list[-1])\n        h_list = torch.stack(h_list)\n        c_list = torch.stack(c_list)\n        h_c_list = (h_list, c_list)\n\n        return last_h_c, h_c_list\n\n\nclass StackedGRUCell(nn.Module):\n\n    def __init__(self, num_layers, input_size, rnn_size, dropout):\n        super(StackedGRUCell, self).__init__()\n        self.dropout = nn.Dropout(dropout)\n        self.num_layers = num_layers\n\n        self.layers = nn.ModuleList()\n        for i in range(num_layers):\n            self.layers.append(nn.GRUCell(input_size, rnn_size))\n            input_size = rnn_size\n\n    def forward(self, x, h):\n        \"\"\"\n        Args:\n            x: [batch_size, input_size]\n            h: [num_layers, batch_size, hidden_size]\n        Return:\n            last_h: [batch_size, hidden_size] (h from last layer)\n            h_list: [num_layers, batch_size, hidden_size] (h from all layers)\n        \"\"\"\n        # h of all layers\n        h_list = []\n        for i, layer in enumerate(self.layers):\n            # h of i-th layer\n            h_i = layer(x, h[i])\n\n            # x for next layer\n            x = h_i\n            if i + 1 is not self.num_layers:\n                x = self.dropout(x)\n            h_list.append(h_i)\n\n        last_h = h_list[-1]\n        h_list = torch.stack(h_list)\n\n        return last_h, h_list\n"
  },
  {
    "path": "model/models.py",
    "content": "import torch\nimport torch.nn as nn\nfrom utils import to_var, pad, normal_kl_div, normal_logpdf, bag_of_words_loss, to_bow, EOS_ID\nimport layers\nimport numpy as np\nimport random\n\nVariationalModels = ['VHRED', 'VHCR']\n\nclass HRED(nn.Module):\n    def __init__(self, config):\n        super(HRED, self).__init__()\n\n        self.config = config\n        self.encoder = layers.EncoderRNN(config.vocab_size,\n                                         config.embedding_size,\n                                         config.encoder_hidden_size,\n                                         config.rnn,\n                                         config.num_layers,\n                                         config.bidirectional,\n                                         config.dropout)\n\n        context_input_size = (config.num_layers\n                              * config.encoder_hidden_size\n                              * self.encoder.num_directions)\n        self.context_encoder = layers.ContextRNN(context_input_size,\n                                                 config.context_size,\n                                                 config.rnn,\n                                                 config.num_layers,\n                                                 config.dropout)\n\n        self.decoder = layers.DecoderRNN(config.vocab_size,\n                                         config.embedding_size,\n                                         config.decoder_hidden_size,\n                                         config.rnncell,\n                                         config.num_layers,\n                                         config.dropout,\n                                         config.word_drop,\n                                         config.max_unroll,\n                                         config.sample,\n                                         config.temperature,\n                                         config.beam_size)\n\n        self.context2decoder = layers.FeedForward(config.context_size,\n                                                  config.num_layers * config.decoder_hidden_size,\n                                                  num_layers=1,\n                                                  activation=config.activation)\n\n        if config.tie_embedding:\n            self.decoder.embedding = self.encoder.embedding\n\n    def forward(self, input_sentences, input_sentence_length,\n                input_conversation_length, target_sentences, decode=False):\n        \"\"\"\n        Args:\n            input_sentences: (Variable, LongTensor) [num_sentences, seq_len]\n            target_sentences: (Variable, LongTensor) [num_sentences, seq_len]\n        Return:\n            decoder_outputs: (Variable, FloatTensor)\n                - train: [batch_size, seq_len, vocab_size]\n                - eval: [batch_size, seq_len]\n        \"\"\"\n        num_sentences = input_sentences.size(0)\n        max_len = input_conversation_length.data.max().item()\n\n        # encoder_outputs: [num_sentences, max_source_length, hidden_size * direction]\n        # encoder_hidden: [num_layers * direction, num_sentences, hidden_size]\n        encoder_outputs, encoder_hidden = self.encoder(input_sentences,\n                                                       input_sentence_length)\n\n        # encoder_hidden: [num_sentences, num_layers * direction * hidden_size]\n        encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(num_sentences, -1)\n\n        # pad and pack encoder_hidden\n        start = torch.cumsum(torch.cat((to_var(input_conversation_length.data.new(1).zero_()),\n                                        input_conversation_length[:-1])), 0)\n\n        # encoder_hidden: [batch_size, max_len, num_layers * direction * hidden_size]\n        encoder_hidden = torch.stack([pad(encoder_hidden.narrow(0, s, l), max_len)\n                                      for s, l in zip(start.data.tolist(),\n                                                      input_conversation_length.data.tolist())], 0)\n\n        # context_outputs: [batch_size, max_len, context_size]\n        context_outputs, context_last_hidden = self.context_encoder(encoder_hidden,\n                                                                    input_conversation_length)\n\n        # flatten outputs\n        # context_outputs: [num_sentences, context_size]\n        context_outputs = torch.cat([context_outputs[i, :l, :]\n                                     for i, l in enumerate(input_conversation_length.data)])\n\n        # project context_outputs to decoder init state\n        decoder_init = self.context2decoder(context_outputs)\n\n        # [num_layers, batch_size, hidden_size]\n        decoder_init = decoder_init.view(self.decoder.num_layers, -1, self.decoder.hidden_size)\n\n        # train: [batch_size, seq_len, vocab_size]\n        # eval: [batch_size, seq_len]\n        if not decode:\n\n            decoder_outputs = self.decoder(target_sentences,\n                                           init_h=decoder_init,\n                                           decode=decode)\n            return decoder_outputs\n\n        else:\n            # decoder_outputs = self.decoder(target_sentences,\n            #                                init_h=decoder_init,\n            #                                decode=decode)\n            # return decoder_outputs.unsqueeze(1)\n            # prediction: [batch_size, beam_size, max_unroll]\n            prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init)\n\n            # Get top prediction only\n            # [batch_size, max_unroll]\n            # prediction = prediction[:, 0]\n\n            # [batch_size, beam_size, max_unroll]\n            return prediction\n\n    def generate(self, context, sentence_length, n_context):\n        # context: [batch_size, n_context, seq_len]\n        batch_size = context.size(0)\n        # n_context = context.size(1)\n        samples = []\n\n        # Run for context\n        context_hidden=None\n        for i in range(n_context):\n            # encoder_outputs: [batch_size, seq_len, hidden_size * direction]\n            # encoder_hidden: [num_layers * direction, batch_size, hidden_size]\n            encoder_outputs, encoder_hidden = self.encoder(context[:, i, :],\n                                                           sentence_length[:, i])\n\n            encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1)\n            # context_outputs: [batch_size, 1, context_hidden_size * direction]\n            # context_hidden: [num_layers * direction, batch_size, context_hidden_size]\n            context_outputs, context_hidden = self.context_encoder.step(encoder_hidden,\n                                                                        context_hidden)\n\n        # Run for generation\n        for j in range(self.config.n_sample_step):\n            # context_outputs: [batch_size, context_hidden_size * direction]\n            context_outputs = context_outputs.squeeze(1)\n            decoder_init = self.context2decoder(context_outputs)\n            decoder_init = decoder_init.view(self.decoder.num_layers, -1, self.decoder.hidden_size)\n\n            prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init)\n            # prediction: [batch_size, seq_len]\n            prediction = prediction[:, 0, :]\n            # length: [batch_size]\n            length = [l[0] for l in length]\n            length = to_var(torch.LongTensor(length))\n            samples.append(prediction)\n\n            encoder_outputs, encoder_hidden = self.encoder(prediction,\n                                                           length)\n\n            encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1)\n\n            context_outputs, context_hidden = self.context_encoder.step(encoder_hidden,\n                                                                        context_hidden)\n\n        samples = torch.stack(samples, 1)\n        return samples\n\n\nclass VHRED(nn.Module):\n    def __init__(self, config):\n        super(VHRED, self).__init__()\n\n        self.config = config\n        self.encoder = layers.EncoderRNN(config.vocab_size,\n                                         config.embedding_size,\n                                         config.encoder_hidden_size,\n                                         config.rnn,\n                                         config.num_layers,\n                                         config.bidirectional,\n                                         config.dropout)\n\n        context_input_size = (config.num_layers\n                              * config.encoder_hidden_size\n                              * self.encoder.num_directions)\n        self.context_encoder = layers.ContextRNN(context_input_size,\n                                                 config.context_size,\n                                                 config.rnn,\n                                                 config.num_layers,\n                                                 config.dropout)\n\n        self.decoder = layers.DecoderRNN(config.vocab_size,\n                                         config.embedding_size,\n                                         config.decoder_hidden_size,\n                                         config.rnncell,\n                                         config.num_layers,\n                                         config.dropout,\n                                         config.word_drop,\n                                         config.max_unroll,\n                                         config.sample,\n                                         config.temperature,\n                                         config.beam_size)\n\n        self.context2decoder = layers.FeedForward(config.context_size + config.z_sent_size,\n                                                  config.num_layers * config.decoder_hidden_size,\n                                                  num_layers=1,\n                                                  activation=config.activation)\n\n        self.softplus = nn.Softplus()\n        self.prior_h = layers.FeedForward(config.context_size,\n                                          config.context_size,\n                                          num_layers=2,\n                                          hidden_size=config.context_size,\n                                          activation=config.activation)\n        self.prior_mu = nn.Linear(config.context_size,\n                                  config.z_sent_size)\n        self.prior_var = nn.Linear(config.context_size,\n                                   config.z_sent_size)\n\n        self.posterior_h = layers.FeedForward(config.encoder_hidden_size * self.encoder.num_directions * config.num_layers + config.context_size,\n                                              config.context_size,\n                                              num_layers=2,\n                                              hidden_size=config.context_size,\n                                              activation=config.activation)\n        self.posterior_mu = nn.Linear(config.context_size,\n                                      config.z_sent_size)\n        self.posterior_var = nn.Linear(config.context_size,\n                                       config.z_sent_size)\n        if config.tie_embedding:\n            self.decoder.embedding = self.encoder.embedding\n\n        if config.bow:\n            self.bow_h = layers.FeedForward(config.z_sent_size,\n                                            config.decoder_hidden_size,\n                                            num_layers=1,\n                                            hidden_size=config.decoder_hidden_size,\n                                            activation=config.activation)\n            self.bow_predict = nn.Linear(config.decoder_hidden_size, config.vocab_size)\n\n    def prior(self, context_outputs):\n        # Context dependent prior\n        h_prior = self.prior_h(context_outputs)\n        mu_prior = self.prior_mu(h_prior)\n        var_prior = self.softplus(self.prior_var(h_prior))\n        return mu_prior, var_prior\n\n    def posterior(self, context_outputs, encoder_hidden):\n        h_posterior = self.posterior_h(torch.cat([context_outputs, encoder_hidden], 1))\n        mu_posterior = self.posterior_mu(h_posterior)\n        var_posterior = self.softplus(self.posterior_var(h_posterior))\n        return mu_posterior, var_posterior\n\n    def compute_bow_loss(self, target_conversations):\n        target_bow = np.stack([to_bow(sent, self.config.vocab_size) for conv in target_conversations for sent in conv], axis=0)\n        target_bow = to_var(torch.FloatTensor(target_bow))\n        bow_logits = self.bow_predict(self.bow_h(self.z_sent))\n        bow_loss = bag_of_words_loss(bow_logits, target_bow)\n        return bow_loss\n\n    def forward(self, sentences, sentence_length,\n                input_conversation_length, target_sentences, decode=False):\n        \"\"\"\n        Args:\n            sentences: (Variable, LongTensor) [num_sentences + batch_size, seq_len]\n            target_sentences: (Variable, LongTensor) [num_sentences, seq_len]\n        Return:\n            decoder_outputs: (Variable, FloatTensor)\n                - train: [batch_size, seq_len, vocab_size]\n                - eval: [batch_size, seq_len]\n        \"\"\"\n        batch_size = input_conversation_length.size(0)\n        num_sentences = sentences.size(0) - batch_size\n        max_len = input_conversation_length.data.max().item()\n\n        # encoder_outputs: [num_sentences + batch_size, max_source_length, hidden_size]\n        # encoder_hidden: [num_layers * direction, num_sentences + batch_size, hidden_size]\n        encoder_outputs, encoder_hidden = self.encoder(sentences,\n                                                       sentence_length)\n\n        # encoder_hidden: [num_sentences + batch_size, num_layers * direction * hidden_size]\n        encoder_hidden = encoder_hidden.transpose(\n            1, 0).contiguous().view(num_sentences + batch_size, -1)\n\n        # pad and pack encoder_hidden\n        start = torch.cumsum(torch.cat((to_var(input_conversation_length.data.new(1).zero_()),\n                                        input_conversation_length[:-1] + 1)), 0)\n        # encoder_hidden: [batch_size, max_len + 1, num_layers * direction * hidden_size]\n        encoder_hidden = torch.stack([pad(encoder_hidden.narrow(0, s, l + 1), max_len + 1)\n                                      for s, l in zip(start.data.tolist(),\n                                                      input_conversation_length.data.tolist())], 0)\n\n        # encoder_hidden_inference: [batch_size, max_len, num_layers * direction * hidden_size]\n        encoder_hidden_inference = encoder_hidden[:, 1:, :]\n        encoder_hidden_inference_flat = torch.cat(\n            [encoder_hidden_inference[i, :l, :] for i, l in enumerate(input_conversation_length.data)])\n\n        # encoder_hidden_input: [batch_size, max_len, num_layers * direction * hidden_size]\n        encoder_hidden_input = encoder_hidden[:, :-1, :]\n\n        # context_outputs: [batch_size, max_len, context_size]\n        context_outputs, context_last_hidden = self.context_encoder(encoder_hidden_input,\n                                                                    input_conversation_length)\n        # flatten outputs\n        # context_outputs: [num_sentences, context_size]\n        context_outputs = torch.cat([context_outputs[i, :l, :]\n                                     for i, l in enumerate(input_conversation_length.data)])\n\n        mu_prior, var_prior = self.prior(context_outputs)\n        eps = to_var(torch.randn((num_sentences, self.config.z_sent_size)))\n        if not decode:\n            mu_posterior, var_posterior = self.posterior(\n                context_outputs, encoder_hidden_inference_flat)\n            z_sent = mu_posterior + torch.sqrt(var_posterior) * eps\n            log_q_zx = normal_logpdf(z_sent, mu_posterior, var_posterior).sum()\n\n            log_p_z = normal_logpdf(z_sent, mu_prior, var_prior).sum()\n            # kl_div: [num_sentneces]\n            kl_div = normal_kl_div(mu_posterior, var_posterior,\n                                    mu_prior, var_prior)\n            kl_div = torch.sum(kl_div)\n        else:\n            z_sent = mu_prior + torch.sqrt(var_prior) * eps\n            kl_div = None\n            log_p_z = normal_logpdf(z_sent, mu_prior, var_prior).sum()\n            log_q_zx = None\n\n        self.z_sent = z_sent\n        latent_context = torch.cat([context_outputs, z_sent], 1)\n        decoder_init = self.context2decoder(latent_context)\n        decoder_init = decoder_init.view(-1,\n                                         self.decoder.num_layers,\n                                         self.decoder.hidden_size)\n        decoder_init = decoder_init.transpose(1, 0).contiguous()\n\n        # train: [batch_size, seq_len, vocab_size]\n        # eval: [batch_size, seq_len]\n        if not decode:\n\n            decoder_outputs = self.decoder(target_sentences,\n                                           init_h=decoder_init,\n                                           decode=decode)\n\n            return decoder_outputs, kl_div, log_p_z, log_q_zx\n\n        else:\n            # prediction: [batch_size, beam_size, max_unroll]\n            prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init)\n\n            return prediction, kl_div, log_p_z, log_q_zx\n\n    def generate(self, context, sentence_length, n_context):\n        # context: [batch_size, n_context, seq_len]\n        batch_size = context.size(0)\n        # n_context = context.size(1)\n        samples = []\n\n        # Run for context\n        context_hidden=None\n        for i in range(n_context):\n            # encoder_outputs: [batch_size, seq_len, hidden_size * direction]\n            # encoder_hidden: [num_layers * direction, batch_size, hidden_size]\n            encoder_outputs, encoder_hidden = self.encoder(context[:, i, :],\n                                                           sentence_length[:, i])\n\n            encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1)\n            # context_outputs: [batch_size, 1, context_hidden_size * direction]\n            # context_hidden: [num_layers * direction, batch_size, context_hidden_size]\n            context_outputs, context_hidden = self.context_encoder.step(encoder_hidden,\n                                                                        context_hidden)\n\n        # Run for generation\n        for j in range(self.config.n_sample_step):\n            # context_outputs: [batch_size, context_hidden_size * direction]\n            context_outputs = context_outputs.squeeze(1)\n\n            mu_prior, var_prior = self.prior(context_outputs)\n            eps = to_var(torch.randn((batch_size, self.config.z_sent_size)))\n            z_sent = mu_prior + torch.sqrt(var_prior) * eps\n\n            latent_context = torch.cat([context_outputs, z_sent], 1)\n            decoder_init = self.context2decoder(latent_context)\n            decoder_init = decoder_init.view(self.decoder.num_layers, -1, self.decoder.hidden_size)\n\n            if self.config.sample:\n                prediction = self.decoder(None, decoder_init)\n                p = prediction.data.cpu().numpy()\n                length = torch.from_numpy(np.where(p == EOS_ID)[1])\n            else:\n                prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init)\n                # prediction: [batch_size, seq_len]\n                prediction = prediction[:, 0, :]\n                # length: [batch_size]\n                length = [l[0] for l in length]\n                length = to_var(torch.LongTensor(length))\n\n            samples.append(prediction)\n\n            encoder_outputs, encoder_hidden = self.encoder(prediction,\n                                                           length)\n\n            encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1)\n\n            context_outputs, context_hidden = self.context_encoder.step(encoder_hidden,\n                                                                        context_hidden)\n\n        samples = torch.stack(samples, 1)\n        return samples\n\n\nclass VHCR(nn.Module):\n    def __init__(self, config):\n        super(VHCR, self).__init__()\n\n        self.config = config\n        self.encoder = layers.EncoderRNN(config.vocab_size,\n                                         config.embedding_size,\n                                         config.encoder_hidden_size,\n                                         config.rnn,\n                                         config.num_layers,\n                                         config.bidirectional,\n                                         config.dropout)\n\n        context_input_size = (config.num_layers\n                              * config.encoder_hidden_size\n                              * self.encoder.num_directions + config.z_conv_size)\n        self.context_encoder = layers.ContextRNN(context_input_size,\n                                                 config.context_size,\n                                                 config.rnn,\n                                                 config.num_layers,\n                                                 config.dropout)\n\n        self.unk_sent = nn.Parameter(torch.randn(context_input_size - config.z_conv_size))\n\n        self.z_conv2context = layers.FeedForward(config.z_conv_size,\n                                                 config.num_layers * config.context_size,\n                                                 num_layers=1,\n                                                 activation=config.activation)\n\n        context_input_size = (config.num_layers\n                              * config.encoder_hidden_size\n                              * self.encoder.num_directions)\n        self.context_inference = layers.ContextRNN(context_input_size,\n                                                   config.context_size,\n                                                   config.rnn,\n                                                   config.num_layers,\n                                                   config.dropout,\n                                                   bidirectional=True)\n\n        self.decoder = layers.DecoderRNN(config.vocab_size,\n                                        config.embedding_size,\n                                        config.decoder_hidden_size,\n                                        config.rnncell,\n                                        config.num_layers,\n                                        config.dropout,\n                                        config.word_drop,\n                                        config.max_unroll,\n                                        config.sample,\n                                        config.temperature,\n                                        config.beam_size)\n\n        self.context2decoder = layers.FeedForward(config.context_size + config.z_sent_size + config.z_conv_size,\n                                                  config.num_layers * config.decoder_hidden_size,\n                                                  num_layers=1,\n                                                  activation=config.activation)\n\n        self.softplus = nn.Softplus()\n\n        self.conv_posterior_h = layers.FeedForward(config.num_layers * self.context_inference.num_directions * config.context_size,\n                                                    config.context_size,\n                                                    num_layers=2,\n                                                    hidden_size=config.context_size,\n                                                    activation=config.activation)\n        self.conv_posterior_mu = nn.Linear(config.context_size,\n                                            config.z_conv_size)\n        self.conv_posterior_var = nn.Linear(config.context_size,\n                                             config.z_conv_size)\n\n        self.sent_prior_h = layers.FeedForward(config.context_size + config.z_conv_size,\n                                               config.context_size,\n                                               num_layers=1,\n                                               hidden_size=config.z_sent_size,\n                                               activation=config.activation)\n        self.sent_prior_mu = nn.Linear(config.context_size,\n                                       config.z_sent_size)\n        self.sent_prior_var = nn.Linear(config.context_size,\n                                        config.z_sent_size)\n\n        self.sent_posterior_h = layers.FeedForward(config.z_conv_size + config.encoder_hidden_size * self.encoder.num_directions * config.num_layers + config.context_size,\n                                                   config.context_size,\n                                                   num_layers=2,\n                                                   hidden_size=config.context_size,\n                                                   activation=config.activation)\n        self.sent_posterior_mu = nn.Linear(config.context_size,\n                                           config.z_sent_size)\n        self.sent_posterior_var = nn.Linear(config.context_size,\n                                            config.z_sent_size)\n\n        if config.tie_embedding:\n            self.decoder.embedding = self.encoder.embedding\n\n    def conv_prior(self):\n        # Standard gaussian prior\n        return to_var(torch.FloatTensor([0.0])), to_var(torch.FloatTensor([1.0]))\n\n    def conv_posterior(self, context_inference_hidden):\n        h_posterior = self.conv_posterior_h(context_inference_hidden)\n        mu_posterior = self.conv_posterior_mu(h_posterior)\n        var_posterior = self.softplus(self.conv_posterior_var(h_posterior))\n        return mu_posterior, var_posterior\n\n    def sent_prior(self, context_outputs, z_conv):\n        # Context dependent prior\n        h_prior = self.sent_prior_h(torch.cat([context_outputs, z_conv], dim=1))\n        mu_prior = self.sent_prior_mu(h_prior)\n        var_prior = self.softplus(self.sent_prior_var(h_prior))\n        return mu_prior, var_prior\n\n    def sent_posterior(self, context_outputs, encoder_hidden, z_conv):\n        h_posterior = self.sent_posterior_h(torch.cat([context_outputs, encoder_hidden, z_conv], 1))\n        mu_posterior = self.sent_posterior_mu(h_posterior)\n        var_posterior = self.softplus(self.sent_posterior_var(h_posterior))\n        return mu_posterior, var_posterior\n\n    def forward(self, sentences, sentence_length,\n                input_conversation_length, target_sentences, decode=False):\n        \"\"\"\n        Args:\n            sentences: (Variable, LongTensor) [num_sentences + batch_size, seq_len]\n            target_sentences: (Variable, LongTensor) [num_sentences, seq_len]\n        Return:\n            decoder_outputs: (Variable, FloatTensor)\n                - train: [batch_size, seq_len, vocab_size]\n                - eval: [batch_size, seq_len]\n        \"\"\"\n        batch_size = input_conversation_length.size(0)\n        num_sentences = sentences.size(0) - batch_size\n        max_len = input_conversation_length.data.max().item()\n\n        # encoder_outputs: [num_sentences + batch_size, max_source_length, hidden_size]\n        # encoder_hidden: [num_layers * direction, num_sentences + batch_size, hidden_size]\n        encoder_outputs, encoder_hidden = self.encoder(sentences,\n                                                       sentence_length)\n\n        # encoder_hidden: [num_sentences + batch_size, num_layers * direction * hidden_size]\n        encoder_hidden = encoder_hidden.transpose(\n            1, 0).contiguous().view(num_sentences + batch_size, -1)\n\n        # pad and pack encoder_hidden\n        start = torch.cumsum(torch.cat((to_var(input_conversation_length.data.new(1).zero_()),\n                                        input_conversation_length[:-1] + 1)), 0)\n        # encoder_hidden: [batch_size, max_len + 1, num_layers * direction * hidden_size]\n        encoder_hidden = torch.stack([pad(encoder_hidden.narrow(0, s, l + 1), max_len + 1)\n                                      for s, l in zip(start.data.tolist(),\n                                                      input_conversation_length.data.tolist())], 0)\n\n        # encoder_hidden_inference: [batch_size, max_len, num_layers * direction * hidden_size]\n        encoder_hidden_inference = encoder_hidden[:, 1:, :]\n        encoder_hidden_inference_flat = torch.cat(\n            [encoder_hidden_inference[i, :l, :] for i, l in enumerate(input_conversation_length.data)])\n\n        # encoder_hidden_input: [batch_size, max_len, num_layers * direction * hidden_size]\n        encoder_hidden_input = encoder_hidden[:, :-1, :]\n\n        # Standard Gaussian prior\n        conv_eps = to_var(torch.randn([batch_size, self.config.z_conv_size]))\n        conv_mu_prior, conv_var_prior = self.conv_prior()\n\n        if not decode:\n            if self.config.sentence_drop > 0.0:\n                indices = np.where(np.random.rand(max_len) < self.config.sentence_drop)[0]\n                if len(indices) > 0:\n                    encoder_hidden_input[:, indices, :] = self.unk_sent\n\n            # context_inference_outputs: [batch_size, max_len, num_directions * context_size]\n            # context_inference_hidden: [num_layers * num_directions, batch_size, hidden_size]\n            context_inference_outputs, context_inference_hidden = self.context_inference(encoder_hidden,\n                                                                                            input_conversation_length + 1)\n\n            # context_inference_hidden: [batch_size, num_layers * num_directions * hidden_size]\n            context_inference_hidden = context_inference_hidden.transpose(\n                1, 0).contiguous().view(batch_size, -1)\n            conv_mu_posterior, conv_var_posterior = self.conv_posterior(context_inference_hidden)\n            z_conv = conv_mu_posterior + torch.sqrt(conv_var_posterior) * conv_eps\n            log_q_zx_conv = normal_logpdf(z_conv, conv_mu_posterior, conv_var_posterior).sum()\n\n            log_p_z_conv = normal_logpdf(z_conv, conv_mu_prior, conv_var_prior).sum()\n            kl_div_conv = normal_kl_div(conv_mu_posterior, conv_var_posterior,\n                                            conv_mu_prior, conv_var_prior).sum()\n\n            context_init = self.z_conv2context(z_conv).view(\n                self.config.num_layers, batch_size, self.config.context_size)\n\n            z_conv_expand = z_conv.view(z_conv.size(0), 1, z_conv.size(\n                1)).expand(z_conv.size(0), max_len, z_conv.size(1))\n            context_outputs, context_last_hidden = self.context_encoder(\n                torch.cat([encoder_hidden_input, z_conv_expand], 2),\n                input_conversation_length,\n                hidden=context_init)\n\n            # flatten outputs\n            # context_outputs: [num_sentences, context_size]\n            context_outputs = torch.cat([context_outputs[i, :l, :]\n                                         for i, l in enumerate(input_conversation_length.data)])\n\n            z_conv_flat = torch.cat(\n                [z_conv_expand[i, :l, :] for i, l in enumerate(input_conversation_length.data)])\n            sent_mu_prior, sent_var_prior = self.sent_prior(context_outputs, z_conv_flat)\n            eps = to_var(torch.randn((num_sentences, self.config.z_sent_size)))\n\n            sent_mu_posterior, sent_var_posterior = self.sent_posterior(\n                context_outputs, encoder_hidden_inference_flat, z_conv_flat)\n            z_sent = sent_mu_posterior + torch.sqrt(sent_var_posterior) * eps\n            log_q_zx_sent = normal_logpdf(z_sent, sent_mu_posterior, sent_var_posterior).sum()\n\n            log_p_z_sent = normal_logpdf(z_sent, sent_mu_prior, sent_var_prior).sum()\n            # kl_div: [num_sentences]\n            kl_div_sent = normal_kl_div(sent_mu_posterior, sent_var_posterior,\n                                        sent_mu_prior, sent_var_prior).sum()\n\n            kl_div = kl_div_conv + kl_div_sent\n            log_q_zx = log_q_zx_conv + log_q_zx_sent\n            log_p_z = log_p_z_conv + log_p_z_sent\n        else:\n            z_conv = conv_mu_prior + torch.sqrt(conv_var_prior) * conv_eps\n            context_init = self.z_conv2context(z_conv).view(\n                self.config.num_layers, batch_size, self.config.context_size)\n\n            z_conv_expand = z_conv.view(z_conv.size(0), 1, z_conv.size(\n                1)).expand(z_conv.size(0), max_len, z_conv.size(1))\n            # context_outputs: [batch_size, max_len, context_size]\n            context_outputs, context_last_hidden = self.context_encoder(\n                torch.cat([encoder_hidden_input, z_conv_expand], 2),\n                input_conversation_length,\n                hidden=context_init)\n            # flatten outputs\n            # context_outputs: [num_sentences, context_size]\n            context_outputs = torch.cat([context_outputs[i, :l, :]\n                                         for i, l in enumerate(input_conversation_length.data)])\n\n\n            z_conv_flat = torch.cat(\n                [z_conv_expand[i, :l, :] for i, l in enumerate(input_conversation_length.data)])\n            sent_mu_prior, sent_var_prior = self.sent_prior(context_outputs, z_conv_flat)\n            eps = to_var(torch.randn((num_sentences, self.config.z_sent_size)))\n\n            z_sent = sent_mu_prior + torch.sqrt(sent_var_prior) * eps\n            kl_div = None\n            log_p_z = normal_logpdf(z_sent, sent_mu_prior, sent_var_prior).sum()\n            log_p_z += normal_logpdf(z_conv, conv_mu_prior, conv_var_prior).sum()\n            log_q_zx = None\n\n        # expand z_conv to all associated sentences\n        z_conv = torch.cat([z.view(1, -1).expand(m.item(), self.config.z_conv_size)\n                             for z, m in zip(z_conv, input_conversation_length)])\n\n        # latent_context: [num_sentences, context_size + z_sent_size +\n        # z_conv_size]\n        latent_context = torch.cat([context_outputs, z_sent, z_conv], 1)\n        decoder_init = self.context2decoder(latent_context)\n        decoder_init = decoder_init.view(-1,\n                                         self.decoder.num_layers,\n                                         self.decoder.hidden_size)\n        decoder_init = decoder_init.transpose(1, 0).contiguous()\n\n        # train: [batch_size, seq_len, vocab_size]\n        # eval: [batch_size, seq_len]\n        if not decode:\n            decoder_outputs = self.decoder(target_sentences,\n                                            init_h=decoder_init,\n                                            decode=decode)\n            return decoder_outputs, kl_div, log_p_z, log_q_zx\n\n        else:\n            # prediction: [batch_size, beam_size, max_unroll]\n            prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init)\n            return prediction, kl_div, log_p_z, log_q_zx\n\n    def generate(self, context, sentence_length, n_context):\n        # context: [batch_size, n_context, seq_len]\n        batch_size = context.size(0)\n        # n_context = context.size(1)\n        samples = []\n\n        # Run for context\n\n        conv_eps = to_var(torch.randn([batch_size, self.config.z_conv_size]))\n        # conv_mu_prior, conv_var_prior = self.conv_prior()\n        # z_conv = conv_mu_prior + torch.sqrt(conv_var_prior) * conv_eps\n\n        encoder_hidden_list = []\n        for i in range(n_context):\n            # encoder_outputs: [batch_size, seq_len, hidden_size * direction]\n            # encoder_hidden: [num_layers * direction, batch_size, hidden_size]\n            encoder_outputs, encoder_hidden = self.encoder(context[:, i, :],\n                                                           sentence_length[:, i])\n\n            # encoder_hidden: [batch_size, num_layers * direction * hidden_size]\n            encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1)\n            encoder_hidden_list.append(encoder_hidden)\n\n        encoder_hidden = torch.stack(encoder_hidden_list, 1)\n        context_inference_outputs, context_inference_hidden = self.context_inference(encoder_hidden,\n                                                                                     to_var(torch.LongTensor([n_context] * batch_size)))\n        context_inference_hidden = context_inference_hidden.transpose(\n            1, 0).contiguous().view(batch_size, -1)\n        conv_mu_posterior, conv_var_posterior = self.conv_posterior(context_inference_hidden)\n        z_conv = conv_mu_posterior + torch.sqrt(conv_var_posterior) * conv_eps\n\n        context_init = self.z_conv2context(z_conv).view(\n            self.config.num_layers, batch_size, self.config.context_size)\n\n        context_hidden = context_init\n        for i in range(n_context):\n            # encoder_outputs: [batch_size, seq_len, hidden_size * direction]\n            # encoder_hidden: [num_layers * direction, batch_size, hidden_size]\n            encoder_outputs, encoder_hidden = self.encoder(context[:, i, :],\n                                                           sentence_length[:, i])\n\n            # encoder_hidden: [batch_size, num_layers * direction *\n            encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1)\n            encoder_hidden_list.append(encoder_hidden)\n            # context_outputs: [batch_size, 1, context_hidden_size * direction]\n            # context_hidden: [num_layers * direction, batch_size, context_hidden_size]\n            context_outputs, context_hidden = self.context_encoder.step(torch.cat([encoder_hidden, z_conv], 1),\n                                                                        context_hidden)\n\n        # Run for generation\n        for j in range(self.config.n_sample_step):\n            # context_outputs: [batch_size, context_hidden_size * direction]\n            context_outputs = context_outputs.squeeze(1)\n\n            mu_prior, var_prior = self.sent_prior(context_outputs, z_conv)\n            eps = to_var(torch.randn((batch_size, self.config.z_sent_size)))\n            z_sent = mu_prior + torch.sqrt(var_prior) * eps\n\n            latent_context = torch.cat([context_outputs, z_sent, z_conv], 1)\n            decoder_init = self.context2decoder(latent_context)\n            decoder_init = decoder_init.view(self.decoder.num_layers, -1, self.decoder.hidden_size)\n\n            if self.config.sample:\n                prediction = self.decoder(None, decoder_init, decode=True)\n                p = prediction.data.cpu().numpy()\n                length = torch.from_numpy(np.where(p == EOS_ID)[1])\n            else:\n                prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init)\n                # prediction: [batch_size, seq_len]\n                prediction = prediction[:, 0, :]\n                # length: [batch_size]\n                length = [l[0] for l in length]\n                length = to_var(torch.LongTensor(length))\n\n            samples.append(prediction)\n\n            encoder_outputs, encoder_hidden = self.encoder(prediction,\n                                                           length)\n\n            encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1)\n\n            context_outputs, context_hidden = self.context_encoder.step(torch.cat([encoder_hidden, z_conv], 1),\n                                                                        context_hidden)\n\n        samples = torch.stack(samples, 1)\n        return samples\n"
  },
  {
    "path": "model/solver.py",
    "content": "from itertools import cycle\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport models\nfrom layers import masked_cross_entropy\nfrom utils import to_var, time_desc_decorator, TensorboardWriter, pad_and_pack, normal_kl_div, to_bow, bag_of_words_loss, normal_kl_div, embedding_metric\nimport os\nfrom tqdm import tqdm\nfrom math import isnan\nimport re\nimport math\nimport pickle\nimport gensim\n\nword2vec_path = \"../datasets/GoogleNews-vectors-negative300.bin\"\n\nclass Solver(object):\n    def __init__(self, config, train_data_loader, eval_data_loader, vocab, is_train=True, model=None):\n        self.config = config\n        self.epoch_i = 0\n        self.train_data_loader = train_data_loader\n        self.eval_data_loader = eval_data_loader\n        self.vocab = vocab\n        self.is_train = is_train\n        self.model = model\n\n    @time_desc_decorator('Build Graph')\n    def build(self, cuda=True):\n\n        if self.model is None:\n            self.model = getattr(models, self.config.model)(self.config)\n\n            # orthogonal initialiation for hidden weights\n            # input gate bias for GRUs\n            if self.config.mode == 'train' and self.config.checkpoint is None:\n                print('Parameter initiailization')\n                for name, param in self.model.named_parameters():\n                    if 'weight_hh' in name:\n                        print('\\t' + name)\n                        nn.init.orthogonal_(param)\n\n                    # bias_hh is concatenation of reset, input, new gates\n                    # only set the input gate bias to 2.0\n                    if 'bias_hh' in name:\n                        print('\\t' + name)\n                        dim = int(param.size(0) / 3)\n                        param.data[dim:2 * dim].fill_(2.0)\n\n        if torch.cuda.is_available() and cuda:\n            self.model.cuda()\n\n        # Overview Parameters\n        print('Model Parameters')\n        for name, param in self.model.named_parameters():\n            print('\\t' + name + '\\t', list(param.size()))\n\n        if self.config.checkpoint:\n            self.load_model(self.config.checkpoint)\n\n        if self.is_train:\n            self.writer = TensorboardWriter(self.config.logdir)\n            self.optimizer = self.config.optimizer(\n                filter(lambda p: p.requires_grad, self.model.parameters()),\n                lr=self.config.learning_rate)\n\n    def save_model(self, epoch):\n        \"\"\"Save parameters to checkpoint\"\"\"\n        ckpt_path = os.path.join(self.config.save_path, f'{epoch}.pkl')\n        print(f'Save parameters to {ckpt_path}')\n        torch.save(self.model.state_dict(), ckpt_path)\n\n    def load_model(self, checkpoint):\n        \"\"\"Load parameters from checkpoint\"\"\"\n        print(f'Load parameters from {checkpoint}')\n        epoch = re.match(r\"[0-9]*\", os.path.basename(checkpoint)).group(0)\n        self.epoch_i = int(epoch)\n        self.model.load_state_dict(torch.load(checkpoint))\n\n    def write_summary(self, epoch_i):\n        epoch_loss = getattr(self, 'epoch_loss', None)\n        if epoch_loss is not None:\n            self.writer.update_loss(\n                loss=epoch_loss,\n                step_i=epoch_i + 1,\n                name='train_loss')\n\n        epoch_recon_loss = getattr(self, 'epoch_recon_loss', None)\n        if epoch_recon_loss is not None:\n            self.writer.update_loss(\n                loss=epoch_recon_loss,\n                step_i=epoch_i + 1,\n                name='train_recon_loss')\n\n        epoch_kl_div = getattr(self, 'epoch_kl_div', None)\n        if epoch_kl_div is not None:\n            self.writer.update_loss(\n                loss=epoch_kl_div,\n                step_i=epoch_i + 1,\n                name='train_kl_div')\n\n        kl_mult = getattr(self, 'kl_mult', None)\n        if kl_mult is not None:\n            self.writer.update_loss(\n                loss=kl_mult,\n                step_i=epoch_i + 1,\n                name='kl_mult')\n\n        epoch_bow_loss = getattr(self, 'epoch_bow_loss', None)\n        if epoch_bow_loss is not None:\n            self.writer.update_loss(\n                loss=epoch_bow_loss,\n                step_i=epoch_i + 1,\n                name='bow_loss')\n\n        validation_loss = getattr(self, 'validation_loss', None)\n        if validation_loss is not None:\n            self.writer.update_loss(\n                loss=validation_loss,\n                step_i=epoch_i + 1,\n                name='validation_loss')\n\n    @time_desc_decorator('Training Start!')\n    def train(self):\n        epoch_loss_history = []\n        for epoch_i in range(self.epoch_i, self.config.n_epoch):\n            self.epoch_i = epoch_i\n            batch_loss_history = []\n            self.model.train()\n            n_total_words = 0\n            for batch_i, (conversations, conversation_length, sentence_length) in enumerate(tqdm(self.train_data_loader, ncols=80)):\n                # conversations: (batch_size) list of conversations\n                #   conversation: list of sentences\n                #   sentence: list of tokens\n                # conversation_length: list of int\n                # sentence_length: (batch_size) list of conversation list of sentence_lengths\n\n                input_conversations = [conv[:-1] for conv in conversations]\n                target_conversations = [conv[1:] for conv in conversations]\n\n                # flatten input and target conversations\n                input_sentences = [sent for conv in input_conversations for sent in conv]\n                target_sentences = [sent for conv in target_conversations for sent in conv]\n                input_sentence_length = [l for len_list in sentence_length for l in len_list[:-1]]\n                target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]]\n                input_conversation_length = [l - 1 for l in conversation_length]\n\n                input_sentences = to_var(torch.LongTensor(input_sentences))\n                target_sentences = to_var(torch.LongTensor(target_sentences))\n                input_sentence_length = to_var(torch.LongTensor(input_sentence_length))\n                target_sentence_length = to_var(torch.LongTensor(target_sentence_length))\n                input_conversation_length = to_var(torch.LongTensor(input_conversation_length))\n\n                # reset gradient\n                self.optimizer.zero_grad()\n\n                sentence_logits = self.model(\n                    input_sentences,\n                    input_sentence_length,\n                    input_conversation_length,\n                    target_sentences,\n                    decode=False)\n\n                batch_loss, n_words = masked_cross_entropy(\n                    sentence_logits,\n                    target_sentences,\n                    target_sentence_length)\n\n                assert not isnan(batch_loss.item())\n                batch_loss_history.append(batch_loss.item())\n                n_total_words += n_words.item()\n\n                if batch_i % self.config.print_every == 0:\n                    tqdm.write(\n                        f'Epoch: {epoch_i+1}, iter {batch_i}: loss = {batch_loss.item()/ n_words.item():.3f}')\n\n                # Back-propagation\n                batch_loss.backward()\n\n                # Gradient cliping\n                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip)\n\n                # Run optimizer\n                self.optimizer.step()\n\n            epoch_loss = np.sum(batch_loss_history) / n_total_words\n            epoch_loss_history.append(epoch_loss)\n            self.epoch_loss = epoch_loss\n\n            print_str = f'Epoch {epoch_i+1} loss average: {epoch_loss:.3f}'\n            print(print_str)\n\n            if epoch_i % self.config.save_every_epoch == 0:\n                self.save_model(epoch_i + 1)\n\n            print('\\n<Validation>...')\n            self.validation_loss = self.evaluate()\n\n            if epoch_i % self.config.plot_every_epoch == 0:\n                    self.write_summary(epoch_i)\n\n        self.save_model(self.config.n_epoch)\n\n        return epoch_loss_history\n\n    def generate_sentence(self, input_sentences, input_sentence_length,\n                          input_conversation_length, target_sentences):\n        self.model.eval()\n\n        # [batch_size, max_seq_len, vocab_size]\n        generated_sentences = self.model(\n            input_sentences,\n            input_sentence_length,\n            input_conversation_length,\n            target_sentences,\n            decode=True)\n\n        # write output to file\n        with open(os.path.join(self.config.save_path, 'samples.txt'), 'a') as f:\n            f.write(f'<Epoch {self.epoch_i}>\\n\\n')\n\n            tqdm.write('\\n<Samples>')\n            for input_sent, target_sent, output_sent in zip(input_sentences, target_sentences, generated_sentences):\n                input_sent = self.vocab.decode(input_sent)\n                target_sent = self.vocab.decode(target_sent)\n                output_sent = '\\n'.join([self.vocab.decode(sent) for sent in output_sent])\n                s = '\\n'.join(['Input sentence: ' + input_sent,\n                               'Ground truth: ' + target_sent,\n                               'Generated response: ' + output_sent + '\\n'])\n                f.write(s + '\\n')\n                print(s)\n            print('')\n\n    def evaluate(self):\n        self.model.eval()\n        batch_loss_history = []\n        n_total_words = 0\n        for batch_i, (conversations, conversation_length, sentence_length) in enumerate(tqdm(self.eval_data_loader, ncols=80)):\n            # conversations: (batch_size) list of conversations\n            #   conversation: list of sentences\n            #   sentence: list of tokens\n            # conversation_length: list of int\n            # sentence_length: (batch_size) list of conversation list of sentence_lengths\n\n            input_conversations = [conv[:-1] for conv in conversations]\n            target_conversations = [conv[1:] for conv in conversations]\n\n            # flatten input and target conversations\n            input_sentences = [sent for conv in input_conversations for sent in conv]\n            target_sentences = [sent for conv in target_conversations for sent in conv]\n            input_sentence_length = [l for len_list in sentence_length for l in len_list[:-1]]\n            target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]]\n            input_conversation_length = [l - 1 for l in conversation_length]\n\n            with torch.no_grad():\n                input_sentences = to_var(torch.LongTensor(input_sentences))\n                target_sentences = to_var(torch.LongTensor(target_sentences))\n                input_sentence_length = to_var(torch.LongTensor(input_sentence_length))\n                target_sentence_length = to_var(torch.LongTensor(target_sentence_length))\n                input_conversation_length = to_var(\n                    torch.LongTensor(input_conversation_length))\n\n            if batch_i == 0:\n                self.generate_sentence(input_sentences,\n                                       input_sentence_length,\n                                       input_conversation_length,\n                                       target_sentences)\n\n            sentence_logits = self.model(\n                input_sentences,\n                input_sentence_length,\n                input_conversation_length,\n                target_sentences)\n\n            batch_loss, n_words = masked_cross_entropy(\n                sentence_logits,\n                target_sentences,\n                target_sentence_length)\n\n            assert not isnan(batch_loss.item())\n            batch_loss_history.append(batch_loss.item())\n            n_total_words += n_words.item()\n\n        epoch_loss = np.sum(batch_loss_history) / n_total_words\n\n        print_str = f'Validation loss: {epoch_loss:.3f}\\n'\n        print(print_str)\n\n        return epoch_loss\n\n    def test(self):\n        self.model.eval()\n        batch_loss_history = []\n        n_total_words = 0\n        for batch_i, (conversations, conversation_length, sentence_length) in enumerate(tqdm(self.eval_data_loader, ncols=80)):\n            # conversations: (batch_size) list of conversations\n            #   conversation: list of sentences\n            #   sentence: list of tokens\n            # conversation_length: list of int\n            # sentence_length: (batch_size) list of conversation list of sentence_lengths\n\n            input_conversations = [conv[:-1] for conv in conversations]\n            target_conversations = [conv[1:] for conv in conversations]\n\n            # flatten input and target conversations\n            input_sentences = [sent for conv in input_conversations for sent in conv]\n            target_sentences = [sent for conv in target_conversations for sent in conv]\n            input_sentence_length = [l for len_list in sentence_length for l in len_list[:-1]]\n            target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]]\n            input_conversation_length = [l - 1 for l in conversation_length]\n\n            with torch.no_grad():\n                input_sentences = to_var(torch.LongTensor(input_sentences))\n                target_sentences = to_var(torch.LongTensor(target_sentences))\n                input_sentence_length = to_var(torch.LongTensor(input_sentence_length))\n                target_sentence_length = to_var(torch.LongTensor(target_sentence_length))\n                input_conversation_length = to_var(torch.LongTensor(input_conversation_length))\n\n            sentence_logits = self.model(\n                input_sentences,\n                input_sentence_length,\n                input_conversation_length,\n                target_sentences)\n\n            batch_loss, n_words = masked_cross_entropy(\n                sentence_logits,\n                target_sentences,\n                target_sentence_length)\n\n            assert not isnan(batch_loss.item())\n            batch_loss_history.append(batch_loss.item())\n            n_total_words += n_words.item()\n\n        epoch_loss = np.sum(batch_loss_history) / n_total_words\n\n        print(f'Number of words: {n_total_words}')\n        print(f'Bits per word: {epoch_loss:.3f}')\n        word_perplexity = np.exp(epoch_loss)\n\n        print_str = f'Word perplexity : {word_perplexity:.3f}\\n'\n        print(print_str)\n\n        return word_perplexity\n\n    def embedding_metric(self):\n        word2vec =  getattr(self, 'word2vec', None)\n        if word2vec is None:\n            print('Loading word2vec model')\n            word2vec = gensim.models.KeyedVectors.load_word2vec_format(word2vec_path, binary=True)\n            self.word2vec = word2vec\n        keys = word2vec.vocab\n        self.model.eval()\n        n_context = self.config.n_context\n        n_sample_step = self.config.n_sample_step\n        metric_average_history = []\n        metric_extrema_history = []\n        metric_greedy_history = []\n        context_history = []\n        sample_history = []\n        n_sent = 0\n        n_conv = 0\n        for batch_i, (conversations, conversation_length, sentence_length) \\\n                in enumerate(tqdm(self.eval_data_loader, ncols=80)):\n            # conversations: (batch_size) list of conversations\n            #   conversation: list of sentences\n            #   sentence: list of tokens\n            # conversation_length: list of int\n            # sentence_length: (batch_size) list of conversation list of sentence_lengths\n\n            conv_indices = [i for i in range(len(conversations)) if len(conversations[i]) >= n_context + n_sample_step]\n            context = [c for i in conv_indices for c in [conversations[i][:n_context]]]\n            ground_truth = [c for i in conv_indices for c in [conversations[i][n_context:n_context + n_sample_step]]]\n            sentence_length = [c for i in conv_indices for c in [sentence_length[i][:n_context]]]\n\n            with torch.no_grad():\n                context = to_var(torch.LongTensor(context))\n                sentence_length = to_var(torch.LongTensor(sentence_length))\n\n            samples = self.model.generate(context, sentence_length, n_context)\n\n            context = context.data.cpu().numpy().tolist()\n            samples = samples.data.cpu().numpy().tolist()\n            context_history.append(context)\n            sample_history.append(samples)\n\n            samples = [[self.vocab.decode(sent) for sent in c] for c in samples]\n            ground_truth = [[self.vocab.decode(sent) for sent in c] for c in ground_truth]\n\n            samples = [sent for c in samples for sent in c]\n            ground_truth = [sent for c in ground_truth for sent in c]\n\n            samples = [[word2vec[s] for s in sent.split() if s in keys] for sent in samples]\n            ground_truth = [[word2vec[s] for s in sent.split() if s in keys] for sent in ground_truth]\n\n            indices = [i for i, s, g in zip(range(len(samples)), samples, ground_truth) if s != [] and g != []]\n            samples = [samples[i] for i in indices]\n            ground_truth = [ground_truth[i] for i in indices]\n            n = len(samples)\n            n_sent += n\n\n            metric_average = embedding_metric(samples, ground_truth, word2vec, 'average')\n            metric_extrema = embedding_metric(samples, ground_truth, word2vec, 'extrema')\n            metric_greedy = embedding_metric(samples, ground_truth, word2vec, 'greedy')\n            metric_average_history.append(metric_average)\n            metric_extrema_history.append(metric_extrema)\n            metric_greedy_history.append(metric_greedy)\n\n        epoch_average = np.mean(np.concatenate(metric_average_history), axis=0)\n        epoch_extrema = np.mean(np.concatenate(metric_extrema_history), axis=0)\n        epoch_greedy = np.mean(np.concatenate(metric_greedy_history), axis=0)\n\n        print('n_sentences:', n_sent)\n        print_str = f'Metrics - Average: {epoch_average:.3f}, Extrema: {epoch_extrema:.3f}, Greedy: {epoch_greedy:.3f}'\n        print(print_str)\n        print('\\n')\n\n        return epoch_average, epoch_extrema, epoch_greedy\n\n\nclass VariationalSolver(Solver):\n\n    def __init__(self, config, train_data_loader, eval_data_loader, vocab, is_train=True, model=None):\n        self.config = config\n        self.epoch_i = 0\n        self.train_data_loader = train_data_loader\n        self.eval_data_loader = eval_data_loader\n        self.vocab = vocab\n        self.is_train = is_train\n        self.model = model\n\n    @time_desc_decorator('Training Start!')\n    def train(self):\n        epoch_loss_history = []\n        kl_mult = 0.0\n        conv_kl_mult = 0.0\n        for epoch_i in range(self.epoch_i, self.config.n_epoch):\n            self.epoch_i = epoch_i\n            batch_loss_history = []\n            recon_loss_history = []\n            kl_div_history = []\n            kl_div_sent_history = []\n            kl_div_conv_history = []\n            bow_loss_history = []\n            self.model.train()\n            n_total_words = 0\n\n            # self.evaluate()\n\n            for batch_i, (conversations, conversation_length, sentence_length) \\\n                    in enumerate(tqdm(self.train_data_loader, ncols=80)):\n                # conversations: (batch_size) list of conversations\n                #   conversation: list of sentences\n                #   sentence: list of tokens\n                # conversation_length: list of int\n                # sentence_length: (batch_size) list of conversation list of sentence_lengths\n\n                target_conversations = [conv[1:] for conv in conversations]\n\n                # flatten input and target conversations\n                sentences = [sent for conv in conversations for sent in conv]\n                input_conversation_length = [l - 1 for l in conversation_length]\n                target_sentences = [sent for conv in target_conversations for sent in conv]\n                target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]]\n                sentence_length = [l for len_list in sentence_length for l in len_list]\n\n                sentences = to_var(torch.LongTensor(sentences))\n                sentence_length = to_var(torch.LongTensor(sentence_length))\n                input_conversation_length = to_var(torch.LongTensor(input_conversation_length))\n                target_sentences = to_var(torch.LongTensor(target_sentences))\n                target_sentence_length = to_var(torch.LongTensor(target_sentence_length))\n\n                # reset gradient\n                self.optimizer.zero_grad()\n\n                sentence_logits, kl_div, _, _ = self.model(\n                    sentences,\n                    sentence_length,\n                    input_conversation_length,\n                    target_sentences)\n\n                recon_loss, n_words = masked_cross_entropy(\n                    sentence_logits,\n                    target_sentences,\n                    target_sentence_length)\n\n                batch_loss = recon_loss + kl_mult * kl_div\n                batch_loss_history.append(batch_loss.item())\n                recon_loss_history.append(recon_loss.item())\n                kl_div_history.append(kl_div.item())\n                n_total_words += n_words.item()\n\n                if self.config.bow:\n                    bow_loss = self.model.compute_bow_loss(target_conversations)\n                    batch_loss += bow_loss\n                    bow_loss_history.append(bow_loss.item())\n\n                assert not isnan(batch_loss.item())\n\n                if batch_i % self.config.print_every == 0:\n                    print_str = f'Epoch: {epoch_i+1}, iter {batch_i}: loss = {batch_loss.item() / n_words.item():.3f}, recon = {recon_loss.item() / n_words.item():.3f}, kl_div = {kl_div.item() / n_words.item():.3f}'\n                    if self.config.bow:\n                        print_str += f', bow_loss = {bow_loss.item() / n_words.item():.3f}'\n                    tqdm.write(print_str)\n\n                # Back-propagation\n                batch_loss.backward()\n\n                # Gradient cliping\n                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip)\n\n                # Run optimizer\n                self.optimizer.step()\n                kl_mult = min(kl_mult + 1.0 / self.config.kl_annealing_iter, 1.0)\n\n            epoch_loss = np.sum(batch_loss_history) / n_total_words\n            epoch_loss_history.append(epoch_loss)\n\n            epoch_recon_loss = np.sum(recon_loss_history) / n_total_words\n            epoch_kl_div = np.sum(kl_div_history) / n_total_words\n\n            self.kl_mult = kl_mult\n            self.epoch_loss = epoch_loss\n            self.epoch_recon_loss = epoch_recon_loss\n            self.epoch_kl_div = epoch_kl_div\n\n            print_str = f'Epoch {epoch_i+1} loss average: {epoch_loss:.3f}, recon_loss: {epoch_recon_loss:.3f}, kl_div: {epoch_kl_div:.3f}'\n            if bow_loss_history:\n                self.epoch_bow_loss = np.sum(bow_loss_history) / n_total_words\n                print_str += f', bow_loss = {self.epoch_bow_loss:.3f}'\n            print(print_str)\n\n            if epoch_i % self.config.save_every_epoch == 0:\n                self.save_model(epoch_i + 1)\n\n            print('\\n<Validation>...')\n            self.validation_loss = self.evaluate()\n\n            if epoch_i % self.config.plot_every_epoch == 0:\n                    self.write_summary(epoch_i)\n\n        return epoch_loss_history\n\n    def generate_sentence(self, sentences, sentence_length,\n                          input_conversation_length, input_sentences, target_sentences):\n        \"\"\"Generate output of decoder (single batch)\"\"\"\n        self.model.eval()\n\n        # [batch_size, max_seq_len, vocab_size]\n        generated_sentences, _, _, _ = self.model(\n            sentences,\n            sentence_length,\n            input_conversation_length,\n            target_sentences,\n            decode=True)\n\n        # write output to file\n        with open(os.path.join(self.config.save_path, 'samples.txt'), 'a') as f:\n            f.write(f'<Epoch {self.epoch_i}>\\n\\n')\n\n            tqdm.write('\\n<Samples>')\n            for input_sent, target_sent, output_sent in zip(input_sentences, target_sentences, generated_sentences):\n                input_sent = self.vocab.decode(input_sent)\n                target_sent = self.vocab.decode(target_sent)\n                output_sent = '\\n'.join([self.vocab.decode(sent) for sent in output_sent])\n                s = '\\n'.join(['Input sentence: ' + input_sent,\n                               'Ground truth: ' + target_sent,\n                               'Generated response: ' + output_sent + '\\n'])\n                f.write(s + '\\n')\n                print(s)\n            print('')\n\n    def evaluate(self):\n        self.model.eval()\n        batch_loss_history = []\n        recon_loss_history = []\n        kl_div_history = []\n        bow_loss_history = []\n        n_total_words = 0\n        for batch_i, (conversations, conversation_length, sentence_length) \\\n                in enumerate(tqdm(self.eval_data_loader, ncols=80)):\n            # conversations: (batch_size) list of conversations\n            #   conversation: list of sentences\n            #   sentence: list of tokens\n            # conversation_length: list of int\n            # sentence_length: (batch_size) list of conversation list of sentence_lengths\n\n            target_conversations = [conv[1:] for conv in conversations]\n\n            # flatten input and target conversations\n            sentences = [sent for conv in conversations for sent in conv]\n            input_conversation_length = [l - 1 for l in conversation_length]\n            target_sentences = [sent for conv in target_conversations for sent in conv]\n            target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]]\n            sentence_length = [l for len_list in sentence_length for l in len_list]\n\n            with torch.no_grad():\n                sentences = to_var(torch.LongTensor(sentences))\n                sentence_length = to_var(torch.LongTensor(sentence_length))\n                input_conversation_length = to_var(\n                    torch.LongTensor(input_conversation_length))\n                target_sentences = to_var(torch.LongTensor(target_sentences))\n                target_sentence_length = to_var(torch.LongTensor(target_sentence_length))\n\n            if batch_i == 0:\n                input_conversations = [conv[:-1] for conv in conversations]\n                input_sentences = [sent for conv in input_conversations for sent in conv]\n                with torch.no_grad():\n                    input_sentences = to_var(torch.LongTensor(input_sentences))\n                self.generate_sentence(sentences,\n                                       sentence_length,\n                                       input_conversation_length,\n                                       input_sentences,\n                                       target_sentences)\n\n            sentence_logits, kl_div, _, _ = self.model(\n                sentences,\n                sentence_length,\n                input_conversation_length,\n                target_sentences)\n\n            recon_loss, n_words = masked_cross_entropy(\n                sentence_logits,\n                target_sentences,\n                target_sentence_length)\n\n            batch_loss = recon_loss + kl_div\n            if self.config.bow:\n                bow_loss = self.model.compute_bow_loss(target_conversations)\n                bow_loss_history.append(bow_loss.item())\n\n            assert not isnan(batch_loss.item())\n            batch_loss_history.append(batch_loss.item())\n            recon_loss_history.append(recon_loss.item())\n            kl_div_history.append(kl_div.item())\n            n_total_words += n_words.item()\n\n        epoch_loss = np.sum(batch_loss_history) / n_total_words\n        epoch_recon_loss = np.sum(recon_loss_history) / n_total_words\n        epoch_kl_div = np.sum(kl_div_history) / n_total_words\n\n        print_str = f'Validation loss: {epoch_loss:.3f}, recon_loss: {epoch_recon_loss:.3f}, kl_div: {epoch_kl_div:.3f}'\n        if bow_loss_history:\n            epoch_bow_loss = np.sum(bow_loss_history) / n_total_words\n            print_str += f', bow_loss = {epoch_bow_loss:.3f}'\n        print(print_str)\n        print('\\n')\n\n        return epoch_loss\n\n    def importance_sample(self):\n        ''' Perform importance sampling to get tighter bound\n        '''\n        self.model.eval()\n        weight_history = []\n        n_total_words = 0\n        kl_div_history = []\n        for batch_i, (conversations, conversation_length, sentence_length) \\\n                in enumerate(tqdm(self.eval_data_loader, ncols=80)):\n            # conversations: (batch_size) list of conversations\n            #   conversation: list of sentences\n            #   sentence: list of tokens\n            # conversation_length: list of int\n            # sentence_length: (batch_size) list of conversation list of sentence_lengths\n\n            target_conversations = [conv[1:] for conv in conversations]\n\n            # flatten input and target conversations\n            sentences = [sent for conv in conversations for sent in conv]\n            input_conversation_length = [l - 1 for l in conversation_length]\n            target_sentences = [sent for conv in target_conversations for sent in conv]\n            target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]]\n            sentence_length = [l for len_list in sentence_length for l in len_list]\n\n            # n_words += sum([len([word for word in sent if word != PAD_ID]) for sent in target_sentences])\n            with torch.no_grad():\n                sentences = to_var(torch.LongTensor(sentences))\n                sentence_length = to_var(torch.LongTensor(sentence_length))\n                input_conversation_length = to_var(\n                    torch.LongTensor(input_conversation_length))\n                target_sentences = to_var(torch.LongTensor(target_sentences))\n                target_sentence_length = to_var(torch.LongTensor(target_sentence_length))\n\n            # treat whole batch as one data sample\n            weights = []\n            for j in range(self.config.importance_sample):\n                sentence_logits, kl_div, log_p_z, log_q_zx = self.model(\n                    sentences,\n                    sentence_length,\n                    input_conversation_length,\n                    target_sentences)\n\n                recon_loss, n_words = masked_cross_entropy(\n                    sentence_logits,\n                    target_sentences,\n                    target_sentence_length)\n\n                log_w = (-recon_loss.sum() + log_p_z - log_q_zx).data\n                weights.append(log_w)\n                if j == 0:\n                    n_total_words += n_words.item()\n                    kl_div_history.append(kl_div.item())\n\n            # weights: [n_samples]\n            weights = torch.stack(weights, 0)\n            m = np.floor(weights.max())\n            weights = np.log(torch.exp(weights - m).sum())\n            weights = m + weights - np.log(self.config.importance_sample)\n            weight_history.append(weights)\n\n        print(f'Number of words: {n_total_words}')\n        bits_per_word = -np.sum(weight_history) / n_total_words\n        print(f'Bits per word: {bits_per_word:.3f}')\n        word_perplexity = np.exp(bits_per_word)\n\n        epoch_kl_div = np.sum(kl_div_history) / n_total_words\n\n        print_str = f'Word perplexity upperbound using {self.config.importance_sample} importance samples: {word_perplexity:.3f}, kl_div: {epoch_kl_div:.3f}\\n'\n        print(print_str)\n\n        return word_perplexity\n\n"
  },
  {
    "path": "model/train.py",
    "content": "from solver import *\nfrom data_loader import get_loader\nfrom configs import get_config\nfrom utils import Vocab\nimport os\nimport pickle\nfrom models import VariationalModels\n\ndef load_pickle(path):\n    with open(path, 'rb') as f:\n        return pickle.load(f)\n\n\nif __name__ == '__main__':\n    config = get_config(mode='train')\n    val_config = get_config(mode='valid')\n    print(config)\n    with open(os.path.join(config.save_path, 'config.txt'), 'w') as f:\n        print(config, file=f)\n\n    print('Loading Vocabulary...')\n    vocab = Vocab()\n    vocab.load(config.word2id_path, config.id2word_path)\n    print(f'Vocabulary size: {vocab.vocab_size}')\n\n    config.vocab_size = vocab.vocab_size\n\n    train_data_loader = get_loader(\n        sentences=load_pickle(config.sentences_path),\n        conversation_length=load_pickle(config.conversation_length_path),\n        sentence_length=load_pickle(config.sentence_length_path),\n        vocab=vocab,\n        batch_size=config.batch_size)\n\n    eval_data_loader = get_loader(\n        sentences=load_pickle(val_config.sentences_path),\n        conversation_length=load_pickle(val_config.conversation_length_path),\n        sentence_length=load_pickle(val_config.sentence_length_path),\n        vocab=vocab,\n        batch_size=val_config.eval_batch_size,\n        shuffle=False)\n\n    # for testing\n    # train_data_loader = eval_data_loader\n    if config.model in VariationalModels:\n        solver = VariationalSolver\n    else:\n        solver = Solver\n\n    solver = solver(config, train_data_loader, eval_data_loader, vocab=vocab, is_train=True)\n\n    solver.build()\n    solver.train()\n"
  },
  {
    "path": "model/utils/__init__.py",
    "content": "from .convert import *\nfrom .time_track import time_desc_decorator\nfrom .tensorboard import TensorboardWriter\nfrom .vocab import *\nfrom .mask import *\nfrom .tokenizer import *\nfrom .probability import *\nfrom .pad import *\nfrom .bow import *\nfrom .embedding_metric import *\n"
  },
  {
    "path": "model/utils/bow.py",
    "content": "import numpy as np\nfrom collections import Counter\nimport torch.nn as nn\nfrom torch.nn import functional as F\nimport torch\nfrom math import isnan\nfrom .vocab import PAD_ID, EOS_ID\n\n\ndef to_bow(sentence, vocab_size):\n    '''  Convert a sentence into a bag of words representation\n    Args\n        - sentence: a list of token ids\n        - vocab_size: V\n    Returns\n        - bow: a integer vector of size V\n    '''\n    bow = Counter(sentence)\n    # Remove EOS tokens\n    bow[PAD_ID] = 0\n    bow[EOS_ID] = 0\n\n    x = np.zeros(vocab_size, dtype=np.int64)\n    x[list(bow.keys())] = list(bow.values())\n\n    return x\n\n\ndef bag_of_words_loss(bow_logits, target_bow, weight=None):\n    ''' Calculate bag of words representation loss\n    Args\n        - bow_logits: [num_sentences, vocab_size]\n        - target_bow: [num_sentences]\n    '''\n    log_probs = F.log_softmax(bow_logits, dim=1)\n    target_distribution = target_bow / (target_bow.sum(1).view(-1, 1) + 1e-23) + 1e-23\n    entropy = -(torch.log(target_distribution) * target_bow).sum()\n    loss = -(log_probs * target_bow).sum() - entropy\n\n    return loss\n"
  },
  {
    "path": "model/utils/convert.py",
    "content": "import torch\nfrom torch.autograd import Variable\n\n\ndef to_var(x, on_cpu=False, gpu_id=None, async=False):\n    \"\"\"Tensor => Variable\"\"\"\n    if torch.cuda.is_available() and not on_cpu:\n        x = x.cuda(gpu_id, async)\n        #x = Variable(x)\n    return x\n\n\ndef to_tensor(x):\n    \"\"\"Variable => Tensor\"\"\"\n    if torch.cuda.is_available():\n        x = x.cpu()\n    return x.data\n\ndef reverse_order(tensor, dim=0):\n    \"\"\"Reverse Tensor or Variable\"\"\"\n    if isinstance(tensor, torch.Tensor) or isinstance(tensor, torch.LongTensor):\n        idx = [i for i in range(tensor.size(dim)-1, -1, -1)]\n        idx = torch.LongTensor(idx)\n        inverted_tensor = tensor.index_select(dim, idx)\n    if isinstance(tensor, torch.cuda.FloatTensor) or isinstance(tensor, torch.cuda.LongTensor):\n        idx = [i for i in range(tensor.size(dim)-1, -1, -1)]\n        idx = torch.cuda.LongTensor(idx)\n        inverted_tensor = tensor.index_select(dim, idx)\n        return inverted_tensor\n    elif isinstance(tensor, Variable):\n        variable = tensor\n        variable.data = reverse_order(variable.data, dim)\n        return variable\n\ndef reverse_order_valid(tensor, length_list, dim=0):\n    \"\"\"\n    Reverse Tensor of Variable only in given length\n    Ex)\n    Args:\n        - tensor (Tensor or Variable)\n         1   2   3   4   5   6\n         6   7   8   9   0   0\n        11  12  13   0   0   0\n        16  17   0   0   0   0\n        21  22  23  24  25  26\n\n        - length_list (list)\n        [6, 4, 3, 2, 6]\n \n    Return:\n        tensor (Tensor or Variable; in-place)\n         6   5   4   3   2   1\n         0   0   9   8   7   6\n         0   0   0  13  12  11\n         0   0   0   0  17  16\n        26  25  24  23  22  21\n    \"\"\"\n    for row, length in zip(tensor, length_list):\n        valid_row = row[:length]\n        reversed_valid_row = reverse_order(valid_row, dim=dim)\n        row[:length] = reversed_valid_row\n    return tensor\n"
  },
  {
    "path": "model/utils/embedding_metric.py",
    "content": "import numpy as np\n\n\ndef cosine_similarity(s, g):\n    similarity = np.sum(s * g, axis=1) / np.sqrt((np.sum(s * s, axis=1) * np.sum(g * g, axis=1)))\n\n    # return np.sum(similarity)\n    return similarity\n\n\ndef embedding_metric(samples, ground_truth, word2vec, method='average'):\n\n    if method == 'average':\n        # s, g: [n_samples, word_dim]\n        s = [np.mean(sample, axis=0) for sample in samples]\n        g = [np.mean(gt, axis=0) for gt in ground_truth]\n        return cosine_similarity(np.array(s), np.array(g))\n    elif method == 'extrema':\n        s_list = []\n        g_list = []\n        for sample, gt in zip(samples, ground_truth):\n            s_max = np.max(sample, axis=0)\n            s_min = np.min(sample, axis=0)\n            s_plus = np.absolute(s_min) <= s_max\n            s_abs = np.max(np.absolute(sample), axis=0)\n            s = s_max * s_plus + s_min * np.logical_not(s_plus)\n            s_list.append(s)\n\n            g_max = np.max(gt, axis=0)\n            g_min = np.min(gt, axis=0)\n            g_plus = np.absolute(g_min) <= g_max\n            g_abs = np.max(np.absolute(gt), axis=0)\n            g = g_max * g_plus + g_min * np.logical_not(g_plus)\n            g_list.append(g)\n\n        return cosine_similarity(np.array(s_list), np.array(g_list))\n    elif method == 'greedy':\n        sim_list = []\n        for s, g in zip(samples, ground_truth):\n            s = np.array(s)\n            g = np.array(g).T\n            sim = (np.matmul(s, g)\n                   / np.sqrt(np.matmul(np.sum(s * s, axis=1, keepdims=True), np.sum(g * g, axis=0, keepdims=True))))\n            sim = np.max(sim, axis=0)\n            sim_list.append(np.mean(sim))\n\n        # return np.sum(sim_list)\n        return np.array(sim_list)\n    else:\n        raise NotImplementedError\n"
  },
  {
    "path": "model/utils/mask.py",
    "content": "import torch\nfrom .convert import to_var\n\n\ndef sequence_mask(sequence_length, max_len=None):\n    \"\"\"\n    Args:\n        sequence_length (Variable, LongTensor) [batch_size]\n            - list of sequence length of each batch\n        max_len (int)\n    Return:\n        masks (bool): [batch_size, max_len]\n            - True if current sequence is valid (not padded), False otherwise\n\n    Ex.\n    sequence length: [3, 2, 1]\n\n    seq_length_expand\n    [[3, 3, 3],\n     [2, 2, 2]\n     [1, 1, 1]]\n\n    seq_range_expand\n    [[0, 1, 2]\n     [0, 1, 2],\n     [0, 1, 2]]\n\n    masks\n    [[True, True, True],\n     [True, True, False],\n     [True, False, False]]\n    \"\"\"\n    if max_len is None:\n        max_len = sequence_length.max()\n    batch_size = sequence_length.size(0)\n\n    # [max_len]\n    seq_range = torch.arange(0, max_len).long()  # [0, 1, ... max_len-1]\n\n    # [batch_size, max_len]\n    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)\n    seq_range_expand = to_var(seq_range_expand)\n\n    # [batch_size, max_len]\n    seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand)\n\n    # [batch_size, max_len]\n    masks = seq_range_expand < seq_length_expand\n\n    return masks\n"
  },
  {
    "path": "model/utils/pad.py",
    "content": "import torch\nfrom torch.autograd import Variable\nfrom .convert import to_var\n\n\ndef pad(tensor, length):\n    if isinstance(tensor, Variable):\n        var = tensor\n        if length > var.size(0):\n            return torch.cat([var,\n                              torch.zeros(length - var.size(0), *var.size()[1:]).cuda()])\n        else:\n            return var\n    else:\n        if length > tensor.size(0):\n            return torch.cat([tensor,\n                              torch.zeros(length - tensor.size(0), *tensor.size()[1:]).cuda()])\n        else:\n            return tensor\n\n\ndef pad_and_pack(tensor_list):\n    length_list = ([t.size(0) for t in tensor_list])\n    max_len = max(length_list)\n    padded = [pad(t, max_len) for t in tensor_list]\n    packed = torch.stack(padded, 0)\n    return packed, length_list\n"
  },
  {
    "path": "model/utils/probability.py",
    "content": "import torch\nimport numpy as np\nfrom .convert import to_var\n\n\ndef normal_logpdf(x, mean, var):\n    \"\"\"\n    Args:\n        x: (Variable, FloatTensor) [batch_size, dim]\n        mean: (Variable, FloatTensor) [batch_size, dim] or [batch_size] or [1]\n        var: (Variable, FloatTensor) [batch_size, dim]: positive value\n    Return:\n        log_p: (Variable, FloatTensor) [batch_size]\n    \"\"\"\n\n    pi = to_var(torch.FloatTensor([np.pi]))\n    return 0.5 * torch.sum(-torch.log(2.0 * pi) - torch.log(var) - ((x - mean).pow(2) / var), dim=1)\n\n\ndef normal_kl_div(mu1, var1,\n                  mu2=to_var(torch.FloatTensor([0.0])),\n                  var2=to_var(torch.FloatTensor([1.0]))):\n    one = to_var(torch.FloatTensor([1.0]))\n    return torch.sum(0.5 * (torch.log(var2) - torch.log(var1)\n                            + (var1 + (mu1 - mu2).pow(2)) / var2 - one), 1)\n"
  },
  {
    "path": "model/utils/tensorboard.py",
    "content": "from tensorboardX import SummaryWriter\n\nclass TensorboardWriter(SummaryWriter):\n    def __init__(self, logdir):\n        \"\"\"\n        Extended SummaryWriter Class from tensorboard-pytorch (tensorbaordX)\n        https://github.com/lanpa/tensorboard-pytorch/blob/master/tensorboardX/writer.py\n\n        Internally calls self.file_writer\n        \"\"\"\n        super(TensorboardWriter, self).__init__(logdir)\n        self.logdir = self.file_writer.get_logdir()\n\n    def update_parameters(self, module, step_i):\n        \"\"\"\n        module: nn.Module\n        \"\"\"\n        for name, param in module.named_parameters():\n            self.add_histogram(name, param.clone().cpu().data.numpy(), step_i)\n\n    def update_loss(self, loss, step_i, name='loss'):\n        self.add_scalar(name, loss, step_i)\n\n    def update_histogram(self, values, step_i, name='hist'):\n        self.add_histogram(name, values, step_i)\n"
  },
  {
    "path": "model/utils/time_track.py",
    "content": "import time\nfrom functools import partial\n\n\ndef base_time_desc_decorator(method, desc='test_description'):\n    def timed(*args, **kwargs):\n\n        # Print Description\n        # print('#' * 50)\n        print(desc)\n        # print('#' * 50 + '\\n')\n\n        # Calculation Runtime\n        start = time.time()\n\n        # Run Method\n        try:\n            result = method(*args, **kwargs)\n        except TypeError:\n            result = method(**kwargs)\n\n        # Print Runtime\n        print('Done! It took {:.2} secs\\n'.format(time.time() - start))\n\n        if result is not None:\n            return result\n\n    return timed\n\n\ndef time_desc_decorator(desc): return partial(base_time_desc_decorator, desc=desc)\n\n\n@time_desc_decorator('this is description')\ndef time_test(arg, kwarg='this is kwarg'):\n    time.sleep(3)\n    print('Inside of time_test')\n    print('printing arg: ', arg)\n    print('printing kwarg: ',  kwarg)\n\n\n@time_desc_decorator('this is second description')\ndef no_arg_method():\n    print('this method has no argument')\n\n\nif __name__ == '__main__':\n    time_test('hello', kwarg=3)\n    time_test(3)\n    no_arg_method()\n"
  },
  {
    "path": "model/utils/tokenizer.py",
    "content": "import re\n\n\ndef clean_str(string):\n    \"\"\"\n    Tokenization/string cleaning for all datasets except for SST.\n    Every dataset is lower cased except for TREC\n    \"\"\"\n    string = re.sub(r\"[^A-Za-z0-9,!?\\'\\`\\.]\", \" \", string)\n    string = re.sub(r\"\\.{3}\", \" ...\", string)\n    string = re.sub(r\"\\'s\", \" \\'s\", string)\n    string = re.sub(r\"\\'ve\", \" \\'ve\", string)\n    string = re.sub(r\"n\\'t\", \" n\\'t\", string)\n    string = re.sub(r\"\\'re\", \" \\'re\", string)\n    string = re.sub(r\"\\'d\", \" \\'d\", string)\n    string = re.sub(r\"\\'ll\", \" \\'ll\", string)\n    string = re.sub(r\",\", \" , \", string)\n    string = re.sub(r\"!\", \" ! \", string)\n    string = re.sub(r\"\\?\", \" ? \", string)\n    string = re.sub(r\"\\s{2,}\", \" \", string)\n    return string.strip().lower()\n\n\nclass Tokenizer():\n    def __init__(self, tokenizer='whitespace', clean_string=True):\n        self.clean_string = clean_string\n        tokenizer = tokenizer.lower()\n\n        # Tokenize with whitespace\n        if tokenizer == 'whitespace':\n            print('Loading whitespace tokenizer')\n            self.tokenize = lambda string: string.strip().split()\n\n        if tokenizer == 'regex':\n            print('Loading regex tokenizer')\n            import re\n            pattern = r\"[A-Z]{2,}(?![a-z])|[A-Z][a-z]+(?=[A-Z])|[\\'\\w\\-]+\"\n            self.tokenize = lambda string: re.findall(pattern, string)\n\n        if tokenizer == 'spacy':\n            print('Loading SpaCy')\n            import spacy\n            nlp = spacy.load('en')\n            self.tokenize = lambda string: [token.text for token in nlp(string)]\n\n        # Tokenize with punctuations other than periods\n        if tokenizer == 'nltk':\n            print('Loading NLTK word tokenizer')\n            from nltk import word_tokenize\n\n            self.tokenize = word_tokenize\n\n    def __call__(self, string):\n        if self.clean_string:\n            string = clean_str(string)\n        return self.tokenize(string)\n\n\nif __name__ == '__main__':\n    tokenizer = Tokenizer()\n    print(tokenizer(\"Hello, how are you doin'?\"))\n\n    tokenizer = Tokenizer('spacy')\n    print(tokenizer(\"Hello, how are you doin'?\"))\n"
  },
  {
    "path": "model/utils/vocab.py",
    "content": "from collections import defaultdict\nimport pickle\nimport torch\nfrom torch import Tensor\nfrom torch.autograd import Variable\nfrom nltk import FreqDist\nfrom .convert import to_tensor, to_var\n\nPAD_TOKEN = '<pad>'\nUNK_TOKEN = '<unk>'\nSOS_TOKEN = '<sos>'\nEOS_TOKEN = '<eos>'\n\nPAD_ID, UNK_ID, SOS_ID, EOS_ID = [0, 1, 2, 3]\n\n\nclass Vocab(object):\n    def __init__(self, tokenizer=None, max_size=None, min_freq=1):\n        \"\"\"Basic Vocabulary object\"\"\"\n\n        self.vocab_size = 0\n        self.freqdist = FreqDist()\n        self.tokenizer = tokenizer\n\n    def update(self, max_size=None, min_freq=1):\n        \"\"\"\n        Initialize id2word & word2id based on self.freqdist\n        max_size include 4 special tokens\n        \"\"\"\n\n        # {0: '<pad>', 1: '<unk>', 2: '<sos>', 3: '<eos>'}\n        self.id2word = {\n            PAD_ID: PAD_TOKEN, UNK_ID: UNK_TOKEN,\n            SOS_ID: SOS_TOKEN, EOS_ID: EOS_TOKEN\n        }\n        # {'<pad>': 0, '<unk>': 1, '<sos>': 2, '<eos>': 3}\n        self.word2id = defaultdict(lambda: UNK_ID)  # Not in vocab => return UNK\n        self.word2id.update({\n            PAD_TOKEN: PAD_ID, UNK_TOKEN: UNK_ID,\n            SOS_TOKEN: SOS_ID, EOS_TOKEN: EOS_ID\n        })\n        # self.word2id = {\n        #     PAD_TOKEN: PAD_ID, UNK_TOKEN: UNK_ID,\n        #     SOS_TOKEN: SOS_ID, EOS_TOKEN: EOS_ID\n        # }\n\n        vocab_size = 4\n        min_freq = max(min_freq, 1)\n\n        # Reset frequencies of special tokens\n        # [...('<eos>', 0), ('<pad>', 0), ('<sos>', 0), ('<unk>', 0)]\n        freqdist = self.freqdist.copy()\n        special_freqdist = {token: freqdist[token]\n                            for token in [PAD_TOKEN, UNK_TOKEN, SOS_TOKEN, EOS_TOKEN]}\n        freqdist.subtract(special_freqdist)\n\n        # Sort: by frequency, then alphabetically\n        # Ex) freqdist = { 'a': 4,   'b': 5,   'c': 3 }\n        #  =>   sorted = [('b', 5), ('a', 4), ('c', 3)]\n        sorted_frequency_counter = sorted(freqdist.items(), key=lambda k_v: k_v[0])\n        sorted_frequency_counter.sort(key=lambda k_v: k_v[1], reverse=True)\n\n        for word, freq in sorted_frequency_counter:\n\n            if freq < min_freq or vocab_size == max_size:\n                break\n            self.id2word[vocab_size] = word\n            self.word2id[word] = vocab_size\n            vocab_size += 1\n\n        self.vocab_size = vocab_size\n\n    def __len__(self):\n        return len(self.id2word)\n\n    def load(self, word2id_path=None, id2word_path=None):\n        if word2id_path:\n            with open(word2id_path, 'rb') as f:\n                word2id = pickle.load(f)\n            # Can't pickle lambda function\n            self.word2id = defaultdict(lambda: UNK_ID)\n            self.word2id.update(word2id)\n            self.vocab_size = len(self.word2id)\n\n        if id2word_path:\n            with open(id2word_path, 'rb') as f:\n                id2word = pickle.load(f)\n            self.id2word = id2word\n\n    def add_word(self, word):\n        assert isinstance(word, str), 'Input should be str'\n        self.freqdist.update([word])\n\n    def add_sentence(self, sentence, tokenized=False):\n        if not tokenized:\n            sentence = self.tokenizer(sentence)\n        for word in sentence:\n            self.add_word(word)\n\n    def add_dataframe(self, conversation_df, tokenized=True):\n        for conversation in conversation_df:\n            for sentence in conversation:\n                self.add_sentence(sentence, tokenized=tokenized)\n\n    def pickle(self, word2id_path, id2word_path):\n        with open(word2id_path, 'wb') as f:\n            pickle.dump(dict(self.word2id), f)\n\n        with open(id2word_path, 'wb') as f:\n            pickle.dump(self.id2word, f)\n\n    def to_list(self, list_like):\n        \"\"\"Convert list-like containers to list\"\"\"\n        if isinstance(list_like, list):\n            return list_like\n\n        if isinstance(list_like, Variable):\n            return list(to_tensor(list_like).numpy())\n        elif isinstance(list_like, Tensor):\n            return list(list_like.numpy())\n\n    def id2sent(self, id_list):\n        \"\"\"list of id => list of tokens (Single sentence)\"\"\"\n        id_list = self.to_list(id_list)\n        sentence = []\n        for id in id_list:\n            word = self.id2word[id]\n            if word not in [EOS_TOKEN, SOS_TOKEN, PAD_TOKEN]:\n                sentence.append(word)\n            if word == EOS_TOKEN:\n                break\n        return sentence\n\n    def sent2id(self, sentence, var=False):\n        \"\"\"list of tokens => list of id (Single sentence)\"\"\"\n        id_list = [self.word2id[word] for word in sentence]\n        if var:\n            id_list = to_var(torch.LongTensor(id_list), eval=True)\n        return id_list\n\n    def decode(self, id_list):\n        sentence = self.id2sent(id_list)\n        return ' '.join(sentence)\n"
  },
  {
    "path": "requirements.txt",
    "content": "pandas==0.20.3\nnumpy==1.14.0\ngensim==3.1.0\nspacy==1.9.0\ntqdm==4.15.0\nnltk==3.4.5\ntensorboardX==1.1\ntorch==0.4\n"
  },
  {
    "path": "ubuntu_preprocess.py",
    "content": "# Load the Ubuntu dialog corpus\n# Available from here:\n# http://dataset.cs.mcgill.ca/ubuntu-corpus-1.0/ubuntu_dialogs.tgz\n\nfrom multiprocessing import Pool\nfrom pathlib import Path\nfrom collections import OrderedDict\nfrom urllib.request import urlretrieve\nimport os\nimport argparse\nimport tarfile\nimport pickle\n\nfrom tqdm import tqdm\nimport pandas as pd\n\nfrom model.utils import Tokenizer, Vocab, PAD_TOKEN, SOS_TOKEN, EOS_TOKEN\n\nproject_dir = Path(__file__).resolve().parent\ndatasets_dir = project_dir.joinpath('datasets/')\nubuntu_dir = datasets_dir.joinpath('ubuntu/')\n\nubuntu_meta_dir = ubuntu_dir.joinpath('meta/')\ndialogs_dir = ubuntu_dir.joinpath('dialogs/')\n\n# Tokenizer\ntokenizer = Tokenizer('spacy')\n\n\ndef prepare_ubuntu_data():\n    \"\"\"Download and unpack dialogs\"\"\"\n\n    tar_filename = 'ubuntu_dialogs.tgz'\n    url = 'http://dataset.cs.mcgill.ca/ubuntu-corpus-1.0/ubuntu_dialogs.tgz'\n    tarfile_path = ubuntu_dir.joinpath(tar_filename)\n    metadata_url = 'https://raw.githubusercontent.com/rkadlec/ubuntu-ranking-dataset-creator/master/src/meta/'\n\n    if not datasets_dir.exists():\n        datasets_dir.mkdir()\n    if not ubuntu_dir.exists():\n        ubuntu_dir.mkdir()\n    if not ubuntu_meta_dir.exists():\n        ubuntu_meta_dir.mkdir()\n\n    # Prepare Dialog data\n    if not dialogs_dir.joinpath(\"10/1.tst\").exists():\n        # Download Dialog tarfile\n        if not tarfile_path.exists():\n            print(f\"Downloading {url} to {tarfile_path}\")\n            urlretrieve(url, tarfile_path)\n            print(f\"Successfully downloaded {tarfile_path}\")\n\n        # Unpack tarfile\n        if not dialogs_dir.exists():\n            print(\"Unpacking dialogs ... (This can take 5~10 mins.)\")\n            with tarfile.open(tarfile_path) as tar:\n                tar.extractall(path=ubuntu_dir)\n            print(\"Archive unpacked.\")\n\n    # Download metadata\n    if not ubuntu_meta_dir.joinpath('trainfiles.csv').exists():\n        print('Downloading metadata ... (This can take 5~10 mins.)')\n        for filename in ['trainfiles.csv', 'valfiles.csv', 'testfiles.csv']:\n            csv_path = ubuntu_meta_dir.joinpath(filename)\n            print(f\"Downloading {metadata_url+filename} to {csv_path}\")\n            urlretrieve(metadata_url + filename, csv_path)\n            print(f\"Successfully downloaded {csv_path}\")\n\n    print('Ubuntu Data prepared!')\n\n\ndef get_dialog_path_list(dataset='train'):\n    if dataset == 'train':\n        filename = 'trainfiles.csv'\n    elif dataset == 'test':\n        filename = 'testfiles.csv'\n    elif dataset == 'valid':\n        filename = 'valfiles.csv'\n    with open(ubuntu_meta_dir.joinpath(filename)) as f:\n        dialog_path_list = []\n        for line in f:\n            file, dir = line.strip().split(\",\")\n            path = dialogs_dir.joinpath(dir, file)\n            dialog_path_list.append(path)\n\n    return dialog_path_list\n\n\ndef read_and_tokenize(dialog_path, min_turn=3):\n    \"\"\"\n    Read conversation\n    Args:\n        dialog_path (str): path of dialog (tsv format)\n    Return:\n        dialogs: (list of list of str) [dialog_length, sentence_length]\n        users: (list of str); [2]\n    \"\"\"\n    with open(dialog_path, 'r', encoding='utf-8') as f:\n\n        # Go through the dialog\n        first_turn = True\n        dialog = []\n        users = []\n        same_user_utterances = []  # list of sentences of current user\n        dialog.append(same_user_utterances)\n\n        for line in f:\n            _time, speaker, _listener, sentence = line.split('\\t')\n            users.append(speaker)\n\n            if first_turn:\n                last_speaker = speaker\n                first_turn = False\n\n            # Speaker has changed\n            if last_speaker != speaker:\n                same_user_utterances = []\n                dialog.append(same_user_utterances)\n\n            same_user_utterances.append(sentence)\n            last_speaker = speaker\n\n        # All users in conversation (len: 2)\n        users = list(OrderedDict.fromkeys(users))\n\n        # 1. Concatenate consecutive sentences of single user\n        # 2. Tokenize\n        dialog = [tokenizer(\" \".join(sentence)) for sentence in dialog]\n\n        if len(dialog) < min_turn:\n            print(f\"Dialog {dialog_path} length ({len(dialog)}) < minimum required length {min_turn}\")\n            return []\n\n    return dialog #, users\n\n\ndef pad_sentences(conversations, max_sentence_length=30, max_conversation_length=10):\n\n    def pad_tokens(tokens, max_sentence_length=max_sentence_length):\n        n_valid_tokens = len(tokens)\n        if n_valid_tokens > max_sentence_length - 1:\n            tokens = tokens[:max_sentence_length - 1]\n        n_pad = max_sentence_length - n_valid_tokens - 1\n        tokens = tokens + [EOS_TOKEN] + [PAD_TOKEN] * n_pad\n        return tokens\n\n    def pad_conversation(conversation):\n        conversation = [pad_tokens(sentence) for sentence in conversation]\n        return conversation\n\n    all_padded_sentences = []\n    all_sentence_length = []\n\n    for conversation in conversations:\n        if len(conversation) > max_conversation_length:\n            conversation = conversation[:max_conversation_length]\n        sentence_length = [min(len(sentence) + 1, max_sentence_length) # +1 for EOS token\n                           for sentence in conversation]\n        all_sentence_length.append(sentence_length)\n\n        sentences = pad_conversation(conversation)\n        all_padded_sentences.append(sentences)\n\n    # [n_conversations, n_sentence (various), max_sentence_length]\n    sentences = all_padded_sentences\n    # [n_conversations, n_sentence (various)]\n    sentence_length = all_sentence_length\n    return sentences, sentence_length\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n\n    # Maximum valid length of sentence\n    # => SOS/EOS will surround sentence (EOS for source / SOS for target)\n    # => maximum length of tensor = max_sentence_length + 1\n    parser.add_argument('-s', '--max_sentence_length', type=int, default=30)\n    parser.add_argument('-c', '--max_conversation_length', type=int, default=10)\n\n    # Vocabulary\n    parser.add_argument('--max_vocab_size', type=int, default=20000)\n    parser.add_argument('--min_vocab_frequency', type=int, default=5)\n\n    # Multiprocess\n    parser.add_argument('--n_workers', type=int, default=os.cpu_count())\n\n    args = parser.parse_args()\n\n    max_sent_len = args.max_sentence_length\n    max_conv_len = args.max_conversation_length\n    max_vocab_size = args.max_vocab_size\n    min_freq = args.min_vocab_frequency\n    n_workers = args.n_workers\n\n    min_turn = 3\n\n    # Download and unpack dialogs if necessary.\n    prepare_ubuntu_data()\n\n    def to_pickle(obj, path):\n        with open(path, 'wb') as f:\n            pickle.dump(obj, f)\n\n    for split_type in ['train', 'test', 'valid']:\n        print(f'Processing {split_type} dataset...')\n        split_data_dir = ubuntu_dir.joinpath(split_type)\n        split_data_dir.mkdir(exist_ok=True)\n\n        # List of dialogs (tsv)\n        dialog_path_list = get_dialog_path_list(split_type)\n\n        print(f'Tokenize.. (n_workers={n_workers})')\n        def _tokenize_conversation(dialog_path):\n            return read_and_tokenize(dialog_path)\n        with Pool(n_workers) as pool:\n            conversations = list(tqdm(pool.imap(_tokenize_conversation, dialog_path_list),\n                                      total=len(dialog_path_list)))\n\n        # Filter too short conversations\n        conversations = list(filter(lambda x: len(x) >= min_turn, conversations))\n\n        # conversations: padded_sentences\n        # [n_conversations, conversation_length (various), max_sentence_length]\n\n        # sentence_length: list of length of sentences\n        # [n_conversations, conversation_length (various)]\n\n        conversation_length = [min(len(conversation), max_conv_len)\n                               for conversation in conversations]\n\n        sentences, sentence_length = pad_sentences(\n            conversations,\n            max_sentence_length=max_sent_len,\n            max_conversation_length=max_conv_len)\n\n        print('Saving preprocessed data at', split_data_dir)\n        to_pickle(conversation_length, split_data_dir.joinpath('conversation_length.pkl'))\n        to_pickle(sentences, split_data_dir.joinpath('sentences.pkl'))\n        to_pickle(sentence_length, split_data_dir.joinpath('sentence_length.pkl'))\n\n        if split_type == 'train':\n            print('Save Vocabulary...')\n            vocab = Vocab(tokenizer)\n            vocab.add_dataframe(conversations)\n            vocab.update(max_size=max_vocab_size, min_freq=min_freq)\n\n            print('Vocabulary size: ', len(vocab))\n            vocab.pickle(ubuntu_dir.joinpath('word2id.pkl'), ubuntu_dir.joinpath('id2word.pkl'))\n\n        print('Done!')\n"
  }
]