[
  {
    "path": ".gitignore",
    "content": "*.idea/\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n.DS_Store\nts_package/\nbuild/\n\n"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "content": "## Code of Conduct\nThis project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).\nFor more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact\nopensource-codeofconduct@amazon.com with any additional questions or comments.\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Contributing Guidelines\n\nThank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional\ndocumentation, we greatly value feedback and contributions from our community.\n\nPlease read through this document before submitting any issues or pull requests to ensure we have all the necessary\ninformation to effectively respond to your bug report or contribution.\n\n\n## Reporting Bugs/Feature Requests\n\nWe welcome you to use the GitHub issue tracker to report bugs or suggest features.\n\nWhen filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already\nreported the issue. Please try to include as much information as you can. Details like these are incredibly useful:\n\n* A reproducible test case or series of steps\n* The version of our code being used\n* Any modifications you've made relevant to the bug\n* Anything unusual about your environment or deployment\n\n\n## Contributing via Pull Requests\nContributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:\n\n1. You are working against the latest source on the *main* branch.\n2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.\n3. You open an issue to discuss any significant work - we would hate for your time to be wasted.\n\nTo send us a pull request, please:\n\n1. Fork the repository.\n2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.\n3. Ensure local tests pass.\n4. Commit to your fork using clear commit messages.\n5. Send us a pull request, answering any default questions in the pull request interface.\n6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.\n\nGitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and\n[creating a pull request](https://help.github.com/articles/creating-a-pull-request/).\n\n\n## Finding contributions to work on\nLooking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.\n\n\n## Code of Conduct\nThis project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).\nFor more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact\nopensource-codeofconduct@amazon.com with any additional questions or comments.\n\n\n## Security issue notifications\nIf you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.\n\n\n## Licensing\n\nSee the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.\n"
  },
  {
    "path": "LICENSE",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n"
  },
  {
    "path": "NOTICE",
    "content": "Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n"
  },
  {
    "path": "README.md",
    "content": "# CrossCodeEval: A Diverse and Multilingual Benchmark for Cross-File Code Completion\n\nThis repository contains the data and inference code of the NeurIPS 2023  (Datasets and Benchmarks track)\npaper \"[CrossCodeEval: A Diverse and Multilingual Benchmark for Cross-File Code Completion](https://arxiv.org/abs/2310.11248).\"\n\n## Requirements\n\n- Uncompress the CrossCodeEval data via `tar -xvJf data/crosscodeeval_data.tar.xz -C data/`\n    - The data contains {baseline, retrieval, retrieval w/ ref.} setting x {bm25, UniXCoder, OpenAI Ada} retriever.\n    - **Please [email us](mailto:y.robin.ding@gmail.com) if you need the raw data.**\n- Install dependencies via `pip install -r requirements.txt`\n- Build tree sitter via `bash scripts/build_treesitter.sh`\n\n\n## Evaluation on CrossCodeEval\nOur evaluation consists of two steps: generation and metrics calculation.\n\n\n### Generation\n\n#### Publicly Available Models\nFor publicly available models like StarCoder, DeepSeek-Coder, etc., we recommended using [vLLM](https://github.com/vllm-project/vllm) for fast and distributed inference on CrossCodeEval. \n\n```bash\nexport gpus=2\nexport model=bigcode/starcoder2-3b\nexport language=python\nexport task=line_completion_rg1_unixcoder_cosine_sim\nexport output_dir=./tmp/crosscodeeval_testrun/\npython scripts/vllm_inference.py \\\n  --tp $gpus \\\n  --task $task \\\n  --language $language \\\n  --model $model \\\n  --output_dir $output_dir \\\n  --use_crossfile_context \n```\nFor additional args, e.g., cross-file context length and sampling top_p, please see `python vllm_inference.py --help`.\n\n<details><summary> If you prefer non-vLLM script <i>:: click to expand ::</i></summary>\n<div>\n\nFirst, configure `accelerate` via `accelerate config` if you haven't. A reference configuration is available at `cceval_config.yaml`\n\nThe following command demonstrates how to run greedy eval using codegen-350M on python with cross-file context.\n\n```bash\nexport model_type=codelm_cfc # or codelm for no cross-file context eval\nexport model_name=Salesforce/codegen-350M-mono\nexport language=python\nexport ts_lib=./build/${language}-lang-parser.so\nexport dtype=bf16 # or fp16\nexport prompt_file=./data/crosscodeeval_data/${language}/line_completion_rg1_unixcoder_cosine_sim.jsonl # or other options in the dir, which corresponds to different retrieval methods and/or retrieval settings\nexport max_seq_length=2048\nexport cfc_seq_length=512 \nexport batch_size=16 # reduce for larger models\nexport output_dir=./tmp/crosscodeeval_testrun/\n\naccelerate launch eval.py \\\n        --model_type $model_type \\\n        --model_name_or_path $model_name \\\n        --cfc_seq_length $cfc_seq_length \\\n        --prompt_file $prompt_file \\\n        --gen_length 50 \\\n        --max_seq_length $max_seq_length \\\n        --batch_size $batch_size \\\n        --output_dir $output_dir \\\n        --dtype $dtype \\\n        --num_return_sequences 1 \\\n        --overwrite_cache True \\\n        --ts_lib $ts_lib \\\n        --language $language\n```\n\nYou may run sampling via the following (additional) args:\n\n```bash\n        --do_sample \\\n        --top_p 0.95 \\\n        --temperature 0.2 \\\n        --num_return_sequences 5 \\\n```\n\n\n</div>\n</details>\n\n#### OpenAI models\nOpenAI models are accessible through an API. You may use the following script:\n```bash\nexport model=gpt-3.5-turbo-0125 \nexport language=python\nexport task=line_completion_rg1_unixcoder_cosine_sim\nexport output_dir=./tmp/crosscodeeval_openai_testrun/\npython scripts/openai_inference.py \\\n  --task $task \\\n  --language $language \\\n  --model $model \\\n  --output_dir $output_dir \\\n  --use_crossfile_context \n\n```\n\n\n### Metrics Calculation\nAfter obtaining the generation, we can calculate the final metrics\n```bash\nexport language=python\nexport ts_lib=./build/${language}-lang-parser.so; \nexport task=line_completion_oracle_unixcoder_cosine_sim\nexport prompt_file=./data/${language}/${task}.jsonl \nexport output_dir=./tmp/crosscodeeval_testrun/;  \npython scripts/eval.py \\\n  --prompt_file $prompt_file \\\n  --output_dir $output_dir \\\n  --ts_lib $ts_lib \\\n  --language $language \\\n  --only_compute_metric\n```\n\n\n\n\n\n\n\n## Citation\n\n```\n\n@inproceedings{ding2023crosscodeeval,\n    title={CrossCodeEval: A Diverse and Multilingual Benchmark for Cross-File Code Completion}, \n    author={Yangruibo Ding and Zijian Wang and Wasi Uddin Ahmad and Hantian Ding and Ming Tan and Nihal Jain and Murali Krishna Ramanathan and Ramesh Nallapati and Parminder Bhatia and Dan Roth and Bing Xiang},\n    year={2023},\n    booktitle={Thirty-seventh Conference on Neural Information Processing Systems Datasets and Benchmarks Track},\n    url={https://arxiv.org/pdf/2310.11248.pdf}\n}\n```\n## Questions\nPlease feel free to email us. You may also submit an issue in this repo.\n\n## Security\n\nSee [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.\n\n## License\n\nThis project is licensed under the Apache-2.0 License.\n"
  },
  {
    "path": "THIRD_PARTY_LICENSES",
    "content": "The CrossCodeEval repository includes the following third-party software/licensing:\n\nThe keywordlist.py was from https://github.com/microsoft/dpu-utils/blob/master/python/dpu_utils/codeutils/keywords/keywordlist.py with license\n\n    MIT License\n\n    Copyright (c) Microsoft Corporation. All rights reserved.\n\n    Permission is hereby granted, free of charge, to any person obtaining a copy\n    of this software and associated documentation files (the \"Software\"), to deal\n    in the Software without restriction, including without limitation the rights\n    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n    copies of the Software, and to permit persons to whom the Software is\n    furnished to do so, subject to the following conditions:\n\n    The above copyright notice and this permission notice shall be included in all\n    copies or substantial portions of the Software.\n\n    THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n    SOFTWARE\n"
  },
  {
    "path": "cceval_config.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndeepspeed_config:\n  gradient_accumulation_steps: 1\n  offload_optimizer_device: none\n  offload_param_device: none\n  zero3_init_flag: false\n  zero3_save_16bit_model: false\n  zero_stage: 3\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nmachine_rank: 0\nmain_training_function: main\nnum_machines: 1\nnum_processes: 8\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n"
  },
  {
    "path": "prompt_builder/README.md",
    "content": "## Retrieval Augmented Prompting\n\nWe can generate the retrieval augmented prompt following the below 3 steps.\n\n1. Please email us if to get the raw software repositories.\n2. Set the `repository_root` variable with the root directory which contains the raw software repositories.\n3. Run the `run.sh` bash script."
  },
  {
    "path": "prompt_builder/augment_with_cfc.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All rights reserved.\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport json\nimport time\nimport glob\nimport argparse\nimport multiprocessing as mp\nfrom tqdm import tqdm\nfrom functools import partial\nfrom rerank_utils import lexical_ranking, SemanticReranking\nfrom utils import str2bool, file_distance, tokenize_nltk\n\nCHUNK_SIZE = 10\nSLIDING_WINDOW_SIZE = 10  # non-overlapping chunks if SLIDING_WINDOW_SIZE=CHUNK_SIZE\nQUERY_LENGTH = 10  # last N lines from prompt will be query\n\nrepository_root = \"/PATH/TO/REPOS\"  # get the data from authors\n\ninput_files = {\n    \"python\": \"../data/crosscodeeval_data/python/line_completion.jsonl\",\n    \"java\": \"../data/crosscodeeval_data/java/line_completion.jsonl\",\n    \"typescript\": \"../data/crosscodeeval_data/typescript/line_completion.jsonl\",\n    \"csharp\": \"../data/crosscodeeval_data/csharp/line_completion.jsonl\"\n}\n\nfile_ext = {\"python\": \"py\", \"java\": \"java\", \"typescript\": \"ts\", \"csharp\": \"cs\"}\n\n\ndef get_crossfile_context_from_chunks(\n        args,\n        prompt,\n        code_chunks,\n        code_chunk_ids,\n        groundtruth,\n        semantic_ranker\n):\n    assert len(code_chunks) != 0\n    candidate_code_chunks = code_chunks[:args.maximum_chunk_to_rerank]\n    candidate_code_chunk_ids = code_chunk_ids[:args.maximum_chunk_to_rerank]\n\n    ranking_scores = None\n    meta_data = {}\n\n    if args.rerank:\n        if args.query_type == \"groundtruth\":\n            # oracle experiment\n            prompt_lines = [pl for pl in prompt.split(\"\\n\") if pl.strip()]\n            groundtruth_lines = [gt for gt in groundtruth.split(\"\\n\") if gt.strip()]\n            code_lines = prompt_lines + groundtruth_lines\n            query = \"\\n\".join(code_lines[-QUERY_LENGTH:])\n        elif args.query_type == \"last_n_lines\":\n            prompt_lines = [pl for pl in prompt.split(\"\\n\") if pl.strip()]\n            query = \"\\n\".join(prompt_lines[-QUERY_LENGTH:])\n        else:\n            raise NotImplementedError\n\n        meta_data[\"query\"] = query\n        start = time.time()\n\n        if args.ranking_fn == \"cosine_sim\":\n            gpu_id = int(mp.current_process().name.split('-')[-1]) - 1\n            candidate_code_chunks, candidate_code_chunk_ids, ranking_scores = semantic_ranker.rerank(\n                query,\n                candidate_code_chunks,\n                candidate_code_chunk_ids,\n                gpu_id,\n                score_threshold=None\n            )\n        else:\n            candidate_code_chunks, candidate_code_chunk_ids, ranking_scores = lexical_ranking(\n                query,\n                candidate_code_chunks,\n                args.ranking_fn,\n                candidate_code_chunk_ids,\n                score_threshold=None\n            )\n\n        meta_data[\"latency\"] = time.time() - start\n        meta_data[\"num_candidates\"] = len(candidate_code_chunks)\n\n    top_k = min(args.maximum_cross_file_chunk, len(candidate_code_chunk_ids))\n    if top_k == 0:\n        return [], meta_data\n\n    selected_chunks = []\n    selected_chunks_filename = []\n    selected_chunks_scores = []\n\n    if args.use_next_chunk_as_cfc:\n        # prepare an id2idx map\n        assert len(candidate_code_chunks) == len(candidate_code_chunk_ids)\n        id2idx = dict()\n        for j, cci in enumerate(code_chunk_ids):\n            id2idx[cci] = j\n\n        total_added = 0\n        for cidx, _id in enumerate(candidate_code_chunk_ids):\n            fname, c_id = _id.rsplit(\"|\", 1)\n            next_id = f\"{fname}|{int(c_id) + 1}\"\n            if next_id not in id2idx:\n                to_add = code_chunks[id2idx[_id]]\n            else:\n                to_add = code_chunks[id2idx[next_id]]\n\n            if to_add not in selected_chunks:\n                selected_chunks.append(to_add)\n                selected_chunks_filename.append(fname)\n                if args.rerank:\n                    selected_chunks_scores.append(ranking_scores[cidx])\n                total_added += 1\n                if total_added == top_k:\n                    break\n    else:\n        selected_chunks = candidate_code_chunks[:top_k]\n        selected_chunks_filename = [_id.rsplit(\"|\", 1)[0] for _id in candidate_code_chunk_ids[:top_k]]\n        if args.rerank:\n            selected_chunks_scores = ranking_scores[:top_k]\n\n    cross_file_context = []\n    for idx in range(len(selected_chunks)):\n        cross_file_context.append({\n            \"retrieved_chunk\": selected_chunks[idx],\n            \"filename\": selected_chunks_filename[idx],\n            \"score\": selected_chunks_scores[idx] if args.rerank else None\n        })\n\n    line_start_sym = \"#\" if args.language == \"python\" else \"//\"\n    cfc_text = f\"{line_start_sym} Here are some relevant code fragments from other files of the repo:\\n\\n\"\n    for sc, scf in zip(selected_chunks, selected_chunks_filename):\n        cfc_text += f\"{line_start_sym} the below code fragment can be found in:\\n{line_start_sym} {scf}\" + \"\\n\"\n        cfc_text += \"\\n\".join([f\"{line_start_sym} {cl}\" for cl in sc.strip('\\n').splitlines()]) + \"\\n\\n\"\n\n    return cross_file_context, cfc_text, meta_data\n\n\ndef read_project_files(repo_name, lang):\n    # root_dir needs a trailing slash (i.e. /root/dir/)\n    project_context = {}\n    root_dir = os.path.join(repository_root, lang, repo_name)\n    if not os.path.isdir(root_dir):\n        print(f\"Repository not found: {root_dir}\")\n        return project_context\n\n    if lang == \"typescript\":\n        src_files = []\n        src_files += glob.glob(os.path.join(root_dir, f'src/**/*.ts'), recursive=True)\n        src_files += glob.glob(os.path.join(root_dir, f'src/**/*.tsx'), recursive=True)\n    else:\n        src_files = glob.glob(os.path.join(root_dir, f'**/*.{file_ext[lang]}'), recursive=True)\n\n    if len(src_files) == 0:\n        return project_context\n\n    for filename in src_files:\n        if os.path.exists(filename):  # weird but some files cannot be opened to read\n            if os.path.isfile(filename):\n                try:\n                    with open(filename, \"r\") as file:\n                        file_content = file.read()\n                except:\n                    with open(filename, \"rb\") as file:\n                        file_content = file.read().decode(errors='replace')\n\n                fileid = os.path.relpath(filename, root_dir)\n                project_context[fileid] = file_content\n        else:\n            pass\n            # print(f\"File not found: {filename}\")\n\n    return project_context\n\n\ndef find_files_within_distance_k(current_file_path, filelist, k):\n    list_of_modules = []\n    module_weight = []\n    for filepath in filelist:\n        if filepath != current_file_path:\n            dist = file_distance(filepath, current_file_path)\n            if dist == -1:\n                continue\n            elif dist <= k:\n                list_of_modules.append(filepath)\n                module_weight.append(dist)\n\n    # sorting in ascending order\n    list_of_modules = [x for _, x in sorted(zip(module_weight, list_of_modules))]\n    return list_of_modules\n\n\ndef get_cfc(example, args, semantic_ranker, repositories):\n    project_context = repositories[example[\"metadata\"][\"repository\"]]\n    status = None\n    current_filepath = example[\"metadata\"][\"file\"]\n    if len(project_context) == 0:\n        example[\"crossfile_context\"] = \"\"\n        status = \"project_not_found\"\n    else:\n        current_filecontent = None\n        for filepath, filecontent in project_context.items():\n            if filepath == current_filepath:\n                current_filecontent = filecontent\n                break\n\n        if current_filecontent is None:\n            example[\"crossfile_context\"] = {}\n            print(current_filepath)\n            status = \"file_not_found_in_project\"\n\n        else:\n            pyfiles = find_files_within_distance_k(\n                example[\"metadata\"][\"file\"],\n                list(project_context.keys()),\n                k=args.crossfile_distance\n            )\n            pyfiles = pyfiles[:args.maximum_cross_files]\n\n            code_chunks = []\n            code_chunk_ids = []\n            for pyfile in pyfiles:\n                lines = project_context[pyfile].split(\"\\n\")\n                lines = [l for l in lines if l.strip()]  # removing empty lines\n                c_id = 0\n                for i in range(0, len(lines), SLIDING_WINDOW_SIZE):\n                    c = \"\\n\".join(lines[i:i + CHUNK_SIZE])\n                    tokenized_c = tokenize_nltk(c)\n                    if len(tokenized_c) > 0:\n                        code_chunks.append(c)\n                        code_chunk_ids.append(f\"{pyfile}|{c_id}\")\n                        c_id += 1\n\n            if len(code_chunks) == 0:\n                example[\"crossfile_context\"] = {}\n                status = \"no_crossfile_context\"\n\n            else:\n                cfc, cfc_text, meta_data = get_crossfile_context_from_chunks(\n                    args=args,\n                    prompt=example[\"prompt\"],\n                    code_chunks=code_chunks,\n                    code_chunk_ids=code_chunk_ids,\n                    groundtruth=example[\"groundtruth\"],\n                    semantic_ranker=semantic_ranker\n                )\n                example[\"crossfile_context\"] = {}\n                example[\"crossfile_context\"][\"text\"] = cfc_text\n                example[\"crossfile_context\"][\"list\"] = cfc\n\n    return example, status\n\n\ndef attach_data(args, srcfile):\n    empty_cfc = 0\n    error_freq = {\n        \"project_not_found\": 0,\n        \"file_not_found_in_project\": 0,\n        \"no_crossfile_context\": 0\n    }\n    output_examples = []\n\n    examples = []\n    repositories = dict()\n    with open(srcfile) as f:\n        for line in f:\n            ex = json.loads(line)\n            repo_name = ex[\"metadata\"][\"repository\"]\n            if repo_name not in repositories:\n                repositories[repo_name] = read_project_files(repo_name, args.language)\n            examples.append(ex)\n\n    semantic_ranker = None\n    if args.ranking_fn == \"cosine_sim\":\n        semantic_ranker = SemanticReranking(\n            args.ranker,\n            max_sequence_length=256\n        )\n\n    pool = mp.Pool(args.num_processes)\n    worker = partial(get_cfc, args=args, semantic_ranker=semantic_ranker, repositories=repositories)\n\n    with tqdm(total=len(examples)) as pbar:\n        for (d, stat) in pool.imap_unordered(worker, examples):\n            if stat in error_freq:\n                error_freq[stat] += 1\n            if len(d[\"crossfile_context\"]) == 0:\n                empty_cfc += 1\n                if not args.skip_if_no_cfc:\n                    output_examples.append(d)\n            else:\n                output_examples.append(d)\n            pbar.update()\n\n    print(\"Total examples with empty CFC: \", empty_cfc)\n    print(error_freq)\n    return output_examples\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--rerank\",\n        type=str2bool,\n        default=True,\n        help=\"rerank the functions\"\n    )\n    parser.add_argument(\n        \"--ranker\",\n        type=str,\n        default=\"sparse\",\n        choices=[\"sparse\", \"unixcoder\"],\n        help=\"ranking function\"\n    )\n    parser.add_argument(\n        \"--ranking_fn\",\n        type=str,\n        default=\"bm25\",\n        choices=[\"tfidf\", \"bm25\", \"jaccard_sim\", \"cosine_sim\"],\n        help=\"ranking function\"\n    )\n    parser.add_argument(\n        \"--query_type\",\n        type=str,\n        default=\"last_n_lines\",\n        choices=[\"last_n_lines\", \"groundtruth\"],\n        help=\"how to form query from prompt\"\n    )\n    parser.add_argument(\n        \"--crossfile_distance\",\n        type=int,\n        default=100,\n        help=\"max distance to search for crossfile\"\n    )\n    parser.add_argument(\n        \"--maximum_chunk_to_rerank\",\n        type=int,\n        default=1000,\n        help=\"max chunks to consider to rank via BM25\"\n    )\n    parser.add_argument(\n        \"--maximum_cross_files\",\n        type=int,\n        default=1000,\n        help=\"max chunks to consider to rank via BM25\"\n    )\n    parser.add_argument(\n        \"--maximum_cross_file_chunk\",\n        type=int,\n        default=50,\n        help=\"max chunks to return as cfc\"\n    )\n    parser.add_argument(\n        \"--use_next_chunk_as_cfc\",\n        type=str2bool,\n        default=True,\n        help=\"use next code chunk as context\"\n    )\n    parser.add_argument(\n        \"--skip_if_no_cfc\",\n        type=str2bool,\n        default=True,\n        help=\"skip adding examples if there is no crossfile context\"\n    )\n    parser.add_argument(\n        \"--output_file_suffix\",\n        type=str,\n        default=None,\n        help=\"add a suffix string to the output file\"\n    )\n    parser.add_argument(\n        \"--language\",\n        type=str,\n        required=True,\n        choices=[\"java\", \"python\", \"typescript\", \"csharp\"],\n        help=\"language name\"\n    )\n    args = parser.parse_args()\n\n    args.output_file_suffix = \"\" if args.output_file_suffix is None else f\"_{args.output_file_suffix}\"\n    if args.use_next_chunk_as_cfc:\n        assert args.rerank\n        assert args.query_type != \"groundtruth\"\n\n    tgtfile_suffix = \"\"\n    if args.rerank:\n        tgtfile_suffix += f\"_{args.ranking_fn}\"\n\n    args.num_processes = 60\n    if args.ranking_fn == \"cosine_sim\":\n        num_gpus = 8\n        args.num_processes = num_gpus\n        mp.set_start_method('spawn')\n\n    input_file = input_files[args.language]\n    output_path = os.path.dirname(input_file)\n    output_filename = os.path.splitext(os.path.basename(input_file))[0]\n    output_filename = output_filename + args.output_file_suffix + tgtfile_suffix + \".jsonl\"\n    output_file = os.path.join(output_path, output_filename)\n    output_examples = attach_data(args, input_file)\n    with open(output_file, \"w\") as fw:\n        for ex in output_examples:\n            fw.write(json.dumps(ex))\n            fw.write(\"\\n\")\n"
  },
  {
    "path": "prompt_builder/rerank_utils.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All rights reserved.\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom rank_bm25 import BM25Okapi\nfrom typing import List\nfrom multiprocessing import Pool, cpu_count\nfrom sklearn.feature_extraction.text import TfidfVectorizer\nfrom sklearn.metrics.pairwise import cosine_similarity\nfrom utils import tokenize_nltk\nfrom transformers import AutoModel, AutoTokenizer, AutoConfig\n\n\ndef jaccard_similarity(tokenized_query, tokenized_doc, containment=False):\n    set1 = set(tokenized_query)\n    set2 = set(tokenized_doc)\n    intersection = len(set1.intersection(set2))\n    union = len(set1) if containment else len(set1.union(set2))\n    return float(intersection) / union\n\n\ndef tokenize_corpus(corpus, tokenizer_fn):\n    pool = Pool(cpu_count())\n    tokenized_corpus = pool.map(tokenizer_fn, corpus)\n    return tokenized_corpus\n\n\ndef tokenize_query_and_docs(query, docs):\n    tokenized_query = tokenize_nltk(query)\n    tokenized_docs = [tokenize_nltk(d) for d in docs]\n    return tokenized_query, tokenized_docs\n\n\ndef lexical_ranking(\n        query,\n        docs,\n        ranking_fn,\n        doc_ids=None,\n        score_threshold=None,\n):\n    if ranking_fn == \"bm25\":\n        tokenized_query, tokenized_docs = tokenize_query_and_docs(query, docs)\n        bm25 = BM25Okapi(tokenized_docs)\n        scores = bm25.get_scores(tokenized_query)\n    elif ranking_fn == \"tfidf\":\n        tfidf_vectorizer = TfidfVectorizer(tokenizer=tokenize_nltk)\n        X = tfidf_vectorizer.fit_transform(docs).toarray()  # (n_fn, n_features)\n        y = tfidf_vectorizer.transform([query]).toarray()  # (1, n_features)\n        scores = cosine_similarity(X, y).tolist()  # (n_fn, 1)\n    elif ranking_fn == \"jaccard_sim\":\n        tokenized_query, tokenized_docs = tokenize_query_and_docs(query, docs)\n        scores = [jaccard_similarity(tokenized_query, d, containment=False) for d in tokenized_docs]\n    else:\n        raise NotImplementedError\n\n    if score_threshold:\n        skip_ids = [idx for idx, s in enumerate(scores) if s < score_threshold]\n        scores = [s for idx, s in enumerate(scores) if idx not in skip_ids]\n        docs = [d for idx, d in enumerate(docs) if idx not in skip_ids]\n        if doc_ids is not None:\n            doc_ids = [doc_id for idx, doc_id in enumerate(doc_ids) if idx not in skip_ids]\n\n    if len(docs) == 0:\n        return docs, doc_ids, scores\n\n    if doc_ids is not None:\n        doc_ids = [x for _, x in sorted(zip(scores, doc_ids), reverse=True)]\n    docs_scores = [(x, s) for s, x in sorted(zip(scores, docs), reverse=True)]\n    docs = [item[0] for item in docs_scores]\n    scores = [item[1] for item in docs_scores]\n\n    return docs, doc_ids, scores\n\n\nclass SemanticReranking:\n\n    def __init__(self, model_type=\"unixcoder\", **kwargs):\n        self.model_type = model_type\n        if model_type == \"unixcoder\":\n            self.tokenizer = AutoTokenizer.from_pretrained('microsoft/unixcoder-base')\n            self.model = AutoModel.from_pretrained('microsoft/unixcoder-base')\n        else:\n            raise NotImplementedError\n\n        # maximum sequence length for query and documents\n        self.max_sequence_length = kwargs.get(\"max_sequence_length\", 256)\n\n    def text_to_tensor(\n            self,\n            text: str,\n            pad_to_max: bool = True,\n    ):\n        text = text.strip()\n\n        # tokenizer automatic padding is explicitly disabled since its inconsistent behavior\n        token_ids = self.tokenizer.encode(\n            text,\n            add_special_tokens=False,\n            max_length=self.max_sequence_length,\n            pad_to_max_length=False,\n            truncation=True\n        )\n\n        if pad_to_max and len(token_ids) < self.max_sequence_length:\n            token_ids = token_ids + [self.tokenizer.pad_token_id] * (self.max_sequence_length - len(token_ids))\n        if len(token_ids) > self.max_sequence_length:\n            token_ids = token_ids[0:self.max_sequence_length]\n\n        return torch.tensor(token_ids)\n\n    def get_pad_id(self):\n        return self.tokenizer.pad_token_id\n\n    def get_attn_mask(self, tokens_tensor):\n        return tokens_tensor != self.get_pad_id()\n\n    def get_representations(self, list_input_ids, gpu_id):\n        device = torch.device('cuda', gpu_id)\n        self.model = self.model.to(device=device, dtype=torch.float16)\n        self.model.eval()\n\n        batch_size = 64\n        sequence_outputs = []\n        pooled_outputs = []\n\n        for idx in range(0, len(list_input_ids), batch_size):\n            start, end = idx, min(idx + batch_size, len(list_input_ids))\n            input_ids = torch.stack(list_input_ids[start:end], dim=0).to(device=device)\n            attention_mask = self.get_attn_mask(input_ids)\n\n            if self.model_type in CODE_SAGE_MODELS.keys():\n                output = self.model(\n                    input_ids=input_ids,\n                    attention_mask=attention_mask,\n                    output_hidden_states=True\n                )\n                token_embeddings = output.hidden_states[-1]  # bsz x seq_len x hid_dim\n            else:\n                output = self.model(input_ids, attention_mask)\n                token_embeddings = output.last_hidden_state  # bsz x seq_len x hid_dim\n\n            mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n            sum_embeddings = torch.sum(token_embeddings * mask_expanded, 1)\n            sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)\n            sequence_embeddings = sum_embeddings / sum_mask  # bsz x hid_dim\n\n            sequence_outputs.append(token_embeddings)\n            pooled_outputs.append(sequence_embeddings)\n\n        sequence_output = torch.cat(sequence_outputs)\n        pooled_output = torch.cat(pooled_outputs)\n\n        return sequence_output, pooled_output\n\n    def rerank(self, query: str, docs: List[str], doc_ids: List[str] = None, gpu_id=0, score_threshold=None):\n        with torch.no_grad():\n            batch_queries = [self.text_to_tensor(query)]\n            batch_candidates = [self.text_to_tensor(d) for d in docs]\n\n            _, query_rep = self.get_representations(batch_queries, gpu_id)  # 1 x hidden_size\n            _, candi_rep = self.get_representations(batch_candidates, gpu_id)  # num_cand x hidden_size\n            scores = torch.nn.functional.cosine_similarity(query_rep, candi_rep).tolist()  # num_cand\n\n        if score_threshold:\n            skip_ids = [idx for idx, s in enumerate(scores) if s < score_threshold]\n            scores = [s for idx, s in enumerate(scores) if idx not in skip_ids]\n            docs = [d for idx, d in enumerate(docs) if idx not in skip_ids]\n            if doc_ids is not None:\n                doc_ids = [doc_id for idx, doc_id in enumerate(doc_ids) if idx not in skip_ids]\n\n        if len(docs) == 0:\n            return docs, doc_ids, scores\n\n        if doc_ids is not None:\n            doc_ids = [x for _, x in sorted(zip(scores, doc_ids), reverse=True)]\n        docs_scores = [(x, s) for s, x in sorted(zip(scores, docs), reverse=True)]\n        docs = [item[0] for item in docs_scores]\n        scores = [item[1] for item in docs_scores]\n\n        return docs, doc_ids, scores\n"
  },
  {
    "path": "prompt_builder/run.sh",
    "content": "#!/usr/bin/env bash\n\nexport PYTHONIOENCODING=utf-8\n\nfunction generate_data() {\n    lang=$1\n    ranker=$2\n    ranking_fn=$3\n\n    echo \"$lang, $ranker, $ranking_fn\"\n\n    output_file_suffix=\"\"\n    if [[ $ranker != \"sparse\" ]]; then\n        output_file_suffix=\"_${ranker}\"\n    fi\n\n    # for RG-1\n    python augment_with_cfc.py \\\n        --language $lang \\\n        --rerank True \\\n        --ranker $ranker \\\n        --ranking_fn $ranking_fn \\\n        --query_type last_n_lines \\\n        --crossfile_distance 100 \\\n        --maximum_chunk_to_rerank 1000 \\\n        --maximum_cross_files 1000 \\\n        --maximum_cross_file_chunk 5 \\\n        --use_next_chunk_as_cfc True \\\n        --skip_if_no_cfc False \\\n        --output_file_suffix \"rg1${output_file_suffix}\"\n\n    # for oracle experiment\n    python augment_with_cfc.py \\\n        --language $lang \\\n        --rerank True \\\n        --ranker $ranker \\\n        --ranking_fn $ranking_fn \\\n        --query_type groundtruth \\\n        --crossfile_distance 100 \\\n        --maximum_chunk_to_rerank 1000 \\\n        --maximum_cross_files 1000 \\\n        --maximum_cross_file_chunk 5 \\\n        --use_next_chunk_as_cfc False \\\n        --skip_if_no_cfc False \\\n        --output_file_suffix \"oracle${output_file_suffix}\"\n}\n\ngenerate_data python sparse bm25\ngenerate_data java sparse bm25\ngenerate_data typescript sparse bm25\ngenerate_data csharp sparse bm25\n\ngenerate_data python sparse jaccard_sim\ngenerate_data java sparse jaccard_sim\ngenerate_data typescript sparse jaccard_sim\ngenerate_data csharp sparse jaccard_sim\n\ngenerate_data python unixcoder cosine_sim\ngenerate_data java unixcoder cosine_sim\ngenerate_data typescript unixcoder cosine_sim\ngenerate_data csharp unixcoder cosine_sim\n"
  },
  {
    "path": "prompt_builder/utils.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All rights reserved.\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport re\nimport os\nfrom typing import List\nfrom nltk.tokenize import word_tokenize\n\n\ndef tokenize_nltk(text):\n    words = word_tokenize(text)\n    output_list = []\n    for w in words:\n        w_list = re.findall(r'\\w+', w)\n        output_list.extend(w_list)\n    return output_list\n\n\ndef file_distance(src_file, dest_file):\n    distance = -1\n    try:\n        commonpath = os.path.commonpath([src_file, dest_file])\n        rel_file1_path = os.path.relpath(src_file, commonpath)\n        rel_file2_path = os.path.relpath(dest_file, commonpath)\n        distance = rel_file1_path.count(os.sep) + rel_file2_path.count(os.sep)\n    except Exception as e:\n        # print(e, src_file, dest_file)\n        pass\n\n    return distance\n\n\ndef str2bool(v):\n    if isinstance(v, bool):\n        return v\n    if v.lower() in ('yes', 'true', 't', 'y', '1'):\n        return True\n    elif v.lower() in ('no', 'false', 'f', 'n', '0'):\n        return False\n    else:\n        raise argparse.ArgumentTypeError('Boolean value expected.')\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch\ntransformers\ndatasets\ntree-sitter\ntimeout-decorator\nbitsandbytes\naccelerate\nscikit-learn\nrank-bm25\nfuzzywuzzy\nnltk\nsacrebleu\ndeepspeed\ntiktoken\nvllm>=0.3.3\n"
  },
  {
    "path": "scripts/build_treesitter.sh",
    "content": "mkdir ts_package;\ncd ts_package;\n# Download the tree-sitter package\ngit clone https://github.com/tree-sitter/tree-sitter-python.git;\ngit clone https://github.com/tree-sitter/tree-sitter-java.git;\ngit clone https://github.com/tree-sitter/tree-sitter-c-sharp.git;\ngit clone https://github.com/tree-sitter/tree-sitter-typescript.git;\ncd ..;\n# Build tree-sitter\npython scripts/build_ts_lib.py\n"
  },
  {
    "path": "scripts/build_ts_lib.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n\nfrom tree_sitter import Language\n\ndef build_language_lib():\n    for lang in [\"java\", \"python\", \"typescript\", \"csharp\"]:\n        ts_lang = \"c-sharp\" if lang == \"csharp\" else lang\n        if lang == \"typescript\":\n            git_dir = f\"ts_package/tree-sitter-{ts_lang}/{lang}\"\n        else:\n            git_dir = f\"ts_package/tree-sitter-{ts_lang}\"\n        Language.build_library(f'build/{lang}-lang-parser.so', [git_dir])\n\n\nif __name__ == \"__main__\":\n    build_language_lib()\n"
  },
  {
    "path": "scripts/custom_generate.py",
    "content": "# Modifications Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n##############################################################################\n# Modified from https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/utils.py\n##############################################################################\n\nimport copy\nimport inspect\nimport warnings\nfrom dataclasses import dataclass\nfrom typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport torch\nimport torch.distributed as dist\nfrom torch import nn\n\nfrom transformers.deepspeed import is_deepspeed_zero3_enabled\nfrom transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput\nfrom transformers.models.auto import (\n    MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,\n    MODEL_FOR_CAUSAL_LM_MAPPING,\n    MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,\n    MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,\n    MODEL_FOR_VISION_2_SEQ_MAPPING,\n)\nfrom transformers.utils import ModelOutput, logging\nfrom transformers.generation.beam_constraints import DisjunctiveConstraint, PhrasalConstraint\nfrom transformers.generation.beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer\nfrom transformers.generation.configuration_utils import GenerationConfig\nfrom transformers.generation.logits_process import (\n    EncoderNoRepeatNGramLogitsProcessor,\n    EncoderRepetitionPenaltyLogitsProcessor,\n    EpsilonLogitsWarper,\n    EtaLogitsWarper,\n    ExponentialDecayLengthPenalty,\n    ForcedBOSTokenLogitsProcessor,\n    ForcedEOSTokenLogitsProcessor,\n    ForceTokensLogitsProcessor,\n    HammingDiversityLogitsProcessor,\n    InfNanRemoveLogitsProcessor,\n    LogitNormalization,\n    LogitsProcessorList,\n    MinLengthLogitsProcessor,\n    MinNewTokensLengthLogitsProcessor,\n    NoBadWordsLogitsProcessor,\n    NoRepeatNGramLogitsProcessor,\n    PrefixConstrainedLogitsProcessor,\n    RepetitionPenaltyLogitsProcessor,\n    SuppressTokensAtBeginLogitsProcessor,\n    SuppressTokensLogitsProcessor,\n    TemperatureLogitsWarper,\n    TopKLogitsWarper,\n    TopPLogitsWarper,\n    TypicalLogitsWarper,\n)\nfrom transformers.generation.stopping_criteria import (\n    MaxLengthCriteria,\n    MaxTimeCriteria,\n    StoppingCriteria,\n    StoppingCriteriaList,\n    validate_stopping_criteria,\n)\nfrom transformers.generation.utils import GenerateOutput, SampleOutput, SampleDecoderOnlyOutput\n\nif TYPE_CHECKING:\n    from transformers.modeling_utils import PreTrainedModel\n    from transformers.generation.streamers import BaseStreamer\n\nlogger = logging.get_logger(__name__)\n\n\n@torch.no_grad()\ndef generate(\n        self,\n        inputs: Optional[torch.Tensor] = None,\n        generation_config: Optional[GenerationConfig] = None,\n        logits_processor: Optional[LogitsProcessorList] = None,\n        stopping_criteria: Optional[StoppingCriteriaList] = None,\n        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,\n        synced_gpus: Optional[bool] = None,\n        assistant_model: Optional[\"PreTrainedModel\"] = None,\n        streamer: Optional[\"BaseStreamer\"] = None,\n        **kwargs,\n) -> Union[GenerateOutput, torch.LongTensor]:\n    r\"\"\"\n\n    Generates sequences of token ids for models with a language modeling head.\n\n    <Tip warning={true}>\n\n    Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the\n    model's default generation configuration. You can override any `generation_config` by passing the corresponding\n    parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.\n\n    For an overview of generation strategies and code examples, check out the [following\n    guide](../generation_strategies).\n\n    </Tip>\n\n    Parameters:\n        inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):\n            The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the\n            method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`\n            should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of\n            `input_ids`, `input_values`, `input_features`, or `pixel_values`.\n        generation_config (`~generation.GenerationConfig`, *optional*):\n            The generation configuration to be used as base parametrization for the generation call. `**kwargs`\n            passed to generate matching the attributes of `generation_config` will override them. If\n            `generation_config` is not provided, the default will be used, which had the following loading\n            priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model\n            configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s\n            default values, whose documentation should be checked to parameterize generation.\n        logits_processor (`LogitsProcessorList`, *optional*):\n            Custom logits processors that complement the default logits processors built from arguments and\n            generation config. If a logit processor is passed that is already created with the arguments or a\n            generation config an error is thrown. This feature is intended for advanced users.\n        stopping_criteria (`StoppingCriteriaList`, *optional*):\n            Custom stopping criteria that complement the default stopping criteria built from arguments and a\n            generation config. If a stopping criteria is passed that is already created with the arguments or a\n            generation config an error is thrown. This feature is intended for advanced users.\n        prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):\n            If provided, this function constraints the beam search to allowed tokens only at each step. If not\n            provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and\n            `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned\n            on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful\n            for constrained generation conditioned on the prefix, as described in [Autoregressive Entity\n            Retrieval](https://arxiv.org/abs/2010.00904).\n        synced_gpus (`bool`, *optional*):\n            Whether to continue running the while loop until max_length. Unless overridden this flag will be set to\n            `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished\n            generating before other GPUs. Otherwise it'll be set to `False`.\n        assistant_model (`PreTrainedModel`, *optional*):\n            An assistant model that can be used to accelerate generation. The assistant model must have the exact\n            same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model\n            is much faster than running generation with the model you're calling generate from. As such, the\n            assistant model should be much smaller.\n        streamer (`BaseStreamer`, *optional*):\n            Streamer object that will be used to stream the generated sequences. Generated tokens are passed\n            through `streamer.put(token_ids)` and the streamer is responsible for any further processing.\n        kwargs:\n            Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be\n            forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder\n            specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.\n\n    Return:\n        [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`\n        or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.\n\n            If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible\n            [`~utils.ModelOutput`] types are:\n\n                - [`~generation.GreedySearchDecoderOnlyOutput`],\n                - [`~generation.SampleDecoderOnlyOutput`],\n                - [`~generation.BeamSearchDecoderOnlyOutput`],\n                - [`~generation.BeamSampleDecoderOnlyOutput`]\n\n            If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible\n            [`~utils.ModelOutput`] types are:\n\n                - [`~generation.GreedySearchEncoderDecoderOutput`],\n                - [`~generation.SampleEncoderDecoderOutput`],\n                - [`~generation.BeamSearchEncoderDecoderOutput`],\n                - [`~generation.BeamSampleEncoderDecoderOutput`]\n    \"\"\"\n\n    if synced_gpus is None:\n        if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1:\n            synced_gpus = True\n        else:\n            synced_gpus = False\n\n    # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call\n    self._validate_model_class()\n\n    # priority: `generation_config` argument > `model.generation_config` (the default generation config)\n    if generation_config is None:\n        # legacy: users may modify the model configuration to control generation -- update the generation config\n        # model attribute accordingly, if it was created from the model config\n        if self.generation_config._from_model_config:\n            new_generation_config = GenerationConfig.from_model_config(self.config)\n            if new_generation_config != self.generation_config:\n                warnings.warn(\n                    \"You have modified the pretrained model configuration to control generation. This is a\"\n                    \" deprecated strategy to control generation and will be removed soon, in a future version.\"\n                    \" Please use a generation configuration file (see\"\n                    \" https://huggingface.co/docs/transformers/main_classes/text_generation)\"\n                )\n                self.generation_config = new_generation_config\n        generation_config = self.generation_config\n\n    generation_config = copy.deepcopy(generation_config)\n    model_kwargs = generation_config.update(**kwargs)  # All unused kwargs must be model kwargs\n    generation_config.validate()\n    self._validate_model_kwargs(model_kwargs.copy())\n\n    # 2. Set generation parameters if not already defined\n    logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()\n    stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()\n\n    if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:\n        if model_kwargs.get(\"attention_mask\", None) is None:\n            logger.warning(\n                \"The attention mask and the pad token id were not set. As a consequence, you may observe \"\n                \"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\"\n            )\n        eos_token_id = generation_config.eos_token_id\n        if isinstance(eos_token_id, list):\n            eos_token_id = eos_token_id[0]\n        logger.warning(f\"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.\")\n        generation_config.pad_token_id = eos_token_id\n\n    # 3. Define model inputs\n    # inputs_tensor has to be defined\n    # model_input_name is defined if model-specific keyword input is passed\n    # otherwise model_input_name is None\n    # all model-specific keyword inputs are removed from `model_kwargs`\n    inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(\n        inputs, generation_config.bos_token_id, model_kwargs\n    )\n    batch_size = inputs_tensor.shape[0]\n\n    # 4. Define other model kwargs\n    model_kwargs[\"output_attentions\"] = generation_config.output_attentions\n    model_kwargs[\"output_hidden_states\"] = generation_config.output_hidden_states\n    model_kwargs[\"use_cache\"] = generation_config.use_cache\n\n    accepts_attention_mask = \"attention_mask\" in set(inspect.signature(self.forward).parameters.keys())\n    requires_attention_mask = \"encoder_outputs\" not in model_kwargs\n\n    if model_kwargs.get(\"attention_mask\", None) is None and requires_attention_mask and accepts_attention_mask:\n        model_kwargs[\"attention_mask\"] = self._prepare_attention_mask_for_generation(\n            inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id\n        )\n\n    # decoder-only models should use left-padding for generation\n    if not self.config.is_encoder_decoder:\n        if (\n                generation_config.pad_token_id is not None\n                and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0\n        ):\n            logger.warning(\n                \"A decoder-only architecture is being used, but right-padding was detected! For correct \"\n                \"generation results, please set `padding_side='left'` when initializing the tokenizer.\"\n            )\n\n    if self.config.is_encoder_decoder and \"encoder_outputs\" not in model_kwargs:\n        # if model is encoder decoder encoder_outputs are created\n        # and added to `model_kwargs`\n        model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(\n            inputs_tensor, model_kwargs, model_input_name\n        )\n\n    # 5. Prepare `input_ids` which will be used for auto-regressive generation\n    if self.config.is_encoder_decoder:\n        input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(\n            batch_size=batch_size,\n            model_input_name=model_input_name,\n            model_kwargs=model_kwargs,\n            decoder_start_token_id=generation_config.decoder_start_token_id,\n            bos_token_id=generation_config.bos_token_id,\n            device=inputs_tensor.device,\n        )\n    else:\n        input_ids = inputs_tensor if model_input_name == \"input_ids\" else model_kwargs.pop(\"input_ids\")\n\n    if streamer is not None:\n        streamer.put(input_ids.cpu())\n\n    # 6. Prepare `max_length` depending on other stopping criteria.\n    input_ids_seq_length = input_ids.shape[-1]\n    has_default_max_length = kwargs.get(\"max_length\") is None and generation_config.max_length is not None\n    if has_default_max_length and generation_config.max_new_tokens is None:\n        warnings.warn(\n            f\"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. \"\n            \"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we\"\n            \" recommend using `max_new_tokens` to control the maximum length of the generation.\",\n            UserWarning,\n        )\n    elif generation_config.max_new_tokens is not None:\n        if not has_default_max_length:\n            logger.warning(\n                f\"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=\"\n                f\"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. \"\n                \"Please refer to the documentation for more information. \"\n                \"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)\"\n            )\n        generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length\n\n    if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:\n        raise ValueError(\n            f\"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than\"\n            f\" the maximum length ({generation_config.max_length})\"\n        )\n    if input_ids_seq_length >= generation_config.max_length:\n        input_ids_string = \"decoder_input_ids\" if self.config.is_encoder_decoder else \"input_ids\"\n        logger.warning(\n            f\"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to\"\n            f\" {generation_config.max_length}. This can lead to unexpected behavior. You should consider\"\n            \" increasing `max_new_tokens`.\"\n        )\n\n    # 7. determine generation mode\n    is_constraint_gen_mode = (\n            generation_config.constraints is not None or generation_config.force_words_ids is not None\n    )\n\n    is_contrastive_search_gen_mode = (\n            (generation_config.num_beams == 1)\n            and generation_config.top_k is not None\n            and generation_config.top_k > 1\n            and generation_config.do_sample is False\n            and generation_config.penalty_alpha is not None\n            and generation_config.penalty_alpha > 0\n    )\n\n    is_greedy_gen_mode = (\n            (generation_config.num_beams == 1)\n            and (generation_config.num_beam_groups == 1)\n            and generation_config.do_sample is False\n            and not is_constraint_gen_mode\n            and not is_contrastive_search_gen_mode\n    )\n    is_sample_gen_mode = (\n            (generation_config.num_beams == 1)\n            and (generation_config.num_beam_groups == 1)\n            and generation_config.do_sample is True\n            and not is_constraint_gen_mode\n            and not is_contrastive_search_gen_mode\n    )\n    is_beam_gen_mode = (\n            (generation_config.num_beams > 1)\n            and (generation_config.num_beam_groups == 1)\n            and generation_config.do_sample is False\n            and not is_constraint_gen_mode\n            and not is_contrastive_search_gen_mode\n    )\n    is_beam_sample_gen_mode = (\n            (generation_config.num_beams > 1)\n            and (generation_config.num_beam_groups == 1)\n            and generation_config.do_sample is True\n            and not is_constraint_gen_mode\n            and not is_contrastive_search_gen_mode\n    )\n    is_group_beam_gen_mode = (\n            (generation_config.num_beams > 1)\n            and (generation_config.num_beam_groups > 1)\n            and not is_constraint_gen_mode\n            and not is_contrastive_search_gen_mode\n    )\n    is_assisted_gen_mode = False\n    if assistant_model is not None:\n        if not (is_greedy_gen_mode or is_sample_gen_mode):\n            raise ValueError(\n                \"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate \"\n                \"is only supported with Greedy Search and Sample.\"\n            )\n        is_assisted_gen_mode = True\n\n    if generation_config.num_beam_groups > generation_config.num_beams:\n        raise ValueError(\"`num_beam_groups` has to be smaller or equal to `num_beams`\")\n    if is_group_beam_gen_mode and generation_config.do_sample is True:\n        raise ValueError(\n            \"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`.\"\n        )\n\n    if streamer is not None and (generation_config.num_beams > 1):\n        raise ValueError(\n            \"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1.\"\n        )\n\n    if self.device.type != input_ids.device.type:\n        warnings.warn(\n            \"You are calling .generate() with the `input_ids` being on a device type different\"\n            f\" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model\"\n            f\" is on {self.device.type}. You may experience unexpected behaviors or slower generation.\"\n            \" Please make sure that you have put `input_ids` to the\"\n            f\" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before\"\n            \" running `.generate()`.\",\n            UserWarning,\n        )\n\n    # 8. prepare distribution pre_processing samplers\n    logits_processor = self._get_logits_processor(\n        generation_config=generation_config,\n        input_ids_seq_length=input_ids_seq_length,\n        encoder_input_ids=inputs_tensor,\n        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,\n        logits_processor=logits_processor,\n    )\n\n    # 9. prepare stopping criteria\n    stopping_criteria = self._get_stopping_criteria(\n        generation_config=generation_config, stopping_criteria=stopping_criteria\n    )\n    # 10. go into different generation modes\n    if is_assisted_gen_mode:\n        if generation_config.num_return_sequences > 1:\n            raise ValueError(\n                \"num_return_sequences has to be 1 when doing assisted generate, \"\n                f\"but is {generation_config.num_return_sequences}.\"\n            )\n        if batch_size > 1:\n            raise ValueError(\"assisted generate is only supported for batch_size = 1\")\n        if not model_kwargs[\"use_cache\"]:\n            raise ValueError(\"assisted generate requires `use_cache=True`\")\n\n        # 11. If the assistant model is an encoder-decoder, prepare its encoder outputs\n        if assistant_model.config.is_encoder_decoder:\n            assistant_model_kwargs = copy.deepcopy(model_kwargs)\n            inputs_tensor, model_input_name, assistant_model_kwargs = assistant_model._prepare_model_inputs(\n                inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_model_kwargs\n            )\n            assistant_model_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation(\n                inputs_tensor, assistant_model_kwargs, model_input_name\n            )\n            model_kwargs[\"assistant_encoder_outputs\"] = assistant_model_kwargs[\"encoder_outputs\"]\n\n        # 12. run assisted generate\n        return self.assisted_decoding(\n            input_ids,\n            assistant_model=assistant_model,\n            do_sample=generation_config.do_sample,\n            logits_processor=logits_processor,\n            logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None,\n            stopping_criteria=stopping_criteria,\n            pad_token_id=generation_config.pad_token_id,\n            eos_token_id=generation_config.eos_token_id,\n            output_scores=generation_config.output_scores,\n            return_dict_in_generate=generation_config.return_dict_in_generate,\n            synced_gpus=synced_gpus,\n            streamer=streamer,\n            **model_kwargs,\n        )\n    if is_greedy_gen_mode:\n        if generation_config.num_return_sequences > 1:\n            raise ValueError(\n                \"num_return_sequences has to be 1 when doing greedy search, \"\n                f\"but is {generation_config.num_return_sequences}.\"\n            )\n\n        # 11. run greedy search\n        return self.greedy_search(\n            input_ids,\n            logits_processor=logits_processor,\n            stopping_criteria=stopping_criteria,\n            pad_token_id=generation_config.pad_token_id,\n            eos_token_id=generation_config.eos_token_id,\n            output_scores=generation_config.output_scores,\n            return_dict_in_generate=generation_config.return_dict_in_generate,\n            synced_gpus=synced_gpus,\n            streamer=streamer,\n            **model_kwargs,\n        )\n\n    elif is_contrastive_search_gen_mode:\n        if generation_config.num_return_sequences > 1:\n            raise ValueError(\n                \"num_return_sequences has to be 1 when doing contrastive search, \"\n                f\"but is {generation_config.num_return_sequences}.\"\n            )\n        if not model_kwargs[\"use_cache\"]:\n            raise ValueError(\"Contrastive search requires `use_cache=True`\")\n\n        return self.contrastive_search(\n            input_ids,\n            top_k=generation_config.top_k,\n            penalty_alpha=generation_config.penalty_alpha,\n            logits_processor=logits_processor,\n            stopping_criteria=stopping_criteria,\n            pad_token_id=generation_config.pad_token_id,\n            eos_token_id=generation_config.eos_token_id,\n            output_scores=generation_config.output_scores,\n            return_dict_in_generate=generation_config.return_dict_in_generate,\n            synced_gpus=synced_gpus,\n            streamer=streamer,\n            **model_kwargs,\n        )\n\n    elif is_sample_gen_mode:\n        # 11. prepare logits warper\n        logits_warper = self._get_logits_warper(generation_config)\n\n        # 12. expand input_ids with `num_return_sequences` additional sequences per batch\n        input_ids, model_kwargs = self._expand_inputs_for_generation(\n            input_ids=input_ids,\n            expand_size=generation_config.num_return_sequences,\n            is_encoder_decoder=self.config.is_encoder_decoder,\n            **model_kwargs,\n        )\n\n        # 13. run sample\n        return sample(\n            self,\n            input_ids,\n            logits_processor=logits_processor,\n            logits_warper=logits_warper,\n            stopping_criteria=stopping_criteria,\n            pad_token_id=generation_config.pad_token_id,\n            eos_token_id=generation_config.eos_token_id,\n            output_scores=generation_config.output_scores,\n            return_dict_in_generate=generation_config.return_dict_in_generate,\n            synced_gpus=synced_gpus,\n            streamer=streamer,\n            **model_kwargs,\n        )\n\n    elif is_beam_gen_mode:\n        if generation_config.num_return_sequences > generation_config.num_beams:\n            raise ValueError(\"`num_return_sequences` has to be smaller or equal to `num_beams`.\")\n\n        if stopping_criteria.max_length is None:\n            raise ValueError(\"`max_length` needs to be a stopping_criteria for now.\")\n\n        # 11. prepare beam search scorer\n        beam_scorer = BeamSearchScorer(\n            batch_size=batch_size,\n            num_beams=generation_config.num_beams,\n            device=inputs_tensor.device,\n            length_penalty=generation_config.length_penalty,\n            do_early_stopping=generation_config.early_stopping,\n            num_beam_hyps_to_keep=generation_config.num_return_sequences,\n            max_length=generation_config.max_length,\n        )\n        # 12. interleave input_ids with `num_beams` additional sequences per batch\n        input_ids, model_kwargs = self._expand_inputs_for_generation(\n            input_ids=input_ids,\n            expand_size=generation_config.num_beams,\n            is_encoder_decoder=self.config.is_encoder_decoder,\n            **model_kwargs,\n        )\n        # 13. run beam search\n        return self.beam_search(\n            input_ids,\n            beam_scorer,\n            logits_processor=logits_processor,\n            stopping_criteria=stopping_criteria,\n            pad_token_id=generation_config.pad_token_id,\n            eos_token_id=generation_config.eos_token_id,\n            output_scores=generation_config.output_scores,\n            return_dict_in_generate=generation_config.return_dict_in_generate,\n            synced_gpus=synced_gpus,\n            **model_kwargs,\n        )\n\n    elif is_beam_sample_gen_mode:\n        # 11. prepare logits warper\n        logits_warper = self._get_logits_warper(generation_config)\n\n        if stopping_criteria.max_length is None:\n            raise ValueError(\"`max_length` needs to be a stopping_criteria for now.\")\n        # 12. prepare beam search scorer\n        beam_scorer = BeamSearchScorer(\n            batch_size=batch_size * generation_config.num_return_sequences,\n            num_beams=generation_config.num_beams,\n            device=inputs_tensor.device,\n            length_penalty=generation_config.length_penalty,\n            do_early_stopping=generation_config.early_stopping,\n            max_length=generation_config.max_length,\n        )\n\n        # 13. interleave input_ids with `num_beams` additional sequences per batch\n        input_ids, model_kwargs = self._expand_inputs_for_generation(\n            input_ids=input_ids,\n            expand_size=generation_config.num_beams * generation_config.num_return_sequences,\n            is_encoder_decoder=self.config.is_encoder_decoder,\n            **model_kwargs,\n        )\n\n        # 14. run beam sample\n        return self.beam_sample(\n            input_ids,\n            beam_scorer,\n            logits_processor=logits_processor,\n            logits_warper=logits_warper,\n            stopping_criteria=stopping_criteria,\n            pad_token_id=generation_config.pad_token_id,\n            eos_token_id=generation_config.eos_token_id,\n            output_scores=generation_config.output_scores,\n            return_dict_in_generate=generation_config.return_dict_in_generate,\n            synced_gpus=synced_gpus,\n            **model_kwargs,\n        )\n\n    elif is_group_beam_gen_mode:\n        if generation_config.num_return_sequences > generation_config.num_beams:\n            raise ValueError(\"`num_return_sequences` has to be smaller or equal to `num_beams`.\")\n\n        if generation_config.num_beams % generation_config.num_beam_groups != 0:\n            raise ValueError(\"`num_beams` should be divisible by `num_beam_groups` for group beam search.\")\n\n        if stopping_criteria.max_length is None:\n            raise ValueError(\"`max_length` needs to be a stopping_criteria for now.\")\n\n        has_default_typical_p = kwargs.get(\"typical_p\") is None and generation_config.typical_p == 1.0\n        if not has_default_typical_p:\n            raise ValueError(\"Decoder argument `typical_p` is not supported with beam groups.\")\n\n        # 11. prepare beam search scorer\n        beam_scorer = BeamSearchScorer(\n            batch_size=batch_size,\n            num_beams=generation_config.num_beams,\n            device=inputs_tensor.device,\n            length_penalty=generation_config.length_penalty,\n            do_early_stopping=generation_config.early_stopping,\n            num_beam_hyps_to_keep=generation_config.num_return_sequences,\n            num_beam_groups=generation_config.num_beam_groups,\n            max_length=generation_config.max_length,\n        )\n        # 12. interleave input_ids with `num_beams` additional sequences per batch\n        input_ids, model_kwargs = self._expand_inputs_for_generation(\n            input_ids=input_ids,\n            expand_size=generation_config.num_beams,\n            is_encoder_decoder=self.config.is_encoder_decoder,\n            **model_kwargs,\n        )\n        # 13. run beam search\n        return self.group_beam_search(\n            input_ids,\n            beam_scorer,\n            logits_processor=logits_processor,\n            stopping_criteria=stopping_criteria,\n            pad_token_id=generation_config.pad_token_id,\n            eos_token_id=generation_config.eos_token_id,\n            output_scores=generation_config.output_scores,\n            return_dict_in_generate=generation_config.return_dict_in_generate,\n            synced_gpus=synced_gpus,\n            **model_kwargs,\n        )\n\n    elif is_constraint_gen_mode:\n        if generation_config.num_return_sequences > generation_config.num_beams:\n            raise ValueError(\"`num_return_sequences` has to be smaller or equal to `num_beams`.\")\n\n        if stopping_criteria.max_length is None:\n            raise ValueError(\"`max_length` needs to be a stopping_criteria for now.\")\n\n        if generation_config.num_beams <= 1:\n            raise ValueError(\"`num_beams` needs to be greater than 1 for constrained generation.\")\n\n        if generation_config.do_sample:\n            raise ValueError(\"`do_sample` needs to be false for constrained generation.\")\n\n        if generation_config.num_beam_groups is not None and generation_config.num_beam_groups > 1:\n            raise ValueError(\"`num_beam_groups` not supported yet for constrained generation.\")\n\n        final_constraints = []\n        if generation_config.constraints is not None:\n            final_constraints = generation_config.constraints\n\n        if generation_config.force_words_ids is not None:\n\n            def typeerror():\n                raise ValueError(\n                    \"`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`\"\n                    f\"of positive integers, but is {generation_config.force_words_ids}.\"\n                )\n\n            if (\n                    not isinstance(generation_config.force_words_ids, list)\n                    or len(generation_config.force_words_ids) == 0\n            ):\n                typeerror()\n\n            for word_ids in generation_config.force_words_ids:\n                if isinstance(word_ids[0], list):\n                    if not isinstance(word_ids, list) or len(word_ids) == 0:\n                        typeerror()\n                    if any(not isinstance(token_ids, list) for token_ids in word_ids):\n                        typeerror()\n                    if any(\n                            any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)\n                            for token_ids in word_ids\n                    ):\n                        typeerror()\n\n                    constraint = DisjunctiveConstraint(word_ids)\n                else:\n                    if not isinstance(word_ids, list) or len(word_ids) == 0:\n                        typeerror()\n                    if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):\n                        typeerror()\n\n                    constraint = PhrasalConstraint(word_ids)\n                final_constraints.append(constraint)\n\n        # 11. prepare beam search scorer\n        constrained_beam_scorer = ConstrainedBeamSearchScorer(\n            constraints=final_constraints,\n            batch_size=batch_size,\n            num_beams=generation_config.num_beams,\n            device=inputs_tensor.device,\n            length_penalty=generation_config.length_penalty,\n            do_early_stopping=generation_config.early_stopping,\n            num_beam_hyps_to_keep=generation_config.num_return_sequences,\n            max_length=generation_config.max_length,\n        )\n        # 12. interleave input_ids with `num_beams` additional sequences per batch\n        input_ids, model_kwargs = self._expand_inputs_for_generation(\n            input_ids=input_ids,\n            expand_size=generation_config.num_beams,\n            is_encoder_decoder=self.config.is_encoder_decoder,\n            **model_kwargs,\n        )\n        # 13. run beam search\n        return self.constrained_beam_search(\n            input_ids,\n            constrained_beam_scorer=constrained_beam_scorer,\n            logits_processor=logits_processor,\n            stopping_criteria=stopping_criteria,\n            pad_token_id=generation_config.pad_token_id,\n            eos_token_id=generation_config.eos_token_id,\n            output_scores=generation_config.output_scores,\n            return_dict_in_generate=generation_config.return_dict_in_generate,\n            synced_gpus=synced_gpus,\n            **model_kwargs,\n        )\n\n\ndef sample(\n        self,\n        input_ids: torch.LongTensor,\n        logits_processor: Optional[LogitsProcessorList] = None,\n        stopping_criteria: Optional[StoppingCriteriaList] = None,\n        logits_warper: Optional[LogitsProcessorList] = None,\n        max_length: Optional[int] = None,\n        pad_token_id: Optional[int] = None,\n        eos_token_id: Optional[Union[int, List[int]]] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        output_scores: Optional[bool] = None,\n        return_dict_in_generate: Optional[bool] = None,\n        synced_gpus: bool = False,\n        streamer: Optional[\"BaseStreamer\"] = None,\n        **model_kwargs,\n) -> Union[SampleOutput, torch.LongTensor]:\n    r\"\"\"\n    Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and\n    can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.\n\n    <Tip warning={true}>\n\n    In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead.\n    For an overview of generation strategies and code examples, check the [following\n    guide](../generation_strategies).\n\n    </Tip>\n\n    Parameters:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            The sequence used as a prompt for the generation.\n        logits_processor (`LogitsProcessorList`, *optional*):\n            An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]\n            used to modify the prediction scores of the language modeling head applied at each generation step.\n        stopping_criteria (`StoppingCriteriaList`, *optional*):\n            An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]\n            used to tell if the generation loop should stop.\n        logits_warper (`LogitsProcessorList`, *optional*):\n            An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used\n            to warp the prediction score distribution of the language modeling head applied before multinomial\n            sampling at each generation step.\n        max_length (`int`, *optional*, defaults to 20):\n            **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated\n            tokens. The maximum length of the sequence to be generated.\n        pad_token_id (`int`, *optional*):\n            The id of the *padding* token.\n        eos_token_id (`Union[int, List[int]]`, *optional*):\n            The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.\n        output_attentions (`bool`, *optional*, defaults to `False`):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n            returned tensors for more details.\n        output_hidden_states (`bool`, *optional*, defaults to `False`):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors\n            for more details.\n        output_scores (`bool`, *optional*, defaults to `False`):\n            Whether or not to return the prediction scores. See `scores` under returned tensors for more details.\n        return_dict_in_generate (`bool`, *optional*, defaults to `False`):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n        synced_gpus (`bool`, *optional*, defaults to `False`):\n            Whether to continue running the while loop until max_length (needed for ZeRO stage 3)\n        streamer (`BaseStreamer`, *optional*):\n            Streamer object that will be used to stream the generated sequences. Generated tokens are passed\n            through `streamer.put(token_ids)` and the streamer is responsible for any further processing.\n        model_kwargs:\n            Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is\n            an encoder-decoder model the kwargs should include `encoder_outputs`.\n\n    Return:\n        [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`:\n        A `torch.LongTensor` containing the generated tokens (default behaviour) or a\n        [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and\n        `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if\n        `model.config.is_encoder_decoder=True`.\n\n    Examples:\n\n    ```python\n    >>> from transformers import (\n    ...     AutoTokenizer,\n    ...     AutoModelForCausalLM,\n    ...     LogitsProcessorList,\n    ...     MinLengthLogitsProcessor,\n    ...     TopKLogitsWarper,\n    ...     TemperatureLogitsWarper,\n    ...     StoppingCriteriaList,\n    ...     MaxLengthCriteria,\n    ... )\n    >>> import torch\n\n    >>> tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n    >>> model = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n\n    >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token\n    >>> model.config.pad_token_id = model.config.eos_token_id\n    >>> model.generation_config.pad_token_id = model.config.eos_token_id\n\n    >>> input_prompt = \"Today is a beautiful day, and\"\n    >>> input_ids = tokenizer(input_prompt, return_tensors=\"pt\").input_ids\n\n    >>> # instantiate logits processors\n    >>> logits_processor = LogitsProcessorList(\n    ...     [\n    ...         MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id),\n    ...     ]\n    ... )\n    >>> # instantiate logits processors\n    >>> logits_warper = LogitsProcessorList(\n    ...     [\n    ...         TopKLogitsWarper(50),\n    ...         TemperatureLogitsWarper(0.7),\n    ...     ]\n    ... )\n\n    >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])\n\n    >>> torch.manual_seed(0)  # doctest: +IGNORE_RESULT\n    >>> outputs = model.sample(\n    ...     input_ids,\n    ...     logits_processor=logits_processor,\n    ...     logits_warper=logits_warper,\n    ...     stopping_criteria=stopping_criteria,\n    ... )\n\n    >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)\n    ['Today is a beautiful day, and we must do everything possible to make it a day of celebration.']\n    ```\"\"\"\n    # init values\n    logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()\n    stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()\n    if max_length is not None:\n        warnings.warn(\n            \"`max_length` is deprecated in this function, use\"\n            \" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.\",\n            UserWarning,\n        )\n        stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)\n    logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()\n    pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id\n    eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id\n    if isinstance(eos_token_id, int):\n        eos_token_id = [eos_token_id]\n    eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None\n    output_scores = output_scores if output_scores is not None else self.generation_config.output_scores\n    output_attentions = (\n        output_attentions if output_attentions is not None else self.generation_config.output_attentions\n    )\n    output_hidden_states = (\n        output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states\n    )\n    return_dict_in_generate = (\n        return_dict_in_generate\n        if return_dict_in_generate is not None\n        else self.generation_config.return_dict_in_generate\n    )\n\n    # init attention / hidden states / scores tuples\n    scores = () if (return_dict_in_generate and output_scores) else None\n    decoder_attentions = () if (return_dict_in_generate and output_attentions) else None\n    cross_attentions = () if (return_dict_in_generate and output_attentions) else None\n    decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None\n\n    # if model is an encoder-decoder, retrieve encoder attention weights and hidden states\n    if return_dict_in_generate and self.config.is_encoder_decoder:\n        encoder_attentions = model_kwargs[\"encoder_outputs\"].get(\"attentions\") if output_attentions else None\n        encoder_hidden_states = (\n            model_kwargs[\"encoder_outputs\"].get(\"hidden_states\") if output_hidden_states else None\n        )\n\n    # keep track of which sequences are already finished\n    unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)\n\n    this_peer_finished = False  # used by synced_gpus only\n    # auto-regressive generation\n    while True:\n        if synced_gpus:\n            # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.\n            # The following logic allows an early break if all peers finished generating their sequence\n            this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)\n            # send 0.0 if we finished, 1.0 otherwise\n            dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)\n            # did all peers finish? the reduced sum will be 0.0 then\n            if this_peer_finished_flag.item() == 0.0:\n                break\n\n        # prepare model inputs\n        model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)\n\n        # forward pass to get next token\n        outputs = self(\n            **model_inputs,\n            return_dict=True,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n        )\n\n        if synced_gpus and this_peer_finished:\n            continue  # don't waste resources running the code we don't need\n\n        next_token_logits = outputs.logits[:, -1, :]\n\n        # pre-process distribution\n        next_token_scores = logits_processor(input_ids, next_token_logits)\n        next_token_scores = logits_warper(input_ids, next_token_scores)\n\n        # Store scores, attentions and hidden_states when required\n        if return_dict_in_generate:\n            if output_scores:\n                # scores += (next_token_scores,)\n                scores += (next_token_logits,)\n            if output_attentions:\n                decoder_attentions += (\n                    (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)\n                )\n                if self.config.is_encoder_decoder:\n                    cross_attentions += (outputs.cross_attentions,)\n\n            if output_hidden_states:\n                decoder_hidden_states += (\n                    (outputs.decoder_hidden_states,)\n                    if self.config.is_encoder_decoder\n                    else (outputs.hidden_states,)\n                )\n\n        # sample\n        probs = nn.functional.softmax(next_token_scores, dim=-1)\n        if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):\n            probs = torch.nan_to_num(probs)\n        next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)\n\n        # finished sentences should have their next token be a padding token\n        if eos_token_id is not None:\n            if pad_token_id is None:\n                raise ValueError(\"If `eos_token_id` is defined, make sure that `pad_token_id` is defined.\")\n            next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)\n\n        # update generated ids, model inputs, and length for next step\n        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)\n        if streamer is not None:\n            streamer.put(next_tokens.cpu())\n        model_kwargs = self._update_model_kwargs_for_generation(\n            outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder\n        )\n\n        # if eos_token was found in one sentence, set sentence to finished\n        if eos_token_id_tensor is not None:\n            unfinished_sequences = unfinished_sequences.mul(\n                next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)\n            )\n\n            # stop when each sentence is finished\n            if unfinished_sequences.max() == 0:\n                this_peer_finished = True\n\n        # stop if we exceed the maximum length\n        if stopping_criteria(input_ids, scores):\n            this_peer_finished = True\n\n        if this_peer_finished and not synced_gpus:\n            break\n\n    if streamer is not None:\n        streamer.end()\n\n    if return_dict_in_generate:\n        if self.config.is_encoder_decoder:\n            return SampleEncoderDecoderOutput(\n                sequences=input_ids,\n                scores=scores,\n                encoder_attentions=encoder_attentions,\n                encoder_hidden_states=encoder_hidden_states,\n                decoder_attentions=decoder_attentions,\n                cross_attentions=cross_attentions,\n                decoder_hidden_states=decoder_hidden_states,\n            )\n        else:\n            return SampleDecoderOnlyOutput(\n                sequences=input_ids,\n                scores=scores,\n                attentions=decoder_attentions,\n                hidden_states=decoder_hidden_states,\n            )\n    else:\n        return input_ids\n"
  },
  {
    "path": "scripts/eval.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All rights reserved.\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport argparse\nimport json\nimport logging\nimport os\n\nimport numpy as np\nimport torch\nfrom accelerate import Accelerator\nfrom accelerate.utils import set_seed\nfrom datasets import load_dataset\nfrom torch.utils.data import DataLoader, SequentialSampler\nfrom tqdm import tqdm\nfrom transformers import (\n    AutoTokenizer,\n    AutoModelForCausalLM\n)\n\nimport custom_generate\nfrom eval_metric import compute_metric_stmt\nfrom eval_utils import compute_mean_logp\n\nlogging.basicConfig(\n    format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n    datefmt=\"%m/%d/%Y %H:%M:%S\",\n    level=logging.INFO,\n)\n\nlogger = logging.getLogger(__name__)\n\nCOMMENT_SYMBOL = {\n    \"python\": \"#\",\n    \"java\": \"//\",\n    \"csharp\": \"//\",\n    \"typescript\": \"//\"\n}\n\n\ndef custom_data_collator(features):\n    first = features[0]\n    batch = {}\n    for k, v in first.items():\n        if v is not None and not isinstance(v, str):\n            if isinstance(v, torch.Tensor):\n                batch[k] = torch.stack([f[k] for f in features])\n            elif isinstance(v, np.ndarray):\n                batch[k] = torch.tensor(np.stack([f[k] for f in features]))\n            else:\n                batch[k] = torch.tensor([f[k] for f in features])\n        if v is not None and isinstance(v, str):\n            batch[k] = [f[k] for f in features]\n\n    return batch\n\n\ndef build_datasets(args, tokenizer):\n    # Initialize the model and tokenizer\n    # when generating, we will use the logits of right-most token to predict the next token\n    # so the padding should be on the left\n    tokenizer.padding_side = \"left\"\n    tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else tokenizer.bos_token\n\n    # load the files into Dataset\n    raw_datasets = load_dataset(\"json\", data_files=args.prompt_file, cache_dir=args.cache_dir)\n    raw_datasets = raw_datasets[\"train\"]\n    raw_datasets = raw_datasets.map(lambda example, idx: {'index': idx, **example}, with_indices=True)\n    index2taskid = {idx: md[\"task_id\"] for idx, md in zip(raw_datasets[\"index\"], raw_datasets[\"metadata\"])}\n    column_names = raw_datasets.column_names\n\n    # Prompt composition\n    def prepare_features(examples):\n        tokenizer.truncation_side = \"left\"\n        tokenized_inputs = tokenizer(\n            examples[\"prompt\"],\n            padding=\"max_length\",\n            truncation=True,\n            max_length=args.max_seq_length - args.gen_length\n        )\n\n        features = {k: t for k, t in tokenized_inputs.items()}\n        features[\"index\"] = examples[\"index\"]\n        return features\n\n    def prepare_features_cfc(examples):\n        max_prompt_length = args.max_seq_length - args.gen_length\n        use_key = \"list\"\n\n        crossfile_context = []\n        if use_key == \"text\":\n            crossfile_context = [ex[\"text\"] for ex in examples[\"crossfile_context\"]]\n        else:\n            ls_sym = COMMENT_SYMBOL[args.language]\n            num_chunk_inc_prompt = []\n            augmented_prompt = 0\n            for cfc_chunks in examples[\"crossfile_context\"]:\n                cfc_chunks = cfc_chunks[\"list\"]  # a list of dict\n                cfc_text = \"\"\n                if cfc_chunks:\n                    # at least 1 relevant cfc_chunk found\n                    init_cfc_text = f\"{ls_sym} Here are some relevant code fragments from other files of the repo:\\n\\n\"\n                    cfc_length = len(tokenizer.tokenize(init_cfc_text))\n                    num_chunk_inc = 0\n                    for cfc_idx, cfc_chunk in enumerate(cfc_chunks):\n                        if cfc_chunk[\"score\"] > args.min_cfc_score:\n                            add_text = f\"{ls_sym} the below code fragment is found in {cfc_chunk['filename']}\" + \"\\n\"\n                            cfc_lines = cfc_chunk[\"retrieved_chunk\"].split('\\n')\n                            add_text += \"\\n\".join([f\"{ls_sym} {cl}\" for cl in cfc_lines if cl]) + \"\\n\\n\"\n                            # check if adding chunk exceeds max length budget for CFC\n                            add_text_len = len(tokenizer.tokenize(add_text))\n                            if cfc_length + add_text_len <= args.cfc_seq_length:\n                                cfc_text += add_text\n                                cfc_length += add_text_len\n                                num_chunk_inc += 1\n                            else:\n                                break\n                    num_chunk_inc_prompt.append(num_chunk_inc)\n                    if num_chunk_inc > 0:\n                        cfc_text = init_cfc_text + cfc_text\n                        augmented_prompt += 1\n                crossfile_context.append(cfc_text)\n\n            logger.info(\n                f\"{augmented_prompt} out of {len(examples['crossfile_context'])} prompts are augmented with cross-file context.\")\n\n        tokenizer.truncation_side = \"right\"\n        crossfile_features = tokenizer(\n            crossfile_context,\n            truncation=True,\n            max_length=args.cfc_seq_length\n        )\n\n        features = {\"input_ids\": [], \"attention_mask\": []}\n        tokenizer.truncation_side = \"left\"\n        for idx, prompt in enumerate(examples[\"prompt\"]):\n            allowed_prompt_length = max_prompt_length - len(crossfile_features[\"input_ids\"][idx])\n            prompt_feats = tokenizer(\n                [prompt],\n                truncation=True,\n                max_length=allowed_prompt_length\n            )\n            for k, v in prompt_feats.items():\n                features[k].append(crossfile_features[k][idx] + prompt_feats[k][0])\n\n        # pad to max_seq_length\n        tokenizer.padding_side = \"left\"\n        features = tokenizer.pad(features, padding=\"max_length\", max_length=args.max_seq_length - args.gen_length)\n        features[\"index\"] = examples[\"index\"]\n        return features\n\n    if args.model_type in [\"codelm\", \"seq2seqlm\"]:\n        tokenized_datasets = raw_datasets.map(\n            prepare_features,\n            batched=True,\n            num_proc=args.preprocessing_num_workers,\n            remove_columns=column_names,\n            load_from_cache_file=not args.overwrite_cache,\n            desc=\"Running tokenizer on dataset\",\n        )\n    elif args.model_type == \"codelm_cfc\":\n        tokenized_datasets = raw_datasets.map(\n            prepare_features_cfc,\n            batched=True,\n            num_proc=args.preprocessing_num_workers,\n            remove_columns=column_names,\n            load_from_cache_file=not args.overwrite_cache,\n            desc=\"Running tokenizer on dataset\",\n        )\n    else:\n        raise NotImplementedError(\"prepare feature functions not implemented for new model type\")\n\n    return tokenized_datasets, index2taskid\n\n\ndef model_inference(tokenized_datasets, index2taskid, tokenizer):\n    if args.dtype == 'fp16':\n        dtype = torch.float16\n    elif args.dtype == 'fp32':\n        dtype = torch.float32\n    elif args.dtype == 'bf16':\n        dtype = torch.bfloat16\n    elif args.dtype == 'int8':\n        dtype = torch.int8\n    else:\n        assert False, f'{args.dtype=} not implemented'\n\n    if args.model_type in [\"codelm\", \"codelm_cfc\"]:\n        model = AutoModelForCausalLM.from_pretrained(\n            args.model_name_or_path,\n            torch_dtype=dtype,\n            trust_remote_code=True,\n            revision=\"main\"\n        )\n    else:\n        raise ValueError(\"Unknown model type\")\n\n    total_samples_cnt = len(tokenized_datasets)\n    logger.info(f\"total samples: {total_samples_cnt}\")\n\n    data_sampler = SequentialSampler(tokenized_datasets)\n    dataloader = DataLoader(\n        tokenized_datasets,\n        sampler=data_sampler,\n        collate_fn=custom_data_collator,\n        batch_size=args.batch_size\n    )\n\n    model = accelerator.prepare_model(model)\n    dataloader = accelerator.prepare_data_loader(dataloader)\n\n    if not os.path.isdir(args.output_dir):\n        os.mkdir(args.output_dir)\n\n    tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else tokenizer.bos_token\n    prompt_length = args.max_seq_length - args.gen_length\n\n    @torch.no_grad()\n    def generate_completions(batch):\n        output_dict = custom_generate.generate(\n            accelerator.unwrap_model(model),\n            input_ids=batch[\"input_ids\"],\n            attention_mask=batch[\"attention_mask\"],\n            max_length=args.max_seq_length,\n            temperature=args.temperature,\n            top_k=args.top_k,\n            top_p=args.top_p,\n            do_sample=args.do_sample,\n            num_beams=args.num_beams,\n            num_return_sequences=1,\n            pad_token_id=tokenizer.pad_token_id,\n            return_dict_in_generate=True,\n            output_scores=True\n        )\n        batch_task_id = batch[\"index\"]\n        batch_pred = accelerator.pad_across_processes(\n            output_dict.sequences, dim=1, pad_index=tokenizer.pad_token_id\n        )\n        scores = torch.stack(output_dict.scores, dim=1)\n        batch_scores = accelerator.pad_across_processes(\n            scores, dim=1, pad_index=tokenizer.pad_token_id\n        )\n        # batch_scores.shape = (batch_size x num_gpus x num_return_sequences, max_length)\n        batch_task_id, batch_pred, batch_scores = accelerator.gather((batch_task_id, batch_pred, batch_scores))\n\n        batch_pred = batch_pred[:, prompt_length:]\n        generated_texts = tokenizer.batch_decode(batch_pred, skip_special_tokens=True)\n\n        mean_logp = compute_mean_logp(batch_scores, batch_pred, tokenizer.pad_token_id)\n        return batch_task_id.tolist(), generated_texts, mean_logp\n\n    all_preds = []\n    all_task_ids = []\n    with torch.no_grad():\n        for idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)):\n            completions = None\n            completion_scores = None\n            for seq_idx in range(args.num_return_sequences):\n                batch_task_id, generated_texts, mean_logp = generate_completions(batch)\n                if seq_idx == 0:\n                    all_task_ids.extend(batch_task_id)\n                    batch_size = len(batch_task_id)\n                    completions = [[] for _ in range(batch_size)]\n                    completion_scores = [[] for _ in range(batch_size)]\n\n                for j in range(batch_size):\n                    completions[j].append(generated_texts[j])\n                    completion_scores[j].append(mean_logp[j])\n\n            if args.num_return_sequences == 1:\n                all_preds.extend([c[0] for c in completions])\n            else:\n                for c, cs in zip(completions, completion_scores):\n                    max_score = max(cs)\n                    max_index = cs.index(max_score)\n                    all_preds.append(c[max_index])\n\n    with open(f\"{args.output_dir}/prediction.jsonl\", \"w\", encoding=\"utf-8\") as f_pred:\n        id_processed = set()\n        for idx, p in zip(all_task_ids, all_preds):\n            if index2taskid[idx] not in id_processed:\n                f_pred.write(json.dumps({\"task_id\": index2taskid[idx], \"pred\": p}) + \"\\n\")\n                id_processed.add(index2taskid[idx])\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    # model inference args\n    parser.add_argument(\"--language\", type=str, required=True, help=\"language name\")\n    parser.add_argument(\"--model_name_or_path\", default=None, type=str, help=\"Pre-trained Model Path\")\n    parser.add_argument(\n        \"--model_type\",\n        type=str,\n        default=\"codelm\",\n        choices=[\"codelm\", \"codelm_cfc\"],\n        help=\"Model type to be loaded\"\n    )\n    parser.add_argument(\"--prompt_file\", type=str, default=None, help=\"file with a list of prompts\")\n    parser.add_argument(\"--gen_length\", type=int, default=50, help=\"max length of generated token sequence\")\n    parser.add_argument(\"--max_seq_length\", type=int, default=2048, help=\"max length of prompt\")\n    parser.add_argument(\n        \"--cfc_seq_length\",\n        type=int,\n        default=512,\n        help=\"For model_type=codelm_cfc: Text sequence length corresponding to the retrieved nodes\"\n    )\n    parser.add_argument(\n        \"--min_cfc_score\",\n        type=float,\n        default=float('-inf'),\n        help=\"For model_type=codelm_cfc: min score of a chunk to be considered as CFC chunk\"\n    )\n    parser.add_argument(\"--batch_size\", type=int, default=32, help=\"batch size for code completion\")\n    parser.add_argument(\"--stop_token\", type=str, default=None, help=\"Token at which text generation is stopped\")\n    parser.add_argument(\"--cache_dir\", type=str, default=None)\n    parser.add_argument(\n        \"--temperature\",\n        type=float,\n        default=0.2,\n        help=\"temperature of 1.0 has no effect, lower tend toward greedy sampling\"\n    )\n    parser.add_argument(\"--output_dir\", type=str, default=\"output_dir\", help=\"output directory to save predictions\")\n    parser.add_argument(\"--top_k\", type=int, default=0)\n    parser.add_argument(\"--top_p\", type=float, default=0.95)\n    parser.add_argument(\"--seed\", type=int, default=42, help=\"random seed for initialization\")\n    parser.add_argument(\"--no_cuda\", action=\"store_true\", help=\"Avoid using CUDA when available\")\n    parser.add_argument(\"--num_return_sequences\", type=int, default=1, help=\"The number of samples to generate.\")\n    parser.add_argument(\"--repetition_penalty\", type=float, default=1.0, help=\"The parameter for repetition penalty.\")\n    parser.add_argument(\n        \"--preprocessing_num_workers\",\n        type=int,\n        default=1,\n        help=\"The number of processes to use for the preprocessing.\"\n    )\n    parser.add_argument(\n        \"--overwrite_cache\",\n        type=bool,\n        default=False,\n        help=\"Overwrite the cached training and evaluation sets\"\n    )\n    parser.add_argument(\"--dtype\", type=str, default='bf16')\n    parser.add_argument(\"--do_sample\", action=\"store_true\", help=\"whether we do sampling or greedy/beam-search\")\n    parser.add_argument(\"--num_beams\", type=int, default=1, help=\"num of beam for beam-search\")\n    # compute metric args\n    parser.add_argument(\n        \"--ts_lib\",\n        type=str,\n        default=\"build/python-lang-parser.so\",\n        help=\"tree-sitter lib for tokenize code\"\n    )\n    # only compute metric\n    parser.add_argument(\"--only_compute_metric\", action=\"store_true\", help=\"only compute metric\")\n    args = parser.parse_args()\n    set_seed(args.seed, device_specific=False)\n\n    if args.num_return_sequences > 1:\n        assert args.do_sample, \"sampling must be set to True when num_return_sequences > 1\"\n\n    accelerator = Accelerator()\n    if not args.only_compute_metric:\n        tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)\n        tokenized_datasets, index2taskid = build_datasets(args, tokenizer)\n        model_inference(tokenized_datasets, index2taskid, tokenizer)\n\n    # check if the process is the main process\n    if accelerator.is_main_process:\n        compute_metric_stmt(args)\n"
  },
  {
    "path": "scripts/eval_metric.py",
    "content": "import json\nfrom functools import partial\n\nimport torch.multiprocessing as mp\nfrom tqdm import tqdm\nfrom tree_sitter import Language, Parser\n\nfrom eval_utils import (\n    postprocess_code_lines,\n    extract_identifiers,\n    cal_edit_sim,\n    remove_comments\n)\nimport os\n\nparser = None\n\n\ndef compute_id_match(pred_ids, target_ids):\n    pred_ids = list(set(pred_ids))\n    target_ids = list(set(target_ids))\n    tp = 0\n    fp = 0\n    fn = 0\n    for pid in pred_ids:\n        if pid in target_ids:\n            tp += 1\n        else:\n            fp += 1\n    for tid in target_ids:\n        if tid not in pred_ids:\n            fn += 1\n    return tp, fp, fn\n\n\ndef compute_edit_sim(samples):\n    refs, hyps = [], []\n    for s in samples:\n        refs.append(s[\"target\"])\n        hyps.append(s[\"pred\"])\n    return cal_edit_sim(refs, hyps)\n\n\ndef process_examples(lang, args):\n    sample, ex = args\n    global parser\n\n    prediction = postprocess_code_lines(ex[\"prompt\"], sample[\"pred\"], parser, lang)\n    prediction = remove_comments(prediction)\n    target = ex[\"groundtruth\"]\n    target = remove_comments(target)\n\n    pred_lines = [l.strip() for l in prediction.split(\"\\n\") if l.strip()]\n    gt_lines = [l.strip() for l in target.split(\"\\n\") if l.strip()]\n    em_label = int(pred_lines == gt_lines)\n\n    pred_ids = extract_identifiers(prediction, lang)\n    target_ids = extract_identifiers(target, lang)\n\n    trunc_s = {\n        \"task_id\": sample[\"task_id\"],\n        \"pred\": prediction,\n        \"target\": target,\n        \"pred_ids\": pred_ids,\n        \"target_ids\": target_ids\n    }\n    return trunc_s, em_label\n\n\ndef compute_metric_stmt(args):\n    with open(os.path.join(args.output_dir, \"prediction.jsonl\"), \"r\") as f_pred:\n        samples = []\n        for l in f_pred.readlines():\n            samples.append(json.loads(l))\n\n    examples = {}\n    with open(args.prompt_file, \"r\") as f_in:\n        for l in f_in.readlines():\n            ex = json.loads(l)\n            examples[ex[\"metadata\"][\"task_id\"]] = {\n                \"prompt\": ex[\"prompt\"],\n                \"groundtruth\": ex[\"groundtruth\"]\n            }\n\n    assert len(samples) == len(examples), f\"{len(samples)} != {len(examples)}\"\n\n    global parser\n    ts_lang = \"c_sharp\" if args.language == \"csharp\" else args.language\n    language = Language(args.ts_lib, ts_lang)\n    parser = Parser()\n    parser.set_language(language)\n\n    truncated_samples = []\n    em_labels = []\n\n    print(\"post-processing samples ...\")\n    pool = mp.Pool(mp.cpu_count() - 1)\n    worker = partial(process_examples, args.language)\n\n    with tqdm(total=len(samples)) as pbar:\n        for output in pool.imap_unordered(worker, zip(samples, [examples[s[\"task_id\"]] for s in samples])):\n            trunc_s, em_label = output\n            em_labels.append(em_label)\n            truncated_samples.append(trunc_s)\n            pbar.update()\n\n    exact_match = 0\n    with open(os.path.join(args.output_dir, \"prediction_truncated.jsonl\"), 'w', encoding=\"utf-8\") as pt, \\\n            open(f\"{args.output_dir}/exact_match_idx.jsonl\", 'w') as em:\n        for trunc_s, em_label in zip(truncated_samples, em_labels):\n            pt.write(json.dumps(trunc_s) + \"\\n\")\n            if em_label == 1:\n                em.write(f'{trunc_s[\"task_id\"]}\\n')\n                exact_match += 1\n\n    ### Score calculation\n\n    id_em = []\n    edit_similarities = []\n    detailed_results = []\n\n    for idx, trunc_s in enumerate(truncated_samples):\n        identifier_em = int(trunc_s[\"pred_ids\"] == trunc_s[\"target_ids\"])\n        es = cal_edit_sim([trunc_s[\"target\"]], [trunc_s[\"pred\"]])\n        id_tp, id_fp, id_fn = compute_id_match(trunc_s[\"pred_ids\"], trunc_s[\"target_ids\"])\n        id_em.append(identifier_em)\n        edit_similarities.append(es)\n\n        detailed_results.append({\n            \"task_id\": trunc_s[\"task_id\"],\n            \"em\": em_labels[idx],\n            \"es\": es,\n            \"id_em\": identifier_em,\n            \"id_precision\": id_tp / (id_tp + id_fp) if (id_tp + id_fp) != 0 else 0,\n            \"id_recall\": id_tp / (id_tp + id_fn) if (id_tp + id_fn) != 0 else 0,\n            \"id_f1\": 2 * id_tp / (2 * id_tp + id_fp + id_fn) if (2 * id_tp + id_fp + id_fn) != 0 else 0,\n        })\n\n    em_ratio = round(exact_match / len(samples) * 100, 2)\n    edit_sim = round(sum(edit_similarities) / len(edit_similarities), 2)\n\n    id_em_ratio = round(\n        sum(detailed_results[idx]['id_em'] for idx in range(len(detailed_results))) / len(detailed_results) * 100, 2)\n    id_precision = round(sum(detailed_results[idx]['id_precision'] for idx in range(len(detailed_results))) / len(\n        detailed_results) * 100, 2)\n    id_recall = round(\n        sum(detailed_results[idx]['id_recall'] for idx in range(len(detailed_results))) / len(detailed_results) * 100,\n        2)\n    id_f1 = round(\n        sum(detailed_results[idx]['id_f1'] for idx in range(len(detailed_results))) / len(detailed_results) * 100, 2)\n\n    print(\n        f\"Code Matching: \"\n        f\"EM {em_ratio:.2f}, \"\n        f\"ES {edit_sim:.2f}\"\n    )\n\n    print(\n        f\"ID matching: \"\n        f\"EM {id_em_ratio}, \"\n        #f\"Precision {id_precision}, \"\n        #f\"Recall {id_recall}, \"\n        f\"F1 {id_f1}\"\n    )\n\n    with open(os.path.join(args.output_dir, \"detailed_results.json\"), 'w') as f:\n        for dr in detailed_results:\n            f.write(json.dumps(dr) + \"\\n\")\n\n    # write the results to a file\n    print(f'writing results to {os.path.join(args.output_dir, \"results.json\")}')\n    with open(os.path.join(args.output_dir, \"results.json\"), 'w') as f:\n        res = {\n            \"em\": em_ratio,\n            \"es\": edit_sim,\n            \"id_em\": id_em_ratio,\n            \"id_precision\": id_precision,\n            \"id_recall\": id_recall,\n            \"id_f1\": id_f1,\n            \"total\": len(truncated_samples)\n        }\n        f.write(json.dumps(res, indent=2))\n"
  },
  {
    "path": "scripts/eval_utils.py",
    "content": "# Copyright Amazon.com, Inc. or its affiliates. All rights reserved.\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport ast\nimport re\nfrom functools import lru_cache\nfrom typing import List\n\nimport timeout_decorator\nimport torch\nfrom fuzzywuzzy import fuzz\nfrom nltk.tokenize import RegexpTokenizer\nfrom sacrebleu.tokenizers.tokenizer_intl import TokenizerV14International\n\nfrom keywords.keywordlist import get_language_keywords\n\nIDENTIFIER_REGEX = re.compile('[_a-zA-Z][_a-zA-Z0-9]*')\nREGEX_TEXT = (\"(?<=[a-z0-9])(?=[A-Z])|\"\n              \"(?<=[A-Z0-9])(?=[A-Z][a-z])|\"\n              \"(?<=[0-9])(?=[a-zA-Z])|\"\n              \"(?<=[A-Za-z])(?=[0-9])|\"\n              \"(?<=[@$.'\\\"])(?=[a-zA-Z0-9])|\"\n              \"(?<=[a-zA-Z0-9])(?=[@$.'\\\"])|\"\n              \"_|\\\\s+\")\nstring_pattern = r'\"([^\"\\\\]*(\\\\.[^\"\\\\]*)*)\"|\\'([^\\'\\\\]*(\\\\.[^\\'\\\\]*)*)\\''\n\nSPLIT_REGEX = re.compile(REGEX_TEXT)\n\nstr_tokenizer = TokenizerV14International()\ncode_tokenizer = RegexpTokenizer(r'\\w+')\n\n\ndef cal_edit_sim(references, hypotheses):\n    total = len(references)\n    edit_sim = 0.0\n    for pred, gt in zip(hypotheses, references):\n        pred = pred.strip()\n        gt = gt.strip()\n        edit_sim += fuzz.ratio(pred, gt)\n    return edit_sim / total\n\n\n@lru_cache(maxsize=5000)\ndef split_identifier_into_parts(identifier: str) -> List[str]:\n    \"\"\"\n    Split a single identifier into parts on snake_case and camelCase\n    \"\"\"\n    identifier_parts = list(s for s in SPLIT_REGEX.split(identifier) if len(s) > 0)\n\n    if len(identifier_parts) == 0:\n        return [identifier]\n    if \"_\" in identifier:  # We consider \"_\" as part of identifier and add it back in between each semantic part\n        # if snake_case, we only split identifiers based on \"_\", ignore the mixed camelCase or other special symbols\n        # this helps us avoid splitting identifiers like \"get_2d_array\" into [\"get\", \"2\", \"d\", \"array\"]\n        # also avoid many other corner cases\n        identifier_parts = identifier.split(\"_\")\n        tmp = [identifier_parts[0]]\n        for i in identifier_parts[1:]:\n            tmp.append(\"_\")\n            tmp.append(i)\n        identifier_parts = tmp\n\n    return identifier_parts\n\n\ndef is_identifier(token, lang=None):\n    return True if IDENTIFIER_REGEX.match(token) \\\n                   and (lang is None or token not in get_language_keywords(lang)) \\\n        else False\n\n\ndef extract_identifiers(source_code, lang):\n    # the main idea is to remove String from a source code\n    # then, tokenize the code to get all words and match with identifier regular expression\n    # check if it is a language specific keyword, it not, then it is an identifier\n    source_code_without_strings = re.sub(string_pattern, '', source_code)\n    _ids = [t for t in code_tokenizer.tokenize(source_code_without_strings) if is_identifier(t, lang)]\n    return _ids\n\n\ndef tokenize_string(input_str):\n    return str_tokenizer(input_str)\n\n\ndef get_bracket_lang_statement(completion):\n    end_idx = None\n    for i in range(len(completion)):\n        if completion[i] in [\";\", \"}\", \"{\"]:\n            end_idx = i\n            break\n    return completion[:end_idx + 1] if end_idx else completion\n\n\n@timeout_decorator.timeout(5)\ndef get_ast(parser, code):\n    assert isinstance(code, str) or isinstance(code, bytes)\n    if isinstance(code, str):\n        code = bytes(code, \"utf8\")\n    try:\n        tree = parser.parse(code)\n        return tree\n    except Exception as e:\n        return None\n\n\ndef remove_comments(code):\n    code = re.sub(r'#.*', '', code)\n    code = re.sub(r'//.*', '', code)\n    return code\n\n\ndef is_parse_valid(parser, code):\n    def syntax_error(node):\n        if node.type == \"ERROR\":\n            return True\n        try:\n            for child in node.children:\n                if syntax_error(child):\n                    return True\n        except RecursionError as err:\n            return True\n\n        return False\n\n    tree = get_ast(parser, code)\n    if tree is not None:\n        return not syntax_error(tree.root_node)\n    return False\n\n\ndef is_code_parseable(code):\n    try:\n        ast.parse(code)\n        return True\n    except SyntaxError:\n        return False\n\n\ndef get_python_one_statement(prompt, completion, parser):\n    for i in range(len(completion)):\n        code = prompt + completion[:i + 1]\n        if not is_parse_valid(parser, code):\n            continue\n        if completion[i + 1] == \"\\n\":\n            return completion[:i + 1].rstrip()\n\n    return completion\n\n\ndef postprocess_code_lines(prompt, completion, parser, lang):\n    try:\n        if lang in [\"java\", \"csharp\", \"typescript\"]:\n            return get_bracket_lang_statement(completion)\n        elif lang == \"python\":\n            return get_python_one_statement(prompt, completion, parser)\n    except Exception as e:\n        return completion\n\n\ndef compute_mean_logp(scores, sequences, pad_token_id):\n    assert scores.shape[0] == sequences.shape[0]\n    assert scores.shape[1] == sequences.shape[1]\n    with torch.no_grad():\n        logp_vocab = torch.nn.functional.log_softmax(scores, dim=-1)\n        indices = torch.unsqueeze(sequences, dim=-1)\n        logp = torch.gather(logp_vocab, dim=-1, index=indices).squeeze(-1)\n        sum_logp = torch.cumsum(logp, dim=1)  # batch_size, seq_len\n        denom = torch.arange(1, sum_logp.shape[1] + 1).reshape(1, -1).to(device=sum_logp.device)  # 1, seq_len\n        mean_logp = (sum_logp / denom).tolist()  # batch_size, seq_len\n        sequence_lengths = (sequences != pad_token_id).sum(1).tolist()  # batch_size\n        mean_logp = [mean_logp[idx][l - 1] for idx, l in enumerate(sequence_lengths)]\n    return mean_logp\n"
  },
  {
    "path": "scripts/keywords/__init__.py",
    "content": ""
  },
  {
    "path": "scripts/keywords/csharp.txt",
    "content": "abstract\nas\nbase\nbool\nbreak\nbyte\ncase\ncatch\nchar\nchecked\nclass\nconst\ncontinue\ndecimal\ndefault\ndelegate\ndo\ndouble\nelse\nenum\nevent\nexplicit\nextern\nfinally\nfixed\nfloat\nfor\nforeach\ngoto\nif\nimplicit\nin\nint\ninterface\ninternal\nis\nlock\nlong\nnamespace\nnew\nnull\nobject\noperator\nout\noverride\nparams\nprivate\nprotected\npublic\nreadonly\nref\nreturn\nsbyte\nsealed\nshort\nsizeof\nstackalloc\nstatic\nstring\nstruct\nswitch\nthis\nthrow\ntry\ntypeof\nuint\nulong\nunchecked\nunsafe\nushort\nusing\nusing\nstatic\nvirtual\nvoid\nvolatile\nwhile"
  },
  {
    "path": "scripts/keywords/java.txt",
    "content": "abstract\nassert\nboolean\nbreak\nbyte\ncase\ncatch\nchar\nclass\ncontinue\ndefault\ndo\ndouble\nelse\nenum\nextends\nfinal\nfinally\nfloat\nfor\nif\nimplements\nimport\ninstanceof\nint\ninterface\nlong\nnative\nnew\npackage\nprivate\nprotected\npublic\nreturn\nshort\nstatic\nstrictfp\nsuper\nswitch\nsynchronized\nthis\nthrow\nthrows\ntransient\ntry\nvoid\nvolatile\nwhile\nvar\nconst\ngoto\n"
  },
  {
    "path": "scripts/keywords/javascript.txt",
    "content": "break\ncase\ncatch\nclass\nconst\ncontinue\ndebugger\ndefault\ndelete\ndo\nelse\nexport\nextends\nfinally\nfor\nfunction\nif\nimport\nin\ninstanceof\nnew\nreturn\nsuper\nswitch\nthis\nthrow\ntry\ntypeof\nvar\nvoid\nwhile\nwith\nyield\nenum\nimplements\ninterface\nlet\npackage\nprivate\nprotected\npublic\nstatic"
  },
  {
    "path": "scripts/keywords/keywordlist.py",
    "content": "# Original Copyright 2021 Microsoft under MIT License.\n# From https://github.com/microsoft/dpu-utils/blob/master/python/dpu_utils/codeutils/keywords/keywordlist.py\n\nimport os\nimport keyword\nfrom functools import lru_cache\nfrom typing import FrozenSet\n\n__all__ = ['get_language_keywords']\n\n_LANGUAGE_TO_FILENAME = {\n    'c': 'c.txt',\n    'cpp': 'cpp.txt',\n    'c++': 'cpp.txt',\n    'csharp': 'csharp.txt',\n    'c_sharp': 'csharp.txt',\n    'c#': 'csharp.txt',\n    'go': 'go.txt',\n    'java': 'java.txt',\n    'javascript': 'javascript.txt',\n    'js': 'javascript.txt',\n    'php': 'php.txt',\n    'ruby': 'ruby.txt',\n    'typescript': 'typescript.txt',\n    'ts': 'typescript.txt',\n}\n\n\n@lru_cache()\ndef get_language_keywords(language: str) -> FrozenSet[str]:\n    \"\"\"\n    Returns the keywords of a programming language.\n\n    There are some inconsistencies across languages wrt to\n    what is considered a keyword. For example, the true/false\n    literals are considered keywords in many languages. However,\n    we exclude them here for consistency. We also exclude special\n    functions-like keywords, such as `die()` in PHP.\n    \"\"\"\n    language = language.lower()\n    if language == 'python':\n        return frozenset(k for k in keyword.kwlist if k != 'True' and k != 'False')\n    elif language in _LANGUAGE_TO_FILENAME:\n        name = _LANGUAGE_TO_FILENAME[language]\n        with open(os.path.join(os.path.dirname(__file__), name)) as f:\n            return frozenset(l.strip() for l in f if len(l.strip()) > 0)\n    else:\n        raise Exception('Language keywords `%s` not supported yet. Consider contributing it to dpu-utils.' % language)\n"
  },
  {
    "path": "scripts/keywords/typescript.txt",
    "content": "break\ncase\ncatch\nclass\nconst\ncontinue\ndebugger\ndefault\ndelete\ndo\nelse\nexport\nextends\nfinally\nfor\nfunction\nif\nimport\nin\ninstanceof\nnew\nreturn\nsuper\nswitch\nthis\nthrow\ntry\ntypeof\nvar\nvoid\nwhile\nwith\nyield\nenum\nimplements\ninterface\nlet\npackage\nprivate\nprotected\npublic\nstatic"
  },
  {
    "path": "scripts/openai_inference.py",
    "content": "\"\"\"\nScript to query an OpenAI API to generate code.\nSet environment variable OPENAI_KEY with your API key\nbefore running this script.\n\"\"\"\n\nimport argparse\nimport json\nimport os\nimport time\nfrom typing import Dict, List, Tuple\n\nimport numpy as np\nimport openai\nimport tiktoken\nfrom openai import OpenAI\nfrom openai.types.chat import ChatCompletion\nfrom tqdm import tqdm\n\nSLEEP_SECOND = 2.8  # minimum time to sleep with API errors\nMAX_SLEEP_SECOND = 120  # maximum time sleep time to wait with exp backoff\nBUFFER = 100  # estimated tokens used by OpenAI + some more buffer\nSYS_PROMPT = 'You are Codex, a code completion language model. Continue the code presented to you.'\n\nopenai_api_key = os.environ.get(\"OPENAI_API_KEY\")\nassert openai_api_key is not None, \"Please set openai_api_key with your API key\"\nclient = OpenAI()\n\n\ndef query(\n        args,\n        prompt: str,\n) -> ChatCompletion:\n    \"\"\"\n    This function queries an OpenAI API to generate code based on the given prompt.\n\n    Args:\n    prompt: str, the prompt to generate code from\n    temperature: float, the value used to module the next token probabilities\n    max_tokens: int, the maximum number of tokens to generate\n    top_p: float, the cumulative probability for top-p filtering\n\n    Returns:\n    OpenAI Completion object, the response from the OpenAI Codex API\n    \"\"\"\n    return client.chat.completions.create(model=args.model,\n                                          messages=[\n                                              {\"role\": \"system\", \"content\": SYS_PROMPT},\n                                              {\"role\": \"user\", \"content\": prompt}\n                                          ],\n                                          temperature=args.temperature,\n                                          max_tokens=args.generation_max_tokens,\n                                          top_p=args.top_p,\n                                          )\n\n\ndef query_with_retry(\n        args,\n        prompt: str,\n\n) -> ChatCompletion | None:\n    \"\"\"\n    This function queries an OpenAI API to generate code based on the given prompt.\n\n    Args:\n    prompt: str, the prompt to generate code from\n    sleep_second: int, the number of seconds to sleep when the rate limit error is raised\n    temperature: float, the value used to module the next token probabilities\n    max_tokens: int, the maximum number of tokens to generate\n    top_p: float, the cumulative probability for top-p filtering\n\n    Returns:\n    OpenAI Completion object, the response from the OpenAI Codex API if succeeds\n    else return None\n\n    Reference:\n    https://github.com/Leolty/repobench/blob/c24b7a80465957e75107eafd23c66d369fa9e755/model/codex.py\n    \"\"\"\n\n    error_sleep_second = SLEEP_SECOND\n\n    def _upd_error_sleep_time(error_sleep_second):\n        # double the sleep time if it is less than MAX_SLEEP_SECOND seconds\n        if error_sleep_second < MAX_SLEEP_SECOND:\n            error_sleep_second *= 2\n        # if the sleep time is greater than MAX_SLEEP_SECOND seconds,\n        # then sleep MAX_SLEEP_SECOND seconds\n        else:\n            error_sleep_second = MAX_SLEEP_SECOND\n        return error_sleep_second\n\n    while True:\n        try:\n            response = query(args, prompt)\n            time.sleep(SLEEP_SECOND + np.random.rand())\n            return response\n        except openai.RateLimitError as e:\n            print(f'RateLimitError: {e}')\n            print(f'Retrying after {error_sleep_second} seconds')\n            time.sleep(error_sleep_second)\n            error_sleep_second = _upd_error_sleep_time(error_sleep_second)\n        except openai.OpenAIError as e:\n            print(f'OpenAIError: {e}')\n            print(f'Retrying after {error_sleep_second} seconds')\n            time.sleep(error_sleep_second)\n            error_sleep_second = _upd_error_sleep_time(error_sleep_second)\n\n\ndef truncate(prompt: str, max_num_tokens: int, tokenizer, side: str) -> str:\n    \"\"\"Truncate prompt from side given the token budget\"\"\"\n\n    # use tiktokenizer to analyze num of tokens\n    tokens = tokenizer.encode(prompt, disallowed_special=())\n    num_tokens = len(tokens)\n\n    if num_tokens > max_num_tokens:\n        if side == 'left':\n            prompt_tokens = tokens[num_tokens - max_num_tokens:]\n        elif side == 'right':\n            prompt_tokens = tokens[:max_num_tokens]\n        else:\n            assert False, 'Invalid side'\n        # decode and encode again as a sanity check\n        prompt = tokenizer.decode(prompt_tokens)\n        new_len = len(tokenizer.encode(prompt, disallowed_special=()))\n        assert new_len <= max_num_tokens\n    return prompt\n\n\ndef prepare_prompt(\n        prompt: str,\n        cross_file_context: str,\n        cross_file_budget: int,\n        prompt_budget: int,\n        tokenizer\n) -> str:\n    \"\"\"Create an augmented prompt according to budget specs\"\"\"\n\n    # left truncate original prompt\n    prompt = truncate(prompt, prompt_budget, tokenizer, 'left')\n\n    if cross_file_context is not None:\n        # right truncate cross file context string\n        cross_file_context = truncate(cross_file_context, cross_file_budget, tokenizer, 'right')\n    else:\n        cross_file_context = ''\n\n    # return <CFC>\\n<PROMPT>\n    return cross_file_context + '\\n' + prompt\n\n\ndef get_openai_response(\n        sample: Dict,\n        tokenizer,\n        args\n) -> Tuple[str, Dict]:\n    \"\"\"Get OpenAI response for a single sample. Returns the prompt used to\n    infer and the response of the API.\"\"\"\n    if args.use_crossfile_context:\n        prompt = prepare_prompt(\n            sample['prompt'], sample['crossfile_context']['text'],\n            args.crossfile_max_tokens,\n            args.model_max_tokens - args.generation_max_tokens - args.crossfile_max_tokens - BUFFER,\n            tokenizer\n        )\n    else:\n        prompt = prepare_prompt(\n            sample['prompt'], None,\n            0,\n            args.model_max_tokens - args.generation_max_tokens - BUFFER,\n            tokenizer\n        )\n\n    response = query_with_retry(args, prompt)\n    return prompt, response\n\n\ndef get_openai_responses(\n        args, data, out_path\n) -> List[str]:\n    \"\"\"Get OpenAI responses to all samples in data, store in out_path,\n    and return list of task ids that were skipped due to some errors\"\"\"\n    tokenizer = tiktoken.encoding_for_model(args.model)\n    skipped = []\n    with open(out_path, 'w') as f:\n        for d in tqdm(data):\n            try:\n                prompt, response = get_openai_response(\n                    d, tokenizer, args\n                )\n            except Exception as e:\n                print('Unknown error', e)\n                raise\n\n            if response is not None:\n                d['pred_raw'] = response.choices[0].message.content  # key compatible with eval script\n                d['pred'] = '\\n'.join(d['pred_raw'].split('\\n')[1:]).strip('`') if d['pred_raw'].startswith('```') else d['pred_raw'] # newer chatgpt may ourput ```[lang_tag]``` at beginning \n                # d['api_response'] = str(response)\n                d['prompt_used'] = prompt  # records the augmented prompt\n                d['task_id'] = d['metadata']['task_id']  # adding for compatibility with eval script\n                print(json.dumps(d), file=f, flush=True)\n            else:\n                skipped.append(d['metadata']['task_id'])\n                print(f'Skipped {d[\"metadata\"][\"task_id\"]}')\n\n    return skipped\n\n\ndef main():\n    # get config for current run\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--temperature', type=float, default=0.2)\n    parser.add_argument('--top_p', type=float, default=0.95)\n    parser.add_argument(\n        '--task', type=str, required=True,\n    )\n    parser.add_argument(\n        '--language', type=str, required=True,\n        choices=['csharp', 'python', 'java', 'typescript']\n    )\n    parser.add_argument(\n        '--data_root_dir', type=str, default='data/',\n        help='path to directory where data is organized in lang/task.jsonl format'\n    )\n    parser.add_argument(\n        '--output_dir', type=str, required=True,\n        help='path to directory where to store outputs'\n    )\n    parser.add_argument(\n        '--model', type=str, required=True,\n        help='openAI-supported model'\n    )\n    parser.add_argument(\n        '--model_max_tokens', type=int, default=16384,\n        help='maximum number of tokens of the model'\n    )\n    parser.add_argument(\n        '--crossfile_max_tokens', type=int, default=12800,\n        help='maximum number of tokens for cross file context'\n    )\n    parser.add_argument(\n        '--use_crossfile_context', action='store_true',\n        help='whether use cross file context'\n    )\n    parser.add_argument(\n        '--generation_max_tokens', type=int, default=50,\n        help='maximum number of tokens to generate'\n    )\n    args = parser.parse_args()\n    print(json.dumps(vars(args), indent=4))\n\n    # setup paths\n    if not os.path.isdir(args.output_dir):\n        print(f'==== Output dir does not exist. Creating: {args.output_dir} ====')\n        os.makedirs(args.output_dir)\n    data_path = os.path.join(args.data_root_dir, args.language, args.task + '.jsonl')\n    data = [json.loads(l) for l in open(data_path, 'r').readlines()]\n\n    out_path = os.path.join(args.output_dir, 'prediction.jsonl')\n    # start OpenAI inference\n    skipped_tasks = get_openai_responses(\n        args, data, out_path\n    )\n\n    # save list of skipped tasks\n    with open(out_path.replace('.jsonl', '_skipped_tasks.json'), 'w') as f:\n        f.write(json.dumps(skipped_tasks))\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "scripts/vllm_inference.py",
    "content": "\"\"\"\nScript to run vllm-based inference. See README for an example.\n\"\"\"\n\nimport argparse\nimport json\nimport os\nfrom typing import List\n\nfrom tqdm import tqdm\nfrom transformers import AutoTokenizer\nfrom transformers.utils import logging\nfrom vllm import LLM, SamplingParams\n\nlogging.set_verbosity_info()\nlogger = logging.get_logger(__name__)\n# add a small buffer to take care of non-lossless tokenizers\nBUFFER = 100\n\n\ndef truncate(prompt: str, max_num_tokens: int, side: str, tokenizer) -> str:\n    \"\"\"Truncate prompt from side given the token budget\"\"\"\n\n    tokens = tokenizer.tokenize(prompt)\n    num_tokens = len(tokens)\n\n    if num_tokens > max_num_tokens:\n        if side == 'left':\n            prompt_tokens = tokens[num_tokens - max_num_tokens:]\n        elif side == 'right':\n            prompt_tokens = tokens[:max_num_tokens]\n        prompt = tokenizer.convert_tokens_to_string(prompt_tokens)\n        new_len = len(tokenizer.tokenize(prompt))\n        if new_len > max_num_tokens:\n            logger.warning(\n                f'Number of tokens after truncation is greater than max tokens allowed: {new_len=} {num_tokens=}')\n    return prompt\n\n\ndef prepare_prompt(\n        prompt: str,\n        cross_file_context: str,\n        cross_file_budget: int,\n        prompt_budget: int,\n        tokenizer\n) -> str:\n    \"\"\"Create an augmented prompt according to budget specs\"\"\"\n\n    # print(f'{cross_file_budget=} {prompt_budget=}')\n    # left truncate original prompt\n    prompt = truncate(prompt, prompt_budget, 'left', tokenizer)\n\n    if cross_file_context is not None:\n        # right truncate cross file context string\n        cross_file_context = truncate(cross_file_context, cross_file_budget, 'right', tokenizer)\n    else:\n        cross_file_context = ''\n\n    return cross_file_context + '\\n' + prompt\n\n\ndef cceval_generate(\n        args,\n        data,\n        tokenizer,\n        sampling_params,\n        llm\n) -> List[str]:\n    prompts = []\n    for d in data:\n        if args.use_crossfile_context:\n            prompt = prepare_prompt(\n                d['prompt'], d['crossfile_context']['text'],\n                args.crossfile_max_tokens,\n                args.model_max_tokens - args.generation_max_tokens - args.crossfile_max_tokens - BUFFER,\n                tokenizer\n            )\n        else:\n            prompt = prepare_prompt(\n                d['prompt'], None,\n                0,\n                args.model_max_tokens - args.generation_max_tokens - BUFFER,\n                tokenizer\n            )\n        prompts.append(prompt)\n\n    outputs = llm.generate(prompts, sampling_params)\n\n    out_path = os.path.join(args.output_dir, 'prediction.jsonl')\n    with open(out_path, 'w') as f:\n        for d, response in tqdm(zip(data, outputs)):\n            d['pred'] = response.outputs[0].text\n            d['task_id'] = d['metadata']['task_id']\n            print(json.dumps(d), file=f, flush=True)\n\n    return\n\n\ndef main():\n    # set the OpenAI key\n    # openai.api_key = os.environ.get('OPENAI_KEY', None)\n    # if openai.api_key is None:\n    #    raise ValueError('OPENAI_KEY environment variable not set')\n\n    # get config for current run\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--temperature', type=float, default=0.2)\n\n    parser.add_argument('--top_p', type=float, default=0.95)\n    parser.add_argument(\n        '--task', type=str, required=True,\n    )\n    parser.add_argument(\n        '--language', type=str, required=True,\n        choices=['csharp', 'python', 'java', 'typescript']\n    )\n    parser.add_argument(\n        '--data_root_dir', type=str, default='data/',\n        help='path to directory where data is organized in lang/task.jsonl format'\n    )\n    parser.add_argument(\n        '--output_dir', type=str, required=True,\n        help='path to directory where to store outputs'\n    )\n    parser.add_argument(\n        '--model', type=str, required=True,\n        help='vLLM-supported model'\n    )\n    parser.add_argument(\n        '--tp_size', type=int, default=1,\n        help='tensor parallel size'\n    )\n    parser.add_argument(\n        '--model_max_tokens', type=int, default=16384,\n        help='maximum number of tokens of the model'\n    )\n    parser.add_argument(\n        '--crossfile_max_tokens', type=int, default=12800,\n        help='maximum number of tokens for cross file context'\n    )\n    parser.add_argument(\n        '--use_crossfile_context', action='store_true',\n        help='whether use cross file context'\n    )\n    parser.add_argument(\n        '--generation_max_tokens', type=int, default=50,\n        help='maximum number of tokens to generate'\n    )\n\n    args = parser.parse_args()\n    print(json.dumps(vars(args), indent=4))\n\n    # load model\n    llm = LLM(model=args.model, tensor_parallel_size=args.tp_size, max_model_len=args.model_max_tokens)\n    tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)\n    sampling_params = SamplingParams(temperature=args.temperature, top_p=args.top_p, max_tokens=args.generation_max_tokens)\n\n    # setup paths\n    if not os.path.isdir(args.output_dir):\n        print(f'==== Output dir does not exist. Creating: {args.output_dir} ====')\n        os.makedirs(args.output_dir)\n    data_path = os.path.join(args.data_root_dir, args.language, args.task + '.jsonl')\n    data = [json.loads(l) for l in open(data_path, 'r').readlines()]\n\n    # generation\n    cceval_generate(args, data, tokenizer, sampling_params, llm)\n\n\nif __name__ == '__main__':\n    main()\n"
  }
]