Showing preview only (466K chars total). Download the full file or copy to clipboard to get everything.
Repository: facebookresearch/multihop_dense_retrieval
Branch: main
Commit: 62eb2427e36a
Files: 53
Total size: 446.1 KB
Directory structure:
gitextract_529i5290/
├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── mdr/
│ ├── __init__.py
│ ├── qa/
│ │ ├── __init__.py
│ │ ├── basic_tokenizer.py
│ │ ├── config.py
│ │ ├── data_utils.py
│ │ ├── hotpot_evaluate_v1.py
│ │ ├── qa_dataset.py
│ │ ├── qa_model.py
│ │ ├── qa_trainer.py
│ │ ├── train.md
│ │ ├── train_ranker.py
│ │ └── utils.py
│ └── retrieval/
│ ├── __init__.py
│ ├── config.py
│ ├── criterions.py
│ ├── decomposed_analysis.py
│ ├── fever.ipynb
│ ├── hotpot.ipynb
│ ├── interactive_retrieval.py
│ ├── mhop_trainer.py
│ ├── single_trainer.py
│ ├── train_single.py
│ └── utils/
│ ├── basic_tokenizer.py
│ ├── gen_index_id_map.py
│ ├── mhop_utils.py
│ ├── tokenizer.py
│ └── utils.py
├── requirements.txt
├── scripts/
│ ├── add_sp_label.sh
│ ├── demo.py
│ ├── download_hotpot.sh
│ ├── encode_corpus.py
│ ├── end2end.py
│ ├── end2end.sh
│ ├── eval/
│ │ ├── eval_mhop_fever.py
│ │ ├── eval_mhop_retrieval.py
│ │ ├── eval_reranked.py
│ │ ├── eval_retrieval.py
│ │ └── eval_single_fever.py
│ ├── train_mhop.py
│ ├── train_momentum.py
│ └── train_qa.py
├── setup.py
├── setup.sh
└── submitit/
├── submit_retrieval.sh
├── submitit_qa.sh
├── submitit_train.py
└── submitit_train_qa.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
data/
mdr.egg*/
apex/
models/
logs/
.DS_Store
*.pyc
*.swp
================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies within all project spaces, and it also applies when
an individual is representing the project or its community in public spaces.
Examples of representing a project or community include using an official
project e-mail address, posting via an official social media account, or acting
as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at <opensource-conduct@fb.com>. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq
================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to multihop_dense_retrieval
We want to make contributing to this project as easy and transparent as
possible.
## Pull Requests
We actively welcome your pull requests.
1. Fork the repo and create your branch from `master`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints.
6. If you haven't already, complete the Contributor License Agreement ("CLA").
## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Facebook's open source projects.
Complete your CLA here: <https://code.facebook.com/cla>
## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.
## License
By contributing to multihop_dense_retrieval, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.
================================================
FILE: LICENSE
================================================
Attribution-NonCommercial 4.0 International
=======================================================================
Creative Commons Corporation ("Creative Commons") is not a law firm and
does not provide legal services or legal advice. Distribution of
Creative Commons public licenses does not create a lawyer-client or
other relationship. Creative Commons makes its licenses and related
information available on an "as-is" basis. Creative Commons gives no
warranties regarding its licenses, any material licensed under their
terms and conditions, or any related information. Creative Commons
disclaims all liability for damages resulting from their use to the
fullest extent possible.
Using Creative Commons Public Licenses
Creative Commons public licenses provide a standard set of terms and
conditions that creators and other rights holders may use to share
original works of authorship and other material subject to copyright
and certain other rights specified in the public license below. The
following considerations are for informational purposes only, are not
exhaustive, and do not form part of our licenses.
Considerations for licensors: Our public licenses are
intended for use by those authorized to give the public
permission to use material in ways otherwise restricted by
copyright and certain other rights. Our licenses are
irrevocable. Licensors should read and understand the terms
and conditions of the license they choose before applying it.
Licensors should also secure all rights necessary before
applying our licenses so that the public can reuse the
material as expected. Licensors should clearly mark any
material not subject to the license. This includes other CC-
licensed material, or material used under an exception or
limitation to copyright. More considerations for licensors:
wiki.creativecommons.org/Considerations_for_licensors
Considerations for the public: By using one of our public
licenses, a licensor grants the public permission to use the
licensed material under specified terms and conditions. If
the licensor's permission is not necessary for any reason--for
example, because of any applicable exception or limitation to
copyright--then that use is not regulated by the license. Our
licenses grant only permissions under copyright and certain
other rights that a licensor has authority to grant. Use of
the licensed material may still be restricted for other
reasons, including because others have copyright or other
rights in the material. A licensor may make special requests,
such as asking that all changes be marked or described.
Although not required by our licenses, you are encouraged to
respect those requests where reasonable. More_considerations
for the public:
wiki.creativecommons.org/Considerations_for_licensees
=======================================================================
Creative Commons Attribution-NonCommercial 4.0 International Public
License
By exercising the Licensed Rights (defined below), You accept and agree
to be bound by the terms and conditions of this Creative Commons
Attribution-NonCommercial 4.0 International Public License ("Public
License"). To the extent this Public License may be interpreted as a
contract, You are granted the Licensed Rights in consideration of Your
acceptance of these terms and conditions, and the Licensor grants You
such rights in consideration of benefits the Licensor receives from
making the Licensed Material available under these terms and
conditions.
Section 1 -- Definitions.
a. Adapted Material means material subject to Copyright and Similar
Rights that is derived from or based upon the Licensed Material
and in which the Licensed Material is translated, altered,
arranged, transformed, or otherwise modified in a manner requiring
permission under the Copyright and Similar Rights held by the
Licensor. For purposes of this Public License, where the Licensed
Material is a musical work, performance, or sound recording,
Adapted Material is always produced where the Licensed Material is
synched in timed relation with a moving image.
b. Adapter's License means the license You apply to Your Copyright
and Similar Rights in Your contributions to Adapted Material in
accordance with the terms and conditions of this Public License.
c. Copyright and Similar Rights means copyright and/or similar rights
closely related to copyright including, without limitation,
performance, broadcast, sound recording, and Sui Generis Database
Rights, without regard to how the rights are labeled or
categorized. For purposes of this Public License, the rights
specified in Section 2(b)(1)-(2) are not Copyright and Similar
Rights.
d. Effective Technological Measures means those measures that, in the
absence of proper authority, may not be circumvented under laws
fulfilling obligations under Article 11 of the WIPO Copyright
Treaty adopted on December 20, 1996, and/or similar international
agreements.
e. Exceptions and Limitations means fair use, fair dealing, and/or
any other exception or limitation to Copyright and Similar Rights
that applies to Your use of the Licensed Material.
f. Licensed Material means the artistic or literary work, database,
or other material to which the Licensor applied this Public
License.
g. Licensed Rights means the rights granted to You subject to the
terms and conditions of this Public License, which are limited to
all Copyright and Similar Rights that apply to Your use of the
Licensed Material and that the Licensor has authority to license.
h. Licensor means the individual(s) or entity(ies) granting rights
under this Public License.
i. NonCommercial means not primarily intended for or directed towards
commercial advantage or monetary compensation. For purposes of
this Public License, the exchange of the Licensed Material for
other material subject to Copyright and Similar Rights by digital
file-sharing or similar means is NonCommercial provided there is
no payment of monetary compensation in connection with the
exchange.
j. Share means to provide material to the public by any means or
process that requires permission under the Licensed Rights, such
as reproduction, public display, public performance, distribution,
dissemination, communication, or importation, and to make material
available to the public including in ways that members of the
public may access the material from a place and at a time
individually chosen by them.
k. Sui Generis Database Rights means rights other than copyright
resulting from Directive 96/9/EC of the European Parliament and of
the Council of 11 March 1996 on the legal protection of databases,
as amended and/or succeeded, as well as other essentially
equivalent rights anywhere in the world.
l. You means the individual or entity exercising the Licensed Rights
under this Public License. Your has a corresponding meaning.
Section 2 -- Scope.
a. License grant.
1. Subject to the terms and conditions of this Public License,
the Licensor hereby grants You a worldwide, royalty-free,
non-sublicensable, non-exclusive, irrevocable license to
exercise the Licensed Rights in the Licensed Material to:
a. reproduce and Share the Licensed Material, in whole or
in part, for NonCommercial purposes only; and
b. produce, reproduce, and Share Adapted Material for
NonCommercial purposes only.
2. Exceptions and Limitations. For the avoidance of doubt, where
Exceptions and Limitations apply to Your use, this Public
License does not apply, and You do not need to comply with
its terms and conditions.
3. Term. The term of this Public License is specified in Section
6(a).
4. Media and formats; technical modifications allowed. The
Licensor authorizes You to exercise the Licensed Rights in
all media and formats whether now known or hereafter created,
and to make technical modifications necessary to do so. The
Licensor waives and/or agrees not to assert any right or
authority to forbid You from making technical modifications
necessary to exercise the Licensed Rights, including
technical modifications necessary to circumvent Effective
Technological Measures. For purposes of this Public License,
simply making modifications authorized by this Section 2(a)
(4) never produces Adapted Material.
5. Downstream recipients.
a. Offer from the Licensor -- Licensed Material. Every
recipient of the Licensed Material automatically
receives an offer from the Licensor to exercise the
Licensed Rights under the terms and conditions of this
Public License.
b. No downstream restrictions. You may not offer or impose
any additional or different terms or conditions on, or
apply any Effective Technological Measures to, the
Licensed Material if doing so restricts exercise of the
Licensed Rights by any recipient of the Licensed
Material.
6. No endorsement. Nothing in this Public License constitutes or
may be construed as permission to assert or imply that You
are, or that Your use of the Licensed Material is, connected
with, or sponsored, endorsed, or granted official status by,
the Licensor or others designated to receive attribution as
provided in Section 3(a)(1)(A)(i).
b. Other rights.
1. Moral rights, such as the right of integrity, are not
licensed under this Public License, nor are publicity,
privacy, and/or other similar personality rights; however, to
the extent possible, the Licensor waives and/or agrees not to
assert any such rights held by the Licensor to the limited
extent necessary to allow You to exercise the Licensed
Rights, but not otherwise.
2. Patent and trademark rights are not licensed under this
Public License.
3. To the extent possible, the Licensor waives any right to
collect royalties from You for the exercise of the Licensed
Rights, whether directly or through a collecting society
under any voluntary or waivable statutory or compulsory
licensing scheme. In all other cases the Licensor expressly
reserves any right to collect such royalties, including when
the Licensed Material is used other than for NonCommercial
purposes.
Section 3 -- License Conditions.
Your exercise of the Licensed Rights is expressly made subject to the
following conditions.
a. Attribution.
1. If You Share the Licensed Material (including in modified
form), You must:
a. retain the following if it is supplied by the Licensor
with the Licensed Material:
i. identification of the creator(s) of the Licensed
Material and any others designated to receive
attribution, in any reasonable manner requested by
the Licensor (including by pseudonym if
designated);
ii. a copyright notice;
iii. a notice that refers to this Public License;
iv. a notice that refers to the disclaimer of
warranties;
v. a URI or hyperlink to the Licensed Material to the
extent reasonably practicable;
b. indicate if You modified the Licensed Material and
retain an indication of any previous modifications; and
c. indicate the Licensed Material is licensed under this
Public License, and include the text of, or the URI or
hyperlink to, this Public License.
2. You may satisfy the conditions in Section 3(a)(1) in any
reasonable manner based on the medium, means, and context in
which You Share the Licensed Material. For example, it may be
reasonable to satisfy the conditions by providing a URI or
hyperlink to a resource that includes the required
information.
3. If requested by the Licensor, You must remove any of the
information required by Section 3(a)(1)(A) to the extent
reasonably practicable.
4. If You Share Adapted Material You produce, the Adapter's
License You apply must not prevent recipients of the Adapted
Material from complying with this Public License.
Section 4 -- Sui Generis Database Rights.
Where the Licensed Rights include Sui Generis Database Rights that
apply to Your use of the Licensed Material:
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
to extract, reuse, reproduce, and Share all or a substantial
portion of the contents of the database for NonCommercial purposes
only;
b. if You include all or a substantial portion of the database
contents in a database in which You have Sui Generis Database
Rights, then the database in which You have Sui Generis Database
Rights (but not its individual contents) is Adapted Material; and
c. You must comply with the conditions in Section 3(a) if You Share
all or a substantial portion of the contents of the database.
For the avoidance of doubt, this Section 4 supplements and does not
replace Your obligations under this Public License where the Licensed
Rights include other Copyright and Similar Rights.
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
c. The disclaimer of warranties and limitation of liability provided
above shall be interpreted in a manner that, to the extent
possible, most closely approximates an absolute disclaimer and
waiver of all liability.
Section 6 -- Term and Termination.
a. This Public License applies for the term of the Copyright and
Similar Rights licensed here. However, if You fail to comply with
this Public License, then Your rights under this Public License
terminate automatically.
b. Where Your right to use the Licensed Material has terminated under
Section 6(a), it reinstates:
1. automatically as of the date the violation is cured, provided
it is cured within 30 days of Your discovery of the
violation; or
2. upon express reinstatement by the Licensor.
For the avoidance of doubt, this Section 6(b) does not affect any
right the Licensor may have to seek remedies for Your violations
of this Public License.
c. For the avoidance of doubt, the Licensor may also offer the
Licensed Material under separate terms or conditions or stop
distributing the Licensed Material at any time; however, doing so
will not terminate this Public License.
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
License.
Section 7 -- Other Terms and Conditions.
a. The Licensor shall not be bound by any additional or different
terms or conditions communicated by You unless expressly agreed.
b. Any arrangements, understandings, or agreements regarding the
Licensed Material not stated herein are separate from and
independent of the terms and conditions of this Public License.
Section 8 -- Interpretation.
a. For the avoidance of doubt, this Public License does not, and
shall not be interpreted to, reduce, limit, restrict, or impose
conditions on any use of the Licensed Material that could lawfully
be made without permission under this Public License.
b. To the extent possible, if any provision of this Public License is
deemed unenforceable, it shall be automatically reformed to the
minimum extent necessary to make it enforceable. If the provision
cannot be reformed, it shall be severed from this Public License
without affecting the enforceability of the remaining terms and
conditions.
c. No term or condition of this Public License will be waived and no
failure to comply consented to unless expressly agreed to by the
Licensor.
d. Nothing in this Public License constitutes or may be interpreted
as a limitation upon, or waiver of, any privileges and immunities
that apply to the Licensor or You, including from the legal
processes of any jurisdiction or authority.
=======================================================================
Creative Commons is not a party to its public
licenses. Notwithstanding, Creative Commons may elect to apply one of
its public licenses to material it publishes and in those instances
will be considered the “Licensor.” The text of the Creative Commons
public licenses is dedicated to the public domain under the CC0 Public
Domain Dedication. Except for the limited purpose of indicating that
material is shared under a Creative Commons public license or as
otherwise permitted by the Creative Commons policies published at
creativecommons.org/policies, Creative Commons does not authorize the
use of the trademark "Creative Commons" or any other trademark or logo
of Creative Commons without its prior written consent including,
without limitation, in connection with any unauthorized modifications
to any of its public licenses or any other arrangements,
understandings, or agreements concerning use of licensed material. For
the avoidance of doubt, this paragraph does not form part of the
public licenses.
Creative Commons may be contacted at creativecommons.org.
================================================
FILE: README.md
================================================
# [<p align=center>Multi-Hop Dense Text Retrieval (`MDR`)</p>](#p-aligncentermulti-hop-dense-text-retrieval-mdrp)
**\*\*\*\*\* Update 3/4/2021: Adding simple demo code based on [streamlit](https://streamlit.io/) \*\*\*\*\***
`MDR` is a simple and generalized dense retrieval method which recursively retrieves supporting text passages for answering complex open-domain questions. The repo provides code and pretrained retrieval models that produce **state-of-the-art** retrieval performance on two multi-hop QA datasets (the [HotpotQA](https://hotpotqa.github.io) dataset and the multi-hop subset of the [FEVER fact extraction and verification dataset](https://fever.ai)).
More details about our approach are described in our ICLR paper [Answering Complex Open-Domain Questions with Multi-Hop Dense Retrieval](https://arxiv.org/abs/2009.12756)
<p align="center"><img width="85%" src="imgs/overview.png" /></p>
- [Use the trained models](#use-the-trained-models)
- [Evaluating retrieval](#evaluating-retrieval)
- [Evaluating QA](#evaluating-qa)
- [Demo](#end-to-end-demo)
- [Train models from scratch](#train-models-from-scratch)
- [Retriever training](#retriever-training)
- [Encoding the corpus for retrieval](#encoding-the-corpus-for-retrieval)
- [ELECTRA QA model training](#electra-qa-model-training)
## Use the trained models
1. Set up the environment
```bash
conda create --name MDR python=3.6
conda activate MDR
git clone git@github.com:facebookresearch/multihop_dense_retrieval.git
cd multihop_dense_retrieval
bash setup.sh
```
2. Download the necessary data files and pretrained retrieval models
Simplified data files with **quesitons** and ground-truth **supporting passages**:
```
# save pretrained models to models/ and all processed hotpotQA into data/
# models will take about 2GB, and data will take 20GB since the pre-trained wikipedia index are included.
bash ./scripts/download_hotpot.sh
```
### Evaluating retrieval
Evalauting direct retrieval performance (The printed statistics might not adhere to the metric names defined in the paper.
* **PR**: whether one of the supporting passages is included in all retrieved passages;
* **P-EM**: whether **both** supporting passages are included in all retrieval passages;
* **Path Recall**: whether any of the topk retrieved chain extract match the ground-truth supporting passages.) and saving topk retrieved passage chains for downstream QA.
Here's an example evaluating the top1 ranked passage chains:
```
python scripts/eval/eval_mhop_retrieval.py \
data/hotpot/hotpot_qas_val.json \
data/hotpot_index/wiki_index.npy \
data/hotpot_index/wiki_id2doc.json \
models/q_encoder.pt \
--batch-size 100 \
--beam-size 1 \
--topk 1 \
--shared-encoder \
--model-name roberta-base \
--gpu \
--save-path ${SAVE_RETRIEVAL_FOR_QA}
```
Sevaral important options includes
* `--beam-size-n`: beam size at each hop;
* `--topk`: topk passage chains from beam search
* `--gpu`: move the dense index to GPU, resulting in much faster search
Expected results (Top1):
```
Evaluating 7405 samples...
Avg PR: 0.8428089128966915
Avg P-EM: 0.6592842673869007
Avg 1-Recall: 0.7906819716407832
Path Recall: 0.6592842673869007
comparison Questions num: 1487
Avg PR: 0.9932750504371217
Avg P-EM: 0.9482178883658372
Avg 1-Recall: 0.9643577673167452
Path Recall: 0.9482178883658372
bridge Questions num: 5918
Avg PR: 0.805001689760054
Avg P-EM: 0.5866846907739101
Avg 1-Recall: 0.7470429199053734
Path Recall: 0.5866846907739101
```
**Note:** For more efficient retrieval on CPU, check out the `--hnsw` option in `scripts/eval/eval_mhop_retrieval.py`.
### Evaluating QA
The best answer extraction model is based on the pretrained [ELECTRA](https://arxiv.org/abs/2003.10555), outperforming the **BERT-large-whole-word-masking** by ~2 points answer EM/F1. We construct the training data with the pretrained MDR retriever and always include the ground-truth passage chain if the MDR failed. Each training question is paired with the groundtruth SP passage chain and also 5 (hyperparameter) retrieved chains which do not match the groundtruth.
As the HotpotQA task requires evaluating the prediction of supporting sentences, we do sentence segmetation on the MDR retrieval result before feeding into the answer extraction models. Follow the script [scripts/add_sp_label.sh](scripts/add_sp_label.sh) to annotate the retrieved chains for train/val data. Supposing we got the top100 retrieved results in `data/hotpot/dev_retrieval_top100_sp.json`:
```
python scripts/train_qa.py \
--do_predict \
--predict_batch_size 200 \
--model_name google/electra-large-discriminator \
--fp16 \
--predict_file data/hotpot/dev_retrieval_top100_sp.json \
--max_seq_len 512 \
--max_q_len 64 \
--init_checkpoint models/qa_electra.pt \
--sp-pred \
--max_ans_len 30 \
--save-prediction hotpot_val_top100.json
```
Expected results:
```
01/21/2021 17:01:49 - INFO - __main__ - evaluated 7405 questions...
01/21/2021 17:01:49 - INFO - __main__ - chain ranking em: 0.8113436866981769
01/21/2021 17:01:50 - INFO - __main__ - .......Using combination factor 0.8......
01/21/2021 17:01:50 - INFO - __main__ - answer em: 0.6233625928426739, count: 7405
01/21/2021 17:01:50 - INFO - __main__ - answer f1: 0.7504594111976622, count: 7405
01/21/2021 17:01:50 - INFO - __main__ - sp em: 0.5654287643484133, count: 7405
01/21/2021 17:01:50 - INFO - __main__ - sp f1: 0.7942837708469039, count: 7405
01/21/2021 17:01:50 - INFO - __main__ - joint em: 0.42052667116812964, count: 7405
01/21/2021 17:01:50 - INFO - __main__ - joint f1: 0.6631669237532106, count: 7405
01/21/2021 17:01:50 - INFO - __main__ - Best joint F1 from combination 0.7504594111976622
01/21/2021 17:01:51 - INFO - __main__ - test performance {'em': 0.6233625928426739, 'f1': 0.7504594111976622, 'joint_em': 0.42052667116812964, 'joint_f1': 0.6631669237532106, 'sp_em': 0.5654287643484133, 'sp_f1': 0.7942837708469039}
```
## End to end Demo
A simple demo code using our pretrained models.
```
streamlit run scripts/demo.py
```
<p align="center"><img width="85%" src="imgs/demo.png" /></p>
## Train models from scratch
Our experiments are mostly run on 8 GPUs, however, we observed similar performance when using a smaller performance.
### Retriever training
```
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python scripts/train_mhop.py \
--do_train \
--prefix ${RUN_ID} \
--predict_batch_size 3000 \
--model_name roberta-base \
--train_batch_size 150 \
--learning_rate 2e-5 \
--fp16 \
--train_file ${TRAIN_DATA_PATH} \
--predict_file ${DEV_DATA_PATH} \
--seed 16 \
--eval-period -1 \
--max_c_len 300 \
--max_q_len 70 \
--max_q_sp_len 350 \
--shared-encoder \
--warmup-ratio 0.1
```
Processed train/validation data for retrieval training:
* `${TRAIN_DATA_PATH}`: data/hotpot/hotpot_train_with_neg_v0.json
* `${DEV_DATA_PATH}`: data/hotpot/hotpot_dev_with_neg_v0.json
### Finetune the question encoder with frozen memory bank
This step happens after the previous training stage and reuses the checkpoint
point.
```
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train_momentum.py \
--do_train \
--prefix {RUN_ID} \
--predict_batch_size 3000 \
--model_name roberta-base \
--train_batch_size 150 \
--learning_rate 1e-5 \
--fp16 \
--train_file {TRAIN_DATA_PATH} \
--predict_file {DEV_DATA_PATH} \
--seed 16 \
--eval-period -1 \
--max_c_len 300 \
--max_q_len 70 \
--max_q_sp_len 350 \
--momentum \
--k 76800 \
--m 0.999 \
--temperature 1 \
--init-retriever {CHECKPOINT_PT}
```
## Encoding the corpus for retrieval
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/encode_corpus.py \
--do_predict \
--predict_batch_size 1000 \
--model_name roberta-base \
--predict_file ${CORPUS_PATH} \
--init_checkpoint ${MODEL_CHECKPOINT} \
--embed_save_path ${SAVE_PATH} \
--fp16 \
--max_c_len 300 \
--num_workers 20
```
* `${CORPUS_PATH}`: each line of this file should be an json encoded object ({"title": str, "text": str}). For HotpotQA, check the authors' [guide](https://hotpotqa.github.io/wiki-readme.html) to get the processed Wikipedia corpus (abstract only).
* `${SAVE_PATH}`: path to save the numpy vectors and ID2DOC lookup table.
### ELECTRA QA model training
The ELECTRA-based QA model is sensitive to the learning rate schedule and adding a 10% warmup stage is necessary to achieve good answer extraction performance in our experiments:
```
CUDA_VISIBLE_DEVICES=0 python train_qa.py \
--do_train \
--prefix electra_large_debug_sn \
--predict_batch_size 1024 \
--model_name google/electra-large-discriminator \
--train_batch_size 12 \
--learning_rate 5e-5 \
--train_file ${QA_TRAIN_DATA} \
--predict_file ${QA_DEV_DATA} \
--seed 42 \
--eval-period 250 \
--max_seq_len 512 \
--max_q_len 64 \
--gradient_accumulation_steps 8 \
--neg-num 5 \
--fp16 \
--use-adam \
--warmup-ratio 0.1 \
--sp-weight 0.05 \
--sp-pred
```
Processed (ran [scripts/add_sp_label.sh](scripts/add_sp_label.sh)) train/validata data for QA training.
* `${QA_TRAIN_DATA}`: data/hotpot/train_retrieval_b100_k100_sp.json
* `${QA_DEV_DATA}`: data/hotpot/dev_retrieval_b50_k50_sp.json
## Cite
```
@article{xiong2020answering,
title={Answering Complex Open-Domain Questions with Multi-Hop Dense Retrieval},
author={Xiong, Wenhan and Li, Xiang Lorraine and Iyer, Srinivasan and Du, Jingfei and Lewis, Patrick and Wang, William Yang and Mehdad, Yashar and Yih, Wen-tau and Riedel, Sebastian and Kiela, Douwe and O{\u{g}}uz, Barlas},
journal={International Conference on Learning Representations},
year={2021}
}
```
## License
CC-BY-NC 4.0
================================================
FILE: mdr/__init__.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from . import qa
from . import retrieval
================================================
FILE: mdr/qa/__init__.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
================================================
FILE: mdr/qa/basic_tokenizer.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#!/usr/bin/env python3
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Base tokenizer/tokens classes and utilities."""
import copy
class Tokens(object):
"""A class to represent a list of tokenized text."""
TEXT = 0
TEXT_WS = 1
SPAN = 2
POS = 3
LEMMA = 4
NER = 5
def __init__(self, data, annotators, opts=None):
self.data = data
self.annotators = annotators
self.opts = opts or {}
def __len__(self):
"""The number of tokens."""
return len(self.data)
def slice(self, i=None, j=None):
"""Return a view of the list of tokens from [i, j)."""
new_tokens = copy.copy(self)
new_tokens.data = self.data[i: j]
return new_tokens
def untokenize(self):
"""Returns the original text (with whitespace reinserted)."""
return ''.join([t[self.TEXT_WS] for t in self.data]).strip()
def words(self, uncased=False):
"""Returns a list of the text of each token
Args:
uncased: lower cases text
"""
if uncased:
return [t[self.TEXT].lower() for t in self.data]
else:
return [t[self.TEXT] for t in self.data]
def offsets(self):
"""Returns a list of [start, end) character offsets of each token."""
return [t[self.SPAN] for t in self.data]
def pos(self):
"""Returns a list of part-of-speech tags of each token.
Returns None if this annotation was not included.
"""
if 'pos' not in self.annotators:
return None
return [t[self.POS] for t in self.data]
def lemmas(self):
"""Returns a list of the lemmatized text of each token.
Returns None if this annotation was not included.
"""
if 'lemma' not in self.annotators:
return None
return [t[self.LEMMA] for t in self.data]
def entities(self):
"""Returns a list of named-entity-recognition tags of each token.
Returns None if this annotation was not included.
"""
if 'ner' not in self.annotators:
return None
return [t[self.NER] for t in self.data]
def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True):
"""Returns a list of all ngrams from length 1 to n.
Args:
n: upper limit of ngram length
uncased: lower cases text
filter_fn: user function that takes in an ngram list and returns
True or False to keep or not keep the ngram
as_string: return the ngram as a string vs list
"""
def _skip(gram):
if not filter_fn:
return False
return filter_fn(gram)
words = self.words(uncased)
ngrams = [(s, e + 1)
for s in range(len(words))
for e in range(s, min(s + n, len(words)))
if not _skip(words[s:e + 1])]
# Concatenate into strings
if as_strings:
ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams]
return ngrams
def entity_groups(self):
"""Group consecutive entity tokens with the same NER tag."""
entities = self.entities()
if not entities:
return None
non_ent = self.opts.get('non_ent', 'O')
groups = []
idx = 0
while idx < len(entities):
ner_tag = entities[idx]
# Check for entity tag
if ner_tag != non_ent:
# Chomp the sequence
start = idx
while (idx < len(entities) and entities[idx] == ner_tag):
idx += 1
groups.append((self.slice(start, idx).untokenize(), ner_tag))
else:
idx += 1
return groups
class Tokenizer(object):
"""Base tokenizer class.
Tokenizers implement tokenize, which should return a Tokens class.
"""
def tokenize(self, text):
raise NotImplementedError
def shutdown(self):
pass
def __del__(self):
self.shutdown()
import regex
import logging
logger = logging.getLogger(__name__)
class RegexpTokenizer(Tokenizer):
DIGIT = r'\p{Nd}+([:\.\,]\p{Nd}+)*'
TITLE = (r'(dr|esq|hon|jr|mr|mrs|ms|prof|rev|sr|st|rt|messrs|mmes|msgr)'
r'\.(?=\p{Z})')
ABBRV = r'([\p{L}]\.){2,}(?=\p{Z}|$)'
ALPHA_NUM = r'[\p{L}\p{N}\p{M}]++'
HYPHEN = r'{A}([-\u058A\u2010\u2011]{A})+'.format(A=ALPHA_NUM)
NEGATION = r"((?!n't)[\p{L}\p{N}\p{M}])++(?=n't)|n't"
CONTRACTION1 = r"can(?=not\b)"
CONTRACTION2 = r"'([tsdm]|re|ll|ve)\b"
START_DQUOTE = r'(?<=[\p{Z}\(\[{<]|^)(``|["\u0093\u201C\u00AB])(?!\p{Z})'
START_SQUOTE = r'(?<=[\p{Z}\(\[{<]|^)[\'\u0091\u2018\u201B\u2039](?!\p{Z})'
END_DQUOTE = r'(?<!\p{Z})(\'\'|["\u0094\u201D\u00BB])'
END_SQUOTE = r'(?<!\p{Z})[\'\u0092\u2019\u203A]'
DASH = r'--|[\u0096\u0097\u2013\u2014\u2015]'
ELLIPSES = r'\.\.\.|\u2026'
PUNCT = r'\p{P}'
NON_WS = r'[^\p{Z}\p{C}]'
def __init__(self, **kwargs):
"""
Args:
annotators: None or empty set (only tokenizes).
substitutions: if true, normalizes some token types (e.g. quotes).
"""
self._regexp = regex.compile(
'(?P<digit>%s)|(?P<title>%s)|(?P<abbr>%s)|(?P<neg>%s)|(?P<hyph>%s)|'
'(?P<contr1>%s)|(?P<alphanum>%s)|(?P<contr2>%s)|(?P<sdquote>%s)|'
'(?P<edquote>%s)|(?P<ssquote>%s)|(?P<esquote>%s)|(?P<dash>%s)|'
'(?<ellipses>%s)|(?P<punct>%s)|(?P<nonws>%s)' %
(self.DIGIT, self.TITLE, self.ABBRV, self.NEGATION, self.HYPHEN,
self.CONTRACTION1, self.ALPHA_NUM, self.CONTRACTION2,
self.START_DQUOTE, self.END_DQUOTE, self.START_SQUOTE,
self.END_SQUOTE, self.DASH, self.ELLIPSES, self.PUNCT,
self.NON_WS),
flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
)
if len(kwargs.get('annotators', {})) > 0:
logger.warning('%s only tokenizes! Skipping annotators: %s' %
(type(self).__name__, kwargs.get('annotators')))
self.annotators = set()
self.substitutions = kwargs.get('substitutions', True)
def tokenize(self, text):
data = []
matches = [m for m in self._regexp.finditer(text)]
for i in range(len(matches)):
# Get text
token = matches[i].group()
# Make normalizations for special token types
if self.substitutions:
groups = matches[i].groupdict()
if groups['sdquote']:
token = "``"
elif groups['edquote']:
token = "''"
elif groups['ssquote']:
token = "`"
elif groups['esquote']:
token = "'"
elif groups['dash']:
token = '--'
elif groups['ellipses']:
token = '...'
# Get whitespace
span = matches[i].span()
start_ws = span[0]
if i + 1 < len(matches):
end_ws = matches[i + 1].span()[0]
else:
end_ws = span[1]
# Format data
data.append((
token,
text[start_ws: end_ws],
span,
))
return Tokens(data, self.annotators)
class SimpleTokenizer(Tokenizer):
ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
NON_WS = r'[^\p{Z}\p{C}]'
def __init__(self, **kwargs):
"""
Args:
annotators: None or empty set (only tokenizes).
"""
self._regexp = regex.compile(
'(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
)
if len(kwargs.get('annotators', {})) > 0:
logger.warning('%s only tokenizes! Skipping annotators: %s' %
(type(self).__name__, kwargs.get('annotators')))
self.annotators = set()
def tokenize(self, text):
data = []
matches = [m for m in self._regexp.finditer(text)]
for i in range(len(matches)):
# Get text
token = matches[i].group()
# Get whitespace
span = matches[i].span()
start_ws = span[0]
if i + 1 < len(matches):
end_ws = matches[i + 1].span()[0]
else:
end_ws = span[1]
# Format data
data.append((
token,
text[start_ws: end_ws],
span,
))
return Tokens(data, self.annotators)
================================================
FILE: mdr/qa/config.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from ast import parse
from typing import NamedTuple
from torch.nn import parallel
class ClusterConfig(NamedTuple):
dist_backend: str
dist_url: str
def common_args():
parser = argparse.ArgumentParser()
# task
parser.add_argument("--train_file", type=str,
default="../data/nq-with-neg-train.txt")
parser.add_argument("--predict_file", type=str,
default="../data/nq-with-neg-dev.txt")
parser.add_argument("--num_workers", default=10, type=int)
parser.add_argument("--do_train", default=False,
action='store_true', help="Whether to run training.")
parser.add_argument("--do_predict", default=False,
action='store_true', help="Whether to run eval on the dev set.")
parser.add_argument("--do_test", default=False, action="store_true", help="for final test submission")
# model
parser.add_argument("--model_name",
default="bert-base-uncased", type=str)
parser.add_argument("--init_checkpoint", type=str,
help="Initial checkpoint (usually from a pre-trained BERT model).",
default="")
parser.add_argument("--max_seq_len", default=512, type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded.")
parser.add_argument("--max_q_len", default=64, type=int)
parser.add_argument("--max_ans_len", default=35, type=int)
parser.add_argument('--fp16', action='store_true')
parser.add_argument('--fp16_opt_level', type=str, default='O1',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument("--no_cuda", default=False, action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument("--local_rank", type=int, default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument("--predict_batch_size", default=256,
type=int, help="Total batch size for predictions.")
parser.add_argument("--save-prediction", default="", type=str)
parser.add_argument("--sp-pred", action="store_true", help="whether to predict sentence sp")
return parser
def train_args():
parser = common_args()
# optimization
parser.add_argument('--prefix', type=str, default="eval")
parser.add_argument("--weight_decay", default=0.0, type=float,
help="Weight decay if we apply some.")
parser.add_argument("--output_dir", default="./logs", type=str,
help="The output directory where the model checkpoints will be written.")
parser.add_argument("--train_batch_size", default=128,
type=int, help="Total batch size for training.")
parser.add_argument("--num_q_per_gpu", default=1)
parser.add_argument("--learning_rate", default=1e-5,
type=float, help="The initial learning rate for Adam.")
parser.add_argument("--num_train_epochs", default=5, type=float,
help="Total number of training epochs to perform.")
parser.add_argument('--seed', type=int, default=3,
help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumualte before performing a backward/update pass.")
parser.add_argument('--eval-period', type=int, default=2500)
parser.add_argument("--max_grad_norm", default=2.0, type=float, help="Max gradient norm.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--neg-num", type=int, default=9, help="how many neg/distant passage chains to use")
parser.add_argument("--shared-norm", action="store_true")
parser.add_argument("--qa-drop", default=0, type=float)
parser.add_argument("--rank-drop", default=0, type=float)
parser.add_argument("--sp-drop", default=0, type=float)
parser.add_argument("--final-metric", default="joint_f1")
parser.add_argument("--use-adam", action="store_true", help="use adam or adamW")
parser.add_argument("--warmup-ratio", default=0, type=float, help="Linear warmup over warmup_steps.")
parser.add_argument("--sp-weight", default=0, type=float, help="weight of the sp loss")
return parser.parse_args()
================================================
FILE: mdr/qa/data_utils.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
from tqdm import tqdm
import numpy as np
def explore(path):
train = json.load(open(path))
neg_counts = []
for item in train:
tfidf_neg = item["tfidf_neg"]
linked_neg = item["linked_neg"]
neg_counts.append(len(tfidf_neg + linked_neg))
import pdb; pdb.set_trace()
return
def load_corpus(corpus_path="/private/home/xwhan/data/hotpot/tfidf/abstracts.txt"):
content = [json.loads(l) for l in open(corpus_path).readlines()]
title2doc = {item["title"]:item["text"] for item in content}
if __name__ == "__main__":
explore("/private/home/xwhan/data/hotpot/hotpot_rerank_train_2_neg_types.json")
================================================
FILE: mdr/qa/hotpot_evaluate_v1.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import sys
import ujson as json
import re
import string
from collections import Counter
import pickle
def normalize_answer(s):
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def f1_score(prediction, ground_truth):
normalized_prediction = normalize_answer(prediction)
normalized_ground_truth = normalize_answer(ground_truth)
ZERO_METRIC = (0, 0, 0)
if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
return ZERO_METRIC
if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
return ZERO_METRIC
prediction_tokens = normalized_prediction.split()
ground_truth_tokens = normalized_ground_truth.split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return ZERO_METRIC
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1, precision, recall
def exact_match_score(prediction, ground_truth):
return (normalize_answer(prediction) == normalize_answer(ground_truth))
def update_answer(metrics, prediction, gold):
em = exact_match_score(prediction, gold)
f1, prec, recall = f1_score(prediction, gold)
metrics['em'] += float(em)
metrics['f1'] += f1
metrics['prec'] += prec
metrics['recall'] += recall
return em, prec, recall
def update_sp(metrics, prediction, gold):
cur_sp_pred = set(map(tuple, prediction))
gold_sp_pred = set(map(tuple, gold))
tp, fp, fn = 0, 0, 0
for e in cur_sp_pred:
if e in gold_sp_pred:
tp += 1
else:
fp += 1
for e in gold_sp_pred:
if e not in cur_sp_pred:
fn += 1
prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0
recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0
f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0
em = 1.0 if fp + fn == 0 else 0.0
metrics['sp_em'] += em
metrics['sp_f1'] += f1
metrics['sp_prec'] += prec
metrics['sp_recall'] += recall
return em, prec, recall
def eval(prediction_file, gold_file):
with open(prediction_file) as f:
prediction = json.load(f)
with open(gold_file) as f:
gold = json.load(f)
metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0,
'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0,
'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0}
for dp in gold:
cur_id = dp['_id']
can_eval_joint = True
if cur_id not in prediction['answer']:
print('missing answer {}'.format(cur_id))
can_eval_joint = False
else:
em, prec, recall = update_answer(
metrics, prediction['answer'][cur_id], dp['answer'])
if cur_id not in prediction['sp']:
print('missing sp fact {}'.format(cur_id))
can_eval_joint = False
else:
sp_em, sp_prec, sp_recall = update_sp(
metrics, prediction['sp'][cur_id], dp['supporting_facts'])
if can_eval_joint:
joint_prec = prec * sp_prec
joint_recall = recall * sp_recall
if joint_prec + joint_recall > 0:
joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall)
else:
joint_f1 = 0.
joint_em = em * sp_em
metrics['joint_em'] += joint_em
metrics['joint_f1'] += joint_f1
metrics['joint_prec'] += joint_prec
metrics['joint_recall'] += joint_recall
N = len(gold)
for k in metrics.keys():
metrics[k] /= N
print(metrics)
if __name__ == '__main__':
eval(sys.argv[1], sys.argv[2])
================================================
FILE: mdr/qa/qa_dataset.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import collections
import json
import random
import torch
from torch.utils.data import Dataset, Sampler
from tqdm import tqdm
from .basic_tokenizer import SimpleTokenizer
from .utils import (find_ans_span_with_char_offsets, match_answer_span, para_has_answer, _is_whitespace)
def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False):
"""Convert a list of 1d tensors into a padded 2d tensor."""
if len(values[0].size()) > 1:
values = [v.view(-1) for v in values]
size = max(v.size(0) for v in values)
res = values[0].new(len(values), size).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
if move_eos_to_beginning:
assert src[-1] == eos_idx
dst[0] = eos_idx
dst[1:] = src[:-1]
else:
dst.copy_(src)
for i, v in enumerate(values):
copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
return res
def prepare(item, tokenizer, special_toks=["[SEP]", "[unused1]", "[unused2]"]):
"""
tokenize the passages chains, add sentence start markers for SP sentence identification
"""
def _process_p(para):
"""
handle each para
"""
title, sents = para["title"].strip(), para["sents"]
# return "[unused1] " + title + " [unused1] " + text # mark title
# return title + " " + text
pre_sents = []
for idx, sent in enumerate(sents):
pre_sents.append("[unused1] " + sent.strip())
return title + " " + " ".join(pre_sents)
# return " ".join(pre_sents)
# mark passage boundary
contexts = []
for para in item["passages"]:
contexts.append(_process_p(para))
context = " [SEP] ".join(contexts)
doc_tokens = []
char_to_word_offset = []
prev_is_whitespace = True
context = "yes no [SEP] " + context
for c in context:
if _is_whitespace(c):
prev_is_whitespace = True
else:
if prev_is_whitespace:
doc_tokens.append(c)
else:
doc_tokens[-1] += c
prev_is_whitespace = False
char_to_word_offset.append(len(doc_tokens) - 1)
sent_starts = []
orig_to_tok_index = []
tok_to_orig_index = []
all_doc_tokens = []
for (i, token) in enumerate(doc_tokens):
orig_to_tok_index.append(len(all_doc_tokens))
if token in special_toks:
if token == "[unused1]":
sent_starts.append(len(all_doc_tokens))
sub_tokens = [token]
else:
sub_tokens = tokenizer.tokenize(token)
for sub_token in sub_tokens:
tok_to_orig_index.append(i)
all_doc_tokens.append(sub_token)
item["context_processed"] = {
"doc_tokens": doc_tokens,
"char_to_word_offset": char_to_word_offset,
"orig_to_tok_index": orig_to_tok_index,
"tok_to_orig_index": tok_to_orig_index,
"all_doc_tokens": all_doc_tokens,
"context": context,
"sent_starts": sent_starts
}
return item
class QAEvalDataset(Dataset):
def __init__(self,
tokenizer,
retrievel_results,
max_seq_len,
max_q_len,
):
retriever_outputs = retrievel_results
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.max_q_len = max_q_len
self.data = []
for item in retriever_outputs:
if item["question"].endswith("?"):
item["question"] = item["question"][:-1]
# for validation, add target predictions
sp_titles = None
gold_answer = item.get("answer", [])
sp_gold = []
for chain in item["candidate_chains"]:
chain_titles = [_["title"] for _ in chain]
if sp_titles:
label = int(set(chain_titles) == sp_titles)
else:
label = -1
self.data.append({
"question": item["question"],
"passages": chain,
"label": label,
"qid": item["_id"],
"gold_answer": gold_answer,
"sp_gold": sp_gold
})
print(f"Total instances size {len(self.data)}")
def __len__(self):
return len(self.data)
def __getitem__(self, index):
item = prepare(self.data[index], self.tokenizer)
context_ann = item["context_processed"]
q_toks = self.tokenizer.tokenize(item["question"])[:self.max_q_len]
para_offset = len(q_toks) + 2 # cls and seq
item["wp_tokens"] = context_ann["all_doc_tokens"]
assert item["wp_tokens"][0] == "yes" and item["wp_tokens"][1] == "no"
item["para_offset"] = para_offset
max_toks_for_doc = self.max_seq_len - para_offset - 1
if len(item["wp_tokens"]) > max_toks_for_doc:
item["wp_tokens"] = item["wp_tokens"][:max_toks_for_doc]
item["encodings"] = self.tokenizer.encode_plus(q_toks, text_pair=item["wp_tokens"], max_length=self.max_seq_len, return_tensors="pt", is_pretokenized=True)
item["paragraph_mask"] = torch.zeros(item["encodings"]["input_ids"].size()).view(-1)
item["paragraph_mask"][para_offset:-1] = 1
item["doc_tokens"] = context_ann["doc_tokens"]
item["tok_to_orig_index"] = context_ann["tok_to_orig_index"]
# filter sentence offsets exceeding max sequence length
sent_labels, sent_offsets = [], []
for idx, s in enumerate(item["context_processed"]["sent_starts"]):
if s >= len(item["wp_tokens"]):
break
if "sp_sent_labels" in item:
sent_labels.append(item["sp_sent_labels"][idx])
sent_offsets.append(s + para_offset)
assert item["encodings"]["input_ids"].view(-1)[s+para_offset] == self.tokenizer.convert_tokens_to_ids("[unused1]")
# supporting fact label
item["sent_offsets"] = sent_offsets
item["sent_offsets"] = torch.LongTensor(item["sent_offsets"])
item["label"] = torch.LongTensor([item["label"]])
return item
class QADataset(Dataset):
def __init__(self,
tokenizer,
data_path,
max_seq_len,
max_q_len,
train=False,
no_sent_label=False
):
retriever_outputs = [json.loads(l) for l in tqdm(open(data_path).readlines())]
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.max_q_len = max_q_len
self.train = train
self.no_sent_label = no_sent_label
self.simple_tok = SimpleTokenizer()
self.data = []
if train:
self.qid2gold = collections.defaultdict(list) # idx
self.qid2neg = collections.defaultdict(list)
for item in retriever_outputs:
if item["question"].endswith("?"):
item["question"] = item["question"][:-1]
sp_sent_labels = []
sp_gold = []
if not self.no_sent_label:
for sp in item["sp"]:
for _ in sp["sp_sent_ids"]:
sp_gold.append([sp["title"], _])
for idx in range(len(sp["sents"])):
sp_sent_labels.append(int(idx in sp["sp_sent_ids"]))
question_type = item["type"]
self.data.append({
"question": item["question"],
"passages": item["sp"],
"label": 1,
"qid": item["_id"],
"gold_answer": item["answer"],
"sp_sent_labels": sp_sent_labels,
"ans_covered": 1, # includes partial chains.
"sp_gold": sp_gold
})
self.qid2gold[item["_id"]].append(len(self.data) - 1)
sp_titles = set([_["title"] for _ in item["sp"]])
if question_type == "bridge":
ans_titles = set([p["title"] for p in item["sp"] if para_has_answer(item["answer"], "".join(p["sents"]), self.simple_tok)])
else:
ans_titles = set()
# top ranked negative chains
ds_count = 0 # track how many distant supervised chain to use
ds_limit = 5
for chain in item["candidate_chains"]:
chain_titles = [_["title"] for _ in chain]
if set(chain_titles) == sp_titles:
continue
if question_type == "bridge":
answer_covered = int(len(set(chain_titles) & ans_titles) > 0)
ds_count += answer_covered
else:
answer_covered = 0
self.data.append({
"question": item["question"],
"passages": chain,
"label": 0,
"qid": item["_id"],
"gold_answer": item["answer"],
"ans_covered": answer_covered,
"sp_gold": sp_gold
})
self.qid2neg[item["_id"]].append(len(self.data) - 1)
else:
for item in retriever_outputs:
if item["question"].endswith("?"):
item["question"] = item["question"][:-1]
# for validation, add target predictions
sp_titles = set([_["title"] for _ in item["sp"]]) if "sp" in item else None
gold_answer = item.get("answer", [])
sp_gold = []
if "sp" in item:
for sp in item["sp"]:
for _ in sp["sp_sent_ids"]:
sp_gold.append([sp["title"], _])
chain_seen = set()
for chain in item["candidate_chains"]:
chain_titles = [_["title"] for _ in chain]
# title_set = frozenset(chain_titles)
# if len(title_set) == 0 or title_set in chain_seen:
# continue
# chain_seen.add(title_set)
if sp_titles:
label = int(set(chain_titles) == sp_titles)
else:
label = -1
self.data.append({
"question": item["question"],
"passages": chain,
"label": label,
"qid": item["_id"],
"gold_answer": gold_answer,
"sp_gold": sp_gold
})
print(f"Data size {len(self.data)}")
def __len__(self):
return len(self.data)
def __getitem__(self, index):
item = prepare(self.data[index], self.tokenizer)
context_ann = item["context_processed"]
q_toks = self.tokenizer.tokenize(item["question"])[:self.max_q_len]
para_offset = len(q_toks) + 2 # cls and seq
item["wp_tokens"] = context_ann["all_doc_tokens"]
assert item["wp_tokens"][0] == "yes" and item["wp_tokens"][1] == "no"
item["para_offset"] = para_offset
max_toks_for_doc = self.max_seq_len - para_offset - 1
if len(item["wp_tokens"]) > max_toks_for_doc:
item["wp_tokens"] = item["wp_tokens"][:max_toks_for_doc]
item["encodings"] = self.tokenizer.encode_plus(q_toks, text_pair=item["wp_tokens"], max_length=self.max_seq_len, return_tensors="pt", is_pretokenized=True)
item["paragraph_mask"] = torch.zeros(item["encodings"]["input_ids"].size()).view(-1)
item["paragraph_mask"][para_offset:-1] = 1
if self.train:
# if item["label"] == 1:
if item["ans_covered"]:
if item["gold_answer"][0] == "yes":
# ans_type = 0
starts, ends= [para_offset], [para_offset]
elif item["gold_answer"][0] == "no":
# ans_type = 1
starts, ends= [para_offset + 1], [para_offset + 1]
else:
# ans_type = 2
matched_spans = match_answer_span(context_ann["context"], item["gold_answer"], self.simple_tok)
ans_starts, ans_ends= [], []
for span in matched_spans:
char_starts = [i for i in range(len(context_ann["context"])) if context_ann["context"].startswith(span, i)]
if len(char_starts) > 0:
char_ends = [start + len(span) - 1 for start in char_starts]
answer = {"text": span, "char_spans": list(zip(char_starts, char_ends))}
ans_spans = find_ans_span_with_char_offsets(
answer, context_ann["char_to_word_offset"], context_ann["doc_tokens"], context_ann["all_doc_tokens"], context_ann["orig_to_tok_index"], self.tokenizer)
for s, e in ans_spans:
ans_starts.append(s)
ans_ends.append(e)
starts, ends = [], []
for s, e in zip(ans_starts, ans_ends):
if s >= len(item["wp_tokens"]):
continue
else:
s = min(s, len(item["wp_tokens"]) - 1) + para_offset
e = min(e, len(item["wp_tokens"]) - 1) + para_offset
starts.append(s)
ends.append(e)
if len(starts) == 0:
starts, ends = [-1], [-1]
else:
starts, ends= [-1], [-1]
# ans_type = -1
item["starts"] = torch.LongTensor(starts)
item["ends"] = torch.LongTensor(ends)
# item["ans_type"] = torch.LongTensor([ans_type])
if item["label"]:
assert len(item["sp_sent_labels"]) == len(item["context_processed"]["sent_starts"])
else:
# # for answer extraction
item["doc_tokens"] = context_ann["doc_tokens"]
item["tok_to_orig_index"] = context_ann["tok_to_orig_index"]
# filter sentence offsets exceeding max sequence length
sent_labels, sent_offsets = [], []
for idx, s in enumerate(item["context_processed"]["sent_starts"]):
if s >= len(item["wp_tokens"]):
break
if "sp_sent_labels" in item:
sent_labels.append(item["sp_sent_labels"][idx])
sent_offsets.append(s + para_offset)
assert item["encodings"]["input_ids"].view(-1)[s+para_offset] == self.tokenizer.convert_tokens_to_ids("[unused1]")
# supporting fact label
item["sent_offsets"] = sent_offsets
item["sent_offsets"] = torch.LongTensor(item["sent_offsets"])
if self.train:
item["sent_labels"] = sent_labels if len(sent_labels) != 0 else [0] * len(sent_offsets)
item["sent_labels"] = torch.LongTensor(item["sent_labels"])
item["ans_covered"] = torch.LongTensor([item["ans_covered"]])
item["label"] = torch.LongTensor([item["label"]])
return item
class MhopSampler(Sampler):
"""
Shuffle QA pairs not context, make sure data within the batch are from the same QA pair
"""
def __init__(self, data_source, num_neg=9, n_gpu=8):
# for each QA pair, sample negative paragraphs
self.qid2gold = data_source.qid2gold
self.qid2neg = data_source.qid2neg
self.neg_num = num_neg
self.n_gpu = n_gpu
self.all_qids = list(self.qid2gold.keys())
assert len(self.qid2gold) == len(self.qid2neg)
self.q_num_per_epoch = len(self.qid2gold) - len(self.qid2gold) % self.n_gpu
self._num_samples = self.q_num_per_epoch * (self.neg_num + 1)
def __len__(self):
return self._num_samples
def __iter__(self):
sample_indice = []
random.shuffle(self.all_qids)
# when use shared-normalization, passages for each question should be on the same GPU
qids_to_use = self.all_qids[:self.q_num_per_epoch]
for qid in qids_to_use:
neg_samples = self.qid2neg[qid]
random.shuffle(neg_samples)
sample_indice += self.qid2gold[qid]
sample_indice += neg_samples[:self.neg_num]
return iter(sample_indice)
def qa_collate(samples, pad_id=0):
if len(samples) == 0:
return {}
batch = {
'input_ids': collate_tokens([s["encodings"]['input_ids'] for s in samples], pad_id),
'attention_mask': collate_tokens([s["encodings"]['attention_mask'] for s in samples], 0),
'paragraph_mask': collate_tokens([s['paragraph_mask'] for s in samples], 0),
'label': collate_tokens([s["label"] for s in samples], -1),
"sent_offsets": collate_tokens([s["sent_offsets"] for s in samples], 0),
}
# training labels
if "starts" in samples[0]:
batch["starts"] = collate_tokens([s['starts'] for s in samples], -1)
batch["ends"] = collate_tokens([s['ends'] for s in samples], -1)
# batch["ans_types"] = collate_tokens([s['ans_type'] for s in samples], -1)
batch["sent_labels"] = collate_tokens([s['sent_labels'] for s in samples], 0)
batch["ans_covered"] = collate_tokens([s['ans_covered'] for s in samples], 0)
# roberta does not use token_type_ids
if "token_type_ids" in samples[0]["encodings"]:
batch["token_type_ids"] = collate_tokens([s["encodings"]['token_type_ids']for s in samples], 0)
batched = {
"qids": [s["qid"] for s in samples],
"passages": [s["passages"] for s in samples],
"gold_answer": [s["gold_answer"] for s in samples],
"sp_gold": [s["sp_gold"] for s in samples],
"para_offsets": [s["para_offset"] for s in samples],
"net_inputs": batch,
}
# for answer extraction
if "doc_tokens" in samples[0]:
batched["doc_tokens"] = [s["doc_tokens"] for s in samples]
batched["tok_to_orig_index"] = [s["tok_to_orig_index"] for s in samples]
batched["wp_tokens"] = [s["wp_tokens"] for s in samples]
return batched
================================================
FILE: mdr/qa/qa_model.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from transformers import AutoModel, BertModel
import torch.nn as nn
from torch.nn import CrossEntropyLoss
import torch
import torch.nn.functional as F
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class QAModel(nn.Module):
def __init__(self,
config,
args
):
super().__init__()
self.model_name = args.model_name
self.sp_weight = args.sp_weight
self.sp_pred = args.sp_pred
self.encoder = AutoModel.from_pretrained(args.model_name)
if "electra" in args.model_name:
self.pooler = BertPooler(config)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.rank = nn.Linear(config.hidden_size, 1) # noan
if self.sp_pred:
self.sp = nn.Linear(config.hidden_size, 1)
self.loss_fct = CrossEntropyLoss(ignore_index=-1, reduction="none")
def forward(self, batch):
outputs = self.encoder(batch['input_ids'], batch['attention_mask'], batch.get('token_type_ids', None))
if "electra" in self.model_name:
sequence_output = outputs[0]
pooled_output = self.pooler(sequence_output)
else:
sequence_output, pooled_output = outputs[0], outputs[1]
logits = self.qa_outputs(sequence_output)
outs = [o.squeeze(-1) for o in logits.split(1, dim=-1)]
outs = [o.float().masked_fill(batch["paragraph_mask"].ne(1), float("-inf")).type_as(o) for o in outs]
start_logits, end_logits = outs[0], outs[1]
rank_score = self.rank(pooled_output)
if self.sp_pred:
gather_index = batch["sent_offsets"].unsqueeze(2).expand(-1, -1, sequence_output.size()[-1])
sent_marker_rep = torch.gather(sequence_output, 1, gather_index)
sp_score = self.sp(sent_marker_rep).squeeze(2)
else:
sp_score = None
if self.training:
rank_target = batch["label"]
if self.sp_pred:
sp_loss = F.binary_cross_entropy_with_logits(sp_score, batch["sent_labels"].float(), reduction="none")
sp_loss = (sp_loss * batch["sent_offsets"]) * batch["label"]
sp_loss = sp_loss.sum()
start_positions, end_positions = batch["starts"], batch["ends"]
rank_loss = F.binary_cross_entropy_with_logits(rank_score, rank_target.float(), reduction="sum")
start_losses = [self.loss_fct(start_logits, starts) for starts in torch.unbind(start_positions, dim=1)]
end_losses = [self.loss_fct(end_logits, ends) for ends in torch.unbind(end_positions, dim=1)]
loss_tensor = torch.cat([t.unsqueeze(1) for t in start_losses], dim=1) + torch.cat([t.unsqueeze(1) for t in end_losses], dim=1)
log_prob = - loss_tensor
log_prob = log_prob.float().masked_fill(log_prob == 0, float('-inf')).type_as(log_prob)
marginal_probs = torch.sum(torch.exp(log_prob), dim=1)
m_prob = [marginal_probs[idx] for idx in marginal_probs.nonzero()]
if len(m_prob) == 0:
span_loss = self.loss_fct(start_logits, start_logits.new_zeros(
start_logits.size(0)).long()-1).sum()
else:
span_loss = - torch.log(torch.cat(m_prob)).sum()
if self.sp_pred:
loss = rank_loss + span_loss + sp_loss * self.sp_weight
else:
loss = rank_loss + span_loss
return loss.unsqueeze(0)
return {
'start_logits': start_logits,
'end_logits': end_logits,
'rank_score': rank_score,
"sp_score": sp_score
}
================================================
FILE: mdr/qa/qa_trainer.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import json
import os
import os.path as osp
import random
from functools import partial
from pathlib import Path
from typing import NamedTuple, Optional
import collections
from torch.optim import lr_scheduler
from tqdm import tqdm
import apex
import attr
import numpy as np
import submitit
import torch
import torch.distributed
import torch.nn as nn
import torch.optim as optim
from apex import amp
from torch.utils.tensorboard import SummaryWriter
from transformers import (AdamW, AutoConfig, AutoTokenizer,
get_linear_schedule_with_warmup)
from config import ClusterConfig
from hotpot_evaluate_v1 import exact_match_score, f1_score, update_sp
from qa_model import QAModel
from reranking_datasets import RankingDataset, rank_collate, MhopSampler
from utils import AverageMeter, move_to_cuda, get_final_text
apex.amp.register_half_function(torch, 'einsum')
@attr.s(auto_attribs=True)
class TrainerState:
"""
Contains the state of the Trainer.
It can be saved to checkpoint the training and loaded to resume it.
"""
epoch: int
model: nn.Module
optimizer: optim.Optimizer
lr_scheduler: torch.optim.lr_scheduler._LRScheduler
global_step: int
def save(self, filename: str) -> None:
data = attr.asdict(self)
# store only the state dict
data["model"] = self.model.state_dict()
data["optimizer"] = self.optimizer.state_dict()
data["lr_scheduler"] = self.lr_scheduler.state_dict()
torch.save(data, filename)
@classmethod
def load(cls, filename: str, default: "TrainerState", gpu: int) -> "TrainerState":
data = torch.load(filename, map_location=lambda storage, loc: storage.cuda(gpu))
# We need this default to load the state dict
model = default.model
model.load_state_dict(data["model"])
data["model"] = model
optimizer = default.optimizer
optimizer.load_state_dict(data["optimizer"])
data["optimizer"] = optimizer
lr_scheduler = default.lr_scheduler
lr_scheduler.load_state_dict(data["lr_scheduler"])
data["lr_scheduler"] = lr_scheduler
return cls(**data)
class Trainer:
def __init__(self, train_cfg: NamedTuple, cluster_cfg: ClusterConfig) -> None:
self._train_cfg = train_cfg
self._cluster_cfg = cluster_cfg
def __call__(self) -> Optional[float]:
"""
Called by submitit for each task.
:return: The master task return the final accuracy of the model.
"""
self._setup_process_group()
self._init_state()
final_acc = self._train()
return final_acc
def log(self, log_data: dict):
job_env = submitit.JobEnvironment()
# z = {**vars(self._train_cfg), **log_data}
save_dir = Path(self._train_cfg.output_dir)
os.makedirs(save_dir, exist_ok=True)
with open(save_dir / 'log.txt', 'a') as f:
f.write(json.dumps(log_data) + '\n')
def checkpoint(self, rm_init=True) -> submitit.helpers.DelayedSubmission:
# will be called by submitit in case of preemption
job_env = submitit.JobEnvironment()
save_dir = osp.join(self._train_cfg.output_dir, str(job_env.job_id))
os.makedirs(save_dir, exist_ok=True)
self._state.save(osp.join(save_dir, "checkpoint.pth"))
# Trick here: when the job will be requeue, we will use the same init file
# but it must not exist when we initialize the process group
# so we delete it, but only when this method is called by submitit for requeue
if rm_init and osp.exists(self._cluster_cfg.dist_url[7:]):
os.remove(self._cluster_cfg.dist_url[7:]) # remove file:// at the beginning
# This allow to remove any non-pickable part of the Trainer instance.
empty_trainer = Trainer(self._train_cfg, self._cluster_cfg)
return submitit.helpers.DelayedSubmission(empty_trainer)
def _setup_process_group(self) -> None:
job_env = submitit.JobEnvironment()
torch.cuda.set_device(job_env.local_rank)
torch.distributed.init_process_group(
backend=self._cluster_cfg.dist_backend,
init_method=self._cluster_cfg.dist_url,
world_size=job_env.num_tasks,
rank=job_env.global_rank,
)
print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
def _init_state(self) -> None:
"""
Initialize the state and load it from an existing checkpoint if any
"""
job_env = submitit.JobEnvironment()
if job_env.global_rank == 0:
# config_path = Path(args.save_folder) / str(job_env.job_id) / 'config.json'
os.makedirs(self._train_cfg.output_dir, exist_ok=True)
config_path = Path(self._train_cfg.output_dir) / 'config.json'
with open(config_path, "w") as g:
g.write(json.dumps(self._train_cfg._asdict()))
print(f"Setting random seed {self._train_cfg.seed}", flush=True)
random.seed(self._train_cfg.seed)
np.random.seed(self._train_cfg.seed)
torch.manual_seed(self._train_cfg.seed)
print("Create data loaders", flush=True)
tokenizer = AutoTokenizer.from_pretrained(self._train_cfg.model_name)
collate_fc = partial(rank_collate, pad_id=tokenizer.pad_token_id)
train_set = RankingDataset(tokenizer, self._train_cfg.train_file, self._train_cfg.max_seq_len, self._train_cfg.max_q_len, train=True)
train_sampler = MhopSampler(train_set, num_neg=self._train_cfg.neg_num)
batch_size_per_gpu = (1 + self._train_cfg.neg_num) * self._train_cfg.num_q_per_gpu
n_gpu = torch.cuda.device_count()
print(f"Number of GPUs: {n_gpu}", flush=True)
print(f"Batch size per node: {batch_size_per_gpu * n_gpu}", flush=True)
self._train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size_per_gpu * n_gpu, num_workers=self._train_cfg.num_workers, collate_fn=collate_fc, sampler=train_sampler)
test_set = RankingDataset(tokenizer, self._train_cfg.predict_file, self._train_cfg.max_seq_len, self._train_cfg.max_q_len)
self._test_loader = torch.utils.data.DataLoader(
test_set,
batch_size=self._train_cfg.predict_batch_size,
num_workers=self._train_cfg.num_workers, collate_fn=collate_fc
)
print("Create model", flush=True)
print(f"Local rank {job_env.local_rank}", flush=True)
bert_config = AutoConfig.from_pretrained(self._train_cfg.model_name)
model = QAModel(bert_config, self._train_cfg)
model.cuda(job_env.local_rank)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(
nd in n for nd in no_decay)], 'weight_decay': self._train_cfg.weight_decay},
{'params': [p for n, p in model.named_parameters() if any(
nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
if self._train_cfg.use_adam:
optimizer = optim.Adam(optimizer_parameters, lr=self._train_cfg.learning_rate)
else:
optimizer = AdamW(optimizer_parameters, lr=self._train_cfg.learning_rate)
# lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)
if self._train_cfg.fp16:
model, optimizer = amp.initialize(
model, optimizer, opt_level=self._train_cfg.fp16_opt_level)
t_total = len(self._train_loader) // self._train_cfg.gradient_accumulation_steps * self._train_cfg.num_train_epochs
warmup_steps = t_total * self._train_cfg.warmup_ratio
lr_scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total
)
model = torch.nn.DataParallel(model)
self._state = TrainerState(
epoch=0, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, global_step=0
)
self.tb_logger = SummaryWriter(self._train_cfg.output_dir.replace("logs", "tflogs"))
checkpoint_fn = osp.join(self._train_cfg.output_dir, str(job_env.job_id), "checkpoint.pth")
# checkpoint_fn = osp.join(self._train_cfg.output_dir, "checkpoint.pth")
if os.path.isfile(checkpoint_fn):
print(f"Load existing checkpoint from {checkpoint_fn}", flush=True)
self._state = TrainerState.load(
checkpoint_fn, default=self._state, gpu=job_env.local_rank)
def _train(self) -> Optional[float]:
job_env = submitit.JobEnvironment()
batch_step = 0 # forward batch count
best_metric = 0
train_loss_meter = AverageMeter()
print(f"Start training", flush=True)
# Start from the loaded epoch
start_epoch = self._state.epoch
global_step = self._state.global_step
for epoch in range(start_epoch, self._train_cfg.num_train_epochs):
print(f"Start epoch {epoch}", flush=True)
self._state.model.train()
self._state.epoch = epoch
for batch in self._train_loader:
batch_step += 1
batch_inputs = move_to_cuda(batch["net_inputs"])
loss = self._state.model(batch_inputs)
if torch.cuda.device_count() > 1:
loss = loss.mean()
if self._train_cfg.gradient_accumulation_steps > 1:
loss = loss / self._train_cfg.gradient_accumulation_steps
if self._train_cfg.fp16:
with amp.scale_loss(loss, self._state.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
train_loss_meter.update(loss.item())
if (batch_step + 1) % self._train_cfg.gradient_accumulation_steps == 0:
if self._train_cfg.fp16:
torch.nn.utils.clip_grad_norm_(
amp.master_params(self._state.optimizer), self._train_cfg.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(
self._state.model.parameters(), self._train_cfg.max_grad_norm)
self._state.optimizer.step()
self._state.lr_scheduler.step()
self._state.model.zero_grad()
global_step += 1
self._state.global_step = global_step
self.tb_logger.add_scalar('batch_train_loss',
loss.item(), global_step)
self.tb_logger.add_scalar('smoothed_train_loss',
train_loss_meter.avg, global_step)
if job_env.global_rank == 0:
if self._train_cfg.eval_period != -1 and global_step % self._train_cfg.eval_period == 0:
metrics = self._eval()
for k, v in metrics.items():
self.tb_logger.add_scalar(k, v*100, global_step)
score = metrics[self._train_cfg.final_metric]
if best_metric < score:
print("Saving model with best %s %.2f -> em %.2f" % (self._train_cfg.final_metric, best_metric*100, score*100), flush=True)
torch.save(self._state.model.state_dict(), os.path.join(self._train_cfg.output_dir, f"checkpoint_best.pt"))
best_metric = score
# Checkpoint only on the master
if job_env.global_rank == 0:
self.checkpoint(rm_init=False)
metrics = self._eval()
for k, v in metrics.items():
self.tb_logger.add_scalar(k, v*100, global_step)
score = metrics[self._train_cfg.final_metric]
if best_metric < score:
print("Saving model with best %s %.2f -> em %.2f" % (self._train_cfg.final_metric, best_metric*100, score*100), flush=True)
torch.save(self._state.model.state_dict(), os.path.join(self._train_cfg.output_dir, f"checkpoint_best.pt"))
best_metric = score
self.log({
"best_score": best_metric,
"curr_score": score,
"smoothed_loss": train_loss_meter.avg,
"epoch": epoch
})
return best_metric
def _eval(self) -> dict:
print("Start evaluation of the model", flush=True)
job_env = submitit.JobEnvironment()
args = self._train_cfg
eval_dataloader = self._test_loader
model = self._state.model
model.eval()
id2result = collections.defaultdict(list)
id2answer = collections.defaultdict(list)
id2gold = {}
id2goldsp = {}
for batch in tqdm(eval_dataloader):
batch_to_feed = move_to_cuda(batch["net_inputs"])
batch_qids = batch["qids"]
batch_labels = batch["net_inputs"]["label"].view(-1).tolist()
with torch.no_grad():
outputs = model(batch_to_feed)
scores = outputs["rank_score"]
scores = scores.view(-1).tolist()
sp_scores = outputs["sp_score"]
sp_scores = sp_scores.float().masked_fill(batch_to_feed["sent_offsets"].eq(0), float("-inf")).type_as(sp_scores)
batch_sp_scores = sp_scores.sigmoid()
# ans_type_predicted = torch.argmax(outputs["ans_type_logits"], dim=1).view(-1).tolist()
outs = [outputs["start_logits"], outputs["end_logits"]]
for qid, label, score in zip(batch_qids, batch_labels, scores):
id2result[qid].append((label, score))
# answer prediction
span_scores = outs[0][:, :, None] + outs[1][:, None]
max_seq_len = span_scores.size(1)
span_mask = np.tril(np.triu(np.ones((max_seq_len, max_seq_len)), 0), args.max_ans_len)
span_mask = span_scores.data.new(max_seq_len, max_seq_len).copy_(torch.from_numpy(span_mask))
span_scores_masked = span_scores.float().masked_fill((1 - span_mask[None].expand_as(span_scores)).bool(), -1e10).type_as(span_scores)
start_position = span_scores_masked.max(dim=2)[0].max(dim=1)[1]
end_position = span_scores_masked.max(dim=2)[1].gather(
1, start_position.unsqueeze(1)).squeeze(1)
answer_scores = span_scores_masked.max(dim=2)[0].max(dim=1)[0].tolist()
para_offset = batch['para_offsets']
start_position_ = list(
np.array(start_position.tolist()) - np.array(para_offset))
end_position_ = list(
np.array(end_position.tolist()) - np.array(para_offset))
for idx, qid in enumerate(batch_qids):
id2gold[qid] = batch["gold_answer"][idx]
id2goldsp[qid] = batch["sp_gold"][idx]
rank_score = scores[idx]
sp_score = batch_sp_scores[idx].tolist()
start = start_position_[idx]
end = end_position_[idx]
span_score = answer_scores[idx]
tok_to_orig_index = batch['tok_to_orig_index'][idx]
doc_tokens = batch['doc_tokens'][idx]
wp_tokens = batch['wp_tokens'][idx]
orig_doc_start = tok_to_orig_index[start]
orig_doc_end = tok_to_orig_index[end]
orig_tokens = doc_tokens[orig_doc_start:(orig_doc_end + 1)]
tok_tokens = wp_tokens[start:end+1]
tok_text = " ".join(tok_tokens)
tok_text = tok_text.replace(" ##", "")
tok_text = tok_text.replace("##", "")
tok_text = tok_text.strip()
tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens)
pred_str = get_final_text(tok_text, orig_text, do_lower_case=True, verbose_logging=False)
pred_sp = []
passages = batch["passages"][idx]
for passage, sent_offset in zip(passages, [0, len(passages[0]["sents"])]):
for idx, _ in enumerate(passage["sents"]):
try:
if sp_score[idx + sent_offset] > 0.5:
pred_sp.append([passage["title"], idx])
except:
continue
id2answer[qid].append((pred_str.strip(), rank_score, span_score, pred_sp))
acc = []
for qid, res in id2result.items():
res.sort(key=lambda x: x[1], reverse=True)
acc.append(res[0][0] == 1)
print(f"evaluated {len(id2result)} questions...", flush=True)
print(f'chain ranking em: {np.mean(acc)}', flush=True)
best_em, best_f1, best_joint_em, best_joint_f1 = 0, 0, 0, 0
lambdas = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
for lambda_ in lambdas:
ems, f1s = [], []
sp_ems, sp_f1s = [], []
joint_ems, joint_f1s = [], []
for qid, res in id2result.items():
ans_res = id2answer[qid]
ans_res.sort(key=lambda x: lambda_ * x[1] + (1 - lambda_) * x[2], reverse=True)
top_pred = ans_res[0][0]
ems.append(exact_match_score(top_pred, id2gold[qid][0]))
f1, prec, recall = f1_score(top_pred, id2gold[qid][0])
f1s.append(f1)
top_pred_sp = ans_res[0][3]
metrics = {'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0}
update_sp(metrics, top_pred_sp, id2goldsp[qid])
sp_ems.append(metrics['sp_em'])
sp_f1s.append(metrics['sp_f1'])
# joint metrics
joint_prec = prec * metrics["sp_prec"]
joint_recall = recall * metrics["sp_recall"]
if joint_prec + joint_recall > 0:
joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall)
else:
joint_f1 = 0
joint_em = ems[-1] * sp_ems[-1]
joint_ems.append(joint_em)
joint_f1s.append(joint_f1)
if best_joint_f1 < np.mean(joint_f1s):
best_joint_f1 = np.mean(joint_f1s)
best_joint_em = np.mean(joint_ems)
best_f1 = np.mean(f1s)
best_em = np.mean(ems)
print(f".......Using combination factor {lambda_}......", flush=True)
print(f'answer em: {np.mean(ems)}, count: {len(ems)}', flush=True)
print(f'answer f1: {np.mean(f1s)}, count: {len(f1s)}', flush=True)
print(f'sp em: {np.mean(sp_ems)}, count: {len(sp_ems)}', flush=True)
print(f'sp f1: {np.mean(sp_f1s)}, count: {len(sp_f1s)}', flush=True)
print(f'joint em: {np.mean(joint_ems)}, count: {len(joint_ems)}', flush=True)
print(f'joint f1: {np.mean(joint_f1s)}, count: {len(joint_f1s)}', flush=True)
print(f"Best joint EM/F1 from combination {best_em}/{best_f1}", flush=True)
model.train()
return {"em": best_em, "f1": best_f1, "joint_em": best_joint_em, "joint_f1": best_joint_f1}
================================================
FILE: mdr/qa/train.md
================================================
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train_qa.py \
--do_train \
--prefix qa_wwm_bert_title_mark_eval_debug \
--predict_batch_size 512 \
--model_name bert-large-uncased-whole-word-masking \
--train_batch_size 80 \
--learning_rate 3e-5 \
--fp16 \
--train_file /private/home/xwhan/data/hotpot/dense_train_b10_top20_outputs.json \
--predict_file /private/home/xwhan/data/hotpot/dense_val_outputs.json \
--seed 3 \
--eval-period 10 \
--max_seq_len 512 \
--max_q_len 100 \
--gradient_accumulation_steps 8 \
--neg-num 4
# spanbert debug, fp16 does not work
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train_qa.py \
--do_train \
--prefix ranked_spanbert_debug \
--predict_batch_size 1024 \
--model_name spanbert \
--train_batch_size 48 \
--learning_rate 3e-5 \
--train_file /private/home/xwhan/data/hotpot/dense_train_b10_top20_outputs_sents.json \
--predict_file /private/home/xwhan/data/hotpot/dense_val_outputs_sents.json \
--seed 3 \
--eval-period 500 \
--max_seq_len 512 \
--max_q_len 64 \
--gradient_accumulation_steps 8 \
--neg-num 5 \
--use-adam
# test electra
CUDA_VISIBLE_DEVICES=0 python train_qa.py \
--do_train \
--prefix electra_large_debug_sn \
--predict_batch_size 1024 \
--model_name google/electra-large-discriminator \
--train_batch_size 12 \
--learning_rate 5e-5 \
--train_file /private/home/xwhan/data/hotpot/dense_train_b100_k100_sents.json \
--predict_file /private/home/xwhan/data/hotpot/dense_val_b30_k30_roberta_sents.json \
--seed 42 \
--eval-period 250 \
--max_seq_len 512 \
--max_q_len 64 \
--gradient_accumulation_steps 8 \
--neg-num 11 \
--fp16 \
--use-adam \
--warmup-ratio 0.1 \
--sp-weight 0.05 \
--sp-pred \
--shared-norm
# QA evaluation
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train_qa.py \
--do_predict \
--predict_batch_size 2000 \
--model_name google/electra-large-discriminator \
--fp16 \
--predict_file /private/home/xwhan/data/hotpot/dense_val_b100_k100_roberta_best_sents.json \
--max_seq_len 512 \
--max_q_len 64 \
--init_checkpoint qa/logs/08-10-2020/electra_val_top30-epoch7-lr5e-05-seed42-rdrop0-qadrop0-decay0-qpergpu2-aggstep8-clip2-evalper250-evalbsize1024-negnum5-warmup0.1-adamTrue-spweight0.025/checkpoint_best.pt \
--sp-pred \
--max_ans_len 30 \
--save-prediction hotpot_val_top100.json
# QA evaluation with wwm
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train_qa.py \
--do_predict \
--predict_batch_size 1024 \
--model_name bert-large-uncased-whole-word-masking \
--fp16 \
--predict_file /private/home/xwhan/data/hotpot/dense_hotpot_val_b250_k250_roberta_best_sents.json \
--max_seq_len 512 \
--max_q_len 64 \
--init_checkpoint qa/logs/08-17-2020/wwm_val_top50-epoch7-lr5e-05-seed42-rdrop0-qadrop0-decay0-qpergpu2-aggstep8-clip2-evalper250-evalbsize1024-negnum5-warmup0.2-adamTrue-spweight0.025-snFalse/checkpoint_best.pt \
--sp-pred \
--max_ans_len 30 \
--save-prediction hotpot_val_wwm_top250.json
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train_qa.py \
--do_predict \
--predict_batch_size 1024 \
--model_name google/electra-large-discriminator \
--fp16 \
--predict_file /private/home/xwhan/data/hotpot/dense_val_b50_k50_roberta_best_sents.json \
--max_seq_len 512 \
--max_q_len 64 \
--init_checkpoint qa/logs/08-10-2020/electra_val_top30-epoch7-lr5e-05-seed42-rdrop0-qadrop0-decay0-qpergpu2-aggstep8-clip2-evalper250-evalbsize1024-negnum5-warmup0.1-adamTrue-spweight0.025/checkpoint_best.pt \
--sp-pred \
--max_ans_len 30 \
--save-prediction hotpot_val_b5_k5.json \
srun --gres=gpu:8 --partition learnfair --time=48:00:00 --mem 500G --constraint volta32gb --cpus-per-task 80 --pty /bin/bash -l
================================================
FILE: mdr/qa/train_ranker.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import collections
import json
import logging
import os
import random
from datetime import date
from functools import partial
import copy
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam
from tqdm import tqdm
from transformers import (AdamW, AutoConfig, AutoTokenizer,
get_linear_schedule_with_warmup)
from config import train_args
from reranking_datasets import RankingDataset, rank_collate
from reranking_model import RankModel
from utils import AverageMeter, convert_to_half, move_to_cuda
def load_saved(model, path):
state_dict = torch.load(path)
def filter(x): return x[7:] if x.startswith('module.') else x
state_dict = {filter(k): v for (k, v) in state_dict.items()}
model.load_state_dict(state_dict)
return model
def main():
args = train_args()
if args.fp16:
import apex
apex.amp.register_half_function(torch, 'einsum')
date_curr = date.today().strftime("%m-%d-%Y")
model_name = f"{args.prefix}-seed{args.seed}-bsz{args.train_batch_size}-fp16{args.fp16}-lr{args.learning_rate}-decay{args.weight_decay}"
args.output_dir = os.path.join(args.output_dir, date_curr, model_name)
tb_logger = SummaryWriter(os.path.join(args.output_dir.replace("logs","tflogs")))
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
print(
f"output directory {args.output_dir} already exists and is not empty.")
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir, exist_ok=True)
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO,
handlers=[logging.FileHandler(os.path.join(args.output_dir, "log.txt")),
logging.StreamHandler()])
logger = logging.getLogger(__name__)
logger.info(args)
if args.local_rank == -1 or args.no_cuda:
device = torch.device(
"cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
n_gpu = torch.cuda.device_count()
else:
device = torch.device("cuda", args.local_rank)
n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl')
logger.info("device %s n_gpu %d distributed training %r",
device, n_gpu, bool(args.local_rank != -1))
if args.accumulate_gradients < 1:
raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
args.accumulate_gradients))
args.train_batch_size = int(
args.train_batch_size / args.accumulate_gradients)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
bert_config = AutoConfig.from_pretrained(args.model_name)
model = RankModel(bert_config, args)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
collate_fc = partial(rank_collate, pad_id=tokenizer.pad_token_id)
if args.do_train and args.max_seq_len > bert_config.max_position_embeddings:
raise ValueError(
"Cannot use sequence length %d because the BERT model "
"was only trained up to sequence length %d" %
(args.max_seq_len, bert_config.max_position_embeddings))
eval_dataset = RankingDataset(
tokenizer, args.predict_file, args.max_seq_len, args.max_q_len)
eval_dataloader = DataLoader(
eval_dataset, batch_size=args.predict_batch_size, collate_fn=collate_fc, pin_memory=True, num_workers=args.num_workers)
logger.info(f"Num of dev batches: {len(eval_dataloader)}")
if args.init_checkpoint != "":
model = load_saved(model, args.init_checkpoint)
model.to(device)
print(f"number of trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
if args.do_train:
no_decay = ['bias', 'LayerNorm.weight']
optimizer_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(
nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
{'params': [p for n, p in model.named_parameters() if any(
nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_parameters,
lr=args.learning_rate, eps=args.adam_epsilon)
if args.fp16:
from apex import amp
model, optimizer = amp.initialize(
model, optimizer, opt_level=args.fp16_opt_level)
else:
if args.fp16:
from apex import amp
model = amp.initialize(model, opt_level=args.fp16_opt_level)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
output_device=args.local_rank)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
if args.do_train:
global_step = 0 # gradient update step
batch_step = 0 # forward batch count
best_acc = 0
train_loss_meter = AverageMeter()
model.train()
train_dataset = RankingDataset(tokenizer, args.train_file, args.max_seq_len, args.max_q_len, train=True)
train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, pin_memory=True, collate_fn=collate_fc, num_workers=args.num_workers, shuffle=True)
logger.info('Start training....')
for epoch in range(int(args.num_train_epochs)):
for batch in tqdm(train_dataloader):
batch_step += 1
batch_inputs = move_to_cuda(batch["net_inputs"])
loss = model(batch_inputs)
if n_gpu > 1:
loss = loss.mean()
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
train_loss_meter.update(loss.item())
tb_logger.add_scalar('batch_train_loss',
loss.item(), global_step)
tb_logger.add_scalar('smoothed_train_loss',
train_loss_meter.avg, global_step)
if (batch_step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
torch.nn.utils.clip_grad_norm_(
amp.master_params(optimizer), args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(
model.parameters(), args.max_grad_norm)
optimizer.step() # We have accumulated enought gradients
model.zero_grad()
global_step += 1
if args.eval_period != -1 and global_step % args.eval_period == 0:
acc = predict(args, model, eval_dataloader,
device, logger)
logger.info("Step %d Train loss %.2f acc %.2f on epoch=%d" % (global_step, train_loss_meter.avg, acc*100, epoch))
# save most recent model
torch.save(model.state_dict(), os.path.join(
args.output_dir, f"checkpoint_last.pt"))
if best_acc < acc:
logger.info("Saving model with best acc %.2f -> acc %.2f on epoch=%d" %
(best_acc*100, acc*100, epoch))
torch.save(model.state_dict(), os.path.join(
args.output_dir, f"checkpoint_best.pt"))
model = model.to(device)
best_acc = acc
acc = predict(args, model, eval_dataloader, device, logger)
logger.info("Step %d Train loss %.2f acc %.2f on epoch=%d" % (
global_step, train_loss_meter.avg, acc*100, epoch))
tb_logger.add_scalar('dev_acc', acc*100, epoch)
torch.save(model.state_dict(), os.path.join(args.output_dir, f"checkpoint_last.pt"))
if best_acc < acc:
logger.info("Saving model with best acc %.2f -> acc %.2f on epoch=%d" % (best_acc*100, acc*100, epoch))
torch.save(model.state_dict(), os.path.join(
args.output_dir, f"checkpoint_best.pt"))
best_acc = acc
logger.info("Training finished!")
elif args.do_predict:
acc = predict(args, model, eval_dataloader, device, logger)
logger.info(f"test performance {acc}")
def predict(args, model, eval_dataloader, device, logger):
model.eval()
id2result = collections.defaultdict(list)
for batch in tqdm(eval_dataloader):
batch_to_feed = move_to_cuda(batch["net_inputs"])
batch_qids = batch["qids"]
batch_labels = batch["net_inputs"]["label"].view(-1).tolist()
with torch.no_grad():
scores = model(batch_to_feed)
scores = scores.view(-1).tolist()
for qid, label, score in zip(batch_qids, batch_labels, scores):
id2result[qid].append((label, score))
acc = []
top_pred = {}
for qid, res in id2result.items():
res.sort(key=lambda x: x[1], reverse=True)
acc.append(res[0][0] == 1)
logger.info(f"evaluated {len(id2result)} questions...")
logger.info(f'acc: {np.mean(acc)}')
model.train()
return np.mean(acc)
if __name__ == "__main__":
main()
================================================
FILE: mdr/qa/utils.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import sqlite3
import unicodedata
import collections
import logging
import re
def set_global_logging_level(level=logging.ERROR, prefices=[""]):
"""
Override logging levels of different modules based on their name as a prefix.
It needs to be invoked after the modules have been loaded so that their loggers have been initialized.
Args:
- level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR
- prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional.
Default is `[""]` to match all active loggers.
The match is a case-sensitive `module_name.startswith(prefix)`
"""
prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })')
for name in logging.root.manager.loggerDict:
if re.match(prefix_re, name):
logging.getLogger(name).setLevel(level)
def load_saved(model, path, exact=True):
try:
state_dict = torch.load(path)
except:
state_dict = torch.load(path, map_location=torch.device('cpu'))
def filter(x): return x[7:] if x.startswith('module.') else x
if exact:
state_dict = {filter(k): v for (k, v) in state_dict.items()}
else:
state_dict = {filter(k): v for (
k, v) in state_dict.items() if filter(k) in model.state_dict()}
model.load_state_dict(state_dict)
return model
def move_to_cuda(sample):
if len(sample) == 0:
return {}
def _move_to_cuda(maybe_tensor):
if torch.is_tensor(maybe_tensor):
return maybe_tensor.cuda()
elif isinstance(maybe_tensor, dict):
return {
key: _move_to_cuda(value)
for key, value in maybe_tensor.items()
}
elif isinstance(maybe_tensor, list):
return [_move_to_cuda(x) for x in maybe_tensor]
else:
return maybe_tensor
return _move_to_cuda(sample)
def convert_to_half(sample):
if len(sample) == 0:
return {}
def _convert_to_half(maybe_floatTensor):
if torch.is_tensor(maybe_floatTensor) and maybe_floatTensor.type() == "torch.FloatTensor":
return maybe_floatTensor.half()
elif isinstance(maybe_floatTensor, dict):
return {
key: _convert_to_half(value)
for key, value in maybe_floatTensor.items()
}
elif isinstance(maybe_floatTensor, list):
return [_convert_to_half(x) for x in maybe_floatTensor]
else:
return maybe_floatTensor
return _convert_to_half(sample)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def normalize(text):
"""Resolve different type of unicode encodings."""
return unicodedata.normalize('NFD', text)
def para_has_answer(answer, para, tokenizer):
text = normalize(para)
tokens = tokenizer.tokenize(text)
text = tokens.words(uncased=True)
assert len(text) == len(tokens)
for single_answer in answer:
single_answer = normalize(single_answer)
single_answer = tokenizer.tokenize(single_answer)
single_answer = single_answer.words(uncased=True)
for i in range(0, len(text) - len(single_answer) + 1):
if single_answer == text[i: i + len(single_answer)]:
return True
return False
def match_answer_span(p, answer, tokenizer, match="string"):
# p has been normalized
if match == 'string':
tokens = tokenizer.tokenize(p)
text = tokens.words(uncased=True)
matched = set()
for single_answer in answer:
single_answer = normalize(single_answer)
single_answer = tokenizer.tokenize(single_answer)
single_answer = single_answer.words(uncased=True)
for i in range(0, len(text) - len(single_answer) + 1):
if single_answer == text[i: i + len(single_answer)]:
matched.add(tokens.slice(
i, i + len(single_answer)).untokenize())
return list(matched)
elif match == 'regex':
# Answer is a regex
single_answer = normalize(answer[0])
return regex_match(p, single_answer)
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
orig_answer_text):
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
for new_start in range(input_start, input_end + 1):
for new_end in range(input_end, new_start - 1, -1):
text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
if text_span == tok_answer_text:
return (new_start, new_end)
return (input_start, input_end)
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a peice of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
def find_ans_span_with_char_offsets(detected_ans, char_to_word_offset, doc_tokens, all_doc_tokens, orig_to_tok_index, tokenizer):
# could return mutiple spans for an answer string
ans_text = detected_ans["text"]
char_spans = detected_ans["char_spans"]
ans_subtok_spans = []
for char_start, char_end in char_spans:
tok_start = char_to_word_offset[char_start]
# char_end points to the last char of the answer, not one after
tok_end = char_to_word_offset[char_end]
sub_tok_start = orig_to_tok_index[tok_start]
if tok_end < len(doc_tokens) - 1:
sub_tok_end = orig_to_tok_index[tok_end + 1] - 1
else:
sub_tok_end = len(all_doc_tokens) - 1
actual_text = " ".join(doc_tokens[tok_start:(tok_end + 1)])
cleaned_answer_text = " ".join(whitespace_tokenize(ans_text))
if actual_text.find(cleaned_answer_text) == -1:
print("Could not find answer: '{}' vs. '{}'".format(
actual_text, cleaned_answer_text))
(sub_tok_start, sub_tok_end) = _improve_answer_span(
all_doc_tokens, sub_tok_start, sub_tok_end, tokenizer, ans_text)
ans_subtok_spans.append((sub_tok_start, sub_tok_end))
return ans_subtok_spans
import six
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
def get_final_text(pred_text, orig_text, do_lower_case=False, verbose_logging=True):
"""Project the tokenized prediction back to the original text."""
def _strip_spaces(text):
ns_chars = []
ns_to_s_map = collections.OrderedDict()
for (i, c) in enumerate(text):
if c == " ":
continue
ns_to_s_map[len(ns_chars)] = i
ns_chars.append(c)
ns_text = "".join(ns_chars)
return (ns_text, ns_to_s_map)
# We first tokenize `orig_text`, strip whitespace from the result
# and `pred_text`, and check if they are the same length. If they are
# NOT the same length, the heuristic has failed. If they are the same
# length, we assume the characters are one-to-one aligned.
tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
tok_text = " ".join(tokenizer.tokenize(orig_text))
start_position = tok_text.find(pred_text)
if start_position == -1:
if verbose_logging:
print(
"Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
return orig_text
end_position = start_position + len(pred_text) - 1
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
if len(orig_ns_text) != len(tok_ns_text):
if verbose_logging:
print("Length not equal after stripping spaces: '%s' vs '%s'",
orig_ns_text, tok_ns_text)
return orig_text
# We then project the characters in `pred_text` back to `orig_text` using
# the character-to-character alignment.
tok_s_to_ns_map = {}
for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
tok_s_to_ns_map[tok_index] = i
orig_start_position = None
if start_position in tok_s_to_ns_map:
ns_start_position = tok_s_to_ns_map[start_position]
if ns_start_position in orig_ns_to_s_map:
orig_start_position = orig_ns_to_s_map[ns_start_position]
if orig_start_position is None:
if verbose_logging:
print("Couldn't map start position")
return orig_text
orig_end_position = None
if end_position in tok_s_to_ns_map:
ns_end_position = tok_s_to_ns_map[end_position]
if ns_end_position in orig_ns_to_s_map:
orig_end_position = orig_ns_to_s_map[ns_end_position]
if orig_end_position is None:
if verbose_logging:
print("Couldn't map end position")
return orig_text
output_text = orig_text[orig_start_position:(orig_end_position + 1)]
return output_text
================================================
FILE: mdr/retrieval/__init__.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#!/usr/bin/env python
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from . import data
from . import models
from . import utils
================================================
FILE: mdr/retrieval/config.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from ast import parse
from typing import NamedTuple
class ClusterConfig(NamedTuple):
dist_backend: str
dist_url: str
def common_args():
parser = argparse.ArgumentParser()
# task
parser.add_argument("--train_file", type=str,
default="../data/nq-with-neg-train.txt")
parser.add_argument("--predict_file", type=str,
default="../data/nq-with-neg-dev.txt")
parser.add_argument("--num_workers", default=30, type=int)
parser.add_argument("--do_train", default=False,
action='store_true', help="Whether to run training.")
parser.add_argument("--do_predict", default=False,
action='store_true', help="Whether to run eval on the dev set.")
# model
parser.add_argument("--model_name",
default="bert-base-uncased", type=str)
parser.add_argument("--init_checkpoint", type=str,
help="Initial checkpoint (usually from a pre-trained BERT model).",
default="")
parser.add_argument("--max_c_len", default=512, type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded.")
parser.add_argument("--max_q_len", default=50, type=int,
help="The maximum number of tokens for the question. Questions longer than this will "
"be truncated to this length.")
parser.add_argument('--fp16', action='store_true')
parser.add_argument('--fp16_opt_level', type=str, default='O1',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument("--no_cuda", default=False, action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument("--local_rank", type=int, default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument("--max_q_sp_len", default=50, type=int)
parser.add_argument("--sent-level", action="store_true")
parser.add_argument("--rnn-retriever", action="store_true")
parser.add_argument("--predict_batch_size", default=512,
type=int, help="Total batch size for predictions.")
parser.add_argument("--shared-encoder", action="store_true")
# multi vector scheme
parser.add_argument("--multi-vector", type=int, default=1)
parser.add_argument("--scheme", type=str, help="how to get the multivector, layerwise or tokenwise", default="none")
# momentum
parser.add_argument("--momentum", action="store_true")
parser.add_argument("--init-retriever", type=str, default="")
parser.add_argument("--k", type=int, default=38400, help="memory bank size")
parser.add_argument("--m", type=float, default=0.999, help="momentum")
# NQ multihop trial
parser.add_argument("--nq-multi", action="store_true", help="train the NQ retrieval model to recover from error cases")
return parser
def train_args():
parser = common_args()
# optimization
parser.add_argument('--prefix', type=str, default="eval")
parser.add_argument("--weight_decay", default=0.0, type=float,
help="Weight decay if we apply some.")
parser.add_argument("--temperature", default=1, type=float)
parser.add_argument("--output_dir", default="./logs", type=str,
help="The output directory where the model checkpoints will be written.")
parser.add_argument("--train_batch_size", default=128,
type=int, help="Total batch size for training.")
parser.add_argument("--learning_rate", default=1e-5,
type=float, help="The initial learning rate for Adam.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
help="Epsilon for Adam optimizer.")
parser.add_argument("--num_train_epochs", default=50, type=float,
help="Total number of training epochs to perform.")
parser.add_argument("--save_checkpoints_steps", default=20000, type=int,
help="How often to save the model checkpoint.")
parser.add_argument("--iterations_per_loop", default=1000, type=int,
help="How many steps to make in each estimator call.")
parser.add_argument("--accumulate_gradients", type=int, default=1,
help="Number of steps to accumulate gradient on (divide the batch_size and accumulate)")
parser.add_argument('--seed', type=int, default=3,
help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumualte before performing a backward/update pass.")
parser.add_argument('--eval-period', type=int, default=2500)
parser.add_argument("--max_grad_norm", default=2.0, type=float, help="Max gradient norm.")
parser.add_argument("--stop-drop", default=0, type=float)
parser.add_argument("--use-adam", action="store_true")
parser.add_argument("--warmup-ratio", default=0, type=float, help="Linear warmup over warmup_steps.")
return parser.parse_args()
def encode_args():
parser = common_args()
parser.add_argument('--embed_save_path', type=str, default="")
parser.add_argument('--is_query_embed', action="store_true")
args = parser.parse_args()
return args
================================================
FILE: mdr/retrieval/criterions.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
# def loss_single(model, batch, momentum=False):
# outputs = model(batch)
# q = outputs['q']
# c = outputs['c']
# neg_c = outputs['neg_c']
# product_in_batch = torch.mm(q, c.t())
# product_neg = (q * neg_c).sum(-1).unsqueeze(1)
# product = torch.cat([product_in_batch, product_neg], dim=-1)
# if momentum:
# queue_c = model.module.encode_queue_ctx()
# product_queue = torch.mm(q, queue_c.t())
# product = torch.cat([product, product_queue], dim=-1)
# model.module.dequeue_and_enqueue(batch)
# target = torch.arange(product.size(0)).to(product.device)
# loss = F.cross_entropy(product, target)
# return loss
# """
# multi-hop retrieval for NQ, train the model to recover from
# """
# def loss_nq_mhop(model, batch, momentum=False):
# outputs = model(batch)
# product_in_batch = torch.mm(outputs['q'], outputs['c'].t())
# product_neg = (outputs['q'] * outputs['neg']).sum(-1).unsqueeze(1)
# # product_neg1 = (outputs['q'] * outputs['dense_neg1']).sum(-1).unsqueeze(1)
# # product_neg2 = (outputs['q'] * outputs['dense_neg2']).sum(-1).unsqueeze(1)
# scores1 = torch.cat([product_in_batch, product_neg], dim=-1)
# product_in_batch_from_error = torch.mm(outputs["q_neg1"], outputs['c'].t())
# dense_neg = torch.cat([outputs["dense_neg1"].unsqueeze(1), outputs["dense_neg2"].unsqueeze(1)], dim=1)
# product_neg_from_error = torch.bmm(outputs["q_neg1"].unsqueeze(1), dense_neg.transpose(1,2)).squeeze(1)
# scores2 = torch.cat([product_in_batch_from_error, product_neg_from_error], dim=-1)
# if momentum:
# queue_neg_scores_1 = torch.mm(outputs['q'], model.module.queue.clone().detach().t())
# queue_neg_scores_2 = torch.mm(outputs["q_neg1"], model.module.queue.clone().detach().t())
# scores1 = torch.cat([scores1, queue_neg_scores_1], dim=1)
# scores2 = torch.cat([scores2, queue_neg_scores_2], dim=1)
# model.module.dequeue_and_enqueue(outputs["c"].detach())
# # model.module.momentum_update_key_encoder()
# target = torch.arange(scores1.size(0)).to(scores1.device)
# loss = F.cross_entropy(scores1, target) + F.cross_entropy(scores2, target)
# # loss = F.cross_entropy(scores1, target)
# return loss
# def eval_nq_mhop(model, batch):
# outputs = model(batch)
# product_in_batch = torch.mm(outputs['q'], outputs['c'].t())
# product_neg = (outputs['q'] * outputs['neg']).sum(-1).unsqueeze(1)
# # product_neg1 = (outputs['q'] * outputs['dense_neg1']).sum(-1).unsqueeze(1)
# # product_neg2 = (outputs['q'] * outputs['dense_neg2']).sum(-1).unsqueeze(1)
# scores1 = torch.cat([product_in_batch, product_neg], dim=-1)
# product_in_batch_from_error = torch.mm(outputs["q_neg1"], outputs['c'].t())
# dense_neg = torch.cat([outputs["dense_neg1"].unsqueeze(1), outputs["dense_neg2"].unsqueeze(1)], dim=1)
# product_neg_from_error = torch.bmm(outputs["q_neg1"].unsqueeze(1), dense_neg.transpose(1,2)).squeeze(1)
# scores2 = torch.cat([product_in_batch_from_error, product_neg_from_error], dim=-1)
# target = torch.arange(scores1.size(0)).to(scores1.device)
# rrs, rrs_2hop = [], []
# ranked = scores1.argsort(dim=1, descending=True)
# ranked_2hop = scores2.argsort(dim=1, descending=True)
# idx2rank = ranked.argsort(dim=1)
# for idx, t in enumerate(target.tolist()):
# rrs.append(1 / (idx2rank[idx][t].item() +1))
# idx2rank2hop = ranked_2hop.argsort(dim=1)
# for idx, t in enumerate(target.tolist()):
# rrs_2hop.append(1 / (idx2rank2hop[idx][t].item() +1))
# return rrs, rrs_2hop
# def eval_vanilla(outputs):
# """
# view the two sp passages as the same, no multi-hop modeling;
# select the passages from all passages in the batch
# """
# rrs = []
# q = outputs['q']
# c1 = outputs['c1']
# c2 = outputs['c2']
# c = torch.cat([c1.unsqueeze(1), c2.unsqueeze(1)], dim=1) # B x 2 x D
# c = c.view(-1, q.size(-1)) # 2B x D
# product_in_batch = torch.mm(q, c.t()) # Bx2B
# neg_c = outputs['neg_c']
# product_neg = (q * neg_c).sum(-1).unsqueeze(1)
# product = torch.cat([product_in_batch, product_neg], dim=-1)
# target = torch.arange(product.size(0)).to(product.device).unsqueeze(1)
# target = torch.cat([target*2, target*2+1], dim=1)
# ranked = product.argsort(dim=1, descending=True)
# # MRR
# idx2rank = ranked.argsort(dim=1)
# for idx, t in enumerate(target):
# correct_idx = t.tolist()
# for _ in correct_idx:
# rrs.append(1 / (idx2rank[idx][_].item() + 1))
# return rrs
def mhop_loss(model, batch, args):
outputs = model(batch)
loss_fct = CrossEntropyLoss(ignore_index=-1)
all_ctx = torch.cat([outputs['c1'], outputs['c2']], dim=0)
neg_ctx = torch.cat([outputs["neg_1"].unsqueeze(1), outputs["neg_2"].unsqueeze(1)], dim=1) # B x 2 x M x h
scores_1_hop = torch.mm(outputs["q"], all_ctx.t())
neg_scores_1 = torch.bmm(outputs["q"].unsqueeze(1), neg_ctx.transpose(1,2)).squeeze(1)
scores_2_hop = torch.mm(outputs["q_sp1"], all_ctx.t())
neg_scores_2 = torch.bmm(outputs["q_sp1"].unsqueeze(1), neg_ctx.transpose(1,2)).squeeze(1)
# mask the 1st hop
bsize = outputs["q"].size(0)
scores_1_mask = torch.cat([torch.zeros(bsize, bsize), torch.eye(bsize)], dim=1).to(outputs["q"].device)
scores_1_hop = scores_1_hop.float().masked_fill(scores_1_mask.bool(), float('-inf')).type_as(scores_1_hop)
scores_1_hop = torch.cat([scores_1_hop, neg_scores_1], dim=1)
scores_2_hop = torch.cat([scores_2_hop, neg_scores_2], dim=1)
if args.momentum:
queue_neg_scores_1 = torch.mm(outputs["q"], model.module.queue.clone().detach().t())
queue_neg_scores_2 = torch.mm(outputs["q_sp1"], model.module.queue.clone().detach().t())
# queue_neg_scores_1 = queue_neg_scores_1 / args.temperature
# queue_neg_scores_2 = queue_neg_scores_2 / args.temperature
scores_1_hop = torch.cat([scores_1_hop, queue_neg_scores_1], dim=1)
scores_2_hop = torch.cat([scores_2_hop, queue_neg_scores_2], dim=1)
model.module.dequeue_and_enqueue(all_ctx.detach())
# model.module.momentum_update_key_encoder()
target_1_hop = torch.arange(outputs["q"].size(0)).to(outputs["q"].device)
target_2_hop = torch.arange(outputs["q"].size(0)).to(outputs["q"].device) + outputs["q"].size(0)
retrieve_loss = loss_fct(scores_1_hop, target_1_hop) + loss_fct(scores_2_hop, target_2_hop)
return retrieve_loss
def mhop_eval(outputs, args):
all_ctx = torch.cat([outputs['c1'], outputs['c2']], dim=0)
neg_ctx = torch.cat([outputs["neg_1"].unsqueeze(1), outputs["neg_2"].unsqueeze(1)], dim=1)
scores_1_hop = torch.mm(outputs["q"], all_ctx.t())
neg_scores_1 = torch.bmm(outputs["q"].unsqueeze(1), neg_ctx.transpose(1,2)).squeeze(1)
scores_2_hop = torch.mm(outputs["q_sp1"], all_ctx.t())
neg_scores_2 = torch.bmm(outputs["q_sp1"].unsqueeze(1), neg_ctx.transpose(1,2)).squeeze(1)
bsize = outputs["q"].size(0)
scores_1_mask = torch.cat([torch.zeros(bsize, bsize), torch.eye(bsize)], dim=1).to(outputs["q"].device)
scores_1_hop = scores_1_hop.float().masked_fill(scores_1_mask.bool(), float('-inf')).type_as(scores_1_hop)
scores_1_hop = torch.cat([scores_1_hop, neg_scores_1], dim=1)
scores_2_hop = torch.cat([scores_2_hop, neg_scores_2], dim=1)
target_1_hop = torch.arange(outputs["q"].size(0)).to(outputs["q"].device)
target_2_hop = torch.arange(outputs["q"].size(0)).to(outputs["q"].device) + outputs["q"].size(0)
ranked_1_hop = scores_1_hop.argsort(dim=1, descending=True)
ranked_2_hop = scores_2_hop.argsort(dim=1, descending=True)
idx2ranked_1 = ranked_1_hop.argsort(dim=1)
idx2ranked_2 = ranked_2_hop.argsort(dim=1)
rrs_1, rrs_2 = [], []
for t, idx2ranked in zip(target_1_hop, idx2ranked_1):
rrs_1.append(1 / (idx2ranked[t].item() + 1))
for t, idx2ranked in zip(target_2_hop, idx2ranked_2):
rrs_2.append(1 / (idx2ranked[t].item() + 1))
return {"rrs_1": rrs_1, "rrs_2": rrs_2}
def unified_loss(model, batch, args):
outputs = model(batch)
all_ctx = torch.cat([outputs['c1'], outputs['c2']], dim=0)
neg_ctx = torch.cat([outputs["neg_1"].unsqueeze(1), outputs["neg_2"].unsqueeze(1)], dim=1)
scores_1_hop = torch.mm(outputs["q"], all_ctx.t())
neg_scores_1 = torch.bmm(outputs["q"].unsqueeze(1), neg_ctx.transpose(1,2)).squeeze(1)
scores_2_hop = torch.mm(outputs["q_sp1"], all_ctx.t())
neg_scores_2 = torch.bmm(outputs["q_sp1"].unsqueeze(1), neg_ctx.transpose(1,2)).squeeze(1)
# mask for 1st hop
bsize = outputs["q"].size(0)
scores_1_mask = torch.cat([torch.zeros(bsize, bsize), torch.eye(bsize)], dim=1).to(outputs["q"].device)
scores_1_hop = scores_1_hop.float().masked_fill(scores_1_mask.bool(), float('-inf')).type_as(scores_1_hop)
scores_1_hop = torch.cat([scores_1_hop, neg_scores_1], dim=1)
scores_2_hop = torch.cat([scores_2_hop, neg_scores_2], dim=1)
stop_loss = F.cross_entropy(outputs["stop_logits"], batch["stop_targets"].view(-1), reduction="sum")
target_1_hop = torch.arange(outputs["q"].size(0)).to(outputs["q"].device)
target_2_hop = torch.arange(outputs["q"].size(0)).to(outputs["q"].device) + outputs["q"].size(0)
retrieve_loss = F.cross_entropy(scores_1_hop, target_1_hop, reduction="sum") + (F.cross_entropy(scores_2_hop, target_2_hop, reduction="none") * batch["stop_targets"].view(-1)).sum()
return retrieve_loss + stop_loss
def unified_eval(outputs, batch):
all_ctx = torch.cat([outputs['c1'], outputs['c2']], dim=0)
neg_ctx = torch.cat([outputs["neg_1"].unsqueeze(1), outputs["neg_2"].unsqueeze(1)], dim=1)
scores_1_hop = torch.mm(outputs["q"], all_ctx.t())
scores_2_hop = torch.mm(outputs["q_sp1"], all_ctx.t())
neg_scores_1 = torch.bmm(outputs["q"].unsqueeze(1), neg_ctx.transpose(1,2)).squeeze(1)
neg_scores_2 = torch.bmm(outputs["q_sp1"].unsqueeze(1), neg_ctx.transpose(1,2)).squeeze(1)
bsize = outputs["q"].size(0)
scores_1_mask = torch.cat([torch.zeros(bsize, bsize), torch.eye(bsize)], dim=1).to(outputs["q"].device)
scores_1_hop = scores_1_hop.float().masked_fill(scores_1_mask.bool(), float('-inf')).type_as(scores_1_hop)
scores_1_hop = torch.cat([scores_1_hop, neg_scores_1], dim=1)
scores_2_hop = torch.cat([scores_2_hop, neg_scores_2], dim=1)
target_1_hop = torch.arange(outputs["q"].size(0)).to(outputs["q"].device)
target_2_hop = torch.arange(outputs["q"].size(0)).to(outputs["q"].device) + outputs["q"].size(0)
# stop accuracy
stop_pred = outputs["stop_logits"].argmax(dim=1)
stop_targets = batch["stop_targets"].view(-1)
stop_acc = (stop_pred == stop_targets).float().tolist()
ranked_1_hop = scores_1_hop.argsort(dim=1, descending=True)
ranked_2_hop = scores_2_hop.argsort(dim=1, descending=True)
idx2ranked_1 = ranked_1_hop.argsort(dim=1)
idx2ranked_2 = ranked_2_hop.argsort(dim=1)
rrs_1_mhop, rrs_2_mhop, rrs_nq = [], [], []
for t1, idx2ranked1, t2, idx2ranked2, stop in zip(target_1_hop, idx2ranked_1, target_2_hop, idx2ranked_2, stop_targets):
if stop: #
rrs_1_mhop.append(1 / (idx2ranked1[t1].item() + 1))
rrs_2_mhop.append(1 / (idx2ranked2[t2].item() + 1))
else:
rrs_nq.append(1 / (idx2ranked1[t1].item() + 1))
return {
"stop_acc": stop_acc,
"rrs_1_mhop": rrs_1_mhop,
"rrs_2_mhop": rrs_2_mhop,
"rrs_nq": rrs_nq
}
================================================
FILE: mdr/retrieval/decomposed_analysis.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
def decomposed_errors():
top1_pred = [json.loads(l) for l in open("/private/home/xwhan/data/hotpot/dense_val_b1_top1.json").readlines()]
analysis_folder = "/private/home/xwhan/data/hotpot/analysis"
start_errors, bridge_errors, failed = [], [], []
correct = []
for item in top1_pred:
pred_titles = [_[0] for _ in item["candidate_chains"][0]]
gold_titles = [_[0] for _ in item["sp"]]
if set(pred_titles) == set(gold_titles):
if item["type"] == "bridge":
correct.append(item)
continue
if item["type"] == "bridge":
start_title = None
for t in gold_titles:
if t != item["bridge"]:
start_title = t
assert start_title is not None
if item["bridge"] in pred_titles and start_title not in pred_titles:
start_errors.append(item)
elif item["bridge"] not in pred_titles and start_title in pred_titles:
bridge_errors.append(item)
else:
failed.append(item)
with open(analysis_folder + "/correct.json", "w") as g:
for _ in correct:
_["predicted"] = _.pop("candidate_chains")[0]
g.write(json.dumps(_) + "\n")
with open(analysis_folder + "/start_errors.json", "w") as g:
for _ in start_errors:
_["predicted"] = _.pop("candidate_chains")[0]
g.write(json.dumps(_) + "\n")
with open(analysis_folder + "/bridge_errors.json", "w") as g:
for _ in bridge_errors:
_["predicted"] = _.pop("candidate_chains")[0]
g.write(json.dumps(_) + "\n")
with open(analysis_folder + "/total_errors.json", "w") as g:
for _ in failed:
_["predicted"] = _.pop("candidate_chains")[0]
g.write(json.dumps(_) + "\n")
print(len(correct))
print(len(start_errors))
print(len(bridge_errors))
print(len(failed))
import random
def collect_gold_decomposition():
"""
interactively collect
"""
dev_qdmr = [json.loads(l) for l in open("/private/home/xwhan/data/QDMR/dev.json").readlines()]
bridge_dev = [_ for _ in dev_qdmr if _["type"] == "bridge"]
random.shuffle(bridge_dev)
idx = 0
samples_to_inspect = []
while True:
print(f"\n-----{len(samples_to_inspect)} samples collected so far-----")
sample = bridge_dev[idx]
idx += 1
print(f"Original Q: {sample['q']}")
print(f"Decomposed Q: {sample['q_decom']}")
print(f"Supporting Passages: {sample['sp']}")
subq1 = input("Type SUB Q1:")
if subq1 == "bad":
continue
elif subq1 == "stop":
break
subq2 = input("Type SUB Q2:")
samples_to_inspect.append({
"id": sample["id"],
"sp": sample["sp"],
"orig_q": sample['q'],
"subQ_1": subq1,
"subQ_2": subq2
})
print(f"{len(samples_to_inspect)} samples collected in total..")
with open("/private/home/xwhan/data/QDMR/inspect.json", "w") as g:
for _ in samples_to_inspect:
g.write(json.dumps(_) + "\n")
def qdmr_utils():
"""
change file format for decomposed and end-to-end retrieval
"""
qdmr_data = [json.loads(l) for l in open("/private/home/xwhan/data/QDMR/inspect.json").readlines()]
mhop_data, decomposed_data = [], []
for idx, item in enumerate(qdmr_data):
if idx in [65,66,67]:
continue
sp = [_["title"] for _ in item["sp"]]
question = item["orig_q"]
mhop_data.append({
"question": question,
"sp": sp,
"type": "bridge",
"_id": item["id"]
})
decomposed_data.append(item)
# with open("/private/home/xwhan/data/QDMR/qdmr_decomposed.json", "w") as g:
# for item in decomposed_data:
# g.write(json.dumps(item) + "\n")
with open("/private/home/xwhan/data/QDMR/qdmr_e2e.json", "w") as g:
for item in mhop_data:
g.write(json.dumps(item) + "\n")
def analyze_results():
decomposed_results = [json.loads(l) for l in open("/private/home/xwhan/data/QDMR/qdmr_decomposed_results.json")]
e2e_results = [json.loads(l) for l in open("/private/home/xwhan/data/QDMR/qdmr_e2e_results.json")]
better = 0
worse = 0
both = 0
for res1, res2 in zip(decomposed_results, e2e_results):
sp_titles = set([_[0] for _ in res1["sp"]])
res1_top1 = [_[0] for _ in res1["candidate_chains"][0]]
res2_top1 = [_[0] for _ in res2["candidate_chains"][0]]
assert res1["_id"] == res2["_id"]
question = res1["question"]
q_pairs = res1["q_pairs"]
if set(res2_top1) == sp_titles and set(res1_top1) != sp_titles:
# print(sp_titles)
# import pdb; pdb.set_trace()
better += 1
elif set(res2_top1) != sp_titles and set(res1_top1) == sp_titles:
worse += 1
elif set(res2_top1) == sp_titles and set(res1_top1) == sp_titles:
both += 1
print(both)
print(better)
print(worse)
print(len(decomposed_results))
if __name__ == "__main__":
# collect_gold_decomposition()
# qdmr_utils()
analyze_results()
================================================
FILE: mdr/retrieval/fever.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1059\n",
"12273\n"
]
}
],
"source": [
"import json\n",
"import numpy as np\n",
"import random\n",
"\n",
"fever_path = \"/private/home/xwhan/data/fever/retrieval/\"\n",
"\n",
"dev = [json.loads(l) for l in open(fever_path + \"dev.txt\").readlines()]\n",
"multi_dev = []\n",
"single_dev = []\n",
"all_evidence_lens = [] # for multi evidence\n",
"random.shuffle(dev)\n",
"all_claim_lens = []\n",
"for item in dev:\n",
" evidence_lens = []\n",
" all_claim_lens.append(len(item[\"claim\"].split()))\n",
" \n",
" for chain in item[\"evidence\"]:\n",
" if len(chain) > 1:\n",
"# evidence_lens.append(len(chain))\n",
" chain_titles = set([p[\"title\"] for p in chain])\n",
" evidence_lens.append(len(chain_titles)) \n",
"# print(item[\"claim\"])\n",
"# print(chain)\n",
"# assert False\n",
" else:\n",
" evidence_lens.append(1)\n",
" multi_count = np.sum([int(c > 1) for c in evidence_lens])\n",
" \n",
" if multi_count == len(evidence_lens):\n",
" multi_dev.append(item)\n",
" all_evidence_lens += evidence_lens\n",
" else:\n",
" single_dev.append(item)\n",
" \n",
"print(len(multi_dev))\n",
"print(len(single_dev))\n",
"with open(\"/private/home/xwhan/data/fever/retrieval/dev_multi_evidence_compact.txt\", \"w\") as g:\n",
" for l in multi_dev:\n",
" g.write(json.dumps(l) + \"\\n\")\n",
"# with open(\"/private/home/xwhan/data/fever/retrieval/dev_single_evidence.txt\", \"w\") as g:\n",
"# for l in single_dev:\n",
"# g.write(json.dumps(l) + \"\\n\")\n",
"# with open(\"/private/home/xwhan/data/fever/retrieval/dev_all.txt\", \"w\") as g:\n",
"# for l in single_dev + multi_dev:\n",
"# g.write(json.dumps(l) + \"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1741\n",
"2.0\n",
"0.5835726593911545 0.5835726593911545 0.5835726593911545\n"
]
}
],
"source": [
"# baseline retrieval for single/multihop subsets\n",
"\n",
"import unicodedata\n",
"def normalize(text):\n",
" \"\"\"Resolve different type of unicode encodings.\"\"\"\n",
" return unicodedata.normalize('NFD', text)\n",
"\n",
"el_results = [json.loads(l) for l in open(\"/private/home/xwhan/data/fever/retrieval/dev.ensembles.s10.jsonl\").readlines()]\n",
"id2el_docs = {_[\"id\"]:_[\"predicted_pages\"] for _ in el_results}\n",
"\n",
"dense_multi_results = [json.loads(l) for l in open(\"/private/home/xwhan/data/fever/retrieval/dense_fever_b1_20_k20.json\").readlines()] \n",
"\n",
"single_gold = {_[\"id\"]:_ for _ in single_dev}\n",
"multi_gold = {_[\"id\"]:_ for _ in multi_dev}\n",
"all_gold = {_[\"id\"]:_ for _ in multi_dev + single_dev}\n",
"\n",
"subset = multi_gold\n",
"precs, recalls = [], []\n",
"doc_count = []\n",
"dense_docs = []\n",
"for item in dense_multi_results:\n",
" if item[\"id\"] in subset:\n",
"# pred = set(item[\"predicted_pages\"])\n",
" retrieved_chains = item[\"candidate_chains\"] \n",
" pred = []\n",
" for chain in retrieved_chains:\n",
" for p in chain:\n",
" if normalize(p[0]) not in pred:\n",
" pred.append(normalize(p[0]))\n",
" pred = pred[:2]\n",
"# pred = [_[\"title\"] for _ in item[\"topk\"][:1]]\n",
" \n",
" pred = set(pred)\n",
"# el_pred = id2el_docs[item[\"id\"]]\n",
"# el_count = 0\n",
"# for title in el_pred:\n",
"# if title not in pred:\n",
"# pred.add(title)\n",
"# el_count +=1\n",
"# if el_count == 2:\n",
"# break\n",
" pred = list(pred)\n",
" \n",
" dense_docs.append({\n",
" \"claim\": item[\"claim\"],\n",
" \"id\": item[\"id\"],\n",
" \"predicted_pages\": list(pred)\n",
" })\n",
" \n",
" doc_count.append(len(pred))\n",
" \n",
" gold_docs = set()\n",
" recall = 0\n",
" for chain in subset[item[\"id\"]][\"evidence\"]:\n",
" chain_titles = set([normalize(p[\"title\"]) for p in chain])\n",
" for t in chain_titles: gold_docs.add(t)\n",
" chain_covered = [int(t in pred) for t in chain_titles]\n",
" if np.sum(chain_covered) == len(chain_titles):\n",
" recall = 1\n",
" break\n",
" \n",
" if len(gold_docs) > 0:\n",
" if len(pred) == 0:\n",
" prec = 0\n",
" else:\n",
" prec = np.mean([int(doc in gold_docs) for doc in pred])\n",
" \n",
" precs.append(prec)\n",
" recalls.append(recall)\n",
" \n",
"print(len(precs))\n",
"print(np.mean(doc_count))\n",
"pr, rec = np.mean(precs), np.mean(recalls)\n",
"print(pr, rec, 2.0 * pr * rec / (pr + rec))\n",
"\n",
"# with open(\"/private/home/xwhan/data/fever/retrieval/dense_wiki_pages_top2.jsonl\", \"w\") as g:\n",
"# for _ in dense_docs:\n",
"# g.write(json.dumps(_) + \"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1741\n",
"2.9959793222286044\n",
"0.6764680633362424 1.0 0.8070157471297632\n"
]
}
],
"source": [
"# inspect the FEVER results\n",
"dev_all_results = [json.loads(l) for l in open(\"/private/home/xwhan/code/Transformer-XH/data/fever_dev_graph.json\").readlines()]\n",
"subset = multi_gold\n",
"precs, recalls = [], []\n",
"doc_count = []\n",
"\n",
"for item in dev_all_results:\n",
" if item[\"qid\"] in subset:\n",
" \n",
" pred = [_[\"name\"] for _ in item[\"node\"]]\n",
" pred = set(pred)\n",
" pred = list(pred)\n",
" \n",
" doc_count.append(len(pred))\n",
" \n",
" gold_docs = set()\n",
" recall = 0\n",
" for chain in subset[item[\"qid\"]][\"evidence\"]:\n",
" chain_titles = set([normalize(p[\"title\"]) for p in chain])\n",
" for t in chain_titles: gold_docs.add(t)\n",
" chain_covered = [int(t in pred) for t in chain_titles]\n",
" if np.sum(chain_covered) == len(chain_titles):\n",
" recall = 1\n",
" break\n",
" \n",
" if len(gold_docs) > 0:\n",
" if len(pred) == 0:\n",
" prec = 0\n",
" else:\n",
" prec = np.mean([int(doc in gold_docs) for doc in pred])\n",
" \n",
" precs.append(prec)\n",
" recalls.append(recall)\n",
" \n",
"print(len(precs))\n",
"print(np.mean(doc_count))\n",
"pr, rec = np.mean(precs), np.mean(recalls)\n",
"print(pr, rec, 2.0 * pr * rec / (pr + rec))"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"# passage retrieval evaluation \n",
"def fever_retrieval_eval(results, topk=5):\n",
" \n",
" precs, recalls = [], []\n",
" for item in results:\n",
" gold = item[\"correct_normalized\"]\n",
" pred = item[\"bm25_topk\"][:topk]\n",
" \n",
" if len(gold) > 0:\n",
" prec = np.mean([int(doc in gold) for doc in pred])\n",
" else:\n",
" prec = 1\n",
" recall = 0\n",
" for chain in item[\"evidence\"]:\n",
" chain_titles = set([normalize(p[\"title\"]) for p in chain])\n",
" chain_covered = [int(t in pred) for t in chain_titles]\n",
" if np.sum(chain_covered) == len(chain_titles):\n",
" recall = 1\n",
" break\n",
" precs.append(prec)\n",
" recalls.append(recall)\n",
" pr, rec = np.mean(precs), np.mean(recalls)\n",
" return pr, rec, 2.0 * pr * rec / (pr + rec)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.12268811028144745, 0.5020103388856979, 0.19718537768662206)\n"
]
}
],
"source": [
"tfidf_results = [json.loads(l) for l in open(\"/private/home/xwhan/data/fever/retrieval/multi_dev_tfidf.txt\").readlines()]\n",
"print(fever_retrieval_eval(tfidf_results, 10))"
]
},
{
"cell_type": "code",
"execution_count": 234,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 244,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1741\n",
"1741\n",
"(0.6223243346735593, 0.46927053417576103, 0.5350675077296432)\n"
]
}
],
"source": [
"# phrase_matching_results = [json.loads(l) for l in open(\"/private/home/xwhan/data/fever/retrieval/all_dev.json\").readlines()]\n",
"# phrase_matching_results = [json.loads(l.decode('utf-8').strip('\\r\\n')) for l in open(\"/private/home/xwhan/data/fever/retrieval/dev.ensembles.s10.jsonl\").readlines()]\n",
"phrase_matching_results = [json.loads(l) for l in open(\"/private/home/xwhan/code/Transformer-XH/data/fever_dev_graph.json\").readlines()]\n",
"# for _ in phrase_matching_results:\n",
"# _[\"id\"] = _[\"qid\"]\n",
"\n",
"phrase_matching_results = [_ for _ in phrase_matching_results if _[\"id\"] in multihop_ids]\n",
"# phrase_matching_results = [json.loads(l) for l in open(\"/private/home/xwhan/data/fever/retrieval/esim_mhop_dev.json\").readlines()]\n",
"\n",
"# json.dump(phrase_matching_results, open(\"/private/home/xwhan/data/fever/retrieval/dev_el_wiki_pages.jsonl\", \"w\"))\n",
"\n",
"print(len(phrase_matching_results))\n",
"\n",
"tfidf_results = [json.loads(l) for l in open(\"/private/home/xwhan/data/fever/retrieval/multi_dev_tfidf.txt\").readlines()]\n",
"print(len(tfidf_results))\n",
"pred_lens = []\n",
"def fever_retrieval_eval_phrase(tfidf_results, phrase_results, topk=5):\n",
" id2gold = {_[\"id\"]:_[\"correct_normalized\"] for _ in tfidf_results}\n",
" id2gold_evidence = {_[\"id\"]:_[\"evidence\"] for _ in tfidf_results}\n",
" precs, recalls = [], []\n",
" for item in phrase_results:\n",
" gold = id2gold[item[\"id\"]]\n",
"# print(gold)\n",
" retrieved_evidence = item[\"evidence\"] \n",
" pred = []\n",
" for e in retrieved_evidence:\n",
" pred.append(normalize(e[0]))\n",
" \n",
"# pred = item[\"predicted_pages\"] + item[\"wiki_results\"]\n",
"# pred = item[\"wiki_results\"]\n",
"# pred = item[\"predicted_pages\"]\n",
" \n",
"# pred = []\n",
"# for n in item[\"node\"]:\n",
"# pred.append(n[\"name\"])\n",
" \n",
" pred = list(set(pred))\n",
" pred_lens.append(len(pred))\n",
" \n",
" if len(gold) > 0:\n",
" if len(pred) == 0:\n",
" prec = 0\n",
" else:\n",
" prec = np.mean([int(doc in gold) for doc in pred])\n",
" else:\n",
" prec = 1\n",
" recall = 0\n",
" for chain in id2gold_evidence[item[\"id\"]]:\n",
" chain_titles = set([normalize(p[\"title\"]) for p in chain])\n",
" chain_covered = [int(t in pred) for t in chain_titles]\n",
" if np.sum(chain_covered) == len(chain_titles):\n",
" recall = 1\n",
" break\n",
" precs.append(prec)\n",
" recalls.append(recall)\n",
" \n",
" pr, rec = np.mean(precs), np.mean(recalls)\n",
" return pr, rec, 2.0 * pr * rec / (pr + rec)\n",
"print(fever_retrieval_eval_phrase(tfidf_results, phrase_matching_results, topk=5))"
]
},
{
"cell_type": "code",
"execution_count": 181,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2.1039632395175185"
]
},
"execution_count": 181,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.mean(pred_lens)"
]
},
{
"cell_type": "code",
"execution_count": 153,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.12426242624262426"
]
},
"execution_count": 153,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_with_all = [json.loads(l) for l in open(\"/private/home/xwhan/data/fever/retrieval/all_dev.json\").readlines()]\n",
"title_count = []\n",
"for item in test_with_all:\n",
" titles = set()\n",
" for e in item[\"evidence\"]:\n",
" titles.add(e[0])\n",
" title_count.append(len(titles))\n",
"np.sum(np.array(title_count) > 7) / len(title_count)"
]
},
{
"cell_type": "code",
"execution_count": 154,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.11321132113211321"
]
},
"execution_count": 154,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_with_all = [json.loads(l) for l in open(\"/private/home/xwhan/data/fever/retrieval/all_test.json\").readlines()]\n",
"title_count = []\n",
"for item in test_with_all:\n",
" titles = set()\n",
" for e in item[\"evidence\"]:\n",
" titles.add(e[0])\n",
" title_count.append(len(titles))\n",
"np.sum(np.array(title_count) > 7) / len(title_count)"
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1960\n",
"1960\n"
]
}
],
"source": [
"# build dense final prediction for evaluation\n",
"final_pred = [json.loads(l) for l in open(\"/private/home/xwhan/code/KernelGAT/kgat/output/dense_bert_dev_mhop_top4.json\").readlines()]\n",
"final_retrieval = [json.loads(l) for l in open(\"/private/home/xwhan/code/KernelGAT/data/bert_dense_top4_mhop_sents.json\").readlines()]\n",
"all_dev_gold = [json.loads(l) for l in open(\"/private/home/xwhan/data/fever/shared_task_dev.jsonl\").readlines()]\n",
"id2gold = {_[\"id\"]:_ for _ in all_dev_gold}\n",
"\n",
"print(len(final_retrieval))\n",
"final = []\n",
"for pred, retrieval in zip(final_pred, final_retrieval):\n",
" assert pred[\"id\"] == retrieval[\"id\"]\n",
" final.append({\n",
" \"id\": pred[\"id\"],\n",
" \"label\": id2gold[pred[\"id\"]][\"label\"],\n",
" \"evidence\": id2gold[pred[\"id\"]][\"evidence\"],\n",
" \"predicted_label\": pred[\"predicted_label\"],\n",
" \"predicted_evidence\": [[normalize(e[0]), int(e[1])] for e in retrieval[\"evidence\"][:5]]\n",
" })\n",
"\n",
"print(len(final))\n",
"with open(\"/private/home/xwhan/data/fever/results/dense_top4_mhop_dev.json\", \"w\") as g:\n",
" for l in final:\n",
" g.write(json.dumps(l) + \"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1960\n",
"1960\n"
]
}
],
"source": [
"# build EL final prediction for evaluation\n",
"final_pred = [json.loads(l) for l in open(\"/private/home/xwhan/code/KernelGAT/kgat/output/el_bert_dev_mhop.json\").readlines()]\n",
"final_retrieval = [json.loads(l) for l in open(\"/private/home/xwhan/code/KernelGAT/data/bert_dev_multi_el.json\").readlines()]\n",
"all_dev_gold = [json.loads(l) for l in open(\"/private/home/xwhan/data/fever/shared_task_dev.jsonl\").readlines()]\n",
"id2gold = {_[\"id\"]:_ for _ in all_dev_gold}\n",
"\n",
"print(len(final_retrieval))\n",
"final = []\n",
"for pred, retrieval in zip(final_pred, final_retrieval):\n",
" assert pred[\"id\"] == retrieval[\"id\"]\n",
" final.append({\n",
" \"id\": pred[\"id\"],\n",
" \"label\": id2gold[pred[\"id\"]][\"label\"],\n",
" \"evidence\": id2gold[pred[\"id\"]][\"evidence\"],\n",
" \"predicted_label\": pred[\"predicted_label\"],\n",
" \"predicted_evidence\": [[normalize(e[0]), int(e[1])] for e in retrieval[\"evidence\"][:5]]\n",
" })\n",
"\n",
"print(len(final))\n",
"with open(\"/private/home/xwhan/data/fever/results/el_mhop_dev.json\", \"w\") as g:\n",
" for l in final:\n",
" g.write(json.dumps(l) + \"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1960\n"
]
}
],
"source": [
"final_pred = [json.loads(l) for l in open(\"/private/home/xwhan/code/KernelGAT/kgat/output/esim_mhop_dev.json\").readlines()]\n",
"final_retrieval = [json.loads(l) for l in open(\"/private/home/xwhan/data/fever/retrieval/esim_mhop_dev.json\").readlines()]\n",
"all_dev_gold = [json.loads(l) for l in open(\"/private/home/xwhan/data/fever/shared_task_dev.jsonl\").readlines()]\n",
"id2gold = {_[\"id\"]:_ for _ in all_dev_gold}\n",
"\n",
"final = []\n",
"for pred, retrieval in zip(final_pred, final_retrieval):\n",
" assert pred[\"id\"] == retrieval[\"id\"]\n",
" final.append({\n",
" \"id\": pred[\"id\"],\n",
" \"label\": id2gold[pred[\"id\"]][\"label\"],\n",
" \"evidence\": id2gold[pred[\"id\"]][\"evidence\"],\n",
" \"predicted_label\": pred[\"predicted_label\"],\n",
" \"predicted_evidence\": [[e[0], int(e[1])] for e in retrieval[\"evidence\"][:5]]\n",
" })\n",
"\n",
"print(len(final))\n",
"with open(\"/private/home/xwhan/data/fever/results/esim_mhop_dev.json\", \"w\") as g:\n",
" for l in final:\n",
" g.write(json.dumps(l) + \"\\n\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: mdr/retrieval/hotpot.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"test_qas = json.load(open(\"/private/home/xwhan/data/hotpot/hotpot_test_fullwiki_v1.json\"))\n",
"test_results = json.load(open(\"/private/home/xwhan/data/hotpot/results/hotpot_test_b200_k500.json\"))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(test_qas) == len(test_results[\"answer\"])"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Who has been in more bands, Deron Miller or Steve Marriott?\n",
"Deron John Miller\n"
]
}
],
"source": [
"import random\n",
"qid2question = {_[\"_id\"]:_[\"question\"] for _ in test_qas}\n",
"qids = list(test_results[\"answer\"].keys())\n",
"random.shuffle(qids)\n",
"\n",
"\n",
"\n",
"print(qid2question[qids[0]])\n",
"print(test_results[\"answer\"][qids[0]])"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting seaborn\r\n",
" Using cached https://files.pythonhosted.org/packages/c7/e6/54aaaafd0b87f51dfba92ba73da94151aa3bc179e5fe88fc5dfb3038e860/seaborn-0.10.1-py3-none-any.whl\r\n",
"Collecting matplotlib>=2.1.2 (from seaborn)\r\n",
" Using cached https://files.pythonhosted.org/packages/96/a7/b6fa244fd8a8814ef9408c8a5a7e4ed0340e232a6f0ce2046b42e50672c0/matplotlib-3.3.1-cp36-cp36m-manylinux1_x86_64.whl\r\n",
"Requirement already satisfied: scipy>=1.0.1 in /public/apps/anaconda3/5.0.1/lib/python3.6/site-packages (from seaborn)\r\n",
"Requirement already satisfied: numpy>=1.13.3 in /public/apps/anaconda3/5.0.1/lib/python3.6/site-packages (from seaborn)\r\n",
"Requirement already satisfied: pandas>=0.22.0 in /public/apps/anaconda3/5.0.1/lib/python3.6/site-packages (from seaborn)\r\n",
"Requirement already satisfied: python-dateutil>=2.1 in /public/apps/anaconda3/5.0.1/lib/python3.6/site-packages (from matplotlib>=2.1.2->seaborn)\r\n",
"Collecting pillow>=6.2.0 (from matplotlib>=2.1.2->seaborn)\r\n",
" Using cached https://files.pythonhosted.org/packages/30/bf/92385b4262178ca22b34f82e0e09c2922eb351fe39f3cc7b8ba9ea555b41/Pillow-7.2.0-cp36-cp36m-manylinux1_x86_64.whl\r\n",
"Collecting certifi>=2020.06.20 (from matplotlib>=2.1.2->seaborn)\r\n",
" Using cached https://files.pythonhosted.org/packages/5e/c4/6c4fe722df5343c33226f0b4e0bb042e4dc13483228b4718baf286f86d87/certifi-2020.6.20-py2.py3-none-any.whl\r\n",
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /public/apps/anaconda3/5.0.1/lib/python3.6/site-packages (from matplotlib>=2.1.2->seaborn)\r\n",
"Requirement already satisfied: cycler>=0.10 in /public/apps/anaconda3/5.0.1/lib/python3.6/site-packages (from matplotlib>=2.1.2->seaborn)\r\n",
"Collecting kiwisolver>=1.0.1 (from matplotlib>=2.1.2->seaborn)\r\n",
" Using cached https://files.pythonhosted.org/packages/ae/23/147de658aabbf968324551ea22c0c13a00284c4ef49a77002e91f79657b7/kiwisolver-1.2.0-cp36-cp36m-manylinux1_x86_64.whl\r\n",
"Requirement already satisfied: pytz>=2011k in /public/apps/anaconda3/5.0.1/lib/python3.6/site-packages (from pandas>=0.22.0->seaborn)\r\n",
"Requirement already satisfied: six>=1.5 in /public/apps/anaconda3/5.0.1/lib/python3.6/site-packages (from python-dateutil>=2.1->matplotlib>=2.1.2->seaborn)\r\n",
"Installing collected packages: pillow, certifi, kiwisolver, matplotlib, seaborn\r\n",
"Successfully installed certifi-2020.6.20 kiwisolver-1.2.0 matplotlib-3.3.1 pillow-7.2.0 seaborn-0.10.1\r\n",
"\u001b[33mYou are using pip version 9.0.1, however version 20.2.2 is available.\r\n",
"You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\r\n"
]
}
],
"source": [
"# Install a pip package in the current Jupyter kernel\n",
"import sys\n",
"!{sys.executable} -m pip install seaborn --user"
]
},
{
"cell_type": "code",
"execution_count": 119,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAESCAYAAAAMifkAAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAIABJREFUeJzs3XlcVFX/wPHPzMAM44LgArK5pIlY\nLii5ofaIuS+4Zy6ZO6W2KKZmaWpqrpVZz5M/pSxNKyXMLSrNJTOXwiVFMzUUEFQUkXVg5v7+QG5O\n7Mhi8n2/Xrxkzj333HOver5zzz33HI2iKApCCCFEHrRlXQEhhBAPPgkWQggh8iXBQgghRL4kWAgh\nhMiXBAshhBD5kmAhhBAiXxIshBBC5EuChRBCiHzZlMZBIiMjmThxovr5zp07JCYmcuTIES5dusSM\nGTOIj4/HwcGBxYsXU6dOndKolhBCiALSlMUb3AsWLMBsNjN79myeffZZBgwYgL+/P1u3bmXLli18\n+umnpV0lIYQQeSj1biiTycS2bdsYMGAAcXFxnDlzhl69egHQq1cvzpw5w82bN0u7WkIIIfJQKt1Q\n99qzZw/Ozs489thj/P777zg7O6PT6QDQ6XQ4OTlx9epVqlatarVfQkICCQkJVmkmk4krV65Qp04d\ntQwhhBB5M5vNXL9+nccffxw7O7sC7VPqwWLLli0MGDCg0PutW7eOVatWlUCNhBCifNqwYQM+Pj4F\nyluqwSI2NpajR4+yZMkSAFxcXIiNjcVsNqPT6TCbzVy7dg0XF5ds+44cOZJ+/fpZpUVFRfHss8+y\nYcMGatasWSrnIIQQ/3YxMTEMGzaMGjVqFHifUg0WX3/9NU8++SSOjo4AVKtWDS8vL7Zv346/vz/b\nt2/Hy8srWxcUgL29Pfb29jmWW7NmTdzd3Uu07kII8bApTPd9qT7g/vrrr7N1Qb355pusX7+erl27\nsn79eubOnVuaVRJCCFEApXpnERoami2tXr16fPXVV6VZDSGEEIUkb3ALIYTIlwQLIYQQ+ZJgIYQQ\nIl8SLIQQQuRLgoUQQoh8SbAQQgiRLwkWQggh8iXBQgghRL4kWAghhMhXqc86K4QQouQdGTmG9Pj4\nbOm2Dg64Lij8tEpyZyGEEA8ZRVFyDBRArun5kTsLIYQoBYqigMWCRqfDkp5O2vXrWNJMWEwmzGlp\nWEwmKtTywM7JCdPNW1w/8BMWkwnL3W2WtDScOz9Fpfr1SPzzAn+t+0zdZk7L3N7glRep0vhx4n7+\npdjrL8FCCFFuKYqS2RCbTFjSTGhtbbCtUgVFUbh96ve/G+u0zAa9Qi0PqjzWCIvJRMT6z+9prDP/\nrN7OF6f/PIkpPp5TM19Xg4ElLQ1Lejp1Rj2Lm38fUmNjCZv4Urb61HthAjW7dsF08yZ/BX2SmajV\notXr0Rn0OHg3o1L9egBY0tPR2tlhW6UKWr0ercGATeVKAFSsU7vYr5UECyHEAykjKQlzSqpVg6y1\ntaVSvUcAiPvlMKZb8Vbfvg1OTtTs8hQAf374Eaa4OKvtVRo/Tt0xowA48twY0m9Zd8k4PeXHo5Mn\nAnB6zjywWKy2u/TsQZXHGgEQE/o9OkNmI63V69Hq9ZhTUgHQ6vVUqvcIWr0BrUF/t7E3ULlBAwD0\nVavx6Csvosva9+6fdjWdAahYtw6tPv8UrV6PxsYGjUZjVY9K9evR5O0FuV47o5troa93fiRYCCHy\npVgsWNLTsaSZsLWvDEDa9euYbt7K/FZ9tzFWzBZqdGgHwI2Dh0i6ePHvbhKTCZ2dHfUCxgFwcU0Q\nCb+fvru/CYspDYOTM81WZK6keWbuAu6cO2dVj0qPPkrTZW8DcHnjFyT/FfH3Rq0WR++marBIu3aN\n9ISEzIbaaMS2ShVs7y68BlCzW1ewWP5urA16KtxdRE2j0fD4W3PR2tqqDbnOYEBXwZh5KL2eNl9s\nyPV62VSogGfglDy2G3H6z5O5btfodNhUrJjr9rIgwUKIIsprtEnLdWtL/PhqF8o9XR3mtDQquLuh\n1etJuXqVpIuX1C4UiymzUXbp1RObCkbiDh8l7tAvf/eJ3y3jsflvojMYuLzpS65u36GWn6VtyGY0\nGg1XvtxM7Hc/WNVJa2enBou4X37hxk8/q10oWoMBQ/Xqal6bihUxONWw+mZtqFZN3e7Wrw/pd+6g\n1RvU/W0qV1a3N5o9C41Gq35z/+c38MfefCPP61dryOA8t2fdQfxb2To45PrvsygkWAhRRPmNNrGY\nTKRdv2HVp21JS6PiI3UxVKtG6rVrdxtrk9WDStdePahQy4Pbp09zZeOX/9jfhNesGVSq9wjXftjN\nn6v+m+343u+/Q4Vatbh17DcurQnKtr3Gkx2wqWAk7do1Ek6fvttVktkg64xGFLMZgIq1a1G9fTv1\nW3VWg47FAjodLj27U61N67uNtUENClkavPwiDaa8nK0LJUutZ57O8/pWa9M6z+33BhaRXV5fWCIj\nIwtdngQLUW4pioI5JQVzSirmlBQsqZl/6qtXx+hSk4zkFGK//17dbk5JxZKaSvUO7ajq0yLf8pMu\n/cXJV2dmS28w9RVqdGhHakwsfwWty0xUH2IaqO7bhgq1PDLraDZndqE4VPm7O8RoB0ClR+tTe8Qw\ntQsl8xu4Af3dRrRGh3Y4NG18tz/dYPUNHMC1d09ce/fMtf7V2rTOs8GuWKcOFevUyXW7phDrO4sH\nnwQL8a+gKApKejrm1MwHiLb29gDEnzhJRlKS2tCbU1IxerhTrVVLFEXh3NIVWFKzAkJmnhr/6UCt\nZ57GkprK4WdGZDuW++CB1B72DBaTSW3Ms/q9dUY77B9/rEB1tnNxocGUl+92s+jVBt3oUhMA+0Ze\neT7ErPLYYzRe9Fau5efXWNtWqYJtlSoFqqsQ+ZFg8ZAq6/70LOm3b5N+JzGzMb/boGttbXFo1hSA\nq7tCSYuNzWzo7243urlRZ2RmI35y+mukREZhTklRu0eqtmqJ12vTAfhj+Tuk306wOmaNJztQrVVL\nNBoNqdFX0ei06IxGDDWqo7Wzw+5uY601GKgzaiQ6ox06u8xAoLWzUxtzW/vKtPr8U3R2dkX6lmxr\nX5kaT7bPdbvWxgatjfwXFP8O8i/1IVXYtzct6el/d8fcHU8OcOfcH6RERd/91p7ZoKPRUHvYM0Dm\niJTbJ0/93VWTmore0YFm7y4H4NzSFdw+9bvVsSrWraNuv7b7R5IjIjIbbKMxc9y4w98jVqo0fpxK\n9epZbTe6uqjbvd6YhdbG5m5Df7fB1//db97s3WW5XiONVotb3z55bn/QRqQIUVYkWJRDJ6bNoPHC\n+Whtbfnrk0+J3rYDJSPj7wxaLW2Dv0Sj0RAT+j3Xdu+x2qZ3dFCDhZKRkZlWvZraoOurVlWzu/Xz\nx+mpTmoXjs5oxKZSJXV7kyUL0Whzn3Wm9vCheZ5L5UfrF/Lsi09xjzYR4kFWasEiLS2NhQsXcujQ\nIQwGA82aNWP+/Pn4+fmh1+sxGAwABAYG0r597rfuImem+HjuhJ8lIfwsCWfO5pnXpmLFzEbe1hb7\nRl5odDq0dpkNue7unygKaDTUGvYMHoMHWH1rv7dvvfaIYXkey7FF8zy35xUoHnSl2Z0nRFkrtWCx\ndOlSDAYDoaGhaDQabty4oW5buXIlDe6+2SjypygKitmM1saGhPCznF+5itToqwBobG2p3ODRPPe/\nd/x51ZZPULXlE7nmNVSrmus2IUT5USrBIikpiZCQEPbt26d+K61+z8s5Im+W9HSSLl7KvGsIP8ud\n8HBqDX2Gmt26oK/qSAV3d5w7P4V9I6/MKQZsbTnoP6Csqy2EeIiUSrC4cuUKDg4OrFq1isOHD1Ox\nYkVeeuklfHx8gMyuJ0VRaNGiBVOmTMH+7rDIeyUkJJCQYD3qJSYmpjSqX+oykpPJuJOInbMT5pQU\njowcgyUtDQC7mjVxaN4cu7sPee2cnfGaNSNbGdKfLoQoTqUSLDIyMrhy5QqNGjVi+vTpnDhxgoCA\nAL7//ns2bNiAi4sLJpOJBQsWMG/ePJYtyz6CZd26daxatao0qlvq0uLiSDiTeceQcOYsSRERODb3\nptEbr6EzGnEf2B+jmxv2Xg3RV3XMv0CkP10IUbxKJVi4urpiY2NDr169AGjatCmOjo5cunSJxo0b\nA6DX6xk6dCjPP/98jmWMHDmSfv36WaXFxMQwbFjeD1gfNIrFQvKVSFKioqjetg2QObz0TvhZtHZ2\nVPZsgMfTg6jS+HF1H4/BA8uqukIIAZRSsKhatSqtWrXi4MGDtGvXjkuXLhEXF4eTkxN37tyhcuXK\nKIrCzp078fLyyrEMe3v7HLun/g2SIi5z6+gx9ZmDOSkJjY0Nji2aozMYqDNyBFpbWyrWrSNTJAgh\nHkilNhpq7ty5vPbaayxevBgbGxuWLFmCyWRiwoQJmM1mLBYL9erVY86cOaVVpUIryFvR6Ql3Mh9C\nnz2LWz9/bO3tufXrb0R8tgGjuzvVfdtg7+WFfaOG6stj9l4NS/U8hBCisEotWHh4ePDZZ59lSw8J\nCSmtKty3vN6K/vOD/5Jw5iwpd2dz1NjY4NjcmyqNH8e5cyecn/JT5zMSQoh/G3mDu5jcOHgIe6+G\nOHV8kspeDalUvx66uy8a2t4zB78QQvwbSbAoJq3Wf/KvfhtZCCHyIq1bAWUkJuW5XQKFEOJhJi1c\nAd089mtZV0EIIcqMBIs8KBYLyZcvA1DjyfZW6//eS96KFkI87OSZRS7SExI4/94qbp/6neYfvIeh\nRg1arf+krKslhBBlotwHi9zenUCjQaPTUXf0c+hl0kMhRDlX7oNFbu9OoCg0WbqISo88UroVEkKI\nB5A8s8iDBAohhMgkwUIIIUS+JFgIIYTIlwQLIYQQ+Sr3wUJrZ5djurw7IYQQfyvXo6FSrsaAxUK1\nNq1oOOPVsq6OEEI8sMpdsMjpvYq4Q4c5MnKMLEUqBBASFsXS0HNEx6fg6mBkWldP+nq7lXW1RBkr\nd8EirzUphCjvQsKimBl8ipR0MwBR8SnMDD4FIAGjnCv3zyyEEH9bGnpODRRZUtLNLA09V0Y1Eg8K\nCRZCCFV0fEqh0kX5IcFCCKFydTAWKl2UHxIshBCqaV09MdrqrNKMtjqmdfUsoxqJB0W5Cxa5vT9R\nnO9V+Pn58fPPP1ulBQcH88wzz6ifd+zYwaBBg2jWrBlt2rRh0KBBbNiwAUVRAJgxYwaenp6cPHlS\n3SciIgJPT/lPK0pOX283FvVvjJuDEQ3g5mBkUf/G8nBblL/RUC3XreXMWwtJu3Yd75XvlEkdgoKC\nWLNmDbNnz6Zdu3ZUrFiR8PBw1q5dy6BBg9Dr9QA4ODjw7rvvEhQUVCb1FOVTX283CQ4im1ILFmlp\naSxcuJBDhw5hMBho1qwZ8+fP59KlS8yYMYP4+HgcHBxYvHgxderUKdG6VGvTGiU9o0SPkZs7d+6w\ncuVKFi9eTNeuXdX0Ro0asXz5cqu8ffv2Zfv27Rw5coSWLVuWdlWFEEJVasFi6dKlGAwGQkND0Wg0\n3LhxA4A5c+YwdOhQ/P392bp1K7Nnz+bTTz8t0bo4d/K7r/3v56WlsLAwTCYTnTp1yjevnZ0dEyZM\n4J133mHjxo33VWchhLgfpRIskpKSCAkJYd++fWg0GgCqV69OXFwcZ86c4eOPPwagV69ezJ8/n5s3\nb1K1atUSqUtGcgoZdxIw1KiBRlv4RzYFfWlp4sSJ6HR/PyhMT0+nUaNG3Lp1C0dHR2xs/r70Q4YM\n4c8//8RkMrF27VqeeOIJq21BQUHs27evxO+4hBAiN6XygPvKlSs4ODiwatUq+vfvz4gRIzh27BhX\nr17F2dlZbVR1Oh1OTk5cvXo1WxkJCQlERkZa/cTExBS6LrdPnOTX8S+QeOFikc6loC8tffDBBxw7\ndkz9mTNnDpD5HOLWrVtkZPzdDbZp0yaOHTuGg4MDFovFqhy9Xs8LL7zAe++9pz78FkKI0lYqwSIj\nI4MrV67QqFEjgoODCQwMZPLkySQnJxe4jHXr1tGpUyern2HDhhW6LilRUQAY3VwLvS/c/0tL3t7e\n6PV6du/eXeBj9u/fn8TERL7//vsC7yPKt/fff5/AwMCyroa469ixY1bPKP+NSqUbytXVFRsbG3r1\n6gVA06ZNcXR0xM7OjtjYWMxmMzqdDrPZzLVr13BxcclWxsiRI+nXr59VWkxMTKEDRkpUNLaOjthU\nqFC0c3EwEpVDYCjoS0v29vZMnDiRuXPnoigK7du3x2g0cu7cOVJScg44NjY2TJo0iQULFhSpzqLs\n7dixg08++YTz589jNBpxd3enb9++DB06VO2aLUuenp4YjUY0Gg2VKlWiR48evPrqq1Zdqf9Ghw8f\nZtq0aezfvz/XPDNmzGD79u3Y2tpia2vLY489xuuvv069evUKdAxPT0++++47ateunWseHx8fQkND\nC13/B0mp3FlUrVqVVq1acfDgQQAuXbpEXFwcderUwcvLi+3btwOwfft2vLy8cnxeYW9vj7u7u9VP\nzZo1C12XlKioIt9VQPG8tDRu3DhmzJjBmjVraNu2LW3btmX27NkEBgbi7e2d4z69evWiRo0aRa63\nKDtBQUEsWLCAMWPG8NNPP/Hzzz8zd+5cfvvtN9LT03Pcx2w255hekrZu3UpYWBjr169n586dbNmy\npVSPf2/XbF5pJWHMmDGEhYWxf/9+nJ2dmTVrVrGVXVrnUOLHVUrJ5cuXleHDhyu9evVS+vbtq+zd\nu1dRFEX5888/lYEDBypdunRRBg4cqFy4cKHAZV65ckVp0KCBcuXKlQLlt1gsyi9Dn1X+/PB/RTqH\nLF//Fqm0XbRbqTN9u9J20W7l698i76s88fBKSEhQmjZtqnz77bd55ps+fboye/ZsZezYsUrTpk2V\ngwcPKj/++KPi7++veHt7Kx06dFBWrlyp5s/6t79p0ybF19dX8fX1VdauXatuX7lypfLiiy8q06ZN\nU5o1a6b06NFDOXnyZK7Hb9CggfLXX3+pn1988UXlzTfftDqPmTNnKr6+vkq7du2UFStWKBkZGer2\nL774QunWrZvSrFkzpXv37srvv/+eY7nTp09XVqxYoSiKovzyyy9K+/btlY8++khp27atEhgYmGOa\noijKnj17lD59+igtWrRQnn76aSU8PFwts2PHjsqaNWuUXr16Kc2bN1deeuklJTU1VUlKSlIaN26s\neHp6Ks2aNVOaNWumxMTE5Hjts+qkKIqyd+9epWnTplZ5vvrqK6Vbt26Kj4+PMnr0aCUyMvP//NCh\nQ5UGDRooTZs2VZo1a6bs2LEjz/PKEhMTo0yaNElp1aqV0rFjR2XdunVqeuPGjZVbt26peU+fPq20\nbNlSMZlMedYl63qvX79e6dy5s9KxY8dc/rYL33YqiqKUWrAoCYUOFmazcv3AT0rC2XMlXDMhMu3b\nt0/x8vJS0tPT88w3ffp0pXnz5sqxY8cUs9mspKamKr/88oty9uxZxWw2K+Hh4UqbNm2U77//XlGU\nv//tv/LKK0pSUpJy9uxZpVWrVsrBgwcVRckMFo8//riyd+9eJSMjQ1m2bJkyaNCgXI9/b6P+559/\nKr6+vsrHH3+sbn/++eeVN954Q0lKSlJu3LihDBgwQNm4caOiKIqyc+dOpV27dsqJEycUi8Wi/PXX\nX2oDll+w8PLyUpYsWaKkpaUpKSkpOab9/vvvSuvWrZXjx48rGRkZSnBwsNKxY0clLS1NUZTMYDFg\nwAAlJiZGuXXrltKtWzfl888/V49xbyOd27XPqlNSUpISGBio9O7dW93+/fffK0899ZTy559/Kunp\n6coHH3ygPP300zleu7zOK6seZrNZ6devn/L+++8raWlpyuXLlxU/Pz9l//79iqIoyogRI5QvvvhC\nLe/tt99W3njjjQLX5bnnnlNu3bqlpKSk5HrORQkW5Wq6D41WS/V2vlT2bFDWVRHlRG5DpX18fGjS\npAlHjx5V0zt16kSLFi3QarUYDAZatWqFp6cnWq2Whg0b0rNnT44cOWJV/sSJE6lQoQKenp70799f\n7dIFaNGiBU8++SQ6nQ5/f3/Onj2bZ1379etHs2bN6NGjBy1btmTo0KEA3Lhxg/379/Paa69RoUIF\nqlWrxnPPPceOHTsA2Lx5M2PHjqVJkyZoNBpq166Nm1vB3jvSarW8+OKL6PV67O4ucfzPtC+//JKn\nn36apk2botPp6NevH7a2thw/flwtZ8SIETg7O+Pg4EDHjh0JDw8v0PGzBAUF4ePjQ/Pmzfn1119Z\nsmSJum3Tpk2MHz+eevXqYWNjQ0BAAOHh4UTdHSxT0PPKcurUKW7evMmkSZPQ6/V4eHgwePBgdu7c\nCUDv3r3Vv0dFUdi5cye9e/cucF3Gjx+Pg4NDtuPer3I13UdyZBQZCQlU9myA5l/+4E78O9w7VDor\nYGzatAmADh06WA2V/ufAjhMnTrBs2TLOnz9Peno6JpOJbt26WeW5dx83Nzf++OMP9XP16tXV3+3s\n7EhLS7Oqxz99/fXX1KpVi127drF8+XKSk5PR6/VER0eTkZFBu3bt1LwWi0U99tWrV6lVq1ahrksW\nR0dHDAZDnmnR0dGEhISwfv16NS09PZ1r166pn+99nmc0Gq22FcTo0aN55ZVXiI6OZuzYsVy6dImG\nDRuqx1+4cCGLFy9W8yuKQmxsbK5BMafzyhIVFcW1a9fw8fFR08xms/q5a9euzJ8/n9jYWCIiItBo\nNOq2gtQlpwFCxaFcBYvY738gZue3tN60Pv/MQhSDe4dKF3bo5NSpUxk+fDhr1qzBYDCwYMECbt26\nZZXn6tWr6qid6OhonJyc7qu+Go2GHj16sHv3bj744ANmzZpFzZo10ev1/PLLLzkGGhcXFy5fvpxj\neUaj0WqU3/Xr13F2drY6Xk51+Gf5AQEBPP/880U6n8JwdXVl1qxZTJ8+nY4dO2JnZ6cev0+fPsVy\nXBcXF9zd3fnuu+9y3G5vb4+vry+7du3i4sWL9OzZUy2vIHUpqdF15aobKiUqGjuXmnJXIUrNvUOl\nv/32W5KSkrBYLISHh+c6VDpLUlISVapUwWAwcPLkSasupiwffvghKSkpnD9/nuDgYHr06FEs9R4/\nfjxffvkl169fx8nJCV9fX95++20SExOxWCxcvnxZ7RIbOHAgQUFB/P777yiKQkREhNot0rBhQ7Zv\n347ZbGb//v1W3W4FNWjQIDZt2sSJEydQFIXk5GT27t1LYmJivvtWq1aN+Ph47ty5U+Dj+fr64uTk\nxBdffAFkdhuuXr2a8+fPA5nzu+3atUvNX716da5cuVLg8ps0aUKlSpVYvXo1qampmM1m/vjjD6sZ\npnv37s3WrVsJDQ1Vu6AKUpeSVK7uLFKioqgoU2aIUjZu3DicnZ1Zs2YN06dPx2g04uHhkedQacic\nN23x4sXMmzePli1b0r17dxISEqzytGzZks6dO6MoCqNHj7bqKrofnp6ePPHEE6xdu5YZM2awZMkS\nli1bRo8ePUhKSsLDw4Nx48YB0L17d+Lj45k6dSrXrl3Dzc2NJUuW4ObmxqxZs5gxYwYbNmzgqaee\n4qmnnip0XRo3bsz8+fOZN28eERER2NnZ0bx5c6tunNzUq1ePnj178tRTT2E2m9mxY4fVnU1uxo4d\ny6JFi3jmmWfo3LkzSUlJTJkyhaioKCpXrkzbtm3p3r07AJMmTWLGjBmkpqYyb948qlWrlmfZOp2O\n//73vyxevJhOnTphMpmoW7cuL7/8sprHz8+PWbNm4erqqnaHAfnWpSRpFOXfO4dEZGQknTp1Yvfu\n3bi7u+eZ15KezqHBQ3Ef2J/aw57JM68QD7qsf/unT5/O9RmEELkpTNuZpdx0Q6XGxILFcl8v5Akh\nRHlVbr6SGJxq0PjtBdiV0EgBIYR4mJWbYKEzGLD3aph/RiH+Bdzd3Tl37lz+GYUoJuWmGyrul8Pc\nPFL4kRhCCCHK0Z1FVHAIGltbqrZ8Iv/MQgghrJSLOwtFUUiJisZYwCkIhBBCWCsXwSIjIYGMxEQq\nuEuwEEKIoigXwSIlKhoo+up4QghR3pWTYHF3KVW5sxBCiCIpFw+4nTr54dCsKfp8XsMXQgiRs3IR\nLDRaLQZZklQIIYqsXHRDRXy2gbjD8o6FEEIU1UMfLCzp6UQGh5B4d0pfIYQQhffQBwuZQFAIIe7f\nQ/vM4sjIMaTHx6ufz7/7PufffR9bBwdarltbhjUTQoh/nyLfWaSnp/Pss88WZ12K1b2BoiDpQggh\nclfkOwtFUQq1RKKfnx96vV5dxDwwMJD27dvj6elJgwYN0Goz49aSJUvw9PQsarWEEEKUgDyDRadO\nnXLdVpQF9lauXEmDBg2ypW/atImKFSsWujwhhBClI89gcfv2baZPn57jsnsmk4mAgIASq5gQQogH\nR57BolGjRhgMBtq0aZNtm8lkKvTdRWBgIIqi0KJFC6ZMmYK9vT0AI0aMwGw206FDByZPnoxer8+2\nb0JCQrbF6mNiYgp1fCGEEEWTZ7CYOHEiRqMxx222trZ8+umnBT7Qhg0bcHFxwWQysWDBAubNm8ey\nZcvYu3cvLi4uJCYmMm3aND744ANeeeWVbPuvW7eOVatWFfh4tg4OOT7MtnVwKHAZQgghMmmUojx8\nuE/nzp3j+eefZ8+ePVbpe/bs4eOPP+azzz7Ltk9udxbDhg1j9+7dOXaVCSGEyC4yMpJOnToVqu3M\nc+jsW2+9ZfX55MmTRapYcnIyd+7cATIfjO/cuRMvLy9u375NamoqABkZGYSGhuLl5ZVjGfb29ri7\nu1v91KxZs0j1EUIIUTh5dkMFBwfz+uuvq5/Hjh3LkSNHCn2QuLg4Jk+ejNlsxmKxUK9ePebMmcPF\nixeZPXs2Go2GjIwMvL29eemllwp/FkIIIUpUnsHinz1URe2x8vDwICQkJFu6k5MT27ZtK1KZQggh\nSk+e3VAajSbPz0IIIcqHPO8sUlNTGTZsmPo5KSnJ6jNkjnISQgjxcMszWCxYsMDq88CBA0u0MuLh\ncOzYMWbOnMmNGzdYsWIFHTt2LOsqCSHuU57Bol+/fqVVD3GfvL291d9TUlLQ6/XodDoA5s6dS58+\nfUqtLu+++y7PPfdctrvQB8WIRI7AAAAgAElEQVQnn3xCcHAwW7ZswdbWFoC1a9eyY8cOvvrqKyIj\nI+nSpQvnzp2z2i8wMJDatWszefLksqi2EGXqoZ2ivLwJCwtTf/fz8+Ott96ibdu2uebPyMjAxqZk\n/vqjo6OpX79+kfYtyXpllf/ss8+ya9cuVq9ezcSJE4mIiODDDz9kw4YNaoAVQliTYPEvEhIWxdLQ\nc0THp+DqYGRaV0/6ersVaN933nmHiIgItFotP/74I2+88QZ169Zl0aJFXLx4ETs7O7p168b06dOx\ntbUlIyODxx57jLlz57J27Vri4+Px9/dXh1JfunSJ119/nbNnz2JjY0O7du1Yvnw5fn5+REdHM27c\nOHQ6HceOHeP69evMmTOHsLAwHBwcGD9+vNqlmVO9IiIiuHz5MhqNhh9//BEPDw9WrVrF9u3b+fTT\nT7Gzs2PhwoVqMExISGDRokUcOHAArVbLgAEDmDx5Mlqtlq+++oqQkBC8vLz45ptvGDFiBJMnT2bB\nggUMHjyYLl26MHfuXEaMGEHDhg1L5i9OiIfAQ79S3sMiJCyKmcGniIpPQQGi4lOYGXyKkLCoApfx\nww8/0KtXL3799Vd69OiBTqdj1qxZ/PLLL2zcuJEDBw7wxRdfWO2zb98+goOD+frrr/nmm2/4+eef\ngcxG/sknn+To0aPs37+foUOHAplv4Ts5OfF///d/hIWFodPpePnll3F3d+fAgQO88847LF261Op9\nnX/WC2D37t0MHDiQo0eP8uijj/Lcc89hY2PDTz/9xPjx43nzzTfV/adNm4bBYOD7779ny5Yt7Nu3\njy1btqjbw8LCeOSRRzh06BDjxo0DoH79+owZM4bhw4dz8+ZNXnjhhUL9fQhR3hQoWOzatSvH9G+/\n/bZYKyNytzT0HCnpZqu0lHQzS0PP5bJHds2bN8fPzw+tVoudnR1NmjShadOm2NjY4OHhweDBg7O9\ndDlhwgQqV66Mu7s7LVu2JDw8HMicGywqKorr169jMBho0aJFjse8cuUKp06dYurUqRgMBh577DH6\n9+/P1q1bc60XQMuWLWnbti02NjZ069aN27dvM3bsWGxsbOjZsycREREkJSURGxvLoUOHeO211zAa\njdSoUYNnn32WnTt3quW7uLgwdOhQdDqdWj6Aj48P8fHxdOvWLcfJK318fKx+5N+7KM8K1A01a9Ys\nunfvni199uzZdOvWrdgrJbKLjk8pVHpOXFxcrD5fuHCBxYsXc/r0aVJSUjCbzTRp0sQqT/Xq1dXf\n7ezsSE5OBmD69Om89957DBgwAEdHR0aPHp3jgIhr167h6OhIhQoV1DQ3NzfOnz+fa73+eVyDwUDV\nqlXVBbKyGvzk5GSioqIwmUxWz2csFgtubn93z+VUflpaGnPmzGHEiBGsW7eOAQMGWO0DmaO67hUY\nGJitHCHKizyDxZUrV4DMN7ezfr93W07fxkTJcHUwEpVDYHB1yHlW4Jz886XKOXPm0LRpU9555x0q\nVqzI2rVr2bt3b4HKcnJyUodWHz16lFGjRuHj44OHh0e2fLdu3SI5OVkNGNHR0Tg7O+dar8JwcXHB\naDRy5MgRNZj8U07lr1q1ipo1a/L6669jY2PD7NmzWbtW1mYXIjd5dkN17tyZLl26kJKSQufOna1+\nXn31VRlCWIqmdfXEaGs9Usdoq2Na16IvQZuUlETlypWpUKECFy5cyPa8Ii87d+4kNjYWgMqVK6PR\naHIcSeTh4cHjjz/OihUrMJlMhIeHExwcTO/evYtc73u5uLjwxBNPsHjxYhITE7FYLEREROS55O/p\n06f5/PPPmT9/PgAvvfQSEREROU5JI4TIlOedxdmzZwEYPnw469evL5UKiZxljXoq6mionEyfPp05\nc+bw0Ucf0ahRI7p3785vv/1WoH1PnjzJwoULSUxMpEaNGsyePRtXV9cc877zzjvMmTMHX19fHBwc\nmDJlCq1bty5yvf9p6dKlLF++nB49epCUlISHhwfjx4/PMW9GRgavvfYakyZNUu+CjEYj8+fP55VX\nXqFDhw7FVi8hHiZlsp5FcSnKnOxCCFHeFaXtLNAD7itXrvDuu+8SHh6uPuDMUtA+biGEEP9eBQoW\ngYGBeHh4MH369FyXWRVCCPHwKlCwOH/+PBs3bsx1tIkQQoiHW4Fa/yeeeIIzZ86UdF2EEEI8oAp0\nZ+Hm5saYMWPo0qWL1ctSgCyDKoQQ5UCBgkVKSgp+fn5kZGQQExNT0nUSQgjxgClQsFi0aFFJ10MI\nIcQDLNdgERkZqY6//edUH/f65/QOQgghHj65BovevXurC+p07twZjUbDP9/f02g06iykQgghHl65\nBot7V17Lmvbjfvj5+aHX6zEYDEDmuxvt27fn+PHjzJ49m7S0NNzc3Fi6dCnVqlW77+MJIYQoPoVa\nKS86OprY2Fhq1qyZ47TP+Vm5ciUNGjRQPyuKwrRp01i0aBE+Pj58+OGHLFu2TJ6RCCHEA6ZA71lc\nu3aN4cOH06VLFyZPnkznzp0ZNmyYOutoUZ06dQqDwYCPjw8AQ4YMkQVmhBDiAVSgO4s333yThg0b\nsnr1aipUqEBycjIrVqxgzpw5/O9//yvwwQIDA1EUhRYtWjBlyhSuXr1qNVNp1apVsVgsxMfH4+Dg\nYLVvQkICCQkJVmkyjFcIIUpHgYLFr7/+ynvvvYetrS0AFSpU4NVXX6V9+/YFPtCGDRtwcXHBZDKx\nYMEC5s2bR+fOnQu8/7p161i1alWB8wshhCg+BeqGqlKlChcuXLBKu3jxIvb29gU+UNYzDr1ez9Ch\nQ/ntt99wcXEhOjpazXPz5k00Gk22uwqAkSNHsnv3bqufDRs2FPj4Qgghiq5AdxZjx47lueeeY+DA\ngbi6uhIdHU1wcHCBp/pITk7GbDZTuXJlFEVh586deHl58fjjj5OamsqxY8fw8fFh06ZNOa71DWBv\nb1+o4CSEEKL4FChYDB48mFq1arFt2zbOnTuHs7Mzy5cvp02bNgU6SFxcHJMnT8ZsNmOxWKhXrx5z\n5sxBq9WyZMkS5syZYzV0VgghxIMlz5XyNm/enPfOd7uMGjVqVKShtPdLVsoTQojCK/aV8rZu3Zpv\nAUlJSVy8eJFp06YxbNiwgtVUCCHEv0qeweKzzz4rUCHnz59n7NixEiyEEOIhVSxL3z366KP07t27\nOIoSQgjxACq2dVIDAwOLqyhRRr755htGjx5d1tXI0bFjx+jatWuJH8fPz4+ff/65xI8jxL+NLKpd\nAo4dO8aQIUNo0aIFLVu2ZMiQIZw8ebKsq2UlMjIST09PMjIy1LQ+ffoQFBRU6LJmzJjB448/jre3\nNy1btmTUqFHZ3svJi6enJxEREXnm8fHxITQ0tNB1E0IUDwkWxSwxMZGAgACGDx/OkSNH2L9/P5Mm\nTUKv15d11UrUmDFjCAsLY//+/Tg7OzNr1qxiK/vegCaEKBsSLIrZpUuXAOjVqxc6nQ47OzvatWtH\nw4YN1TybN2+me/fuPPHEE4wZM4aoqCh1m6enJxs2bKBLly54e3vz7rvvcvnyZZ5++mmaN2/OSy+9\nhMlkAuD27dtMmDCB1q1b88QTTzBhwgSr+bJGjBjBu+++y5AhQ/D29mb06NHcvHkTgOHDhwPwxBNP\n4O3tTVhYGMHBwTzzzDPq/ufPn2fUqFG0bNmStm3bFmgeMDs7O7p3755tWvvczjlrUIS/vz/e3t7s\n3LmTw4cP06FDB1avXo2vry8zZ85U07LExsYyefJkWrdujZ+fH59++qma3qRJE+Lj49W8Z86coVWr\nVqSnp3P58mWeffZZWrVqRatWrZg6dWq2OceEENlJsCiCkLAofN/eQ90ZO/B9ew8hYX839nXr1kWn\n0zF9+nT27dvH7du3rfb94Ycf+Oijj1i1ahWHDh2iRYsWTJ061SrPgQMHCA4O5ssvv2TNmjW88cYb\nLFu2jH379nH+/Hl27NgBgMVioX///vz444/8+OOPGAwG5s2bZ1XW9u3bWbRoEYcOHSI9PV3tZlq/\nfj0AR48eJSwsDG9vb6v9EhMTGTVqFO3bt+fAgQN89913BXoJMzk5me3bt1OrVq0CnXPWlC1bt24l\nLCyMHj16AHDjxg1u377Njz/+yPz5862OYbFYeP755/H09GT//v2sW7eOdevWceDAAZydnWnWrBnf\nffedmn/btm107doVW1tbFEVhwoQJHDhwgF27dhETE8P777+f73kJUd5JsCikkLAoZgafIio+BQWI\nik9hZvApNWBUqlSJzz//HI1GwxtvvEGbNm0ICAjgxo0bAGzatInx48dTr149bGxsCAgIIDw83Oru\nYty4cVSqVIlHH32UBg0a4Ovri4eHB5UrV6ZDhw6cOXMGAEdHR7p27YrRaKRSpUo8//zzHD161Kq+\n/fv3p27dutjZ2dGtW7cCr2y4d+9eqlevzujRozEYDFSqVImmTZvmmj8oKAgfHx+aN2/Or7/+ypIl\nS9RtBTnnf9Jqtbz44ovo9Xrs7Oystp06dYqbN2+q3XseHh4MHjyYnTt3ApmrPG7fvh1AnV4ma7Re\n7dq18fX1Ra/XU7VqVUaNGpXtmgkhsivU4kcCloaeIyXdbJWWkm5maeg5+nq7AVCvXj3efvttAC5c\nuMC0adNYuHAhK1asIDo6moULF7J48WJ1f0VRiI2Nxc0tc//q1aur2wwGQ7bPWYEnJSWFRYsWceDA\nAfUOJikpCbPZjE6nA6BGjRrqvkajkeTk5AKd59WrV63uDvIzevRoXnnlFaKjoxk7diyXLl1Su94K\ncs7/5OjoqK6q+E9RUVFcu3ZNXQcFwGw2q5+7du3K/PnziY2NJSIiAo1Go26Li4vjrbfe4tixYyQl\nJaEoisw5JkQBSLAopOj4lEKl16tXj/79+/PFF18AmbPvBgQE0KdPn/uuS1BQEJcuXeLLL7+kRo0a\nhIeH07dv32xrpedEo9Hkud3FxUXt7ioMV1dXZs2axfTp0+nYsSN2dnZFOue86ufi4oK7u7tVV9O9\n7O3t8fX1ZdeuXVy8eJGePXuq5S1fvhyNRsM333yDo6MjP/zwQ7auOyFEdtINVUiuDsY80y9cuEBQ\nUJD6oPnq1ats375d7cIZMmQIq1ev5vz58wDcuXOHXbt2FakuSUlJGAwG7O3tiY+PL9R6H1WrVkWr\n1XLlypUct//nP//hxo0bfPLJJ5hMJhITEzlx4kSByvb19cXJyUkNkPmdc/Xq1XOtR06aNGlCpUqV\nWL16NampqZjNZv744w+r4cm9e/dm69athIaGWr0wmpSURIUKFbC3tyc2NpY1a9YU+LhClGcSLApp\nWldPjLY6qzSjrY5pXT2BzGcWJ06cYNCgQTRr1ozBgwfToEEDZsyYAUDnzp0ZO3YsU6ZMoXnz5vTq\n1Yv9+/cXqS4jR44kLS2N1q1b8/TTTxdqMSqj0UhAQADPPPMMPj4+HD9+3Gp7pUqVCAoK4scff8TX\n15euXbty+PDhApc/duxY1qxZg8lkyvecJ02axIwZM/Dx8VGfO+RFp9Px3//+l7Nnz9KpUydat27N\n66+/TmJioprHz8+Pv/76i+rVq1uNRJs0aRJnzpzBx8eH8ePH06VLlwKfkxDlWZ6zzj7oymrW2ZCw\nKJaGniM6PgVXByPTunqqzyuEEOJBV+yzzoqc9fV2k+AghChXpBtKCCFEviRYCCGEyJcECyGEEPmS\nYCGEECJfEiyEEELkS4KFEEKIfEmwEEIIkS8JFkIIIfJV6sFi1apVeHp68scffwCZi/307t0bf39/\n/P39OXfuXGlXSQghRD5K9Q3u06dPc/z4cVxdXa3SN23aRMWKFUuzKkIIIQqh1IKFyWRi3rx5LFu2\njJEjRxZ6/4SEhGzLX967hKgQQoiSU2rB4r333qNPnz54eHhk2zZixAjMZjMdOnRg8uTJ6PX6bHnW\nrVtXqCm4hRBCFJ9SCRZhYWGcOnWKwMDAbNv27t2Li4sLiYmJTJs2jQ8++IBXXnklW76RI0fSr18/\nq7SYmBiGDRtWYvUWQgiRqVSCxdGjR7l48SKdOnUCMhv5MWPGsGjRItq1awdkrp8waNAgPv744xzL\nsLe3l+UvhRCijJRKsBg/fjzjx49XP/v5+fG///0PZ2dnUlNTsbOzIyMjg9DQULy8vEqjSkIIIQqh\nTNezuHjxIrNnz0aj0ZCRkYG3tzcvvfRSWVZJCCFEDsokWOzZs0f9fdu2bWVRBSGEEIUgb3ALIYTI\nlwQLIYQQ+Sq3a3D7+flx48YNdDodOp2O+vXr4+/vz9NPP41Wm3MMnTFjBs7OzurQ3vPnzzNq1ChG\njx7N6NGj8fPz46233uLkyZN89NFHAGRkZJCRkYGdnR0Arq6u7Nixo3ROUgghikm5DRYA//vf/2jb\nti137tzhyJEjLFiwgJMnT7Jo0aJsec1ms9Xn8PBwRo8ezcSJExk+fLjVtoCAAAICAgAIDg7mq6++\nYuPGjSV3IkIIUcIe2mAREhbF0tBzRMen4OpgZFpXT/p6u+WYt3LlynTq1IkaNWowePBgRo0aRVBQ\nEAaDgejoaI4ePcqHH36o5j958iTjxo0jMDCQQYMGldYpCSFEmXkon1mEhEUxM/gUUfEpKEBUfAoz\ng08REhaV535NmjShZs2aHDt2DIDt27cTEBDAb7/9RosWLQA4deoUY8eOZebMmRIohBDlxkMZLJaG\nniMl3brbKCXdzNLQ/Kc/d3Jy4vbt2wB06tSJFi1aoNVqMRgMABw/fpxKlSrRoUOH4q+4EEI8oB7K\nYBEdn1Ko9HvFxsZSpUoVAFxcXLJtHzZsGI0bN2b06NFqUBFCiIfdQxksXB2MhUrPcvLkSWJjY9Uu\np5xotVqWLVuGi4sLY8aMITEx8b7qKoQQ/wYPZbCY1tUTo63OKs1oq2NaV88c8ycmJvLjjz8yZcoU\n+vTpg6dnzvmy2Nra8t577+Ho6Mi4ceNITk4utroLIcSD6KEcDZU16im/0VABAQHodDq0Wi3169dn\n1KhRDBkypEDH0Ov1rFq1igkTJhAQEMDq1auL/TyEEOJBoVEURSnrShRVZGQknTp1Yvfu3bi7u5d1\ndYQQ4l+hKG3nQ9kNJYQQonhJsBBCCJEvCRZCCCHyJcFCCCFEviRYCCGEyJcECyGEEPmSYCGEECJf\nEiyEEELkS4KFEEKIfEmwEEIIka9SDxarVq3C09OTP/74A8hcH6JPnz507dqV0aNHExcXV9pVEkII\nkY9SDRanT5/m+PHjuLq6AqAoCtOmTWP27NmEhobi4+PDsmXLSrNKQgghCqDUgoXJZGLevHnMmTMH\njUYDZC5RajAY8PHxAWDIkCF8++23Oe6fkJBAZGSk1U9MTExpVT+b4OBgnnnmmVy3jxgxgq+++qoU\naySEECWn1KYof++99+jTpw8eHh5q2tWrV9W7DICqVatisViIj4/HwcHBav9169axatWqYq/Xjh07\n+OSTTzh//jxGoxF3d3f69u3L0KFD1aB2v4KDg5k1axZ2dnZW6d9++y3Ozs7FcgwhhChJpRIswsLC\nOHXqFIGBgUUuY+TIkfTr188qLSYmhmHDhhW5zKCgINasWcPs2bNp164dFStWJDw8nLVr1zJo0CD0\nen2Ry/6nZs2asXHjxmIrTwghSlOpBIujR49y8eJFOnXqBGQ28mPGjGHEiBFER0er+W7evIlGo8l2\nVwFgb2+Pvb19sdXpzp07rFy5ksWLF9O1a1c1vVGjRixfvlzNM3/+fPbv34/RaGTQoEEEBASg1Wbv\nvTt48CDz58/n+vXr+Pv7U5hlQvz8/Bg2bBghISFER0fTvn17Fi9ejMFguP8TFUKIYlAqzyzGjx/P\nTz/9xJ49e9izZw81a9Zk7dq1jB07ltTUVI4dOwbApk2b6N69e7EcMyQsCt+391B3xg58395DSFiU\n1fawsDBMJpMawHIyf/587ty5ww8//MBnn33G1q1b2bJlS7Z8N2/eZPLkybz88sv88ssv1KpVi99+\n+61Q9d21axdr1qxh9+7dnDt3juDg4ELtL4QQJalMl1XVarUsWbKEOXPmkJaWhpubG0uXLr3vckPC\nopgZfIqUdDMAUfEpzAw+Bfy95OqtW7dwdHTExubvSzBkyBD+/PNPTCYTa9asYefOnYSEhFCpUiUq\nVarEqFGj+Oabbxg0aJDV8fbv30/9+vXp1q0bkNllFhQUZJXnxIkT6oN8AAcHB3744Qf184gRI9Tn\nFx07diQ8PPy+r4MQQhSXMgkWe/bsUX9v3rw527ZtK9byl4aeUwNFlpR0M0tDz6nBwsHBgVu3bpGR\nkaEGjE2bNgHQoUMHbty4QXp6utUDeFdXV2JjY7Md79q1a9SsWVP9rNFocHFxscrTtGnTPJ9Z1KhR\nQ/3daDRy7dq1gp6uEEKUuIfyDe7o+JR80729vdHr9ezevTvHvI6Ojtja2lo9U7l69WqOo5dq1Khh\nNYxXURSuXr1a1OoLIcQD56EMFq4OxnzT7e3tmThxInPnzuXbb78lKSkJi8VCeHg4KSkpaLVaunXr\nxjvvvENiYiJRUVF8/PHH9OnTJ1u5Tz75JOfPn+e7774jIyODTz/9lBs3bpTY+QkhRGkr02cWJWVa\nV0+rZxYARlsd07p6WuUbN24czs7OrFmzhunTp2M0GvHw8CAwMBBvb28aNmzI/PnzeeqppzAYDAwa\nNIgBAwZkO17VqlV57733WLBgATNnzsTf35/mzZtb5Tl+/Dje3t5WaevWraNJkybFeOYCID09ncjI\nSFJTU8u6KkKUOTs7O9zd3bG1tb2vcjRKYcZ4PmAiIyPp1KkTu3fvxt3d3WpbSFgUS0PPER2fgquD\nkWldPdXnFeLhdunSJSpXrky1atWK7cVKIf6NFEUhLi6OO3fuULduXTU9r7YzNw/lnQVkjnqS4FA+\npaamUqdOHQkUotzTaDRUq1aN69ev33dZD+UzCyEkUAiRqbj+L0iwEEIIka+HthtKiLLk5eVFgwYN\nMJvNuLu7s2TJkmKdrqagZs2axahRo6hfv/59lTNixAiuXbuGXq8nPT2dtm3b8vLLL5fJORX22iYk\nJLBt27Y855EbMmSI+p5VSZs4cSKRkZEkJydz8+ZN9ZnBnDlzsg2MyU9gYCDdunXjqaeeKomqWpFg\nIcq1IyPHkB4fny3d1sGBluvWFrlcOzs7tm7dCsD06dPZsGEDzz//fJHLy4vZbEan0+W4bcGCBcV2\nnGXLltG4cWNMJhMrVqzghRdeYP369cVWfkEV9tomJCSwcePGHINF1rUryUChKAqKoqhzyn3wwQcA\nHD58mKCgID766KMSO3Zxkm4oUa7lFCjySi+KZs2aWb35v2bNGgYMGEDv3r1ZuXKlmh4SEkLv3r3p\n06cP06ZNA2DGjBlWa7xkDb8+fPgwI0aMYOrUqfTu3Zvk5GTGjx9Pnz596NWrFzt37gQy7whOnTrF\n559/zpIlS9RygoODmT9/PgBbt25l4MCB+Pv7M3v2bMxm69kP/kmv1zNt2jSio6M5e/ZsnmV4e3vz\nzjvv0KdPHwYPHqy+f7Rr1y569epFnz591EbcbDazePFi9doUpAEvyLVdvnw5ly9fxt/fn8WLF2e7\ndvde19zKWLp0KRs2bFDzvP/+++qUPjnlj4yMpHv37rz55pv069evwC/pHjx4EH9/f3r37s3rr7+O\nyWQCMmeVWLZsGQMHDmTQoEFcuXIl277Lly/ntddew2KxFOhYhSV3FuKhd2rW7Gxp1X3b4tKjW777\npickcHax9eqNjRfMK/CxzWYzhw4dYuDAgQD89NNPREREsHnzZhRF4fnnn+fo0aM4ODjw3//+l40b\nN1K1alXiCxCsTp06xbZt2/Dw8CA0NBQnJydWr14NZM6YfK9u3brx9NNP8+qrrwKwc+dOAgICuHDh\nArt27WLjxo3Y2try5ptvsm3bNvr27ZvnsXU6HQ0bNuTixYvY2trmWkZycjJNmzbllVdeYcmSJXz5\n5Ze88MILfPjhh6xduxZnZ2cSEhIA2Lx5M5UrV2bLli2YTCaGDBmCr6+v1Ro4Rbm2U6dO5fz58+rd\nyOHDh62u3b1yK6Nnz54sXLhQDWxZE3/mlt/FxYVLly6xaNEi3nzzzXz/LgFSUlJ47bXX+Oyzz6hV\nqxZTp07lyy+/ZPjw4UDmi8SbN29m8+bNLFq0iA8//FDdd9GiRZhMJhYsWFBigzskWAhRAlJTU/H3\n9ycqKorHHnsMX19fIPOb48GDB9XGODk5mb/++ovU1FS6detG1apVAXKcpv+fGjdurDZ2DRo0YPHi\nxSxdupSOHTtaTVoJmS+Oenh4cPz4cWrXrs2lS5do0aIFGzZs4Pfff1cb3NTUVKpVq1agc8x6RevQ\noUO5lmFra0vHjh0BePzxxzl48CCQ+U1+xowZdO/enc6dO6vX5ty5c4SGhgKZAS8iIiJbg17Ya/vP\nedr+ee3ulVsZgwYNIi4ujtjYWG7duoW9vT2urq589tlnuR7T1dWVZs2aFehaAly4cIHatWtTq1Yt\nAPr27cvmzZvVYNGrVy8A+vTpoy6jALBy5Uq8vb2ZO3dugY9VFBIsxEOvMHcC/2Rrb1+k/bP61e/c\nucOECRPYsGEDzz77LIqiMH78eIYMGWKV/9NPP82xHJ1Op3YrKIpCenq6uq1ChQrq73Xr1iU4OJh9\n+/axfPlyfH19mTRpklVZ3bt3Z9euXTzyyCN07twZjUaDoij069ePqVOnFur8zGYzf/zxB4888ghx\ncXG5lmFra6t+09VqtWr31Lx58zhx4gR79+6lb9++hISEoCgKr7/+Ou3bt8/z2IW9tpGRkdnKuPfa\n3Su3MgC6du1KaGgoN27coGfPnnnmj4yMzPUYucnv/ejc7hiaNGnC77//zu3bt6lSpUqhjlkY8sxC\niBJUuXJlXn/9dYKCgkhPT6ddu3Zs2bKFpKQkAGJjY4mLi6NNmzZ8++233Lp1C0DthnJzc+P06dMA\n7N692ypY3Cs2Nhaj0aA2EG4AAAySSURBVIi/vz9jxozhzJkz2fJ06dKFH374ge3bt9OjRw8A2rRp\nQ2hoKHFxcepxo6Kisu17r/T0dJYvX46LiwsNGzYsUhmXL1+madOmvPTSSzg6OhITE0O7du3YuHGj\neo6XLl0iOTk51zIKem0rVqyopuUntzIAevbsyc6dOwkNDVUXTMsrf2HVr1+fiIgI9XnEN998Q8uW\nLdXtWc+htm/fbjVq6j//+Q+jR49mwoQJBT7PopA7C1Gu2To45Doaqrg0atSIhg0bsmPHDvr27cuF\nCxfUb6IVKlRg6dKlPProowQEBDBixAi0Wi2NGjXi7bffZvDgwbzwwgsMHDiQNm3a5Ppt9Y8//mDJ\nkiVotVpsbGxy7CevUqUK9evX588//1TnJKtfvz4vv/wyo0ePxmKxYGtry+zZs3Fzyz77QWBgIHq9\nHpPJRNu2bdU+88KUkWXJkiVERESgKAqtW7emYcOGeHp6EhUVRf/+/VEUBUdHR6t++aJe21q1atG8\neXN69epF+/bt+c9//pNree3atcuxjGrVqvHoo4+SlJSEk5MTTk5OeebPaTXN/BiNRhYsWMCkSZOw\nWCw0adLEau2clJQUBg4ciEajYcWKFVb79uzZk6SkJF544QVWr15dIqtsPrRzQ4nyKzw8HC8vr7Ku\nhhDFpkOHDmzfvr3I77X88/9EUdpO6YYSQgiRL+mGEkKIB9z+/fvLugpyZyEeTv/i3lUhilVx/V+Q\nYCEeOnZ2dsTFxUnAEOVe1noWdnZ2912WdEOJh467uzuRkZHFMoe/EP92WSvl3S8JFuKhY2tra7Uq\nmBDi/pVasHjhhReIjIxEq9VSoUIF3njjDby8vPDz80Ov16vjggMDA/N9g1MIIUTpKrVgsXjxYipX\nrgzADz/8wGuvvcbXX38NZM5t0qBBg9KqihBCiEIqtWCRFSgAEhMTCz0zYkJCgjo7ZZasKQViYmLu\nv4JCCFFOZLWZ+U1Hf69SfWYxa9YsDh48iKIorFmzRk0PDAxEURRatGjBlClTcnxLcd26daxatSrH\ncvNaAUsIIUTOrl+/Tu3atQuUt0ym+wgJCWHHjh383//9H//f3r2GNPm+cQD/6siBmefDppKWoEkG\nDkUjUrEEM5QCCcw0oROBiUUgZjJ/mUQrQyMGovUiSTqQNdJMw8QwMlI0SuzklDwtNZ2Wp8bc838R\nPv/U7ZmZPrNxfd7Nze27+77o7n7crlulUkEsFrO92CcmJpCfn7/gd/TtLDQaDXp6euDt7b3gpLCv\nX7/iwIEDKCsrg0gkWtH386co29JQtqWhbEtjztlmZmYwNDSEgICARX+s1iSfhtq7dy+kUinUajXb\na97KygqJiYkGj0e0tbXVu+PYuHEj52uJRKJV2zeKsi0NZVsayrY05pptsTuKWbx8KW9iYmLOsYJ1\ndXWws7ODUChkT/RiGAZVVVXUAI4QQlYhXnYWU1NTSE9Px9TUFCwtLWFnZ4eioiIMDw8jLS0NMzMz\n0Ol08PHxQU5ODh+RCCGE/AFeFgtnZ2fcu3dP730KhYKPCIQQQv6C4L/Fnib+DxIKhQgNDV2Rg0D+\nFmVbGsq2NJRtaSjb//3Thx8RQgjhB3WdJYQQYhQtFoQQQowy266zXV1dyMzMxOjoKOzt7SGTyeDt\n7c17DrVajYyMDHR3d8PKygpeXl7Izc2Fo6Mj/Pz84Ovryx7ufunSJfj5+fGaz1Ajxzdv3kAqleLn\nz5/w8PBgD63nS29vL1JTU9nbP378wPj4OF6/fm2S5pMymQw1NTXo6+tDRUUF2
gitextract_529i5290/
├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── mdr/
│ ├── __init__.py
│ ├── qa/
│ │ ├── __init__.py
│ │ ├── basic_tokenizer.py
│ │ ├── config.py
│ │ ├── data_utils.py
│ │ ├── hotpot_evaluate_v1.py
│ │ ├── qa_dataset.py
│ │ ├── qa_model.py
│ │ ├── qa_trainer.py
│ │ ├── train.md
│ │ ├── train_ranker.py
│ │ └── utils.py
│ └── retrieval/
│ ├── __init__.py
│ ├── config.py
│ ├── criterions.py
│ ├── decomposed_analysis.py
│ ├── fever.ipynb
│ ├── hotpot.ipynb
│ ├── interactive_retrieval.py
│ ├── mhop_trainer.py
│ ├── single_trainer.py
│ ├── train_single.py
│ └── utils/
│ ├── basic_tokenizer.py
│ ├── gen_index_id_map.py
│ ├── mhop_utils.py
│ ├── tokenizer.py
│ └── utils.py
├── requirements.txt
├── scripts/
│ ├── add_sp_label.sh
│ ├── demo.py
│ ├── download_hotpot.sh
│ ├── encode_corpus.py
│ ├── end2end.py
│ ├── end2end.sh
│ ├── eval/
│ │ ├── eval_mhop_fever.py
│ │ ├── eval_mhop_retrieval.py
│ │ ├── eval_reranked.py
│ │ ├── eval_retrieval.py
│ │ └── eval_single_fever.py
│ ├── train_mhop.py
│ ├── train_momentum.py
│ └── train_qa.py
├── setup.py
├── setup.sh
└── submitit/
├── submit_retrieval.sh
├── submitit_qa.sh
├── submitit_train.py
└── submitit_train_qa.py
SYMBOL INDEX (222 symbols across 29 files)
FILE: mdr/qa/basic_tokenizer.py
class Tokens (line 16) | class Tokens(object):
method __init__ (line 25) | def __init__(self, data, annotators, opts=None):
method __len__ (line 30) | def __len__(self):
method slice (line 34) | def slice(self, i=None, j=None):
method untokenize (line 40) | def untokenize(self):
method words (line 44) | def words(self, uncased=False):
method offsets (line 55) | def offsets(self):
method pos (line 59) | def pos(self):
method lemmas (line 67) | def lemmas(self):
method entities (line 75) | def entities(self):
method ngrams (line 83) | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True):
method entity_groups (line 110) | def entity_groups(self):
class Tokenizer (line 132) | class Tokenizer(object):
method tokenize (line 137) | def tokenize(self, text):
method shutdown (line 140) | def shutdown(self):
method __del__ (line 143) | def __del__(self):
class RegexpTokenizer (line 153) | class RegexpTokenizer(Tokenizer):
method __init__ (line 172) | def __init__(self, **kwargs):
method tokenize (line 196) | def tokenize(self, text):
class SimpleTokenizer (line 236) | class SimpleTokenizer(Tokenizer):
method __init__ (line 240) | def __init__(self, **kwargs):
method tokenize (line 254) | def tokenize(self, text):
FILE: mdr/qa/config.py
class ClusterConfig (line 12) | class ClusterConfig(NamedTuple):
function common_args (line 16) | def common_args():
function train_args (line 57) | def train_args():
FILE: mdr/qa/data_utils.py
function explore (line 11) | def explore(path):
function load_corpus (line 23) | def load_corpus(corpus_path="/private/home/xwhan/data/hotpot/tfidf/abstr...
FILE: mdr/qa/hotpot_evaluate_v1.py
function normalize_answer (line 13) | def normalize_answer(s):
function f1_score (line 31) | def f1_score(prediction, ground_truth):
function exact_match_score (line 54) | def exact_match_score(prediction, ground_truth):
function update_answer (line 57) | def update_answer(metrics, prediction, gold):
function update_sp (line 66) | def update_sp(metrics, prediction, gold):
function eval (line 88) | def eval(prediction_file, gold_file):
FILE: mdr/qa/qa_dataset.py
function collate_tokens (line 17) | def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_e...
function prepare (line 38) | def prepare(item, tokenizer, special_toks=["[SEP]", "[unused1]", "[unuse...
class QAEvalDataset (line 108) | class QAEvalDataset(Dataset):
method __init__ (line 110) | def __init__(self,
method __len__ (line 150) | def __len__(self):
method __getitem__ (line 153) | def __getitem__(self, index):
class QADataset (line 188) | class QADataset(Dataset):
method __init__ (line 190) | def __init__(self,
method __len__ (line 302) | def __len__(self):
method __getitem__ (line 305) | def __getitem__(self, index):
class MhopSampler (line 391) | class MhopSampler(Sampler):
method __init__ (line 396) | def __init__(self, data_source, num_neg=9, n_gpu=8):
method __len__ (line 408) | def __len__(self):
method __iter__ (line 411) | def __iter__(self):
function qa_collate (line 424) | def qa_collate(samples, pad_id=0):
FILE: mdr/qa/qa_model.py
class BertPooler (line 13) | class BertPooler(nn.Module):
method __init__ (line 14) | def __init__(self, config):
method forward (line 19) | def forward(self, hidden_states):
class QAModel (line 27) | class QAModel(nn.Module):
method __init__ (line 29) | def __init__(self,
method forward (line 49) | def forward(self, batch):
FILE: mdr/qa/qa_trainer.py
class TrainerState (line 40) | class TrainerState:
method save (line 52) | def save(self, filename: str) -> None:
method load (line 61) | def load(cls, filename: str, default: "TrainerState", gpu: int) -> "Tr...
class Trainer (line 78) | class Trainer:
method __init__ (line 79) | def __init__(self, train_cfg: NamedTuple, cluster_cfg: ClusterConfig) ...
method __call__ (line 83) | def __call__(self) -> Optional[float]:
method log (line 93) | def log(self, log_data: dict):
method checkpoint (line 101) | def checkpoint(self, rm_init=True) -> submitit.helpers.DelayedSubmission:
method _setup_process_group (line 117) | def _setup_process_group(self) -> None:
method _init_state (line 128) | def _init_state(self) -> None:
method _train (line 209) | def _train(self) -> Optional[float]:
method _eval (line 283) | def _eval(self) -> dict:
FILE: mdr/qa/train_ranker.py
function load_saved (line 30) | def load_saved(model, path):
function main (line 37) | def main():
function predict (line 212) | def predict(args, model, eval_dataloader, device, logger):
FILE: mdr/qa/utils.py
function set_global_logging_level (line 13) | def set_global_logging_level(level=logging.ERROR, prefices=[""]):
function load_saved (line 29) | def load_saved(model, path, exact=True):
function move_to_cuda (line 45) | def move_to_cuda(sample):
function convert_to_half (line 65) | def convert_to_half(sample):
class AverageMeter (line 85) | class AverageMeter(object):
method __init__ (line 88) | def __init__(self):
method reset (line 91) | def reset(self):
method update (line 97) | def update(self, val, n=1):
function normalize (line 104) | def normalize(text):
function para_has_answer (line 109) | def para_has_answer(answer, para, tokenizer):
function match_answer_span (line 124) | def match_answer_span(p, answer, tokenizer, match="string"):
function _is_whitespace (line 145) | def _is_whitespace(char):
function _improve_answer_span (line 160) | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
function whitespace_tokenize (line 173) | def whitespace_tokenize(text):
function find_ans_span_with_char_offsets (line 182) | def find_ans_span_with_char_offsets(detected_ans, char_to_word_offset, d...
function convert_to_unicode (line 212) | def convert_to_unicode(text):
function _is_control (line 232) | def _is_control(char):
function _is_punctuation (line 243) | def _is_punctuation(char):
class BasicTokenizer (line 259) | class BasicTokenizer(object):
method __init__ (line 262) | def __init__(self, do_lower_case=True):
method tokenize (line 269) | def tokenize(self, text):
method _run_strip_accents (line 284) | def _run_strip_accents(self, text):
method _run_split_on_punc (line 295) | def _run_split_on_punc(self, text):
method _clean_text (line 315) | def _clean_text(self, text):
function get_final_text (line 329) | def get_final_text(pred_text, orig_text, do_lower_case=False, verbose_lo...
FILE: mdr/retrieval/config.py
class ClusterConfig (line 10) | class ClusterConfig(NamedTuple):
function common_args (line 14) | def common_args():
function train_args (line 71) | def train_args():
function encode_args (line 107) | def encode_args():
FILE: mdr/retrieval/criterions.py
function mhop_loss (line 114) | def mhop_loss(model, batch, args):
function mhop_eval (line 153) | def mhop_eval(outputs, args):
function unified_loss (line 185) | def unified_loss(model, batch, args):
function unified_eval (line 212) | def unified_eval(outputs, batch):
FILE: mdr/retrieval/decomposed_analysis.py
function decomposed_errors (line 9) | def decomposed_errors():
function collect_gold_decomposition (line 62) | def collect_gold_decomposition():
function qdmr_utils (line 99) | def qdmr_utils():
function analyze_results (line 128) | def analyze_results():
FILE: mdr/retrieval/mhop_trainer.py
class TrainerState (line 42) | class TrainerState:
method save (line 54) | def save(self, filename: str) -> None:
method load (line 63) | def load(cls, filename: str, default: "TrainerState", gpu: int) -> "Tr...
class Trainer (line 80) | class Trainer:
method __init__ (line 81) | def __init__(self, train_cfg: NamedTuple, cluster_cfg: ClusterConfig) ...
method __call__ (line 85) | def __call__(self) -> Optional[float]:
method log (line 95) | def log(self, log_data: dict):
method checkpoint (line 103) | def checkpoint(self, rm_init=True) -> submitit.helpers.DelayedSubmission:
method _setup_process_group (line 119) | def _setup_process_group(self) -> None:
method _init_state (line 130) | def _init_state(self) -> None:
method _train (line 204) | def _train(self) -> Optional[float]:
method _eval (line 271) | def _eval(self) -> float:
FILE: mdr/retrieval/single_trainer.py
class TrainerState (line 41) | class TrainerState:
method save (line 53) | def save(self, filename: str) -> None:
method load (line 62) | def load(cls, filename: str, default: "TrainerState", gpu: int) -> "Tr...
class Trainer (line 79) | class Trainer:
method __init__ (line 80) | def __init__(self, train_cfg: NamedTuple, cluster_cfg: ClusterConfig) ...
method __call__ (line 84) | def __call__(self) -> Optional[float]:
method log (line 94) | def log(self, log_data: dict):
method checkpoint (line 102) | def checkpoint(self, rm_init=True) -> submitit.helpers.DelayedSubmission:
method _setup_process_group (line 118) | def _setup_process_group(self) -> None:
method _init_state (line 129) | def _init_state(self) -> None:
method _train (line 203) | def _train(self) -> Optional[float]:
method _eval (line 276) | def _eval(self) -> float:
FILE: mdr/retrieval/train_single.py
function main (line 111) | def main():
function predict (line 295) | def predict(args, model, eval_dataloader, device, logger):
FILE: mdr/retrieval/utils/basic_tokenizer.py
class Tokens (line 18) | class Tokens(object):
method __init__ (line 27) | def __init__(self, data, annotators, opts=None):
method __len__ (line 32) | def __len__(self):
method slice (line 36) | def slice(self, i=None, j=None):
method untokenize (line 42) | def untokenize(self):
method words (line 46) | def words(self, uncased=False):
method offsets (line 57) | def offsets(self):
method pos (line 61) | def pos(self):
method lemmas (line 69) | def lemmas(self):
method entities (line 77) | def entities(self):
method ngrams (line 85) | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True):
method entity_groups (line 112) | def entity_groups(self):
class Tokenizer (line 134) | class Tokenizer(object):
method tokenize (line 139) | def tokenize(self, text):
method shutdown (line 142) | def shutdown(self):
method __del__ (line 145) | def __del__(self):
class RegexpTokenizer (line 155) | class RegexpTokenizer(Tokenizer):
method __init__ (line 174) | def __init__(self, **kwargs):
method tokenize (line 198) | def tokenize(self, text):
class SimpleTokenizer (line 238) | class SimpleTokenizer(Tokenizer):
method __init__ (line 242) | def __init__(self, **kwargs):
method tokenize (line 256) | def tokenize(self, text):
function normalize (line 303) | def normalize(text):
function filter_word (line 308) | def filter_word(text):
function filter_ngram (line 317) | def filter_ngram(gram, mode='any'):
FILE: mdr/retrieval/utils/mhop_utils.py
function pick_bridge_v0 (line 16) | def pick_bridge_v0(title2linked, title2doc, titles, q, ans):
function load_annotated (line 31) | def load_annotated(path="/private/home/xwhan/data/hotpot/tfidf/abstracts...
function normalize_answer (line 37) | def normalize_answer(s):
function hotpot_sp_data (line 55) | def hotpot_sp_data(raw_path):
function add_qid (line 106) | def add_qid(raw_path):
function add_bridge_ann (line 135) | def add_bridge_ann(raw_path):
function check_2hop (line 166) | def check_2hop(raw_path):
function add_sp_labels (line 173) | def add_sp_labels(raw_path, input_file, save_path,
function explore_QDMR (line 212) | def explore_QDMR(path="/private/home/xwhan/data/Break-dataset/QDMR-high-...
function add_sents_to_corpus_dict (line 249) | def add_sents_to_corpus_dict():
FILE: mdr/retrieval/utils/tokenizer.py
function convert_tokens_to_ids (line 32) | def convert_tokens_to_ids(vocab, tokens):
function whitespace_tokenize (line 39) | def whitespace_tokenize(text):
function convert_to_unicode (line 48) | def convert_to_unicode(text):
function _is_whitespace (line 68) | def _is_whitespace(char):
function _is_control (line 80) | def _is_control(char):
class BasicTokenizer (line 91) | class BasicTokenizer(object):
method __init__ (line 94) | def __init__(self, do_lower_case=True):
method tokenize (line 101) | def tokenize(self, text):
method _run_strip_accents (line 116) | def _run_strip_accents(self, text):
method _run_split_on_punc (line 127) | def _run_split_on_punc(self, text):
method _clean_text (line 147) | def _clean_text(self, text):
function _is_punctuation (line 161) | def _is_punctuation(char):
function process (line 177) | def process(s, tokenizer):
FILE: mdr/retrieval/utils/utils.py
function load_saved (line 10) | def load_saved(model, path, exact=True):
function move_to_cuda (line 24) | def move_to_cuda(sample):
function convert_to_half (line 43) | def convert_to_half(sample):
class AverageMeter (line 63) | class AverageMeter(object):
method __init__ (line 66) | def __init__(self):
method reset (line 69) | def reset(self):
method update (line 75) | def update(self, val, n=1):
function normalize (line 82) | def normalize(text):
class DocDB (line 87) | class DocDB(object):
method __init__ (line 93) | def __init__(self, db_path=None):
method __enter__ (line 97) | def __enter__(self):
method __exit__ (line 100) | def __exit__(self, *args):
method close (line 103) | def close(self):
method get_doc_ids (line 107) | def get_doc_ids(self):
method get_doc_text (line 115) | def get_doc_text(self, doc_id):
function para_has_answer (line 126) | def para_has_answer(answer, para, tokenizer):
function complex_ans_recall (line 142) | def complex_ans_recall():
FILE: scripts/demo.py
function init_retrieval (line 28) | def init_retrieval(args):
function init_reader (line 54) | def init_reader(args):
FILE: scripts/encode_corpus.py
function main (line 41) | def main():
function predict (line 95) | def predict(model, eval_dataloader):
FILE: scripts/end2end.py
function convert_hnsw_query (line 49) | def convert_hnsw_query(query_vectors):
FILE: scripts/eval/eval_mhop_retrieval.py
function convert_hnsw_query (line 44) | def convert_hnsw_query(query_vectors):
FILE: scripts/eval/eval_retrieval.py
function init (line 61) | def init():
function get_score (line 66) | def get_score(answer_doc, topk=20):
function add_marker_q (line 85) | def add_marker_q(tokenizer, q):
FILE: scripts/train_mhop.py
function main (line 54) | def main():
function predict (line 233) | def predict(args, model, eval_dataloader, device, logger):
FILE: scripts/train_momentum.py
function main (line 28) | def main():
function predict (line 214) | def predict(args, model, eval_dataloader, device, logger):
FILE: scripts/train_qa.py
function load_saved (line 33) | def load_saved(model, path):
function main (line 40) | def main():
function predict (line 220) | def predict(args, model, eval_dataloader, logger, fixed_thresh=None):
function eval_final (line 380) | def eval_final(args, model, eval_dataloader, weight=0.8, gpu=True):
FILE: submitit/submitit_train.py
function get_shared_folder (line 20) | def get_shared_folder() -> Path:
function get_init_file (line 23) | def get_init_file() -> Path:
function grid_parameters (line 31) | def grid_parameters(grid: Dict):
function grid_search (line 43) | def grid_search(args):
FILE: submitit/submitit_train_qa.py
function get_shared_folder (line 19) | def get_shared_folder() -> Path:
function get_init_file (line 22) | def get_init_file() -> Path:
function grid_parameters (line 30) | def grid_parameters(grid: Dict):
function grid_search (line 42) | def grid_search(args):
Condensed preview — 53 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (478K chars).
[
{
"path": ".gitignore",
"chars": 58,
"preview": "data/\nmdr.egg*/\napex/\nmodels/\nlogs/\n.DS_Store\n*.pyc\n*.swp\n"
},
{
"path": "CODE_OF_CONDUCT.md",
"chars": 3349,
"preview": "# Code of Conduct\n\n## Our Pledge\n\nIn the interest of fostering an open and welcoming environment, we as\ncontributors and"
},
{
"path": "CONTRIBUTING.md",
"chars": 1284,
"preview": "# Contributing to multihop_dense_retrieval\nWe want to make contributing to this project as easy and transparent as\npossi"
},
{
"path": "LICENSE",
"chars": 19327,
"preview": "Attribution-NonCommercial 4.0 International\n\n=======================================================================\n\nCr"
},
{
"path": "README.md",
"chars": 9886,
"preview": "\n\n\n\n# [<p align=center>Multi-Hop Dense Text Retrieval (`MDR`)</p>](#p-aligncentermulti-hop-dense-text-retrieval-mdrp)\n\n*"
},
{
"path": "mdr/__init__.py",
"chars": 239,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "mdr/qa/__init__.py",
"chars": 199,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "mdr/qa/basic_tokenizer.py",
"chars": 9125,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "mdr/qa/config.py",
"chars": 4948,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "mdr/qa/data_utils.py",
"chars": 865,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "mdr/qa/hotpot_evaluate_v1.py",
"chars": 4454,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "mdr/qa/qa_dataset.py",
"chars": 18839,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "mdr/qa/qa_model.py",
"chars": 4377,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "mdr/qa/qa_trainer.py",
"chars": 19844,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "mdr/qa/train.md",
"chars": 3902,
"preview": "\n\nCUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train_qa.py \\\n --do_train \\\n --prefix qa_wwm_bert_title_mark_eval_de"
},
{
"path": "mdr/qa/train_ranker.py",
"chars": 10232,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "mdr/qa/utils.py",
"chars": 13636,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "mdr/retrieval/__init__.py",
"chars": 466,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "mdr/retrieval/config.py",
"chars": 5934,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "mdr/retrieval/criterions.py",
"chars": 11894,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "mdr/retrieval/decomposed_analysis.py",
"chars": 5525,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "mdr/retrieval/fever.ipynb",
"chars": 19640,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": 90,\n \"metadata\": {},\n \"outputs\": [\n {\n \"name\""
},
{
"path": "mdr/retrieval/hotpot.ipynb",
"chars": 98392,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": 2,\n \"metadata\": {},\n \"outputs\": [],\n \"source\": [\n "
},
{
"path": "mdr/retrieval/interactive_retrieval.py",
"chars": 2527,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "mdr/retrieval/mhop_trainer.py",
"chars": 12533,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "mdr/retrieval/single_trainer.py",
"chars": 13371,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "mdr/retrieval/train_single.py",
"chars": 13227,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "mdr/retrieval/utils/basic_tokenizer.py",
"chars": 11487,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "mdr/retrieval/utils/gen_index_id_map.py",
"chars": 463,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "mdr/retrieval/utils/mhop_utils.py",
"chars": 11200,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "mdr/retrieval/utils/tokenizer.py",
"chars": 5510,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "mdr/retrieval/utils/utils.py",
"chars": 5278,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "requirements.txt",
"chars": 67,
"preview": "transformers==2.11.0\ntensorboard>=1.15.0\nnumpy\ntqdm\nujson\nstreamlit"
},
{
"path": "scripts/add_sp_label.sh",
"chars": 585,
"preview": "#!/bin/bash\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed u"
},
{
"path": "scripts/demo.py",
"chars": 7262,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "scripts/download_hotpot.sh",
"chars": 1263,
"preview": "#!/bin/bash\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed u"
},
{
"path": "scripts/encode_corpus.py",
"chars": 3707,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "scripts/end2end.py",
"chars": 8165,
"preview": "#!/usr/bin/env python\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is "
},
{
"path": "scripts/end2end.sh",
"chars": 951,
"preview": "#!/bin/bash\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed u"
},
{
"path": "scripts/eval/eval_mhop_fever.py",
"chars": 8812,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "scripts/eval/eval_mhop_retrieval.py",
"chars": 12631,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "scripts/eval/eval_reranked.py",
"chars": 2099,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "scripts/eval/eval_retrieval.py",
"chars": 9540,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "scripts/eval/eval_single_fever.py",
"chars": 5252,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "scripts/train_mhop.py",
"chars": 10731,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "scripts/train_momentum.py",
"chars": 10671,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "scripts/train_qa.py",
"chars": 21946,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "setup.py",
"chars": 774,
"preview": "#!/usr/bin/env python\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is "
},
{
"path": "setup.sh",
"chars": 479,
"preview": "#!/bin/bash\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed u"
},
{
"path": "submitit/submit_retrieval.sh",
"chars": 580,
"preview": "#!/bin/bash\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed u"
},
{
"path": "submitit/submitit_qa.sh",
"chars": 613,
"preview": "#!/bin/bash\n# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed u"
},
{
"path": "submitit/submitit_train.py",
"chars": 4031,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
},
{
"path": "submitit/submitit_train_qa.py",
"chars": 4658,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the lic"
}
]
About this extraction
This page contains the full source code of the facebookresearch/multihop_dense_retrieval GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 53 files (446.1 KB), approximately 148.0k tokens, and a symbol index with 222 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.