32 files
152.5 KB
34.5k tokens
145 symbols
1 requests
Download .txt
Repository: ctr4si/A-Hierarchical-Latent-Structure-for-Variational-Conversation-Modeling
Branch: master
Commit: 83ca9dd96272
Files: 32
Total size: 152.5 KB

Directory structure:
gitextract_erg9sdck/

├── .gitignore
├── LICENSE
├── Readme.md
├── cornell_preprocess.py
├── model/
│   ├── __init__.py
│   ├── configs.py
│   ├── data_loader.py
│   ├── eval.py
│   ├── eval_embed.py
│   ├── layers/
│   │   ├── __init__.py
│   │   ├── beam_search.py
│   │   ├── decoder.py
│   │   ├── encoder.py
│   │   ├── feedforward.py
│   │   ├── loss.py
│   │   └── rnncells.py
│   ├── models.py
│   ├── solver.py
│   ├── train.py
│   └── utils/
│       ├── __init__.py
│       ├── bow.py
│       ├── convert.py
│       ├── embedding_metric.py
│       ├── mask.py
│       ├── pad.py
│       ├── probability.py
│       ├── tensorboard.py
│       ├── time_track.py
│       ├── tokenizer.py
│       └── vocab.py
├── requirements.txt
└── ubuntu_preprocess.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
/etc/
datasets/
/cornell_movie_dialogue/
*.orig
*.lprof

# Remote edit
*.ftpconfig

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# dotenv
.env

# virtualenv
.venv
venv/
ENV/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2018 Center for SuperIntelligence

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: Readme.md
================================================
# Variational Hierarchical Conversation RNN (VHCR)
[PyTorch 0.4](https://github.com/pytorch/pytorch) Implementation of ["A Hierarchical Latent Structure for Variational Conversation Modeling"](https://arxiv.org/abs/1804.03424) (NAACL 2018 Oral)
* [NAACL 2018 Oral Presentation Video](https://vimeo.com/277671819)

## Prerequisite
Install Python packages

```
pip install -r requirements.txt
```

## Download & Preprocess data
Following scripts will

1. Create directories `./datasets/cornell/` and `./datasets/ubuntu/` respectively.

2. Download and preprocess conversation data inside each directory.

### for [Cornell Movie Dialogue dataset](https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html)
```
python cornell_preprocess.py
    --max_sentence_length (maximum number of words in sentence; default: 30)
    --max_conversation_length (maximum turns of utterances in single conversation; default: 10)
    --max_vocab_size (maximum size of word vocabulary; default: 20000)
    --max_vocab_frequency (minimum frequency of word to be included in vocabulary; default: 5)
    --n_workers (number of workers for multiprocessing; default: os.cpu_count())
```

### for [Ubuntu Dialog Dataset](http://dataset.cs.mcgill.ca/ubuntu-corpus-1.0/)
```
python ubuntu_preprocess.py
    --max_sentence_length (maximum number of words in sentence; default: 30)
    --max_conversation_length (maximum turns of utterances in single conversation; default: 10)
    --max_vocab_size (maximum size of word vocabulary; default: 20000)
    --max_vocab_frequency (minimum frequency of word to be included in vocabulary; default: 5)
    --n_workers (number of workers for multiprocessing; default: os.cpu_count())
```


## Training
Go to the model directory and set the save_dir in configs.py (this is where the model checkpoints will be saved)

We provide our implementation of VHCR, as well as our reference implementations for [HRED](https://arxiv.org/abs/1507.02221) and [VHRED](https://arxiv.org/abs/1605.06069).

To run training:
```
python train.py --data=<data> --model=<model> --batch_size=<batch_size>
```

For example:
1. Train HRED on Cornell Movie:
```
python train.py --data=cornell --model=HRED
```

2. Train VHRED with word drop of ratio 0.25 and kl annealing iterations 250000:
```
python train.py --data=ubuntu --model=VHRED --batch_size=40 --word_drop=0.25 --kl_annealing_iter=250000
```

3. Train VHCR with utterance drop of ratio 0.25:
```
python train.py --data=ubuntu --model=VHCR --batch_size=40 --sentence_drop=0.25 --kl_annealing_iter=250000
```

By default, it will save a model checkpoint every epoch to <save_dir> and a tensorboard summary.
For more arguments and options, see config.py.


## Evaluation
To evaluate the word perplexity:
```
python eval.py --model=<model> --checkpoint=<path_to_your_checkpoint>
```

For embedding based metrics, you need to download [Google News word vectors](https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit?usp=sharing), unzip it and put it under the datasets folder.
Then run:
```
python eval_embed.py --model=<model> --checkpoint=<path_to_your_checkpoint>
```


## Reference

If you use this code or dataset as part of any published research, please refer the following paper.

```
@inproceedings{VHCR:2018:NAACL,
    author    = {Yookoon Park and Jaemin Cho and Gunhee Kim},
    title     = "{A Hierarchical Latent Structure for Variational Conversation Modeling}",
    booktitle = {NAACL},
    year      = 2018
    }
```


================================================
FILE: cornell_preprocess.py
================================================
# Preprocess cornell movie dialogs dataset

from multiprocessing import Pool
import argparse
import pickle
import random
import os
from urllib.request import urlretrieve
from zipfile import ZipFile
from pathlib import Path
from tqdm import tqdm
from model.utils import Tokenizer, Vocab, PAD_TOKEN, SOS_TOKEN, EOS_TOKEN

project_dir = Path(__file__).resolve().parent
datasets_dir = project_dir.joinpath('datasets/')
cornell_dir = datasets_dir.joinpath('cornell/')

# Tokenizer
tokenizer = Tokenizer('spacy')

def prepare_cornell_data():
    """Download and unpack dialogs"""

    zip_url = 'http://www.mpi-sws.org/~cristian/data/cornell_movie_dialogs_corpus.zip'
    zipfile_path = datasets_dir.joinpath('cornell.zip')

    if not datasets_dir.exists():
        datasets_dir.mkdir()

    # Prepare Dialog data
    if not cornell_dir.exists():
        print(f'Downloading {zip_url} to {zipfile_path}')
        urlretrieve(zip_url, zipfile_path)
        print(f'Successfully downloaded {zipfile_path}')

        zip_ref = ZipFile(zipfile_path, 'r')
        zip_ref.extractall(datasets_dir)
        zip_ref.close()

        datasets_dir.joinpath('cornell movie-dialogs corpus').rename(cornell_dir)

    else:
        print('Cornell Data prepared!')


def loadLines(fileName,
              fields=["lineID", "characterID", "movieID", "character", "text"],
              delimiter=" +++$+++ "):
    """
    Args:
        fileName (str): file to load
        field (set<str>): fields to extract
    Return:
        dict<dict<str>>: the extracted fields for each line
    """
    lines = {}

    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(delimiter)

            # Extract fields
            lineObj = {}
            for i, field in enumerate(fields):
                lineObj[field] = values[i]

            lines[lineObj['lineID']] = lineObj

    return lines


def loadConversations(fileName, lines,
                      fields=["character1ID", "character2ID", "movieID", "utteranceIDs"],
                      delimiter=" +++$+++ "):
    """
    Args:
        fileName (str): file to load
        field (set<str>): fields to extract
    Return:
        dict<dict<str>>: the extracted fields for each line
    """
    conversations = []

    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(delimiter)

            # Extract fields
            convObj = {}
            for i, field in enumerate(fields):
                convObj[field] = values[i]

            # Convert string to list (convObj["utteranceIDs"] == "['L598485', 'L598486', ...]")
            lineIds = eval(convObj["utteranceIDs"])

            # Reassemble lines
            convObj["lines"] = []
            for lineId in lineIds:
                convObj["lines"].append(lines[lineId])

            conversations.append(convObj)

    return conversations


def train_valid_test_split_by_conversation(conversations, split_ratio=[0.8, 0.1, 0.1]):
    """Train/Validation/Test split by randomly selected movies"""

    train_ratio, valid_ratio, test_ratio = split_ratio
    assert train_ratio + valid_ratio + test_ratio == 1.0

    n_conversations = len(conversations)

    # Random shuffle movie list
    random.seed(0)
    random.shuffle(conversations)

    # Train / Validation / Test Split
    train_split = int(n_conversations * train_ratio)
    valid_split = int(n_conversations * (train_ratio + valid_ratio))

    train = conversations[:train_split]
    valid = conversations[train_split:valid_split]
    test = conversations[valid_split:]

    print(f'Train set: {len(train)} conversations')
    print(f'Validation set: {len(valid)} conversations')
    print(f'Test set: {len(test)} conversations')

    return train, valid, test


def tokenize_conversation(lines):
    sentence_list = [tokenizer(line['text']) for line in lines]
    return sentence_list


def pad_sentences(conversations, max_sentence_length=30, max_conversation_length=10):
    def pad_tokens(tokens, max_sentence_length=max_sentence_length):
        n_valid_tokens = len(tokens)
        if n_valid_tokens > max_sentence_length - 1:
            tokens = tokens[:max_sentence_length - 1]
        n_pad = max_sentence_length - n_valid_tokens - 1
        tokens = tokens + [EOS_TOKEN] + [PAD_TOKEN] * n_pad
        return tokens

    def pad_conversation(conversation):
        conversation = [pad_tokens(sentence) for sentence in conversation]
        return conversation

    all_padded_sentences = []
    all_sentence_length = []

    for conversation in conversations:
        if len(conversation) > max_conversation_length:
            conversation = conversation[:max_conversation_length]
        sentence_length = [min(len(sentence) + 1, max_sentence_length) # +1 for EOS token
                           for sentence in conversation]
        all_sentence_length.append(sentence_length)

        sentences = pad_conversation(conversation)
        all_padded_sentences.append(sentences)

    sentences = all_padded_sentences
    sentence_length = all_sentence_length
    return sentences, sentence_length


if __name__ == '__main__':

    parser = argparse.ArgumentParser()

    # Maximum valid length of sentence
    # => SOS/EOS will surround sentence (EOS for source / SOS for target)
    # => maximum length of tensor = max_sentence_length + 1
    parser.add_argument('-s', '--max_sentence_length', type=int, default=30)
    parser.add_argument('-c', '--max_conversation_length', type=int, default=10)

    # Split Ratio
    split_ratio = [0.8, 0.1, 0.1]

    # Vocabulary
    parser.add_argument('--max_vocab_size', type=int, default=20000)
    parser.add_argument('--min_vocab_frequency', type=int, default=5)

    # Multiprocess
    parser.add_argument('--n_workers', type=int, default=os.cpu_count())

    args = parser.parse_args()

    max_sent_len = args.max_sentence_length
    max_conv_len = args.max_conversation_length
    max_vocab_size = args.max_vocab_size
    min_freq = args.min_vocab_frequency
    n_workers = args.n_workers

    # Download and extract dialogs if necessary.
    prepare_cornell_data()

    print("Loading lines")
    lines = loadLines(cornell_dir.joinpath("movie_lines.txt"))
    print('Number of lines:', len(lines))

    print("Loading conversations...")
    conversations = loadConversations(cornell_dir.joinpath("movie_conversations.txt"), lines)
    print('Number of conversations:', len(conversations))
    print('Train/Valid/Test Split')
    # train, valid, test = train_valid_test_split_by_movie(conversations, split_ratio)
    train, valid, test = train_valid_test_split_by_conversation(conversations, split_ratio)

    def to_pickle(obj, path):
        with open(path, 'wb') as f:
            pickle.dump(obj, f)

    for split_type, conv_objects in [('train', train), ('valid', valid), ('test', test)]:
        print(f'Processing {split_type} dataset...')
        split_data_dir = cornell_dir.joinpath(split_type)
        split_data_dir.mkdir(exist_ok=True)

        print(f'Tokenize.. (n_workers={n_workers})')
        def _tokenize_conversation(conv):
            return tokenize_conversation(conv['lines'])
        with Pool(n_workers) as pool:
            conversations = list(tqdm(pool.imap(_tokenize_conversation, conv_objects),
                                     total=len(conv_objects)))

        conversation_length = [min(len(conv['lines']), max_conv_len)
                               for conv in conv_objects]

        sentences, sentence_length = pad_sentences(
            conversations,
            max_sentence_length=max_sent_len,
            max_conversation_length=max_conv_len)

        print('Saving preprocessed data at', split_data_dir)
        to_pickle(conversation_length, split_data_dir.joinpath('conversation_length.pkl'))
        to_pickle(sentences, split_data_dir.joinpath('sentences.pkl'))
        to_pickle(sentence_length, split_data_dir.joinpath('sentence_length.pkl'))

        if split_type == 'train':

            print('Save Vocabulary...')
            vocab = Vocab(tokenizer)
            vocab.add_dataframe(conversations)
            vocab.update(max_size=max_vocab_size, min_freq=min_freq)

            print('Vocabulary size: ', len(vocab))
            vocab.pickle(cornell_dir.joinpath('word2id.pkl'), cornell_dir.joinpath('id2word.pkl'))

    print('Done!')


================================================
FILE: model/__init__.py
================================================


================================================
FILE: model/configs.py
================================================
import os
import argparse
from datetime import datetime
from collections import defaultdict
from pathlib import Path
import pprint
from torch import optim
import torch.nn as nn
from layers.rnncells import StackedLSTMCell, StackedGRUCell

project_dir = Path(__file__).resolve().parent.parent
data_dir = project_dir.joinpath('datasets')
data_dict = {'cornell': data_dir.joinpath('cornell'), 'ubuntu': data_dir.joinpath('ubuntu')}
optimizer_dict = {'RMSprop': optim.RMSprop, 'Adam': optim.Adam}
rnn_dict = {'lstm': nn.LSTM, 'gru': nn.GRU}
rnncell_dict = {'lstm': StackedLSTMCell, 'gru': StackedGRUCell}
username = Path.home().name
save_dir = Path(f'/data1/{username}/conversation/')


def str2bool(v):
    """string to boolean"""
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


class Config(object):
    def __init__(self, **kwargs):
        """Configuration Class: set kwargs as class attributes with setattr"""
        if kwargs is not None:
            for key, value in kwargs.items():
                if key == 'optimizer':
                    value = optimizer_dict[value]
                if key == 'rnn':
                    value = rnn_dict[value]
                if key == 'rnncell':
                    value = rnncell_dict[value]
                setattr(self, key, value)

        # Dataset directory: ex) ./datasets/cornell/
        self.dataset_dir = data_dict[self.data.lower()]

        # Data Split ex) 'train', 'valid', 'test'
        self.data_dir = self.dataset_dir.joinpath(self.mode)
        # Pickled Vocabulary
        self.word2id_path = self.dataset_dir.joinpath('word2id.pkl')
        self.id2word_path = self.dataset_dir.joinpath('id2word.pkl')

        # Pickled Dataframes
        self.sentences_path = self.data_dir.joinpath('sentences.pkl')
        self.sentence_length_path = self.data_dir.joinpath('sentence_length.pkl')
        self.conversation_length_path = self.data_dir.joinpath('conversation_length.pkl')

        # Save path
        if self.mode == 'train' and self.checkpoint is None:
            time_now = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
            self.save_path = save_dir.joinpath(self.data, self.model, time_now)
            self.logdir = self.save_path
            os.makedirs(self.save_path, exist_ok=True)
        elif self.checkpoint is not None:
            assert os.path.exists(self.checkpoint)
            self.save_path = os.path.dirname(self.checkpoint)
            self.logdir = self.save_path

    def __str__(self):
        """Pretty-print configurations in alphabetical order"""
        config_str = 'Configurations\n'
        config_str += pprint.pformat(self.__dict__)
        return config_str


def get_config(parse=True, **optional_kwargs):
    """
    Get configurations as attributes of class
    1. Parse configurations with argparse.
    2. Create Config class initilized with parsed kwargs.
    3. Return Config class.
    """
    parser = argparse.ArgumentParser()

    # Mode
    parser.add_argument('--mode', type=str, default='train')

    # Train
    parser.add_argument('--batch_size', type=int, default=80)
    parser.add_argument('--eval_batch_size', type=int, default=80)
    parser.add_argument('--n_epoch', type=int, default=30)
    parser.add_argument('--learning_rate', type=float, default=1e-4)
    parser.add_argument('--optimizer', type=str, default='Adam')
    parser.add_argument('--clip', type=float, default=1.0)
    parser.add_argument('--checkpoint', type=str, default=None)

    # Generation
    parser.add_argument('--max_unroll', type=int, default=30)
    parser.add_argument('--sample', type=str2bool, default=False,
                        help='if false, use beam search for decoding')
    parser.add_argument('--temperature', type=float, default=1.0)
    parser.add_argument('--beam_size', type=int, default=1)

    # Model
    parser.add_argument('--model', type=str, default='VHCR',
                        help='one of {HRED, VHRED, VHCR}')
    # Currently does not support lstm
    parser.add_argument('--rnn', type=str, default='gru')
    parser.add_argument('--rnncell', type=str, default='gru')
    parser.add_argument('--num_layers', type=int, default=1)
    parser.add_argument('--embedding_size', type=int, default=500)
    parser.add_argument('--tie_embedding', type=str2bool, default=True)
    parser.add_argument('--encoder_hidden_size', type=int, default=1000)
    parser.add_argument('--bidirectional', type=str2bool, default=True)
    parser.add_argument('--decoder_hidden_size', type=int, default=1000)
    parser.add_argument('--dropout', type=float, default=0.2)
    parser.add_argument('--context_size', type=int, default=1000)
    parser.add_argument('--feedforward', type=str, default='FeedForward')
    parser.add_argument('--activation', type=str, default='Tanh')

    # VAE model
    parser.add_argument('--z_sent_size', type=int, default=100)
    parser.add_argument('--z_conv_size', type=int, default=100)
    parser.add_argument('--word_drop', type=float, default=0.0,
                        help='only applied to variational models')
    parser.add_argument('--kl_threshold', type=float, default=0.0)
    parser.add_argument('--kl_annealing_iter', type=int, default=25000)
    parser.add_argument('--importance_sample', type=int, default=100)
    parser.add_argument('--sentence_drop', type=float, default=0.0)

    # Generation
    parser.add_argument('--n_context', type=int, default=1)
    parser.add_argument('--n_sample_step', type=int, default=1)

    # BOW
    parser.add_argument('--bow', type=str2bool, default=False)

    # Utility
    parser.add_argument('--print_every', type=int, default=100)
    parser.add_argument('--plot_every_epoch', type=int, default=1)
    parser.add_argument('--save_every_epoch', type=int, default=1)

    # Data
    parser.add_argument('--data', type=str, default='ubuntu')

    # Parse arguments
    if parse:
        kwargs = parser.parse_args()
    else:
        kwargs = parser.parse_known_args()[0]

    # Namespace => Dictionary
    kwargs = vars(kwargs)
    kwargs.update(optional_kwargs)

    return Config(**kwargs)


================================================
FILE: model/data_loader.py
================================================
import random
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader
from utils import PAD_ID, UNK_ID, SOS_ID, EOS_ID
import numpy as np


class DialogDataset(Dataset):
    def __init__(self, sentences, conversation_length, sentence_length, vocab, data=None):

        # [total_data_size, max_conversation_length, max_sentence_length]
        # tokenized raw text of sentences
        self.sentences = sentences
        self.vocab = vocab

        # conversation length of each batch
        # [total_data_size]
        self.conversation_length = conversation_length

        # list of length of sentences
        # [total_data_size, max_conversation_length]
        self.sentence_length = sentence_length
        self.data = data
        self.len = len(sentences)

    def __getitem__(self, index):
        """Return Single data sentence"""
        # [max_conversation_length, max_sentence_length]
        sentence = self.sentences[index]
        conversation_length = self.conversation_length[index]
        sentence_length = self.sentence_length[index]

        # word => word_ids
        sentence = self.sent2id(sentence)

        return sentence, conversation_length, sentence_length

    def __len__(self):
        return self.len

    def sent2id(self, sentences):
        """word => word id"""
        # [max_conversation_length, max_sentence_length]
        return [self.vocab.sent2id(sentence) for sentence in sentences]


def get_loader(sentences, conversation_length, sentence_length, vocab, batch_size=100, data=None, shuffle=True):
    """Load DataLoader of given DialogDataset"""

    def collate_fn(data):
        """
        Collate list of data in to batch

        Args:
            data: list of tuple(source, target, conversation_length, source_length, target_length)
        Return:
            Batch of each feature
            - source (LongTensor): [batch_size, max_conversation_length, max_source_length]
            - target (LongTensor): [batch_size, max_conversation_length, max_source_length]
            - conversation_length (np.array): [batch_size]
            - source_length (LongTensor): [batch_size, max_conversation_length]
        """
        # Sort by conversation length (descending order) to use 'pack_padded_sequence'
        data.sort(key=lambda x: x[1], reverse=True)

        # Separate
        sentences, conversation_length, sentence_length = zip(*data)

        # return sentences, conversation_length, sentence_length.tolist()
        return sentences, conversation_length, sentence_length

    dataset = DialogDataset(sentences, conversation_length,
                            sentence_length, vocab, data=data)

    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=collate_fn)

    return data_loader


================================================
FILE: model/eval.py
================================================
from solver import Solver, VariationalSolver
from data_loader import get_loader
from configs import get_config
from utils import Vocab, Tokenizer
import os
import pickle
from models import VariationalModels


def load_pickle(path):
    with open(path, 'rb') as f:
        return pickle.load(f)


if __name__ == '__main__':
    config = get_config(mode='test')

    print('Loading Vocabulary...')
    vocab = Vocab()
    vocab.load(config.word2id_path, config.id2word_path)
    print(f'Vocabulary size: {vocab.vocab_size}')

    config.vocab_size = vocab.vocab_size

    data_loader = get_loader(
        sentences=load_pickle(config.sentences_path),
        conversation_length=load_pickle(config.conversation_length_path),
        sentence_length=load_pickle(config.sentence_length_path),
        vocab=vocab,
        batch_size=config.batch_size)

    if config.model in VariationalModels:
        solver = VariationalSolver(config, None, data_loader, vocab=vocab, is_train=False)
        solver.build()
        solver.importance_sample()
    else:
        solver = Solver(config, None, data_loader, vocab=vocab, is_train=False)
        solver.build()
        solver.test()


================================================
FILE: model/eval_embed.py
================================================
from solver import Solver, VariationalSolver
from data_loader import get_loader
from configs import get_config
from utils import Vocab, Tokenizer
import os
import pickle
from models import VariationalModels
import re


def load_pickle(path):
    with open(path, 'rb') as f:
        return pickle.load(f)


if __name__ == '__main__':
    config = get_config(mode='test')

    print('Loading Vocabulary...')
    vocab = Vocab()
    vocab.load(config.word2id_path, config.id2word_path)
    print(f'Vocabulary size: {vocab.vocab_size}')

    config.vocab_size = vocab.vocab_size

    data_loader = get_loader(
        sentences=load_pickle(config.sentences_path),
        conversation_length=load_pickle(config.conversation_length_path),
        sentence_length=load_pickle(config.sentence_length_path),
        vocab=vocab,
        batch_size=config.batch_size,
        shuffle=False)

    if config.model in VariationalModels:
        solver = VariationalSolver(config, None, data_loader, vocab=vocab, is_train=False)
    else:
        solver = Solver(config, None, data_loader, vocab=vocab, is_train=False)

    solver.build()
    solver.embedding_metric()


================================================
FILE: model/layers/__init__.py
================================================
from .encoder import *
from .decoder import *
from .rnncells import StackedLSTMCell, StackedGRUCell
from .loss import *
from .feedforward import *


================================================
FILE: model/layers/beam_search.py
================================================
import torch
from utils import EOS_ID


class Beam(object):
    def __init__(self, batch_size, hidden_size, vocab_size, beam_size, max_unroll, batch_position):
        """Beam class for beam search"""
        self.batch_size = batch_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.beam_size = beam_size
        self.max_unroll = max_unroll

        # batch_position [batch_size]
        #   [0, beam_size, beam_size * 2, .., beam_size * (batch_size-1)]
        #   Points where batch starts in [batch_size x beam_size] tensors
        #   Ex. position_idx[5]: when 5-th batch starts
        self.batch_position = batch_position

        self.log_probs = list()  # [(batch*k, vocab_size)] * sequence_length
        self.scores = list()  # [(batch*k)] * sequence_length
        self.back_pointers = list()  # [(batch*k)] * sequence_length
        self.token_ids = list()  # [(batch*k)] * sequence_length
        # self.hidden = list()  # [(num_layers, batch*k, hidden_size)] * sequence_length

        self.metadata = {
            'inputs': None,
            'output': None,
            'scores': None,
            'length': None,
            'sequence': None,
        }

    def update(self, score, back_pointer, token_id):  # , h):
        """Append intermediate top-k candidates to beam at each step"""

        # self.log_probs.append(log_prob)
        self.scores.append(score)
        self.back_pointers.append(back_pointer)
        self.token_ids.append(token_id)
        # self.hidden.append(h)

    def backtrack(self):
        """Backtracks over batch to generate optimal k-sequences

        Returns:
            prediction ([batch, k, max_unroll])
                A list of Tensors containing predicted sequence
            final_score [batch, k]
                A list containing the final scores for all top-k sequences
            length [batch, k]
                A list specifying the length of each sequence in the top-k candidates
        """
        prediction = list()

        # import ipdb
        # ipdb.set_trace()
        # Initialize for length of top-k sequences
        length = [[self.max_unroll] * self.beam_size for _ in range(self.batch_size)]

        # Last step output of the beam are not sorted => sort here!
        # Size not changed [batch size, beam_size]
        top_k_score, top_k_idx = self.scores[-1].topk(self.beam_size, dim=1)

        # Initialize sequence scores
        top_k_score = top_k_score.clone()

        n_eos_in_batch = [0] * self.batch_size

        # Initialize Back-pointer from the last step
        # Add self.position_idx for indexing variable with batch x beam as the first dimension
        # [batch x beam]
        back_pointer = (top_k_idx + self.batch_position.unsqueeze(1)).view(-1)

        for t in reversed(range(self.max_unroll)):
            # Reorder variables with the Back-pointer
            # [batch x beam]
            token_id = self.token_ids[t].index_select(0, back_pointer)

            # Reorder the Back-pointer
            # [batch x beam]
            back_pointer = self.back_pointers[t].index_select(0, back_pointer)

            # Indices of ended sequences
            # [< batch x beam]
            eos_indices = self.token_ids[t].data.eq(EOS_ID).nonzero()

            # For each batch, every time we see an EOS in the backtracking process,
            # If not all sequences are ended
            #    lowest scored survived sequence <- detected ended sequence
            # if all sequences are ended
            #    lowest scored ended sequence <- detected ended sequence
            if eos_indices.dim() > 0:
                # Loop over all EOS at current step
                for i in range(eos_indices.size(0) - 1, -1, -1):
                    # absolute index of detected ended sequence
                    eos_idx = eos_indices[i, 0].item()

                    # At which batch EOS is located
                    batch_idx = eos_idx // self.beam_size
                    batch_start_idx = batch_idx * self.beam_size

                    # if n_eos_in_batch[batch_idx] > self.beam_size:

                    # Index of sequence with lowest score
                    _n_eos_in_batch = n_eos_in_batch[batch_idx] % self.beam_size
                    beam_idx_to_be_replaced = self.beam_size - _n_eos_in_batch - 1
                    idx_to_be_replaced = batch_start_idx + beam_idx_to_be_replaced

                    # Replace old information with new sequence information
                    back_pointer[idx_to_be_replaced] = self.back_pointers[t][eos_idx].item()
                    token_id[idx_to_be_replaced] = self.token_ids[t][eos_idx].item()
                    top_k_score[batch_idx,
                                beam_idx_to_be_replaced] = self.scores[t].view(-1)[eos_idx].item()
                    length[batch_idx][beam_idx_to_be_replaced] = t + 1

                    n_eos_in_batch[batch_idx] += 1

            # max_unroll * [batch x beam]
            prediction.append(token_id)

        # Sort and re-order again as the added ended sequences may change the order
        # [batch, beam]
        top_k_score, top_k_idx = top_k_score.topk(self.beam_size, dim=1)
        final_score = top_k_score.data

        for batch_idx in range(self.batch_size):
            length[batch_idx] = [length[batch_idx][beam_idx.item()]
                                 for beam_idx in top_k_idx[batch_idx]]

        # [batch x beam]
        top_k_idx = (top_k_idx + self.batch_position.unsqueeze(1)).view(-1)

        # Reverse the sequences and re-order at the same time
        # It is reversed because the backtracking happens in the reverse order
        # [batch, beam]

        prediction = [step.index_select(0, top_k_idx).view(
            self.batch_size, self.beam_size) for step in reversed(prediction)]

        # [batch, beam, max_unroll]
        prediction = torch.stack(prediction, 2)

        return prediction, final_score, length


================================================
FILE: model/layers/decoder.py
================================================
import random
import torch
from torch import nn
from torch.nn import functional as F
from .rnncells import StackedLSTMCell, StackedGRUCell
from .beam_search import Beam
from .feedforward import FeedForward
from utils import to_var, SOS_ID, UNK_ID, EOS_ID
import math


class BaseRNNDecoder(nn.Module):
    def __init__(self):
        """Base Decoder Class"""
        super(BaseRNNDecoder, self).__init__()

    @property
    def use_lstm(self):
        return isinstance(self.rnncell, StackedLSTMCell)

    def init_token(self, batch_size, SOS_ID=SOS_ID):
        """Get Variable of <SOS> Index (batch_size)"""
        x = to_var(torch.LongTensor([SOS_ID] * batch_size))
        return x

    def init_h(self, batch_size=None, zero=True, hidden=None):
        """Return RNN initial state"""
        if hidden is not None:
            return hidden

        if self.use_lstm:
            # (h, c)
            return (to_var(torch.zeros(self.num_layers,
                                       batch_size,
                                       self.hidden_size)),
                    to_var(torch.zeros(self.num_layers,
                                       batch_size,
                                       self.hidden_size)))
        else:
            # h
            return to_var(torch.zeros(self.num_layers,
                                      batch_size,
                                      self.hidden_size))

    def batch_size(self, inputs=None, h=None):
        """
        inputs: [batch_size, seq_len]
        h: [num_layers, batch_size, hidden_size] (RNN/GRU)
        h_c: [2, num_layers, batch_size, hidden_size] (LSTMCell)
        """
        if inputs is not None:
            batch_size = inputs.size(0)
            return batch_size

        else:
            if self.use_lstm:
                batch_size = h[0].size(1)
            else:
                batch_size = h.size(1)
            return batch_size

    def decode(self, out):
        """
        Args:
            out: unnormalized word distribution [batch_size, vocab_size]
        Return:
            x: word_index [batch_size]
        """

        # Sample next word from multinomial word distribution
        if self.sample:
            # x: [batch_size] - word index (next input)
            x = torch.multinomial(self.softmax(out / self.temperature), 1).view(-1)

        # Greedy sampling
        else:
            # x: [batch_size] - word index (next input)
            _, x = out.max(dim=1)
        return x

    def forward(self):
        """Base forward function to inherit"""
        raise NotImplementedError

    def forward_step(self):
        """Run RNN single step"""
        raise NotImplementedError

    def embed(self, x):
        """word index: [batch_size] => word vectors: [batch_size, hidden_size]"""

        if self.training and self.word_drop > 0.0:
            if random.random() < self.word_drop:
                embed = self.embedding(to_var(x.data.new([UNK_ID] * x.size(0))))
            else:
                embed = self.embedding(x)
        else:
            embed = self.embedding(x)

        return embed

    def beam_decode(self,
                    init_h=None,
                    encoder_outputs=None, input_valid_length=None,
                    decode=False):
        """
        Args:
            encoder_outputs (Variable, FloatTensor): [batch_size, source_length, hidden_size]
            input_valid_length (Variable, LongTensor): [batch_size] (optional)
            init_h (variable, FloatTensor): [batch_size, hidden_size] (optional)
        Return:
            out   : [batch_size, seq_len]
        """
        batch_size = self.batch_size(h=init_h)

        # [batch_size x beam_size]
        x = self.init_token(batch_size * self.beam_size, SOS_ID)

        # [num_layers, batch_size x beam_size, hidden_size]
        h = self.init_h(batch_size, hidden=init_h).repeat(1, self.beam_size, 1)

        # batch_position [batch_size]
        #   [0, beam_size, beam_size * 2, .., beam_size * (batch_size-1)]
        #   Points where batch starts in [batch_size x beam_size] tensors
        #   Ex. position_idx[5]: when 5-th batch starts
        batch_position = to_var(torch.arange(0, batch_size).long() * self.beam_size)

        # Initialize scores of sequence
        # [batch_size x beam_size]
        # Ex. batch_size: 5, beam_size: 3
        # [0, -inf, -inf, 0, -inf, -inf, 0, -inf, -inf, 0, -inf, -inf, 0, -inf, -inf]
        score = torch.ones(batch_size * self.beam_size) * -float('inf')
        score.index_fill_(0, torch.arange(0, batch_size).long() * self.beam_size, 0.0)
        score = to_var(score)

        # Initialize Beam that stores decisions for backtracking
        beam = Beam(
            batch_size,
            self.hidden_size,
            self.vocab_size,
            self.beam_size,
            self.max_unroll,
            batch_position)

        for i in range(self.max_unroll):

            # x: [batch_size x beam_size]; (token index)
            # =>
            # out: [batch_size x beam_size, vocab_size]
            # h: [num_layers, batch_size x beam_size, hidden_size]
            out, h = self.forward_step(x, h,
                                       encoder_outputs=encoder_outputs,
                                       input_valid_length=input_valid_length)
            # log_prob: [batch_size x beam_size, vocab_size]
            log_prob = F.log_softmax(out, dim=1)

            # [batch_size x beam_size]
            # => [batch_size x beam_size, vocab_size]
            score = score.view(-1, 1) + log_prob

            # Select `beam size` transitions out of `vocab size` combinations

            # [batch_size x beam_size, vocab_size]
            # => [batch_size, beam_size x vocab_size]
            # Cutoff and retain candidates with top-k scores
            # score: [batch_size, beam_size]
            # top_k_idx: [batch_size, beam_size]
            #       each element of top_k_idx [0 ~ beam x vocab)

            score, top_k_idx = score.view(batch_size, -1).topk(self.beam_size, dim=1)

            # Get token ids with remainder after dividing by top_k_idx
            # Each element is among [0, vocab_size)
            # Ex. Index of token 3 in beam 4
            # (4 * vocab size) + 3 => 3
            # x: [batch_size x beam_size]
            x = (top_k_idx % self.vocab_size).view(-1)

            # top-k-pointer [batch_size x beam_size]
            #       Points top-k beam that scored best at current step
            #       Later used as back-pointer at backtracking
            #       Each element is beam index: 0 ~ beam_size
            #                     + position index: 0 ~ beam_size x (batch_size-1)
            beam_idx = top_k_idx / self.vocab_size  # [batch_size, beam_size]
            top_k_pointer = (beam_idx + batch_position.unsqueeze(1)).view(-1)

            # Select next h (size doesn't change)
            # [num_layers, batch_size * beam_size, hidden_size]
            h = h.index_select(1, top_k_pointer)

            # Update sequence scores at beam
            beam.update(score.clone(), top_k_pointer, x)  # , h)

            # Erase scores for EOS so that they are not expanded
            # [batch_size, beam_size]
            eos_idx = x.data.eq(EOS_ID).view(batch_size, self.beam_size)
            if eos_idx.nonzero().dim() > 0:
                score.data.masked_fill_(eos_idx, -float('inf'))

        # prediction ([batch, k, max_unroll])
        #     A list of Tensors containing predicted sequence
        # final_score [batch, k]
        #     A list containing the final scores for all top-k sequences
        # length [batch, k]
        #     A list specifying the length of each sequence in the top-k candidates
        # prediction, final_score, length = beam.backtrack()
        prediction, final_score, length = beam.backtrack()

        return prediction, final_score, length


class DecoderRNN(BaseRNNDecoder):
    def __init__(self, vocab_size, embedding_size,
                 hidden_size, rnncell=StackedGRUCell, num_layers=1,
                 dropout=0.0, word_drop=0.0,
                 max_unroll=30, sample=True, temperature=1.0, beam_size=1):
        super(DecoderRNN, self).__init__()

        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.temperature = temperature
        self.word_drop = word_drop
        self.max_unroll = max_unroll
        self.sample = sample
        self.beam_size = beam_size

        self.embedding = nn.Embedding(vocab_size, embedding_size)

        self.rnncell = rnncell(num_layers,
                               embedding_size,
                               hidden_size,
                               dropout)
        self.out = nn.Linear(hidden_size, vocab_size)
        self.softmax = nn.Softmax(dim=1)

    def forward_step(self, x, h,
                     encoder_outputs=None,
                     input_valid_length=None):
        """
        Single RNN Step
        1. Input Embedding (vocab_size => hidden_size)
        2. RNN Step (hidden_size => hidden_size)
        3. Output Projection (hidden_size => vocab size)

        Args:
            x: [batch_size]
            h: [num_layers, batch_size, hidden_size] (h and c from all layers)

        Return:
            out: [batch_size,vocab_size] (Unnormalized word distribution)
            h: [num_layers, batch_size, hidden_size] (h and c from all layers)
        """
        # x: [batch_size] => [batch_size, hidden_size]
        x = self.embed(x)

        # last_h: [batch_size, hidden_size] (h from Top RNN layer)
        # h: [num_layers, batch_size, hidden_size] (h and c from all layers)
        last_h, h = self.rnncell(x, h)

        if self.use_lstm:
            # last_h_c: [2, batch_size, hidden_size] (h from Top RNN layer)
            # h_c: [2, num_layers, batch_size, hidden_size] (h and c from all layers)
            last_h = last_h[0]

        # Unormalized word distribution
        # out: [batch_size, vocab_size]
        out = self.out(last_h)
        return out, h

    def forward(self, inputs, init_h=None, encoder_outputs=None, input_valid_length=None,
                decode=False):
        """
        Train (decode=False)
            Args:
                inputs (Variable, LongTensor): [batch_size, seq_len]
                init_h: (Variable, FloatTensor): [num_layers, batch_size, hidden_size]
            Return:
                out   : [batch_size, seq_len, vocab_size]
        Test (decode=True)
            Args:
                inputs: None
                init_h: (Variable, FloatTensor): [num_layers, batch_size, hidden_size]
            Return:
                out   : [batch_size, seq_len]
        """
        batch_size = self.batch_size(inputs, init_h)

        # x: [batch_size]
        x = self.init_token(batch_size, SOS_ID)

        # h: [num_layers, batch_size, hidden_size]
        h = self.init_h(batch_size, hidden=init_h)


        if not decode:
            out_list = []
            seq_len = inputs.size(1)
            for i in range(seq_len):

                # x: [batch_size]
                # =>
                # out: [batch_size, vocab_size]
                # h: [num_layers, batch_size, hidden_size] (h and c from all layers)
                out, h = self.forward_step(x, h)

                out_list.append(out)
                x = inputs[:, i]

            # [batch_size, max_target_len, vocab_size]
            return torch.stack(out_list, dim=1)
        else:
            x_list = []
            for i in range(self.max_unroll):

                # x: [batch_size]
                # =>
                # out: [batch_size, vocab_size]
                # h: [num_layers, batch_size, hidden_size] (h and c from all layers)
                out, h = self.forward_step(x, h)

                # out: [batch_size, vocab_size]
                # => x: [batch_size]
                x = self.decode(out)
                x_list.append(x)

            # [batch_size, max_target_len]
            return torch.stack(x_list, dim=1)


================================================
FILE: model/layers/encoder.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence
from utils import to_var, reverse_order_valid, PAD_ID
from .rnncells import StackedGRUCell, StackedLSTMCell

import copy

class BaseRNNEncoder(nn.Module):
    def __init__(self):
        """Base RNN Encoder Class"""
        super(BaseRNNEncoder, self).__init__()

    @property
    def use_lstm(self):
        if hasattr(self, 'rnn'):
            return isinstance(self.rnn, nn.LSTM)
        else:
            raise AttributeError('no rnn selected')

    def init_h(self, batch_size=None, hidden=None):
        """Return RNN initial state"""
        if hidden is not None:
            return hidden

        if self.use_lstm:
            return (to_var(torch.zeros(self.num_layers*self.num_directions,
                                      batch_size,
                                      self.hidden_size)),
                    to_var(torch.zeros(self.num_layers*self.num_directions,
                                      batch_size,
                                      self.hidden_size)))
        else:
            return to_var(torch.zeros(self.num_layers*self.num_directions,
                                        batch_size,
                                        self.hidden_size))

    def batch_size(self, inputs=None, h=None):
        """
        inputs: [batch_size, seq_len]
        h: [num_layers, batch_size, hidden_size] (RNN/GRU)
        h_c: [2, num_layers, batch_size, hidden_size] (LSTM)
        """
        if inputs is not None:
            batch_size = inputs.size(0)
            return batch_size

        else:
            if self.use_lstm:
                batch_size = h[0].size(1)
            else:
                batch_size = h.size(1)
            return batch_size

    def forward(self):
        raise NotImplementedError


class EncoderRNN(BaseRNNEncoder):
    def __init__(self, vocab_size, embedding_size,
                 hidden_size, rnn=nn.GRU, num_layers=1, bidirectional=False,
                 dropout=0.0, bias=True, batch_first=True):
        """Sentence-level Encoder"""
        super(EncoderRNN, self).__init__()

        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.batch_first = batch_first
        self.bidirectional = bidirectional

        if bidirectional:
            self.num_directions = 2
        else:
            self.num_directions = 1

        # word embedding
        self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=PAD_ID)

        self.rnn = rnn(input_size=embedding_size,
                        hidden_size=hidden_size,
                        num_layers=num_layers,
                        bias=bias,
                        batch_first=batch_first,
                        dropout=dropout,
                        bidirectional=bidirectional)

    def forward(self, inputs, input_length, hidden=None):
        """
        Args:
            inputs (Variable, LongTensor): [num_setences, max_seq_len]
            input_length (Variable, LongTensor): [num_sentences]
        Return:
            outputs (Variable): [max_source_length, batch_size, hidden_size]
                - list of all hidden states
            hidden ((tuple of) Variable): [num_layers*num_directions, batch_size, hidden_size]
                - last hidden state
                - (h, c) or h
        """
        batch_size, seq_len = inputs.size()

        # Sort in decreasing order of length for pack_padded_sequence()
        input_length_sorted, indices = input_length.sort(descending=True)

        input_length_sorted = input_length_sorted.data.tolist()

        # [num_sentences, max_source_length]
        inputs_sorted = inputs.index_select(0, indices)

        # [num_sentences, max_source_length, embedding_dim]
        embedded = self.embedding(inputs_sorted)

        # batch_first=True
        rnn_input = pack_padded_sequence(embedded, input_length_sorted,
                                            batch_first=self.batch_first)

        hidden = self.init_h(batch_size, hidden=hidden)

        # outputs: [batch, seq_len, hidden_size * num_directions]
        # hidden: [num_layers * num_directions, batch, hidden_size]
        self.rnn.flatten_parameters()
        outputs, hidden = self.rnn(rnn_input, hidden)
        outputs, outputs_lengths = pad_packed_sequence(outputs, batch_first=self.batch_first)

        # Reorder outputs and hidden
        _, inverse_indices = indices.sort()
        outputs = outputs.index_select(0, inverse_indices)

        if self.use_lstm:
            hidden = (hidden[0].index_select(1, inverse_indices),
                        hidden[1].index_select(1, inverse_indices))
        else:
            hidden = hidden.index_select(1, inverse_indices)

        return outputs, hidden

class ContextRNN(BaseRNNEncoder):
    def __init__(self, input_size, context_size, rnn=nn.GRU, num_layers=1, dropout=0.0,
                 bidirectional=False, bias=True, batch_first=True):
        """Context-level Encoder"""
        super(ContextRNN, self).__init__()

        self.input_size = input_size
        self.context_size = context_size
        self.hidden_size = self.context_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.bidirectional = bidirectional
        self.batch_first = batch_first

        if bidirectional:
            self.num_directions = 2
        else:
            self.num_directions = 1

        self.rnn = rnn(input_size=input_size,
                        hidden_size=context_size,
                        num_layers=num_layers,
                        bias=bias,
                        batch_first=batch_first,
                        dropout=dropout,
                        bidirectional=bidirectional)

    def forward(self, encoder_hidden, conversation_length, hidden=None):
        """
        Args:
            encoder_hidden (Variable, FloatTensor): [batch_size, max_len, num_layers * direction * hidden_size]
            conversation_length (Variable, LongTensor): [batch_size]
        Return:
            outputs (Variable): [batch_size, max_seq_len, hidden_size]
                - list of all hidden states
            hidden ((tuple of) Variable): [num_layers*num_directions, batch_size, hidden_size]
                - last hidden state
                - (h, c) or h
        """
        batch_size, seq_len, _  = encoder_hidden.size()

        # Sort for PackedSequence
        conv_length_sorted, indices = conversation_length.sort(descending=True)
        conv_length_sorted = conv_length_sorted.data.tolist()
        encoder_hidden_sorted = encoder_hidden.index_select(0, indices)

        rnn_input = pack_padded_sequence(encoder_hidden_sorted, conv_length_sorted, batch_first=True)

        hidden = self.init_h(batch_size, hidden=hidden)

        self.rnn.flatten_parameters()
        outputs, hidden = self.rnn(rnn_input, hidden)

        # outputs: [batch_size, max_conversation_length, context_size]
        outputs, outputs_length = pad_packed_sequence(outputs, batch_first=True)

        # reorder outputs and hidden
        _, inverse_indices = indices.sort()
        outputs = outputs.index_select(0, inverse_indices)

        if self.use_lstm:
            hidden = (hidden[0].index_select(1, inverse_indices),
                    hidden[1].index_select(1, inverse_indices))
        else:
            hidden = hidden.index_select(1, inverse_indices)

        # outputs: [batch, seq_len, hidden_size * num_directions]
        # hidden: [num_layers * num_directions, batch, hidden_size]
        return outputs, hidden

    def step(self, encoder_hidden, hidden):

        batch_size = encoder_hidden.size(0)
        # encoder_hidden: [1, batch_size, hidden_size]
        encoder_hidden = torch.unsqueeze(encoder_hidden, 1)

        if hidden is None:
            hidden = self.init_h(batch_size, hidden=None)

        outputs, hidden = self.rnn(encoder_hidden, hidden)
        return outputs, hidden


================================================
FILE: model/layers/feedforward.py
================================================
import torch
import torch.nn as nn


class FeedForward(nn.Module):
    def __init__(self, input_size, output_size, num_layers=1, hidden_size=None,
                 activation="Tanh", bias=True):
        super(FeedForward, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.activation = getattr(nn, activation)()
        n_inputs = [input_size] + [hidden_size] * (num_layers - 1)
        n_outputs = [hidden_size] * (num_layers - 1) + [output_size]
        self.linears = nn.ModuleList([nn.Linear(n_in, n_out, bias=bias)
                                      for n_in, n_out in zip(n_inputs, n_outputs)])

    def forward(self, input):
        x = input
        for linear in self.linears:
            x = linear(x)
            x = self.activation(x)

        return x


================================================
FILE: model/layers/loss.py
================================================
import torch
from torch.nn import functional as F
import torch.nn as nn
from utils import to_var, sequence_mask


# https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
def masked_cross_entropy(logits, target, length, per_example=False):
    """
    Args:
        logits (Variable, FloatTensor): [batch, max_len, num_classes]
            - unnormalized probability for each class
        target (Variable, LongTensor): [batch, max_len]
            - index of true class for each corresponding step
        length (Variable, LongTensor): [batch]
            - length of each data in a batch
    Returns:
        loss (Variable): []
            - An average loss value masked by the length
    """
    batch_size, max_len, num_classes = logits.size()

    # [batch_size * max_len, num_classes]
    logits_flat = logits.view(-1, num_classes)

    # [batch_size * max_len, num_classes]
    log_probs_flat = F.log_softmax(logits_flat, dim=1)

    # [batch_size * max_len, 1]
    target_flat = target.view(-1, 1)

    # Negative Log-likelihood: -sum {  1* log P(target)  + 0 log P(non-target)} = -sum( log P(target) )
    # [batch_size * max_len, 1]
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)

    # [batch_size, max_len]
    losses = losses_flat.view(batch_size, max_len)

    # [batch_size, max_len]
    mask = sequence_mask(sequence_length=length, max_len=max_len)

    # Apply masking on loss
    losses = losses * mask.float()

    # word-wise cross entropy
    # loss = losses.sum() / length.float().sum()

    if per_example:
        # loss: [batch_size]
        return losses.sum(1)
    else:
        loss = losses.sum()
        return loss, length.float().sum()


================================================
FILE: model/layers/rnncells.py
================================================
# Modified from OpenNMT.py, Z-forcing

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn._functions.thnn.rnnFusedPointwise import LSTMFused, GRUFused


class StackedLSTMCell(nn.Module):

    def __init__(self, num_layers, input_size, rnn_size, dropout):
        super(StackedLSTMCell, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.num_layers = num_layers

        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.append(nn.LSTMCell(input_size, rnn_size))
            input_size = rnn_size

    def forward(self, x, h_c):
        """
        Args:
            x: [batch_size, input_size]
            h_c: [2, num_layers, batch_size, hidden_size]
        Return:
            last_h_c: [2, batch_size, hidden_size] (h from last layer)
            h_c_list: [2, num_layers, batch_size, hidden_size] (h and c from all layers)
        """
        h_0, c_0 = h_c
        h_list, c_list = [], []
        for i, layer in enumerate(self.layers):
            # h of i-th layer
            h_i, c_i = layer(x, (h_0[i], c_0[i]))

            # x for next layer
            x = h_i
            if i + 1 != self.num_layers:
                x = self.dropout(x)
            h_list += [h_i]
            c_list += [c_i]

        last_h_c = (h_list[-1], c_list[-1])
        h_list = torch.stack(h_list)
        c_list = torch.stack(c_list)
        h_c_list = (h_list, c_list)

        return last_h_c, h_c_list


class StackedGRUCell(nn.Module):

    def __init__(self, num_layers, input_size, rnn_size, dropout):
        super(StackedGRUCell, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.num_layers = num_layers

        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.append(nn.GRUCell(input_size, rnn_size))
            input_size = rnn_size

    def forward(self, x, h):
        """
        Args:
            x: [batch_size, input_size]
            h: [num_layers, batch_size, hidden_size]
        Return:
            last_h: [batch_size, hidden_size] (h from last layer)
            h_list: [num_layers, batch_size, hidden_size] (h from all layers)
        """
        # h of all layers
        h_list = []
        for i, layer in enumerate(self.layers):
            # h of i-th layer
            h_i = layer(x, h[i])

            # x for next layer
            x = h_i
            if i + 1 is not self.num_layers:
                x = self.dropout(x)
            h_list.append(h_i)

        last_h = h_list[-1]
        h_list = torch.stack(h_list)

        return last_h, h_list


================================================
FILE: model/models.py
================================================
import torch
import torch.nn as nn
from utils import to_var, pad, normal_kl_div, normal_logpdf, bag_of_words_loss, to_bow, EOS_ID
import layers
import numpy as np
import random

VariationalModels = ['VHRED', 'VHCR']

class HRED(nn.Module):
    def __init__(self, config):
        super(HRED, self).__init__()

        self.config = config
        self.encoder = layers.EncoderRNN(config.vocab_size,
                                         config.embedding_size,
                                         config.encoder_hidden_size,
                                         config.rnn,
                                         config.num_layers,
                                         config.bidirectional,
                                         config.dropout)

        context_input_size = (config.num_layers
                              * config.encoder_hidden_size
                              * self.encoder.num_directions)
        self.context_encoder = layers.ContextRNN(context_input_size,
                                                 config.context_size,
                                                 config.rnn,
                                                 config.num_layers,
                                                 config.dropout)

        self.decoder = layers.DecoderRNN(config.vocab_size,
                                         config.embedding_size,
                                         config.decoder_hidden_size,
                                         config.rnncell,
                                         config.num_layers,
                                         config.dropout,
                                         config.word_drop,
                                         config.max_unroll,
                                         config.sample,
                                         config.temperature,
                                         config.beam_size)

        self.context2decoder = layers.FeedForward(config.context_size,
                                                  config.num_layers * config.decoder_hidden_size,
                                                  num_layers=1,
                                                  activation=config.activation)

        if config.tie_embedding:
            self.decoder.embedding = self.encoder.embedding

    def forward(self, input_sentences, input_sentence_length,
                input_conversation_length, target_sentences, decode=False):
        """
        Args:
            input_sentences: (Variable, LongTensor) [num_sentences, seq_len]
            target_sentences: (Variable, LongTensor) [num_sentences, seq_len]
        Return:
            decoder_outputs: (Variable, FloatTensor)
                - train: [batch_size, seq_len, vocab_size]
                - eval: [batch_size, seq_len]
        """
        num_sentences = input_sentences.size(0)
        max_len = input_conversation_length.data.max().item()

        # encoder_outputs: [num_sentences, max_source_length, hidden_size * direction]
        # encoder_hidden: [num_layers * direction, num_sentences, hidden_size]
        encoder_outputs, encoder_hidden = self.encoder(input_sentences,
                                                       input_sentence_length)

        # encoder_hidden: [num_sentences, num_layers * direction * hidden_size]
        encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(num_sentences, -1)

        # pad and pack encoder_hidden
        start = torch.cumsum(torch.cat((to_var(input_conversation_length.data.new(1).zero_()),
                                        input_conversation_length[:-1])), 0)

        # encoder_hidden: [batch_size, max_len, num_layers * direction * hidden_size]
        encoder_hidden = torch.stack([pad(encoder_hidden.narrow(0, s, l), max_len)
                                      for s, l in zip(start.data.tolist(),
                                                      input_conversation_length.data.tolist())], 0)

        # context_outputs: [batch_size, max_len, context_size]
        context_outputs, context_last_hidden = self.context_encoder(encoder_hidden,
                                                                    input_conversation_length)

        # flatten outputs
        # context_outputs: [num_sentences, context_size]
        context_outputs = torch.cat([context_outputs[i, :l, :]
                                     for i, l in enumerate(input_conversation_length.data)])

        # project context_outputs to decoder init state
        decoder_init = self.context2decoder(context_outputs)

        # [num_layers, batch_size, hidden_size]
        decoder_init = decoder_init.view(self.decoder.num_layers, -1, self.decoder.hidden_size)

        # train: [batch_size, seq_len, vocab_size]
        # eval: [batch_size, seq_len]
        if not decode:

            decoder_outputs = self.decoder(target_sentences,
                                           init_h=decoder_init,
                                           decode=decode)
            return decoder_outputs

        else:
            # decoder_outputs = self.decoder(target_sentences,
            #                                init_h=decoder_init,
            #                                decode=decode)
            # return decoder_outputs.unsqueeze(1)
            # prediction: [batch_size, beam_size, max_unroll]
            prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init)

            # Get top prediction only
            # [batch_size, max_unroll]
            # prediction = prediction[:, 0]

            # [batch_size, beam_size, max_unroll]
            return prediction

    def generate(self, context, sentence_length, n_context):
        # context: [batch_size, n_context, seq_len]
        batch_size = context.size(0)
        # n_context = context.size(1)
        samples = []

        # Run for context
        context_hidden=None
        for i in range(n_context):
            # encoder_outputs: [batch_size, seq_len, hidden_size * direction]
            # encoder_hidden: [num_layers * direction, batch_size, hidden_size]
            encoder_outputs, encoder_hidden = self.encoder(context[:, i, :],
                                                           sentence_length[:, i])

            encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1)
            # context_outputs: [batch_size, 1, context_hidden_size * direction]
            # context_hidden: [num_layers * direction, batch_size, context_hidden_size]
            context_outputs, context_hidden = self.context_encoder.step(encoder_hidden,
                                                                        context_hidden)

        # Run for generation
        for j in range(self.config.n_sample_step):
            # context_outputs: [batch_size, context_hidden_size * direction]
            context_outputs = context_outputs.squeeze(1)
            decoder_init = self.context2decoder(context_outputs)
            decoder_init = decoder_init.view(self.decoder.num_layers, -1, self.decoder.hidden_size)

            prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init)
            # prediction: [batch_size, seq_len]
            prediction = prediction[:, 0, :]
            # length: [batch_size]
            length = [l[0] for l in length]
            length = to_var(torch.LongTensor(length))
            samples.append(prediction)

            encoder_outputs, encoder_hidden = self.encoder(prediction,
                                                           length)

            encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1)

            context_outputs, context_hidden = self.context_encoder.step(encoder_hidden,
                                                                        context_hidden)

        samples = torch.stack(samples, 1)
        return samples


class VHRED(nn.Module):
    def __init__(self, config):
        super(VHRED, self).__init__()

        self.config = config
        self.encoder = layers.EncoderRNN(config.vocab_size,
                                         config.embedding_size,
                                         config.encoder_hidden_size,
                                         config.rnn,
                                         config.num_layers,
                                         config.bidirectional,
                                         config.dropout)

        context_input_size = (config.num_layers
                              * config.encoder_hidden_size
                              * self.encoder.num_directions)
        self.context_encoder = layers.ContextRNN(context_input_size,
                                                 config.context_size,
                                                 config.rnn,
                                                 config.num_layers,
                                                 config.dropout)

        self.decoder = layers.DecoderRNN(config.vocab_size,
                                         config.embedding_size,
                                         config.decoder_hidden_size,
                                         config.rnncell,
                                         config.num_layers,
                                         config.dropout,
                                         config.word_drop,
                                         config.max_unroll,
                                         config.sample,
                                         config.temperature,
                                         config.beam_size)

        self.context2decoder = layers.FeedForward(config.context_size + config.z_sent_size,
                                                  config.num_layers * config.decoder_hidden_size,
                                                  num_layers=1,
                                                  activation=config.activation)

        self.softplus = nn.Softplus()
        self.prior_h = layers.FeedForward(config.context_size,
                                          config.context_size,
                                          num_layers=2,
                                          hidden_size=config.context_size,
                                          activation=config.activation)
        self.prior_mu = nn.Linear(config.context_size,
                                  config.z_sent_size)
        self.prior_var = nn.Linear(config.context_size,
                                   config.z_sent_size)

        self.posterior_h = layers.FeedForward(config.encoder_hidden_size * self.encoder.num_directions * config.num_layers + config.context_size,
                                              config.context_size,
                                              num_layers=2,
                                              hidden_size=config.context_size,
                                              activation=config.activation)
        self.posterior_mu = nn.Linear(config.context_size,
                                      config.z_sent_size)
        self.posterior_var = nn.Linear(config.context_size,
                                       config.z_sent_size)
        if config.tie_embedding:
            self.decoder.embedding = self.encoder.embedding

        if config.bow:
            self.bow_h = layers.FeedForward(config.z_sent_size,
                                            config.decoder_hidden_size,
                                            num_layers=1,
                                            hidden_size=config.decoder_hidden_size,
                                            activation=config.activation)
            self.bow_predict = nn.Linear(config.decoder_hidden_size, config.vocab_size)

    def prior(self, context_outputs):
        # Context dependent prior
        h_prior = self.prior_h(context_outputs)
        mu_prior = self.prior_mu(h_prior)
        var_prior = self.softplus(self.prior_var(h_prior))
        return mu_prior, var_prior

    def posterior(self, context_outputs, encoder_hidden):
        h_posterior = self.posterior_h(torch.cat([context_outputs, encoder_hidden], 1))
        mu_posterior = self.posterior_mu(h_posterior)
        var_posterior = self.softplus(self.posterior_var(h_posterior))
        return mu_posterior, var_posterior

    def compute_bow_loss(self, target_conversations):
        target_bow = np.stack([to_bow(sent, self.config.vocab_size) for conv in target_conversations for sent in conv], axis=0)
        target_bow = to_var(torch.FloatTensor(target_bow))
        bow_logits = self.bow_predict(self.bow_h(self.z_sent))
        bow_loss = bag_of_words_loss(bow_logits, target_bow)
        return bow_loss

    def forward(self, sentences, sentence_length,
                input_conversation_length, target_sentences, decode=False):
        """
        Args:
            sentences: (Variable, LongTensor) [num_sentences + batch_size, seq_len]
            target_sentences: (Variable, LongTensor) [num_sentences, seq_len]
        Return:
            decoder_outputs: (Variable, FloatTensor)
                - train: [batch_size, seq_len, vocab_size]
                - eval: [batch_size, seq_len]
        """
        batch_size = input_conversation_length.size(0)
        num_sentences = sentences.size(0) - batch_size
        max_len = input_conversation_length.data.max().item()

        # encoder_outputs: [num_sentences + batch_size, max_source_length, hidden_size]
        # encoder_hidden: [num_layers * direction, num_sentences + batch_size, hidden_size]
        encoder_outputs, encoder_hidden = self.encoder(sentences,
                                                       sentence_length)

        # encoder_hidden: [num_sentences + batch_size, num_layers * direction * hidden_size]
        encoder_hidden = encoder_hidden.transpose(
            1, 0).contiguous().view(num_sentences + batch_size, -1)

        # pad and pack encoder_hidden
        start = torch.cumsum(torch.cat((to_var(input_conversation_length.data.new(1).zero_()),
                                        input_conversation_length[:-1] + 1)), 0)
        # encoder_hidden: [batch_size, max_len + 1, num_layers * direction * hidden_size]
        encoder_hidden = torch.stack([pad(encoder_hidden.narrow(0, s, l + 1), max_len + 1)
                                      for s, l in zip(start.data.tolist(),
                                                      input_conversation_length.data.tolist())], 0)

        # encoder_hidden_inference: [batch_size, max_len, num_layers * direction * hidden_size]
        encoder_hidden_inference = encoder_hidden[:, 1:, :]
        encoder_hidden_inference_flat = torch.cat(
            [encoder_hidden_inference[i, :l, :] for i, l in enumerate(input_conversation_length.data)])

        # encoder_hidden_input: [batch_size, max_len, num_layers * direction * hidden_size]
        encoder_hidden_input = encoder_hidden[:, :-1, :]

        # context_outputs: [batch_size, max_len, context_size]
        context_outputs, context_last_hidden = self.context_encoder(encoder_hidden_input,
                                                                    input_conversation_length)
        # flatten outputs
        # context_outputs: [num_sentences, context_size]
        context_outputs = torch.cat([context_outputs[i, :l, :]
                                     for i, l in enumerate(input_conversation_length.data)])

        mu_prior, var_prior = self.prior(context_outputs)
        eps = to_var(torch.randn((num_sentences, self.config.z_sent_size)))
        if not decode:
            mu_posterior, var_posterior = self.posterior(
                context_outputs, encoder_hidden_inference_flat)
            z_sent = mu_posterior + torch.sqrt(var_posterior) * eps
            log_q_zx = normal_logpdf(z_sent, mu_posterior, var_posterior).sum()

            log_p_z = normal_logpdf(z_sent, mu_prior, var_prior).sum()
            # kl_div: [num_sentneces]
            kl_div = normal_kl_div(mu_posterior, var_posterior,
                                    mu_prior, var_prior)
            kl_div = torch.sum(kl_div)
        else:
            z_sent = mu_prior + torch.sqrt(var_prior) * eps
            kl_div = None
            log_p_z = normal_logpdf(z_sent, mu_prior, var_prior).sum()
            log_q_zx = None

        self.z_sent = z_sent
        latent_context = torch.cat([context_outputs, z_sent], 1)
        decoder_init = self.context2decoder(latent_context)
        decoder_init = decoder_init.view(-1,
                                         self.decoder.num_layers,
                                         self.decoder.hidden_size)
        decoder_init = decoder_init.transpose(1, 0).contiguous()

        # train: [batch_size, seq_len, vocab_size]
        # eval: [batch_size, seq_len]
        if not decode:

            decoder_outputs = self.decoder(target_sentences,
                                           init_h=decoder_init,
                                           decode=decode)

            return decoder_outputs, kl_div, log_p_z, log_q_zx

        else:
            # prediction: [batch_size, beam_size, max_unroll]
            prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init)

            return prediction, kl_div, log_p_z, log_q_zx

    def generate(self, context, sentence_length, n_context):
        # context: [batch_size, n_context, seq_len]
        batch_size = context.size(0)
        # n_context = context.size(1)
        samples = []

        # Run for context
        context_hidden=None
        for i in range(n_context):
            # encoder_outputs: [batch_size, seq_len, hidden_size * direction]
            # encoder_hidden: [num_layers * direction, batch_size, hidden_size]
            encoder_outputs, encoder_hidden = self.encoder(context[:, i, :],
                                                           sentence_length[:, i])

            encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1)
            # context_outputs: [batch_size, 1, context_hidden_size * direction]
            # context_hidden: [num_layers * direction, batch_size, context_hidden_size]
            context_outputs, context_hidden = self.context_encoder.step(encoder_hidden,
                                                                        context_hidden)

        # Run for generation
        for j in range(self.config.n_sample_step):
            # context_outputs: [batch_size, context_hidden_size * direction]
            context_outputs = context_outputs.squeeze(1)

            mu_prior, var_prior = self.prior(context_outputs)
            eps = to_var(torch.randn((batch_size, self.config.z_sent_size)))
            z_sent = mu_prior + torch.sqrt(var_prior) * eps

            latent_context = torch.cat([context_outputs, z_sent], 1)
            decoder_init = self.context2decoder(latent_context)
            decoder_init = decoder_init.view(self.decoder.num_layers, -1, self.decoder.hidden_size)

            if self.config.sample:
                prediction = self.decoder(None, decoder_init)
                p = prediction.data.cpu().numpy()
                length = torch.from_numpy(np.where(p == EOS_ID)[1])
            else:
                prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init)
                # prediction: [batch_size, seq_len]
                prediction = prediction[:, 0, :]
                # length: [batch_size]
                length = [l[0] for l in length]
                length = to_var(torch.LongTensor(length))

            samples.append(prediction)

            encoder_outputs, encoder_hidden = self.encoder(prediction,
                                                           length)

            encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1)

            context_outputs, context_hidden = self.context_encoder.step(encoder_hidden,
                                                                        context_hidden)

        samples = torch.stack(samples, 1)
        return samples


class VHCR(nn.Module):
    def __init__(self, config):
        super(VHCR, self).__init__()

        self.config = config
        self.encoder = layers.EncoderRNN(config.vocab_size,
                                         config.embedding_size,
                                         config.encoder_hidden_size,
                                         config.rnn,
                                         config.num_layers,
                                         config.bidirectional,
                                         config.dropout)

        context_input_size = (config.num_layers
                              * config.encoder_hidden_size
                              * self.encoder.num_directions + config.z_conv_size)
        self.context_encoder = layers.ContextRNN(context_input_size,
                                                 config.context_size,
                                                 config.rnn,
                                                 config.num_layers,
                                                 config.dropout)

        self.unk_sent = nn.Parameter(torch.randn(context_input_size - config.z_conv_size))

        self.z_conv2context = layers.FeedForward(config.z_conv_size,
                                                 config.num_layers * config.context_size,
                                                 num_layers=1,
                                                 activation=config.activation)

        context_input_size = (config.num_layers
                              * config.encoder_hidden_size
                              * self.encoder.num_directions)
        self.context_inference = layers.ContextRNN(context_input_size,
                                                   config.context_size,
                                                   config.rnn,
                                                   config.num_layers,
                                                   config.dropout,
                                                   bidirectional=True)

        self.decoder = layers.DecoderRNN(config.vocab_size,
                                        config.embedding_size,
                                        config.decoder_hidden_size,
                                        config.rnncell,
                                        config.num_layers,
                                        config.dropout,
                                        config.word_drop,
                                        config.max_unroll,
                                        config.sample,
                                        config.temperature,
                                        config.beam_size)

        self.context2decoder = layers.FeedForward(config.context_size + config.z_sent_size + config.z_conv_size,
                                                  config.num_layers * config.decoder_hidden_size,
                                                  num_layers=1,
                                                  activation=config.activation)

        self.softplus = nn.Softplus()

        self.conv_posterior_h = layers.FeedForward(config.num_layers * self.context_inference.num_directions * config.context_size,
                                                    config.context_size,
                                                    num_layers=2,
                                                    hidden_size=config.context_size,
                                                    activation=config.activation)
        self.conv_posterior_mu = nn.Linear(config.context_size,
                                            config.z_conv_size)
        self.conv_posterior_var = nn.Linear(config.context_size,
                                             config.z_conv_size)

        self.sent_prior_h = layers.FeedForward(config.context_size + config.z_conv_size,
                                               config.context_size,
                                               num_layers=1,
                                               hidden_size=config.z_sent_size,
                                               activation=config.activation)
        self.sent_prior_mu = nn.Linear(config.context_size,
                                       config.z_sent_size)
        self.sent_prior_var = nn.Linear(config.context_size,
                                        config.z_sent_size)

        self.sent_posterior_h = layers.FeedForward(config.z_conv_size + config.encoder_hidden_size * self.encoder.num_directions * config.num_layers + config.context_size,
                                                   config.context_size,
                                                   num_layers=2,
                                                   hidden_size=config.context_size,
                                                   activation=config.activation)
        self.sent_posterior_mu = nn.Linear(config.context_size,
                                           config.z_sent_size)
        self.sent_posterior_var = nn.Linear(config.context_size,
                                            config.z_sent_size)

        if config.tie_embedding:
            self.decoder.embedding = self.encoder.embedding

    def conv_prior(self):
        # Standard gaussian prior
        return to_var(torch.FloatTensor([0.0])), to_var(torch.FloatTensor([1.0]))

    def conv_posterior(self, context_inference_hidden):
        h_posterior = self.conv_posterior_h(context_inference_hidden)
        mu_posterior = self.conv_posterior_mu(h_posterior)
        var_posterior = self.softplus(self.conv_posterior_var(h_posterior))
        return mu_posterior, var_posterior

    def sent_prior(self, context_outputs, z_conv):
        # Context dependent prior
        h_prior = self.sent_prior_h(torch.cat([context_outputs, z_conv], dim=1))
        mu_prior = self.sent_prior_mu(h_prior)
        var_prior = self.softplus(self.sent_prior_var(h_prior))
        return mu_prior, var_prior

    def sent_posterior(self, context_outputs, encoder_hidden, z_conv):
        h_posterior = self.sent_posterior_h(torch.cat([context_outputs, encoder_hidden, z_conv], 1))
        mu_posterior = self.sent_posterior_mu(h_posterior)
        var_posterior = self.softplus(self.sent_posterior_var(h_posterior))
        return mu_posterior, var_posterior

    def forward(self, sentences, sentence_length,
                input_conversation_length, target_sentences, decode=False):
        """
        Args:
            sentences: (Variable, LongTensor) [num_sentences + batch_size, seq_len]
            target_sentences: (Variable, LongTensor) [num_sentences, seq_len]
        Return:
            decoder_outputs: (Variable, FloatTensor)
                - train: [batch_size, seq_len, vocab_size]
                - eval: [batch_size, seq_len]
        """
        batch_size = input_conversation_length.size(0)
        num_sentences = sentences.size(0) - batch_size
        max_len = input_conversation_length.data.max().item()

        # encoder_outputs: [num_sentences + batch_size, max_source_length, hidden_size]
        # encoder_hidden: [num_layers * direction, num_sentences + batch_size, hidden_size]
        encoder_outputs, encoder_hidden = self.encoder(sentences,
                                                       sentence_length)

        # encoder_hidden: [num_sentences + batch_size, num_layers * direction * hidden_size]
        encoder_hidden = encoder_hidden.transpose(
            1, 0).contiguous().view(num_sentences + batch_size, -1)

        # pad and pack encoder_hidden
        start = torch.cumsum(torch.cat((to_var(input_conversation_length.data.new(1).zero_()),
                                        input_conversation_length[:-1] + 1)), 0)
        # encoder_hidden: [batch_size, max_len + 1, num_layers * direction * hidden_size]
        encoder_hidden = torch.stack([pad(encoder_hidden.narrow(0, s, l + 1), max_len + 1)
                                      for s, l in zip(start.data.tolist(),
                                                      input_conversation_length.data.tolist())], 0)

        # encoder_hidden_inference: [batch_size, max_len, num_layers * direction * hidden_size]
        encoder_hidden_inference = encoder_hidden[:, 1:, :]
        encoder_hidden_inference_flat = torch.cat(
            [encoder_hidden_inference[i, :l, :] for i, l in enumerate(input_conversation_length.data)])

        # encoder_hidden_input: [batch_size, max_len, num_layers * direction * hidden_size]
        encoder_hidden_input = encoder_hidden[:, :-1, :]

        # Standard Gaussian prior
        conv_eps = to_var(torch.randn([batch_size, self.config.z_conv_size]))
        conv_mu_prior, conv_var_prior = self.conv_prior()

        if not decode:
            if self.config.sentence_drop > 0.0:
                indices = np.where(np.random.rand(max_len) < self.config.sentence_drop)[0]
                if len(indices) > 0:
                    encoder_hidden_input[:, indices, :] = self.unk_sent

            # context_inference_outputs: [batch_size, max_len, num_directions * context_size]
            # context_inference_hidden: [num_layers * num_directions, batch_size, hidden_size]
            context_inference_outputs, context_inference_hidden = self.context_inference(encoder_hidden,
                                                                                            input_conversation_length + 1)

            # context_inference_hidden: [batch_size, num_layers * num_directions * hidden_size]
            context_inference_hidden = context_inference_hidden.transpose(
                1, 0).contiguous().view(batch_size, -1)
            conv_mu_posterior, conv_var_posterior = self.conv_posterior(context_inference_hidden)
            z_conv = conv_mu_posterior + torch.sqrt(conv_var_posterior) * conv_eps
            log_q_zx_conv = normal_logpdf(z_conv, conv_mu_posterior, conv_var_posterior).sum()

            log_p_z_conv = normal_logpdf(z_conv, conv_mu_prior, conv_var_prior).sum()
            kl_div_conv = normal_kl_div(conv_mu_posterior, conv_var_posterior,
                                            conv_mu_prior, conv_var_prior).sum()

            context_init = self.z_conv2context(z_conv).view(
                self.config.num_layers, batch_size, self.config.context_size)

            z_conv_expand = z_conv.view(z_conv.size(0), 1, z_conv.size(
                1)).expand(z_conv.size(0), max_len, z_conv.size(1))
            context_outputs, context_last_hidden = self.context_encoder(
                torch.cat([encoder_hidden_input, z_conv_expand], 2),
                input_conversation_length,
                hidden=context_init)

            # flatten outputs
            # context_outputs: [num_sentences, context_size]
            context_outputs = torch.cat([context_outputs[i, :l, :]
                                         for i, l in enumerate(input_conversation_length.data)])

            z_conv_flat = torch.cat(
                [z_conv_expand[i, :l, :] for i, l in enumerate(input_conversation_length.data)])
            sent_mu_prior, sent_var_prior = self.sent_prior(context_outputs, z_conv_flat)
            eps = to_var(torch.randn((num_sentences, self.config.z_sent_size)))

            sent_mu_posterior, sent_var_posterior = self.sent_posterior(
                context_outputs, encoder_hidden_inference_flat, z_conv_flat)
            z_sent = sent_mu_posterior + torch.sqrt(sent_var_posterior) * eps
            log_q_zx_sent = normal_logpdf(z_sent, sent_mu_posterior, sent_var_posterior).sum()

            log_p_z_sent = normal_logpdf(z_sent, sent_mu_prior, sent_var_prior).sum()
            # kl_div: [num_sentences]
            kl_div_sent = normal_kl_div(sent_mu_posterior, sent_var_posterior,
                                        sent_mu_prior, sent_var_prior).sum()

            kl_div = kl_div_conv + kl_div_sent
            log_q_zx = log_q_zx_conv + log_q_zx_sent
            log_p_z = log_p_z_conv + log_p_z_sent
        else:
            z_conv = conv_mu_prior + torch.sqrt(conv_var_prior) * conv_eps
            context_init = self.z_conv2context(z_conv).view(
                self.config.num_layers, batch_size, self.config.context_size)

            z_conv_expand = z_conv.view(z_conv.size(0), 1, z_conv.size(
                1)).expand(z_conv.size(0), max_len, z_conv.size(1))
            # context_outputs: [batch_size, max_len, context_size]
            context_outputs, context_last_hidden = self.context_encoder(
                torch.cat([encoder_hidden_input, z_conv_expand], 2),
                input_conversation_length,
                hidden=context_init)
            # flatten outputs
            # context_outputs: [num_sentences, context_size]
            context_outputs = torch.cat([context_outputs[i, :l, :]
                                         for i, l in enumerate(input_conversation_length.data)])


            z_conv_flat = torch.cat(
                [z_conv_expand[i, :l, :] for i, l in enumerate(input_conversation_length.data)])
            sent_mu_prior, sent_var_prior = self.sent_prior(context_outputs, z_conv_flat)
            eps = to_var(torch.randn((num_sentences, self.config.z_sent_size)))

            z_sent = sent_mu_prior + torch.sqrt(sent_var_prior) * eps
            kl_div = None
            log_p_z = normal_logpdf(z_sent, sent_mu_prior, sent_var_prior).sum()
            log_p_z += normal_logpdf(z_conv, conv_mu_prior, conv_var_prior).sum()
            log_q_zx = None

        # expand z_conv to all associated sentences
        z_conv = torch.cat([z.view(1, -1).expand(m.item(), self.config.z_conv_size)
                             for z, m in zip(z_conv, input_conversation_length)])

        # latent_context: [num_sentences, context_size + z_sent_size +
        # z_conv_size]
        latent_context = torch.cat([context_outputs, z_sent, z_conv], 1)
        decoder_init = self.context2decoder(latent_context)
        decoder_init = decoder_init.view(-1,
                                         self.decoder.num_layers,
                                         self.decoder.hidden_size)
        decoder_init = decoder_init.transpose(1, 0).contiguous()

        # train: [batch_size, seq_len, vocab_size]
        # eval: [batch_size, seq_len]
        if not decode:
            decoder_outputs = self.decoder(target_sentences,
                                            init_h=decoder_init,
                                            decode=decode)
            return decoder_outputs, kl_div, log_p_z, log_q_zx

        else:
            # prediction: [batch_size, beam_size, max_unroll]
            prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init)
            return prediction, kl_div, log_p_z, log_q_zx

    def generate(self, context, sentence_length, n_context):
        # context: [batch_size, n_context, seq_len]
        batch_size = context.size(0)
        # n_context = context.size(1)
        samples = []

        # Run for context

        conv_eps = to_var(torch.randn([batch_size, self.config.z_conv_size]))
        # conv_mu_prior, conv_var_prior = self.conv_prior()
        # z_conv = conv_mu_prior + torch.sqrt(conv_var_prior) * conv_eps

        encoder_hidden_list = []
        for i in range(n_context):
            # encoder_outputs: [batch_size, seq_len, hidden_size * direction]
            # encoder_hidden: [num_layers * direction, batch_size, hidden_size]
            encoder_outputs, encoder_hidden = self.encoder(context[:, i, :],
                                                           sentence_length[:, i])

            # encoder_hidden: [batch_size, num_layers * direction * hidden_size]
            encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1)
            encoder_hidden_list.append(encoder_hidden)

        encoder_hidden = torch.stack(encoder_hidden_list, 1)
        context_inference_outputs, context_inference_hidden = self.context_inference(encoder_hidden,
                                                                                     to_var(torch.LongTensor([n_context] * batch_size)))
        context_inference_hidden = context_inference_hidden.transpose(
            1, 0).contiguous().view(batch_size, -1)
        conv_mu_posterior, conv_var_posterior = self.conv_posterior(context_inference_hidden)
        z_conv = conv_mu_posterior + torch.sqrt(conv_var_posterior) * conv_eps

        context_init = self.z_conv2context(z_conv).view(
            self.config.num_layers, batch_size, self.config.context_size)

        context_hidden = context_init
        for i in range(n_context):
            # encoder_outputs: [batch_size, seq_len, hidden_size * direction]
            # encoder_hidden: [num_layers * direction, batch_size, hidden_size]
            encoder_outputs, encoder_hidden = self.encoder(context[:, i, :],
                                                           sentence_length[:, i])

            # encoder_hidden: [batch_size, num_layers * direction *
            encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1)
            encoder_hidden_list.append(encoder_hidden)
            # context_outputs: [batch_size, 1, context_hidden_size * direction]
            # context_hidden: [num_layers * direction, batch_size, context_hidden_size]
            context_outputs, context_hidden = self.context_encoder.step(torch.cat([encoder_hidden, z_conv], 1),
                                                                        context_hidden)

        # Run for generation
        for j in range(self.config.n_sample_step):
            # context_outputs: [batch_size, context_hidden_size * direction]
            context_outputs = context_outputs.squeeze(1)

            mu_prior, var_prior = self.sent_prior(context_outputs, z_conv)
            eps = to_var(torch.randn((batch_size, self.config.z_sent_size)))
            z_sent = mu_prior + torch.sqrt(var_prior) * eps

            latent_context = torch.cat([context_outputs, z_sent, z_conv], 1)
            decoder_init = self.context2decoder(latent_context)
            decoder_init = decoder_init.view(self.decoder.num_layers, -1, self.decoder.hidden_size)

            if self.config.sample:
                prediction = self.decoder(None, decoder_init, decode=True)
                p = prediction.data.cpu().numpy()
                length = torch.from_numpy(np.where(p == EOS_ID)[1])
            else:
                prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init)
                # prediction: [batch_size, seq_len]
                prediction = prediction[:, 0, :]
                # length: [batch_size]
                length = [l[0] for l in length]
                length = to_var(torch.LongTensor(length))

            samples.append(prediction)

            encoder_outputs, encoder_hidden = self.encoder(prediction,
                                                           length)

            encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1)

            context_outputs, context_hidden = self.context_encoder.step(torch.cat([encoder_hidden, z_conv], 1),
                                                                        context_hidden)

        samples = torch.stack(samples, 1)
        return samples


================================================
FILE: model/solver.py
================================================
from itertools import cycle
import numpy as np
import torch
import torch.nn as nn
import models
from layers import masked_cross_entropy
from utils import to_var, time_desc_decorator, TensorboardWriter, pad_and_pack, normal_kl_div, to_bow, bag_of_words_loss, normal_kl_div, embedding_metric
import os
from tqdm import tqdm
from math import isnan
import re
import math
import pickle
import gensim

word2vec_path = "../datasets/GoogleNews-vectors-negative300.bin"

class Solver(object):
    def __init__(self, config, train_data_loader, eval_data_loader, vocab, is_train=True, model=None):
        self.config = config
        self.epoch_i = 0
        self.train_data_loader = train_data_loader
        self.eval_data_loader = eval_data_loader
        self.vocab = vocab
        self.is_train = is_train
        self.model = model

    @time_desc_decorator('Build Graph')
    def build(self, cuda=True):

        if self.model is None:
            self.model = getattr(models, self.config.model)(self.config)

            # orthogonal initialiation for hidden weights
            # input gate bias for GRUs
            if self.config.mode == 'train' and self.config.checkpoint is None:
                print('Parameter initiailization')
                for name, param in self.model.named_parameters():
                    if 'weight_hh' in name:
                        print('\t' + name)
                        nn.init.orthogonal_(param)

                    # bias_hh is concatenation of reset, input, new gates
                    # only set the input gate bias to 2.0
                    if 'bias_hh' in name:
                        print('\t' + name)
                        dim = int(param.size(0) / 3)
                        param.data[dim:2 * dim].fill_(2.0)

        if torch.cuda.is_available() and cuda:
            self.model.cuda()

        # Overview Parameters
        print('Model Parameters')
        for name, param in self.model.named_parameters():
            print('\t' + name + '\t', list(param.size()))

        if self.config.checkpoint:
            self.load_model(self.config.checkpoint)

        if self.is_train:
            self.writer = TensorboardWriter(self.config.logdir)
            self.optimizer = self.config.optimizer(
                filter(lambda p: p.requires_grad, self.model.parameters()),
                lr=self.config.learning_rate)

    def save_model(self, epoch):
        """Save parameters to checkpoint"""
        ckpt_path = os.path.join(self.config.save_path, f'{epoch}.pkl')
        print(f'Save parameters to {ckpt_path}')
        torch.save(self.model.state_dict(), ckpt_path)

    def load_model(self, checkpoint):
        """Load parameters from checkpoint"""
        print(f'Load parameters from {checkpoint}')
        epoch = re.match(r"[0-9]*", os.path.basename(checkpoint)).group(0)
        self.epoch_i = int(epoch)
        self.model.load_state_dict(torch.load(checkpoint))

    def write_summary(self, epoch_i):
        epoch_loss = getattr(self, 'epoch_loss', None)
        if epoch_loss is not None:
            self.writer.update_loss(
                loss=epoch_loss,
                step_i=epoch_i + 1,
                name='train_loss')

        epoch_recon_loss = getattr(self, 'epoch_recon_loss', None)
        if epoch_recon_loss is not None:
            self.writer.update_loss(
                loss=epoch_recon_loss,
                step_i=epoch_i + 1,
                name='train_recon_loss')

        epoch_kl_div = getattr(self, 'epoch_kl_div', None)
        if epoch_kl_div is not None:
            self.writer.update_loss(
                loss=epoch_kl_div,
                step_i=epoch_i + 1,
                name='train_kl_div')

        kl_mult = getattr(self, 'kl_mult', None)
        if kl_mult is not None:
            self.writer.update_loss(
                loss=kl_mult,
                step_i=epoch_i + 1,
                name='kl_mult')

        epoch_bow_loss = getattr(self, 'epoch_bow_loss', None)
        if epoch_bow_loss is not None:
            self.writer.update_loss(
                loss=epoch_bow_loss,
                step_i=epoch_i + 1,
                name='bow_loss')

        validation_loss = getattr(self, 'validation_loss', None)
        if validation_loss is not None:
            self.writer.update_loss(
                loss=validation_loss,
                step_i=epoch_i + 1,
                name='validation_loss')

    @time_desc_decorator('Training Start!')
    def train(self):
        epoch_loss_history = []
        for epoch_i in range(self.epoch_i, self.config.n_epoch):
            self.epoch_i = epoch_i
            batch_loss_history = []
            self.model.train()
            n_total_words = 0
            for batch_i, (conversations, conversation_length, sentence_length) in enumerate(tqdm(self.train_data_loader, ncols=80)):
                # conversations: (batch_size) list of conversations
                #   conversation: list of sentences
                #   sentence: list of tokens
                # conversation_length: list of int
                # sentence_length: (batch_size) list of conversation list of sentence_lengths

                input_conversations = [conv[:-1] for conv in conversations]
                target_conversations = [conv[1:] for conv in conversations]

                # flatten input and target conversations
                input_sentences = [sent for conv in input_conversations for sent in conv]
                target_sentences = [sent for conv in target_conversations for sent in conv]
                input_sentence_length = [l for len_list in sentence_length for l in len_list[:-1]]
                target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]]
                input_conversation_length = [l - 1 for l in conversation_length]

                input_sentences = to_var(torch.LongTensor(input_sentences))
                target_sentences = to_var(torch.LongTensor(target_sentences))
                input_sentence_length = to_var(torch.LongTensor(input_sentence_length))
                target_sentence_length = to_var(torch.LongTensor(target_sentence_length))
                input_conversation_length = to_var(torch.LongTensor(input_conversation_length))

                # reset gradient
                self.optimizer.zero_grad()

                sentence_logits = self.model(
                    input_sentences,
                    input_sentence_length,
                    input_conversation_length,
                    target_sentences,
                    decode=False)

                batch_loss, n_words = masked_cross_entropy(
                    sentence_logits,
                    target_sentences,
                    target_sentence_length)

                assert not isnan(batch_loss.item())
                batch_loss_history.append(batch_loss.item())
                n_total_words += n_words.item()

                if batch_i % self.config.print_every == 0:
                    tqdm.write(
                        f'Epoch: {epoch_i+1}, iter {batch_i}: loss = {batch_loss.item()/ n_words.item():.3f}')

                # Back-propagation
                batch_loss.backward()

                # Gradient cliping
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip)

                # Run optimizer
                self.optimizer.step()

            epoch_loss = np.sum(batch_loss_history) / n_total_words
            epoch_loss_history.append(epoch_loss)
            self.epoch_loss = epoch_loss

            print_str = f'Epoch {epoch_i+1} loss average: {epoch_loss:.3f}'
            print(print_str)

            if epoch_i % self.config.save_every_epoch == 0:
                self.save_model(epoch_i + 1)

            print('\n<Validation>...')
            self.validation_loss = self.evaluate()

            if epoch_i % self.config.plot_every_epoch == 0:
                    self.write_summary(epoch_i)

        self.save_model(self.config.n_epoch)

        return epoch_loss_history

    def generate_sentence(self, input_sentences, input_sentence_length,
                          input_conversation_length, target_sentences):
        self.model.eval()

        # [batch_size, max_seq_len, vocab_size]
        generated_sentences = self.model(
            input_sentences,
            input_sentence_length,
            input_conversation_length,
            target_sentences,
            decode=True)

        # write output to file
        with open(os.path.join(self.config.save_path, 'samples.txt'), 'a') as f:
            f.write(f'<Epoch {self.epoch_i}>\n\n')

            tqdm.write('\n<Samples>')
            for input_sent, target_sent, output_sent in zip(input_sentences, target_sentences, generated_sentences):
                input_sent = self.vocab.decode(input_sent)
                target_sent = self.vocab.decode(target_sent)
                output_sent = '\n'.join([self.vocab.decode(sent) for sent in output_sent])
                s = '\n'.join(['Input sentence: ' + input_sent,
                               'Ground truth: ' + target_sent,
                               'Generated response: ' + output_sent + '\n'])
                f.write(s + '\n')
                print(s)
            print('')

    def evaluate(self):
        self.model.eval()
        batch_loss_history = []
        n_total_words = 0
        for batch_i, (conversations, conversation_length, sentence_length) in enumerate(tqdm(self.eval_data_loader, ncols=80)):
            # conversations: (batch_size) list of conversations
            #   conversation: list of sentences
            #   sentence: list of tokens
            # conversation_length: list of int
            # sentence_length: (batch_size) list of conversation list of sentence_lengths

            input_conversations = [conv[:-1] for conv in conversations]
            target_conversations = [conv[1:] for conv in conversations]

            # flatten input and target conversations
            input_sentences = [sent for conv in input_conversations for sent in conv]
            target_sentences = [sent for conv in target_conversations for sent in conv]
            input_sentence_length = [l for len_list in sentence_length for l in len_list[:-1]]
            target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]]
            input_conversation_length = [l - 1 for l in conversation_length]

            with torch.no_grad():
                input_sentences = to_var(torch.LongTensor(input_sentences))
                target_sentences = to_var(torch.LongTensor(target_sentences))
                input_sentence_length = to_var(torch.LongTensor(input_sentence_length))
                target_sentence_length = to_var(torch.LongTensor(target_sentence_length))
                input_conversation_length = to_var(
                    torch.LongTensor(input_conversation_length))

            if batch_i == 0:
                self.generate_sentence(input_sentences,
                                       input_sentence_length,
                                       input_conversation_length,
                                       target_sentences)

            sentence_logits = self.model(
                input_sentences,
                input_sentence_length,
                input_conversation_length,
                target_sentences)

            batch_loss, n_words = masked_cross_entropy(
                sentence_logits,
                target_sentences,
                target_sentence_length)

            assert not isnan(batch_loss.item())
            batch_loss_history.append(batch_loss.item())
            n_total_words += n_words.item()

        epoch_loss = np.sum(batch_loss_history) / n_total_words

        print_str = f'Validation loss: {epoch_loss:.3f}\n'
        print(print_str)

        return epoch_loss

    def test(self):
        self.model.eval()
        batch_loss_history = []
        n_total_words = 0
        for batch_i, (conversations, conversation_length, sentence_length) in enumerate(tqdm(self.eval_data_loader, ncols=80)):
            # conversations: (batch_size) list of conversations
            #   conversation: list of sentences
            #   sentence: list of tokens
            # conversation_length: list of int
            # sentence_length: (batch_size) list of conversation list of sentence_lengths

            input_conversations = [conv[:-1] for conv in conversations]
            target_conversations = [conv[1:] for conv in conversations]

            # flatten input and target conversations
            input_sentences = [sent for conv in input_conversations for sent in conv]
            target_sentences = [sent for conv in target_conversations for sent in conv]
            input_sentence_length = [l for len_list in sentence_length for l in len_list[:-1]]
            target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]]
            input_conversation_length = [l - 1 for l in conversation_length]

            with torch.no_grad():
                input_sentences = to_var(torch.LongTensor(input_sentences))
                target_sentences = to_var(torch.LongTensor(target_sentences))
                input_sentence_length = to_var(torch.LongTensor(input_sentence_length))
                target_sentence_length = to_var(torch.LongTensor(target_sentence_length))
                input_conversation_length = to_var(torch.LongTensor(input_conversation_length))

            sentence_logits = self.model(
                input_sentences,
                input_sentence_length,
                input_conversation_length,
                target_sentences)

            batch_loss, n_words = masked_cross_entropy(
                sentence_logits,
                target_sentences,
                target_sentence_length)

            assert not isnan(batch_loss.item())
            batch_loss_history.append(batch_loss.item())
            n_total_words += n_words.item()

        epoch_loss = np.sum(batch_loss_history) / n_total_words

        print(f'Number of words: {n_total_words}')
        print(f'Bits per word: {epoch_loss:.3f}')
        word_perplexity = np.exp(epoch_loss)

        print_str = f'Word perplexity : {word_perplexity:.3f}\n'
        print(print_str)

        return word_perplexity

    def embedding_metric(self):
        word2vec =  getattr(self, 'word2vec', None)
        if word2vec is None:
            print('Loading word2vec model')
            word2vec = gensim.models.KeyedVectors.load_word2vec_format(word2vec_path, binary=True)
            self.word2vec = word2vec
        keys = word2vec.vocab
        self.model.eval()
        n_context = self.config.n_context
        n_sample_step = self.config.n_sample_step
        metric_average_history = []
        metric_extrema_history = []
        metric_greedy_history = []
        context_history = []
        sample_history = []
        n_sent = 0
        n_conv = 0
        for batch_i, (conversations, conversation_length, sentence_length) \
                in enumerate(tqdm(self.eval_data_loader, ncols=80)):
            # conversations: (batch_size) list of conversations
            #   conversation: list of sentences
            #   sentence: list of tokens
            # conversation_length: list of int
            # sentence_length: (batch_size) list of conversation list of sentence_lengths

            conv_indices = [i for i in range(len(conversations)) if len(conversations[i]) >= n_context + n_sample_step]
            context = [c for i in conv_indices for c in [conversations[i][:n_context]]]
            ground_truth = [c for i in conv_indices for c in [conversations[i][n_context:n_context + n_sample_step]]]
            sentence_length = [c for i in conv_indices for c in [sentence_length[i][:n_context]]]

            with torch.no_grad():
                context = to_var(torch.LongTensor(context))
                sentence_length = to_var(torch.LongTensor(sentence_length))

            samples = self.model.generate(context, sentence_length, n_context)

            context = context.data.cpu().numpy().tolist()
            samples = samples.data.cpu().numpy().tolist()
            context_history.append(context)
            sample_history.append(samples)

            samples = [[self.vocab.decode(sent) for sent in c] for c in samples]
            ground_truth = [[self.vocab.decode(sent) for sent in c] for c in ground_truth]

            samples = [sent for c in samples for sent in c]
            ground_truth = [sent for c in ground_truth for sent in c]

            samples = [[word2vec[s] for s in sent.split() if s in keys] for sent in samples]
            ground_truth = [[word2vec[s] for s in sent.split() if s in keys] for sent in ground_truth]

            indices = [i for i, s, g in zip(range(len(samples)), samples, ground_truth) if s != [] and g != []]
            samples = [samples[i] for i in indices]
            ground_truth = [ground_truth[i] for i in indices]
            n = len(samples)
            n_sent += n

            metric_average = embedding_metric(samples, ground_truth, word2vec, 'average')
            metric_extrema = embedding_metric(samples, ground_truth, word2vec, 'extrema')
            metric_greedy = embedding_metric(samples, ground_truth, word2vec, 'greedy')
            metric_average_history.append(metric_average)
            metric_extrema_history.append(metric_extrema)
            metric_greedy_history.append(metric_greedy)

        epoch_average = np.mean(np.concatenate(metric_average_history), axis=0)
        epoch_extrema = np.mean(np.concatenate(metric_extrema_history), axis=0)
        epoch_greedy = np.mean(np.concatenate(metric_greedy_history), axis=0)

        print('n_sentences:', n_sent)
        print_str = f'Metrics - Average: {epoch_average:.3f}, Extrema: {epoch_extrema:.3f}, Greedy: {epoch_greedy:.3f}'
        print(print_str)
        print('\n')

        return epoch_average, epoch_extrema, epoch_greedy


class VariationalSolver(Solver):

    def __init__(self, config, train_data_loader, eval_data_loader, vocab, is_train=True, model=None):
        self.config = config
        self.epoch_i = 0
        self.train_data_loader = train_data_loader
        self.eval_data_loader = eval_data_loader
        self.vocab = vocab
        self.is_train = is_train
        self.model = model

    @time_desc_decorator('Training Start!')
    def train(self):
        epoch_loss_history = []
        kl_mult = 0.0
        conv_kl_mult = 0.0
        for epoch_i in range(self.epoch_i, self.config.n_epoch):
            self.epoch_i = epoch_i
            batch_loss_history = []
            recon_loss_history = []
            kl_div_history = []
            kl_div_sent_history = []
            kl_div_conv_history = []
            bow_loss_history = []
            self.model.train()
            n_total_words = 0

            # self.evaluate()

            for batch_i, (conversations, conversation_length, sentence_length) \
                    in enumerate(tqdm(self.train_data_loader, ncols=80)):
                # conversations: (batch_size) list of conversations
                #   conversation: list of sentences
                #   sentence: list of tokens
                # conversation_length: list of int
                # sentence_length: (batch_size) list of conversation list of sentence_lengths

                target_conversations = [conv[1:] for conv in conversations]

                # flatten input and target conversations
                sentences = [sent for conv in conversations for sent in conv]
                input_conversation_length = [l - 1 for l in conversation_length]
                target_sentences = [sent for conv in target_conversations for sent in conv]
                target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]]
                sentence_length = [l for len_list in sentence_length for l in len_list]

                sentences = to_var(torch.LongTensor(sentences))
                sentence_length = to_var(torch.LongTensor(sentence_length))
                input_conversation_length = to_var(torch.LongTensor(input_conversation_length))
                target_sentences = to_var(torch.LongTensor(target_sentences))
                target_sentence_length = to_var(torch.LongTensor(target_sentence_length))

                # reset gradient
                self.optimizer.zero_grad()

                sentence_logits, kl_div, _, _ = self.model(
                    sentences,
                    sentence_length,
                    input_conversation_length,
                    target_sentences)

                recon_loss, n_words = masked_cross_entropy(
                    sentence_logits,
                    target_sentences,
                    target_sentence_length)

                batch_loss = recon_loss + kl_mult * kl_div
                batch_loss_history.append(batch_loss.item())
                recon_loss_history.append(recon_loss.item())
                kl_div_history.append(kl_div.item())
                n_total_words += n_words.item()

                if self.config.bow:
                    bow_loss = self.model.compute_bow_loss(target_conversations)
                    batch_loss += bow_loss
                    bow_loss_history.append(bow_loss.item())

                assert not isnan(batch_loss.item())

                if batch_i % self.config.print_every == 0:
                    print_str = f'Epoch: {epoch_i+1}, iter {batch_i}: loss = {batch_loss.item() / n_words.item():.3f}, recon = {recon_loss.item() / n_words.item():.3f}, kl_div = {kl_div.item() / n_words.item():.3f}'
                    if self.config.bow:
                        print_str += f', bow_loss = {bow_loss.item() / n_words.item():.3f}'
                    tqdm.write(print_str)

                # Back-propagation
                batch_loss.backward()

                # Gradient cliping
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip)

                # Run optimizer
                self.optimizer.step()
                kl_mult = min(kl_mult + 1.0 / self.config.kl_annealing_iter, 1.0)

            epoch_loss = np.sum(batch_loss_history) / n_total_words
            epoch_loss_history.append(epoch_loss)

            epoch_recon_loss = np.sum(recon_loss_history) / n_total_words
            epoch_kl_div = np.sum(kl_div_history) / n_total_words

            self.kl_mult = kl_mult
            self.epoch_loss = epoch_loss
            self.epoch_recon_loss = epoch_recon_loss
            self.epoch_kl_div = epoch_kl_div

            print_str = f'Epoch {epoch_i+1} loss average: {epoch_loss:.3f}, recon_loss: {epoch_recon_loss:.3f}, kl_div: {epoch_kl_div:.3f}'
            if bow_loss_history:
                self.epoch_bow_loss = np.sum(bow_loss_history) / n_total_words
                print_str += f', bow_loss = {self.epoch_bow_loss:.3f}'
            print(print_str)

            if epoch_i % self.config.save_every_epoch == 0:
                self.save_model(epoch_i + 1)

            print('\n<Validation>...')
            self.validation_loss = self.evaluate()

            if epoch_i % self.config.plot_every_epoch == 0:
                    self.write_summary(epoch_i)

        return epoch_loss_history

    def generate_sentence(self, sentences, sentence_length,
                          input_conversation_length, input_sentences, target_sentences):
        """Generate output of decoder (single batch)"""
        self.model.eval()

        # [batch_size, max_seq_len, vocab_size]
        generated_sentences, _, _, _ = self.model(
            sentences,
            sentence_length,
            input_conversation_length,
            target_sentences,
            decode=True)

        # write output to file
        with open(os.path.join(self.config.save_path, 'samples.txt'), 'a') as f:
            f.write(f'<Epoch {self.epoch_i}>\n\n')

            tqdm.write('\n<Samples>')
            for input_sent, target_sent, output_sent in zip(input_sentences, target_sentences, generated_sentences):
                input_sent = self.vocab.decode(input_sent)
                target_sent = self.vocab.decode(target_sent)
                output_sent = '\n'.join([self.vocab.decode(sent) for sent in output_sent])
                s = '\n'.join(['Input sentence: ' + input_sent,
                               'Ground truth: ' + target_sent,
                               'Generated response: ' + output_sent + '\n'])
                f.write(s + '\n')
                print(s)
            print('')

    def evaluate(self):
        self.model.eval()
        batch_loss_history = []
        recon_loss_history = []
        kl_div_history = []
        bow_loss_history = []
        n_total_words = 0
        for batch_i, (conversations, conversation_length, sentence_length) \
                in enumerate(tqdm(self.eval_data_loader, ncols=80)):
            # conversations: (batch_size) list of conversations
            #   conversation: list of sentences
            #   sentence: list of tokens
            # conversation_length: list of int
            # sentence_length: (batch_size) list of conversation list of sentence_lengths

            target_conversations = [conv[1:] for conv in conversations]

            # flatten input and target conversations
            sentences = [sent for conv in conversations for sent in conv]
            input_conversation_length = [l - 1 for l in conversation_length]
            target_sentences = [sent for conv in target_conversations for sent in conv]
            target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]]
            sentence_length = [l for len_list in sentence_length for l in len_list]

            with torch.no_grad():
                sentences = to_var(torch.LongTensor(sentences))
                sentence_length = to_var(torch.LongTensor(sentence_length))
                input_conversation_length = to_var(
                    torch.LongTensor(input_conversation_length))
                target_sentences = to_var(torch.LongTensor(target_sentences))
                target_sentence_length = to_var(torch.LongTensor(target_sentence_length))

            if batch_i == 0:
                input_conversations = [conv[:-1] for conv in conversations]
                input_sentences = [sent for conv in input_conversations for sent in conv]
                with torch.no_grad():
                    input_sentences = to_var(torch.LongTensor(input_sentences))
                self.generate_sentence(sentences,
                                       sentence_length,
                                       input_conversation_length,
                                       input_sentences,
                                       target_sentences)

            sentence_logits, kl_div, _, _ = self.model(
                sentences,
                sentence_length,
                input_conversation_length,
                target_sentences)

            recon_loss, n_words = masked_cross_entropy(
                sentence_logits,
                target_sentences,
                target_sentence_length)

            batch_loss = recon_loss + kl_div
            if self.config.bow:
                bow_loss = self.model.compute_bow_loss(target_conversations)
                bow_loss_history.append(bow_loss.item())

            assert not isnan(batch_loss.item())
            batch_loss_history.append(batch_loss.item())
            recon_loss_history.append(recon_loss.item())
            kl_div_history.append(kl_div.item())
            n_total_words += n_words.item()

        epoch_loss = np.sum(batch_loss_history) / n_total_words
        epoch_recon_loss = np.sum(recon_loss_history) / n_total_words
        epoch_kl_div = np.sum(kl_div_history) / n_total_words

        print_str = f'Validation loss: {epoch_loss:.3f}, recon_loss: {epoch_recon_loss:.3f}, kl_div: {epoch_kl_div:.3f}'
        if bow_loss_history:
            epoch_bow_loss = np.sum(bow_loss_history) / n_total_words
            print_str += f', bow_loss = {epoch_bow_loss:.3f}'
        print(print_str)
        print('\n')

        return epoch_loss

    def importance_sample(self):
        ''' Perform importance sampling to get tighter bound
        '''
        self.model.eval()
        weight_history = []
        n_total_words = 0
        kl_div_history = []
        for batch_i, (conversations, conversation_length, sentence_length) \
                in enumerate(tqdm(self.eval_data_loader, ncols=80)):
            # conversations: (batch_size) list of conversations
            #   conversation: list of sentences
            #   sentence: list of tokens
            # conversation_length: list of int
            # sentence_length: (batch_size) list of conversation list of sentence_lengths

            target_conversations = [conv[1:] for conv in conversations]

            # flatten input and target conversations
            sentences = [sent for conv in conversations for sent in conv]
            input_conversation_length = [l - 1 for l in conversation_length]
            target_sentences = [sent for conv in target_conversations for sent in conv]
            target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]]
            sentence_length = [l for len_list in sentence_length for l in len_list]

            # n_words += sum([len([word for word in sent if word != PAD_ID]) for sent in target_sentences])
            with torch.no_grad():
                sentences = to_var(torch.LongTensor(sentences))
                sentence_length = to_var(torch.LongTensor(sentence_length))
                input_conversation_length = to_var(
                    torch.LongTensor(input_conversation_length))
                target_sentences = to_var(torch.LongTensor(target_sentences))
                target_sentence_length = to_var(torch.LongTensor(target_sentence_length))

            # treat whole batch as one data sample
            weights = []
            for j in range(self.config.importance_sample):
                sentence_logits, kl_div, log_p_z, log_q_zx = self.model(
                    sentences,
                    sentence_length,
                    input_conversation_length,
                    target_sentences)

                recon_loss, n_words = masked_cross_entropy(
                    sentence_logits,
                    target_sentences,
                    target_sentence_length)

                log_w = (-recon_loss.sum() + log_p_z - log_q_zx).data
                weights.append(log_w)
                if j == 0:
                    n_total_words += n_words.item()
                    kl_div_history.append(kl_div.item())

            # weights: [n_samples]
            weights = torch.stack(weights, 0)
            m = np.floor(weights.max())
            weights = np.log(torch.exp(weights - m).sum())
            weights = m + weights - np.log(self.config.importance_sample)
            weight_history.append(weights)

        print(f'Number of words: {n_total_words}')
        bits_per_word = -np.sum(weight_history) / n_total_words
        print(f'Bits per word: {bits_per_word:.3f}')
        word_perplexity = np.exp(bits_per_word)

        epoch_kl_div = np.sum(kl_div_history) / n_total_words

        print_str = f'Word perplexity upperbound using {self.config.importance_sample} importance samples: {word_perplexity:.3f}, kl_div: {epoch_kl_div:.3f}\n'
        print(print_str)

        return word_perplexity



================================================
FILE: model/train.py
================================================
from solver import *
from data_loader import get_loader
from configs import get_config
from utils import Vocab
import os
import pickle
from models import VariationalModels

def load_pickle(path):
    with open(path, 'rb') as f:
        return pickle.load(f)


if __name__ == '__main__':
    config = get_config(mode='train')
    val_config = get_config(mode='valid')
    print(config)
    with open(os.path.join(config.save_path, 'config.txt'), 'w') as f:
        print(config, file=f)

    print('Loading Vocabulary...')
    vocab = Vocab()
    vocab.load(config.word2id_path, config.id2word_path)
    print(f'Vocabulary size: {vocab.vocab_size}')

    config.vocab_size = vocab.vocab_size

    train_data_loader = get_loader(
        sentences=load_pickle(config.sentences_path),
        conversation_length=load_pickle(config.conversation_length_path),
        sentence_length=load_pickle(config.sentence_length_path),
        vocab=vocab,
        batch_size=config.batch_size)

    eval_data_loader = get_loader(
        sentences=load_pickle(val_config.sentences_path),
        conversation_length=load_pickle(val_config.conversation_length_path),
        sentence_length=load_pickle(val_config.sentence_length_path),
        vocab=vocab,
        batch_size=val_config.eval_batch_size,
        shuffle=False)

    # for testing
    # train_data_loader = eval_data_loader
    if config.model in VariationalModels:
        solver = VariationalSolver
    else:
        solver = Solver

    solver = solver(config, train_data_loader, eval_data_loader, vocab=vocab, is_train=True)

    solver.build()
    solver.train()


================================================
FILE: model/utils/__init__.py
================================================
from .convert import *
from .time_track import time_desc_decorator
from .tensorboard import TensorboardWriter
from .vocab import *
from .mask import *
from .tokenizer import *
from .probability import *
from .pad import *
from .bow import *
from .embedding_metric import *


================================================
FILE: model/utils/bow.py
================================================
import numpy as np
from collections import Counter
import torch.nn as nn
from torch.nn import functional as F
import torch
from math import isnan
from .vocab import PAD_ID, EOS_ID


def to_bow(sentence, vocab_size):
    '''  Convert a sentence into a bag of words representation
    Args
        - sentence: a list of token ids
        - vocab_size: V
    Returns
        - bow: a integer vector of size V
    '''
    bow = Counter(sentence)
    # Remove EOS tokens
    bow[PAD_ID] = 0
    bow[EOS_ID] = 0

    x = np.zeros(vocab_size, dtype=np.int64)
    x[list(bow.keys())] = list(bow.values())

    return x


def bag_of_words_loss(bow_logits, target_bow, weight=None):
    ''' Calculate bag of words representation loss
    Args
        - bow_logits: [num_sentences, vocab_size]
        - target_bow: [num_sentences]
    '''
    log_probs = F.log_softmax(bow_logits, dim=1)
    target_distribution = target_bow / (target_bow.sum(1).view(-1, 1) + 1e-23) + 1e-23
    entropy = -(torch.log(target_distribution) * target_bow).sum()
    loss = -(log_probs * target_bow).sum() - entropy

    return loss


================================================
FILE: model/utils/convert.py
================================================
import torch
from torch.autograd import Variable


def to_var(x, on_cpu=False, gpu_id=None, async=False):
    """Tensor => Variable"""
    if torch.cuda.is_available() and not on_cpu:
        x = x.cuda(gpu_id, async)
        #x = Variable(x)
    return x


def to_tensor(x):
    """Variable => Tensor"""
    if torch.cuda.is_available():
        x = x.cpu()
    return x.data

def reverse_order(tensor, dim=0):
    """Reverse Tensor or Variable"""
    if isinstance(tensor, torch.Tensor) or isinstance(tensor, torch.LongTensor):
        idx = [i for i in range(tensor.size(dim)-1, -1, -1)]
        idx = torch.LongTensor(idx)
        inverted_tensor = tensor.index_select(dim, idx)
    if isinstance(tensor, torch.cuda.FloatTensor) or isinstance(tensor, torch.cuda.LongTensor):
        idx = [i for i in range(tensor.size(dim)-1, -1, -1)]
        idx = torch.cuda.LongTensor(idx)
        inverted_tensor = tensor.index_select(dim, idx)
        return inverted_tensor
    elif isinstance(tensor, Variable):
        variable = tensor
        variable.data = reverse_order(variable.data, dim)
        return variable

def reverse_order_valid(tensor, length_list, dim=0):
    """
    Reverse Tensor of Variable only in given length
    Ex)
    Args:
        - tensor (Tensor or Variable)
         1   2   3   4   5   6
         6   7   8   9   0   0
        11  12  13   0   0   0
        16  17   0   0   0   0
        21  22  23  24  25  26

        - length_list (list)
        [6, 4, 3, 2, 6]
 
    Return:
        tensor (Tensor or Variable; in-place)
         6   5   4   3   2   1
         0   0   9   8   7   6
         0   0   0  13  12  11
         0   0   0   0  17  16
        26  25  24  23  22  21
    """
    for row, length in zip(tensor, length_list):
        valid_row = row[:length]
        reversed_valid_row = reverse_order(valid_row, dim=dim)
        row[:length] = reversed_valid_row
    return tensor


================================================
FILE: model/utils/embedding_metric.py
================================================
import numpy as np


def cosine_similarity(s, g):
    similarity = np.sum(s * g, axis=1) / np.sqrt((np.sum(s * s, axis=1) * np.sum(g * g, axis=1)))

    # return np.sum(similarity)
    return similarity


def embedding_metric(samples, ground_truth, word2vec, method='average'):

    if method == 'average':
        # s, g: [n_samples, word_dim]
        s = [np.mean(sample, axis=0) for sample in samples]
        g = [np.mean(gt, axis=0) for gt in ground_truth]
        return cosine_similarity(np.array(s), np.array(g))
    elif method == 'extrema':
        s_list = []
        g_list = []
        for sample, gt in zip(samples, ground_truth):
            s_max = np.max(sample, axis=0)
            s_min = np.min(sample, axis=0)
            s_plus = np.absolute(s_min) <= s_max
            s_abs = np.max(np.absolute(sample), axis=0)
            s = s_max * s_plus + s_min * np.logical_not(s_plus)
            s_list.append(s)

            g_max = np.max(gt, axis=0)
            g_min = np.min(gt, axis=0)
            g_plus = np.absolute(g_min) <= g_max
            g_abs = np.max(np.absolute(gt), axis=0)
            g = g_max * g_plus + g_min * np.logical_not(g_plus)
            g_list.append(g)

        return cosine_similarity(np.array(s_list), np.array(g_list))
    elif method == 'greedy':
        sim_list = []
        for s, g in zip(samples, ground_truth):
            s = np.array(s)
            g = np.array(g).T
            sim = (np.matmul(s, g)
                   / np.sqrt(np.matmul(np.sum(s * s, axis=1, keepdims=True), np.sum(g * g, axis=0, keepdims=True))))
            sim = np.max(sim, axis=0)
            sim_list.append(np.mean(sim))

        # return np.sum(sim_list)
        return np.array(sim_list)
    else:
        raise NotImplementedError


================================================
FILE: model/utils/mask.py
================================================
import torch
from .convert import to_var


def sequence_mask(sequence_length, max_len=None):
    """
    Args:
        sequence_length (Variable, LongTensor) [batch_size]
            - list of sequence length of each batch
        max_len (int)
    Return:
        masks (bool): [batch_size, max_len]
            - True if current sequence is valid (not padded), False otherwise

    Ex.
    sequence length: [3, 2, 1]

    seq_length_expand
    [[3, 3, 3],
     [2, 2, 2]
     [1, 1, 1]]

    seq_range_expand
    [[0, 1, 2]
     [0, 1, 2],
     [0, 1, 2]]

    masks
    [[True, True, True],
     [True, True, False],
     [True, False, False]]
    """
    if max_len is None:
        max_len = sequence_length.max()
    batch_size = sequence_length.size(0)

    # [max_len]
    seq_range = torch.arange(0, max_len).long()  # [0, 1, ... max_len-1]

    # [batch_size, max_len]
    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
    seq_range_expand = to_var(seq_range_expand)

    # [batch_size, max_len]
    seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand)

    # [batch_size, max_len]
    masks = seq_range_expand < seq_length_expand

    return masks


================================================
FILE: model/utils/pad.py
================================================
import torch
from torch.autograd import Variable
from .convert import to_var


def pad(tensor, length):
    if isinstance(tensor, Variable):
        var = tensor
        if length > var.size(0):
            return torch.cat([var,
                              torch.zeros(length - var.size(0), *var.size()[1:]).cuda()])
        else:
            return var
    else:
        if length > tensor.size(0):
            return torch.cat([tensor,
                              torch.zeros(length - tensor.size(0), *tensor.size()[1:]).cuda()])
        else:
            return tensor


def pad_and_pack(tensor_list):
    length_list = ([t.size(0) for t in tensor_list])
    max_len = max(length_list)
    padded = [pad(t, max_len) for t in tensor_list]
    packed = torch.stack(padded, 0)
    return packed, length_list


================================================
FILE: model/utils/probability.py
================================================
import torch
import numpy as np
from .convert import to_var


def normal_logpdf(x, mean, var):
    """
    Args:
        x: (Variable, FloatTensor) [batch_size, dim]
        mean: (Variable, FloatTensor) [batch_size, dim] or [batch_size] or [1]
        var: (Variable, FloatTensor) [batch_size, dim]: positive value
    Return:
        log_p: (Variable, FloatTensor) [batch_size]
    """

    pi = to_var(torch.FloatTensor([np.pi]))
    return 0.5 * torch.sum(-torch.log(2.0 * pi) - torch.log(var) - ((x - mean).pow(2) / var), dim=1)


def normal_kl_div(mu1, var1,
                  mu2=to_var(torch.FloatTensor([0.0])),
                  var2=to_var(torch.FloatTensor([1.0]))):
    one = to_var(torch.FloatTensor([1.0]))
    return torch.sum(0.5 * (torch.log(var2) - torch.log(var1)
                            + (var1 + (mu1 - mu2).pow(2)) / var2 - one), 1)


================================================
FILE: model/utils/tensorboard.py
================================================
from tensorboardX import SummaryWriter

class TensorboardWriter(SummaryWriter):
    def __init__(self, logdir):
        """
        Extended SummaryWriter Class from tensorboard-pytorch (tensorbaordX)
        https://github.com/lanpa/tensorboard-pytorch/blob/master/tensorboardX/writer.py

        Internally calls self.file_writer
        """
        super(TensorboardWriter, self).__init__(logdir)
        self.logdir = self.file_writer.get_logdir()

    def update_parameters(self, module, step_i):
        """
        module: nn.Module
        """
        for name, param in module.named_parameters():
            self.add_histogram(name, param.clone().cpu().data.numpy(), step_i)

    def update_loss(self, loss, step_i, name='loss'):
        self.add_scalar(name, loss, step_i)

    def update_histogram(self, values, step_i, name='hist'):
        self.add_histogram(name, values, step_i)


================================================
FILE: model/utils/time_track.py
================================================
import time
from functools import partial


def base_time_desc_decorator(method, desc='test_description'):
    def timed(*args, **kwargs):

        # Print Description
        # print('#' * 50)
        print(desc)
        # print('#' * 50 + '\n')

        # Calculation Runtime
        start = time.time()

        # Run Method
        try:
            result = method(*args, **kwargs)
        except TypeError:
            result = method(**kwargs)

        # Print Runtime
        print('Done! It took {:.2} secs\n'.format(time.time() - start))

        if result is not None:
            return result

    return timed


def time_desc_decorator(desc): return partial(base_time_desc_decorator, desc=desc)


@time_desc_decorator('this is description')
def time_test(arg, kwarg='this is kwarg'):
    time.sleep(3)
    print('Inside of time_test')
    print('printing arg: ', arg)
    print('printing kwarg: ',  kwarg)


@time_desc_decorator('this is second description')
def no_arg_method():
    print('this method has no argument')


if __name__ == '__main__':
    time_test('hello', kwarg=3)
    time_test(3)
    no_arg_method()


================================================
FILE: model/utils/tokenizer.py
================================================
import re


def clean_str(string):
    """
    Tokenization/string cleaning for all datasets except for SST.
    Every dataset is lower cased except for TREC
    """
    string = re.sub(r"[^A-Za-z0-9,!?\'\`\.]", " ", string)
    string = re.sub(r"\.{3}", " ...", string)
    string = re.sub(r"\'s", " \'s", string)
    string = re.sub(r"\'ve", " \'ve", string)
    string = re.sub(r"n\'t", " n\'t", string)
    string = re.sub(r"\'re", " \'re", string)
    string = re.sub(r"\'d", " \'d", string)
    string = re.sub(r"\'ll", " \'ll", string)
    string = re.sub(r",", " , ", string)
    string = re.sub(r"!", " ! ", string)
    string = re.sub(r"\?", " ? ", string)
    string = re.sub(r"\s{2,}", " ", string)
    return string.strip().lower()


class Tokenizer():
    def __init__(self, tokenizer='whitespace', clean_string=True):
        self.clean_string = clean_string
        tokenizer = tokenizer.lower()

        # Tokenize with whitespace
        if tokenizer == 'whitespace':
            print('Loading whitespace tokenizer')
            self.tokenize = lambda string: string.strip().split()

        if tokenizer == 'regex':
            print('Loading regex tokenizer')
            import re
            pattern = r"[A-Z]{2,}(?![a-z])|[A-Z][a-z]+(?=[A-Z])|[\'\w\-]+"
            self.tokenize = lambda string: re.findall(pattern, string)

        if tokenizer == 'spacy':
            print('Loading SpaCy')
            import spacy
            nlp = spacy.load('en')
            self.tokenize = lambda string: [token.text for token in nlp(string)]

        # Tokenize with punctuations other than periods
        if tokenizer == 'nltk':
            print('Loading NLTK word tokenizer')
            from nltk import word_tokenize

            self.tokenize = word_tokenize

    def __call__(self, string):
        if self.clean_string:
            string = clean_str(string)
        return self.tokenize(string)


if __name__ == '__main__':
    tokenizer = Tokenizer()
    print(tokenizer("Hello, how are you doin'?"))

    tokenizer = Tokenizer('spacy')
    print(tokenizer("Hello, how are you doin'?"))


================================================
FILE: model/utils/vocab.py
================================================
from collections import defaultdict
import pickle
import torch
from torch import Tensor
from torch.autograd import Variable
from nltk import FreqDist
from .convert import to_tensor, to_var

PAD_TOKEN = '<pad>'
UNK_TOKEN = '<unk>'
SOS_TOKEN = '<sos>'
EOS_TOKEN = '<eos>'

PAD_ID, UNK_ID, SOS_ID, EOS_ID = [0, 1, 2, 3]


class Vocab(object):
    def __init__(self, tokenizer=None, max_size=None, min_freq=1):
        """Basic Vocabulary object"""

        self.vocab_size = 0
        self.freqdist = FreqDist()
        self.tokenizer = tokenizer

    def update(self, max_size=None, min_freq=1):
        """
        Initialize id2word & word2id based on self.freqdist
        max_size include 4 special tokens
        """

        # {0: '<pad>', 1: '<unk>', 2: '<sos>', 3: '<eos>'}
        self.id2word = {
            PAD_ID: PAD_TOKEN, UNK_ID: UNK_TOKEN,
            SOS_ID: SOS_TOKEN, EOS_ID: EOS_TOKEN
        }
        # {'<pad>': 0, '<unk>': 1, '<sos>': 2, '<eos>': 3}
        self.word2id = defaultdict(lambda: UNK_ID)  # Not in vocab => return UNK
        self.word2id.update({
            PAD_TOKEN: PAD_ID, UNK_TOKEN: UNK_ID,
            SOS_TOKEN: SOS_ID, EOS_TOKEN: EOS_ID
        })
        # self.word2id = {
        #     PAD_TOKEN: PAD_ID, UNK_TOKEN: UNK_ID,
        #     SOS_TOKEN: SOS_ID, EOS_TOKEN: EOS_ID
        # }

        vocab_size = 4
        min_freq = max(min_freq, 1)

        # Reset frequencies of special tokens
        # [...('<eos>', 0), ('<pad>', 0), ('<sos>', 0), ('<unk>', 0)]
        freqdist = self.freqdist.copy()
        special_freqdist = {token: freqdist[token]
                            for token in [PAD_TOKEN, UNK_TOKEN, SOS_TOKEN, EOS_TOKEN]}
        freqdist.subtract(special_freqdist)

        # Sort: by frequency, then alphabetically
        # Ex) freqdist = { 'a': 4,   'b': 5,   'c': 3 }
        #  =>   sorted = [('b', 5), ('a', 4), ('c', 3)]
        sorted_frequency_counter = sorted(freqdist.items(), key=lambda k_v: k_v[0])
        sorted_frequency_counter.sort(key=lambda k_v: k_v[1], reverse=True)

        for word, freq in sorted_frequency_counter:

            if freq < min_freq or vocab_size == max_size:
                break
            self.id2word[vocab_size] = word
            self.word2id[word] = vocab_size
            vocab_size += 1

        self.vocab_size = vocab_size

    def __len__(self):
        return len(self.id2word)

    def load(self, word2id_path=None, id2word_path=None):
        if word2id_path:
            with open(word2id_path, 'rb') as f:
                word2id = pickle.load(f)
            # Can't pickle lambda function
            self.word2id = defaultdict(lambda: UNK_ID)
            self.word2id.update(word2id)
            self.vocab_size = len(self.word2id)

        if id2word_path:
            with open(id2word_path, 'rb') as f:
                id2word = pickle.load(f)
            self.id2word = id2word

    def add_word(self, word):
        assert isinstance(word, str), 'Input should be str'
        self.freqdist.update([word])

    def add_sentence(self, sentence, tokenized=False):
        if not tokenized:
            sentence = self.tokenizer(sentence)
        for word in sentence:
            self.add_word(word)

    def add_dataframe(self, conversation_df, tokenized=True):
        for conversation in conversation_df:
            for sentence in conversation:
                self.add_sentence(sentence, tokenized=tokenized)

    def pickle(self, word2id_path, id2word_path):
        with open(word2id_path, 'wb') as f:
            pickle.dump(dict(self.word2id), f)

        with open(id2word_path, 'wb') as f:
            pickle.dump(self.id2word, f)

    def to_list(self, list_like):
        """Convert list-like containers to list"""
        if isinstance(list_like, list):
            return list_like

        if isinstance(list_like, Variable):
            return list(to_tensor(list_like).numpy())
        elif isinstance(list_like, Tensor):
            return list(list_like.numpy())

    def id2sent(self, id_list):
        """list of id => list of tokens (Single sentence)"""
        id_list = self.to_list(id_list)
        sentence = []
        for id in id_list:
            word = self.id2word[id]
            if word not in [EOS_TOKEN, SOS_TOKEN, PAD_TOKEN]:
                sentence.append(word)
            if word == EOS_TOKEN:
                break
        return sentence

    def sent2id(self, sentence, var=False):
        """list of tokens => list of id (Single sentence)"""
        id_list = [self.word2id[word] for word in sentence]
        if var:
            id_list = to_var(torch.LongTensor(id_list), eval=True)
        return id_list

    def decode(self, id_list):
        sentence = self.id2sent(id_list)
        return ' '.join(sentence)


================================================
FILE: requirements.txt
================================================
pandas==0.20.3
numpy==1.14.0
gensim==3.1.0
spacy==1.9.0
tqdm==4.15.0
nltk==3.4.5
tensorboardX==1.1
torch==0.4


================================================
FILE: ubuntu_preprocess.py
================================================
# Load the Ubuntu dialog corpus
# Available from here:
# http://dataset.cs.mcgill.ca/ubuntu-corpus-1.0/ubuntu_dialogs.tgz

from multiprocessing import Pool
from pathlib import Path
from collections import OrderedDict
from urllib.request import urlretrieve
import os
import argparse
import tarfile
import pickle

from tqdm import tqdm
import pandas as pd

from model.utils import Tokenizer, Vocab, PAD_TOKEN, SOS_TOKEN, EOS_TOKEN

project_dir = Path(__file__).resolve().parent
datasets_dir = project_dir.joinpath('datasets/')
ubuntu_dir = datasets_dir.joinpath('ubuntu/')

ubuntu_meta_dir = ubuntu_dir.joinpath('meta/')
dialogs_dir = ubuntu_dir.joinpath('dialogs/')

# Tokenizer
tokenizer = Tokenizer('spacy')


def prepare_ubuntu_data():
    """Download and unpack dialogs"""

    tar_filename = 'ubuntu_dialogs.tgz'
    url = 'http://dataset.cs.mcgill.ca/ubuntu-corpus-1.0/ubuntu_dialogs.tgz'
    tarfile_path = ubuntu_dir.joinpath(tar_filename)
    metadata_url = 'https://raw.githubusercontent.com/rkadlec/ubuntu-ranking-dataset-creator/master/src/meta/'

    if not datasets_dir.exists():
        datasets_dir.mkdir()
    if not ubuntu_dir.exists():
        ubuntu_dir.mkdir()
    if not ubuntu_meta_dir.exists():
        ubuntu_meta_dir.mkdir()

    # Prepare Dialog data
    if not dialogs_dir.joinpath("10/1.tst").exists():
        # Download Dialog tarfile
        if not tarfile_path.exists():
            print(f"Downloading {url} to {tarfile_path}")
            urlretrieve(url, tarfile_path)
            print(f"Successfully downloaded {tarfile_path}")

        # Unpack tarfile
        if not dialogs_dir.exists():
            print("Unpacking dialogs ... (This can take 5~10 mins.)")
            with tarfile.open(tarfile_path) as tar:
                tar.extractall(path=ubuntu_dir)
            print("Archive unpacked.")

    # Download metadata
    if not ubuntu_meta_dir.joinpath('trainfiles.csv').exists():
        print('Downloading metadata ... (This can take 5~10 mins.)')
        for filename in ['trainfiles.csv', 'valfiles.csv', 'testfiles.csv']:
            csv_path = ubuntu_meta_dir.joinpath(filename)
            print(f"Downloading {metadata_url+filename} to {csv_path}")
            urlretrieve(metadata_url + filename, csv_path)
            print(f"Successfully downloaded {csv_path}")

    print('Ubuntu Data prepared!')


def get_dialog_path_list(dataset='train'):
    if dataset == 'train':
        filename = 'trainfiles.csv'
    elif dataset == 'test':
        filename = 'testfiles.csv'
    elif dataset == 'valid':
        filename = 'valfiles.csv'
    with open(ubuntu_meta_dir.joinpath(filename)) as f:
        dialog_path_list = []
        for line in f:
            file, dir = line.strip().split(",")
            path = dialogs_dir.joinpath(dir, file)
            dialog_path_list.append(path)

    return dialog_path_list


def read_and_tokenize(dialog_path, min_turn=3):
    """
    Read conversation
    Args:
        dialog_path (str): path of dialog (tsv format)
    Return:
        dialogs: (list of list of str) [dialog_length, sentence_length]
        users: (list of str); [2]
    """
    with open(dialog_path, 'r', encoding='utf-8') as f:

        # Go through the dialog
        first_turn = True
        dialog = []
        users = []
        same_user_utterances = []  # list of sentences of current user
        dialog.append(same_user_utterances)

        for line in f:
            _time, speaker, _listener, sentence = line.split('\t')
            users.append(speaker)

            if first_turn:
                last_speaker = speaker
                first_turn = False

            # Speaker has changed
            if last_speaker != speaker:
                same_user_utterances = []
                dialog.append(same_user_utterances)

            same_user_utterances.append(sentence)
            last_speaker = speaker

        # All users in conversation (len: 2)
        users = list(OrderedDict.fromkeys(users))

        # 1. Concatenate consecutive sentences of single user
        # 2. Tokenize
        dialog = [tokenizer(" ".join(sentence)) for sentence in dialog]

        if len(dialog) < min_turn:
            print(f"Dialog {dialog_path} length ({len(dialog)}) < minimum required length {min_turn}")
            return []

    return dialog #, users


def pad_sentences(conversations, max_sentence_length=30, max_conversation_length=10):

    def pad_tokens(tokens, max_sentence_length=max_sentence_length):
        n_valid_tokens = len(tokens)
        if n_valid_tokens > max_sentence_length - 1:
            tokens = tokens[:max_sentence_length - 1]
        n_pad = max_sentence_length - n_valid_tokens - 1
        tokens = tokens + [EOS_TOKEN] + [PAD_TOKEN] * n_pad
        return tokens

    def pad_conversation(conversation):
        conversation = [pad_tokens(sentence) for sentence in conversation]
        return conversation

    all_padded_sentences = []
    all_sentence_length = []

    for conversation in conversations:
        if len(conversation) > max_conversation_length:
            conversation = conversation[:max_conversation_length]
        sentence_length = [min(len(sentence) + 1, max_sentence_length) # +1 for EOS token
                           for sentence in conversation]
        all_sentence_length.append(sentence_length)

        sentences = pad_conversation(conversation)
        all_padded_sentences.append(sentences)

    # [n_conversations, n_sentence (various), max_sentence_length]
    sentences = all_padded_sentences
    # [n_conversations, n_sentence (various)]
    sentence_length = all_sentence_length
    return sentences, sentence_length


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # Maximum valid length of sentence
    # => SOS/EOS will surround sentence (EOS for source / SOS for target)
    # => maximum length of tensor = max_sentence_length + 1
    parser.add_argument('-s', '--max_sentence_length', type=int, default=30)
    parser.add_argument('-c', '--max_conversation_length', type=int, default=10)

    # Vocabulary
    parser.add_argument('--max_vocab_size', type=int, default=20000)
    parser.add_argument('--min_vocab_frequency', type=int, default=5)

    # Multiprocess
    parser.add_argument('--n_workers', type=int, default=os.cpu_count())

    args = parser.parse_args()

    max_sent_len = args.max_sentence_length
    max_conv_len = args.max_conversation_length
    max_vocab_size = args.max_vocab_size
    min_freq = args.min_vocab_frequency
    n_workers = args.n_workers

    min_turn = 3

    # Download and unpack dialogs if necessary.
    prepare_ubuntu_data()

    def to_pickle(obj, path):
        with open(path, 'wb') as f:
            pickle.dump(obj, f)

    for split_type in ['train', 'test', 'valid']:
        print(f'Processing {split_type} dataset...')
        split_data_dir = ubuntu_dir.joinpath(split_type)
        split_data_dir.mkdir(exist_ok=True)

        # List of dialogs (tsv)
        dialog_path_list = get_dialog_path_list(split_type)

        print(f'Tokenize.. (n_workers={n_workers})')
        def _tokenize_conversation(dialog_path):
            return read_and_tokenize(dialog_path)
        with Pool(n_workers) as pool:
            conversations = list(tqdm(pool.imap(_tokenize_conversation, dialog_path_list),
                                      total=len(dialog_path_list)))

        # Filter too short conversations
        conversations = list(filter(lambda x: len(x) >= min_turn, conversations))

        # conversations: padded_sentences
        # [n_conversations, conversation_length (various), max_sentence_length]

        # sentence_length: list of length of sentences
        # [n_conversations, conversation_length (various)]

        conversation_length = [min(len(conversation), max_conv_len)
                               for conversation in conversations]

        sentences, sentence_length = pad_sentences(
            conversations,
            max_sentence_length=max_sent_len,
            max_conversation_length=max_conv_len)

        print('Saving preprocessed data at', split_data_dir)
        to_pickle(conversation_length, split_data_dir.joinpath('conversation_length.pkl'))
        to_pickle(sentences, split_data_dir.joinpath('sentences.pkl'))
        to_pickle(sentence_length, split_data_dir.joinpath('sentence_length.pkl'))

        if split_type == 'train':
            print('Save Vocabulary...')
            vocab = Vocab(tokenizer)
            vocab.add_dataframe(conversations)
            vocab.update(max_size=max_vocab_size, min_freq=min_freq)

            print('Vocabulary size: ', len(vocab))
            vocab.pickle(ubuntu_dir.joinpath('word2id.pkl'), ubuntu_dir.joinpath('id2word.pkl'))

        print('Done!')
Download .txt
gitextract_erg9sdck/

├── .gitignore
├── LICENSE
├── Readme.md
├── cornell_preprocess.py
├── model/
│   ├── __init__.py
│   ├── configs.py
│   ├── data_loader.py
│   ├── eval.py
│   ├── eval_embed.py
│   ├── layers/
│   │   ├── __init__.py
│   │   ├── beam_search.py
│   │   ├── decoder.py
│   │   ├── encoder.py
│   │   ├── feedforward.py
│   │   ├── loss.py
│   │   └── rnncells.py
│   ├── models.py
│   ├── solver.py
│   ├── train.py
│   └── utils/
│       ├── __init__.py
│       ├── bow.py
│       ├── convert.py
│       ├── embedding_metric.py
│       ├── mask.py
│       ├── pad.py
│       ├── probability.py
│       ├── tensorboard.py
│       ├── time_track.py
│       ├── tokenizer.py
│       └── vocab.py
├── requirements.txt
└── ubuntu_preprocess.py
Download .txt
SYMBOL INDEX (145 symbols across 25 files)

FILE: cornell_preprocess.py
  function prepare_cornell_data (line 21) | def prepare_cornell_data():
  function loadLines (line 46) | def loadLines(fileName,
  function loadConversations (line 72) | def loadConversations(fileName, lines,
  function train_valid_test_split_by_conversation (line 106) | def train_valid_test_split_by_conversation(conversations, split_ratio=[0...
  function tokenize_conversation (line 133) | def tokenize_conversation(lines):
  function pad_sentences (line 138) | def pad_sentences(conversations, max_sentence_length=30, max_conversatio...
  function to_pickle (line 211) | def to_pickle(obj, path):
  function _tokenize_conversation (line 221) | def _tokenize_conversation(conv):

FILE: model/configs.py
  function str2bool (line 21) | def str2bool(v):
  class Config (line 31) | class Config(object):
    method __init__ (line 32) | def __init__(self, **kwargs):
    method __str__ (line 69) | def __str__(self):
  function get_config (line 76) | def get_config(parse=True, **optional_kwargs):

FILE: model/data_loader.py
  class DialogDataset (line 8) | class DialogDataset(Dataset):
    method __init__ (line 9) | def __init__(self, sentences, conversation_length, sentence_length, vo...
    method __getitem__ (line 26) | def __getitem__(self, index):
    method __len__ (line 38) | def __len__(self):
    method sent2id (line 41) | def sent2id(self, sentences):
  function get_loader (line 47) | def get_loader(sentences, conversation_length, sentence_length, vocab, b...

FILE: model/eval.py
  function load_pickle (line 10) | def load_pickle(path):

FILE: model/eval_embed.py
  function load_pickle (line 11) | def load_pickle(path):

FILE: model/layers/beam_search.py
  class Beam (line 5) | class Beam(object):
    method __init__ (line 6) | def __init__(self, batch_size, hidden_size, vocab_size, beam_size, max...
    method update (line 34) | def update(self, score, back_pointer, token_id):  # , h):
    method backtrack (line 43) | def backtrack(self):

FILE: model/layers/decoder.py
  class BaseRNNDecoder (line 12) | class BaseRNNDecoder(nn.Module):
    method __init__ (line 13) | def __init__(self):
    method use_lstm (line 18) | def use_lstm(self):
    method init_token (line 21) | def init_token(self, batch_size, SOS_ID=SOS_ID):
    method init_h (line 26) | def init_h(self, batch_size=None, zero=True, hidden=None):
    method batch_size (line 45) | def batch_size(self, inputs=None, h=None):
    method decode (line 62) | def decode(self, out):
    method forward (line 81) | def forward(self):
    method forward_step (line 85) | def forward_step(self):
    method embed (line 89) | def embed(self, x):
    method beam_decode (line 102) | def beam_decode(self,
  class DecoderRNN (line 212) | class DecoderRNN(BaseRNNDecoder):
    method __init__ (line 213) | def __init__(self, vocab_size, embedding_size,
    method forward_step (line 239) | def forward_step(self, x, h,
    method forward (line 273) | def forward(self, inputs, init_h=None, encoder_outputs=None, input_val...

FILE: model/layers/encoder.py
  class BaseRNNEncoder (line 10) | class BaseRNNEncoder(nn.Module):
    method __init__ (line 11) | def __init__(self):
    method use_lstm (line 16) | def use_lstm(self):
    method init_h (line 22) | def init_h(self, batch_size=None, hidden=None):
    method batch_size (line 39) | def batch_size(self, inputs=None, h=None):
    method forward (line 56) | def forward(self):
  class EncoderRNN (line 60) | class EncoderRNN(BaseRNNEncoder):
    method __init__ (line 61) | def __init__(self, vocab_size, embedding_size,
    method forward (line 91) | def forward(self, inputs, input_length, hidden=None):
  class ContextRNN (line 140) | class ContextRNN(BaseRNNEncoder):
    method __init__ (line 141) | def __init__(self, input_size, context_size, rnn=nn.GRU, num_layers=1,...
    method forward (line 167) | def forward(self, encoder_hidden, conversation_length, hidden=None):
    method step (line 210) | def step(self, encoder_hidden, hidden):

FILE: model/layers/feedforward.py
  class FeedForward (line 5) | class FeedForward(nn.Module):
    method __init__ (line 6) | def __init__(self, input_size, output_size, num_layers=1, hidden_size=...
    method forward (line 19) | def forward(self, input):

FILE: model/layers/loss.py
  function masked_cross_entropy (line 8) | def masked_cross_entropy(logits, target, length, per_example=False):

FILE: model/layers/rnncells.py
  class StackedLSTMCell (line 10) | class StackedLSTMCell(nn.Module):
    method __init__ (line 12) | def __init__(self, num_layers, input_size, rnn_size, dropout):
    method forward (line 22) | def forward(self, x, h_c):
  class StackedGRUCell (line 52) | class StackedGRUCell(nn.Module):
    method __init__ (line 54) | def __init__(self, num_layers, input_size, rnn_size, dropout):
    method forward (line 64) | def forward(self, x, h):

FILE: model/models.py
  class HRED (line 10) | class HRED(nn.Module):
    method __init__ (line 11) | def __init__(self, config):
    method forward (line 52) | def forward(self, input_sentences, input_sentence_length,
    method generate (line 122) | def generate(self, context, sentence_length, n_context):
  class VHRED (line 169) | class VHRED(nn.Module):
    method __init__ (line 170) | def __init__(self, config):
    method prior (line 239) | def prior(self, context_outputs):
    method posterior (line 246) | def posterior(self, context_outputs, encoder_hidden):
    method compute_bow_loss (line 252) | def compute_bow_loss(self, target_conversations):
    method forward (line 259) | def forward(self, sentences, sentence_length,
    method generate (line 350) | def generate(self, context, sentence_length, n_context):
  class VHCR (line 409) | class VHCR(nn.Module):
    method __init__ (line 410) | def __init__(self, config):
    method conv_prior (line 500) | def conv_prior(self):
    method conv_posterior (line 504) | def conv_posterior(self, context_inference_hidden):
    method sent_prior (line 510) | def sent_prior(self, context_outputs, z_conv):
    method sent_posterior (line 517) | def sent_posterior(self, context_outputs, encoder_hidden, z_conv):
    method forward (line 523) | def forward(self, sentences, sentence_length,
    method generate (line 677) | def generate(self, context, sentence_length, n_context):

FILE: model/solver.py
  class Solver (line 18) | class Solver(object):
    method __init__ (line 19) | def __init__(self, config, train_data_loader, eval_data_loader, vocab,...
    method build (line 29) | def build(self, cuda=True):
    method save_model (line 67) | def save_model(self, epoch):
    method load_model (line 73) | def load_model(self, checkpoint):
    method write_summary (line 80) | def write_summary(self, epoch_i):
    method train (line 124) | def train(self):
    method generate_sentence (line 206) | def generate_sentence(self, input_sentences, input_sentence_length,
    method evaluate (line 234) | def evaluate(self):
    method test (line 291) | def test(self):
    method embedding_metric (line 345) | def embedding_metric(self):
  class VariationalSolver (line 420) | class VariationalSolver(Solver):
    method __init__ (line 422) | def __init__(self, config, train_data_loader, eval_data_loader, vocab,...
    method train (line 432) | def train(self):
    method generate_sentence (line 543) | def generate_sentence(self, sentences, sentence_length,
    method evaluate (line 572) | def evaluate(self):
    method importance_sample (line 650) | def importance_sample(self):

FILE: model/train.py
  function load_pickle (line 9) | def load_pickle(path):

FILE: model/utils/bow.py
  function to_bow (line 10) | def to_bow(sentence, vocab_size):
  function bag_of_words_loss (line 29) | def bag_of_words_loss(bow_logits, target_bow, weight=None):

FILE: model/utils/convert.py
  function to_var (line 5) | def to_var(x, on_cpu=False, gpu_id=None, async=False):
  function to_tensor (line 13) | def to_tensor(x):
  function reverse_order (line 19) | def reverse_order(tensor, dim=0):
  function reverse_order_valid (line 35) | def reverse_order_valid(tensor, length_list, dim=0):

FILE: model/utils/embedding_metric.py
  function cosine_similarity (line 4) | def cosine_similarity(s, g):
  function embedding_metric (line 11) | def embedding_metric(samples, ground_truth, word2vec, method='average'):

FILE: model/utils/mask.py
  function sequence_mask (line 5) | def sequence_mask(sequence_length, max_len=None):

FILE: model/utils/pad.py
  function pad (line 6) | def pad(tensor, length):
  function pad_and_pack (line 22) | def pad_and_pack(tensor_list):

FILE: model/utils/probability.py
  function normal_logpdf (line 6) | def normal_logpdf(x, mean, var):
  function normal_kl_div (line 20) | def normal_kl_div(mu1, var1,

FILE: model/utils/tensorboard.py
  class TensorboardWriter (line 3) | class TensorboardWriter(SummaryWriter):
    method __init__ (line 4) | def __init__(self, logdir):
    method update_parameters (line 14) | def update_parameters(self, module, step_i):
    method update_loss (line 21) | def update_loss(self, loss, step_i, name='loss'):
    method update_histogram (line 24) | def update_histogram(self, values, step_i, name='hist'):

FILE: model/utils/time_track.py
  function base_time_desc_decorator (line 5) | def base_time_desc_decorator(method, desc='test_description'):
  function time_desc_decorator (line 31) | def time_desc_decorator(desc): return partial(base_time_desc_decorator, ...
  function time_test (line 35) | def time_test(arg, kwarg='this is kwarg'):
  function no_arg_method (line 43) | def no_arg_method():

FILE: model/utils/tokenizer.py
  function clean_str (line 4) | def clean_str(string):
  class Tokenizer (line 24) | class Tokenizer():
    method __init__ (line 25) | def __init__(self, tokenizer='whitespace', clean_string=True):
    method __call__ (line 53) | def __call__(self, string):

FILE: model/utils/vocab.py
  class Vocab (line 17) | class Vocab(object):
    method __init__ (line 18) | def __init__(self, tokenizer=None, max_size=None, min_freq=1):
    method update (line 25) | def update(self, max_size=None, min_freq=1):
    method __len__ (line 73) | def __len__(self):
    method load (line 76) | def load(self, word2id_path=None, id2word_path=None):
    method add_word (line 90) | def add_word(self, word):
    method add_sentence (line 94) | def add_sentence(self, sentence, tokenized=False):
    method add_dataframe (line 100) | def add_dataframe(self, conversation_df, tokenized=True):
    method pickle (line 105) | def pickle(self, word2id_path, id2word_path):
    method to_list (line 112) | def to_list(self, list_like):
    method id2sent (line 122) | def id2sent(self, id_list):
    method sent2id (line 134) | def sent2id(self, sentence, var=False):
    method decode (line 141) | def decode(self, id_list):

FILE: ubuntu_preprocess.py
  function prepare_ubuntu_data (line 30) | def prepare_ubuntu_data():
  function get_dialog_path_list (line 72) | def get_dialog_path_list(dataset='train'):
  function read_and_tokenize (line 89) | def read_and_tokenize(dialog_path, min_turn=3):
  function pad_sentences (line 137) | def pad_sentences(conversations, max_sentence_length=30, max_conversatio...
  function to_pickle (line 200) | def to_pickle(obj, path):
  function _tokenize_conversation (line 213) | def _tokenize_conversation(dialog_path):
Condensed preview — 32 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (162K chars).
[
  {
    "path": ".gitignore",
    "chars": 1241,
    "preview": "/etc/\ndatasets/\n/cornell_movie_dialogue/\n*.orig\n*.lprof\n\n# Remote edit\n*.ftpconfig\n\n# Byte-compiled / optimized / DLL fi"
  },
  {
    "path": "LICENSE",
    "chars": 1085,
    "preview": "MIT License\n\nCopyright (c) 2018 Center for SuperIntelligence\n\nPermission is hereby granted, free of charge, to any perso"
  },
  {
    "path": "Readme.md",
    "chars": 3498,
    "preview": "# Variational Hierarchical Conversation RNN (VHCR)\n[PyTorch 0.4](https://github.com/pytorch/pytorch) Implementation of ["
  },
  {
    "path": "cornell_preprocess.py",
    "chars": 8421,
    "preview": "# Preprocess cornell movie dialogs dataset\n\nfrom multiprocessing import Pool\nimport argparse\nimport pickle\nimport random"
  },
  {
    "path": "model/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "model/configs.py",
    "chars": 6285,
    "preview": "import os\nimport argparse\nfrom datetime import datetime\nfrom collections import defaultdict\nfrom pathlib import Path\nimp"
  },
  {
    "path": "model/data_loader.py",
    "chars": 2861,
    "preview": "import random\nfrom collections import defaultdict\nfrom torch.utils.data import Dataset, DataLoader\nfrom utils import PAD"
  },
  {
    "path": "model/eval.py",
    "chars": 1176,
    "preview": "from solver import Solver, VariationalSolver\nfrom data_loader import get_loader\nfrom configs import get_config\nfrom util"
  },
  {
    "path": "model/eval_embed.py",
    "chars": 1156,
    "preview": "from solver import Solver, VariationalSolver\nfrom data_loader import get_loader\nfrom configs import get_config\nfrom util"
  },
  {
    "path": "model/layers/__init__.py",
    "chars": 147,
    "preview": "from .encoder import *\nfrom .decoder import *\nfrom .rnncells import StackedLSTMCell, StackedGRUCell\nfrom .loss import *\n"
  },
  {
    "path": "model/layers/beam_search.py",
    "chars": 5997,
    "preview": "import torch\nfrom utils import EOS_ID\n\n\nclass Beam(object):\n    def __init__(self, batch_size, hidden_size, vocab_size, "
  },
  {
    "path": "model/layers/decoder.py",
    "chars": 12185,
    "preview": "import random\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .rnncells import StackedLSTMCe"
  },
  {
    "path": "model/layers/encoder.py",
    "chars": 8195,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn.utils.rnn import pad_packed_sequence, p"
  },
  {
    "path": "model/layers/feedforward.py",
    "chars": 899,
    "preview": "import torch\nimport torch.nn as nn\n\n\nclass FeedForward(nn.Module):\n    def __init__(self, input_size, output_size, num_l"
  },
  {
    "path": "model/layers/loss.py",
    "chars": 1710,
    "preview": "import torch\nfrom torch.nn import functional as F\nimport torch.nn as nn\nfrom utils import to_var, sequence_mask\n\n\n# http"
  },
  {
    "path": "model/layers/rnncells.py",
    "chars": 2652,
    "preview": "# Modified from OpenNMT.py, Z-forcing\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional "
  },
  {
    "path": "model/models.py",
    "chars": 39562,
    "preview": "import torch\nimport torch.nn as nn\nfrom utils import to_var, pad, normal_kl_div, normal_logpdf, bag_of_words_loss, to_bo"
  },
  {
    "path": "model/solver.py",
    "chars": 31651,
    "preview": "from itertools import cycle\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport models\nfrom layers import maske"
  },
  {
    "path": "model/train.py",
    "chars": 1620,
    "preview": "from solver import *\nfrom data_loader import get_loader\nfrom configs import get_config\nfrom utils import Vocab\nimport os"
  },
  {
    "path": "model/utils/__init__.py",
    "chars": 273,
    "preview": "from .convert import *\nfrom .time_track import time_desc_decorator\nfrom .tensorboard import TensorboardWriter\nfrom .voca"
  },
  {
    "path": "model/utils/bow.py",
    "chars": 1102,
    "preview": "import numpy as np\nfrom collections import Counter\nimport torch.nn as nn\nfrom torch.nn import functional as F\nimport tor"
  },
  {
    "path": "model/utils/convert.py",
    "chars": 1922,
    "preview": "import torch\nfrom torch.autograd import Variable\n\n\ndef to_var(x, on_cpu=False, gpu_id=None, async=False):\n    \"\"\"Tensor "
  },
  {
    "path": "model/utils/embedding_metric.py",
    "chars": 1774,
    "preview": "import numpy as np\n\n\ndef cosine_similarity(s, g):\n    similarity = np.sum(s * g, axis=1) / np.sqrt((np.sum(s * s, axis=1"
  },
  {
    "path": "model/utils/mask.py",
    "chars": 1207,
    "preview": "import torch\nfrom .convert import to_var\n\n\ndef sequence_mask(sequence_length, max_len=None):\n    \"\"\"\n    Args:\n        s"
  },
  {
    "path": "model/utils/pad.py",
    "chars": 813,
    "preview": "import torch\nfrom torch.autograd import Variable\nfrom .convert import to_var\n\n\ndef pad(tensor, length):\n    if isinstanc"
  },
  {
    "path": "model/utils/probability.py",
    "chars": 860,
    "preview": "import torch\nimport numpy as np\nfrom .convert import to_var\n\n\ndef normal_logpdf(x, mean, var):\n    \"\"\"\n    Args:\n       "
  },
  {
    "path": "model/utils/tensorboard.py",
    "chars": 895,
    "preview": "from tensorboardX import SummaryWriter\n\nclass TensorboardWriter(SummaryWriter):\n    def __init__(self, logdir):\n        "
  },
  {
    "path": "model/utils/time_track.py",
    "chars": 1132,
    "preview": "import time\nfrom functools import partial\n\n\ndef base_time_desc_decorator(method, desc='test_description'):\n    def timed"
  },
  {
    "path": "model/utils/tokenizer.py",
    "chars": 2115,
    "preview": "import re\n\n\ndef clean_str(string):\n    \"\"\"\n    Tokenization/string cleaning for all datasets except for SST.\n    Every d"
  },
  {
    "path": "model/utils/vocab.py",
    "chars": 4795,
    "preview": "from collections import defaultdict\nimport pickle\nimport torch\nfrom torch import Tensor\nfrom torch.autograd import Varia"
  },
  {
    "path": "requirements.txt",
    "chars": 110,
    "preview": "pandas==0.20.3\nnumpy==1.14.0\ngensim==3.1.0\nspacy==1.9.0\ntqdm==4.15.0\nnltk==3.4.5\ntensorboardX==1.1\ntorch==0.4\n"
  },
  {
    "path": "ubuntu_preprocess.py",
    "chars": 8775,
    "preview": "# Load the Ubuntu dialog corpus\n# Available from here:\n# http://dataset.cs.mcgill.ca/ubuntu-corpus-1.0/ubuntu_dialogs.tg"
  }
]

About this extraction

This page contains the full source code of the ctr4si/A-Hierarchical-Latent-Structure-for-Variational-Conversation-Modeling GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 32 files (152.5 KB), approximately 34.5k tokens, and a symbol index with 145 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!