Full Code of oneal2000/PRAG for AI

main 43d7e1433830 cached
132 files
468.4 KB
122.0k tokens
352 symbols
1 requests
Download .txt
Showing preview only (510K chars total). Download the full file or copy to clipboard to get everything.
Repository: oneal2000/PRAG
Branch: main
Commit: 43d7e1433830
Files: 132
Total size: 468.4 KB

Directory structure:
gitextract_zinswnzk/

├── README.md
├── all_prompt.md
├── configs/
│   ├── 2wikimultihopqa_llama3-8b-instruct.sh
│   ├── 2wikimultihopqa_llama3.2-1b-instruct.sh
│   ├── 2wikimultihopqa_qwen2.5-1.5b-instruct.sh
│   ├── complexwebquestions_llama3-8b-instruct.sh
│   ├── complexwebquestions_llama3.2-1b-instruct.sh
│   ├── complexwebquestions_qwen2.5-1.5b-instruct.sh
│   ├── hotpotqa_llama3-8b-instruct.sh
│   ├── hotpotqa_llama3.2-1b-instruct.sh
│   ├── hotpotqa_qwen2.5-1.5b-instruct.sh
│   ├── popqa_llama3-8b-instruct.sh
│   ├── popqa_llama3.2-1b-instruct.sh
│   └── popqa_qwen2.5-1.5b-instruct.sh
├── prep_elastic.py
├── requirements.txt
└── src/
    ├── augment.py
    ├── encode.py
    ├── fewshot/
    │   ├── 2wikimultihopqa.json
    │   └── hotpotqa.json
    ├── get_warmup_data.py
    ├── inference.py
    ├── prompt_template.py
    ├── retrieve/
    │   ├── beir/
    │   │   ├── .gitignore
    │   │   ├── .gitmodules
    │   │   ├── CONTRIBUTORS.txt
    │   │   ├── LICENSE
    │   │   ├── NOTICE.txt
    │   │   ├── README.md
    │   │   ├── beir/
    │   │   │   ├── __init__.py
    │   │   │   ├── datasets/
    │   │   │   │   ├── __init__.py
    │   │   │   │   ├── data_loader.py
    │   │   │   │   └── data_loader_hf.py
    │   │   │   ├── generation/
    │   │   │   │   ├── __init__.py
    │   │   │   │   ├── generate.py
    │   │   │   │   └── models/
    │   │   │   │       ├── __init__.py
    │   │   │   │       ├── auto_model.py
    │   │   │   │       └── tilde.py
    │   │   │   ├── logging.py
    │   │   │   ├── losses/
    │   │   │   │   ├── __init__.py
    │   │   │   │   ├── bpr_loss.py
    │   │   │   │   └── margin_mse_loss.py
    │   │   │   ├── reranking/
    │   │   │   │   ├── __init__.py
    │   │   │   │   ├── models/
    │   │   │   │   │   ├── __init__.py
    │   │   │   │   │   ├── cross_encoder.py
    │   │   │   │   │   └── mono_t5.py
    │   │   │   │   └── rerank.py
    │   │   │   ├── retrieval/
    │   │   │   │   ├── __init__.py
    │   │   │   │   ├── custom_metrics.py
    │   │   │   │   ├── evaluation.py
    │   │   │   │   ├── models/
    │   │   │   │   │   ├── __init__.py
    │   │   │   │   │   ├── bpr.py
    │   │   │   │   │   ├── dpr.py
    │   │   │   │   │   ├── sentence_bert.py
    │   │   │   │   │   ├── sparta.py
    │   │   │   │   │   ├── splade.py
    │   │   │   │   │   ├── tldr.py
    │   │   │   │   │   ├── unicoil.py
    │   │   │   │   │   └── use_qa.py
    │   │   │   │   ├── search/
    │   │   │   │   │   ├── __init__.py
    │   │   │   │   │   ├── base.py
    │   │   │   │   │   ├── dense/
    │   │   │   │   │   │   ├── __init__.py
    │   │   │   │   │   │   ├── exact_search.py
    │   │   │   │   │   │   ├── exact_search_multi_gpu.py
    │   │   │   │   │   │   ├── faiss_index.py
    │   │   │   │   │   │   ├── faiss_search.py
    │   │   │   │   │   │   └── util.py
    │   │   │   │   │   ├── lexical/
    │   │   │   │   │   │   ├── __init__.py
    │   │   │   │   │   │   ├── bm25_search.py
    │   │   │   │   │   │   └── elastic_search.py
    │   │   │   │   │   └── sparse/
    │   │   │   │   │       ├── __init__.py
    │   │   │   │   │       └── sparse_search.py
    │   │   │   │   └── train.py
    │   │   │   └── util.py
    │   │   ├── examples/
    │   │   │   ├── beir-pyserini/
    │   │   │   │   ├── Dockerfile
    │   │   │   │   ├── config.py
    │   │   │   │   ├── dockerhub.sh
    │   │   │   │   └── main.py
    │   │   │   ├── benchmarking/
    │   │   │   │   ├── benchmark_bm25.py
    │   │   │   │   ├── benchmark_bm25_ce_reranking.py
    │   │   │   │   └── benchmark_sbert.py
    │   │   │   ├── dataset/
    │   │   │   │   ├── README.md
    │   │   │   │   ├── download_dataset.py
    │   │   │   │   ├── md5.csv
    │   │   │   │   └── scrape_tweets.py
    │   │   │   ├── generation/
    │   │   │   │   ├── passage_expansion_tilde.py
    │   │   │   │   ├── query_gen.py
    │   │   │   │   ├── query_gen_and_train.py
    │   │   │   │   └── query_gen_multi_gpu.py
    │   │   │   └── retrieval/
    │   │   │       ├── README.md
    │   │   │       ├── evaluation/
    │   │   │       │   ├── README.md
    │   │   │       │   ├── custom/
    │   │   │       │   │   ├── evaluate_custom_dataset.py
    │   │   │       │   │   ├── evaluate_custom_dataset_files.py
    │   │   │       │   │   ├── evaluate_custom_metrics.py
    │   │   │       │   │   └── evaluate_custom_model.py
    │   │   │       │   ├── dense/
    │   │   │       │   │   ├── evaluate_ance.py
    │   │   │       │   │   ├── evaluate_bpr.py
    │   │   │       │   │   ├── evaluate_dim_reduction.py
    │   │   │       │   │   ├── evaluate_dpr.py
    │   │   │       │   │   ├── evaluate_faiss_dense.py
    │   │   │       │   │   ├── evaluate_sbert.py
    │   │   │       │   │   ├── evaluate_sbert_hf_loader.py
    │   │   │       │   │   ├── evaluate_sbert_multi_gpu.py
    │   │   │       │   │   ├── evaluate_tldr.py
    │   │   │       │   │   └── evaluate_useqa.py
    │   │   │       │   ├── late-interaction/
    │   │   │       │   │   └── README.md
    │   │   │       │   ├── lexical/
    │   │   │       │   │   ├── evaluate_anserini_bm25.py
    │   │   │       │   │   ├── evaluate_bm25.py
    │   │   │       │   │   └── evaluate_multilingual_bm25.py
    │   │   │       │   ├── reranking/
    │   │   │       │   │   ├── README.md
    │   │   │       │   │   ├── evaluate_bm25_ce_reranking.py
    │   │   │       │   │   ├── evaluate_bm25_monot5_reranking.py
    │   │   │       │   │   └── evaluate_bm25_sbert_reranking.py
    │   │   │       │   └── sparse/
    │   │   │       │       ├── evaluate_anserini_docT5query.py
    │   │   │       │       ├── evaluate_anserini_docT5query_parallel.py
    │   │   │       │       ├── evaluate_deepct.py
    │   │   │       │       ├── evaluate_sparta.py
    │   │   │       │       ├── evaluate_splade.py
    │   │   │       │       └── evaluate_unicoil.py
    │   │   │       └── training/
    │   │   │           ├── train_msmarco_v2.py
    │   │   │           ├── train_msmarco_v3.py
    │   │   │           ├── train_msmarco_v3_bpr.py
    │   │   │           ├── train_msmarco_v3_margin_MSE.py
    │   │   │           ├── train_sbert.py
    │   │   │           └── train_sbert_BM25_hardnegs.py
    │   │   ├── setup.cfg
    │   │   └── setup.py
    │   ├── readme.md
    │   └── retriever.py
    ├── root_dir_path.py
    ├── utils.py
    └── warmup_lora.py

================================================
FILE CONTENTS
================================================

================================================
FILE: README.md
================================================
# Parametric RAG

📢 **News: this work has been accepted at the SIGIR 2025!**




![Overall Analysis](assets/overall.png)

![Overall Analysis](assets/ParametricRAG.gif)


## Overview

**Welcome to the Official Repository of Parametric Retrieval-Augmented Generation (Parametric RAG)!**

This repository contains the code, datasets models used in our paper:
 **"Parametric Retrieval-Augmented Generation"**.



**If you find our project interesting or helpful, we would sincerely appreciate it if you could give us a star! Your support is a tremendous encouragement to us!**



#### What is Parametric RAG?

Parametric RAG introduces a new paradigm for retrieval-augmented generation by embedding external knowledge directly into the parametric space of Large Language Models (LLMs). This approach overcomes the limitations of traditional in-context RAG methods by:

- Reducing computational overhead by avoiding large context windows.

- Deeply integrating external knowledge into the Feed-Forward Networks (FFN) of LLMs for improved reasoning and synthesis.



#### What’s Included?

- End-to-end implementation of the Parametric RAG pipeline.
- Preprocessed benchmark datasets for experiments and scripts for customizing and adding new datasets.

## Reproduce Paper Results

In the following GitHub repository, we demonstrate how to test the performance of Parametric RAG on various QA datasets. Specifically, follow these steps to run Parametric RAG:

- **Run the Data Augmentation Module**: This step corresponds to Section 3.2.1 *Self-Augmentation* in the original paper, where documents are transformed into a data-augmented dataset.
- **Generate Parametric Representations of Documents**: This step corresponds to Section 3.2.2 *Additional Parameter Training* in the original paper, where additional LoRA parameters are trained.
- **Inference**: Merge the parametric representations of relevant documents, insert them into the LLM, and use the updated LLM for inference.

All the prompts used in the experiment are displayed in the `all_prompt.md` file.

### Install Environment

```
conda create -n prag python=3.10.4
conda activate prag
pip install torch==2.1.0
pip install -r requirements.txt
```

Please change the `ROOT_DIR` variable in `src/root_dir_path.py` to the folder address where you store PRAG.

### Self-Augmentation

You can directly use the pre-augmented data file `data_aug.tar.gz`. To extract it, run the command `tar -xzvf data_aug.tar.gz` in your terminal.

If you want to perform data augmentation yourself, please process it as follows.

#### Prepare BM25 for retrieval

1. Download the Wikipedia dump from the [DPR repository](https://github.com/facebookresearch/DPR/blob/main/dpr/data/download_data.py#L32) using the following command

```bash
mkdir -p data/dpr
wget -O data/dpr/psgs_w100.tsv.gz https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz
pushd data/dpr
gzip -d psgs_w100.tsv.gz
popd
```

2. Use Elasticsearch to index the Wikipedia dump

```bash
cd data
wget -O elasticsearch-8.15.0.tar.gz https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-8.15.0-linux-x86_64.tar.gz  # download Elasticsearch
tar zxvf elasticsearch-8.15.0.tar.gz
rm elasticsearch-8.15.0.tar.gz 
cd elasticsearch-8.15.0
nohup bin/elasticsearch &  # run Elasticsearch in background
cd ../..
python prep_elastic.py --data_path data/dpr/psgs_w100.tsv --index_name wiki  # build index
```

#### Download dataset

For 2WikiMultihopQA:

Download the [2WikiMultihopQA](https://www.dropbox.com/s/ms2m13252h6xubs/data_ids_april7.zip?e=1) dataset from its repository <https://www.dropbox.com/s/ms2m13252h6xubs/data_ids_april7.zip?e=1>. Unzip it and move the folder to `data/2wikimultihopqa`.

For HotpotQA:

```bash
mkdir -p data/hotpotqa
wget -P data/hotpotqa/ http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_distractor_v1.json
```

For PopQA:

Download the [PopQA](https://github.com/AlexTMallen/adaptive-retrieval?tab=readme-ov-file#popqa) dataset from its repository <https://github.com/AlexTMallen/adaptive-retrieval/blob/main/data/popQA.tsv>, and put the file `popQA.tsv` into folder `data/popqa`.

```bash
mkdir -p data/popqa
wget -P data/popqa https://github.com/AlexTMallen/adaptive-retrieval/blob/main/data/popQA.tsv
```

For ComplexWebQuestions:

Download the [ComplexWebQuestions](https://www.tau-nlp.sites.tau.ac.il/compwebq) dataset from its repository <https://www.dropbox.com/scl/fo/nqujvpg2gc4y0ozkw3wgr/AOzjVEsdUhv2Fx2pamfJlSw?rlkey=746t7xehfqxf1zr867nxiq8aq&e=1>, and put the file `ComplexWebQuestions_dev.json` into folder `data/complexwebquestions`.

#### Data Augmentation:

```bash
python src/augment.py \
    --model_name llama3.2-1b-instruct \
    --dataset 2wikimultihopqa \
    --data_path data/2wikimultihopqa/ \
    --sample 300  \
    --topk 3
```

| **Parameter** | **Example/Options** |
| ------------------------------ | ---------------------------------------------------- |
| `model_name` | `llama3.2-1b-instruct`, `qwen2.5-1.5b-instruct`, `llama3-8b-instruct` |
| `dataset` | `2wikimultihopqa`, `hotpotqa`, `popqa`, `complexwebquestions` |
| `data_path` | folder to the saved data, such as `data/2wikimultihopqa` |
| `sample` | Number of questions to run |
| `topk` | retrieval number |

The results of data augmentation will be stored in the file `data_aug/{dataset}/{data_type}.json`.

If you want to apply data augmentation to a new dataset, the default data format for the augmented data is JSON. Each element in the array should include both a 'question' and an 'answer,' as shown in the example below.

```json
[
    {
        "question": "string",
        "answer": "string or list[string]",
    }
]
```

At this point, the input parameter `dataset` refers to the name of the dataset you’ve set, and `data_path` is the path to the JSON file mentioned above. The last filename in `data_path` will be treated as the `data_type`. The output file will be saved in `data_aug/{your_dataset_name}/{data_type}.json`.





### Document Parameterizing


![Methodology](assets/method.png)

By calling the `src/encode.py` file, you will generate a parameterized representation of the documents (LoRA) for the given dataset. The parameters for this file are as follows:

| **Parameter**                  | **Example/Options**                                  |
| ------------------------------ | ---------------------------------------------------- |
| `model_name`                   | `llama3.2-1b-instruct`, `qwen2.5-1.5b-instruct`, `llama3-8b-instruct` |
| `dataset`                      | `2wikimultihopqa`, `hotpotqa`, `popqa`, `complexwebquestions` |
| `data_type`                    | Not set means using the entire dataset, otherwise, specify a particular data type |
| `with_cot`                     | If included, generate a CoT |
| `sample`                        | Number of questions to run |
| `augment_model`                | Model used for data augmentation. If not set, the current model will be used for augmentation |
| `per_device_train_batch_size`, `num_train_epochs`, `learning_rate` | Training parameters |
| `lora_rank`, `lora_alpha`       | LoRA parameters, dropout will be set to 0 |

When running for the first time with a specific LoRA parameter, an initial random parameter, `base_weight` will be created. All subsequent training will start from this base_weight.

All generated parameters are stored in the `offline` folder. 
The specific location of the parameter files is as follows:

```plain
offline/
├── {model_name}/
│   └── rank={lora_rank}_alpha={lora_alpha}/
│       ├── base_weight/
│       └── {dataset}/
│           └── lr={learning_rate}_epoch={num_train_epochs}/
│               └── aug_model={augment_model}/
│                   └── {data_type}/
│                       └── data_{did}/
│                           └── passage_{pid}/
|                               └── parameters
```

The running parameters of the main experiments in the paper are listed in the `configs` folder.

### Generate

By calling the `src/inference.py` file, you will generate a parameterized representation of the documents (LoRA) for the given dataset. The parameters for this file are as follows:

| **Parameter**                  | **Example/Options**                                  |
| ------------------------------ | ---------------------------------------------------- |
| `model_name`                   | `llama3.2-1b-instruct`, `qwen2.5-1.5b-instruct`, `llama3-8b-instruct` |
| `dataset`                      | `2wikimultihopqa`, `hotpotqa`, `popqa`, `complexwebquestions` |
| `data_type`                    | Not set means using the entire dataset, otherwise, specify a particular data type |
| `with_cot`                     | If included, generate a CoT |
| `sample`                        | Number of questions to run |
| `augment_model`                | Model used for data augmentation. If not set, the current model will be used for augmentation |
| `per_device_train_batch_size`, `num_train_epochs`, `learning_rate` | Training parameters |
| `lora_rank`, `lora_alpha`       | LoRA parameters, dropout will be set to 0 |
| `max_new_tokens` | Number of generate tokens |
| `inference_method` | "icl" is naive RAG, "prag" is our method, and "combine" is using both methods together |

All generated results are stored in the `output` folder. The specific location of the parameter files is as follows:

```plain
offline/
├── {model_name}/
│   └── rank={lora_rank}_alpha={lora_alpha}/
│       └── {dataset}/
│           └── lr={learning_rate}_epoch={num_train_epochs}/
│               └── aug_model={augment_model}/
│                   └── {inference_method}/
│                       └── {data_type}/
│                           ├── config.json
│                           ├── predict.json
│                           └── result.txt
```

Also, the running parameters of the main experiments in the paper are listed in the `configs` folder.

## Warm up LoRA

After calling `python src/get_warmup_data.py`, the initialization training data for finetuning will be generated from the **latter** part of the dataset. The data generation code ensures that there is no data leakage. 

Then, the following code will be used to train and generate two base LoRA weights:


```bash
# the training used 600 data points 
python src/warmup_lora.py \
    --model_name llama3.2-1b-instruct \
    --per_device_train_batch_size 1 \
    --num_train_epochs 1 \
    --learning_rate 3e-4  \
    --block_size 3000 \
    --lora_rank 2 \
    --lora_alpha 32 \
    --with_cot 

# the training used 2000 data points  
python src/warmup_lora.py \
    --model_name llama3.2-1b-instruct \
    --per_device_train_batch_size 1 \
    --num_train_epochs 1 \
    --learning_rate 3e-4  \
    --lora_rank 2 \
    --lora_alpha 32 \
    --block_size 3000  
```


================================================
FILE: all_prompt.md
================================================
# Prompt Design for Our Work

This repository contains all the prompts involved in our work, categorized and explained for better understanding. The prompts are organized into the following sections:

- **Prompt for Document Augmentation (Section 3.2.1):**
  Specific prompts for augmenting documents, including tasks like rewriting and generating question-answer pairs.



- **Prompt for Experimental Datasets:**
  Specific prompts for the datasets used in our experiments, including 2WikiMultihopQA, HotpotQA, PopQA, and ComplexWebQuestions.





## Prompt for Document Augmentation

### Document Rewriting
Details of the prompts used to rewrite or transform documents for augmentation purposes.

```plain
Rewrite the following passage. While keeping the entities, proper nouns, and key details such as names, locations, and terminology intact, create a new version of the text that expresses the same ideas in a different way. Make sure the revised passage is distinct from the original one, but preserves the core meaning and relevant information.
{passage}
```

### QA Generation
Explanation of the prompts used to generate question-answer pairs for document augmentation.

```plain
I will provide a passage of text, and you need to generate three different questions based on the content of this passage. Each question should be answerable using the information provided in the passage. Additionally, please provide an appropriate answer for each question derived from the passage.
You need to generate the question and answer in the following format:
[
    {
        "question": "What is the capital of France?",
        "answer": "Paris"
        "full_answer": "The capital of France is Paris."
    }, 
]
This list should have at least three elements. You only need to output this list in the above format.
Passage:
{passage}
```

## Prompt for Experimental Datasets

Following prior works such as FLARE, DRAGIN, SEAKR, and DRAD, we adopt the few-shot prompting template introduced in IR-CoT, which includes exemplars to guide the reasoning process of LLMs. However, IR-CoT only provides prompt templates with few-shot examples for 2WikiMultihopQA and HotpotQA. Consequently, to be consistent with previous works, we employ the same few-shot prompting strategy for these two datasets.

In contrast, IR-CoT does not provide predefined few-shot templates for PopQA and ComplexWebQuestions. We use zero-shot prompting for these datasets to ensure a fair comparison without introducing dataset-specific prompt engineering.

**Furthermore, we ensured that P-RAG and all the baselines share the same prompt template within each dataset, ensuring a fair evaluation.**



### 2WikiMultihopQA
Description of the prompts designed specifically for 2WikiMultihopQA dataset tasks.

```plain
<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
You should reference the knowledge provided below and combine it with your own knowledge to answer the question. Please follow the format of the example I provided above.
Here are some examples about how to answer the questions.
Question: When did the director of film Hypocrite (Film) die?
Answer: The film Hypocrite was directed by Miguel Morayta. Miguel Morayta died on 19 June 2013. So the answer is 19 June 2013.

Question: Are both Kurram Garhi and Trojkrsti located in the same country?
Answer: Kurram Garhi is located in the country of Pakistan. Trojkrsti is located in the country of Republic of Macedonia. Thus, they are not in the same country. So the answer is no.

Question: Do director of film Coolie No. 1 (1995 Film) and director of film The Sensational Trial have the same nationality?
Answer: Coolie No. 1 (1995 film) was directed by David Dhawan. The Sensational Trial was directed by Karl Freund. David Dhawan's nationality is India. Karl Freund's nationality is Germany. Thus, they do not have the same nationality. So the answer is no.

Question: Who is Boraqchin (Wife Of Ögedei)'s father-in-law?
Answer: Boraqchin is married to Ögedei Khan. Ögedei Khan's father is Genghis Khan. Thus, Boraqchin's father-in-law is Genghis Khan. So the answer is Genghis Khan.

Question: Who was born first out of Martin Hodge and Ivania Martinich?
Answer: Martin Hodge was born on 4 February 1959. Ivania Martinich was born on 25 July 1995. Thus, Martin Hodge was born first. So the answer is Martin Hodge.

Question: When did the director of film Laughter In Hell die?
Answer: The film Laughter In Hell was directed by Edward L. Cahn. Edward L. Cahn died on August 25, 1963. So the answer is August 25, 1963.

Question: Which film has the director died later, The Gal Who Took the West or Twenty Plus Two?
Answer: The film Twenty Plus Two was directed by Joseph M. Newman. The Gal Who Took the West was directed by Frederick de Cordova. Joseph M. Newman died on January 23, 2006. Fred de Cordova died on September 15, 2001. Thus, the person to die later from the two is Twenty Plus Two. So the answer is Twenty Plus Two.

Question: Who is the grandchild of Krishna Shah (Nepalese Royal)?
Answer: Krishna Shah has a child named Rudra Shah. Rudra Shah has a child named Prithvipati Shah. Thus, Krishna Shah has a grandchild named Prithvipati Shah. So the answer is Prithvipati Shah.

Here are some reference.
Passage 1: Polish-Russian War (film) Polish-Russian War (Wojna polsko-ruska) is a 2009 Polish film directed by Xawery Żuławski based on the novel Polish-Russian War under the white-red flag by Dorota Masłowska. The film's events take place over several days and they are set in the present time in a large Polish city. The main character is a bandit, a Polish dres (a Polish chav) called "Strong" (Borys Szyc), who does not work or study, and who frequently gets into conflict with the law and is in love with Magda (Roma Gąsiorowska). The relationship is not going well. "Strong" is insanely jealous of
Passage 2: film was shot between May 6 and 18 June 2008 in locations of Warsaw, Wejherowo, Sopot and Gdynia outskirts. The film premiered on May 22, 2009. The budget of Polish-Russian War amounted to approx. 4 million zlotys. The creators of the music for the film are Jan Komar, Filip Kuncewicz, Liroy, Mateusz Łapot and Jarosław Karczmarczyk. The soundtrack also included the following songs: Polish-Russian War (film) Polish-Russian War (Wojna polsko-ruska) is a 2009 Polish film directed by Xawery Żuławski based on the novel Polish-Russian War under the white-red flag by Dorota Masłowska. The film's events take place over several days
Passage 3: she has been in contact with Olga's dead mother and she asks the Attorney to participate in a seance. Body (2015 Polish film) Body () is a 2015 Polish drama film directed by Małgorzata Szumowska. It was screened in the main competition section of the 65th Berlin International Film Festival where Szumowska won the Silver Bear for Best Director. The film also received the Golden Lions Award at the 2015 Gdynia Film Festival and the People's Choice Award at the 2016 European Film Awards. Olga who struggles with anorexia is sent to psychiatric hospital by the Attorney, where she is


Let's think step by step. Answer the questions in the same format as above.
Question: Who is the mother of the director of film Polish-Russian War (Film)?<|im_end|>
<|im_start|>assistant
Answer: 
```

### HotpotQA
Details about the prompts tailored for the HotpotQA dataset.

```plain
<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
You should reference the knowledge provided below and combine it with your own knowledge to answer the question. Please follow the format of the example I provided above.
Here are some examples about how to answer the questions.
Question: Jeremy Theobald and Christopher Nolan share what profession?
Answer: Jeremy Theobald is an actor and producer. Christopher Nolan is a director, producer, and screenwriter. Therefore, they both share the profession of being a producer. So the answer is producer.

Question: What film directed by Brian Patrick Butler was inspired by a film directed by F.W. Murnau?
Answer: Brian Patrick Butler directed the film The Phantom Hour. The Phantom Hour was inspired by the films such as Nosferatu and The Cabinet of Dr. Caligari. Of these Nosferatu was directed by F.W. Murnau. So the answer is The Phantom Hour..

Question: How many episodes were in the South Korean television series in which Ryu Hye-young played Bo-ra?
Answer: The South Korean television series in which Ryu Hye-young played Bo-ra is Reply 1988. The number of episodes Reply 1988 has is 20. So the answer is 20.

Question: Were Lonny and Allure both founded in the 1990s?
Answer: Lonny (magazine) was founded in 2009. Allure (magazine) was founded in 1991. Thus, of the two, only Allure was founded in 1990s. So the answer is no.

Question: Vertical Limit stars which actor who also played astronaut Alan Shepard in "The Right Stuff"?
Answer: The actor who played astronaut Alan Shepard in "The Right Stuff" is Scott Glenn. The movie Vertical Limit also starred Scott Glenn. So the answer is Scott Glenn.

Question: What was the 2014 population of the city where Lake Wales Medical Center is located?
Answer: Lake Wales Medical Center is located in the city of Polk County, Florida. The population of Polk County in 2014 was 15,140. So the answer is 15,140.

Question: Who was born first? Jan de Bont or Raoul Walsh?
Answer: Jan de Bont was born on 22 October 1943. Raoul Walsh was born on March 11, 1887. Thus, Raoul Walsh was born the first. So the answer is Raoul Walsh.

Question: In what country was Lost Gravity manufactured?
Answer: The Lost Gravity (roller coaster) was manufactured by Mack Rides. Mack Rides is a German company. So the answer is Germany.

Question: Which of the following had a debut album entitled "We Have an Emergency": Hot Hot Heat or The Operation M.D.?
Answer: The debut album of the band "Hot Hot Heat" was "Make Up the Breakdown". The debut album of the band "The Operation M.D." was "We Have an Emergency". So the answer is The Operation M.D..

Question: How many awards did the "A Girl Like Me" singer win at the American Music Awards of 2012?
Answer: The singer of "A Girl Like Me" singer is Rihanna. In the American Music Awards of 2012, Rihana won one award. So the answer is one.

Question: The actor that stars as Joe Proctor on the series "Power" also played a character on "Entourage" that has what last name?
Answer: The actor that stars as Joe Proctor on the series "Power" is Jerry Ferrara. Jerry Ferrara also played a character on Entourage named Turtle Assante. Thus, Turtle Assante's last name is Assante. So the answer is Assante.

Here are some reference.
Passage 1: and has two children. Critical, public and commercial reception to films Derrickson has directed as of November 13, 2016. Scott Derrickson Scott Derrickson (born July 16, 1966) is an American director, screenwriter and producer. He lives in Los Angeles, California. Derrickson is best known for directing numerous horror films, such as "The Exorcism of Emily Rose" (2005), "Sinister" (2012), and "Deliver Us From Evil" (2014), as well as the Marvel Cinematic Universe superhero film "Doctor Strange" (2016). Derrickson grew up in Denver, Colorado. He graduated from Biola University with a B.A. in Humanities, with an emphasis on literature and philosophy,
Passage 2: Scott Derrickson Scott Derrickson (born July 16, 1966) is an American director, screenwriter and producer. He lives in Los Angeles, California. Derrickson is best known for directing numerous horror films, such as "The Exorcism of Emily Rose" (2005), "Sinister" (2012), and "Deliver Us From Evil" (2014), as well as the Marvel Cinematic Universe superhero film "Doctor Strange" (2016). Derrickson grew up in Denver, Colorado. He graduated from Biola University with a B.A. in Humanities, with an emphasis on literature and philosophy, and a B.A. in communications, with an emphasis on film, and a minor in theological studies. He earned his
Passage 3: M.A. in film production from USC School of Cinematic Arts. Derrickson co-wrote and directed the film "The Exorcism of Emily Rose" which was loosely based on a true story about Anneliese Michel. The film won the 2005 Saturn Award for Best Horror or Thriller Film and in 2006 was named in the Chicago Film Critics Association's list of the "Top 100 Scariest Films Ever Made." Theatrical box office gross for "The Exorcism of Emily Rose" was over $144 million worldwide. That same year, Derrickson wrote "Land of Plenty" for director Wim Wenders, an independent drama starring Michelle Williams. Derrickson next


Let's think step by step. Answer the questions in the same format as above.
Question: Were Scott Derrickson and Ed Wood of the same nationality?<|im_end|>
<|im_start|>assistant
Answer: 
```

### PopQA
Explanation of the prompts created for PopQA dataset experiments.

```plain
<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
You should answer the question by referring to the knowledge provided below and integrating your own knowledge.
Passage 1: George Rankin Major General George James Rankin, (1 May 1887 – 28 December 1957) was an Australian soldier and politician. He served in both the House of Representatives and the Senate, representing the Country Party of Australia. Rankin was born at Bamawm, Victoria, the tenth child of Irish farmer James Rankin and Sarah, née Gallagher. He attended the local state school and became a farmer. In 1907, he joined the Militia, and was commissioned in the 9th Light Horse Regiment in 1909. He married Annie Isabella Oliver at Rochester, Victoria on 7 July 1912. In 1914, he was appointed a
Passage 2: was buried, and was survived by his wife. George Rankin Major General George James Rankin, (1 May 1887 – 28 December 1957) was an Australian soldier and politician. He served in both the House of Representatives and the Senate, representing the Country Party of Australia. Rankin was born at Bamawm, Victoria, the tenth child of Irish farmer James Rankin and Sarah, née Gallagher. He attended the local state school and became a farmer. In 1907, he joined the Militia, and was commissioned in the 9th Light Horse Regiment in 1909. He married Annie Isabella Oliver at Rochester, Victoria on 7
Passage 3: George Claus Rankin Sir George Claus Rankin PC (12 August 1877 – 8 April 1946) was a British judge in India. Rankin was born in Lamington, Lanarkshire, the son of Rev. Robert Rankin. He was educated at the University of Edinburgh and Trinity College, Cambridge. He as admitted at Lincoln's Inn and called to the bar in 1904. He served in the First World War with the Royal Garrison Artillery. He went to India in 1918 and served first as a puisne judge of the High Court of Calcutta, and then as Chief Justice, from 1926 to 1934. While in


Question: What is George Rankin's occupation?<|im_end|>
<|im_start|>assistant
The answer is 
```

### ComplexWebQuestions
Information on the prompts designed for the ComplexWebQuestions dataset.

```plain
<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
You should answer the question by referring to the knowledge provided below and integrating your own knowledge.
Passage 1: party such as the United Nations. The 1974 Interim Constitution Act was passed by the 48-member Azad Jammu and Kashmir unicameral assembly. Azad Jammu and Kashmir (AJK) is a self-governing state under Pakistani control, but under Pakistan's constitution the state is informally part of the country. Pakistan is administering the region as a self-governing territory rather than incorporating it in the federation since the UN-mandated ceasefire. Azad Kashmir has its own elected President, Prime Minister, Legislative Assembly, High Court, with Azam Khan as its present chief justice, and official flag. Azad Kashmir's financial matters, i.e., budget and tax affairs, are
Passage 2: and new research programmes have already been launched. However Institute of Geology is ranked 2nd in whole country based upon diverse field work and researches in Masters and Ph.D programmes. University of Azad Jammu and Kashmir The University of Azad Jammu and Kashmir is a university at Muzaffarabad, Azad Jammu and Kashmir, Pakistan. It was established in 1980, and is currently ranked at No.14 in HEC ranking of General category universities in Pakistan. The University of Azad Jammu and Kashmir is a multi-campus, multi-discipline university. The University of Azad Jammu and Kashmir has been making steady progress in both academic
Passage 3: organisations as "Pakistan administered Kashmir". Azad Kashmir is one-sixth of the size of Gilgit-Baltistan. The territory also borders Pakistan's Punjab province to the south and Khyber Pakhtunkhwa province to the west. To the east, Azad Kashmir is separated from the state of Jammu and Kashmir by the Line of Control, the "de facto" border between India and Pakistan. Azad Kashmir has a total area of , and a total population of 4,045,366 as per the 2017 Census. The territory has a parliamentary form of government modeled after the Westminster system, with its capital located at Muzaffarabad. The President is the


Question: Who was the president in 1980 of the country that has Azad Kashmir?<|im_end|>
<|im_start|>assistant
The answer is 
```

================================================
FILE: configs/2wikimultihopqa_llama3-8b-instruct.sh
================================================
python3 src/encode.py \
    --model_name=llama3-8b-instruct \
    --dataset=2wikimultihopqa \
    --sample=300 \
    --per_device_train_batch_size=1 \
    --num_train_epochs=1 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 \
    --with_cot

python3 src/inference.py \
    --model_name=llama3-8b-instruct \
    --dataset=2wikimultihopqa \
    --sample=300 \
    --num_train_epochs=1 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 \
    --max_new_tokens=128 \
    --inference_method=combine \
    --with_cot

================================================
FILE: configs/2wikimultihopqa_llama3.2-1b-instruct.sh
================================================
python3 src/encode.py \
    --model_name=llama3.2-1b-instruct \
    --dataset=2wikimultihopqa \
    --sample=300 \
    --per_device_train_batch_size=1 \
    --num_train_epochs=1 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 \
    --with_cot

python3 src/inference.py \
    --model_name=llama3.2-1b-instruct \
    --dataset=2wikimultihopqa \
    --sample=300 \
    --num_train_epochs=1 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 \
    --max_new_tokens=128 \
    --inference_method=combine \
    --with_cot

================================================
FILE: configs/2wikimultihopqa_qwen2.5-1.5b-instruct.sh
================================================
python3 src/encode.py \
    --model_name=qwen2.5-1.5b-instruct \
    --dataset=2wikimultihopqa \
    --sample=300 \
    --per_device_train_batch_size=1 \
    --num_train_epochs=1 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 \
    --with_cot

python3 src/inference.py \
    --model_name=qwen2.5-1.5b-instruct \
    --dataset=2wikimultihopqa \
    --sample=300 \
    --num_train_epochs=1 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 \
    --max_new_tokens=128 \
    --inference_method=combine \
    --with_cot

================================================
FILE: configs/complexwebquestions_llama3-8b-instruct.sh
================================================
python3 src/encode.py \
    --model_name=llama3-8b-instruct \
    --dataset=complexwebquestions \
    --sample=300 \
    --per_device_train_batch_size=1 \
    --num_train_epochs=1 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 

python3 src/inference.py \
    --model_name=llama3-8b-instruct \
    --dataset=complexwebquestions \
    --sample=300 \
    --num_train_epochs=1 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 \
    --max_new_tokens=20 \
    --inference_method=combine 

================================================
FILE: configs/complexwebquestions_llama3.2-1b-instruct.sh
================================================
python3 src/encode.py \
    --model_name=llama3.2-1b-instruct \
    --dataset=complexwebquestions \
    --sample=300 \
    --per_device_train_batch_size=1 \
    --num_train_epochs=1 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 

python3 src/inference.py \
    --model_name=llama3.2-1b-instruct \
    --dataset=complexwebquestions \
    --sample=300 \
    --num_train_epochs=1 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 \
    --max_new_tokens=20 \
    --inference_method=combine 

================================================
FILE: configs/complexwebquestions_qwen2.5-1.5b-instruct.sh
================================================
python3 src/encode.py \
    --model_name=qwen2.5-1.5b-instruct \
    --dataset=complexwebquestions \
    --sample=300 \
    --per_device_train_batch_size=1 \
    --num_train_epochs=1 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 

python3 src/inference.py \
    --model_name=qwen2.5-1.5b-instruct \
    --dataset=complexwebquestions \
    --sample=300 \
    --num_train_epochs=1 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 \
    --max_new_tokens=20 \
    --inference_method=combine 

================================================
FILE: configs/hotpotqa_llama3-8b-instruct.sh
================================================
python3 src/encode.py \
    --model_name=llama3-8b-instruct \
    --dataset=hotpotqa \
    --sample=300 \
    --per_device_train_batch_size=1 \
    --num_train_epochs=1 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 \
    --with_cot

python3 src/inference.py \
    --model_name=llama3-8b-instruct \
    --dataset=hotpotqa \
    --sample=300 \
    --num_train_epochs=1 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 \
    --max_new_tokens=128 \
    --inference_method=combine \
    --with_cot

================================================
FILE: configs/hotpotqa_llama3.2-1b-instruct.sh
================================================
python3 src/encode.py \
    --model_name=llama3.2-1b-instruct \
    --dataset=hotpotqa \
    --sample=300 \
    --per_device_train_batch_size=1 \
    --num_train_epochs=1 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 \
    --with_cot

python3 src/inference.py \
    --model_name=llama3.2-1b-instruct \
    --dataset=hotpotqa \
    --sample=300 \
    --num_train_epochs=1 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 \
    --max_new_tokens=128 \
    --inference_method=combine \
    --with_cot

================================================
FILE: configs/hotpotqa_qwen2.5-1.5b-instruct.sh
================================================
python3 src/encode.py \
    --model_name=qwen2.5-1.5b-instruct \
    --dataset=hotpotqa \
    --sample=300 \
    --per_device_train_batch_size=1 \
    --num_train_epochs=2 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 \
    --with_cot

python3 src/inference.py \
    --model_name=qwen2.5-1.5b-instruct \
    --dataset=hotpotqa \
    --sample=300 \
    --num_train_epochs=2 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 \
    --max_new_tokens=128 \
    --inference_method=combine \
    --with_cot

================================================
FILE: configs/popqa_llama3-8b-instruct.sh
================================================
python3 src/encode.py \
    --model_name=llama3-8b-instruct \
    --dataset=popqa \
    --sample=300 \
    --per_device_train_batch_size=1 \
    --num_train_epochs=2 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 

python3 src/inference.py \
    --model_name=llama3-8b-instruct \
    --dataset=popqa \
    --sample=300 \
    --num_train_epochs=2 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 \
    --max_new_tokens=20 \
    --inference_method=combine 

================================================
FILE: configs/popqa_llama3.2-1b-instruct.sh
================================================
python3 src/encode.py \
    --model_name=llama3.2-1b-instruct \
    --dataset=popqa \
    --sample=300 \
    --per_device_train_batch_size=1 \
    --num_train_epochs=2 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 

python3 src/inference.py \
    --model_name=llama3.2-1b-instruct \
    --dataset=popqa \
    --sample=300 \
    --num_train_epochs=2 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 \
    --max_new_tokens=20 \
    --inference_method=combine 

================================================
FILE: configs/popqa_qwen2.5-1.5b-instruct.sh
================================================
python3 src/encode.py \
    --model_name=qwen2.5-1.5b-instruct \
    --dataset=popqa \
    --sample=300 \
    --per_device_train_batch_size=1 \
    --num_train_epochs=2 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 

python3 src/inference.py \
    --model_name=qwen2.5-1.5b-instruct \
    --dataset=popqa \
    --sample=300 \
    --num_train_epochs=2 \
    --learning_rate=0.0003 \
    --lora_rank=2 \
    --lora_alpha=32 \
    --max_new_tokens=20 \
    --inference_method=combine 

================================================
FILE: prep_elastic.py
================================================
import argparse
import glob
import time
import csv
from tqdm import tqdm
from src.retrieve.beir.beir.retrieval.search.lexical.elastic_search import ElasticSearch

def build_elasticsearch(
    beir_corpus_file_pattern: str,
    index_name: str,
):
    beir_corpus_files = glob.glob(beir_corpus_file_pattern)
    print(f'#files {len(beir_corpus_files)}')
    config = {
        'hostname': 'http://localhost:9200',
        'index_name': index_name,
        'keys': {'title': 'title', 'body': 'txt'},
        'timeout': 100,
        'retry_on_timeout': True,
        'maxsize': 24,
        'number_of_shards': 'default',
        'language': 'english',
    }
    es = ElasticSearch(config)

    # create index
    print(f'create index {index_name}')
    es.delete_index()
    time.sleep(5)
    es.create_index()

    # generator
    def generate_actions():
        for beir_corpus_file in beir_corpus_files:
            with open(beir_corpus_file, 'r') as fin:
                reader = csv.reader(fin, delimiter='\t')
                header = next(reader)  # skip header
                for row in reader:
                    _id, text, title = row[0], row[1], row[2]
                    es_doc = {
                        '_id': _id,
                        '_op_type': 'index',
                        'refresh': 'wait_for',
                        config['keys']['title']: title,
                        config['keys']['body']: text,
                    }
                    yield es_doc

    # index
    progress = tqdm(unit='docs')
    es.bulk_add_to_index(
        generate_actions=generate_actions(),
        progress=progress)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', type=str, default=None, help='input file')
    parser.add_argument("--index_name", type=str, default=None, help="index name")
    args = parser.parse_args()
    build_elasticsearch(args.data_path, index_name=args.index_name)

================================================
FILE: requirements.txt
================================================
torch==1.13.1
transformers==4.44.2
elasticsearch==8.15.0
peft==0.13.2
pandas==1.5.3
numpy==1.26.4
faiss-cpu==1.8.0
termcolor

================================================
FILE: src/augment.py
================================================
import os
import json
import random
import argparse
import pandas as pd
from tqdm import tqdm

from retrieve.retriever import bm25_retrieve
from utils import get_model, model_generate
from root_dir_path import ROOT_DIR

random.seed(42)


def load_popqa(data_path):
    data_path = os.path.join(data_path, "popQA.tsv")
    dataset = pd.read_csv(data_path, sep="\t")
    new_dataset = []
    for did in range(len(dataset)):
        data = dataset.iloc[did]
        question = data["question"]
        answer = [data["obj"]] + eval(data["o_aliases"])
        val = {
            "test_id": did, 
            "question": question, 
            "answer": answer,
        }        
        new_dataset.append(val)
    return {"total": new_dataset}


def load_complexwebquestions(data_path):
    data_path = os.path.join(data_path, "ComplexWebQuestions_dev.json")
    with open(data_path, "r") as fin:
        dataset = json.load(fin)
    new_dataset = []
    for did, data in enumerate(dataset):
        question = data["question"]
        answer = []
        for ans in data["answers"]:
            answer.append(ans["answer"])
            answer.extend(ans["aliases"])
        answer = list(set(answer))
        val = {
            "test_id": did, 
            "question": question, 
            "answer": answer,
        }        
        new_dataset.append(val)
    ret = {"total": new_dataset}
    return ret


def load_2wikimultihopqa(data_path):
    with open(os.path.join(data_path, "dev.json"), "r") as fin:
        dataset = json.load(fin)
    with open(os.path.join(data_path, "id_aliases.json"), "r") as fin:
        aliases = dict()
        for li in fin:
            t = json.loads(li)
            aliases[t["Q_id"]] = t["aliases"]
    new_dataset = []
    type_to_dataset = {}
    for did, data in enumerate(dataset):
        ans_id = data["answer_id"]
        val = {
            "qid": data["_id"], 
            "test_id": did, 
            "question": data["question"], 
            "answer": aliases[ans_id] if ans_id else data["answer"]
        }
        golden_passages = []
        contexts = {name: " ".join(sents) for name, sents in data["context"]}
        for fact_name, _sent_id in data["supporting_facts"]:
            psg = contexts[fact_name]
            golden_passages.append(psg)
        val["golden_passages"] = golden_passages
        val["type"] = data["type"]
        new_dataset.append(val)
        if data["type"] not in type_to_dataset:
            type_to_dataset[data["type"]] = []
        type_to_dataset[data["type"]].append(val)
    ret = {"total": new_dataset}
    ret.update(type_to_dataset)
    return ret


def load_hotpotqa(data_path):
    data_path = os.path.join(data_path, "hotpot_dev_distractor_v1.json")
    with open(data_path, "r") as fin:
        dataset = json.load(fin)
    new_dataset = []
    type_to_dataset = {}
    for did, data in enumerate(dataset):
        val = {
            "qid": data["_id"], 
            "test_id": did, 
            "question": data["question"], 
            "answer": data["answer"]
        }
        tmp = []
        contexts = {name: "".join(sents) for name, sents in data["context"]}
        for fact_name, _sent_id in data["supporting_facts"]:
            psg = contexts[fact_name]
            tmp.append(psg)
        golden_passages = []
        for p in tmp:
            if p not in golden_passages:
                golden_passages.append(p)
        val["golden_passages"] = golden_passages
        val["type"] = data["type"]
        new_dataset.append(val)
        if data["type"] not in type_to_dataset:
            type_to_dataset[data["type"]] = []
        type_to_dataset[data["type"]].append(val)
    ret = {"total": new_dataset}
    ret.update(type_to_dataset)
    return ret


def load_default_format_data(data_path):
    filename = data_path.split("/")[-1]
    assert filename.endswith(".json"), f"Need json data: {data_path}"
    with open(data_path, "r") as fin:
        dataset = json.load(fin)
    for did, data in enumerate(dataset):
        assert "question" in data, f"\"question\" not in data, {data_path}"
        question = data["question"]
        assert type(question) == str, f"\"question\": {question} should be a string"
        assert "answer" in data, f"\"answer\" not in data, {data_path}"
        answer = data["answer"]
        assert type(answer) == str or \
               (type(answer) == list and (not any(type(a) != str for a in answer))), \
               f"\"answer\": {answer} should be a string or a list[str]" 
        data["test_id"] = did
    return {filename: dataset}


def get_rewrite(passage, model_name, model=None, tokenizer=None, generation_config=None):
    rewrite_prompt = "Rewrite the following passage. While keeping the entities, proper nouns, and key details such as names, locations, and terminology intact, create a new version of the text that expresses the same ideas in a different way. Make sure the revised passage is distinct from the original one, but preserves the core meaning and relevant information.\n{passage}"
    return model_generate(rewrite_prompt.format(passage=passage), model, tokenizer, generation_config)


qa_prompt_template = "I will provide a passage of text, and you need to generate three different questions based on the content of this passage. Each question should be answerable using the information provided in the passage. Additionally, please provide an appropriate answer for each question derived from the passage.\n\
You need to generate the question and answer in the following format:\n\
[\n\
    {{\n\
        \"question\": \"What is the capital of France?\",\n\
        \"answer\": \"Paris\"\n\
        \"full_answer\": \"The capital of France is Paris.\"\n\
    }}, \n\
]\n\n\
This list should have at least three elements. You only need to output this list in the above format.\n\
Passage:\n\
{passage}"

def fix_qa(qa):
    if isinstance(qa, list):
        if len(qa) >= 3:
            qa = qa[:3]
            for data in qa:
                if "question" not in data or "answer" not in data or "full_answer" not in data:
                    return False, qa
                if isinstance(data["answer"], list):
                    data["answer"] = ", ".join(data["answer"])
                if isinstance(data["answer"], int):
                    data["answer"] = str(data["answer"])
                if data["answer"] is None:
                    data["answer"] = "Unknown"
            return True, qa
    return False, qa

def get_qa(passage, model_name, model=None, tokenizer=None, generation_config=None):

    def fix_json(output):
        if model_name == "llama3.2-1b-instruct":
            output = output[output.find("["):]
            if output.endswith(","):
                output = output[:-1]
            if not output.endswith("]"):
                output += "]"
        elif model_name == "llama3-8b-instruct":
            if "[" in output:
                output = output[output.find("["):] 
            if "]" in output:
                output = output[:output.find("]")+1]
        return output

    try_times = 100
    prompt = qa_prompt_template.format(passage=passage)
    output = None
    while try_times:
        output = model_generate(prompt, model, tokenizer, generation_config)
        output = fix_json(output)
        try:
            qa = json.loads(output)
            ret, qa = fix_qa(qa)
            if ret:
                return qa
        except:
            try_times -= 1
    return output
    

def main(args):
    output_dir = os.path.join(ROOT_DIR, "data_aug", args.dataset, args.model_name)
    os.makedirs(output_dir, exist_ok=True)

    print("### Loading dataset ###")
    if f"load_{args.dataset}" in globals():
        load_func = globals()[f"load_{args.dataset}"]
    else:
        load_func = globals()["load_default_format_data"]
    load_dataset = load_func(args.data_path)
    if len(load_dataset) == 1:
        solve_dataset = load_dataset
    else:
        solve_dataset = {k: v for k, v in load_dataset.items() if k != "total"}
        with open(os.path.join(output_dir, "total.json"), "w") as fout:
            json.dump(load_dataset["total"][:args.sample], fout, indent=4)
    
    model, tokenizer, _ = get_model(args.model_name)
    generation_config = dict(
        max_new_tokens=512,
        return_dict_in_generate=True,
        pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0,
        temperature=0.7,
        top_k=50,
    )

    for filename, dataset in solve_dataset.items():
        print(f"### Solving {filename} ###")
        output_file = os.path.join(
            output_dir, 
            filename if filename.endswith(".json") else filename + ".json"
        )
        ret = []
        dataset = dataset[:args.sample]
        pbar = tqdm(total = args.sample * args.topk)
        for data in dataset:
            passages = bm25_retrieve(data["question"], topk=args.topk+10)
            final_passages = []
            data["augment"] = []
            for psg in passages:
                val = { 
                    "pid": len(final_passages), 
                    "passage": psg, 
                    f"{args.model_name}_rewrite": get_rewrite(psg, args.model_name, model, tokenizer, generation_config)
                }
                qa = get_qa(psg, args.model_name, model, tokenizer, generation_config)
                if fix_qa(qa)[0] == False: # skip error passage
                    continue
                val[f"{args.model_name}_qa"] = qa
                data["augment"].append(val)
                final_passages.append(psg)
                pbar.update(1)
                if len(data["augment"]) == args.topk:
                    break
            data["passages"] = final_passages
            ret.append(data)
        with open(output_file, "w") as fout:
            json.dump(ret, fout, indent=4)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--dataset", type=str, required=True)
    parser.add_argument("--data_path", type=str, required=True)
    parser.add_argument("--sample", type=int, required=True)
    parser.add_argument("--topk", type=int, default=3) 
    args = parser.parse_args()
    print(args)
    main(args)

================================================
FILE: src/encode.py
================================================
import os
import gc
import time
import argparse
import torch
from tqdm import tqdm
from peft import TaskType, get_peft_model, LoraConfig, PeftModel
from torch.utils.data import Dataset
from transformers import DefaultDataCollator
from typing import Dict, List

import prompt_template
from root_dir_path import ROOT_DIR
from utils import get_model, load_data

import numpy as np
import random

seed = 42 
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)


class TrainingData(Dataset):
    ignored_id = -100

    def __init__(self, prompt_ids, tokenizer, max_length=3000):
        self.max_length = max_length
        self.dataset = []
        pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
        for input_ids in prompt_ids:
            labels = input_ids.copy()
            if len(input_ids) > max_length:
                input_ids = input_ids[:max_length]
                labels = labels[:max_length]
            attention_mask = [1] * len(input_ids) + [0] * (max_length - len(input_ids))
            input_ids += [pad_token_id] * (max_length - len(input_ids))
            labels += [self.ignored_id] * (max_length - len(labels))
            self.dataset.append({
                "input_ids": input_ids,
                "labels": labels,
                "attention_mask": attention_mask,
            })
        self.total_len = len(self.dataset)
    
    def __len__(self):
        return self.total_len
    
    def __getitem__(self, idx) -> Dict[str, list]:
        return self.dataset[idx]


class TrainingDataCollator(DefaultDataCollator):
    def __init__(self, tokenizer, device):
        super().__init__()
        self.tokenizer = tokenizer
        self.device = device
    
    def __call__(self, examples: List[Dict[str, list]]) -> Dict[str, torch.Tensor]:
        input_ids, labels, attention_mask = tuple(
            map(lambda x: [example[x] for example in examples], ["input_ids", "labels", "attention_mask"])
        )
        return {
            "input_ids": torch.tensor(input_ids).to(self.device),
            "labels": torch.tensor(labels).to(self.device),
            "attention_mask": torch.tensor(attention_mask).to(self.device),
        }
    

def get_train_data(aug_model, augments, tokenizer, args):
    from prompt_template import get_prompt
    prompt_ids = []
    for aug in augments:
        psg = aug["passage"]
        rew = aug[f"{aug_model}_rewrite"]
        qas = aug[f"{aug_model}_qa"]
        qpa_cnt = (len(qas) + 1) // 2
        for qid, qa in enumerate(qas):
            if qid < qpa_cnt:
                for ppp in [psg, rew]:
                    prompt_ids.append(get_prompt(tokenizer, qa["question"], 
                                                    [ppp], 
                                                    qa["answer"] if not args.with_cot else qa["full_answer"], 
                                                    with_cot=args.with_cot))
            else:
                prompt_ids.append(get_prompt(tokenizer, qa["question"], 
                                                None, 
                                                qa["answer"] if not args.with_cot else qa["full_answer"], 
                                                with_cot=args.with_cot))
    return prompt_ids


def train(question, augments, args, model, tokenizer, 
          init_adapter_path, save_path):
    prompt_ids = get_train_data(args.augment_model, augments, tokenizer, args)
    train_data = TrainingData(prompt_ids, tokenizer)
    train_dataloader = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.per_device_train_batch_size,
        collate_fn=TrainingDataCollator(tokenizer, model.device),
        shuffle=False,
    )
    model = PeftModel.from_pretrained(model, init_adapter_path, is_trainable=True)
    model.is_parallelizable = True
    model.model_parallel = True
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.AdamW(model_parameters, lr=args.learning_rate)
    for epoch in range(args.num_train_epochs):
        for step, batch in enumerate(train_dataloader):
            optimizer.zero_grad()
            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
    os.makedirs(save_path, exist_ok=True)
    model.save_pretrained(save_path)
    model = model.unload()
    torch.cuda.empty_cache()
    gc.collect()
    return model


def main(args):
    data_list = load_data(args.dataset, args.data_type, args.augment_model)
    model, tokenizer, _generation_config = get_model(args.model_name)
    if args.with_cot:
        prompt_template.get_fewshot(args.dataset)

    init_adapter_path = os.path.join(
        ROOT_DIR, 
        "offline", 
        args.model_name, 
        f"rank={args.lora_rank}_alpha={args.lora_alpha}",
        "base_weight",
    )
    if not os.path.exists(os.path.join(init_adapter_path, "adapter_model.safetensors")):
        print("No LoRA base weight, creating...")
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            target_modules=['down_proj', 'gate_proj', 'up_proj'],
            inference_mode=False,
            r=args.lora_rank,
            lora_alpha=args.lora_alpha,
            lora_dropout=0, # !!!
        )
        model = get_peft_model(model, peft_config)
        model.is_parallelizable = True
        model.model_parallel = True
        print(f'Save LoRA base weight to {init_adapter_path}')
        os.makedirs(init_adapter_path, exist_ok=True)
        model.save_pretrained(init_adapter_path)
        time.sleep(2)
        assert os.path.exists(os.path.join(init_adapter_path, "adapter_model.safetensors")) 

    cot_name = "cot" if args.with_cot else "direct"
    for filename, fulldata in data_list:
        filename = filename.split('.')[0] 
        print(f"### Solving {filename} ###")
        output_dir = os.path.join(
            ROOT_DIR, 
            "offline", 
            args.model_name, 
            f"rank={args.lora_rank}_alpha={args.lora_alpha}",
            args.dataset,
            f"lr={args.learning_rate}_epoch={args.num_train_epochs}_{cot_name}",
            f"aug_model={args.augment_model}",
            filename,
        )
        os.makedirs(output_dir, exist_ok=True)
        fulldata = fulldata if args.sample == -1 else fulldata[:args.sample]
        for did, data in tqdm(enumerate(fulldata), total=len(fulldata)):
            augment = data["augment"]
            for pid in range(len(augment)):
                save_path = os.path.join(output_dir, f"data_{did}", f"passage_{pid}")
                if os.path.exists(os.path.join(save_path, "adapter_model.safetensors")):
                    continue
                model = train(data["question"], [augment[pid]], args, model, tokenizer, 
                            init_adapter_path, save_path)
                

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--dataset", type=str, required=True)
    parser.add_argument("--data_type", type=str)
    parser.add_argument("--with_cot", action="store_true")
    parser.add_argument("--sample", type=int, default=-1) # -1 means all
    parser.add_argument("--augment_model", type=str, default=None)
    # Train
    parser.add_argument("--per_device_train_batch_size", type=int, default=1)
    parser.add_argument("--num_train_epochs", type=int, default=3)
    parser.add_argument("--learning_rate", type=float, default=3e-4)
    # LoRA
    parser.add_argument("--lora_rank", type=int, default=None)
    parser.add_argument("--lora_alpha", type=int, default=None)
    args = parser.parse_args()
    assert args.lora_rank and args.lora_alpha, "No config for LoRA"
    if args.augment_model is None:
        args.augment_model = args.model_name
    print(args)
    main(args)

================================================
FILE: src/fewshot/2wikimultihopqa.json
================================================
[
    {
        "question": "When did the director of film Hypocrite (Film) die?",
        "answer": "The film Hypocrite was directed by Miguel Morayta. Miguel Morayta died on 19 June 2013. So the answer is 19 June 2013."
    },
    {
        "question": "Are both Kurram Garhi and Trojkrsti located in the same country?",
        "answer": "Kurram Garhi is located in the country of Pakistan. Trojkrsti is located in the country of Republic of Macedonia. Thus, they are not in the same country. So the answer is no."
    },
    {
        "question": "Do director of film Coolie No. 1 (1995 Film) and director of film The Sensational Trial have the same nationality?",
        "answer": "Coolie No. 1 (1995 film) was directed by David Dhawan. The Sensational Trial was directed by Karl Freund. David Dhawan's nationality is India. Karl Freund's nationality is Germany. Thus, they do not have the same nationality. So the answer is no."
    },
    {
        "question": "Who is Boraqchin (Wife Of \u00d6gedei)'s father-in-law?",
        "answer": "Boraqchin is married to \u00d6gedei Khan. \u00d6gedei Khan's father is Genghis Khan. Thus, Boraqchin's father-in-law is Genghis Khan. So the answer is Genghis Khan."
    },
    {
        "question": "Who was born first out of Martin Hodge and Ivania Martinich?",
        "answer": "Martin Hodge was born on 4 February 1959. Ivania Martinich was born on 25 July 1995. Thus, Martin Hodge was born first. So the answer is Martin Hodge."
    },
    {
        "question": "When did the director of film Laughter In Hell die?",
        "answer": "The film Laughter In Hell was directed by Edward L. Cahn. Edward L. Cahn died on August 25, 1963. So the answer is August 25, 1963."
    },
    {
        "question": "Which film has the director died later, The Gal Who Took the West or Twenty Plus Two?",
        "answer": "The film Twenty Plus Two was directed by Joseph M. Newman. The Gal Who Took the West was directed by Frederick de Cordova. Joseph M. Newman died on January 23, 2006. Fred de Cordova died on September 15, 2001. Thus, the person to die later from the two is Twenty Plus Two. So the answer is Twenty Plus Two."
    },
    {
        "question": "Who is the grandchild of Krishna Shah (Nepalese Royal)?",
        "answer": "Krishna Shah has a child named Rudra Shah. Rudra Shah has a child named Prithvipati Shah. Thus, Krishna Shah has a grandchild named Prithvipati Shah. So the answer is Prithvipati Shah."
    }
]

================================================
FILE: src/fewshot/hotpotqa.json
================================================
[
    {
        "question": "Jeremy Theobald and Christopher Nolan share what profession?",
        "answer": "Jeremy Theobald is an actor and producer. Christopher Nolan is a director, producer, and screenwriter. Therefore, they both share the profession of being a producer. So the answer is producer."
    },
    {
        "question": "What film directed by Brian Patrick Butler was inspired by a film directed by F.W. Murnau?",
        "answer": "Brian Patrick Butler directed the film The Phantom Hour. The Phantom Hour was inspired by the films such as Nosferatu and The Cabinet of Dr. Caligari. Of these Nosferatu was directed by F.W. Murnau. So the answer is The Phantom Hour.."
    },
    {
        "question": "How many episodes were in the South Korean television series in which Ryu Hye-young played Bo-ra?",
        "answer": "The South Korean television series in which Ryu Hye-young played Bo-ra is Reply 1988. The number of episodes Reply 1988 has is 20. So the answer is 20."
    },
    {
        "question": "Were Lonny and Allure both founded in the 1990s?",
        "answer": "Lonny (magazine) was founded in 2009. Allure (magazine) was founded in 1991. Thus, of the two, only Allure was founded in 1990s. So the answer is no."
    },
    {
        "question": "Vertical Limit stars which actor who also played astronaut Alan Shepard in \"The Right Stuff\"?",
        "answer": "The actor who played astronaut Alan Shepard in \"The Right Stuff\" is Scott Glenn. The movie Vertical Limit also starred Scott Glenn. So the answer is Scott Glenn."
    },
    {
        "question": "What was the 2014 population of the city where Lake Wales Medical Center is located?",
        "answer": "Lake Wales Medical Center is located in the city of Polk County, Florida. The population of Polk County in 2014 was 15,140. So the answer is 15,140."
    },
    {
        "question": "Who was born first? Jan de Bont or Raoul Walsh?",
        "answer": "Jan de Bont was born on 22 October 1943. Raoul Walsh was born on March 11, 1887. Thus, Raoul Walsh was born the first. So the answer is Raoul Walsh."
    },
    {
        "question": "In what country was Lost Gravity manufactured?",
        "answer": "The Lost Gravity (roller coaster) was manufactured by Mack Rides. Mack Rides is a German company. So the answer is Germany."
    },
    {
        "question": "Which of the following had a debut album entitled \"We Have an Emergency\": Hot Hot Heat or The Operation M.D.?",
        "answer": "The debut album of the band \"Hot Hot Heat\" was \"Make Up the Breakdown\". The debut album of the band \"The Operation M.D.\" was \"We Have an Emergency\". So the answer is The Operation M.D.."
    },
    {
        "question": "How many awards did the \"A Girl Like Me\" singer win at the American Music Awards of 2012?",
        "answer": "The singer of \"A Girl Like Me\" singer is Rihanna. In the American Music Awards of 2012, Rihana won one award. So the answer is one."
    },
    {
        "question": "The actor that stars as Joe Proctor on the series \"Power\" also played a character on \"Entourage\" that has what last name?",
        "answer": "The actor that stars as Joe Proctor on the series \"Power\" is Jerry Ferrara. Jerry Ferrara also played a character on Entourage named Turtle Assante. Thus, Turtle Assante's last name is Assante. So the answer is Assante."
    }
]

================================================
FILE: src/get_warmup_data.py
================================================
import os
import json
import pandas as pd
import random
import torch
from tqdm import tqdm

import prompt_template
from utils import get_model, evaluate
from root_dir_path import ROOT_DIR
from augment import load_complexwebquestions, load_popqa

random.seed(42)


# direct, without fewshot and cot 
# for popqa and complexwebquestions 
def create_direct():
    for name, func in (("popqa", load_popqa), ("complexwebquestions", load_complexwebquestions)):
        dataset = func(os.path.join(ROOT_DIR, "data", name))["total"]
        dataset = dataset[1000:] # to prevent data leakage, only the first 300 entries were actually tested.
        for data in dataset:
            if isinstance(data["answer"], list):
                data["answer"] = data["answer"][0]
            del_keys = [k for k in data.keys() if k != "answer" and k != "question"]
            for k in del_keys:
                data.pop(k)
        
        output_dir = os.path.join(ROOT_DIR, "warmup", "data", "direct")
        os.makedirs(output_dir, exist_ok=True)
        with open(os.path.join(output_dir, name+".json"), "w") as fout:
            json.dump(dataset, fout, indent=4)


def load_2wikimultihopqa(data_path):
    with open(os.path.join(data_path, "dev.json"), "r") as fin:
        dataset = json.load(fin)
    with open(os.path.join(data_path, "id_aliases.json"), "r") as fin:
        aliases = dict()
        for li in fin:
            t = json.loads(li)
            aliases[t["Q_id"]] = t["aliases"]
    new_dataset = []
    for did, data in enumerate(dataset):
        name_to_ctx = {}
        for ct in data['context']:
            name_to_ctx[ct[0]] = ct[1]
        context = []
        flag = False
        for fact_name, fact_id in data["supporting_facts"]:
            if fact_name not in name_to_ctx or fact_id >= len(name_to_ctx[fact_name]):
                flag = True
                break
            context.append(name_to_ctx[fact_name][fact_id])
        if flag:
            continue
        answer = data["answer"]
        answer = answer if type(answer) == str else answer[0]
        val = {
            "qid": data["_id"], 
            "test_id": did, 
            "question": data["question"], 
            "answer": answer,
            "context": context,
            "type": data["type"],
        }
        new_dataset.append(val)
    return {"total": new_dataset}


def load_hotpotqa(data_path):
    with open(os.path.join(data_path, 'hotpot_dev_distractor_v1.json'), 'r') as fin:
        dataset = json.load(fin)
    new_dataset = []
    for did, data in enumerate(dataset):
        all_ctxs = {}
        for name, text in data["context"]:
            all_ctxs[name] = text
        context = []
        flag = False
        for name, id in data["supporting_facts"]:
            if id > len(all_ctxs[name]):
                print("### Error supporting facts id: ", data["_id"], id, len(all_ctxs[name]))
                flag = True
                continue
            context.append(all_ctxs[name][id])
        if flag:
            continue
        val = {
            'qid': data['_id'], 
            'test_id': did, 
            'question': data['question'], 
            'answer': data['answer'], 
            "context": context,
            "type": data["type"],
        }

        new_dataset.append(val)
    return {"total": new_dataset}


USER_PROMPT_WITH_COT = "You should use the context provided below to answer the question. Please follow the same structure as the example.\n\n\
Here are some examples of how to answer the questions:\n\
{fewshot}\n\
Here is the context for the question:\n\
{context}\n\n\
The correct answer to the given question is {answer}. Now, please answer the question in the same format as above.\n\
Question: {question}"
ASSISTANT_PROMPT_WITH_COT = "Answer: "


# fewshot and cot
# for 2wikimultihopqa and hotpotqa
def create_cot():
    AIM_CNT_EACH_DATASET = 300
    
    output_dataset = []

    model_name = "llama3-8b-instruct"
    model, tokenizer, generation_config = get_model(model_name, max_new_tokens=128)
    model.eval()

    for name, func in (("2wikimultihopqa", load_2wikimultihopqa), ("hotpotqa", load_hotpotqa)):
        print(f"### solving {name} ###")
        dataset = func(os.path.join(ROOT_DIR, "data", name))["total"]
        # prevent data leakage
        mark_idx = {}
        for did, data in enumerate(dataset):
            typ = data["type"]
            if typ not in mark_idx:
                mark_idx[typ] = {"cnt": 1, "last_idx": did}
            else:
                if mark_idx[typ]["cnt"] >= 300:
                    continue
                mark_idx[typ]["cnt"] += 1
                mark_idx[typ]["last_idx"] = did
        last_idx = max(v["last_idx"] for k, v in mark_idx.items())
        dataset = dataset[last_idx + 1000:]
        random.shuffle(dataset)

        last_cnt = len(output_dataset)
        pbar = tqdm(total=AIM_CNT_EACH_DATASET)
        prompt_template.get_fewshot(name)
        for did, data in enumerate(dataset):
            passages = data["context"]
            context = ""
            for pid, psg in enumerate(passages):
                context += f"{pid+1}. {psg.strip()}\n"
            user_content = USER_PROMPT_WITH_COT.format(fewshot=prompt_template.fewshot, 
                                                 context=context,
                                                 question=data["question"], 
                                                 answer=data["answer"])
            messages = [{"role": "user", "content": user_content}]
            inputs = tokenizer.apply_chat_template(
                messages, 
                add_generation_prompt=True
            )
            inputs += tokenizer.encode(ASSISTANT_PROMPT_WITH_COT, add_special_tokens=False)
            input_len = len(inputs)
            input_ids = torch.tensor(inputs).unsqueeze(0).to(model.device)
            with torch.no_grad():
                output = model.generate(
                    input_ids, 
                    attention_mask = torch.ones(input_ids.shape).to(model.device),
                    **generation_config)
            output = output.sequences[0][input_len:]
            text = tokenizer.decode(output, skip_special_tokens=True)
            if text is None or not "the answer is" in text:
                continue
            for stop_words in ["\n\n", "Questions"]:
                if stop_words in text:
                    text = (text[:text.find(stop_words)]).strip()
            if evaluate(text, data["answer"], with_cot=True)["em"] == "1":
                data["cot"] = text
                data["from"] = name
                output_dataset.append(data)
                pbar.update(1)
                if len(output_dataset) - last_cnt == AIM_CNT_EACH_DATASET:
                    # print(f"### {name}: {AIM_CNT_EACH_DATASET} / {did+1} ###")
                    break
    
    output_dir = os.path.join(ROOT_DIR, "warmup", "data", "cot")
    os.makedirs(output_dir, exist_ok=True)
    with open(os.path.join(output_dir, "train_data.json"), "w") as fout:
        json.dump(output_dataset, fout, indent=4)


if __name__ == "__main__":
    create_direct()
    create_cot()

================================================
FILE: src/inference.py
================================================
import os
import gc
import json
import argparse
import torch
from tqdm import tqdm
from peft import PeftModel

import prompt_template
from root_dir_path import ROOT_DIR
from utils import get_model, evaluate, predict, load_data, read_complete

def main(args):
    data_list = load_data(args.dataset, args.data_type, args.augment_model)
    model, tokenizer, generation_config = get_model(
        args.model_name,
        max_new_tokens = args.max_new_tokens,
    )
    if args.with_cot:
        prompt_template.get_fewshot(args.dataset)
    
    cot_name = "cot" if args.with_cot else "direct"
    load_adapter_path = os.path.join(
        ROOT_DIR, 
        "offline", 
        args.model_name, 
        f"rank={args.lora_rank}_alpha={args.lora_alpha}",
        args.dataset,
        f"lr={args.learning_rate}_epoch={args.num_train_epochs}_{cot_name}",
        f"aug_model={args.augment_model}",
    )
    output_root_dir = os.path.join(
        ROOT_DIR, 
        "output",
        args.model_name, 
        f"rank={args.lora_rank}_alpha={args.lora_alpha}",
        args.dataset,
        f"lr={args.learning_rate}_epoch={args.num_train_epochs}_{cot_name}",
        f"aug_model={args.augment_model}",
        args.inference_method, 
    )
    for filename, fulldata in data_list:
        filename = filename.split(".")[0]
        print(f"### Solving {filename} ###")
        output_dir = os.path.join(output_root_dir, filename)
        os.makedirs(output_dir, exist_ok=True)
        with open(os.path.join(output_dir, "config.json"), "w") as fout:
            json.dump(vars(args), fout, indent=4)

        predict_file = os.path.join(output_dir, "predict.json")
        ret, start_with = read_complete(predict_file)

        fulldata = fulldata[start_with:] if args.sample == -1 else fulldata[start_with:args.sample]
        for test_id, data in tqdm(enumerate(fulldata), total=len(fulldata)):
            test_id = test_id + start_with
            assert test_id == len(ret), f"test_id {test_id} != len(ret) {len(ret)}"

            question = data["question"]
            passages = data["passages"]
            answer = data["answer"]

            def get_pred(model, psgs):
                text = predict(model, tokenizer, generation_config, 
                                        question, with_cot=args.with_cot, 
                                        passages=psgs)
                pred = {
                    "test_id": test_id, 
                    "question": question, 
                    "answer": answer, 
                    "text": text,
                }
                pred.update(evaluate(text, answer, args.with_cot))
                return pred

            if args.inference_method == "icl":
                ret.append(get_pred(model, psgs=passages))
            else:
                for pid in range(len(passages)):
                    adapter_path = os.path.join(load_adapter_path, filename, f"data_{test_id}", f"passage_{pid}")
                    if pid == 0:
                        model = PeftModel.from_pretrained(
                            model, 
                            adapter_path,
                            adapter_name = "0", 
                            is_trainable = False
                        )
                    else:
                        model.load_adapter(adapter_path, adapter_name = str(pid)) 
                # merge
                model.add_weighted_adapter(
                    adapters = [str(i) for i in range(len(passages))], 
                    weights = [1] * len(passages),
                    adapter_name = "merge", 
                    combination_type = "cat",
                )
                model.set_adapter("merge")
                ret.append(get_pred(model, psgs=None if args.inference_method == "prag" else passages))
                model.delete_adapter("merge")
                model = model.unload()
                torch.cuda.empty_cache()
                gc.collect()

        with open(predict_file, "w") as fout:
            json.dump(ret, fout, indent=4)

        ##### Evaluating #####
        metrics = ["em", "f1", "prec", "recall"]
        ret_str = ""
        for met in metrics:
            acc = sum(float(d[met]) for d in ret) / len(ret)
            acc = round(acc, 4)
            ret_str += f"{met}\t{acc}\n"
        ret_str += "\n" + json.dumps(vars(args), indent=4)
        with open(os.path.join(output_dir, "result.txt"), "w") as fout:
            fout.write(ret_str)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--max_new_tokens", type=int, required=True)
    parser.add_argument("--dataset", type=str, required=True)
    parser.add_argument("--data_type", type=str)
    parser.add_argument("--with_cot", action="store_true")
    parser.add_argument("--sample", type=int, default=-1) # -1 means all
    parser.add_argument("--augment_model", type=str, default=None)  
    parser.add_argument("--num_train_epochs", type=int, required=True)
    parser.add_argument("--learning_rate", type=float, default=3e-4)
    parser.add_argument("--inference_method", type=str, required=True, choices=["icl", "prag", "combine"])
    # LoRA
    parser.add_argument("--lora_rank", type=int)
    parser.add_argument("--lora_alpha", type=int)
    args = parser.parse_args()
    assert args.lora_rank and args.lora_alpha, "No Config for LoRA"
    if args.augment_model is None:
        args.augment_model = args.model_name
    print(args)
    main(args)

================================================
FILE: src/prompt_template.py
================================================
import os
from root_dir_path import ROOT_DIR

current_dataset = None
fewshot = None
fewshot_path = os.path.join(ROOT_DIR, "src", "fewshot")

USER_PROMPT = "You should answer the question by referring to the knowledge provided below and integrating your own knowledge.\n\
{passages}\n\n\
Question: {question}"

USER_PROMPT_WITH_COT = "You should reference the knowledge provided below and combine it with your own knowledge to answer the question. Please follow the format of the example I provided above.\n\
Here are some examples about how to answer the questions.\n\
{fewshot}\
Here are some reference.\n\
{passages}\n\n\
Let's think step by step. Answer the questions in the same format as above.\n\
Question: {question}"

ASSISTANT_PROMPT = "The answer is {answer}"
ASSISTANT_PROMPT_WITH_COT = "Answer: {answer}"

def _get_prompt(question, passages=None, answer=None):
    question = question.strip()
    if not question.endswith('?'):
        question = question.strip() + '?'
    elif question.endswith(' ?'):
        question = (question[:-1]).strip() + '?'
     
    if passages and not isinstance(passages, list):
        passages = [passages]
    
    if answer is None:
        answer = ""
    else:
        answer = answer.strip()
        if not answer.endswith('.'):
            answer += "."
    return question, passages, answer


def get_fewshot(dataset):
    import json
    global current_dataset
    global fewshot
    # assert current_dataset is None
    if dataset.endswith("_golden"):
        dataset = dataset.split("_golden")[0]
    current_dataset = dataset
    with open(os.path.join(fewshot_path, dataset + ".json"), "r") as fin:
        tmp = json.load(fin)
    fewshot = ""
    for data in tmp:
        q = data["question"]
        a = data["answer"]
        fewshot += f"Question: {q}\nAnswer: {a}\n\n"


def get_prompt(tokenizer, question, passages=None, answer=None, with_cot=False):
    question, passages, answer = _get_prompt(question, passages, answer)
    contexts = ""
    if passages:
        for pid, psg in enumerate(passages):
            contexts += f"Passage {pid+1}: {psg}\n"
    if not with_cot:
        user_content = USER_PROMPT.format(question=question, passages=contexts)
        assistant_content = ASSISTANT_PROMPT.format(answer=answer)
    else:
        assert fewshot is not None
        user_content = USER_PROMPT_WITH_COT.format(question=question, passages=contexts, fewshot=fewshot)
        assistant_content = ASSISTANT_PROMPT_WITH_COT.format(answer=answer)

    messages = [{
        "role": "user",
        "content": user_content,
    }]

    inputs = tokenizer.apply_chat_template(
        messages, 
        add_generation_prompt=True)
    inputs += tokenizer.encode(assistant_content, add_special_tokens=False)
    return inputs

================================================
FILE: src/retrieve/beir/.gitignore
================================================
# Custom added
examples/**/datasets/
examples/**/output/
examples/**/DeepCT/
examples/**/models/
examples/**/faiss-index/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/


================================================
FILE: src/retrieve/beir/.gitmodules
================================================
[submodule "examples/retrieval/evaluation/late-interaction/beir-ColBERT"]
	path = examples/retrieval/evaluation/late-interaction/beir-ColBERT
	url = https://github.com/NThakur20/beir-ColBERT.git


================================================
FILE: src/retrieve/beir/CONTRIBUTORS.txt
================================================
Individual Contributors to the BEIR Repository (BEIR contributors) include:
1. Nandan Thakur
2. Nils Reimers
3. Iryna Gurevych
4. Jimmy Lin
5. Andreas Rücklé
6. Abhishek Srivastava 

================================================
FILE: src/retrieve/beir/LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright 2020-2023 Nandan Thakur

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: src/retrieve/beir/NOTICE.txt
================================================
-------------------------------------------------------------------------------
Copyright since 2022
University of Waterloo
-------------------------------------------------------------------------------

-------------------------------------------------------------------------------
Copyright since 2020
Ubiquitous Knowledge Processing (UKP) Lab, Technische Universität Darmstadt
-------------------------------------------------------------------------------

For individual contributors, please refer to the CONTRIBUTORS file.

================================================
FILE: src/retrieve/beir/README.md
================================================
<h1 align="center">
<img style="vertical-align:middle" width="450" height="180" src="https://raw.githubusercontent.com/benchmarkir/beir/main/images/color_logo_transparent_cropped.png" />
</h1>

<p align="center">
    <a href="https://github.com/beir-cellar/beir/releases">
        <img alt="GitHub release" src="https://img.shields.io/github/release/beir-cellar/beir.svg">
    </a>
    <a href="https://www.python.org/">
            <img alt="Build" src="https://img.shields.io/badge/Made%20with-Python-1f425f.svg?color=purple">
    </a>
    <a href="https://github.com/beir-cellar/beir/blob/master/LICENSE">
        <img alt="License" src="https://img.shields.io/github/license/beir-cellar/beir.svg?color=green">
    </a>
    <a href="https://colab.research.google.com/drive/1HfutiEhHMJLXiWGT8pcipxT5L2TpYEdt?usp=sharing">
        <img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg">
    </a>
    <a href="https://pepy.tech/project/beir">
        <img alt="Downloads" src="https://static.pepy.tech/personalized-badge/beir?period=total&units=international_system&left_color=grey&right_color=orange&left_text=Downloads">
    </a>
    <a href="https://github.com/beir-cellar/beir/">
        <img alt="Downloads" src="https://badges.frapsoft.com/os/v1/open-source.svg?v=103">
    </a>
</p>

<h4 align="center">
    <p>
        <a href="https://openreview.net/forum?id=wCu6T5xFjeJ">Paper</a> |
        <a href="#beers-installation">Installation</a> |
        <a href="#beers-quick-example">Quick Example</a> |
        <a href="#beers-available-datasets">Datasets</a> |
        <a href="https://github.com/beir-cellar/beir/wiki">Wiki</a> |
        <a href="https://huggingface.co/BeIR">Hugging Face</a>
    <p>
</h4>

<!-- > The development of BEIR benchmark is supported by: -->

<h3 align="center">
    <a href="http://www.ukp.tu-darmstadt.de"><img style="float: left; padding: 2px 7px 2px 7px;" width="220" height="100" src="./images/ukp.png" /></a>
    <a href="https://www.tu-darmstadt.de/"><img style="float: middle; padding: 2px 7px 2px 7px;" width="250" height="90" src="./images/tu-darmstadt.png" /></a>
    <a href="https://uwaterloo.ca"><img style="float: right; padding: 2px 7px 2px 7px;" width="320" height="100" src="./images/uwaterloo.png" /></a>
</h3>

<h3 align="center">
    <a href="https://huggingface.co/"><img style="float: middle; padding: 2px 7px 2px 7px;" width="400" height="80" src="./images/HF.png" /></a>
</h3>

## :beers: What is it?

**BEIR** is a **heterogeneous benchmark** containing diverse IR tasks. It also provides a **common and easy framework** for evaluation of your NLP-based retrieval models within the benchmark.

For **an overview**, checkout our **new wiki** page: [https://github.com/beir-cellar/beir/wiki](https://github.com/beir-cellar/beir/wiki).

For **models and datasets**, checkout out **Hugging Face (HF)** page: [https://huggingface.co/BeIR](https://huggingface.co/BeIR).

For **Leaderboard**, checkout out **Eval AI** page: [https://eval.ai/web/challenges/challenge-page/1897](https://eval.ai/web/challenges/challenge-page/1897).

For more information, checkout out our publications:

- [BEIR: A Heterogenous Benchmark for Zero-shot Evaluation of Information Retrieval Models](https://openreview.net/forum?id=wCu6T5xFjeJ) (NeurIPS 2021, Datasets and Benchmarks Track)
- [Resources for Brewing BEIR: Reproducible Reference Models and an Official Leaderboard](https://arxiv.org/abs/2306.07471) (Arxiv 2023)

## :beers: Installation

Install via pip:

```python
pip install beir
```

If you want to build from source, use:

```python
$ git clone https://github.com/beir-cellar/beir.git
$ cd beir
$ pip install -e .
```

Tested with python versions 3.6 and 3.7

## :beers: Features 

- Preprocess your own IR dataset or use one of the already-preprocessed 17 benchmark datasets
- Wide settings included, covers diverse benchmarks useful for both academia and industry
- Includes well-known retrieval architectures (lexical, dense, sparse and reranking-based)
- Add and evaluate your own model in a easy framework using different state-of-the-art evaluation metrics

## :beers: Quick Example

For other example codes, please refer to our **[Examples and Tutorials](https://github.com/beir-cellar/beir/wiki/Examples-and-tutorials)** Wiki page. 

```python
from beir import util, LoggingHandler
from beir.retrieval import models
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES

import logging
import pathlib, os

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout

#### Download scifact.zip dataset and unzip the dataset
dataset = "scifact"
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
data_path = util.download_and_unzip(url, out_dir)

#### Provide the data_path where scifact has been downloaded and unzipped
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

#### Load the SBERT model and retrieve using cosine-similarity
model = DRES(models.SentenceBERT("msmarco-distilbert-base-tas-b"), batch_size=16)
retriever = EvaluateRetrieval(model, score_function="dot") # or "cos_sim" for cosine similarity
results = retriever.retrieve(corpus, queries)

#### Evaluate your model with NDCG@k, MAP@K, Recall@K and Precision@K  where k = [1,3,5,10,100,1000] 
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
```

## :beers: Available Datasets

Command to generate md5hash using Terminal:  ``md5sum filename.zip``.

You can view all datasets available **[here](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/)** or on **[Hugging Face](https://huggingface.co/BeIR)**.


| Dataset   | Website| BEIR-Name | Public? | Type | Queries  | Corpus | Rel D/Q | Down-load | md5 |
| -------- | -----| ---------| ------- | --------- | ----------- | ---------| ---------| :----------: | :------:|
| MSMARCO    | [Homepage](https://microsoft.github.io/msmarco/)| ``msmarco`` | ✅ | ``train``<br>``dev``<br>``test``|  6,980   |  8.84M     |    1.1 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/msmarco.zip) | ``444067daf65d982533ea17ebd59501e4`` |
| TREC-COVID |  [Homepage](https://ir.nist.gov/covidSubmit/index.html)| ``trec-covid``| ✅ | ``test``| 50|  171K| 493.5 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/trec-covid.zip) | ``ce62140cb23feb9becf6270d0d1fe6d1`` |
| NFCorpus   | [Homepage](https://www.cl.uni-heidelberg.de/statnlpgroup/nfcorpus/) | ``nfcorpus`` | ✅ |``train``<br>``dev``<br>``test``|  323     |  3.6K     |  38.2 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/nfcorpus.zip) | ``a89dba18a62ef92f7d323ec890a0d38d`` |
| BioASQ     | [Homepage](http://bioasq.org) | ``bioasq``| ❌ | ``train``<br>``test`` | 500 |  14.91M    |  4.7 | No | [How to Reproduce?](https://github.com/beir-cellar/beir/blob/main/examples/dataset#2-bioasq) |
| NQ         | [Homepage](https://ai.google.com/research/NaturalQuestions) | ``nq``| ✅ | ``train``<br>``test``| 3,452   |  2.68M  |  1.2 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/nq.zip) | ``d4d3d2e48787a744b6f6e691ff534307`` |
| HotpotQA   | [Homepage](https://hotpotqa.github.io) | ``hotpotqa``| ✅ |``train``<br>``dev``<br>``test``|  7,405   |  5.23M  |  2.0 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/hotpotqa.zip)  | ``f412724f78b0d91183a0e86805e16114`` |
| FiQA-2018  | [Homepage](https://sites.google.com/view/fiqa/) | ``fiqa`` | ✅ | ``train``<br>``dev``<br>``test``|  648     |  57K    |  2.6 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fiqa.zip)  | ``17918ed23cd04fb15047f73e6c3bd9d9`` |
| Signal-1M(RT) | [Homepage](https://research.signal-ai.com/datasets/signal1m-tweetir.html)| ``signal1m`` | ❌ | ``test``| 97   |  2.86M  |  19.6 | No | [How to Reproduce?](https://github.com/beir-cellar/beir/blob/main/examples/dataset#4-signal-1m) |
| TREC-NEWS  | [Homepage](https://trec.nist.gov/data/news2019.html) | ``trec-news`` | ❌ | ``test``| 57    |  595K    |  19.6 | No | [How to Reproduce?](https://github.com/beir-cellar/beir/blob/main/examples/dataset#1-trec-news) |
| Robust04 | [Homepage](https://trec.nist.gov/data/robust/04.guidelines.html) | ``robust04``| ❌ | ``test``| 249  |  528K  |  69.9 |  No  |  [How to Reproduce?](https://github.com/beir-cellar/beir/blob/main/examples/dataset#3-robust04)  |
| ArguAna    | [Homepage](http://argumentation.bplaced.net/arguana/data) | ``arguana``| ✅ |``test`` | 1,406     |  8.67K    |  1.0 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/arguana.zip)  | ``8ad3e3c2a5867cdced806d6503f29b99`` |
| Touche-2020| [Homepage](https://webis.de/events/touche-20/shared-task-1.html) | ``webis-touche2020``| ✅ | ``test``| 49     |  382K    |  19.0 |  [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/webis-touche2020.zip) | ``46f650ba5a527fc69e0a6521c5a23563`` |
| CQADupstack| [Homepage](http://nlp.cis.unimelb.edu.au/resources/cqadupstack/) | ``cqadupstack``| ✅ | ``test``| 13,145 |  457K  |  1.4 |  [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/cqadupstack.zip) | ``4e41456d7df8ee7760a7f866133bda78`` |
| Quora| [Homepage](https://www.quora.com/q/quoradata/First-Quora-Dataset-Release-Question-Pairs) | ``quora``| ✅ | ``dev``<br>``test``| 10,000     |  523K    |  1.6 |  [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/quora.zip) | ``18fb154900ba42a600f84b839c173167`` |
| DBPedia | [Homepage](https://github.com/iai-group/DBpedia-Entity/) | ``dbpedia-entity``| ✅ | ``dev``<br>``test``| 400    |  4.63M    |  38.2 | [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/dbpedia-entity.zip) | ``c2a39eb420a3164af735795df012ac2c`` |
| SCIDOCS| [Homepage](https://allenai.org/data/scidocs) | ``scidocs``| ✅ | ``test``| 1,000     |  25K    |  4.9 |  [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/scidocs.zip) | ``38121350fc3a4d2f48850f6aff52e4a9`` |
| FEVER | [Homepage](http://fever.ai) | ``fever``| ✅ | ``train``<br>``dev``<br>``test``|  6,666     |  5.42M    |  1.2|  [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fever.zip)  | ``5a818580227bfb4b35bb6fa46d9b6c03`` |
| Climate-FEVER| [Homepage](http://climatefever.ai) | ``climate-fever``| ✅ |``test``|  1,535     |  5.42M |  3.0 |  [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/climate-fever.zip)  | ``8b66f0a9126c521bae2bde127b4dc99d`` |
| SciFact| [Homepage](https://github.com/allenai/scifact) | ``scifact``| ✅ | ``train``<br>``test``|  300     |  5K    |  1.1 |  [Link](https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/scifact.zip)  | ``5f7d1de60b170fc8027bb7898e2efca1`` |


## :beers: Additional Information

We also provide a variety of additional information in our **[Wiki](https://github.com/beir-cellar/beir/wiki)** page. 
Please refer to these pages for the following:


### Quick Start

- [Installing BEIR](https://github.com/beir-cellar/beir/wiki/Installing-beir)
- [Examples and Tutorials](https://github.com/beir-cellar/beir/wiki/Examples-and-tutorials)

### Datasets

- [Datasets Available](https://github.com/beir-cellar/beir/wiki/Datasets-available)
- [Multilingual Datasets](https://github.com/beir-cellar/beir/wiki/Multilingual-datasets)
- [Load your Custom Dataset](https://github.com/beir-cellar/beir/wiki/Load-your-custom-dataset)

### Models 
- [Models Available](https://github.com/beir-cellar/beir/wiki/Models-available)
- [Evaluate your Custom Model](https://github.com/beir-cellar/beir/wiki/Evaluate-your-custom-model)

### Metrics

- [Metrics Available](https://github.com/beir-cellar/beir/wiki/Metrics-available)

### Miscellaneous

- [BEIR Leaderboard](https://github.com/beir-cellar/beir/wiki/Leaderboard)
- [Couse Material on IR](https://github.com/beir-cellar/beir/wiki/Course-material-on-ir)

## :beers: Disclaimer

Similar to Tensorflow [datasets](https://github.com/tensorflow/datasets) or Hugging Face's [datasets](https://github.com/huggingface/datasets) library, we just downloaded and prepared public datasets. We only distribute these datasets in a specific format, but we do not vouch for their quality or fairness, or claim that you have license to use the dataset. It remains the user's responsibility to determine whether you as a user have permission to use the dataset under the dataset's license and to cite the right owner of the dataset.

If you're a dataset owner and wish to update any part of it, or do not want your dataset to be included in this library, feel free to post an issue here or make a pull request!

If you're a dataset owner and wish to include your dataset or model in this library, feel free to post an issue here or make a pull request!

## :beers: Citing & Authors

If you find this repository helpful, feel free to cite our publication [BEIR: A Heterogenous Benchmark for Zero-shot Evaluation of Information Retrieval Models](https://arxiv.org/abs/2104.08663):

```
@inproceedings{
    thakur2021beir,
    title={{BEIR}: A Heterogeneous Benchmark for Zero-shot Evaluation of Information Retrieval Models},
    author={Nandan Thakur and Nils Reimers and Andreas R{\"u}ckl{\'e} and Abhishek Srivastava and Iryna Gurevych},
    booktitle={Thirty-fifth Conference on Neural Information Processing Systems Datasets and Benchmarks Track (Round 2)},
    year={2021},
    url={https://openreview.net/forum?id=wCu6T5xFjeJ}
}
```

If you use any baseline score from the BEIR leaderboard, feel free to cite our publication [Resources for Brewing BEIR: Reproducible Reference Models and an Official Leaderboard](https://arxiv.org/abs/2306.07471)
```
@misc{kamalloo2023resources,
      title={Resources for Brewing BEIR: Reproducible Reference Models and an Official Leaderboard}, 
      author={Ehsan Kamalloo and Nandan Thakur and Carlos Lassance and Xueguang Ma and Jheng-Hong Yang and Jimmy Lin},
      year={2023},
      eprint={2306.07471},
      archivePrefix={arXiv},
      primaryClass={cs.IR}
}
```

The main contributors of this repository are:
- [Nandan Thakur](https://github.com/Nthakur20), Personal Website: [nandan-thakur.com](https://nandan-thakur.com)

Contact person: Nandan Thakur, [nandant@gmail.com](mailto:nandant@gmail.com)

Don't hesitate to send us an e-mail or report an issue, if something is broken (and it shouldn't be) or if you have further questions.

> This repository contains experimental software and is published for the sole purpose of giving additional background details on the respective publication.

## :beers: Collaboration

The BEIR Benchmark has been made possible due to a collaborative effort of the following universities and organizations:
- [UKP Lab, Technical University of Darmstadt](http://www.ukp.tu-darmstadt.de/)
- [University of Waterloo](https://uwaterloo.ca/)
- [Hugging Face](https://huggingface.co/)

## :beers: Contributors

Thanks go to all these wonderful collaborations for their contribution towards the BEIR benchmark:

<!-- ALL-CONTRIBUTORS-LIST:START - Do not remove or modify this section -->
<!-- prettier-ignore-start -->
<!-- markdownlint-disable -->
<table>
  <tr>
    <td align="center"><a href="https://www.nandan-thakur.com"><img src="https://avatars.githubusercontent.com/u/30648040?v=4" width="100px;" alt=""/><br /><sub><b>Nandan Thakur</b></sub></a></td>
    <td align="center"><a href="https://www.nils-reimers.de/"><img src="https://avatars.githubusercontent.com/u/10706961?v=4" width="100px;" alt=""/><br /><sub><b>Nils Reimers</b></sub></a></td>
    <td align="center"><a href="https://www.informatik.tu-darmstadt.de/ukp/ukp_home/head_ukp/index.en.jsp"><img src="https://www.informatik.tu-darmstadt.de/media/ukp/pictures_1/people_1/Gurevych_Iryna_500x750_415x415.jpg" width="100px;" alt=""/><br /><sub><b>Iryna Gurevych</b></sub></a></td>
    <td align="center"><a href="https://cs.uwaterloo.ca/~jimmylin/"><img src="https://avatars.githubusercontent.com/u/313837?v=4" width="100px;" alt=""/><br /><sub><b>Jimmy Lin</b></sub></a></td>
    <td align="center"><a href="http://rueckle.net"><img src="https://i1.rgstatic.net/ii/profile.image/601126613295104-1520331161365_Q512/Andreas-Rueckle.jpg" width="100px;" alt=""/><br /><sub><b>Andreas Rücklé</b></sub></a></td>
    <td align="center"><a href="https://www.linkedin.com/in/abhesrivas"><img src="https://avatars.githubusercontent.com/u/19344566?v=4" width="100px;" alt=""/><br /><sub><b>Abhishek Srivastava</b></sub></a></td>
  </tr>
</table>

<!-- markdownlint-restore -->
<!-- prettier-ignore-end -->
<!-- ALL-CONTRIBUTORS-LIST:END -->


================================================
FILE: src/retrieve/beir/beir/__init__.py
================================================
from .logging import LoggingHandler

================================================
FILE: src/retrieve/beir/beir/datasets/__init__.py
================================================


================================================
FILE: src/retrieve/beir/beir/datasets/data_loader.py
================================================
from typing import Dict, Tuple
from tqdm.autonotebook import tqdm
import json
import os
import logging
import csv

logger = logging.getLogger(__name__)

class GenericDataLoader:
    
    def __init__(self, data_folder: str = None, prefix: str = None, corpus_file: str = "corpus.jsonl", query_file: str = "queries.jsonl", 
                 qrels_folder: str = "qrels", qrels_file: str = ""):
        self.corpus = {}
        self.queries = {}
        self.qrels = {}
        
        if prefix:
            query_file = prefix + "-" + query_file
            qrels_folder = prefix + "-" + qrels_folder

        self.corpus_file = os.path.join(data_folder, corpus_file) if data_folder else corpus_file
        self.query_file = os.path.join(data_folder, query_file) if data_folder else query_file
        self.qrels_folder = os.path.join(data_folder, qrels_folder) if data_folder else None
        self.qrels_file = qrels_file
    
    @staticmethod
    def check(fIn: str, ext: str):
        if not os.path.exists(fIn):
            raise ValueError("File {} not present! Please provide accurate file.".format(fIn))
        
        if not fIn.endswith(ext):
            raise ValueError("File {} must be present with extension {}".format(fIn, ext))

    def load_custom(self) -> Tuple[Dict[str, Dict[str, str]], Dict[str, str], Dict[str, Dict[str, int]]]:

        self.check(fIn=self.corpus_file, ext="jsonl")
        self.check(fIn=self.query_file, ext="jsonl")
        self.check(fIn=self.qrels_file, ext="tsv")
        
        if not len(self.corpus):
            logger.info("Loading Corpus...")
            self._load_corpus()
            logger.info("Loaded %d Documents.", len(self.corpus))
            logger.info("Doc Example: %s", list(self.corpus.values())[0])
        
        if not len(self.queries):
            logger.info("Loading Queries...")
            self._load_queries()
        
        if os.path.exists(self.qrels_file):
            self._load_qrels()
            self.queries = {qid: self.queries[qid] for qid in self.qrels}
            logger.info("Loaded %d Queries.", len(self.queries))
            logger.info("Query Example: %s", list(self.queries.values())[0])
        
        return self.corpus, self.queries, self.qrels

    def load(self, split="test") -> Tuple[Dict[str, Dict[str, str]], Dict[str, str], Dict[str, Dict[str, int]]]:
        
        self.qrels_file = os.path.join(self.qrels_folder, split + ".tsv")
        self.check(fIn=self.corpus_file, ext="jsonl")
        self.check(fIn=self.query_file, ext="jsonl")
        self.check(fIn=self.qrels_file, ext="tsv")
        
        if not len(self.corpus):
            logger.info("Loading Corpus...")
            self._load_corpus()
            logger.info("Loaded %d %s Documents.", len(self.corpus), split.upper())
            logger.info("Doc Example: %s", list(self.corpus.values())[0])
        
        if not len(self.queries):
            logger.info("Loading Queries...")
            self._load_queries()
        
        if os.path.exists(self.qrels_file):
            self._load_qrels()
            self.queries = {qid: self.queries[qid] for qid in self.qrels}
            logger.info("Loaded %d %s Queries.", len(self.queries), split.upper())
            logger.info("Query Example: %s", list(self.queries.values())[0])
        
        return self.corpus, self.queries, self.qrels
    
    def load_corpus(self) -> Dict[str, Dict[str, str]]:
        
        self.check(fIn=self.corpus_file, ext="jsonl")

        if not len(self.corpus):
            logger.info("Loading Corpus...")
            self._load_corpus()
            logger.info("Loaded %d Documents.", len(self.corpus))
            logger.info("Doc Example: %s", list(self.corpus.values())[0])

        return self.corpus
    
    def _load_corpus(self):
    
        num_lines = sum(1 for i in open(self.corpus_file, 'rb'))
        with open(self.corpus_file, encoding='utf8') as fIn:
            for line in tqdm(fIn, total=num_lines):
                line = json.loads(line)
                self.corpus[line.get("_id")] = {
                    "text": line.get("text"),
                    "title": line.get("title"),
                }
    
    def _load_queries(self):
        
        with open(self.query_file, encoding='utf8') as fIn:
            for line in fIn:
                line = json.loads(line)
                self.queries[line.get("_id")] = line.get("text")
        
    def _load_qrels(self):
        
        reader = csv.reader(open(self.qrels_file, encoding="utf-8"), 
                            delimiter="\t", quoting=csv.QUOTE_MINIMAL)
        next(reader)
        
        for id, row in enumerate(reader):
            query_id, corpus_id, score = row[0], row[1], int(row[2])
            
            if query_id not in self.qrels:
                self.qrels[query_id] = {corpus_id: score}
            else:
                self.qrels[query_id][corpus_id] = score

================================================
FILE: src/retrieve/beir/beir/datasets/data_loader_hf.py
================================================
from collections import defaultdict
from typing import Dict, Tuple
import os
import logging
from datasets import load_dataset, Value, Features

logger = logging.getLogger(__name__)


class HFDataLoader:
    
    def __init__(self, hf_repo: str = None, hf_repo_qrels: str = None, data_folder: str = None, prefix: str = None, corpus_file: str = "corpus.jsonl", query_file: str = "queries.jsonl", 
                 qrels_folder: str = "qrels", qrels_file: str = "", streaming: bool = False, keep_in_memory: bool = False):
        self.corpus = {}
        self.queries = {}
        self.qrels = {}
        self.hf_repo = hf_repo
        if hf_repo:
            logger.warn("A huggingface repository is provided. This will override the data_folder, prefix and *_file arguments.")
            self.hf_repo_qrels = hf_repo_qrels if hf_repo_qrels else hf_repo + "-qrels"
        else:
            # data folder would contain these files: 
            # (1) fiqa/corpus.jsonl  (format: jsonlines)
            # (2) fiqa/queries.jsonl (format: jsonlines)
            # (3) fiqa/qrels/test.tsv (format: tsv ("\t"))
            if prefix:
                query_file = prefix + "-" + query_file
                qrels_folder = prefix + "-" + qrels_folder

            self.corpus_file = os.path.join(data_folder, corpus_file) if data_folder else corpus_file
            self.query_file = os.path.join(data_folder, query_file) if data_folder else query_file
            self.qrels_folder = os.path.join(data_folder, qrels_folder) if data_folder else None
            self.qrels_file = qrels_file
        self.streaming = streaming
        self.keep_in_memory = keep_in_memory
    
    @staticmethod
    def check(fIn: str, ext: str):
        if not os.path.exists(fIn):
            raise ValueError("File {} not present! Please provide accurate file.".format(fIn))
        
        if not fIn.endswith(ext):
            raise ValueError("File {} must be present with extension {}".format(fIn, ext))

    def load(self, split="test") -> Tuple[Dict[str, Dict[str, str]], Dict[str, str], Dict[str, Dict[str, int]]]:
        
        if not self.hf_repo:
            self.qrels_file = os.path.join(self.qrels_folder, split + ".tsv")
            self.check(fIn=self.corpus_file, ext="jsonl")
            self.check(fIn=self.query_file, ext="jsonl")
            self.check(fIn=self.qrels_file, ext="tsv")
        
        if not len(self.corpus):
            logger.info("Loading Corpus...")
            self._load_corpus()
            logger.info("Loaded %d %s Documents.", len(self.corpus), split.upper())
            logger.info("Doc Example: %s", self.corpus[0])
        
        if not len(self.queries):
            logger.info("Loading Queries...")
            self._load_queries()
        
        self._load_qrels(split)
        # filter queries with no qrels
        qrels_dict = defaultdict(dict)

        def qrels_dict_init(row):
            qrels_dict[row['query-id']][row['corpus-id']] = int(row['score'])
        self.qrels.map(qrels_dict_init)
        self.qrels = qrels_dict
        self.queries = self.queries.filter(lambda x: x['id'] in self.qrels)
        logger.info("Loaded %d %s Queries.", len(self.queries), split.upper())
        logger.info("Query Example: %s", self.queries[0])
        
        return self.corpus, self.queries, self.qrels
    
    def load_corpus(self) -> Dict[str, Dict[str, str]]:
        if not self.hf_repo:
            self.check(fIn=self.corpus_file, ext="jsonl")

        if not len(self.corpus):
            logger.info("Loading Corpus...")
            self._load_corpus()
            logger.info("Loaded %d %s Documents.", len(self.corpus))
            logger.info("Doc Example: %s", self.corpus[0])

        return self.corpus
    
    def _load_corpus(self):
        if self.hf_repo:
            corpus_ds = load_dataset(self.hf_repo, 'corpus', keep_in_memory=self.keep_in_memory, streaming=self.streaming)
        else:
            corpus_ds = load_dataset('json', data_files=self.corpus_file, streaming=self.streaming, keep_in_memory=self.keep_in_memory)
        corpus_ds = next(iter(corpus_ds.values())) # get first split
        corpus_ds = corpus_ds.cast_column('_id', Value('string'))
        corpus_ds = corpus_ds.rename_column('_id', 'id')
        corpus_ds = corpus_ds.remove_columns([col for col in corpus_ds.column_names if col not in ['id', 'text', 'title']])
        self.corpus = corpus_ds
    
    def _load_queries(self):
        if self.hf_repo:
            queries_ds = load_dataset(self.hf_repo, 'queries', keep_in_memory=self.keep_in_memory, streaming=self.streaming)
        else:
            queries_ds = load_dataset('json', data_files=self.query_file, streaming=self.streaming, keep_in_memory=self.keep_in_memory)
        queries_ds = next(iter(queries_ds.values())) # get first split
        queries_ds = queries_ds.cast_column('_id', Value('string'))
        queries_ds = queries_ds.rename_column('_id', 'id')
        queries_ds = queries_ds.remove_columns([col for col in queries_ds.column_names if col not in ['id', 'text']])
        self.queries = queries_ds
        
    def _load_qrels(self, split):
        if self.hf_repo:
            qrels_ds = load_dataset(self.hf_repo_qrels, keep_in_memory=self.keep_in_memory, streaming=self.streaming)[split]
        else:
            qrels_ds = load_dataset('csv', data_files=self.qrels_file, delimiter='\t', keep_in_memory=self.keep_in_memory)
        features = Features({'query-id': Value('string'), 'corpus-id': Value('string'), 'score': Value('float')})
        qrels_ds = qrels_ds.cast(features)
        self.qrels = qrels_ds

================================================
FILE: src/retrieve/beir/beir/generation/__init__.py
================================================
from .generate import QueryGenerator, PassageExpansion

================================================
FILE: src/retrieve/beir/beir/generation/generate.py
================================================
from tqdm.autonotebook import trange
from ..util import write_to_json, write_to_tsv
from typing import Dict
import logging, os

logger = logging.getLogger(__name__)

class PassageExpansion:
    def __init__(self, model, **kwargs):
        self.model = model
        self.corpus_exp = {}
    
    @staticmethod
    def save(output_dir: str, corpus: Dict[str, str], prefix: str):
        os.makedirs(output_dir, exist_ok=True)
        
        corpus_file = os.path.join(output_dir, prefix + "-corpus.jsonl")
        
        logger.info("Saving expanded passages to {}".format(corpus_file))
        write_to_json(output_file=corpus_file, data=corpus)

    def expand(self, 
                 corpus: Dict[str, Dict[str, str]], 
                 output_dir: str, 
                 top_k: int = 200, 
                 max_length: int = 350,
                 prefix: str = "gen", 
                 batch_size: int = 32,
                 sep: str = " "):
        
        logger.info("Starting to expand Passages with {} tokens chosen...".format(top_k))
        logger.info("Params: top_k = {}".format(top_k))
        logger.info("Params: passage max_length = {}".format(max_length))
        logger.info("Params: batch size = {}".format(batch_size))

        corpus_ids = list(corpus.keys())
        corpus_list = [corpus[doc_id] for doc_id in corpus_ids]

        for start_idx in trange(0, len(corpus_list), batch_size, desc='pas'):
            expansions = self.model.generate(
                corpus=corpus_list[start_idx:start_idx + batch_size], 
                max_length=max_length,
                top_k=top_k)
            
            for idx in range(len(expansions)):
                doc_id = corpus_ids[start_idx + idx]
                self.corpus_exp[doc_id] = {
                    "title": corpus[doc_id]["title"],
                    "text": corpus[doc_id]["text"] + sep + expansions[idx],
                }
        
        # Saving finally all the questions
        logger.info("Saving {} Expanded Passages...".format(len(self.corpus_exp)))
        self.save(output_dir, self.corpus_exp, prefix)


class QueryGenerator:
    def __init__(self, model, **kwargs):
        self.model = model
        self.qrels = {}
        self.queries = {}

    @staticmethod
    def save(output_dir: str, queries: Dict[str, str], qrels: Dict[str, Dict[str, int]], prefix: str):
        
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(os.path.join(output_dir, prefix + "-qrels"), exist_ok=True)
        
        query_file = os.path.join(output_dir, prefix + "-queries.jsonl")
        qrels_file = os.path.join(output_dir, prefix + "-qrels", "train.tsv")
        
        logger.info("Saving Generated Queries to {}".format(query_file))
        write_to_json(output_file=query_file, data=queries)
        
        logger.info("Saving Generated Qrels to {}".format(qrels_file))
        write_to_tsv(output_file=qrels_file, data=qrels)

    def generate(self, 
                 corpus: Dict[str, Dict[str, str]], 
                 output_dir: str, 
                 top_p: int = 0.95, 
                 top_k: int = 25, 
                 max_length: int = 64,
                 ques_per_passage: int = 1, 
                 prefix: str = "gen", 
                 batch_size: int = 32,
                 save: bool = True, 
                 save_after: int = 100000):
        
        logger.info("Starting to Generate {} Questions Per Passage using top-p (nucleus) sampling...".format(ques_per_passage))
        logger.info("Params: top_p = {}".format(top_p))
        logger.info("Params: top_k = {}".format(top_k))
        logger.info("Params: max_length = {}".format(max_length))
        logger.info("Params: ques_per_passage = {}".format(ques_per_passage))
        logger.info("Params: batch size = {}".format(batch_size))
        
        count = 0
        corpus_ids = list(corpus.keys())
        corpus = [corpus[doc_id] for doc_id in corpus_ids]

        for start_idx in trange(0, len(corpus), batch_size, desc='pas'):            
            
            size = len(corpus[start_idx:start_idx + batch_size])
            queries = self.model.generate(
                corpus=corpus[start_idx:start_idx + batch_size], 
                ques_per_passage=ques_per_passage,
                max_length=max_length,
                top_p=top_p,
                top_k=top_k
                )
            
            assert len(queries) == size * ques_per_passage

            for idx in range(size):      
                # Saving generated questions after every "save_after" corpus ids
                if (len(self.queries) % save_after == 0 and len(self.queries) >= save_after):
                    logger.info("Saving {} Generated Queries...".format(len(self.queries)))
                    self.save(output_dir, self.queries, self.qrels, prefix)

                corpus_id = corpus_ids[start_idx + idx]
                start_id = idx * ques_per_passage
                end_id = start_id + ques_per_passage
                query_set = set([q.strip() for q in queries[start_id:end_id]])

                for query in query_set:
                    count += 1
                    query_id = "genQ" + str(count)
                    self.queries[query_id] = query
                    self.qrels[query_id] = {corpus_id: 1}
        
        # Saving finally all the questions
        logger.info("Saving {} Generated Queries...".format(len(self.queries)))
        self.save(output_dir, self.queries, self.qrels, prefix)
    
    def generate_multi_process(self, 
                 corpus: Dict[str, Dict[str, str]], 
                 pool:  Dict[str, object],
                 output_dir: str, 
                 top_p: int = 0.95, 
                 top_k: int = 25, 
                 max_length: int = 64,
                 ques_per_passage: int = 1, 
                 prefix: str = "gen", 
                 batch_size: int = 32,
                 chunk_size: int = None):
        
        logger.info("Starting to Generate {} Questions Per Passage using top-p (nucleus) sampling...".format(ques_per_passage))
        logger.info("Params: top_p = {}".format(top_p))
        logger.info("Params: top_k = {}".format(top_k))
        logger.info("Params: max_length = {}".format(max_length))
        logger.info("Params: ques_per_passage = {}".format(ques_per_passage))
        logger.info("Params: batch size = {}".format(batch_size))
        
        count = 0
        corpus_ids = list(corpus.keys())
        corpus = [corpus[doc_id] for doc_id in corpus_ids]

        queries = self.model.generate_multi_process(
                            corpus=corpus, 
                            pool=pool,
                            ques_per_passage=ques_per_passage,
                            max_length=max_length,
                            top_p=top_p,
                            top_k=top_k,
                            chunk_size=chunk_size,
                            batch_size=batch_size,
                            )

        assert len(queries) == len(corpus) * ques_per_passage

        for idx in range(len(corpus)):      
            corpus_id = corpus_ids[idx]
            start_id = idx * ques_per_passage
            end_id = start_id + ques_per_passage
            query_set = set([q.strip() for q in queries[start_id:end_id]])

            for query in query_set:
                count += 1
                query_id = "genQ" + str(count)
                self.queries[query_id] = query
                self.qrels[query_id] = {corpus_id: 1}
    
        # Saving finally all the questions
        logger.info("Saving {} Generated Queries...".format(len(self.queries)))
        self.save(output_dir, self.queries, self.qrels, prefix)

================================================
FILE: src/retrieve/beir/beir/generation/models/__init__.py
================================================
from .auto_model import QGenModel
from .tilde import TILDE

================================================
FILE: src/retrieve/beir/beir/generation/models/auto_model.py
================================================
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from tqdm.autonotebook import trange
import torch, logging, math, queue
import torch.multiprocessing as mp
from typing import List, Dict

logger = logging.getLogger(__name__)


class QGenModel:
    def __init__(self, model_path: str, gen_prefix: str = "", use_fast: bool = True, device: str = None, **kwargs):
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=use_fast)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
        self.gen_prefix = gen_prefix
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        logger.info("Use pytorch device: {}".format(self.device))
        self.model = self.model.to(self.device)
    
    def generate(self, corpus: List[Dict[str, str]], ques_per_passage: int, top_k: int, max_length: int, top_p: float = None, temperature: float = None) -> List[str]:
        
        texts = [(self.gen_prefix + doc["title"] + " " + doc["text"]) for doc in corpus]
        encodings = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
        
        # Top-p nucleus sampling
        # https://huggingface.co/blog/how-to-generate
        with torch.no_grad():
            if not temperature:
                outs = self.model.generate(
                    input_ids=encodings['input_ids'].to(self.device), 
                    do_sample=True,
                    max_length=max_length,  # 64
                    top_k=top_k,  # 25
                    top_p=top_p,  # 0.95
                    num_return_sequences=ques_per_passage  # 1
                    )
            else:
                outs = self.model.generate(
                    input_ids=encodings['input_ids'].to(self.device), 
                    do_sample=True,
                    max_length=max_length,  # 64
                    top_k=top_k,  # 25
                    temperature=temperature,
                    num_return_sequences=ques_per_passage  # 1
                    )

        return self.tokenizer.batch_decode(outs, skip_special_tokens=True)
    
    def start_multi_process_pool(self, target_devices: List[str] = None):
        """
        Starts multi process to process the encoding with several, independent processes.
        This method is recommended if you want to encode on multiple GPUs. It is advised
        to start only one process per GPU. This method works together with encode_multi_process
        :param target_devices: PyTorch target devices, e.g. cuda:0, cuda:1... If None, all available CUDA devices will be used
        :return: Returns a dict with the target processes, an input queue and and output queue.
        """
        if target_devices is None:
            if torch.cuda.is_available():
                target_devices = ['cuda:{}'.format(i) for i in range(torch.cuda.device_count())]
            else:
                logger.info("CUDA is not available. Start 4 CPU worker")
                target_devices = ['cpu']*4

        logger.info("Start multi-process pool on devices: {}".format(', '.join(map(str, target_devices))))

        ctx = mp.get_context('spawn')
        input_queue = ctx.Queue()
        output_queue = ctx.Queue()
        processes = []

        for cuda_id in target_devices:
            p = ctx.Process(target=QGenModel._generate_multi_process_worker, args=(cuda_id, self.model, self.tokenizer, input_queue, output_queue), daemon=True)
            p.start()
            processes.append(p)

        return {'input': input_queue, 'output': output_queue, 'processes': processes}
    
    @staticmethod
    def stop_multi_process_pool(pool):
        """
        Stops all processes started with start_multi_process_pool
        """
        for p in pool['processes']:
            p.terminate()

        for p in pool['processes']:
            p.join()
            p.close()

        pool['input'].close()
        pool['output'].close()
    
    @staticmethod
    def _generate_multi_process_worker(target_device: str, model, tokenizer, input_queue, results_queue):
        """
        Internal working process to generate questions in multi-process setup
        """
        while True:
            try:
                id, batch_size, texts, ques_per_passage, top_p, top_k, max_length = input_queue.get()
                model = model.to(target_device)
                generated_texts = []
                
                for start_idx in trange(0, len(texts), batch_size, desc='{}'.format(target_device)):
                    texts_batch = texts[start_idx:start_idx + batch_size]
                    encodings = tokenizer(texts_batch, padding=True, truncation=True, return_tensors="pt")
                    with torch.no_grad():
                        outs = model.generate(
                            input_ids=encodings['input_ids'].to(target_device), 
                            do_sample=True,
                            max_length=max_length, # 64
                            top_k=top_k, # 25
                            top_p=top_p, # 0.95
                            num_return_sequences=ques_per_passage # 1
                            )
                    generated_texts += tokenizer.batch_decode(outs, skip_special_tokens=True)
                
                results_queue.put([id, generated_texts])
            except queue.Empty:
                break
    
    def generate_multi_process(self, corpus: List[Dict[str, str]], ques_per_passage: int, top_p: int, top_k: int, max_length: int, 
                               pool: Dict[str, object], batch_size: int = 32, chunk_size: int = None):
        """
        This method allows to run encode() on multiple GPUs. The sentences are chunked into smaller packages
        and sent to individual processes, which encode these on the different GPUs. This method is only suitable
        for encoding large sets of sentences
        :param sentences: List of sentences
        :param pool: A pool of workers started with SentenceTransformer.start_multi_process_pool
        :param batch_size: Encode sentences with batch size
        :param chunk_size: Sentences are chunked and sent to the individual processes. If none, it determine a sensible size.
        :return: Numpy matrix with all embeddings
        """

        texts = [(self.gen_prefix + doc["title"] + " " + doc["text"]) for doc in corpus]

        if chunk_size is None:
            chunk_size = min(math.ceil(len(texts) / len(pool["processes"]) / 10), 5000)

        logger.info("Chunk data into packages of size {}".format(chunk_size))

        input_queue = pool['input']
        last_chunk_id = 0
        chunk = []

        for doc_text in texts:
            chunk.append(doc_text)
            if len(chunk) >= chunk_size:
                input_queue.put([last_chunk_id, batch_size, chunk, ques_per_passage, top_p, top_k, max_length])
                last_chunk_id += 1
                chunk = []

        if len(chunk) > 0:
            input_queue.put([last_chunk_id, batch_size, chunk, ques_per_passage, top_p, top_k, max_length])
            last_chunk_id += 1

        output_queue = pool['output']
        
        results_list = sorted([output_queue.get() for _ in range(last_chunk_id)], key=lambda x: x[0])        
        queries = [result[1] for result in results_list]
        
        return [item for sublist in queries for item in sublist]

================================================
FILE: src/retrieve/beir/beir/generation/models/tilde.py
================================================
from transformers import BertLMHeadModel, BertTokenizer, DataCollatorWithPadding
from tqdm.autonotebook import trange
import torch, logging, math, queue
import torch.multiprocessing as mp
from typing import List, Dict
from nltk.corpus import stopwords
import numpy as np
import re

logger = logging.getLogger(__name__)

class TILDE:
    def __init__(self, model_path: str, gen_prefix: str = "", use_fast: bool = True, device: str = None, **kwargs):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', use_fast=use_fast)
        self.model = BertLMHeadModel.from_pretrained(model_path)
        self.gen_prefix = gen_prefix
        _, self.bad_ids = self._clean_vocab(self.tokenizer)
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        logger.info("Use pytorch device: {}".format(self.device))
        self.model = self.model.to(self.device)
    
    def _clean_vocab(self, tokenizer, do_stopwords=True):
        if do_stopwords:
            stop_words = set(stopwords.words('english'))
            # keep some common words in ms marco questions
            # stop_words.difference_update(["where", "how", "what", "when", "which", "why", "who"])
            stop_words.add("definition")

        vocab = tokenizer.get_vocab()
        tokens = vocab.keys()

        good_ids = []
        bad_ids = []

        for stop_word in stop_words:
            ids = tokenizer(stop_word, add_special_tokens=False)["input_ids"]
            if len(ids) == 1:
                bad_ids.append(ids[0])

        for token in tokens:
            token_id = vocab[token]
            if token_id in bad_ids:
                continue

            if token[0] == '#' and len(token) > 1:
                good_ids.append(token_id)
            else:
                if not re.match("^[A-Za-z0-9_-]*$", token):
                    bad_ids.append(token_id)
                else:
                    good_ids.append(token_id)
        bad_ids.append(2015)  # add ##s to stopwords
        return good_ids, bad_ids
    
    def generate(self, corpus: List[Dict[str, str]], top_k: int, max_length: int) -> List[str]:
        
        expansions = []
        texts_batch = [(self.gen_prefix + doc["title"] + " " + doc["text"]) for doc in corpus]
        encode_texts = np.array(self.tokenizer.batch_encode_plus(
                texts_batch,
                max_length=max_length,
                truncation='only_first',
                return_attention_mask=False,
                padding='max_length')['input_ids'])
        
        encode_texts[:,0] = 1
        encoded_texts_gpu = torch.tensor(encode_texts).to(self.device)

        with torch.no_grad():
            logits = self.model(encoded_texts_gpu, return_dict=True).logits[:, 0]
            batch_selected = torch.topk(logits, top_k).indices.cpu().numpy()

            for idx, selected in enumerate(batch_selected):
                expand_term_ids = np.setdiff1d(np.setdiff1d(selected, encode_texts[idx], assume_unique=True), self.bad_ids, assume_unique=True)
                expansions.append(self.tokenizer.decode(expand_term_ids))
        
        return expansions

================================================
FILE: src/retrieve/beir/beir/logging.py
================================================
import logging
import tqdm

class LoggingHandler(logging.Handler):
    def __init__(self, level=logging.NOTSET):
        super().__init__(level)
    
    def emit(self, record):
        try:
            msg = self.format(record)
            tqdm.tqdm.write(msg)
            self.flush()
        except (KeyboardInterrupt, SystemExit):
            raise
        except:
            self.handleError(record)

================================================
FILE: src/retrieve/beir/beir/losses/__init__.py
================================================
from .bpr_loss import BPRLoss
from .margin_mse_loss import MarginMSELoss

================================================
FILE: src/retrieve/beir/beir/losses/bpr_loss.py
================================================
import math
import torch
from typing import Iterable, Dict
from sentence_transformers import SentenceTransformer, util

class BPRLoss(torch.nn.Module):
    """
        This loss expects as input a batch consisting of sentence triplets (a_1, p_1, n_1), (a_2, p_2, n_2)..., (a_n, p_n, n_n)
        where we assume that (a_i, p_i) are a positive pair and (a_i, p_j) for i!=j a negative pair. 
        You can also provide one or multiple hard negatives (n_1, n_2, ..) per anchor-positive pair by structering the data like this.
        
        We define the loss function as defined in ACL2021: Efficient Passage Retrieval with Hashing for Open-domain Question Answering.
        For more information: https://arxiv.org/abs/2106.00882
        
        Parts of the code has been reused from the source code of BPR (Binary Passage Retriever): https://github.com/studio-ousia/bpr.
        
        We combine two losses for training a binary code based retriever model =>
        1. Margin Ranking Loss: https://pytorch.org/docs/stable/generated/torch.nn.MarginRankingLoss.html
        2. Cross Entropy Loss (or Multiple Negatives Ranking Loss): https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html

    """
    def __init__(self, model: SentenceTransformer, scale: float = 1.0, similarity_fct = util.dot_score, binary_ranking_loss_margin: float = 2.0, hashnet_gamma: float = 0.1):
        """
        :param model: SentenceTransformer model
        :param scale: Output of similarity function is multiplied by scale value
        :param similarity_fct: similarity function between sentence embeddings. By default, dot_score. Can also be set to cosine similarity.
        :param binary_ranking_loss_margin: margin used for binary loss. By default original authors found enhanced performance = 2.0, (Appendix D, https://arxiv.org/abs/2106.00882).
        :param hashnet_gamma: hashnet gamma function used for scaling tanh function. By default original authors found enhanced performance = 0.1, (Appendix B, https://arxiv.org/abs/2106.00882).
        """
        super(BPRLoss, self).__init__()
        self.global_step = 0
        self.model = model
        self.scale = scale
        self.similarity_fct = similarity_fct
        self.hashnet_gamma = hashnet_gamma
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
        self.margin_ranking_loss = torch.nn.MarginRankingLoss(margin=binary_ranking_loss_margin)
    
    def convert_to_binary(self, input_repr: torch.Tensor) -> torch.Tensor:
        """
        The paper uses tanh function as an approximation for sign function, because of its incompatibility with backpropogation.
        """
        scale = math.pow((1.0 + self.global_step * self.hashnet_gamma), 0.5)
        return torch.tanh(input_repr * scale)

    def forward(self, sentence_features: Iterable[Dict[str, torch.Tensor]], labels: torch.Tensor):
        
        reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
        embeddings_a = reps[0]
        embeddings_b = torch.cat([self.convert_to_binary(rep) for rep in reps[1:]])    
        
        # Dense Loss (or Multiple Negatives Ranking Loss)
        # Used to learn the encoder model
        scores = self.similarity_fct(embeddings_a, embeddings_b) * self.scale
        labels = torch.tensor(range(len(scores)), dtype=torch.long, device=scores.device)  # Example a[i] should match with b[i]
        dense_loss = self.cross_entropy_loss(scores, labels)
        
        # Binary Loss (or Margin Ranking Loss)
        # Used to learn to binary coded model
        binary_query_repr = self.convert_to_binary(embeddings_a)
        binary_query_scores = torch.matmul(binary_query_repr, embeddings_b.transpose(0, 1))
        pos_mask = binary_query_scores.new_zeros(binary_query_scores.size(), dtype=torch.bool)
        for n, label in enumerate(labels):
            pos_mask[n, label] = True
        pos_bin_scores = torch.masked_select(binary_query_scores, pos_mask)
        pos_bin_scores = pos_bin_scores.repeat_interleave(embeddings_b.size(0) - 1)
        neg_bin_scores = torch.masked_select(binary_query_scores, torch.logical_not(pos_mask))
        bin_labels = pos_bin_scores.new_ones(pos_bin_scores.size(), dtype=torch.int64)
        binary_loss = self.margin_ranking_loss(
            pos_bin_scores, neg_bin_scores, bin_labels)
        
        self.global_step += 1
        
        return dense_loss + binary_loss


================================================
FILE: src/retrieve/beir/beir/losses/margin_mse_loss.py
================================================
from .. import util
import torch
from torch import nn, Tensor
from typing import Union, Tuple, List, Iterable, Dict
from torch.nn import functional as F


class MarginMSELoss(nn.Module):
    """
    Computes the Margin MSE loss between the query, positive passage and negative passage. This loss
    is used to train dense-models using cross-architecture knowledge distillation setup. 

    Margin MSE Loss is defined as from (Eq.11) in Sebastian Hofstätter et al. in https://arxiv.org/abs/2010.02666:
    Loss(𝑄, 𝑃+, 𝑃−) = MSE(𝑀𝑠(𝑄, 𝑃+) − 𝑀𝑠(𝑄, 𝑃−), 𝑀𝑡(𝑄, 𝑃+) − 𝑀𝑡(𝑄, 𝑃−))
    where 𝑄: Query, 𝑃+: Relevant passage, 𝑃−: Non-relevant passage, 𝑀𝑠: Student model, 𝑀𝑡: Teacher model

    Remember: Pass the difference in scores of the passages as labels.
    """
    def __init__(self, model, scale: float = 1.0, similarity_fct = 'dot'):
        super(MarginMSELoss, self).__init__()
        self.model = model
        self.scale = scale
        self.similarity_fct = similarity_fct
        self.loss_fct = nn.MSELoss()

    def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
        # sentence_features: query, positive passage, negative passage
        reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
        embeddings_query = reps[0]
        embeddings_pos = reps[1]
        embeddings_neg = reps[2]

        scores_pos = (embeddings_query * embeddings_pos).sum(dim=-1) * self.scale
        scores_neg = (embeddings_query * embeddings_neg).sum(dim=-1) * self.scale
        margin_pred = scores_pos - scores_neg

        return self.loss_fct(margin_pred, labels)


================================================
FILE: src/retrieve/beir/beir/reranking/__init__.py
================================================
from .rerank import Rerank

================================================
FILE: src/retrieve/beir/beir/reranking/models/__init__.py
================================================
from .cross_encoder import CrossEncoder
from .mono_t5 import MonoT5

================================================
FILE: src/retrieve/beir/beir/reranking/models/cross_encoder.py
================================================
from sentence_transformers.cross_encoder import CrossEncoder as CE
import numpy as np
from typing import List, Dict, Tuple

class CrossEncoder:
    def __init__(self, model_path: str, **kwargs):
        self.model = CE(model_path, **kwargs)
    
    def predict(self, sentences: List[Tuple[str,str]], batch_size: int = 32, show_progress_bar: bool = True) -> List[float]:
        return self.model.predict(
            sentences=sentences, 
            batch_size=batch_size, 
            show_progress_bar=show_progress_bar)


================================================
FILE: src/retrieve/beir/beir/reranking/models/mono_t5.py
================================================
# Majority of the code has been copied from PyGaggle MonoT5 implementation
# https://github.com/castorini/pygaggle/blob/master/pygaggle/rerank/transformer.py

from transformers import (AutoTokenizer,
                          AutoModelForSeq2SeqLM,
                          PreTrainedModel,
                          PreTrainedTokenizer,
                          T5ForConditionalGeneration)
from typing import List, Union, Tuple, Mapping, Optional
from dataclasses import dataclass
from tqdm.autonotebook import trange
import torch


TokenizerReturnType = Mapping[str, Union[torch.Tensor, List[int],
                                         List[List[int]],
                                         List[List[str]]]]

@dataclass
class QueryDocumentBatch:
    query: str
    documents: List[str]
    output: Optional[TokenizerReturnType] = None

    def __len__(self):
        return len(self.documents)

class QueryDocumentBatchTokenizer:
    def __init__(self,
                 tokenizer: PreTrainedTokenizer,
                 pattern: str = '{query} {document}',
                 **tokenizer_kwargs):
        self.tokenizer = tokenizer
        self.tokenizer_kwargs = tokenizer_kwargs
        self.pattern = pattern
    
    def encode(self, strings: List[str]):
        assert self.tokenizer and self.tokenizer_kwargs is not None, \
                'mixin used improperly'
        ret = self.tokenizer.batch_encode_plus(strings,
                                               **self.tokenizer_kwargs)
        ret['tokens'] = list(map(self.tokenizer.tokenize, strings))
        return ret

    def traverse_query_document(
            self, batch_input: Tuple[str, List[str]], batch_size: int):
        query, doc_texts = batch_input[0], batch_input[1]
        for batch_idx in range(0, len(doc_texts), batch_size):
            docs = doc_texts[batch_idx:batch_idx + batch_size]
            outputs = self.encode([self.pattern.format(
                                        query=query,
                                        document=doc) for doc in docs])
            yield QueryDocumentBatch(query, docs, outputs)

class T5BatchTokenizer(QueryDocumentBatchTokenizer):
    def __init__(self, *args, **kwargs):
        kwargs['pattern'] = 'Query: {query} Document: {document} Relevant:'
        if 'return_attention_mask' not in kwargs:
            kwargs['return_attention_mask'] = True
        if 'padding' not in kwargs:
            kwargs['padding'] = 'longest'
        if 'truncation' not in kwargs:
            kwargs['truncation'] = True
        if 'return_tensors' not in kwargs:
            kwargs['return_tensors'] = 'pt'
        if 'max_length' not in kwargs:
            kwargs['max_length'] = 512
        super().__init__(*args, **kwargs)


@torch.no_grad()
def greedy_decode(model: PreTrainedModel,
                  input_ids: torch.Tensor,
                  length: int,
                  attention_mask: torch.Tensor = None,
                  return_last_logits: bool = True) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    decode_ids = torch.full((input_ids.size(0), 1),
                            model.config.decoder_start_token_id,
                            dtype=torch.long).to(input_ids.device)
    encoder_outputs = model.get_encoder()(input_ids, attention_mask=attention_mask)
    next_token_logits = None
    for _ in range(length):
        model_inputs = model.prepare_inputs_for_generation(
            decode_ids,
            encoder_outputs=encoder_outputs,
            past=None,
            attention_mask=attention_mask,
            use_cache=True)
        outputs = model(**model_inputs)  # (batch_size, cur_len, vocab_size)
        next_token_logits = outputs[0][:, -1, :]  # (batch_size, vocab_size)
        decode_ids = torch.cat([decode_ids,
                                next_token_logits.max(1)[1].unsqueeze(-1)],
                               dim=-1)
    if return_last_logits:
        return decode_ids, next_token_logits
    return decode_ids


class MonoT5:
    def __init__(self, 
                 model_path: str,
                 tokenizer: QueryDocumentBatchTokenizer = None,
                 use_amp = True,
                 token_false = None,
                 token_true  = None):
        self.model = self.get_model(model_path)
        self.tokenizer = tokenizer or self.get_tokenizer(model_path)
        self.token_false_id, self.token_true_id = self.get_prediction_tokens(
                model_path, self.tokenizer, token_false, token_true)
        self.model_path = model_path
        self.device = next(self.model.parameters(), None).device
        self.use_amp = use_amp

    @staticmethod
    def get_model(model_path: str, *args, device: str = None, **kwargs) -> T5ForConditionalGeneration:
        device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        device = torch.device(device)
        return AutoModelForSeq2SeqLM.from_pretrained(model_path, *args, **kwargs).to(device).eval()

    @staticmethod
    def get_tokenizer(model_path: str, *args, **kwargs) -> T5BatchTokenizer:
        return T5BatchTokenizer(
            AutoTokenizer.from_pretrained(model_path, use_fast=False, *args, **kwargs)
        )

    @staticmethod
    def get_prediction_tokens(model_path: str, tokenizer, token_false, token_true):
        if (token_false and token_true):
            token_false_id = tokenizer.tokenizer.get_vocab()[token_false]
            token_true_id  = tokenizer.tokenizer.get_vocab()[token_true]
            return token_false_id, token_true_id

    def predict(self, sentences: List[Tuple[str,str]], batch_size: int = 32, **kwargs) -> List[float]:
        
        sentence_dict, queries, scores = {}, [], []

        # T5 model requires a batch of single query and top-k documents
        for (query, doc_text) in sentences:
            if query not in sentence_dict:
                sentence_dict[query] = []
                queries.append(query) # Preserves order of queries
            sentence_dict[query].append(doc_text) 
        
        for start_idx in trange(0, len(queries), 1): # Take one query at a time
            batch_input = (queries[start_idx], sentence_dict[queries[start_idx]]) # (single query, top-k docs)            
            for batch in self.tokenizer.traverse_query_document(batch_input, batch_size): 
                with torch.cuda.amp.autocast(enabled=self.use_amp):
                    input_ids = batch.output['input_ids'].to(self.device)
                    attn_mask = batch.output['attention_mask'].to(self.device)
                    _, batch_scores = greedy_decode(self.model,
                                                    input_ids,
                                                    length=1,
                                                    attention_mask=attn_mask,
                                                    return_last_logits=True)

                    batch_scores = batch_scores[:, [self.token_false_id, self.token_true_id]]
                    batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
                    batch_log_probs = batch_scores[:, 1].tolist()
                    scores.extend(batch_log_probs)
        
        assert len(scores) == len(sentences) # Sanity check, should be equal
        return scores

================================================
FILE: src/retrieve/beir/beir/reranking/rerank.py
================================================
import logging
from typing import Dict, List

logger = logging.getLogger(__name__)

#Parent class for any reranking model
class Rerank:
    
    def __init__(self, model, batch_size: int = 128, **kwargs):
        self.cross_encoder = model
        self.batch_size = batch_size
        self.rerank_results = {}
        
    def rerank(self, 
               corpus: Dict[str, Dict[str, str]], 
               queries: Dict[str, str],
               results: Dict[str, Dict[str, float]],
               top_k: int) -> Dict[str, Dict[str, float]]:
        
        sentence_pairs, pair_ids = [], []
        
        for query_id in results:
            if len(results[query_id]) > top_k:
                for (doc_id, _) in sorted(results[query_id].items(), key=lambda item: item[1], reverse=True)[:top_k]:
                    pair_ids.append([query_id, doc_id])
                    corpus_text = (corpus[doc_id].get("title", "") + " " + corpus[doc_id].get("text", "")).strip()
                    sentence_pairs.append([queries[query_id], corpus_text])
            
            else:
                for doc_id in results[query_id]:
                    pair_ids.append([query_id, doc_id])
                    corpus_text = (corpus[doc_id].get("title", "") + " " + corpus[doc_id].get("text", "")).strip()
                    sentence_pairs.append([queries[query_id], corpus_text])

        #### Starting to Rerank using cross-attention
        logging.info("Starting To Rerank Top-{}....".format(top_k))
        rerank_scores = [float(score) for score in self.cross_encoder.predict(sentence_pairs, batch_size=self.batch_size)]

        #### Reranking results
        self.rerank_results = {query_id: {} for query_id in results}
        for pair, score in zip(pair_ids, rerank_scores):
            query_id, doc_id = pair[0], pair[1]
            self.rerank_results[query_id][doc_id] = score

        return self.rerank_results 


================================================
FILE: src/retrieve/beir/beir/retrieval/__init__.py
================================================


================================================
FILE: src/retrieve/beir/beir/retrieval/custom_metrics.py
================================================
import logging
from typing import List, Dict, Union, Tuple

def mrr(qrels: Dict[str, Dict[str, int]], 
        results: Dict[str, Dict[str, float]], 
        k_values: List[int]) -> Tuple[Dict[str, float]]:
    
    MRR = {}
    
    for k in k_values:
        MRR[f"MRR@{k}"] = 0.0
    
    k_max, top_hits = max(k_values), {}
    logging.info("\n")
    
    for query_id, doc_scores in results.items():
        top_hits[query_id] = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max]   
    
    for query_id in top_hits:
        query_relevant_docs = set([doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0])    
        for k in k_values:
            for rank, hit in enumerate(top_hits[query_id][0:k]):
                if hit[0] in query_relevant_docs:
                    MRR[f"MRR@{k}"] += 1.0 / (rank + 1)
                    break

    for k in k_values:
        MRR[f"MRR@{k}"] = round(MRR[f"MRR@{k}"]/len(qrels), 5)
        logging.info("MRR@{}: {:.4f}".format(k, MRR[f"MRR@{k}"]))

    return MRR

def recall_cap(qrels: Dict[str, Dict[str, int]], 
               results: Dict[str, Dict[str, float]], 
               k_values: List[int]) -> Tuple[Dict[str, float]]:
    
    capped_recall = {}
    
    for k in k_values:
        capped_recall[f"R_cap@{k}"] = 0.0
    
    k_max = max(k_values)
    logging.info("\n")
    
    for query_id, doc_scores in results.items():
        top_hits = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max]   
        query_relevant_docs = [doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0]
        for k in k_values:
            retrieved_docs = [row[0] for row in top_hits[0:k] if qrels[query_id].get(row[0], 0) > 0]
            denominator = min(len(query_relevant_docs), k)
            capped_recall[f"R_cap@{k}"] += (len(retrieved_docs) / denominator)

    for k in k_values:
        capped_recall[f"R_cap@{k}"] = round(capped_recall[f"R_cap@{k}"]/len(qrels), 5)
        logging.info("R_cap@{}: {:.4f}".format(k, capped_recall[f"R_cap@{k}"]))

    return capped_recall


def hole(qrels: Dict[str, Dict[str, int]], 
               results: Dict[str, Dict[str, float]], 
               k_values: List[int]) -> Tuple[Dict[str, float]]:
    
    Hole = {}
    
    for k in k_values:
        Hole[f"Hole@{k}"] = 0.0
    
    annotated_corpus = set()
    for _, docs in qrels.items():
        for doc_id, score in docs.items():    
            annotated_corpus.add(doc_id)
    
    k_max = max(k_values)
    logging.info("\n")
    
    for _, scores in results.items():
        top_hits = sorted(scores.items(), key=lambda item: item[1], reverse=True)[0:k_max]
        for k in k_values:
            hole_docs = [row[0] for row in top_hits[0:k] if row[0] not in annotated_corpus]
            Hole[f"Hole@{k}"] += len(hole_docs) / k

    for k in k_values:
        Hole[f"Hole@{k}"] = round(Hole[f"Hole@{k}"]/len(qrels), 5)
        logging.info("Hole@{}: {:.4f}".format(k, Hole[f"Hole@{k}"]))

    return Hole

def top_k_accuracy(
        qrels: Dict[str, Dict[str, int]], 
        results: Dict[str, Dict[str, float]], 
        k_values: List[int]) -> Tuple[Dict[str, float]]:
    
    top_k_acc = {}
    
    for k in k_values:
        top_k_acc[f"Accuracy@{k}"] = 0.0
    
    k_max, top_hits = max(k_values), {}
    logging.info("\n")
    
    for query_id, doc_scores in results.items():
        top_hits[query_id] = [item[0] for item in sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max]]
    
    for query_id in top_hits:
        query_relevant_docs = set([doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0])
        for k in k_values:
            for relevant_doc_id in query_relevant_docs:
                if relevant_doc_id in top_hits[query_id][0:k]:
                    top_k_acc[f"Accuracy@{k}"] += 1.0
                    break

    for k in k_values:
        top_k_acc[f"Accuracy@{k}"] = round(top_k_acc[f"Accuracy@{k}"]/len(qrels), 5)
        logging.info("Accuracy@{}: {:.4f}".format(k, top_k_acc[f"Accuracy@{k}"]))

    return top_k_acc

================================================
FILE: src/retrieve/beir/beir/retrieval/evaluation.py
================================================
# import pytrec_eval
import logging
from typing import List, Dict, Tuple
from .search.base import BaseSearch
from .custom_metrics import mrr, recall_cap, hole, top_k_accuracy

logger = logging.getLogger(__name__)

class EvaluateRetrieval:
    
    def __init__(self, retriever: BaseSearch = None, k_values: List[int] = [1,3,5,10,100,1000], score_function: str = "cos_sim"):
        self.k_values = k_values
        self.top_k = max(k_values)
        self.retriever = retriever
        self.score_function = score_function
            
    def retrieve(self, corpus: Dict[str, Dict[str, str]], queries: Dict[str, str], **kwargs) -> Dict[str, Dict[str, float]]:
        if not self.retriever:
            raise ValueError("Model/Technique has not been provided!")
        return self.retriever.search(corpus, queries, self.top_k, self.score_function, **kwargs)
    
    def rerank(self, 
            corpus: Dict[str, Dict[str, str]], 
            queries: Dict[str, str],
            results: Dict[str, Dict[str, float]],
            top_k: int) -> Dict[str, Dict[str, float]]:
    
        new_corpus = {}
    
        for query_id in results:
            if len(results[query_id]) > top_k:
                for (doc_id, _) in sorted(results[query_id].items(), key=lambda item: item[1], reverse=True)[:top_k]:
                    new_corpus[doc_id] = corpus[doc_id]
            else:
                for doc_id in results[query_id]:
                    new_corpus[doc_id] = corpus[doc_id]
                    
        return self.retriever.search(new_corpus, queries, top_k, self.score_function)

    # @staticmethod
    # def evaluate(qrels: Dict[str, Dict[str, int]], 
    #              results: Dict[str, Dict[str, float]], 
    #              k_values: List[int],
    #              ignore_identical_ids: bool=True) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float]]:
        
    #     if ignore_identical_ids:
    #         logger.info('For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this.')
    #         popped = []
    #         for qid, rels in results.items():
    #             for pid in list(rels):
    #                 if qid == pid:
    #                     results[qid].pop(pid)
    #                     popped.append(pid)

    #     ndcg = {}
    #     _map = {}
    #     recall = {}
    #     precision = {}
        
    #     for k in k_values:
    #         ndcg[f"NDCG@{k}"] = 0.0
    #         _map[f"MAP@{k}"] = 0.0
    #         recall[f"Recall@{k}"] = 0.0
    #         precision[f"P@{k}"] = 0.0
        
    #     map_string = "map_cut." + ",".join([str(k) for k in k_values])
    #     ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values])
    #     recall_string = "recall." + ",".join([str(k) for k in k_values])
    #     precision_string = "P." + ",".join([str(k) for k in k_values])
    #     evaluator = pytrec_eval.RelevanceEvaluator(qrels, {map_string, ndcg_string, recall_string, precision_string})
    #     scores = evaluator.evaluate(results)
        
    #     for query_id in scores.keys():
    #         for k in k_values:
    #             ndcg[f"NDCG@{k}"] += scores[query_id]["ndcg_cut_" + str(k)]
    #             _map[f"MAP@{k}"] += scores[query_id]["map_cut_" + str(k)]
    #             recall[f"Recall@{k}"] += scores[query_id]["recall_" + str(k)]
    #             precision[f"P@{k}"] += scores[query_id]["P_"+ str(k)]
        
    #     for k in k_values:
    #         ndcg[f"NDCG@{k}"] = round(ndcg[f"NDCG@{k}"]/len(scores), 5)
    #         _map[f"MAP@{k}"] = round(_map[f"MAP@{k}"]/len(scores), 5)
    #         recall[f"Recall@{k}"] = round(recall[f"Recall@{k}"]/len(scores), 5)
    #         precision[f"P@{k}"] = round(precision[f"P@{k}"]/len(scores), 5)
        
    #     for eval in [ndcg, _map, recall, precision]:
    #         logger.info("\n")
    #         for k in eval.keys():
    #             logger.info("{}: {:.4f}".format(k, eval[k]))

    #     return ndcg, _map, recall, precision
    
    @staticmethod
    def evaluate_custom(qrels: Dict[str, Dict[str, int]], 
                 results: Dict[str, Dict[str, float]], 
                 k_values: List[int], metric: str) -> Tuple[Dict[str, float]]:
        
        if metric.lower() in ["mrr", "mrr@k", "mrr_cut"]:
            return mrr(qrels, results, k_values)
        
        elif metric.lower() in ["recall_cap", "r_cap", "r_cap@k"]:
            return recall_cap(qrels, results, k_values)
        
        elif metric.lower() in ["hole", "hole@k"]:
            return hole(qrels, results, k_values)
        
        elif metric.lower() in ["acc", "top_k_acc", "accuracy", "accuracy@k", "top_k_accuracy"]:
            return top_k_accuracy(qrels, results, k_values)


================================================
FILE: src/retrieve/beir/beir/retrieval/models/__init__.py
================================================
from .sentence_bert import SentenceBERT
from .use_qa import UseQA
from .sparta import SPARTA
from .dpr import DPR
from .bpr import BinarySentenceBERT
from .unicoil import UniCOIL
from .splade import SPLADE
from .tldr import TLDR


================================================
FILE: src/retrieve/beir/beir/retrieval/models/bpr.py
================================================
from sentence_transformers import SentenceTransformer
from torch import Tensor
from typing import List, Dict, Union, Tuple
import numpy as np

class BinarySentenceBERT:
    def __init__(self, model_path: Union[str, Tuple] = None, sep: str = " ", threshold: Union[float, Tensor] = 0, **kwargs):
        self.sep = sep
        self.threshold = threshold
        
        if isinstance(model_path, str):
            self.q_model = SentenceTransformer(model_path)
            self.doc_model = self.q_model
        
        elif isinstance(model_path, tuple):
            self.q_model = SentenceTransformer(model_path[0])
            self.doc_model = SentenceTransformer(model_path[1])
    
    def _convert_embedding_to_binary_code(self, embeddings: List[Tensor]) -> List[Tensor]:
        return embeddings.new_ones(embeddings.size()).masked_fill_(embeddings < self.threshold, -1.0)
    
    def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs) -> Union[List[Tensor], np.ndarray, Tensor]:
        return self.q_model.encode(queries, batch_size=batch_size, **kwargs)
    
    def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 8, **kwargs) -> np.ndarray:
        sentences = [(doc["title"] + self.sep + doc["text"]).strip() for doc in corpus]
        embs = self.doc_model.encode(sentences, batch_size=batch_size, convert_to_tensor=True, **kwargs)
        embs = self._convert_embedding_to_binary_code(embs).cpu().numpy()
        embs = np.where(embs == -1, 0, embs).astype(np.bool)
        embs = np.packbits(embs).reshape(embs.shape[0], -1)
        return np.vstack(embs)

================================================
FILE: src/retrieve/beir/beir/retrieval/models/dpr.py
================================================
from transformers import DPRContextEncoder, DPRContextEncoderTokenizerFast
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizerFast
from typing import Union, List, Dict, Tuple
from tqdm.autonotebook import trange
import torch

class DPR:
    def __init__(self, model_path: Union[str, Tuple] = None, **kwargs):
        # Query tokenizer and model
        self.q_tokenizer = DPRQuestionEncoderTokenizerFast.from_pretrained(model_path[0])
        self.q_model = DPRQuestionEncoder.from_pretrained(model_path[0])
        self.q_model.cuda()
        self.q_model.eval()
        
        # Context tokenizer and model
        self.ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(model_path[1])
        self.ctx_model = DPRContextEncoder.from_pretrained(model_path[1])
        self.ctx_model.cuda()
        self.ctx_model.eval()
    
    def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs) -> torch.Tensor:
        query_embeddings = []
        with torch.no_grad():
            for start_idx in trange(0, len(queries), batch_size):
                encoded = self.q_tokenizer(queries[start_idx:start_idx+batch_size], truncation=True, padding=True, return_tensors='pt')
                model_out = self.q_model(encoded['input_ids'].cuda(), attention_mask=encoded['attention_mask'].cuda())
                query_embeddings += model_out.pooler_output

        return torch.stack(query_embeddings)
        
    def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 8, **kwargs) -> torch.Tensor:
        
        corpus_embeddings = []
        with torch.no_grad():
            for start_idx in trange(0, len(corpus), batch_size):
                titles = [row['title'] for row in corpus[start_idx:start_idx+batch_size]]
                texts = [row['text']  for row in corpus[start_idx:start_idx+batch_size]]
                encoded = self.ctx_tokenizer(titles, texts, truncation='longest_first', padding=True, return_tensors='pt')
                model_out = self.ctx_model(encoded['input_ids'].cuda(), attention_mask=encoded['attention_mask'].cuda())
                corpus_embeddings += model_out.pooler_output.detach()
        
        return torch.stack(corpus_embeddings)

================================================
FILE: src/retrieve/beir/beir/retrieval/models/sentence_bert.py
================================================
from sentence_transformers import SentenceTransformer
from torch import Tensor
import torch.multiprocessing as mp
from typing import List, Dict, Union, Tuple
import numpy as np
import logging
from datasets import Dataset
from tqdm import tqdm

logger = logging.getLogger(__name__)


class SentenceBERT:
    def __init__(self, model_path: Union[str, Tuple] = None, sep: str = " ", **kwargs):
        self.sep = sep
        
        if isinstance(model_path, str):
            self.q_model = SentenceTransformer(model_path)
            self.doc_model = self.q_model
        
        elif isinstance(model_path, tuple):
            self.q_model = SentenceTransformer(model_path[0])
            self.doc_model = SentenceTransformer(model_path[1])
    
    def start_multi_process_pool(self, target_devices: List[str] = None) -> Dict[str, object]:
        logger.info("Start multi-process pool on devices: {}".format(', '.join(map(str, target_devices))))

        ctx = mp.get_context('spawn')
        input_queue = ctx.Queue()
        output_queue = ctx.Queue()
        processes = []

        for process_id, device_name in enumerate(target_devices):
            p = ctx.Process(target=SentenceTransformer._encode_multi_process_worker, args=(process_id, device_name, self.doc_model, input_queue, output_queue), daemon=True)
            p.start()
            processes.append(p)

        return {'input': input_queue, 'output': output_queue, 'processes': processes}

    def stop_multi_process_pool(self, pool: Dict[str, object]):
        output_queue = pool['output']
        [output_queue.get() for _ in range(len(pool['processes']))]
        return self.doc_model.stop_multi_process_pool(pool)

    def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs) -> Union[List[Tensor], np.ndarray, Tensor]:
        return self.q_model.encode(queries, batch_size=batch_size, **kwargs)
    
    def encode_corpus(self, corpus: Union[List[Dict[str, str]], Dict[str, List]], batch_size: int = 8, **kwargs) -> Union[List[Tensor], np.ndarray, Tensor]:
        if type(corpus) is dict:
            sentences = [(corpus["title"][i] + self.sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() for i in range(len(corpus['text']))]
        else:
            sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
        return self.doc_model.encode(sentences, batch_size=batch_size, **kwargs)

    ## Encoding corpus in parallel
    def encode_corpus_parallel(self, corpus: Union[List[Dict[str, str]], Dataset], pool: Dict[str, str], batch_size: int = 8, chunk_id: int = None, **kwargs):
        if type(corpus) is dict:
            sentences = [(corpus["title"][i] + self.sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() for i in range(len(corpus['text']))]
        else:
            sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
        
        if chunk_id is not None and chunk_id >= len(pool['processes']):
            output_queue = pool['output']
            output_queue.get()

        input_queue = pool['input']
        input_queue.put([chunk_id, batch_size, sentences])


================================================
FILE: src/retrieve/beir/beir/retrieval/models/sparta.py
================================================
from typing import List, Dict, Union, Tuple
from tqdm.autonotebook import trange
from transformers import AutoTokenizer, AutoModel
from scipy.sparse import csr_matrix
import torch
import numpy as np

class SPARTA:
    def __init__(self, model_path: str = None, sep: str = " ", sparse_vector_dim: int = 2000, max_length: int = 500, **kwargs):
        self.sep = sep
        self.max_length = max_length
        self.sparse_vector_dim = sparse_vector_dim
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModel.from_pretrained(model_path)
        self.initialization()
        self.bert_input_embeddings = self._bert_input_embeddings()
    
    def initialization(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model.to(self.device)
        self.model.eval()
    
    def _bert_input_embeddings(self):
        bert_input_embs = self.model.embeddings.word_embeddings(
            torch.tensor(list(range(0, len(self.tokenizer))), device=self.device))
        
        # Set Special tokens [CLS] [MASK] etc. to zero
        for special_id in self.tokenizer.all_special_ids:
            bert_input_embs[special_id] = 0 * bert_input_embs[special_id]
        
        return bert_input_embs
    
    def _compute_sparse_embeddings(self, documents):
        sparse_embeddings = []
        with torch.no_grad():
            tokens = self.tokenizer(documents, padding=True, truncation=True, return_tensors='pt', max_length=self.max_length).to(self.device)
            document_embs = self.model(**tokens).last_hidden_state
            for document_emb in document_embs:
                scores = torch.matmul(self.bert_input_embeddings, document_emb.transpose(0, 1))
                max_scores = torch.max(scores, dim=-1).values
                scores = torch.log(torch.relu(max_scores) + 1)
                top_results = torch.topk(scores, k=self.sparse_vector_dim)
                tids = top_results[1].cpu().detach().tolist()
                scores = top_results[0].cpu().detach().tolist()
                passage_emb = []
                
                for tid, score in zip(tids, scores):
                    if score > 0:
                        passage_emb.append((tid, score))
                    else:
                        break
                sparse_embeddings.append(passage_emb)

        return sparse_embeddings
    
    def encode_query(self, query: str, **kwargs):
        return self.tokenizer(query, add_special_tokens=False)['input_ids']
    
    def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 16, **kwargs):
        
        sentences = [(doc["title"] + self.sep + doc["text"]).strip() for doc in corpus]
        sparse_idx = 0
        num_elements = len(sentences) * self.sparse_vector_dim
        col = np.zeros(num_elements, dtype=np.int)
        row = np.zeros(num_elements, dtype=np.int)
        values = np.zeros(num_elements, dtype=np.float)
        
        for start_idx in trange(0, len(sentences), batch_size, desc="docs"):
            doc_embs = self._compute_sparse_embeddings(sentences[start_idx: start_idx + batch_size])
            for doc_id, emb in enumerate(doc_embs):
                for tid, score in emb:
                    col[sparse_idx] = start_idx+doc_id
                    row[sparse_idx] = tid
                    values[sparse_idx] = score
                    sparse_idx += 1
                    
        return csr_matrix((values, (row, col)), shape=(len(self.bert_input_embeddings), len(sentences)), dtype=np.float)

================================================
FILE: src/retrieve/beir/beir/retrieval/models/splade.py
================================================
import logging
from typing import List, Dict, Union
import numpy as np
import torch
from numpy import ndarray
from torch import Tensor
from tqdm.autonotebook import trange
from transformers import AutoModelForMaskedLM, AutoTokenizer
from sentence_transformers.util import batch_to_device

logger = logging.getLogger(__name__)


class SPLADE:
    def __init__(self, model_path: str = None, sep: str = " ", max_length: int = 256, **kwargs):
        self.max_length = max_length
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = SpladeNaver(model_path)
        self.model.eval()

    # Write your own encoding query function (Returns: Query embeddings as numpy array)
    def encode_queries(self, queries: List[str], batch_size: int, **kwargs) -> np.ndarray:
        return self.model.encode_sentence_bert(self.tokenizer, queries, is_q=True, maxlen=self.max_length)

    # Write your own encoding corpus function (Returns: Document embeddings as numpy array)  out_features
    def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs) -> np.ndarray:
        sentences = [(doc["title"] + ' ' + doc["text"]).strip() for doc in corpus]
        return self.model.encode_sentence_bert(self.tokenizer, sentences, maxlen=self.max_length)


# Chunks of this code has been taken from: https://github.com/naver/splade/blob/main/beir_evaluation/models.py
# For more details, please refer to SPLADE by Thibault Formal, Benjamin Piwowarski and Stéphane Clinchant (https://arxiv.org/abs/2107.05720)
class SpladeNaver(torch.nn.Module):
    def __init__(self, model_path):
        super().__init__()
        self.transformer = AutoModelForMaskedLM.from_pretrained(model_path)

    def forward(self, **kwargs):
        out = self.transformer(**kwargs)["logits"]  # output (logits) of MLM head, shape (bs, pad_len, voc_size)
        return torch.max(torch.log(1 + torch.relu(out)) * kwargs["attention_mask"].unsqueeze(-1), dim=1).values

    def _text_length(self, text: Union[List[int], List[List[int]]]):
        """helper function to get the length for the input text. Text can be either
        a list of ints (which means a single text as input), or a tuple of list of ints
        (representing several text inputs to the model).
        """

        if isinstance(text, dict):  # {key: value} case
            return len(next(iter(text.values())))
        elif not hasattr(text, '__len__'):  # Object has no len() method
            return 1
        elif len(text) == 0 or isinstance(text[0], int):  # Empty string or list of ints
            return len(text)
        else:
            return sum([len(t) for t in text])  # Sum of length of individual strings

    def encode_sentence_bert(self, tokenizer, sentences: Union[str, List[str], List[int]],
                             batch_size: int = 32,
                             show_progress_bar: bool = None,
                             output_value: str = 'sentence_embedding',
                             convert_to_numpy: bool = True,
                             convert_to_tensor: bool = False,
                             device: str = None,
                             normalize_embeddings: bool = False,
                             maxlen: int = 512,
                             is_q: bool = False) -> Union[List[Tensor], ndarray, Tensor]:
        """
        Computes sentence embeddings
        :param sentences: the sentences to embed
        :param batch_size: the batch size used for the computation
        :param show_progress_bar: Output a progress bar when encode sentences
        :param output_value:  Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings.
        :param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors.
        :param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy
        :param device: Which torch.device to use for the computation
        :param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
        :return:
           By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned.
        """
        self.eval()
        if show_progress_bar is None:
            show_progress_bar = True

        if convert_to_tensor:
            convert_to_numpy = False

        if output_value == 'token_embeddings':
            convert_to_tensor = False
            convert_to_numpy = False

        input_was_string = False
        if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
            # Cast an individual sentence to a list with length 1
            sentences = [sentences]
            input_was_string = True

        if device is None:
            device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

        self.to(device)

        all_embeddings = []
        length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
        sentences_sorted = [sentences[idx] for idx in length_sorted_idx]

        for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
            sentences_batch = sentences_sorted[start_index:start_index + batch_size]
            # features = tokenizer(sentences_batch)
            # print(sentences_batch)
            features = tokenizer(sentences_batch,
                                 add_special_tokens=True,
                                 padding="longest",  # pad to max sequence length in batch
                                 truncation="only_first",  # truncates to self.max_length
                                 max_length=maxlen,
                                 return_attention_mask=True,
                                 return_tensors="pt")
            # print(features)
            features = batch_to_device(features, device)

            with torch.no_grad():
                out_features = self.forward(**features)
                if output_value == 'token_embeddings':
                    embeddings = []
                    for token_emb, attention in zip(out_features[output_value], out_features['attention_mask']):
                        last_mask_id = len(attention) - 1
                        while last_mask_id > 0 and attention[last_mask_id].item() == 0:
                            last_mask_id -= 1
                        embeddings.append(token_emb[0:last_mask_id + 1])
                else:  # Sentence embeddings
                    # embeddings = out_features[output_value]
                    embeddings = out_features
                    embeddings = embeddings.detach()
                    if normalize_embeddings:
                        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
                    # fixes for #522 and #487 to avoid oom problems on gpu with large datasets
                    if convert_to_numpy:
                        embeddings = embeddings.cpu()
                all_embeddings.extend(embeddings)
        all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]
        if convert_to_tensor:
            all_embeddings = torch.stack(all_embeddings)
        elif convert_to_numpy:
            all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
        if input_was_string:
            all_embeddings = all_embeddings[0]
        return all_embeddings

================================================
FILE: src/retrieve/beir/beir/retrieval/models/tldr.py
================================================
from sentence_transformers import SentenceTransformer    
import torch
from torch import Tensor
from typing import List, Dict, Union, Tuple
import numpy as np
import importlib.util

if importlib.util.find_spec("tldr") is not None:
    from tldr import TLDR as NaverTLDR

class TLDR:
    def __init__(self, encoder_model: SentenceTransformer, model_path: Union[str, Tuple] = None, sep: str = " ", n_components: int = 128, n_neighbors: int = 5,
                encoder: str = "linear", projector: str = "mlp-2-2048", verbose: int = 2, knn_approximation: str = None, output_folder: str = "data/", **kwargs):
        self.encoder_model = encoder_model
        self.sep = sep
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.output_folder = output_folder
        
        if model_path: self.load(model_path)
        
        else: 
            self.model = NaverTLDR(
                n_components=n_components,
                n_neighbors=n_neighbors,
                encoder=encoder,
                projector=projector,
                device=self.device,
                verbose=verbose,
                knn_approximation=knn_approximation,
            )
    
    def fit(self, corpus: List[Dict[str, str]], batch_size: int = 8, epochs: int = 100, warmup_epochs: int = 10, 
            train_batch_size: int = 1024, print_every: int = 100, **kwargs):
        
        sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
        self.model.fit(self.encoder_model.encode(sentences, batch_size=batch_size, **kwargs), 
                 epochs=epochs, 
                 warmup_epochs=warmup_epochs, 
                 batch_size=batch_size, 
                 output_folder=self.output_folder, 
                 print_every=print_every)
    
    def save(self, model_path: str, knn_path: str = None):
        self.model.save(model_path)
        if knn_path: self.model.save_knn(knn_path)
    
    def load(self, model_path: str):
        self.model = NaverTLDR()
        self.model.load(model_path, init=True)

    def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs) -> Union[List[Tensor], np.ndarray, Tensor]:
        return self.model.transform(self.encoder_model.encode(queries, batch_size=batch_size, **kwargs), l2_norm=True)
    
    def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 8, **kwargs) -> Union[List[Tensor], np.ndarray, Tensor]:
        sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
        return self.model.transform(self.encoder_model.encode(sentences, batch_size=batch_size, **kwargs), l2_norm=True)

================================================
FILE: src/retrieve/beir/beir/retrieval/models/unicoil.py
================================================
from typing import Optional, List, Dict, Union, Tuple
from transformers import BertConfig, BertModel, BertTokenizer, PreTrainedModel
import numpy as np
import torch
from tqdm.autonotebook import trange
from scipy.sparse import csr_matrix

class UniCOIL:
    def __init__(self, model_path: Union[str, Tuple] = None, sep: str = " ", query_max_length: int = 128, 
                doc_max_length: int = 500, **kwargs):
        self.sep = sep
        self.model = UniCoilEncoder.from_pretrained(model_path)
        self.tokenizer = BertTokenizer.from_pretrained(model_path)
        self.bert_input_emb = len(self.tokenizer.get_vocab())
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.query_max_length = query_max_length
        self.doc_max_length = doc_max_length 
        self.model.to(self.device)
        self.model.eval()
    
    def encode_query(self, query: str, batch_size: int = 16, **kwargs):
        embedding = np.zeros(self.bert_input_emb, dtype=np.float)
        input_ids = self.tokenizer(query, max_length=self.query_max_length, padding='longest',
                                        truncation=True, add_special_tokens=True,
                                        return_tensors='pt').to(self.device)["input_ids"]
        
        with torch.no_grad():
            batch_weights = self.model(input_ids).cpu().detach().numpy()
            batch_token_ids = input_ids.cpu().detach().numpy()
            np.put(embedding, batch_token_ids, batch_weights.flatten())
        
        return embedding

    def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 8, **kwargs):
        sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
        return self.encode(sentences, batch_size=batch_size, max_length=self.doc_max_length)
    
    def encode(
        self,
        sentences: Union[str, List[str], List[int]],
        batch_size: int = 32,
        max_length: int = 512) -> np.ndarray:

        passage_embs = []
        non_zero_tokens = 0
        
        for start_idx in trange(0, len(sentences), batch_size, desc="docs"):
            documents = sentences[start_idx: start_idx + batch_size]
            input_ids = self.tokenizer(documents, max_length=max_length, padding='longest',
                                        truncation=True, add_special_tokens=True,
                                        return_tensors='pt').to(self.device)["input_ids"]

            with torch.no_grad():
                batch_weights = self.model(input_ids).cpu().detach().numpy()
                batch_token_ids = input_ids.cpu().detach().numpy()
            
            for idx in range(len(batch_token_ids)):
                token_ids_and_embs = list(zip(batch_token_ids[idx], batch_weights[idx].flatten()))
                non_zero_tokens += len(token_ids_and_embs)
                passage_embs.append(token_ids_and_embs)
            
        col = np.zeros(non_zero_tokens, dtype=np.int)
        row = np.zeros(non_zero_tokens, dtype=np.int)
        values = np.zeros(non_zero_tokens, dtype=np.float)
        sparse_idx = 0    
        
        for pid, emb in enumerate(passage_embs):
            for tid, score in emb:
                col[sparse_idx] = pid
                row[sparse_idx] = tid
                values[sparse_idx] = score
                sparse_idx += 1

        return csr_matrix((values, (col, row)), shape=(len(sentences), self.bert_input_emb), dtype=np.float)

# class UniCOIL:
#     def __init__(self, model_path: Union[str, Tuple] = None, sep: str = " ", **kwargs):
#         self.sep = sep
#         self.model = UniCoilEncoder.from_pretrained(model_path)
#         self.tokenizer = BertTokenizer.from_pretrained(model_path)
#         self.sparse_vector_dim = len(self.tokenizer.get_vocab())
#         self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
#         self.model.to(self.device)
#         self.model.eval()

#     def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs):
#         max_length = 128  # hardcode for now
#         return self.encode(queries, batch_size=batch_size, max_length=max_length)

#     def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 8, **kwargs):
#         max_length = 500
#         sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
#         return self.encode(sentences, batch_size=batch_size, max_length=max_length)
    
#     def encode(
#         self,
#         sentences: Union[str, List[str], List[int]],
#         batch_size: int = 32,
#         max_length: int = 512) -> np.ndarray:

#         embeddings = np.zeros((len(sentences), self.sparse_vector_dim), dtype=np.float)
        
#         for start_idx in trange(0, len(sentences), batch_size, desc="docs"):
#             documents = sentences[start_idx: start_idx + batch_size]
#             input_ids = self.tokenizer(documents, max_length=max_length, padding='longest',
#                                         truncation=True, add_special_tokens=True,
#                                         return_tensors='pt').to(self.device)["input_ids"]

#             with torch.no_grad():
#                 batch_weights = self.model(input_ids).cpu().detach().numpy()
#                 batch_token_ids = input_ids.cpu().detach().numpy()
            
#             for idx in range(len(batch_token_ids)):
#                 np.put(embeddings[start_idx + idx], batch_token_ids[idx], batch_weights[idx].flatten())

#         return embeddings
#         # return csr_matrix((values, (row, col)), shape=(len(sentences), self.sparse_vector_dim), dtype=np.float).toarray()


# Chunks of this code has been taken from: https://github.com/castorini/pyserini/blob/master/pyserini/encode/_unicoil.py
# For more details, please refer to uniCOIL by Jimmy Lin and Xueguang Ma (https://arxiv.org/abs/2106.14807)
class UniCoilEncoder(PreTrainedModel):
    config_class = BertConfig
    base_model_prefix = 'coil_encoder'
    load_tf_weights = None

    def __init__(self, config: BertConfig):
        super().__init__(config)
        self.config = config
        self.bert = BertModel(config)
        self.tok_proj = torch.nn.Linear(config.hidden_size, 1)
        self.init_weights()

    # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
    def _init_weights(self, module):
        """ Initialize the weights """
        if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, torch.nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, torch.nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def init_weights(self):
        self.bert.init_weights()
        self.tok_proj.apply(self._init_weights)

    def forward(
            self,
            input_ids: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
    ):
        input_shape = input_ids.size()
        device = input_ids.device
        if attention_mask is None:
            attention_mask = (
                torch.ones(input_shape, device=device)
                if input_ids is None
                else (input_ids != self.bert.config.pad_token_id)
            )
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        tok_weights = self.tok_proj(sequence_output)
        tok_weights
Download .txt
gitextract_zinswnzk/

├── README.md
├── all_prompt.md
├── configs/
│   ├── 2wikimultihopqa_llama3-8b-instruct.sh
│   ├── 2wikimultihopqa_llama3.2-1b-instruct.sh
│   ├── 2wikimultihopqa_qwen2.5-1.5b-instruct.sh
│   ├── complexwebquestions_llama3-8b-instruct.sh
│   ├── complexwebquestions_llama3.2-1b-instruct.sh
│   ├── complexwebquestions_qwen2.5-1.5b-instruct.sh
│   ├── hotpotqa_llama3-8b-instruct.sh
│   ├── hotpotqa_llama3.2-1b-instruct.sh
│   ├── hotpotqa_qwen2.5-1.5b-instruct.sh
│   ├── popqa_llama3-8b-instruct.sh
│   ├── popqa_llama3.2-1b-instruct.sh
│   └── popqa_qwen2.5-1.5b-instruct.sh
├── prep_elastic.py
├── requirements.txt
└── src/
    ├── augment.py
    ├── encode.py
    ├── fewshot/
    │   ├── 2wikimultihopqa.json
    │   └── hotpotqa.json
    ├── get_warmup_data.py
    ├── inference.py
    ├── prompt_template.py
    ├── retrieve/
    │   ├── beir/
    │   │   ├── .gitignore
    │   │   ├── .gitmodules
    │   │   ├── CONTRIBUTORS.txt
    │   │   ├── LICENSE
    │   │   ├── NOTICE.txt
    │   │   ├── README.md
    │   │   ├── beir/
    │   │   │   ├── __init__.py
    │   │   │   ├── datasets/
    │   │   │   │   ├── __init__.py
    │   │   │   │   ├── data_loader.py
    │   │   │   │   └── data_loader_hf.py
    │   │   │   ├── generation/
    │   │   │   │   ├── __init__.py
    │   │   │   │   ├── generate.py
    │   │   │   │   └── models/
    │   │   │   │       ├── __init__.py
    │   │   │   │       ├── auto_model.py
    │   │   │   │       └── tilde.py
    │   │   │   ├── logging.py
    │   │   │   ├── losses/
    │   │   │   │   ├── __init__.py
    │   │   │   │   ├── bpr_loss.py
    │   │   │   │   └── margin_mse_loss.py
    │   │   │   ├── reranking/
    │   │   │   │   ├── __init__.py
    │   │   │   │   ├── models/
    │   │   │   │   │   ├── __init__.py
    │   │   │   │   │   ├── cross_encoder.py
    │   │   │   │   │   └── mono_t5.py
    │   │   │   │   └── rerank.py
    │   │   │   ├── retrieval/
    │   │   │   │   ├── __init__.py
    │   │   │   │   ├── custom_metrics.py
    │   │   │   │   ├── evaluation.py
    │   │   │   │   ├── models/
    │   │   │   │   │   ├── __init__.py
    │   │   │   │   │   ├── bpr.py
    │   │   │   │   │   ├── dpr.py
    │   │   │   │   │   ├── sentence_bert.py
    │   │   │   │   │   ├── sparta.py
    │   │   │   │   │   ├── splade.py
    │   │   │   │   │   ├── tldr.py
    │   │   │   │   │   ├── unicoil.py
    │   │   │   │   │   └── use_qa.py
    │   │   │   │   ├── search/
    │   │   │   │   │   ├── __init__.py
    │   │   │   │   │   ├── base.py
    │   │   │   │   │   ├── dense/
    │   │   │   │   │   │   ├── __init__.py
    │   │   │   │   │   │   ├── exact_search.py
    │   │   │   │   │   │   ├── exact_search_multi_gpu.py
    │   │   │   │   │   │   ├── faiss_index.py
    │   │   │   │   │   │   ├── faiss_search.py
    │   │   │   │   │   │   └── util.py
    │   │   │   │   │   ├── lexical/
    │   │   │   │   │   │   ├── __init__.py
    │   │   │   │   │   │   ├── bm25_search.py
    │   │   │   │   │   │   └── elastic_search.py
    │   │   │   │   │   └── sparse/
    │   │   │   │   │       ├── __init__.py
    │   │   │   │   │       └── sparse_search.py
    │   │   │   │   └── train.py
    │   │   │   └── util.py
    │   │   ├── examples/
    │   │   │   ├── beir-pyserini/
    │   │   │   │   ├── Dockerfile
    │   │   │   │   ├── config.py
    │   │   │   │   ├── dockerhub.sh
    │   │   │   │   └── main.py
    │   │   │   ├── benchmarking/
    │   │   │   │   ├── benchmark_bm25.py
    │   │   │   │   ├── benchmark_bm25_ce_reranking.py
    │   │   │   │   └── benchmark_sbert.py
    │   │   │   ├── dataset/
    │   │   │   │   ├── README.md
    │   │   │   │   ├── download_dataset.py
    │   │   │   │   ├── md5.csv
    │   │   │   │   └── scrape_tweets.py
    │   │   │   ├── generation/
    │   │   │   │   ├── passage_expansion_tilde.py
    │   │   │   │   ├── query_gen.py
    │   │   │   │   ├── query_gen_and_train.py
    │   │   │   │   └── query_gen_multi_gpu.py
    │   │   │   └── retrieval/
    │   │   │       ├── README.md
    │   │   │       ├── evaluation/
    │   │   │       │   ├── README.md
    │   │   │       │   ├── custom/
    │   │   │       │   │   ├── evaluate_custom_dataset.py
    │   │   │       │   │   ├── evaluate_custom_dataset_files.py
    │   │   │       │   │   ├── evaluate_custom_metrics.py
    │   │   │       │   │   └── evaluate_custom_model.py
    │   │   │       │   ├── dense/
    │   │   │       │   │   ├── evaluate_ance.py
    │   │   │       │   │   ├── evaluate_bpr.py
    │   │   │       │   │   ├── evaluate_dim_reduction.py
    │   │   │       │   │   ├── evaluate_dpr.py
    │   │   │       │   │   ├── evaluate_faiss_dense.py
    │   │   │       │   │   ├── evaluate_sbert.py
    │   │   │       │   │   ├── evaluate_sbert_hf_loader.py
    │   │   │       │   │   ├── evaluate_sbert_multi_gpu.py
    │   │   │       │   │   ├── evaluate_tldr.py
    │   │   │       │   │   └── evaluate_useqa.py
    │   │   │       │   ├── late-interaction/
    │   │   │       │   │   └── README.md
    │   │   │       │   ├── lexical/
    │   │   │       │   │   ├── evaluate_anserini_bm25.py
    │   │   │       │   │   ├── evaluate_bm25.py
    │   │   │       │   │   └── evaluate_multilingual_bm25.py
    │   │   │       │   ├── reranking/
    │   │   │       │   │   ├── README.md
    │   │   │       │   │   ├── evaluate_bm25_ce_reranking.py
    │   │   │       │   │   ├── evaluate_bm25_monot5_reranking.py
    │   │   │       │   │   └── evaluate_bm25_sbert_reranking.py
    │   │   │       │   └── sparse/
    │   │   │       │       ├── evaluate_anserini_docT5query.py
    │   │   │       │       ├── evaluate_anserini_docT5query_parallel.py
    │   │   │       │       ├── evaluate_deepct.py
    │   │   │       │       ├── evaluate_sparta.py
    │   │   │       │       ├── evaluate_splade.py
    │   │   │       │       └── evaluate_unicoil.py
    │   │   │       └── training/
    │   │   │           ├── train_msmarco_v2.py
    │   │   │           ├── train_msmarco_v3.py
    │   │   │           ├── train_msmarco_v3_bpr.py
    │   │   │           ├── train_msmarco_v3_margin_MSE.py
    │   │   │           ├── train_sbert.py
    │   │   │           └── train_sbert_BM25_hardnegs.py
    │   │   ├── setup.cfg
    │   │   └── setup.py
    │   ├── readme.md
    │   └── retriever.py
    ├── root_dir_path.py
    ├── utils.py
    └── warmup_lora.py
Download .txt
SYMBOL INDEX (352 symbols across 50 files)

FILE: prep_elastic.py
  function build_elasticsearch (line 8) | def build_elasticsearch(

FILE: src/augment.py
  function load_popqa (line 15) | def load_popqa(data_path):
  function load_complexwebquestions (line 32) | def load_complexwebquestions(data_path):
  function load_2wikimultihopqa (line 54) | def load_2wikimultihopqa(data_path):
  function load_hotpotqa (line 88) | def load_hotpotqa(data_path):
  function load_default_format_data (line 121) | def load_default_format_data(data_path):
  function get_rewrite (line 139) | def get_rewrite(passage, model_name, model=None, tokenizer=None, generat...
  function fix_qa (line 157) | def fix_qa(qa):
  function get_qa (line 173) | def get_qa(passage, model_name, model=None, tokenizer=None, generation_c...
  function main (line 205) | def main(args):

FILE: src/encode.py
  class TrainingData (line 25) | class TrainingData(Dataset):
    method __init__ (line 28) | def __init__(self, prompt_ids, tokenizer, max_length=3000):
    method __len__ (line 47) | def __len__(self):
    method __getitem__ (line 50) | def __getitem__(self, idx) -> Dict[str, list]:
  class TrainingDataCollator (line 54) | class TrainingDataCollator(DefaultDataCollator):
    method __init__ (line 55) | def __init__(self, tokenizer, device):
    method __call__ (line 60) | def __call__(self, examples: List[Dict[str, list]]) -> Dict[str, torch...
  function get_train_data (line 71) | def get_train_data(aug_model, augments, tokenizer, args):
  function train (line 94) | def train(question, augments, args, model, tokenizer,
  function main (line 124) | def main(args):

FILE: src/get_warmup_data.py
  function create_direct (line 18) | def create_direct():
  function load_2wikimultihopqa (line 35) | def load_2wikimultihopqa(data_path):
  function load_hotpotqa (line 71) | def load_hotpotqa(data_path):
  function create_cot (line 114) | def create_cot():

FILE: src/inference.py
  function main (line 13) | def main(args):

FILE: src/prompt_template.py
  function _get_prompt (line 23) | def _get_prompt(question, passages=None, answer=None):
  function get_fewshot (line 42) | def get_fewshot(dataset):
  function get_prompt (line 59) | def get_prompt(tokenizer, question, passages=None, answer=None, with_cot...

FILE: src/retrieve/beir/beir/datasets/data_loader.py
  class GenericDataLoader (line 10) | class GenericDataLoader:
    method __init__ (line 12) | def __init__(self, data_folder: str = None, prefix: str = None, corpus...
    method check (line 28) | def check(fIn: str, ext: str):
    method load_custom (line 35) | def load_custom(self) -> Tuple[Dict[str, Dict[str, str]], Dict[str, st...
    method load (line 59) | def load(self, split="test") -> Tuple[Dict[str, Dict[str, str]], Dict[...
    method load_corpus (line 84) | def load_corpus(self) -> Dict[str, Dict[str, str]]:
    method _load_corpus (line 96) | def _load_corpus(self):
    method _load_queries (line 107) | def _load_queries(self):
    method _load_qrels (line 114) | def _load_qrels(self):

FILE: src/retrieve/beir/beir/datasets/data_loader_hf.py
  class HFDataLoader (line 10) | class HFDataLoader:
    method __init__ (line 12) | def __init__(self, hf_repo: str = None, hf_repo_qrels: str = None, dat...
    method check (line 38) | def check(fIn: str, ext: str):
    method load (line 45) | def load(self, split="test") -> Tuple[Dict[str, Dict[str, str]], Dict[...
    method load_corpus (line 77) | def load_corpus(self) -> Dict[str, Dict[str, str]]:
    method _load_corpus (line 89) | def _load_corpus(self):
    method _load_queries (line 100) | def _load_queries(self):
    method _load_qrels (line 111) | def _load_qrels(self, split):

FILE: src/retrieve/beir/beir/generation/generate.py
  class PassageExpansion (line 8) | class PassageExpansion:
    method __init__ (line 9) | def __init__(self, model, **kwargs):
    method save (line 14) | def save(output_dir: str, corpus: Dict[str, str], prefix: str):
    method expand (line 22) | def expand(self,
  class QueryGenerator (line 57) | class QueryGenerator:
    method __init__ (line 58) | def __init__(self, model, **kwargs):
    method save (line 64) | def save(output_dir: str, queries: Dict[str, str], qrels: Dict[str, Di...
    method generate (line 78) | def generate(self,
    method generate_multi_process (line 135) | def generate_multi_process(self,

FILE: src/retrieve/beir/beir/generation/models/auto_model.py
  class QGenModel (line 10) | class QGenModel:
    method __init__ (line 11) | def __init__(self, model_path: str, gen_prefix: str = "", use_fast: bo...
    method generate (line 19) | def generate(self, corpus: List[Dict[str, str]], ques_per_passage: int...
    method start_multi_process_pool (line 48) | def start_multi_process_pool(self, target_devices: List[str] = None):
    method stop_multi_process_pool (line 78) | def stop_multi_process_pool(pool):
    method _generate_multi_process_worker (line 93) | def _generate_multi_process_worker(target_device: str, model, tokenize...
    method generate_multi_process (line 121) | def generate_multi_process(self, corpus: List[Dict[str, str]], ques_pe...

FILE: src/retrieve/beir/beir/generation/models/tilde.py
  class TILDE (line 12) | class TILDE:
    method __init__ (line 13) | def __init__(self, model_path: str, gen_prefix: str = "", use_fast: bo...
    method _clean_vocab (line 22) | def _clean_vocab(self, tokenizer, do_stopwords=True):
    method generate (line 55) | def generate(self, corpus: List[Dict[str, str]], top_k: int, max_lengt...

FILE: src/retrieve/beir/beir/logging.py
  class LoggingHandler (line 4) | class LoggingHandler(logging.Handler):
    method __init__ (line 5) | def __init__(self, level=logging.NOTSET):
    method emit (line 8) | def emit(self, record):

FILE: src/retrieve/beir/beir/losses/bpr_loss.py
  class BPRLoss (line 6) | class BPRLoss(torch.nn.Module):
    method __init__ (line 22) | def __init__(self, model: SentenceTransformer, scale: float = 1.0, sim...
    method convert_to_binary (line 39) | def convert_to_binary(self, input_repr: torch.Tensor) -> torch.Tensor:
    method forward (line 46) | def forward(self, sentence_features: Iterable[Dict[str, torch.Tensor]]...

FILE: src/retrieve/beir/beir/losses/margin_mse_loss.py
  class MarginMSELoss (line 8) | class MarginMSELoss(nn.Module):
    method __init__ (line 19) | def __init__(self, model, scale: float = 1.0, similarity_fct = 'dot'):
    method forward (line 26) | def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labe...

FILE: src/retrieve/beir/beir/reranking/models/cross_encoder.py
  class CrossEncoder (line 5) | class CrossEncoder:
    method __init__ (line 6) | def __init__(self, model_path: str, **kwargs):
    method predict (line 9) | def predict(self, sentences: List[Tuple[str,str]], batch_size: int = 3...

FILE: src/retrieve/beir/beir/reranking/models/mono_t5.py
  class QueryDocumentBatch (line 20) | class QueryDocumentBatch:
    method __len__ (line 25) | def __len__(self):
  class QueryDocumentBatchTokenizer (line 28) | class QueryDocumentBatchTokenizer:
    method __init__ (line 29) | def __init__(self,
    method encode (line 37) | def encode(self, strings: List[str]):
    method traverse_query_document (line 45) | def traverse_query_document(
  class T5BatchTokenizer (line 55) | class T5BatchTokenizer(QueryDocumentBatchTokenizer):
    method __init__ (line 56) | def __init__(self, *args, **kwargs):
  function greedy_decode (line 72) | def greedy_decode(model: PreTrainedModel,
  class MonoT5 (line 99) | class MonoT5:
    method __init__ (line 100) | def __init__(self,
    method get_model (line 115) | def get_model(model_path: str, *args, device: str = None, **kwargs) ->...
    method get_tokenizer (line 121) | def get_tokenizer(model_path: str, *args, **kwargs) -> T5BatchTokenizer:
    method get_prediction_tokens (line 127) | def get_prediction_tokens(model_path: str, tokenizer, token_false, tok...
    method predict (line 133) | def predict(self, sentences: List[Tuple[str,str]], batch_size: int = 3...

FILE: src/retrieve/beir/beir/reranking/rerank.py
  class Rerank (line 7) | class Rerank:
    method __init__ (line 9) | def __init__(self, model, batch_size: int = 128, **kwargs):
    method rerank (line 14) | def rerank(self,

FILE: src/retrieve/beir/beir/retrieval/custom_metrics.py
  function mrr (line 4) | def mrr(qrels: Dict[str, Dict[str, int]],
  function recall_cap (line 33) | def recall_cap(qrels: Dict[str, Dict[str, int]],
  function hole (line 60) | def hole(qrels: Dict[str, Dict[str, int]],
  function top_k_accuracy (line 89) | def top_k_accuracy(

FILE: src/retrieve/beir/beir/retrieval/evaluation.py
  class EvaluateRetrieval (line 9) | class EvaluateRetrieval:
    method __init__ (line 11) | def __init__(self, retriever: BaseSearch = None, k_values: List[int] =...
    method retrieve (line 17) | def retrieve(self, corpus: Dict[str, Dict[str, str]], queries: Dict[st...
    method rerank (line 22) | def rerank(self,
    method evaluate_custom (line 94) | def evaluate_custom(qrels: Dict[str, Dict[str, int]],

FILE: src/retrieve/beir/beir/retrieval/models/bpr.py
  class BinarySentenceBERT (line 6) | class BinarySentenceBERT:
    method __init__ (line 7) | def __init__(self, model_path: Union[str, Tuple] = None, sep: str = " ...
    method _convert_embedding_to_binary_code (line 19) | def _convert_embedding_to_binary_code(self, embeddings: List[Tensor]) ...
    method encode_queries (line 22) | def encode_queries(self, queries: List[str], batch_size: int = 16, **k...
    method encode_corpus (line 25) | def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int ...

FILE: src/retrieve/beir/beir/retrieval/models/dpr.py
  class DPR (line 7) | class DPR:
    method __init__ (line 8) | def __init__(self, model_path: Union[str, Tuple] = None, **kwargs):
    method encode_queries (line 21) | def encode_queries(self, queries: List[str], batch_size: int = 16, **k...
    method encode_corpus (line 31) | def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int ...

FILE: src/retrieve/beir/beir/retrieval/models/sentence_bert.py
  class SentenceBERT (line 13) | class SentenceBERT:
    method __init__ (line 14) | def __init__(self, model_path: Union[str, Tuple] = None, sep: str = " ...
    method start_multi_process_pool (line 25) | def start_multi_process_pool(self, target_devices: List[str] = None) -...
    method stop_multi_process_pool (line 40) | def stop_multi_process_pool(self, pool: Dict[str, object]):
    method encode_queries (line 45) | def encode_queries(self, queries: List[str], batch_size: int = 16, **k...
    method encode_corpus (line 48) | def encode_corpus(self, corpus: Union[List[Dict[str, str]], Dict[str, ...
    method encode_corpus_parallel (line 56) | def encode_corpus_parallel(self, corpus: Union[List[Dict[str, str]], D...

FILE: src/retrieve/beir/beir/retrieval/models/sparta.py
  class SPARTA (line 8) | class SPARTA:
    method __init__ (line 9) | def __init__(self, model_path: str = None, sep: str = " ", sparse_vect...
    method initialization (line 18) | def initialization(self):
    method _bert_input_embeddings (line 23) | def _bert_input_embeddings(self):
    method _compute_sparse_embeddings (line 33) | def _compute_sparse_embeddings(self, documents):
    method encode_query (line 56) | def encode_query(self, query: str, **kwargs):
    method encode_corpus (line 59) | def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int ...

FILE: src/retrieve/beir/beir/retrieval/models/splade.py
  class SPLADE (line 14) | class SPLADE:
    method __init__ (line 15) | def __init__(self, model_path: str = None, sep: str = " ", max_length:...
    method encode_queries (line 22) | def encode_queries(self, queries: List[str], batch_size: int, **kwargs...
    method encode_corpus (line 26) | def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int,...
  class SpladeNaver (line 33) | class SpladeNaver(torch.nn.Module):
    method __init__ (line 34) | def __init__(self, model_path):
    method forward (line 38) | def forward(self, **kwargs):
    method _text_length (line 42) | def _text_length(self, text: Union[List[int], List[List[int]]]):
    method encode_sentence_bert (line 57) | def encode_sentence_bert(self, tokenizer, sentences: Union[str, List[s...

FILE: src/retrieve/beir/beir/retrieval/models/tldr.py
  class TLDR (line 11) | class TLDR:
    method __init__ (line 12) | def __init__(self, encoder_model: SentenceTransformer, model_path: Uni...
    method fit (line 32) | def fit(self, corpus: List[Dict[str, str]], batch_size: int = 8, epoch...
    method save (line 43) | def save(self, model_path: str, knn_path: str = None):
    method load (line 47) | def load(self, model_path: str):
    method encode_queries (line 51) | def encode_queries(self, queries: List[str], batch_size: int = 16, **k...
    method encode_corpus (line 54) | def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int ...

FILE: src/retrieve/beir/beir/retrieval/models/unicoil.py
  class UniCOIL (line 8) | class UniCOIL:
    method __init__ (line 9) | def __init__(self, model_path: Union[str, Tuple] = None, sep: str = " ...
    method encode_query (line 21) | def encode_query(self, query: str, batch_size: int = 16, **kwargs):
    method encode_corpus (line 34) | def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int ...
    method encode (line 38) | def encode(
  class UniCoilEncoder (line 122) | class UniCoilEncoder(PreTrainedModel):
    method __init__ (line 127) | def __init__(self, config: BertConfig):
    method _init_weights (line 135) | def _init_weights(self, module):
    method init_weights (line 147) | def init_weights(self):
    method forward (line 151) | def forward(

FILE: src/retrieve/beir/beir/retrieval/models/use_qa.py
  class UseQA (line 13) | class UseQA:
    method __init__ (line 14) | def __init__(self, hub_url=None, **kwargs):
    method initialisation (line 19) | def initialisation():
    method encode_queries (line 29) | def encode_queries(self, queries: List[str], batch_size: int = 16, **k...
    method encode_corpus (line 39) | def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int ...

FILE: src/retrieve/beir/beir/retrieval/search/base.py
  class BaseSearch (line 4) | class BaseSearch(ABC):
    method search (line 7) | def search(self,

FILE: src/retrieve/beir/beir/retrieval/search/dense/exact_search.py
  class DenseRetrievalExactSearch (line 12) | class DenseRetrievalExactSearch(BaseSearch):
    method __init__ (line 14) | def __init__(self, model, batch_size: int = 128, corpus_chunk_size: in...
    method search (line 25) | def search(self,

FILE: src/retrieve/beir/beir/retrieval/search/dense/exact_search_multi_gpu.py
  class DummyMetric (line 28) | class DummyMetric(EvaluationModule):
    method _info (line 31) | def _info(self):
    method _compute (line 40) | def _compute(self, cos_scores_top_k_values, cos_scores_top_k_idx, batc...
    method warmup (line 49) | def warmup(self):
  class DenseRetrievalParallelExactSearch (line 56) | class DenseRetrievalParallelExactSearch(BaseSearch):
    method __init__ (line 58) | def __init__(self, model, batch_size: int = 128, corpus_chunk_size: in...
    method search (line 82) | def search(self,
    method _encode_multi_process_worker (line 171) | def _encode_multi_process_worker(self, process_id, device, model, inpu...

FILE: src/retrieve/beir/beir/retrieval/search/dense/faiss_index.py
  class FaissIndex (line 13) | class FaissIndex:
    method __init__ (line 14) | def __init__(self, index: faiss.Index, passage_ids: List[int] = None):
    method search (line 20) | def search(self, query_embeddings: np.ndarray, k: int, **kwargs) -> Tu...
    method save (line 28) | def save(self, fname: str):
    method build (line 32) | def build(
    method to_gpu (line 46) | def to_gpu(self):
  class FaissHNSWIndex (line 58) | class FaissHNSWIndex(FaissIndex):
    method search (line 59) | def search(self, query_embeddings: np.ndarray, k: int, **kwargs) -> Tu...
    method save (line 63) | def save(self, output_path: str):
    method build (line 67) | def build(
  class FaissTrainIndex (line 80) | class FaissTrainIndex(FaissIndex):
    method search (line 81) | def search(self, query_embeddings: np.ndarray, k: int, **kwargs) -> Tu...
    method save (line 84) | def save(self, output_path: str):
    method build (line 88) | def build(
  class FaissBinaryIndex (line 98) | class FaissBinaryIndex(FaissIndex):
    method __init__ (line 99) | def __init__(self, index: faiss.Index, passage_ids: List[int] = None, ...
    method search (line 109) | def search(self, query_embeddings: np.ndarray, k: int, binary_k: int =...
    method save (line 158) | def save(self, fname: str):
    method build (line 162) | def build(

FILE: src/retrieve/beir/beir/retrieval/search/dense/faiss_search.py
  class DenseRetrievalFaissSearch (line 14) | class DenseRetrievalFaissSearch(BaseSearch):
    method __init__ (line 16) | def __init__(self, model, batch_size: int = 128, corpus_chunk_size: in...
    method _create_mapping_ids (line 30) | def _create_mapping_ids(self, corpus_ids):
    method _load (line 36) | def _load(self, input_dir: str, prefix: str, ext: str):
    method save (line 51) | def save(self, output_dir: str, prefix: str, ext: str):
    method _index (line 64) | def _index(self, corpus: Dict[str, Dict[str, str]], score_function: st...
    method search (line 102) | def search(self,
  class BinaryFaissSearch (line 134) | class BinaryFaissSearch(DenseRetrievalFaissSearch):
    method load (line 136) | def load(self, input_dir: str, prefix: str = "my-index", ext: str = "b...
    method index (line 146) | def index(self, corpus: Dict[str, Dict[str, str]], score_function: str...
    method save (line 153) | def save(self, output_dir: str, prefix: str = "my-index", ext: str = "...
    method search (line 156) | def search(self,
    method get_index_name (line 164) | def get_index_name(self):
  class PQFaissSearch (line 168) | class PQFaissSearch(DenseRetrievalFaissSearch):
    method __init__ (line 169) | def __init__(self, model, batch_size: int = 128, corpus_chunk_size: in...
    method load (line 177) | def load(self, input_dir: str, prefix: str = "my-index", ext: str = "p...
    method index (line 187) | def index(self, corpus: Dict[str, Dict[str, str]], score_function: str...
    method save (line 212) | def save(self, output_dir: str, prefix: str = "my-index", ext: str = "...
    method search (line 215) | def search(self,
    method get_index_name (line 223) | def get_index_name(self):
  class HNSWFaissSearch (line 227) | class HNSWFaissSearch(DenseRetrievalFaissSearch):
    method __init__ (line 228) | def __init__(self, model, batch_size: int = 128, corpus_chunk_size: in...
    method load (line 236) | def load(self, input_dir: str, prefix: str = "my-index", ext: str = "h...
    method index (line 247) | def index(self, corpus: Dict[str, Dict[str, str]], score_function: str...
    method save (line 265) | def save(self, output_dir: str, prefix: str = "my-index", ext: str = "...
    method search (line 268) | def search(self,
    method get_index_name (line 276) | def get_index_name(self):
  class HNSWSQFaissSearch (line 279) | class HNSWSQFaissSearch(DenseRetrievalFaissSearch):
    method __init__ (line 280) | def __init__(self, model, batch_size: int = 128, corpus_chunk_size: in...
    method load (line 290) | def load(self, input_dir: str, prefix: str = "my-index", ext: str = "h...
    method index (line 295) | def index(self, corpus: Dict[str, Dict[str, str]], score_function: str...
    method save (line 310) | def save(self, output_dir: str, prefix: str = "my-index", ext: str = "...
    method search (line 313) | def search(self,
    method get_index_name (line 321) | def get_index_name(self):
  class FlatIPFaissSearch (line 324) | class FlatIPFaissSearch(DenseRetrievalFaissSearch):
    method load (line 325) | def load(self, input_dir: str, prefix: str = "my-index", ext: str = "f...
    method index (line 335) | def index(self, corpus: Dict[str, Dict[str, str]], score_function: str...
    method save (line 345) | def save(self, output_dir: str, prefix: str = "my-index", ext: str = "...
    method search (line 348) | def search(self,
    method get_index_name (line 356) | def get_index_name(self):
  class PCAFaissSearch (line 359) | class PCAFaissSearch(DenseRetrievalFaissSearch):
    method __init__ (line 360) | def __init__(self, model, base_index: faiss.Index, output_dimension: i...
    method load (line 370) | def load(self, input_dir: str, prefix: str = "my-index", ext: str = "p...
    method index (line 381) | def index(self, corpus: Dict[str, Dict[str, str]], score_function: str...
    method save (line 401) | def save(self, output_dir: str, prefix: str = "my-index", ext: str = "...
    method search (line 404) | def search(self,
    method get_index_name (line 412) | def get_index_name(self):
  class SQFaissSearch (line 415) | class SQFaissSearch(DenseRetrievalFaissSearch):
    method __init__ (line 416) | def __init__(self, model, batch_size: int = 128, corpus_chunk_size: in...
    method load (line 422) | def load(self, input_dir: str, prefix: str = "my-index", ext: str = "s...
    method index (line 432) | def index(self, corpus: Dict[str, Dict[str, str]], score_function: str...
    method save (line 447) | def save(self, output_dir: str, prefix: str = "my-index", ext: str = "...
    method search (line 450) | def search(self,
    method get_index_name (line 458) | def get_index_name(self):

FILE: src/retrieve/beir/beir/retrieval/search/dense/util.py
  function cos_sim (line 5) | def cos_sim(a: torch.Tensor, b: torch.Tensor):
  function dot_score (line 26) | def dot_score(a: torch.Tensor, b: torch.Tensor):
  function normalize (line 45) | def normalize(a: np.ndarray) -> np.ndarray:
  function save_dict_to_tsv (line 48) | def save_dict_to_tsv(_dict, output_path, keys=[]):
  function load_tsv_to_dict (line 56) | def load_tsv_to_dict(input_path, header=True):

FILE: src/retrieve/beir/beir/retrieval/search/lexical/bm25_search.py
  function sleep (line 7) | def sleep(seconds):
  class BM25Search (line 10) | class BM25Search(BaseSearch):
    method __init__ (line 11) | def __init__(self, index_name: str, hostname: str = "localhost", keys:...
    method initialise (line 32) | def initialise(self):
    method search (line 37) | def search(self, corpus: Dict[str, Dict[str, str]], queries: Dict[str,...
    method index (line 66) | def index(self, corpus: Dict[str, Dict[str, str]]):

FILE: src/retrieve/beir/beir/retrieval/search/lexical/elastic_search.py
  class ElasticSearch (line 11) | class ElasticSearch(object):
    method __init__ (line 13) | def __init__(self, es_credentials: Dict[str, object]):
    method check_language_supported (line 40) | def check_language_supported(self):
    method check_index_name (line 47) | def check_index_name(self):
    method create_index (line 68) | def create_index(self):
    method delete_index (line 96) | def delete_index(self):
    method bulk_add_to_index (line 105) | def bulk_add_to_index(self, generate_actions, progress):
    method lexical_search (line 119) | def lexical_search(self, text: str, top_hits: int, ids: List[str] = No...
    method lexical_multisearch (line 157) | def lexical_multisearch(self, texts: List[str], top_hits: int, skip: i...
    method generate_actions (line 203) | def generate_actions(self, dictionary: Dict[str, Dict[str, str]], upda...
    method hit_template (line 229) | def hit_template(self, es_res: Dict[str, object], hits: List[Tuple[str...

FILE: src/retrieve/beir/beir/retrieval/search/sparse/sparse_search.py
  class SparseSearch (line 9) | class SparseSearch(BaseSearch):
    method __init__ (line 11) | def __init__(self, model, batch_size: int = 16, **kwargs):
    method search (line 17) | def search(self,

FILE: src/retrieve/beir/beir/retrieval/train.py
  class TrainRetriever (line 16) | class TrainRetriever:
    method __init__ (line 18) | def __init__(self, model: SentenceTransformer, batch_size: int = 64):
    method load_train (line 22) | def load_train(self, corpus: Dict[str, Dict[str, str]], queries: Dict[...
    method load_train_triplets (line 43) | def load_train_triplets(self, triplets: List[Tuple[str, str, str]]) ->...
    method prepare_train (line 56) | def prepare_train(self, train_dataset: List[InputExample], shuffle: bo...
    method prepare_train_triplets (line 64) | def prepare_train_triplets(self, train_dataset: List[InputExample]) ->...
    method load_ir_evaluator (line 69) | def load_ir_evaluator(self, corpus: Dict[str, Dict[str, str]], queries...
    method load_dummy_evaluator (line 110) | def load_dummy_evaluator(self) -> SentenceEvaluator:
    method fit (line 113) | def fit(self,

FILE: src/retrieve/beir/beir/util.py
  function dot_score (line 13) | def dot_score(a: torch.Tensor, b: torch.Tensor):
  function cos_sim (line 32) | def cos_sim(a: torch.Tensor, b: torch.Tensor):
  function download_url (line 53) | def download_url(url: str, save_path: str, chunk_size: int = 1024):
  function unzip (line 75) | def unzip(zip_file: str, out_dir: str):
  function download_and_unzip (line 80) | def download_and_unzip(url: str, out_dir: str, chunk_size: int = 1024) -...
  function write_to_json (line 96) | def write_to_json(output_file: str, data: Dict[str, str]):
  function write_to_tsv (line 115) | def write_to_tsv(output_file: str, data: Dict[str, str]):

FILE: src/retrieve/beir/examples/beir-pyserini/config.py
  class IndexSettings (line 3) | class IndexSettings(BaseSettings):
  function hit_template (line 7) | def hit_template(hits):

FILE: src/retrieve/beir/examples/beir-pyserini/main.py
  function upload (line 12) | async def upload(file: UploadFile = File(...)):
  function index (line 22) | def index(index_name: str, threads: Optional[int] = 8):
  function search (line 35) | def search(q: str,
  function batch_search (line 51) | def batch_search(queries: List[str],
  function batch_search_rm3 (line 65) | def batch_search_rm3(queries: List[str],

FILE: src/retrieve/beir/examples/dataset/download_dataset.py
  function main (line 5) | def main():

FILE: src/retrieve/beir/examples/dataset/scrape_tweets.py
  function chunks (line 30) | def chunks(lst, n):
  function de_emojify (line 35) | def de_emojify(text):
  function preprocessing (line 44) | def preprocessing(text):
  function update_tweet_dict (line 47) | def update_tweet_dict(tweets, tweet_dict):
  function write_dict_to_file (line 58) | def write_dict_to_file(filename, dic):

FILE: src/retrieve/beir/examples/retrieval/evaluation/custom/evaluate_custom_model.py
  class YourCustomModel (line 12) | class YourCustomModel:
    method __init__ (line 13) | def __init__(self, model_path=None, **kwargs):
    method encode_queries (line 19) | def encode_queries(self, queries: List[str], batch_size: int = 16, **k...
    method encode_corpus (line 25) | def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int ...

FILE: src/retrieve/beir/examples/retrieval/evaluation/sparse/evaluate_anserini_docT5query_parallel.py
  function init_process (line 55) | def init_process(device, model_id):
  function _decide_device (line 70) | def _decide_device(cpu_procs):
  function _download_dataset (line 84) | def _download_dataset(dataset):
  function _generate_query (line 92) | def _generate_query(corpus_list):
  function _add_generated_queries_to_corpus (line 110) | def _add_generated_queries_to_corpus(num_procs, device, model_id, corpus):
  function _write_pyserini_corpus (line 124) | def _write_pyserini_corpus(pyserini_index_file, corpus):
  function _index_pyserini (line 139) | def _index_pyserini(pyserini_index_file, dataset):
  function _search_pyserini (line 148) | def _search_pyserini(queries, k):
  function _print_retrieval_examples (line 164) | def _print_retrieval_examples(corpus, queries, results):
  function main (line 177) | def main():

FILE: src/retrieve/beir/examples/retrieval/training/train_msmarco_v3.py
  class MSMARCODataset (line 106) | class MSMARCODataset(Dataset):
    method __init__ (line 107) | def __init__(self, queries, corpus):
    method __getitem__ (line 117) | def __getitem__(self, item):
    method __len__ (line 131) | def __len__(self):

FILE: src/retrieve/beir/examples/retrieval/training/train_msmarco_v3_bpr.py
  class MSMARCODataset (line 110) | class MSMARCODataset(Dataset):
    method __init__ (line 111) | def __init__(self, queries, corpus):
    method __getitem__ (line 121) | def __getitem__(self, item):
    method __len__ (line 135) | def __len__(self):

FILE: src/retrieve/beir/examples/retrieval/training/train_msmarco_v3_margin_MSE.py
  class MSMARCODataset (line 106) | class MSMARCODataset(Dataset):
    method __init__ (line 107) | def __init__(self, queries, corpus):
    method __getitem__ (line 117) | def __getitem__(self, item):
    method __len__ (line 133) | def __len__(self):

FILE: src/retrieve/retriever.py
  function get_random_doc_id (line 22) | def get_random_doc_id():
  class BM25 (line 25) | class BM25:
    method __init__ (line 26) | def __init__(
    method retrieve (line 42) | def retrieve(
  function bm25search_search (line 97) | def bm25search_search(self, corpus: Dict[str, Dict[str, str]], queries: ...
  function elasticsearch_lexical_multisearch (line 126) | def elasticsearch_lexical_multisearch(self, texts: List[str], top_hits: ...
  function elasticsearch_hit_template (line 173) | def elasticsearch_hit_template(self, es_res: Dict[str, object], hits: Li...
  function bm25_retrieve (line 204) | def bm25_retrieve(question, topk):

FILE: src/utils.py
  class BaseDataset (line 16) | class BaseDataset:
    method normalize_answer (line 18) | def normalize_answer(cls, s):
    method exact_match_score (line 31) | def exact_match_score(
    method f1_score (line 45) | def f1_score(
  function load_data (line 78) | def load_data(data_name, data_type, model_name):
  function get_model_path (line 122) | def get_model_path(model_name):
  function get_model (line 133) | def get_model(model_name, max_new_tokens=20):
  function model_generate (line 154) | def model_generate(prompt, model, tokenizer, generation_config):
  function read_complete (line 176) | def read_complete(filepath):
  function evaluate (line 185) | def evaluate(pred, ground_truth, with_cot=False):
  function predict (line 221) | def predict(model, tokenizer, generation_config, question, with_cot, pas...

FILE: src/warmup_lora.py
  class TrainingData (line 25) | class TrainingData(Dataset):
    method __init__ (line 28) | def __init__(self, origin_dataset, tokenizer, args):
    method __len__ (line 65) | def __len__(self):
    method __getitem__ (line 68) | def __getitem__(self, idx) -> Dict[str, list]:
  class TrainingDataCollator (line 72) | class TrainingDataCollator(DefaultDataCollator):
    method __init__ (line 73) | def __init__(self, tokenizer, device):
    method __call__ (line 78) | def __call__(self, examples: List[Dict[str, list]]) -> Dict[str, torch...
  function main (line 89) | def main(args):
Condensed preview — 132 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (506K chars).
[
  {
    "path": "README.md",
    "chars": 10824,
    "preview": "# Parametric RAG\n\n📢 **News: this work has been accepted at the SIGIR 2025!**\n\n\n\n\n![Overall Analysis](assets/overall.png)"
  },
  {
    "path": "all_prompt.md",
    "chars": 17495,
    "preview": "# Prompt Design for Our Work\n\nThis repository contains all the prompts involved in our work, categorized and explained f"
  },
  {
    "path": "configs/2wikimultihopqa_llama3-8b-instruct.sh",
    "chars": 553,
    "preview": "python3 src/encode.py \\\n    --model_name=llama3-8b-instruct \\\n    --dataset=2wikimultihopqa \\\n    --sample=300 \\\n    --p"
  },
  {
    "path": "configs/2wikimultihopqa_llama3.2-1b-instruct.sh",
    "chars": 557,
    "preview": "python3 src/encode.py \\\n    --model_name=llama3.2-1b-instruct \\\n    --dataset=2wikimultihopqa \\\n    --sample=300 \\\n    -"
  },
  {
    "path": "configs/2wikimultihopqa_qwen2.5-1.5b-instruct.sh",
    "chars": 559,
    "preview": "python3 src/encode.py \\\n    --model_name=qwen2.5-1.5b-instruct \\\n    --dataset=2wikimultihopqa \\\n    --sample=300 \\\n    "
  },
  {
    "path": "configs/complexwebquestions_llama3-8b-instruct.sh",
    "chars": 528,
    "preview": "python3 src/encode.py \\\n    --model_name=llama3-8b-instruct \\\n    --dataset=complexwebquestions \\\n    --sample=300 \\\n   "
  },
  {
    "path": "configs/complexwebquestions_llama3.2-1b-instruct.sh",
    "chars": 532,
    "preview": "python3 src/encode.py \\\n    --model_name=llama3.2-1b-instruct \\\n    --dataset=complexwebquestions \\\n    --sample=300 \\\n "
  },
  {
    "path": "configs/complexwebquestions_qwen2.5-1.5b-instruct.sh",
    "chars": 534,
    "preview": "python3 src/encode.py \\\n    --model_name=qwen2.5-1.5b-instruct \\\n    --dataset=complexwebquestions \\\n    --sample=300 \\\n"
  },
  {
    "path": "configs/hotpotqa_llama3-8b-instruct.sh",
    "chars": 539,
    "preview": "python3 src/encode.py \\\n    --model_name=llama3-8b-instruct \\\n    --dataset=hotpotqa \\\n    --sample=300 \\\n    --per_devi"
  },
  {
    "path": "configs/hotpotqa_llama3.2-1b-instruct.sh",
    "chars": 543,
    "preview": "python3 src/encode.py \\\n    --model_name=llama3.2-1b-instruct \\\n    --dataset=hotpotqa \\\n    --sample=300 \\\n    --per_de"
  },
  {
    "path": "configs/hotpotqa_qwen2.5-1.5b-instruct.sh",
    "chars": 545,
    "preview": "python3 src/encode.py \\\n    --model_name=qwen2.5-1.5b-instruct \\\n    --dataset=hotpotqa \\\n    --sample=300 \\\n    --per_d"
  },
  {
    "path": "configs/popqa_llama3-8b-instruct.sh",
    "chars": 500,
    "preview": "python3 src/encode.py \\\n    --model_name=llama3-8b-instruct \\\n    --dataset=popqa \\\n    --sample=300 \\\n    --per_device_"
  },
  {
    "path": "configs/popqa_llama3.2-1b-instruct.sh",
    "chars": 504,
    "preview": "python3 src/encode.py \\\n    --model_name=llama3.2-1b-instruct \\\n    --dataset=popqa \\\n    --sample=300 \\\n    --per_devic"
  },
  {
    "path": "configs/popqa_qwen2.5-1.5b-instruct.sh",
    "chars": 506,
    "preview": "python3 src/encode.py \\\n    --model_name=qwen2.5-1.5b-instruct \\\n    --dataset=popqa \\\n    --sample=300 \\\n    --per_devi"
  },
  {
    "path": "prep_elastic.py",
    "chars": 1963,
    "preview": "import argparse\nimport glob\nimport time\nimport csv\nfrom tqdm import tqdm\nfrom src.retrieve.beir.beir.retrieval.search.le"
  },
  {
    "path": "requirements.txt",
    "chars": 124,
    "preview": "torch==1.13.1\ntransformers==4.44.2\nelasticsearch==8.15.0\npeft==0.13.2\npandas==1.5.3\nnumpy==1.26.4\nfaiss-cpu==1.8.0\ntermc"
  },
  {
    "path": "src/augment.py",
    "chars": 10350,
    "preview": "import os\nimport json\nimport random\nimport argparse\nimport pandas as pd\nfrom tqdm import tqdm\n\nfrom retrieve.retriever i"
  },
  {
    "path": "src/encode.py",
    "chars": 7947,
    "preview": "import os\nimport gc\nimport time\nimport argparse\nimport torch\nfrom tqdm import tqdm\nfrom peft import TaskType, get_peft_m"
  },
  {
    "path": "src/fewshot/2wikimultihopqa.json",
    "chars": 2474,
    "preview": "[\n    {\n        \"question\": \"When did the director of film Hypocrite (Film) die?\",\n        \"answer\": \"The film Hypocrite"
  },
  {
    "path": "src/fewshot/hotpotqa.json",
    "chars": 3389,
    "preview": "[\n    {\n        \"question\": \"Jeremy Theobald and Christopher Nolan share what profession?\",\n        \"answer\": \"Jeremy Th"
  },
  {
    "path": "src/get_warmup_data.py",
    "chars": 7191,
    "preview": "import os\nimport json\nimport pandas as pd\nimport random\nimport torch\nfrom tqdm import tqdm\n\nimport prompt_template\nfrom "
  },
  {
    "path": "src/inference.py",
    "chars": 5564,
    "preview": "import os\nimport gc\nimport json\nimport argparse\nimport torch\nfrom tqdm import tqdm\nfrom peft import PeftModel\n\nimport pr"
  },
  {
    "path": "src/prompt_template.py",
    "chars": 2792,
    "preview": "import os\nfrom root_dir_path import ROOT_DIR\n\ncurrent_dataset = None\nfewshot = None\nfewshot_path = os.path.join(ROOT_DIR"
  },
  {
    "path": "src/retrieve/beir/.gitignore",
    "chars": 1922,
    "preview": "# Custom added\nexamples/**/datasets/\nexamples/**/output/\nexamples/**/DeepCT/\nexamples/**/models/\nexamples/**/faiss-index"
  },
  {
    "path": "src/retrieve/beir/.gitmodules",
    "chars": 195,
    "preview": "[submodule \"examples/retrieval/evaluation/late-interaction/beir-ColBERT\"]\n\tpath = examples/retrieval/evaluation/late-int"
  },
  {
    "path": "src/retrieve/beir/CONTRIBUTORS.txt",
    "chars": 181,
    "preview": "Individual Contributors to the BEIR Repository (BEIR contributors) include:\n1. Nandan Thakur\n2. Nils Reimers\n3. Iryna Gu"
  },
  {
    "path": "src/retrieve/beir/LICENSE",
    "chars": 11348,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "src/retrieve/beir/NOTICE.txt",
    "chars": 530,
    "preview": "-------------------------------------------------------------------------------\nCopyright since 2022\nUniversity of Water"
  },
  {
    "path": "src/retrieve/beir/README.md",
    "chars": 17160,
    "preview": "<h1 align=\"center\">\n<img style=\"vertical-align:middle\" width=\"450\" height=\"180\" src=\"https://raw.githubusercontent.com/b"
  },
  {
    "path": "src/retrieve/beir/beir/__init__.py",
    "chars": 35,
    "preview": "from .logging import LoggingHandler"
  },
  {
    "path": "src/retrieve/beir/beir/datasets/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/retrieve/beir/beir/datasets/data_loader.py",
    "chars": 4961,
    "preview": "from typing import Dict, Tuple\nfrom tqdm.autonotebook import tqdm\nimport json\nimport os\nimport logging\nimport csv\n\nlogge"
  },
  {
    "path": "src/retrieve/beir/beir/datasets/data_loader_hf.py",
    "chars": 5641,
    "preview": "from collections import defaultdict\nfrom typing import Dict, Tuple\nimport os\nimport logging\nfrom datasets import load_da"
  },
  {
    "path": "src/retrieve/beir/beir/generation/__init__.py",
    "chars": 54,
    "preview": "from .generate import QueryGenerator, PassageExpansion"
  },
  {
    "path": "src/retrieve/beir/beir/generation/generate.py",
    "chars": 7748,
    "preview": "from tqdm.autonotebook import trange\nfrom ..util import write_to_json, write_to_tsv\nfrom typing import Dict\nimport loggi"
  },
  {
    "path": "src/retrieve/beir/beir/generation/models/__init__.py",
    "chars": 58,
    "preview": "from .auto_model import QGenModel\nfrom .tilde import TILDE"
  },
  {
    "path": "src/retrieve/beir/beir/generation/models/auto_model.py",
    "chars": 7413,
    "preview": "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\nfrom tqdm.autonotebook import trange\nimport torch, logging"
  },
  {
    "path": "src/retrieve/beir/beir/generation/models/tilde.py",
    "chars": 3156,
    "preview": "from transformers import BertLMHeadModel, BertTokenizer, DataCollatorWithPadding\nfrom tqdm.autonotebook import trange\nim"
  },
  {
    "path": "src/retrieve/beir/beir/logging.py",
    "chars": 405,
    "preview": "import logging\nimport tqdm\n\nclass LoggingHandler(logging.Handler):\n    def __init__(self, level=logging.NOTSET):\n       "
  },
  {
    "path": "src/retrieve/beir/beir/losses/__init__.py",
    "chars": 72,
    "preview": "from .bpr_loss import BPRLoss\nfrom .margin_mse_loss import MarginMSELoss"
  },
  {
    "path": "src/retrieve/beir/beir/losses/bpr_loss.py",
    "chars": 4479,
    "preview": "import math\nimport torch\nfrom typing import Iterable, Dict\nfrom sentence_transformers import SentenceTransformer, util\n\n"
  },
  {
    "path": "src/retrieve/beir/beir/losses/margin_mse_loss.py",
    "chars": 1648,
    "preview": "from .. import util\nimport torch\nfrom torch import nn, Tensor\nfrom typing import Union, Tuple, List, Iterable, Dict\nfrom"
  },
  {
    "path": "src/retrieve/beir/beir/reranking/__init__.py",
    "chars": 26,
    "preview": "from .rerank import Rerank"
  },
  {
    "path": "src/retrieve/beir/beir/reranking/models/__init__.py",
    "chars": 67,
    "preview": "from .cross_encoder import CrossEncoder\nfrom .mono_t5 import MonoT5"
  },
  {
    "path": "src/retrieve/beir/beir/reranking/models/cross_encoder.py",
    "chars": 525,
    "preview": "from sentence_transformers.cross_encoder import CrossEncoder as CE\nimport numpy as np\nfrom typing import List, Dict, Tup"
  },
  {
    "path": "src/retrieve/beir/beir/reranking/models/mono_t5.py",
    "chars": 7319,
    "preview": "# Majority of the code has been copied from PyGaggle MonoT5 implementation\n# https://github.com/castorini/pygaggle/blob/"
  },
  {
    "path": "src/retrieve/beir/beir/reranking/rerank.py",
    "chars": 1923,
    "preview": "import logging\nfrom typing import Dict, List\n\nlogger = logging.getLogger(__name__)\n\n#Parent class for any reranking mode"
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/custom_metrics.py",
    "chars": 4129,
    "preview": "import logging\nfrom typing import List, Dict, Union, Tuple\n\ndef mrr(qrels: Dict[str, Dict[str, int]], \n        results: "
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/evaluation.py",
    "chars": 4831,
    "preview": "# import pytrec_eval\nimport logging\nfrom typing import List, Dict, Tuple\nfrom .search.base import BaseSearch\nfrom .custo"
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/models/__init__.py",
    "chars": 229,
    "preview": "from .sentence_bert import SentenceBERT\nfrom .use_qa import UseQA\nfrom .sparta import SPARTA\nfrom .dpr import DPR\nfrom ."
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/models/bpr.py",
    "chars": 1613,
    "preview": "from sentence_transformers import SentenceTransformer\nfrom torch import Tensor\nfrom typing import List, Dict, Union, Tup"
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/models/dpr.py",
    "chars": 2244,
    "preview": "from transformers import DPRContextEncoder, DPRContextEncoderTokenizerFast\nfrom transformers import DPRQuestionEncoder, "
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/models/sentence_bert.py",
    "chars": 3296,
    "preview": "from sentence_transformers import SentenceTransformer\nfrom torch import Tensor\nimport torch.multiprocessing as mp\nfrom t"
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/models/sparta.py",
    "chars": 3571,
    "preview": "from typing import List, Dict, Union, Tuple\nfrom tqdm.autonotebook import trange\nfrom transformers import AutoTokenizer,"
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/models/splade.py",
    "chars": 7671,
    "preview": "import logging\nfrom typing import List, Dict, Union\nimport numpy as np\nimport torch\nfrom numpy import ndarray\nfrom torch"
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/models/tldr.py",
    "chars": 2740,
    "preview": "from sentence_transformers import SentenceTransformer    \nimport torch\nfrom torch import Tensor\nfrom typing import List,"
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/models/unicoil.py",
    "chars": 7900,
    "preview": "from typing import Optional, List, Dict, Union, Tuple\nfrom transformers import BertConfig, BertModel, BertTokenizer, Pre"
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/models/use_qa.py",
    "chars": 2245,
    "preview": "import numpy as np\nimport importlib.util\nfrom typing import List, Dict\nfrom tqdm.autonotebook import trange\n\nif importli"
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/search/__init__.py",
    "chars": 28,
    "preview": "from .base import BaseSearch"
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/search/base.py",
    "chars": 316,
    "preview": "from abc import ABC, abstractmethod\nfrom typing import Dict\n\nclass BaseSearch(ABC):\n\n    @abstractmethod\n    def search("
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/search/dense/__init__.py",
    "chars": 294,
    "preview": "from .exact_search import DenseRetrievalExactSearch \nfrom .exact_search_multi_gpu import DenseRetrievalParallelExactSear"
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/search/dense/exact_search.py",
    "chars": 4810,
    "preview": "from .. import BaseSearch\nfrom .util import cos_sim, dot_score\nimport logging\nimport torch\nfrom typing import Dict\nimpor"
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/search/dense/exact_search_multi_gpu.py",
    "chars": 10703,
    "preview": "from .. import BaseSearch\nfrom .util import cos_sim, dot_score\nfrom sentence_transformers import SentenceTransformer\nfro"
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/search/dense/faiss_index.py",
    "chars": 7058,
    "preview": "from .util import normalize\nfrom typing import List, Optional, Tuple, Union\nfrom tqdm.autonotebook import trange\nimport "
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/search/dense/faiss_search.py",
    "chars": 22824,
    "preview": "from .. import BaseSearch\nfrom .util import save_dict_to_tsv, load_tsv_to_dict\nfrom .faiss_index import FaissBinaryIndex"
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/search/dense/util.py",
    "chars": 1890,
    "preview": "import torch\nimport numpy as np\nimport csv\n\ndef cos_sim(a: torch.Tensor, b: torch.Tensor):\n    \"\"\"\n    Computes the cosi"
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/search/lexical/__init__.py",
    "chars": 35,
    "preview": "from .bm25_search import BM25Search"
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/search/lexical/bm25_search.py",
    "chars": 3234,
    "preview": "from .. import BaseSearch\nfrom .elastic_search import ElasticSearch\nimport tqdm\nimport time\nfrom typing import List, Dic"
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/search/lexical/elastic_search.py",
    "chars": 9859,
    "preview": "from elasticsearch import Elasticsearch\nfrom elasticsearch.helpers import streaming_bulk\nfrom typing import Dict, List, "
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/search/sparse/__init__.py",
    "chars": 39,
    "preview": "from .sparse_search import SparseSearch"
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/search/sparse/sparse_search.py",
    "chars": 1777,
    "preview": "from .. import BaseSearch\nfrom tqdm.autonotebook import trange\nfrom typing import Dict\nimport logging\nimport numpy as np"
  },
  {
    "path": "src/retrieve/beir/beir/retrieval/train.py",
    "chars": 6625,
    "preview": "from sentence_transformers import SentenceTransformer, SentencesDataset, datasets\nfrom sentence_transformers.evaluation "
  },
  {
    "path": "src/retrieve/beir/beir/util.py",
    "chars": 3808,
    "preview": "from typing import Dict\nfrom tqdm.autonotebook import tqdm\nimport csv\nimport torch\nimport json\nimport logging\nimport os\n"
  },
  {
    "path": "src/retrieve/beir/examples/beir-pyserini/Dockerfile",
    "chars": 829,
    "preview": "FROM python:3.6-slim\n\n# Install Java first, to better take advantage of layer caching.\n#\n# Note (1): first mkdir line fi"
  },
  {
    "path": "src/retrieve/beir/examples/beir-pyserini/config.py",
    "chars": 366,
    "preview": "from pydantic import BaseSettings\n\nclass IndexSettings(BaseSettings):\n    index_name: str = \"beir/test\"\n    data_folder:"
  },
  {
    "path": "src/retrieve/beir/examples/beir-pyserini/dockerhub.sh",
    "chars": 284,
    "preview": "#!/bin/sh\n#This tagname build the docker hub containers\n\n# TAGNAME=\"1.0\"\n\n# docker build --no-cache -t beir/pyserini-fas"
  },
  {
    "path": "src/retrieve/beir/examples/beir-pyserini/main.py",
    "chars": 3076,
    "preview": "import sys, os\nimport config\n\nfrom fastapi import FastAPI, File, UploadFile\nfrom pyserini.search import SimpleSearcher\nf"
  },
  {
    "path": "src/retrieve/beir/examples/benchmarking/benchmark_bm25.py",
    "chars": 2670,
    "preview": "from beir import util, LoggingHandler\nfrom beir.datasets.data_loader import GenericDataLoader\nfrom beir.retrieval.evalua"
  },
  {
    "path": "src/retrieve/beir/examples/benchmarking/benchmark_bm25_ce_reranking.py",
    "chars": 3403,
    "preview": "from beir import util, LoggingHandler\nfrom beir.datasets.data_loader import GenericDataLoader\nfrom beir.retrieval.evalua"
  },
  {
    "path": "src/retrieve/beir/examples/benchmarking/benchmark_sbert.py",
    "chars": 3886,
    "preview": "from beir import util, LoggingHandler\nfrom beir.retrieval import models\nfrom beir.datasets.data_loader import GenericDat"
  },
  {
    "path": "src/retrieve/beir/examples/dataset/README.md",
    "chars": 2879,
    "preview": "# Dataset Information\n\nGenerally, all public datasets can be easily downloaded using the zip folder.\n\nBelow we mention h"
  },
  {
    "path": "src/retrieve/beir/examples/dataset/download_dataset.py",
    "chars": 929,
    "preview": "import os\nimport pathlib\nfrom beir import util\n\ndef main():\n    \n    out_dir = pathlib.Path(__file__).parent.absolute()\n"
  },
  {
    "path": "src/retrieve/beir/examples/dataset/md5.csv",
    "chars": 754,
    "preview": "dataset,md5\nmsmarco.zip,444067daf65d982533ea17ebd59501e4\ntrec-covid.zip,ce62140cb23feb9becf6270d0d1fe6d1\nnfcorpus.zip,a8"
  },
  {
    "path": "src/retrieve/beir/examples/dataset/scrape_tweets.py",
    "chars": 3413,
    "preview": "'''\nThe following is a basic twitter scraper code using tweepy.\nWe preprocess the text - 1. Remove Emojis 2. Remove urls"
  },
  {
    "path": "src/retrieve/beir/examples/generation/passage_expansion_tilde.py",
    "chars": 2093,
    "preview": "from beir import util, LoggingHandler\nfrom beir.datasets.data_loader import GenericDataLoader\nfrom beir.generation impor"
  },
  {
    "path": "src/retrieve/beir/examples/generation/query_gen.py",
    "chars": 1954,
    "preview": "from beir import util, LoggingHandler\nfrom beir.datasets.data_loader import GenericDataLoader\nfrom beir.generation impor"
  },
  {
    "path": "src/retrieve/beir/examples/generation/query_gen_and_train.py",
    "chars": 4007,
    "preview": "from beir import util, LoggingHandler\nfrom beir.datasets.data_loader import GenericDataLoader\nfrom beir.generation impor"
  },
  {
    "path": "src/retrieve/beir/examples/generation/query_gen_multi_gpu.py",
    "chars": 3002,
    "preview": "\"\"\"\nThis code shows how to generate using parallel GPU's for very long corpus.\nMultiple GPU's can be used to generate fa"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/README.md",
    "chars": 134,
    "preview": "# Retrieval\n\nThis folder contains various examples to evaluate, train retriever models for datasets in BEIR.\n\n## Overall"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/README.md",
    "chars": 76,
    "preview": "## Deep Dive into Evaluation of Retrieval Models\n\n### Leaderboard Overall \n\n"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/custom/evaluate_custom_dataset.py",
    "chars": 2924,
    "preview": "from beir import LoggingHandler\nfrom beir.retrieval import models\nfrom beir.datasets.data_loader import GenericDataLoade"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/custom/evaluate_custom_dataset_files.py",
    "chars": 2626,
    "preview": "from beir import LoggingHandler\nfrom beir.retrieval import models\nfrom beir.datasets.data_loader import GenericDataLoade"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/custom/evaluate_custom_metrics.py",
    "chars": 2911,
    "preview": "from beir import util, LoggingHandler\nfrom beir.retrieval import models\nfrom beir.datasets.data_loader import GenericDat"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/custom/evaluate_custom_model.py",
    "chars": 2446,
    "preview": "from beir import util, LoggingHandler\nfrom beir.datasets.data_loader import GenericDataLoader\nfrom beir.retrieval.evalua"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/dense/evaluate_ance.py",
    "chars": 2449,
    "preview": "from beir import util, LoggingHandler\nfrom beir.retrieval import models\nfrom beir.datasets.data_loader import GenericDat"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/dense/evaluate_bpr.py",
    "chars": 8359,
    "preview": "\"\"\"\nThe pre-trained models produce embeddings of size 512 - 1024. However, when storing a large\nnumber of embeddings, th"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/dense/evaluate_dim_reduction.py",
    "chars": 6227,
    "preview": "\"\"\"\nThe pre-trained models produce embeddings of size 512 - 1024. However, when storing a large\nnumber of embeddings, th"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/dense/evaluate_dpr.py",
    "chars": 3645,
    "preview": "from beir import util, LoggingHandler\nfrom beir.retrieval import models\nfrom beir.datasets.data_loader import GenericDat"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/dense/evaluate_faiss_dense.py",
    "chars": 6377,
    "preview": "\"\"\"\nIn this example, we show how to utilize different faiss indexes for evaluation in BEIR. We currently support \nIndexF"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/dense/evaluate_sbert.py",
    "chars": 2893,
    "preview": "from time import time\nfrom beir import util, LoggingHandler\nfrom beir.retrieval import models\nfrom beir.datasets.data_lo"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/dense/evaluate_sbert_hf_loader.py",
    "chars": 3720,
    "preview": "from collections import defaultdict\nfrom beir import util, LoggingHandler\nfrom beir.retrieval import models\nfrom beir.da"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/dense/evaluate_sbert_multi_gpu.py",
    "chars": 3991,
    "preview": "'''\nThis sample python shows how to evaluate BEIR dataset quickly using Mutliple GPU for evaluation (for large datasets)"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/dense/evaluate_tldr.py",
    "chars": 4879,
    "preview": "'''\nIn this example, we show how to evaluate TLDR: Twin Learning Dimensionality Reduction using the BEIR Benchmark.\nTLDR"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/dense/evaluate_useqa.py",
    "chars": 2574,
    "preview": "from beir import util, LoggingHandler\nfrom beir.retrieval import models\nfrom beir.datasets.data_loader import GenericDat"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/late-interaction/README.md",
    "chars": 5851,
    "preview": "# BEIR Evaluation with ColBERT\n\nIn this example, we show how to evaluate the ColBERT zero-shot model on the BEIR Benchma"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/lexical/evaluate_anserini_bm25.py",
    "chars": 4009,
    "preview": "\"\"\"\nThis example shows how to evaluate Anserini-BM25 in BEIR.\nSince Anserini uses Java-11, we would advise you to use do"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/lexical/evaluate_bm25.py",
    "chars": 3699,
    "preview": "\"\"\"\nThis example show how to evaluate BM25 model (Elasticsearch) in BEIR.\nTo be able to run Elasticsearch, you should ha"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/lexical/evaluate_multilingual_bm25.py",
    "chars": 4202,
    "preview": "\"\"\"\nThis example show how to evaluate BM25 model (Elasticsearch) in BEIR for German.\nThis script can be used to any eval"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/reranking/README.md",
    "chars": 3167,
    "preview": "### Re-ranking BM25 top-100 using Cross-Encoder (Leaderboard)\n\nIn table below, we evaluate various different reranking a"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/reranking/evaluate_bm25_ce_reranking.py",
    "chars": 3233,
    "preview": "from beir import util, LoggingHandler\nfrom beir.datasets.data_loader import GenericDataLoader\nfrom beir.retrieval.evalua"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/reranking/evaluate_bm25_monot5_reranking.py",
    "chars": 4329,
    "preview": "from beir import util, LoggingHandler\nfrom beir.datasets.data_loader import GenericDataLoader\nfrom beir.retrieval.evalua"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/reranking/evaluate_bm25_sbert_reranking.py",
    "chars": 2512,
    "preview": "from beir import util, LoggingHandler\nfrom beir.datasets.data_loader import GenericDataLoader\nfrom beir.retrieval.evalua"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/sparse/evaluate_anserini_docT5query.py",
    "chars": 5450,
    "preview": "\"\"\"\nThis example shows how to evaluate DocTTTTTquery in BEIR.\n\nSince Anserini uses Java-11, we would advise you to use d"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/sparse/evaluate_anserini_docT5query_parallel.py",
    "chars": 7972,
    "preview": "\"\"\"\nThis example shows how to evaluate docTTTTTquery in BEIR.\n\nSince Anserini uses Java 11, we would advise you to use d"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/sparse/evaluate_deepct.py",
    "chars": 7779,
    "preview": "\"\"\"\nThis example shows how to evaluate DeepCT (using Anserini) in BEIR.\nFor more details on DeepCT, refer here: https://"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/sparse/evaluate_sparta.py",
    "chars": 2268,
    "preview": "from beir import util, LoggingHandler\nfrom beir.retrieval import models\nfrom beir.datasets.data_loader import GenericDat"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/sparse/evaluate_splade.py",
    "chars": 3071,
    "preview": "from beir import util, LoggingHandler\nfrom beir.retrieval import models\nfrom beir.datasets.data_loader import GenericDat"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/evaluation/sparse/evaluate_unicoil.py",
    "chars": 2942,
    "preview": "from beir import util, LoggingHandler\nfrom beir.retrieval import models\nfrom beir.datasets.data_loader import GenericDat"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/training/train_msmarco_v2.py",
    "chars": 4998,
    "preview": "'''\n\"\"\"\nThis examples show how to train a Bi-Encoder for the MS Marco dataset (https://github.com/microsoft/MSMARCO-Pass"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/training/train_msmarco_v3.py",
    "chars": 7841,
    "preview": "'''\nThis example shows how to train a SOTA Bi-Encoder for the MS Marco dataset (https://github.com/microsoft/MSMARCO-Pas"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/training/train_msmarco_v3_bpr.py",
    "chars": 8120,
    "preview": "'''\nThis example shows how to train a Binary-Code (Binary Passage Retriever) based Bi-Encoder for the MS Marco dataset ("
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/training/train_msmarco_v3_margin_MSE.py",
    "chars": 8142,
    "preview": "'''\nThis example shows how to train a SOTA Bi-Encoder with Margin-MSE loss for the MS Marco dataset (https://github.com/"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/training/train_sbert.py",
    "chars": 3620,
    "preview": "'''\nThis examples show how to train a basic Bi-Encoder for any BEIR dataset without any mined hard negatives or triplets"
  },
  {
    "path": "src/retrieve/beir/examples/retrieval/training/train_sbert_BM25_hardnegs.py",
    "chars": 5627,
    "preview": "'''\nThis examples show how to train a Bi-Encoder for any BEIR dataset.\n\nThe queries and passages are passed independentl"
  },
  {
    "path": "src/retrieve/beir/setup.cfg",
    "chars": 39,
    "preview": "[metadata]\ndescription-file = README.md"
  },
  {
    "path": "src/retrieve/beir/setup.py",
    "chars": 1304,
    "preview": "from setuptools import setup, find_packages\n\nwith open(\"README.md\", mode=\"r\", encoding=\"utf-8\") as readme_file:\n    read"
  },
  {
    "path": "src/retrieve/readme.md",
    "chars": 305,
    "preview": "This repository uses [beir2.0.0](https://github.com/beir-cellar/beir/releases/tag/v2.0.0) for BM25.\n\nDue to server limit"
  },
  {
    "path": "src/retrieve/retriever.py",
    "chars": 7211,
    "preview": "from typing import List, Dict, Tuple\nimport os\nimport time\nimport tqdm\nimport uuid\nimport numpy as np\nimport torch\nimpor"
  },
  {
    "path": "src/root_dir_path.py",
    "chars": 25,
    "preview": "ROOT_DIR = \"path_to_PRAG\""
  },
  {
    "path": "src/utils.py",
    "chars": 8607,
    "preview": "import os\nimport re\nimport json\nimport torch\nimport string\nimport numpy as np\nfrom collections import Counter\nfrom typin"
  },
  {
    "path": "src/warmup_lora.py",
    "chars": 6445,
    "preview": "import os\nimport gc\nimport json\nimport numpy as np\nimport random\nimport argparse\nimport torch\nfrom tqdm import tqdm\nfrom"
  }
]

About this extraction

This page contains the full source code of the oneal2000/PRAG GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 132 files (468.4 KB), approximately 122.0k tokens, and a symbol index with 352 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!