Repository: amazon-science/cceval Branch: main Commit: 40c68d2b7ca2 Files: 29 Total size: 40.2 MB Directory structure: gitextract_65e_y3xi/ ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── THIRD_PARTY_LICENSES ├── cceval_config.yaml ├── data/ │ └── crosscodeeval_data.tar.xz ├── prompt_builder/ │ ├── README.md │ ├── augment_with_cfc.py │ ├── rerank_utils.py │ ├── run.sh │ └── utils.py ├── requirements.txt └── scripts/ ├── build_treesitter.sh ├── build_ts_lib.py ├── custom_generate.py ├── eval.py ├── eval_metric.py ├── eval_utils.py ├── keywords/ │ ├── __init__.py │ ├── csharp.txt │ ├── java.txt │ ├── javascript.txt │ ├── keywordlist.py │ └── typescript.txt ├── openai_inference.py └── vllm_inference.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *.idea/ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ .DS_Store ts_package/ build/ ================================================ FILE: CODE_OF_CONDUCT.md ================================================ ## Code of Conduct This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact opensource-codeofconduct@amazon.com with any additional questions or comments. ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing Guidelines Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional documentation, we greatly value feedback and contributions from our community. Please read through this document before submitting any issues or pull requests to ensure we have all the necessary information to effectively respond to your bug report or contribution. ## Reporting Bugs/Feature Requests We welcome you to use the GitHub issue tracker to report bugs or suggest features. When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: * A reproducible test case or series of steps * The version of our code being used * Any modifications you've made relevant to the bug * Anything unusual about your environment or deployment ## Contributing via Pull Requests Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 1. You are working against the latest source on the *main* branch. 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. To send us a pull request, please: 1. Fork the repository. 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 3. Ensure local tests pass. 4. Commit to your fork using clear commit messages. 5. Send us a pull request, answering any default questions in the pull request interface. 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). ## Finding contributions to work on Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. ## Code of Conduct This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact opensource-codeofconduct@amazon.com with any additional questions or comments. ## Security issue notifications If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. ## Licensing See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. ================================================ FILE: NOTICE ================================================ Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. ================================================ FILE: README.md ================================================ # CrossCodeEval: A Diverse and Multilingual Benchmark for Cross-File Code Completion This repository contains the data and inference code of the NeurIPS 2023 (Datasets and Benchmarks track) paper "[CrossCodeEval: A Diverse and Multilingual Benchmark for Cross-File Code Completion](https://arxiv.org/abs/2310.11248)." ## Requirements - Uncompress the CrossCodeEval data via `tar -xvJf data/crosscodeeval_data.tar.xz -C data/` - The data contains {baseline, retrieval, retrieval w/ ref.} setting x {bm25, UniXCoder, OpenAI Ada} retriever. - **Please [email us](mailto:y.robin.ding@gmail.com) if you need the raw data.** - Install dependencies via `pip install -r requirements.txt` - Build tree sitter via `bash scripts/build_treesitter.sh` ## Evaluation on CrossCodeEval Our evaluation consists of two steps: generation and metrics calculation. ### Generation #### Publicly Available Models For publicly available models like StarCoder, DeepSeek-Coder, etc., we recommended using [vLLM](https://github.com/vllm-project/vllm) for fast and distributed inference on CrossCodeEval. ```bash export gpus=2 export model=bigcode/starcoder2-3b export language=python export task=line_completion_rg1_unixcoder_cosine_sim export output_dir=./tmp/crosscodeeval_testrun/ python scripts/vllm_inference.py \ --tp $gpus \ --task $task \ --language $language \ --model $model \ --output_dir $output_dir \ --use_crossfile_context ``` For additional args, e.g., cross-file context length and sampling top_p, please see `python vllm_inference.py --help`.
If you prefer non-vLLM script :: click to expand ::
First, configure `accelerate` via `accelerate config` if you haven't. A reference configuration is available at `cceval_config.yaml` The following command demonstrates how to run greedy eval using codegen-350M on python with cross-file context. ```bash export model_type=codelm_cfc # or codelm for no cross-file context eval export model_name=Salesforce/codegen-350M-mono export language=python export ts_lib=./build/${language}-lang-parser.so export dtype=bf16 # or fp16 export prompt_file=./data/crosscodeeval_data/${language}/line_completion_rg1_unixcoder_cosine_sim.jsonl # or other options in the dir, which corresponds to different retrieval methods and/or retrieval settings export max_seq_length=2048 export cfc_seq_length=512 export batch_size=16 # reduce for larger models export output_dir=./tmp/crosscodeeval_testrun/ accelerate launch eval.py \ --model_type $model_type \ --model_name_or_path $model_name \ --cfc_seq_length $cfc_seq_length \ --prompt_file $prompt_file \ --gen_length 50 \ --max_seq_length $max_seq_length \ --batch_size $batch_size \ --output_dir $output_dir \ --dtype $dtype \ --num_return_sequences 1 \ --overwrite_cache True \ --ts_lib $ts_lib \ --language $language ``` You may run sampling via the following (additional) args: ```bash --do_sample \ --top_p 0.95 \ --temperature 0.2 \ --num_return_sequences 5 \ ```
#### OpenAI models OpenAI models are accessible through an API. You may use the following script: ```bash export model=gpt-3.5-turbo-0125 export language=python export task=line_completion_rg1_unixcoder_cosine_sim export output_dir=./tmp/crosscodeeval_openai_testrun/ python scripts/openai_inference.py \ --task $task \ --language $language \ --model $model \ --output_dir $output_dir \ --use_crossfile_context ``` ### Metrics Calculation After obtaining the generation, we can calculate the final metrics ```bash export language=python export ts_lib=./build/${language}-lang-parser.so; export task=line_completion_oracle_unixcoder_cosine_sim export prompt_file=./data/${language}/${task}.jsonl export output_dir=./tmp/crosscodeeval_testrun/; python scripts/eval.py \ --prompt_file $prompt_file \ --output_dir $output_dir \ --ts_lib $ts_lib \ --language $language \ --only_compute_metric ``` ## Citation ``` @inproceedings{ding2023crosscodeeval, title={CrossCodeEval: A Diverse and Multilingual Benchmark for Cross-File Code Completion}, author={Yangruibo Ding and Zijian Wang and Wasi Uddin Ahmad and Hantian Ding and Ming Tan and Nihal Jain and Murali Krishna Ramanathan and Ramesh Nallapati and Parminder Bhatia and Dan Roth and Bing Xiang}, year={2023}, booktitle={Thirty-seventh Conference on Neural Information Processing Systems Datasets and Benchmarks Track}, url={https://arxiv.org/pdf/2310.11248.pdf} } ``` ## Questions Please feel free to email us. You may also submit an issue in this repo. ## Security See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. ## License This project is licensed under the Apache-2.0 License. ================================================ FILE: THIRD_PARTY_LICENSES ================================================ The CrossCodeEval repository includes the following third-party software/licensing: The keywordlist.py was from https://github.com/microsoft/dpu-utils/blob/master/python/dpu_utils/codeutils/keywords/keywordlist.py with license MIT License Copyright (c) Microsoft Corporation. All rights reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE ================================================ FILE: cceval_config.yaml ================================================ compute_environment: LOCAL_MACHINE deepspeed_config: gradient_accumulation_steps: 1 offload_optimizer_device: none offload_param_device: none zero3_init_flag: false zero3_save_16bit_model: false zero_stage: 3 distributed_type: DEEPSPEED downcast_bf16: 'no' machine_rank: 0 main_training_function: main num_machines: 1 num_processes: 8 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false ================================================ FILE: data/crosscodeeval_data.tar.xz ================================================ [File too large to display: 40.1 MB] ================================================ FILE: prompt_builder/README.md ================================================ ## Retrieval Augmented Prompting We can generate the retrieval augmented prompt following the below 3 steps. 1. Please email us if to get the raw software repositories. 2. Set the `repository_root` variable with the root directory which contains the raw software repositories. 3. Run the `run.sh` bash script. ================================================ FILE: prompt_builder/augment_with_cfc.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import json import time import glob import argparse import multiprocessing as mp from tqdm import tqdm from functools import partial from rerank_utils import lexical_ranking, SemanticReranking from utils import str2bool, file_distance, tokenize_nltk CHUNK_SIZE = 10 SLIDING_WINDOW_SIZE = 10 # non-overlapping chunks if SLIDING_WINDOW_SIZE=CHUNK_SIZE QUERY_LENGTH = 10 # last N lines from prompt will be query repository_root = "/PATH/TO/REPOS" # get the data from authors input_files = { "python": "../data/crosscodeeval_data/python/line_completion.jsonl", "java": "../data/crosscodeeval_data/java/line_completion.jsonl", "typescript": "../data/crosscodeeval_data/typescript/line_completion.jsonl", "csharp": "../data/crosscodeeval_data/csharp/line_completion.jsonl" } file_ext = {"python": "py", "java": "java", "typescript": "ts", "csharp": "cs"} def get_crossfile_context_from_chunks( args, prompt, code_chunks, code_chunk_ids, groundtruth, semantic_ranker ): assert len(code_chunks) != 0 candidate_code_chunks = code_chunks[:args.maximum_chunk_to_rerank] candidate_code_chunk_ids = code_chunk_ids[:args.maximum_chunk_to_rerank] ranking_scores = None meta_data = {} if args.rerank: if args.query_type == "groundtruth": # oracle experiment prompt_lines = [pl for pl in prompt.split("\n") if pl.strip()] groundtruth_lines = [gt for gt in groundtruth.split("\n") if gt.strip()] code_lines = prompt_lines + groundtruth_lines query = "\n".join(code_lines[-QUERY_LENGTH:]) elif args.query_type == "last_n_lines": prompt_lines = [pl for pl in prompt.split("\n") if pl.strip()] query = "\n".join(prompt_lines[-QUERY_LENGTH:]) else: raise NotImplementedError meta_data["query"] = query start = time.time() if args.ranking_fn == "cosine_sim": gpu_id = int(mp.current_process().name.split('-')[-1]) - 1 candidate_code_chunks, candidate_code_chunk_ids, ranking_scores = semantic_ranker.rerank( query, candidate_code_chunks, candidate_code_chunk_ids, gpu_id, score_threshold=None ) else: candidate_code_chunks, candidate_code_chunk_ids, ranking_scores = lexical_ranking( query, candidate_code_chunks, args.ranking_fn, candidate_code_chunk_ids, score_threshold=None ) meta_data["latency"] = time.time() - start meta_data["num_candidates"] = len(candidate_code_chunks) top_k = min(args.maximum_cross_file_chunk, len(candidate_code_chunk_ids)) if top_k == 0: return [], meta_data selected_chunks = [] selected_chunks_filename = [] selected_chunks_scores = [] if args.use_next_chunk_as_cfc: # prepare an id2idx map assert len(candidate_code_chunks) == len(candidate_code_chunk_ids) id2idx = dict() for j, cci in enumerate(code_chunk_ids): id2idx[cci] = j total_added = 0 for cidx, _id in enumerate(candidate_code_chunk_ids): fname, c_id = _id.rsplit("|", 1) next_id = f"{fname}|{int(c_id) + 1}" if next_id not in id2idx: to_add = code_chunks[id2idx[_id]] else: to_add = code_chunks[id2idx[next_id]] if to_add not in selected_chunks: selected_chunks.append(to_add) selected_chunks_filename.append(fname) if args.rerank: selected_chunks_scores.append(ranking_scores[cidx]) total_added += 1 if total_added == top_k: break else: selected_chunks = candidate_code_chunks[:top_k] selected_chunks_filename = [_id.rsplit("|", 1)[0] for _id in candidate_code_chunk_ids[:top_k]] if args.rerank: selected_chunks_scores = ranking_scores[:top_k] cross_file_context = [] for idx in range(len(selected_chunks)): cross_file_context.append({ "retrieved_chunk": selected_chunks[idx], "filename": selected_chunks_filename[idx], "score": selected_chunks_scores[idx] if args.rerank else None }) line_start_sym = "#" if args.language == "python" else "//" cfc_text = f"{line_start_sym} Here are some relevant code fragments from other files of the repo:\n\n" for sc, scf in zip(selected_chunks, selected_chunks_filename): cfc_text += f"{line_start_sym} the below code fragment can be found in:\n{line_start_sym} {scf}" + "\n" cfc_text += "\n".join([f"{line_start_sym} {cl}" for cl in sc.strip('\n').splitlines()]) + "\n\n" return cross_file_context, cfc_text, meta_data def read_project_files(repo_name, lang): # root_dir needs a trailing slash (i.e. /root/dir/) project_context = {} root_dir = os.path.join(repository_root, lang, repo_name) if not os.path.isdir(root_dir): print(f"Repository not found: {root_dir}") return project_context if lang == "typescript": src_files = [] src_files += glob.glob(os.path.join(root_dir, f'src/**/*.ts'), recursive=True) src_files += glob.glob(os.path.join(root_dir, f'src/**/*.tsx'), recursive=True) else: src_files = glob.glob(os.path.join(root_dir, f'**/*.{file_ext[lang]}'), recursive=True) if len(src_files) == 0: return project_context for filename in src_files: if os.path.exists(filename): # weird but some files cannot be opened to read if os.path.isfile(filename): try: with open(filename, "r") as file: file_content = file.read() except: with open(filename, "rb") as file: file_content = file.read().decode(errors='replace') fileid = os.path.relpath(filename, root_dir) project_context[fileid] = file_content else: pass # print(f"File not found: {filename}") return project_context def find_files_within_distance_k(current_file_path, filelist, k): list_of_modules = [] module_weight = [] for filepath in filelist: if filepath != current_file_path: dist = file_distance(filepath, current_file_path) if dist == -1: continue elif dist <= k: list_of_modules.append(filepath) module_weight.append(dist) # sorting in ascending order list_of_modules = [x for _, x in sorted(zip(module_weight, list_of_modules))] return list_of_modules def get_cfc(example, args, semantic_ranker, repositories): project_context = repositories[example["metadata"]["repository"]] status = None current_filepath = example["metadata"]["file"] if len(project_context) == 0: example["crossfile_context"] = "" status = "project_not_found" else: current_filecontent = None for filepath, filecontent in project_context.items(): if filepath == current_filepath: current_filecontent = filecontent break if current_filecontent is None: example["crossfile_context"] = {} print(current_filepath) status = "file_not_found_in_project" else: pyfiles = find_files_within_distance_k( example["metadata"]["file"], list(project_context.keys()), k=args.crossfile_distance ) pyfiles = pyfiles[:args.maximum_cross_files] code_chunks = [] code_chunk_ids = [] for pyfile in pyfiles: lines = project_context[pyfile].split("\n") lines = [l for l in lines if l.strip()] # removing empty lines c_id = 0 for i in range(0, len(lines), SLIDING_WINDOW_SIZE): c = "\n".join(lines[i:i + CHUNK_SIZE]) tokenized_c = tokenize_nltk(c) if len(tokenized_c) > 0: code_chunks.append(c) code_chunk_ids.append(f"{pyfile}|{c_id}") c_id += 1 if len(code_chunks) == 0: example["crossfile_context"] = {} status = "no_crossfile_context" else: cfc, cfc_text, meta_data = get_crossfile_context_from_chunks( args=args, prompt=example["prompt"], code_chunks=code_chunks, code_chunk_ids=code_chunk_ids, groundtruth=example["groundtruth"], semantic_ranker=semantic_ranker ) example["crossfile_context"] = {} example["crossfile_context"]["text"] = cfc_text example["crossfile_context"]["list"] = cfc return example, status def attach_data(args, srcfile): empty_cfc = 0 error_freq = { "project_not_found": 0, "file_not_found_in_project": 0, "no_crossfile_context": 0 } output_examples = [] examples = [] repositories = dict() with open(srcfile) as f: for line in f: ex = json.loads(line) repo_name = ex["metadata"]["repository"] if repo_name not in repositories: repositories[repo_name] = read_project_files(repo_name, args.language) examples.append(ex) semantic_ranker = None if args.ranking_fn == "cosine_sim": semantic_ranker = SemanticReranking( args.ranker, max_sequence_length=256 ) pool = mp.Pool(args.num_processes) worker = partial(get_cfc, args=args, semantic_ranker=semantic_ranker, repositories=repositories) with tqdm(total=len(examples)) as pbar: for (d, stat) in pool.imap_unordered(worker, examples): if stat in error_freq: error_freq[stat] += 1 if len(d["crossfile_context"]) == 0: empty_cfc += 1 if not args.skip_if_no_cfc: output_examples.append(d) else: output_examples.append(d) pbar.update() print("Total examples with empty CFC: ", empty_cfc) print(error_freq) return output_examples if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--rerank", type=str2bool, default=True, help="rerank the functions" ) parser.add_argument( "--ranker", type=str, default="sparse", choices=["sparse", "unixcoder"], help="ranking function" ) parser.add_argument( "--ranking_fn", type=str, default="bm25", choices=["tfidf", "bm25", "jaccard_sim", "cosine_sim"], help="ranking function" ) parser.add_argument( "--query_type", type=str, default="last_n_lines", choices=["last_n_lines", "groundtruth"], help="how to form query from prompt" ) parser.add_argument( "--crossfile_distance", type=int, default=100, help="max distance to search for crossfile" ) parser.add_argument( "--maximum_chunk_to_rerank", type=int, default=1000, help="max chunks to consider to rank via BM25" ) parser.add_argument( "--maximum_cross_files", type=int, default=1000, help="max chunks to consider to rank via BM25" ) parser.add_argument( "--maximum_cross_file_chunk", type=int, default=50, help="max chunks to return as cfc" ) parser.add_argument( "--use_next_chunk_as_cfc", type=str2bool, default=True, help="use next code chunk as context" ) parser.add_argument( "--skip_if_no_cfc", type=str2bool, default=True, help="skip adding examples if there is no crossfile context" ) parser.add_argument( "--output_file_suffix", type=str, default=None, help="add a suffix string to the output file" ) parser.add_argument( "--language", type=str, required=True, choices=["java", "python", "typescript", "csharp"], help="language name" ) args = parser.parse_args() args.output_file_suffix = "" if args.output_file_suffix is None else f"_{args.output_file_suffix}" if args.use_next_chunk_as_cfc: assert args.rerank assert args.query_type != "groundtruth" tgtfile_suffix = "" if args.rerank: tgtfile_suffix += f"_{args.ranking_fn}" args.num_processes = 60 if args.ranking_fn == "cosine_sim": num_gpus = 8 args.num_processes = num_gpus mp.set_start_method('spawn') input_file = input_files[args.language] output_path = os.path.dirname(input_file) output_filename = os.path.splitext(os.path.basename(input_file))[0] output_filename = output_filename + args.output_file_suffix + tgtfile_suffix + ".jsonl" output_file = os.path.join(output_path, output_filename) output_examples = attach_data(args, input_file) with open(output_file, "w") as fw: for ex in output_examples: fw.write(json.dumps(ex)) fw.write("\n") ================================================ FILE: prompt_builder/rerank_utils.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from rank_bm25 import BM25Okapi from typing import List from multiprocessing import Pool, cpu_count from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity from utils import tokenize_nltk from transformers import AutoModel, AutoTokenizer, AutoConfig def jaccard_similarity(tokenized_query, tokenized_doc, containment=False): set1 = set(tokenized_query) set2 = set(tokenized_doc) intersection = len(set1.intersection(set2)) union = len(set1) if containment else len(set1.union(set2)) return float(intersection) / union def tokenize_corpus(corpus, tokenizer_fn): pool = Pool(cpu_count()) tokenized_corpus = pool.map(tokenizer_fn, corpus) return tokenized_corpus def tokenize_query_and_docs(query, docs): tokenized_query = tokenize_nltk(query) tokenized_docs = [tokenize_nltk(d) for d in docs] return tokenized_query, tokenized_docs def lexical_ranking( query, docs, ranking_fn, doc_ids=None, score_threshold=None, ): if ranking_fn == "bm25": tokenized_query, tokenized_docs = tokenize_query_and_docs(query, docs) bm25 = BM25Okapi(tokenized_docs) scores = bm25.get_scores(tokenized_query) elif ranking_fn == "tfidf": tfidf_vectorizer = TfidfVectorizer(tokenizer=tokenize_nltk) X = tfidf_vectorizer.fit_transform(docs).toarray() # (n_fn, n_features) y = tfidf_vectorizer.transform([query]).toarray() # (1, n_features) scores = cosine_similarity(X, y).tolist() # (n_fn, 1) elif ranking_fn == "jaccard_sim": tokenized_query, tokenized_docs = tokenize_query_and_docs(query, docs) scores = [jaccard_similarity(tokenized_query, d, containment=False) for d in tokenized_docs] else: raise NotImplementedError if score_threshold: skip_ids = [idx for idx, s in enumerate(scores) if s < score_threshold] scores = [s for idx, s in enumerate(scores) if idx not in skip_ids] docs = [d for idx, d in enumerate(docs) if idx not in skip_ids] if doc_ids is not None: doc_ids = [doc_id for idx, doc_id in enumerate(doc_ids) if idx not in skip_ids] if len(docs) == 0: return docs, doc_ids, scores if doc_ids is not None: doc_ids = [x for _, x in sorted(zip(scores, doc_ids), reverse=True)] docs_scores = [(x, s) for s, x in sorted(zip(scores, docs), reverse=True)] docs = [item[0] for item in docs_scores] scores = [item[1] for item in docs_scores] return docs, doc_ids, scores class SemanticReranking: def __init__(self, model_type="unixcoder", **kwargs): self.model_type = model_type if model_type == "unixcoder": self.tokenizer = AutoTokenizer.from_pretrained('microsoft/unixcoder-base') self.model = AutoModel.from_pretrained('microsoft/unixcoder-base') else: raise NotImplementedError # maximum sequence length for query and documents self.max_sequence_length = kwargs.get("max_sequence_length", 256) def text_to_tensor( self, text: str, pad_to_max: bool = True, ): text = text.strip() # tokenizer automatic padding is explicitly disabled since its inconsistent behavior token_ids = self.tokenizer.encode( text, add_special_tokens=False, max_length=self.max_sequence_length, pad_to_max_length=False, truncation=True ) if pad_to_max and len(token_ids) < self.max_sequence_length: token_ids = token_ids + [self.tokenizer.pad_token_id] * (self.max_sequence_length - len(token_ids)) if len(token_ids) > self.max_sequence_length: token_ids = token_ids[0:self.max_sequence_length] return torch.tensor(token_ids) def get_pad_id(self): return self.tokenizer.pad_token_id def get_attn_mask(self, tokens_tensor): return tokens_tensor != self.get_pad_id() def get_representations(self, list_input_ids, gpu_id): device = torch.device('cuda', gpu_id) self.model = self.model.to(device=device, dtype=torch.float16) self.model.eval() batch_size = 64 sequence_outputs = [] pooled_outputs = [] for idx in range(0, len(list_input_ids), batch_size): start, end = idx, min(idx + batch_size, len(list_input_ids)) input_ids = torch.stack(list_input_ids[start:end], dim=0).to(device=device) attention_mask = self.get_attn_mask(input_ids) if self.model_type in CODE_SAGE_MODELS.keys(): output = self.model( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True ) token_embeddings = output.hidden_states[-1] # bsz x seq_len x hid_dim else: output = self.model(input_ids, attention_mask) token_embeddings = output.last_hidden_state # bsz x seq_len x hid_dim mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() sum_embeddings = torch.sum(token_embeddings * mask_expanded, 1) sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9) sequence_embeddings = sum_embeddings / sum_mask # bsz x hid_dim sequence_outputs.append(token_embeddings) pooled_outputs.append(sequence_embeddings) sequence_output = torch.cat(sequence_outputs) pooled_output = torch.cat(pooled_outputs) return sequence_output, pooled_output def rerank(self, query: str, docs: List[str], doc_ids: List[str] = None, gpu_id=0, score_threshold=None): with torch.no_grad(): batch_queries = [self.text_to_tensor(query)] batch_candidates = [self.text_to_tensor(d) for d in docs] _, query_rep = self.get_representations(batch_queries, gpu_id) # 1 x hidden_size _, candi_rep = self.get_representations(batch_candidates, gpu_id) # num_cand x hidden_size scores = torch.nn.functional.cosine_similarity(query_rep, candi_rep).tolist() # num_cand if score_threshold: skip_ids = [idx for idx, s in enumerate(scores) if s < score_threshold] scores = [s for idx, s in enumerate(scores) if idx not in skip_ids] docs = [d for idx, d in enumerate(docs) if idx not in skip_ids] if doc_ids is not None: doc_ids = [doc_id for idx, doc_id in enumerate(doc_ids) if idx not in skip_ids] if len(docs) == 0: return docs, doc_ids, scores if doc_ids is not None: doc_ids = [x for _, x in sorted(zip(scores, doc_ids), reverse=True)] docs_scores = [(x, s) for s, x in sorted(zip(scores, docs), reverse=True)] docs = [item[0] for item in docs_scores] scores = [item[1] for item in docs_scores] return docs, doc_ids, scores ================================================ FILE: prompt_builder/run.sh ================================================ #!/usr/bin/env bash export PYTHONIOENCODING=utf-8 function generate_data() { lang=$1 ranker=$2 ranking_fn=$3 echo "$lang, $ranker, $ranking_fn" output_file_suffix="" if [[ $ranker != "sparse" ]]; then output_file_suffix="_${ranker}" fi # for RG-1 python augment_with_cfc.py \ --language $lang \ --rerank True \ --ranker $ranker \ --ranking_fn $ranking_fn \ --query_type last_n_lines \ --crossfile_distance 100 \ --maximum_chunk_to_rerank 1000 \ --maximum_cross_files 1000 \ --maximum_cross_file_chunk 5 \ --use_next_chunk_as_cfc True \ --skip_if_no_cfc False \ --output_file_suffix "rg1${output_file_suffix}" # for oracle experiment python augment_with_cfc.py \ --language $lang \ --rerank True \ --ranker $ranker \ --ranking_fn $ranking_fn \ --query_type groundtruth \ --crossfile_distance 100 \ --maximum_chunk_to_rerank 1000 \ --maximum_cross_files 1000 \ --maximum_cross_file_chunk 5 \ --use_next_chunk_as_cfc False \ --skip_if_no_cfc False \ --output_file_suffix "oracle${output_file_suffix}" } generate_data python sparse bm25 generate_data java sparse bm25 generate_data typescript sparse bm25 generate_data csharp sparse bm25 generate_data python sparse jaccard_sim generate_data java sparse jaccard_sim generate_data typescript sparse jaccard_sim generate_data csharp sparse jaccard_sim generate_data python unixcoder cosine_sim generate_data java unixcoder cosine_sim generate_data typescript unixcoder cosine_sim generate_data csharp unixcoder cosine_sim ================================================ FILE: prompt_builder/utils.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import re import os from typing import List from nltk.tokenize import word_tokenize def tokenize_nltk(text): words = word_tokenize(text) output_list = [] for w in words: w_list = re.findall(r'\w+', w) output_list.extend(w_list) return output_list def file_distance(src_file, dest_file): distance = -1 try: commonpath = os.path.commonpath([src_file, dest_file]) rel_file1_path = os.path.relpath(src_file, commonpath) rel_file2_path = os.path.relpath(dest_file, commonpath) distance = rel_file1_path.count(os.sep) + rel_file2_path.count(os.sep) except Exception as e: # print(e, src_file, dest_file) pass return distance def str2bool(v): if isinstance(v, bool): return v if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Boolean value expected.') ================================================ FILE: requirements.txt ================================================ torch transformers datasets tree-sitter timeout-decorator bitsandbytes accelerate scikit-learn rank-bm25 fuzzywuzzy nltk sacrebleu deepspeed tiktoken vllm>=0.3.3 ================================================ FILE: scripts/build_treesitter.sh ================================================ mkdir ts_package; cd ts_package; # Download the tree-sitter package git clone https://github.com/tree-sitter/tree-sitter-python.git; git clone https://github.com/tree-sitter/tree-sitter-java.git; git clone https://github.com/tree-sitter/tree-sitter-c-sharp.git; git clone https://github.com/tree-sitter/tree-sitter-typescript.git; cd ..; # Build tree-sitter python scripts/build_ts_lib.py ================================================ FILE: scripts/build_ts_lib.py ================================================ #!/usr/bin/env python # coding=utf-8 # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. from tree_sitter import Language def build_language_lib(): for lang in ["java", "python", "typescript", "csharp"]: ts_lang = "c-sharp" if lang == "csharp" else lang if lang == "typescript": git_dir = f"ts_package/tree-sitter-{ts_lang}/{lang}" else: git_dir = f"ts_package/tree-sitter-{ts_lang}" Language.build_library(f'build/{lang}-lang-parser.so', [git_dir]) if __name__ == "__main__": build_language_lib() ================================================ FILE: scripts/custom_generate.py ================================================ # Modifications Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ############################################################################## # Modified from https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/utils.py ############################################################################## import copy import inspect import warnings from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist from torch import nn from transformers.deepspeed import is_deepspeed_zero3_enabled from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput from transformers.models.auto import ( MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_VISION_2_SEQ_MAPPING, ) from transformers.utils import ModelOutput, logging from transformers.generation.beam_constraints import DisjunctiveConstraint, PhrasalConstraint from transformers.generation.beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from transformers.generation.configuration_utils import GenerationConfig from transformers.generation.logits_process import ( EncoderNoRepeatNGramLogitsProcessor, EncoderRepetitionPenaltyLogitsProcessor, EpsilonLogitsWarper, EtaLogitsWarper, ExponentialDecayLengthPenalty, ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, ForceTokensLogitsProcessor, HammingDiversityLogitsProcessor, InfNanRemoveLogitsProcessor, LogitNormalization, LogitsProcessorList, MinLengthLogitsProcessor, MinNewTokensLengthLogitsProcessor, NoBadWordsLogitsProcessor, NoRepeatNGramLogitsProcessor, PrefixConstrainedLogitsProcessor, RepetitionPenaltyLogitsProcessor, SuppressTokensAtBeginLogitsProcessor, SuppressTokensLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper, ) from transformers.generation.stopping_criteria import ( MaxLengthCriteria, MaxTimeCriteria, StoppingCriteria, StoppingCriteriaList, validate_stopping_criteria, ) from transformers.generation.utils import GenerateOutput, SampleOutput, SampleDecoderOnlyOutput if TYPE_CHECKING: from transformers.modeling_utils import PreTrainedModel from transformers.generation.streamers import BaseStreamer logger = logging.get_logger(__name__) @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, synced_gpus: Optional[bool] = None, assistant_model: Optional["PreTrainedModel"] = None, streamer: Optional["BaseStreamer"] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: r""" Generates sequences of token ids for models with a language modeling head. Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the model's default generation configuration. You can override any `generation_config` by passing the corresponding parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. For an overview of generation strategies and code examples, check out the [following guide](../generation_strategies). Parameters: inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of `input_ids`, `input_values`, `input_features`, or `pixel_values`. generation_config (`~generation.GenerationConfig`, *optional*): The generation configuration to be used as base parametrization for the generation call. `**kwargs` passed to generate matching the attributes of `generation_config` will override them. If `generation_config` is not provided, the default will be used, which had the following loading priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s default values, whose documentation should be checked to parameterize generation. logits_processor (`LogitsProcessorList`, *optional*): Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. stopping_criteria (`StoppingCriteriaList`, *optional*): Custom stopping criteria that complement the default stopping criteria built from arguments and a generation config. If a stopping criteria is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): If provided, this function constraints the beam search to allowed tokens only at each step. If not provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful for constrained generation conditioned on the prefix, as described in [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904). synced_gpus (`bool`, *optional*): Whether to continue running the while loop until max_length. Unless overridden this flag will be set to `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished generating before other GPUs. Otherwise it'll be set to `False`. assistant_model (`PreTrainedModel`, *optional*): An assistant model that can be used to accelerate generation. The assistant model must have the exact same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model is much faster than running generation with the model you're calling generate from. As such, the assistant model should be much smaller. streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. kwargs: Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. Return: [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible [`~utils.ModelOutput`] types are: - [`~generation.GreedySearchDecoderOnlyOutput`], - [`~generation.SampleDecoderOnlyOutput`], - [`~generation.BeamSearchDecoderOnlyOutput`], - [`~generation.BeamSampleDecoderOnlyOutput`] If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible [`~utils.ModelOutput`] types are: - [`~generation.GreedySearchEncoderDecoderOutput`], - [`~generation.SampleEncoderDecoderOutput`], - [`~generation.BeamSearchEncoderDecoderOutput`], - [`~generation.BeamSampleEncoderDecoderOutput`] """ if synced_gpus is None: if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1: synced_gpus = True else: synced_gpus = False # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() # priority: `generation_config` argument > `model.generation_config` (the default generation config) if generation_config is None: # legacy: users may modify the model configuration to control generation -- update the generation config # model attribute accordingly, if it was created from the model config if self.generation_config._from_model_config: new_generation_config = GenerationConfig.from_model_config(self.config) if new_generation_config != self.generation_config: warnings.warn( "You have modified the pretrained model configuration to control generation. This is a" " deprecated strategy to control generation and will be removed soon, in a future version." " Please use a generation configuration file (see" " https://huggingface.co/docs/transformers/main_classes/text_generation)" ) self.generation_config = new_generation_config generation_config = self.generation_config generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs generation_config.validate() self._validate_model_kwargs(model_kwargs.copy()) # 2. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: if model_kwargs.get("attention_mask", None) is None: logger.warning( "The attention mask and the pad token id were not set. As a consequence, you may observe " "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." ) eos_token_id = generation_config.eos_token_id if isinstance(eos_token_id, list): eos_token_id = eos_token_id[0] logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") generation_config.pad_token_id = eos_token_id # 3. Define model inputs # inputs_tensor has to be defined # model_input_name is defined if model-specific keyword input is passed # otherwise model_input_name is None # all model-specific keyword inputs are removed from `model_kwargs` inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) batch_size = inputs_tensor.shape[0] # 4. Define other model kwargs model_kwargs["output_attentions"] = generation_config.output_attentions model_kwargs["output_hidden_states"] = generation_config.output_hidden_states model_kwargs["use_cache"] = generation_config.use_cache accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id ) # decoder-only models should use left-padding for generation if not self.config.is_encoder_decoder: if ( generation_config.pad_token_id is not None and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0 ): logger.warning( "A decoder-only architecture is being used, but right-padding was detected! For correct " "generation results, please set `padding_side='left'` when initializing the tokenizer." ) if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: # if model is encoder decoder encoder_outputs are created # and added to `model_kwargs` model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( inputs_tensor, model_kwargs, model_input_name ) # 5. Prepare `input_ids` which will be used for auto-regressive generation if self.config.is_encoder_decoder: input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( batch_size=batch_size, model_input_name=model_input_name, model_kwargs=model_kwargs, decoder_start_token_id=generation_config.decoder_start_token_id, bos_token_id=generation_config.bos_token_id, device=inputs_tensor.device, ) else: input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") if streamer is not None: streamer.put(input_ids.cpu()) # 6. Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" " recommend using `max_new_tokens` to control the maximum length of the generation.", UserWarning, ) elif generation_config.max_new_tokens is not None: if not has_default_max_length: logger.warning( f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " "Please refer to the documentation for more information. " "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ) generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: raise ValueError( f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" f" the maximum length ({generation_config.max_length})" ) if input_ids_seq_length >= generation_config.max_length: input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" logger.warning( f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" " increasing `max_new_tokens`." ) # 7. determine generation mode is_constraint_gen_mode = ( generation_config.constraints is not None or generation_config.force_words_ids is not None ) is_contrastive_search_gen_mode = ( (generation_config.num_beams == 1) and generation_config.top_k is not None and generation_config.top_k > 1 and generation_config.do_sample is False and generation_config.penalty_alpha is not None and generation_config.penalty_alpha > 0 ) is_greedy_gen_mode = ( (generation_config.num_beams == 1) and (generation_config.num_beam_groups == 1) and generation_config.do_sample is False and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_sample_gen_mode = ( (generation_config.num_beams == 1) and (generation_config.num_beam_groups == 1) and generation_config.do_sample is True and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_beam_gen_mode = ( (generation_config.num_beams > 1) and (generation_config.num_beam_groups == 1) and generation_config.do_sample is False and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_beam_sample_gen_mode = ( (generation_config.num_beams > 1) and (generation_config.num_beam_groups == 1) and generation_config.do_sample is True and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_group_beam_gen_mode = ( (generation_config.num_beams > 1) and (generation_config.num_beam_groups > 1) and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_assisted_gen_mode = False if assistant_model is not None: if not (is_greedy_gen_mode or is_sample_gen_mode): raise ValueError( "You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate " "is only supported with Greedy Search and Sample." ) is_assisted_gen_mode = True if generation_config.num_beam_groups > generation_config.num_beams: raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") if is_group_beam_gen_mode and generation_config.do_sample is True: raise ValueError( "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." ) if streamer is not None and (generation_config.num_beams > 1): raise ValueError( "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." ) if self.device.type != input_ids.device.type: warnings.warn( "You are calling .generate() with the `input_ids` being on a device type different" f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." " Please make sure that you have put `input_ids` to the" f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" " running `.generate()`.", UserWarning, ) # 8. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_seq_length, encoder_input_ids=inputs_tensor, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, logits_processor=logits_processor, ) # 9. prepare stopping criteria stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria ) # 10. go into different generation modes if is_assisted_gen_mode: if generation_config.num_return_sequences > 1: raise ValueError( "num_return_sequences has to be 1 when doing assisted generate, " f"but is {generation_config.num_return_sequences}." ) if batch_size > 1: raise ValueError("assisted generate is only supported for batch_size = 1") if not model_kwargs["use_cache"]: raise ValueError("assisted generate requires `use_cache=True`") # 11. If the assistant model is an encoder-decoder, prepare its encoder outputs if assistant_model.config.is_encoder_decoder: assistant_model_kwargs = copy.deepcopy(model_kwargs) inputs_tensor, model_input_name, assistant_model_kwargs = assistant_model._prepare_model_inputs( inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_model_kwargs ) assistant_model_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation( inputs_tensor, assistant_model_kwargs, model_input_name ) model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs["encoder_outputs"] # 12. run assisted generate return self.assisted_decoding( input_ids, assistant_model=assistant_model, do_sample=generation_config.do_sample, logits_processor=logits_processor, logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None, stopping_criteria=stopping_criteria, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, streamer=streamer, **model_kwargs, ) if is_greedy_gen_mode: if generation_config.num_return_sequences > 1: raise ValueError( "num_return_sequences has to be 1 when doing greedy search, " f"but is {generation_config.num_return_sequences}." ) # 11. run greedy search return self.greedy_search( input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, streamer=streamer, **model_kwargs, ) elif is_contrastive_search_gen_mode: if generation_config.num_return_sequences > 1: raise ValueError( "num_return_sequences has to be 1 when doing contrastive search, " f"but is {generation_config.num_return_sequences}." ) if not model_kwargs["use_cache"]: raise ValueError("Contrastive search requires `use_cache=True`") return self.contrastive_search( input_ids, top_k=generation_config.top_k, penalty_alpha=generation_config.penalty_alpha, logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, streamer=streamer, **model_kwargs, ) elif is_sample_gen_mode: # 11. prepare logits warper logits_warper = self._get_logits_warper(generation_config) # 12. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) # 13. run sample return sample( self, input_ids, logits_processor=logits_processor, logits_warper=logits_warper, stopping_criteria=stopping_criteria, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, streamer=streamer, **model_kwargs, ) elif is_beam_gen_mode: if generation_config.num_return_sequences > generation_config.num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") # 11. prepare beam search scorer beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=generation_config.num_beams, device=inputs_tensor.device, length_penalty=generation_config.length_penalty, do_early_stopping=generation_config.early_stopping, num_beam_hyps_to_keep=generation_config.num_return_sequences, max_length=generation_config.max_length, ) # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) # 13. run beam search return self.beam_search( input_ids, beam_scorer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_beam_sample_gen_mode: # 11. prepare logits warper logits_warper = self._get_logits_warper(generation_config) if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") # 12. prepare beam search scorer beam_scorer = BeamSearchScorer( batch_size=batch_size * generation_config.num_return_sequences, num_beams=generation_config.num_beams, device=inputs_tensor.device, length_penalty=generation_config.length_penalty, do_early_stopping=generation_config.early_stopping, max_length=generation_config.max_length, ) # 13. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_beams * generation_config.num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) # 14. run beam sample return self.beam_sample( input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, stopping_criteria=stopping_criteria, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_group_beam_gen_mode: if generation_config.num_return_sequences > generation_config.num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") if generation_config.num_beams % generation_config.num_beam_groups != 0: raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.") if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") has_default_typical_p = kwargs.get("typical_p") is None and generation_config.typical_p == 1.0 if not has_default_typical_p: raise ValueError("Decoder argument `typical_p` is not supported with beam groups.") # 11. prepare beam search scorer beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=generation_config.num_beams, device=inputs_tensor.device, length_penalty=generation_config.length_penalty, do_early_stopping=generation_config.early_stopping, num_beam_hyps_to_keep=generation_config.num_return_sequences, num_beam_groups=generation_config.num_beam_groups, max_length=generation_config.max_length, ) # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) # 13. run beam search return self.group_beam_search( input_ids, beam_scorer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_constraint_gen_mode: if generation_config.num_return_sequences > generation_config.num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") if generation_config.num_beams <= 1: raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.") if generation_config.do_sample: raise ValueError("`do_sample` needs to be false for constrained generation.") if generation_config.num_beam_groups is not None and generation_config.num_beam_groups > 1: raise ValueError("`num_beam_groups` not supported yet for constrained generation.") final_constraints = [] if generation_config.constraints is not None: final_constraints = generation_config.constraints if generation_config.force_words_ids is not None: def typeerror(): raise ValueError( "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" f"of positive integers, but is {generation_config.force_words_ids}." ) if ( not isinstance(generation_config.force_words_ids, list) or len(generation_config.force_words_ids) == 0 ): typeerror() for word_ids in generation_config.force_words_ids: if isinstance(word_ids[0], list): if not isinstance(word_ids, list) or len(word_ids) == 0: typeerror() if any(not isinstance(token_ids, list) for token_ids in word_ids): typeerror() if any( any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids) for token_ids in word_ids ): typeerror() constraint = DisjunctiveConstraint(word_ids) else: if not isinstance(word_ids, list) or len(word_ids) == 0: typeerror() if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids): typeerror() constraint = PhrasalConstraint(word_ids) final_constraints.append(constraint) # 11. prepare beam search scorer constrained_beam_scorer = ConstrainedBeamSearchScorer( constraints=final_constraints, batch_size=batch_size, num_beams=generation_config.num_beams, device=inputs_tensor.device, length_penalty=generation_config.length_penalty, do_early_stopping=generation_config.early_stopping, num_beam_hyps_to_keep=generation_config.num_return_sequences, max_length=generation_config.max_length, ) # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) # 13. run beam search return self.constrained_beam_search( input_ids, constrained_beam_scorer=constrained_beam_scorer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) def sample( self, input_ids: torch.LongTensor, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[SampleOutput, torch.LongTensor]: r""" Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead. For an overview of generation strategies and code examples, check the [following guide](../generation_strategies). Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. logits_processor (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. stopping_criteria (`StoppingCriteriaList`, *optional*): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. logits_warper (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used to warp the prediction score distribution of the language modeling head applied before multinomial sampling at each generation step. max_length (`int`, *optional*, defaults to 20): **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated tokens. The maximum length of the sequence to be generated. pad_token_id (`int`, *optional*): The id of the *padding* token. eos_token_id (`Union[int, List[int]]`, *optional*): The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. output_hidden_states (`bool`, *optional*, defaults to `False`): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more details. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. Return: [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. Examples: ```python >>> from transformers import ( ... AutoTokenizer, ... AutoModelForCausalLM, ... LogitsProcessorList, ... MinLengthLogitsProcessor, ... TopKLogitsWarper, ... TemperatureLogitsWarper, ... StoppingCriteriaList, ... MaxLengthCriteria, ... ) >>> import torch >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2") >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token >>> model.config.pad_token_id = model.config.eos_token_id >>> model.generation_config.pad_token_id = model.config.eos_token_id >>> input_prompt = "Today is a beautiful day, and" >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids >>> # instantiate logits processors >>> logits_processor = LogitsProcessorList( ... [ ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id), ... ] ... ) >>> # instantiate logits processors >>> logits_warper = LogitsProcessorList( ... [ ... TopKLogitsWarper(50), ... TemperatureLogitsWarper(0.7), ... ] ... ) >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT >>> outputs = model.sample( ... input_ids, ... logits_processor=logits_processor, ... logits_warper=logits_warper, ... stopping_criteria=stopping_criteria, ... ) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Today is a beautiful day, and we must do everything possible to make it a day of celebration.'] ```""" # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.generation_config.return_dict_in_generate ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # keep track of which sequences are already finished unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) this_peer_finished = False # used by synced_gpus only # auto-regressive generation while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need next_token_logits = outputs.logits[:, -1, :] # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: # scores += (next_token_scores,) scores += (next_token_logits,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # sample probs = nn.functional.softmax(next_token_scores, dim=-1) if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)): probs = torch.nan_to_num(probs) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) # stop when each sentence is finished if unfinished_sequences.max() == 0: this_peer_finished = True # stop if we exceed the maximum length if stopping_criteria(input_ids, scores): this_peer_finished = True if this_peer_finished and not synced_gpus: break if streamer is not None: streamer.end() if return_dict_in_generate: if self.config.is_encoder_decoder: return SampleEncoderDecoderOutput( sequences=input_ids, scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, ) else: return SampleDecoderOnlyOutput( sequences=input_ids, scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: return input_ids ================================================ FILE: scripts/eval.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import json import logging import os import numpy as np import torch from accelerate import Accelerator from accelerate.utils import set_seed from datasets import load_dataset from torch.utils.data import DataLoader, SequentialSampler from tqdm import tqdm from transformers import ( AutoTokenizer, AutoModelForCausalLM ) import custom_generate from eval_metric import compute_metric_stmt from eval_utils import compute_mean_logp logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger = logging.getLogger(__name__) COMMENT_SYMBOL = { "python": "#", "java": "//", "csharp": "//", "typescript": "//" } def custom_data_collator(features): first = features[0] batch = {} for k, v in first.items(): if v is not None and not isinstance(v, str): if isinstance(v, torch.Tensor): batch[k] = torch.stack([f[k] for f in features]) elif isinstance(v, np.ndarray): batch[k] = torch.tensor(np.stack([f[k] for f in features])) else: batch[k] = torch.tensor([f[k] for f in features]) if v is not None and isinstance(v, str): batch[k] = [f[k] for f in features] return batch def build_datasets(args, tokenizer): # Initialize the model and tokenizer # when generating, we will use the logits of right-most token to predict the next token # so the padding should be on the left tokenizer.padding_side = "left" tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else tokenizer.bos_token # load the files into Dataset raw_datasets = load_dataset("json", data_files=args.prompt_file, cache_dir=args.cache_dir) raw_datasets = raw_datasets["train"] raw_datasets = raw_datasets.map(lambda example, idx: {'index': idx, **example}, with_indices=True) index2taskid = {idx: md["task_id"] for idx, md in zip(raw_datasets["index"], raw_datasets["metadata"])} column_names = raw_datasets.column_names # Prompt composition def prepare_features(examples): tokenizer.truncation_side = "left" tokenized_inputs = tokenizer( examples["prompt"], padding="max_length", truncation=True, max_length=args.max_seq_length - args.gen_length ) features = {k: t for k, t in tokenized_inputs.items()} features["index"] = examples["index"] return features def prepare_features_cfc(examples): max_prompt_length = args.max_seq_length - args.gen_length use_key = "list" crossfile_context = [] if use_key == "text": crossfile_context = [ex["text"] for ex in examples["crossfile_context"]] else: ls_sym = COMMENT_SYMBOL[args.language] num_chunk_inc_prompt = [] augmented_prompt = 0 for cfc_chunks in examples["crossfile_context"]: cfc_chunks = cfc_chunks["list"] # a list of dict cfc_text = "" if cfc_chunks: # at least 1 relevant cfc_chunk found init_cfc_text = f"{ls_sym} Here are some relevant code fragments from other files of the repo:\n\n" cfc_length = len(tokenizer.tokenize(init_cfc_text)) num_chunk_inc = 0 for cfc_idx, cfc_chunk in enumerate(cfc_chunks): if cfc_chunk["score"] > args.min_cfc_score: add_text = f"{ls_sym} the below code fragment is found in {cfc_chunk['filename']}" + "\n" cfc_lines = cfc_chunk["retrieved_chunk"].split('\n') add_text += "\n".join([f"{ls_sym} {cl}" for cl in cfc_lines if cl]) + "\n\n" # check if adding chunk exceeds max length budget for CFC add_text_len = len(tokenizer.tokenize(add_text)) if cfc_length + add_text_len <= args.cfc_seq_length: cfc_text += add_text cfc_length += add_text_len num_chunk_inc += 1 else: break num_chunk_inc_prompt.append(num_chunk_inc) if num_chunk_inc > 0: cfc_text = init_cfc_text + cfc_text augmented_prompt += 1 crossfile_context.append(cfc_text) logger.info( f"{augmented_prompt} out of {len(examples['crossfile_context'])} prompts are augmented with cross-file context.") tokenizer.truncation_side = "right" crossfile_features = tokenizer( crossfile_context, truncation=True, max_length=args.cfc_seq_length ) features = {"input_ids": [], "attention_mask": []} tokenizer.truncation_side = "left" for idx, prompt in enumerate(examples["prompt"]): allowed_prompt_length = max_prompt_length - len(crossfile_features["input_ids"][idx]) prompt_feats = tokenizer( [prompt], truncation=True, max_length=allowed_prompt_length ) for k, v in prompt_feats.items(): features[k].append(crossfile_features[k][idx] + prompt_feats[k][0]) # pad to max_seq_length tokenizer.padding_side = "left" features = tokenizer.pad(features, padding="max_length", max_length=args.max_seq_length - args.gen_length) features["index"] = examples["index"] return features if args.model_type in ["codelm", "seq2seqlm"]: tokenized_datasets = raw_datasets.map( prepare_features, batched=True, num_proc=args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not args.overwrite_cache, desc="Running tokenizer on dataset", ) elif args.model_type == "codelm_cfc": tokenized_datasets = raw_datasets.map( prepare_features_cfc, batched=True, num_proc=args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not args.overwrite_cache, desc="Running tokenizer on dataset", ) else: raise NotImplementedError("prepare feature functions not implemented for new model type") return tokenized_datasets, index2taskid def model_inference(tokenized_datasets, index2taskid, tokenizer): if args.dtype == 'fp16': dtype = torch.float16 elif args.dtype == 'fp32': dtype = torch.float32 elif args.dtype == 'bf16': dtype = torch.bfloat16 elif args.dtype == 'int8': dtype = torch.int8 else: assert False, f'{args.dtype=} not implemented' if args.model_type in ["codelm", "codelm_cfc"]: model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, torch_dtype=dtype, trust_remote_code=True, revision="main" ) else: raise ValueError("Unknown model type") total_samples_cnt = len(tokenized_datasets) logger.info(f"total samples: {total_samples_cnt}") data_sampler = SequentialSampler(tokenized_datasets) dataloader = DataLoader( tokenized_datasets, sampler=data_sampler, collate_fn=custom_data_collator, batch_size=args.batch_size ) model = accelerator.prepare_model(model) dataloader = accelerator.prepare_data_loader(dataloader) if not os.path.isdir(args.output_dir): os.mkdir(args.output_dir) tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else tokenizer.bos_token prompt_length = args.max_seq_length - args.gen_length @torch.no_grad() def generate_completions(batch): output_dict = custom_generate.generate( accelerator.unwrap_model(model), input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], max_length=args.max_seq_length, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, do_sample=args.do_sample, num_beams=args.num_beams, num_return_sequences=1, pad_token_id=tokenizer.pad_token_id, return_dict_in_generate=True, output_scores=True ) batch_task_id = batch["index"] batch_pred = accelerator.pad_across_processes( output_dict.sequences, dim=1, pad_index=tokenizer.pad_token_id ) scores = torch.stack(output_dict.scores, dim=1) batch_scores = accelerator.pad_across_processes( scores, dim=1, pad_index=tokenizer.pad_token_id ) # batch_scores.shape = (batch_size x num_gpus x num_return_sequences, max_length) batch_task_id, batch_pred, batch_scores = accelerator.gather((batch_task_id, batch_pred, batch_scores)) batch_pred = batch_pred[:, prompt_length:] generated_texts = tokenizer.batch_decode(batch_pred, skip_special_tokens=True) mean_logp = compute_mean_logp(batch_scores, batch_pred, tokenizer.pad_token_id) return batch_task_id.tolist(), generated_texts, mean_logp all_preds = [] all_task_ids = [] with torch.no_grad(): for idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)): completions = None completion_scores = None for seq_idx in range(args.num_return_sequences): batch_task_id, generated_texts, mean_logp = generate_completions(batch) if seq_idx == 0: all_task_ids.extend(batch_task_id) batch_size = len(batch_task_id) completions = [[] for _ in range(batch_size)] completion_scores = [[] for _ in range(batch_size)] for j in range(batch_size): completions[j].append(generated_texts[j]) completion_scores[j].append(mean_logp[j]) if args.num_return_sequences == 1: all_preds.extend([c[0] for c in completions]) else: for c, cs in zip(completions, completion_scores): max_score = max(cs) max_index = cs.index(max_score) all_preds.append(c[max_index]) with open(f"{args.output_dir}/prediction.jsonl", "w", encoding="utf-8") as f_pred: id_processed = set() for idx, p in zip(all_task_ids, all_preds): if index2taskid[idx] not in id_processed: f_pred.write(json.dumps({"task_id": index2taskid[idx], "pred": p}) + "\n") id_processed.add(index2taskid[idx]) if __name__ == "__main__": parser = argparse.ArgumentParser() # model inference args parser.add_argument("--language", type=str, required=True, help="language name") parser.add_argument("--model_name_or_path", default=None, type=str, help="Pre-trained Model Path") parser.add_argument( "--model_type", type=str, default="codelm", choices=["codelm", "codelm_cfc"], help="Model type to be loaded" ) parser.add_argument("--prompt_file", type=str, default=None, help="file with a list of prompts") parser.add_argument("--gen_length", type=int, default=50, help="max length of generated token sequence") parser.add_argument("--max_seq_length", type=int, default=2048, help="max length of prompt") parser.add_argument( "--cfc_seq_length", type=int, default=512, help="For model_type=codelm_cfc: Text sequence length corresponding to the retrieved nodes" ) parser.add_argument( "--min_cfc_score", type=float, default=float('-inf'), help="For model_type=codelm_cfc: min score of a chunk to be considered as CFC chunk" ) parser.add_argument("--batch_size", type=int, default=32, help="batch size for code completion") parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped") parser.add_argument("--cache_dir", type=str, default=None) parser.add_argument( "--temperature", type=float, default=0.2, help="temperature of 1.0 has no effect, lower tend toward greedy sampling" ) parser.add_argument("--output_dir", type=str, default="output_dir", help="output directory to save predictions") parser.add_argument("--top_k", type=int, default=0) parser.add_argument("--top_p", type=float, default=0.95) parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.") parser.add_argument("--repetition_penalty", type=float, default=1.0, help="The parameter for repetition penalty.") parser.add_argument( "--preprocessing_num_workers", type=int, default=1, help="The number of processes to use for the preprocessing." ) parser.add_argument( "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" ) parser.add_argument("--dtype", type=str, default='bf16') parser.add_argument("--do_sample", action="store_true", help="whether we do sampling or greedy/beam-search") parser.add_argument("--num_beams", type=int, default=1, help="num of beam for beam-search") # compute metric args parser.add_argument( "--ts_lib", type=str, default="build/python-lang-parser.so", help="tree-sitter lib for tokenize code" ) # only compute metric parser.add_argument("--only_compute_metric", action="store_true", help="only compute metric") args = parser.parse_args() set_seed(args.seed, device_specific=False) if args.num_return_sequences > 1: assert args.do_sample, "sampling must be set to True when num_return_sequences > 1" accelerator = Accelerator() if not args.only_compute_metric: tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True) tokenized_datasets, index2taskid = build_datasets(args, tokenizer) model_inference(tokenized_datasets, index2taskid, tokenizer) # check if the process is the main process if accelerator.is_main_process: compute_metric_stmt(args) ================================================ FILE: scripts/eval_metric.py ================================================ import json from functools import partial import torch.multiprocessing as mp from tqdm import tqdm from tree_sitter import Language, Parser from eval_utils import ( postprocess_code_lines, extract_identifiers, cal_edit_sim, remove_comments ) import os parser = None def compute_id_match(pred_ids, target_ids): pred_ids = list(set(pred_ids)) target_ids = list(set(target_ids)) tp = 0 fp = 0 fn = 0 for pid in pred_ids: if pid in target_ids: tp += 1 else: fp += 1 for tid in target_ids: if tid not in pred_ids: fn += 1 return tp, fp, fn def compute_edit_sim(samples): refs, hyps = [], [] for s in samples: refs.append(s["target"]) hyps.append(s["pred"]) return cal_edit_sim(refs, hyps) def process_examples(lang, args): sample, ex = args global parser prediction = postprocess_code_lines(ex["prompt"], sample["pred"], parser, lang) prediction = remove_comments(prediction) target = ex["groundtruth"] target = remove_comments(target) pred_lines = [l.strip() for l in prediction.split("\n") if l.strip()] gt_lines = [l.strip() for l in target.split("\n") if l.strip()] em_label = int(pred_lines == gt_lines) pred_ids = extract_identifiers(prediction, lang) target_ids = extract_identifiers(target, lang) trunc_s = { "task_id": sample["task_id"], "pred": prediction, "target": target, "pred_ids": pred_ids, "target_ids": target_ids } return trunc_s, em_label def compute_metric_stmt(args): with open(os.path.join(args.output_dir, "prediction.jsonl"), "r") as f_pred: samples = [] for l in f_pred.readlines(): samples.append(json.loads(l)) examples = {} with open(args.prompt_file, "r") as f_in: for l in f_in.readlines(): ex = json.loads(l) examples[ex["metadata"]["task_id"]] = { "prompt": ex["prompt"], "groundtruth": ex["groundtruth"] } assert len(samples) == len(examples), f"{len(samples)} != {len(examples)}" global parser ts_lang = "c_sharp" if args.language == "csharp" else args.language language = Language(args.ts_lib, ts_lang) parser = Parser() parser.set_language(language) truncated_samples = [] em_labels = [] print("post-processing samples ...") pool = mp.Pool(mp.cpu_count() - 1) worker = partial(process_examples, args.language) with tqdm(total=len(samples)) as pbar: for output in pool.imap_unordered(worker, zip(samples, [examples[s["task_id"]] for s in samples])): trunc_s, em_label = output em_labels.append(em_label) truncated_samples.append(trunc_s) pbar.update() exact_match = 0 with open(os.path.join(args.output_dir, "prediction_truncated.jsonl"), 'w', encoding="utf-8") as pt, \ open(f"{args.output_dir}/exact_match_idx.jsonl", 'w') as em: for trunc_s, em_label in zip(truncated_samples, em_labels): pt.write(json.dumps(trunc_s) + "\n") if em_label == 1: em.write(f'{trunc_s["task_id"]}\n') exact_match += 1 ### Score calculation id_em = [] edit_similarities = [] detailed_results = [] for idx, trunc_s in enumerate(truncated_samples): identifier_em = int(trunc_s["pred_ids"] == trunc_s["target_ids"]) es = cal_edit_sim([trunc_s["target"]], [trunc_s["pred"]]) id_tp, id_fp, id_fn = compute_id_match(trunc_s["pred_ids"], trunc_s["target_ids"]) id_em.append(identifier_em) edit_similarities.append(es) detailed_results.append({ "task_id": trunc_s["task_id"], "em": em_labels[idx], "es": es, "id_em": identifier_em, "id_precision": id_tp / (id_tp + id_fp) if (id_tp + id_fp) != 0 else 0, "id_recall": id_tp / (id_tp + id_fn) if (id_tp + id_fn) != 0 else 0, "id_f1": 2 * id_tp / (2 * id_tp + id_fp + id_fn) if (2 * id_tp + id_fp + id_fn) != 0 else 0, }) em_ratio = round(exact_match / len(samples) * 100, 2) edit_sim = round(sum(edit_similarities) / len(edit_similarities), 2) id_em_ratio = round( sum(detailed_results[idx]['id_em'] for idx in range(len(detailed_results))) / len(detailed_results) * 100, 2) id_precision = round(sum(detailed_results[idx]['id_precision'] for idx in range(len(detailed_results))) / len( detailed_results) * 100, 2) id_recall = round( sum(detailed_results[idx]['id_recall'] for idx in range(len(detailed_results))) / len(detailed_results) * 100, 2) id_f1 = round( sum(detailed_results[idx]['id_f1'] for idx in range(len(detailed_results))) / len(detailed_results) * 100, 2) print( f"Code Matching: " f"EM {em_ratio:.2f}, " f"ES {edit_sim:.2f}" ) print( f"ID matching: " f"EM {id_em_ratio}, " #f"Precision {id_precision}, " #f"Recall {id_recall}, " f"F1 {id_f1}" ) with open(os.path.join(args.output_dir, "detailed_results.json"), 'w') as f: for dr in detailed_results: f.write(json.dumps(dr) + "\n") # write the results to a file print(f'writing results to {os.path.join(args.output_dir, "results.json")}') with open(os.path.join(args.output_dir, "results.json"), 'w') as f: res = { "em": em_ratio, "es": edit_sim, "id_em": id_em_ratio, "id_precision": id_precision, "id_recall": id_recall, "id_f1": id_f1, "total": len(truncated_samples) } f.write(json.dumps(res, indent=2)) ================================================ FILE: scripts/eval_utils.py ================================================ # Copyright Amazon.com, Inc. or its affiliates. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast import re from functools import lru_cache from typing import List import timeout_decorator import torch from fuzzywuzzy import fuzz from nltk.tokenize import RegexpTokenizer from sacrebleu.tokenizers.tokenizer_intl import TokenizerV14International from keywords.keywordlist import get_language_keywords IDENTIFIER_REGEX = re.compile('[_a-zA-Z][_a-zA-Z0-9]*') REGEX_TEXT = ("(?<=[a-z0-9])(?=[A-Z])|" "(?<=[A-Z0-9])(?=[A-Z][a-z])|" "(?<=[0-9])(?=[a-zA-Z])|" "(?<=[A-Za-z])(?=[0-9])|" "(?<=[@$.'\"])(?=[a-zA-Z0-9])|" "(?<=[a-zA-Z0-9])(?=[@$.'\"])|" "_|\\s+") string_pattern = r'"([^"\\]*(\\.[^"\\]*)*)"|\'([^\'\\]*(\\.[^\'\\]*)*)\'' SPLIT_REGEX = re.compile(REGEX_TEXT) str_tokenizer = TokenizerV14International() code_tokenizer = RegexpTokenizer(r'\w+') def cal_edit_sim(references, hypotheses): total = len(references) edit_sim = 0.0 for pred, gt in zip(hypotheses, references): pred = pred.strip() gt = gt.strip() edit_sim += fuzz.ratio(pred, gt) return edit_sim / total @lru_cache(maxsize=5000) def split_identifier_into_parts(identifier: str) -> List[str]: """ Split a single identifier into parts on snake_case and camelCase """ identifier_parts = list(s for s in SPLIT_REGEX.split(identifier) if len(s) > 0) if len(identifier_parts) == 0: return [identifier] if "_" in identifier: # We consider "_" as part of identifier and add it back in between each semantic part # if snake_case, we only split identifiers based on "_", ignore the mixed camelCase or other special symbols # this helps us avoid splitting identifiers like "get_2d_array" into ["get", "2", "d", "array"] # also avoid many other corner cases identifier_parts = identifier.split("_") tmp = [identifier_parts[0]] for i in identifier_parts[1:]: tmp.append("_") tmp.append(i) identifier_parts = tmp return identifier_parts def is_identifier(token, lang=None): return True if IDENTIFIER_REGEX.match(token) \ and (lang is None or token not in get_language_keywords(lang)) \ else False def extract_identifiers(source_code, lang): # the main idea is to remove String from a source code # then, tokenize the code to get all words and match with identifier regular expression # check if it is a language specific keyword, it not, then it is an identifier source_code_without_strings = re.sub(string_pattern, '', source_code) _ids = [t for t in code_tokenizer.tokenize(source_code_without_strings) if is_identifier(t, lang)] return _ids def tokenize_string(input_str): return str_tokenizer(input_str) def get_bracket_lang_statement(completion): end_idx = None for i in range(len(completion)): if completion[i] in [";", "}", "{"]: end_idx = i break return completion[:end_idx + 1] if end_idx else completion @timeout_decorator.timeout(5) def get_ast(parser, code): assert isinstance(code, str) or isinstance(code, bytes) if isinstance(code, str): code = bytes(code, "utf8") try: tree = parser.parse(code) return tree except Exception as e: return None def remove_comments(code): code = re.sub(r'#.*', '', code) code = re.sub(r'//.*', '', code) return code def is_parse_valid(parser, code): def syntax_error(node): if node.type == "ERROR": return True try: for child in node.children: if syntax_error(child): return True except RecursionError as err: return True return False tree = get_ast(parser, code) if tree is not None: return not syntax_error(tree.root_node) return False def is_code_parseable(code): try: ast.parse(code) return True except SyntaxError: return False def get_python_one_statement(prompt, completion, parser): for i in range(len(completion)): code = prompt + completion[:i + 1] if not is_parse_valid(parser, code): continue if completion[i + 1] == "\n": return completion[:i + 1].rstrip() return completion def postprocess_code_lines(prompt, completion, parser, lang): try: if lang in ["java", "csharp", "typescript"]: return get_bracket_lang_statement(completion) elif lang == "python": return get_python_one_statement(prompt, completion, parser) except Exception as e: return completion def compute_mean_logp(scores, sequences, pad_token_id): assert scores.shape[0] == sequences.shape[0] assert scores.shape[1] == sequences.shape[1] with torch.no_grad(): logp_vocab = torch.nn.functional.log_softmax(scores, dim=-1) indices = torch.unsqueeze(sequences, dim=-1) logp = torch.gather(logp_vocab, dim=-1, index=indices).squeeze(-1) sum_logp = torch.cumsum(logp, dim=1) # batch_size, seq_len denom = torch.arange(1, sum_logp.shape[1] + 1).reshape(1, -1).to(device=sum_logp.device) # 1, seq_len mean_logp = (sum_logp / denom).tolist() # batch_size, seq_len sequence_lengths = (sequences != pad_token_id).sum(1).tolist() # batch_size mean_logp = [mean_logp[idx][l - 1] for idx, l in enumerate(sequence_lengths)] return mean_logp ================================================ FILE: scripts/keywords/__init__.py ================================================ ================================================ FILE: scripts/keywords/csharp.txt ================================================ abstract as base bool break byte case catch char checked class const continue decimal default delegate do double else enum event explicit extern finally fixed float for foreach goto if implicit in int interface internal is lock long namespace new null object operator out override params private protected public readonly ref return sbyte sealed short sizeof stackalloc static string struct switch this throw try typeof uint ulong unchecked unsafe ushort using using static virtual void volatile while ================================================ FILE: scripts/keywords/java.txt ================================================ abstract assert boolean break byte case catch char class continue default do double else enum extends final finally float for if implements import instanceof int interface long native new package private protected public return short static strictfp super switch synchronized this throw throws transient try void volatile while var const goto ================================================ FILE: scripts/keywords/javascript.txt ================================================ break case catch class const continue debugger default delete do else export extends finally for function if import in instanceof new return super switch this throw try typeof var void while with yield enum implements interface let package private protected public static ================================================ FILE: scripts/keywords/keywordlist.py ================================================ # Original Copyright 2021 Microsoft under MIT License. # From https://github.com/microsoft/dpu-utils/blob/master/python/dpu_utils/codeutils/keywords/keywordlist.py import os import keyword from functools import lru_cache from typing import FrozenSet __all__ = ['get_language_keywords'] _LANGUAGE_TO_FILENAME = { 'c': 'c.txt', 'cpp': 'cpp.txt', 'c++': 'cpp.txt', 'csharp': 'csharp.txt', 'c_sharp': 'csharp.txt', 'c#': 'csharp.txt', 'go': 'go.txt', 'java': 'java.txt', 'javascript': 'javascript.txt', 'js': 'javascript.txt', 'php': 'php.txt', 'ruby': 'ruby.txt', 'typescript': 'typescript.txt', 'ts': 'typescript.txt', } @lru_cache() def get_language_keywords(language: str) -> FrozenSet[str]: """ Returns the keywords of a programming language. There are some inconsistencies across languages wrt to what is considered a keyword. For example, the true/false literals are considered keywords in many languages. However, we exclude them here for consistency. We also exclude special functions-like keywords, such as `die()` in PHP. """ language = language.lower() if language == 'python': return frozenset(k for k in keyword.kwlist if k != 'True' and k != 'False') elif language in _LANGUAGE_TO_FILENAME: name = _LANGUAGE_TO_FILENAME[language] with open(os.path.join(os.path.dirname(__file__), name)) as f: return frozenset(l.strip() for l in f if len(l.strip()) > 0) else: raise Exception('Language keywords `%s` not supported yet. Consider contributing it to dpu-utils.' % language) ================================================ FILE: scripts/keywords/typescript.txt ================================================ break case catch class const continue debugger default delete do else export extends finally for function if import in instanceof new return super switch this throw try typeof var void while with yield enum implements interface let package private protected public static ================================================ FILE: scripts/openai_inference.py ================================================ """ Script to query an OpenAI API to generate code. Set environment variable OPENAI_KEY with your API key before running this script. """ import argparse import json import os import time from typing import Dict, List, Tuple import numpy as np import openai import tiktoken from openai import OpenAI from openai.types.chat import ChatCompletion from tqdm import tqdm SLEEP_SECOND = 2.8 # minimum time to sleep with API errors MAX_SLEEP_SECOND = 120 # maximum time sleep time to wait with exp backoff BUFFER = 100 # estimated tokens used by OpenAI + some more buffer SYS_PROMPT = 'You are Codex, a code completion language model. Continue the code presented to you.' openai_api_key = os.environ.get("OPENAI_API_KEY") assert openai_api_key is not None, "Please set openai_api_key with your API key" client = OpenAI() def query( args, prompt: str, ) -> ChatCompletion: """ This function queries an OpenAI API to generate code based on the given prompt. Args: prompt: str, the prompt to generate code from temperature: float, the value used to module the next token probabilities max_tokens: int, the maximum number of tokens to generate top_p: float, the cumulative probability for top-p filtering Returns: OpenAI Completion object, the response from the OpenAI Codex API """ return client.chat.completions.create(model=args.model, messages=[ {"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": prompt} ], temperature=args.temperature, max_tokens=args.generation_max_tokens, top_p=args.top_p, ) def query_with_retry( args, prompt: str, ) -> ChatCompletion | None: """ This function queries an OpenAI API to generate code based on the given prompt. Args: prompt: str, the prompt to generate code from sleep_second: int, the number of seconds to sleep when the rate limit error is raised temperature: float, the value used to module the next token probabilities max_tokens: int, the maximum number of tokens to generate top_p: float, the cumulative probability for top-p filtering Returns: OpenAI Completion object, the response from the OpenAI Codex API if succeeds else return None Reference: https://github.com/Leolty/repobench/blob/c24b7a80465957e75107eafd23c66d369fa9e755/model/codex.py """ error_sleep_second = SLEEP_SECOND def _upd_error_sleep_time(error_sleep_second): # double the sleep time if it is less than MAX_SLEEP_SECOND seconds if error_sleep_second < MAX_SLEEP_SECOND: error_sleep_second *= 2 # if the sleep time is greater than MAX_SLEEP_SECOND seconds, # then sleep MAX_SLEEP_SECOND seconds else: error_sleep_second = MAX_SLEEP_SECOND return error_sleep_second while True: try: response = query(args, prompt) time.sleep(SLEEP_SECOND + np.random.rand()) return response except openai.RateLimitError as e: print(f'RateLimitError: {e}') print(f'Retrying after {error_sleep_second} seconds') time.sleep(error_sleep_second) error_sleep_second = _upd_error_sleep_time(error_sleep_second) except openai.OpenAIError as e: print(f'OpenAIError: {e}') print(f'Retrying after {error_sleep_second} seconds') time.sleep(error_sleep_second) error_sleep_second = _upd_error_sleep_time(error_sleep_second) def truncate(prompt: str, max_num_tokens: int, tokenizer, side: str) -> str: """Truncate prompt from side given the token budget""" # use tiktokenizer to analyze num of tokens tokens = tokenizer.encode(prompt, disallowed_special=()) num_tokens = len(tokens) if num_tokens > max_num_tokens: if side == 'left': prompt_tokens = tokens[num_tokens - max_num_tokens:] elif side == 'right': prompt_tokens = tokens[:max_num_tokens] else: assert False, 'Invalid side' # decode and encode again as a sanity check prompt = tokenizer.decode(prompt_tokens) new_len = len(tokenizer.encode(prompt, disallowed_special=())) assert new_len <= max_num_tokens return prompt def prepare_prompt( prompt: str, cross_file_context: str, cross_file_budget: int, prompt_budget: int, tokenizer ) -> str: """Create an augmented prompt according to budget specs""" # left truncate original prompt prompt = truncate(prompt, prompt_budget, tokenizer, 'left') if cross_file_context is not None: # right truncate cross file context string cross_file_context = truncate(cross_file_context, cross_file_budget, tokenizer, 'right') else: cross_file_context = '' # return \n return cross_file_context + '\n' + prompt def get_openai_response( sample: Dict, tokenizer, args ) -> Tuple[str, Dict]: """Get OpenAI response for a single sample. Returns the prompt used to infer and the response of the API.""" if args.use_crossfile_context: prompt = prepare_prompt( sample['prompt'], sample['crossfile_context']['text'], args.crossfile_max_tokens, args.model_max_tokens - args.generation_max_tokens - args.crossfile_max_tokens - BUFFER, tokenizer ) else: prompt = prepare_prompt( sample['prompt'], None, 0, args.model_max_tokens - args.generation_max_tokens - BUFFER, tokenizer ) response = query_with_retry(args, prompt) return prompt, response def get_openai_responses( args, data, out_path ) -> List[str]: """Get OpenAI responses to all samples in data, store in out_path, and return list of task ids that were skipped due to some errors""" tokenizer = tiktoken.encoding_for_model(args.model) skipped = [] with open(out_path, 'w') as f: for d in tqdm(data): try: prompt, response = get_openai_response( d, tokenizer, args ) except Exception as e: print('Unknown error', e) raise if response is not None: d['pred_raw'] = response.choices[0].message.content # key compatible with eval script d['pred'] = '\n'.join(d['pred_raw'].split('\n')[1:]).strip('`') if d['pred_raw'].startswith('```') else d['pred_raw'] # newer chatgpt may ourput ```[lang_tag]``` at beginning # d['api_response'] = str(response) d['prompt_used'] = prompt # records the augmented prompt d['task_id'] = d['metadata']['task_id'] # adding for compatibility with eval script print(json.dumps(d), file=f, flush=True) else: skipped.append(d['metadata']['task_id']) print(f'Skipped {d["metadata"]["task_id"]}') return skipped def main(): # get config for current run parser = argparse.ArgumentParser() parser.add_argument('--temperature', type=float, default=0.2) parser.add_argument('--top_p', type=float, default=0.95) parser.add_argument( '--task', type=str, required=True, ) parser.add_argument( '--language', type=str, required=True, choices=['csharp', 'python', 'java', 'typescript'] ) parser.add_argument( '--data_root_dir', type=str, default='data/', help='path to directory where data is organized in lang/task.jsonl format' ) parser.add_argument( '--output_dir', type=str, required=True, help='path to directory where to store outputs' ) parser.add_argument( '--model', type=str, required=True, help='openAI-supported model' ) parser.add_argument( '--model_max_tokens', type=int, default=16384, help='maximum number of tokens of the model' ) parser.add_argument( '--crossfile_max_tokens', type=int, default=12800, help='maximum number of tokens for cross file context' ) parser.add_argument( '--use_crossfile_context', action='store_true', help='whether use cross file context' ) parser.add_argument( '--generation_max_tokens', type=int, default=50, help='maximum number of tokens to generate' ) args = parser.parse_args() print(json.dumps(vars(args), indent=4)) # setup paths if not os.path.isdir(args.output_dir): print(f'==== Output dir does not exist. Creating: {args.output_dir} ====') os.makedirs(args.output_dir) data_path = os.path.join(args.data_root_dir, args.language, args.task + '.jsonl') data = [json.loads(l) for l in open(data_path, 'r').readlines()] out_path = os.path.join(args.output_dir, 'prediction.jsonl') # start OpenAI inference skipped_tasks = get_openai_responses( args, data, out_path ) # save list of skipped tasks with open(out_path.replace('.jsonl', '_skipped_tasks.json'), 'w') as f: f.write(json.dumps(skipped_tasks)) if __name__ == '__main__': main() ================================================ FILE: scripts/vllm_inference.py ================================================ """ Script to run vllm-based inference. See README for an example. """ import argparse import json import os from typing import List from tqdm import tqdm from transformers import AutoTokenizer from transformers.utils import logging from vllm import LLM, SamplingParams logging.set_verbosity_info() logger = logging.get_logger(__name__) # add a small buffer to take care of non-lossless tokenizers BUFFER = 100 def truncate(prompt: str, max_num_tokens: int, side: str, tokenizer) -> str: """Truncate prompt from side given the token budget""" tokens = tokenizer.tokenize(prompt) num_tokens = len(tokens) if num_tokens > max_num_tokens: if side == 'left': prompt_tokens = tokens[num_tokens - max_num_tokens:] elif side == 'right': prompt_tokens = tokens[:max_num_tokens] prompt = tokenizer.convert_tokens_to_string(prompt_tokens) new_len = len(tokenizer.tokenize(prompt)) if new_len > max_num_tokens: logger.warning( f'Number of tokens after truncation is greater than max tokens allowed: {new_len=} {num_tokens=}') return prompt def prepare_prompt( prompt: str, cross_file_context: str, cross_file_budget: int, prompt_budget: int, tokenizer ) -> str: """Create an augmented prompt according to budget specs""" # print(f'{cross_file_budget=} {prompt_budget=}') # left truncate original prompt prompt = truncate(prompt, prompt_budget, 'left', tokenizer) if cross_file_context is not None: # right truncate cross file context string cross_file_context = truncate(cross_file_context, cross_file_budget, 'right', tokenizer) else: cross_file_context = '' return cross_file_context + '\n' + prompt def cceval_generate( args, data, tokenizer, sampling_params, llm ) -> List[str]: prompts = [] for d in data: if args.use_crossfile_context: prompt = prepare_prompt( d['prompt'], d['crossfile_context']['text'], args.crossfile_max_tokens, args.model_max_tokens - args.generation_max_tokens - args.crossfile_max_tokens - BUFFER, tokenizer ) else: prompt = prepare_prompt( d['prompt'], None, 0, args.model_max_tokens - args.generation_max_tokens - BUFFER, tokenizer ) prompts.append(prompt) outputs = llm.generate(prompts, sampling_params) out_path = os.path.join(args.output_dir, 'prediction.jsonl') with open(out_path, 'w') as f: for d, response in tqdm(zip(data, outputs)): d['pred'] = response.outputs[0].text d['task_id'] = d['metadata']['task_id'] print(json.dumps(d), file=f, flush=True) return def main(): # set the OpenAI key # openai.api_key = os.environ.get('OPENAI_KEY', None) # if openai.api_key is None: # raise ValueError('OPENAI_KEY environment variable not set') # get config for current run parser = argparse.ArgumentParser() parser.add_argument('--temperature', type=float, default=0.2) parser.add_argument('--top_p', type=float, default=0.95) parser.add_argument( '--task', type=str, required=True, ) parser.add_argument( '--language', type=str, required=True, choices=['csharp', 'python', 'java', 'typescript'] ) parser.add_argument( '--data_root_dir', type=str, default='data/', help='path to directory where data is organized in lang/task.jsonl format' ) parser.add_argument( '--output_dir', type=str, required=True, help='path to directory where to store outputs' ) parser.add_argument( '--model', type=str, required=True, help='vLLM-supported model' ) parser.add_argument( '--tp_size', type=int, default=1, help='tensor parallel size' ) parser.add_argument( '--model_max_tokens', type=int, default=16384, help='maximum number of tokens of the model' ) parser.add_argument( '--crossfile_max_tokens', type=int, default=12800, help='maximum number of tokens for cross file context' ) parser.add_argument( '--use_crossfile_context', action='store_true', help='whether use cross file context' ) parser.add_argument( '--generation_max_tokens', type=int, default=50, help='maximum number of tokens to generate' ) args = parser.parse_args() print(json.dumps(vars(args), indent=4)) # load model llm = LLM(model=args.model, tensor_parallel_size=args.tp_size, max_model_len=args.model_max_tokens) tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) sampling_params = SamplingParams(temperature=args.temperature, top_p=args.top_p, max_tokens=args.generation_max_tokens) # setup paths if not os.path.isdir(args.output_dir): print(f'==== Output dir does not exist. Creating: {args.output_dir} ====') os.makedirs(args.output_dir) data_path = os.path.join(args.data_root_dir, args.language, args.task + '.jsonl') data = [json.loads(l) for l in open(data_path, 'r').readlines()] # generation cceval_generate(args, data, tokenizer, sampling_params, llm) if __name__ == '__main__': main()