Full Code of facebookresearch/PAQ for AI

main 2bfd2c85e58e cached
52 files
194.1 KB
50.0k tokens
141 symbols
1 requests
Download .txt
Showing preview only (209K chars total). Download the full file or copy to clipboard to get everything.
Repository: facebookresearch/PAQ
Branch: main
Commit: 2bfd2c85e58e
Files: 52
Total size: 194.1 KB

Directory structure:
gitextract_5g97rwbv/

├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── full_models_list.md
├── generator_configs/
│   ├── answer_extractor_configs/
│   │   ├── learnt_answer_extractor_config.json
│   │   └── named_entity_answer_extractor_config.json
│   ├── filterer_configs/
│   │   ├── dummy_filtering_config.json
│   │   ├── global_filtering_config.json
│   │   └── local_filtering_config.json
│   ├── paq_L1_config.json
│   ├── paq_L1_with_local_filtering_config.json
│   ├── paq_L4_config.json
│   ├── paq_NE_config.json
│   ├── passage_ranker_configs/
│   │   ├── dummy_passage_scorer_config.json
│   │   ├── learnt_passage_scorer_config.json
│   │   └── lookup_passage_scorer_config.json
│   └── question_generator_configs/
│       └── question_generation_config.json
├── paq/
│   ├── __init__.py
│   ├── download.py
│   ├── evaluation/
│   │   ├── __init__.py
│   │   ├── eval_reranker.py
│   │   ├── eval_retriever.py
│   │   └── eval_utils.py
│   ├── generation/
│   │   ├── __init__.py
│   │   ├── answer_extractor/
│   │   │   ├── __init__.py
│   │   │   ├── extract_answers.py
│   │   │   ├── extractors.py
│   │   │   └── span2D_model.py
│   │   ├── filtering/
│   │   │   ├── __init__.py
│   │   │   ├── filter_questions.py
│   │   │   └── filterer.py
│   │   ├── generate_qa_pairs.py
│   │   ├── passage_scorer/
│   │   │   ├── __init__.py
│   │   │   ├── score_passages.py
│   │   │   └── scorer.py
│   │   └── question_generator/
│   │       ├── __init__.py
│   │       ├── generate_questions.py
│   │       └── generator.py
│   ├── paq_utils.py
│   ├── rerankers/
│   │   ├── __init__.py
│   │   └── rerank.py
│   ├── retrievers/
│   │   ├── __init__.py
│   │   ├── build_index.py
│   │   ├── embed.py
│   │   ├── retrieve.py
│   │   └── retriever_utils.py
│   └── server/
│       ├── __init__.py
│       ├── client.py
│       ├── launch_server.sh
│       └── server.py
└── requirements.txt

================================================
FILE CONTENTS
================================================

================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Code of Conduct

Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
Please read the [full text](https://code.fb.com/codeofconduct/)
so that you can understand what actions will and will not be tolerated.


================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to this repo

## Pull Requests

In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Facebook's open source projects.

Complete your CLA here: <https://code.facebook.com/cla>

## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.

## License
By contributing to this repo, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.


================================================
FILE: LICENSE
================================================
Attribution-NonCommercial 4.0 International

=======================================================================

Creative Commons Corporation ("Creative Commons") is not a law firm and
does not provide legal services or legal advice. Distribution of
Creative Commons public licenses does not create a lawyer-client or
other relationship. Creative Commons makes its licenses and related
information available on an "as-is" basis. Creative Commons gives no
warranties regarding its licenses, any material licensed under their
terms and conditions, or any related information. Creative Commons
disclaims all liability for damages resulting from their use to the
fullest extent possible.

Using Creative Commons Public Licenses

Creative Commons public licenses provide a standard set of terms and
conditions that creators and other rights holders may use to share
original works of authorship and other material subject to copyright
and certain other rights specified in the public license below. The
following considerations are for informational purposes only, are not
exhaustive, and do not form part of our licenses.

     Considerations for licensors: Our public licenses are
     intended for use by those authorized to give the public
     permission to use material in ways otherwise restricted by
     copyright and certain other rights. Our licenses are
     irrevocable. Licensors should read and understand the terms
     and conditions of the license they choose before applying it.
     Licensors should also secure all rights necessary before
     applying our licenses so that the public can reuse the
     material as expected. Licensors should clearly mark any
     material not subject to the license. This includes other CC-
     licensed material, or material used under an exception or
     limitation to copyright. More considerations for licensors:
	wiki.creativecommons.org/Considerations_for_licensors

     Considerations for the public: By using one of our public
     licenses, a licensor grants the public permission to use the
     licensed material under specified terms and conditions. If
     the licensor's permission is not necessary for any reason--for
     example, because of any applicable exception or limitation to
     copyright--then that use is not regulated by the license. Our
     licenses grant only permissions under copyright and certain
     other rights that a licensor has authority to grant. Use of
     the licensed material may still be restricted for other
     reasons, including because others have copyright or other
     rights in the material. A licensor may make special requests,
     such as asking that all changes be marked or described.
     Although not required by our licenses, you are encouraged to
     respect those requests where reasonable. More_considerations
     for the public: 
	wiki.creativecommons.org/Considerations_for_licensees

=======================================================================

Creative Commons Attribution-NonCommercial 4.0 International Public
License

By exercising the Licensed Rights (defined below), You accept and agree
to be bound by the terms and conditions of this Creative Commons
Attribution-NonCommercial 4.0 International Public License ("Public
License"). To the extent this Public License may be interpreted as a
contract, You are granted the Licensed Rights in consideration of Your
acceptance of these terms and conditions, and the Licensor grants You
such rights in consideration of benefits the Licensor receives from
making the Licensed Material available under these terms and
conditions.

Section 1 -- Definitions.

  a. Adapted Material means material subject to Copyright and Similar
     Rights that is derived from or based upon the Licensed Material
     and in which the Licensed Material is translated, altered,
     arranged, transformed, or otherwise modified in a manner requiring
     permission under the Copyright and Similar Rights held by the
     Licensor. For purposes of this Public License, where the Licensed
     Material is a musical work, performance, or sound recording,
     Adapted Material is always produced where the Licensed Material is
     synched in timed relation with a moving image.

  b. Adapter's License means the license You apply to Your Copyright
     and Similar Rights in Your contributions to Adapted Material in
     accordance with the terms and conditions of this Public License.

  c. Copyright and Similar Rights means copyright and/or similar rights
     closely related to copyright including, without limitation,
     performance, broadcast, sound recording, and Sui Generis Database
     Rights, without regard to how the rights are labeled or
     categorized. For purposes of this Public License, the rights
     specified in Section 2(b)(1)-(2) are not Copyright and Similar
     Rights.
  d. Effective Technological Measures means those measures that, in the
     absence of proper authority, may not be circumvented under laws
     fulfilling obligations under Article 11 of the WIPO Copyright
     Treaty adopted on December 20, 1996, and/or similar international
     agreements.

  e. Exceptions and Limitations means fair use, fair dealing, and/or
     any other exception or limitation to Copyright and Similar Rights
     that applies to Your use of the Licensed Material.

  f. Licensed Material means the artistic or literary work, database,
     or other material to which the Licensor applied this Public
     License.

  g. Licensed Rights means the rights granted to You subject to the
     terms and conditions of this Public License, which are limited to
     all Copyright and Similar Rights that apply to Your use of the
     Licensed Material and that the Licensor has authority to license.

  h. Licensor means the individual(s) or entity(ies) granting rights
     under this Public License.

  i. NonCommercial means not primarily intended for or directed towards
     commercial advantage or monetary compensation. For purposes of
     this Public License, the exchange of the Licensed Material for
     other material subject to Copyright and Similar Rights by digital
     file-sharing or similar means is NonCommercial provided there is
     no payment of monetary compensation in connection with the
     exchange.

  j. Share means to provide material to the public by any means or
     process that requires permission under the Licensed Rights, such
     as reproduction, public display, public performance, distribution,
     dissemination, communication, or importation, and to make material
     available to the public including in ways that members of the
     public may access the material from a place and at a time
     individually chosen by them.

  k. Sui Generis Database Rights means rights other than copyright
     resulting from Directive 96/9/EC of the European Parliament and of
     the Council of 11 March 1996 on the legal protection of databases,
     as amended and/or succeeded, as well as other essentially
     equivalent rights anywhere in the world.

  l. You means the individual or entity exercising the Licensed Rights
     under this Public License. Your has a corresponding meaning.

Section 2 -- Scope.

  a. License grant.

       1. Subject to the terms and conditions of this Public License,
          the Licensor hereby grants You a worldwide, royalty-free,
          non-sublicensable, non-exclusive, irrevocable license to
          exercise the Licensed Rights in the Licensed Material to:

            a. reproduce and Share the Licensed Material, in whole or
               in part, for NonCommercial purposes only; and

            b. produce, reproduce, and Share Adapted Material for
               NonCommercial purposes only.

       2. Exceptions and Limitations. For the avoidance of doubt, where
          Exceptions and Limitations apply to Your use, this Public
          License does not apply, and You do not need to comply with
          its terms and conditions.

       3. Term. The term of this Public License is specified in Section
          6(a).

       4. Media and formats; technical modifications allowed. The
          Licensor authorizes You to exercise the Licensed Rights in
          all media and formats whether now known or hereafter created,
          and to make technical modifications necessary to do so. The
          Licensor waives and/or agrees not to assert any right or
          authority to forbid You from making technical modifications
          necessary to exercise the Licensed Rights, including
          technical modifications necessary to circumvent Effective
          Technological Measures. For purposes of this Public License,
          simply making modifications authorized by this Section 2(a)
          (4) never produces Adapted Material.

       5. Downstream recipients.

            a. Offer from the Licensor -- Licensed Material. Every
               recipient of the Licensed Material automatically
               receives an offer from the Licensor to exercise the
               Licensed Rights under the terms and conditions of this
               Public License.

            b. No downstream restrictions. You may not offer or impose
               any additional or different terms or conditions on, or
               apply any Effective Technological Measures to, the
               Licensed Material if doing so restricts exercise of the
               Licensed Rights by any recipient of the Licensed
               Material.

       6. No endorsement. Nothing in this Public License constitutes or
          may be construed as permission to assert or imply that You
          are, or that Your use of the Licensed Material is, connected
          with, or sponsored, endorsed, or granted official status by,
          the Licensor or others designated to receive attribution as
          provided in Section 3(a)(1)(A)(i).

  b. Other rights.

       1. Moral rights, such as the right of integrity, are not
          licensed under this Public License, nor are publicity,
          privacy, and/or other similar personality rights; however, to
          the extent possible, the Licensor waives and/or agrees not to
          assert any such rights held by the Licensor to the limited
          extent necessary to allow You to exercise the Licensed
          Rights, but not otherwise.

       2. Patent and trademark rights are not licensed under this
          Public License.

       3. To the extent possible, the Licensor waives any right to
          collect royalties from You for the exercise of the Licensed
          Rights, whether directly or through a collecting society
          under any voluntary or waivable statutory or compulsory
          licensing scheme. In all other cases the Licensor expressly
          reserves any right to collect such royalties, including when
          the Licensed Material is used other than for NonCommercial
          purposes.

Section 3 -- License Conditions.

Your exercise of the Licensed Rights is expressly made subject to the
following conditions.

  a. Attribution.

       1. If You Share the Licensed Material (including in modified
          form), You must:

            a. retain the following if it is supplied by the Licensor
               with the Licensed Material:

                 i. identification of the creator(s) of the Licensed
                    Material and any others designated to receive
                    attribution, in any reasonable manner requested by
                    the Licensor (including by pseudonym if
                    designated);

                ii. a copyright notice;

               iii. a notice that refers to this Public License;

                iv. a notice that refers to the disclaimer of
                    warranties;

                 v. a URI or hyperlink to the Licensed Material to the
                    extent reasonably practicable;

            b. indicate if You modified the Licensed Material and
               retain an indication of any previous modifications; and

            c. indicate the Licensed Material is licensed under this
               Public License, and include the text of, or the URI or
               hyperlink to, this Public License.

       2. You may satisfy the conditions in Section 3(a)(1) in any
          reasonable manner based on the medium, means, and context in
          which You Share the Licensed Material. For example, it may be
          reasonable to satisfy the conditions by providing a URI or
          hyperlink to a resource that includes the required
          information.

       3. If requested by the Licensor, You must remove any of the
          information required by Section 3(a)(1)(A) to the extent
          reasonably practicable.

       4. If You Share Adapted Material You produce, the Adapter's
          License You apply must not prevent recipients of the Adapted
          Material from complying with this Public License.

Section 4 -- Sui Generis Database Rights.

Where the Licensed Rights include Sui Generis Database Rights that
apply to Your use of the Licensed Material:

  a. for the avoidance of doubt, Section 2(a)(1) grants You the right
     to extract, reuse, reproduce, and Share all or a substantial
     portion of the contents of the database for NonCommercial purposes
     only;

  b. if You include all or a substantial portion of the database
     contents in a database in which You have Sui Generis Database
     Rights, then the database in which You have Sui Generis Database
     Rights (but not its individual contents) is Adapted Material; and

  c. You must comply with the conditions in Section 3(a) if You Share
     all or a substantial portion of the contents of the database.

For the avoidance of doubt, this Section 4 supplements and does not
replace Your obligations under this Public License where the Licensed
Rights include other Copyright and Similar Rights.

Section 5 -- Disclaimer of Warranties and Limitation of Liability.

  a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
     EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
     AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
     ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
     IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
     WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
     PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
     ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
     KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
     ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.

  b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
     TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
     NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
     INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
     COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
     USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
     ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
     DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
     IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.

  c. The disclaimer of warranties and limitation of liability provided
     above shall be interpreted in a manner that, to the extent
     possible, most closely approximates an absolute disclaimer and
     waiver of all liability.

Section 6 -- Term and Termination.

  a. This Public License applies for the term of the Copyright and
     Similar Rights licensed here. However, if You fail to comply with
     this Public License, then Your rights under this Public License
     terminate automatically.

  b. Where Your right to use the Licensed Material has terminated under
     Section 6(a), it reinstates:

       1. automatically as of the date the violation is cured, provided
          it is cured within 30 days of Your discovery of the
          violation; or

       2. upon express reinstatement by the Licensor.

     For the avoidance of doubt, this Section 6(b) does not affect any
     right the Licensor may have to seek remedies for Your violations
     of this Public License.

  c. For the avoidance of doubt, the Licensor may also offer the
     Licensed Material under separate terms or conditions or stop
     distributing the Licensed Material at any time; however, doing so
     will not terminate this Public License.

  d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
     License.

Section 7 -- Other Terms and Conditions.

  a. The Licensor shall not be bound by any additional or different
     terms or conditions communicated by You unless expressly agreed.

  b. Any arrangements, understandings, or agreements regarding the
     Licensed Material not stated herein are separate from and
     independent of the terms and conditions of this Public License.

Section 8 -- Interpretation.

  a. For the avoidance of doubt, this Public License does not, and
     shall not be interpreted to, reduce, limit, restrict, or impose
     conditions on any use of the Licensed Material that could lawfully
     be made without permission under this Public License.

  b. To the extent possible, if any provision of this Public License is
     deemed unenforceable, it shall be automatically reformed to the
     minimum extent necessary to make it enforceable. If the provision
     cannot be reformed, it shall be severed from this Public License
     without affecting the enforceability of the remaining terms and
     conditions.

  c. No term or condition of this Public License will be waived and no
     failure to comply consented to unless expressly agreed to by the
     Licensor.

  d. Nothing in this Public License constitutes or may be interpreted
     as a limitation upon, or waiver of, any privileges and immunities
     that apply to the Licensor or You, including from the legal
     processes of any jurisdiction or authority.

=======================================================================

Creative Commons is not a party to its public
licenses. Notwithstanding, Creative Commons may elect to apply one of
its public licenses to material it publishes and in those instances
will be considered the “Licensor.” The text of the Creative Commons
public licenses is dedicated to the public domain under the CC0 Public
Domain Dedication. Except for the limited purpose of indicating that
material is shared under a Creative Commons public license or as
otherwise permitted by the Creative Commons policies published at
creativecommons.org/policies, Creative Commons does not authorize the
use of the trademark "Creative Commons" or any other trademark or logo
of Creative Commons without its prior written consent including,
without limitation, in connection with any unauthorized modifications
to any of its public licenses or any other arrangements,
understandings, or agreements concerning use of licensed material. For
the avoidance of doubt, this paragraph does not form part of the
public licenses.

Creative Commons may be contacted at creativecommons.org.


================================================
FILE: README.md
================================================
# PAQ: 65 Million Probably-Asked Questions and What You Can Do With Them


This repository contains code and models to support the research paper [PAQ: 65 Million Probably-Asked Questions and What You Can Do With Them](https://arxiv.org/abs/2102.07033)

<br>
<p align="center">
  <img src="https://dl.fbaipublicfiles.com/MLQA/logos.png" alt="Facebook AI Research and UCL NLP"  width="60%"/>
  <br>
</p>
<br>

## Table of Contents

* [Table of Contents](#table-of-contents)
* [Data Downloads](#data-downloads)
  * [PAQ QA-pairs](#paq-qa-pairs)
  * [PAQ Metadata](#paq-metadata)
  * [Preprocessed Wikipedia Dump](#preprocessed-wikipedia-dump)
  * [Passage Selector Scores](#passage-selector-scores)
  * [PAQ QA-pair metadata](#paq-qa-pair-metadata)
  * [PAQ <em>unfiltered</em> QA-pair metadata](#paq-unfiltered-qa-pair-metadata)
  * [Training/Dev/Test QA Pairs](#trainingdevtest-qa-pairs)
* [Code and Models](#code-and-models)
  * [Installation and Setup:](#installation-and-setup)
  * [Download Tool](#download-tool)
  * [Question Answering with RePAQ](#question-answering-with-repaq)
     * [RePAQ Retrievers:](#repaq-retrievers)
        * [Minimal Retrieval Inference Example:](#minimal-retrieval-inference-example)
        * [Retriever Models, Precomputed Vectors and Indexes:](#retriever-models-precomputed-vectors-and-indexes)
        * [Embedding QA pairs:](#embedding-qa-pairs)
        * [Building indices:](#building-indices)
        * [Retriever Inference:](#retriever-inference)
        * [Evaluating Retriever Results:](#evaluating-retriever-results)
     * [RePAQ ReRankers:](#repaq-rerankers)
        * [Minimal Reranker Inference example:](#minimal-reranker-inference-example)
        * [Reranker Models:](#reranker-models)
        * [ReRanker Inference:](#reranker-inference)
        * [Evaluating Rerankers:](#evaluating-rerankers)
  * [Question-Answer Pair Generation](#question-answer-pair-generation)
     * [Passage Scoring/Ranking](#passage-scoringranking)
     * [Answer Extraction](#answer-extraction)
     * [Question Generation](#question-generation)
     * [Filtering Generated QA-pairs](#filtering-generated-qa-pairs)
     * [End2End Generation Tool](#end2end-generation-tool)
* [Citing](#citing)
* [LICENSE](#license)
  * [Code License:](#code-license)
  * [Data License:](#data-license)
 
## Data Downloads

PAQ QA pairs, their metadata, preprocessed wikipedia dumps and Train/dev/test QA pairs downloads are described in this section. For downloading models, indices etc, see [Code And Models](#code-and-models) section.
In addition to downloading the data here, you can use the `paq.download` tool, (recommended for downloading models, indices etc), see the [Download Tool](#download-tool) section for use.

### PAQ QA-pairs

The PAQ QA pairs can be downloaded below. We use the same format as for NQ-open (see [here](https://github.com/google-research-datasets/natural-questions/tree/master/nq_open)). The
 TQA_TRAIN_NQ_TRAIN_PAQ is the concatenation of the TriviaQA and NQ training QA-Pairs with the PAQ QA-Pairs.

| Dataset  | # QAs | Size (unzipped)| link | License |
| ------------- | ------------- | --------- | ---- | -----|
| PAQ     | 64.9M   | 5.8 GB| [download](https://dl.fbaipublicfiles.com/paq/v1/PAQ.tar.gz) |  [CC-BY-SA](https://creativecommons.org/licenses/by-sa/3.0/)|
| PAQ-L1  | 14.1M   | 1.3 GB| [download](https://dl.fbaipublicfiles.com/paq/v1/PAQ_L1.tar.gz) | [CC-BY-SA](https://creativecommons.org/licenses/by-sa/3.0/)|
| PAQ-L4  |  53.8M  | 4.9 GB| [download](https://dl.fbaipublicfiles.com/paq/v1/PAQ_L4.tar.gz) | [CC-BY-SA](https://creativecommons.org/licenses/by-sa/3.0/)|
| PAQ-NE1 | 12.0M   | 1.0 GB| [download](https://dl.fbaipublicfiles.com/paq/v1/PAQ_NE1.tar.gz) | [CC-BY-SA](https://creativecommons.org/licenses/by-sa/3.0/)|
| TQA_TRAIN_NQ_TRAIN_PAQ | 65.00M   | 5.9 GB| [download](https://dl.fbaipublicfiles.com/paq/v1/TQA_TRAIN_NQ_TRAIN_PAQ.tar.gz) | [CC-BY-SA](https://creativecommons.org/licenses/by-sa/3.0/)|


###  PAQ Metadata

Available metadata to support PAQ is available, and can be downloaded from the following table. See the descriptions below for details:

| Dataset  | Size (unzipped)| link | License |
| ------------- | ------------- | --------- |  ----|
| Preprocessed Wikipedia Dump   | 13 GB| [download](https://dl.fbaipublicfiles.com/paq/v1/psgs_w100.tsv.gz) | [CC-BY-SA](https://creativecommons.org/licenses/by-sa/3.0/)|
| Passage Selector Scores  | 560 MB| [download](https://dl.fbaipublicfiles.com/paq/v1/PASSAGE_SCORES.tar.gz) | [CC-BY-SA](https://creativecommons.org/licenses/by-sa/3.0/)|
| PAQ QA-Pair metadata  |  16 GB| [download](https://dl.fbaipublicfiles.com/paq/v1/PAQ.metadata.jsonl.gz) | [CC-BY-SA](https://creativecommons.org/licenses/by-sa/3.0/)|
| PAQ *unfiltered* QA-pairs and metadata | 95 GB| [download](https://dl.fbaipublicfiles.com/paq/v1/PAQ.unfiltered_metadata.jsonl.gz) | [CC-BY-SA](https://creativecommons.org/licenses/by-sa/3.0/)|

### Preprocessed Wikipedia Dump

This file contains the preprocessed wikipedia dump used to generate PAQ. The file consists of 100-word passages of a 2018 Wikipedia dump, and was produced by [Karphukin et al.](https://github.com/facebookresearch/DPR) for [DPR](https://github.com/facebookresearch/DPR).
The file is in TSV format, with 3 columns. The first column is passage id, the second column is the passage text, the third is the wikipedia article title. 

### Passage Selector Scores

This file contains the passage selection scores for passages, using the passage selection model described in the paper.
The file is in TSV format, with 2 columns. The first column is passage id (see "Preprocessed Wikipedia Dump"), the second column is the logprob score from the passage selector for that passage.

### PAQ QA-pair metadata

This file contains metadata for the QA pairs in PAQ. The file is in jsonl format. Each line is a json dict with metadata for one question-answer pair in PAQ.
The format is as follows:
```
{
    "question":  question string
    "subsets":   list of PAQ subsets the question appears in ("L1", "L4" or "NE")
    "answer":  the question's answer produced by the consistency filter model
    "passage_score": passage selection score of highest scoring passage that generated this question
    "answers": [
        {
            "passage_id": id of wiki passage this answer was extracted from (see "Preprocessed Wikipedia Dump")
            "offset": character offset to start of answer span
            "text": text of answer span
            "extractor": answer extractor model, either "L" (for learnt extracor), or "NE" (for Named Entity extractor)
        },
        ...
    ]
}
```
There are a small number of questions where the "subset" is "NE-legacy". These questions were generated by an earlier iteration of the "NE" generation pipeline.

### PAQ *unfiltered* QA-pair metadata

This file contains similar metadata to that described above in "PAQ QA pair metadata", but for *all* generated questions, even those that do not pass the consistency filter. 
As such, this is a very large file, and is provided for completeness, but should not be of interest to most users interested in PAQ metadata.
The file is in jsonl format. Each line is a json dict with metadata for one question-answer pair.
The format is as follows:
```
{
    "question":  question string
    "subsets":   list of PAQ subsets the question appears in ("L1", "L4" or "NE")
    "consistent_subsets":  list of PAQ subsets the question appears in, which pass the consistnency filters ("L1", "L4" or "NE")
    "canonical_answer":  the question's answer produced by the consistency filter model
    "consistent": boolean. If true, the question passes the global consistency filter
    "passage_score": passage selection score of highest scoring passage that generated this question
    "answers": [
        {
            "passage_id": id of wiki passage this answer was extracted from (see "Preprocessed Wikipedia Dump")
            "offset": character offset to start of answer span
            "text": text of answer span
            "extractor": answer extractor model, either "L" (for learnt extracor), or "NE" (for Named Entity extractor)
            "consistent": boolean. If true, this answer span is the consistent with the answer from the global consistency filter
        },
        ...
    ]
}
```

### Training/Dev/Test QA Pairs

The QA Pairs in the Open Domain NaturalQuestions and TriviaQA Train/Dev/Test sets are available below, as well as a file with the concatenation of the training sets and PAQ (useful for retrieval later).


| Dataset  | Description | Link | 
| ------------- |------------- | --------- |
| NQ-open.train-train.jsonl | Open-NaturalQuestions Training set | [download](https://dl.fbaipublicfiles.com/paq/v1/annotated_datasets/NQ-open.train-train.jsonl)|
| NQ-open.train-dev.jsonl |  Open-NaturalQuestions Development set| [download](https://dl.fbaipublicfiles.com/paq/v1/annotated_datasets/NQ-open.train-dev.jsonl)|
| NQ-open.test.jsonl |  Open-NaturalQuestions Test set| [download](https://dl.fbaipublicfiles.com/paq/v1/annotated_datasets/NQ-open.test.jsonl)|
| triviaqa.train-train.jsonl | Open-TriviaQA Training set  | [download](https://dl.fbaipublicfiles.com/paq/v1/annotated_datasets/triviaqa.train-train.jsonl)|
| triviaqa.train-dev.jsonl | Open-TriviaQA Development set| [download](https://dl.fbaipublicfiles.com/paq/v1/annotated_datasets/triviaqa.train-dev.jsonl)|
| triviaqa.test.jsonl | Open-TriviaQA Test set| [download](https://dl.fbaipublicfiles.com/paq/v1/annotated_datasets/triviaqa.test.jsonl)|
| tqa-train-nq-train-PAQ.jsonl | Concatenation of NQ-open.train-train.jsonl, triviaqa.train-train.jsonl and PAQ | [download](https://dl.fbaipublicfiles.com/paq/v1/TQA_TRAIN_NQ_TRAIN_PAQ.tar.gz)|

## Code and Models

All users should follow the instructions in [Installation and Setup](#installation-and-setup), and use the [Download Tool](#download-tool), which will make downloanding models and assets much easier.

Code to run inference for Question Answering using RePAQ and the full question generation pipeline are now available. Functionality to help train your own models is coming soon.

Users interested in running question answering with REPAQ, read [Question Answering with RePAQ](#question-answering-with-repaq).

Users interested in running Question generation using the PAQ generation pipeline, read [Question Answering with RePAQ](#question-answering-with-repaq).



### Installation and Setup:

We highly recommend you use conda environments. The requirements are pytorch, spacy, Transformers 4.1.0 (other versions unlikely to work), FID, and the packages listed in `requirements.txt`.
The following script should install all nececessary code dependencies:

```bash
conda create -n paq python=3.7
conda activate paq

# install pytorch
conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.1 -c pytorch
conda install -c pytorch faiss-gpu cudatoolkit=10.1

# For Spacy:
conda install -c conda-forge spacy
conda install -c conda-forge cupy
python -m spacy download en_core_web_sm
pip install -r requirements.txt

# Install FID for QA-pair consistency filtering:
git clone git@github.com:facebookresearch/FiD.git
cd FiD; git checkout baf533c3f7a26c1cac624ee9252ce5ccf344a935

```

### Download Tool

To make downloading resources easier, we've built a download tool.
This is the recommended way for downloading data, trained models, precomputed vectors and indices. 
This will download and uncompress resources to the `./data` directory, where the code will expect these resources to be, and handle path management.
Run it by supplying a resource key name (run with `-h` to see available resources): 
```bash
# Downloads a RePAQ retriever model:
$ python -m paq.download -v -n models.retrievers.retriever_multi_base_256
```

### Question Answering with RePAQ

Question Answering over PAQ with RePAQ is accomplished using [Dense Retrieval](#repaq-retrievers), optionally following by [Reranking](#repaq-rerankers). 
Reranking will improve accuracy, but is slower.

To enable wider use our work,
we have trained more compact retrievers and indices than those used in the original paper.
Thse will still give strong results, but run on machines with smaller GPUs and modest amounts of CPU RAM (64GB CPU RAM should be plenty). 
These models are only marginally less accurate than the larger ones used in the paper, and we list them as "recommended" in the tables below.



#### RePAQ Retrievers:

##### Minimal Retrieval Inference Example:
TL;DR if you just want to run retrieval:

First, download 1) A retrieval model, 2) A KB of QA Pairs (in our case, TQA train set, NQ train set and PAQ) and 3) a pre-built index for those QA Pairs.

```bash
# download retriever model
$ python -m paq.download -v -n models.retrievers.retriever_multi_base_256

# Download QA Pairs, and a corresponding faiss index:
$ python -m paq.download -v -n paq.TQA_TRAIN_NQ_TRAIN_PAQ
$ python -m paq.download -v -n indices.multi_base_256_hnsw_sq8

# Download NaturalQuestions data, we'll run inference on the test set
$ python -m paq.download -v -n annotated_datasets.naturalquestions

```

Then, run retrieval inference (here we're using the v fast but slightly less accurate HNSW faiss index):

```bash
$ python -m paq.retrievers.retrieve \
    --model_name_or_path ./data/models/retrievers/retriever_multi_base_256 \
    --qas_to_answer data/annotated_datasets/NQ-open.test.jsonl \
    --qas_to_retrieve_from ./data/paq/TQA_TRAIN_NQ_TRAIN_PAQ/tqa-train-nq-train-PAQ.jsonl \
    --top_k 50 \
    --output_file my_retrieval_results.jsonl \
    --faiss_index_path data/indices/multi_base_256_hnsw_sq8.faiss \
    --fp16 \
    --memory_friendly_parsing \
    --verbose
```


Finally, either use a reranker to rerank the top K results (see [here](#minimal-reranker-inference-example)), or evaluate retrieval performance:

```bash
$ python -m paq.evaluation.eval_retriever \
    --predictions my_retrieval_results.jsonl \
    --references data/annotated_datasets/NQ-open.test.jsonl \
    --hits_at_k 1,10,50
1: 40.0%
(1443 / 3610)
10: 55.7%
(2010 / 3610)
50: 63.9%
(2306 / 3610)
```


##### Retriever Models, Precomputed Vectors and Indexes:

The following table lists the recommended models for inference. 
For an exahustive list of models available, see [full_models_list.md](./full_models_list.md). 
We highly recommend using `retriever_multi_base_256`. 
This model has been designed to be compute and memory-friendly. 
It's embedding dimension is 256 c.f. 768 used in the original paper, saving RAM when performing retrieval.
It  outperforms the base model from the paper, and loses only 0.7% on average ove the xlarge model from the paper.


| Model  | Training data |  Architecture | Embedding Dim | NQ EM | + rerank | TQA EM | + rerank |  Download Resource Key Name |
| ------------- |----------| --- | --------- | ---------- |---- |---- | ---- | ---- |
| retriever_multi_base_256  (recommended)| NQ + TQA | AlBERT-base | 256  | 41.4 | 47.3 | 40.2 | 50.9| `models.retrievers.retriever_multi_base_256` |
| retriever_multi_base | NQ + TQA | AlBERT-base |  728 | 40.9| 47.4 | 39.7 | 51.2 | `models.retrievers.retriever_multi_base`  |
| retriever_multi_xlarge | NQ + TQA | AlBERT-xlarge| 728  | 41.7 | 47.6 | 41.3 | 52.1 |`models.retrievers.retriever_multi_xlarge`|

The table below lists available precomputed embeddings and indices for download. The embeddings are stored according to the order in tqa-train-nq-train-PAQ.jsonl, corresponding to the TQA training set, the NQ training set and PAQ.
I.e. the kth QA pair in the file is embedded in the kth vector in these files.

To download precomputed vectors, use the `paq/download.py` script, as indicated in the table. 
We recommend using the FAISS indexes for running inference, either `multi_base_256.flat.sq8.faiss`
(slower, 1-10s questions/sec, but more accurate, and has lowest memory requirement ~16GB RAM), 
or `multi_base_256.hnsw.sq8.faiss` (very fast, 100-1000s questions/sec depending on machine, slightly less accurate (0.8% on average) but higher memory requirements ~32GB RAM)


| File  | Description | Size |  Download Resource Key Name  | 
| ------------- |------------- | --------- |---- |
| tqa-train-nq-train-PAQ.jsonl (required) | Concatenation of NQ-open.train-train.jsonl, triviaqa.train-train.jsonl and PAQ | | `paq.TQA_TRAIN_NQ_TRAIN_PAQ` |
| multi_base_256_vectors | embeddings for QAS in `tqa-train-nq-train-PAQ.jsonl` using `retriever_multi_base_256` | 16GB | `vectors.multi_base_vectors_256 ` |
| multi_base_vectors | embeddings for QAS in `tqa-train-nq-train-PAQ.jsonl` using `retriever_multi_base` | 48GB |`vectors.multi_base_vectors` |
| multi_xlarge_vectors| embeddings for QAS in `tqa-train-nq-train-PAQ.jsonl` using `retriever_multi_xlarge` | 48GB| `vectors.multi_xlarge_vectors` |
| multi_base_256.flat.sq8.faiss (recommended) | Flat FAISS index for `retriever_multi_base_256` - slower (1-10s questions / sec) | 16GB | `indices.multi_base_256_flat_sq8.faiss`|
| multi_base_256.hnsw.sq8.faiss (recommended) | Fast FAISS index for `retriever_multi_base_256` - faster (100-1000s queries / sec) | 32GB | `indices.multi_base_256_hnsw_sq8.faiss`|



##### Embedding QA pairs:

To embed a set of QA pairs in the NaturalQuestions jsonl format, use the `paq/evaluation/embed.py` file.

E.g. to embed the NQ training set to vectors using the `retriever_multi_base_256` model, and write them to disk,
run the following command:

```
python -m paq.retrievers.embed \
    --model_name_or_path ./data/models/retrievers/retriever_multi_base_256 \
    --qas_to_embed data/annotated_datasets/NQ-open.train-train.jsonl \
    --output_dir ./my_vectors \
    --fp16 \
    --batch_size 128 \
    --verbose \
    --n_jobs -1 
# see below for explanation of --n_jobs
```

For very large numbers of QA pairs, you may want to run this in parallel. 
This script is set up to work with submitit, and by default, can submit a slurm job array to embed the QA pairs in parallel.
For example, to run embedding locally, set `--n_jobs -1` (As above), or to run 10 parallel jobs to embed a file, run with `--n_jobs 10`.
The full command is given below:

```
python -m paq.retrievers.embed \
    --model_name_or_path ./data/models/retrievers/retriever_multi_base_256 \
    --qas_to_embed data/annotated_datasets/NQ-open.train-train.jsonl \
    --output_dir ./my_vectors_distributed \
    --fp16 \
    --batch_size 128 \
    --verbose \
    --memory_friendly_parsing \
    --n_jobs 10 \
    --slurm_partition my_clusters_partition \
    --slurm_comment "my embedding job"
    
```

The submitit job array config can be seen and edited for your clusters needs at `paq/paq_utils.py` (the `get_submitit_executor` function)

##### Building indices:

To build faiss MIPS indices on vectors produced by `paq.retrievers.embed`, (for improved quantization and speed over raw exact search in pytorch), use the `paq/retreiver/build_index.py`.
This will allow you to build indices like the ones used in the paper (specifically, Flat and HNSW indices, optionally with scalar quantization).

```
# build a flat index with Scaler quantization (slower queries, but slightly more accurate)
python -m paq.retrievers.build_index \
    --embeddings_dir ./my_vectors \
    --output_path ./my_index.faiss \
    --SQ8 \
    --verbose

# or, build an hnsw index with scaler (mcuh much faster qurerying, slightly less accurate)
python -m paq.retrievers.build_index \
    --embeddings_dir ./my_vectors \
    --output_path ./my_index.hnsw.faiss \
    --hnsw \
    --SQ8 \
    --store_n 32 \
    --ef_construction 128 \
    --ef_search 128 \
    --verbose

```

Building indices is a deep, nuanced and complex area. The scripts we provide to build indices is mostly a convenience and reproduciblity wrapper.
It's likely that stronger compression is possible without losing performance (e.g. by using Product Quantization), as is faster inference. 
If the indexes we provide are too large or slow, consider building your own by referring the the [faiss documentation](https://github.com/facebookresearch/faiss) directly.

##### Retriever Inference:

Run QA-pair Retrieval using `paq/retrievers/retrieve.py`. You can see argument help by passing `-h`. 
You must pass in a jsonl file of QA pairs to retrieve from, using the `--qas_to_retrieve_from` argument. 
You can also pass in either a directory of embeddings for the qa-pairs to retrieve from using the `--precomputed_embeddings_dir` (e.g. the output of `paq.retrievers.embed`) 
or a faiss index of the qa-pairs to retrieve from, using the `--faiss_index_path`. If neither `--faiss_index_path` or `--precomputed_embeddings_dir` are given, the QA-pairs to retrieve from will be embedded on-the-fly. This may be slow for large QA-pair KBs.

The following command will run retrieve the top 50 QA-pairs from the PAQ KB for the NQ-test set, using the fast HNSW faiss index, and write the results to `my_retrieval_results.jsonl`

```bash
#Download the relevant artefacts
$ python -m paq.download -v -n models.retrievers.retriever_multi_base_256
$ python -m paq.download -v -n paq.TQA_TRAIN_NQ_TRAIN_PAQ
$ python -m paq.download -v -n indices.multi_base_256_hnsw_sq8
$ python -m paq.download -v -n annotated_datasets.naturalquestions

$ python -m paq.retrievers.retrieve \
    --model_name_or_path ./data/models/retrievers/retriever_multi_base_256 \
    --qas_to_answer data/annotated_datasets/NQ-open.test.jsonl \
    --qas_to_retrieve_from ./data/paq/TQA_TRAIN_NQ_TRAIN_PAQ/tqa-train-nq-train-PAQ.jsonl \
    --top_k 50 \
    --output_file my_retrieval_results.jsonl \
    --faiss_index_path data/indices/multi_base_256_hnsw_sq8.faiss \
    --fp16 \
    --memory_friendly_parsing \
    --verbose
```
    
##### Evaluating Retriever Results:

Evaluate retrieval performance using the `paq.evaluation.eval_retriever` tool. 
It will return the hits@k (whether the correct answer is in the top K retrieved questions' answers). Hits@1 is equivalent to Exact Match score

```bash
$ python -m paq.evaluation.eval_retriever \
    --predictions my_retrieval_results.jsonl \
    --references data/annotated_datasets/NQ-open.test.jsonl \
    --hits_at_k 1,10,50
1: 40.0%
(1443 / 3610)
10: 55.7%
(2010 / 3610)
50: 63.9%
(2306 / 3610)
```


#### RePAQ ReRankers:

##### Minimal Reranker Inference example:
Tl;DR for if you just want to run reranking:
First, download a reranker model, (and if you dont already have retrieval results you want to rerank, download some)

```bash
# download reranker model (here we're using the albert xxlarge model, smaller ones are available)
$ python -m paq.download -v -n models.rerankers.reranker_multi_xxlarge

# download some retrieval results to rerank if you dont already have some
$ python -m paq.download -v -n predictions.retriever_results.multi_xlarge_nq

```
Next, run reranking:
```
$ python -m paq.rerankers.rerank \
    --model_name_or_path data/models/rerankers/reranker_multi_xxlarge  \
    --qas_to_rerank data/predictions/retriever_results/multi_xlarge_nq_test.jsonl \
    --output_file my_reranker_results.jsonl \
    --top_k 50 \
    --fp16 \
    --batch_size 4 --verbose --n_jobs -1
```

Then calculate results:
```
$ python -m paq.evaluation.eval_reranker --predictions my_reranker_results.jsonl --references data/annotated_datasets/NQ-open.test.jsonl
47.6%
(1699 / 3610)
```


##### Reranker Models:

The following table lists the recommended models for inference. 
For an exahustive list of models available, see [full_models_list.md](./full_models_list.md). 

| Model  | Training data |  Architecture | NQ EM | TQA EM |  Download Resource Key Name |
| ------------- |----------| --- | --------- | ---------- |---- |
|reranker_multi_base| NQ + TQA| AlBERT-base |46.0 |48.9 | `models.rerankers.reranker_multi_base`| 
|reranker_multi_large| NQ + TQA|AlBERT-large | 46.2| 49.4|`models.rerankers.reranker_multi_large`| 
|reranker_multi_xlarge| NQ + TQA|AlBERT-xlarge | 46.0| 49.1| `models.rerankers.reranker_multi_xlarge`| 
|reranker_multi_xxlarge| NQ + TQA|AlBERT-xxlarge | 47.7| 52.1 | `models.rerankers.reranker_multi_xxlarge`| 

##### ReRanker Inference:

Run QA-pair Retrieval using `paq/rerankers/rerank.py`. You can see argument help by passing `-h`. 
Pass retrieval results files of the format produced by `paq/retrievers/retrieve.py` into the `--qas_to_rerank` file.

If you have many retrieval results files to rerank, it might be useful to submit them to a cluster using `submitit` to run in parallel rather than run them one by one locally.

You can pass in a comma-separated list of retrieval results filepaths to `--qas_to_rerank` (and corresponding comma-separated list of output paths to `--output_file`) to do this, and specify the number of parallel jobs to schedule uing `--n_jobs`. To run reranking locally, pass in `--n_jobs -1`

An example of reranking the top 50 retrieved QA pairs on the NQ test set, using the ALBERT-xxlarge model running locally is shown below:
```bash
# download resources if needed:
python -m paq.download -v -n annotated_datasets.naturalquestions
python -m paq.download -v -n models.rerankers.reranker_multi_xxlarge
python -m paq.download -v -n predictions.retriever_results.multi_xlarge_nq

# run reranking
python -m paq.rerankers.rerank \
    --model_name_or_path data/models/rerankers/reranker_multi_xxlarge  \
    --qas_to_rerank data/predictions/retriever_results/multi_xlarge_nq_test.jsonl \
    --output_file my_reranker_results.jsonl \
    --top_k 50 \
    --fp16 \
    --batch_size 4 --verbose --n_jobs -1
```

##### Evaluating Rerankers:
Evalute the results of reranking using the `eval_reranker.py` file, this will return the Exact Match Score:

```
$ python -m paq.evaluation.eval_reranker --predictions my_reranker_results.jsonl --references data/annotated_datasets/NQ-open.test.jsonl
47.6%
(1699 / 3610)
```

### Question-Answer Pair Generation

The following sections details how to run the PAQ QA-Pair generation.

TL;DR for users who just want to generate QA pairs: The easiest way to generate QA-pairs is to use the [End2End Generation Tool](#end2end-generation-tool) section.

Each step in the pipeline can be run by itself, as described in the [Passage Scoring/Ranking](#passage-scoringranking), [Answer Extraction](#answer-extraction), [Question Generation](#question-generation) and [Filtering Generated QA-pairs](#filtering-generated-qa-pairs) section,
or the generation pipeline can be run fully end2end (from passages to filtered QA pairs), as described in the [End2End Generation Tool](#end2end-generation-tool) section.

Training code for training your own models is coming soon.

The pipelines have a lot of configurations and options, so to keep track of these, we use json config files to specify pipeline behaviours.
A number of example configs are listed in the `generator_configs` directory, or you can adapt them or write your own to fit your own needs.

#### Passage Scoring/Ranking

To perform passage ranking, use the `paq.generation.passage_scorer.score_passages` program, which takes as input a config json file and file of passages formatted as a tsv (passage id, passage text, passage title).

There are three passage rankers implemented:
* `DummyPassageScorer`: Applies the same score to all documents. An example config for this scorer is `generator_configs/passage_ranker_configs/dummy_passage_scorer_config.json`
* `LookupPassageScorer`: Looks up precomputed scores based on passage id (useful if you run the same passages through the pipeline a lot, and want to save compute). An example config for this scorer is `generator_configs/passage_ranker_configs/lookup_passage_scorer_config.json`
* `LearntPassageScorer`: Use a trained Passage Scorer (as done in the Paper). An example config for this scorer is `generator_configs/passage_ranker_configs/learnt_passage_scorer_config.json`

A trained passage scorer is available for download: 

| Model  | Training data |  Architecture |  Download Resource Key Name |
| ------------- |---------- |---- | ---- |
| passage_ranker_base| NQ | BERT-base |  `models.passage_rankers.passage_ranker_base`| 

Note, the original Passage ranker model used in the paper was unfortunately lost due to a storage corruption issue.
The model available here is a reproduction using the same hardware and HPs, but differs a little due to the stochastic training sampling procedure.

Below is an example to get passage scores for the the first 1000 passages of wikipedia:

```bash
# download the passage scorer model, and wikipedia text
python -m paq.download -v -n models.passage_rankers.passage_ranker_base
python -m paq.download -v -n paq.psgs_w100

# get 1000 passages to score
head -n 1000 data/paq/psgs_w100.tsv > data/paq/psgs_w100.first_1000.tsv

# run scoring
python -m paq.generation.passage_scorer.score_passages \
    --passages_to_score data/paq/psgs_w100.first_1000.tsv \
    --output_path my_passages_with_scores.jsonl \
    --path_to_config generator_configs/passage_ranker_configs/learnt_passage_scorer_config.json \
    --verbose

```

This will output a jsonl file with the following format (which is accepted by the [Answer Extraction](#answer-extraction) component below)
```json
{
  "passage_id": "ID for passage", 
  "passage": "Main text of passage.",
  "metadata": {"title": "Title of passage", "ps_score": "passage score"}
}
```

#### Answer Extraction

To perform answer extraction on passages, use the `paq.generation.answer_extractor.extract_answers` program, which takes as input a config file and passages formatted in the output format of the [Passage Scoring/Ranking](#passage-scoringranking) functionality.

There are two answer extractors implemented:
* `SpacyNERExtractor`: This answer extractor will extract named entities from passages as answers (as used in PAQ-NE). An example config for this extractor is `generator_configs/answer_extractor_configs/named_entity_answer_extractor_config.json` 
* `Span2DAnswerExtractor`: This answer extractor uses a learnt answer span extractor to extract answers (as used in PAQ-L).  An example config for this extractor is `generator_configs/answer_extractor_configs/learnt_answer_extractor_config.json`

The learnt answer span extractor model used in the paper is available for download:


| Model  | Description | Training data |  Architecture |  Download Resource Key Name |
| ----------| --- |---------- |---- | ---- |
| answer_extractor_nq_base| Learnt Answer Span Extractor, BERT-base, NQ-trained | NQ | BERT-base |  `models.answer_extractors.answer_extractor_nq_base`| 


Below is an example to extract answers from passages, using the learnt extractor:

```bash

# download the span extractor model:
python -m paq.download -v -n models.answer_extractors.answer_extractor_nq_base

# run answer extraction
python -m paq.generation.answer_extractor.extract_answers \
    --passages_to_extract_from my_passages_with_scores.jsonl \
    --output_path my_pasages_with_answers.jsonl \
    --path_to_config generator_configs/answer_extractor_configs/learnt_answer_extractor_config.json \
    --verbose
```
This will output a jsonl file with the following format (which is accepted by the [Question Generation](#question-generation) component below)

```json
{
  "passage_id": "ID for passage", 
  "passage": "Main text of passage.",
  "metadata": {"title": "Title of passage", "ps_score": "passage score"},
  "answers": [{"text": "Main", "start": 0, "end": 5, "score": "score for answer"}, {"text": "passage", "start": 13, "end": 20, "score": "score for answer"}]
}
```

#### Question Generation

To perform Question Generation on passages with extracted answers, use the `paq.generation.question_generator.generate_questions` program, which takes as input a config file and passages with answers formatted in the output format of the [Answer Extraction](#answer-extraction) functionality. 

An example config for question generation can be found here: `generator_configs/question_generator_configs/question_generation_config.json`

The following trained question generators are available:

| Model  | Training data |  Architecture |  Download Resource Key Name |
| ------------- |---------- |---- | ---- |
| qgen_nq_base| NQ | BART-base |  `models.qgen.qgen_nq_base`| 
| qgen_multi_base| Multitask | BART-base |  `models.qgen.qgen_multi_base`| 


Below is an example to generate questions from passages with extracted answers, using the multitask generator:

```bash

# download the qgen model:
python -m paq.download -v -n models.qgen.qgen_multi_base

# run question generation extraction
python -m paq.generation.question_generator.generate_questions \
    --passage_answer_pairs_to_generate_from my_pasages_with_answers.jsonl \
    --output_path my_generated_questions.jsonl \
    --path_to_config generator_configs/question_generator_configs/question_generation_config.json \
    --verbose
```

This will output a jsonl file with the following format (which is accepted by the [Filtering Generated QA-pairs](#filtering-generated-qa-pairs) component below)

```json
{
  "passage_id": "ID for passage", 
  "answer": "Benedict", 
  "question": "which pope has the middle name gregory",
  "metadata": {"answer_start": 617, "answer_end": 625, "ae_score": "score for answer", "qg_score": "currently not implemented, but score for question can go here"}
}
```

#### Filtering Generated QA-pairs

Generated questions can be inconsistent, or poor quality, or overly ambiguous. 
Empirically, we find it important to filter the generated questions for answer consistency. 
To perform filtering on generated questions, use the `paq.generation.filtering.filter_questions` program, which takes as input a config file, and generated questions formatted in the output format of the [Question Generation](#question-generation) functionality.

Filtering is split into two parts: retrieval and reading. 
The retriever retrieves passages from a corpus using the generated question, and the reader reads the passages and computes an answer.

We have implemented the following filterers:
* *Dummy filtering*: uses a `DummyFilteringRetriever` and `DummyReader`, assigns all answers as consistent. An example config is `generator_configs/filterer_configs/dummy_filtering_config.json` 
* *Local filtering* (fast but not as good): essentially performs reading comprehension. uses a `LocalFilteringRetriever` to "retrieve" the passage the question was generated from. The reader (`FiDReader`) generates an answer using only this single gold passage. We use FID supplied with a single passage as the reader, which worked as well as standard readers in our experiments. An example config is `generator_configs/filterer_configs/local_filtering_config.json`.
* *Global Filtering* (slow but important for strong performance): Uses A `GlobalFilteringRetriever` to retrieve relevant passages for the question (this uses DPR under the hood). The reader is a `FiDReader`, (this is FID under the hood). An example config is `generator_configs/filterer_configs/global_filtering_config.json`

The following trained models are available for download:

| Model  | Description | Training data |  Architecture |  Download Resource Key Name |
| ----------| --- |---------- |---- | ---- |
| dpr_nq_passage_retriever| DPR Passage retriever and faiss index, from the DPR Paper, used for retrieving passage for the reader in global filtering, NQ-trained| NQ | BERT-base |  `models.filtering.dpr_nq_passage_retriever`| 
| fid_reader_nq_base| FID-base reader, from the Fusion-in-Decoder paper, used in global and local filtering, NQ-trained | NQ | t5-base |  `models.filtering.fid_reader_nq_base`| 

Below is an example of how to filter questions (both with local and global filtering):

```bash
# download the corpus to retrieve from, the DPR retriever and the reader:
python -m paq.download -v -n paq.psgs_w100
python -m paq.download -v -n models.filtering.dpr_nq_passage_retriever
python -m paq.download -v -n models.filtering.fid_reader_nq_base

# run filtering using local filtering...
python -m paq.generation.filtering.filter_questions \
    --generated_questions_to_filter my_generated_questions.jsonl \
    --output_path my_locally_filtered_questions.jsonl \
    --path_to_config generator_configs/filterer_configs/local_filtering_config.json \
    --verbose

# or, run filtering using global filtering 
python -m paq.generation.filtering.filter_questions \
    --generated_questions_to_filter my_generated_questions.jsonl \
    --output_path my_globally_filtered_questions.jsonl \
    --path_to_config generator_configs/filterer_configs/global_filtering_config.json \
    --verbose

```

This will output a jsonl file with the following format:

```json
{
  "passage_id": "ID for passage", 
  "answer": "Benedict", 
  "question": "which pope has the middle name gregory",
  "metadata": {"filter_answer": "benedict", "consistent": true, "answer_start": 617, "answer_end": 625, "ae_score": "score for answer", "qg_score": "currently not implemented, but score for question can go here"}
}
```

#### End2End Generation Tool

To run all the steps in the pipeline end2end, use the `paq.generation.generate_qa_pairs` program.
This will run passage ranking, then answer extraction, then generation, then finally filtering automatically.

The tool takes as input a config json file, and a file passages to generate QA pairs from, formatted as a tsv (passage id, passage text, passage title).
The tool will create out output directory, and write intermediate results to it, including the final generated QA-pairs in the `final_qas.jsonl` file.

The following example configs can be used with this tool to replicate the generation pipelines used in the paper:
* `generator_configs/paq_L1_config.json`: run a generation pipeline replicating PAQ-L1
* `generator_configs/paq_L4_config.json`: run a generation pipeline replicating PAQ-L4
* `generator_configs/paq_NE_config.json`: run a generation pipeline replicating PAQ-NE
* `generator_configs/paq_L1_local_filtering_config.json`: run a generation pipeline replicating PAQ-L1, but with local rather than global filtering.

Or, write your own config to fit your generation needs.

The following code will run the PAQ-L1 generation pipeline on the first 1000 passages in the preprocssed wikipedia dump:

```bash
# Download the models and data we need:
python -m paq.download -v -n models.passage_rankers.passage_ranker_base
python -m paq.download -v -n models.answer_extractors.answer_extractor_nq_base
python -m paq.download -v -n models.qgen.qgen_multi_base
python -m paq.download -v -n paq.psgs_w100
python -m paq.download -v -n models.filtering.dpr_nq_passage_retriever
python -m paq.download -v -n models.filtering.fid_reader_nq_base

head -n 1000 data/paq/psgs_w100.tsv > data/paq/psgs_w100.first_1000.tsv

python -m paq.generation.generate_qa_pairs \
    --passage_files_to_generate data/paq/psgs_w100.first_1000.tsv \
    --output_dirs my_generated_qas \
    --path_to_config generator_configs/paq_L1_config.json\
    --verbose --n_jobs -1
```

The paq.generation.generate_qa_pairs can use submitit to run generation on a cluster.
The `--n_jobs` flag indicates how many concurrent submitit jobs to submit, use --n_jobs -1 to run locally.
To run generation in several jobs in parallel, you can pass in a comma-separated list of input files to `--passage_files_to_generate`
 and a corresponding comma separated list of output directories to create.



## Citing

To cite this work, please use the following bibtex:
```
@article{lewis2021paq,
      title={PAQ: 65 Million Probably-Asked Questions and What You Can Do With Them}, 
      author={Patrick Lewis and Yuxiang Wu and Linqing Liu and Pasquale Minervini and Heinrich Küttler and Aleksandra Piktus and Pontus Stenetorp and Sebastian Riedel},
      year={2021},
      eprint={2102.07033},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}
```

## LICENSE

### Code License:

The majority of the PAQ code is licensed under [CC-BY-NC](./LICENSE), however portions of the project are available under separate license terms: HuggingFace Transformers is licensed under Apache License 2.0; spaCy and wandb are licensed under the MIT License.
The code in this repository is licenced according the [LICENSE](./LICENSE) file.

### Data License:

The PAQ QA-pairs and metadata is licensed under [CC-BY-SA](https://creativecommons.org/licenses/by-sa/3.0/). 
Other data is licensed according to the accompanying license files.


================================================
FILE: full_models_list.md
================================================
# Full List of Models Available for Download

## BiEncoder Retrievers


| Model  | Training data |  Architecture | Embedding Dim | NQ EM | + rerank | TQA EM | + rerank |  Download Resource Key Name |
| ------------- |----------| --- | --------- | ---------- |---- |---- | ---- | ---- |
| retriever_multi_base_256  (recommended)| NQ + TQA | AlBERT-base | 256  | 41.4 | 47.3 | 40.2 | 50.9| `models.retrievers.retriever_multi_base_256` |
| retriever_multi_base | NQ + TQA | AlBERT-base  | 728 | 40.9| 47.4 | 39.7 | 51.2 | `models.retrievers.retriever_multi_base`  |
| retriever_multi_large | NQ + TQA | AlBERT-large | 728 | 41.2 | 47.5 | 41.0| 51.9 |`models.retrievers.retriever_multi_large`|
| retriever_multi_xlarge | NQ + TQA | AlBERT-xlarge| 728  | 41.7 | 47.6 | 41.3 | 52.1 |`models.retrievers.retriever_multi_xlarge`|
| retriever_nq_base | NQ | AlBERT-base | 728 | 41.0 | 47.2 |35.6 | 49.0 |`models.retrievers.retriever_nq_base`|
| retriever_nq_large | NQ | AlBERT-large | 728 | 40.4 | 47.3| 34.1|48.1 |`models.retrievers.retriever_nq_large`|
| retriever_nq_xlarge | NQ | AlBERT-xlarge | 728 | 41.1 |47.7 | 35.7| 48.9|`models.retrievers.retriever_nq_xlarge`|
| retriever_tqa_base | TQA | AlBERT-base | 728 | 37.5| 46.8 | 38.7| 51.0| `models.retrievers.retriever_tqa_base`|
| retriever_tqa_large | TQA | AlBERT-large |  728 | 38.2| 47.0| 39.6|51.4 |`models.retrievers.retriever_tqa_large`|
| retriever_tqa_xlarge | TQA | AlBERT-xlarge | 728 | 38.0| 46.5 | 38.9|51.2 |`models.retrievers.retriever_tqa_xlarge`|

(Rerank scores calculated with `reranker_multi_xxlarge`)

## QA Rerankers


| Model  | Training data |  Architecture | NQ EM | TQA EM |  Download Resource Key Name |
| ------------- |----------| --- | --------- | ---------- |---- |
|reranker_multi_base| NQ + TQA| AlBERT-base |46.0 |48.9 | `models.rerankers.reranker_multi_base`| 
|reranker_multi_large| NQ + TQA|AlBERT-large | 46.2| 49.4|`models.rerankers.reranker_multi_large`| 
|reranker_multi_xlarge| NQ + TQA|AlBERT-xlarge | 46.0| 49.1| `models.rerankers.reranker_multi_xlarge`| 
|reranker_multi_xxlarge| NQ + TQA|AlBERT-xxlarge | 47.7| 52.1 | `models.rerankers.reranker_multi_xxlarge`| 
|reranker_nq_xlarge| NQ | AlBERT-xlarge | 45.2| 46.7 | `models.rerankers.reranker_nq_xlarge`| 
|reranker_nq_xxlarge| NQ| AlBERT-xxlarge |46.4 | 49.6| `models.rerankers.reranker_nq_xxlarge`| 
|reranker_tqa_xlarge| TQA | AlBERT-xlarge | 45.0|49.7 | `models.rerankers.reranker_tqa_xlarge`| 
|reranker_tqa_xxlarge| TQA | AlBERT-xxlarge | 46.0|51.7 | `models.rerankers.reranker_tqa_xxlarge`| 

(EM scores in this table calculated using  `retriever_multi_xlarge` retriever)

## Qgen Models

| Model  | Training data |  Architecture |  Download Resource Key Name |
| ------------- |---------- |---- | ---- |
| qgen_nq_base| NQ | BART-base |  `models.qgen.qgen_nq_base`| 
| qgen_multi_base| Multitask | BART-base |  `models.qgen.qgen_multi_base`| 


## Passage Ranker Models

Models used for selecting passages to generate questions from:

| Model  | Training data |  Architecture |  Download Resource Key Name |
| ------------- |---------- |---- | ---- |
| passage_ranker_base| NQ | BERT-base |  `models.passage_rankers.passage_ranker_base`| 

Note, the original Passage ranker model used in the paper was unfortunately lost due to a storage corruption issue.
The model here is a reproduction using the same hardware and HPs, but differs a little due to the stochastic training sampling procedure.

## Answer Extractor Models


| Model  | Description | Training data |  Architecture |  Download Resource Key Name |
| ----------| --- |---------- |---- | ---- |
| answer_extractor_nq_base| Learnt Answer Span Extractor, BERT-base, NQ-trained | NQ | BERT-base |  `models.answer_extractors.answer_extractor_nq_base`| 

## Filterer Models

| Model  | Description | Training data |  Architecture |  Download Resource Key Name |
| ----------| --- |---------- |---- | ---- |
| dpr_nq_passage_retriever| DPR Passage retriever and faiss index, from the DPR Paper, used for retrieving passage for the reader in global filtering, NQ-trained| NQ | BERT-base |  `models.filtering.dpr_nq_passage_retriever`| 
| fid_reader_nq_base| FID-base reader, from the Fusion-in-Decoder paper, used in global and local filtering, NQ-trained | NQ | t5-base |  `models.filtering.fid_reader_nq_base`| 


================================================
FILE: generator_configs/answer_extractor_configs/learnt_answer_extractor_config.json
================================================
{
  "answer_extractor": {
    "name": "answer_extractor/span2D",
    "config": {
      "model_path": "data/models/answer_extractors/answer_extractor_nq_base",
      "config_path": "data/models/answer_extractors/answer_extractor_nq_base",
      "tokenizer_path": "data/models/answer_extractors/answer_extractor_nq_base",
      "topk": 8,
      "max_answer_len": 30,
      "max_seq_len": 256,
      "doc_stride": 128,
      "batch_size": 128,
      "device": 0
    }
  }
}

================================================
FILE: generator_configs/answer_extractor_configs/named_entity_answer_extractor_config.json
================================================
{
    "answer_extractor": {
    "name": "answer_extractor/spacy_ner",
    "config": {
      "model": "en_core_web_sm"
    }
  }
}

================================================
FILE: generator_configs/filterer_configs/dummy_filtering_config.json
================================================
{
  "filterer": {
    "retriever": {
      "name": "filtering/dummy_filtering_retriever",
      "config": {
      }
    },
    "reader": {
      "name": "filtering/dummy_reader",
      "config": {
      }
    }
  }
}

================================================
FILE: generator_configs/filterer_configs/global_filtering_config.json
================================================
{
  "filterer": {
    "retriever": {
      "name": "filtering/global_filtering_retriever",
      "config": {
        "corpus_path": "data/paq/psgs_w100.tsv",
        "index_path": "data/models/filtering/dpr_nq_passage_retriever/dpr_index.hnsw.SQ8.index.dpr",
        "index_id_to_db_id_path": "data/models/filtering/dpr_nq_passage_retriever/dpr_index.hnsw.SQ8.index_meta.dpr",
        "model_path": "data/models/filtering/dpr_nq_passage_retriever",
        "batch_size": 128,
        "n_queries_to_parallelize": 2048,
        "max_seq_len":256,
        "n_docs": 50,
        "device": 0
      }
    },
    "reader": {
      "name": "filtering/fid_reader",
      "config": {
        "model_path": "data/models/filtering/fid_reader_nq_base",
        "batch_size": 4,
        "device": 0,
        "max_seq_len": 200,
        "n_docs": 50
      }
    }
  }
}

================================================
FILE: generator_configs/filterer_configs/local_filtering_config.json
================================================
{
    "filterer": {
    "retriever": {
      "name": "filtering/local_filtering_retriever",
      "config": {
        "corpus_path": "data/paq/psgs_w100.tsv"
      }
    },
    "reader": {
      "name": "filtering/fid_reader",
      "config": {
        "model_path": "data/models/filtering/fid_reader_nq_base",
        "batch_size": 32,
        "device": 0,
        "max_seq_len": 200,
        "n_docs": 1
      }
    }
  }
}

================================================
FILE: generator_configs/paq_L1_config.json
================================================
{
  "passage_scorer": {
    "name": "passage_scorer/learnt",
    "config": {
      "model_path":"data/models/passage_rankers/passage_ranker_base",
      "config_path":"data/models/passage_rankers/passage_ranker_base",
      "tokenizer_path":"data/models/passage_rankers/passage_ranker_base",
      "device": 0,
      "batch_size": 64,
      "max_seq_len": 256
    }
  },
  "answer_extractor": {
    "name": "answer_extractor/span2D",
    "config": {
      "model_path": "data/models/answer_extractors/answer_extractor_nq_base",
      "config_path": "data/models/answer_extractors/answer_extractor_nq_base",
      "tokenizer_path": "data/models/answer_extractors/answer_extractor_nq_base",
      "topk": 8,
      "max_answer_len": 30,
      "max_seq_len": 256,
      "doc_stride": 128,
      "batch_size": 128,
      "device": 0
    }
  },
  "question_generator": {
    "name": "question_generator/standard",
    "config": {
      "model_path": "data/models/qgen/qgen_multi_base",
      "config_path": null,
      "tokenizer_path": "data/models/qgen/qgen_multi_base",
      "include_title": true,
      "num_beams": 4,
      "num_return_sequences": 1,
      "max_question_len": 20,
      "batch_size": 64,
      "device": 0
    }
  },
  "filterer": {
    "retriever": {
      "name": "filtering/global_filtering_retriever",
      "config": {
        "corpus_path": "data/paq/psgs_w100.tsv",
        "index_path": "data/models/filtering/dpr_nq_passage_retriever/dpr_index.hnsw.SQ8.index.dpr",
        "index_id_to_db_id_path": "data/models/filtering/dpr_nq_passage_retriever/dpr_index.hnsw.SQ8.index_meta.dpr",
        "model_path": "data/models/filtering/dpr_nq_passage_retriever",
        "batch_size": 128,
        "n_queries_to_parallelize": 2048,
        "max_seq_len":256,
        "n_docs": 50,
        "device": 0
      }
    },
    "reader": {
      "name": "filtering/fid_reader",
      "config": {
        "model_path": "data/models/filtering/fid_reader_nq_base",
        "batch_size": 4,
        "device": 0,
        "max_seq_len": 200,
        "n_docs": 50
      }
    }
  }
}

================================================
FILE: generator_configs/paq_L1_with_local_filtering_config.json
================================================
{
  "passage_scorer": {
    "name": "passage_scorer/learnt",
    "config": {
      "model_path":"data/models/passage_rankers/passage_ranker_base",
      "config_path":"data/models/passage_rankers/passage_ranker_base",
      "tokenizer_path":"data/models/passage_rankers/passage_ranker_base",
      "device": 0,
      "batch_size": 64,
      "max_seq_len": 256
    }
  },
  "answer_extractor": {
    "name": "answer_extractor/span2D",
    "config": {
      "model_path": "data/models/answer_extractors/answer_extractor_nq_base",
      "config_path": "data/models/answer_extractors/answer_extractor_nq_base",
      "tokenizer_path": "data/models/answer_extractors/answer_extractor_nq_base",
      "topk": 8,
      "max_answer_len": 30,
      "max_seq_len": 256,
      "doc_stride": 128,
      "batch_size": 128,
      "device": 0
    }
  },
  "question_generator": {
    "name": "question_generator/standard",
    "config": {
      "model_path": "data/models/qgen/qgen_multi_base",
      "config_path": null,
      "tokenizer_path": "data/models/qgen/qgen_multi_base",
      "include_title": true,
      "num_beams": 4,
      "num_return_sequences": 1,
      "max_question_len": 20,
      "batch_size": 64,
      "device": 0
    }
  },
    "filterer": {
    "retriever": {
      "name": "filtering/local_filtering_retriever",
      "config": {
        "corpus_path": "data/paq/psgs_w100.tsv"
      }
    },
    "reader": {
      "name": "filtering/fid_reader",
      "config": {
        "model_path": "data/models/filtering/fid_reader_nq_base",
        "batch_size": 32,
        "device": 0,
        "max_seq_len": 200,
        "n_docs": 1
      }
    }
  }
}

================================================
FILE: generator_configs/paq_L4_config.json
================================================
{
  "passage_scorer": {
    "name": "passage_scorer/learnt",
    "config": {
      "model_path":"data/models/passage_rankers/passage_ranker_base",
      "config_path":"data/models/passage_rankers/passage_ranker_base",
      "tokenizer_path":"data/models/passage_rankers/passage_ranker_base",
      "device": 0,
      "batch_size": 64,
      "max_seq_len": 256
    }
  },
  "answer_extractor": {
    "name": "answer_extractor/span2D",
    "config": {
      "model_path": "data/models/answer_extractors/answer_extractor_nq_base",
      "config_path": "data/models/answer_extractors/answer_extractor_nq_base",
      "tokenizer_path": "data/models/answer_extractors/answer_extractor_nq_base",
      "topk": 8,
      "max_answer_len": 30,
      "max_seq_len": 256,
      "doc_stride": 128,
      "batch_size": 128,
      "device": 0
    }
  },
  "question_generator": {
    "name": "question_generator/standard",
    "config": {
      "model_path": "data/models/qgen/qgen_multi_base",
      "config_path": null,
      "tokenizer_path": "data/models/qgen/qgen_multi_base",
      "include_title": true,
      "num_beams": 4,
      "num_return_sequences": 4,
      "max_question_len": 20,
      "batch_size": 64,
      "device": 0
    }
  },
  "filterer": {
    "retriever": {
      "name": "filtering/global_filtering_retriever",
      "config": {
        "corpus_path": "data/paq/psgs_w100.tsv",
        "index_path": "data/models/filtering/dpr_nq_passage_retriever/dpr_index.hnsw.SQ8.index.dpr",
        "index_id_to_db_id_path": "data/models/filtering/dpr_nq_passage_retriever/dpr_index.hnsw.SQ8.index_meta.dpr",
        "model_path": "data/models/filtering/dpr_nq_passage_retriever",
        "batch_size": 128,
        "n_queries_to_parallelize": 2048,
        "max_seq_len":256,
        "n_docs": 50,
        "device": 0
      }
    },
    "reader": {
      "name": "filtering/fid_reader",
      "config": {
        "model_path": "data/models/filtering/fid_reader_nq_base",
        "batch_size": 4,
        "device": 0,
        "max_seq_len": 200,
        "n_docs": 50
      }
    }
  }
}

================================================
FILE: generator_configs/paq_NE_config.json
================================================
{
  "passage_scorer": {
    "name": "passage_scorer/learnt",
    "config": {
      "model_path":"data/models/passage_rankers/passage_ranker_base",
      "config_path":"data/models/passage_rankers/passage_ranker_base",
      "tokenizer_path":"data/models/passage_rankers/passage_ranker_base",
      "device": 0,
      "batch_size": 64,
      "max_seq_len": 256
    }
  },
  "answer_extractor": {
    "name": "answer_extractor/spacy_ner",
    "config": {
      "model": "en_core_web_sm"
    }
  },
  "question_generator": {
    "name": "question_generator/standard",
    "config": {
      "model_path": "data/models/qgen/qgen_nq_base",
      "config_path": null,
      "tokenizer_path": "data/models/qgen/qgen_nq_base",
      "include_title": true,
      "num_beams": 4,
      "num_return_sequences": 1,
      "max_question_len": 20,
      "batch_size": 64,
      "device": 0
    }
  },
  "filterer": {
    "retriever": {
      "name": "filtering/global_filtering_retriever",
      "config": {
        "corpus_path": "data/paq/psgs_w100.tsv",
        "index_path": "data/models/filtering/dpr_nq_passage_retriever/dpr_index.hnsw.SQ8.index.dpr",
        "index_id_to_db_id_path": "data/models/filtering/dpr_nq_passage_retriever/dpr_index.hnsw.SQ8.index_meta.dpr",
        "model_path": "data/models/filtering/dpr_nq_passage_retriever",
        "batch_size": 128,
        "n_queries_to_parallelize": 2048,
        "max_seq_len":256,
        "n_docs": 50,
        "device": 0
      }
    },
    "reader": {
      "name": "filtering/fid_reader",
      "config": {
        "model_path": "data/models/filtering/fid_reader_nq_base",
        "batch_size": 4,
        "device": 0,
        "max_seq_len": 200,
        "n_docs": 50
      }
    }
  }
}

================================================
FILE: generator_configs/passage_ranker_configs/dummy_passage_scorer_config.json
================================================
{
  "passage_scorer": {
    "name": "passage_scorer/dummy",
    "config":{
      "default_score": -1000
    }
  }
}

================================================
FILE: generator_configs/passage_ranker_configs/learnt_passage_scorer_config.json
================================================
{
  "passage_scorer": {
    "name": "passage_scorer/learnt",
    "config": {
      "model_path":"data/models/passage_rankers/passage_ranker_base",
      "config_path":"data/models/passage_rankers/passage_ranker_base",
      "tokenizer_path":"data/models/passage_rankers/passage_ranker_base",
      "device": 0,
      "batch_size": 64,
      "max_seq_len": 256
    }
  }
}

================================================
FILE: generator_configs/passage_ranker_configs/lookup_passage_scorer_config.json
================================================
{
  "passage_scorer": {
    "name": "passage_scorer/lookup",
    "config":{
      "default_score": -1000,
      "scores_file": "data/paq/PASSAGE_SCORES/passage_scores.tsv"
    }
  }
}

================================================
FILE: generator_configs/question_generator_configs/question_generation_config.json
================================================
{
  "question_generator": {
    "name": "question_generator/standard",
    "config": {
      "model_path": "data/models/qgen/qgen_multi_base",
      "config_path": null,
      "tokenizer_path": "data/models/qgen/qgen_multi_base",
      "include_title": true,
      "num_beams": 4,
      "num_return_sequences": 1,
      "max_question_len": 20,
      "batch_size": 64,
      "device": 0
    }
  }
}

================================================
FILE: paq/__init__.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

================================================
FILE: paq/download.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import gzip
import logging
import os
import pathlib
import wget
import tarfile

from typing import Tuple, List

logger = logging.getLogger(__name__)


NQ_LICENSE_FILES = [
    "https://dl.fbaipublicfiles.com/dpr/nq_license/LICENSE",
    "https://dl.fbaipublicfiles.com/dpr/nq_license/README",
]

RESOURCES_MAP = {
    "paq.PAQ": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/PAQ.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "Full PAQ generated QA pairs (PAQ-L + PAQ-NE)",
        "skip_if_exists_path": "paq/PAQ"
    },
    "paq.PAQ_L1": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/PAQ_L1.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "PAQ-L1 subset of PAQ generated QA pairs",
        "skip_if_exists_path": "paq/PAQ_L1"
    },
    "paq.PAQ_L4": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/PAQ_L4.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "PAQ-L4 subset of PAQ generated QA pairs",
        "skip_if_exists_path": "paq/PAQ_L4"
    },
    "paq.PAQ_NE1": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/PAQ_NE1.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "PAQ-NE1 subset of PAQ generated QA pairs",
        "skip_if_exists_path": "paq/PAQ_NE1"
    },
    "paq.TQA_TRAIN_NQ_TRAIN_PAQ": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/TQA_TRAIN_NQ_TRAIN_PAQ.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "TriviaQA train set QA pairs, NQ train set QA pairs and Full PAQ generated QA pairs",
        "skip_if_exists_path": "paq/TQA_TRAIN_NQ_TRAIN_PAQ"

    },
    "paq.psgs_w100": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/psgs_w100.tsv.gz',
        "original_ext": ".tsv",
        "compressed": True,
        "desc": "Preprocessed wikipedia dump, split into 100 word passages",
        "skip_if_exists_path": "paq/psgs_w100.tsv"
    },
    "paq.PASSAGE_SCORES": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/PASSAGE_SCORES.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "Passage selection scores for the passages in `psgs_w100`",
        "skip_if_exists_path": "paq/PASSAGE_SCORES"
    },
    "paq.PAQ_metadata": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/PAQ.metadata.jsonl.gz',
        "original_ext": ".jsonl",
        "compressed": True,
        "desc": "PAQ QA pairs metadata ",
        "skip_if_exists_path": "paq/PAQ.metadata.jsonl"
    },
    "paq.PAQ_unfiltered_metadata": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/PAQ.unfiltered_metadata.jsonl.gz',
        "original_ext": ".jsonl",
        "compressed": True,
        "desc": "PAQ QA pairs metadata for unfiltered QA pairs",
        "skip_if_exists_path": "paq/PAQ.unfiltered_metadata.jsonl"
    },

    'annotated_datasets.naturalquestions': {
        's3_url': [
            'https://dl.fbaipublicfiles.com/paq/v1/annotated_datasets/NQ-open.train-train.jsonl',
            'https://dl.fbaipublicfiles.com/paq/v1/annotated_datasets/NQ-open.train-dev.jsonl',
            'https://dl.fbaipublicfiles.com/paq/v1/annotated_datasets/NQ-open.test.jsonl',
            'https://dl.fbaipublicfiles.com/paq/v1/annotated_datasets/NQ_LICENSE',
            'https://dl.fbaipublicfiles.com/paq/v1/annotated_datasets/NQ_README',
        ],
        "original_ext": [".jsonl", ".jsonl", ".jsonl", "", ""],
        "compressed": False,
        "desc": "The Open NaturalQuestions QA Pairs used in our experiments",
        "skip_if_exists_path": "annotated_datasets/naturalquestions"
    },
    'annotated_datasets.triviaqa': {
        's3_url': [
            'https://dl.fbaipublicfiles.com/paq/v1/annotated_datasets/triviaqa.train-train.jsonl',
            'https://dl.fbaipublicfiles.com/paq/v1/annotated_datasets/triviaqa.train-dev.jsonl',
            'https://dl.fbaipublicfiles.com/paq/v1/annotated_datasets/triviaqa.test.jsonl'
        ],
        "original_ext": ".jsonl",
        "compressed": False,
        "desc": "The TriviaQA QA Pairs used in our experiments",
        "skip_if_exists_path": "annotated_datasets/triviaqa"
    },

    "models.retrievers.retriever_multi_base_256": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/retrievers/retriever_multi_base_256.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "RePAQ Retriever Albert-Base model with 256 output embedding dim, multask. Recommended RePAQ retriever",
        "skip_if_exists_path": "models/retrievers/retriever_multi_base_256"
    },
    "models.retrievers.retriever_multi_base": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/retrievers/retriever_multi_base.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "RePAQ Retriever Albert-Base model with 768 output embedding dim, multitask",
        "skip_if_exists_path": "models/retrievers/retriever_multi_base"
    },
    "models.retrievers.retriever_multi_large": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/retrievers/retriever_multi_large.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "RePAQ Retriever Albert-Large model with 768 output embedding dim, multitask",
        "skip_if_exists_path": "models/retrievers/retriever_multi_large"
    },
    "models.retrievers.retriever_multi_xlarge": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/retrievers/retriever_multi_xlarge.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "RePAQ Retriever Albert-xlarge model with 768 output embedding dim, multitask",
        "skip_if_exists_path": "models/retrievers/retriever_multi_xlarge"
    },
    "models.retrievers.retriever_nq_base": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/retrievers/retriever_nq_base.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "RePAQ Retriever Albert-base model with 768 output embedding dim, trained on NQ",
        "skip_if_exists_path": "models/retrievers/retriever_nq_base"
    },
    "models.retrievers.retriever_nq_large": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/retrievers/retriever_nq_large.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "RePAQ Retriever Albert-large model with 768 output embedding dim, trained on NQ",
        "skip_if_exists_path": "models/retrievers/retriever_nq_large"
    },
    "models.retrievers.retriever_nq_xlarge": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/retrievers/retriever_nq_xlarge.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "RePAQ Retriever Albert-xlarge model with 768 output embedding dim, trained on NQ",
        "skip_if_exists_path": "models/retrievers/retriever_nq_xlarge"

    },
    "models.retrievers.retriever_tqa_base": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/retrievers/retriever_tqa_base.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "RePAQ Retriever Albert-base model with 768 output embedding dim, trained on TriviaQA",
        "skip_if_exists_path": "models/retrievers/retriever_tqa_base"
    },
    "models.retrievers.retriever_tqa_large": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/retrievers/retriever_tqa_large.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "RePAQ Retriever Albert-large model with 768 output embedding dim, trained on TriviaQA",
        "skip_if_exists_path": "models/retrievers/retriever_tqa_large"
    },
    "models.retrievers.retriever_tqa_xlarge": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/retrievers/retriever_tqa_xlarge.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "RePAQ Retriever Albert-xlarge model with 768 output embedding dim, trained on TriviaQA",
        "skip_if_exists_path": "models/retrievers/retriever_tqa_xlarge"
    },



    'vectors.multi_base_256_vectors': {
        "s3_url": [
            "https://dl.fbaipublicfiles.com/paq/v1/models/vectors/multi_base_256_vectors/embeddings.pt.{}".format(
                i
            )
            for i in range(50)
        ],
        "original_ext": ".pt",
        "compressed": False,
        "desc": "Precomputed vectors for tqa-train-nq-train-PAQ.jsonl, using the `multi_base_256` model",
        "skip_if_exists_path": "vectors/multi_base_256_vectors"
    },


    "indices.multi_base_256_flat_sq8": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/indices/multi_base_256.flat.sq8.faiss',
        "original_ext": ".faiss",
        "compressed": False,
        "desc": "Precomputed Flat Faiss Index for tqa-train-nq-train-PAQ.jsonl, using the `multi_base_256` model. Slow but exact",
        "skip_if_exists_path": "indices/multi_base_256_flat_sq8"
    },
    "indices.multi_base_256_hnsw_sq8": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/indices/multi_base_256.hnsw.sq8.faiss',
        "original_ext": ".faiss",
        "compressed": False,
        "desc": "Precomputed Flat Faiss Index for tqa-train-nq-train-PAQ.jsonl, using the `multi_base_256` model. "
                "Very Fast but slightly less accurate than `multi_base_256_flat_sq8`",
        "skip_if_exists_path": "indices/multi_base_256_hnsw_sq8"
    },

    "models.rerankers.reranker_multi_base": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/rerankers/reranker_multi_base.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "RePAQ Reranker AlBERT-Base model, multitask",
        "skip_if_exists_path": "models/rerankers/reranker_multi_base"

    },
    "models.rerankers.reranker_multi_large": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/rerankers/reranker_multi_large.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "RePAQ Reranker AlBERT-Large model, multitask",
        "skip_if_exists_path": "models/rerankers/reranker_multi_large"
    },
    "models.rerankers.reranker_multi_xlarge": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/rerankers/reranker_multi_xlarge.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "RePAQ Reranker AlBERT-xlarge model, multitask",
        "skip_if_exists_path": "models/rerankers/reranker_multi_xlarge"
    },
    "models.rerankers.reranker_multi_xxlarge": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/rerankers/reranker_multi_xxlarge.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "RePAQ Reranker AlBERT-xxlarge model, multitask",
        "skip_if_exists_path": "models/rerankers/reranker_multi_xxlarge"
    },
    "models.rerankers.reranker_tqa_xlarge": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/rerankers/reranker_tqa_xlarge.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "RePAQ Reranker AlBERT-xlarge model, TriviaQA-trained",
        "skip_if_exists_path": "models/rerankers/reranker_tqa_xlarge"
    },
    "models.rerankers.reranker_tqa_xxlarge": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/rerankers/reranker_tqa_xxlarge.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "RePAQ Reranker AlBERT-xxlarge model, TriviaQA-trained",
        "skip_if_exists_path": "models/rerankers/reranker_tqa_xxlarge"
    },
    "models.rerankers.reranker_nq_xlarge": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/rerankers/reranker_nq_xlarge.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "RePAQ Reranker AlBERT-xlarge model, NQ-trained",
        "skip_if_exists_path": "models/rerankers/reranker_nq_xlarge"
    },
    "models.rerankers.reranker_nq_xxlarge": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/rerankers/reranker_nq_xxlarge.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "RePAQ Reranker AlBERT-xxlarge model, NQ-trained",
        "skip_if_exists_path": "models/rerankers/reranker_nq_xxlarge"
    },

    "models.passage_rankers.passage_ranker_base": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/passage_rankers/passage_ranker_base.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "Passage Ranker model, BERT-base model, trained on NQ passages with hard negatives",
        "skip_if_exists_path": "models/passage_rankers/passage_ranker_base"
    },


    "models.qgen.qgen_multi_base": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/qgen/qgen_multi_base.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "Question Generator model. BART-base model, multitask-trained",
        "skip_if_exists_path": "models/qgen/qgen_multi_base"
    },
    "models.qgen.qgen_nq_base": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/qgen/qgen_nq_base.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "Question Generator model. BART-base model, NQ-trained",
        "skip_if_exists_path": "models/qgen/qgen_nq_base"
    },


    "models.filtering.dpr_nq_passage_retriever": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/filtering/dpr_nq_passage_retriever.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "DPR Passage retriever and faiss index, from the DPR Paper, used in global filtering, NQ-trained",
        "skip_if_exists_path": "models/filtering/dpr_nq_passage_retriever",

    },
    "models.filtering.fid_reader_nq_base": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/filtering/fid_reader_nq_base.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "FID-base reader, from the Fusion-in-Decoder paper, used in global and local filtering, NQ-trained",
        "skip_if_exists_path": "models/filtering/fid_reader_nq_base",
    },

    "models.answer_extractors.answer_extractor_nq_base": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/models/answer_extractors/answer_extractor_nq_base.tar.gz',
        "original_ext": ".tar.gz",
        "compressed": True,
        "desc": "Learnt Answer Span Extractor, BERT-base, NQ-trained ",
        "skip_if_exists_path": "models/answer_extractors/answer_extractor_nq_base",
    },


    "predictions.retriever_results.multi_xlarge_nq_test": {
        's3_url': 'https://dl.fbaipublicfiles.com/paq/v1/predictions/retriever_results/multi_xlarge_nq_test.jsonl.gz',
        "original_ext": ".jsonl",
        "compressed": True,
        "desc": "Learnt Answer Span Extractor, BERT-base, NQ-trained ",
        "skip_if_exists_path": "predictions/retriever_results/multi_xlarge_nq_test.jsonl",
    },


}


def untar(tar_filename: str) -> List[str]:
    logger.info("Uncompressing %s", tar_filename)
    tar = tarfile.open(tar_filename)
    tar.extractall(path=os.path.dirname(tar_filename))
    tar.close()
    tar_dir = tar_filename.split('.tar.gz.tmp')[0]
    return [os.path.join(tar_dir, f) for f in os.listdir(tar_dir)]


def unpack(gzip_file: str, out_file: str):
    logger.info("Uncompressing %s", gzip_file)
    input = gzip.GzipFile(gzip_file, "rb")
    s = input.read()
    input.close()
    output = open(out_file, "wb")
    output.write(s)
    output.close()
    logger.info(" Saved to %s", out_file)


def _get_root_dir(out_dir):
    if out_dir:
        root_dir = out_dir
    else:
        # since hydra overrides the location for the 'current dir' for every run and we don't want to duplicate
        # resources multiple times, remove the current folder's volatile part
        root_dir = os.path.abspath("./")
        if "/outputs/" in root_dir:
            root_dir = root_dir[: root_dir.index("/outputs/")]
    return root_dir

def download_resource(
    s3_url: str, original_ext: str, compressed: bool, resource_key: str, out_dir: str, use_url_fname=False,
) -> Tuple[str, str]:
    logger.info("Requested resource from %s", s3_url)
    path_names = resource_key.split(".")

    root_dir = _get_root_dir(out_dir)
    logger.info("Download root_dir %s", root_dir)
    save_root = os.path.join(root_dir, "data", *path_names[:-1])  # last segment is for file name

    pathlib.Path(save_root).mkdir(parents=True, exist_ok=True)
    if use_url_fname:
        local_file_uncompressed = os.path.abspath(
            os.path.join(save_root, s3_url.split('/')[-1])
        )
    else:
        local_file_uncompressed = os.path.abspath(
            os.path.join(save_root, path_names[-1] + original_ext)
        )
    logger.info("File to be downloaded as %s", local_file_uncompressed)

    if os.path.exists(local_file_uncompressed):
        logger.info("File already exist %s", local_file_uncompressed)
        return save_root, local_file_uncompressed

    local_file = local_file_uncompressed if not compressed else local_file_uncompressed + '.tmp'
    wget.download(s3_url, out=local_file)

    logger.info("Downloaded to %s", local_file)

    if compressed:
        if original_ext == '.tar.gz':
            local_files = untar(local_file)
            os.remove(local_file)
            local_file = ','.join(local_files)
        else:
            uncompressed_file = os.path.join(save_root, path_names[-1] + original_ext)
            unpack(local_file, uncompressed_file)
            os.remove(local_file)
            local_file = uncompressed_file
    return save_root, local_file


def download_file(s3_url: str, out_dir: str, file_name: str):
    logger.info("Loading from %s", s3_url)
    local_file = os.path.join(out_dir, file_name)

    if os.path.exists(local_file):
        logger.info("File already exist %s", local_file)
        return

    wget.download(s3_url, out=local_file)
    logger.info("Downloaded to %s", local_file)


def download(resource_key: str, out_dir: str = None):
    if resource_key not in RESOURCES_MAP:
        # match by prefix
        resources = [k for k in RESOURCES_MAP.keys() if k.startswith(resource_key)]
        if resources:
            for key in resources:
                download(key, out_dir)
        else:
            logger.info("no resources found for specified key")
        return []
    download_info = RESOURCES_MAP[resource_key]

    if "skip_if_exists_path" in download_info:
        root_dir = _get_root_dir(out_dir)
        save_root = os.path.join(root_dir, "data", download_info['skip_if_exists_path'])
        if os.path.exists(save_root):
            logger.info(f"Resource: {resource_key} already exists here: {save_root}, "
                        f"delete this directory to force re-download")
            return []


    s3_url = download_info["s3_url"]

    save_root_dir = None
    data_files = []
    if isinstance(s3_url, list):
        if isinstance(download_info["original_ext"], str):
            exts = [download_info["original_ext"] for _ in s3_url]
        else:
            exts = download_info['original_ext']
        for i, (url, ext) in enumerate(zip(s3_url, exts)):
            save_root_dir, local_file = download_resource(
                url,
                ext,
                download_info["compressed"],
                resource_key,
                # "{}_{}".format(resource_key, i),
                out_dir,
                True
            )
            data_files.append(local_file)
    else:
        save_root_dir, local_file = download_resource(
            s3_url,
            download_info["original_ext"],
            download_info["compressed"],
            resource_key,
            out_dir,
        )
        data_files.append(local_file)

    license_files = download_info.get("license_files", None)
    if license_files:
        download_file(license_files[0], save_root_dir, "LICENSE")
        download_file(license_files[1], save_root_dir, "README")
    return data_files


def main():
    NL = '\n'
    parser = argparse.ArgumentParser("Tool for downloading resources",formatter_class=argparse.RawTextHelpFormatter)

    parser.add_argument(
        "--output_dir",
        default="./",
        type=str,
        help="The output directory to download file",
    )
    parser.add_argument(
        "--name", "-n",
        type=str,
        required=True,
        help=f"Resource name. Choose between: {NL + NL.join([str(k) + ' : ' + str(v['desc']) for k, v in RESOURCES_MAP.items()])}",
    )
    parser.add_argument('-v', '--verbose', action="store_true")
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    if args.name:
        downloaded_files = download(args.name, args.output_dir)
        logger.info(f'\nDownloaded the following files for resource {args.name} :')
        for d in downloaded_files:
            if ',' in d:
                for d2 in d.split(','):
                    logger.info(d2)
            else:
                logger.info(f'Downloaded {d}')
    else:
        logger.error("Please specify resource value. Possible options are:")
        for k, v in RESOURCES_MAP.items():
            logger.error("Resource key=%s  :  %s", k, v["desc"])


if __name__ == "__main__":
    main()


================================================
FILE: paq/evaluation/__init__.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

================================================
FILE: paq/evaluation/eval_reranker.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from paq.evaluation.eval_utils import metric_max_over_ground_truths, exact_match_score
from paq.paq_utils import load_jsonl


def evaluate_exact_match(preds, refs):
    assert len(refs) == len(preds)

    scores = []
    for ref, pred in zip(refs, preds):
        score = metric_max_over_ground_truths(exact_match_score, pred, ref)
        scores.append(score)

    return sum(scores) / len(scores)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--predictions', type=str, help="path to predicted answers in jsonl format {'question': question, 'prediciton': predicted answer}")
    parser.add_argument('--references', type=str, help="path to gold answers, in jsonl format")
    args = parser.parse_args()

    refs = load_jsonl(args.references)
    preds = load_jsonl(args.predictions)
    assert len(refs) == len(preds), "number of references doesnt match number of predictions"

    assert len(refs) == len(preds)
    scores = []
    for r, p in zip(refs, preds):
        ref_answers = r['answer']
        pred_answer = p['prediction']
        score = metric_max_over_ground_truths(exact_match_score, pred_answer, ref_answers)
        scores.append(score)

    print(f'{100 * sum(scores) / len(scores):0.1f}% \n({sum(scores)} / {len(scores)})')


================================================
FILE: paq/evaluation/eval_retriever.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from paq.evaluation.eval_utils import metric_max_over_ground_truths, exact_match_score
from paq.paq_utils import load_jsonl


def eval_retriever(refs, preds, hits_at_k):
    for k in hits_at_k:
        scores = []
        dont_print = False
        for r, p in zip(refs, preds):
            if hits_at_k[-1] > len(p['retrieved_qas']):
                print(f'Skipping hits@{K} eval as {K} is larger than number of retrieved results')
                dont_print = True
            ref_answers = r['answer']
            em = any([
                metric_max_over_ground_truths(exact_match_score, pred_answer['answer'][0], ref_answers)
                for pred_answer in p['retrieved_qas'][:k]
            ])
            scores.append(em)

        if not dont_print:
            print(f'{k}: {100 * sum(scores) / len(scores):0.1f}% \n({sum(scores)} / {len(scores)})')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--predictions', type=str, help="path to retrieval results to eval, in PAQ's retrieved results jsonl format")
    parser.add_argument('--references', type=str, help="path to gold answers, in jsonl format")
    parser.add_argument('--hits_at_k', type=str, help='comma separated list of K to eval hits@k for', default="1,10,50")
    args = parser.parse_args()

    refs = load_jsonl(args.references)
    preds = load_jsonl(args.predictions)
    assert len(refs) == len(preds), "number of references doesnt match number of predictions"

    hits_at_k = sorted([int(k) for k in args.hits_at_k.split(',')])
    eval_retriever(refs, preds, hits_at_k)


================================================
FILE: paq/evaluation/eval_utils.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import re
import string
from typing import List, Union


def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def exact_match_score(prediction, ground_truth):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))


def metric_max_over_ground_truths(metric_fn, predictions: Union[str, List[str]], ground_truths: List[str]):
    scores_for_ground_truths = []

    if isinstance(predictions, str):
        predictions = [predictions]

    for prediction in predictions:
        for ground_truth in ground_truths:
            score = metric_fn(prediction, ground_truth)
            scores_for_ground_truths.append(score)

    return max(scores_for_ground_truths)



================================================
FILE: paq/generation/__init__.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

================================================
FILE: paq/generation/answer_extractor/__init__.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

================================================
FILE: paq/generation/answer_extractor/extract_answers.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
from paq.paq_utils import load_jsonl, dump_jsonl, load_dpr_tsv
from paq.generation.answer_extractor.extractors import load_answer_extractor
import logging
import argparse

logger = logging.getLogger(__name__)


def load_passages(path):
    try:
        return load_jsonl(path)
    except:
        return load_dpr_tsv(path)


def extract_answers(config, input_file, verbose):
    answer_extractor = load_answer_extractor(config)
    passages = load_passages(input_file)
    logger.info("Running answer extractor...")
    annotations = answer_extractor.extract_answers_from_passages(passages, disable_tqdm=not verbose)
    return annotations


def extract_answers_and_write_to_file(config, input_path, output_path, verbose):
    annotations = extract_answers(config, input_path, verbose)
    logger.info('writing extracted answers to file...')
    dump_jsonl(annotations, output_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser("Extract answers from passages")
    parser.add_argument('--passages_to_extract_from', type=str, required=True, help='path to passages to extract in jsonl format')
    parser.add_argument('--output_path', type=str, required=True, help='Path to dump results to')
    parser.add_argument('--path_to_config', type=str, required=True, help='path to answer extractor config file')
    parser.add_argument('-v', '--verbose', action="store_true")
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    with open(args.path_to_config) as f:
        config = json.load(f)

    if 'answer_extractor' in config:
        config = config['answer_extractor']

    extract_answers_and_write_to_file(config, args.passages_to_extract_from, args.output_path, args.verbose)


================================================
FILE: paq/generation/answer_extractor/extractors.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import numpy as np
from typing import List, Dict
import torch
from tqdm.auto import tqdm
from transformers import AutoConfig, AutoTokenizer
from paq.paq_utils import is_spacy_available
from paq.generation.answer_extractor.span2D_model import AnswerSpanExtractor2DModel, postprocess_span2d_output


def get_output_format(all_passages, all_answers):
    all_results = []
    assert len(all_passages) == len(all_answers)
    for passage, answers in zip(all_passages, all_answers):
        result = {
            "passage_id": passage["passage_id"],
            "passage": passage["passage"],
            "answers": answers,
            "metadata": passage["metadata"],
        }
        all_results.append(result)
    return all_results


class SpacyNERExtractor:
    """
    Spacy NER extractor
    """
    name = "answer_extractor/spacy_ner"

    def __init__(self, model="en_core_web_sm"):
        assert is_spacy_available(), "Spacy is not installed. Please install with `pip install spacy`."
        import spacy
        self.nlp = spacy.load(model)

    def extract_from_passage(self, passage: str) -> List[Dict]:
        doc = self.nlp(passage)
        answers = []
        for ent in doc.ents:
            answers.append({
                "text": ent.text,
                "start": ent.start_char,
                "end": ent.end_char,
                "score": None
            })
        return answers

    def extract_answers_from_passages(self, passages_to_label, disable_tqdm=False):
        all_answers = []
        for doc in tqdm(self.nlp.pipe([p['passage'] for p in passages_to_label], batch_size=128), disable=disable_tqdm):
            answers = []
            for ent in doc.ents:
                answers.append({
                    "text": ent.text,
                    "start": ent.start_char,
                    "end": ent.end_char,
                    "score": None
                })
            all_answers.append(answers)

        # Post-process
        all_results = get_output_format(passages_to_label, all_answers)
        return all_results


class Span2DAnswerExtractor:
    """
    Predict answer spans with their joint span probability P(start, end|context).
    """
    name = "answer_extractor/span2D"

    def __init__(
        self,
        model_path: str,
        config_path: str = None,
        tokenizer_path: str = None,
        topk: int = 5,
        max_answer_len: int = 30,
        max_seq_len: int = 256,
        doc_stride: int = 128,
        batch_size: int = 10,
        device: int = 0,
        **kwargs
    ):
        assert model_path is not None
        self.device = torch.device(f"cuda:{device}") if device is not None else torch.device("cpu")

        config = AutoConfig.from_pretrained(config_path if config_path is not None else model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path if tokenizer_path is not None else model_path)
        self.model = AnswerSpanExtractor2DModel.from_pretrained(model_path, config=config)

        self.model.to(self.device)
        self.model.eval()

        self.topk = topk
        self.max_answer_len = max_answer_len
        self.max_seq_len = max_seq_len
        self.doc_stride = doc_stride
        logging.info(f"Extract top {self.topk} answer spans with "
                     f"max_answer_len={self.max_answer_len}, max_seq_len={self.max_seq_len}, "
                     f"doc_stride={self.doc_stride}.")

        self.kwargs = kwargs
        self.batch_size = batch_size

    def _tokenize(self, passage: str):
        input_features = self.tokenizer(
            passage,
            truncation=True,
            max_length=self.max_seq_len,
            stride=self.doc_stride,
            return_overflowing_tokens = True,
            return_offsets_mapping = True,
            padding="max_length",
        )
        input_features["input_ids"] = torch.tensor(input_features["input_ids"]).to(self.device)
        input_features["token_type_ids"] = torch.tensor(input_features["token_type_ids"]).to(self.device)
        input_features["attention_mask"] = torch.tensor(input_features["attention_mask"]).to(self.device)
        return input_features

    def extract_from_passage(self, passage: str):
        input_features = self._tokenize(passage)
        model_output = self.model(**input_features, return_dict=True)
        answers = postprocess_span2d_output(model_output, input_features, self.max_answer_len, passage, self.topk)
        for answer in answers:
            answer['score'] = np.log(answer['score'])
        return answers

    def extract_answers_from_passages(self, passages_to_label, disable_tqdm=False):

        # Run the pipeline (model) to extract the answer spans
        all_answers = []
        for passage in tqdm(passages_to_label, disable=disable_tqdm):
            answers = self.extract_from_passage(passage["passage"])
            all_answers.append(answers)

        # Post-process
        all_results = get_output_format(passages_to_label, all_answers)
        return all_results


def load_answer_extractor(config):
    ANS_EXT_MAP = {m.name: m for m in [SpacyNERExtractor, Span2DAnswerExtractor]}
    answer_extractor = ANS_EXT_MAP[config['name']](**config['config'])
    return answer_extractor


================================================
FILE: paq/generation/answer_extractor/span2D_model.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import List, Optional, Tuple, Dict
import numpy as np
import math

import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, ModuleList

from transformers import BertPreTrainedModel, BertModel
from transformers.file_utils import ModelOutput


@dataclass
class AnswerSpanExtractor2DModelOutput(ModelOutput):
    """
    Base class for outputs of question answering models.

    Args:
        loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
            Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
        span_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, sequence_length)`):
            Span scores (before SoftMax).
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
            sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    loss: Optional[torch.FloatTensor] = None
    span_logits: torch.FloatTensor = None
    span_masks: torch.Tensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


class AnswerSpanExtractor2DModel(BertPreTrainedModel):
    _keys_to_ignore_on_load_unexpected = [r"pooler"]

    def __init__(self, config):
        super().__init__(config)

        self.bert = BertModel(config)

        # Linear mapping for start and end representation
        self.start_outputs = nn.Linear(config.hidden_size, config.span_output_size)
        self.end_outputs = nn.Linear(config.hidden_size, config.span_output_size)
        prev_out_size = config.span_output_size * 2

        # Add final MLP output layers to produce probabilities
        self.output_mlp = None
        mlp_sizes = getattr(config, "output_mlp_sizes", None)
        if mlp_sizes and len(mlp_sizes) > 0:
            self.output_mlp = ModuleList()
            for cur_size in mlp_sizes:
                self.output_mlp.append(nn.Linear(prev_out_size, cur_size))
                self.output_mlp.append(nn.ReLU())
                prev_out_size = cur_size

        self.span_outputs = nn.Linear(prev_out_size, 1)

        self.max_answer_length = getattr(config, "max_answer_length", 30)
        self.init_weights()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        r"""
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size, max_num_answers)`, `optional`):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size, max_num_answers)`, `optional`):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        start_hidden = self.start_outputs(sequence_output)  # [B, L, D]
        end_hidden = self.end_outputs(sequence_output)  # [B, L, D]

        sequence_length = sequence_output.shape[1]
        start_hidden = start_hidden.unsqueeze(2).expand(-1, -1, sequence_length, -1)  # [B, L, L, D]
        end_hidden = end_hidden.unsqueeze(1).expand(-1, sequence_length, -1, -1)  # [B, L, L, D]
        # Concat the start and end representation to form span representation
        span_hidden = torch.cat((start_hidden, end_hidden), -1)  # [B, L, L, D*2]

        # Run MLP layers
        if self.output_mlp is not None:
            for layer in self.output_mlp:
                span_hidden = layer(span_hidden)  # [B, L, L, ?]

        span_logits = self.span_outputs(span_hidden)  # [B, L, L, 1]
        span_logits = span_logits.squeeze(-1)  # [B, L, L]

        span_masks = torch.einsum('bi,bj->bij', attention_mask, attention_mask)  # [B, L, L]
        span_masks = torch.triu(span_masks)
        span_masks = torch.tril(span_masks, diagonal=self.max_answer_length)

        def _convert_to_span_matrix(start_positions, end_positions):
            span_labels = torch.zeros_like(span_logits)  # [B, L, L]
            for i, (start_post, end_post) in enumerate(zip(start_positions, end_positions)):
                for start_idx, end_idx in zip(start_post, end_post):
                    if 0 <= start_idx and 0 <= end_idx:  # we use -1 as null indicator
                        assert start_idx < sequence_length and end_idx < sequence_length
                        span_labels[i, start_idx, end_idx] = 1.
                    else:
                        break
            return span_labels

        total_loss = None
        if start_positions is not None and end_positions is not None:
            span_labels = _convert_to_span_matrix(start_positions, end_positions)

            loss_fct = BCEWithLogitsLoss(weight=span_masks, reduction="sum")
            total_loss = loss_fct(span_logits, span_labels)  # / torch.sum(span_masks)

        if not return_dict:
            output = (span_logits,) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

        return AnswerSpanExtractor2DModelOutput(
            loss=total_loss,
            span_logits=span_logits,
            span_masks=span_masks,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


def sigmoid(x):
    return 1 / (1 + math.exp(-x))


def postprocess_span2d_output(span2D_output: AnswerSpanExtractor2DModelOutput, features,
        max_answer_length, passage: str, n_best_size:int) -> List[Dict]:
    all_span_logits = span2D_output.span_logits.detach().cpu().numpy()
    all_span_masks = span2D_output.span_masks.detach().cpu().numpy()

    prelim_predictions = []
    # Looping through all the features associated to the current example.
    for feature_index in range(len(all_span_logits)):
        # We grab the predictions of the model for this feature.
        span_logits = all_span_logits[feature_index]
        span_masks = all_span_masks[feature_index]
        span_logits += -100 * (1 - span_masks)  # mask the span logits

        # This is what will allow us to map some the positions in our logits to span of texts in the original
        # context.
        offset_mapping = features["offset_mapping"][feature_index]
        # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context
        # available in the current feature.
        token_is_max_context = None

        # Update minimum null prediction.
        feature_null_score = span_logits[0, 0]
        min_null_prediction = {"offsets": (0, 0), "score": feature_null_score}

        # Go through all possibilities for the `n_best_size` greater start and end logits.
        # start_indexes = np.argsort(start_logits)[-1: -n_best_size - 1: -1].tolist()
        # end_indexes = np.argsort(end_logits)[-1: -n_best_size - 1: -1].tolist()
        start_indexes, end_indexes = np.unravel_index(
            np.argsort(span_logits, axis=None)[-1:-n_best_size - 10:-1],  # a buffer of 10 in case some are invalid
            span_logits.shape
        )
        start_indexes, end_indexes = start_indexes.tolist(), end_indexes.tolist()
        for start_index, end_index in zip(start_indexes, end_indexes):
            # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
            # to part of the input_ids that are not in the context.
            if (
                start_index >= len(offset_mapping)
                or end_index >= len(offset_mapping)
                or offset_mapping[start_index] is None
                or offset_mapping[end_index] is None
            ):
                continue
            # Don't consider answers with a length that is either < 0 or > max_answer_length.
            if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                continue
            # Don't consider answer that don't have the maximum context available (if such information is
            # provided).
            if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
                continue
            prelim_predictions.append(
                {
                    "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
                    "score": span_logits[start_index, end_index],
                }
            )

    # Only keep the best `n_best_size` predictions.
    predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]

    # Use the offsets to gather the answer text in the original context.
    for pred in predictions:
        offsets = pred.pop("offsets")
        pred["text"] = passage[offsets[0]: offsets[1]]
        pred["start"] = offsets[0]
        pred["end"] = offsets[1]

    # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
    # failure.
    if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""):
        predictions.insert(0, {"text": "null", "score": -100.0, "start": 0, "end": 0})

    # Include the probabilities in our predictions.
    for pred in predictions:
        score = pred.get("score")
        pred["score"] = sigmoid(score)

    return predictions


================================================
FILE: paq/generation/filtering/__init__.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

================================================
FILE: paq/generation/filtering/filter_questions.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
from paq.paq_utils import load_jsonl, dump_jsonl
from paq.generation.filtering.filterer import load_retriever, load_reader
import logging
import argparse

logger = logging.getLogger(__name__)


def retrieve_documents_for_generated_questions(config, input_file, verbose):
    retriever = load_retriever(config["retriever"])
    generated_questions = load_jsonl(input_file)
    logger.info("Running Filterer Retriever...")
    generated_questions_with_retrieved_docs = retriever.retrieve_documents(generated_questions)
    return generated_questions_with_retrieved_docs


def generate_answers_for_generated_questions_with_retrieved_docs(config, input_file, verbose):
    reader = load_reader(config["reader"])
    generated_questions_with_retrieved_docs = load_jsonl(input_file)

    logger.info("Running Filterer Reader...")
    results = reader.generate_answers(generated_questions_with_retrieved_docs)
    return results


def filter_generated_questions_and_write_to_file(config, input_path, output_path, verbose):
    results = retrieve_documents_for_generated_questions(config, input_path, verbose)
    retrieval_results_fi = output_path + '.retrieval_results'
    dump_jsonl(results, retrieval_results_fi)
    results = generate_answers_for_generated_questions_with_retrieved_docs(config, retrieval_results_fi, verbose)
    logger.info('Writing generated questions to file...')
    dump_jsonl(results, output_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser("Extract answers from passages")
    parser.add_argument('--generated_questions_to_filter',
                        type=str,
                        required=True,
                        help='path to generate from (in jsonl format, produced by `answer_extractor`)')
    parser.add_argument('--output_path', type=str, required=True, help='Path to dump results to')
    parser.add_argument('--path_to_config', type=str, required=True, help='path to question generator config file')
    parser.add_argument('-v', '--verbose', action="store_true")
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    with open(args.path_to_config) as f:
        config = json.load(f)

    if 'filterer' in config:
        config = config['filterer']

    filter_generated_questions_and_write_to_file(config, args.generated_questions_to_filter, args.output_path, args.verbose)


================================================
FILE: paq/generation/filtering/filterer.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import sys
import faiss
import os
# hack to add FID to path:
file_path = os.path.realpath(__file__)
fid_path = os.path.join(os.path.dirname(file_path), '../../../FiD')
sys.path.append(fid_path)
import torch
import transformers
import numpy as np
from torch.utils.data import DataLoader, SequentialSampler
from paq.paq_utils import load_dpr_tsv, load_jsonl
from transformers import AutoModel, AutoConfig, AutoTokenizer
import src.util
import src.data
import src.evaluation
import src.model
import logging
from paq.retrievers.embed import embed
from paq.retrievers.retrieve import mips
import pickle
from torch import nn
logger = logging.getLogger(__name__)


def _load_corpus(path):
    if 'tsv' in path or 'csv' in path:
        docs = load_dpr_tsv(path)
    else:
        docs = load_jsonl(path)
    logger.info('Parsed Corpus for retrieval')
    return {d['passage_id']: {'title': d['metadata']['title'], 'text': d['passage']} for d in docs}


class DummyFilteringRetriever:
    """Dummy filterer - does not retrieve any evidence"""
    name = "filtering/dummy_filtering_retriever"

    def retrieve_documents(self, data):
        return [{'question': d['question'], 'answers': [d['answer']], 'ctxs': [], 'metadata': d} for d in data]


class LocalFilteringRetriever:
    """Retrieves a single document (the gold context the question was generated from"""
    name = "filtering/local_filtering_retriever"
    corpus = None

    def __init__(self, corpus_path):
        self.corpus_path = corpus_path

    def retrieve_documents(self, data):
        self.corpus = _load_corpus(self.corpus_path) if self.corpus is None else self.corpus

        examples = []
        for d in data:
            assert d['passage_id'] in self.corpus
            gold_doc = self.corpus[d['passage_id']]
            examples.append(
                {'question': d['question'].strip(), 'answers': [d['answer']], 'ctxs': [gold_doc], 'metadata': d}
            )
        return examples


class DPRQuestionEncoder(nn.Module):
    """simple wrapper on DPR Question Encoder Bert model"""

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, *args, **kwargs):
        seq_outputs = self.model(*args, **kwargs)['last_hidden_state']
        return seq_outputs[:, 0]


class GlobalFilteringRetriever:
    """Uses DPR to retrieve relevant documents for the question"""
    name = "filtering/global_filtering_retriever"
    corpus = None
    index_id_to_db_id = None
    index = None

    def __init__(self,
                 corpus_path,
                 index_path,
                 index_id_to_db_id_path,
                 model_path,
                 batch_size,
                 n_queries_to_parallelize,
                 max_seq_len,
                 n_docs,
                 device
                 ):
        self.corpus_path = corpus_path
        self.index_path = index_path
        self.index_id_to_db_id_path = index_id_to_db_id_path
        self.n_docs = n_docs
        self.device = torch.device(f"cuda:{device}") if device is not None else torch.device("cpu")

        config = AutoConfig.from_pretrained(model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, config=config)
        self.model = DPRQuestionEncoder(AutoModel.from_pretrained(model_path, config=config))

        self.model.to(self.device)
        self.model.eval()
        self.batch_size = batch_size
        self.n_queries_to_parallelize = n_queries_to_parallelize
        self.max_seq_len = max_seq_len

    def _load_corpus(self):
        logger.info("Loading Corpus if not already loaded...")
        self.corpus = _load_corpus(self.corpus_path) if self.corpus is None else self.corpus
        logger.info("Loading Faiss index if not already loaded...")
        self.index = faiss.read_index(self.index_path) if self.index is None else self.index
        if self.index_id_to_db_id is None:
            with open(self.index_id_to_db_id_path, 'rb') as f:
                self.index_id_to_db_id = pickle.load(f)

    def retrieve_documents(self, qa_pairs):
        self._load_corpus()
        examples = []
        for ci in range(0, len(qa_pairs), self.n_queries_to_parallelize):
            chunk_examples = qa_pairs[ci: ci + self.n_queries_to_parallelize]
            queries = embed(self.model, self.tokenizer, chunk_examples, bsz=self.batch_size)
            top_indices, _ = mips(self.index, queries, self.n_docs, self.n_queries_to_parallelize)
            for ati, d in zip(top_indices, chunk_examples):
                ctxs = [self.corpus[self.index_id_to_db_id[ati[j]]] for j in range(self.n_docs)]
                examples.append({'question': d['question'], 'answers': [d['answer']], 'ctxs': ctxs, 'metadata': d})
        return examples


class CompatableEncoderWrapper(torch.nn.Module):
    """Patched version of fid.model.EncoderWrapper to make it compatable with our version of transformers"""

    def __init__(self, encoder, use_checkpoint=False):
        super().__init__()
        self.encoder = encoder

    def forward(self, input_ids=None, attention_mask=None, **kwargs, ):
        # total_length = n_passages * passage_length
        bsz, total_length = input_ids.shape
        passage_length = total_length // self.n_passages
        input_ids = input_ids.view(bsz * self.n_passages, passage_length)
        attention_mask = attention_mask.view(bsz * self.n_passages, passage_length)
        outputs = self.encoder(input_ids, attention_mask, **kwargs)
        outputs.last_hidden_state = outputs.last_hidden_state.view(bsz, self.n_passages * passage_length, -1)
        return outputs


class FIDReader:
    """FID Filterer"""
    name = "filtering/fid_reader"

    def __init__(self,
                 model_path: str,
                 batch_size: int = 10,
                 device: int = 0,
                 max_seq_len: int = 200,
                 n_docs:int = 50,
                 ):
        self.device = torch.device(f"cuda:{device}") if device is not None else torch.device("cpu")
        self.tokenizer = transformers.T5Tokenizer.from_pretrained('t5-base', return_dict=False)
        self.model = src.model.FiDT5.from_pretrained(model_path)

        self.model.to(self.device)
        self.model.eval()
        self.model.encoder = CompatableEncoderWrapper(self.model.encoder.encoder) # hack to make FID compatable with newer transformers version
        self.batch_size = batch_size
        self.max_seq_len = max_seq_len
        self.n_docs = n_docs
        self.collator = src.data.Collator(self.max_seq_len, self.tokenizer)

    def _get_dataloader_for_examples(self, examples):
        for k, example in enumerate(examples):
            example['id'] = k
            for c in example['ctxs']:
                c['score'] = 1.0 / (k + 1)

        eval_dataset = src.data.Dataset(examples, self.n_docs)

        eval_sampler = SequentialSampler(eval_dataset)
        eval_dataloader = DataLoader(
            eval_dataset,
            sampler=eval_sampler,
            batch_size=self.batch_size,
            num_workers=20,
            collate_fn=self.collator
        )
        return eval_dataset, eval_dataloader

    def generate_answers(self, examples):

        eval_dataset, eval_dataloader = self._get_dataloader_for_examples(examples)
        total = 0
        exactmatch = []
        with torch.no_grad():
            for i, batch in enumerate(eval_dataloader):
                (idx, _, _, context_ids, context_mask) = batch

                outputs = self.model.generate(
                    input_ids=context_ids.to(self.device),
                    attention_mask=context_mask.to(self.device),
                    max_length=10,
                )
                for k, o in enumerate(outputs):
                    ans = self.tokenizer.decode(o, skip_special_tokens=True)
                    example = eval_dataset.data[idx[k]]
                    score = src.evaluation.ems(ans, example['answers'])
                    exactmatch.append(score)
                    example['consistent'] = score
                    example['filter_answer'] = ans
                    total += 1

                if (i + 1) % 10 == 0:
                    logger.info(f'FID filtering: {i+1} / {len(eval_dataloader)} | ave = {np.mean(exactmatch):.3f}')
        logger.info(f'FID filtering: {i+1} / {len(eval_dataloader)} | ave = {np.mean(exactmatch):.3f}')
        output = _get_reader_output_format(examples)
        return output


class DummyReader:
    """Dummy Reader, always returns consistent"""
    name = "filtering/dummy_reader"

    def generate_answers(self, examples):
        for example in examples:
            example['consistent'] = True
            example['filter_answer'] = "DUMMY_READER_ANSWER"
        output = _get_reader_output_format(examples)
        return output


def _get_reader_output_format(dataset):
    out = []
    for e in dataset:
        o = e['metadata']
        o['metadata'] = o.get('metadata', {})
        o['metadata']['consistent'] = e['consistent']
        o['metadata']['filter_answer'] = e['filter_answer']
        out.append(o)
    return out


def load_reader(config):
    READER_MAP = {m.name: m for m in [DummyReader, FIDReader]}
    reader = READER_MAP[config['name']](**config['config'])
    return reader


def load_retriever(config):
    RETRIEVER_MAP = {m.name: m for m in [LocalFilteringRetriever, GlobalFilteringRetriever, DummyFilteringRetriever]}
    retriever = RETRIEVER_MAP[config['name']](**config['config'])
    return retriever



================================================
FILE: paq/generation/generate_qa_pairs.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
from collections import defaultdict
import logging
import argparse

from paq.paq_utils import load_jsonl, dump_jsonl, get_submitit_executor
from paq.generation.passage_scorer.score_passages import score_passages_and_write_to_file
from paq.generation.answer_extractor.extract_answers import extract_answers_and_write_to_file
from paq.generation.question_generator.generate_questions import generate_questions_and_write_to_file
from paq.generation.filtering.filter_questions import filter_generated_questions_and_write_to_file


logger = logging.getLogger(__name__)

CONFIG_FILE = "config.json"

FINAL_OUTPUT = "final_qas.jsonl"
FINAL_DONE = "FINAL_DONE"


def touch(path):
    """Create an empty file. Update the mtime if it exists."""
    with open(path, 'a'):
        os.utime(path, None)


def _run_pipeline_step(config, input_file, output_file, done_indicator, verbose, fun):
    if not os.path.exists(done_indicator):
        fun(config, input_file, output_file, verbose)
        touch(done_indicator)
    return output_file


def run_passage_scoring(config, input_file, output_dir, verbose=False):
    output_file = os.path.join(output_dir, "ps.jsonl")
    done_path = os.path.join(output_dir, "PS_DONE")
    func = score_passages_and_write_to_file
    return _run_pipeline_step(config['passage_scorer'], input_file, output_file, done_path, verbose, func)


def run_answer_extraction(config, input_file, output_dir, verbose=False):
    output_file = os.path.join(output_dir, "ae.jsonl")
    done_path = os.path.join(output_dir, "AE_DONE")
    func = extract_answers_and_write_to_file
    return _run_pipeline_step(config['answer_extractor'], input_file, output_file, done_path, verbose, func)


def run_question_generation(config, input_file, output_dir, verbose=False):
    output_file = os.path.join(output_dir, "qg.jsonl")
    done_path = os.path.join(output_dir, "QG_DONE")
    func = generate_questions_and_write_to_file
    return _run_pipeline_step(config['question_generator'], input_file, output_file, done_path, verbose, func)


def run_filtering(config, input_file, output_dir, verbose=False):
    output_file = os.path.join(output_dir, "filterd_qg.jsonl")
    done_path = os.path.join(output_dir, "FILTERED_DONE")
    func = filter_generated_questions_and_write_to_file
    return _run_pipeline_step(config['filterer'], input_file, output_file, done_path, verbose, func)


def combine_generated_files(document_ranker_file,
                            question_generation_file,
                            output_file
                            ):
    # Write final generated QA-pairs to an output file

    def _get_passage_score_map(doc_ranker_file):
        passage_scores = {}
        with open(doc_ranker_file, "r") as f:
            for line in f.readlines():
                row = json.loads(line)
                passage_scores[row["passage_id"]] = row["metadata"].get("ps_score", None)
        return passage_scores

    def _add_passage_metadata(questions_fi, passage_scores):
        generated_qas = load_jsonl(questions_fi)
        qas_dict = defaultdict(list)
        for qas in generated_qas:
            question, answer, passage_id = qas["question"], qas["answer"], qas["passage_id"]
            metadata = {"passage_id": passage_id, "ps_score": passage_scores[passage_id], 'answer': answer}
            metadata.update(qas["metadata"])
            qas_dict[question].append((answer, metadata))
        return qas_dict

    def _get_output_format(qas_dict):
        final_qas = []
        for question, answers_meta in qas_dict.items():
            answers, metadata_list = zip(*answers_meta)
            final_qa = {"question": question, "answer": answers, "metadata": metadata_list}
            final_qas.append(final_qa)
        return final_qas

    passage_score_map = _get_passage_score_map(document_ranker_file)
    qas_with_meta = _add_passage_metadata(question_generation_file, passage_score_map)
    final_qas = _get_output_format(qas_with_meta)
    dump_jsonl(final_qas, output_file)


def run_paq_generation_pipeline(config: dict, input_file: str, output_dir: str, verbose: bool = False):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Save the config
    config["source"], config['output_dir'] = input_file, output_dir
    with open(os.path.join(output_dir, CONFIG_FILE), "w") as cf:
        json.dump(config, cf, indent=2)

    # Run the pipeline:
    passages_fi = run_passage_scoring(config, input_file, output_dir, verbose=verbose)
    answers_fi = run_answer_extraction(config, passages_fi, output_dir, verbose=verbose)
    questions_fi = run_question_generation(config, answers_fi, output_dir, verbose=verbose)
    filtered_questions_fi = run_filtering(config, questions_fi, output_dir, verbose=verbose)

    # Write final generated QA-pairs to an output file
    output_fi = os.path.join(output_dir, FINAL_OUTPUT)
    logging.info(f"Writing generated QA pairs to {output_fi}...")

    final_indicator = os.path.join(output_dir, FINAL_DONE)
    if not os.path.exists(final_indicator):
        output_fi = os.path.join(output_dir, FINAL_OUTPUT)
        combine_generated_files(passages_fi, filtered_questions_fi, output_fi)
        touch(final_indicator)


def _is_job_finished(job_number, output_dir):
    if os.path.exists(os.path.join(output_dir, FINAL_DONE)):
        print(f'not launching job {job_number} as its already finished: ', os.path.join(output_dir, FINAL_DONE))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--path_to_config", help="Path of config file")
    parser.add_argument("--passage_files_to_generate", help="comma separated list of files to generate QA pairs from")
    parser.add_argument("--output_dirs", help="comma separated list of directories to write the generated QA pairs to")
    parser.add_argument('--n_jobs', type=int, required=True, help='how many parallel jobs to use in slurm (n_jobs=-1 will run locally)')
    parser.add_argument('--slurm_partition', type=str, default="learnfair", help='If using submitit to run slurm jobs, define cluster partition here')
    parser.add_argument('--slurm_comment', type=str, default="", help='If using submitit to run slurm jobs, define job comment heree')
    parser.add_argument('-v', '--verbose', action="store_true")
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    with open(args.path_to_config) as f:
        config = json.load(f)

    input_files = args.passage_files_to_generate.split(',')
    output_dirs = args.output_dirs.split(',')

    if args.n_jobs == -1:
        # Run locally
        for i, (inf, out_dir) in enumerate(zip(input_files, output_dirs)):
            if not _is_job_finished(i, out_dir):
                logging.info(f'Running generation job {i}:\ninput file: {inf} \nSaving results to: {out_dir}')
                run_paq_generation_pipeline(config, inf, out_dir, args.verbose)
    else:
        # Run with submitit
        executor = get_submitit_executor(n_jobs=args.n_jobs, comment=args.slurm_comment, partition=args.slurm_partition)
        jobs = []
        with executor.batch():
            for i, (inf, out_dir) in enumerate(zip(input_files, output_dirs)):
                if not _is_job_finished(i, out_dir):
                    job = executor.submit(run_paq_generation_pipeline, config, inf, out_dir, args.verbose)
                    jobs.append((job, inf, out_dir))

        logging.info('Launching the following jobs:')
        for job, inf, out_dir in jobs:
            logging.info(f'{job.job_id} {inf} -> {out_dir}')


================================================
FILE: paq/generation/passage_scorer/__init__.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

================================================
FILE: paq/generation/passage_scorer/score_passages.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
from paq.paq_utils import load_jsonl, dump_jsonl, load_dpr_tsv
from paq.generation.passage_scorer.scorer import load_passage_scorer
import logging
import argparse

logger = logging.getLogger(__name__)


def load_passages(path):
    try:
        return load_jsonl(path)
    except:
        return load_dpr_tsv(path)


def score_passages(config, input_file, verbose):
    passage_scorer = load_passage_scorer(config)
    passages = load_passages(input_file)
    logger.info("Running Passage Scorer...")
    annotations = passage_scorer.score_passages(passages, disable_tqdm=not verbose)
    return annotations


def score_passages_and_write_to_file(config, input_path, output_path, verbose):
    annotations = score_passages(config, input_path, verbose)
    logger.info('writing extracted answers to file...')
    dump_jsonl(annotations, output_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser("Extract answers from passages")
    parser.add_argument('--passages_to_score', type=str, required=True, help='path to passages to extract in jsonl format')
    parser.add_argument('--output_path', type=str, required=True, help='Path to dump results to')
    parser.add_argument('--path_to_config', type=str, required=True, help='path to answer extractor config file')
    parser.add_argument('-v', '--verbose', action="store_true")
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    with open(args.path_to_config) as f:
        config = json.load(f)

    if 'passage_scorer' in config:
        config = config['passage_scorer']

        score_passages_and_write_to_file(config, args.passages_to_score, args.output_path, args.verbose)


================================================
FILE: paq/generation/passage_scorer/scorer.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Dict, Union
from tqdm.auto import tqdm
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
import torch


class DummyPassageScorer:
    """
    Dummy scorer that will always return the same score for any passage.
    """
    name = "passage_scorer/dummy"

    def __init__(self, default_score=0.0):
        self.default_score = default_score

    def score_passage(self, passage: Dict) -> float:
        return self.default_score

    def score_passages(self, passages_to_label, disable_tqdm=False):
        for passage in tqdm(passages_to_label, disable=disable_tqdm):
            score = self.score_passage(passage)
            passage['metadata']['ps_score'] = score
        return passages_to_label


class LookupPassageScorer:
    """
    Lookup scorer that will return the score from a file of precomputed passage scores for passages, or if not present, return a default score.
    """
    name = "passage_scorer/lookup"

    def __init__(self, scores_file, default_score=-10000.0):
        self._load_passage_scores(scores_file)
        self.default_score = default_score

    def _load_passage_scores(self, scores_file):
        self.passage_scores = {}
        for line in open(scores_file):
            k, v = line.strip('\n').split('\t')
            self.passage_scores[k] = v

    def score_passage(self, passage: Dict) -> float:
        return self.passage_scores.get(passage['passage_id'], self.default_score)

    def score_passages(self, passages_to_label, disable_tqdm=False):
        for passage in tqdm(passages_to_label, disable=disable_tqdm):
            score = self.score_passage(passage)
            passage['metadata']['ps_score'] = score
        return passages_to_label


class LearntPassageScorer:
    """Learnt scorer"""
    name = "passage_scorer/learnt"

    def __init__(self,
                 model_path: str,
                 config_path: str,
                 tokenizer_path: str = None,
                 batch_size: int = 10,
                 device: int = 0,
                 max_seq_len: int = 256):
        self.device = torch.device(f"cuda:{device}") if device is not None else torch.device("cpu")

        config = AutoConfig.from_pretrained(config_path if config_path is not None else model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, config=config)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_path, config=config)

        self.model.to(self.device)
        self.model.eval()
        self.batch_size = batch_size
        self.max_seq_len = max_seq_len

    def _tokenize(self, texts):
        input_features = self.tokenizer.batch_encode_plus(
            texts, return_tensors='pt', padding=True, add_special_tokens=True, max_length=256, truncation=True
        )
        input_features = {k: v.to(self.device) for k, v in input_features.items()}
        return input_features

    def score_passages(self, passages_to_label, disable_tqdm=False):

        def _run_batch(batch):
            inputs = self._tokenize([b['passage'] for b in batch])
            scores = self.model(**inputs)
            log_probs = torch.log_softmax(scores.logits, dim=-1)[:, 1].cpu().tolist()
            for s, b in zip(log_probs, batch):
                b['metadata']['ps_score'] = float(s)
            return scores

        batch, outputs = [], []
        for passage in tqdm(passages_to_label, disable=disable_tqdm):
            batch.append(passage)

            if len(batch) == self.batch_size:
                _run_batch(batch)
                outputs += batch
                batch = []

        if len(batch) != 0:
            _run_batch(batch)
            outputs += batch

        return outputs


def load_passage_scorer(config):
    PASSAGE_SCORER_MAP = {m.name: m for m in [LearntPassageScorer, DummyPassageScorer, LookupPassageScorer]}
    answer_extractor = PASSAGE_SCORER_MAP[config['name']](**config['config'])
    return answer_extractor


================================================
FILE: paq/generation/question_generator/__init__.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

================================================
FILE: paq/generation/question_generator/generate_questions.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
from paq.paq_utils import load_jsonl, dump_jsonl
from paq.generation.question_generator.generator import load_question_generator
import logging
import argparse

logger = logging.getLogger(__name__)


def generate_questions(config, input_file, verbose):
    question_generator = load_question_generator(config)
    passage_answer_pairs = load_jsonl(input_file)
    logger.info("Running Question Generation...")
    annotations = question_generator.generate_questions_from_passage_answer_pairs(passage_answer_pairs, disable_tqdm=not verbose)
    return annotations


def generate_questions_and_write_to_file(config, input_path, output_path, verbose):
    annotations = generate_questions(config, input_path, verbose)
    logger.info('writing generated questions to file...')
    dump_jsonl(annotations, output_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser("Extract answers from passages")
    parser.add_argument('--passage_answer_pairs_to_generate_from',
                        type=str,
                        required=True,
                        help='path to generate from (in jsonl format, produced by `answer_extractor`)')
    parser.add_argument('--output_path', type=str, required=True, help='Path to dump results to')
    parser.add_argument('--path_to_config', type=str, required=True, help='path to question generator config file')
    parser.add_argument('-v', '--verbose', action="store_true")
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    with open(args.path_to_config) as f:
        config = json.load(f)

    if 'question_generator' in config:
        config = config['question_generator']

        generate_questions_and_write_to_file(config, args.passage_answer_pairs_to_generate_from, args.output_path, args.verbose)


================================================
FILE: paq/generation/question_generator/generator.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import List, Union, Set
from tqdm.auto import tqdm
import warnings
import torch

warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.pipelines import Text2TextGenerationPipeline
from paq.paq_utils import to_fp16


logger = logging.getLogger(__name__)


def _batch_iterator(context_answer_pairs,
                    batch_size,
                    include_title: bool = True,
                    ):

    def _answer_context_pair_2_text(answer, context):
        answer_start, answer_end, answer_text = answer["start"], answer['end'], answer['text']
        return context[:answer_start] + "** " + context[answer_start:answer_end] + " **" + context[answer_end:]

    def _create_input_text(context, answer, title=None) -> str:
        text = _answer_context_pair_2_text(answer, context)

        if title is not None:
            output = f"answer: {answer['text']} | title: {title} | context: {text}"
        else:
            output = f"answer: {answer['text']} | context: {text}"
        return output

    iter_batch = []
    for context_answer_pair in context_answer_pairs:

        passage_id = context_answer_pair["passage_id"]
        context = context_answer_pair["passage"]
        answers = context_answer_pair["answers"]
        title = context_answer_pair["metadata"]["title"] if include_title else None

        for answer in answers:
            input_text = _create_input_text(context, answer, title)
            iter_batch.append((passage_id, answer, input_text))

            if len(iter_batch) >= batch_size:
                yield iter_batch
                iter_batch = []

    if len(iter_batch) > 0:
        yield iter_batch


class QuestionGenerator:
    name = "question_generator/standard"

    def __init__(
        self,
        model_path: str,
        config_path: str = None,
        tokenizer_path: str = None,
        include_title: bool = True,
        num_beams: int = None,
        num_return_sequences: int = 1,
        max_question_len: int = 30,
        batch_size: int = 1,
        device: int = 0,
        **kwargs
    ):
        assert model_path is not None

        super().__init__()

        config = AutoConfig.from_pretrained(config_path if config_path is not None else model_path)
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path if tokenizer_path is not None else model_path)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_path, config=config)

        if kwargs.get('fp16', False):
            model = model.cuda()
            model = to_fp16(model)

        self.pipeline = Text2TextGenerationPipeline(model=model, tokenizer=tokenizer, task="question-generation",
                                                    device=device)

        self.include_title = include_title  # include title in the source sequence
        self.num_beams = num_beams
        self.num_return_sequences = num_return_sequences
        self.max_question_len = max_question_len
        logger.info(
            f"Generate {self.num_return_sequences} questions for each passage with beam size {self.num_beams}.")

        self.batch_size = batch_size

        self.kwargs = kwargs

    def generate_question(self, data: Union[str, List[str]]):
        """
        Generate question for a single input sequence or a batch of input sequences.
        """
        if isinstance(data, str):
            data = [data]

        all_records = self.pipeline(
            data,
            return_text=True,
            # return_scores=True,
            clean_up_tokenization_spaces=True,
            max_length=self.max_question_len,
            min_length=3,
            num_beams=self.num_beams,
            num_return_sequences=self.num_return_sequences,
            **self.kwargs
        )

        assert len(all_records) == len(data) * self.num_return_sequences

        generated_questions = [r["generated_text"].strip() for r in all_records]
        scores = [r.get("score", None) for r in all_records]

        batched_questions = [
            generated_questions[i:i + self.num_return_sequences]
            for i in range(0, len(generated_questions), self.num_return_sequences)
        ]
        batched_scores = [
            scores[i:i + self.num_return_sequences]
            for i in range(0, len(scores), self.num_return_sequences)
        ]

        return batched_questions, batched_scores

    def generate_questions_from_passage_answer_pairs(self, passage_answer_pairs, disable_tqdm=False):
        outputs = []
        for batch in tqdm(
            _batch_iterator(passage_answer_pairs, self.batch_size, include_title=self.include_title),
            disable=disable_tqdm,
            total=len(passage_answer_pairs) // self.batch_size
        ):
            # try:
            batch_ids, batch_answers, batch_inputs = zip(*batch)
            batch_questions, batch_scores = self.generate_question(list(batch_inputs))
            # except Exception as e:
            #     logging.info('skipping Broken batch')
            #     continue

            for passage_id, answer, questions, scores in zip(batch_ids, batch_answers, batch_questions,
                                                             batch_scores):
                for question, score in zip(questions, scores):
                    output = {
                        "passage_id": passage_id,
                        "answer": answer["text"],
                        "question": question,
                        "metadata": {
                            "answer_start": answer["start"],
                            "answer_end": answer["end"],
                            "ae_score": answer["score"],
                            "qg_score": score,
                        },
                    }
                    outputs.append(output)
        return outputs


def load_question_generator(config):
    return QuestionGenerator(**config['config'])


================================================
FILE: paq/paq_utils.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
import logging
import torch
import glob
import os
import csv
try:
    import submitit
    _has_submitit = True
except ImportError:
    _has_submitit = False
try:
    import apex
    from apex import amp
    apex.amp.register_half_function(torch, "einsum")
    _has_apex = True
except ImportError:
    _has_apex = False
try:
    import spacy
    from spacy.util import minibatch, compounding
    spacy.prefer_gpu()

    _has_spacy = True
except (ImportError, AttributeError):
    _has_spacy = False


logger = logging.getLogger(__name__)


def is_spacy_available():
    return _has_spacy


def is_submitit_available():
    return _has_submitit


def is_apex_available():
    return _has_apex


def to_fp16(model):
    if is_apex_available():
        model = amp.initialize(model, opt_level="O1")
    else:
        model = model.half()
    return model


def load_jsonl_memory_friendly(fi):
    logging.info(f'Loading {fi}')

    results = []
    for ln, line in enumerate(open(fi)):
        results.append(json.loads(line))
        logging.info(f'Loaded {ln + 1} Items from {fi}') if ln % 1000000 == 0 else None

    logging.info(f'Loaded {ln + 1} Items from {fi}')
    return results


def load_jsonl_fast(fi):
    logging.info(f'Loading {fi}')

    results = []
    with open(fi) as f:
        txt = f.read()
        logging.info(f'{fi} Loaded, splitting into lines...')
        lines = [t for t in txt.split('\n') if t.strip()!='']
        logging.info(f'Parsing {len(lines)} items from jsonl:')

    for ln, line in enumerate(lines):
        results.append(json.loads(line))
        logging.info(f'Loaded {ln + 1} Items from {fi}') if ln % 1000000 == 0 else None

    logging.info(f'Loaded {ln + 1} Items from {fi}')
    return results


def load_jsonl(fi, memory_friendly=False):
    if memory_friendly:
        return load_jsonl_memory_friendly(fi)
    else:
        return load_jsonl_fast(fi)


def dump_jsonl(items, fi):
    logging.info(f'Dumping {len(items)} items into {fi}')
    k = 0

    with open(fi, 'w') as f:
        for k, item in enumerate(items):
            f.write(json.dumps(item) + '\n')
            logging.info(f'Written {k + 1} / {len(items)} items') if k % 10000 == 0 else None

    logging.info(f'Written {k + 1} / {len(items)} items')


def load_dpr_tsv(fi):
    items = []
    with open(fi) as ifile:
        reader = csv.reader(ifile, delimiter='\t')
        for spl in reader:
            idd, text, title = spl
            items.append({'passage_id': idd, "passage": text, "metadata": {'title': title}})
    return items


def get_vectors_file_paths_in_vector_directory(embeddings_dir):
    paths = glob.glob(os.path.abspath(embeddings_dir) + '/*')
    np = len(paths)
    template = '.'.join(paths[0].split('.')[:-1])
    return [template + f'.{j}' for j in range(np)]


def parse_vectors_from_directory_chunks(embeddings_dir, half):
    paths = get_vectors_file_paths_in_vector_directory(embeddings_dir)
    for j, p in enumerate(paths):
        logger.info(f'Loading vectors from {p} ({j+1} / {len(paths)})')
        m = torch.load(p)
        assert int(p.split('.')[-1]) == j, (p, j)

        if half:
            m = m if m.dtype == torch.float16 else m.half()
        else:
            m = m if m.dtype == torch.float32 else m.float()
        yield m


def parse_vectors_from_directory_fast(embeddings_dir):
    ms = []
    for m in parse_vectors_from_directory_chunks(embeddings_dir):
        ms.append(m)

    out = torch.cat(ms)
    logger.info(f'loaded index of shape {out.shape}')
    return out


def parse_vectors_from_directory_memory_friendly(embeddings_dir, size=None):
    paths = get_vectors_file_paths_in_vector_directory(embeddings_dir)
    if size is None:
        size = 0
        for j, p in enumerate(paths):
            logger.info(f'Loading vectors from {p} ({j+1} / {len(paths)}) to find total num vectors')
            m = torch.load(p)
            size += m.shape[0]

    out = None
    offset = 0
    for j, p in enumerate(paths):
        logger.info(f'Loading vectors from {p} ({j+1} / {len(paths)})')
        m = torch.load(p)

        assert int(p.split('.')[-1]) == j, (p, j)
        if out is None:
            out = torch.zeros(size, m.shape[1])
        out[offset: offset + m.shape[0]] = m
        offset += m.shape[0]
    assert offset == size
    logger.info(f'loaded index of shape {out.shape}')

    return out


def parse_vectors_from_directory(fi, memory_friendly=False, size=None, as_chunks=False, half=False):
    assert os.path.isdir(fi), f"Vectors directory {fi} doesnt exist, or is not a directory of pytorch vectors"
    if as_chunks:
        return parse_vectors_from_directory_chunks(fi, half)

    if memory_friendly:
        out = parse_vectors_from_directory_memory_friendly(fi, size=size)
    else:
        out = parse_vectors_from_directory_fast(fi)

    if half:
        out = out if out.dtype == torch.float16 else out.half()
    else:
        out = out if out.dtype == torch.float32 else out.float()

    return out


def get_submitit_executor(n_jobs=10, comment="", partition='learnfair'):
    if not is_submitit_available():
        raise Exception('Submitit Not installed')
    executor = submitit.AutoExecutor(folder='PAQ_embedding_jobs')
    executor.update_parameters(timeout_min=120,
                               slurm_partition=partition,
                               slurm_nodes=1,
                               slurm_ntasks_per_node=1,
                               slurm_cpus_per_task=10,
                               slurm_constraint='volta32gb',
                               slurm_gpus_per_node='volta:1',
                               slurm_array_parallelism=n_jobs,
                               slurm_comment=comment,
                               slurm_mem='64G')
    return executor



================================================
FILE: paq/rerankers/__init__.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

================================================
FILE: paq/rerankers/rerank.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import torch
import logging
import os
import time
from paq.paq_utils import is_apex_available, load_jsonl, dump_jsonl, get_submitit_executor, to_fp16
from transformers import AutoConfig, AutoTokenizer, AutoModelForMultipleChoice
if is_apex_available():
    import apex
    from apex import amp
    apex.amp.register_half_function(torch, "einsum")

logger = logging.getLogger(__name__)
CUDA = torch.cuda.is_available()


def load_reranker(model_name_or_path):
    logger.info(f'Loading model from: {model_name_or_path}')
    config = AutoConfig.from_pretrained(model_name_or_path)
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, do_lower_case=True)
    model = AutoModelForMultipleChoice.from_pretrained(
        model_name_or_path,
        from_tf=bool(".ckpt" in model_name_or_path),
        config=config,
    )
    model = model.eval()
    return model, tokenizer


def get_output_format(qas, prediction_indices, prediction_scores):
    assert len(qas) == len(prediction_indices)
    return [
        {
             'question': q['input_qa']['question'],
             'prediction': q['retrieved_qas'][p]['answer'][0],
             'score':s, 'index': int(p)
        }
        for q, p, s in zip(qas, prediction_indices, prediction_scores)
    ]


def tokenize(tokenizer, batch_qas, cuda, top_k):
    input_as, input_bs = [], []

    for item in batch_qas:
        question_a = item['input_qa']['question'] + '?'
        question_bs = [q['question'] + '? ' + q['answer'][0] for q in item['retrieved_qas']]
        question_bs = question_bs[:top_k]
        input_as += [question_a for _ in range(len(question_bs))]
        input_bs += question_bs

    inputs = tokenizer.batch_encode_plus(
        list(zip(input_as, input_bs)), return_tensors='pt', padding='longest', add_special_tokens=True
    )
    inputs = {k: v.reshape(len(batch_qas), v.shape[0]//len(batch_qas), -1) for k,v in inputs.items()}
    return {k: v.cuda() for k, v in inputs.items()} if cuda else inputs


def predict(model, tokenizer, qas, cuda=CUDA, bsz=16, fp16=False, top_k=30):

    if cuda:
        model = model.cuda()
        model = to_fp16(model) if fp16 else model

    t = time.time()

    def log_progress(j, outputs):
        t2 = time.time()
        logger.info(
            f'Reranked {j + 1} / {len(list(range(0, len(qas), bsz)))} batches in {t2 - t:0.2f} seconds '
            f'({len(outputs) / (t2 - t): 0.4f} QAs per second)')

    def forward(inputs):
        logits = model(**inputs)[0]
        scores, inds = logits.topk(1, dim=1)
        scores, inds = scores.squeeze().tolist(), inds.squeeze().tolist()
        if padded_batch:
            scores, inds = scores[:-1], inds[:-1]
        return scores, inds

    outputs = []
    output_scores = []
    logger.info(f'Embedding {len(qas)} inputs in {len(list(range(0, len(qas), bsz)))} batches:')
    with torch.no_grad():
        for j, batch_start in enumerate(range(0, len(qas), bsz)):

            batch = qas[batch_start: batch_start + bsz]
            padded_batch = len(batch) == 1
            if padded_batch: # hack for batch size 1 issues
                batch = [batch[0],batch[0]]

            inputs = tokenize(tokenizer, batch, cuda, top_k)
            scores, inds = forward(inputs)

            outputs.extend(inds)
            output_scores.extend(scores)

            log_progress(j, outputs) if j % 1 == 0 else None

    log_progress(j, outputs)

    return get_output_format(qas, outputs, output_scores)


def run_predictions(qas_to_rerank_file, output_file, model_name_or_path, batch_size, fp16, top_k):
    qas_to_rerank = load_jsonl(qas_to_rerank_file)
    reranker_model, reranker_tokenizer = load_reranker(model_name_or_path)

    predictions = predict(
        reranker_model,
        reranker_tokenizer,
        qas_to_rerank,
        bsz=batch_size,
        fp16=fp16,
        top_k=top_k
    )
    dump_jsonl(predictions, output_file)


def parse_files(args):
    infis, outfis = args.qas_to_rerank.split(','), args.output_files.split(',')
    assert len(infis) == len(outfis)
    pairs = []
    for in_fi, out_fi in zip(infis, outfis):
        if os.path.exists(out_fi):
            logging.info(f'skipping inference on {out_fi}, file exists')
        pairs.append((in_fi, out_fi))
    return pairs


if __name__ == '__main__':
    parser = argparse.ArgumentParser("Perform RePAQ Reranking. This program will rerank retrieval results from retrieve.py.")
    parser.add_argument('--model_name_or_path', type=str,)
    parser.add_argument('--qas_to_rerank', type=str, help='comma separated list of files produced by retrieve.py to rerank')
    parser.add_argument('--output_files', type=str, help='comma separated list of filenames to write, one for each filenmae in --qas_to_rerank')
    parser.add_argument('--top_k', type=int, default=50, help='top k to rerank')
    parser.add_argument('--fp16', action='store_true')
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--n_jobs', type=int, required=True, help='how many parallel jobs to use in slurm (n_jobs=-1 will run locally)')
    parser.add_argument('--slurm_partition', type=str, default="learnfair", help='If using submitit to run slurm jobs, define cluster partition here')
    parser.add_argument('--slurm_comment', type=str, default="", help='If using submitit to run slurm jobs, define job comment here')
    parser.add_argument('-v', '--verbose', action="store_true")
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    pairs = parse_files(args)

    if args.n_jobs != -1:
        executor = get_submitit_executor(n_jobs=args.n_jobs, comment=args.slurm_comment, partition=args.slurm_partition)
        with executor.batch():
            jobs = [
                executor.submit(run_predictions, infi, outfi, args.model_name_or_path, args.batch_size, args.fp16, args.top_k)
                for infi, outfi in pairs
            ]
        logger.info('launched the following jobs:')
        [logger.info(job.job_id) for job in jobs]
    else:
        for infi, outfi in pairs:
            run_predictions(infi, outfi, args.model_name_or_path, args.batch_size, args.fp16, args.top_k)


================================================
FILE: paq/retrievers/__init__.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

================================================
FILE: paq/retrievers/build_index.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import torch
import logging
import faiss
import os
import random
from paq.paq_utils import parse_vectors_from_directory

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)


def get_vector_sample(cached_embeddings_path, sample_fraction):
    samples = []
    max_phi = -1
    N = 0
    vectors = parse_vectors_from_directory(cached_embeddings_path, as_chunks=True)
    for chunk in vectors:
        phis = (chunk ** 2).sum(1)
        max_phi = max(max_phi, phis.max())
        N += chunk.shape[0]
        if sample_fraction == 1.0:
            chunk_sample = chunk
        else:
            chunk_sample = chunk[random.sample(range(0, len(chunk)), int(len(chunk) * sample_fraction))]
        samples.append(chunk_sample)

    del vectors
    vector_sample = torch.cat(samples)
    return vector_sample, max_phi, N


def get_vectors_dim(cached_embeddings_path):
    vectors = parse_vectors_from_directory(cached_embeddings_path, as_chunks=True)
    vector_size = next(vectors).shape[1]
    del(vectors)
    return vector_size


def augment_vectors(vectors, max_phi):
    phis = (vectors ** 2).sum(1)
    aux_dim = torch.sqrt(max_phi - phis)
    vectors = torch.cat([vectors, aux_dim.unsqueeze(-1)], -1)
    return vectors


def build_index_streaming(cached_embeddings_path,
                          output_path,
                          hnsw=False,
                          sq8_quantization=False,
                          fp16_quantization=False,
                          store_n=256,
                          ef_search=32,
                          ef_construction=80,
                          sample_fraction=0.1,
                          indexing_batch_size=5000000,
                          ):

    vector_size = get_vectors_dim(cached_embeddings_path)

    if hnsw:
        if sq8_quantization:
            index = faiss.IndexHNSWSQ(vector_size + 1, faiss.ScalarQuantizer.QT_8bit, store_n)
        elif fp16_quantization:
            index = faiss.IndexHNSWSQ(vector_size + 1, faiss.ScalarQuantizer.QT_fp16, store_n)
        else:
            index = faiss.IndexHNSWFlat(vector_size + 1, store_n)

        index.hnsw.efSearch = ef_search
        index.hnsw.efConstruction = ef_construction
    else:
        if sq8_quantization:
            index = faiss.IndexScalarQuantizer(vector_size, faiss.ScalarQuantizer.QT_8bit, faiss.METRIC_L2)
        elif fp16_quantization:
            index = faiss.IndexScalarQuantizer(vector_size, faiss.ScalarQuantizer.QT_fp16, faiss.METRIC_L2)
        else:
            index = faiss.IndexIP(vector_size + 1, store_n)

    vector_sample, max_phi, N = get_vector_sample(cached_embeddings_path, sample_fraction)
    if hnsw:
        vector_sample = augment_vectors(vector_sample, max_phi)

    if sq8_quantization or fp16_quantization: # index requires training
        vs = vector_sample.numpy()
        logging.info(f'Training Quantizer with matrix of shape {vs.shape}')
        index.train(vs)
        del vs
    del vector_sample

    chunks_to_add = []
    added = 0
    for vector_chunk in parse_vectors_from_directory(cached_embeddings_path, as_chunks=True):
        if hnsw:
            vector_chunk = augment_vectors(vector_chunk, max_phi)

        chunks_to_add.append(vector_chunk)

        if sum(c.shape[0] for c in chunks_to_add) > indexing_batch_size:
            logging.info(f'Adding Vectors {added} -> {added + to_add.shape[0]} of {N}')
            to_add = torch.cat(chunks_to_add)
            chunks_to_add = []
            index.add(to_add)
            added += 1

    if len(chunks_to_add) > 0:
        to_add = torch.cat(chunks_to_add).numpy()
        index.add(to_add)
        logging.info(f'Adding Vectors {added} -> {added + to_add.shape[0]} of {N}')

    logger.info(f'Index Built, writing index to {output_path}')
    faiss.write_index(index, output_path)
    logger.info(f'Index dumped')
    return index


if __name__ == '__main__':
    parser = argparse.ArgumentParser("Build a FAISS index from precomputed vector files from embed.py. "
                                     "Provides functionality to build either flat indexes (slow but exact)"
                                     " or HNSW indexes (much faster, but approximate). "
                                     "Optional application of 8bit or 16bit quantization is also available."
                                     " Many more indexes are possible with Faiss, consult the Faiss repository here"
                                     " if you want to build more advanced indexes.")
    parser.add_argument('--embeddings_dir', type=str, help='path to directory containing vectors to build index from')
    parser.add_argument('--output_path', type=str, help='path to write results to')
    parser.add_argument('--hnsw', action='store_true', help='Build an HNSW index rather than Flat')
    parser.add_argument('--SQ8', action='store_true', help='use SQ8 quantization on index to save memory')
    parser.add_argument('--fp16', action='store_true', help='use fp16 quantization on index to save memory')
    parser.add_argument('--store_n', type=int, default=32, help='hnsw store_n parameter')
    parser.add_argument('--ef_construction', type=int, default=128, help='hnsw ef_construction parameter')
    parser.add_argument('--ef_search', type=int, default=128, help='hnsw ef_search parameter')
    parser.add_argument('--sample_fraction', type=float, default=1.0,
                        help='If memory is limited, specify a fraction (0.0->1.0) of the '
                             'data to sample for training the quantizer')
    parser.add_argument('--indexing_batch_size', type=int, default=None,
                        help='If memory is limited, specify the approximate number '
                             'of vectors to add to the index at once')
    parser.add_argument('-v', '--verbose', action="store_true")
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)
    assert not (args.SQ8 and args.fp16), 'cant use both sq8 and fp16 Quantization'
    assert not os.path.exists(args.output_path), "Faiss index with name specificed in --output_path already exists"

    args.indexing_batch_size = 10000000000000 if args.indexing_batch_size is None else args.indexing_batch_size
    assert 0 < args.sample_fraction <= 1.0

    if args.sample_fraction:
        build_index_streaming(
            args.embeddings_dir,
            args.output_path,
            args.hnsw,
            sq8_quantization=args.SQ8,
            fp16_quantization=args.fp16,
            store_n=args.store_n,
            ef_construction=args.ef_construction,
            ef_search=args.ef_search,
            sample_fraction=args.sample_fraction,
            indexing_batch_size=args.indexing_batch_size,
        )


================================================
FILE: paq/retrievers/embed.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import torch
import logging
import time
import math
import os
from paq.paq_utils import is_apex_available, load_jsonl, get_submitit_executor, to_fp16
from paq.retrievers.retriever_utils import load_retriever


logger = logging.getLogger(__name__)
CUDA = torch.cuda.is_available()


def embed(model, tokenizer, qas, bsz=256, cuda=CUDA, fp16=False):

    def normalize_q(question: str) -> str:
        return question.strip().strip('?').lower().strip()

    def tokenize(batch_qas):
        input_qs = [normalize_q(q['question']) for q in batch_qas]
        inputs = tokenizer.batch_encode_plus(
            input_qs, return_tensors='pt', padding=True, add_special_tokens=True
        )
        return {k: v.cuda() for k, v in inputs.items()} if cuda else inputs

    if cuda:
        model = model.cuda()
        model = to_fp16(model) if fp16 else model

    t = time.time()

    def log_progress(j, outputs):
        t2 = time.time()
        logger.info(
            f'Embedded {j + 1} / {len(list(range(0, len(qas), bsz)))} batches in {t2 - t:0.2f} seconds '
            f'({sum([len(o) for o in outputs]) / (t2 - t): 0.4f} QAs per second)')

    outputs = []
    with torch.no_grad():
        for j, batch_start in enumerate(range(0, len(qas), bsz)):
            batch_qas = qas[batch_start: batch_start + bsz]
            inputs = tokenize(batch_qas)
            batch_outputs = model(**inputs)
            outputs.append(batch_outputs.cpu())
            if j % 10 == 0:
                log_progress(j, outputs)

    log_progress(j, outputs)

    return torch.cat(outputs, dim=0).cpu()


def embed_job(qas_to_embed_path, model_name_or_path, output_file_name, n_jobs, job_num, batch_size, fp16, memory_friendly_parsing):
    os.makedirs(os.path.dirname(output_file_name), exist_ok=True)

    qas_to_embed = load_jsonl(qas_to_embed_path, memory_friendly=memory_friendly_parsing)
    chunk_size = math.ceil(len(qas_to_embed) / n_jobs)

    qas_to_embed_this_job = qas_to_embed[job_num * chunk_size: (job_num + 1) * chunk_size]
    logger.info(f'Embedding Job {job_num}: Embedding {len(qas_to_embed)} inputs in {int(len(qas_to_embed) / batch_size)} batches:')

    model, tokenizer = load_retriever(model_name_or_path)
    mat = embed(model, tokenizer, qas_to_embed_this_job, bsz=batch_size, fp16=fp16)
    torch.save(mat.half(), output_file_name + f'.{job_num}')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name_or_path', type=str, required=True, help='path to HF model dir')
    parser.add_argument('--qas_to_embed', type=str,required=True, help='Path to questions to embed in jsonl format')
    parser.add_argument('--n_jobs', type=int, required=True, help='how many jobs to embed with (n_jobs=-1 will run a single job locally)')
    parser.add_argument('--output_dir', type=str, help='path to write vectors to')
    parser.add_argument('--fp16', action='store_true')
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--memory_friendly_parsing', action='store_true', help='Pass this to load jsonl files more slowly, but save memory')
    parser.add_argument('--slurm_partition', type=str, default="learnfair", help='If using submitit to run slurm jobs, define cluster partition here')
    parser.add_argument('--slurm_comment', type=str, default="", help='If using submitit to run slurm jobs, define job comment heree')
    parser.add_argument('-v', '--verbose', action="store_true")
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    if args.fp16 and not CUDA:
        raise Exception('Cant use --fp16 without a gpu, CUDA not found')

    output_path = os.path.join(args.output_dir, 'embeddings.pt')

    if args.n_jobs == -1:
        embed_job(
            args.qas_to_embed,
            args.model_name_or_path,
            output_path,
            n_jobs=1,
            job_num=0,
            batch_size=args.batch_size,
            fp16=args.fp16,
            memory_friendly_parsing=args.memory_friendly_parsing
        )
    else:
        executor = get_submitit_executor(n_jobs=args.n_jobs, comment=args.slurm_comment, partition=args.slurm_partition)
        jobs = []
        with executor.batch():
            for jn in range(args.n_jobs):
                job = executor.submit(
                    embed_job,
                    args.qas_to_embed,
                    args.model_name_or_path,
                    output_path,
                    args.n_jobs,
                    jn,
                    args.batch_size,
                    args.fp16,
                    args.memory_friendly_parsing
                )
                jobs.append(job)

        logger.info('launched the following jobs:')
        for job in jobs:
            logger.info(job.job_id)


================================================
FILE: paq/retrievers/retrieve.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import torch
import logging
import time
import faiss
import numpy as np
from paq.retrievers.retriever_utils import load_retriever
from paq.paq_utils import load_jsonl, dump_jsonl, parse_vectors_from_directory
from paq.retrievers.embed import embed
from copy import deepcopy

logger = logging.getLogger(__name__)

CUDA = torch.cuda.is_available()


def get_output_format(qas_to_answer, qas_to_retrieve_from, top_indices, top_scores):
    results = []
    for qa_ind, qa in enumerate(qas_to_answer):
        res = []
        for score_ind, ind in enumerate(top_indices[qa_ind]):
            score = top_scores[qa_ind][score_ind]
            ret_qa = deepcopy(qas_to_retrieve_from[ind])
            ret_qa['score'] = float(score)
            res.append(ret_qa)
        results.append(res)

    return [{'input_qa': in_qa, 'retrieved_qas': ret_qas} for in_qa, ret_qas in zip(qas_to_answer, results)]


def _torch_mips(index, query_batch, top_k):
    sims = torch.matmul(query_batch, index.t())
    return sims.topk(top_k)


def _flat_index_mips(index, query_batch, top_k):
    return index.search(query_batch.numpy(), top_k)


def _aux_dim_index_mips(index, query_batch, top_k):
    # querying faiss indexes for MIPS using a euclidean distance index, used with hnsw
    aux_dim = query_batch.new(query_batch.shape[0]).fill_(0)
    aux_query_batch = torch.cat([query_batch, aux_dim.unsqueeze(-1)], -1)
    return index.search(aux_query_batch.numpy(), top_k)


def _get_mips_function(index):
    if type(index) == torch.Tensor:
        return _torch_mips
    elif 'hnsw' in str(type(index)).lower():
        return _aux_dim_index_mips
    else:
        return _flat_index_mips


def mips(index, queries, top_k, n_queries_to_parallelize=256):
    t = time.time()
    all_top_indices = None
    all_top_scores = None

    _mips = _get_mips_function(index)

    for mb in range(0, len(queries), n_queries_to_parallelize):
        query_batch = queries[mb:mb + n_queries_to_parallelize].float()
        scores, top_indices = _mips(index, query_batch, top_k)

        all_top_indices = top_indices if all_top_indices is None else np.concatenate([all_top_indices, top_indices])
        all_top_scores = scores if all_top_scores is None else np.concatenate([all_top_scores, scores])

        delta = time.time() - t
        logger.info(
            f'{len(all_top_indices)}/ {len(queries)} queries searched in {delta:04f} '
            f'seconds ({len(all_top_indices) / delta} per second)')

    assert len(all_top_indices) == len(queries)

    delta = time.time() - t
    logger.info(f'Index searched in {delta:04f} seconds ({len(queries) / delta} per second)')
    return all_top_indices, all_top_scores


def run_queries(model, tokenizer, qas_to_retrieve_from, qas_to_answer, top_k, index=None,
                batch_size=128, fp16=False, n_queries_to_parallelize=2048):

    if index is None:
        index = embed(model, tokenizer, qas_to_retrieve_from, bsz=batch_size, fp16=fp16).float()

    logger.info('Embedding QAs to answer:')
    embedded_qas_to_answer = embed(model, tokenizer, qas_to_answer, bsz=batch_size, fp16=fp16)
    logger.info('Running MIPS search:')
    top_indices, top_scores = mips(index, embedded_qas_to_answer, top_k, n_queries_to_parallelize=n_queries_to_parallelize)

    return get_output_format(qas_to_answer, qas_to_retrieve_from, top_indices, top_scores)


def _load_index_if_exists(faiss_index_path, precomputed_embeddings_dir, n_vectors_to_load=None, memory_friendly=False, efsearch=128):
    index = None
    if faiss_index_path is not None:
        assert precomputed_embeddings_dir is None, "Do not specify both a --faiss_index_path and --precomputed_embeddings_dir"
        logger.info('Loading Faiss index:')
        index = faiss.read_index(faiss_index_path)
        if hasattr(index, 'hnsw'):
             index.hnsw.efSearch = efsearch

    elif precomputed_embeddings_dir is not None:
        logger.info('Loading vectors index from file:')
        index = parse_vectors_from_directory(
            precomputed_embeddings_dir,
            memory_friendly=memory_friendly,
            size=n_vectors_to_load
        ).float()

    logger.info('Index loaded') if index is not None else None
    return index


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        "Perform REPAQ QA-Pair Retrieval. This program will embed a file of questions which need"
        " answering passed as `--qas_to_answer`. These will be answered by retrieving QA-pairs from a "
        " set of QA pairs to retrieve answers from, passed in as `--qas_to_retrieve_from`. "
        " The program can retrieve either from a prebuilt faiss index for `qas_to_retrieve_from`, "
        "or a directory of precomputed vectors, or, if neither are passed in, "
        "will embed the `qas_to_retrieve_from` before performing retrieval"
    )
    parser.add_argument('--model_name_or_path', type=str, required=True, help='path to HF model dir')
    parser.add_argument('--qas_to_answer', type=str, required=True, help="path to questions to answer in jsonl format")
    parser.add_argument('--qas_to_retrieve_from', type=str, required=True,
                        help="path to QA-pairs to retrieve answers from in jsonl format")
    parser.add_argument('--top_k', type=int, default=50, help="top K QA-pairs to retrieve for each input question")
    parser.add_argument('--output_file', type=str, required=True, help='Path to write jsonl results to')
    parser.add_argument('--faiss_index_path', default=None, type=str, help="Path to faiss index, if retrieving from a faiss index")
    parser.add_argument('--precomputed_embeddings_dir', default=None, type=str, help="path to a directory of vector embeddings if retrieving from raw embeddign vectors")
    parser.add_argument('--fp16', action='store_true')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size for embedding questions for querying')
    parser.add_argument('--n_queries_to_parallelize', type=int, default=256, help="query batch size")
    parser.add_argument('-v', '--verbose', action="store_true")
    parser.add_argument('--memory_friendly_parsing', action='store_true', help='Pass this to load files more slowly, but save memory')
    parser.add_argument('--faiss_efsearch', type=int, default=128, help='EFSearch searchtime parameter for hnsw , higher is more accuate but slower')

    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    qas_to_answer = load_jsonl(args.qas_to_answer, memory_friendly=args.memory_friendly_parsing)
    qas_to_retrieve_from = load_jsonl(args.qas_to_retrieve_from, memory_friendly=args.memory_friendly_parsing)

    index = _load_index_if_exists(
        args.faiss_index_path,
        args.precomputed_embeddings_dir,
        n_vectors_to_load=len(qas_to_retrieve_from),
        memory_friendly=args.memory_friendly_parsing,
        efsearch=args.faiss_efsearch
    )
    model, tokenizer = load_retriever(args.model_name_or_path)

    retrieved_answers = run_queries(
        model,
        tokenizer,
        qas_to_retrieve_from,
        qas_to_answer,
        args.top_k,
        index,
        args.batch_size,
        args.fp16,
        args.n_queries_to_parallelize,
    )
    dump_jsonl(retrieved_answers, args.output_file)


================================================
FILE: paq/retrievers/retriever_utils.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
import os
import logging
from transformers import AutoConfig, AutoTokenizer, AutoModel


logger = logging.getLogger(__name__)


def _get_proj_keys_from_state_dict(state_dict):
    weight_key = [k for k in state_dict.keys() if 'encode_proj' in k and 'weight' in k]
    bias_key = [k for k in state_dict.keys() if 'encode_proj' in k and 'bias' in k]
    assert len(weight_key) == 1 == len(bias_key)
    weight_key, bias_key = weight_key[0], bias_key[0]
    return weight_key, bias_key


def _get_proj_dim_from_model_path(model_name_or_path):
    state = torch.load(os.path.join(model_name_or_path, 'pytorch_model.bin'), map_location=torch.device('cpu'))
    proj_dim = None
    if any('encode_proj' in k for k in state.keys()):
        _, bias_key = _get_proj_keys_from_state_dict(state)
        proj_dim = state[bias_key].shape[0]
    return proj_dim


def load_retriever(model_name_or_path):
    logger.info(f'Loading model from: {model_name_or_path}')
    model = RetrieverEncoder.from_pretrained(model_name_or_path)
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, do_lower_case=True)
    model.eval()
    return model, tokenizer


class RetrieverEncoder(nn.Module):
    """A wrapper for HF models, with an optional projection"""

    def __init__(self, config, proj_dim):
        super().__init__()
        # EncoderBase.__init__(self, config.hidden_size, project_dim)
        self.model = AutoModel.from_config(config)
        self.encode_proj = nn.Linear(config.hidden_size, proj_dim) if proj_dim is not None else None
        self.model.init_weights()

    @classmethod
    def from_pretrained(cls, model_name_or_path):
        config = AutoConfig.from_pretrained(model_name_or_path)
        proj_dim = _get_proj_dim_from_model_path(model_name_or_path)

        retriever = cls(config, proj_dim)
        state = torch.load(os.path.join(model_name_or_path, 'pytorch_model.bin'), map_location=torch.device('cpu'))
        retriever.model.load_state_dict({k.replace('albert.',''):v for k,v in state.items() if 'encode_proj' not in k}, strict=True)

        if proj_dim is not None:
            weight_key, bias_key = _get_proj_keys_from_state_dict(state)
            retriever.encode_proj.load_state_dict({'weight': state[weight_key], 'bias': state[bias_key]}, strict=True)

        return retriever

    def forward(self, *args, **kwargs):
        seq_outputs = self.model(*args, **kwargs)['last_hidden_state']
        return self.encode_proj(seq_outputs[:, 0]) if self.encode_proj is not None else seq_outputs[:, 0]


================================================
FILE: paq/server/__init__.py
================================================
#!/usr/bin/env python3
# C
Download .txt
gitextract_5g97rwbv/

├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── full_models_list.md
├── generator_configs/
│   ├── answer_extractor_configs/
│   │   ├── learnt_answer_extractor_config.json
│   │   └── named_entity_answer_extractor_config.json
│   ├── filterer_configs/
│   │   ├── dummy_filtering_config.json
│   │   ├── global_filtering_config.json
│   │   └── local_filtering_config.json
│   ├── paq_L1_config.json
│   ├── paq_L1_with_local_filtering_config.json
│   ├── paq_L4_config.json
│   ├── paq_NE_config.json
│   ├── passage_ranker_configs/
│   │   ├── dummy_passage_scorer_config.json
│   │   ├── learnt_passage_scorer_config.json
│   │   └── lookup_passage_scorer_config.json
│   └── question_generator_configs/
│       └── question_generation_config.json
├── paq/
│   ├── __init__.py
│   ├── download.py
│   ├── evaluation/
│   │   ├── __init__.py
│   │   ├── eval_reranker.py
│   │   ├── eval_retriever.py
│   │   └── eval_utils.py
│   ├── generation/
│   │   ├── __init__.py
│   │   ├── answer_extractor/
│   │   │   ├── __init__.py
│   │   │   ├── extract_answers.py
│   │   │   ├── extractors.py
│   │   │   └── span2D_model.py
│   │   ├── filtering/
│   │   │   ├── __init__.py
│   │   │   ├── filter_questions.py
│   │   │   └── filterer.py
│   │   ├── generate_qa_pairs.py
│   │   ├── passage_scorer/
│   │   │   ├── __init__.py
│   │   │   ├── score_passages.py
│   │   │   └── scorer.py
│   │   └── question_generator/
│   │       ├── __init__.py
│   │       ├── generate_questions.py
│   │       └── generator.py
│   ├── paq_utils.py
│   ├── rerankers/
│   │   ├── __init__.py
│   │   └── rerank.py
│   ├── retrievers/
│   │   ├── __init__.py
│   │   ├── build_index.py
│   │   ├── embed.py
│   │   ├── retrieve.py
│   │   └── retriever_utils.py
│   └── server/
│       ├── __init__.py
│       ├── client.py
│       ├── launch_server.sh
│       └── server.py
└── requirements.txt
Download .txt
SYMBOL INDEX (141 symbols across 21 files)

FILE: paq/download.py
  function untar (line 340) | def untar(tar_filename: str) -> List[str]:
  function unpack (line 349) | def unpack(gzip_file: str, out_file: str):
  function _get_root_dir (line 360) | def _get_root_dir(out_dir):
  function download_resource (line 371) | def download_resource(
  function download_file (line 414) | def download_file(s3_url: str, out_dir: str, file_name: str):
  function download (line 426) | def download(resource_key: str, out_dir: str = None):
  function main (line 484) | def main():

FILE: paq/evaluation/eval_reranker.py
  function evaluate_exact_match (line 12) | def evaluate_exact_match(preds, refs):

FILE: paq/evaluation/eval_retriever.py
  function eval_retriever (line 12) | def eval_retriever(refs, preds, hits_at_k):

FILE: paq/evaluation/eval_utils.py
  function normalize_answer (line 12) | def normalize_answer(s):
  function exact_match_score (line 31) | def exact_match_score(prediction, ground_truth):
  function metric_max_over_ground_truths (line 35) | def metric_max_over_ground_truths(metric_fn, predictions: Union[str, Lis...

FILE: paq/generation/answer_extractor/extract_answers.py
  function load_passages (line 16) | def load_passages(path):
  function extract_answers (line 23) | def extract_answers(config, input_file, verbose):
  function extract_answers_and_write_to_file (line 31) | def extract_answers_and_write_to_file(config, input_path, output_path, v...

FILE: paq/generation/answer_extractor/extractors.py
  function get_output_format (line 17) | def get_output_format(all_passages, all_answers):
  class SpacyNERExtractor (line 31) | class SpacyNERExtractor:
    method __init__ (line 37) | def __init__(self, model="en_core_web_sm"):
    method extract_from_passage (line 42) | def extract_from_passage(self, passage: str) -> List[Dict]:
    method extract_answers_from_passages (line 54) | def extract_answers_from_passages(self, passages_to_label, disable_tqd...
  class Span2DAnswerExtractor (line 72) | class Span2DAnswerExtractor:
    method __init__ (line 78) | def __init__(
    method _tokenize (line 112) | def _tokenize(self, passage: str):
    method extract_from_passage (line 127) | def extract_from_passage(self, passage: str):
    method extract_answers_from_passages (line 135) | def extract_answers_from_passages(self, passages_to_label, disable_tqd...
  function load_answer_extractor (line 148) | def load_answer_extractor(config):

FILE: paq/generation/answer_extractor/span2D_model.py
  class AnswerSpanExtractor2DModelOutput (line 21) | class AnswerSpanExtractor2DModelOutput(ModelOutput):
  class AnswerSpanExtractor2DModel (line 50) | class AnswerSpanExtractor2DModel(BertPreTrainedModel):
    method __init__ (line 53) | def __init__(self, config):
    method forward (line 78) | def forward(
  function sigmoid (line 171) | def sigmoid(x):
  function postprocess_span2d_output (line 175) | def postprocess_span2d_output(span2D_output: AnswerSpanExtractor2DModelO...

FILE: paq/generation/filtering/filter_questions.py
  function retrieve_documents_for_generated_questions (line 16) | def retrieve_documents_for_generated_questions(config, input_file, verbo...
  function generate_answers_for_generated_questions_with_retrieved_docs (line 24) | def generate_answers_for_generated_questions_with_retrieved_docs(config,...
  function filter_generated_questions_and_write_to_file (line 33) | def filter_generated_questions_and_write_to_file(config, input_path, out...

FILE: paq/generation/filtering/filterer.py
  function _load_corpus (line 32) | def _load_corpus(path):
  class DummyFilteringRetriever (line 41) | class DummyFilteringRetriever:
    method retrieve_documents (line 45) | def retrieve_documents(self, data):
  class LocalFilteringRetriever (line 49) | class LocalFilteringRetriever:
    method __init__ (line 54) | def __init__(self, corpus_path):
    method retrieve_documents (line 57) | def retrieve_documents(self, data):
  class DPRQuestionEncoder (line 70) | class DPRQuestionEncoder(nn.Module):
    method __init__ (line 73) | def __init__(self, model):
    method forward (line 77) | def forward(self, *args, **kwargs):
  class GlobalFilteringRetriever (line 82) | class GlobalFilteringRetriever:
    method __init__ (line 89) | def __init__(self,
    method _load_corpus (line 116) | def _load_corpus(self):
    method retrieve_documents (line 125) | def retrieve_documents(self, qa_pairs):
  class CompatableEncoderWrapper (line 138) | class CompatableEncoderWrapper(torch.nn.Module):
    method __init__ (line 141) | def __init__(self, encoder, use_checkpoint=False):
    method forward (line 145) | def forward(self, input_ids=None, attention_mask=None, **kwargs, ):
  class FIDReader (line 156) | class FIDReader:
    method __init__ (line 160) | def __init__(self,
    method _get_dataloader_for_examples (line 179) | def _get_dataloader_for_examples(self, examples):
    method generate_answers (line 197) | def generate_answers(self, examples):
  class DummyReader (line 227) | class DummyReader:
    method generate_answers (line 231) | def generate_answers(self, examples):
  function _get_reader_output_format (line 239) | def _get_reader_output_format(dataset):
  function load_reader (line 250) | def load_reader(config):
  function load_retriever (line 256) | def load_retriever(config):

FILE: paq/generation/generate_qa_pairs.py
  function touch (line 28) | def touch(path):
  function _run_pipeline_step (line 34) | def _run_pipeline_step(config, input_file, output_file, done_indicator, ...
  function run_passage_scoring (line 41) | def run_passage_scoring(config, input_file, output_dir, verbose=False):
  function run_answer_extraction (line 48) | def run_answer_extraction(config, input_file, output_dir, verbose=False):
  function run_question_generation (line 55) | def run_question_generation(config, input_file, output_dir, verbose=False):
  function run_filtering (line 62) | def run_filtering(config, input_file, output_dir, verbose=False):
  function combine_generated_files (line 69) | def combine_generated_files(document_ranker_file,
  function run_paq_generation_pipeline (line 107) | def run_paq_generation_pipeline(config: dict, input_file: str, output_di...
  function _is_job_finished (line 133) | def _is_job_finished(job_number, output_dir):

FILE: paq/generation/passage_scorer/score_passages.py
  function load_passages (line 16) | def load_passages(path):
  function score_passages (line 23) | def score_passages(config, input_file, verbose):
  function score_passages_and_write_to_file (line 31) | def score_passages_and_write_to_file(config, input_path, output_path, ve...

FILE: paq/generation/passage_scorer/scorer.py
  class DummyPassageScorer (line 13) | class DummyPassageScorer:
    method __init__ (line 19) | def __init__(self, default_score=0.0):
    method score_passage (line 22) | def score_passage(self, passage: Dict) -> float:
    method score_passages (line 25) | def score_passages(self, passages_to_label, disable_tqdm=False):
  class LookupPassageScorer (line 32) | class LookupPassageScorer:
    method __init__ (line 38) | def __init__(self, scores_file, default_score=-10000.0):
    method _load_passage_scores (line 42) | def _load_passage_scores(self, scores_file):
    method score_passage (line 48) | def score_passage(self, passage: Dict) -> float:
    method score_passages (line 51) | def score_passages(self, passages_to_label, disable_tqdm=False):
  class LearntPassageScorer (line 58) | class LearntPassageScorer:
    method __init__ (line 62) | def __init__(self,
    method _tokenize (line 80) | def _tokenize(self, texts):
    method score_passages (line 87) | def score_passages(self, passages_to_label, disable_tqdm=False):
  function load_passage_scorer (line 113) | def load_passage_scorer(config):

FILE: paq/generation/question_generator/generate_questions.py
  function generate_questions (line 16) | def generate_questions(config, input_file, verbose):
  function generate_questions_and_write_to_file (line 24) | def generate_questions_and_write_to_file(config, input_path, output_path...

FILE: paq/generation/question_generator/generator.py
  function _batch_iterator (line 24) | def _batch_iterator(context_answer_pairs,
  class QuestionGenerator (line 62) | class QuestionGenerator:
    method __init__ (line 65) | def __init__(
    method generate_question (line 104) | def generate_question(self, data: Union[str, List[str]]):
    method generate_questions_from_passage_answer_pairs (line 139) | def generate_questions_from_passage_answer_pairs(self, passage_answer_...
  function load_question_generator (line 171) | def load_question_generator(config):

FILE: paq/paq_utils.py
  function is_spacy_available (line 38) | def is_spacy_available():
  function is_submitit_available (line 42) | def is_submitit_available():
  function is_apex_available (line 46) | def is_apex_available():
  function to_fp16 (line 50) | def to_fp16(model):
  function load_jsonl_memory_friendly (line 58) | def load_jsonl_memory_friendly(fi):
  function load_jsonl_fast (line 70) | def load_jsonl_fast(fi):
  function load_jsonl (line 88) | def load_jsonl(fi, memory_friendly=False):
  function dump_jsonl (line 95) | def dump_jsonl(items, fi):
  function load_dpr_tsv (line 107) | def load_dpr_tsv(fi):
  function get_vectors_file_paths_in_vector_directory (line 117) | def get_vectors_file_paths_in_vector_directory(embeddings_dir):
  function parse_vectors_from_directory_chunks (line 124) | def parse_vectors_from_directory_chunks(embeddings_dir, half):
  function parse_vectors_from_directory_fast (line 138) | def parse_vectors_from_directory_fast(embeddings_dir):
  function parse_vectors_from_directory_memory_friendly (line 148) | def parse_vectors_from_directory_memory_friendly(embeddings_dir, size=No...
  function parse_vectors_from_directory (line 174) | def parse_vectors_from_directory(fi, memory_friendly=False, size=None, a...
  function get_submitit_executor (line 192) | def get_submitit_executor(n_jobs=10, comment="", partition='learnfair'):

FILE: paq/rerankers/rerank.py
  function load_reranker (line 23) | def load_reranker(model_name_or_path):
  function get_output_format (line 36) | def get_output_format(qas, prediction_indices, prediction_scores):
  function tokenize (line 48) | def tokenize(tokenizer, batch_qas, cuda, top_k):
  function predict (line 65) | def predict(model, tokenizer, qas, cuda=CUDA, bsz=16, fp16=False, top_k=...
  function run_predictions (line 111) | def run_predictions(qas_to_rerank_file, output_file, model_name_or_path,...
  function parse_files (line 126) | def parse_files(args):

FILE: paq/retrievers/build_index.py
  function get_vector_sample (line 19) | def get_vector_sample(cached_embeddings_path, sample_fraction):
  function get_vectors_dim (line 39) | def get_vectors_dim(cached_embeddings_path):
  function augment_vectors (line 46) | def augment_vectors(vectors, max_phi):
  function build_index_streaming (line 53) | def build_index_streaming(cached_embeddings_path,

FILE: paq/retrievers/embed.py
  function embed (line 21) | def embed(model, tokenizer, qas, bsz=256, cuda=CUDA, fp16=False):
  function embed_job (line 60) | def embed_job(qas_to_embed_path, model_name_or_path, output_file_name, n...

FILE: paq/retrievers/retrieve.py
  function get_output_format (line 23) | def get_output_format(qas_to_answer, qas_to_retrieve_from, top_indices, ...
  function _torch_mips (line 37) | def _torch_mips(index, query_batch, top_k):
  function _flat_index_mips (line 42) | def _flat_index_mips(index, query_batch, top_k):
  function _aux_dim_index_mips (line 46) | def _aux_dim_index_mips(index, query_batch, top_k):
  function _get_mips_function (line 53) | def _get_mips_function(index):
  function mips (line 62) | def mips(index, queries, top_k, n_queries_to_parallelize=256):
  function run_queries (line 88) | def run_queries(model, tokenizer, qas_to_retrieve_from, qas_to_answer, t...
  function _load_index_if_exists (line 102) | def _load_index_if_exists(faiss_index_path, precomputed_embeddings_dir, ...

FILE: paq/retrievers/retriever_utils.py
  function _get_proj_keys_from_state_dict (line 17) | def _get_proj_keys_from_state_dict(state_dict):
  function _get_proj_dim_from_model_path (line 25) | def _get_proj_dim_from_model_path(model_name_or_path):
  function load_retriever (line 34) | def load_retriever(model_name_or_path):
  class RetrieverEncoder (line 42) | class RetrieverEncoder(nn.Module):
    method __init__ (line 45) | def __init__(self, config, proj_dim):
    method from_pretrained (line 53) | def from_pretrained(cls, model_name_or_path):
    method forward (line 67) | def forward(self, *args, **kwargs):

FILE: paq/server/server.py
  class http_server (line 16) | class http_server:
    method __init__ (line 17) | def __init__(self, index, model, tokenizer, qas_to_retrieve_from, fp16):
  class WebServerHandler (line 28) | class WebServerHandler(BaseHTTPRequestHandler):
    method do_POST (line 31) | def do_POST(self):
  function main (line 62) | def main(args):
Condensed preview — 52 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (209K chars).
[
  {
    "path": "CODE_OF_CONDUCT.md",
    "chars": 244,
    "preview": "# Code of Conduct\n\nFacebook has adopted a Code of Conduct that we expect project participants to adhere to.\nPlease read "
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 572,
    "preview": "# Contributing to this repo\n\n## Pull Requests\n\nIn order to accept your pull request, we need you to submit a CLA. You on"
  },
  {
    "path": "LICENSE",
    "chars": 19329,
    "preview": "Attribution-NonCommercial 4.0 International\n\n=======================================================================\n\nCr"
  },
  {
    "path": "README.md",
    "chars": 40489,
    "preview": "# PAQ: 65 Million Probably-Asked Questions and What You Can Do With Them\n\n\nThis repository contains code and models to s"
  },
  {
    "path": "full_models_list.md",
    "chars": 4320,
    "preview": "# Full List of Models Available for Download\n\n## BiEncoder Retrievers\n\n\n| Model  | Training data |  Architecture | Embed"
  },
  {
    "path": "generator_configs/answer_extractor_configs/learnt_answer_extractor_config.json",
    "chars": 470,
    "preview": "{\n  \"answer_extractor\": {\n    \"name\": \"answer_extractor/span2D\",\n    \"config\": {\n      \"model_path\": \"data/models/answer"
  },
  {
    "path": "generator_configs/answer_extractor_configs/named_entity_answer_extractor_config.json",
    "chars": 129,
    "preview": "{\n    \"answer_extractor\": {\n    \"name\": \"answer_extractor/spacy_ner\",\n    \"config\": {\n      \"model\": \"en_core_web_sm\"\n  "
  },
  {
    "path": "generator_configs/filterer_configs/dummy_filtering_config.json",
    "chars": 216,
    "preview": "{\n  \"filterer\": {\n    \"retriever\": {\n      \"name\": \"filtering/dummy_filtering_retriever\",\n      \"config\": {\n      }\n    "
  },
  {
    "path": "generator_configs/filterer_configs/global_filtering_config.json",
    "chars": 854,
    "preview": "{\n  \"filterer\": {\n    \"retriever\": {\n      \"name\": \"filtering/global_filtering_retriever\",\n      \"config\": {\n        \"co"
  },
  {
    "path": "generator_configs/filterer_configs/local_filtering_config.json",
    "chars": 425,
    "preview": "{\n    \"filterer\": {\n    \"retriever\": {\n      \"name\": \"filtering/local_filtering_retriever\",\n      \"config\": {\n        \"c"
  },
  {
    "path": "generator_configs/paq_L1_config.json",
    "chars": 2086,
    "preview": "{\n  \"passage_scorer\": {\n    \"name\": \"passage_scorer/learnt\",\n    \"config\": {\n      \"model_path\":\"data/models/passage_ran"
  },
  {
    "path": "generator_configs/paq_L1_with_local_filtering_config.json",
    "chars": 1657,
    "preview": "{\n  \"passage_scorer\": {\n    \"name\": \"passage_scorer/learnt\",\n    \"config\": {\n      \"model_path\":\"data/models/passage_ran"
  },
  {
    "path": "generator_configs/paq_L4_config.json",
    "chars": 2086,
    "preview": "{\n  \"passage_scorer\": {\n    \"name\": \"passage_scorer/learnt\",\n    \"config\": {\n      \"model_path\":\"data/models/passage_ran"
  },
  {
    "path": "generator_configs/paq_NE_config.json",
    "chars": 1737,
    "preview": "{\n  \"passage_scorer\": {\n    \"name\": \"passage_scorer/learnt\",\n    \"config\": {\n      \"model_path\":\"data/models/passage_ran"
  },
  {
    "path": "generator_configs/passage_ranker_configs/dummy_passage_scorer_config.json",
    "chars": 115,
    "preview": "{\n  \"passage_scorer\": {\n    \"name\": \"passage_scorer/dummy\",\n    \"config\":{\n      \"default_score\": -1000\n    }\n  }\n}"
  },
  {
    "path": "generator_configs/passage_ranker_configs/learnt_passage_scorer_config.json",
    "chars": 371,
    "preview": "{\n  \"passage_scorer\": {\n    \"name\": \"passage_scorer/learnt\",\n    \"config\": {\n      \"model_path\":\"data/models/passage_ran"
  },
  {
    "path": "generator_configs/passage_ranker_configs/lookup_passage_scorer_config.json",
    "chars": 183,
    "preview": "{\n  \"passage_scorer\": {\n    \"name\": \"passage_scorer/lookup\",\n    \"config\":{\n      \"default_score\": -1000,\n      \"scores_"
  },
  {
    "path": "generator_configs/question_generator_configs/question_generation_config.json",
    "chars": 397,
    "preview": "{\n  \"question_generator\": {\n    \"name\": \"question_generator/standard\",\n    \"config\": {\n      \"model_path\": \"data/models/"
  },
  {
    "path": "paq/__init__.py",
    "chars": 218,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/download.py",
    "chars": 21828,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/evaluation/__init__.py",
    "chars": 218,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/evaluation/eval_reranker.py",
    "chars": 1527,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/evaluation/eval_retriever.py",
    "chars": 1837,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/evaluation/eval_utils.py",
    "chars": 1356,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/generation/__init__.py",
    "chars": 218,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/generation/answer_extractor/__init__.py",
    "chars": 218,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/generation/answer_extractor/extract_answers.py",
    "chars": 1986,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/generation/answer_extractor/extractors.py",
    "chars": 5501,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/generation/answer_extractor/span2D_model.py",
    "chars": 11519,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/generation/filtering/__init__.py",
    "chars": 218,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/generation/filtering/filter_questions.py",
    "chars": 2629,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/generation/filtering/filterer.py",
    "chars": 9771,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/generation/generate_qa_pairs.py",
    "chars": 7866,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/generation/passage_scorer/__init__.py",
    "chars": 218,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/generation/passage_scorer/score_passages.py",
    "chars": 1937,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/generation/passage_scorer/scorer.py",
    "chars": 4230,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/generation/question_generator/__init__.py",
    "chars": 218,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/generation/question_generator/generate_questions.py",
    "chars": 2056,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/generation/question_generator/generator.py",
    "chars": 6303,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/paq_utils.py",
    "chars": 6039,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/rerankers/__init__.py",
    "chars": 218,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/rerankers/rerank.py",
    "chars": 6459,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/retrievers/__init__.py",
    "chars": 218,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/retrievers/build_index.py",
    "chars": 7046,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/retrievers/embed.py",
    "chars": 5065,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/retrievers/retrieve.py",
    "chars": 7568,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/retrievers/retriever_utils.py",
    "chars": 2791,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/server/__init__.py",
    "chars": 218,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/server/client.py",
    "chars": 399,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "paq/server/launch_server.sh",
    "chars": 744,
    "preview": "#!/usr/bin/env bash\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is li"
  },
  {
    "path": "paq/server/server.py",
    "chars": 4364,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is"
  },
  {
    "path": "requirements.txt",
    "chars": 67,
    "preview": "wget>=3.2\ntransformers==4.1.0\nsentencepiece\nprotobuf\nsubmitit\nspacy"
  }
]

About this extraction

This page contains the full source code of the facebookresearch/PAQ GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 52 files (194.1 KB), approximately 50.0k tokens, and a symbol index with 141 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!