[
  {
    "path": "CC-BY-NC-SA-4.0.txt",
    "content": "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International\n\nCreative Commons Corporation (\"Creative Commons\") is not a law firm and\ndoes not provide legal services or legal advice. Distribution of\nCreative Commons public licenses does not create a lawyer-client or\nother relationship. Creative Commons makes its licenses and related\ninformation available on an \"as-is\" basis. Creative Commons gives no\nwarranties regarding its licenses, any material licensed under their\nterms and conditions, or any related information. Creative Commons\ndisclaims all liability for damages resulting from their use to the\nfullest extent possible.\n\nUsing Creative Commons Public Licenses\n\nCreative Commons public licenses provide a standard set of terms and\nconditions that creators and other rights holders may use to share\noriginal works of authorship and other material subject to copyright and\ncertain other rights specified in the public license below. The\nfollowing considerations are for informational purposes only, are not\nexhaustive, and do not form part of our licenses.\n\nConsiderations for licensors: Our public licenses are intended for use\nby those authorized to give the public permission to use material in\nways otherwise restricted by copyright and certain other rights. Our\nlicenses are irrevocable. Licensors should read and understand the terms\nand conditions of the license they choose before applying it. Licensors\nshould also secure all rights necessary before applying our licenses so\nthat the public can reuse the material as expected. Licensors should\nclearly mark any material not subject to the license. This includes\nother CC-licensed material, or material used under an exception or\nlimitation to copyright. More considerations for licensors :\nwiki.creativecommons.org/Considerations_for_licensors\n\nConsiderations for the public: By using one of our public licenses, a\nlicensor grants the public permission to use the licensed material under\nspecified terms and conditions. If the licensor's permission is not\nnecessary for any reason–for example, because of any applicable\nexception or limitation to copyright–then that use is not regulated by\nthe license. Our licenses grant only permissions under copyright and\ncertain other rights that a licensor has authority to grant. Use of the\nlicensed material may still be restricted for other reasons, including\nbecause others have copyright or other rights in the material. A\nlicensor may make special requests, such as asking that all changes be\nmarked or described. Although not required by our licenses, you are\nencouraged to respect those requests where reasonable. More\nconsiderations for the public :\nwiki.creativecommons.org/Considerations_for_licensees\n\nCreative Commons Attribution-NonCommercial-ShareAlike 4.0 International\nPublic License\n\nBy exercising the Licensed Rights (defined below), You accept and agree\nto be bound by the terms and conditions of this Creative Commons\nAttribution-NonCommercial-ShareAlike 4.0 International Public License\n(\"Public License\"). To the extent this Public License may be interpreted\nas a contract, You are granted the Licensed Rights in consideration of\nYour acceptance of these terms and conditions, and the Licensor grants\nYou such rights in consideration of benefits the Licensor receives from\nmaking the Licensed Material available under these terms and conditions.\n\nSection 1 – Definitions.\n\n-   a. Adapted Material means material subject to Copyright and Similar\n    Rights that is derived from or based upon the Licensed Material and\n    in which the Licensed Material is translated, altered, arranged,\n    transformed, or otherwise modified in a manner requiring permission\n    under the Copyright and Similar Rights held by the Licensor. For\n    purposes of this Public License, where the Licensed Material is a\n    musical work, performance, or sound recording, Adapted Material is\n    always produced where the Licensed Material is synched in timed\n    relation with a moving image.\n-   b. Adapter's License means the license You apply to Your Copyright\n    and Similar Rights in Your contributions to Adapted Material in\n    accordance with the terms and conditions of this Public License.\n-   c. BY-NC-SA Compatible License means a license listed at\n    creativecommons.org/compatiblelicenses, approved by Creative Commons\n    as essentially the equivalent of this Public License.\n-   d. Copyright and Similar Rights means copyright and/or similar\n    rights closely related to copyright including, without limitation,\n    performance, broadcast, sound recording, and Sui Generis Database\n    Rights, without regard to how the rights are labeled or categorized.\n    For purposes of this Public License, the rights specified in Section\n    2(b)(1)-(2) are not Copyright and Similar Rights.\n-   e. Effective Technological Measures means those measures that, in\n    the absence of proper authority, may not be circumvented under laws\n    fulfilling obligations under Article 11 of the WIPO Copyright Treaty\n    adopted on December 20, 1996, and/or similar international\n    agreements.\n-   f. Exceptions and Limitations means fair use, fair dealing, and/or\n    any other exception or limitation to Copyright and Similar Rights\n    that applies to Your use of the Licensed Material.\n-   g. License Elements means the license attributes listed in the name\n    of a Creative Commons Public License. The License Elements of this\n    Public License are Attribution, NonCommercial, and ShareAlike.\n-   h. Licensed Material means the artistic or literary work, database,\n    or other material to which the Licensor applied this Public License.\n-   i. Licensed Rights means the rights granted to You subject to the\n    terms and conditions of this Public License, which are limited to\n    all Copyright and Similar Rights that apply to Your use of the\n    Licensed Material and that the Licensor has authority to license.\n-   j. Licensor means the individual(s) or entity(ies) granting rights\n    under this Public License.\n-   k. NonCommercial means not primarily intended for or directed\n    towards commercial advantage or monetary compensation. For purposes\n    of this Public License, the exchange of the Licensed Material for\n    other material subject to Copyright and Similar Rights by digital\n    file-sharing or similar means is NonCommercial provided there is no\n    payment of monetary compensation in connection with the exchange.\n-   l. Share means to provide material to the public by any means or\n    process that requires permission under the Licensed Rights, such as\n    reproduction, public display, public performance, distribution,\n    dissemination, communication, or importation, and to make material\n    available to the public including in ways that members of the public\n    may access the material from a place and at a time individually\n    chosen by them.\n-   m. Sui Generis Database Rights means rights other than copyright\n    resulting from Directive 96/9/EC of the European Parliament and of\n    the Council of 11 March 1996 on the legal protection of databases,\n    as amended and/or succeeded, as well as other essentially equivalent\n    rights anywhere in the world.\n-   n. You means the individual or entity exercising the Licensed Rights\n    under this Public License. Your has a corresponding meaning.\n\nSection 2 – Scope.\n\n-   a. License grant.\n    -   1. Subject to the terms and conditions of this Public License,\n        the Licensor hereby grants You a worldwide, royalty-free,\n        non-sublicensable, non-exclusive, irrevocable license to\n        exercise the Licensed Rights in the Licensed Material to:\n        -   A. reproduce and Share the Licensed Material, in whole or in\n            part, for NonCommercial purposes only; and\n        -   B. produce, reproduce, and Share Adapted Material for\n            NonCommercial purposes only.\n    -   2. Exceptions and Limitations. For the avoidance of doubt, where\n        Exceptions and Limitations apply to Your use, this Public\n        License does not apply, and You do not need to comply with its\n        terms and conditions.\n    -   3. Term. The term of this Public License is specified in Section\n        6(a).\n    -   4. Media and formats; technical modifications allowed. The\n        Licensor authorizes You to exercise the Licensed Rights in all\n        media and formats whether now known or hereafter created, and to\n        make technical modifications necessary to do so. The Licensor\n        waives and/or agrees not to assert any right or authority to\n        forbid You from making technical modifications necessary to\n        exercise the Licensed Rights, including technical modifications\n        necessary to circumvent Effective Technological Measures. For\n        purposes of this Public License, simply making modifications\n        authorized by this Section 2(a)(4) never produces Adapted\n        Material.\n    -   5. Downstream recipients.\n        -   A. Offer from the Licensor – Licensed Material. Every\n            recipient of the Licensed Material automatically receives an\n            offer from the Licensor to exercise the Licensed Rights\n            under the terms and conditions of this Public License.\n        -   B. Additional offer from the Licensor – Adapted Material.\n            Every recipient of Adapted Material from You automatically\n            receives an offer from the Licensor to exercise the Licensed\n            Rights in the Adapted Material under the conditions of the\n            Adapter's License You apply.\n        -   C. No downstream restrictions. You may not offer or impose\n            any additional or different terms or conditions on, or apply\n            any Effective Technological Measures to, the Licensed\n            Material if doing so restricts exercise of the Licensed\n            Rights by any recipient of the Licensed Material.\n    -   6. No endorsement. Nothing in this Public License constitutes or\n        may be construed as permission to assert or imply that You are,\n        or that Your use of the Licensed Material is, connected with, or\n        sponsored, endorsed, or granted official status by, the Licensor\n        or others designated to receive attribution as provided in\n        Section 3(a)(1)(A)(i).\n-   b. Other rights.\n    -   1. Moral rights, such as the right of integrity, are not\n        licensed under this Public License, nor are publicity, privacy,\n        and/or other similar personality rights; however, to the extent\n        possible, the Licensor waives and/or agrees not to assert any\n        such rights held by the Licensor to the limited extent necessary\n        to allow You to exercise the Licensed Rights, but not otherwise.\n    -   2. Patent and trademark rights are not licensed under this\n        Public License.\n    -   3. To the extent possible, the Licensor waives any right to\n        collect royalties from You for the exercise of the Licensed\n        Rights, whether directly or through a collecting society under\n        any voluntary or waivable statutory or compulsory licensing\n        scheme. In all other cases the Licensor expressly reserves any\n        right to collect such royalties, including when the Licensed\n        Material is used other than for NonCommercial purposes.\n\nSection 3 – License Conditions.\n\nYour exercise of the Licensed Rights is expressly made subject to the\nfollowing conditions.\n\n-   a. Attribution.\n    -   1. If You Share the Licensed Material (including in modified\n        form), You must:\n        -   A. retain the following if it is supplied by the Licensor\n            with the Licensed Material:\n            -   i. identification of the creator(s) of the Licensed\n                Material and any others designated to receive\n                attribution, in any reasonable manner requested by the\n                Licensor (including by pseudonym if designated);\n            -   ii. a copyright notice;\n            -   iii. a notice that refers to this Public License;\n            -   iv. a notice that refers to the disclaimer of\n                warranties;\n            -   v. a URI or hyperlink to the Licensed Material to the\n                extent reasonably practicable;\n\n        -   B. indicate if You modified the Licensed Material and retain\n            an indication of any previous modifications; and\n        -   C. indicate the Licensed Material is licensed under this\n            Public License, and include the text of, or the URI or\n            hyperlink to, this Public License.\n    -   2. You may satisfy the conditions in Section 3(a)(1) in any\n        reasonable manner based on the medium, means, and context in\n        which You Share the Licensed Material. For example, it may be\n        reasonable to satisfy the conditions by providing a URI or\n        hyperlink to a resource that includes the required information.\n    -   3. If requested by the Licensor, You must remove any of the\n        information required by Section 3(a)(1)(A) to the extent\n        reasonably practicable.\n-   b. ShareAlike.In addition to the conditions in Section 3(a), if You\n    Share Adapted Material You produce, the following conditions also\n    apply.\n    -   1. The Adapter's License You apply must be a Creative Commons\n        license with the same License Elements, this version or later,\n        or a BY-NC-SA Compatible License.\n    -   2. You must include the text of, or the URI or hyperlink to, the\n        Adapter's License You apply. You may satisfy this condition in\n        any reasonable manner based on the medium, means, and context in\n        which You Share Adapted Material.\n    -   3. You may not offer or impose any additional or different terms\n        or conditions on, or apply any Effective Technological Measures\n        to, Adapted Material that restrict exercise of the rights\n        granted under the Adapter's License You apply.\n\nSection 4 – Sui Generis Database Rights.\n\nWhere the Licensed Rights include Sui Generis Database Rights that apply\nto Your use of the Licensed Material:\n\n-   a. for the avoidance of doubt, Section 2(a)(1) grants You the right\n    to extract, reuse, reproduce, and Share all or a substantial portion\n    of the contents of the database for NonCommercial purposes only;\n-   b. if You include all or a substantial portion of the database\n    contents in a database in which You have Sui Generis Database\n    Rights, then the database in which You have Sui Generis Database\n    Rights (but not its individual contents) is Adapted Material,\n    including for purposes of Section 3(b); and\n-   c. You must comply with the conditions in Section 3(a) if You Share\n    all or a substantial portion of the contents of the database.\n    For the avoidance of doubt, this Section 4 supplements and does not\n    replace Your obligations under this Public License where the\n    Licensed Rights include other Copyright and Similar Rights.\n\nSection 5 – Disclaimer of Warranties and Limitation of Liability.\n\n-   a. Unless otherwise separately undertaken by the Licensor, to the\n    extent possible, the Licensor offers the Licensed Material as-is and\n    as-available, and makes no representations or warranties of any kind\n    concerning the Licensed Material, whether express, implied,\n    statutory, or other. This includes, without limitation, warranties\n    of title, merchantability, fitness for a particular purpose,\n    non-infringement, absence of latent or other defects, accuracy, or\n    the presence or absence of errors, whether or not known or\n    discoverable. Where disclaimers of warranties are not allowed in\n    full or in part, this disclaimer may not apply to You.\n-   b. To the extent possible, in no event will the Licensor be liable\n    to You on any legal theory (including, without limitation,\n    negligence) or otherwise for any direct, special, indirect,\n    incidental, consequential, punitive, exemplary, or other losses,\n    costs, expenses, or damages arising out of this Public License or\n    use of the Licensed Material, even if the Licensor has been advised\n    of the possibility of such losses, costs, expenses, or damages.\n    Where a limitation of liability is not allowed in full or in part,\n    this limitation may not apply to You.\n-   c. The disclaimer of warranties and limitation of liability provided\n    above shall be interpreted in a manner that, to the extent possible,\n    most closely approximates an absolute disclaimer and waiver of all\n    liability.\n\nSection 6 – Term and Termination.\n\n-   a. This Public License applies for the term of the Copyright and\n    Similar Rights licensed here. However, if You fail to comply with\n    this Public License, then Your rights under this Public License\n    terminate automatically.\n-   b. Where Your right to use the Licensed Material has terminated\n    under Section 6(a), it reinstates:\n\n    -   1. automatically as of the date the violation is cured, provided\n        it is cured within 30 days of Your discovery of the violation;\n        or\n    -   2. upon express reinstatement by the Licensor.\n\n    For the avoidance of doubt, this Section 6(b) does not affect any\n    right the Licensor may have to seek remedies for Your violations of\n    this Public License.\n\n-   c. For the avoidance of doubt, the Licensor may also offer the\n    Licensed Material under separate terms or conditions or stop\n    distributing the Licensed Material at any time; however, doing so\n    will not terminate this Public License.\n-   d. Sections 1, 5, 6, 7, and 8 survive termination of this Public\n    License.\n\nSection 7 – Other Terms and Conditions.\n\n-   a. The Licensor shall not be bound by any additional or different\n    terms or conditions communicated by You unless expressly agreed.\n-   b. Any arrangements, understandings, or agreements regarding the\n    Licensed Material not stated herein are separate from and\n    independent of the terms and conditions of this Public License.\n\nSection 8 – Interpretation.\n\n-   a. For the avoidance of doubt, this Public License does not, and\n    shall not be interpreted to, reduce, limit, restrict, or impose\n    conditions on any use of the Licensed Material that could lawfully\n    be made without permission under this Public License.\n-   b. To the extent possible, if any provision of this Public License\n    is deemed unenforceable, it shall be automatically reformed to the\n    minimum extent necessary to make it enforceable. If the provision\n    cannot be reformed, it shall be severed from this Public License\n    without affecting the enforceability of the remaining terms and\n    conditions.\n-   c. No term or condition of this Public License will be waived and no\n    failure to comply consented to unless expressly agreed to by the\n    Licensor.\n-   d. Nothing in this Public License constitutes or may be interpreted\n    as a limitation upon, or waiver of, any privileges and immunities\n    that apply to the Licensor or You, including from the legal\n    processes of any jurisdiction or authority.\n\nCreative Commons is not a party to its public licenses. Notwithstanding,\nCreative Commons may elect to apply one of its public licenses to\nmaterial it publishes and in those instances will be considered the\n\"Licensor.\" The text of the Creative Commons public licenses is\ndedicated to the public domain under the CC0 Public Domain Dedication.\nExcept for the limited purpose of indicating that material is shared\nunder a Creative Commons public license or as otherwise permitted by the\nCreative Commons policies published at creativecommons.org/policies,\nCreative Commons does not authorize the use of the trademark \"Creative\nCommons\" or any other trademark or logo of Creative Commons without its\nprior written consent including, without limitation, in connection with\nany unauthorized modifications to any of its public licenses or any\nother arrangements, understandings, or agreements concerning use of\nlicensed material. For the avoidance of doubt, this paragraph does not\nform part of the public licenses.\n\nCreative Commons may be contacted at creativecommons.org.\n"
  },
  {
    "path": "LaMP/data/datasets.py",
    "content": "from torch.utils.data import Dataset\nimport json\nimport datasets\nimport torch\n\ndef get_all_labels(task):\n    if task == \"LaMP-1\":\n        return [\"[1]\",\"[2]\"]\n    elif task == \"LaMP-2\":\n        return ['sci-fi', 'based on a book', 'comedy', 'action', 'twist ending', 'dystopia', 'dark comedy', 'classic', 'psychology', 'fantasy', 'romance', 'thought-provoking', 'social commentary', 'violence', 'true story']\n    elif task == \"LaMP-3\":\n        return [\"1\", \"2\", \"3\", \"4\", \"5\"]\n    elif task == \"LaMP-4\":\n        return []\n    elif task == \"LaMP-5\":\n        return []\n    elif task == \"LaMP-6\":\n        return []\n    elif task == \"LaMP-7\":\n        return []\n\ndef create_preprocessor(tokenizer, max_length):\n    def preprocess_dataset(examples):\n        inputs = [example for example in examples[\"source\"]]\n        targets = [example for example in examples[\"target\"]]\n        model_inputs = tokenizer(inputs, text_target=targets, max_length=max_length, truncation=True)\n        return model_inputs\n    return preprocess_dataset\n\ndef create_preprocessor_scores(tokenizer, max_length):\n    def preprocess_dataset(examples):\n        inputs = [example for example in examples[\"source\"]]\n        targets = [example for example in examples[\"target\"]]\n        model_inputs = tokenizer(inputs, text_target=targets, max_length=max_length, truncation=True)\n        model_inputs['id_1'] = examples['id_1']\n        model_inputs['id_2'] = examples['id_2']\n        return model_inputs\n    return preprocess_dataset\n\ndef create_preprocessor_scores_seq(tokenizer, max_length):\n    def preprocess_dataset(examples):\n        inputs = [example for example in examples[\"source\"]]\n        targets = [example for example in examples[\"target\"]]\n        model_inputs = tokenizer(inputs, text_target=targets, max_length=max_length, truncation=True)\n        model_inputs['id'] = examples['id']\n        return model_inputs\n    return preprocess_dataset\n\ndef convert_to_hf_dataset(dataset, cache_dir):\n    def gen():\n        for idx in range(len(dataset)):\n            yield dataset[idx]\n    return datasets.Dataset.from_generator(gen, cache_dir = cache_dir)\n\nclass GeneralSeq2SeqDataset(Dataset):\n\n    def __init__(self, data_addr, use_profile, task, create_prompt = None) -> None:\n        super().__init__()\n        with open(data_addr) as file:\n            self.data = json.load(file)\n        self.use_profile = use_profile\n        self.task = task\n        assert not (use_profile ^ (create_prompt != None)), \"You should provide a prompt maker function when you use profile\"\n        self.create_prompt = create_prompt\n\n    def __getitem__(self, index):\n        if self.use_profile:\n            return {\n                \"id\" : self.data[index]['id'],\n                \"source\" : self.create_prompt(self.data[index]['input'], self.data[index]['profile'], self.task),\n                \"target\" : self.data[index]['output']\n            }\n        else:\n            return {\n                \"id\" : self.data[index]['id'],\n                \"source\" : self.data[index]['input'],\n                \"target\" : self.data[index]['output']\n            }\n    \n    def __len__(self):\n        return len(self.data)\n\nclass GeneralSeq2SeqForScoreGenerationDataset(Dataset):\n\n    def __init__(self, data_addr, use_profile, task, create_prompt = None, max_prof_size = -1) -> None:\n        super().__init__()\n        with open(data_addr) as file:\n            self.data = json.load(file)\n        self.use_profile = use_profile\n        self.task = task\n        assert not (use_profile ^ (create_prompt != None)), \"You should provide a prompt maker function when you use profile\"\n        self.create_prompt = create_prompt\n        self.max_prof_size = max_prof_size\n        self.size = 0\n        self.index_dict = dict()\n        for i, x in enumerate(self.data):\n            for j, y in enumerate(x['profile']):\n                if max_prof_size == -1 or j < self.max_prof_size:\n                    self.index_dict[self.size] = (i, j)\n                    self.size += 1\n\n    def __getitem__(self, index):\n        self.use_profile = True\n        i, j = self.index_dict[index]\n        if self.use_profile:\n            return {\n                \"source\" : self.create_prompt(self.data[i]['input'], [self.data[i]['profile'][j]], self.task),\n                \"target\" : self.data[i]['output'],\n                \"id_1\" : self.data[i]['id'],\n                \"id_2\" : self.data[i]['profile'][j]['id']\n            }\n        else:\n            return {\n                \"source\" : self.data[index]['input'],\n                \"target\" : self.data[index]['output']\n            }\n    \n    def __len__(self):\n        return self.size"
  },
  {
    "path": "LaMP/evaluate_llm.py",
    "content": "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, AutoModelForCausalLM\n# from transformers.models.llama import LlamaTokenizer\nfrom transformers.data.data_collator import DataCollatorForSeq2Seq\nimport argparse\nfrom metrics.classification_metrics import create_metric_f1_accuracy, create_metric_mae_rmse\nfrom metrics.generation_metrics import create_metric_bleu_rouge_meteor\nfrom data.datasets import get_all_labels, GeneralSeq2SeqDataset, create_preprocessor, convert_to_hf_dataset\nfrom prompts.prompts import create_prompt_generator\nimport json\nimport os\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument(\"--validation_data\", required = True)\nparser.add_argument(\"--model_addr\", required = True)\nparser.add_argument(\"--task\", required = True)\nparser.add_argument(\"--output_dir\", required = True)\nparser.add_argument(\"--use_profile\", action = \"store_true\")\nparser.add_argument(\"--max_length\", type = int, default = 256)\nparser.add_argument(\"--generation_max_length\", type = int, default = 128)\nparser.add_argument(\"--per_device_batch_size\", type = int, default = 16)\nparser.add_argument(\"--generation_num_beams\", type = int, default = 4)\nparser.add_argument(\"--num_retrieved\", type = int, default = 1)\nparser.add_argument(\"--retriever\", default = \"bm25\")\nparser.add_argument(\"--is_ranked\", action = \"store_true\")\nparser.add_argument(\"--cache_dir\", default = \"./cache\")\n\n\nif __name__ == \"__main__\":\n\n    opts = parser.parse_args()\n    model = AutoModelForSeq2SeqLM.from_pretrained(opts.model_addr, cache_dir=opts.cache_dir)\n    tokenizer = AutoTokenizer.from_pretrained(opts.model_addr, cache_dir=opts.cache_dir)\n    collator = DataCollatorForSeq2Seq(tokenizer = tokenizer, model = model, max_length = opts.max_length)\n\n    task = opts.task\n    if opts.use_profile:\n        prompt_generator, contriver = create_prompt_generator(opts.num_retrieved, opts.retriever, opts.is_ranked, opts.max_length, tokenizer)\n    else:\n        prompt_generator, contriver = None, None\n\n    if task == \"LaMP-1\":\n        labels = get_all_labels(task)\n        eval_dataset = GeneralSeq2SeqDataset(opts.validation_data, opts.use_profile, task, prompt_generator)\n        compute_metrics = create_metric_f1_accuracy(tokenizer = tokenizer, all_labels = labels)\n    elif task == \"LaMP-2\":\n        labels = get_all_labels(task)\n        eval_dataset = GeneralSeq2SeqDataset(opts.validation_data, opts.use_profile, task, prompt_generator)\n        compute_metrics = create_metric_f1_accuracy(tokenizer = tokenizer, all_labels = labels)\n    elif task == \"LaMP-3\":\n        labels = get_all_labels(task)\n        eval_dataset = GeneralSeq2SeqDataset(opts.validation_data, opts.use_profile, task, prompt_generator)\n        compute_metrics = create_metric_mae_rmse(tokenizer = tokenizer, all_labels = labels)\n    elif task == \"LaMP-4\":\n        eval_dataset = GeneralSeq2SeqDataset(opts.validation_data, opts.use_profile, task, prompt_generator)\n        compute_metrics = create_metric_bleu_rouge_meteor(tokenizer = tokenizer)\n    elif task == \"LaMP-5\":\n        eval_dataset = GeneralSeq2SeqDataset(opts.validation_data, opts.use_profile, task, prompt_generator)\n        compute_metrics = create_metric_bleu_rouge_meteor(tokenizer = tokenizer)\n    elif task == \"LaMP-7\":\n        eval_dataset = GeneralSeq2SeqDataset(opts.validation_data, opts.use_profile, task, prompt_generator)\n        compute_metrics = create_metric_bleu_rouge_meteor(tokenizer = tokenizer)\n    elif task == \"LaMP-6\":\n        eval_dataset = GeneralSeq2SeqDataset(opts.validation_data, opts.use_profile, task, prompt_generator)\n        compute_metrics = create_metric_bleu_rouge_meteor(tokenizer = tokenizer)\n   \n    eval_dataset = convert_to_hf_dataset(eval_dataset, cache_dir = opts.cache_dir).map(create_preprocessor(tokenizer = tokenizer, max_length = opts.max_length), batched=True)\n\n    if contriver:\n        contriver = contriver.to(\"cpu\")\n\n    training_args = Seq2SeqTrainingArguments(\n        output_dir = opts.output_dir,\n        do_eval = True,\n        per_device_eval_batch_size = opts.per_device_batch_size,\n        generation_num_beams = opts.generation_num_beams,\n        predict_with_generate = True,\n        eval_accumulation_steps = 1,\n        generation_max_length = opts.generation_max_length\n    )\n\n    trainer = Seq2SeqTrainer(\n        model = model,\n        args = training_args,\n        data_collator = collator,\n        eval_dataset = eval_dataset,\n        tokenizer = tokenizer,\n        compute_metrics = compute_metrics\n    )\n    results = trainer.evaluate(eval_dataset)\n    print(results)\n\n    with open(os.path.join(opts.output_dir,'results_output.json'), 'w') as file:\n        json.dump(results, file, indent = 4)"
  },
  {
    "path": "LaMP/metrics/classification_metrics.py",
    "content": "import numpy as np\nimport evaluate\n\ndef postprocess_text(preds, labels):\n    preds = [pred.strip() for pred in preds]\n    labels = [label.strip() for label in labels]\n\n    return preds, labels\n\ndef create_metric_f1_accuracy(tokenizer, all_labels):\n    f1_metric = evaluate.load(\"f1\")\n    accuracy_metric = evaluate.load(\"accuracy\")\n    def create_mapping(x):\n        try:\n            return all_labels.index(x)\n        except:\n            print(x)\n            return -1\n    def compute_metrics(eval_preds):\n        preds, labels = eval_preds\n        if isinstance(preds, tuple):\n            preds = preds[0]\n        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)\n        decoded_preds = [create_mapping(x) for x in decoded_preds]\n        decoded_labels = [create_mapping(x) for x in decoded_labels]\n        result_acc = accuracy_metric.compute(predictions=decoded_preds, references=decoded_labels)\n        result_f1 = f1_metric.compute(predictions=decoded_preds, references=decoded_labels, labels=list(range(len(all_labels))), average = \"macro\")\n        result = {\"accuracy\" : result_acc[\"accuracy\"], \"f1\" : result_f1[\"f1\"]}\n        return result\n    return compute_metrics\n\ndef create_metric_f1_accuracy_bert(all_labels):\n    f1_metric = evaluate.load(\"f1\")\n    accuracy_metric = evaluate.load(\"accuracy\")\n    def compute_metrics(eval_preds):\n        preds, labels = eval_preds\n        preds = np.argmax(preds, axis=1)\n        result_acc = accuracy_metric.compute(predictions=preds, references=labels)\n        result_f1 = f1_metric.compute(predictions=preds, references=labels, labels=list(range(len(all_labels))), average = \"macro\")\n        result = {\"accuracy\" : result_acc[\"accuracy\"], \"f1\" : result_f1[\"f1\"]}\n        return result\n    return compute_metrics\n\ndef create_metric_mae_rmse_bert(all_labels):\n    mse_metric = evaluate.load(\"mse\")\n    mae_metric = evaluate.load(\"mae\")\n    def compute_metrics(eval_preds):\n        preds, labels = eval_preds\n        preds = np.argmax(preds, axis=1)\n        result_mae = mae_metric.compute(predictions=preds, references=labels)\n        result_rmse = mse_metric.compute(predictions=preds, references=labels, squared = False)\n        result = {\"mae\" : result_mae[\"mae\"], \"rmse\" : result_rmse[\"mse\"]}\n        return result\n    return compute_metrics\n\ndef create_metric_mae_rmse(tokenizer, all_labels):\n    mse_metric = evaluate.load(\"mse\")\n    mae_metric = evaluate.load(\"mae\")\n    def create_mapping(x, y):\n        try:\n            return float(x)\n        except:\n            print(x)\n            y = float(y)\n            if abs(1 - y) > abs(5 - y):\n                return 1.0\n            else:\n                return 5.0\n    def compute_metrics(eval_preds):\n        preds, labels = eval_preds\n        if isinstance(preds, tuple):\n            preds = preds[0]\n        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)\n        decoded_preds = [create_mapping(x,y) for x,y in zip(decoded_preds, decoded_labels)]\n        decoded_labels = [create_mapping(x,x) for x in decoded_labels]\n        result_mae = mae_metric.compute(predictions=decoded_preds, references=decoded_labels)\n        result_rmse = mse_metric.compute(predictions=decoded_preds, references=decoded_labels, squared = False)\n        result = {\"mae\" : result_mae[\"mae\"], \"rmse\" : result_rmse[\"mse\"]}\n        return result\n    return compute_metrics\n\n\ndef create_metric_f1_accuracy_chatgpt(all_labels):\n    f1_metric = evaluate.load(\"f1\")\n    accuracy_metric = evaluate.load(\"accuracy\")\n    def create_mapping(x):\n        try:\n            return all_labels.index(x)\n        except:\n            print(x)\n            return -1\n    def compute_metrics(decoded_preds, decoded_labels):\n        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)\n        decoded_preds = [create_mapping(x) for x in decoded_preds]\n        decoded_labels = [create_mapping(x) for x in decoded_labels]\n        result_acc = accuracy_metric.compute(predictions=decoded_preds, references=decoded_labels)\n        result_f1 = f1_metric.compute(predictions=decoded_preds, references=decoded_labels, labels=list(range(len(all_labels))), average = \"macro\")\n        result = {\"accuracy\" : result_acc[\"accuracy\"], \"f1\" : result_f1[\"f1\"]}\n        return result\n    return compute_metrics\n\ndef create_metric_mae_rmse_chatgpt(all_labels):\n    mse_metric = evaluate.load(\"mse\")\n    mae_metric = evaluate.load(\"mae\")\n    def create_mapping(x, y):\n        try:\n            return float(x)\n        except:\n            print(x)\n            y = float(y)\n            if abs(1 - y) > abs(5 - y):\n                return 1.0\n            else:\n                return 5.0\n    def compute_metrics(decoded_preds, decoded_labels):\n        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)\n        decoded_preds = [create_mapping(x,y) for x,y in zip(decoded_preds, decoded_labels)]\n        decoded_labels = [create_mapping(x,x) for x in decoded_labels]\n        result_mae = mae_metric.compute(predictions=decoded_preds, references=decoded_labels)\n        result_rmse = mse_metric.compute(predictions=decoded_preds, references=decoded_labels, squared = False)\n        result = {\"mae\" : result_mae[\"mae\"], \"rmse\" : result_rmse[\"mse\"]}\n        return result\n    return compute_metrics"
  },
  {
    "path": "LaMP/metrics/generation_metrics.py",
    "content": "import numpy as np\nimport evaluate\n\ndef postprocess_text(preds, labels):\n    preds = [pred.strip() for pred in preds]\n    labels = [[label.strip()] for label in labels]\n\n    return preds, labels\n\ndef create_metric_bleu_rouge_meteor(tokenizer):\n    bleu_metric = evaluate.load(\"sacrebleu\")\n    rouge_metric = evaluate.load('rouge')\n    meteor_metric = evaluate.load('meteor')\n    def compute_metrics(eval_preds):\n        preds, labels = eval_preds\n        if isinstance(preds, tuple):\n            preds = preds[0]\n        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)\n        result_bleu = bleu_metric.compute(predictions=decoded_preds, references=decoded_labels)\n        result_rouge = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels)\n        result_meteor = meteor_metric.compute(predictions=decoded_preds, references=decoded_labels)\n        result = {\"bleu\" : result_bleu[\"score\"], \"rouge-1\" : result_rouge[\"rouge1\"], \"rouge-2\" : result_rouge[\"rouge2\"], \"rouge-L\" : result_rouge[\"rougeL\"], \"rouge-LSum\" : result_rouge[\"rougeLsum\"], \"meteor\" : result_meteor['meteor']}\n        return result\n    return compute_metrics\n\ndef create_metric_bleu_rouge_meteor_chatgpt():\n    bleu_metric = evaluate.load(\"sacrebleu\")\n    rouge_metric = evaluate.load('rouge')\n    meteor_metric = evaluate.load('meteor')\n    def compute_metrics(decoded_preds, decoded_labels):\n        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)\n        result_bleu = bleu_metric.compute(predictions=decoded_preds, references=decoded_labels)\n        result_rouge = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels)\n        result_meteor = meteor_metric.compute(predictions=decoded_preds, references=decoded_labels)\n        result = {\"bleu\" : result_bleu[\"score\"], \"rouge-1\" : result_rouge[\"rouge1\"], \"rouge-2\" : result_rouge[\"rouge2\"], \"rouge-L\" : result_rouge[\"rougeL\"], \"rouge-LSum\" : result_rouge[\"rougeLsum\"], \"meteor\" : result_meteor['meteor']}\n        return result\n    return compute_metrics\n"
  },
  {
    "path": "LaMP/profile_item_utilization_scorer.py",
    "content": "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, AutoModelForCausalLM\n# from transformers.models.llama import LlamaTokenizer\nfrom transformers.data.data_collator import DataCollatorForSeq2Seq\nimport argparse\nfrom metrics.classification_metrics import create_metric_f1_accuracy, create_metric_mae_rmse\nfrom metrics.generation_metrics import create_metric_bleu_rouge_meteor\nfrom data.datasets import get_all_labels, GeneralSeq2SeqForScoreGenerationDataset, create_preprocessor_scores, convert_to_hf_dataset\nfrom prompts.prompts import create_prompt_generator\nimport tqdm\nimport datasets\nimport os\nimport json\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument(\"--train_data\", required = True)\nparser.add_argument(\"--model_name\", required = True)\nparser.add_argument(\"--task\", required = True)\nparser.add_argument(\"--output_dir\", required = True)\nparser.add_argument(\"--max_length\", type = int, default = 512)\nparser.add_argument(\"--generation_max_length\", type = int, default = 128)\nparser.add_argument(\"--per_device_batch_size\", type = int, default = 16)\nparser.add_argument(\"--generation_num_beams\", type = int, default = 4)\nparser.add_argument(\"--cache_dir\", default = \"./cache\")\nparser.add_argument(\"--start_index\", type = int, default=0)\nparser.add_argument(\"--end_index\", type = int, default=-1)\nparser.add_argument(\"--profile_size\", type = int,required = True)\n\n\nif __name__ == \"__main__\":\n\n    opts = parser.parse_args()\n    model = AutoModelForSeq2SeqLM.from_pretrained(opts.model_name, cache_dir=opts.cache_dir)\n    tokenizer = AutoTokenizer.from_pretrained(opts.model_name, cache_dir=opts.cache_dir)\n    collator = DataCollatorForSeq2Seq(tokenizer = tokenizer, model = model, max_length = opts.max_length)\n\n    task = opts.task\n    prompt_generator, contriver = create_prompt_generator(1, \"bm25\", True, opts.max_length, tokenizer)\n    \n    if task == \"LaMP-1\":\n        labels = get_all_labels(task)\n        eval_dataset = GeneralSeq2SeqForScoreGenerationDataset(opts.train_data, True, task, prompt_generator, opts.profile_size)\n        compute_metrics = create_metric_f1_accuracy(tokenizer = tokenizer, all_labels = labels)\n    elif task == \"LaMP-2\":\n        labels = get_all_labels(task)\n        eval_dataset = GeneralSeq2SeqForScoreGenerationDataset(opts.train_data, True, task, prompt_generator, opts.profile_size)\n        compute_metrics = create_metric_f1_accuracy(tokenizer = tokenizer, all_labels = labels)\n    elif task == \"LaMP-3\":\n        labels = get_all_labels(task)\n        eval_dataset = GeneralSeq2SeqForScoreGenerationDataset(opts.train_data, True, task, prompt_generator, opts.profile_size)\n        compute_metrics = create_metric_mae_rmse(tokenizer = tokenizer, all_labels = labels)\n    elif task == \"LaMP-4\":\n        eval_dataset = GeneralSeq2SeqForScoreGenerationDataset(opts.train_data, True, task, prompt_generator, opts.profile_size)\n        compute_metrics = create_metric_bleu_rouge_meteor(tokenizer = tokenizer)\n    elif task == \"LaMP-5\":\n        eval_dataset = GeneralSeq2SeqForScoreGenerationDataset(opts.train_data, True, task, prompt_generator, opts.profile_size)\n        compute_metrics = create_metric_bleu_rouge_meteor(tokenizer = tokenizer)\n    elif task == \"LaMP-7\":\n        eval_dataset = GeneralSeq2SeqForScoreGenerationDataset(opts.train_data, True, task, prompt_generator, opts.profile_size)\n        compute_metrics = create_metric_bleu_rouge_meteor(tokenizer = tokenizer)\n    elif task == \"LaMP-6\":\n        eval_dataset = GeneralSeq2SeqForScoreGenerationDataset(opts.train_data, True, task, prompt_generator, opts.profile_size)\n        compute_metrics = create_metric_bleu_rouge_meteor(tokenizer = tokenizer)\n    \n    eval_dataset = convert_to_hf_dataset(eval_dataset, opts.cache_dir).map(create_preprocessor_scores(tokenizer = tokenizer, max_length = opts.max_length), batched=True)\n\n    if contriver:\n        contriver = contriver.to(\"cpu\")\n\n    training_args = Seq2SeqTrainingArguments(\n        output_dir = opts.output_dir,\n        do_eval = True,\n        per_device_eval_batch_size = 1,\n        generation_num_beams = opts.generation_num_beams,\n        predict_with_generate = True,\n        eval_accumulation_steps = 1,\n        generation_max_length = opts.generation_max_length\n    )\n\n    trainer = Seq2SeqTrainer(\n        model = model,\n        args = training_args,\n        data_collator = collator,\n        eval_dataset = eval_dataset,\n        tokenizer = tokenizer,\n        compute_metrics = compute_metrics\n    )\n\n    results_dict = dict()\n\n    for i, x in enumerate(tqdm.tqdm(eval_dataset)):\n        if i < opts.start_index:\n            continue\n        if i >= opts.end_index and opts.end_index != -1:\n            break\n        metrics = trainer.predict(datasets.Dataset.from_list([x])).metrics\n        results_dict[f\"{x['id_1']}-{x['id_2']}\"] = {k.replace(\"test_\", '') : v for k,v in metrics.items()}\n    \n    with open(os.path.join(opts.output_dir, f\"scores_{opts.start_index}_{opts.end_index}.json\"), \"w\") as file:\n        json.dump(results_dict, file, indent = 4)"
  },
  {
    "path": "LaMP/prompts/contriever_retriever.py",
    "content": "import torch\nfrom prompts.utils import batchify\n\ndef mean_pooling(token_embeddings, mask):\n    token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)\n    sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]\n    return sentence_embeddings\n\ndef retrieve_top_k_with_contriever(contriver, tokenizer, corpus, profile, query, k):\n    query_tokens = tokenizer([query], padding=True, truncation=True, return_tensors='pt').to(\"cuda:0\")\n    output_query = contriver(**query_tokens)\n    output_query = mean_pooling(output_query.last_hidden_state, query_tokens['attention_mask'])\n    batch_size = 4\n    scores = []\n    batched_corpus = batchify(corpus, batch_size)\n    for batch in batched_corpus:\n        tokens_batch = tokenizer(batch, padding=True, truncation=True, return_tensors='pt').to(\"cuda:0\")\n        outputs_batch = contriver(**tokens_batch)\n        outputs_batch = mean_pooling(outputs_batch.last_hidden_state, tokens_batch['attention_mask'])\n        temp_scores = output_query.squeeze() @ outputs_batch.T\n        scores.extend(temp_scores.tolist())\n    topk_values, topk_indices = torch.topk(torch.tensor(scores), k)\n    return [profile[m] for m in topk_indices.tolist()]\n"
  },
  {
    "path": "LaMP/prompts/prompts.py",
    "content": "from rank_bm25 import BM25Okapi\nfrom transformers import AutoTokenizer, AutoModel\nfrom prompts.utils import extract_strings_between_quotes, extract_after_article, extract_after_review, extract_after_paper, add_string_after_title, extract_after_colon, extract_after_description, extract_after_abstract\nfrom prompts.contriever_retriever import retrieve_top_k_with_contriever\nimport random\n\ndef classification_citation_query_corpus_maker(inp, profile):\n    corpus = [f'{x[\"title\"]} {x[\"abstract\"]}' for x in profile]\n    extracted = extract_strings_between_quotes(inp)\n    query = f'{extracted[1]} {extracted[2]}'\n    return corpus, query\n\ndef classification_news_query_corpus_maker(inp, profile):\n    corpus = [f'{x[\"title\"]} {x[\"text\"]}' for x in profile]\n    query = extract_after_article(inp)\n    return corpus, query\n\ndef classification_movies_query_corpus_maker(inp, profile):\n    corpus = [f'{x[\"description\"]}' for x in profile]\n    query = extract_after_description(inp)\n    return corpus, query\n\ndef classification_review_query_corpus_maker(inp, profile):\n    corpus = [f'{x[\"text\"]}' for x in profile]\n    query = extract_after_review(inp)\n    return corpus, query\n\ndef generation_news_query_corpus_maker(inp, profile):\n    corpus = [f'{x[\"title\"]} {x[\"text\"]}' for x in profile]\n    query = extract_after_article(inp)\n    return corpus, query\n\ndef generation_paper_query_corpus_maker(inp, profile):\n    corpus = [f'{x[\"title\"]} {x[\"abstract\"]}' for x in profile]\n    query = extract_after_paper(inp)\n    return corpus, query\n\ndef generation_paper_long_query_corpus_maker(inp, profile):\n    corpus = [f'{x[\"title\"]} {x[\"abstract\"]}' for x in profile]\n    query = extract_after_abstract(inp)\n    return corpus, query\n\n\ndef parphrase_tweet_query_corpus_maker(inp, profile):\n    corpus = [f'{x[\"text\"]}' for x in profile]\n    query = extract_after_colon(inp)\n    return corpus, query\n\ndef generation_avocado_query_corpus_maker(inp, profile):\n    corpus = [f'{x[\"text\"]}' for x in profile]\n    query = extract_after_colon(inp)\n    return corpus, query\n\ndef generation_avocado_long_query_corpus_maker(inp, profile):\n    corpus = [f'{x[\"text\"]} {x[\"title\"]}' for x in profile]\n    query = extract_after_colon(inp)\n    return corpus, query\n\ndef create_classification_citation_prompt(inp, profile, max_length, tokenizer):\n    prompts = []\n    per_p_max_length = (max_length - 2 * (len(profile) - 1)) // len(profile)\n    saved_tokens = 0\n    for p in profile:\n        tokens = tokenizer(p[\"title\"], max_length=per_p_max_length + saved_tokens - 2, truncation=True)\n        saved_tokens += per_p_max_length - len(tokens['input_ids']) - 2\n        new_title = tokenizer.batch_decode([tokens['input_ids']], skip_special_tokens=True)[0]\n        prompt = f'\"{new_title}\"'\n        prompts.append(prompt)\n    return add_string_after_title(inp, \", and \".join(prompts))\n\ndef create_classification_news_prompt(inp, profile, max_length, tokenizer): # good\n    per_p_max_length = (max_length - 1 - 2 * (len(profile) - 1)) // len(profile)\n    saved_tokens = 0\n    prompts = []\n    for p in profile:\n        needed_part_len = len(tokenizer(f'the category for the article: \" \" is \"{p[\"category\"]}\" ')['input_ids'])\n        tokens = tokenizer(p[\"text\"], max_length=per_p_max_length + saved_tokens - needed_part_len, truncation=True)\n        saved_tokens += per_p_max_length - len(tokens['input_ids']) - needed_part_len\n        new_text = tokenizer.batch_decode([tokens['input_ids']], skip_special_tokens=True)[0]\n        prompt = f'the category for the article: \"{new_text}\" is \"{p[\"category\"]}\" '\n        prompts.append(prompt)\n    return f'{\", and \".join(prompts)}. {inp}'\n\ndef create_classification_movies_prompt(inp, profile, max_length, tokenizer): # good\n    per_p_max_length = (max_length - 1 - 2 * (len(profile) - 1)) // len(profile)\n    saved_tokens = 0\n    prompts = []\n    for p in profile:\n        needed_part_len = len(tokenizer(f'the tag for the movie: \" \" is \"{p[\"tag\"]}\" ')['input_ids'])\n        tokens = tokenizer(p[\"description\"], max_length=per_p_max_length + saved_tokens - needed_part_len, truncation=True)\n        saved_tokens += per_p_max_length - len(tokens['input_ids']) - needed_part_len\n        new_text = tokenizer.batch_decode([tokens['input_ids']], skip_special_tokens=True)[0]\n        prompt = f'the tag for the movie: \"{new_text}\" is \"{p[\"tag\"]}\" '\n        prompts.append(prompt)\n    return f'{\", and \".join(prompts)}. {inp}'\n\ndef create_classification_review_prompt(inp, profile, max_length, tokenizer):\n    per_p_max_length = (max_length - 1 - 2 * (len(profile) - 1)) // len(profile)\n    saved_tokens = 0\n    prompts = []\n    for p in profile:\n        needed_part_len = len(tokenizer(f'{p[\"score\"]} is the score for \" \" ')['input_ids'])\n        tokens = tokenizer(p[\"text\"], max_length=per_p_max_length + saved_tokens - needed_part_len, truncation=True)\n        saved_tokens += per_p_max_length - len(tokens['input_ids']) - needed_part_len\n        new_text = tokenizer.batch_decode([tokens['input_ids']], skip_special_tokens=True)[0]\n        prompt = f'{p[\"score\"]} is the score for \"{new_text}\" '\n        prompts.append(prompt)\n    return f'{\", and \".join(prompts)}. {inp}'\n\ndef create_generation_news_prompt(inp, profile, max_length, tokenizer):\n    per_p_max_length = (max_length - 1 - 2 * (len(profile) - 1)) // len(profile)\n    saved_tokens = 0\n    prompts = []\n    for p in profile:\n        needed_part_len = len(tokenizer(f'\"{p[\"title\"]}\" is the title for \" \" ')['input_ids'])\n        tokens = tokenizer(p[\"text\"], max_length=per_p_max_length + saved_tokens - needed_part_len, truncation=True)\n        saved_tokens += per_p_max_length - len(tokens['input_ids']) - needed_part_len\n        new_text = tokenizer.batch_decode([tokens['input_ids']], skip_special_tokens=True)[0]\n        prompt = f'\"{p[\"title\"]}\" is the title for \"{new_text}\" '\n        prompts.append(prompt)\n    return f'{\", and \".join(prompts)}. {inp}'\n\ndef create_generation_paper_prompt(inp, profile, max_length, tokenizer):\n    per_p_max_length = (max_length - 1 - 2 * (len(profile) - 1) - len(tokenizer(\"Following the given patterns\")['input_ids'])) // len(profile)\n    saved_tokens = 0\n    prompts = []\n    for p in profile:\n        needed_part_len = len(tokenizer(f'\"{p[\"title\"]}\" is a title \" \" ')['input_ids'])\n        tokens = tokenizer(p[\"abstract\"], max_length=per_p_max_length + saved_tokens - needed_part_len, truncation=True)\n        saved_tokens += per_p_max_length - len(tokens['input_ids']) - needed_part_len\n        new_asbtract = tokenizer.batch_decode([tokens['input_ids']], skip_special_tokens=True)[0]\n        prompt = f'\"{p[\"title\"]}\" is a title for \"{new_asbtract}\" '\n        prompts.append(prompt)\n    return f'{\", and \".join(prompts)}. Following the given patterns {inp}'\n\ndef create_generation_paper_long_prompt(inp, profile, max_length, tokenizer):\n    per_p_max_length = (max_length - 1 - 2 * (len(profile) - 1) - len(tokenizer(\"Following the given patterns\")['input_ids'])) // len(profile)\n    saved_tokens = 0\n    prompts = []\n    for p in profile:\n        needed_part_len = len(tokenizer(f'\"{p[\"title\"]}\" is the title \" \" ')['input_ids'])\n        tokens = tokenizer(p[\"abstract\"], max_length=per_p_max_length + saved_tokens - needed_part_len, truncation=True)\n        saved_tokens += per_p_max_length - len(tokens['input_ids']) - needed_part_len\n        new_asbtract = tokenizer.batch_decode([tokens['input_ids']], skip_special_tokens=True)[0]\n        prompt = f'\"{p[\"title\"]}\" is the title for \"{new_asbtract}\" '\n        prompts.append(prompt)\n    return f'{\", and \".join(prompts)}. Following the given patterns {inp}'\n\n\ndef create_parphrase_tweet_prompt(inp, profile, max_length, tokenizer):\n    per_p_max_length = (max_length - 1 - 2 * (len(profile) - 1) - len(tokenizer(\"are written by user. Following the given patterns\")['input_ids'])) // len(profile)\n    saved_tokens = 0\n    prompts = []\n    for p in profile:\n        needed_part_len = len(tokenizer(f'\"\" ')['input_ids'])\n        tokens = tokenizer(p[\"text\"], max_length=per_p_max_length + saved_tokens - needed_part_len, truncation=True)\n        saved_tokens += per_p_max_length - len(tokens['input_ids']) - needed_part_len\n        new_asbtract = tokenizer.batch_decode([tokens['input_ids']], skip_special_tokens=True)[0]\n        prompt = f'\"{new_asbtract}\" '\n        prompts.append(prompt)\n    return f'{\", and \".join(prompts)} are written by a person. Following the given patterns {inp}'\n\ndef create_generation_avocado_prompt(inp, profile, max_length, tokenizer):\n    per_p_max_length = (max_length - 1 - 2 * (len(profile) - 1)) // len(profile)\n    saved_tokens = 0\n    prompts = []\n    for p in profile:\n        needed_part_len = len(tokenizer(f'\"{p[\"title\"]}\" is the title for \" \" ')['input_ids'])\n        tokens = tokenizer(p[\"text\"], max_length=per_p_max_length + saved_tokens - needed_part_len, truncation=True)\n        saved_tokens += per_p_max_length - len(tokens['input_ids']) - needed_part_len\n        new_text = tokenizer.batch_decode([tokens['input_ids']], skip_special_tokens=True)[0]\n        prompt = f'\"{p[\"title\"]}\" is the title for \"{new_text}\" '\n        prompts.append(prompt)\n    return f'{\", and \".join(prompts)}. {inp}'\n\ndef create_generation_avocado_long_prompt(inp, profile, max_length, tokenizer):\n    per_p_max_length = (max_length - 1 - 2 * (len(profile) - 1) - len(tokenizer(\"are written by user. Following the given patterns\")['input_ids'])) // len(profile)\n    saved_tokens = 0\n    prompts = []\n    for p in profile:\n        needed_part_len = len(tokenizer(f'\"{p[\"title\"]}\" is the title for \" \" ')['input_ids'])\n        tokens = tokenizer(p[\"text\"], max_length=per_p_max_length + saved_tokens - needed_part_len, truncation=True)\n        saved_tokens += per_p_max_length - len(tokens['input_ids']) - needed_part_len\n        new_text = tokenizer.batch_decode([tokens['input_ids']], skip_special_tokens=True)[0]\n        prompt = f'\"{p[\"title\"]}\" is the title for \"{new_text}\" '\n        prompts.append(prompt)\n    return f'{\", and \".join(prompts)}. Following the given patterns {inp}'\n\ndef create_prompt_generator(num_retrieve, ret_type = \"bm25\", is_ranked = False, max_length = 512, tokenizer = None):\n    contriever = None\n    if ret_type == \"contriever\" and not is_ranked:\n        tokenizer = AutoTokenizer.from_pretrained('facebook/contriever')\n        contriever = AutoModel.from_pretrained('facebook/contriever').to(\"cuda:0\")\n        contriever.eval()\n\n    def prompt(inp, profile, task):\n        if task == \"LaMP-1\":\n            corpus, query = classification_citation_query_corpus_maker(inp, profile)\n        elif task == \"LaMP-2-old\":\n            corpus, query = classification_news_query_corpus_maker(inp, profile)\n        elif task == \"LaMP-2\":\n            corpus, query = classification_movies_query_corpus_maker(inp, profile)\n        elif task == \"LaMP-3\":\n            corpus, query = classification_review_query_corpus_maker(inp, profile)\n        elif task == \"LaMP-4\":\n            corpus, query = generation_news_query_corpus_maker(inp, profile)\n        elif task == \"LaMP-5\":\n            corpus, query = generation_paper_query_corpus_maker(inp, profile)\n        elif task == \"LaMP-7\":\n            corpus, query = parphrase_tweet_query_corpus_maker(inp, profile)\n        elif task == \"LaMP-6\":\n            corpus, query = generation_avocado_query_corpus_maker(inp, profile)\n        \n        if not is_ranked:\n            if ret_type == \"bm25\":\n                tokenized_corpus = [x.split() for x in corpus]\n                bm25 = BM25Okapi(tokenized_corpus)\n                tokenized_query = query.split()\n                selected_profs = bm25.get_top_n(tokenized_query, profile, n=num_retrieve)\n            elif ret_type == \"contriever\":\n                selected_profs = retrieve_top_k_with_contriever(contriever, tokenizer, corpus, profile, query, num_retrieve)\n            elif ret_type == \"random\":\n                selected_profs = random.choices(profile, k = num_retrieve)\n            elif ret_type == \"recency\":\n                profile = sorted(profile, key=lambda x: tuple(map(int, str(x['date']).split(\"-\"))))\n                selected_profs = profile[-num_retrieve:][::-1]\n        else:\n            if ret_type == \"recency_contriever\":\n                selected_profs_cont = profile[:num_retrieve // 2]\n                profile = sorted(profile, key=lambda x: tuple(map(int, str(x['date']).split(\"-\"))))\n                selected_profs_rec = profile[-(num_retrieve // 2):][::-1]\n                selected_profs = selected_profs_cont + selected_profs_rec\n            else:\n                selected_profs_cont = profile[:num_retrieve]\n                selected_profs = selected_profs_cont\n        factor = 0.6\n        while True:\n            try:\n                max_len_prompt = max_length - min(len(tokenizer(inp)['input_ids']), int(factor * max_length))\n                if task == \"LaMP-1\":\n                    return create_classification_citation_prompt(inp, selected_profs, max_len_prompt, tokenizer)\n                elif task == \"LaMP-2-old\":\n                    return create_classification_news_prompt(inp, selected_profs, max_len_prompt, tokenizer)\n                elif task == \"LaMP-2\":\n                    return create_classification_movies_prompt(inp, selected_profs, max_len_prompt, tokenizer)\n                elif task == \"LaMP-3\":\n                    return create_classification_review_prompt(inp, selected_profs, max_len_prompt, tokenizer)\n                elif task == \"LaMP-4\":\n                    return create_generation_news_prompt(inp, selected_profs, max_len_prompt, tokenizer)\n                elif task == \"LaMP-5\":\n                    return create_generation_paper_prompt(inp, selected_profs, max_len_prompt, tokenizer)\n                elif task == \"LaMP-7\":\n                    return create_parphrase_tweet_prompt(inp, selected_profs, max_len_prompt, tokenizer)\n                elif task == \"LaMP-6\":\n                    return create_generation_avocado_prompt(inp, selected_profs, max_len_prompt, tokenizer)\n            except:\n                factor -= 0.1\n                if factor < 0:\n                    print(\"not possible\")\n                    return inp\n    return prompt, contriever"
  },
  {
    "path": "LaMP/prompts/utils.py",
    "content": "def extract_strings_between_quotes(input_string):\n    output_list = []\n    inside_quotes = False\n    current_string = ''\n    \n    for char in input_string:\n        if char == '\"' and not inside_quotes:\n            inside_quotes = True\n        elif char == '\"' and inside_quotes:\n            inside_quotes = False\n            output_list.append(current_string)\n            current_string = ''\n        elif inside_quotes:\n            current_string += char\n    \n    return output_list\n\ndef extract_after_article(input_string):\n    article_index = input_string.find('article:')\n    if article_index == -1:\n        return None\n    return input_string[article_index + len('article:'):].strip()\n\ndef extract_after_description(input_string):\n    article_index = input_string.find('description:')\n    if article_index == -1:\n        return None\n    return input_string[article_index + len('description:'):].strip()\n\n\ndef extract_after_review(input_string):\n    article_index = input_string.find('review:')\n    if article_index == -1:\n        return None\n    return input_string[article_index + len('review:'):].strip()\n\ndef extract_after_paper(input_string):\n    article_index = input_string.find('paper:')\n    if article_index == -1:\n        return None\n    return input_string[article_index + len('paper:'):].strip()\n\ndef extract_after_abstract(input_string):\n    article_index = input_string.find('abstract:')\n    if article_index == -1:\n        return None\n    return input_string[article_index + len('abstract:'):].strip()\n\ndef extract_after_colon(input_string):\n    article_index = input_string.find(':')\n    if article_index == -1:\n        return None\n    return input_string[article_index + len(':'):].strip()\n\n\ndef add_string_after_title(original_string, string_to_add):\n    title_index = original_string.find(\"title\")\n    \n    if title_index == -1:\n        return original_string\n    \n    return original_string[:title_index+5] + \", and \" + string_to_add + original_string[title_index+5:]\n\ndef batchify(lst, batch_size):\n    return [lst[i:i+batch_size] for i in range(0, len(lst), batch_size)]"
  },
  {
    "path": "LaMP/rank_profiles.py",
    "content": "import torch\nfrom prompts.utils import batchify\nfrom transformers import AutoModel, AutoTokenizer\nimport json\nimport tqdm\nfrom prompts.utils import extract_strings_between_quotes, extract_after_article, extract_after_review, extract_after_paper, add_string_after_title, extract_after_colon, extract_after_abstract, extract_after_description\nfrom rank_bm25 import BM25Okapi\nimport argparse\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument(\"--input_data_addr\", required = True)\nparser.add_argument(\"--output_ranking_addr\", required = True)\nparser.add_argument(\"--task\", required = True)\nparser.add_argument(\"--ranker\", required = True)\nparser.add_argument(\"--batch_size\", type = int, default=16)\nparser.add_argument(\"--use_date\", action='store_true')\nparser.add_argument(\"--contriever_checkpoint\", default=\"facebook/contriever\")\n\n\ndef mean_pooling(token_embeddings, mask):\n    token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)\n    sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]\n    return sentence_embeddings\n\ndef retrieve_top_k_with_contriver(contriver, tokenizer, corpus, profile, query, k, batch_size = 16):\n    query_tokens = tokenizer([query], padding=True, truncation=True, return_tensors='pt').to(\"cuda:0\")\n    output_query = contriver(**query_tokens)\n    output_query = mean_pooling(output_query.last_hidden_state, query_tokens['attention_mask'])\n    scores = []\n    batched_corpus = batchify(corpus, batch_size)\n    for batch in batched_corpus:\n        tokens_batch = tokenizer(batch, padding=True, truncation=True, return_tensors='pt').to(\"cuda:0\")\n        outputs_batch = contriver(**tokens_batch)\n        outputs_batch = mean_pooling(outputs_batch.last_hidden_state, tokens_batch['attention_mask'])\n        temp_scores = output_query.squeeze() @ outputs_batch.T\n        scores.extend(temp_scores.tolist())\n    topk_values, topk_indices = torch.topk(torch.tensor(scores), k)\n    return [profile[m] for m in topk_indices.tolist()]\n\ndef retrieve_top_k_with_bm25(corpus, profile, query, k):\n    tokenized_corpus = [x.split() for x in corpus]\n    bm25 = BM25Okapi(tokenized_corpus)\n    tokenized_query = query.split()\n    selected_profs = bm25.get_top_n(tokenized_query, profile, n=k)\n    return selected_profs\n\ndef classification_citation_query_corpus_maker(inp, profile, use_date):\n    if use_date:\n        corpus = [f'{x[\"title\"]} {x[\"abstract\"]} date: {x[\"date\"]}' for x in profile]\n    else:\n        corpus = [f'{x[\"title\"]} {x[\"abstract\"]}' for x in profile]\n    ids = [x['id'] for x in profile]\n    extracted = extract_strings_between_quotes(inp)\n    query = f'{extracted[1]} {extracted[2]}'\n    return corpus, query, ids\n\ndef classification_review_query_corpus_maker(inp, profile, use_date):\n    if use_date:\n        corpus = [f'{x[\"text\"]} date: {x[\"date\"]}' for x in profile]\n    else:\n        corpus = [f'{x[\"text\"]}' for x in profile]\n    ids = [x['id'] for x in profile]\n    query = extract_after_review(inp)\n    return corpus, query, ids\n\ndef generation_news_query_corpus_maker(inp, profile, use_date):\n    if use_date:\n        corpus = [f'{x[\"title\"]} {x[\"text\"]} date: {x[\"date\"]}' for x in profile]\n    else:\n        corpus = [f'{x[\"title\"]} {x[\"text\"]}' for x in profile]\n    ids = [x['id'] for x in profile]\n    query = extract_after_article(inp)\n    return corpus, query, ids\n\ndef generation_paper_query_corpus_maker(inp, profile, use_date):\n    if use_date:\n        corpus = [f'{x[\"title\"]} {x[\"abstract\"]} date: {x[\"date\"]}' for x in profile]\n    else:\n        corpus = [f'{x[\"title\"]} {x[\"abstract\"]}' for x in profile]\n    ids = [x['id'] for x in profile]\n    query = extract_after_colon(inp)\n    return corpus, query, ids\n\ndef parphrase_tweet_query_corpus_maker(inp, profile, use_date):\n    if use_date:\n        corpus = [f'{x[\"text\"]} date: {x[\"date\"]}' for x in profile]\n    else:\n        corpus = [f'{x[\"text\"]}' for x in profile]\n    query = extract_after_colon(inp)\n    ids = [x['id'] for x in profile]\n    return corpus, query, ids\n\ndef generation_avocado_query_corpus_maker(inp, profile, use_date):\n    if use_date:\n        corpus = [f'{x[\"text\"]} date: {x[\"date\"]}' for x in profile]\n    else:\n        corpus = [f'{x[\"text\"]}' for x in profile]\n    ids = [x['id'] for x in profile]\n    query = extract_after_colon(inp)\n    return corpus, query, ids\n\ndef classification_movies_query_corpus_maker(inp, profile, use_date):\n    if use_date:\n        corpus = [f'{x[\"description\"]} date: {x[\"date\"]}' for x in profile]\n    else:\n        corpus = [f'{x[\"description\"]}' for x in profile]\n    query = extract_after_description(inp)\n    ids = [x['id'] for x in profile]\n    return corpus, query, ids\n\n\nif __name__ == \"__main__\":\n\n    opts = parser.parse_args()\n    task = opts.task\n    ranker = opts.ranker\n\n    with open(opts.input_data_addr) as file:\n        dataset = json.load(file)\n    \n    rank_dict = dict()\n\n    for data in tqdm.tqdm(dataset):\n        inp = data['input']\n        profile = data['profile']\n        if task == \"LaMP-1\":\n            corpus, query, ids = classification_citation_query_corpus_maker(inp, profile, opts.use_date)\n        elif task == \"LaMP-3\":\n            corpus, query, ids = classification_review_query_corpus_maker(inp, profile, opts.use_date)\n        elif task == \"LaMP-2\":\n            corpus, query = classification_movies_query_corpus_maker(inp, profile, opts.use_date)\n        elif task == \"LaMP-4\":\n            corpus, query, ids = generation_news_query_corpus_maker(inp, profile, opts.use_date)\n        elif task == \"LaMP-5\":\n            corpus, query, ids = generation_paper_query_corpus_maker(inp, profile, opts.use_date)\n        elif task == \"LaMP-7\":\n            corpus, query, ids = parphrase_tweet_query_corpus_maker(inp, profile, opts.use_date)\n        elif task == \"LaMP-6\":\n            corpus, query, ids = generation_avocado_query_corpus_maker(inp, profile, opts.use_date)\n        \n        if ranker == \"contriever\":\n            tokenizer = AutoTokenizer.from_pretrained(opts.contriever_checkpoint)\n            contriver = AutoModel.from_pretrained(opts.contriever_checkpoint).to(\"cuda:0\")\n            contriver.eval()\n            randked_profile = retrieve_top_k_with_contriver(contriver, tokenizer, corpus, profile, query, len(profile), opts.batch_size)\n        elif ranker == \"bm25\":\n            randked_profile = retrieve_top_k_with_bm25(corpus, profile, query, len(profile))\n        elif ranker == \"recency\":\n            profile = sorted(profile, key=lambda x: tuple(map(int, str(x['date']).split(\"-\"))))\n            randked_profile = profile[::-1]\n\n        data['profile'] = randked_profile\n\n        rank_dict[data['id']] = [x['id'] for x in randked_profile]\n\n    \n    with open(opts.output_ranking_addr, \"w\") as file:\n        json.dump(rank_dict, file)"
  },
  {
    "path": "LaMP/requirements.txt",
    "content": "mail_parser==3.15.0\nnumpy==1.24.2\nrank_bm25==0.2.2\ntorch==2.0.0\ntqdm==4.65.0\ntransformers==4.27.1\n"
  },
  {
    "path": "LaMP/retriever_utilization_scorer.py",
    "content": "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, AutoModelForCausalLM\n# from transformers.models.llama import LlamaTokenizer\nfrom transformers.data.data_collator import DataCollatorForSeq2Seq\nimport argparse\nfrom metrics.classification_metrics import create_metric_f1_accuracy, create_metric_mae_rmse\nfrom metrics.generation_metrics import create_metric_bleu_rouge_meteor\nfrom data.datasets import get_all_labels, GeneralSeq2SeqDataset, create_preprocessor_scores_seq, convert_to_hf_dataset\nfrom prompts.prompts import create_prompt_generator\nimport tqdm\nimport datasets\nimport os\nimport json\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument(\"--data_addr\", required = True)\nparser.add_argument(\"--model_name\", required = True)\nparser.add_argument(\"--task\", required = True)\nparser.add_argument(\"--output_dir\", required = True)\nparser.add_argument(\"--use_profile\", action = \"store_true\")\nparser.add_argument(\"--max_length\", type = int, default = 256)\nparser.add_argument(\"--generation_max_length\", type = int, default = 128)\nparser.add_argument(\"--generation_num_beams\", type = int, default = 4)\nparser.add_argument(\"--num_retrieved\", type = int, default = 4)\nparser.add_argument(\"--retriever\", default = \"bm25\")\nparser.add_argument(\"--is_ranked\", action = \"store_true\")\nparser.add_argument(\"--cache_dir\", default = \"./cache\")\n\n\n\nif __name__ == \"__main__\":\n\n    opts = parser.parse_args()\n    model = AutoModelForSeq2SeqLM.from_pretrained(opts.model_name, cache_dir=opts.cache_dir)\n    tokenizer = AutoTokenizer.from_pretrained(opts.model_name, cache_dir=opts.cache_dir)\n    collator = DataCollatorForSeq2Seq(tokenizer = tokenizer, model = model, max_length = opts.max_length)\n\n    task = opts.task\n    if opts.use_profile:\n        prompt_generator, contriver = create_prompt_generator(opts.num_retrieved, opts.retriever, opts.is_ranked, opts.max_length, tokenizer)\n    else:\n        prompt_generator, contriver = None, None\n\n    if task == \"LaMP-1\":\n        labels = get_all_labels(task)\n        eval_dataset = GeneralSeq2SeqDataset(opts.data_addr, opts.use_profile, task, prompt_generator)\n        compute_metrics = create_metric_f1_accuracy(tokenizer = tokenizer, all_labels = labels)\n    elif task == \"LaMP-2\":\n        labels = get_all_labels(task)\n        eval_dataset = GeneralSeq2SeqDataset(opts.data_addr, opts.use_profile, task, prompt_generator)\n        compute_metrics = create_metric_f1_accuracy(tokenizer = tokenizer, all_labels = labels)\n    elif task == \"LaMP-3\":\n        labels = get_all_labels(task)\n        eval_dataset = GeneralSeq2SeqDataset(opts.data_addr, opts.use_profile, task, prompt_generator)\n        compute_metrics = create_metric_mae_rmse(tokenizer = tokenizer, all_labels = labels)\n    elif task == \"LaMP-4\":\n        eval_dataset = GeneralSeq2SeqDataset(opts.data_addr, opts.use_profile, task, prompt_generator)\n        compute_metrics = create_metric_bleu_rouge_meteor(tokenizer = tokenizer)\n    elif task == \"LaMP-5\":\n        eval_dataset = GeneralSeq2SeqDataset(opts.data_addr, opts.use_profile, task, prompt_generator)\n        compute_metrics = create_metric_bleu_rouge_meteor(tokenizer = tokenizer)\n    elif task == \"LaMP-7\":\n        eval_dataset = GeneralSeq2SeqDataset(opts.data_addr, opts.use_profile, task, prompt_generator)\n        compute_metrics = create_metric_bleu_rouge_meteor(tokenizer = tokenizer)\n    elif task == \"LaMP-6\":\n        eval_dataset = GeneralSeq2SeqDataset(opts.data_addr, opts.use_profile, task, prompt_generator)\n        compute_metrics = create_metric_bleu_rouge_meteor(tokenizer = tokenizer)\n\n    eval_dataset = convert_to_hf_dataset(eval_dataset, cache_dir = opts.cache_dir).map(create_preprocessor_scores_seq(tokenizer = tokenizer, max_length = opts.max_length), batched=True)\n\n    if contriver:\n        contriver = contriver.to(\"cpu\")\n\n    training_args = Seq2SeqTrainingArguments(\n        output_dir = opts.output_dir,\n        do_eval = True,\n        per_device_eval_batch_size = 1,\n        generation_num_beams = opts.generation_num_beams,\n        predict_with_generate = True,\n        eval_accumulation_steps = 1,\n        generation_max_length = opts.generation_max_length\n    )\n\n    trainer = Seq2SeqTrainer(\n        model = model,\n        args = training_args,\n        data_collator = collator,\n        eval_dataset = eval_dataset,\n        tokenizer = tokenizer,\n        compute_metrics = compute_metrics\n    )\n\n    results_dict = dict()\n  \n    for i, x in enumerate(tqdm.tqdm(eval_dataset)):\n        preds = trainer.predict(datasets.Dataset.from_list([x]))\n        metrics = preds.metrics\n        output = tokenizer.batch_decode(preds.predictions, skip_special_tokens=True)[0].strip()\n        results_dict[f\"{x['id']}\"] = {\n            \"metric\" : {k.replace(\"test_\", '') : v for k,v in metrics.items()}, \n            \"output\" : output, \n            \"input\":x['source']\n        }\n\n    with open(os.path.join(opts.output_dir, f\"scores.json\"), \"w\") as file:\n        json.dump(results_dict, file, indent = 4)"
  },
  {
    "path": "LaMP/train_llm.py",
    "content": "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments\nfrom transformers.data.data_collator import DataCollatorForSeq2Seq\nimport argparse\nfrom metrics.classification_metrics import create_metric_f1_accuracy, create_metric_mae_rmse\nfrom metrics.generation_metrics import create_metric_bleu_rouge_meteor\nfrom data.datasets import get_all_labels, GeneralSeq2SeqDataset, create_preprocessor, convert_to_hf_dataset\nfrom prompts.prompts import create_prompt_generator\nimport os\nimport json\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument(\"--train_data\", required = True)\nparser.add_argument(\"--validation_data\", required = True)\nparser.add_argument(\"--test_data\", default=\"\")\nparser.add_argument(\"--model_name\", required = True)\nparser.add_argument(\"--task\", required = True)\nparser.add_argument(\"--output_dir\", required = True)\nparser.add_argument(\"--retriever\", default = \"bm25\")\nparser.add_argument(\"--use_profile\", action = \"store_true\")\nparser.add_argument(\"--is_ranked\", action = \"store_true\")\nparser.add_argument(\"--max_length\", type = int, default = 256)\nparser.add_argument(\"--generation_max_length\", type = int, default = 128)\nparser.add_argument(\"--per_device_batch_size\", type = int, default = 16)\nparser.add_argument(\"--learning_rate\", type = float, default = 5e-5)\nparser.add_argument(\"--weight_decay\", type = float, default = 0.0001)\nparser.add_argument(\"--num_train_epochs\", type = int, default = 10)\nparser.add_argument(\"--lr_scheduler_type\", default = \"linear\")\nparser.add_argument(\"--warmup_ratio\", type = float, default = 0.05)\nparser.add_argument(\"--generation_num_beams\", type = int, default = 4)\nparser.add_argument(\"--num_retrieved\", type = int, required=True)\nparser.add_argument(\"--gradient_accumulation_steps\", type = int, default = 1)\nparser.add_argument(\"--cache_dir\", default = \"./cache\")\n\n\nif __name__ == \"__main__\":\n\n    opts = parser.parse_args()\n    \n    model = AutoModelForSeq2SeqLM.from_pretrained(opts.model_name, cache_dir=opts.cache_dir)\n    tokenizer = AutoTokenizer.from_pretrained(opts.model_name, cache_dir=opts.cache_dir)\n    collator = DataCollatorForSeq2Seq(tokenizer = tokenizer, model = model, max_length = opts.max_length)\n\n    task = opts.task\n    if opts.use_profile:\n        prompt_generator, contriver = create_prompt_generator(opts.num_retrieved, opts.retriever, opts.is_ranked, opts.max_length, tokenizer)\n    else:\n        prompt_generator, contriver = None, None\n\n    greater_is_better = True\n    if task == \"LaMP-1\":\n        train_dataset, labels = GeneralSeq2SeqDataset(opts.train_data, opts.use_profile, task, prompt_generator), get_all_labels(task)\n        eval_dataset = GeneralSeq2SeqDataset(opts.validation_data, opts.use_profile, task, prompt_generator)\n        if opts.test_data:\n            test_dataset = GeneralSeq2SeqDataset(opts.test_data, opts.use_profile, task, prompt_generator)\n        compute_metrics = create_metric_f1_accuracy(tokenizer = tokenizer, all_labels = labels)\n        best_metric = \"accuracy\"\n    elif task == \"LaMP-2-old\":\n        train_dataset, labels = GeneralSeq2SeqDataset(opts.train_data, opts.use_profile, task, prompt_generator), get_all_labels(task)\n        eval_dataset = GeneralSeq2SeqDataset(opts.validation_data, opts.use_profile, task, prompt_generator)\n        if opts.test_data:\n            test_dataset = GeneralSeq2SeqDataset(opts.test_data, opts.use_profile, task, prompt_generator)\n        compute_metrics = create_metric_f1_accuracy(tokenizer = tokenizer, all_labels = labels)\n        best_metric = \"accuracy\"\n    elif task == \"LaMP-2\":\n        train_dataset, labels = GeneralSeq2SeqDataset(opts.train_data, opts.use_profile, task, prompt_generator), get_all_labels(task)\n        eval_dataset = GeneralSeq2SeqDataset(opts.validation_data, opts.use_profile, task, prompt_generator)\n        if opts.test_data:\n            test_dataset = GeneralSeq2SeqDataset(opts.test_data, opts.use_profile, task, prompt_generator)\n        compute_metrics = create_metric_f1_accuracy(tokenizer = tokenizer, all_labels = labels)\n        best_metric = \"accuracy\"\n    elif task == \"LaMP-3\":\n        train_dataset, labels = GeneralSeq2SeqDataset(opts.train_data, opts.use_profile, task, prompt_generator), get_all_labels(task)\n        eval_dataset = GeneralSeq2SeqDataset(opts.validation_data, opts.use_profile, task, prompt_generator)\n        if opts.test_data:\n            test_dataset = GeneralSeq2SeqDataset(opts.test_data, opts.use_profile, task, prompt_generator)\n        compute_metrics = create_metric_mae_rmse(tokenizer = tokenizer, all_labels = labels)\n        best_metric = \"mae\"\n        greater_is_better = False\n    elif task == \"LaMP-4\":\n        train_dataset = GeneralSeq2SeqDataset(opts.train_data, opts.use_profile, task, prompt_generator)\n        eval_dataset = GeneralSeq2SeqDataset(opts.validation_data, opts.use_profile, task, prompt_generator)\n        if opts.test_data:\n            test_dataset = GeneralSeq2SeqDataset(opts.test_data, opts.use_profile, task, prompt_generator)\n        compute_metrics = create_metric_bleu_rouge_meteor(tokenizer = tokenizer)\n        best_metric = \"rouge-1\"\n    elif task == \"LaMP-5\":\n        train_dataset = GeneralSeq2SeqDataset(opts.train_data, opts.use_profile, task, prompt_generator)\n        eval_dataset = GeneralSeq2SeqDataset(opts.validation_data, opts.use_profile, task, prompt_generator)\n        if opts.test_data:\n            test_dataset = GeneralSeq2SeqDataset(opts.test_data, opts.use_profile, task, prompt_generator)\n        compute_metrics = create_metric_bleu_rouge_meteor(tokenizer = tokenizer)\n        best_metric = \"rouge-1\"\n    elif task == \"LaMP-7\":\n        train_dataset = GeneralSeq2SeqDataset(opts.train_data, opts.use_profile, task, prompt_generator)\n        eval_dataset = GeneralSeq2SeqDataset(opts.validation_data, opts.use_profile, task, prompt_generator)\n        if opts.test_data:\n            test_dataset = GeneralSeq2SeqDataset(opts.test_data, opts.use_profile, task, prompt_generator)\n        compute_metrics = create_metric_bleu_rouge_meteor(tokenizer = tokenizer)\n        best_metric = \"rouge-1\"\n    elif task == \"LaMP-6\":\n        train_dataset = GeneralSeq2SeqDataset(opts.train_data, opts.use_profile, task, prompt_generator)\n        eval_dataset = GeneralSeq2SeqDataset(opts.validation_data, opts.use_profile, task, prompt_generator)\n        if opts.test_data:\n            test_dataset = GeneralSeq2SeqDataset(opts.test_data, opts.use_profile, task, prompt_generator)\n        compute_metrics = create_metric_bleu_rouge_meteor(tokenizer = tokenizer)\n        best_metric = \"rouge-1\"\n    \n    train_dataset = convert_to_hf_dataset(train_dataset, cache_dir = opts.cache_dir).map(create_preprocessor(tokenizer = tokenizer, max_length = opts.max_length), batched=True)\n    eval_dataset = convert_to_hf_dataset(eval_dataset, cache_dir = opts.cache_dir).map(create_preprocessor(tokenizer = tokenizer, max_length = opts.max_length), batched=True)\n    if opts.test_data:\n        test_dataset = convert_to_hf_dataset(test_dataset, cache_dir = opts.cache_dir).map(create_preprocessor(tokenizer = tokenizer, max_length = opts.max_length), batched=True)\n\n    if contriver:\n        contriver = contriver.to(\"cpu\")\n\n    training_args = Seq2SeqTrainingArguments(\n        output_dir = opts.output_dir,\n        do_train = True,\n        do_eval = True,\n        evaluation_strategy = \"epoch\",\n        per_device_train_batch_size = opts.per_device_batch_size,\n        per_device_eval_batch_size = opts.per_device_batch_size,\n        gradient_accumulation_steps = opts.gradient_accumulation_steps,\n        learning_rate = opts.learning_rate,\n        weight_decay = opts.weight_decay,\n        num_train_epochs = opts.num_train_epochs,\n        lr_scheduler_type = opts.lr_scheduler_type,\n        warmup_ratio = opts.warmup_ratio,\n        generation_num_beams = opts.generation_num_beams,\n        predict_with_generate = True,\n        save_strategy = \"epoch\",\n        logging_steps = 50,\n        eval_accumulation_steps = 1,\n        generation_max_length = opts.generation_max_length,\n        load_best_model_at_end = True,\n        metric_for_best_model = best_metric,\n        greater_is_better = greater_is_better\n    )\n\n    trainer = Seq2SeqTrainer(\n        model = model,\n        args = training_args,\n        data_collator = collator,\n        train_dataset = train_dataset,\n        eval_dataset = eval_dataset,\n        tokenizer = tokenizer,\n        compute_metrics = compute_metrics\n    )\n\n    trainer.train()\n\n    if opts.test_data:\n        results = trainer.evaluate(test_dataset)\n        print(results)\n\n        with open(os.join(opts.output_dir,'results_output.json'), 'w') as file:\n            json.dump(results, file, indent = 4)"
  },
  {
    "path": "LaMP/utils/merge_with_rank.py",
    "content": "import json \nimport argparse\n\ndef merge(inps, outs, ranks):\n    for inp in inps:\n        for o in outs:\n            if o['id'] == inp['id']:\n                output = o['output']\n                break\n        new_profile = []\n        for x in ranks[inp['id']]:\n            for y in inp['profile']:\n                if y['id'] == x:\n                    new_profile.append(y)\n                    break\n        inp['profile'] = new_profile\n        inp['output'] = output\n    return inps\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument(\"--lamp_questions_addr\", required = True)\nparser.add_argument(\"--lamp_output_addr\", required = True)\nparser.add_argument(\"--merged_output_addr\", required = True)\nparser.add_argument(\"--profile_ranking_addr\", default=\"\")\n\nif __name__ == \"__main__\":\n    opts = parser.parse_args()\n    q_addr = opts.lamp_questions_addr\n    o_addr = opts.lamp_output_addr\n    rank_addr = opts.profile_ranking_addr\n    res_addr = opts.merged_output_addr\n\n    with open(q_addr) as qfile:\n        inp = json.load(qfile)\n    with open(o_addr) as oflie:\n        out = json.load(oflie)\n    if rank_addr:\n        with open(rank_addr) as rflie:\n            rank = json.load(rflie)\n    else:\n        rank = dict()\n        for data in inp:\n            rank[data['id']] = []\n            for item in data['profile']:\n                rank[data['id']].append(item['id'])\n\n    with open(res_addr, \"w\") as resfile:\n        res = merge(inp, out, rank)\n        json.dump(res, resfile, indent=4)\n\n\n"
  },
  {
    "path": "PEFT/data/datasets.py",
    "content": "from torch.utils.data import Dataset\nimport json\nimport datasets\nimport torch\nimport random\nfrom itertools import combinations\n\ndef sublists_between_2_and_k(lst, k):\n    sublists = []\n    for size in range(2, k+1):  # Iterate through sizes from 2 to k\n        for comb in combinations(lst, size):\n            sublists.append(list(comb))\n    return sublists\n\ndef sample_sublists(lst, k, num_samples):\n    sublists = []\n    for i in range(k+1, len(lst)):\n        sub = list(random.sample(lst[:i], k-1))\n        sub.sort(key=lambda x: x['date'])\n        sub += [lst[i]]\n        sublists.append(sub)\n    while len(sublists) < num_samples:\n        idx = random.randint(k+1, len(lst) - 1)\n        sub = list(random.sample(lst[:idx], k-1))\n        sub.sort(key=lambda x: x['date'])\n        sub += [lst[idx]]\n        sublists.append(sub)\n    return sublists\n\ndef get_all_labels(task):\n    if task == \"classification_citation\":\n        return [\"[1]\",\"[2]\"]\n    elif task == \"classification_news\":\n        return [\"food & drink\", \"sports\", \"education\", \"parents\", \"religion\", \"travel\", \"business\", \"crime\", \"science & technology\", \"culture & arts\", \"entertainment\", \"politics\", \"women\", \"style & beauty\", \"healthy living\"]\n    elif task == \"classification_movies\":\n        return ['sci-fi', 'based on a book', 'comedy', 'action', 'twist ending', 'dystopia', 'dark comedy', 'classic', 'psychology', 'fantasy', 'romance', 'thought-provoking', 'social commentary', 'violence', 'true story']\n    elif task == \"classification_review\":\n        return [\"1\", \"2\", \"3\", \"4\", \"5\"]\n    elif task == \"generation_news\":\n        return []\n    elif task == \"generation_paper\":\n        return []\n    elif task == \"paraphrase_paper\":\n        return []\n\ndef create_preprocessor(tokenizer, max_length):\n    def preprocess_dataset(examples):\n        inputs = [example for example in examples[\"source\"]]\n        targets = [example for example in examples[\"target\"]]\n        model_inputs = tokenizer(inputs, text_target=targets, max_length=max_length, truncation=True)\n        return model_inputs\n    return preprocess_dataset\n\ndef convert_to_hf_dataset(dataset, cache_dir):\n    def gen():\n        for idx in range(len(dataset)):\n            yield dataset[idx]\n    return datasets.Dataset.from_generator(gen, cache_dir = cache_dir)\n\ndef create_input_output_gen_func(task):\n    if task == \"LaMP-1\":\n        def func(item):\n            inp = f\"Write an abstract for this title: {item['title']}\"\n            out = f'{item[\"abstract\"]}'\n            return inp, out\n    elif task == \"LaMP-2\":\n        def func(item):\n            inp = f\"Which tag does this movie relate to among the following tags? Just answer with the tag name without further explanation. tags: [sci-fi, based on a book, comedy, action, twist ending, dystopia, dark comedy, classic, psychology, fantasy, romance, thought-provoking, social commentary, violence, true story] description: {item['description']}\"\n            out = f'{item[\"tag\"]}'\n            return inp, out\n    elif task == \"LaMP-3\":\n        def func(item):\n            inp = f\"What is the score of the following review on a scale of 1 to 5? just answer with 1, 2, 3, 4, or 5 without further explanation. review: {item['text']}\"\n            out = f'{item[\"score\"]}'\n            return inp, out\n    elif task == \"LaMP-4\":\n        def func(item):\n            inp = f\"Generate a headline for the following article: {item['text']}\"\n            out = f'{item[\"title\"]}'\n            return inp, out\n    elif task == \"LaMP-5\":\n        def func(item):\n            inp = f\"Generate a title for the following abstract of a paper: {item['abstract']}\"\n            out = f'{item[\"title\"]}'\n            return inp, out\n    elif task == \"LaMP-6\":\n        def func(item):\n            inp = f\"Generate a subject for the following email: {item['text']}\"\n            out = f'{item[\"title\"]}'\n            return inp, out\n    elif task == \"LaMP-7\":\n        def func(item):\n            percent = random.uniform(0.1, 0.25)\n            tweet_words = item['text'].split()\n            index = int(len(tweet_words) * percent)\n            in_inp = \" \".join(tweet_words[:index])\n            in_out = \" \".join(tweet_words[index:])\n            inp = f\"Complete the following tweet: {in_inp}\"\n            out = f'{in_out}'\n            return inp, out\n    return func\n\n\ndef create_per_user_dataset(data_addr, user_ids, task, cache_dir):\n    with open(data_addr) as file:\n        orig_dataset = json.load(file)\n    seen_users = set()\n    datasets = dict()\n    input_output_gen_func = create_input_output_gen_func(task)\n    for data in orig_dataset:\n        uid = str(data['user_id'])\n        if user_ids is not None and uid not in user_ids:\n            continue\n        if uid in seen_users:\n            continue\n        else:\n            seen_users.add(uid)\n        cur_dataset = []\n        for i, item in enumerate(data['profile']):\n            id = f'{uid}-{data[\"id\"]}-{i}'\n            inp, out = input_output_gen_func(item)\n            cur_dataset.append(\n                {\n                    \"id\" : id,\n                    \"input\" : inp,\n                    \"output\" : out\n                }\n            )\n        datasets[uid] = convert_to_hf_dataset(GeneralSeq2SeqDataset(cur_dataset), cache_dir)\n    return datasets\n\ndef create_per_user_dataset_test(data_addr, user_ids, task, cache_dir):\n    with open(data_addr) as file:\n        orig_dataset = json.load(file)\n    seen_users = set()\n    datasets = dict()\n    for data in orig_dataset:\n        uid = str(data['user_id'])\n        if user_ids is not None and uid not in user_ids:\n            continue\n        elif uid not in seen_users:\n            seen_users.add(uid)\n            datasets[uid] = []\n        \n        datasets[uid].append(\n            {\n                \"id\" : data[\"id\"],\n                \"input\" : data[\"input\"],\n                \"output\" : data[\"output\"]\n            }\n        )\n    \n    for key, value in datasets.items():\n        datasets[key] = convert_to_hf_dataset(GeneralSeq2SeqDataset(value), cache_dir)\n    return datasets\n\nclass GeneralSeq2SeqDataset(Dataset):\n\n    def __init__(self, data) -> None:\n        super().__init__()\n        self.data = data\n\n    def __getitem__(self, index):\n        return {\n            \"id\" : self.data[index]['id'],\n            \"source\" : self.data[index]['input'],\n            \"target\" : self.data[index]['output']\n        }\n    \n    def __len__(self):\n        return len(self.data)"
  },
  {
    "path": "PEFT/evaluate_llm.py",
    "content": "import argparse\nfrom transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments\nfrom transformers.data.data_collator import DataCollatorForSeq2Seq\nfrom data.datasets import create_per_user_dataset_test, create_preprocessor\nimport json\nfrom peft import get_peft_config, get_peft_model, LoraConfig, TaskType\nimport os\nimport torch\nimport glob\nimport numpy as np\n\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument(\"--test_data\", required = True)\nparser.add_argument(\"--user_ids\", default = \"\")\nparser.add_argument(\"--golds_addr\", required = True)\nparser.add_argument(\"--task\", required = True)\nparser.add_argument(\"--user_checkpoints\", required = True)\nparser.add_argument(\"--output_dir\", required = True)\nparser.add_argument(\"--max_length\", type = int, default = 512)\nparser.add_argument(\"--num_shards\", type = int, default = 1)\nparser.add_argument(\"--shard_id\", type = int, default = 0)\nparser.add_argument(\"--generation_max_length\", type = int, default = 128)\nparser.add_argument(\"--per_device_batch_size\", type = int, default = 16)\nparser.add_argument(\"--generation_num_beams\", type = int, default = 4)\nparser.add_argument(\"--cache\", default=\"./cache\")\n\nif __name__ == \"__main__\":\n    opts = parser.parse_args()\n    print(opts)\n    with open(opts.user_ids) as file:\n        all_user_ids = [str(x) for x in json.load(file)]\n        shard_size = len(all_user_ids) // opts.num_shards + 1\n        user_ids = all_user_ids[int(opts.shard_id * shard_size):int((opts.shard_id + 1) * shard_size)]\n    \n    user_datasets = create_per_user_dataset_test(opts.test_data, user_ids, opts.task, opts.cache)\n\n    model_name_or_path = \"google/flan-t5-xxl\"\n    model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, cache_dir = opts.cache)\n    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir = opts.cache)  \n    processor = create_preprocessor(tokenizer = tokenizer, max_length = opts.max_length)\n    collator = DataCollatorForSeq2Seq(tokenizer = tokenizer, model = model, max_length = opts.max_length)\n    peft_config = LoraConfig(\n        task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1, target_modules=['q','v','k']\n    )\n    model = get_peft_model(model, peft_config)\n\n    final_outputs = []\n\n    for key, dataset in user_datasets.items():\n        model.unload()\n        checkpoitns = glob.glob(os.path.join(opts.user_checkpoints, 'adaptors', key, '*'))\n        if len(checkpoitns) > 0:\n            checkpoint_addr = checkpoitns[0]\n            print(checkpoint_addr)\n            model.load_adapter(checkpoint_addr, key)\n            model.set_adapter(key)\n        encoded_dataset = dataset.map(processor, batched=True)\n        \n        training_args = Seq2SeqTrainingArguments(\n            output_dir = opts.output_dir,\n            do_train = False,\n            do_eval = True,\n            per_device_train_batch_size = opts.per_device_batch_size,\n            generation_max_length = opts.generation_max_length,\n            generation_num_beams = opts.generation_num_beams,\n            predict_with_generate=True,\n            eval_accumulation_steps = 1\n        )\n\n        trainer = Seq2SeqTrainer(\n            model = model,\n            args = training_args,\n            data_collator = collator,\n            train_dataset = encoded_dataset,\n            tokenizer = tokenizer\n        )\n\n        preds = trainer.predict(encoded_dataset).predictions\n        preds = np.where(preds != -100, preds, tokenizer.pad_token_id)\n        preds = tokenizer.batch_decode(preds, skip_special_tokens = True)\n        for data, pred in zip(dataset, preds):\n            final_outputs.append(\n                {\n                    \"id\" : data['id'],\n                    \"output\" : pred\n                }\n            )\n    prediction_addr = os.path.join(opts.output_dir, 'predictions.json')\n\n    with open(prediction_addr, 'w') as file:\n        json.dump(\n            {\n                \"task\" : opts.task,\n                \"golds\" : final_outputs\n            },\n            file,\n            indent=4\n        )"
  },
  {
    "path": "PEFT/requirements.txt",
    "content": "datasets==2.8.0\nregex==2022.10.31\nsentencepiece==0.1.97\ntokenizers==0.11.1\ntorch==2.0.1\ntqdm==4.64.1\ntransformers==4.28.0\nevaluate\nabsl-py\nrouge-score\npeft"
  },
  {
    "path": "PEFT/train_peft.py",
    "content": "import argparse\nfrom transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments\nfrom transformers.data.data_collator import DataCollatorForSeq2Seq\nfrom data.datasets import create_per_user_dataset, create_preprocessor\nimport json\nfrom peft import get_peft_config, get_peft_model, LoraConfig, TaskType\nimport os\nimport torch\n\n\ndef is_directory_empty(path):\n    if os.path.exists(path):\n        if not os.listdir(path):\n            return True\n        else:\n            return False\n    else:\n        return True\n\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument(\"--train_data\", required = True)\nparser.add_argument(\"--user_ids\", default=\"\")\nparser.add_argument(\"--task\", required = True)\nparser.add_argument(\"--output_dir\", required = True)\nparser.add_argument(\"--max_length\", type = int, default = 512)\nparser.add_argument(\"--num_shards\", type = int, default = 1)\nparser.add_argument(\"--shard_id\", type = int, default = 0)\nparser.add_argument(\"--generation_max_length\", type = int, default = 512)\nparser.add_argument(\"--per_device_batch_size\", type = int, default = 16)\nparser.add_argument(\"--learning_rate\", type = float, default = 5e-5)\nparser.add_argument(\"--weight_decay\", type = float, default = 0.0001)\nparser.add_argument(\"--num_train_epochs\", type = int, default = 30)\nparser.add_argument(\"--lora_r\", type = int, default = 8)\nparser.add_argument(\"--lr_scheduler_type\", default = \"linear\")\nparser.add_argument(\"--warmup_ratio\", type = float, default = 0.05)\nparser.add_argument(\"--gradient_accumulation_steps\", type = int, default = 1)\nparser.add_argument(\"--cache\", default=\"./cache\")\n\n\n\nif __name__ == \"__main__\":\n\n    opts = parser.parse_args()\n    print(opts)\n    if opts.user_ids:\n        with open(opts.user_ids) as file:\n            all_user_ids = [str(x) for x in json.load(file)]\n            shard_size = len(all_user_ids) // opts.num_shards + 1\n            user_ids = all_user_ids[int(opts.shard_id * shard_size):int((opts.shard_id + 1) * shard_size)]\n    else:\n        user_ids = None\n    user_datasets = create_per_user_dataset(opts.train_data, user_ids, opts.task, opts.cache)\n\n    model_name_or_path = \"google/flan-t5-xxl\"\n    model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, cache_dir = opts.cache)\n    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir = opts.cache)  \n    processor = create_preprocessor(tokenizer = tokenizer, max_length = opts.max_length)\n    collator = DataCollatorForSeq2Seq(tokenizer = tokenizer, model = model, max_length = opts.max_length)\n    for key, dataset in user_datasets.items():\n        print(key)\n        if not is_directory_empty(os.path.join(opts.output_dir, 'adaptors', key)):\n            continue\n        peft_config = LoraConfig(\n            task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=opts.lora_r, lora_alpha=32, lora_dropout=0.1, target_modules=['q','v','k']\n        )\n        model = get_peft_model(model, peft_config)\n        encoded_dataset = dataset.map(processor, batched=True)\n        training_args = Seq2SeqTrainingArguments(\n            output_dir = os.path.join(opts.output_dir, 'adaptors', key),\n            do_train = True,\n            do_eval = False,\n            per_device_train_batch_size = opts.per_device_batch_size,\n            gradient_accumulation_steps = opts.gradient_accumulation_steps,\n            learning_rate = opts.learning_rate,\n            weight_decay = opts.weight_decay,\n            num_train_epochs = opts.num_train_epochs,\n            lr_scheduler_type = opts.lr_scheduler_type,\n            warmup_ratio = opts.warmup_ratio,\n            save_strategy = \"epoch\",\n            save_total_limit=1,\n            logging_steps = 10,\n            generation_max_length = opts.generation_max_length,\n            save_only_model = True\n        )\n\n        trainer = Seq2SeqTrainer(\n            model = model,\n            args = training_args,\n            data_collator = collator,\n            train_dataset = encoded_dataset,\n            tokenizer = tokenizer\n        )\n\n        trainer.train()\n        model.unload()\n        trainer = None\n        torch.cuda.empty_cache()\n        \n        "
  },
  {
    "path": "README.md",
    "content": "# Codes for papers on Large Language Models Personalization (LaMP)\n\n[LaMP: When Large Language Models Meet Personalization](https://arxiv.org/abs/2304.11406)\n\nThis paper highlights the importance of personalization in the current state of natural language understanding and generation and introduces the LaMP benchmark --- a novel benchmark for training and evaluating language models for producing personalized outputs. LaMP offers a comprehensive evaluation framework with diverse language tasks and multiple entries for each user profile. It consists of seven personalized tasks, spanning across three classification and four text generation tasks. We further propose a retrieval augmentation approach that retrieves personalized items from user profiles to construct personalized prompts for large language models. The experiments conducted to establish fine-tuned and zero-shot baseline results for the benchmark conclude that LMs utilizing profile augmentation outperform their counterparts that do not factor in profile information.\n\n```\n@misc{salemi2023lamp,\n      title={La{MP}: When Large Language Models Meet Personalization}, \n      author={Alireza Salemi and Sheshera Mysore and Michael Bendersky and Hamed Zamani},\n      year={2023},\n      eprint={2304.11406},\n      archivePrefix={arXiv},\n      primaryClass={cs.CL}\n}\n```\n\n[Optimization Methods for Personalizing Large Language Models through Retrieval Augmentation](https://arxiv.org/abs/2404.05970)\n\nThis paper studies retrieval-augmented approaches for personalizing large language models (LLMs), which potentially have a substantial impact on various applications and domains. We propose the first attempt to optimize the retrieval models that deliver a limited number of personal documents to large language models for the purpose of personalized generation. We develop two optimization algorithms that solicit feedback from the downstream personalized generation tasks for retrieval optimization--one based on reinforcement learning whose reward function is defined using any arbitrary metric for personalized generation and another based on knowledge distillation from the downstream LLM to the retrieval model. This paper also introduces a pre- and post-generation retriever selection model that decides what retriever to choose for each LLM input. Extensive experiments on diverse tasks from the language model personalization (LaMP) benchmark reveal statistically significant improvements in six out of seven datasets.\n\n```\n@misc{salemi2024optimization,\n      title={Optimization Methods for Personalizing Large Language Models through Retrieval Augmentation}, \n      author={Alireza Salemi and Surya Kallumadi and Hamed Zamani},\n      year={2024},\n      eprint={2404.05970},\n      archivePrefix={arXiv},\n      primaryClass={cs.CL}\n}\n```\n\n[Comparing Retrieval-Augmentation and Parameter-Efficient Fine-Tuning for Privacy-Preserving Personalization of Large Language Models](https://arxiv.org/abs/2409.09510)\n\nPrivacy-preserving methods for personalizing large language models (LLMs) are relatively under-explored. There are two schools of thought on this topic: (1) generating personalized outputs by personalizing the input prompt through retrieval augmentation from the user's personal information (RAG-based methods), and (2) parameter-efficient fine-tuning of LLMs per user that considers efficiency and space limitations (PEFT-based methods). This paper presents the first systematic comparison between two approaches on a wide range of personalization tasks using seven diverse datasets. Our results indicate that RAG-based and PEFT-based personalization methods on average yield 14.92% and 1.07% improvements over the non-personalized LLM, respectively. We find that combining RAG with PEFT elevates these improvements to 15.98%. Additionally, we identify a positive correlation between the amount of user data and PEFT's effectiveness, indicating that RAG is a better choice for cold-start users (i.e., user's with limited personal data).\n\n```\n@misc{salemi2024comparingretrievalaugmentationparameterefficientfinetuning,\n      title={Comparing Retrieval-Augmentation and Parameter-Efficient Fine-Tuning for Privacy-Preserving Personalization of Large Language Models}, \n      author={Alireza Salemi and Hamed Zamani},\n      year={2024},\n      eprint={2409.09510},\n      archivePrefix={arXiv},\n      primaryClass={cs.CL},\n      url={https://arxiv.org/abs/2409.09510}, \n}\n```\n\n## Data\n\nYou can download all the datasets from the links provided [here](https://lamp-benchmark.github.io/download). However, we provided the minimal ids to generate the dataset using our codes for the Personalized Email Subject Generation because this dataset is not publicly accessible. Follow the following section to generate that dataset.\n\n### LaMP 6: Personalized Email Subject Generation (Avocado dataset)\n\nThe [Avocado](https://catalog.ldc.upenn.edu/LDC2015T03) dataset is not publicly accessible. However, we provided the samples' id and the code we used to generate our dataset. Therefore, if you get access to the dataset, you can quickly generate the dataset with the same format as the other datasets in LaMP using the following code:\n\n```\npython data/avocado/create_avocado_dataset.py \\\n    --avocado_files_dir \\*Address to the directory containing zip files for avocado dataset 'avocado-1.0.2/data/text'*\\ \\\n    --extract_addr \\*A temp dir to extract the files for creating dataset*\\ \\\n    --output_dir \\*The directory to generate the final dataset*\\ \\\n    --input_question_file_train \\*The address to the train_questions.json file we provided in LaMP*\\ \\\n    --input_question_file_dev \\*The address to the dev_questions.json file we provided in LaMP*\\ \\\n    --input_question_file_test \\*The address to the test_questions.json file we provided in LaMP*\\\n```\n\n## Evaluation\n\nThe instructions for evaluating your results on the test set are provided [here](https://lamp-benchmark.github.io/leaderboard). In order to evaluate your results on the dev set, we provided an evaluation script that can be found here:\n\n\nEvaluate all tasks together:\n\n```\npython eval/eval_all.py \\\n    --golds_zip /*Address to all gold labels for all tasks zipped in a file*/ \\\n    --preds_zip /*Address to all predictions for all tasks zipped in a file*/ \\\n    --temp_dir /*Address to a temp dir for extracting files*/ \\\n    --output_file /*Address to the results file*/ \\\n```\n\nEvaluate one task:\n\n```\npython eval/eval_task.py \\\n    --golds_json /*Address to gold labels for the task as a json file*/ \\\n    --preds_json /*Address to predictions for the task as a json file*/ \\\n    --task_name /*Name of the task [LaMP_1, LaMP_2, LaMP_3, LaMP_4, LaMP_5, LaMP_6, LaMP_7]*/\n    --output_file /*Address to the results file*/ \\\n```\n\nThe pred files should follow the exact same format as the gold files:\n\n```\n{\n    \"task\" : \"/*task name*/\",\n    \"golds\" : [\n        {\n            \"id\" : \"/*sample 1 id*/\",\n            \"output\" : \"/*output of the model for the first sample*/\"\n        },\n        ...,\n        {\n            \"id\" : \"/*sample n id*/\",\n            \"output\" : \"/*output of the model for the n'th sample*/\"\n        }\n    ]\n}\n```\n\n## Personalizing LLMs with RAG (LaMP)\n\nYou first need to create an environment for this using the following script:\n\n```\npython3 -m venv lamp_venv\nsource lamp_venv/bin/activate\npip install -r LaMP/requirements.txt\n```\n\n### Ranking Profiles based on the Input\n\nThe first step is to sort items in each user profile based on the input for the task:\n\n```\ncd LaMP\npython rank_profiles.py \\\n    --input_data_addr /*input questions for one of the LaMP tasks*/ \\\n    --output_ranking_addr /*output address for the generated ranking file*/ \\\n    --task /*name of the task [LaMP-1, LaMP-2, ..., LaMP-7]*/ \\\n    --ranker /*the ranking model to be used [bm25, contriever, recency]*/ \\\n    [optional] --use_date /*the batch size for ranking*/ \\\n    [optional] --use_date \\ /*if used, it adds time to the text of each profile item*/\n    [optional] --contriever_checkpoint /*address to the Contriever checkpoint to be used*/ \\\n\n```\n\nAfter that, use the following script to sort the profiles in the dataset based on the ranking file:\n\n```\ncd LaMP\npython utils/merge_with_rank.py \\\n    --lamp_questions_addr /*address to the LaMP task inputs file*/ \\\n    --lamp_output_addr /*address to the LaMP task outputs file*/ \\\n    --profile_ranking_addr /*address to the generated ranking file from the previous script*/\n    --merged_output_addr /*address to the sorted dataset using the provided ranking file*/ \\\n\n```\n\n### Training LLM with RAG\n\nThe next step is to train the LLM on a LaMP task:\n\n```\ncd LaMP\npython train_llm.py \\\n    --train_data /*address to sorted training data using the previous step*/ \\\n    --validation_data /*address to sorted validation data using the previous step*/ \\\n    [optional] --test_data /*address to sorted test data using the previous step*/ \\\n    --model_name /*address to the model that should be used for initialization of the LLM*/ \\\n    --task /*name of the task [LaMP-1, LaMP-2, ..., LaMP-7]*/ \\\n    --output_dir /*output directory to save results and checkpoints*/ \\\n    --retriever /*the ranking model to be used [bm25, contriever, recency]*/ \\\n    --use_profile \\ /*used to perfrom personalization with RAG */\n    --is_ranked \\ /*used if you pre-ranked the profiles based on the provided retrieval model*/\n    --num_retrieved /*number of items to be retrieved from the user profile*/ \\ \n```\n\n### Zero-shot Evaluation of LLM with RAG\n\nYou can also evaluate the LLMs with the following script:\n\n```\ncd LaMP\npython evaluate_llm.py \\\n    --validation_data /*address to sorted validation data using the previous step*/ \\\n    --model_addr /*address to the model that should be used for initialization of the LLM*/ \\\n    --task /*name of the task [LaMP-1, LaMP-2, ..., LaMP-7]*/ \\\n    --output_dir /*output directory to save results */ \\\n    --use_profile \\ /*used to perfrom personalization with RAG */\n    --retriever /*the ranking model to be used [bm25, contriever, recency]*/ \\\n    --is_ranked \\ /*used if you pre-ranked the profiles based on the provided retrieval model*/\n    --num_retrieved /*number of items to be retrieved from the user profile*/ \\ \n```\n\n## Optimizing Retrieval Model for Personalizing LLMs (ROPG)\n\nThis code uses the feedback from LLM to train a retrieval model for personalizing the LLM. You first need to create an environment for this using the following script:\n\n```\npython3 -m venv ropg_venv\nsource ropg_venv/bin/activate\npip install -r ROPG/requirements.txt\n```\n\n### Feedback Generation using LLM for Items in the User Profile\n\nThe first step is to collect feedback from the LLM using the following script:\n\n```\ncd LaMP\npython profile_item_utilization_scorer.py \\\n    --train_data /*address to sorted training data using the previous steps*/ \\\n    --model_name /*address to the model that should be used for feedback generation*/ \\\n    --task /*name of the task [LaMP-1, LaMP-2, ..., LaMP-7]*/ \\\n    --output_dir /*output directory to save results */ \\\n    --profile_size /*number of top k items from user profile to get feedback for them*/\n```\n\n### Optimizing Retrieval Model\n\nYou can use the following code to train a retrieval model based on the feedback generated from the previous step.\n\nFor training with ROPG-KD, which uses knowledge distillation, use the following script:\n\n```\ncd ROPG\nNGPU=/*Number of GPUs*/ python -m torch.distributed.launch --nproc_per_node=/*Number of GPUs*/ train_kd.py \\\n    --train_data /*address to sorted training data using the previous steps*/ \\\n    --do_train \\\n    --scores_path /*address to the feedback file generated in the previous step*/\n    --name /*output directory*/ \\\n    --ctx_size /*number of documents to be used for training the retrieval model for each query*/ \\\n    --task /*name of the task [LaMP-1, LaMP-2, ..., LaMP-7]*/ \\\n    --temperature /*temperature for distillation*/\n```\n\nFor training with ROPG-RL, which uses reinforcement learning, use the following script:\n\n```\ncd ROPG\nNGPU=/*Number of GPUs*/ python -m torch.distributed.launch --nproc_per_node=/*Number of GPUs*/ train_rl.py \\\n    --train_data /*address to sorted training data using the previous steps*/ \\\n    --do_train \\\n    --scores_path /*address to the feedback file generated in the previous step*/\n    --name /*output directory*/ \\\n    --ctx_size /*number of documents to be used for training the retrieval model for each query*/ \\\n    --task /*name of the task [LaMP-1, LaMP-2, ..., LaMP-7]*/ \\\n```\n\n## Retrieval Model Selection for Personalizing LLMs (RSPG)\n\nThis section uses the feedback from the LLM based on the performance of different retrieval models to train a retrieval model selector. You first need to create an environment for this using the following script:\n\n```\npython3 -m venv rspg_venv\nsource rspg_venv/bin/activate\npip install -r RSPG/requirements.txt\n```\n\n\n### Feedback Generation using LLM for each Retrieval Model\n\nuse the following code to get the feedback for each retrieval model in the retrieval model pool:\n\n```\ncd LaMP\npython retriever_utilization_scorer.py \\\n    --data_addr /**address to sorted task data using the previous steps**/\n    --model_name /*address to the model that should be used for feedback generation*/ \\\n    --task /*name of the task [LaMP-1, LaMP-2, ..., LaMP-7]*/ \\\n    --output_dir /*output directory to save results */ \\\n    --use_profile \\ /*use only in the case you want the feedback from RAG approach, shouldn't be used when getting feedback from an LLM without RAG*/\n    --num_retrieved /*number of items to be retrieved from the user profile*/ \\ \n    --retriever /*the retriever model that should be used to get feedback for*/ \\\n    --is_ranked \\ /*used if you pre-ranked the profiles based on the provided retrieval model*/\n```\n\nYou should use the following script with all the retrieval models in your retrieval model pool. In our paper we used Contriever, ROPG-RL, ROPG-KD, Recency, BM25, and no retrieval (no RAG).\n\n### Optimizing Retrieval Model Selector\n\nThe first step is to combine all the feebacks got from the previous step and make a training and validation set:\n\n```\ncd RSPG\npython utils/create_data.py \\\n    --retrivers_data_addr \"/*address to feedback 1*/\" \"/*address to feedback 2*/\" ... \"/*address to feedback n*/\" \\\n    --task_inputs_addr /*input questions for one of the LaMP tasks*/ \\\n    --task_outputs_addr /*outputs for one of the LaMP tasks*/ \\\n    --output_dataset_addr /*address to save the created dataset*/\n    --metric /*the metric name that should be used as feedback [accuracy, rouge-1, rouge-l]*/ \\\n```\n\nAfter this, you can use the following script to train the retrieval selector model (RSPG):\n\n```\ncd RSPG\nNGPU=/*Number of GPUs*/ python -m torch.distributed.launch --nproc_per_node=/*Number of GPUs*/ rspg.py \\\n    --train_data /*address to the training data created in the previous step*/ \\\n    --val_data /*address to the validation data created in the previous step*/ \\\n    --rspg_type /*retrieval selection mode: [Pre, Post]*/\n    --val_lamp_golds /*address to the a LaMP task output file for validation set*/ \\\n    --task /*name of the task [LaMP-1, LaMP-2, ..., LaMP-7]*/ \\\n    --do_train \\\n    --name /*output directory*/ \\\n    --temperature /*distillation temperature*/ \\\n```\n\n### Inference with Retrieval Model Selector\n\nIn order to do inference with RSPG, you can use the following script:\n\n```\ncd RSPG\nNGPU=/*Number of GPUs*/ python -m torch.distributed.launch --nproc_per_node=/*Number of GPUs*/ rspg.py \\\n    --train_data /*address to the training data created in the previous step*/ \\\n    --val_data /*address to the validation data created in the previous step*/ \\\n    --rspg_type /*retrieval selection mode: [Pre, Post]*/\n    --val_lamp_golds /*address to the a LaMP task output file for validation set*/ \\\n    --task /*name of the task [LaMP-1, LaMP-2, ..., LaMP-7]*/ \\\n    --do_validation \\\n    --name /*output directory*/ \\\n    --model_path /*address to the checkpoint to be evaluated*/ \\\n```\n\n## PEFT for Personalizing LLMs\n\nThis section trains an LLM per user on its personal data usign LoRA. You first need to create an environment for this using the following script:\n\n```\npython3 -m venv peft_venv\nsource peft_venv/bin/activate\npip install -r PEFT/requirements.txt\n```\n\n### Training LLM using LoRA on personal data\nIn order to train a LoRA adaptor per user, you can use the following script:\n\n```\ncd PEFT\npython train_peft.py \\\n      --train_data /*address to sorted training data using the previous step*/ \\\n      --task /*name of the task [LaMP-1, LaMP-2, ..., LaMP-7]*/ \\\n      --output_dir /*output directory to save per user checkpoints*/ \\\n      --lora_r /*lora r parameter*/\n```\n\n### inference using per user LLM\nIn order to inference using the LoRA adaptor per user, you can use the following script:\n\n```\ncd PEFT\npython evaluate_llm.py \\\n      --test_data /*address to sorted test/validation data using the previous step*/ \\\n      --task /*name of the task [LaMP-1, LaMP-2, ..., LaMP-7]*/ \\\n      --output_dir /*output directory to save outputs*/ \\\n      --user_checkpoints /*directory containing per user checkpoints*/\n```\n\n## Reference\n\nIf you find this repository helpful, please cite the following works!\n\n[LaMP: When Large Language Models Meet Personalization](https://arxiv.org/abs/2304.11406)\n\n```\n@misc{salemi2023lamp,\n      title={La{MP}: When Large Language Models Meet Personalization}, \n      author={Alireza Salemi and Sheshera Mysore and Michael Bendersky and Hamed Zamani},\n      year={2023},\n      eprint={2304.11406},\n      archivePrefix={arXiv},\n      primaryClass={cs.CL}\n}\n```\n\n[Optimization Methods for Personalizing Large Language Models through Retrieval Augmentation](https://arxiv.org/abs/2404.05970)\n\n```\n@misc{salemi2024optimization,\n      title={Optimization Methods for Personalizing Large Language Models through Retrieval Augmentation}, \n      author={Alireza Salemi and Surya Kallumadi and Hamed Zamani},\n      year={2024},\n      eprint={2404.05970},\n      archivePrefix={arXiv},\n      primaryClass={cs.CL}\n}\n```\n\n[Comparing Retrieval-Augmentation and Parameter-Efficient Fine-Tuning for Privacy-Preserving Personalization of Large Language Models](https://arxiv.org/abs/2409.09510)\n\n```\n@misc{salemi2024comparingretrievalaugmentationparameterefficientfinetuning,\n      title={Comparing Retrieval-Augmentation and Parameter-Efficient Fine-Tuning for Privacy-Preserving Personalization of Large Language Models}, \n      author={Alireza Salemi and Hamed Zamani},\n      year={2024},\n      eprint={2409.09510},\n      archivePrefix={arXiv},\n      primaryClass={cs.CL},\n      url={https://arxiv.org/abs/2409.09510}, \n}\n```\n\n## License\nLaMP (codes and data creation methods) is licensed by Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0). See the [CC-BY-NC-SA-4.0.txt](CC-BY-NC-SA-4.0.txt) file for details. For the datasets in this benchmark, you should follow their license.\n\n## Acknowledgments\n\nThis work was supported in part by the Center for Intelligent Information Retrieval, in part by NSF grant #2143434, in part by the Office of Naval Research contract number N000142212688, and in part by Lowe's, in part by an Amazon Research Award, Fall 2022 CFP, in part by an award from Google, and in part by an award from Microsoft. Any opinions, findings and conclusions or recommendations expressed in this material are those of the authors and do not necessarily reflect those of the sponsors.\n"
  },
  {
    "path": "ROPG/data/collators.py",
    "content": "from typing import Any, List\nimport numpy as np\nimport torch\nimport json\n\n\nclass ReaderToRetreieverCollator:\n\n    def __init__(self, tokenizer, query_max_lenght, document_max_length, number_of_ctx, scores_addr = \"\") -> None:\n        self.tokenizer = tokenizer\n        self.query_max_lenght = query_max_lenght\n        self.document_max_length = document_max_length\n        self.number_of_ctx = number_of_ctx\n        self.scores_addr = scores_addr\n        if scores_addr:\n            with open(scores_addr) as file:\n                self.scores = json.load(file)\n    \n    def __call__(self, examples: List[dict]):\n        query_tokens = self.tokenizer([x['query'] for x in examples], max_length = self.query_max_lenght, padding = True, return_tensors = 'pt', truncation = True)\n        docs = []\n        for x in examples:\n            temp_docs = []\n            for y in x['documents'][:self.number_of_ctx]:\n                temp_docs.append(y)\n            while len(temp_docs) < self.number_of_ctx:\n                temp_docs.append(\"\")\n            docs.append(temp_docs)\n        \n        profiles = []\n        for x in examples:\n            temp_docs = []\n            for y in x['profile'][:self.number_of_ctx]:\n                temp_docs.append(y)\n            profiles.append(temp_docs)\n        \n        documents_tokens = self.tokenizer([y for x in docs for y in x], max_length = self.document_max_length, padding = True, return_tensors = 'pt', truncation = True)\n        documents_tokens['input_ids'] = documents_tokens['input_ids'].view(len(examples), self.number_of_ctx, -1)\n        documents_tokens['token_type_ids'] = documents_tokens['token_type_ids'].view(len(examples), self.number_of_ctx, -1)\n        documents_tokens['attention_mask'] = documents_tokens['attention_mask'].view(len(examples), self.number_of_ctx, -1)\n        \n        scores_batch = []\n        for x in examples:\n            scores_sample = []\n            for prof in x['profile'][:self.number_of_ctx]:\n                score = self.scores[f'{x[\"qid\"]}-{prof[\"id\"]}']\n                scores_sample.append(score)\n            scores_batch.append(scores_sample)\n        # print(scores_batch)\n        target_txt = [x['target'] for x in examples]       \n        ctxs = torch.tensor([[len(x['documents'][:self.number_of_ctx])] for x in examples])\n        return {\n            \"query_input_ids\" : query_tokens['input_ids'],\n            \"query_token_type_ids\" : query_tokens['token_type_ids'],\n            \"query_attention_mask\" : query_tokens['attention_mask'],\n            \"documents_input_ids\" : documents_tokens['input_ids'],\n            \"documents_token_type_ids\" : documents_tokens['token_type_ids'],\n            \"documents_attention_mask\" : documents_tokens['attention_mask'],\n            \"documents_ctxs\" : ctxs,\n            \"batch_docs_text\" : profiles,\n            \"batch_questions_text\" : [x['query_raw'] for x in examples],\n            \"target_txt\" : target_txt,\n            \"scores_gold\" : scores_batch\n        }\n"
  },
  {
    "path": "ROPG/data/datasets.py",
    "content": "from torch.utils.data import Dataset\nimport json\nimport datasets\n\ndef get_all_labels(task):\n    if task == \"LaMP-1\":\n        return [\"[1]\",\"[2]\"]\n    elif task == \"LaMP-2\":\n        return ['sci-fi', 'based on a book', 'comedy', 'action', 'twist ending', 'dystopia', 'dark comedy', 'classic', 'psychology', 'fantasy', 'romance', 'thought-provoking', 'social commentary', 'violence', 'true story']\n    elif task == \"LaMP-3\":\n        return [\"1\", \"2\", \"3\", \"4\", \"5\"]\n    else:\n        return []\n\ndef create_preprocessor(tokenizer, max_length):\n    def preprocess_dataset(examples):\n        inputs = [example for example in examples[\"source\"]]\n        targets = [example for example in examples[\"target\"]]\n        model_inputs = tokenizer(inputs, text_target=targets, max_length=max_length, truncation=True)\n        return model_inputs\n    return preprocess_dataset\n\ndef create_preprocessor_chatgpt(tokenizer, max_length):\n    def preprocess_dataset(examples):\n        inputs = [example for example in examples[\"source\"]]\n        targets = [example for example in examples[\"target\"]]\n        model_inputs = tokenizer(inputs, text_target=targets, max_length=max_length, truncation=True)\n        model_inputs = tokenizer.batch_decode(model_inputs['input_ids'], skip_special_tokens=True)\n        return {\"chatgpt_inputs\" : model_inputs}\n    return preprocess_dataset\n\ndef convert_to_hf_dataset(dataset):\n    def gen():\n        for idx in range(len(dataset)):\n            yield dataset[idx]\n    return datasets.Dataset.from_generator(gen)\n\nclass GeneralSeq2SeqDataset(Dataset):\n\n    def __init__(self, data_addr, use_profile, task, create_prompt = None) -> None:\n        super().__init__()\n        with open(data_addr) as file:\n            self.data = json.load(file)\n        self.use_profile = use_profile\n        self.task = task\n        assert not (use_profile ^ (create_prompt != None)), \"You should provide a prompt maker function when you use profile\"\n        self.create_prompt = create_prompt\n\n    def __getitem__(self, index):\n        if self.use_profile:\n            return {\n                \"source\" : self.create_prompt(self.data[index]['input'], self.data[index]['profile'], self.task),\n                \"target\" : self.data[index]['output']\n            }\n        else:\n            return {\n                \"source\" : self.data[index]['input'],\n                \"target\" : self.data[index]['output']\n            }\n    \n    def __len__(self):\n        return len(self.data)\n\nclass ReaderToRetrieverDataset(Dataset):\n\n    def __init__(self, data_addr, task, create_query_corpus, is_llama = False) -> None:\n        super().__init__()\n        with open(data_addr) as file:\n            self.data = json.load(file)\n        self.task = task\n        self.create_query_corpus = create_query_corpus\n        self.is_llama = is_llama\n\n    def __getitem__(self, index):\n        query, corpus = self.create_query_corpus(self.data[index]['input'], self.data[index]['profile'])\n        return {\n            \"qid\" : self.data[index]['id'],\n            \"query_raw\" : self.data[index]['input'] + \" answer:\" if self.is_llama else \"\",\n            \"query\" : query,\n            \"documents\" : corpus,\n            \"profile\" : self.data[index]['profile'],\n            \"target\" : self.data[index]['output']\n        }\n    \n    def __len__(self):\n        return len(self.data)"
  },
  {
    "path": "ROPG/models/optim.py",
    "content": "import torch\n\nclass WarmupLinearScheduler(torch.optim.lr_scheduler.LambdaLR):\n    def __init__(self, optimizer, warmup_steps, scheduler_steps, min_ratio, fixed_lr, last_epoch=-1):\n        self.warmup_steps = warmup_steps\n        self.scheduler_steps = scheduler_steps\n        self.min_ratio = min_ratio\n        self.fixed_lr = fixed_lr\n        super(WarmupLinearScheduler, self).__init__(\n            optimizer, self.lr_lambda, last_epoch=last_epoch\n        )\n\n    def lr_lambda(self, step):\n        if step < self.warmup_steps:\n            return (1 - self.min_ratio)*step/float(max(1, self.warmup_steps)) + self.min_ratio\n\n        if self.fixed_lr:\n            return 1.0\n\n        return max(0.0,\n            1.0 + (self.min_ratio - 1) * (step - self.warmup_steps)/float(max(1.0, self.scheduler_steps - self.warmup_steps)),\n        )\n\n\nclass FixedScheduler(torch.optim.lr_scheduler.LambdaLR):\n    \n    def __init__(self, optimizer, last_epoch=-1):\n        super(FixedScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)\n   \n    def lr_lambda(self, step):\n        return 1.0\n\n\ndef set_optim(opt, model):\n    if opt.optim == 'adam':\n        optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)\n    elif opt.optim == 'adamw':\n        optimizer = torch.optim.AdamW(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)\n    if opt.scheduler == 'fixed':\n        scheduler = FixedScheduler(optimizer)\n    elif opt.scheduler == 'linear':\n        if opt.scheduler_steps is None:\n            scheduler_steps = opt.total_steps\n        else:\n            scheduler_steps = opt.scheduler_steps\n        scheduler = WarmupLinearScheduler(optimizer, warmup_steps=opt.warmup_steps, scheduler_steps=scheduler_steps, min_ratio=0., fixed_lr=opt.fixed_lr)\n    return optimizer, scheduler"
  },
  {
    "path": "ROPG/models/retriever.py",
    "content": "from typing import Any\nfrom transformers import BertModel\nfrom transformers.configuration_utils import PretrainedConfig\nimport torch\n\nclass Contriever(BertModel):\n    def __init__(self, config, pooling=\"average\", **kwargs):\n        super().__init__(config, add_pooling_layer=False)\n        if not hasattr(config, \"pooling\"):\n            self.config.pooling = pooling\n\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        token_type_ids=None,\n        position_ids=None,\n        head_mask=None,\n        inputs_embeds=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        normalize=False,\n    ):\n\n        model_output = super().forward(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            token_type_ids=token_type_ids,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=encoder_attention_mask,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n        )\n\n        last_hidden = model_output[\"last_hidden_state\"]\n        last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0)\n\n        if self.config.pooling == \"average\":\n            emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]\n        elif self.config.pooling == \"cls\":\n            emb = last_hidden[:, 0]\n\n        if normalize:\n            emb = torch.nn.functional.normalize(emb, dim=-1)\n        return emb\n\n\n\n"
  },
  {
    "path": "ROPG/prompts/contriever_retriever.py",
    "content": "import torch\nfrom prompts.utils import batchify\n\ndef mean_pooling(token_embeddings, mask):\n    token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)\n    sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]\n    return sentence_embeddings\n\ndef retrieve_top_k_with_contriever(contriver, tokenizer, corpus, profile, query, k):\n    query_tokens = tokenizer([query], padding=True, truncation=True, return_tensors='pt').to(\"cuda:0\")\n    output_query = contriver(**query_tokens)\n    output_query = mean_pooling(output_query.last_hidden_state, query_tokens['attention_mask'])\n    batch_size = 4\n    scores = []\n    batched_corpus = batchify(corpus, batch_size)\n    for batch in batched_corpus:\n        tokens_batch = tokenizer(batch, padding=True, truncation=True, return_tensors='pt').to(\"cuda:0\")\n        outputs_batch = contriver(**tokens_batch)\n        outputs_batch = mean_pooling(outputs_batch.last_hidden_state, tokens_batch['attention_mask'])\n        temp_scores = output_query.squeeze() @ outputs_batch.T\n        scores.extend(temp_scores.tolist())\n    topk_values, topk_indices = torch.topk(torch.tensor(scores), k)\n    return [profile[m] for m in topk_indices.tolist()]\n"
  },
  {
    "path": "ROPG/prompts/prompts.py",
    "content": "from rank_bm25 import BM25Okapi\nfrom transformers import AutoTokenizer, AutoModel\nfrom prompts.utils import extract_strings_between_quotes, extract_after_article, extract_after_review, extract_after_paper, add_string_after_title, extract_after_colon, extract_after_abstract, extract_after_description\nfrom prompts.contriever_retriever import retrieve_top_k_with_contriever\nimport random\n\ndef classification_citation_query_corpus_maker(inp, profile):\n    corpus = [f'{x[\"title\"]} {x[\"abstract\"]} date: {x[\"date\"]}' for x in profile]\n    extracted = extract_strings_between_quotes(inp)\n    query = f'{extracted[1]} {extracted[2]}'\n    return corpus, query\n\ndef classification_news_query_corpus_maker(inp, profile):\n    corpus = [f'{x[\"title\"]} {x[\"text\"]} date: {x[\"date\"]}' for x in profile]\n    query = extract_after_article(inp)\n    return corpus, query\n\ndef classification_movies_query_corpus_maker(inp, profile):\n    corpus = [f'{x[\"description\"]} date: {x[\"date\"]}' for x in profile]\n    query = extract_after_description(inp)\n    return corpus, query\n\ndef classification_review_query_corpus_maker(inp, profile):\n    corpus = [f'{x[\"text\"]} date: {x[\"date\"]}' for x in profile]\n    query = extract_after_review(inp)\n    return corpus, query\n\ndef generation_news_query_corpus_maker(inp, profile):\n    corpus = [f'{x[\"title\"]} {x[\"text\"]} date: {x[\"date\"]}' for x in profile]\n    query = extract_after_article(inp)\n    return corpus, query\n\ndef generation_paper_query_corpus_maker(inp, profile):\n    corpus = [f'{x[\"title\"]} {x[\"abstract\"]} date: {x[\"date\"]}' for x in profile]\n    query = extract_after_paper(inp)\n    return corpus, query\n\ndef generation_paper_long_query_corpus_maker(inp, profile):\n    corpus = [f'{x[\"title\"]} {x[\"abstract\"]} date: {x[\"date\"]}' for x in profile]\n    query = extract_after_abstract(inp)\n    return corpus, query\n\n\ndef parphrase_tweet_query_corpus_maker(inp, profile):\n    corpus = [f'{x[\"text\"]} date: {x[\"date\"]}' for x in profile]\n    query = extract_after_colon(inp)\n    return corpus, query\n\ndef generation_avocado_query_corpus_maker(inp, profile):\n    corpus = [f'{x[\"text\"]} date: {x[\"date\"]}' for x in profile]\n    query = extract_after_colon(inp)\n    return corpus, query\n\ndef generation_avocado_long_query_corpus_maker(inp, profile):\n    corpus = [f'{x[\"text\"]} {x[\"title\"]} date: {x[\"date\"]}' for x in profile]\n    query = extract_after_colon(inp)\n    return corpus, query\n\ndef create_classification_citation_prompt(inp, profile, max_length, tokenizer):\n    prompts = []\n    per_p_max_length = (max_length - 2 * (len(profile) - 1)) // len(profile)\n    saved_tokens = 0\n    for p in profile:\n        tokens = tokenizer(p[\"title\"], max_length=per_p_max_length + saved_tokens - 2, truncation=True)\n        saved_tokens += per_p_max_length - len(tokens['input_ids']) - 2\n        new_title = tokenizer.batch_decode([tokens['input_ids']], skip_special_tokens=True)[0]\n        prompt = f'\"{new_title}\"'\n        prompts.append(prompt)\n    return add_string_after_title(inp, \", and \".join(prompts))\n\ndef create_classification_news_prompt(inp, profile, max_length, tokenizer): # good\n    per_p_max_length = (max_length - 1 - 2 * (len(profile) - 1)) // len(profile)\n    saved_tokens = 0\n    prompts = []\n    for p in profile:\n        needed_part_len = len(tokenizer(f'the category for the article: \" \" is \"{p[\"category\"]}\" ')['input_ids'])\n        tokens = tokenizer(p[\"text\"], max_length=per_p_max_length + saved_tokens - needed_part_len, truncation=True)\n        saved_tokens += per_p_max_length - len(tokens['input_ids']) - needed_part_len\n        new_text = tokenizer.batch_decode([tokens['input_ids']], skip_special_tokens=True)[0]\n        prompt = f'the category for the article: \"{new_text}\" is \"{p[\"category\"]}\" '\n        prompts.append(prompt)\n    return f'{\", and \".join(prompts)}. {inp}'\n\ndef create_classification_review_prompt(inp, profile, max_length, tokenizer):\n    per_p_max_length = (max_length - 1 - 2 * (len(profile) - 1)) // len(profile)\n    saved_tokens = 0\n    prompts = []\n    for p in profile:\n        needed_part_len = len(tokenizer(f'{p[\"score\"]} is the score for \" \" ')['input_ids'])\n        tokens = tokenizer(p[\"text\"], max_length=per_p_max_length + saved_tokens - needed_part_len, truncation=True)\n        saved_tokens += per_p_max_length - len(tokens['input_ids']) - needed_part_len\n        new_text = tokenizer.batch_decode([tokens['input_ids']], skip_special_tokens=True)[0]\n        prompt = f'{p[\"score\"]} is the score for \"{new_text}\" '\n        prompts.append(prompt)\n    return f'{\", and \".join(prompts)}. {inp}'\n\ndef create_generation_news_prompt(inp, profile, max_length, tokenizer):\n    per_p_max_length = (max_length - 1 - 2 * (len(profile) - 1)) // len(profile)\n    saved_tokens = 0\n    prompts = []\n    for p in profile:\n        needed_part_len = len(tokenizer(f'\"{p[\"title\"]}\" is the title for \" \" ')['input_ids'])\n        tokens = tokenizer(p[\"text\"], max_length=per_p_max_length + saved_tokens - needed_part_len, truncation=True)\n        saved_tokens += per_p_max_length - len(tokens['input_ids']) - needed_part_len\n        new_text = tokenizer.batch_decode([tokens['input_ids']], skip_special_tokens=True)[0]\n        prompt = f'\"{p[\"title\"]}\" is the title for \"{new_text}\" '\n        prompts.append(prompt)\n    return f'{\", and \".join(prompts)}. {inp}'\n\ndef create_generation_paper_prompt(inp, profile, max_length, tokenizer):\n    per_p_max_length = (max_length - 1 - 2 * (len(profile) - 1) - len(tokenizer(\"Following the given patterns\")['input_ids'])) // len(profile)\n    saved_tokens = 0\n    prompts = []\n    for p in profile:\n        needed_part_len = len(tokenizer(f'\"{p[\"title\"]}\" is a title \" \" ')['input_ids'])\n        tokens = tokenizer(p[\"abstract\"], max_length=per_p_max_length + saved_tokens - needed_part_len, truncation=True)\n        saved_tokens += per_p_max_length - len(tokens['input_ids']) - needed_part_len\n        new_asbtract = tokenizer.batch_decode([tokens['input_ids']], skip_special_tokens=True)[0]\n        prompt = f'\"{p[\"title\"]}\" is a title for \"{new_asbtract}\" '\n        prompts.append(prompt)\n    return f'{\", and \".join(prompts)}. Following the given patterns {inp}'\n\ndef create_generation_paper_long_prompt(inp, profile, max_length, tokenizer):\n    per_p_max_length = (max_length - 1 - 2 * (len(profile) - 1) - len(tokenizer(\"Following the given patterns\")['input_ids'])) // len(profile)\n    saved_tokens = 0\n    prompts = []\n    for p in profile:\n        needed_part_len = len(tokenizer(f'\"{p[\"title\"]}\" is the title \" \" ')['input_ids'])\n        tokens = tokenizer(p[\"abstract\"], max_length=per_p_max_length + saved_tokens - needed_part_len, truncation=True)\n        saved_tokens += per_p_max_length - len(tokens['input_ids']) - needed_part_len\n        new_asbtract = tokenizer.batch_decode([tokens['input_ids']], skip_special_tokens=True)[0]\n        prompt = f'\"{p[\"title\"]}\" is the title for \"{new_asbtract}\" '\n        prompts.append(prompt)\n    return f'{\", and \".join(prompts)}. Following the given patterns {inp}'\n\ndef create_parphrase_tweet_prompt(inp, profile, max_length, tokenizer):\n    per_p_max_length = (max_length - 1 - 2 * (len(profile) - 1) - len(tokenizer(\"are written by user. Following the given patterns\")['input_ids'])) // len(profile)\n    saved_tokens = 0\n    prompts = []\n    for p in profile:\n        needed_part_len = len(tokenizer(f'\"\" ')['input_ids'])\n        tokens = tokenizer(p[\"text\"], max_length=per_p_max_length + saved_tokens - needed_part_len, truncation=True)\n        saved_tokens += per_p_max_length - len(tokens['input_ids']) - needed_part_len\n        new_asbtract = tokenizer.batch_decode([tokens['input_ids']], skip_special_tokens=True)[0]\n        prompt = f'\"{new_asbtract}\" '\n        prompts.append(prompt)\n    return f'{\", and \".join(prompts)} are written by a person. Following the given patterns {inp}'\n\ndef create_generation_avocado_prompt(inp, profile, max_length, tokenizer):\n    per_p_max_length = (max_length - 1 - 2 * (len(profile) - 1)) // len(profile)\n    saved_tokens = 0\n    prompts = []\n    for p in profile:\n        needed_part_len = len(tokenizer(f'\"{p[\"title\"]}\" is the title for \" \" ')['input_ids'])\n        tokens = tokenizer(p[\"text\"], max_length=per_p_max_length + saved_tokens - needed_part_len, truncation=True)\n        saved_tokens += per_p_max_length - len(tokens['input_ids']) - needed_part_len\n        new_text = tokenizer.batch_decode([tokens['input_ids']], skip_special_tokens=True)[0]\n        prompt = f'\"{p[\"title\"]}\" is the title for \"{new_text}\" '\n        prompts.append(prompt)\n    return f'{\", and \".join(prompts)}. {inp}'\n\ndef create_generation_avocado_long_prompt(inp, profile, max_length, tokenizer):\n    per_p_max_length = (max_length - 1 - 2 * (len(profile) - 1) - len(tokenizer(\"are written by user. Following the given patterns\")['input_ids'])) // len(profile)\n    saved_tokens = 0\n    prompts = []\n    for p in profile:\n        needed_part_len = len(tokenizer(f'\"{p[\"title\"]}\" is the title for \" \" ')['input_ids'])\n        tokens = tokenizer(p[\"text\"], max_length=per_p_max_length + saved_tokens - needed_part_len, truncation=True)\n        saved_tokens += per_p_max_length - len(tokens['input_ids']) - needed_part_len\n        new_text = tokenizer.batch_decode([tokens['input_ids']], skip_special_tokens=True)[0]\n        prompt = f'\"{p[\"title\"]}\" is the title for \"{new_text}\" '\n        prompts.append(prompt)\n    return f'{\", and \".join(prompts)}. Following the given patterns {inp}'\n\ndef create_classification_movies_prompt(inp, profile, max_length, tokenizer): # good\n    per_p_max_length = (max_length - 1 - 2 * (len(profile) - 1)) // len(profile)\n    saved_tokens = 0\n    prompts = []\n    for p in profile:\n        needed_part_len = len(tokenizer(f'the tag for the movie: \" \" is \"{p[\"tag\"]}\" ')['input_ids'])\n        tokens = tokenizer(p[\"description\"], max_length=per_p_max_length + saved_tokens - needed_part_len, truncation=True)\n        saved_tokens += per_p_max_length - len(tokens['input_ids']) - needed_part_len\n        new_text = tokenizer.batch_decode([tokens['input_ids']], skip_special_tokens=True)[0]\n        prompt = f'the tag for the movie: \"{new_text}\" is \"{p[\"tag\"]}\" '\n        prompts.append(prompt)\n    return f'{\", and \".join(prompts)}. {inp}'\n\ndef create_query_corpus_generator(task):\n    def create_query_corpus(inp, profile):\n        if task == \"LaMP-1\":\n            corpus, query = classification_citation_query_corpus_maker(inp, profile)\n        elif task == \"LaMP-2\":\n            corpus, query = classification_movies_query_corpus_maker(inp, profile)\n        elif task == \"LaMP-3\":\n            corpus, query = classification_review_query_corpus_maker(inp, profile)\n        elif task == \"LaMP-4\":\n            corpus, query = generation_news_query_corpus_maker(inp, profile)\n        elif task == \"LaMP-5\":\n            corpus, query = generation_paper_query_corpus_maker(inp, profile)\n        elif task == \"LaMP-7\":\n            corpus, query = parphrase_tweet_query_corpus_maker(inp, profile)\n        elif task == \"LaMP-6\":\n            corpus, query = generation_avocado_query_corpus_maker(inp, profile)\n        return query, corpus\n    return create_query_corpus\n\ndef create_prompt_generator(num_retrieve, ret_type = \"bm25\", is_ranked = False, max_length = 512, tokenizer = None):\n    contriever = None\n    if ret_type == \"contriever\" and not is_ranked:\n        tokenizer = AutoTokenizer.from_pretrained('facebook/contriever')\n        contriever = AutoModel.from_pretrained('facebook/contriever').to(\"cuda:0\")\n        contriever.eval()\n\n    def prompt(inp, profile, task):\n        if task == \"LaMP-1\":\n            corpus, query = classification_citation_query_corpus_maker(inp, profile)\n        elif task == \"LaMP-2\":\n            corpus, query = classification_movies_query_corpus_maker(inp, profile)\n        elif task == \"LaMP-3\":\n            corpus, query = classification_review_query_corpus_maker(inp, profile)\n        elif task == \"LaMP-4\":\n            corpus, query = generation_news_query_corpus_maker(inp, profile)\n        elif task == \"LaMP-5\":\n            corpus, query = generation_paper_query_corpus_maker(inp, profile)\n        elif task == \"LaMP-7\":\n            corpus, query = parphrase_tweet_query_corpus_maker(inp, profile)\n        elif task == \"LaMP-6\":\n            corpus, query = generation_avocado_query_corpus_maker(inp, profile)\n        \n        if not is_ranked:\n            if ret_type == \"bm25\":\n                tokenized_corpus = [x.split() for x in corpus]\n                bm25 = BM25Okapi(tokenized_corpus)\n                tokenized_query = query.split()\n                selected_profs = bm25.get_top_n(tokenized_query, profile, n=num_retrieve)\n            elif ret_type == \"contriever\":\n                selected_profs = retrieve_top_k_with_contriever(contriever, tokenizer, corpus, profile, query, num_retrieve)\n            elif ret_type == \"random\":\n                selected_profs = random.choices(profile, k = num_retrieve)\n            elif ret_type == \"rec\":\n                selected_profs = profile[-num_retrieve:][::-1]\n        else:\n            selected_profs = profile[:num_retrieve]\n        factor = 0.6\n        while True:\n            try:\n                max_len_prompt = max_length - min(len(tokenizer(inp)['input_ids']), int(factor * max_length))\n                if task == \"LaMP-1\":\n                    return create_classification_citation_prompt(inp, selected_profs, max_len_prompt, tokenizer)\n                elif task == \"LaMP-3\":\n                    return create_classification_review_prompt(inp, selected_profs, max_len_prompt, tokenizer)\n                elif task == \"LaMP-2\":\n                    return create_classification_movies_prompt(inp, selected_profs, max_len_prompt, tokenizer)\n                elif task == \"LaMP-4\":\n                    return create_generation_news_prompt(inp, selected_profs, max_len_prompt, tokenizer)\n                elif task == \"LaMP-5\":\n                    return create_generation_paper_prompt(inp, selected_profs, max_len_prompt, tokenizer)\n                elif task == \"LaMP-7\":\n                    return create_parphrase_tweet_prompt(inp, selected_profs, max_len_prompt, tokenizer)\n                elif task == \"LaMP-6\":\n                    return create_generation_avocado_prompt(inp, selected_profs, max_len_prompt, tokenizer)\n            except:\n                factor -= 0.1\n                if factor < 0:\n                    print(len(profile))\n                    print(len(selected_profs))\n                    return inp\n                    # raise RuntimeError(\"not possible\")\n    return prompt, contriever"
  },
  {
    "path": "ROPG/prompts/utils.py",
    "content": "import re\n\ndef extract_strings_between_quotes(input_string):\n    pattern = r'\"(.*?)\"'\n    titles = re.findall(pattern, input_string)\n    return titles\n\ndef extract_after_article(input_string):\n    article_index = input_string.find('article:')\n    if article_index == -1:\n        return None\n    return input_string[article_index + len('article:'):].strip()\n\ndef extract_after_description(input_string):\n    article_index = input_string.find('description:')\n    if article_index == -1:\n        return None\n    return input_string[article_index + len('description:'):].strip()\n\ndef extract_after_review(input_string):\n    article_index = input_string.find('review:')\n    if article_index == -1:\n        return None\n    return input_string[article_index + len('review:'):].strip()\n\ndef extract_after_paper(input_string):\n    article_index = input_string.find('paper:')\n    if article_index == -1:\n        return None\n    return input_string[article_index + len('paper:'):].strip()\n\ndef extract_after_colon(input_string):\n    article_index = input_string.find(':')\n    if article_index == -1:\n        return None\n    return input_string[article_index + len(':'):].strip()\n\ndef extract_after_abstract(input_string):\n    article_index = input_string.find('abstract:')\n    if article_index == -1:\n        return None\n    return input_string[article_index + len('abstract:'):].strip()\n\n\ndef add_string_after_title(original_string, string_to_add):\n    title_index = original_string.find(\"title\")\n    \n    if title_index == -1:\n        return original_string\n    \n    return original_string[:title_index+5] + \", and \" + string_to_add + original_string[title_index+5:]\n\ndef batchify(lst, batch_size):\n    return [lst[i:i+batch_size] for i in range(0, len(lst), batch_size)]"
  },
  {
    "path": "ROPG/requirements.txt",
    "content": "evaluate==0.4.0\nnumpy==1.24.3\nrank_bm25==0.2.2\ntorch==2.0.1\ntransformers==4.29.2\nrouge_score"
  },
  {
    "path": "ROPG/train_kd.py",
    "content": "from pathlib import Path\nfrom utils.distributed import init_distributed_mode, init_signal_handler\nimport torch\nimport numpy as np\nimport os\nfrom torch.utils.data import DataLoader, DistributedSampler, SequentialSampler, RandomSampler\nimport tqdm\nimport argparse\nfrom transformers import AutoModel, AutoTokenizer, AutoModelForSeq2SeqLM\nfrom data.collators import ReaderToRetreieverCollator\nfrom data.datasets import ReaderToRetrieverDataset, get_all_labels\nfrom prompts.prompts import create_query_corpus_generator\nfrom trainers.trainer import KDReaderToRetrieverTrainer\nfrom models.retriever import Contriever\nfrom utils.util import average_main\nfrom utils.log import init_logger\nfrom models.optim import set_optim\nfrom utils.util import save_checkpoint, load_checkpoint\n\n\ndef train(opts, model, optimizer, scheduler, step, dataset, collator, checkpoint_path):\n    \n    if opts.is_main:\n        try:\n            tb_logger = torch.utils.tensorboard.SummaryWriter(Path(opts.checkpoint_dir)/opts.name)\n        except:\n           tb_logger = None\n           logger.warning('Tensorboard is not available.')\n\n    torch.manual_seed(opts.global_rank + opts.seed)\n    if opts.is_distributed:\n        train_sampler = DistributedSampler(dataset, num_replicas=opts.n_gpu_per_node, rank=opts.local_rank)\n    else:\n        train_sampler = RandomSampler(dataset)\n    bar = tqdm.tqdm(total=opts.total_steps)\n    train_dataloader = DataLoader(\n        dataset,\n        sampler = train_sampler,\n        batch_size = opts.per_gpu_batch_size,\n        drop_last = True,\n        num_workers = 10,\n        collate_fn = collator,\n    )\n\n    loss, curr_loss = 0.0, 0.0\n    epoch = 1\n    model.train()\n    temp_step = 0\n    while step < opts.total_steps:\n        epoch += 1\n        for i, batch in enumerate(train_dataloader):\n            temp_step += 1\n            batch = {k:v.cuda() if type(v) != list else v for k, v in batch.items()}\n            train_loss, scores, gold_scores = model(**batch)\n\n            train_loss.backward()\n            if temp_step % opts.accumulation_steps == 0:\n                step += 1\n                temp_step = 0\n                torch.nn.utils.clip_grad_norm_(model.parameters(), opts.clip)\n                optimizer.step()\n                scheduler.step()\n                model.zero_grad()\n                if opts.is_main:\n                    bar.update(1)\n\n            train_loss = average_main(train_loss, opts)\n            curr_loss += train_loss.item()\n            if opts.is_main and step % opts.save_freq == 0:\n                save_checkpoint(model, optimizer, scheduler, step, opts, checkpoint_path, f\"step-{step}\")\n            if step > opts.total_steps:\n                save_checkpoint(model, optimizer, scheduler, step, opts, checkpoint_path, f\"step-{step}\")\n                break\n    \nparser = argparse.ArgumentParser()\n\nparser.add_argument(\"--train_data\", required = True, help=\"training data\")\nparser.add_argument(\"--do_train\", action='store_true', help=\"perform training\")\nparser.add_argument(\"--scores_path\", required=True, help=\"address to pre-computed profile item score\")\n\nparser.add_argument(\"--max_length_query\", type = int, default = 512, help=\"max length query\")\nparser.add_argument(\"--max_length_document\", type = int, default = 512, help=\"max length document\")\n\nparser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment')\nparser.add_argument('--checkpoint_dir', type=str, default='./checkpoint/', help='models are saved here')\n\nparser.add_argument(\"--per_gpu_batch_size\", default=1, type=int, \n           help=\"Batch size per GPU/CPU for training.\")\nparser.add_argument(\"--local-rank\", type=int, default=-1,\n           help=\"For distributed training: local_rank\")\nparser.add_argument(\"--main_port\", type=int, default=-1,\n           help=\"Main port (for multi-node SLURM jobs)\")\nparser.add_argument('--seed', type=int, default=0, help=\"random seed for initialization\")\nparser.add_argument('--save_freq', type=int, default=5000,\n           help='save model every <save_freq> steps during training')\nparser.add_argument('--warmup_steps', type=int, default=1000, help=\"number of warmup steps\")\nparser.add_argument('--total_steps', type=int, default=1000, help=\"number of training steps\")\nparser.add_argument('--scheduler_steps', type=int, default=None, \n           help='total number of step for the scheduler, if None then scheduler_total_step = total_step')\nparser.add_argument('--accumulation_steps', type=int, default=1, help=\"number of gradient accumulation steps\")\nparser.add_argument('--dropout', type=float, default=0.1, help='dropout rate')\nparser.add_argument('--lr', type=float, default=1e-5, help='learning rate')\nparser.add_argument('--clip', type=float, default=1., help='gradient clipping')\nparser.add_argument('--optim', type=str, default='adam', help=\"optimizer which is used for training\")\nparser.add_argument('--scheduler', type=str, default='fixed', help=\"scheduler which is used for training\")\nparser.add_argument('--weight_decay', type=float, default=0.0, help=\"weight decay rate\")\nparser.add_argument('--fixed_lr', action='store_true', help=\"use a fixed lr\")\n\nparser.add_argument('--ctx_size', type=int, default=20, help=\"number of docs per query for training\")\n\nparser.add_argument(\"--task\", required = True, help=\"task name\")\nparser.add_argument(\"--model_path\", default=\"\", help=\"address to a checkpoint to be load\")\n\nparser.add_argument('--temperature', type=float, default=1.0, help=\"temperature for distillation\")\nparser.add_argument('--cache_dir', default=\"cache\")\n\n\n\n\n\n\nif __name__ == \"__main__\":\n    opts = parser.parse_args()\n\n    torch.manual_seed(opts.seed)\n    init_distributed_mode(opts)\n    init_signal_handler()\n\n    checkpoint_path = Path(opts.checkpoint_dir)/opts.name\n    checkpoint_exists = checkpoint_path.exists()\n    \n    if opts.is_distributed:\n        torch.distributed.barrier()\n    \n    checkpoint_path.mkdir(parents = True, exist_ok = True)\n    opts.output_dir = checkpoint_path\n\n\n    logger = init_logger(\n        opts.is_main,\n        opts.is_distributed,\n        checkpoint_path / 'run.log'\n    )\n\n    logger.info(opts)\n\n    model = Contriever.from_pretrained('facebook/contriever', cache_dir = opts.cache_dir)\n    tokenizer = AutoTokenizer.from_pretrained('facebook/contriever', cache_dir = opts.cache_dir)\n    collator = ReaderToRetreieverCollator(tokenizer = tokenizer, query_max_lenght = opts.max_length_query, document_max_length = opts.max_length_document, number_of_ctx = opts.ctx_size, scores_addr = opts.scores_path)\n    \n    task = opts.task\n    \n    reader_tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-base', cache_dir = opts.cache_dir)\n    query_corpus_generator = create_query_corpus_generator(task)\n\n    greater_is_better = True\n\n    if task == \"LaMP-1\":\n        train_dataset, labels = ReaderToRetrieverDataset(opts.train_data, task, query_corpus_generator), get_all_labels(task)\n        best_metric_generation = \"accuracy\"\n    elif task == \"LaMP-2\":\n        train_dataset, labels = ReaderToRetrieverDataset(opts.train_data, task, query_corpus_generator), get_all_labels(task)\n        best_metric_generation = \"accuracy\"\n    elif task == \"LaMP-3\":\n        train_dataset, labels = ReaderToRetrieverDataset(opts.train_data, task, query_corpus_generator), get_all_labels(task)\n        best_metric_generation = \"mae\"\n        greater_is_better = False\n    elif task == \"LaMP-4\":\n        train_dataset = ReaderToRetrieverDataset(opts.train_data, task, query_corpus_generator)\n        best_metric_generation = \"rouge-1\"\n    elif task == \"LaMP-5\":\n        train_dataset = ReaderToRetrieverDataset(opts.train_data, task, query_corpus_generator)\n        best_metric_generation = \"rouge-1\"\n    elif task == \"LaMP-7\":\n        train_dataset = ReaderToRetrieverDataset(opts.train_data, task, query_corpus_generator)\n        best_metric_generation = \"rouge-1\"\n    elif task == \"LaMP-6\":\n        train_dataset = ReaderToRetrieverDataset(opts.train_data, task, query_corpus_generator)\n        best_metric_generation = \"rouge-1\"\n    \n    opts.greater_is_better = greater_is_better\n    opts.reader_gold_metric = best_metric_generation\n\n    if not checkpoint_exists and not opts.model_path:\n        model = KDReaderToRetrieverTrainer(model = model, args = opts)\n        model = model.to(opts.local_rank)\n        optimizer, scheduler = set_optim(opts, model)\n        step = 0\n    elif checkpoint_exists and opts.model_path and opts.do_train:\n        model, optimizer, scheduler, opt_checkpoint, step = load_checkpoint(Contriever, opts.model_path, opts)\n        model = KDReaderToRetrieverTrainer(model = model, args = opts)\n    \n    if opts.is_distributed:\n        model = torch.nn.parallel.DistributedDataParallel(\n            model,\n            device_ids=[opts.local_rank],\n            output_device=opts.local_rank,\n            find_unused_parameters=True,\n        )\n\n    if opts.do_train:\n        train(opts, model, optimizer, scheduler, step, train_dataset, collator, checkpoint_path)"
  },
  {
    "path": "ROPG/train_rl.py",
    "content": "from pathlib import Path\nfrom utils.distributed import init_distributed_mode, init_signal_handler\nimport torch\nimport numpy as np\nimport os\nfrom torch.utils.data import DataLoader, DistributedSampler, SequentialSampler, RandomSampler\nimport tqdm\nimport argparse\nfrom transformers import AutoModel, AutoTokenizer, AutoModelForSeq2SeqLM\nfrom data.collators import ReaderToRetreieverCollator\nfrom data.datasets import ReaderToRetrieverDataset, get_all_labels\nfrom prompts.prompts import create_prompt_generator, create_query_corpus_generator\nfrom trainers.trainer import RLReaderToRetrieverTrainer\nfrom models.retriever import Contriever\nfrom utils.util import average_main\nfrom utils.log import init_logger\nfrom models.optim import set_optim\nfrom utils.util import save_checkpoint, load_checkpoint\n\n\ndef train(opts, model, optimizer, scheduler, step, dataset, collator, checkpoint_path):\n    \n    if opts.is_main:\n        try:\n            tb_logger = torch.utils.tensorboard.SummaryWriter(Path(opts.checkpoint_dir)/opts.name)\n        except:\n           tb_logger = None\n           logger.warning('Tensorboard is not available.')\n\n    torch.manual_seed(opts.global_rank + opts.seed)\n    if opts.is_distributed:\n        train_sampler = DistributedSampler(dataset, num_replicas=opts.n_gpu_per_node, rank=opts.local_rank)\n    else:\n        train_sampler = RandomSampler(dataset)\n    bar = tqdm.tqdm(total=opts.total_steps)\n    train_dataloader = DataLoader(\n        dataset,\n        sampler = train_sampler,\n        batch_size = opts.per_gpu_batch_size,\n        drop_last = True,\n        num_workers = 10,\n        collate_fn = collator,\n    )\n\n    loss, curr_loss = 0.0, 0.0\n    epoch = 1\n    model.train()\n    temp_step = 0\n    while step < opts.total_steps:\n        epoch += 1\n        for i, batch in enumerate(train_dataloader):\n            temp_step += 1\n            batch = {k:v.cuda() if type(v) != list else v for k, v in batch.items()}\n            train_loss, scores, gold_scores = model(**batch)\n\n            train_loss.backward()\n            if temp_step % opts.accumulation_steps == 0:\n                step += 1\n                bar.update(1)\n                temp_step = 0\n                torch.nn.utils.clip_grad_norm_(model.parameters(), opts.clip)\n                optimizer.step()\n                scheduler.step()\n                model.zero_grad()\n\n            train_loss = average_main(train_loss, opts)\n            curr_loss += train_loss.item()\n            if opts.is_main and step % opts.save_freq == 0:\n                save_checkpoint(model, optimizer, scheduler, step, opts, checkpoint_path, f\"step-{step}\")\n            if step > opts.total_steps:\n                save_checkpoint(model, optimizer, scheduler, step, opts, checkpoint_path, f\"step-{step}\")\n                break\n    \nparser = argparse.ArgumentParser()\n\nparser.add_argument(\"--train_data\", required = True, help=\"training data\")\nparser.add_argument(\"--do_train\", action='store_true', help=\"perform training\")\nparser.add_argument(\"--scores_path\", required=True, help=\"address to pre-computed profile item score\")\n\nparser.add_argument(\"--max_length_query\", type = int, default = 512, help=\"max length query\")\nparser.add_argument(\"--max_length_document\", type = int, default = 512, help=\"max length document\")\n\nparser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment')\nparser.add_argument('--checkpoint_dir', type=str, default='./checkpoint/', help='models are saved here')\n\nparser.add_argument(\"--per_gpu_batch_size\", default=1, type=int, \n           help=\"Batch size per GPU/CPU for training.\")\nparser.add_argument(\"--local-rank\", type=int, default=-1,\n           help=\"For distributed training: local_rank\")\nparser.add_argument(\"--main_port\", type=int, default=-1,\n           help=\"Main port (for multi-node SLURM jobs)\")\nparser.add_argument('--seed', type=int, default=0, help=\"random seed for initialization\")\nparser.add_argument('--save_freq', type=int, default=5000,\n           help='save model every <save_freq> steps during training')\nparser.add_argument('--warmup_steps', type=int, default=1000, help=\"number of warmup steps\")\nparser.add_argument('--total_steps', type=int, default=1000, help=\"number of training steps\")\nparser.add_argument('--scheduler_steps', type=int, default=None, \n           help='total number of step for the scheduler, if None then scheduler_total_step = total_step')\nparser.add_argument('--accumulation_steps', type=int, default=1, help=\"number of gradient accumulation steps\")\nparser.add_argument('--dropout', type=float, default=0.1, help='dropout rate')\nparser.add_argument('--lr', type=float, default=1e-5, help='learning rate')\nparser.add_argument('--clip', type=float, default=1., help='gradient clipping')\nparser.add_argument('--optim', type=str, default='adam', help=\"optimizer which is used for training\")\nparser.add_argument('--scheduler', type=str, default='fixed', help=\"scheduler which is used for training\")\nparser.add_argument('--weight_decay', type=float, default=0.0, help=\"weight decay rate\")\nparser.add_argument('--fixed_lr', action='store_true', help=\"use a fixed lr\")\n\nparser.add_argument('--ctx_size', type=int, default=20, help=\"number of docs per query for training\")\n\nparser.add_argument(\"--task\", required = True, help=\"task name\")\nparser.add_argument(\"--model_path\", default=\"\", help=\"address to a checkpoint to be load\")\nparser.add_argument('--cache_dir', default=\"cache\")\n\n\nif __name__ == \"__main__\":\n    opts = parser.parse_args()\n\n    torch.manual_seed(opts.seed)\n    init_distributed_mode(opts)\n    init_signal_handler()\n\n    checkpoint_path = Path(opts.checkpoint_dir)/opts.name\n    checkpoint_exists = checkpoint_path.exists()\n    \n    if opts.is_distributed:\n        torch.distributed.barrier()\n    \n    checkpoint_path.mkdir(parents = True, exist_ok = True)\n    opts.output_dir = checkpoint_path\n\n\n    logger = init_logger(\n        opts.is_main,\n        opts.is_distributed,\n        checkpoint_path / 'run.log'\n    )\n\n    logger.info(opts)\n\n    model = Contriever.from_pretrained('facebook/contriever', cache_dir = opts.cache_dir)\n    tokenizer = AutoTokenizer.from_pretrained('facebook/contriever', cache_dir = opts.cache_dir)\n    collator = ReaderToRetreieverCollator(tokenizer = tokenizer, query_max_lenght = opts.max_length_query, document_max_length = opts.max_length_document, number_of_ctx = opts.ctx_size, scores_addr = opts.scores_path)\n    \n    task = opts.task\n    \n    model = Contriever.from_pretrained('facebook/contriever', cache_dir = opts.cache_dir)\n    tokenizer = AutoTokenizer.from_pretrained('facebook/contriever', cache_dir = opts.cache_dir)\n    collator = ReaderToRetreieverCollator(tokenizer = tokenizer, query_max_lenght = opts.max_length_query, document_max_length = opts.max_length_document, number_of_ctx = opts.ctx_size, scores_addr = opts.scores_path)\n    \n    task = opts.task\n    \n    query_corpus_generator = create_query_corpus_generator(task)\n\n    greater_is_better = True\n\n    if task == \"LaMP-1\":\n        train_dataset, labels = ReaderToRetrieverDataset(opts.train_data, task, query_corpus_generator), get_all_labels(task)\n        best_metric_generation = \"accuracy\"\n    elif task == \"LaMP-2\":\n        train_dataset, labels = ReaderToRetrieverDataset(opts.train_data, task, query_corpus_generator), get_all_labels(task)\n        best_metric_generation = \"accuracy\"\n    elif task == \"LaMP-3\":\n        train_dataset, labels = ReaderToRetrieverDataset(opts.train_data, task, query_corpus_generator), get_all_labels(task)\n        best_metric_generation = \"mae\"\n        greater_is_better = False\n    elif task == \"LaMP-4\":\n        train_dataset = ReaderToRetrieverDataset(opts.train_data, task, query_corpus_generator)\n        best_metric_generation = \"rouge-1\"\n    elif task == \"LaMP-5\":\n        train_dataset = ReaderToRetrieverDataset(opts.train_data, task, query_corpus_generator)\n        best_metric_generation = \"rouge-1\"\n    elif task == \"LaMP-7\":\n        train_dataset = ReaderToRetrieverDataset(opts.train_data, task, query_corpus_generator)\n        best_metric_generation = \"rouge-1\"\n    elif task == \"LaMP-6\":\n        train_dataset = ReaderToRetrieverDataset(opts.train_data, task, query_corpus_generator)\n        best_metric_generation = \"rouge-1\"\n    \n    opts.greater_is_better = greater_is_better\n    opts.reader_gold_metric = best_metric_generation\n\n    if not checkpoint_exists and not opts.model_path:\n        model = RLReaderToRetrieverTrainer(model = model, args = opts)\n        model = model.to(opts.local_rank)\n        optimizer, scheduler = set_optim(opts, model)\n        step = 0\n    elif checkpoint_exists and opts.model_path and opts.do_train:\n        model, optimizer, scheduler, opt_checkpoint, step = load_checkpoint(Contriever, opts.model_path, opts)\n        model = RLReaderToRetrieverTrainer(model = model, args = opts)\n    \n    \n    if opts.is_distributed:\n        model = torch.nn.parallel.DistributedDataParallel(\n            model,\n            device_ids=[opts.local_rank],\n            output_device=opts.local_rank,\n            find_unused_parameters=True,\n        )\n\n    if opts.do_train:\n        train(opts, model, optimizer, scheduler, step, train_dataset, collator, checkpoint_path)"
  },
  {
    "path": "ROPG/trainers/trainer.py",
    "content": "from dataclasses import dataclass\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\nfrom torch.nn import DataParallel\nimport torch\nfrom torch import nn\nfrom torch.utils.data import Dataset\nfrom transformers import Trainer, TrainingArguments\nimport math\n\n\ndef select_elements(tensor, k):\n    B = tensor.size(0)\n    selected_rows = []\n    for i in range(B):\n        start_idx = i * k\n        end_idx = (i + 1) * k\n        selected_rows.append(tensor[i, start_idx:end_idx])\n    selected_tensor = torch.stack(selected_rows)\n    return selected_tensor\n\ndef loss_fn_reinforce(probs, rewards):\n    x = torch.mul(probs, rewards)\n    x = torch.sum(x) / (probs.shape[0] * probs.shape[1])\n    return -x\n\nclass KDReaderToRetrieverTrainer(nn.Module):\n\n    def __init__(self, model, args) -> None:\n        super().__init__()\n        self.model = model\n        self.args = args\n        \n    def forward(\n            self, \n            query_input_ids = None,\n            query_token_type_ids = None,\n            query_attention_mask = None,\n            documents_input_ids = None,\n            documents_token_type_ids = None,\n            documents_attention_mask = None,\n            documents_ctxs = None,\n            scores_gold = None,\n            target_txt = None,\n            **kws\n        ):\n        return self._forward(\n            query_input_ids,\n            query_token_type_ids,\n            query_attention_mask,\n            documents_input_ids,\n            documents_token_type_ids,\n            documents_attention_mask,\n            documents_ctxs,\n            target_txt,\n            scores_gold\n        )\n    \n    def _forward(\n            self, \n            query_input_ids,\n            query_token_type_ids,\n            query_attention_mask,\n            documents_input_ids,\n            documents_token_type_ids,\n            documents_attention_mask,\n            documents_ctxs,\n            target_txt,\n            scores_gold,\n            **kws\n        ):\n        B = documents_token_type_ids.shape[0]\n        ctx_size = documents_token_type_ids.shape[1]\n\n        query_reps = self.model(input_ids = query_input_ids, token_type_ids = query_token_type_ids, attention_mask = query_attention_mask)\n        docs_reps = self.model(input_ids = documents_input_ids.view(B * ctx_size, -1), token_type_ids = documents_token_type_ids.view(B * ctx_size, -1), attention_mask = documents_attention_mask.view(B * ctx_size, -1))\n        docs_reps = docs_reps.view(B, ctx_size, -1)\n        scores = torch.einsum(\"ij,ikj->ik\", query_reps, docs_reps)\n        probs = torch.softmax(scores / self.args.temperature, dim = -1)\n        \n        # gold label creation\n        if self.args.greater_is_better:\n            gold_scores = torch.zeros_like(scores, device = scores.device)\n            for i, sample in enumerate(scores_gold):\n                for j, score in enumerate(sample):\n                    gold_scores[i, j] = (score[self.args.reader_gold_metric])\n        else:\n            target_numerical = [int(x) for x in target_txt]\n            worst_score_array = [max(abs(x-1), abs(x-5)) for x in target_numerical]\n            gold_scores = torch.tensor([[float(x) for y in range(ctx_size)] for x in worst_score_array], device = scores.device)\n            for i, sample in enumerate(scores_gold):\n                for j, score in enumerate(sample):\n                    gold_scores[i, j] = (worst_score_array[i] - score[self.args.reader_gold_metric]) / worst_score_array[i]\n        \n        loss_fn = nn.CrossEntropyLoss()\n        loss = loss_fn(probs, torch.softmax(gold_scores  / self.args.temperature, dim = -1))\n        \n        \n        return loss, torch.softmax(scores, dim = -1), gold_scores \n    \nclass RLReaderToRetrieverTrainer(nn.Module):\n\n    def __init__(self, model, args) -> None:\n        super().__init__()\n        self.model = model\n        self.args = args\n        \n    def forward(\n            self, \n            query_input_ids = None,\n            query_token_type_ids = None,\n            query_attention_mask = None,\n            documents_input_ids = None,\n            documents_token_type_ids = None,\n            documents_attention_mask = None,\n            documents_ctxs = None,\n            scores_gold = None,\n            target_txt = None,\n            **kws\n        ):\n        return self._forward(\n            query_input_ids,\n            query_token_type_ids,\n            query_attention_mask,\n            documents_input_ids,\n            documents_token_type_ids,\n            documents_attention_mask,\n            documents_ctxs,\n            target_txt,\n            scores_gold\n        )\n        \n    def _forward(\n            self, \n            query_input_ids,\n            query_token_type_ids,\n            query_attention_mask,\n            documents_input_ids,\n            documents_token_type_ids,\n            documents_attention_mask,\n            documents_ctxs,\n            target_txt,\n            scores_gold,\n            **kws\n        ):\n        B = documents_token_type_ids.shape[0]\n        ctx_size = documents_token_type_ids.shape[1]\n\n        query_reps = self.model(input_ids = query_input_ids, token_type_ids = query_token_type_ids, attention_mask = query_attention_mask)\n        docs_reps = self.model(input_ids = documents_input_ids.view(B * ctx_size, -1), token_type_ids = documents_token_type_ids.view(B * ctx_size, -1), attention_mask = documents_attention_mask.view(B * ctx_size, -1))\n        docs_reps = docs_reps.view(B, ctx_size, -1)\n        scores = torch.einsum(\"ij,ikj->ik\", query_reps, docs_reps)\n        probs = torch.softmax(scores, dim = -1)\n        \n        sample_idx = probs.multinomial(1)\n        \n        # gold label creation\n        if self.args.greater_is_better:\n            gold_scores = torch.zeros_like(scores, device = scores.device)\n            for i, sample in enumerate(scores_gold):\n                for j, score in enumerate(sample):\n                    gold_scores[i, j] = (score[self.args.reader_gold_metric] - sample[0][self.args.reader_gold_metric])\n        else:\n            target_numerical = [int(x) for x in target_txt]\n            worst_score_array = [max(abs(x-1), abs(x-5)) for x in target_numerical]\n            gold_scores = torch.tensor([[float(x) for y in range(ctx_size)] for x in worst_score_array], device = scores.device)\n            for i, sample in enumerate(scores_gold):\n                for j, score in enumerate(sample):\n                    gold_scores[i, j] = (sample[0][self.args.reader_gold_metric] - score[self.args.reader_gold_metric]) / worst_score_array[i]\n        \n        probs = torch.gather(probs, 1, sample_idx)\n        gold_scores = torch.gather(gold_scores, 1, sample_idx)\n        loss = loss_fn_reinforce(torch.log(probs), gold_scores)\n        \n        return loss, torch.softmax(scores, dim = -1), gold_scores "
  },
  {
    "path": "ROPG/utils/distributed.py",
    "content": "from logging import getLogger\nimport os\nimport sys\nimport torch\nimport socket\nimport signal\nimport subprocess\nimport datetime\n\n\nlogger = getLogger()\n\ndef sig_handler(signum, frame):\n    logger.warning(\"Signal handler called with signal \" + str(signum))\n    prod_id = int(os.environ['SLURM_PROCID'])\n    logger.warning(\"Host: %s - Global rank: %i\" % (socket.gethostname(), prod_id))\n    if prod_id == 0:\n        logger.warning(\"Requeuing job \" + os.environ['SLURM_JOB_ID'])\n        os.system('scontrol requeue ' + os.environ['SLURM_JOB_ID'])\n    else:\n        logger.warning(\"Not the main process, no need to requeue.\")\n    sys.exit(-1)\n\n\ndef term_handler(signum, frame):\n    logger.warning(\"Signal handler called with signal \" + str(signum))\n    logger.warning(\"Bypassing SIGTERM.\")\n\n\ndef init_signal_handler():\n    signal.signal(signal.SIGUSR1, sig_handler)\n    signal.signal(signal.SIGTERM, term_handler)\n\n\ndef init_distributed_mode(params):\n    \n    has_local_rank = hasattr(params, 'local_rank')\n    if has_local_rank:\n        params.local_rank = params.local_rank\n    \n    if has_local_rank and params.local_rank != -1:\n\n        assert params.main_port == -1\n\n        # read environment variables\n        params.global_rank = int(os.environ['RANK'])\n        params.world_size = int(os.environ['WORLD_SIZE'])\n        params.n_gpu_per_node = int(os.environ['NGPU'])\n\n        # number of nodes / node ID\n        params.n_nodes = params.world_size // params.n_gpu_per_node\n        params.node_id = params.global_rank // params.n_gpu_per_node\n        params.is_distributed = True\n    else:\n        n_gpu = torch.cuda.device_count()\n        params.n_nodes = 1\n        params.node_id = 0\n        params.local_rank = 0\n        params.global_rank = 0\n        params.world_size = n_gpu\n        params.n_gpu_per_node = n_gpu\n        params.is_distributed = False\n    \n    # define whether this is the master process / if we are in distributed mode\n    params.is_main = params.node_id == 0 and params.local_rank == 0\n    params.multi_node = params.n_nodes > 1\n    params.multi_gpu = params.world_size > 1\n\n    # set GPU device\n    if params.is_distributed:\n        torch.cuda.set_device(params.local_rank)\n        device = torch.device(\"cuda\", params.local_rank)\n    else:\n        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    params.device = device\n\n    # initialize multi-GPU\n    if params.is_distributed:\n        torch.distributed.init_process_group(\n            init_method='env://',\n            backend='nccl',\n            timeout = datetime.timedelta(seconds=36000)\n        )"
  },
  {
    "path": "ROPG/utils/log.py",
    "content": "import logging\nimport torch\nimport sys\n\nlogger = logging.getLogger(__name__)\n\ndef init_logger(is_main=True, is_distributed=False, filename=None):\n    if is_distributed:\n        torch.distributed.barrier()\n    handlers = [logging.StreamHandler(sys.stdout)]\n    if filename is not None:\n        handlers.append(logging.FileHandler(filename = filename))\n    logging.basicConfig(\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO if is_main else logging.WARN,\n        format=\"[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s\",\n        handlers=handlers,\n    )\n    logging.getLogger('transformers.tokenization_utils').setLevel(logging.ERROR)\n    logging.getLogger('transformers.tokenization_utils_base').setLevel(logging.ERROR)\n    return logger"
  },
  {
    "path": "ROPG/utils/util.py",
    "content": "import os\nfrom logging import getLogger\nimport torch\nfrom models.optim import set_optim\nimport torch.distributed as dist\nimport errno\n\ndef load_checkpoint(model_class, dir_path, opt, reset_params=False):\n    epoch_path = dir_path\n    optimizer_path = os.path.join(epoch_path, \"optimizer.pth.tar\")\n    logger = getLogger()\n    logger.info(\"Loading %s\" % epoch_path)\n    model = model_class.from_pretrained(epoch_path, local_files_only=True)\n    model = model.to(opt.device)\n    logger.info(\"loading checkpoint %s\" %optimizer_path)\n    checkpoint = torch.load(optimizer_path, map_location=opt.device)\n    opt_checkpoint = checkpoint[\"opt\"]\n    step = checkpoint[\"step\"]\n    if not reset_params:\n        optimizer, scheduler = set_optim(opt_checkpoint, model)\n        scheduler.load_state_dict(checkpoint[\"scheduler\"])\n        optimizer.load_state_dict(checkpoint[\"optimizer\"])\n    else:\n        optimizer, scheduler = set_optim(opt, model)\n\n    return model, optimizer, scheduler, opt_checkpoint, step\n\ndef average_main(x, opt):\n    if not opt.is_distributed:\n        return x\n    if opt.world_size > 1:\n        dist.reduce(x, 0, op=dist.ReduceOp.SUM)\n        if opt.is_main:\n            x = x / opt.world_size\n    return x\n\ndef symlink_force(target, link_name):\n    try:\n        os.symlink(target, link_name)\n    except OSError as e:\n        if e.errno == errno.EEXIST:\n            os.remove(link_name)\n            os.symlink(target, link_name)\n        else:\n            raise e\n\ndef save_checkpoint(model, optimizer, scheduler, step, opt, dir_path, name):\n    model_to_save = model.module.model if hasattr(model, \"module\") else model.model\n    path = os.path.join(dir_path, \"checkpoint\")\n    epoch_path = os.path.join(path, name) #\"step-%s\" % step)\n    os.makedirs(epoch_path, exist_ok=True)\n    model_to_save.save_pretrained(epoch_path)\n    cp = os.path.join(path, \"latest\")\n    fp = os.path.join(epoch_path, \"optimizer.pth.tar\")\n    checkpoint = {\n        \"step\": step,\n        \"optimizer\": optimizer.state_dict(),\n        \"scheduler\": scheduler.state_dict(),\n        \"opt\": opt,\n    }\n    torch.save(checkpoint, fp)\n    symlink_force(epoch_path, cp)\n"
  },
  {
    "path": "RSPG/data/collators.py",
    "content": "import json\nimport datasets\nimport torch\n\nclass RSPGPreCollator(object):\n\n    def __init__(self, tokenizer, max_length) -> None:\n        self.tokenizer = tokenizer\n        self.max_len = max_length\n        \n    def __call__(self, batch):\n\n        inps = [x for ex in batch for x in ex['inputs']]\n        inps = self.tokenizer.batch_encode_plus(\n            inps,\n            max_length=self.max_len,\n            padding=True,\n            return_tensors='pt',\n            truncation=True,\n        )\n        labels = torch.tensor([x['labels'] for x in batch])\n        \n\n        return {\n            \"id\" : [x['id'] for x in batch],\n            \"input_ids\" : inps['input_ids'],\n            \"attention_mask\" : inps['attention_mask'],\n            \"labels\" : labels,\n            \"outputs\" : [x['outputs'] for x in batch],\n            \"gold\" : [x['gold'] for x in batch]\n        }\n\nclass RSPGPostCollator(object):\n\n    def __init__(self, tokenizer, max_length) -> None:\n        self.tokenizer = tokenizer\n        self.max_len = max_length\n        \n    def __call__(self, batch):\n\n        inps = [f\"{x} {self.tokenizer.sep_token} {y}\" for ex in batch for x, y in zip(ex['inputs'], ex['outputs'])]\n        inps = self.tokenizer.batch_encode_plus(\n            inps,\n            max_length=self.max_len,\n            padding=True,\n            return_tensors='pt',\n            truncation=True,\n        )\n        labels = torch.tensor([x['labels'] for x in batch])\n        \n\n        return {\n            \"id\" : [x['id'] for x in batch],\n            \"input_ids\" : inps['input_ids'],\n            \"attention_mask\" : inps['attention_mask'],\n            \"labels\" : labels,\n            \"outputs\" : [x['outputs'] for x in batch],\n            \"gold\" : [x['gold'] for x in batch]\n        }"
  },
  {
    "path": "RSPG/data/dataset.py",
    "content": "from torch.utils.data import Dataset\nimport json\nimport datasets\nimport numpy as np\nimport random\n\nclass RSPGDataset(Dataset):\n\n    def __init__(self, data_addr, smaller_is_better = False) -> None:\n        super().__init__()\n        with open(data_addr) as file:\n            data = json.load(file)\n        self.data = data\n        self.smaller_is_better = smaller_is_better\n\n    def __getitem__(self, index):\n        data = self.data[index]\n        did = data['id']\n        if self.smaller_is_better:\n            worst_score = max(abs(int(data['gold'])-1), abs(int(data['gold'])-5))\n            labels = [(worst_score - x) / worst_score for x in self.data[index]['labels']]\n        else:\n            labels = self.data[index]['labels']\n        outputs = self.data[index]['outputs']\n\n        return {\n            \"id\" : did,\n            \"inputs\" : [x.lower() for x in data['inputs']],\n            \"labels\" : labels,\n            \"outputs\" : outputs,\n            \"gold\" : data['gold']\n        }\n    \n    def __len__(self):\n        return len(self.data)"
  },
  {
    "path": "RSPG/metrics/evaluation.py",
    "content": "import json\nimport zipfile\nimport glob\nimport os\nimport shutil\nimport evaluate\n\ndef postprocess_text_classification(preds, labels):\n    preds = [str(pred).strip() for pred in preds]\n    labels = [str(label).strip() for label in labels]\n    return preds, labels\n\ndef postprocess_text_generation(preds, labels):\n    preds = [pred.strip() for pred in preds]\n    labels = [[label.strip()] for label in labels]\n\n    return preds, labels\n\ndef create_metric_f1_accuracy(all_labels):\n    f1_metric = evaluate.load(\"f1\")\n    accuracy_metric = evaluate.load(\"accuracy\")\n    def create_mapping(x):\n        try:\n            return all_labels.index(x)\n        except:\n            return -1\n    def compute_metrics(decoded_preds, decoded_labels):\n        decoded_preds, decoded_labels = postprocess_text_classification(decoded_preds, decoded_labels)\n        decoded_preds = [create_mapping(x) for x in decoded_preds]\n        decoded_labels = [create_mapping(x) for x in decoded_labels]\n        result_acc = accuracy_metric.compute(predictions=decoded_preds, references=decoded_labels)\n        result_f1 = f1_metric.compute(predictions=decoded_preds, references=decoded_labels, labels=list(range(len(all_labels))), average = \"macro\")\n        result = {\"accuracy\" : result_acc[\"accuracy\"], \"f1\" : result_f1[\"f1\"]}\n        return result\n    return compute_metrics\n\ndef create_metric_f1_accuracy_sigtest(all_labels):\n    f1_metric = evaluate.load(\"f1\")\n    accuracy_metric = evaluate.load(\"accuracy\")\n    def create_mapping(x):\n        try:\n            return all_labels.index(x)\n        except:\n            return -1\n    def compute_metrics(decoded_preds, decoded_labels):\n        decoded_preds, decoded_labels = postprocess_text_classification(decoded_preds, decoded_labels)\n        decoded_preds = [create_mapping(x) for x in decoded_preds]\n        decoded_labels = [create_mapping(x) for x in decoded_labels]\n        results_acc = []\n        results_f1 = []\n        for pred, gold in zip(decoded_preds, decoded_labels):\n            result_acc = accuracy_metric.compute(predictions=[pred], references=[gold])\n            result_f1 = f1_metric.compute(predictions=[pred], references=[gold], labels=list(range(len(all_labels))), average = \"macro\", pos_label = gold)\n            results_acc.append(result_acc[\"accuracy\"])\n            results_f1.append(result_f1[\"f1\"])\n        result = {\"accuracy\" : results_acc, \"f1\" : results_f1}\n        return result\n    return compute_metrics\n\ndef create_metric_mae_rmse():\n    mse_metric = evaluate.load(\"mse\")\n    mae_metric = evaluate.load(\"mae\")\n    def create_mapping(x, y):\n        try:\n            return float(x)\n        except:\n            print(x)\n            y = float(y)\n            if abs(1 - y) > abs(5 - y):\n                return 1.0\n            else:\n                return 5.0\n    def compute_metrics(decoded_preds, decoded_labels):\n        decoded_preds, decoded_labels = postprocess_text_classification(decoded_preds, decoded_labels)\n        decoded_preds = [create_mapping(x,y) for x,y in zip(decoded_preds, decoded_labels)]\n        decoded_labels = [create_mapping(x,x) for x in decoded_labels]\n        result_mae = mae_metric.compute(predictions=decoded_preds, references=decoded_labels)\n        result_rmse = mse_metric.compute(predictions=decoded_preds, references=decoded_labels, squared = False)\n        result = {\"MAE\" : result_mae[\"mae\"], \"RMSE\" : result_rmse[\"mse\"]}\n        return result\n    return compute_metrics\n\ndef create_metric_mae_rmse_sigtest():\n    mse_metric = evaluate.load(\"mse\")\n    mae_metric = evaluate.load(\"mae\")\n    def create_mapping(x, y):\n        try:\n            return float(x)\n        except:\n            print(x)\n            y = float(y)\n            if abs(1 - y) > abs(5 - y):\n                return 1.0\n            else:\n                return 5.0\n    def compute_metrics(decoded_preds, decoded_labels):\n        decoded_preds, decoded_labels = postprocess_text_classification(decoded_preds, decoded_labels)\n        decoded_preds = [create_mapping(x,y) for x,y in zip(decoded_preds, decoded_labels)]\n        decoded_labels = [create_mapping(x,x) for x in decoded_labels]\n        results_mae = []\n        results_rmse = []\n        for pred, gold in zip(decoded_preds, decoded_labels):\n            result_mae = mae_metric.compute(predictions=[pred], references=[gold])\n            result_rmse = mse_metric.compute(predictions=[pred], references=[gold], squared = False)\n            results_mae.append(result_mae[\"mae\"])\n            results_rmse.append(result_rmse[\"mse\"])\n        result = {\"MAE\" : results_mae, \"RMSE\" : results_rmse}\n        return result\n    return compute_metrics\n\ndef create_metric_rouge():\n    rouge_metric = evaluate.load('rouge')\n    def compute_metrics(decoded_preds, decoded_labels):\n        decoded_preds, decoded_labels = postprocess_text_generation(decoded_preds, decoded_labels)\n        result_rouge = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels)\n        result = {\"rouge-1\" : result_rouge[\"rouge1\"], \"rouge-L\" : result_rouge[\"rougeL\"]}\n        return result\n    return compute_metrics\n\ndef create_metric_rouge_sigtest():\n    rouge_metric = evaluate.load('rouge')\n    def compute_metrics(decoded_preds, decoded_labels):\n        decoded_preds, decoded_labels = postprocess_text_generation(decoded_preds, decoded_labels)\n        result_rouge = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels, use_aggregator = False)\n        result = {\"rouge-1\" : result_rouge[\"rouge1\"], \"rouge-L\" : result_rouge[\"rougeL\"]}\n        return result\n    return compute_metrics\n\nclass LaMPEvaluation(object):\n    \n    def __init__(self, all_golds_zip_file_addr = None, single_gold_json_file_addr = None, extract_addr = \"./tmp\") -> None:\n        assert all_golds_zip_file_addr or single_gold_json_file_addr, \"The golds should be provided for all datasets or at least one.\"\n        assert not (all_golds_zip_file_addr and single_gold_json_file_addr), \"The golds should be provided using zip file or json file not both.\"\n        self.tasks_golds = dict()\n        self.extract_addr = extract_addr\n        self.evaluate_all_is_possible = False\n        if all_golds_zip_file_addr:\n            os.makedirs(self.extract_addr, exist_ok=True)\n            with zipfile.ZipFile(all_golds_zip_file_addr, 'r') as zobj:\n                zobj.extractall(path = extract_addr)\n            for file_addr in glob.glob(os.path.join(self.extract_addr, \"**/*.json\"), recursive=True):\n                with open(file_addr) as file:\n                    task = json.load(file)\n                    self.tasks_golds[task['task']] = task['golds']\n            self._empty_dir(self.extract_addr)\n            self.evaluate_all_is_possible = True\n        if single_gold_json_file_addr:\n            with open(single_gold_json_file_addr) as file:\n                    task = json.load(file)\n                    self.tasks_golds[task['task']] = task['golds']\n    \n    def _empty_dir(self, directory_path):\n        for filename in os.listdir(directory_path):\n            file_path = os.path.join(directory_path, filename)\n            try:\n                if os.path.isfile(file_path):\n                    os.unlink(file_path)\n                elif os.path.isdir(file_path):\n                    shutil.rmtree(file_path)\n            except Exception as e:\n                print(f'Failed to delete {file_path}. Reason: {e}')\n\n    def _get_all_gold_ids(self, task_name):\n        return set([sample['id'] for sample in self.tasks_golds[task_name]])\n    \n    def _get_all_ids(self, input):\n        return set([sample['id'] for sample in input])\n    \n    def evaluate_all(self, predicts_zipfile_addr):\n        assert self.evaluate_all_is_possible, \"You did not provide golds for all tasks.\"\n        with zipfile.ZipFile(predicts_zipfile_addr, 'r') as zobj:\n            zobj.extractall(path = self.extract_addr)\n        results_raw = dict()\n        all_task_names = set()\n        for file_addr in glob.glob(os.path.join(self.extract_addr, \"**/*.json\"), recursive=True):\n            with open(file_addr) as file:\n                preds = json.load(file)\n            all_task_names.add(preds['task'])\n            results_raw[preds['task']] = self._evaluate_task(preds['golds'], preds['task'])\n        self._empty_dir(self.extract_addr)\n        assert len(all_task_names) == 7, \"The provided results do not cover all the tasks in the benchmark.\"\n        return results_raw\n\n    def evaluate_task(self, predicts_json_addr, task_name):\n        with open(predicts_json_addr) as file:\n            preds = json.load(file)\n        assert preds['task'] == task_name or preds['task'].replace(\"-\",\"_\") == task_name, \"The provided task_name and the results do not match.\"\n        assert preds['task'] in self.tasks_golds.keys() or preds['task'].replace(\"-\",\"_\") in self.tasks_golds.keys(), \"The provided golds cannot be used to evaluate this task.\"\n        return self._evaluate_task(preds['golds'], task_name)\n        \n    def _evaluate_task(self, predictions, task_name):\n        golds_dict = {y['id']:y['output'] for y in self.tasks_golds[task_name]}\n        preds_dict = {x['id']:x['output'] for x in predictions}\n        \n        gold_ids = self._get_all_gold_ids(task_name)\n        pred_ids = self._get_all_ids(predictions)\n        print(gold_ids - pred_ids)\n        assert gold_ids == pred_ids, \"Predictions ids and gold ids do not match.\"\n\n        if task_name in [\"LaMP_1\", \"LaMP_2\"]:\n            metric = create_metric_f1_accuracy(self._get_labels(task_name))\n        elif task_name == \"LaMP_3\":\n            metric = create_metric_mae_rmse()\n        else:\n            metric = create_metric_rouge()\n        \n        gold_ids = list(gold_ids)\n        golds = [golds_dict[id] for id in gold_ids]\n        preds = [preds_dict[id] for id in gold_ids]\n        return metric(preds, golds)\n    \n    def _evaluate_task_per_sample(self, predictions, task_name):\n        golds_dict = {y['id']:y['output'] for y in self.tasks_golds[task_name]}\n        preds_dict = {x['id']:x['output'] for x in predictions}\n        \n        gold_ids = self._get_all_gold_ids(task_name)\n        pred_ids = self._get_all_ids(predictions)\n        print(gold_ids - pred_ids)\n        assert gold_ids == pred_ids, \"Predictions ids and gold ids do not match.\"\n\n        if task_name in [\"LaMP_1\", \"LaMP_2\"]:\n            metric = create_metric_f1_accuracy_sigtest(self._get_labels(task_name))\n        elif task_name == \"LaMP_3\":\n            metric = create_metric_mae_rmse_sigtest()\n        else:\n            metric = create_metric_rouge_sigtest()\n        \n        gold_ids = list(gold_ids)\n        golds = [golds_dict[id] for id in gold_ids]\n        preds = [preds_dict[id] for id in gold_ids]\n        return metric(preds, golds)\n    \n    def _get_labels(self, task_name):\n        if task_name == \"LaMP_1\":\n            return [\"[1]\", \"[2]\"]\n        elif task_name == \"LaMP_2\":\n            return ['sci-fi', 'based on a book', 'comedy', 'action', 'twist ending', 'dystopia', 'dark comedy', 'classic', 'psychology', 'fantasy', 'romance', 'thought-provoking', 'social commentary', 'violence', 'true story']\n        elif task_name == \"LaMP_3\":\n            return [\"1\", \"2\", \"3\", \"4\", \"5\"]\n        else:\n            raise ValueError(\"Invalid task_name\")"
  },
  {
    "path": "RSPG/modeling/__init__.py",
    "content": ""
  },
  {
    "path": "RSPG/modeling/modeling.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport types\nimport torch\nfrom transformers import PreTrainedModel, AutoModel\nimport torch.nn.functional as F\nfrom torch import nn\nfrom torch.nn import CrossEntropyLoss\nimport numpy as np\n\nclass RSPG(PreTrainedModel):\n    def __init__(self, config, **kwargs):\n        super().__init__(config, **kwargs)\n        self.model = AutoModel.from_pretrained(config.init_model)\n        self.classifier = nn.Linear(self.model.config.hidden_size, 1)\n    \n    def forward(self, input_ids = None, attention_mask = None, token_type_ids = None, **kwargs):\n        output = self.model(\n            input_ids = input_ids,\n            attention_mask = attention_mask,\n            token_type_ids = token_type_ids\n        )\n        return self.classifier(output.pooler_output).view(-1, self.config.num_labels)\n\nclass Trainer(nn.Module):\n\n    def __init__(self, model, temperature = 1.0):\n        super().__init__()\n        self.model = model\n        self.loss_fn_kl = nn.KLDivLoss(reduction=\"batchmean\")\n        self.temperature = temperature\n    \n    def forward(\n            self,\n            input_ids = None,\n            attention_mask = None,\n            token_type_ids = None,\n            labels = None\n        ):\n\n        outputs = self.model(\n            input_ids = input_ids,\n            attention_mask = attention_mask,\n            token_type_ids = token_type_ids\n        )\n\n        loss = self.loss_fn_kl(nn.functional.log_softmax(outputs, dim = -1) / self.temperature, nn.functional.softmax(labels, dim = -1))\n        \n        return loss, outputs"
  },
  {
    "path": "RSPG/modeling/optim.py",
    "content": "import torch\n\nclass WarmupLinearScheduler(torch.optim.lr_scheduler.LambdaLR):\n    def __init__(self, optimizer, warmup_steps, scheduler_steps, min_ratio, fixed_lr, last_epoch=-1):\n        self.warmup_steps = warmup_steps\n        self.scheduler_steps = scheduler_steps\n        self.min_ratio = min_ratio\n        self.fixed_lr = fixed_lr\n        super(WarmupLinearScheduler, self).__init__(\n            optimizer, self.lr_lambda, last_epoch=last_epoch\n        )\n\n    def lr_lambda(self, step):\n        if step < self.warmup_steps:\n            return (1 - self.min_ratio)*step/float(max(1, self.warmup_steps)) + self.min_ratio\n\n        if self.fixed_lr:\n            return 1.0\n\n        return max(0.0,\n            1.0 + (self.min_ratio - 1) * (step - self.warmup_steps)/float(max(1.0, self.scheduler_steps - self.warmup_steps)),\n        )\n\n\nclass FixedScheduler(torch.optim.lr_scheduler.LambdaLR):\n    \n    def __init__(self, optimizer, last_epoch=-1):\n        super(FixedScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)\n   \n    def lr_lambda(self, step):\n        return 1.0\n\n\ndef set_optim(opt, model):\n    if opt.optim == 'adam':\n        optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)\n    elif opt.optim == 'adamw':\n        optimizer = torch.optim.AdamW(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)\n    if opt.scheduler == 'fixed':\n        scheduler = FixedScheduler(optimizer)\n    elif opt.scheduler == 'linear':\n        if opt.scheduler_steps is None:\n            scheduler_steps = opt.total_steps\n        else:\n            scheduler_steps = opt.scheduler_steps\n        scheduler = WarmupLinearScheduler(optimizer, warmup_steps=opt.warmup_steps, scheduler_steps=scheduler_steps, min_ratio=0., fixed_lr=opt.fixed_lr)\n    return optimizer, scheduler"
  },
  {
    "path": "RSPG/modeling/utils.py",
    "content": "import os\nfrom logging import getLogger\nimport torch\nfrom modeling.optim import set_optim\nimport torch.distributed as dist\nimport errno\n\ndef load_checkpoint(model_class, dir_path, opt, reset_params=False):\n    epoch_path = os.path.realpath(dir_path)\n    optimizer_path = os.path.join(epoch_path, \"optimizer.pth.tar\")\n    logger = getLogger()\n    logger.info(\"Loading %s\" % epoch_path)\n    model = model_class.from_pretrained(epoch_path, local_files_only=True)\n    model = model.to(opt.device)\n    logger.info(\"loading checkpoint %s\" %optimizer_path)\n    checkpoint = torch.load(optimizer_path, map_location=opt.device)\n    opt_checkpoint = checkpoint[\"opt\"]\n    step = checkpoint[\"step\"]\n    if not reset_params:\n        optimizer, scheduler = set_optim(opt_checkpoint, model)\n        scheduler.load_state_dict(checkpoint[\"scheduler\"])\n        optimizer.load_state_dict(checkpoint[\"optimizer\"])\n    else:\n        optimizer, scheduler = set_optim(opt, model)\n\n    return model, optimizer, scheduler, opt_checkpoint, step\n\ndef average_main(x, opt):\n    if not opt.is_distributed:\n        return x\n    if opt.world_size > 1:\n        dist.reduce(x, 0, op=dist.ReduceOp.SUM)\n        if opt.is_main:\n            x = x / opt.world_size\n    return x\n\ndef symlink_force(target, link_name):\n    try:\n        os.symlink(target, link_name)\n    except OSError as e:\n        if e.errno == errno.EEXIST:\n            os.remove(link_name)\n            os.symlink(target, link_name)\n        else:\n            raise e\n\ndef save_checkpoint(model, optimizer, scheduler, step, opt, dir_path, name):\n    model_to_save = model.module if hasattr(model, \"module\") else model\n    path = os.path.join(dir_path, \"checkpoint\")\n    epoch_path = os.path.join(path, name) #\"step-%s\" % step)\n    os.makedirs(epoch_path, exist_ok=True)\n    model_to_save.save_pretrained(epoch_path)\n    cp = os.path.join(path, \"latest\")\n    fp = os.path.join(epoch_path, \"optimizer.pth.tar\")\n    checkpoint = {\n        \"step\": step,\n        \"optimizer\": optimizer.state_dict(),\n        \"scheduler\": scheduler.state_dict(),\n        \"opt\": opt,\n    }\n    torch.save(checkpoint, fp)\n    symlink_force(epoch_path, cp)\n\n"
  },
  {
    "path": "RSPG/requirements.txt",
    "content": "datasets==2.8.0\nregex==2022.10.31\nsentencepiece==0.1.97\ntokenizers==0.11.1\ntorch==2.0.1\ntqdm==4.64.1\ntransformers==4.28.0\nevaluate\nabsl-py\nrouge-score\n"
  },
  {
    "path": "RSPG/rspg.py",
    "content": "import argparse\nimport torch\nimport json\nimport os\nfrom pathlib import Path\nfrom utils.log import init_logger\nfrom pathlib import Path\nfrom utils.distributed import init_distributed_mode, init_signal_handler\nimport torch\nfrom modeling import optim\nfrom data.dataset import RSPGDataset\nfrom data.collators import RSPGPostCollator, RSPGPreCollator\nfrom modeling.modeling import RSPG, Trainer\nfrom modeling.utils import load_checkpoint, average_main, save_checkpoint\nfrom torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler\nimport tqdm\nfrom transformers import AutoTokenizer, AutoModel, PretrainedConfig\nfrom metrics.evaluation import LaMPEvaluation\nimport numpy as np\nimport glob\n\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument(\"--train_data\", required = True, help=\"the training data\")\nparser.add_argument(\"--val_data\", required = True, help=\"the validation data\")\nparser.add_argument(\"--rspg_type\", required = True, help=\"RSPG type: [Pre, Post]\")\n\nparser.add_argument(\"--val_lamp_golds\", required = True, help=\"the validation data\")\nparser.add_argument(\"--do_filtering\", action='store_true')\n\nparser.add_argument(\"--task\", required = True, help=\"task\")\nparser.add_argument(\"--do_train\", action='store_true', help=\"perform training\")\nparser.add_argument(\"--do_validation\", action='store_true', help=\"perform validation\")\nparser.add_argument(\"--max_length_input\", type = int, default = 512, help=\"maximum input length\")\nparser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment')\nparser.add_argument('--checkpoint_dir', type=str, default='./checkpoint/', help='models are saved here')\nparser.add_argument('--model_path', type=str, default='none', help='path for a checkpoint to start training from that')\nparser.add_argument(\"--per_gpu_batch_size\", default=1, type=int, \n           help=\"Batch size per GPU/CPU for training.\")\nparser.add_argument(\"--local-rank\", type=int, default=-1,\n           help=\"For distributed training: local_rank\")\nparser.add_argument(\"--main_port\", type=int, default=-1,\n           help=\"Main port (for multi-node SLURM jobs)\")\nparser.add_argument('--seed', type=int, default=0, help=\"random seed for initialization\")\nparser.add_argument('--eval_freq', type=int, default=500,\n           help='evaluate model every <eval_freq> steps during training')\nparser.add_argument('--save_freq', type=int, default=5000,\n           help='save model every <save_freq> steps during training')\nparser.add_argument('--eval_print_freq', type=int, default=1000,\n           help='print intermdiate results of evaluation every <eval_print_freq> steps')\nparser.add_argument('--warmup_steps', type=int, default=1000, help=\"number of warmup steps\")\nparser.add_argument('--total_steps', type=int, default=1000, help=\"number of training steps\")\nparser.add_argument('--scheduler_steps', type=int, default=None, \n           help='total number of step for the scheduler, if None then scheduler_total_step = total_step')\nparser.add_argument('--accumulation_steps', type=int, default=1, help=\"number of gradient accumulation steps\")\nparser.add_argument('--dropout', type=float, default=0.1, help='dropout rate')\nparser.add_argument('--lr', type=float, default=1e-5, help='learning rate')\nparser.add_argument('--clip', type=float, default=1., help='gradient clipping')\nparser.add_argument('--optim', type=str, default='adam', help=\"optimizer which is used for training\")\nparser.add_argument('--scheduler', type=str, default='fixed', help=\"scheduler which is used for training\")\nparser.add_argument('--weight_decay', type=float, default=0.0, help=\"weight decay rate\")\nparser.add_argument('--temperature', type=float, default=1.0, help=\"distillation temperature\")\nparser.add_argument('--fixed_lr', action='store_true', help=\"use a fixed lr\")\n\n\ndef train(opts, model, optimizer, scheduler, step, dataset, collator, checkpoint_path, test_dataset, logger, compute_metrics):\n    \n    if opts.is_main:\n        try:\n            tb_logger = torch.utils.tensorboard.SummaryWriter(Path(opts.checkpoint_dir)/opts.name)\n        except:\n           tb_logger = None\n           logger.warning('Tensorboard is not available.')\n\n    torch.manual_seed(opts.global_rank + opts.seed) #different seed for different sampling depending on global_rank\n    train_sampler = DistributedSampler(dataset, num_replicas=opts.n_gpu_per_node, rank=opts.local_rank)\n    train_dataloader = DataLoader(\n        dataset,\n        sampler = train_sampler,\n        batch_size = opts.per_gpu_batch_size,\n        drop_last = True,\n        num_workers = 10,\n        collate_fn = collator\n    )\n\n    loss, curr_loss = 0.0, 0.0\n    epoch = 1\n    model.train()\n    bar = tqdm.tqdm(total=opts.total_steps)\n    temp_step = 0\n    while step < opts.total_steps:\n        epoch += 1\n        for i, batch in enumerate(train_dataloader):\n            temp_step += 1\n            train_loss = model(\n                input_ids = batch['input_ids'].cuda(),\n                attention_mask = batch['attention_mask'].cuda(),\n                labels = batch['labels'].cuda()\n            )[0]\n            \n            train_loss.backward()\n\n            if temp_step % opts.accumulation_steps == 0:\n                step += 1\n                temp_step = 0\n                bar.update(1)\n                torch.nn.utils.clip_grad_norm_(model.parameters(), opts.clip)\n                optimizer.step()\n                scheduler.step()\n                model.zero_grad()\n\n            train_loss = average_main(train_loss, opts)\n            curr_loss += train_loss.item()\n            \n            if opts.is_main and step % opts.eval_freq == 0 and temp_step == 0 and step != 0:\n                metrics = evaluate(model.module, test_dataset, collator, opts, step, logger, compute_metrics)\n                if opts.is_main:\n                    log = f\"{step} / {opts.total_steps} |\"\n                    log += f\"train: {curr_loss/opts.eval_freq:.4f} | {metrics}\"\n                    logger.info(log)\n                    if tb_logger is not None:\n                        tb_logger.add_scalar(\"Training\", curr_loss / (opts.eval_freq), step)\n                    curr_loss = 0.\n                    save_checkpoint(model.module.model, optimizer, scheduler, step, opts, checkpoint_path, f\"step-{step}\")\n                    model.train()\n            if opts.is_main and step % opts.save_freq == 0 and temp_step == 0:\n                save_checkpoint(model.module.model, optimizer, scheduler, step, opts, checkpoint_path, f\"step-{step}\")\n            if step > opts.total_steps:\n                break\n    save_checkpoint(model.module.model, optimizer, scheduler, step, opts, checkpoint_path, f\"step-{step}\")\n\ndef evaluate(model, dataset, collator, opt, step, logger, evaluator, test_eval = False):\n    sampler = SequentialSampler(dataset)\n    dataloader = DataLoader(dataset,\n        sampler=sampler,\n        batch_size=opt.per_gpu_batch_size,\n        drop_last=False,\n        num_workers=10,\n        collate_fn=collator\n    )\n    model.eval()\n    with torch.no_grad():\n        preds = []\n        golds = []\n        ids = []\n        indices = []\n        \n        logger.info(\"Evaluation Started\")\n        for i, batch in enumerate(tqdm.tqdm(dataloader)):\n            if test_eval:\n                outputs = model(\n                    input_ids=batch['input_ids'].cuda(), \n                    attention_mask = batch['attention_mask'].cuda(),\n                )\n            else:\n                outputs = model.model(\n                    input_ids=batch['input_ids'].cuda(), \n                    attention_mask = batch['attention_mask'].cuda(),\n                )\n            indices_max = torch.argmax(outputs, dim = -1)\n            ids.extend(batch['id'])\n            preds.extend([z[k] for k, z in zip(indices_max, batch['outputs'])])\n            golds.extend(batch['gold'])\n            indices.extend(indices_max.tolist())\n    \n    checkpoint_path = Path(opts.checkpoint_dir) / opts.name / \"predictions\" / str(step)\n    checkpoint_path.mkdir(parents = True, exist_ok = True)\n    \n    with open(os.path.join(checkpoint_path, f'{str(opts.local_rank)}.json'), \"w\") as file:\n        json.dump({\"preds\" : preds, \"golds\" : golds, \"ids\" : ids, \"indices\" : indices}, file)\n\n    if opts.is_main:\n        results = glob.glob(os.path.join(checkpoint_path, f'*.json'))\n        preds, golds, ids, indices = [], [], [], []\n        for addr in results:\n            with open(addr) as file:\n                temp = json.load(file)\n                preds.extend(temp['preds'])\n                golds.extend(temp['golds'])\n                ids.extend(temp['ids'])\n                indices.extend(temp['indices'])\n        final_preds = {\n            \"task\" : opts.task.replace(\"-\", \"_\"),\n            \"golds\" : [{\"id\" : id, \"output\" : out, \"index\":ind} for id, out, ind in zip(ids, preds, indices)]\n        }\n        final_preds_addr = os.path.join(checkpoint_path, f'final_preds.json')\n        with open(final_preds_addr, \"w\") as file:\n            json.dump(final_preds, file, indent=4)\n        return evaluator.evaluate_task(final_preds_addr, opts.task.replace(\"-\", \"_\"))\n\nif __name__ == \"__main__\":\n    opts = parser.parse_args()\n\n    torch.manual_seed(opts.seed)\n    init_distributed_mode(opts)\n    init_signal_handler()\n\n    checkpoint_path = Path(opts.checkpoint_dir) / opts.name\n    checkpoint_exists = checkpoint_path.exists()\n    \n    if opts.is_distributed:\n        torch.distributed.barrier()\n    \n    checkpoint_path.mkdir(parents = True, exist_ok = True)\n\n\n    logger = init_logger(\n        opts.is_main,\n        opts.is_distributed,\n        checkpoint_path / 'run.log'\n    )\n\n    logger.info(opts)\n\n    tokenizer = AutoTokenizer.from_pretrained('allenai/longformer-base-4096')\n\n    task = opts.task\n    if task == \"LaMP-3\":\n        smaller_is_better = True\n    else:\n        smaller_is_better = False\n    train_dataset = RSPGDataset(opts.train_data, smaller_is_better)\n    val_dataset = RSPGDataset(opts.val_data, smaller_is_better)\n    if opts.rspg_type == \"Post\":\n        collator = RSPGPostCollator(tokenizer, opts.max_length_input)\n    else:\n        collator = RSPGPreCollator(tokenizer, opts.max_length_input)\n    \n    compute_metrics = LaMPEvaluation(\n        single_gold_json_file_addr=opts.val_lamp_golds\n    )\n\n    if checkpoint_exists and opts.do_train:\n        model, optimizer, scheduler, checkpoint_opts, step = load_checkpoint(RSPG, os.path.join(checkpoint_path, \"checkpoint\", \"latest\"), opts)\n    elif opts.do_train:\n        model = AutoModel.from_pretrained('allenai/longformer-base-4096')\n        model.config.num_labels = 6\n        model.config.init_model = 'allenai/longformer-base-4096'\n        model = RSPG(model.config)\n        model = Trainer(model, opts.temperature)\n        optimizer, scheduler = optim.set_optim(opts, model)\n        step = 0\n    elif opts.do_validation:\n        config = PretrainedConfig.from_pretrained(opts.model_path)\n        model = RSPG.from_pretrained(opts.model_path, config = config)\n    \n    model = model.to(opts.local_rank)\n    \n    if opts.is_distributed:\n        model = torch.nn.parallel.DistributedDataParallel(\n            model,\n            device_ids=[opts.local_rank],\n            output_device=opts.local_rank,\n            find_unused_parameters=True,\n        )\n    \n    if opts.do_train:\n        train(opts, model, optimizer, scheduler, step, train_dataset, collator, checkpoint_path, val_dataset, logger, compute_metrics)\n    \n    if opts.do_validation and opts.is_main:\n        metrics = evaluate(model, val_dataset, collator, opts, \"validation\", logger, compute_metrics, True)\n        if opts.is_main:\n            log = f\"test: {metrics}\"\n            logger.info(log)\n\n"
  },
  {
    "path": "RSPG/utils/__init__.py",
    "content": ""
  },
  {
    "path": "RSPG/utils/create_data.py",
    "content": "import argparse\nimport json\nimport os\n\ndef merge(inps, outs, label_files, input_files, score_name):\n    for inp in inps:\n        del inp['profile']\n        for o in outs:\n            if o['id'] == inp['id']:\n                output = o['output']\n                break\n        inp['gold'] = output\n        labels = []\n        outputs = []\n        inputs = []\n        for k, label_file in enumerate(label_files):\n            print(k)\n            labels.append(label_file[inp['id']]['metric'][score_name])\n            outputs.append(label_file[inp['id']]['output'])\n        for input_file in input_files:\n            inputs.append(input_file[inp['id']]['input'])\n        inp['labels'] = labels\n        inp['outputs'] = outputs\n        inp['inputs'] = inputs\n    return inps\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument(\"--retrivers_data_addr\", '--names-list', nargs='+', required=True)\nparser.add_argument(\"--task_inputs_addr\", required=True)\nparser.add_argument(\"--task_outputs_addr\", required=True)\nparser.add_argument(\"--output_dataset_addr\", required=True)\nparser.add_argument(\"--metric\", required=True)\n\nif __name__ == \"__main__\":\n    opts = parser.parse_args()\n\n    score_name = opts.metric\n\n    q_addr = opts.task_inputs_addr\n    o_addr = opts.task_outputs_addr\n    res_addr = opts.output_dataset_addr\n    retrivers_data_addrs = opts.retrivers_data_addr\n\n    with open(q_addr) as qfile, open(o_addr) as oflie, open(res_addr, \"w\") as resfile:\n        inp = json.load(qfile)\n        out = json.load(oflie)\n        scores_file = []\n        input_file = []\n        for x in retrivers_data_addrs:\n            with open(os.path.join(x, \"scores.json\")) as sfile:\n                scores_file.append(json.load(sfile))\n            with open(os.path.join(x, \"data.json\")) as sfile:\n                input_file.append(json.load(sfile))\n        res = merge(inp, out['golds'], scores_file, input_file, score_name)\n        json.dump(res, resfile, indent=4)\n"
  },
  {
    "path": "RSPG/utils/distributed.py",
    "content": "from logging import getLogger\nimport os\nimport sys\nimport torch\nimport socket\nimport signal\nimport subprocess\nimport datetime\nimport os\n\n\nlogger = getLogger()\n\ndef sig_handler(signum, frame):\n    logger.warning(\"Signal handler called with signal \" + str(signum))\n    prod_id = int(os.environ['SLURM_PROCID'])\n    logger.warning(\"Host: %s - Global rank: %i\" % (socket.gethostname(), prod_id))\n    if prod_id == 0:\n        logger.warning(\"Requeuing job \" + os.environ['SLURM_JOB_ID'])\n        os.system('scontrol requeue ' + os.environ['SLURM_JOB_ID'])\n    else:\n        logger.warning(\"Not the main process, no need to requeue.\")\n    sys.exit(-1)\n\n\ndef term_handler(signum, frame):\n    logger.warning(\"Signal handler called with signal \" + str(signum))\n    logger.warning(\"Bypassing SIGTERM.\")\n\n\ndef init_signal_handler():\n    signal.signal(signal.SIGUSR1, sig_handler)\n    signal.signal(signal.SIGTERM, term_handler)\n\n\ndef init_distributed_mode(params):\n    \n    has_local_rank = hasattr(params, 'local_rank')\n\n    if has_local_rank:\n        params.local_rank = params.local_rank\n\n    \n    if has_local_rank and params.local_rank != -1:\n\n        assert params.main_port == -1\n\n        # read environment variables\n        params.global_rank = int(os.environ['RANK'])\n        params.world_size = int(os.environ['WORLD_SIZE'])\n        params.n_gpu_per_node = int(os.environ['NGPU'])\n\n        # number of nodes / node ID\n        params.n_nodes = params.world_size // params.n_gpu_per_node\n        params.node_id = params.global_rank // params.n_gpu_per_node\n        params.is_distributed = True\n    else:\n        n_gpu = torch.cuda.device_count()\n        params.n_nodes = 1\n        params.node_id = 0\n        params.local_rank = 0\n        params.global_rank = 0\n        params.world_size = n_gpu\n        params.n_gpu_per_node = n_gpu\n        params.is_distributed = False\n    \n    # define whether this is the master process / if we are in distributed mode\n    params.is_main = params.node_id == 0 and params.local_rank == 0\n    params.multi_node = params.n_nodes > 1\n    params.multi_gpu = params.world_size > 1\n\n    # set GPU device\n    if params.is_distributed:\n        torch.cuda.set_device(params.local_rank)\n        device = torch.device(\"cuda\", params.local_rank)\n    else:\n        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    params.device = device\n\n    # initialize multi-GPU\n    if params.is_distributed:\n        torch.distributed.init_process_group(\n            init_method='env://',\n            backend='nccl',\n            timeout = datetime.timedelta(seconds=36000)\n        )\n"
  },
  {
    "path": "RSPG/utils/log.py",
    "content": "import logging\nimport torch\nimport sys\n\nlogger = logging.getLogger(__name__)\n\ndef init_logger(is_main=True, is_distributed=False, filename=None):\n    if is_distributed:\n        torch.distributed.barrier()\n    handlers = [logging.StreamHandler(sys.stdout)]\n    if filename is not None:\n        handlers.append(logging.FileHandler(filename = filename))\n    logging.basicConfig(\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO if is_main else logging.WARN,\n        format=\"[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s\",\n        handlers=handlers,\n    )\n    logging.getLogger('transformers.tokenization_utils').setLevel(logging.ERROR)\n    logging.getLogger('transformers.tokenization_utils_base').setLevel(logging.ERROR)\n    return logger"
  },
  {
    "path": "data/avocado/create_avocado_dataset.py",
    "content": "import zipfile\nimport glob\nimport os\nimport shutil\nimport json\nimport tqdm\nimport mailparser\nimport argparse\n\ndef empty_dir(directory_path):\n    for filename in os.listdir(directory_path):\n        file_path = os.path.join(directory_path, filename)\n        try:\n            # if the current item is a file, remove it\n            if os.path.isfile(file_path):\n                os.unlink(file_path)\n            # if the current item is a directory, remove it recursively using shutil.rmtree()\n            elif os.path.isdir(file_path):\n                shutil.rmtree(file_path)\n        except Exception as e:\n            print(f'Failed to delete {file_path}. Reason: {e}')\n\ndef process_file(file_addr):\n    message = \"\"\n    id = os.path.basename(file_addr)\n    mail = mailparser.parse_from_file(file_addr)\n    subject = mail.subject\n    message = mail.body\n    return id, {\"subject\" : subject, \"content\" : message.strip()}\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument(\"--avocado_files_dir\", required=True, help=\"Address to the directory containing zip files for avocado dataset 'avocado-1.0.2/data/text'\")\nparser.add_argument(\"--extract_addr\", required=True, help=\"A temp dir to extract the files for creating dataset\")\nparser.add_argument(\"--output_dir\", required=True, help=\"The directory to generate the final dataset\")\nparser.add_argument(\"--input_question_file_train\", required=True, help=\"The address to the train_questions.json file\")\nparser.add_argument(\"--input_question_file_dev\", required=True, help=\"The address to the dev_questions.json file\")\nparser.add_argument(\"--input_question_file_test\", required=True, help=\"The address to the test_questions.json file\")\nparser.add_argument(\"--time_based_separation\", action=\"store_true\", help=\"Stores extra information about user and date\")\n\n\nif __name__ == \"__main__\":\n    opts = parser.parse_args()\n\n    with open(opts.input_question_file_train) as file:\n        input_questions_train = json.load(file)\n    with open(opts.input_question_file_dev) as file:\n        input_questions_dev = json.load(file)\n    with open(opts.input_question_file_test) as file:\n        input_questions_test = json.load(file)\n    \n    all_required_files = set()\n    for sample in input_questions_train + input_questions_dev + input_questions_test:\n        all_required_files.add(sample['input'])\n        for p in sample['profile']:\n            all_required_files.add(p['text'])\n    \n    zip_addrs = glob.glob(os.path.join(opts.avocado_files_dir, \"*\"))\n    os.makedirs(opts.extract_addr, exist_ok=True)\n    database = dict()\n    for zip_addr in tqdm.tqdm(zip_addrs):\n        with zipfile.ZipFile(zip_addr, 'r') as zobj:\n            zobj.extractall(path = opts.extract_addr)\n            extracted_files_addrs = glob.glob(os.path.join(opts.extract_addr, \"*/*\"))\n            for file_addr in extracted_files_addrs:\n                if os.path.basename(file_addr) in all_required_files:\n                    id, obj = process_file(file_addr)\n                    database[id] = obj\n        empty_dir(opts.extract_addr)\n    \n    os.makedirs(opts.output_dir, exist_ok=True)\n\n    inps_train, outs_train = [], []\n    for sample in input_questions_train:\n        id = sample['input']\n        sample['input'] = f\"Generate a subject for the following email: {database[id]['content']}\"\n        sample['output'] = database[id]['subject']\n        for p in sample['profile']:\n            pid = p['text']\n            p['text'] = database[pid]['content']\n            p['title'] = database[pid]['subject']\n        if opts.time_based_separation:\n            inps_train.append({\"id\" : sample['id'], \"input\" : sample['input'], \"profile\" : sample['profile'], \"user_id\" : sample['user_id']})\n        else:\n            inps_train.append({\"id\" : sample['id'], \"input\" : sample['input'], \"profile\" : sample['profile']})\n        outs_train.append({\"id\" : sample['id'], \"output\" : sample['output']})\n\n    inps_dev, outs_dev = [], []\n    for sample in input_questions_dev:\n        id = sample['input']\n        sample['input'] = f\"Generate a subject for the following email: {database[id]['content']}\"\n        sample['output'] = database[id]['subject']\n        for p in sample['profile']:\n            pid = p['text']\n            p['text'] = database[pid]['content']\n            p['title'] = database[pid]['subject']\n        if opts.time_based_separation:\n            inps_dev.append({\"id\" : sample['id'], \"input\" : sample['input'], \"profile\" : sample['profile'], \"user_id\" : sample['user_id']})\n        else:\n            inps_dev.append({\"id\" : sample['id'], \"input\" : sample['input'], \"profile\" : sample['profile']})\n        outs_dev.append({\"id\" : sample['id'], \"output\" : sample['output']})\n\n    \n    inps_test= []\n    for sample in input_questions_test:\n        id = sample['input']\n        sample['input'] = f\"Generate a subject for the following email: {database[id]['content']}\"\n        for p in sample['profile']:\n            pid = p['text']\n            p['text'] = database[pid]['content']\n            p['title'] = database[pid]['subject']\n        if opts.time_based_separation:\n            inps_test.append({\"id\" : sample['id'], \"input\" : sample['input'], \"profile\" : sample['profile'], \"user_id\" : sample['user_id']})\n        else:\n            inps_test.append({\"id\" : sample['id'], \"input\" : sample['input'], \"profile\" : sample['profile']})\n        \n    with open(os.path.join(opts.output_dir, \"train_questions.json\"), \"w\") as file:\n        json.dump(inps_train, file)\n    with open(os.path.join(opts.output_dir, \"train_outputs.json\"), \"w\") as file:\n        json.dump({\"task\":\"LaMP_6\",\"golds\":outs_train}, file)\n\n    with open(os.path.join(opts.output_dir, \"dev_questions.json\"), \"w\") as file:\n        json.dump(inps_dev, file)\n    with open(os.path.join(opts.output_dir, \"dev_outputs.json\"), \"w\") as file:\n        json.dump({\"task\":\"LaMP_6\",\"golds\":outs_dev}, file)\n\n    with open(os.path.join(opts.output_dir, \"test_questions.json\"), \"w\") as file:\n        json.dump({\"task\":\"LaMP_6\",\"golds\":inps_test}, file)\n\n        \n\n"
  },
  {
    "path": "eval/eval_all.py",
    "content": "from evaluation import LaMPEvaluation\nimport argparse\nimport json\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument(\"--golds_zip\", required=True, help=\"Address to all gold labels for all tasks zipped in a file\")\nparser.add_argument(\"--preds_zip\", required=True, help=\"Address to all predictions for all tasks zipped in a file\")\nparser.add_argument(\"--temp_dir\", required=False, help=\"Address to a temp dir for extracting files\", default=\"./tmp\")\nparser.add_argument(\"--output_file\", required=True, help=\"Address to the results file\")\n\nif __name__ == \"__main__\":\n\n    opts = parser.parse_args()\n\n    evaluator = LaMPEvaluation(all_golds_zip_file_addr=opts.golds_zip, extract_addr=opts.temp_dir)\n    results = evaluator.evaluate_all(opts.preds_zip)\n    with open(opts.output_file, \"w\") as file:\n        json.dump(results, file)\n"
  },
  {
    "path": "eval/eval_task.py",
    "content": "from evaluation import LaMPEvaluation\nimport argparse\nimport json\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument(\"--golds_json\", required=True, help=\"Address to all gold labels for the task as a json file\")\nparser.add_argument(\"--preds_json\", required=True, help=\"Address to all predictions for the task as a json file\")\nparser.add_argument(\"--task_name\", required=True, help=\"[LaMP_1, LaMP_2, LaMP_3, LaMP_4, LaMP_5, LaMP_6, LaMP_7]\")\nparser.add_argument(\"--output_file\", required=True, help=\"Address to the results file\")\n\nif __name__ == \"__main__\":\n\n    opts = parser.parse_args()\n\n    evaluator = LaMPEvaluation(single_gold_json_file_addr=opts.golds_json)\n    results = evaluator.evaluate_task(opts.preds_json, opts.task_name)\n    with open(opts.output_file, \"w\") as file:\n        json.dump(results, file)\n"
  },
  {
    "path": "eval/evaluation.py",
    "content": "import json\nimport zipfile\nimport glob\nimport os\nimport shutil\nimport evaluate\n\ndef postprocess_text_classification(preds, labels):\n    preds = [str(pred).strip() for pred in preds]\n    labels = [str(label).strip() for label in labels]\n    return preds, labels\n\ndef postprocess_text_generation(preds, labels):\n    preds = [pred.strip() for pred in preds]\n    labels = [[label.strip()] for label in labels]\n\n    return preds, labels\n\ndef create_metric_f1_accuracy(all_labels):\n    f1_metric = evaluate.load(\"f1\")\n    accuracy_metric = evaluate.load(\"accuracy\")\n    def create_mapping(x):\n        try:\n            return all_labels.index(x)\n        except:\n            return -1\n    def compute_metrics(decoded_preds, decoded_labels):\n        decoded_preds, decoded_labels = postprocess_text_classification(decoded_preds, decoded_labels)\n        decoded_preds = [create_mapping(x) for x in decoded_preds]\n        decoded_labels = [create_mapping(x) for x in decoded_labels]\n        result_acc = accuracy_metric.compute(predictions=decoded_preds, references=decoded_labels)\n        result_f1 = f1_metric.compute(predictions=decoded_preds, references=decoded_labels, labels=list(range(len(all_labels))), average = \"macro\")\n        result = {\"accuracy\" : result_acc[\"accuracy\"], \"f1\" : result_f1[\"f1\"]}\n        return result\n    return compute_metrics\n\ndef create_metric_mae_rmse():\n    mse_metric = evaluate.load(\"mse\")\n    mae_metric = evaluate.load(\"mae\")\n    def create_mapping(x, y):\n        try:\n            return float(x)\n        except:\n            print(x)\n            y = float(y)\n            if abs(1 - y) > abs(5 - y):\n                return 1.0\n            else:\n                return 5.0\n    def compute_metrics(decoded_preds, decoded_labels):\n        decoded_preds, decoded_labels = postprocess_text_classification(decoded_preds, decoded_labels)\n        decoded_preds = [create_mapping(x,y) for x,y in zip(decoded_preds, decoded_labels)]\n        decoded_labels = [create_mapping(x,x) for x in decoded_labels]\n        result_mae = mae_metric.compute(predictions=decoded_preds, references=decoded_labels)\n        result_rmse = mse_metric.compute(predictions=decoded_preds, references=decoded_labels, squared = False)\n        result = {\"MAE\" : result_mae[\"mae\"], \"RMSE\" : result_rmse[\"mse\"]}\n        return result\n    return compute_metrics\n\ndef create_metric_rouge():\n    rouge_metric = evaluate.load('rouge')\n    def compute_metrics(decoded_preds, decoded_labels):\n        decoded_preds, decoded_labels = postprocess_text_generation(decoded_preds, decoded_labels)\n        result_rouge = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels)\n        result = {\"rouge-1\" : result_rouge[\"rouge1\"], \"rouge-L\" : result_rouge[\"rougeL\"]}\n        return result\n    return compute_metrics\n\nclass LaMPEvaluation(object):\n    \n    def __init__(self, all_golds_zip_file_addr = None, single_gold_json_file_addr = None, extract_addr = \"./tmp\") -> None:\n        assert all_golds_zip_file_addr or single_gold_json_file_addr, \"The golds should be provided for all datasets or at least one.\"\n        assert not (all_golds_zip_file_addr and single_gold_json_file_addr), \"The golds should be provided using zip file or json file not both.\"\n        self.tasks_golds = dict()\n        self.extract_addr = extract_addr\n        self.evaluate_all_is_possible = False\n        if all_golds_zip_file_addr:\n            os.makedirs(self.extract_addr, exist_ok=True)\n            with zipfile.ZipFile(all_golds_zip_file_addr, 'r') as zobj:\n                zobj.extractall(path = extract_addr)\n            for file_addr in glob.glob(os.path.join(self.extract_addr, \"**/*.json\"), recursive=True):\n                with open(file_addr) as file:\n                    task = json.load(file)\n                    self.tasks_golds[task['task']] = task['golds']\n            self._empty_dir(self.extract_addr)\n            self.evaluate_all_is_possible = True\n        if single_gold_json_file_addr:\n            with open(single_gold_json_file_addr) as file:\n                    task = json.load(file)\n                    self.tasks_golds[task['task']] = task['golds']\n    \n    def _empty_dir(self, directory_path):\n        for filename in os.listdir(directory_path):\n            file_path = os.path.join(directory_path, filename)\n            try:\n                if os.path.isfile(file_path):\n                    os.unlink(file_path)\n                elif os.path.isdir(file_path):\n                    shutil.rmtree(file_path)\n            except Exception as e:\n                print(f'Failed to delete {file_path}. Reason: {e}')\n\n    def _get_all_gold_ids(self, task_name):\n        return set([sample['id'] for sample in self.tasks_golds[task_name]])\n    \n    def _get_all_ids(self, input):\n        return set([sample['id'] for sample in input])\n    \n    def evaluate_all(self, predicts_zipfile_addr):\n        assert self.evaluate_all_is_possible, \"You did not provide golds for all tasks.\"\n        with zipfile.ZipFile(predicts_zipfile_addr, 'r') as zobj:\n            zobj.extractall(path = self.extract_addr)\n        results_raw = dict()\n        all_task_names = set()\n        for file_addr in glob.glob(os.path.join(self.extract_addr, \"**/*.json\"), recursive=True):\n            with open(file_addr) as file:\n                preds = json.load(file)\n            all_task_names.add(preds['task'])\n            results_raw[preds['task']] = self._evaluate_task(preds['golds'], preds['task'])\n        self._empty_dir(self.extract_addr)\n        assert len(all_task_names) == 7, \"The provided results do not cover all the tasks in the benchmark.\"\n        return results_raw\n\n    def evaluate_task(self, predicts_json_addr, task_name):\n        with open(predicts_json_addr) as file:\n            preds = json.load(file)\n        assert preds['task'] == task_name, \"The provided task_name and the results do not match.\"\n        assert preds['task'] in self.tasks_golds.keys(), \"The provided golds cannot be used to evaluate this task.\"\n        return self._evaluate_task(preds['golds'], task_name)\n\n    def _evaluate_task(self, predictions, task_name):\n        golds_dict = {y['id']:y['output'] for y in self.tasks_golds[task_name]}\n        preds_dict = {x['id']:x['output'] for x in predictions}\n        \n        gold_ids = self._get_all_gold_ids(task_name)\n        pred_ids = self._get_all_ids(predictions)\n\n        assert gold_ids == pred_ids, \"Predictions ids and gold ids do not match.\"\n\n        if task_name in [\"LaMP_1\", \"LaMP_2\"]:\n            metric = create_metric_f1_accuracy(self._get_labels(task_name))\n        elif task_name == \"LaMP_3\":\n            metric = create_metric_mae_rmse()\n        else:\n            metric = create_metric_rouge()\n        \n        gold_ids = list(gold_ids)\n        golds = [golds_dict[id] for id in gold_ids]\n        preds = [preds_dict[id] for id in gold_ids]\n        return metric(preds, golds)\n    \n    def _get_labels(self, task_name):\n        if task_name == \"LaMP_1\":\n            return [\"[1]\", \"[2]\"]\n        elif task_name == \"LaMP_2\":\n            return ['sci-fi', 'based on a book', 'comedy', 'action', 'twist ending', 'dystopia', 'dark comedy', 'classic', 'psychology', 'fantasy', 'romance', 'thought-provoking', 'social commentary', 'violence', 'true story']\n        elif task_name == \"LaMP_3\":\n            return [\"1\", \"2\", \"3\", \"4\", \"5\"]\n        else:\n            raise ValueError(\"Invalid task_name\")\n"
  }
]