Full Code of seanie12/Info-HCVAE for AI

master fe6fba1d3686 cached
31 files
272.6 KB
62.5k tokens
265 symbols
1 requests
Download .txt
Showing preview only (284K chars total). Download the full file or copy to clipboard to get everything.
Repository: seanie12/Info-HCVAE
Branch: master
Commit: fe6fba1d3686
Files: 31
Total size: 272.6 KB

Directory structure:
gitextract_n8dcazx2/

├── .gitignore
├── LICENSE
├── README.md
├── qa-eval/
│   ├── distributed_run.py
│   ├── main.py
│   ├── squad_utils.py
│   ├── trainer.py
│   └── utils.py
└── vae/
    ├── eval.py
    ├── generate_qa.py
    ├── main.py
    ├── models.py
    ├── qgevalcap/
    │   ├── .gitignore
    │   ├── README.md
    │   ├── bleu/
    │   │   ├── .gitignore
    │   │   ├── LICENSE
    │   │   ├── bleu.py
    │   │   └── bleu_scorer.py
    │   ├── cider/
    │   │   ├── __init__.py
    │   │   ├── cider.py
    │   │   └── cider_scorer.py
    │   ├── eval.py
    │   ├── meteor/
    │   │   ├── __init__.py
    │   │   ├── meteor-1.5.jar
    │   │   └── meteor.py
    │   └── rouge/
    │       ├── __init__.py
    │       └── rouge.py
    ├── squad_utils.py
    ├── trainer.py
    ├── translate.py
    └── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/


./data

================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# Generating Diverse and Consistent QA pairs from Contexts with Information-Maximizing Hierarchical Conditional VAEs
This is the **Pytorch implementation** for the paper Generating Diverse and Consistent QA pairs from Contexts with
Information-Maximizing Hierarchical Conditional VAEs (**ACL 2020**, **long paper**) :
[[Paper]](https://www.aclweb.org/anthology/2020.acl-main.20/) [[Slide]](https://docs.google.com/presentation/d/1oWPYnYxY1Ne2-tHDbY5W1l34lkdyeS9a/edit?usp=sharing&ouid=100985927142420891504&rtpof=true&sd=true) [[Video]](https://slideslive.com/38928851/generating-diverse-and-consistent-qa-pairs-from-contexts-with-informationmaximizing-hierarchical-conditional-vaes).



## Abstract
<img align="middle" width="800" src="https://github.com/seanie12/Info-HCVAE/blob/master/images/concept.png">
One of the most crucial challenges in question answering (QA) is the scarcity of labeled data, since it is costly to obtain question-answer (QA) pairs for a target text domain with human annotation. An alternative approach to
tackle the problem is to use automatically generated QA pairs from either the problem context or from large amount of unstructured texts (e.g. Wikipedia). In this work, we propose a hierarchical conditional variational autoencoder
(HCVAE) for generating QA pairs given unstructured texts as contexts, while maximizing
the mutual information between generated QA pairs to ensure their consistency. We validate
our Information Maximizing Hierarchical Conditional Variational AutoEncoder (InfoHCVAE) on several benchmark datasets by
evaluating the performance of the QA model (BERT-base) using only the generated QA pairs (QA-based evaluation) or by using both the generated and human-labeled pairs (semisupervised learning) for training, against stateof-the-art baseline models. The results show that our model obtains impressive performance gains over all baselines on both tasks,
using only a fraction of data for training.

__Contribution of this work__
- We propose a novel hierarchical variational framework for generating diverse QA pairs from a single context, which is, to our knowledge, the first probabilistic generative model for questionanswer pair generation (QAG). 
- We propose an InfoMax regularizer which effectively enforces the consistency between the
generated QA pairs, by maximizing their mutual information. This is a novel approach in resolving consistency between QA pairs for QAG.
- We evaluate our framework on several benchmark datasets by either training a new model entirely using generated QA pairs (QA-based evaluation), or use both ground-truth and generated QA pairs (semi-supervised QA). Our model
achieves impressive performances on both tasks, largely outperforming existing QAG baselines.


## Dependencies
This code is written in Python. Dependencies include
* python >= 3.6
* pytorch >= 1.4
* json-lines
* tqdm
* [pytorch_scatter](https://github.com/rusty1s/pytorch_scatter)
* [transfomers](https://github.com/huggingface/transformers)


## Download SQuAD 
Download data from [here](https://drive.google.com/file/d/1FKqjEQ1AEnKI5xfVA8zFTwcLJgQ8pMoq/view?usp=sharing). It contains SQuAD training file(data/squad/train-v1.1.json) and our own dev/test split(data/squad/my_dev.json, data/squad/my_test.json). We preprocess it and convert to examples.pkl and features.pkl. Those pickle files are in data/pickle-file folder. If you want to download the original data, run the following commands

```bash
mkdir squad
wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json -O ./squad/train-v1.1.json
wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json -O ./squad/dev-v1.1.json
```

## Train Info-HCVAE
Train Info-HCVAE with the following command. The checkpoint will be save at ./save/vae-checkpoint.
```bash
cd vae
python main.py
```
## Generate QA pairs 
Generate QA pairs from unlabeled paragraphs. If you generate QA pairs from SQuAD, use option --squad.
```bash
cd vae
python translate.py --data_file "DATA DIRECTORY for paragraph" --checkpoint "directory for Info-HCVAE model" --output_file "output file directory" --k "the number of QA pairs to sample for each paragraph" --ratio "the percentage of context to use"
```

## QA-based-Evaluation (QAE) 
It requires **3 1080ti GPUS (11GB memory)** to reproduce the results. You should download data from [here](https://drive.google.com/file/d/1FKqjEQ1AEnKI5xfVA8zFTwcLJgQ8pMoq/view?usp=sharing) and place it under the root directory. Uncompress it and the "data" folder contains all the files required for QAE and Semi-supervised Learning.
```bash
cd qa-eval
python main.py --devices 0_1_2 --pretrain_file $PATH_TO_qaeval --unlabel_ratio 0.1 --lazy_loader --batch_size 24
```

## Semi-Supervised Learning for SQuAD
It requires **4 1080ti GPUS (11GB memory)** As QAE, you should download the data from [here](https://drive.google.com/file/d/1FKqjEQ1AEnKI5xfVA8zFTwcLJgQ8pMoq/view?usp=sharing) and place it under the root directory.
```bash
cd qa-eval
python main.py --devices 0_1_2_3 --pretrain_file $PATH_TO_semieval --unlabel_ratio 1.0 --lazy_loader --batch_size 32
```

## Synthetic QA pairs

<!---
Download data from [here](https://drive.google.com/file/d/1FKqjEQ1AEnKI5xfVA8zFTwcLJgQ8pMoq/view?usp=sharing) and uncompress it under the root directory. The folder data/harv_synthetic_data_qae contains generated QA pairs from Harvesting QA dataset without any filtering. Another folder data/harv_synthetic_data_semi contains the same generated QA pairs but with postprocessing. We replace the generated answer with pretrained BERT QA model if its F1 is lower than the threshold.
-->

Download data from [here](https://drive.google.com/file/d/1FKqjEQ1AEnKI5xfVA8zFTwcLJgQ8pMoq/view?usp=sharing) and uncompress it under the root directory. The folder data/harv_synthetic_data_qae contains generated QA pairs from Harvesting QA dataset without any filtering. Another folder data/harv_synthetic_data_semi contains the same generated QA pairs but with postprocessing. We replace the generated answer with pretrained BERT QA model if its F1 is lower than the threshold.


## Reference
To cite the code/data/paper, please use this BibTex
```bibtex
@inproceedings{lee2020generating,
  title={Generating Diverse and Consistent QA pairs from Contexts with Information-Maximizing Hierarchical Conditional VAEs},
  author={Lee, Dong Bok and Lee, Seanie and Jeong, Woo Tae and Kim, Donghwan and Hwang, Sung Ju},
  booktitle={Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics},
  year={2020}
}
```


================================================
FILE: qa-eval/distributed_run.py
================================================
import torch
import torch.multiprocessing as mp
from trainer import Trainer


def distributed_main(args):
    ngpus_per_node = len(args.devices)
    assert ngpus_per_node <= torch.cuda.device_count(), "GPU device num exceeds max capacity."

    # Since we have ngpus_per_node processes per node, the total world_size
    # needs to be adjusted accordingly
    args.world_size = ngpus_per_node * args.world_size
    # Use torch.multiprocessing.spawn to launch distributed processes: the
    # main_worker process function

    mp.spawn(worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))


def worker(gpu, ngpus_per_node, args):
    trainer = Trainer(args)

    trainer.make_model_env(gpu, ngpus_per_node)
    # model.make_run_env()
    trainer.train()


================================================
FILE: qa-eval/main.py
================================================
import argparse
import linecache
import os
import pickle
import subprocess
import time

import torch
from torch.utils.data import Dataset, TensorDataset

from distributed_run import distributed_main
from trainer import Trainer


class HarvestingQADataset(Dataset):
    def __init__(self, filename, ratio):
        self.filename = filename
        self.total_size = int(int(subprocess.check_output("wc -l " + filename, shell=True).split()[0]) * ratio)

    def __getitem__(self, idx):
        line = linecache.getline(self.filename, idx + 1)
        str_loaded = line.split("\t")

        input_ids = str_loaded[0].split()
        input_mask = str_loaded[1].split()
        segment_ids = str_loaded[2].split()
        start_position = str_loaded[3]
        end_position = str_loaded[4]

        input_ids = torch.tensor([int(idx) for idx in input_ids], dtype=torch.long)
        input_mask = torch.tensor([int(idx) for idx in input_mask], dtype=torch.long)
        segment_ids = torch.tensor([int(idx) for idx in segment_ids], dtype=torch.long)
        start_position = torch.tensor([int(start_position)], dtype=torch.long)
        end_position = torch.tensor([int(end_position)], dtype=torch.long)

        return input_ids, input_mask, segment_ids, start_position, end_position

    def __len__(self):
        return self.total_size


def main(args):

    args.workers = int(args.workers)

    args.dev_features_file = "../data/pickle-file/dev_features.pkl"
    args.dev_examples_file = "../data/pickle-file/dev_examples.pkl"
    args.dev_json_file = "../data/squad/my_dev.json"
    args.test_features_file = "../data/pickle-file/test_features.pkl"
    args.test_examples_file = "../data/pickle-file/test_examples.pkl"
    args.test_json_file = "../data/squad/my_test.json"

    args.distributed = True

    if args.debug:
        args.pretrain_epochs = 1

    if args.unlabel_ratio > 1.0:
        args.unlabel_ratio = 1.0

    args.devices = [int(gpu) for gpu in args.devices.split('_')]
    args.use_cuda = args.use_cuda and torch.cuda.is_available()

    if args.lazy_loader:
        args.pretrain_dataset = HarvestingQADataset(args.pretrain_file, args.unlabel_ratio)
    else:
        with open(args.pretrain_file, "rb") as f:
            features = pickle.load(f)
        features = features[:int(len(features) * args.unlabel_ratio)]
        all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
        all_seg_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
        all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
        all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
        args.pretrain_dataset = TensorDataset(all_input_ids, all_input_mask, all_seg_ids, all_start_positions, all_end_positions)

    with open(args.dev_examples_file, "rb") as f:
        args.dev_examples = pickle.load(f)
    with open(args.dev_features_file, "rb") as f:
        features = pickle.load(f)
    args.dev_features = features
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_seg_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    args.dev_dataset = TensorDataset(all_input_ids, all_input_mask, all_seg_ids)

    with open(args.test_examples_file, "rb") as f:
        args.test_examples = pickle.load(f)
    with open(args.test_features_file, "rb") as f:
        features = pickle.load(f)
    args.test_features = features
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_seg_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    args.test_dataset = TensorDataset(all_input_ids, all_input_mask, all_seg_ids)

    distributed_main(args)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--debug", action="store_true", help="debugging mode")

    # preprocess option
    parser.add_argument("--max_seq_length", default=384, type=int, help="max sequence length")
    parser.add_argument("--max_query_length", default=64, type=int, help="max query length")
    parser.add_argument("--doc_stride", default=128, type=int)
    parser.add_argument("--do_lower_case", default=True, help="do lower case on text")

    # training option
    parser.add_argument("--bert_model", default="bert-base-uncased", type=str)
    parser.add_argument("--pretrain_epochs", default=2, type=int, help="number of epochs")
    parser.add_argument("--batch_size", default=24, type=int, help="batch size")
    parser.add_argument("--pretrain_lr", default=5e-5, type=float)
    parser.add_argument("--max_grad_norm", default=5.0, type=float)
    parser.add_argument("--warmup_steps", default=0, type=int)
    parser.add_argument("--adam_epsilon", default=1e-8, type=float)
    parser.add_argument("--weight_decay", default=0.0, type=float)
    parser.add_argument("--unlabel_ratio", default=0.1, type=float)

    # directory option
    parser.add_argument("--lazy_loader", action="store_true", help="lazy loader")
    parser.add_argument("--pretrain_file",
    default="../data/synthetic_data/vae98/harv1.0/0.2_replaced_1.0_harv_features.txt",
    type=str, help="path of training data file")
    # gpu option
    parser.add_argument("--use_cuda", default=True, help="use cuda or not")
    parser.add_argument("--devices", type=str, default='0_1_2_3', help="gpu device ids to use")
    parser.add_argument("--workers", default=4, help="Number of processes(workers) per node." "It should be equal to the number of gpu devices to use in one node")
    parser.add_argument("--world_size", default=1, help="Number of total workers. Initial value should be set to the number of nodes." "Final value will be Num.nodes * Num.devices")
    parser.add_argument("--rank", default=0, help="The priority rank of current node.")
    parser.add_argument("--dist_backend", default="nccl", help="Backend communication method. NCCL is used for DistributedDataParallel")
    parser.add_argument("--dist_url", default="tcp://127.0.0.1:9990", help="DistributedDataParallel server")
    parser.add_argument("--multiprocessing_distributed", default=True, help="Use multiprocess distribution or not")
    parser.add_argument("--random_seed", default=2019, help="random state (seed)")
    args = parser.parse_args()

    main(args)


================================================
FILE: qa-eval/squad_utils.py
================================================
import collections
import gzip
import json
import math
import re
import string
import sys
from copy import deepcopy

import json_lines
import numpy as np
from transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
from tqdm import tqdm


class SquadExample(object):
    """
       A single training/test example for the Squad dataset.
       For examples without an answer, the start and end position are -1.
       """
    def __init__(self,
                 qas_id,
                 question_text,
                 doc_tokens,
                 orig_answer_text=None,
                 start_position=None,
                 end_position=None,
                 is_impossible=None):
        self.qas_id = qas_id
        self.question_text = question_text
        self.doc_tokens = doc_tokens
        self.orig_answer_text = orig_answer_text
        self.start_position = start_position
        self.end_position = end_position
        self.is_impossible = is_impossible

    def __str__(self):
        return self.__repr__()

    def __repr__(self):
        s = ""
        s += "qas_id: %s" % self.qas_id
        s += ", question_text: %s" % self.question_text
        s += ", doc_tokens: [%s]" % " ".join(self.doc_tokens)
        if self.start_position:
            s += ", start_position: %d" % self.start_position
        if self.end_position:
            s += ", end_position: %d" % self.end_position
        if self.is_impossible:
            s += ", is_impossible: %r" % self.is_impossible
        return s



class InputFeatures(object):
    """A single set of features of data."""
    def __init__(self,
                 unique_id,
                 example_index,
                 doc_span_index,
                 tokens,
                 token_to_orig_map,
                 token_is_max_context,
                 input_ids,
                 c_ids,
                 context_tokens,
                 q_ids,
                 q_tokens,
                 answer_text,
                 tag_ids,
                 input_mask,
                 segment_ids,
                 context_segment_ids=None,
                 noq_start_position=None,
                 noq_end_position=None,
                 start_position=None,
                 end_position=None,
                 is_impossible=None):
        self.unique_id = unique_id
        self.example_index = example_index
        self.doc_span_index = doc_span_index
        self.tokens = tokens
        self.token_to_orig_map = token_to_orig_map
        self.token_is_max_context = token_is_max_context
        self.input_ids = input_ids
        self.c_ids = c_ids
        self.context_tokens = context_tokens
        self.q_ids = q_ids
        self.q_tokens = q_tokens
        self.answer_text = answer_text
        self.tag_ids = tag_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.context_segment_ids = context_segment_ids
        self.noq_start_position = noq_start_position
        self.noq_end_position = noq_end_position
        self.start_position = start_position
        self.end_position = end_position
        self.is_impossible = is_impossible


def convert_examples_to_features(examples, tokenizer, max_seq_length,
                                 doc_stride, max_query_length, is_training):
    """Loads a data file into a list of `InputBatch`s."""

    unique_id = 1000000000

    features = []
    for (example_index, example) in tqdm(enumerate(examples), total=len(examples)):
        query_tokens = tokenizer.tokenize(example.question_text)

        if len(query_tokens) > max_query_length:
            query_tokens = query_tokens[0:max_query_length]

        tok_to_orig_index = []
        orig_to_tok_index = []
        all_doc_tokens = []
        for (i, token) in enumerate(example.doc_tokens):
            orig_to_tok_index.append(len(all_doc_tokens))
            sub_tokens = tokenizer.tokenize(token)
            for sub_token in sub_tokens:
                tok_to_orig_index.append(i)
                all_doc_tokens.append(sub_token)

        tok_start_position = None
        tok_end_position = None
        if is_training and example.is_impossible:
            tok_start_position = -1
            tok_end_position = -1
        if is_training and not example.is_impossible:
            tok_start_position = orig_to_tok_index[example.start_position]
            if example.end_position < len(example.doc_tokens) - 1:
                tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
            else:
                tok_end_position = len(all_doc_tokens) - 1
            (tok_start_position, tok_end_position) = _improve_answer_span(
                all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
                example.orig_answer_text)

        # The -3 accounts for [CLS], [SEP] and [SEP]
        max_tokens_for_doc = max_seq_length - len(query_tokens) - 3

        # We can have documents that are longer than the maximum sequence length.
        # To deal with this we do a sliding window approach, where we take chunks
        # of the up to our max length with a stride of `doc_stride`.
        _DocSpan = collections.namedtuple(  # pylint: disable=invalid-name
            "DocSpan", ["start", "length"])
        doc_spans = []
        start_offset = 0
        while start_offset < len(all_doc_tokens):
            length = len(all_doc_tokens) - start_offset
            if length > max_tokens_for_doc:
                length = max_tokens_for_doc
            doc_spans.append(_DocSpan(start=start_offset, length=length))
            if start_offset + length == len(all_doc_tokens):
                break
            start_offset += min(length, doc_stride)

        for (doc_span_index, doc_span) in enumerate(doc_spans):
            tokens = []
            token_to_orig_map = {}
            token_is_max_context = {}
            segment_ids = []
            tokens.append("[CLS]")
            segment_ids.append(0)
            for token in query_tokens:
                tokens.append(token)
                segment_ids.append(0)
            tokens.append("[SEP]")
            segment_ids.append(0)

            context_tokens = list()
            context_tokens.append("[CLS]")
            for i in range(doc_span.length):
                split_token_index = doc_span.start + i
                token_to_orig_map[len(
                    tokens)] = tok_to_orig_index[split_token_index]

                is_max_context = _check_is_max_context(doc_spans, doc_span_index,
                                                       split_token_index)
                token_is_max_context[len(tokens)] = is_max_context
                tokens.append(all_doc_tokens[split_token_index])
                segment_ids.append(1)
                context_tokens.append(all_doc_tokens[split_token_index])
            tokens.append("[SEP]")
            segment_ids.append(1)
            context_tokens.append("[SEP]")

            input_ids = tokenizer.convert_tokens_to_ids(tokens)

            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [1] * len(input_ids)

            # Zero-pad up to the sequence length.
            while len(input_ids) < max_seq_length:
                input_ids.append(0)
                input_mask.append(0)
                segment_ids.append(0)

            input_ids = np.asarray(input_ids, dtype=np.int32)
            input_mask = np.asarray(input_mask, dtype=np.uint8)
            segment_ids = np.asarray(segment_ids, dtype=np.uint8)

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length

            start_position = None
            end_position = None

            if is_training and not example.is_impossible:
                # For training, if our document chunk does not contain an annotation
                # we throw it out, since there is nothing to predict.
                doc_start = doc_span.start
                doc_end = doc_span.start + doc_span.length - 1
                out_of_span = False
                if not (tok_start_position >= doc_start and
                        tok_end_position <= doc_end):
                    out_of_span = True
                if out_of_span:
                    start_position = 0
                    end_position = 0
                else:
                    doc_offset = len(query_tokens) + 2
                    start_position = tok_start_position - doc_start + doc_offset
                    end_position = tok_end_position - doc_start + doc_offset

                if out_of_span:
                    continue

            if is_training and example.is_impossible:
                start_position = 0
                end_position = 0
            c_ids = tokenizer.convert_tokens_to_ids(context_tokens)

            while len(c_ids) < max_seq_length:
                c_ids.append(0)
            c_ids = np.asarray(c_ids, dtype=np.int32)

            features.append(
                InputFeatures(
                    unique_id=unique_id,
                    example_index=example_index,
                    doc_span_index=doc_span_index,
                    tokens=tokens,
                    token_to_orig_map=token_to_orig_map,
                    token_is_max_context=token_is_max_context,
                    input_ids=input_ids,
                    input_mask=input_mask,
                    c_ids=c_ids,
                    context_tokens=None,
                    q_ids=None,
                    q_tokens=None,
                    answer_text=example.orig_answer_text,
                    tag_ids=None,
                    segment_ids=segment_ids,
                    noq_start_position=None,
                    noq_end_position=None,
                    start_position=start_position,
                    end_position=end_position,
                    is_impossible=example.is_impossible))
            unique_id += 1

    return features


def convert_examples_to_harv_features(examples, tokenizer, max_seq_length,
                                      doc_stride, max_query_length, is_training):
    """Loads a data file into a list of `InputBatch`s.
       each example only contains a sequence of ids for context(paragraph)
    """

    unique_id = 1000000000

    features = []
    for example in tqdm(examples, total=len(examples)):
        query_tokens = tokenizer.tokenize(example.question_text)

        if len(query_tokens) > max_query_length:
            query_tokens = query_tokens[0:max_query_length]

        tok_to_orig_index = []
        orig_to_tok_index = []
        all_doc_tokens = []
        for (i, token) in enumerate(example.doc_tokens):
            orig_to_tok_index.append(len(all_doc_tokens))
            sub_tokens = tokenizer.tokenize(token)
            for sub_token in sub_tokens:
                tok_to_orig_index.append(i)
                all_doc_tokens.append(sub_token)

        tok_start_position = None
        tok_end_position = None
        if is_training and example.is_impossible:
            tok_start_position = -1
            tok_end_position = -1
        if is_training and not example.is_impossible:
            tok_start_position = orig_to_tok_index[example.start_position]
            if example.end_position < len(example.doc_tokens) - 1:
                tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
            else:
                tok_end_position = len(all_doc_tokens) - 1
            (tok_start_position, tok_end_position) = _improve_answer_span(
                all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
                example.orig_answer_text)

        # The -3 accounts for [CLS], [SEP] and [SEP]
        max_tokens_for_doc = max_seq_length - len(query_tokens) - 3

        # We can have documents that are longer than the maximum sequence length.
        # To deal with this we do a sliding window approach, where we take chunks
        # of the up to our max length with a stride of `doc_stride`.
        _DocSpan = collections.namedtuple(  # pylint: disable=invalid-name
            "DocSpan", ["start", "length"])
        doc_spans = []
        start_offset = 0
        while start_offset < len(all_doc_tokens):
            length = len(all_doc_tokens) - start_offset
            if length > max_tokens_for_doc:
                length = max_tokens_for_doc
            doc_spans.append(_DocSpan(start=start_offset, length=length))
            if start_offset + length == len(all_doc_tokens):
                break
            start_offset += min(length, doc_stride)

        for (doc_span_index, doc_span) in enumerate(doc_spans):
            tokens = []
            token_to_orig_map = {}
            token_is_max_context = {}
            segment_ids = []
            tokens.append("[CLS]")
            segment_ids.append(0)
            for token in query_tokens:
                tokens.append(token)
                segment_ids.append(0)
            tokens.append("[SEP]")

            context_tokens = list()
            context_tokens.append("[CLS]")
            for i in range(doc_span.length):
                split_token_index = doc_span.start + i
                token_to_orig_map[len(
                    tokens)] = tok_to_orig_index[split_token_index]

                is_max_context = _check_is_max_context(doc_spans, doc_span_index,
                                                       split_token_index)
                token_is_max_context[len(tokens)] = is_max_context
                tokens.append(all_doc_tokens[split_token_index])
                context_tokens.append(all_doc_tokens[split_token_index])

            tokens.append("[SEP]")
            context_tokens.append("[SEP]")

            if is_training and not example.is_impossible:
                # For training, if our document chunk does not contain an annotation
                # we throw it out, since there is nothing to predict.
                doc_start = doc_span.start
                doc_end = doc_span.start + doc_span.length - 1
                out_of_span = False
                if not (tok_start_position >= doc_start and
                        tok_end_position <= doc_end):
                    out_of_span = True
                if out_of_span:
                    continue

            c_ids = tokenizer.convert_tokens_to_ids(context_tokens)

            while len(c_ids) < max_seq_length:
                c_ids.append(0)

            features.append(
                InputFeatures(
                    unique_id=None,
                    example_index=None,
                    doc_span_index=None,
                    tokens=None,
                    token_to_orig_map=None,
                    token_is_max_context=None,
                    input_ids=None,
                    input_mask=None,
                    c_ids=c_ids,
                    context_tokens=None,
                    q_ids=None,
                    q_tokens=None,
                    answer_text=None,
                    tag_ids=None,
                    segment_ids=None,
                    noq_start_position=None,
                    noq_end_position=None,
                    start_position=None,
                    end_position=None,
                    is_impossible=None))
            unique_id += 1

    return features


def convert_examples_to_features_answer_id(examples, tokenizer, max_seq_length,
                                           doc_stride, max_query_length, max_ans_length, is_training):
    """Loads a data file into a list of `InputBatch`s.
       In addition to the original InputFeature class, it contains 
       c_ids: ids for context
       tag ids: indicate the answer span of context,
       noq_start_position: start position of answer in context without concatenation of question
       noq_end_position: end position of answer in context without concatenation of question
    """

    unique_id = 1000000000

    features = []
    for (example_index, example) in tqdm(enumerate(examples), total=len(examples)):
        query_tokens = tokenizer.tokenize(example.question_text)

        if len(query_tokens) > max_query_length:
            query_tokens = query_tokens[0:max_query_length]

        tok_to_orig_index = []
        orig_to_tok_index = []
        all_doc_tokens = []
        for (i, token) in enumerate(example.doc_tokens):
            orig_to_tok_index.append(len(all_doc_tokens))
            sub_tokens = tokenizer.tokenize(token)
            for sub_token in sub_tokens:
                tok_to_orig_index.append(i)
                all_doc_tokens.append(sub_token)

        tok_start_position = None
        tok_end_position = None
        if is_training and example.is_impossible:
            tok_start_position = -1
            tok_end_position = -1
        if is_training and not example.is_impossible:
            tok_start_position = orig_to_tok_index[example.start_position]
            if example.end_position < len(example.doc_tokens) - 1:
                tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
            else:
                tok_end_position = len(all_doc_tokens) - 1
            (tok_start_position, tok_end_position) = _improve_answer_span(
                all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
                example.orig_answer_text)

        # The -3 accounts for [CLS], [SEP] and [SEP]
        max_tokens_for_doc = max_seq_length - len(query_tokens) - 3

        # We can have documents that are longer than the maximum sequence length.
        # To deal with this we do a sliding window approach, where we take chunks
        # of the up to our max length with a stride of `doc_stride`.
        _DocSpan = collections.namedtuple(  # pylint: disable=invalid-name
            "DocSpan", ["start", "length"])
        doc_spans = []
        start_offset = 0
        while start_offset < len(all_doc_tokens):
            length = len(all_doc_tokens) - start_offset
            if length > max_tokens_for_doc:
                length = max_tokens_for_doc
            doc_spans.append(_DocSpan(start=start_offset, length=length))
            if start_offset + length == len(all_doc_tokens):
                break
            start_offset += min(length, doc_stride)

        for (doc_span_index, doc_span) in enumerate(doc_spans):
            tokens = []
            token_to_orig_map = {}
            token_is_max_context = {}
            segment_ids = []
            tokens.append("[CLS]")
            segment_ids.append(0)
            for token in query_tokens:
                tokens.append(token)
                segment_ids.append(0)
            tokens.append("[SEP]")
            segment_ids.append(0)

            context_tokens = list()
            context_tokens.append("[CLS]")
            for i in range(doc_span.length):
                split_token_index = doc_span.start + i
                token_to_orig_map[len(
                    tokens)] = tok_to_orig_index[split_token_index]

                is_max_context = _check_is_max_context(doc_spans, doc_span_index,
                                                       split_token_index)
                token_is_max_context[len(tokens)] = is_max_context
                tokens.append(all_doc_tokens[split_token_index])
                segment_ids.append(1)
                context_tokens.append(all_doc_tokens[split_token_index])
            tokens.append("[SEP]")
            segment_ids.append(1)
            context_tokens.append("[SEP]")

            input_ids = tokenizer.convert_tokens_to_ids(tokens)

            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [1] * len(input_ids)

            # Zero-pad up to the sequence length.
            while len(input_ids) < max_seq_length:
                input_ids.append(0)
                input_mask.append(0)
                segment_ids.append(0)

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length

            start_position = None
            end_position = None
            noq_start_position = None
            noq_end_position = None

            if is_training and not example.is_impossible:
                # For training, if our document chunk does not contain an annotation
                # we throw it out, since there is nothing to predict.
                doc_start = doc_span.start
                doc_end = doc_span.start + doc_span.length - 1
                out_of_span = False
                if not (tok_start_position >= doc_start and
                        tok_end_position <= doc_end):
                    out_of_span = True
                if out_of_span:
                    start_position = 0
                    end_position = 0
                    noq_start_position = 0
                    noq_end_position = 0
                else:
                    doc_offset = len(query_tokens) + 2
                    start_position = tok_start_position - doc_start + doc_offset
                    end_position = tok_end_position - doc_start + doc_offset

                    # plus one for [CLS] token
                    noq_start_position = tok_start_position - doc_start + 1
                    noq_end_position = tok_end_position - doc_start + 1

                # skip the context that does not contain any answer span
                if out_of_span:
                    continue

            if is_training and example.is_impossible:
                start_position = 0
                end_position = 0
                noq_start_position = 0
                noq_end_position = 0
            q_tokens = deepcopy(query_tokens)[:max_query_length - 2]
            q_tokens.insert(0, "[CLS]")
            q_tokens.append("[SEP]")
            q_ids = tokenizer.convert_tokens_to_ids(q_tokens)
            c_ids = tokenizer.convert_tokens_to_ids(context_tokens)

            # pad up to maximum length
            while len(q_ids) < max_query_length:
                q_ids.append(0)

            while len(c_ids) < max_seq_length:
                c_ids.append(0)

            # answer_text = example.orig_answer_text

            # answer_tokens = tokenizer.tokenize(answer_text)[:max_ans_length - 2]
            # answer_tokens.insert(0, "[CLS]")
            # answer_tokens.append("[SEP]")
            # answer_ids = tokenizer.convert_tokens_to_ids(answer_tokens)

            # while len(answer_ids) < max_ans_length:
            #     answer_ids.append(0)

            context_segment_ids = [0] * len(c_ids)
            for answer_idx in range(noq_start_position, noq_end_position + 1):
                context_segment_ids[answer_idx] = 1
            # BIO tagging scheme
            tag_ids = [0] * len(c_ids)  # Outside
            if noq_start_position is not None and noq_end_position is not None:
                tag_ids[noq_start_position] = 1  # Begin
                # Inside tag
                for idx in range(noq_start_position + 1, noq_end_position + 1):
                    tag_ids[idx] = 2

            assert len(tag_ids) == len(c_ids), "length of tag :{}, length of c :{}".format(
                len(tag_ids), len(c_ids))
            features.append(
                InputFeatures(
                    unique_id=unique_id,
                    example_index=example_index,
                    doc_span_index=doc_span_index,
                    tokens=tokens,
                    token_to_orig_map=token_to_orig_map,
                    token_is_max_context=token_is_max_context,
                    input_ids=input_ids,
                    input_mask=input_mask,
                    c_ids=c_ids,
                    context_tokens=context_tokens,
                    q_ids=q_ids,
                    q_tokens=q_tokens,
                    answer_text=example.orig_answer_text,
                    tag_ids=tag_ids,
                    segment_ids=segment_ids,
                    context_segment_ids=context_segment_ids,
                    noq_start_position=noq_start_position,
                    noq_end_position=noq_end_position,
                    start_position=start_position,
                    end_position=end_position,
                    is_impossible=example.is_impossible))
            unique_id += 1

    return features



def read_examples(input_file, debug=False, is_training=False):
    # Read data
    unproc_data = []
    with gzip.open(input_file, 'rt', encoding='utf-8') as f:  # opening file in binary(rb) mode
        for item in json_lines.reader(f):
            # print(item) #or use print(item['X']) for printing specific data
            unproc_data.append(item)

    # Delete header
    unproc_data = unproc_data[1:]
    if debug:
        unproc_data = unproc_data[:100]

    examples = []
    skip_tags = ['<Table>', '<Tr>', '<Td>', '<Ol>', '<Ul>', '<Li>']
    for item in unproc_data:
        # in case of NQ dataset, context containing tags is excluded for training
        context = item["context"]
        skip_flag = False
        for tag in skip_tags:
            if tag in context:
                skip_flag = True
                break
        if skip_flag and is_training:
            continue

        doc_tokens = []
        for token in item["context_tokens"]:
            if token[0] in ['[TLE]', '[PAR]', '[DOC]']:
                token[0] = '[SEP]'
            doc_tokens.append(token[0])

        # 2. qas
        for qa in item['qas']:
            qas_id = qa['qid']
            question_text = qa['question']

            # Only take the first answer
            answer = qa['detected_answers'][0]
            orig_answer_text = answer['text']
            # Only take the first span
            start_position = answer['token_spans'][0][0]
            end_position = answer['token_spans'][0][1]

            example = SquadExample(
                qas_id=qas_id,
                question_text=question_text,
                doc_tokens=doc_tokens,
                orig_answer_text=orig_answer_text,
                start_position=start_position,
                end_position=end_position)
            examples.append(example)

    return examples


def read_squad_examples(input_file, is_training, version_2_with_negative=False,
                        debug=False, reduce_size=False, ratio=1.0):
    """Read a SQuAD json file into a list of SquadExample."""
    with open(input_file, "r", encoding='utf-8') as reader:
        input_data = json.load(reader)["data"]

    def is_whitespace(c):
        if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
            return True
        return False

    examples = []
    if debug:
        input_data = input_data[:5]

    for entry in input_data:
        paragraphs = entry["paragraphs"]
        for paragraph in paragraphs:
            paragraph_text = paragraph["context"]
            doc_tokens = []
            char_to_word_offset = []
            prev_is_whitespace = True
            for c in paragraph_text:
                if is_whitespace(c):
                    prev_is_whitespace = True
                else:
                    if prev_is_whitespace:
                        doc_tokens.append(c)
                    else:
                        doc_tokens[-1] += c
                    prev_is_whitespace = False
                char_to_word_offset.append(len(doc_tokens) - 1)

            for qa in paragraph["qas"]:
                qas_id = qa["id"]
                question_text = qa["question"]
                start_position = None
                end_position = None
                orig_answer_text = None
                is_impossible = False
                if is_training:
                    if version_2_with_negative:
                        is_impossible = qa["is_impossible"]
                    if not is_impossible:
                        answer = qa["answers"][0]
                        orig_answer_text = answer["text"]
                        answer_offset = answer["answer_start"]
                        answer_length = len(orig_answer_text)
                        start_position = char_to_word_offset[answer_offset]
                        end_position = char_to_word_offset[answer_offset +
                                                           answer_length - 1]
                        # Only add answers where the text can be exactly recovered from the
                        # document. If this CAN'T happen it's likely due to weird Unicode
                        # stuff so we will just skip the example.
                        #
                        # Note that this means for training mode, every example is NOT
                        # guaranteed to be preserved.
                        actual_text = " ".join(
                            doc_tokens[start_position:(end_position + 1)])
                        cleaned_answer_text = " ".join(
                            whitespace_tokenize(orig_answer_text))
                        if actual_text.find(cleaned_answer_text) == -1:
                            continue
                    else:
                        start_position = -1
                        end_position = -1
                        orig_answer_text = ""

                example = SquadExample(
                    qas_id=qas_id,
                    question_text=question_text,
                    doc_tokens=doc_tokens,
                    orig_answer_text=orig_answer_text,
                    start_position=start_position,
                    end_position=end_position,
                    is_impossible=is_impossible)
                examples.append(example)
    return examples


def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
                         orig_answer_text):
    """Returns tokenized answer spans that better match the annotated answer."""

    # The SQuAD annotations are character based. We first project them to
    # whitespace-tokenized words. But then after WordPiece tokenization, we can
    # often find a "better match". For example:
    #
    #   Question: What year was John Smith born?
    #   Context: The leader was John Smith (1895-1943).
    #   Answer: 1895
    #
    # The original whitespace-tokenized answer will be "(1895-1943).". However
    # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
    # the exact answer, 1895.
    #
    # However, this is not always possible. Consider the following:
    #
    #   Question: What country is the top exporter of electornics?
    #   Context: The Japanese electronics industry is the lagest in the world.
    #   Answer: Japan
    #
    # In this case, the annotator chose "Japan" as a character sub-span of
    # the word "Japanese". Since our WordPiece tokenizer does not split
    # "Japanese", we just use "Japanese" as the annotation. This is fairly rare
    # in SQuAD, but does happen.
    tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))

    for new_start in range(input_start, input_end + 1):
        for new_end in range(input_end, new_start - 1, -1):
            text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
            if text_span == tok_answer_text:
                return (new_start, new_end)

    return (input_start, input_end)


def _check_is_max_context(doc_spans, cur_span_index, position):
    """Check if this is the 'max context' doc span for the token."""

    # Because of the sliding window approach taken to scoring documents, a single
    # token can appear in multiple documents. E.g.
    #  Doc: the man went to the store and bought a gallon of milk
    #  Span A: the man went to the
    #  Span B: to the store and bought
    #  Span C: and bought a gallon of
    #  ...
    #
    # Now the word 'bought' will have two scores from spans B and C. We only
    # want to consider the score with "maximum context", which we define as
    # the *minimum* of its left and right context (the *sum* of left and
    # right context will always be the same, of course).
    #
    # In the example the maximum context for 'bought' would be span C since
    # it has 1 left context and 3 right context, while span B has 4 left context
    # and 0 right context.
    best_score = None
    best_span_index = None
    for (span_index, doc_span) in enumerate(doc_spans):
        end = doc_span.start + doc_span.length - 1
        if position < doc_span.start:
            continue
        if position > end:
            continue
        num_left_context = position - doc_span.start
        num_right_context = end - position
        score = min(num_left_context, num_right_context) + \
            0.01 * doc_span.length
        if best_score is None or score > best_score:
            best_score = score
            best_span_index = span_index

    return cur_span_index == best_span_index


def write_predictions(all_examples, all_features, all_results, n_best_size,
                      max_answer_length, do_lower_case, output_prediction_file,
                      verbose_logging, version_2_with_negative, null_score_diff_threshold,
                      noq_position=False):
    """Write final predictions to the json file and log-odds of null if needed."""
    # logger.info("Writing predictions to: %s" % (output_prediction_file))
    # logger.info("Writing nbest to: %s" % (output_nbest_file))

    example_index_to_features = collections.defaultdict(list)
    for feature in all_features:
        example_index_to_features[feature.example_index].append(feature)

    unique_id_to_result = {}
    for result in all_results:
        unique_id_to_result[result.unique_id] = result

    _PrelimPrediction = collections.namedtuple(  # pylint: disable=invalid-name
        "PrelimPrediction",
        ["feature_index", "start_index", "end_index", "start_logit", "end_logit"])

    all_predictions = collections.OrderedDict()
    all_nbest_json = collections.OrderedDict()
    scores_diff_json = collections.OrderedDict()

    for (example_index, example) in enumerate(all_examples):
        features = example_index_to_features[example_index]

        prelim_predictions = []
        # keep track of the minimum score of null start+end of position 0
        score_null = 1000000  # large and positive
        min_null_feature_index = 0  # the paragraph slice with min null score
        null_start_logit = 0  # the start logit at the slice with min null score
        null_end_logit = 0  # the end logit at the slice with min null score
        for (feature_index, feature) in enumerate(features):
            result = unique_id_to_result[feature.unique_id]
            start_indexes = _get_best_indexes(result.start_logits, n_best_size)
            end_indexes = _get_best_indexes(result.end_logits, n_best_size)
            # if we could have irrelevant answers, get the min score of irrelevant
            if version_2_with_negative:
                feature_null_score = result.start_logits[0] + \
                    result.end_logits[0]
                if feature_null_score < score_null:
                    score_null = feature_null_score
                    min_null_feature_index = feature_index
                    null_start_logit = result.start_logits[0]
                    null_end_logit = result.end_logits[0]
            for start_index in start_indexes:
                for end_index in end_indexes:
                    if noq_position:
                        # start and end is computed without question tokens
                        # add length of q and and -1 for [CLS] in q_ids
                        q_ids = feature.q_ids
                        q_len = np.sum(np.sign(q_ids))
                        noq_start_index = start_index
                        noq_end_index = end_index
                        start_index = start_index + q_len - 1
                        end_index = end_index + q_len - 1
                    # We could hypothetically create invalid predictions, e.g., predict
                    # that the start of the span is in the question. We throw out all
                    # invalid predictions.
                    if start_index >= len(feature.tokens):
                        continue
                    if end_index >= len(feature.tokens):
                        continue
                    if start_index not in feature.token_to_orig_map:
                        continue
                    if end_index not in feature.token_to_orig_map:
                        continue
                    if not feature.token_is_max_context.get(start_index, False):
                        continue
                    if end_index < start_index:
                        continue
                    length = end_index - start_index + 1
                    if length > max_answer_length:
                        continue
                    if noq_position:
                        prelim_predictions.append(
                            _PrelimPrediction(
                                feature_index=feature_index,
                                start_index=start_index,
                                end_index=end_index,
                                start_logit=result.start_logits[noq_start_index],
                                end_logit=result.end_logits[noq_end_index]))
                    else:
                        prelim_predictions.append(
                            _PrelimPrediction(
                                feature_index=feature_index,
                                start_index=start_index,
                                end_index=end_index,
                                start_logit=result.start_logits[start_index],
                                end_logit=result.end_logits[end_index]))

        if version_2_with_negative:
            prelim_predictions.append(
                _PrelimPrediction(
                    feature_index=min_null_feature_index,
                    start_index=0,
                    end_index=0,
                    start_logit=null_start_logit,
                    end_logit=null_end_logit))
        prelim_predictions = sorted(
            prelim_predictions,
            key=lambda x: (x.start_logit + x.end_logit),
            reverse=True)

        _NbestPrediction = collections.namedtuple(  # pylint: disable=invalid-name
            "NbestPrediction", ["text", "start_logit", "end_logit"])

        seen_predictions = {}
        nbest = []
        for pred in prelim_predictions:
            if len(nbest) >= n_best_size:
                break
            feature = features[pred.feature_index]
            if pred.start_index > 0:  # this is a non-null prediction
                tok_tokens = feature.tokens[pred.start_index:(
                    pred.end_index + 1)]
                orig_doc_start = feature.token_to_orig_map[pred.start_index]
                orig_doc_end = feature.token_to_orig_map[pred.end_index]
                orig_tokens = example.doc_tokens[orig_doc_start:(
                    orig_doc_end + 1)]
                tok_text = " ".join(tok_tokens)

                # De-tokenize WordPieces that have been split off.
                tok_text = tok_text.replace(" ##", "")
                tok_text = tok_text.replace("##", "")

                # Clean whitespace
                tok_text = tok_text.strip()
                tok_text = " ".join(tok_text.split())
                orig_text = " ".join(orig_tokens)

                final_text = get_final_text(
                    tok_text, orig_text, do_lower_case, verbose_logging)
                if final_text in seen_predictions:
                    continue

                seen_predictions[final_text] = True
            else:
                final_text = ""
                seen_predictions[final_text] = True

            nbest.append(
                _NbestPrediction(text=final_text,
                                 start_logit=pred.start_logit,
                                 end_logit=pred.end_logit))
        # if we didn't include the empty option in the n-best, include it
        if version_2_with_negative:
            if "" not in seen_predictions:
                nbest.append(
                    _NbestPrediction(text="",
                                     start_logit=null_start_logit,
                                     end_logit=null_end_logit))

            # In very rare edge cases we could only have single null prediction.
            # So we just create a nonce prediction in this case to avoid failure.
            if len(nbest) == 1:
                nbest.insert(0, _NbestPrediction(text="empty",
                                                 start_logit=0.0,
                                                 end_logit=0.0))

        # In very rare edge cases we could have no valid predictions. So we
        # just create a nonce prediction in this case to avoid failure.
        if not nbest:
            nbest.append(_NbestPrediction(text="empty",
                                          start_logit=0.0,
                                          end_logit=0.0))

        assert len(nbest) >= 1

        total_scores = []
        best_non_null_entry = None
        for entry in nbest:
            total_scores.append(entry.start_logit + entry.end_logit)
            if not best_non_null_entry:
                if entry.text:
                    best_non_null_entry = entry

        probs = _compute_softmax(total_scores)

        nbest_json = []
        for (i, entry) in enumerate(nbest):
            output = collections.OrderedDict()
            output["text"] = entry.text
            output["probability"] = probs[i]
            output["start_logit"] = entry.start_logit
            output["end_logit"] = entry.end_logit
            nbest_json.append(output)
        assert len(nbest_json) >= 1

        if not version_2_with_negative:
            all_predictions[example.qas_id] = nbest_json[0]["text"]
        else:
            # predict "" iff the null score - the score of best non-null > threshold
            score_diff = score_null - best_non_null_entry.start_logit - (
                best_non_null_entry.end_logit)
            scores_diff_json[example.qas_id] = score_diff
            if score_diff > null_score_diff_threshold:
                all_predictions[example.qas_id] = ""
            else:
                all_predictions[example.qas_id] = best_non_null_entry.text
            all_nbest_json[example.qas_id] = nbest_json
    with open(output_prediction_file, "w") as writer:
        writer.write(json.dumps(all_predictions, indent=4) + "\n")


def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
    """Project the tokenized prediction back to the original text."""

    # When we created the data, we kept track of the alignment between original
    # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
    # now `orig_text` contains the span of our original text corresponding to the
    # span that we predicted.
    #
    # However, `orig_text` may contain extra characters that we don't want in
    # our prediction.
    #
    # For example, let's say:
    #   pred_text = steve smith
    #   orig_text = Steve Smith's
    #
    # We don't want to return `orig_text` because it contains the extra "'s".
    #
    # We don't want to return `pred_text` because it's already been normalized
    # (the SQuAD eval script also does punctuation stripping/lower casing but
    # our tokenizer does additional normalization like stripping accent
    # characters).
    #
    # What we really want to return is "Steve Smith".
    #
    # Therefore, we have to apply a semi-complicated alignment heuristic between
    # `pred_text` and `orig_text` to get a character-to-character alignment. This
    # can fail in certain cases in which case we just return `orig_text`.

    def _strip_spaces(text):
        ns_chars = []
        ns_to_s_map = collections.OrderedDict()
        for (i, c) in enumerate(text):
            if c == " ":
                continue
            ns_to_s_map[len(ns_chars)] = i
            ns_chars.append(c)
        ns_text = "".join(ns_chars)
        return (ns_text, ns_to_s_map)

    # We first tokenize `orig_text`, strip whitespace from the result
    # and `pred_text`, and check if they are the same length. If they are
    # NOT the same length, the heuristic has failed. If they are the same
    # length, we assume the characters are one-to-one aligned.
    tokenizer = BasicTokenizer(do_lower_case=do_lower_case)

    tok_text = " ".join(tokenizer.tokenize(orig_text))

    start_position = tok_text.find(pred_text)
    if start_position == -1:
        # if verbose_logging:
        #     logger.info(
        #         "Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
        return orig_text
    end_position = start_position + len(pred_text) - 1

    (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
    (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)

    if len(orig_ns_text) != len(tok_ns_text):
        # if verbose_logging:
        #     logger.info("Length not equal after stripping spaces: '%s' vs '%s'",
        #                 orig_ns_text, tok_ns_text)
        return orig_text

    # We then project the characters in `pred_text` back to `orig_text` using
    # the character-to-character alignment.
    tok_s_to_ns_map = {}
    for (i, tok_index) in tok_ns_to_s_map.items():
        tok_s_to_ns_map[tok_index] = i

    orig_start_position = None
    if start_position in tok_s_to_ns_map:
        ns_start_position = tok_s_to_ns_map[start_position]
        if ns_start_position in orig_ns_to_s_map:
            orig_start_position = orig_ns_to_s_map[ns_start_position]

    if orig_start_position is None:
        # if verbose_logging:
        #     logger.info("Couldn't map start position")
        return orig_text

    orig_end_position = None
    if end_position in tok_s_to_ns_map:
        ns_end_position = tok_s_to_ns_map[end_position]
        if ns_end_position in orig_ns_to_s_map:
            orig_end_position = orig_ns_to_s_map[ns_end_position]

    if orig_end_position is None:
        # if verbose_logging:
        return orig_text

    output_text = orig_text[orig_start_position:(orig_end_position + 1)]
    return output_text


def _get_best_indexes(logits, n_best_size):
    """Get the n-best logits from a list."""
    index_and_score = sorted(
        enumerate(logits), key=lambda x: x[1], reverse=True)

    best_indexes = []
    for i in range(len(index_and_score)):
        if i >= n_best_size:
            break
        best_indexes.append(index_and_score[i][0])
    return best_indexes


def _compute_softmax(scores):
    """Compute softmax probability over raw logits."""
    if not scores:
        return []

    max_score = None
    for score in scores:
        if max_score is None or score > max_score:
            max_score = score

    exp_scores = []
    total_sum = 0.0
    for score in scores:
        x = math.exp(score - max_score)
        exp_scores.append(x)
        total_sum += x

    probs = []
    for score in exp_scores:
        probs.append(score / total_sum)
    return probs


def write_answer_predictions(all_examples, all_features, all_results, n_best_size,
                             max_answer_length, do_lower_case, output_prediction_file,
                             output_nbest_file, output_null_log_odds_file, verbose_logging,
                             version_2_with_negative, null_score_diff_threshold):
    """Write final predictions to the json file and log-odds of null if needed."""
    # logger.info("Writing predictions to: %s" % (output_prediction_file))
    # logger.info("Writing nbest to: %s" % (output_nbest_file))

    example_index_to_features = collections.defaultdict(list)
    for feature in all_features:
        example_index_to_features[feature.example_index].append(feature)

    unique_id_to_result = {}
    for result in all_results:
        unique_id_to_result[result.unique_id] = result

    _PrelimPrediction = collections.namedtuple(  # pylint: disable=invalid-name
        "PrelimPrediction",
        ["feature_index", "start_index", "end_index", "start_logit", "end_logit"])

    all_predictions = collections.OrderedDict()
    all_nbest_json = collections.OrderedDict()
    scores_diff_json = collections.OrderedDict()

    for (example_index, example) in enumerate(all_examples):
        features = example_index_to_features[example_index]

        prelim_predictions = []
        # keep track of the minimum score of null start+end of position 0
        score_null = 1000000  # large and positive
        min_null_feature_index = 0  # the paragraph slice with min null score
        null_start_logit = 0  # the start logit at the slice with min null score
        null_end_logit = 0  # the end logit at the slice with min null score
        for (feature_index, feature) in enumerate(features):
            result = unique_id_to_result[feature.unique_id]

            start_indexes = _get_best_indexes(result.start_logits, n_best_size)
            end_indexes = _get_best_indexes(result.end_logits, n_best_size)
            # if we could have irrelevant answers, get the min score of irrelevant
            if version_2_with_negative:
                feature_null_score = result.start_logits[0] + \
                    result.end_logits[0]
                if feature_null_score < score_null:
                    score_null = feature_null_score
                    min_null_feature_index = feature_index
                    null_start_logit = result.start_logits[0]
                    null_end_logit = result.end_logits[0]

            # start and end index is from [CLS] [UNK] [SEP] C [SEP]
            # each should be deducted 1 and added length of Q
            offset = len(feature.q_tokens) - 2  # -2 for [CLS] and [SEP]
            for start_index in start_indexes:
                for end_index in end_indexes:
                    start_index = offset + start_index - 1  # -1 for [UNK]
                    end_index = offset + end_index - 1  # -1 for [UNK]
                    # We could hypothetically create invalid predictions, e.g., predict
                    # that the start of the span is in the question. We throw out all
                    # invalid predictions.
                    if start_index >= len(feature.tokens):
                        continue
                    if end_index >= len(feature.tokens):
                        continue
                    if start_index not in feature.token_to_orig_map:
                        continue
                    if end_index not in feature.token_to_orig_map:
                        continue
                    if not feature.token_is_max_context.get(start_index, False):
                        continue
                    if end_index < start_index:
                        continue
                    length = end_index - start_index + 1
                    if length > max_answer_length:
                        continue
                    prelim_predictions.append(
                        _PrelimPrediction(
                            feature_index=feature_index,
                            start_index=start_index,
                            end_index=end_index,
                            start_logit=result.start_logits[start_index],
                            end_logit=result.end_logits[end_index]))
        if version_2_with_negative:
            prelim_predictions.append(
                _PrelimPrediction(
                    feature_index=min_null_feature_index,
                    start_index=0,
                    end_index=0,
                    start_logit=null_start_logit,
                    end_logit=null_end_logit))
        prelim_predictions = sorted(
            prelim_predictions,
            key=lambda x: (x.start_logit + x.end_logit),
            reverse=True)

        _NbestPrediction = collections.namedtuple(  # pylint: disable=invalid-name
            "NbestPrediction", ["text", "start_logit", "end_logit"])

        seen_predictions = {}
        nbest = []
        for pred in prelim_predictions:
            if len(nbest) >= n_best_size:
                break
            feature = features[pred.feature_index]
            if pred.start_index > 0:  # this is a non-null prediction
                tok_tokens = feature.tokens[pred.start_index:(
                    pred.end_index + 1)]
                orig_doc_start = feature.token_to_orig_map[pred.start_index]
                orig_doc_end = feature.token_to_orig_map[pred.end_index]
                orig_tokens = example.doc_tokens[orig_doc_start:(
                    orig_doc_end + 1)]
                tok_text = " ".join(tok_tokens)

                # De-tokenize WordPieces that have been split off.
                tok_text = tok_text.replace(" ##", "")
                tok_text = tok_text.replace("##", "")

                # Clean whitespace
                tok_text = tok_text.strip()
                tok_text = " ".join(tok_text.split())
                orig_text = " ".join(orig_tokens)

                final_text = get_final_text(
                    tok_text, orig_text, do_lower_case, verbose_logging)
                if final_text in seen_predictions:
                    continue

                seen_predictions[final_text] = True
            else:
                final_text = ""
                seen_predictions[final_text] = True

            nbest.append(
                _NbestPrediction(
                    text=final_text,
                    start_logit=pred.start_logit,
                    end_logit=pred.end_logit))
        # if we didn't include the empty option in the n-best, include it
        if version_2_with_negative:
            if "" not in seen_predictions:
                nbest.append(
                    _NbestPrediction(
                        text="",
                        start_logit=null_start_logit,
                        end_logit=null_end_logit))

            # In very rare edge cases we could only have single null prediction.
            # So we just create a nonce prediction in this case to avoid failure.
            if len(nbest) == 1:
                nbest.insert(0,
                             _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))

        # In very rare edge cases we could have no valid predictions. So we
        # just create a nonce prediction in this case to avoid failure.
        if not nbest:
            nbest.append(
                _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))

        assert len(nbest) >= 1

        total_scores = []
        best_non_null_entry = None
        for entry in nbest:
            total_scores.append(entry.start_logit + entry.end_logit)
            if not best_non_null_entry:
                if entry.text:
                    best_non_null_entry = entry

        probs = _compute_softmax(total_scores)

        nbest_json = []
        for (i, entry) in enumerate(nbest):
            output = collections.OrderedDict()
            output["text"] = entry.text
            output["probability"] = probs[i]
            output["start_logit"] = entry.start_logit
            output["end_logit"] = entry.end_logit
            nbest_json.append(output)
        assert len(nbest_json) >= 1

        if not version_2_with_negative:
            all_predictions[example.qas_id] = nbest_json[0]["text"]
        else:
            # predict "" iff the null score - the score of best non-null > threshold
            score_diff = score_null - best_non_null_entry.start_logit - (
                best_non_null_entry.end_logit)
            scores_diff_json[example.qas_id] = score_diff
            if score_diff > null_score_diff_threshold:
                all_predictions[example.qas_id] = ""
            else:
                all_predictions[example.qas_id] = best_non_null_entry.text
        all_nbest_json[example.qas_id] = nbest_json
    with open(output_prediction_file, "w") as writer:
        writer.write(json.dumps(all_predictions, indent=4) + "\n")

    with open(output_nbest_file, "w") as writer:
        writer.write(json.dumps(all_nbest_json, indent=4) + "\n")

    if version_2_with_negative:
        with open(output_null_log_odds_file, "w") as writer:
            writer.write(json.dumps(scores_diff_json, indent=4) + "\n")


def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = collections.Counter(
        prediction_tokens) & collections.Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_match_score(prediction, ground_truth):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)


def evaluate(dataset, predictions):
    f1 = exact_match = total = 0
    for article in dataset:
        for paragraph in article['paragraphs']:
            for qa in paragraph['qas']:
                total += 1
                if qa['id'] not in predictions:
                    message = 'Unanswered question ' + qa['id'] + \
                              ' will receive score 0.'
                    print(message, file=sys.stderr)
                    continue
                ground_truths = list(map(lambda x: x['text'], qa['answers']))
                prediction = predictions[qa['id']]
                exact_match += metric_max_over_ground_truths(
                    exact_match_score, prediction, ground_truths)
                f1 += metric_max_over_ground_truths(
                    f1_score, prediction, ground_truths)

    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total

    return {'exact_match': exact_match, 'f1': f1}


def read_predictions(prediction_file):
    with open(prediction_file) as f:
        predictions = json.load(f)
    return predictions


def read_answers(gold_file):
    answers = {}
    with gzip.open(gold_file, 'rb') as f:
        for i, line in enumerate(f):
            example = json.loads(line)
            if i == 0 and 'header' in example:
                continue
            for qa in example['qas']:
                answers[qa['qid']] = qa['answers']
    return answers


def evaluate_mrqa(answers, predictions, skip_no_answer=False):
    f1 = exact_match = total = 0
    for qid, ground_truths in answers.items():
        if qid not in predictions:
            if not skip_no_answer:
                message = 'Unanswered question %s will receive score 0.' % qid
                print(message)
                total += 1
            continue
        total += 1
        prediction = predictions[qid]
        exact_match += metric_max_over_ground_truths(
            exact_match_score, prediction, ground_truths)
        f1 += metric_max_over_ground_truths(
            f1_score, prediction, ground_truths)

    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total

    return {'exact_match': exact_match, 'f1': f1}


================================================
FILE: qa-eval/trainer.py
================================================
import collections
import math
import os
import time
import json
import socket
from tqdm import tqdm
import numpy as np
import pickle
import random

import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim as optim
from transformers import BertForQuestionAnswering, BertTokenizer, AdamW
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from torch.optim.lr_scheduler import LambdaLR
from torch.nn.utils import clip_grad_norm_

from squad_utils import read_squad_examples, read_examples, \
                        convert_examples_to_features, \
                        write_predictions, read_answers, \
                        evaluate, evaluate_mrqa

from utils import eta, progress_bar, user_friendly_time, time_since

def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
    """ Create a schedule with a learning rate that decreases linearly after
    linearly increasing during a warmup period.
    """
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))

    return LambdaLR(optimizer, lr_lambda, last_epoch)

class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.set_random_seed(random_seed=args.random_seed)
        self.tokenizer = BertTokenizer.from_pretrained(args.bert_model)

    def make_model_env(self, gpu, ngpus_per_node):
        if gpu is not None:
            self.args.gpu = self.args.devices[gpu]

        if self.args.use_cuda and self.args.distributed:
            if self.args.multiprocessing_distributed:
                # For multiprocessing distributed training, rank needs to be the
                # global rank among all the processes
                self.args.rank = self.args.rank * ngpus_per_node + gpu
            dist.init_process_group(backend=self.args.dist_backend, init_method=self.args.dist_url,
                                    world_size=self.args.world_size, rank=self.args.rank)

        if self.args.rank == 0:
            if self.args.debug:
                print("debugging mode on.")

        self.model = BertForQuestionAnswering.from_pretrained(self.args.bert_model)
        if self.args.rank == 0:
            self.get_dev_loader()
            self.get_test_loader()

        self.args.batch_size = int(self.args.batch_size / ngpus_per_node)
        self.get_pretrain_loader()

        self.pretrain_t_total = len(self.pretrain_loader) * self.args.pretrain_epochs

        no_decay = ['bias', 'LayerNorm.weight']
        self.optimizer_grouped_parameters = [
            {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': self.args.weight_decay},
            {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

        self.optimizer = AdamW(self.optimizer_grouped_parameters,
        lr=self.args.pretrain_lr, eps=self.args.adam_epsilon)
        self.scheduler = get_linear_schedule_with_warmup(self.optimizer,
        num_warmup_steps=self.args.warmup_steps, num_training_steps=self.pretrain_t_total)

        if self.args.use_cuda:
            torch.cuda.set_device(self.args.gpu)
            self.model.cuda(self.args.gpu)
            self.args.workers = int((self.args.workers + ngpus_per_node - 1) / ngpus_per_node)
            self.model = DistributedDataParallel(self.model, device_ids=[self.args.gpu],
                                                 find_unused_parameters=True)

        cudnn.benchmark = True

    def get_pretrain_loader(self):
        data = self.args.pretrain_dataset

        self.pretrain_sampler = DistributedSampler(data)
        self.pretrain_loader = DataLoader(data, num_workers=self.args.workers, pin_memory=True,
                                          sampler=self.pretrain_sampler, batch_size=self.args.batch_size)


    def get_dev_loader(self):
        data = self.args.dev_dataset

        self.dev_loader = DataLoader(data, shuffle=False, batch_size=self.args.batch_size)
        self.dev_examples = self.args.dev_examples
        self.dev_features = self.args.dev_features

    def get_test_loader(self):
        data = self.args.test_dataset

        self.test_loader = DataLoader(data, shuffle=False, batch_size=self.args.batch_size)
        self.test_examples = self.args.test_examples
        self.test_features = self.args.test_features

    def train(self):

        self.model.zero_grad()

        for epoch in range(0, self.args.pretrain_epochs):

            num_batches = len(self.pretrain_loader)
            self.pretrain_sampler.set_epoch(epoch)
            start = time.time()

            # pretrain with unsupervised dataset
            for step, batch in enumerate(self.pretrain_loader, start=1):
                self.model.train()
                input_ids, input_mask, seg_ids, start_positions, end_positions = batch

                seq_len = torch.sum(torch.sign(input_ids), 1)
                max_len = torch.max(seq_len)
                input_ids = input_ids[:, :max_len].clone().cuda(self.args.gpu, non_blocking=True)
                input_mask = input_mask[:, :max_len].clone().cuda(self.args.gpu, non_blocking=True)
                seg_ids = seg_ids[:, :max_len].clone().cuda(self.args.gpu, non_blocking=True)
                start_positions = start_positions.clone().cuda(self.args.gpu, non_blocking=True)
                end_positions = end_positions.clone().cuda(self.args.gpu, non_blocking=True)

                inputs = {
                    "input_ids": input_ids,
                    "attention_mask": input_mask,
                    "token_type_ids": seg_ids,
                    "start_positions": start_positions,
                    "end_positions": end_positions
                }
                loss = self.model(**inputs)[0]
                loss.backward()

                clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
                self.optimizer.step()
                self.scheduler.step()
                self.model.zero_grad()

                if self.args.rank == 0:
                    msg = "PRETRAIN {}/{} {} - ETA : {} - LOSS : {:.4f}".format(step,
                                num_batches, progress_bar(step, num_batches),
                                eta(start, step, num_batches),
                                float(loss.item()))
                    print(msg, end="\r")

                if self.args.debug:
                    break

            # save model
            if self.args.rank == 0:
                result_dict = self.evaluate_model(msg)
                em = result_dict["exact_match"]
                f1 = result_dict["f1"]
                print("\nPRETRAIN took {} DEV - F1: {:.4f}, EM: {:.4f}\n"
                      .format(user_friendly_time(time_since(start)), f1, em))

        if self.args.rank == 0:

            result_dict = self.evaluate_model("TEST", False)
            em = result_dict["exact_match"]
            f1 = result_dict["f1"]
            print("\nFINAL TEST - F1: {:.4f}, EM: {:.4f}\n"
                  .format(f1, em))

    def evaluate_model(self, msg, dev=True):
        if dev:
            eval_examples = self.dev_examples
            eval_features = self.dev_features
            eval_loader = self.dev_loader
            eval_file = self.args.dev_json_file
        else:
            eval_examples = self.test_examples
            eval_features = self.test_features
            eval_loader = self.test_loader
            eval_file = self.args.test_json_file

        RawResult = collections.namedtuple("RawResult",
                                           ["unique_id", "start_logits", "end_logits"])
        all_results = []
        example_index = -1
        self.model.eval()
        num_val_batches = len(eval_loader)
        for i, batch in enumerate(eval_loader):
            input_ids, input_mask, seg_ids = batch
            seq_len = torch.sum(torch.sign(input_ids), 1)
            max_len = torch.max(seq_len)

            input_ids = input_ids[:, :max_len].clone().cuda(self.args.gpu, non_blocking=True)
            input_mask = input_mask[:, :max_len].clone().cuda(self.args.gpu, non_blocking=True)
            seg_ids = seg_ids[:, :max_len].clone().cuda(self.args.gpu, non_blocking=True)

            with torch.no_grad():
                inputs = {
                    "input_ids": input_ids,
                    "attention_mask": input_mask,
                    "token_type_ids": seg_ids
                }
                if hasattr(self.model, 'module'):
                    outputs = self.model.module(**inputs)
                else:
                    outputs = self.model(**inputs)
                batch_start_logits, batch_end_logits = outputs[0], outputs[1]
                batch_size = batch_start_logits.size(0)
            for j in range(batch_size):
                example_index += 1
                start_logits = batch_start_logits[j].detach().cpu().tolist()
                end_logits = batch_end_logits[j].detach().cpu().tolist()
                eval_feature = eval_features[example_index]
                unique_id = int(eval_feature.unique_id)
                all_results.append(RawResult(unique_id=unique_id,
                                             start_logits=start_logits,
                                             end_logits=end_logits))
            msg2 = "{} => Evaluating :{}/{}".format(msg, i, num_val_batches)
            print(msg2, end="\r")

        os.makedirs("./save/", exist_ok=True)
        output_prediction_file = os.path.join("./save/", "prediction.json")
        write_predictions(eval_examples, eval_features, all_results,
                          n_best_size=20, max_answer_length=30, do_lower_case=True,
                          output_prediction_file=output_prediction_file,
                          verbose_logging=False,
                          version_2_with_negative=False,
                          null_score_diff_threshold=0,
                          noq_position=False)

        with open(output_prediction_file) as prediction_file:
            predictions = json.load(prediction_file)

        with open(eval_file) as f:
            data_json = json.load(f)
            dataset = data_json["data"]

        result_dict = evaluate(dataset, predictions)
        
        return result_dict

    def set_random_seed(self, random_seed=2019):
        torch.manual_seed(random_seed)
        torch.cuda.manual_seed(random_seed)
        np.random.seed(random_seed)


================================================
FILE: qa-eval/utils.py
================================================
import time
import math
import torch

def time_since(t):
    """ Function for time. """
    return time.time() - t


def progress_bar(completed, total, step=5):
    """ Function returning a string progress bar. """
    percent = int((completed / total) * 100)
    bar = '[='
    arrow_reached = False
    for t in range(step, 101, step):
        if arrow_reached:
            bar += ' '
        else:
            if percent // t != 0:
                bar += '='
            else:
                bar = bar[:-1]
                bar += '>'
                arrow_reached = True
    if percent == 100:
        bar = bar[:-1]
        bar += '='
    bar += ']'
    return bar


def user_friendly_time(s):
    """ Display a user friendly time from number of second. """
    s = int(s)
    if s < 60:
        return "{}s".format(s)

    m = s // 60
    s = s % 60
    if m < 60:
        return "{}m {}s".format(m, s)

    h = m // 60
    m = m % 60
    if h < 24:
        return "{}h {}m {}s".format(h, m, s)

    d = h // 24
    h = h % 24
    return "{}d {}h {}m {}s".format(d, h, m, s)


def eta(start, completed, total):
    """ Function returning an ETA. """
    # Computation
    took = time_since(start)
    time_per_step = took / completed
    remaining_steps = total - completed
    remaining_time = time_per_step * remaining_steps

    return user_friendly_time(remaining_time)


def cal_running_avg_loss(loss, running_avg_loss, decay=0.99):
    if running_avg_loss == 0:
        return loss
    else:
        running_avg_loss = running_avg_loss * decay + (1 - decay) * loss
    return running_avg_loss


def kl_coef(i):
    # coef for KL annealing
    # reaches 1 at i = 22000
    # https://github.com/kefirski/pytorch_RVAE/blob/master/utils/functional.py
    return (math.tanh((i - 3500) / 1000) + 1) / 2


def compute_kernel(x, y):
    x_size = x.size(0)
    y_size = y.size(0)
    dim = x.size(1)
    x = x.unsqueeze(1)  # (x_size, 1, dim)
    y = y.unsqueeze(0)  # (1, y_size, dim)
    tiled_x = x.expand(x_size, y_size, dim)
    tiled_y = y.expand(x_size, y_size, dim)
    kernel_input = (tiled_x - tiled_y).pow(2).mean(2) / float(dim)
    return torch.exp(-kernel_input)  # (x_size, y_size)


def compute_mmd(x, y):
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
    mmd = x_kernel.mean() + y_kernel.mean() - 2 * xy_kernel.mean()
    return mmd


class EMA(object):
    """Exponential moving average of model parameters.
    Args:
        model (torch.nn.Module): Model with parameters whose EMA will be kept.
        decay (float): Decay rate for exponential moving average.
    """

    def __init__(self, model, decay):
        self.decay = decay
        self.shadow = {}
        self.original = {}

        # Register model parameters
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def __call__(self, model, num_updates):
        decay = min(self.decay, (1.0 + num_updates) / (10.0 + num_updates))
        for name, param in model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                new_average = \
                    (1.0 - decay) * param.data + decay * self.shadow[name]
                self.shadow[name] = new_average.clone()

    def assign(self, model):
        """Assign exponential moving average of parameter values to the
        respective parameters.
        Args:
            model (torch.nn.Module): Model to assign parameter values.
        """
        for name, param in model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                self.original[name] = param.data.clone()
                param.data = self.shadow[name]

    def resume(self, model):
        """Restore original parameters to a model. That is, put back
        the values that were in each parameter at the last call to `assign`.
        Args:
            model (torch.nn.Module): Model to assign parameter values.
        """
        for name, param in model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                param.data = self.original[name]


================================================
FILE: vae/eval.py
================================================
import collections
import json
import os

import torch
from transformers import BertTokenizer
from tqdm import tqdm

from qgevalcap.eval import eval_qg
from squad_utils import evaluate, write_predictions
from utils import batch_to_device

def to_string(index, tokenizer):
    tok_tokens = tokenizer.convert_ids_to_tokens(index)
    tok_text = " ".join(tok_tokens)

    # De-tokenize WordPieces that have been split off.
    tok_text = tok_text.replace("[PAD]", "")
    tok_text = tok_text.replace("[SEP]", "")
    tok_text = tok_text.replace("[CLS]", "")
    tok_text = tok_text.replace(" ##", "")
    tok_text = tok_text.replace("##", "")

    # Clean whitespace
    tok_text = tok_text.strip()
    tok_text = " ".join(tok_text.split())
    return tok_text

class Result(object):
    def __init__(self,
                 context,
                 real_question,
                 posterior_question,
                 prior_question,
                 real_answer,
                 posterior_answer,
                 prior_answer,
                 posterior_z_prob,
                 prior_z_prob):
        self.context = context
        self.real_question = real_question
        self.posterior_question = posterior_question
        self.prior_question = prior_question
        self.real_answer = real_answer
        self.posterior_answer = posterior_answer
        self.prior_answer = prior_answer
        self.posterior_z_prob = posterior_z_prob
        self.prior_z_prob = prior_z_prob


def eval_vae(epoch, args, trainer, eval_data):
    tokenizer = BertTokenizer.from_pretrained(args.bert_model)
    RawResult = collections.namedtuple("RawResult",
                                       ["unique_id", "start_logits", "end_logits"])

    eval_loader, eval_examples, eval_features = eval_data

    all_results = []
    qa_results = []
    qg_results = {}
    res_dict = {}
    example_index = -1

    for batch in tqdm(eval_loader, desc="Eval iter", leave=False, position=4):
        c_ids, q_ids, a_ids, start, end = batch_to_device(batch, args.device)
        batch_size = c_ids.size(0)
        batch_c_ids = c_ids.cpu().tolist()
        batch_q_ids = q_ids.cpu().tolist()
        batch_start = start.cpu().tolist()
        batch_end = end.cpu().tolist()

        batch_posterior_q_ids, \
        batch_posterior_start, batch_posterior_end, \
        posterior_z_prob = trainer.generate_posterior(c_ids, q_ids, a_ids)

        batch_start_logits, batch_end_logits \
        = trainer.generate_answer_logits(c_ids, q_ids, a_ids)

        batch_posterior_q_ids, \
        batch_posterior_start, batch_posterior_end = \
        batch_posterior_q_ids.cpu().tolist(), \
        batch_posterior_start.cpu().tolist(), batch_posterior_end.cpu().tolist()
        posterior_z_prob = posterior_z_prob.cpu()

        batch_prior_q_ids, \
        batch_prior_start, batch_prior_end, \
        prior_z_prob = trainer.generate_prior(c_ids)

        batch_prior_q_ids, \
        batch_prior_start, batch_prior_end = \
        batch_prior_q_ids.cpu().tolist(), \
        batch_prior_start.cpu().tolist(), batch_prior_end.cpu().tolist()
        prior_z_prob = prior_z_prob.cpu()

        for i in range(batch_size):
            example_index += 1
            start_logits = batch_start_logits[i].detach().cpu().tolist()
            end_logits = batch_end_logits[i].detach().cpu().tolist()
            eval_feature = eval_features[example_index]
            unique_id = int(eval_feature.unique_id)

            context = to_string(batch_c_ids[i], tokenizer)

            real_question = to_string(batch_q_ids[i], tokenizer)
            posterior_question = to_string(batch_posterior_q_ids[i], tokenizer)
            prior_question = to_string(batch_prior_q_ids[i], tokenizer)

            real_answer = to_string(batch_c_ids[i][batch_start[i]:(batch_end[i] + 1)], tokenizer)
            posterior_answer = to_string(batch_c_ids[i][batch_posterior_start[i]:(batch_posterior_end[i] + 1)], tokenizer)
            prior_answer = to_string(batch_c_ids[i][batch_prior_start[i]:(batch_prior_end[i] + 1)], tokenizer)

            all_results.append(Result(context=context,
                                      real_question=real_question,
                                      posterior_question=posterior_question,
                                      prior_question=prior_question,
                                      real_answer=real_answer,
                                      posterior_answer=posterior_answer,
                                      prior_answer=prior_answer,
                                      posterior_z_prob=posterior_z_prob[i],
                                      prior_z_prob=prior_z_prob[i]))

            qg_results[unique_id] = posterior_question
            res_dict[unique_id] = real_question
            qa_results.append(RawResult(unique_id=unique_id,
                                        start_logits=start_logits,
                                        end_logits=end_logits))

    output_prediction_file = os.path.join(args.model_dir, "pred.json")
    write_predictions(eval_examples, eval_features, qa_results, n_best_size=20,
                      max_answer_length=30, do_lower_case=True,
                      output_prediction_file=output_prediction_file,
                      verbose_logging=False,
                      version_2_with_negative=False,
                      null_score_diff_threshold=0,
                      noq_position=True)

    with open(args.dev_dir) as f:
        dataset_json = json.load(f)
        dataset = dataset_json["data"]
    with open(os.path.join(args.model_dir, "pred.json")) as prediction_file:
        predictions = json.load(prediction_file)
    ret = evaluate(dataset, predictions)
    bleu = eval_qg(res_dict, qg_results)

    return ret, bleu, all_results


================================================
FILE: vae/generate_qa.py
================================================
import argparse
import json
import os

import torch
from torch.utils.data import DataLoader, Dataset, TensorDataset
from tqdm import tqdm
from transformers import BertTokenizer

from models import DiscreteVAE


class CustomDatset(Dataset):
    def __init__(self, tokenizer, input_file, max_length=512):
        self.lines = open(input_file, "r").readlines()
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.num_total = len(self.lines)

    def __getitem__(self, idx):
        context = self.lines[idx].strip()
        tokens = self.tokenizer.tokenize(context)[:self.max_length]
        ids = self.tokenizer.convert_tokens_to_ids(tokens)
        
        # padding up to the maximum length
        while len(ids) < self.max_length:
            ids.append(0)
        ids = torch.tensor(ids, dtype=torch.long)
        
        return ids

    def __len__(self):
        return self.num_total


def main(args):
    tokenizer = BertTokenizer.from_pretrained(args.bert_model)
    data = CustomDatset(tokenizer, args.data_file, args.max_length)
    data_loader = DataLoader(data, shuffle=False, batch_size=args.batch_size)

    device = torch.cuda.current_device()
    checkpoint = torch.load(args.checkpoint, map_location="cpu")
    vae = DiscreteVAE(checkpoint["args"])
    vae.load_state_dict(checkpoint["state_dict"])
    vae.eval()
    vae = vae.to(device)
    
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    output_file = os.path.join(args.output_dir, "synthetic_qa.jsonl")
    
    fw = open(output_file, "w")
    for batch in tqdm(data_loader, total=len(data_loader)):
        c_ids = batch
        c_len = torch.sum(torch.sign(c_ids),1 )
        max_c_len = torch.max(c_len)
        c_ids = c_ids[:, :max_c_len].to(device)

        # sample latent variable K times
        for _ in range(args.k):
            with torch.no_grad():
                _, _, zq, _, za = vae.prior_encoder(c_ids)
                batch_q_ids, batch_start, batch_end = vae.generate(
                    zq, za, c_ids)

            for i in range(c_ids.size(0)):
                _c_ids = c_ids[i].cpu().tolist()
                q_ids = batch_q_ids[i].cpu().tolist()
                start_pos = batch_start[i].item()
                end_pos = batch_end[i].item()
                
                a_ids = _c_ids[start_pos: end_pos+1]
                c_text = tokenizer.decode(_c_ids, replace_special_tokens=True)
                q_text = tokenizer.decode(q_ids, replace_speical_tokens=True)
                a_text = tokenizer.decode(a_ids, replace_special_tokens=True)
                json_dict = {
                    "context":c_text,
                    "question": q_text,
                    "answer": a_text
                }
                fw.write(json.dumps(json_dict) + "\n")
                fw.flush()

    fw.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", default=1004, type=int)
    parser.add_argument("--bert_model", default='bert-base-uncased', type=str)
    parser.add_argument("--max_length", default=384,
                        type=int, help="max context length")
    
    parser.add_argument("--batch_size", default=64,
                        type=int, help="batch_size")
    parser.add_argument("--data_file", type=str,
                        required=True, help="text file of paragraphs")
    parser.add_argument("--checkpoint", default="../save/vae-checkpoint/best_f1_model.pt",
                        type=str, help="checkpoint for vae model")
    parser.add_argument("--output_dir", default="../data/synthetic_data/", type=str)

    parser.add_argument("--ratio", default=1.0, type=float)
    parser.add_argument("--k", default=1, type=int,
                        help="the number of QA pairs for each paragraph")

    args = parser.parse_args()
    main(args)


================================================
FILE: vae/main.py
================================================
import argparse
import os
import random

import numpy as np
import torch
from tqdm import tqdm, trange
from transformers import BertTokenizer

from eval import eval_vae
from trainer import VAETrainer
from utils import batch_to_device, get_harv_data_loader, get_squad_data_loader


def main(args):
    tokenizer = BertTokenizer.from_pretrained(args.bert_model)
    train_loader, _, _ = get_squad_data_loader(tokenizer, args.train_dir,
                                         shuffle=True, args=args)
    eval_data = get_squad_data_loader(tokenizer, args.dev_dir,
                                      shuffle=False, args=args)

    args.device = torch.cuda.current_device()

    trainer = VAETrainer(args)

    loss_log1 = tqdm(total=0, bar_format='{desc}', position=2)
    loss_log2 = tqdm(total=0, bar_format='{desc}', position=3)
    eval_log = tqdm(total=0, bar_format='{desc}', position=5)
    best_eval_log = tqdm(total=0, bar_format='{desc}', position=6)

    print("MODEL DIR: " + args.model_dir)

    best_bleu, best_em, best_f1 = 0.0, 0.0, 0.0
    for epoch in trange(int(args.epochs), desc="Epoch", position=0):
        for batch in tqdm(train_loader, desc="Train iter", leave=False, position=1):
            c_ids, q_ids, a_ids, start_positions, end_positions \
            = batch_to_device(batch, args.device)
            trainer.train(c_ids, q_ids, a_ids, start_positions, end_positions)
            
            str1 = 'Q REC : {:06.4f} A REC : {:06.4f}'
            str2 = 'ZQ KL : {:06.4f} ZA KL : {:06.4f} INFO : {:06.4f}'
            str1 = str1.format(float(trainer.loss_q_rec), float(trainer.loss_a_rec))
            str2 = str2.format(float(trainer.loss_zq_kl), float(trainer.loss_za_kl), float(trainer.loss_info))
            loss_log1.set_description_str(str1)
            loss_log2.set_description_str(str2)

        if epoch > 10:
            metric_dict, bleu, _ = eval_vae(epoch, args, trainer, eval_data)
            f1 = metric_dict["f1"]
            em = metric_dict["exact_match"]
            bleu = bleu * 100
            _str = '{}-th Epochs BLEU : {:02.2f} EM : {:02.2f} F1 : {:02.2f}'
            _str = _str.format(epoch, bleu, em, f1)
            eval_log.set_description_str(_str)
            if em > best_em:
                best_em = em
            if f1 > best_f1:
                best_f1 = f1
                trainer.save(os.path.join(args.model_dir, "best_f1_model.pt"))
            if bleu > best_bleu:
                best_bleu = bleu
                trainer.save(os.path.join(args.model_dir, "best_bleu_model.pt"))

            _str = 'BEST BLEU : {:02.2f} EM : {:02.2f} F1 : {:02.2f}'
            _str = _str.format(best_bleu, best_em, best_f1)
            best_eval_log.set_description_str(_str)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", default=1004, type=int)
    parser.add_argument('--debug', dest='debug', action='store_true')
    parser.add_argument('--train_dir', default='../data/squad/train-v1.1.json')
    parser.add_argument('--dev_dir', default='../data/squad/my_dev.json')
    
    parser.add_argument("--max_c_len", default=384, type=int, help="max context length")
    parser.add_argument("--max_q_len", default=64, type=int, help="max query length")

    parser.add_argument("--model_dir", default="../save/vae-checkpoint", type=str)
    parser.add_argument("--epochs", default=20, type=int)
    parser.add_argument("--lr", default=1e-3, type=float, help="lr")
    parser.add_argument("--batch_size", default=64, type=int, help="batch_size")
    parser.add_argument("--weight_decay", default=0.0, type=float, help="weight decay")
    parser.add_argument("--clip", default=5.0, type=float, help="max grad norm")

    parser.add_argument("--bert_model", default='bert-base-uncased', type=str)
    parser.add_argument('--enc_nhidden', type=int, default=300)
    parser.add_argument('--enc_nlayers', type=int, default=1)
    parser.add_argument('--enc_dropout', type=float, default=0.2)
    parser.add_argument('--dec_a_nhidden', type=int, default=300)
    parser.add_argument('--dec_a_nlayers', type=int, default=1)
    parser.add_argument('--dec_a_dropout', type=float, default=0.2)
    parser.add_argument('--dec_q_nhidden', type=int, default=900)
    parser.add_argument('--dec_q_nlayers', type=int, default=2)
    parser.add_argument('--dec_q_dropout', type=float, default=0.3)
    parser.add_argument('--nzqdim', type=int, default=50)
    parser.add_argument('--nza', type=int, default=20)
    parser.add_argument('--nzadim', type=int, default=10)
    parser.add_argument('--lambda_kl', type=float, default=0.1)
    parser.add_argument('--lambda_info', type=float, default=1.0)

    args = parser.parse_args()

    if args.debug:
        args.model_dir = "./dummy"
    # set model dir
    model_dir = args.model_dir
    os.makedirs(model_dir, exist_ok=True)
    args.model_dir = os.path.abspath(model_dir)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    main(args)


================================================
FILE: vae/models.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch_scatter import scatter_max
from transformers import BertModel, BertTokenizer


def return_mask_lengths(ids):
    mask = torch.sign(ids).float()
    lengths = torch.sum(mask, 1)
    return mask, lengths


def cal_attn(query, memories, mask):
    mask = (1.0 - mask.float()) * -10000.0
    attn_logits = torch.matmul(query, memories.transpose(-1, -2).contiguous())
    attn_logits = attn_logits + mask
    attn_weights = F.softmax(attn_logits, dim=-1)
    attn_outputs = torch.matmul(attn_weights, memories)
    return attn_outputs, attn_logits


def gumbel_softmax(logits, tau=1, hard=False, eps=1e-20, dim=-1):
    # type: (Tensor, float, bool, float, int) -> Tensor

    gumbels = -(torch.empty_like(logits).exponential_() +
                eps).log()  # ~Gumbel(0,1)
    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
    y_soft = gumbels.softmax(dim)

    if hard:
        # Straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0)
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Re-parametrization trick.
        ret = y_soft
    return ret


class CategoricalKLLoss(nn.Module):
    def __init__(self):
        super(CategoricalKLLoss, self).__init__()

    def forward(self, P, Q):
        log_P = P.log()
        log_Q = Q.log()
        kl = (P * (log_P - log_Q)).sum(dim=-1).sum(dim=-1)
        return kl.mean(dim=0)


class GaussianKLLoss(nn.Module):
    def __init__(self):
        super(GaussianKLLoss, self).__init__()

    def forward(self, mu1, logvar1, mu2, logvar2):
        numerator = logvar1.exp() + torch.pow(mu1 - mu2, 2)
        fraction = torch.div(numerator, (logvar2.exp()))
        kl = 0.5 * torch.sum(logvar2 - logvar1 + fraction - 1, dim=1)
        return kl.mean(dim=0)


class Embedding(nn.Module):
    def __init__(self, bert_model):
        super(Embedding, self).__init__()
        bert_embeddings = BertModel.from_pretrained(bert_model).embeddings
        self.word_embeddings = bert_embeddings.word_embeddings
        self.token_type_embeddings = bert_embeddings.token_type_embeddings
        self.position_embeddings = bert_embeddings.position_embeddings
        self.LayerNorm = bert_embeddings.LayerNorm
        self.dropout = bert_embeddings.dropout

    def forward(self, input_ids, token_type_ids=None, position_ids=None):
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        if position_ids is None:
            seq_length = input_ids.size(1)
            position_ids = torch.arange(
                seq_length, dtype=torch.long, device=input_ids.device)
            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

        words_embeddings = self.word_embeddings(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        position_embeddings = self.position_embeddings(position_ids)

        embeddings = words_embeddings + token_type_embeddings + position_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)

        return embeddings


class ContextualizedEmbedding(nn.Module):
    def __init__(self, bert_model):
        super(ContextualizedEmbedding, self).__init__()
        bert = BertModel.from_pretrained(bert_model)
        self.embedding = bert.embeddings
        self.encoder = bert.encoder
        self.num_hidden_layers = bert.config.num_hidden_layers

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        seq_length = input_ids.size(1)
        position_ids = torch.arange(
            seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

        extended_attention_mask = attention_mask.unsqueeze(
            1).unsqueeze(2).float()
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        head_mask = [None] * self.num_hidden_layers

        embedding_output = self.embedding(
            input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
        encoder_outputs = self.encoder(embedding_output,
                                       extended_attention_mask,
                                       head_mask=head_mask)
        sequence_output = encoder_outputs[0]

        return sequence_output


class CustomLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout, bidirectional=False):
        super(CustomLSTM, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.bidirectional = bidirectional
        self.dropout = nn.Dropout(dropout)
        if dropout > 0.0 and num_layers == 1:
            dropout = 0.0

        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
                            num_layers=num_layers, dropout=dropout,
                            bidirectional=bidirectional, batch_first=True)

    def forward(self, inputs, input_lengths, state=None):
        _, total_length, _ = inputs.size()

        input_packed = pack_padded_sequence(inputs, input_lengths,
                                            batch_first=True, enforce_sorted=False)

        self.lstm.flatten_parameters()
        output_packed, state = self.lstm(input_packed, state)

        output = pad_packed_sequence(
            output_packed, batch_first=True, total_length=total_length)[0]
        output = self.dropout(output)

        return output, state


class PosteriorEncoder(nn.Module):
    def __init__(self, embedding, emsize,
                 nhidden, nlayers,
                 nzqdim, nza, nzadim,
                 dropout=0.0):
        super(PosteriorEncoder, self).__init__()

        self.embedding = embedding
        self.nhidden = nhidden
        self.nlayers = nlayers
        self.nzqdim = nzqdim
        self.nza = nza
        self.nzadim = nzadim

        self.encoder = CustomLSTM(input_size=emsize,
                                  hidden_size=nhidden,
                                  num_layers=nlayers,
                                  dropout=dropout,
                                  bidirectional=True)

        

        self.question_attention = nn.Linear(2 * nhidden, 2 * nhidden)
        self.context_attention = nn.Linear(2 * nhidden, 2 * nhidden)
        self.zq_attention = nn.Linear(nzqdim, 2 * nhidden)

        self.zq_linear = nn.Linear(4 * 2 * nhidden, 2 * nzqdim)
        self.za_linear = nn.Linear(nzqdim + 2 * 2 * nhidden, nza * nzadim)

    def forward(self, c_ids, q_ids, a_ids):
        c_mask, c_lengths = return_mask_lengths(c_ids)
        q_mask, q_lengths = return_mask_lengths(q_ids)

        # question enc
        q_embeddings = self.embedding(q_ids)
        q_hs, q_state = self.encoder(q_embeddings, q_lengths)
        q_h = q_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1]
        q_h = q_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden)

        # context enc
        c_embeddings = self.embedding(c_ids)
        c_hs, c_state = self.encoder(c_embeddings, c_lengths)
        c_h = c_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1]
        c_h = c_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden)

        # context and answer enc
        c_a_embeddings = self.embedding(c_ids, a_ids, None)
        c_a_hs, c_a_state = self.encoder(c_a_embeddings, c_lengths)
        c_a_h = c_a_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1]
        c_a_h = c_a_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden)

        # attetion q, c
        mask = c_mask.unsqueeze(1)
        c_attned_by_q, _ = cal_attn(self.question_attention(q_h).unsqueeze(1),
                                    c_hs,
                                    mask)
        c_attned_by_q = c_attned_by_q.squeeze(1)

        # attetion c, q
        mask = q_mask.unsqueeze(1)
        q_attned_by_c, _ = cal_attn(self.context_attention(c_h).unsqueeze(1),
                                    q_hs,
                                    mask)
        q_attned_by_c = q_attned_by_c.squeeze(1)

        h = torch.cat([q_h, q_attned_by_c, c_h, c_attned_by_q], dim=-1)

        zq_mu, zq_logvar = torch.split(self.zq_linear(h), self.nzqdim, dim=1)
        zq = zq_mu + torch.randn_like(zq_mu) * torch.exp(0.5 * zq_logvar)

        # attention zq, c_a
        mask = c_mask.unsqueeze(1)
        c_a_attned_by_zq, _ = cal_attn(self.zq_attention(zq).unsqueeze(1),
                                       c_a_hs,
                                       mask)
        c_a_attned_by_zq = c_a_attned_by_zq.squeeze(1)

        h = torch.cat([zq, c_a_attned_by_zq, c_a_h], dim=-1)

        za_logits = self.za_linear(h).view(-1, self.nza, self.nzadim)
        za_prob = F.softmax(za_logits, dim=-1)
        za = gumbel_softmax(za_logits, hard=True)

        return zq_mu, zq_logvar, zq, za_prob, za


class PriorEncoder(nn.Module):
    def __init__(self, embedding, emsize,
                 nhidden, nlayers,
                 nzqdim, nza, nzadim,
                 dropout=0):
        super(PriorEncoder, self).__init__()

        self.embedding = embedding
        self.nhidden = nhidden
        self.nlayers = nlayers
        self.nzqdim = nzqdim
        self.nza = nza
        self.nzadim = nzadim

        self.context_encoder = CustomLSTM(input_size=emsize,
                                          hidden_size=nhidden,
                                          num_layers=nlayers,
                                          dropout=dropout,
                                          bidirectional=True)

        self.zq_attention = nn.Linear(nzqdim, 2 * nhidden)

        self.zq_linear = nn.Linear(2 * nhidden, 2 * nzqdim)
        self.za_linear = nn.Linear(nzqdim + 2 * 2 * nhidden, nza * nzadim)

    def forward(self, c_ids):
        c_mask, c_lengths = return_mask_lengths(c_ids)

        c_embeddings = self.embedding(c_ids)
        c_hs, c_state = self.context_encoder(c_embeddings, c_lengths)
        c_h = c_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1]
        c_h = c_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden)

        zq_mu, zq_logvar = torch.split(self.zq_linear(c_h), self.nzqdim, dim=1)
        zq = zq_mu + torch.randn_like(zq_mu)*torch.exp(0.5*zq_logvar)

        mask = c_mask.unsqueeze(1)
        c_attned_by_zq, _ = cal_attn(self.zq_attention(zq).unsqueeze(1),
                                     c_hs,
                                     mask)
        c_attned_by_zq = c_attned_by_zq.squeeze(1)

        h = torch.cat([zq, c_attned_by_zq, c_h], dim=-1)

        za_logits = self.za_linear(h).view(-1, self.nza, self.nzadim)
        za_prob = F.softmax(za_logits, dim=-1)
        za = gumbel_softmax(za_logits, hard=True)

        return zq_mu, zq_logvar, zq, za_prob, za

    def interpolation(self, c_ids, zq):

        c_mask, c_lengths = return_mask_lengths(c_ids)

        c_embeddings = self.embedding(c_ids)
        c_hs, c_state = self.context_encoder(c_embeddings, c_lengths)
        c_h = c_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1]
        c_h = c_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden)

        mask = c_mask.unsqueeze(1)
        c_attned_by_zq, _ = cal_attn(
            self.zq_attention(zq).unsqueeze(1), c_hs, mask)
        c_attned_by_zq = c_attned_by_zq.squeeze(1)

        h = torch.cat([zq, c_attned_by_zq, c_h], dim=-1)

        za_logits = self.za_linear(h).view(-1, self.nza, self.nzadim)
        za = gumbel_softmax(za_logits, hard=True)

        return za


class AnswerDecoder(nn.Module):
    def __init__(self, embedding, emsize,
                 nhidden, nlayers,
                 dropout=0.0):
        super(AnswerDecoder, self).__init__()

        self.embedding = embedding

        self.context_lstm = CustomLSTM(input_size=4 * emsize,
                                       hidden_size=nhidden,
                                       num_layers=nlayers,
                                       dropout=dropout,
                                       bidirectional=True)

        self.start_linear = nn.Linear(2 * nhidden, 1)
        self.end_linear = nn.Linear(2 * nhidden, 1)
        self.ls = nn.LogSoftmax(dim=1)

    def forward(self, init_state, c_ids):
        _, max_c_len = c_ids.size()
        c_mask, c_lengths = return_mask_lengths(c_ids)

        H = self.embedding(c_ids, c_mask)
        U = init_state.unsqueeze(1).repeat(1, max_c_len, 1)
        G = torch.cat([H, U, H * U, torch.abs(H - U)], dim=-1)
        M, _ = self.context_lstm(G, c_lengths)

        start_logits = self.start_linear(M).squeeze(-1)
        end_logits = self.end_linear(M).squeeze(-1)

        start_end_mask = (c_mask == 0)
        masked_start_logits = start_logits.masked_fill(
            start_end_mask, -10000.0)
        masked_end_logits = end_logits.masked_fill(start_end_mask, -10000.0)

        return masked_start_logits, masked_end_logits

    def generate(self, init_state, c_ids):
        start_logits, end_logits = self.forward(init_state, c_ids)
        c_mask, _ = return_mask_lengths(c_ids)
        batch_size, max_c_len = c_ids.size()

        mask = torch.matmul(c_mask.unsqueeze(2).float(),
                            c_mask.unsqueeze(1).float())
        mask = torch.triu(mask) == 0
        score = (self.ls(start_logits).unsqueeze(2)
                 + self.ls(end_logits).unsqueeze(1))
        score = score.masked_fill(mask, -10000.0)
        score, start_positions = score.max(dim=1)
        score, end_positions = score.max(dim=1)
        start_positions = torch.gather(start_positions,
                                       1,
                                       end_positions.view(-1, 1)).squeeze(1)

        idxes = torch.arange(0, max_c_len, out=torch.LongTensor(max_c_len))
        idxes = idxes.unsqueeze(0).to(
            start_logits.device).repeat(batch_size, 1)

        start_positions = start_positions.unsqueeze(1)
        start_mask = (idxes >= start_positions).long()
        end_positions = end_positions.unsqueeze(1)
        end_mask = (idxes <= end_positions).long()
        a_ids = start_mask + end_mask - 1

        return a_ids, start_positions.squeeze(1), end_positions.squeeze(1)


class ContextEncoderforQG(nn.Module):
    def __init__(self, embedding, emsize,
                 nhidden, nlayers,
                 dropout=0.0):
        super(ContextEncoderforQG, self).__init__()
        self.embedding = embedding
        self.context_lstm = CustomLSTM(input_size=emsize,
                                       hidden_size=nhidden,
                                       num_layers=nlayers,
                                       dropout=dropout,
                                       bidirectional=True)
        self.context_linear = nn.Linear(2 * nhidden, 2 * nhidden)
        self.fusion = nn.Linear(4 * nhidden, 2 * nhidden, bias=False)
        self.gate = nn.Linear(4 * nhidden, 2 * nhidden, bias=False)

    def forward(self, c_ids, a_ids):
        c_mask, c_lengths = return_mask_lengths(c_ids)
        c_embeddings = self.embedding(c_ids, c_mask, a_ids)
        c_outputs, _ = self.context_lstm(c_embeddings, c_lengths)
        # attention
        mask = torch.matmul(c_mask.unsqueeze(2), c_mask.unsqueeze(1))
        c_attned_by_c, _ = cal_attn(self.context_linear(c_outputs),
                                    c_outputs,
                                    mask)
        c_concat = torch.cat([c_outputs, c_attned_by_c], dim=2)
        c_fused = self.fusion(c_concat).tanh()
        c_gate = self.gate(c_concat).sigmoid()
        c_outputs = c_gate * c_fused + (1 - c_gate) * c_outputs
        return c_outputs


class QuestionDecoder(nn.Module):
    def __init__(self, sos_id, eos_id,
                 embedding, contextualized_embedding, emsize,
                 nhidden, ntokens, nlayers,
                 dropout=0.0,
                 max_q_len=64):
        super(QuestionDecoder, self).__init__()

        self.sos_id = sos_id
        self.eos_id = eos_id
        self.emsize = emsize
        self.embedding = embedding
        self.nhidden = nhidden
        self.ntokens = ntokens
        self.nlayers = nlayers
        # this max_len include sos eos
        self.max_q_len = max_q_len

        self.context_lstm = ContextEncoderforQG(contextualized_embedding, emsize,
                                                nhidden // 2, nlayers, dropout)

        self.question_lstm = CustomLSTM(input_size=emsize,
                                        hidden_size=nhidden,
                                        num_layers=nlayers,
                                        dropout=dropout,
                                        bidirectional=False)

        self.question_linear = nn.Linear(nhidden, nhidden)

        self.concat_linear = nn.Sequential(nn.Linear(2*nhidden, 2*nhidden),
                                           nn.Dropout(dropout),
                                           nn.Linear(2*nhidden, 2*emsize))

        self.logit_linear = nn.Linear(emsize, ntokens, bias=False)

        # fix output word matrix
        self.logit_linear.weight = embedding.word_embeddings.weight
        for param in self.logit_linear.parameters():
            param.requires_grad = False

        self.discriminator = nn.Bilinear(emsize, nhidden, 1)

    def postprocess(self, q_ids):
        eos_mask = q_ids == self.eos_id
        no_eos_idx_sum = (eos_mask.sum(dim=1) == 0).long() * \
            (self.max_q_len - 1)
        eos_mask = eos_mask.cpu().numpy()
        q_lengths = np.argmax(eos_mask, axis=1) + 1
        q_lengths = torch.tensor(q_lengths).to(
            q_ids.device).long() + no_eos_idx_sum
        batch_size, max_len = q_ids.size()
        idxes = torch.arange(0, max_len).to(q_ids.device)
        idxes = idxes.unsqueeze(0).repeat(batch_size, 1)
        q_mask = (idxes < q_lengths.unsqueeze(1))
        q_ids = q_ids.long() * q_mask.long()
        return q_ids

    def forward(self, init_state, c_ids, q_ids, a_ids):
        batch_size, max_q_len = q_ids.size()

        c_outputs = self.context_lstm(c_ids, a_ids)

        c_mask, _ = return_mask_lengths(c_ids)
        q_mask, q_lengths = return_mask_lengths(q_ids)

        # question dec
        q_embeddings = self.embedding(q_ids)
        q_outputs, _ = self.question_lstm(q_embeddings, q_lengths, init_state)

        # attention
        mask = torch.matmul(q_mask.unsqueeze(2), c_mask.unsqueeze(1))
        c_attned_by_q, attn_logits = cal_attn(self.question_linear(q_outputs),
                                              c_outputs,
                                              mask)

        # gen logits
        q_concated = torch.cat([q_outputs, c_attned_by_q], dim=2)
        q_concated = self.concat_linear(q_concated)
        q_maxouted, _ = q_concated.view(
            batch_size, max_q_len, self.emsize, 2).max(dim=-1)
        gen_logits = self.logit_linear(q_maxouted)

        # copy logits
        bq = batch_size * max_q_len
        c_ids = c_ids.unsqueeze(1).repeat(
            1, max_q_len, 1).view(bq, -1).contiguous()
        attn_logits = attn_logits.view(bq, -1).contiguous()
        copy_logits = torch.zeros(bq, self.ntokens).to(c_ids.device)
        copy_logits = copy_logits - 10000.0
        copy_logits, _ = scatter_max(attn_logits, c_ids, out=copy_logits)
        copy_logits = copy_logits.masked_fill(copy_logits == -10000.0, 0)
        copy_logits = copy_logits.view(batch_size, max_q_len, -1).contiguous()

        logits = gen_logits + copy_logits

        # mutual information btw answer and question
        a_emb = c_outputs * a_ids.float().unsqueeze(2)
        a_mean_emb = torch.sum(a_emb, 1) / a_ids.sum(1).unsqueeze(1).float()
        fake_a_mean_emb = torch.cat([a_mean_emb[-1].unsqueeze(0),
                                     a_mean_emb[:-1]], dim=0)

        q_emb = q_maxouted * q_mask.unsqueeze(2)
        q_mean_emb = torch.sum(q_emb, 1) / q_lengths.unsqueeze(1).float()
        fake_q_mean_emb = torch.cat([q_mean_emb[-1].unsqueeze(0),
                                     q_mean_emb[:-1]], dim=0)

        bce_loss = nn.BCEWithLogitsLoss()
        true_logits = self.discriminator(q_mean_emb, a_mean_emb)
        true_labels = torch.ones_like(true_logits)

        fake_a_logits = self.discriminator(q_mean_emb, fake_a_mean_emb)
        fake_q_logits = self.discriminator(fake_q_mean_emb, a_mean_emb)
        fake_logits = torch.cat([fake_a_logits, fake_q_logits], dim=0)
        fake_labels = torch.zeros_like(fake_logits)

        true_loss = bce_loss(true_logits, true_labels)
        fake_loss = 0.5 * bce_loss(fake_logits, fake_labels)
        loss_info = 0.5 * (true_loss + fake_loss)

        return logits, loss_info

    def generate(self, init_state, c_ids, a_ids):
        c_mask, _ = return_mask_lengths(c_ids)
        c_outputs = self.context_lstm(c_ids, a_ids)

        batch_size = c_ids.size(0)

        q_ids = torch.LongTensor([self.sos_id] * batch_size).unsqueeze(1)
        q_ids = q_ids.to(c_ids.device)
        token_type_ids = torch.zeros_like(q_ids)
        position_ids = torch.zeros_like(q_ids)
        q_embeddings = self.embedding(q_ids, token_type_ids, position_ids)

        state = init_state

        # unroll
        all_q_ids = list()
        all_q_ids.append(q_ids)
        for _ in range(self.max_q_len - 1):
            position_ids = position_ids + 1
            q_outputs, state = self.question_lstm.lstm(q_embeddings, state)

            # attention
            mask = c_mask.unsqueeze(1)
            c_attned_by_q, attn_logits = cal_attn(self.question_linear(q_outputs),
                                                  c_outputs,
                                                  mask)

            # gen logits
            q_concated = torch.cat([q_outputs, c_attned_by_q], dim=2)
            q_concated = self.concat_linear(q_concated)
            q_maxouted, _ = q_concated.view(
                batch_size, 1, self.emsize, 2).max(dim=-1)
            gen_logits = self.logit_linear(q_maxouted)

            # copy logits
            attn_logits = attn_logits.squeeze(1)
            copy_logits = torch.zeros(
                batch_size, self.ntokens).to(c_ids.device)
            copy_logits = copy_logits - 10000.0
            copy_logits, _ = scatter_max(attn_logits, c_ids, out=copy_logits)
            copy_logits = copy_logits.masked_fill(copy_logits == -10000.0, 0)

            logits = gen_logits + copy_logits.unsqueeze(1)

            q_ids = torch.argmax(logits, 2)
            all_q_ids.append(q_ids)

            q_embeddings = self.embedding(q_ids, token_type_ids, position_ids)

        q_ids = torch.cat(all_q_ids, 1)
        q_ids = self.postprocess(q_ids)

        return q_ids

    def sample(self, init_state, c_ids, a_ids):
        c_mask, c_lengths = return_mask_lengths(c_ids)
        c_outputs = self.context_lstm(c_ids, a_ids)

        batch_size = c_ids.size(0)

        q_ids = torch.LongTensor([self.sos_id] * batch_size).unsqueeze(1)
        q_ids = q_ids.to(c_ids.device)
        token_type_ids = torch.zeros_like(q_ids)
        position_ids = torch.zeros_like(q_ids)
        q_embeddings = self.embedding(q_ids, token_type_ids, position_ids)

        state = init_state

        # unroll
        all_q_ids = list()
        all_q_ids.append(q_ids)
        for _ in range(self.max_q_len - 1):
            position_ids = position_ids + 1
            q_outputs, state = self.question_lstm.lstm(q_embeddings, state)

            # attention
            mask = c_mask.unsqueeze(1)
            c_attned_by_q, attn_logits = cal_attn(self.question_linear(q_outputs),
                                                  c_outputs,
                                                  mask)

            # gen logits
            q_concated = torch.cat([q_outputs, c_attned_by_q], dim=2)
            q_concated = self.concat_linear(q_concated)
            q_maxouted, _ = q_concated.view(batch_size, 1, self.emsize, 2).max(dim=-1)
            gen_logits = self.logit_linear(q_maxouted)

            # copy logits
            attn_logits = attn_logits.squeeze(1)
            copy_logits = torch.zeros(batch_size, self.ntokens).to(c_ids.device)
            copy_logits = copy_logits - 10000.0
            copy_logits, _ = scatter_max(attn_logits, c_ids, out=copy_logits)
            copy_logits = copy_logits.masked_fill(copy_logits == -10000.0, 0)

            logits = gen_logits + copy_logits.unsqueeze(1)
            logits = logits.squeeze(1)
            logits = self.top_k_top_p_filtering(logits, 2, top_p=0.8)
            probs = F.softmax(logits, dim=-1)
            q_ids = torch.multinomial(probs, num_samples=1)  # [b,1]
            all_q_ids.append(q_ids)

            q_embeddings = self.embedding(q_ids, token_type_ids, position_ids)

        q_ids = torch.cat(all_q_ids, 1)
        q_ids = self.postprocess(q_ids)

        return q_ids


class DiscreteVAE(nn.Module):
    def __init__(self, args):
        super(DiscreteVAE, self).__init__()
        tokenizer = BertTokenizer.from_pretrained(args.bert_model)
        padding_idx = tokenizer.vocab['[PAD]']
        sos_id = tokenizer.vocab['[CLS]']
        eos_id = tokenizer.vocab['[SEP]']
        ntokens = len(tokenizer.vocab)

        bert_model = args.bert_model
        if "large" in bert_model:
            emsize = 1024
        else:
            emsize = 768

        enc_nhidden = args.enc_nhidden
        enc_nlayers = args.enc_nlayers
        enc_dropout = args.enc_dropout
        dec_a_nhidden = args.dec_a_nhidden
        dec_a_nlayers = args.dec_a_nlayers
        dec_a_dropout = args.dec_a_dropout
        self.dec_q_nhidden = dec_q_nhidden = args.dec_q_nhidden
        self.dec_q_nlayers = dec_q_nlayers = args.dec_q_nlayers
        dec_q_dropout = args.dec_q_dropout
        self.nzqdim = nzqdim = args.nzqdim
        self.nza = nza = args.nza
        self.nzadim = nzadim = args.nzadim

        self.lambda_kl = args.lambda_kl
        self.lambda_info = args.lambda_info

        max_q_len = args.max_q_len

        embedding = Embedding(bert_model)
        contextualized_embedding = ContextualizedEmbedding(bert_model)
        # freeze embedding
        for param in embedding.parameters():
            param.requires_grad = False
        for param in contextualized_embedding.parameters():
            param.requires_grad = False

        self.posterior_encoder = PosteriorEncoder(embedding, emsize,
                                                  enc_nhidden, enc_nlayers,
                                                  nzqdim, nza, nzadim,
                                                  enc_dropout)

        self.prior_encoder = PriorEncoder(embedding, emsize,
                                          enc_nhidden, enc_nlayers,
                                          nzqdim, nza, nzadim, enc_dropout)

        self.answer_decoder = AnswerDecoder(contextualized_embedding, emsize,
                                            dec_a_nhidden, dec_a_nlayers,
                                            dec_a_dropout)

        self.question_decoder = QuestionDecoder(sos_id, eos_id,
                                                embedding, contextualized_embedding, emsize,
                                                dec_q_nhidden, ntokens, dec_q_nlayers,
                                                dec_q_dropout,
                                                max_q_len)

        self.q_h_linear = nn.Linear(nzqdim, dec_q_nlayers * dec_q_nhidden)
        self.q_c_linear = nn.Linear(nzqdim, dec_q_nlayers * dec_q_nhidden)
        self.a_linear = nn.Linear(nza * nzadim, emsize, False)

        self.q_rec_criterion = nn.CrossEntropyLoss(ignore_index=padding_idx)
        self.gaussian_kl_criterion = GaussianKLLoss()
        self.categorical_kl_criterion = CategoricalKLLoss()

    def return_init_state(self, zq, za):

        q_init_h = self.q_h_linear(zq)
        q_init_c = self.q_c_linear(zq)
        q_init_h = q_init_h.view(-1, self.dec_q_nlayers,
                                 self.dec_q_nhidden).transpose(0, 1).contiguous()
        q_init_c = q_init_c.view(-1, self.dec_q_nlayers,
                                 self.dec_q_nhidden).transpose(0, 1).contiguous()
        q_init_state = (q_init_h, q_init_c)

        za_flatten = za.view(-1, self.nza * self.nzadim)
        a_init_state = self.a_linear(za_flatten)

        return q_init_state, a_init_state


    def forward(self, c_ids, q_ids, a_ids, start_positions, end_positions):

        posterior_zq_mu, posterior_zq_logvar, posterior_zq, \
            posterior_za_prob, posterior_za \
            = self.posterior_encoder(c_ids, q_ids, a_ids)

        prior_zq_mu, prior_zq_logvar, _, \
            prior_za_prob, _ \
            = self.prior_encoder(c_ids)

        q_init_state, a_init_state = self.return_init_state(
            posterior_zq, posterior_za)

        # answer decoding
        start_logits, end_logits = self.answer_decoder(a_init_state, c_ids)
        # question decoding
        q_logits, loss_info = self.question_decoder(
            q_init_state, c_ids, q_ids, a_ids)

        # q rec loss
        loss_q_rec = self.q_rec_criterion(q_logits[:, :-1, :].transpose(1, 2).contiguous(),
                                          q_ids[:, 1:])

        # a rec loss
        max_c_len = c_ids.size(1)
        a_rec_criterion = nn.CrossEntropyLoss(ignore_index=max_c_len)
        start_positions.clamp_(0, max_c_len)
        end_positions.clamp_(0, max_c_len)
        loss_start_a_rec = a_rec_criterion(start_logits, start_positions)
        loss_end_a_rec = a_rec_criterion(end_logits, end_positions)
        loss_a_rec = 0.5 * (loss_start_a_rec + loss_end_a_rec)

        # kl loss
        loss_zq_kl = self.gaussian_kl_criterion(posterior_zq_mu,
                                                posterior_zq_logvar,
                                                prior_zq_mu,
                                                prior_zq_logvar)

        loss_za_kl = self.categorical_kl_criterion(posterior_za_prob,
                                                   prior_za_prob)

        loss_kl = self.lambda_kl * (loss_zq_kl + loss_za_kl)
        loss_info = self.lambda_info * loss_info

        loss = loss_q_rec + loss_a_rec + loss_kl + loss_info

        return loss, \
            loss_q_rec, loss_a_rec, \
            loss_zq_kl, loss_za_kl, \
            loss_info

    def generate(self, zq, za, c_ids):
        q_init_state, a_init_state = self.return_init_state(zq, za)

        a_ids, start_positions, end_positions = self.answer_decoder.generate(
            a_init_state, c_ids)

        q_ids = self.question_decoder.generate(q_init_state, c_ids, a_ids)

        return q_ids, start_positions, end_positions

    def return_answer_logits(self, zq, za, c_ids):
        _, a_init_state = self.return_init_state(zq, za)

        start_logits, end_logits = self.answer_decoder(a_init_state, c_ids)

        return start_logits, end_logits


================================================
FILE: vae/qgevalcap/.gitignore
================================================
*.pyc


================================================
FILE: vae/qgevalcap/README.md
================================================
## evaluation scripts

./eval.py --out_file \<path to output file\>


================================================
FILE: vae/qgevalcap/bleu/.gitignore
================================================
*.pyc


================================================
FILE: vae/qgevalcap/bleu/LICENSE
================================================
Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.


================================================
FILE: vae/qgevalcap/bleu/bleu.py
================================================
#!/usr/bin/env python
#
# File Name : bleu.py
#
# Description : Wrapper for BLEU scorer.
#
# Creation Date : 06-01-2015
# Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
# Authors : Hao Fang <hfang@uw.edu> and Tsung-Yi Lin <tl483@cornell.edu>

#!/usr/bin/env python

# bleu_scorer.py
# David Chiang <chiang@isi.edu>

# Copyright (c) 2004-2006 University of Maryland. All rights
# reserved. Do not redistribute without permission from the
# author. Not for commercial use.

# Modified by:
# Hao Fang <hfang@uw.edu>
# Tsung-Yi Lin <tl483@cornell.edu>

'''Provides:
cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
'''

import copy
import sys, math, re
from collections import defaultdict

def precook(s, n=4, out=False):
    """Takes a string as input and returns an object that can be given to
    either cook_refs or cook_test. This is optional: cook_refs and cook_test
    can take string arguments as well."""
    words = s.split()
    counts = defaultdict(int)
    for k in range(1,n+1):
        for i in range(len(words)-k+1):
            ngram = tuple(words[i:i+k])
            counts[ngram] += 1
    return (len(words), counts)

def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average"
    '''Takes a list of reference sentences for a single segment
    and returns an object that encapsulates everything that BLEU
    needs to know about them.'''

    reflen = []
    maxcounts = {}
    for ref in refs:
        rl, counts = precook(ref, n)
        reflen.append(rl)
        for (ngram,count) in counts.items():
            maxcounts[ngram] = max(maxcounts.get(ngram,0), count)

    # Calculate effective reference sentence length.
    if eff == "shortest":
        reflen = min(reflen)
    elif eff == "average":
        reflen = float(sum(reflen))/len(reflen)

    ## lhuang: N.B.: leave reflen computaiton to the very end!!

    ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)

    return (reflen, maxcounts)

def cook_test(test, reflen, refmaxcounts, eff=None, n=4):
    '''Takes a test sentence and returns an object that
    encapsulates everything that BLEU needs to know about it.'''

    testlen, counts = precook(test, n, True)

    result = {}

    # Calculate effective reference sentence length.

    if eff == "closest":
        result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
    else: ## i.e., "average" or "shortest" or None
        result["reflen"] = reflen

    result["testlen"] = testlen

    result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)]

    result['correct'] = [0]*n
    for (ngram, count) in counts.items():
        result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)

    return result

class BleuScorer(object):
    """Bleu scorer.
    """

    __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
    # special_reflen is used in oracle (proportional effective ref len for a node).

    def copy(self):
        ''' copy the refs.'''
        new = BleuScorer(n=self.n)
        new.ctest = copy.copy(self.ctest)
        new.crefs = copy.copy(self.crefs)
        new._score = None
        return new

    def __init__(self, test=None, refs=None, n=4, special_reflen=None):
        ''' singular instance '''

        self.n = n
        self.crefs = []
        self.ctest = []
        self.cook_append(test, refs)
        self.special_reflen = special_reflen

    def cook_append(self, test, refs):
        '''called by constructor and __iadd__ to avoid creating new instances.'''

        if refs is not None:
            self.crefs.append(cook_refs(refs))
            if test is not None:
                cooked_test = cook_test(test, self.crefs[-1][0],self.crefs[-1][1])
                self.ctest.append(cooked_test) ## N.B.: -1
            else:
                self.ctest.append(None) # lens of crefs and ctest have to match

        self._score = None ## need to recompute

    def ratio(self, option=None):
        self.compute_score(option=option)
        return self._ratio

    def score_ratio(self, option=None):
        '''return (bleu, len_ratio) pair'''
        return (self.fscore(option=option), self.ratio(option=option))

    def score_ratio_str(self, option=None):
        return "%.4f (%.2f)" % self.score_ratio(option)

    def reflen(self, option=None):
        self.compute_score(option=option)
        return self._reflen

    def testlen(self, option=None):
        self.compute_score(option=option)
        return self._testlen

    def retest(self, new_test):
        if type(new_test) is str:
            new_test = [new_test]
        assert len(new_test) == len(self.crefs), new_test
        self.ctest = []
        for t, rs in zip(new_test, self.crefs):
            self.ctest.append(cook_test(t, rs[0], rs[1]))
        self._score = None

        return self

    def rescore(self, new_test):
        ''' replace test(s) with new test(s), and returns the new score.'''

        return self.retest(new_test).compute_score()

    def size(self):
        assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
        return len(self.crefs)

    def __iadd__(self, other):
        '''add an instance (e.g., from another sentence).'''

        if type(other) is tuple:
            ## avoid creating new BleuScorer instances
            self.cook_append(other[0], other[1])
        else:
            assert self.compatible(other), "incompatible BLEUs."
            self.ctest.extend(other.ctest)
            self.crefs.extend(other.crefs)
            self._score = None ## need to recompute

        return self

    def compatible(self, other):
        return isinstance(other, BleuScorer) and self.n == other.n

    def single_reflen(self, option="average"):
        return self._single_reflen(self.crefs[0][0], option)

    def _single_reflen(self, reflens, option=None, testlen=None):

        if option == "shortest":
            reflen = min(reflens)
        elif option == "average":
            reflen = float(sum(reflens))/len(reflens)
        elif option == "closest":
            reflen = min((abs(l-testlen), l) for l in reflens)[1]
        else:
            assert False, "unsupported reflen option %s" % option

        return reflen

    def recompute_score(self, option=None, verbose=0):
        self._score = None
        return self.compute_score(option, verbose)

    def compute_score(self, option=None, verbose=0):
        n = self.n
        small = 1e-9
        tiny = 1e-15 ## so that if guess is 0 still return 0
        bleu_list = [[] for _ in range(n)]

        if self._score is not None:
            return self._score

        if option is None:
            option = "average" if len(self.crefs) == 1 else "closest"

        self._testlen = 0
        self._reflen = 0
        totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}

        # for each sentence
        for comps in self.ctest:
            testlen = comps['testlen']
            self._testlen += testlen

            if self.special_reflen is None: ## need computation
                reflen = self._single_reflen(comps['reflen'], option, testlen)
            else:
                reflen = self.special_reflen

            self._reflen += reflen

            for key in ['guess','correct']:
                for k in range(n):
                    totalcomps[key][k] += comps[key][k]

            # append per image bleu score
            bleu = 1.
            for k in range(n):
                bleu *= (float(comps['correct'][k]) + tiny) \
                        /(float(comps['guess'][k]) + small)
                bleu_list[k].append(bleu ** (1./(k+1)))
            ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
            if ratio < 1:
                for k in range(n):
                    bleu_list[k][-1] *= math.exp(1 - 1/ratio)

            if verbose > 1:
                print (comps, reflen)

        totalcomps['reflen'] = self._reflen
        totalcomps['testlen'] = self._testlen

        bleus = []
        bleu = 1.
        for k in range(n):
            bleu *= float(totalcomps['correct'][k] + tiny) \
                    / (totalcomps['guess'][k] + small)
            bleus.append(bleu ** (1./(k+1)))
        ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
        if ratio < 1:
            for k in range(n):
                bleus[k] *= math.exp(1 - 1/ratio)

        if verbose > 0:
            print (totalcomps)
            print ("ratio:", ratio)

        self._score = bleus
        return self._score, bleu_list



class Bleu:
    def __init__(self, n=4):
        # default compute Blue score up to 4
        self._n = n
        self._hypo_for_image = {}
        self.ref_for_image = {}

    def compute_score(self, gts, res):

        assert(gts.keys() == res.keys())
        imgIds = gts.keys()

        bleu_scorer = BleuScorer(n=self._n)
        for id in imgIds:
            hypo = res[id]
            ref = gts[id]

            # Sanity check.
            assert(type(hypo) is list)
            assert(len(hypo) == 1)
            assert(type(ref) is list)
            assert(len(ref) >= 1)

            bleu_scorer += (hypo[0], ref)

        #score, scores = bleu_scorer.compute_score(option='shortest')
        score, scores = bleu_scorer.compute_score(option='closest', verbose=0)
        #score, scores = bleu_scorer.compute_score(option='average', verbose=1)

        # return (bleu, bleu_info)
        return score, scores

    def method(self):
        return "Bleu"


================================================
FILE: vae/qgevalcap/bleu/bleu_scorer.py
================================================
#!/usr/bin/env python

# bleu_scorer.py
# David Chiang <chiang@isi.edu>

# Copyright (c) 2004-2006 University of Maryland. All rights
# reserved. Do not redistribute without permission from the
# author. Not for commercial use.

# Modified by: 
# Hao Fang <hfang@uw.edu>
# Tsung-Yi Lin <tl483@cornell.edu>

'''Provides:
cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
'''

import copy
import sys, math, re
from collections import defaultdict

def precook(s, n=4, out=False):
    """Takes a string as input and returns an object that can be given to
    either cook_refs or cook_test. This is optional: cook_refs and cook_test
    can take string arguments as well."""
    words = s.split()
    counts = defaultdict(int)
    for k in xrange(1,n+1):
        for i in xrange(len(words)-k+1):
            ngram = tuple(words[i:i+k])
            counts[ngram] += 1
    return (len(words), counts)

def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average"
    '''Takes a list of reference sentences for a single segment
    and returns an object that encapsulates everything that BLEU
    needs to know about them.'''

    reflen = []
    maxcounts = {}
    for ref in refs:
        rl, counts = precook(ref, n)
        reflen.append(rl)
        for (ngram,count) in counts.iteritems():
            maxcounts[ngram] = max(maxcounts.get(ngram,0), count)

    # Calculate effective reference sentence length.
    if eff == "shortest":
        reflen = min(reflen)
    elif eff == "average":
        reflen = float(sum(reflen))/len(reflen)

    ## lhuang: N.B.: leave reflen computaiton to the very end!!
    
    ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)

    return (reflen, maxcounts)

def cook_test(test, (reflen, refmaxcounts), eff=None, n=4):
    '''Takes a test sentence and returns an object that
    encapsulates everything that BLEU needs to know about it.'''

    testlen, counts = precook(test, n, True)

    result = {}

    # Calculate effective reference sentence length.
    
    if eff == "closest":
        result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
    else: ## i.e., "average" or "shortest" or None
        result["reflen"] = reflen

    result["testlen"] = testlen

    result["guess"] = [max(0,testlen-k+1) for k in xrange(1,n+1)]

    result['correct'] = [0]*n
    for (ngram, count) in counts.iteritems():
        result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)

    return result

class BleuScorer(object):
    """Bleu scorer.
    """

    __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
    # special_reflen is used in oracle (proportional effective ref len for a node).

    def copy(self):
        ''' copy the refs.'''
        new = BleuScorer(n=self.n)
        new.ctest = copy.copy(self.ctest)
        new.crefs = copy.copy(self.crefs)
        new._score = None
        return new

    def __init__(self, test=None, refs=None, n=4, special_reflen=None):
        ''' singular instance '''

        self.n = n
        self.crefs = []
        self.ctest = []
        self.cook_append(test, refs)
        self.special_reflen = special_reflen

    def cook_append(self, test, refs):
        '''called by constructor and __iadd__ to avoid creating new instances.'''
        
        if refs is not None:
            self.crefs.append(cook_refs(refs))
            if test is not None:
                cooked_test = cook_test(test, self.crefs[-1])
                self.ctest.append(cooked_test) ## N.B.: -1
            else:
                self.ctest.append(None) # lens of crefs and ctest have to match

        self._score = None ## need to recompute

    def ratio(self, option=None):
        self.compute_score(option=option)
        return self._ratio

    def score_ratio(self, option=None):
        '''return (bleu, len_ratio) pair'''
        return (self.fscore(option=option), self.ratio(option=option))

    def score_ratio_str(self, option=None):
        return "%.4f (%.2f)" % self.score_ratio(option)

    def reflen(self, option=None):
        self.compute_score(option=option)
        return self._reflen

    def testlen(self, option=None):
        self.compute_score(option=option)
        return self._testlen        

    def retest(self, new_test):
        if type(new_test) is str:
            new_test = [new_test]
        assert len(new_test) == len(self.crefs), new_test
        self.ctest = []
        for t, rs in zip(new_test, self.crefs):
            self.ctest.append(cook_test(t, rs))
        self._score = None

        return self

    def rescore(self, new_test):
        ''' replace test(s) with new test(s), and returns the new score.'''
        
        return self.retest(new_test).compute_score()

    def size(self):
        assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
        return len(self.crefs)

    def __iadd__(self, other):
        '''add an instance (e.g., from another sentence).'''

        if type(other) is tuple:
            ## avoid creating new BleuScorer instances
            self.cook_append(other[0], other[1])
        else:
            assert self.compatible(other), "incompatible BLEUs."
            self.ctest.extend(other.ctest)
            self.crefs.extend(other.crefs)
            self._score = None ## need to recompute

        return self        

    def compatible(self, other):
        return isinstance(other, BleuScorer) and self.n == other.n

    def single_reflen(self, option="average"):
        return self._single_reflen(self.crefs[0][0], option)

    def _single_reflen(self, reflens, option=None, testlen=None):
        
        if option == "shortest":
            reflen = min(reflens)
        elif option == "average":
            reflen = float(sum(reflens))/len(reflens)
        elif option == "closest":
            reflen = min((abs(l-testlen), l) for l in reflens)[1]
        else:
            assert False, "unsupported reflen option %s" % option

        return reflen

    def recompute_score(self, option=None, verbose=0):
        self._score = None
        return self.compute_score(option, verbose)
        
    def compute_score(self, option=None, verbose=0):
        n = self.n
        small = 1e-9
        tiny = 1e-15 ## so that if guess is 0 still return 0
        bleu_list = [[] for _ in range(n)]

        if self._score is not None:
            return self._score

        if option is None:
            option = "average" if len(self.crefs) == 1 else "closest"

        self._testlen = 0
        self._reflen = 0
        totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}

        # for each sentence
        for comps in self.ctest:            
            testlen = comps['testlen']
            self._testlen += testlen

            if self.special_reflen is None: ## need computation
                reflen = self._single_reflen(comps['reflen'], option, testlen)
            else:
                reflen = self.special_reflen

            self._reflen += reflen
                
            for key in ['guess','correct']:
                for k in xrange(n):
                    totalcomps[key][k] += comps[key][k]

            # append per image bleu score
            bleu = 1.
            for k in xrange(n):
                bleu *= (float(comps['correct'][k]) + tiny) \
                        /(float(comps['guess'][k]) + small) 
                bleu_list[k].append(bleu ** (1./(k+1)))
            ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
            if ratio < 1:
                for k in xrange(n):
                    bleu_list[k][-1] *= math.exp(1 - 1/ratio)

            if verbose > 1:
                print comps, reflen

        totalcomps['reflen'] = self._reflen
        totalcomps['testlen'] = self._testlen

        bleus = []
        bleu = 1.
        for k in xrange(n):
            bleu *= float(totalcomps['correct'][k] + tiny) \
                    / (totalcomps['guess'][k] + small)
            bleus.append(bleu ** (1./(k+1)))
        ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
        if ratio < 1:
            for k in xrange(n):
                bleus[k] *= math.exp(1 - 1/ratio)

        if verbose > 0:
            print totalcomps
            print "ratio:", ratio

        self._score = bleus
        return self._score, bleu_list


================================================
FILE: vae/qgevalcap/cider/__init__.py
================================================
__author__ = 'tylin'


================================================
FILE: vae/qgevalcap/cider/cider.py
================================================
# Filename: cider.py
#
# Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric
#               by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
#
# Creation Date: Sun Feb  8 14:16:54 2015
#
# Authors: Ramakrishna Vedantam <vrama91@vt.edu> and Tsung-Yi Lin <tl483@cornell.edu>

import pdb
#!/usr/bin/env python
# Tsung-Yi Lin <tl483@cornell.edu>
# Ramakrishna Vedantam <vrama91@vt.edu>

import copy
from collections import defaultdict
import numpy as np
import pdb
import math

def precook(s, n=4, out=False):
    """
    Takes a string as input and returns an object that can be given to
    either cook_refs or cook_test. This is optional: cook_refs and cook_test
    can take string arguments as well.
    :param s: string : sentence to be converted into ngrams
    :param n: int    : number of ngrams for which representation is calculated
    :return: term frequency vector for occuring ngrams
    """
    words = s.split()
    counts = defaultdict(int)
    for k in xrange(1,n+1):
        for i in xrange(len(words)-k+1):
            ngram = tuple(words[i:i+k])
            counts[ngram] += 1
    return counts

def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
    '''Takes a list of reference sentences for a single segment
    and returns an object that encapsulates everything that BLEU
    needs to know about them.
    :param refs: list of string : reference sentences for some image
    :param n: int : number of ngrams for which (ngram) representation is calculated
    :return: result (list of dict)
    '''
    return [precook(ref, n) for ref in refs]

def cook_test(test, n=4):
    '''Takes a test sentence and returns an object that
    encapsulates everything that BLEU needs to know about it.
    :param test: list of string : hypothesis sentence for some image
    :param n: int : number of ngrams for which (ngram) representation is calculated
    :return: result (dict)
    '''
    return precook(test, n, True)

class CiderScorer(object):
    """CIDEr scorer.
    """

    def copy(self):
        ''' copy the refs.'''
        new = CiderScorer(n=self.n)
        new.ctest = copy.copy(self.ctest)
        new.crefs = copy.copy(self.crefs)
        return new

    def __init__(self, test=None, refs=None, n=4, sigma=6.0):
        ''' singular instance '''
        self.n = n
        self.sigma = sigma
        self.crefs = []
        self.ctest = []
        self.document_frequency = defaultdict(float)
        self.cook_append(test, refs)
        self.ref_len = None

    def cook_append(self, test, refs):
        '''called by constructor and __iadd__ to avoid creating new instances.'''

        if refs is not None:
            self.crefs.append(cook_refs(refs))
            if test is not None:
                self.ctest.append(cook_test(test)) ## N.B.: -1
            else:
                self.ctest.append(None) # lens of crefs and ctest have to match

    def size(self):
        assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
        return len(self.crefs)

    def __iadd__(self, other):
        '''add an instance (e.g., from another sentence).'''

        if type(other) is tuple:
            ## avoid creating new CiderScorer instances
            self.cook_append(other[0], other[1])
        else:
            self.ctest.extend(other.ctest)
            self.crefs.extend(other.crefs)

        return self
    def compute_doc_freq(self):
        '''
        Compute term frequency for reference data.
        This will be used to compute idf (inverse document frequency later)
        The term frequency is stored in the object
        :return: None
        '''
        for refs in self.crefs:
            # refs, k ref captions of one image
            for ngram in set([ngram for ref in refs for (ngram,count) in ref.iteritems()]):
                self.document_frequency[ngram] += 1
            # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)

    def compute_cider(self):
        def counts2vec(cnts):
            """
            Function maps counts of ngram to vector of tfidf weights.
            The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
            The n-th entry of array denotes length of n-grams.
            :param cnts:
            :return: vec (array of dict), norm (array of float), length (int)
            """
            vec = [defaultdict(float) for _ in range(self.n)]
            length = 0
            norm = [0.0 for _ in range(self.n)]
            for (ngram,term_freq) in cnts.iteritems():
                # give word count 1 if it doesn't appear in reference corpus
                df = np.log(max(1.0, self.document_frequency[ngram]))
                # ngram index
                n = len(ngram)-1
                # tf (term_freq) * idf (precomputed idf) for n-grams
                vec[n][ngram] = float(term_freq)*(self.ref_len - df)
                # compute norm for the vector.  the norm will be used for computing similarity
                norm[n] += pow(vec[n][ngram], 2)

                if n == 1:
                    length += term_freq
            norm = [np.sqrt(n) for n in norm]
            return vec, norm, length

        def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
            '''
            Compute the cosine similarity of two vectors.
            :param vec_hyp: array of dictionary for vector corresponding to hypothesis
            :param vec_ref: array of dictionary for vector corresponding to reference
            :param norm_hyp: array of float for vector corresponding to hypothesis
            :param norm_ref: array of float for vector corresponding to reference
            :param length_hyp: int containing length of hypothesis
            :param length_ref: int containing length of reference
            :return: array of score for each n-grams cosine similarity
            '''
            delta = float(length_hyp - length_ref)
            # measure consine similarity
            val = np.array([0.0 for _ in range(self.n)])
            for n in range(self.n):
                # ngram
                for (ngram,count) in vec_hyp[n].iteritems():
                    # vrama91 : added clipping
                    val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]

                if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
                    val[n] /= (norm_hyp[n]*norm_ref[n])

                assert(not math.isnan(val[n]))
                # vrama91: added a length based gaussian penalty
                val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
            return val

        # compute log reference length
        self.ref_len = np.log(float(len(self.crefs)))

        scores = []
        for test, refs in zip(self.ctest, self.crefs):
            # compute vector for test captions
            vec, norm, length = counts2vec(test)
            # compute vector for ref captions
            score = np.array([0.0 for _ in range(self.n)])
            for ref in refs:
                vec_ref, norm_ref, length_ref = counts2vec(ref)
                score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
            # change by vrama91 - mean of ngram scores, instead of sum
            score_avg = np.mean(score)
            # divide by number of references
            score_avg /= len(refs)
            # multiply score by 10
            score_avg *= 10.0
            # append score of an image to the score list
            scores.append(score_avg)
        return scores

    def compute_score(self, option=None, verbose=0):
        # compute idf
        self.compute_doc_freq()
        # assert to check document frequency
        assert(len(self.ctest) >= max(self.document_frequency.values()))
        # compute cider score
        score = self.compute_cider()
        # debug
        # print score
        return np.mean(np.array(score)), np.array(score)
class Cider:
    """
    Main Class to compute the CIDEr metric

    """
    def __init__(self, test=None, refs=None, n=4, sigma=6.0):
        # set cider to sum over 1 to 4-grams
        self._n = n
        # set the standard deviation parameter for gaussian penalty
        self._sigma = sigma

    def compute_score(self, gts, res):
        """
        Main function to compute CIDEr score
        :param  hypo_for_image (dict) : dictionary with key <image> and value <tokenized hypothesis / candidate sentence>
                ref_for_image (dict)  : dictionary with key <image> and value <tokenized reference sentence>
        :return: cider (float) : computed CIDEr score for the corpus
        """

        assert(gts.keys() == res.keys())
        imgIds = gts.keys()

        cider_scorer = CiderScorer(n=self._n, sigma=self._sigma)

        for id in imgIds:
            hypo = res[id]
            ref = gts[id]

            # Sanity check.
            assert(type(hypo) is list)
            assert(len(hypo) == 1)
            assert(type(ref) is list)
            assert(len(ref) > 0)

            cider_scorer += (hypo[0], ref)

        (score, scores) = cider_scorer.compute_score()

        return score, scores

    def method(self):
        return "CIDEr"


================================================
FILE: vae/qgevalcap/cider/cider_scorer.py
================================================
#!/usr/bin/env python
# Tsung-Yi Lin <tl483@cornell.edu>
# Ramakrishna Vedantam <vrama91@vt.edu>

import copy
from collections import defaultdict
import numpy as np
import pdb
import math

def precook(s, n=4, out=False):
    """
    Takes a string as input and returns an object that can be given to
    either cook_refs or cook_test. This is optional: cook_refs and cook_test
    can take string arguments as well.
    :param s: string : sentence to be converted into ngrams
    :param n: int    : number of ngrams for which representation is calculated
    :return: term frequency vector for occuring ngrams
    """
    words = s.split()
    counts = defaultdict(int)
    for k in xrange(1,n+1):
        for i in xrange(len(words)-k+1):
            ngram = tuple(words[i:i+k])
            counts[ngram] += 1
    return counts

def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
    '''Takes a list of reference sentences for a single segment
    and returns an object that encapsulates everything that BLEU
    needs to know about them.
    :param refs: list of string : reference sentences for some image
    :param n: int : number of ngrams for which (ngram) representation is calculated
    :return: result (list of dict)
    '''
    return [precook(ref, n) for ref in refs]

def cook_test(test, n=4):
    '''Takes a test sentence and returns an object that
    encapsulates everything that BLEU needs to know about it.
    :param test: list of string : hypothesis sentence for some image
    :param n: int : number of ngrams for which (ngram) representation is calculated
    :return: result (dict)
    '''
    return precook(test, n, True)

class CiderScorer(object):
    """CIDEr scorer.
    """

    def copy(self):
        ''' copy the refs.'''
        new = CiderScorer(n=self.n)
        new.ctest = copy.copy(self.ctest)
        new.crefs = copy.copy(self.crefs)
        return new

    def __init__(self, test=None, refs=None, n=4, sigma=6.0):
        ''' singular instance '''
        self.n = n
        self.sigma = sigma
        self.crefs = []
        self.ctest = []
        self.document_frequency = defaultdict(float)
        self.cook_append(test, refs)
        self.ref_len = None

    def cook_append(self, test, refs):
        '''called by constructor and __iadd__ to avoid creating new instances.'''

        if refs is not None:
            self.crefs.append(cook_refs(refs))
            if test is not None:
                self.ctest.append(cook_test(test)) ## N.B.: -1
            else:
                self.ctest.append(None) # lens of crefs and ctest have to match

    def size(self):
        assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
        return len(self.crefs)

    def __iadd__(self, other):
        '''add an instance (e.g., from another sentence).'''

        if type(other) is tuple:
            ## avoid creating new CiderScorer instances
            self.cook_append(other[0], other[1])
        else:
            self.ctest.extend(other.ctest)
            self.crefs.extend(other.crefs)

        return self
    def compute_doc_freq(self):
        '''
        Compute term frequency for reference data.
        This will be used to compute idf (inverse document frequency later)
        The term frequency is stored in the object
        :return: None
        '''
        for refs in self.crefs:
            # refs, k ref captions of one image
            for ngram in set([ngram for ref in refs for (ngram,count) in ref.iteritems()]):
                self.document_frequency[ngram] += 1
            # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)

    def compute_cider(self):
        def counts2vec(cnts):
            """
            Function maps counts of ngram to vector of tfidf weights.
            The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
            The n-th entry of array denotes length of n-grams.
            :param cnts:
            :return: vec (array of dict), norm (array of float), length (int)
            """
            vec = [defaultdict(float) for _ in range(self.n)]
            length = 0
            norm = [0.0 for _ in range(self.n)]
            for (ngram,term_freq) in cnts.iteritems():
                # give word count 1 if it doesn't appear in reference corpus
                df = np.log(max(1.0, self.document_frequency[ngram]))
                # ngram index
                n = len(ngram)-1
                # tf (term_freq) * idf (precomputed idf) for n-grams
                vec[n][ngram] = float(term_freq)*(self.ref_len - df)
                # compute norm for the vector.  the norm will be used for computing similarity
                norm[n] += pow(vec[n][ngram], 2)

                if n == 1:
                    length += term_freq
            norm = [np.sqrt(n) for n in norm]
            return vec, norm, length

        def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
            '''
            Compute the cosine similarity of two vectors.
            :param vec_hyp: array of dictionary for vector corresponding to hypothesis
            :param vec_ref: array of dictionary for vector corresponding to reference
            :param norm_hyp: array of float for vector corresponding to hypothesis
            :param norm_ref: array of float for vector corresponding to reference
            :param length_hyp: int containing length of hypothesis
            :param length_ref: int containing length of reference
            :return: array of score for each n-grams cosine similarity
            '''
            delta = float(length_hyp - length_ref)
            # measure consine similarity
            val = np.array([0.0 for _ in range(self.n)])
            for n in range(self.n):
                # ngram
                for (ngram,count) in vec_hyp[n].iteritems():
                    # vrama91 : added clipping
                    val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]

                if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
                    val[n] /= (norm_hyp[n]*norm_ref[n])

                assert(not math.isnan(val[n]))
                # vrama91: added a length based gaussian penalty
                val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
            return val

        # compute log reference length
        self.ref_len = np.log(float(len(self.crefs)))

        scores = []
        for test, refs in zip(self.ctest, self.crefs):
            # compute vector for test captions
            vec, norm, length = counts2vec(test)
            # compute vector for ref captions
            score = np.array([0.0 for _ in range(self.n)])
            for ref in refs:
                vec_ref, norm_ref, length_ref = counts2vec(ref)
                score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
            # change by vrama91 - mean of ngram scores, instead of sum
            score_avg = np.mean(score)
            # divide by number of references
            score_avg /= len(refs)
            # multiply score by 10
            score_avg *= 10.0
            # append score of an image to the score list
            scores.append(score_avg)
        return scores

    def compute_score(self, option=None, verbose=0):
        # compute idf
        self.compute_doc_freq()
        # assert to check document frequency
        assert(len(self.ctest) >= max(self.document_frequency.values()))
        # compute cider score
        score = self.compute_cider()
        # debug
        # print score
        return np.mean(np.array(score)), np.array(score)

================================================
FILE: vae/qgevalcap/eval.py
================================================
import os
import copy
from collections import defaultdict
from argparse import ArgumentParser
import pickle
import json
from json import encoder
import math


def precook(s, n=4, out=False):
    """Takes a string as input and returns an object that can be given to
    either cook_refs or cook_test. This is optional: cook_refs and cook_test
    can take string arguments as well."""
    words = s.split()
    counts = defaultdict(int)
    for k in range(1, n + 1):
        for i in range(len(words) - k + 1):
            ngram = tuple(words[i:i + k])
            counts[ngram] += 1
    return (len(words), counts)


def cook_refs(refs, eff=None, n=4):  ## lhuang: oracle will call with "average"
    '''Takes a list of reference sentences for a single segment
    and returns an object that encapsulates everything that BLEU
    needs to know about them.'''

    reflen = []
    maxcounts = {}
    for ref in refs:
        rl, counts = precook(ref, n)
        reflen.append(rl)
        for (ngram, count) in counts.items():
            maxcounts[ngram] = max(maxcounts.get(ngram, 0), count)

    # Calculate effective reference sentence length.
    if eff == "shortest":
        reflen = min(reflen)
    elif eff == "average":
        reflen = float(sum(reflen)) / len(reflen)

    ## lhuang: N.B.: leave reflen computaiton to the very end!!

    ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)

    return (reflen, maxcounts)


def cook_test(test, reflen, refmaxcounts, eff=None, n=4):
    '''Takes a test sentence and returns an object that
    encapsulates everything that BLEU needs to know about it.'''

    testlen, counts = precook(test, n, True)

    result = {}

    # Calculate effective reference sentence length.

    if eff == "closest":
        result["reflen"] = min((abs(l - testlen), l) for l in reflen)[1]
    else:  ## i.e., "average" or "shortest" or None
        result["reflen"] = reflen

    result["testlen"] = testlen

    result["guess"] = [max(0, testlen - k + 1) for k in range(1, n + 1)]

    result['correct'] = [0] * n
    for (ngram, count) in counts.items():
        result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count)

    return result


class BleuScorer(object):
    """Bleu scorer.
    """

    __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"

    # special_reflen is used in oracle (proportional effective ref len for a node).

    def copy(self):
        ''' copy the refs.'''
        new = BleuScorer(n=self.n)
        new.ctest = copy.copy(self.ctest)
        new.crefs = copy.copy(self.crefs)
        new._score = None
        return new

    def __init__(self, test=None, refs=None, n=4, special_reflen=None):
        ''' singular instance '''

        self.n = n
        self.crefs = []
        self.ctest = []
        self.cook_append(test, refs)
        self.special_reflen = special_reflen

    def cook_append(self, test, refs):
        '''called by constructor and __iadd__ to avoid creating new instances.'''

        if refs is not None:
            self.crefs.append(cook_refs(refs))
            if test is not None:
                cooked_test = cook_test(test, self.crefs[-1][0], self.crefs[-1][1])
                self.ctest.append(cooked_test)  ## N.B.: -1
            else:
                self.ctest.append(None)  # lens of crefs and ctest have to match

        self._score = None  ## need to recompute

    def ratio(self, option=None):
        self.compute_score(option=option)
        return self._ratio

    def score_ratio(self, option=None):
        '''return (bleu, len_ratio) pair'''
        return (self.fscore(option=option), self.ratio(option=option))

    def score_ratio_str(self, option=None):
        return "%.4f (%.2f)" % self.score_ratio(option)

    def reflen(self, option=None):
        self.compute_score(option=option)
        return self._reflen

    def testlen(self, option=None):
        self.compute_score(option=option)
        return self._testlen

    def retest(self, new_test):
        if type(new_test) is str:
            new_test = [new_test]
        assert len(new_test) == len(self.crefs), new_test
        self.ctest = []
        for t, rs in zip(new_test, self.crefs):
            self.ctest.append(cook_test(t, rs[0], rs[1]))
        self._score = None

        return self

    def rescore(self, new_test):
        ''' replace test(s) with new test(s), and returns the new score.'''

        return self.retest(new_test).compute_score()

    def size(self):
        assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
        return len(self.crefs)

    def __iadd__(self, other):
        '''add an instance (e.g., from another sentence).'''

        if type(other) is tuple:
            ## avoid creating new BleuScorer instances
            self.cook_append(other[0], other[1])
        else:
            assert self.compatible(other), "incompatible BLEUs."
            self.ctest.extend(other.ctest)
            self.crefs.extend(other.crefs)
            self._score = None  ## need to recompute

        return self

    def compatible(self, other):
        return isinstance(other, BleuScorer) and self.n == other.n

    def single_reflen(self, option="average"):
        return self._single_reflen(self.crefs[0][0], option)

    def _single_reflen(self, reflens, option=None, testlen=None):

        if option == "shortest":
            reflen = min(reflens)
        elif option == "average":
            reflen = float(sum(reflens)) / len(reflens)
        elif option == "closest":
            reflen = min((abs(l - testlen), l) for l in reflens)[1]
        else:
            assert False, "unsupported reflen option %s" % option

        return reflen

    def recompute_score(self, option=None, verbose=0):
        self._score = None
        return self.compute_score(option, verbose)

    def compute_score(self, option=None, verbose=0):
        n = self.n
        small = 1e-9
        tiny = 1e-15  ## so that if guess is 0 still return 0
        bleu_list = [[] for _ in range(n)]

        if self._score is not None:
            return self._score

        if option is None:
            option = "average" if len(self.crefs) == 1 else "closest"

        self._testlen = 0
        self._reflen = 0
        totalcomps = {'testlen': 0, 'reflen': 0, 'guess': [0] * n, 'correct': [0] * n}

        # for each sentence
        for comps in self.ctest:
            testlen = comps['testlen']
            self._testlen += testlen

            if self.special_reflen is None:  ## need computation
                reflen = self._single_reflen(comps['reflen'], option, testlen)
            else:
                reflen = self.special_reflen

            self._reflen += reflen

            for key in ['guess', 'correct']:
                for k in range(n):
                    totalcomps[key][k] += comps[key][k]

            # append per image bleu score
            bleu = 1.
            for k in range(n):
                bleu *= (float(comps['correct'][k]) + tiny) \
                        / (float(comps['guess'][k]) + small)
                bleu_list[k].append(bleu ** (1. / (k + 1)))
            ratio = (testlen + tiny) / (reflen + small)  ## N.B.: avoid zero division
            if ratio < 1:
                for k in range(n):
                    bleu_list[k][-1] *= math.exp(1 - 1 / ratio)

            if verbose > 1:
                print(comps, reflen)

        totalcomps['reflen'] = self._reflen
        totalcomps['testlen'] = self._testlen

        bleus = []
        bleu = 1.
        for k in range(n):
            bleu *= float(totalcomps['correct'][k] + tiny) \
                    / (totalcomps['guess'][k] + small)
            bleus.append(bleu ** (1. / (k + 1)))
        ratio = (self._testlen + tiny) / (self._reflen + small)  ## N.B.: avoid zero division
        if ratio < 1:
            for k in range(n):
                bleus[k] *= math.exp(1 - 1 / ratio)

        if verbose > 0:
            print(totalcomps)
            print("ratio:", ratio)

        self._score = bleus
        return self._score, bleu_list


class Bleu:
    def __init__(self, n=4):
        # default compute Blue score up to 4
        self._n = n
        self._hypo_for_image = {}
        self.ref_for_image = {}

    def compute_score(self, gts, res):
        assert (gts.keys() == res.keys())
        imgIds = gts.keys()

        bleu_scorer = BleuScorer(n=self._n)
        for id in imgIds:
            hypo = res[id]
            ref = gts[id]

            # Sanity check.
            assert (type(hypo) is list)
            assert (len(hypo) == 1)
            assert (type(ref) is list)
            assert (len(ref) >= 1)

            bleu_scorer += (hypo[0], ref)

        # score, scores = bleu_scorer.compute_score(option='shortest')
        score, scores = bleu_scorer.compute_score(option='closest', verbose=0)
        # score, scores = bleu_scorer.compute_score(option='average', verbose=1)

        # return (bleu, bleu_info)
        return score, scores

    def method(self):
        return "Bleu"


class QGEvalCap:
    def __init__(self, gts, res):
        self.gts = gts
        self.res = res

    def evaluate(self, not_print=True):
        output = []
        scorers = [
            (Bleu(4), "Bleu_4"),
            # (meteor.Meteor(),"METEOR"),
            # (rouge.Rouge(), "ROUGE_L"),
            # (cider.Cider(), "CIDEr")
        ]

        # =================================================
        # Compute scores
        # =================================================
        for scorer, method in scorers:
            # print 'computing %s score...'%(scorer.method())
            score, scores = scorer.compute_score(self.gts, self.res)
            if type(method) == list:
                for sc, scs, m in zip(score, scores, method):
                    if not not_print:
                        print("%s: %0.5f" % (m, sc))
                    output.append(sc)
            else:
                if not not_print:
                    print("%s: %0.5f" % (method, score))
                output.append(score)
        return output


def eval_qg(res_dict, gts_dict, not_print=True):
    encoder.FLOAT_REPR = lambda o: format(o, '.4f')

    res = defaultdict(lambda: [])
    gts = defaultdict(lambda: [])

    for key in gts_dict.keys():
        res[key] = [res_dict[key].encode('utf-8')]
        gts[key].append(gts_dict[key].encode('utf-8'))

    QGEval = QGEvalCap(gts, res)
    return QGEval.evaluate(not_print)[0][-1]


if __name__ == 
Download .txt
gitextract_n8dcazx2/

├── .gitignore
├── LICENSE
├── README.md
├── qa-eval/
│   ├── distributed_run.py
│   ├── main.py
│   ├── squad_utils.py
│   ├── trainer.py
│   └── utils.py
└── vae/
    ├── eval.py
    ├── generate_qa.py
    ├── main.py
    ├── models.py
    ├── qgevalcap/
    │   ├── .gitignore
    │   ├── README.md
    │   ├── bleu/
    │   │   ├── .gitignore
    │   │   ├── LICENSE
    │   │   ├── bleu.py
    │   │   └── bleu_scorer.py
    │   ├── cider/
    │   │   ├── __init__.py
    │   │   ├── cider.py
    │   │   └── cider_scorer.py
    │   ├── eval.py
    │   ├── meteor/
    │   │   ├── __init__.py
    │   │   ├── meteor-1.5.jar
    │   │   └── meteor.py
    │   └── rouge/
    │       ├── __init__.py
    │       └── rouge.py
    ├── squad_utils.py
    ├── trainer.py
    ├── translate.py
    └── utils.py
Download .txt
SYMBOL INDEX (265 symbols across 20 files)

FILE: qa-eval/distributed_run.py
  function distributed_main (line 6) | def distributed_main(args):
  function worker (line 19) | def worker(gpu, ngpus_per_node, args):

FILE: qa-eval/main.py
  class HarvestingQADataset (line 15) | class HarvestingQADataset(Dataset):
    method __init__ (line 16) | def __init__(self, filename, ratio):
    method __getitem__ (line 20) | def __getitem__(self, idx):
    method __len__ (line 38) | def __len__(self):
  function main (line 42) | def main(args):

FILE: qa-eval/squad_utils.py
  class SquadExample (line 16) | class SquadExample(object):
    method __init__ (line 21) | def __init__(self,
    method __str__ (line 37) | def __str__(self):
    method __repr__ (line 40) | def __repr__(self):
  class InputFeatures (line 55) | class InputFeatures(object):
    method __init__ (line 57) | def __init__(self,
  function convert_examples_to_features (line 102) | def convert_examples_to_features(examples, tokenizer, max_seq_length,
  function convert_examples_to_harv_features (line 268) | def convert_examples_to_harv_features(examples, tokenizer, max_seq_length,
  function convert_examples_to_features_answer_id (line 399) | def convert_examples_to_features_answer_id(examples, tokenizer, max_seq_...
  function read_examples (line 610) | def read_examples(input_file, debug=False, is_training=False):
  function read_squad_examples (line 666) | def read_squad_examples(input_file, is_training, version_2_with_negative...
  function _improve_answer_span (line 746) | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
  function _check_is_max_context (line 783) | def _check_is_max_context(doc_spans, cur_span_index, position):
  function write_predictions (line 821) | def write_predictions(all_examples, all_features, all_results, n_best_size,
  function get_final_text (line 1027) | def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=...
  function _get_best_indexes (line 1122) | def _get_best_indexes(logits, n_best_size):
  function _compute_softmax (line 1135) | def _compute_softmax(scores):
  function write_answer_predictions (line 1158) | def write_answer_predictions(all_examples, all_features, all_results, n_...
  function normalize_answer (line 1359) | def normalize_answer(s):
  function f1_score (line 1378) | def f1_score(prediction, ground_truth):
  function exact_match_score (line 1392) | def exact_match_score(prediction, ground_truth):
  function metric_max_over_ground_truths (line 1396) | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
  function evaluate (line 1404) | def evaluate(dataset, predictions):
  function read_predictions (line 1428) | def read_predictions(prediction_file):
  function read_answers (line 1434) | def read_answers(gold_file):
  function evaluate_mrqa (line 1446) | def evaluate_mrqa(answers, predictions, skip_no_answer=False):

FILE: qa-eval/trainer.py
  function get_linear_schedule_with_warmup (line 30) | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_tra...
  class Trainer (line 41) | class Trainer(object):
    method __init__ (line 42) | def __init__(self, args):
    method make_model_env (line 47) | def make_model_env(self, gpu, ngpus_per_node):
    method get_pretrain_loader (line 94) | def get_pretrain_loader(self):
    method get_dev_loader (line 102) | def get_dev_loader(self):
    method get_test_loader (line 109) | def get_test_loader(self):
    method train (line 116) | def train(self):
    method evaluate_model (line 180) | def evaluate_model(self, msg, dev=True):
    method set_random_seed (line 252) | def set_random_seed(self, random_seed=2019):

FILE: qa-eval/utils.py
  function time_since (line 5) | def time_since(t):
  function progress_bar (line 10) | def progress_bar(completed, total, step=5):
  function user_friendly_time (line 32) | def user_friendly_time(s):
  function eta (line 53) | def eta(start, completed, total):
  function cal_running_avg_loss (line 64) | def cal_running_avg_loss(loss, running_avg_loss, decay=0.99):
  function kl_coef (line 72) | def kl_coef(i):
  function compute_kernel (line 79) | def compute_kernel(x, y):
  function compute_mmd (line 91) | def compute_mmd(x, y):
  class EMA (line 99) | class EMA(object):
    method __init__ (line 106) | def __init__(self, model, decay):
    method __call__ (line 116) | def __call__(self, model, num_updates):
    method assign (line 125) | def assign(self, model):
    method resume (line 137) | def resume(self, model):

FILE: vae/eval.py
  function to_string (line 13) | def to_string(index, tokenizer):
  class Result (line 29) | class Result(object):
    method __init__ (line 30) | def __init__(self,
  function eval_vae (line 51) | def eval_vae(epoch, args, trainer, eval_data):

FILE: vae/generate_qa.py
  class CustomDatset (line 13) | class CustomDatset(Dataset):
    method __init__ (line 14) | def __init__(self, tokenizer, input_file, max_length=512):
    method __getitem__ (line 20) | def __getitem__(self, idx):
    method __len__ (line 32) | def __len__(self):
  function main (line 36) | def main(args):

FILE: vae/main.py
  function main (line 15) | def main(args):

FILE: vae/models.py
  function return_mask_lengths (line 10) | def return_mask_lengths(ids):
  function cal_attn (line 16) | def cal_attn(query, memories, mask):
  function gumbel_softmax (line 25) | def gumbel_softmax(logits, tau=1, hard=False, eps=1e-20, dim=-1):
  class CategoricalKLLoss (line 44) | class CategoricalKLLoss(nn.Module):
    method __init__ (line 45) | def __init__(self):
    method forward (line 48) | def forward(self, P, Q):
  class GaussianKLLoss (line 55) | class GaussianKLLoss(nn.Module):
    method __init__ (line 56) | def __init__(self):
    method forward (line 59) | def forward(self, mu1, logvar1, mu2, logvar2):
  class Embedding (line 66) | class Embedding(nn.Module):
    method __init__ (line 67) | def __init__(self, bert_model):
    method forward (line 76) | def forward(self, input_ids, token_type_ids=None, position_ids=None):
  class ContextualizedEmbedding (line 96) | class ContextualizedEmbedding(nn.Module):
    method __init__ (line 97) | def __init__(self, bert_model):
    method forward (line 104) | def forward(self, input_ids, attention_mask, token_type_ids=None):
  class CustomLSTM (line 128) | class CustomLSTM(nn.Module):
    method __init__ (line 129) | def __init__(self, input_size, hidden_size, num_layers, dropout, bidir...
    method forward (line 142) | def forward(self, inputs, input_lengths, state=None):
  class PosteriorEncoder (line 158) | class PosteriorEncoder(nn.Module):
    method __init__ (line 159) | def __init__(self, embedding, emsize,
    method forward (line 187) | def forward(self, c_ids, q_ids, a_ids):
  class PriorEncoder (line 244) | class PriorEncoder(nn.Module):
    method __init__ (line 245) | def __init__(self, embedding, emsize,
    method forward (line 269) | def forward(self, c_ids):
    method interpolation (line 294) | def interpolation(self, c_ids, zq):
  class AnswerDecoder (line 316) | class AnswerDecoder(nn.Module):
    method __init__ (line 317) | def __init__(self, embedding, emsize,
    method forward (line 334) | def forward(self, init_state, c_ids):
    method generate (line 353) | def generate(self, init_state, c_ids):
  class ContextEncoderforQG (line 383) | class ContextEncoderforQG(nn.Module):
    method __init__ (line 384) | def __init__(self, embedding, emsize,
    method forward (line 398) | def forward(self, c_ids, a_ids):
  class QuestionDecoder (line 414) | class QuestionDecoder(nn.Module):
    method __init__ (line 415) | def __init__(self, sos_id, eos_id,
    method postprocess (line 456) | def postprocess(self, q_ids):
    method forward (line 471) | def forward(self, init_state, c_ids, q_ids, a_ids):
    method generate (line 535) | def generate(self, init_state, c_ids, a_ids):
    method sample (line 589) | def sample(self, init_state, c_ids, a_ids):
  class DiscreteVAE (line 644) | class DiscreteVAE(nn.Module):
    method __init__ (line 645) | def __init__(self, args):
    method return_init_state (line 712) | def return_init_state(self, zq, za):
    method forward (line 728) | def forward(self, c_ids, q_ids, a_ids, start_positions, end_positions):
    method generate (line 779) | def generate(self, zq, za, c_ids):
    method return_answer_logits (line 789) | def return_answer_logits(self, zq, za, c_ids):

FILE: vae/qgevalcap/bleu/bleu.py
  function precook (line 33) | def precook(s, n=4, out=False):
  function cook_refs (line 45) | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "av...
  function cook_test (line 70) | def cook_test(test, reflen, refmaxcounts, eff=None, n=4):
  class BleuScorer (line 95) | class BleuScorer(object):
    method copy (line 102) | def copy(self):
    method __init__ (line 110) | def __init__(self, test=None, refs=None, n=4, special_reflen=None):
    method cook_append (line 119) | def cook_append(self, test, refs):
    method ratio (line 132) | def ratio(self, option=None):
    method score_ratio (line 136) | def score_ratio(self, option=None):
    method score_ratio_str (line 140) | def score_ratio_str(self, option=None):
    method reflen (line 143) | def reflen(self, option=None):
    method testlen (line 147) | def testlen(self, option=None):
    method retest (line 151) | def retest(self, new_test):
    method rescore (line 162) | def rescore(self, new_test):
    method size (line 167) | def size(self):
    method __iadd__ (line 171) | def __iadd__(self, other):
    method compatible (line 185) | def compatible(self, other):
    method single_reflen (line 188) | def single_reflen(self, option="average"):
    method _single_reflen (line 191) | def _single_reflen(self, reflens, option=None, testlen=None):
    method recompute_score (line 204) | def recompute_score(self, option=None, verbose=0):
    method compute_score (line 208) | def compute_score(self, option=None, verbose=0):
  class Bleu (line 277) | class Bleu:
    method __init__ (line 278) | def __init__(self, n=4):
    method compute_score (line 284) | def compute_score(self, gts, res):
    method method (line 309) | def method(self):

FILE: vae/qgevalcap/bleu/bleu_scorer.py
  function precook (line 23) | def precook(s, n=4, out=False):
  function cook_refs (line 35) | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "av...
  function cook_test (line 60) | def cook_test(test, (reflen, refmaxcounts), eff=None, n=4):
  class BleuScorer (line 85) | class BleuScorer(object):
    method copy (line 92) | def copy(self):
    method __init__ (line 100) | def __init__(self, test=None, refs=None, n=4, special_reflen=None):
    method cook_append (line 109) | def cook_append(self, test, refs):
    method ratio (line 122) | def ratio(self, option=None):
    method score_ratio (line 126) | def score_ratio(self, option=None):
    method score_ratio_str (line 130) | def score_ratio_str(self, option=None):
    method reflen (line 133) | def reflen(self, option=None):
    method testlen (line 137) | def testlen(self, option=None):
    method retest (line 141) | def retest(self, new_test):
    method rescore (line 152) | def rescore(self, new_test):
    method size (line 157) | def size(self):
    method __iadd__ (line 161) | def __iadd__(self, other):
    method compatible (line 175) | def compatible(self, other):
    method single_reflen (line 178) | def single_reflen(self, option="average"):
    method _single_reflen (line 181) | def _single_reflen(self, reflens, option=None, testlen=None):
    method recompute_score (line 194) | def recompute_score(self, option=None, verbose=0):
    method compute_score (line 198) | def compute_score(self, option=None, verbose=0):

FILE: vae/qgevalcap/cider/cider.py
  function precook (line 21) | def precook(s, n=4, out=False):
  function cook_refs (line 38) | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
  function cook_test (line 48) | def cook_test(test, n=4):
  class CiderScorer (line 57) | class CiderScorer(object):
    method copy (line 61) | def copy(self):
    method __init__ (line 68) | def __init__(self, test=None, refs=None, n=4, sigma=6.0):
    method cook_append (line 78) | def cook_append(self, test, refs):
    method size (line 88) | def size(self):
    method __iadd__ (line 92) | def __iadd__(self, other):
    method compute_doc_freq (line 103) | def compute_doc_freq(self):
    method compute_cider (line 116) | def compute_cider(self):
    method compute_score (line 193) | def compute_score(self, option=None, verbose=0):
  class Cider (line 203) | class Cider:
    method __init__ (line 208) | def __init__(self, test=None, refs=None, n=4, sigma=6.0):
    method compute_score (line 214) | def compute_score(self, gts, res):
    method method (line 243) | def method(self):

FILE: vae/qgevalcap/cider/cider_scorer.py
  function precook (line 11) | def precook(s, n=4, out=False):
  function cook_refs (line 28) | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
  function cook_test (line 38) | def cook_test(test, n=4):
  class CiderScorer (line 47) | class CiderScorer(object):
    method copy (line 51) | def copy(self):
    method __init__ (line 58) | def __init__(self, test=None, refs=None, n=4, sigma=6.0):
    method cook_append (line 68) | def cook_append(self, test, refs):
    method size (line 78) | def size(self):
    method __iadd__ (line 82) | def __iadd__(self, other):
    method compute_doc_freq (line 93) | def compute_doc_freq(self):
    method compute_cider (line 106) | def compute_cider(self):
    method compute_score (line 183) | def compute_score(self, option=None, verbose=0):

FILE: vae/qgevalcap/eval.py
  function precook (line 11) | def precook(s, n=4, out=False):
  function cook_refs (line 24) | def cook_refs(refs, eff=None, n=4):  ## lhuang: oracle will call with "a...
  function cook_test (line 50) | def cook_test(test, reflen, refmaxcounts, eff=None, n=4):
  class BleuScorer (line 76) | class BleuScorer(object):
    method copy (line 84) | def copy(self):
    method __init__ (line 92) | def __init__(self, test=None, refs=None, n=4, special_reflen=None):
    method cook_append (line 101) | def cook_append(self, test, refs):
    method ratio (line 114) | def ratio(self, option=None):
    method score_ratio (line 118) | def score_ratio(self, option=None):
    method score_ratio_str (line 122) | def score_ratio_str(self, option=None):
    method reflen (line 125) | def reflen(self, option=None):
    method testlen (line 129) | def testlen(self, option=None):
    method retest (line 133) | def retest(self, new_test):
    method rescore (line 144) | def rescore(self, new_test):
    method size (line 149) | def size(self):
    method __iadd__ (line 153) | def __iadd__(self, other):
    method compatible (line 167) | def compatible(self, other):
    method single_reflen (line 170) | def single_reflen(self, option="average"):
    method _single_reflen (line 173) | def _single_reflen(self, reflens, option=None, testlen=None):
    method recompute_score (line 186) | def recompute_score(self, option=None, verbose=0):
    method compute_score (line 190) | def compute_score(self, option=None, verbose=0):
  class Bleu (line 258) | class Bleu:
    method __init__ (line 259) | def __init__(self, n=4):
    method compute_score (line 265) | def compute_score(self, gts, res):
    method method (line 289) | def method(self):
  class QGEvalCap (line 293) | class QGEvalCap:
    method __init__ (line 294) | def __init__(self, gts, res):
    method evaluate (line 298) | def evaluate(self, not_print=True):
  function eval_qg (line 325) | def eval_qg(res_dict, gts_dict, not_print=True):

FILE: vae/qgevalcap/meteor/meteor.py
  class Meteor (line 15) | class Meteor:
    method __init__ (line 17) | def __init__(self):
    method compute_score (line 33) | def compute_score(self, gts, res):
    method method (line 53) | def method(self):
    method _stat (line 56) | def _stat(self, hypothesis_str, reference_list):
    method _score (line 64) | def _score(self, hypothesis_str, reference_list):
    method __del__ (line 81) | def __del__(self):

FILE: vae/qgevalcap/rouge/rouge.py
  function my_lcs (line 13) | def my_lcs(string, sub):
  class Rouge (line 36) | class Rouge():
    method __init__ (line 41) | def __init__(self):
    method calc_score (line 45) | def calc_score(self, candidate, refs):
    method compute_score (line 77) | def compute_score(self, gts, res):
    method method (line 104) | def method(self):

FILE: vae/squad_utils.py
  class SquadExample (line 17) | class SquadExample(object):
    method __init__ (line 22) | def __init__(self,
    method __str__ (line 38) | def __str__(self):
    method __repr__ (line 41) | def __repr__(self):
  class InputFeatures (line 56) | class InputFeatures(object):
    method __init__ (line 58) | def __init__(self,
  function convert_examples_to_features (line 103) | def convert_examples_to_features(examples, tokenizer, max_seq_length,
  function convert_examples_to_harv_features (line 269) | def convert_examples_to_harv_features(examples, tokenizer, max_seq_length,
  function convert_examples_to_features_answer_id (line 400) | def convert_examples_to_features_answer_id(examples, tokenizer, max_seq_...
  function read_examples (line 611) | def read_examples(input_file, debug=False, is_training=False):
  function read_squad_examples (line 667) | def read_squad_examples(input_file, is_training, version_2_with_negative...
  function _improve_answer_span (line 747) | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
  function _check_is_max_context (line 784) | def _check_is_max_context(doc_spans, cur_span_index, position):
  function write_predictions (line 822) | def write_predictions(all_examples, all_features, all_results, n_best_size,
  function get_final_text (line 1028) | def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=...
  function _get_best_indexes (line 1123) | def _get_best_indexes(logits, n_best_size):
  function _compute_softmax (line 1136) | def _compute_softmax(scores):
  function write_answer_predictions (line 1159) | def write_answer_predictions(all_examples, all_features, all_results, n_...
  function normalize_answer (line 1360) | def normalize_answer(s):
  function f1_score (line 1379) | def f1_score(prediction, ground_truth):
  function exact_match_score (line 1393) | def exact_match_score(prediction, ground_truth):
  function metric_max_over_ground_truths (line 1397) | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
  function evaluate (line 1405) | def evaluate(dataset, predictions):
  function read_predictions (line 1429) | def read_predictions(prediction_file):
  function read_answers (line 1435) | def read_answers(gold_file):
  function evaluate_mrqa (line 1447) | def evaluate_mrqa(answers, predictions, skip_no_answer=False):

FILE: vae/trainer.py
  class VAETrainer (line 7) | class VAETrainer(object):
    method __init__ (line 8) | def __init__(self, args):
    method train (line 23) | def train(self, c_ids, q_ids, a_ids, start_positions, end_positions):
    method generate_posterior (line 46) | def generate_posterior(self, c_ids, q_ids, a_ids):
    method generate_answer_logits (line 53) | def generate_answer_logits(self, c_ids, q_ids, a_ids):
    method generate_prior (line 60) | def generate_prior(self, c_ids):
    method save (line 67) | def save(self, filename):

FILE: vae/translate.py
  function return_mask_lengths (line 14) | def return_mask_lengths(ids):
  function post_process (line 20) | def post_process(q_ids, start_positions, end_positions, c_ids, total_max...
  function main (line 63) | def main(args):

FILE: vae/utils.py
  function get_squad_data_loader (line 11) | def get_squad_data_loader(tokenizer, file, shuffle, args):
  function get_harv_data_loader (line 33) | def get_harv_data_loader(tokenizer, file, shuffle, ratio, args):
  function batch_to_device (line 50) | def batch_to_device(batch, device):
Condensed preview — 31 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (292K chars).
[
  {
    "path": ".gitignore",
    "chars": 1807,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 6577,
    "preview": "# Generating Diverse and Consistent QA pairs from Contexts with Information-Maximizing Hierarchical Conditional VAEs\nThi"
  },
  {
    "path": "qa-eval/distributed_run.py",
    "chars": 760,
    "preview": "import torch\nimport torch.multiprocessing as mp\nfrom trainer import Trainer\n\n\ndef distributed_main(args):\n    ngpus_per_"
  },
  {
    "path": "qa-eval/main.py",
    "chars": 6640,
    "preview": "import argparse\nimport linecache\nimport os\nimport pickle\nimport subprocess\nimport time\n\nimport torch\nfrom torch.utils.da"
  },
  {
    "path": "qa-eval/squad_utils.py",
    "chars": 60305,
    "preview": "import collections\nimport gzip\nimport json\nimport math\nimport re\nimport string\nimport sys\nfrom copy import deepcopy\n\nimp"
  },
  {
    "path": "qa-eval/trainer.py",
    "chars": 10893,
    "preview": "import collections\nimport math\nimport os\nimport time\nimport json\nimport socket\nfrom tqdm import tqdm\nimport numpy as np\n"
  },
  {
    "path": "qa-eval/utils.py",
    "chars": 4436,
    "preview": "import time\r\nimport math\r\nimport torch\r\n\r\ndef time_since(t):\r\n    \"\"\" Function for time. \"\"\"\r\n    return time.time() - t"
  },
  {
    "path": "vae/eval.py",
    "chars": 5831,
    "preview": "import collections\nimport json\nimport os\n\nimport torch\nfrom transformers import BertTokenizer\nfrom tqdm import tqdm\n\nfro"
  },
  {
    "path": "vae/generate_qa.py",
    "chars": 3904,
    "preview": "import argparse\nimport json\nimport os\n\nimport torch\nfrom torch.utils.data import DataLoader, Dataset, TensorDataset\nfrom"
  },
  {
    "path": "vae/main.py",
    "chars": 5195,
    "preview": "import argparse\r\nimport os\r\nimport random\r\n\r\nimport numpy as np\r\nimport torch\r\nfrom tqdm import tqdm, trange\r\nfrom trans"
  },
  {
    "path": "vae/models.py",
    "chars": 32071,
    "preview": "import numpy as np\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nfrom torch.nn.utils.rnn import"
  },
  {
    "path": "vae/qgevalcap/.gitignore",
    "chars": 6,
    "preview": "*.pyc\n"
  },
  {
    "path": "vae/qgevalcap/README.md",
    "chars": 68,
    "preview": "## evaluation scripts\n\n./eval.py --out_file \\<path to output file\\>\n"
  },
  {
    "path": "vae/qgevalcap/bleu/.gitignore",
    "chars": 6,
    "preview": "*.pyc\n"
  },
  {
    "path": "vae/qgevalcap/bleu/LICENSE",
    "chars": 1105,
    "preview": "Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam\n\nPermission is hereby granted, free of "
  },
  {
    "path": "vae/qgevalcap/bleu/bleu.py",
    "chars": 9847,
    "preview": "#!/usr/bin/env python\n#\n# File Name : bleu.py\n#\n# Description : Wrapper for BLEU scorer.\n#\n# Creation Date : 06-01-2015\n"
  },
  {
    "path": "vae/qgevalcap/bleu/bleu_scorer.py",
    "chars": 8703,
    "preview": "#!/usr/bin/env python\n\n# bleu_scorer.py\n# David Chiang <chiang@isi.edu>\n\n# Copyright (c) 2004-2006 University of Marylan"
  },
  {
    "path": "vae/qgevalcap/cider/__init__.py",
    "chars": 21,
    "preview": "__author__ = 'tylin'\n"
  },
  {
    "path": "vae/qgevalcap/cider/cider.py",
    "chars": 9325,
    "preview": "# Filename: cider.py\n#\n# Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evalua"
  },
  {
    "path": "vae/qgevalcap/cider/cider_scorer.py",
    "chars": 7694,
    "preview": "#!/usr/bin/env python\n# Tsung-Yi Lin <tl483@cornell.edu>\n# Ramakrishna Vedantam <vrama91@vt.edu>\n\nimport copy\nfrom colle"
  },
  {
    "path": "vae/qgevalcap/eval.py",
    "chars": 11514,
    "preview": "import os\nimport copy\nfrom collections import defaultdict\nfrom argparse import ArgumentParser\nimport pickle\nimport json\n"
  },
  {
    "path": "vae/qgevalcap/meteor/__init__.py",
    "chars": 21,
    "preview": "__author__ = 'tylin'\n"
  },
  {
    "path": "vae/qgevalcap/meteor/meteor.py",
    "chars": 3304,
    "preview": "#!/usr/bin/env python\n\n# Python wrapper for METEOR implementation, by Xinlei Chen\n# Acknowledge Michael Denkowski for th"
  },
  {
    "path": "vae/qgevalcap/rouge/__init__.py",
    "chars": 23,
    "preview": "__author__ = 'vrama91'\n"
  },
  {
    "path": "vae/qgevalcap/rouge/rouge.py",
    "chars": 3643,
    "preview": "#!/usr/bin/env python\n# \n# File Name : rouge.py\n#\n# Description : Computes ROUGE-L metric as described by Lin and Hovey "
  },
  {
    "path": "vae/squad_utils.py",
    "chars": 61840,
    "preview": "import collections\r\nimport gzip\r\nimport json\r\nimport math\r\nimport re\r\nimport string\r\nimport sys\r\nfrom copy import deepco"
  },
  {
    "path": "vae/trainer.py",
    "chars": 2368,
    "preview": "import torch\nimport torch.nn as nn\n\nfrom models import DiscreteVAE, return_mask_lengths\n\n\nclass VAETrainer(object):\n    "
  },
  {
    "path": "vae/translate.py",
    "chars": 6803,
    "preview": "import argparse\nimport pickle\n\nimport torch\nfrom transformers import BertTokenizer\nfrom torch.utils.data import DataLoad"
  },
  {
    "path": "vae/utils.py",
    "chars": 3043,
    "preview": "import random\r\n\r\nimport torch\r\nfrom torch.utils.data import DataLoader, TensorDataset\r\n\r\nfrom squad_utils import (conver"
  }
]

// ... and 1 more files (download for full content)

About this extraction

This page contains the full source code of the seanie12/Info-HCVAE GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 31 files (272.6 KB), approximately 62.5k tokens, and a symbol index with 265 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!