Full Code of amazon-science/cceval for AI

main 40c68d2b7ca2 cached
29 files
40.2 MB
32.5k tokens
54 symbols
1 requests
Download .txt
Repository: amazon-science/cceval
Branch: main
Commit: 40c68d2b7ca2
Files: 29
Total size: 40.2 MB

Directory structure:
gitextract_65e_y3xi/

├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── NOTICE
├── README.md
├── THIRD_PARTY_LICENSES
├── cceval_config.yaml
├── data/
│   └── crosscodeeval_data.tar.xz
├── prompt_builder/
│   ├── README.md
│   ├── augment_with_cfc.py
│   ├── rerank_utils.py
│   ├── run.sh
│   └── utils.py
├── requirements.txt
└── scripts/
    ├── build_treesitter.sh
    ├── build_ts_lib.py
    ├── custom_generate.py
    ├── eval.py
    ├── eval_metric.py
    ├── eval_utils.py
    ├── keywords/
    │   ├── __init__.py
    │   ├── csharp.txt
    │   ├── java.txt
    │   ├── javascript.txt
    │   ├── keywordlist.py
    │   └── typescript.txt
    ├── openai_inference.py
    └── vllm_inference.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
*.idea/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/
.DS_Store
ts_package/
build/



================================================
FILE: CODE_OF_CONDUCT.md
================================================
## Code of Conduct
This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
opensource-codeofconduct@amazon.com with any additional questions or comments.


================================================
FILE: CONTRIBUTING.md
================================================
# Contributing Guidelines

Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
documentation, we greatly value feedback and contributions from our community.

Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
information to effectively respond to your bug report or contribution.


## Reporting Bugs/Feature Requests

We welcome you to use the GitHub issue tracker to report bugs or suggest features.

When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:

* A reproducible test case or series of steps
* The version of our code being used
* Any modifications you've made relevant to the bug
* Anything unusual about your environment or deployment


## Contributing via Pull Requests
Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:

1. You are working against the latest source on the *main* branch.
2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
3. You open an issue to discuss any significant work - we would hate for your time to be wasted.

To send us a pull request, please:

1. Fork the repository.
2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
3. Ensure local tests pass.
4. Commit to your fork using clear commit messages.
5. Send us a pull request, answering any default questions in the pull request interface.
6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.

GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
[creating a pull request](https://help.github.com/articles/creating-a-pull-request/).


## Finding contributions to work on
Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.


## Code of Conduct
This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
opensource-codeofconduct@amazon.com with any additional questions or comments.


## Security issue notifications
If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.


## Licensing

See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.


================================================
FILE: LICENSE
================================================

                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.


================================================
FILE: NOTICE
================================================
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.


================================================
FILE: README.md
================================================
# CrossCodeEval: A Diverse and Multilingual Benchmark for Cross-File Code Completion

This repository contains the data and inference code of the NeurIPS 2023  (Datasets and Benchmarks track)
paper "[CrossCodeEval: A Diverse and Multilingual Benchmark for Cross-File Code Completion](https://arxiv.org/abs/2310.11248)."

## Requirements

- Uncompress the CrossCodeEval data via `tar -xvJf data/crosscodeeval_data.tar.xz -C data/`
    - The data contains {baseline, retrieval, retrieval w/ ref.} setting x {bm25, UniXCoder, OpenAI Ada} retriever.
    - **Please [email us](mailto:y.robin.ding@gmail.com) if you need the raw data.**
- Install dependencies via `pip install -r requirements.txt`
- Build tree sitter via `bash scripts/build_treesitter.sh`


## Evaluation on CrossCodeEval
Our evaluation consists of two steps: generation and metrics calculation.


### Generation

#### Publicly Available Models
For publicly available models like StarCoder, DeepSeek-Coder, etc., we recommended using [vLLM](https://github.com/vllm-project/vllm) for fast and distributed inference on CrossCodeEval. 

```bash
export gpus=2
export model=bigcode/starcoder2-3b
export language=python
export task=line_completion_rg1_unixcoder_cosine_sim
export output_dir=./tmp/crosscodeeval_testrun/
python scripts/vllm_inference.py \
  --tp $gpus \
  --task $task \
  --language $language \
  --model $model \
  --output_dir $output_dir \
  --use_crossfile_context 
```
For additional args, e.g., cross-file context length and sampling top_p, please see `python vllm_inference.py --help`.

<details><summary> If you prefer non-vLLM script <i>:: click to expand ::</i></summary>
<div>

First, configure `accelerate` via `accelerate config` if you haven't. A reference configuration is available at `cceval_config.yaml`

The following command demonstrates how to run greedy eval using codegen-350M on python with cross-file context.

```bash
export model_type=codelm_cfc # or codelm for no cross-file context eval
export model_name=Salesforce/codegen-350M-mono
export language=python
export ts_lib=./build/${language}-lang-parser.so
export dtype=bf16 # or fp16
export prompt_file=./data/crosscodeeval_data/${language}/line_completion_rg1_unixcoder_cosine_sim.jsonl # or other options in the dir, which corresponds to different retrieval methods and/or retrieval settings
export max_seq_length=2048
export cfc_seq_length=512 
export batch_size=16 # reduce for larger models
export output_dir=./tmp/crosscodeeval_testrun/

accelerate launch eval.py \
        --model_type $model_type \
        --model_name_or_path $model_name \
        --cfc_seq_length $cfc_seq_length \
        --prompt_file $prompt_file \
        --gen_length 50 \
        --max_seq_length $max_seq_length \
        --batch_size $batch_size \
        --output_dir $output_dir \
        --dtype $dtype \
        --num_return_sequences 1 \
        --overwrite_cache True \
        --ts_lib $ts_lib \
        --language $language
```

You may run sampling via the following (additional) args:

```bash
        --do_sample \
        --top_p 0.95 \
        --temperature 0.2 \
        --num_return_sequences 5 \
```


</div>
</details>

#### OpenAI models
OpenAI models are accessible through an API. You may use the following script:
```bash
export model=gpt-3.5-turbo-0125 
export language=python
export task=line_completion_rg1_unixcoder_cosine_sim
export output_dir=./tmp/crosscodeeval_openai_testrun/
python scripts/openai_inference.py \
  --task $task \
  --language $language \
  --model $model \
  --output_dir $output_dir \
  --use_crossfile_context 

```


### Metrics Calculation
After obtaining the generation, we can calculate the final metrics
```bash
export language=python
export ts_lib=./build/${language}-lang-parser.so; 
export task=line_completion_oracle_unixcoder_cosine_sim
export prompt_file=./data/${language}/${task}.jsonl 
export output_dir=./tmp/crosscodeeval_testrun/;  
python scripts/eval.py \
  --prompt_file $prompt_file \
  --output_dir $output_dir \
  --ts_lib $ts_lib \
  --language $language \
  --only_compute_metric
```







## Citation

```

@inproceedings{ding2023crosscodeeval,
    title={CrossCodeEval: A Diverse and Multilingual Benchmark for Cross-File Code Completion}, 
    author={Yangruibo Ding and Zijian Wang and Wasi Uddin Ahmad and Hantian Ding and Ming Tan and Nihal Jain and Murali Krishna Ramanathan and Ramesh Nallapati and Parminder Bhatia and Dan Roth and Bing Xiang},
    year={2023},
    booktitle={Thirty-seventh Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
    url={https://arxiv.org/pdf/2310.11248.pdf}
}
```
## Questions
Please feel free to email us. You may also submit an issue in this repo.

## Security

See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.

## License

This project is licensed under the Apache-2.0 License.


================================================
FILE: THIRD_PARTY_LICENSES
================================================
The CrossCodeEval repository includes the following third-party software/licensing:

The keywordlist.py was from https://github.com/microsoft/dpu-utils/blob/master/python/dpu_utils/codeutils/keywords/keywordlist.py with license

    MIT License

    Copyright (c) Microsoft Corporation. All rights reserved.

    Permission is hereby granted, free of charge, to any person obtaining a copy
    of this software and associated documentation files (the "Software"), to deal
    in the Software without restriction, including without limitation the rights
    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    copies of the Software, and to permit persons to whom the Software is
    furnished to do so, subject to the following conditions:

    The above copyright notice and this permission notice shall be included in all
    copies or substantial portions of the Software.

    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    SOFTWARE


================================================
FILE: cceval_config.yaml
================================================
compute_environment: LOCAL_MACHINE
deepspeed_config:
  gradient_accumulation_steps: 1
  offload_optimizer_device: none
  offload_param_device: none
  zero3_init_flag: false
  zero3_save_16bit_model: false
  zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false


================================================
FILE: data/crosscodeeval_data.tar.xz
================================================
[File too large to display: 40.1 MB]

================================================
FILE: prompt_builder/README.md
================================================
## Retrieval Augmented Prompting

We can generate the retrieval augmented prompt following the below 3 steps.

1. Please email us if to get the raw software repositories.
2. Set the `repository_root` variable with the root directory which contains the raw software repositories.
3. Run the `run.sh` bash script.

================================================
FILE: prompt_builder/augment_with_cfc.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import json
import time
import glob
import argparse
import multiprocessing as mp
from tqdm import tqdm
from functools import partial
from rerank_utils import lexical_ranking, SemanticReranking
from utils import str2bool, file_distance, tokenize_nltk

CHUNK_SIZE = 10
SLIDING_WINDOW_SIZE = 10  # non-overlapping chunks if SLIDING_WINDOW_SIZE=CHUNK_SIZE
QUERY_LENGTH = 10  # last N lines from prompt will be query

repository_root = "/PATH/TO/REPOS"  # get the data from authors

input_files = {
    "python": "../data/crosscodeeval_data/python/line_completion.jsonl",
    "java": "../data/crosscodeeval_data/java/line_completion.jsonl",
    "typescript": "../data/crosscodeeval_data/typescript/line_completion.jsonl",
    "csharp": "../data/crosscodeeval_data/csharp/line_completion.jsonl"
}

file_ext = {"python": "py", "java": "java", "typescript": "ts", "csharp": "cs"}


def get_crossfile_context_from_chunks(
        args,
        prompt,
        code_chunks,
        code_chunk_ids,
        groundtruth,
        semantic_ranker
):
    assert len(code_chunks) != 0
    candidate_code_chunks = code_chunks[:args.maximum_chunk_to_rerank]
    candidate_code_chunk_ids = code_chunk_ids[:args.maximum_chunk_to_rerank]

    ranking_scores = None
    meta_data = {}

    if args.rerank:
        if args.query_type == "groundtruth":
            # oracle experiment
            prompt_lines = [pl for pl in prompt.split("\n") if pl.strip()]
            groundtruth_lines = [gt for gt in groundtruth.split("\n") if gt.strip()]
            code_lines = prompt_lines + groundtruth_lines
            query = "\n".join(code_lines[-QUERY_LENGTH:])
        elif args.query_type == "last_n_lines":
            prompt_lines = [pl for pl in prompt.split("\n") if pl.strip()]
            query = "\n".join(prompt_lines[-QUERY_LENGTH:])
        else:
            raise NotImplementedError

        meta_data["query"] = query
        start = time.time()

        if args.ranking_fn == "cosine_sim":
            gpu_id = int(mp.current_process().name.split('-')[-1]) - 1
            candidate_code_chunks, candidate_code_chunk_ids, ranking_scores = semantic_ranker.rerank(
                query,
                candidate_code_chunks,
                candidate_code_chunk_ids,
                gpu_id,
                score_threshold=None
            )
        else:
            candidate_code_chunks, candidate_code_chunk_ids, ranking_scores = lexical_ranking(
                query,
                candidate_code_chunks,
                args.ranking_fn,
                candidate_code_chunk_ids,
                score_threshold=None
            )

        meta_data["latency"] = time.time() - start
        meta_data["num_candidates"] = len(candidate_code_chunks)

    top_k = min(args.maximum_cross_file_chunk, len(candidate_code_chunk_ids))
    if top_k == 0:
        return [], meta_data

    selected_chunks = []
    selected_chunks_filename = []
    selected_chunks_scores = []

    if args.use_next_chunk_as_cfc:
        # prepare an id2idx map
        assert len(candidate_code_chunks) == len(candidate_code_chunk_ids)
        id2idx = dict()
        for j, cci in enumerate(code_chunk_ids):
            id2idx[cci] = j

        total_added = 0
        for cidx, _id in enumerate(candidate_code_chunk_ids):
            fname, c_id = _id.rsplit("|", 1)
            next_id = f"{fname}|{int(c_id) + 1}"
            if next_id not in id2idx:
                to_add = code_chunks[id2idx[_id]]
            else:
                to_add = code_chunks[id2idx[next_id]]

            if to_add not in selected_chunks:
                selected_chunks.append(to_add)
                selected_chunks_filename.append(fname)
                if args.rerank:
                    selected_chunks_scores.append(ranking_scores[cidx])
                total_added += 1
                if total_added == top_k:
                    break
    else:
        selected_chunks = candidate_code_chunks[:top_k]
        selected_chunks_filename = [_id.rsplit("|", 1)[0] for _id in candidate_code_chunk_ids[:top_k]]
        if args.rerank:
            selected_chunks_scores = ranking_scores[:top_k]

    cross_file_context = []
    for idx in range(len(selected_chunks)):
        cross_file_context.append({
            "retrieved_chunk": selected_chunks[idx],
            "filename": selected_chunks_filename[idx],
            "score": selected_chunks_scores[idx] if args.rerank else None
        })

    line_start_sym = "#" if args.language == "python" else "//"
    cfc_text = f"{line_start_sym} Here are some relevant code fragments from other files of the repo:\n\n"
    for sc, scf in zip(selected_chunks, selected_chunks_filename):
        cfc_text += f"{line_start_sym} the below code fragment can be found in:\n{line_start_sym} {scf}" + "\n"
        cfc_text += "\n".join([f"{line_start_sym} {cl}" for cl in sc.strip('\n').splitlines()]) + "\n\n"

    return cross_file_context, cfc_text, meta_data


def read_project_files(repo_name, lang):
    # root_dir needs a trailing slash (i.e. /root/dir/)
    project_context = {}
    root_dir = os.path.join(repository_root, lang, repo_name)
    if not os.path.isdir(root_dir):
        print(f"Repository not found: {root_dir}")
        return project_context

    if lang == "typescript":
        src_files = []
        src_files += glob.glob(os.path.join(root_dir, f'src/**/*.ts'), recursive=True)
        src_files += glob.glob(os.path.join(root_dir, f'src/**/*.tsx'), recursive=True)
    else:
        src_files = glob.glob(os.path.join(root_dir, f'**/*.{file_ext[lang]}'), recursive=True)

    if len(src_files) == 0:
        return project_context

    for filename in src_files:
        if os.path.exists(filename):  # weird but some files cannot be opened to read
            if os.path.isfile(filename):
                try:
                    with open(filename, "r") as file:
                        file_content = file.read()
                except:
                    with open(filename, "rb") as file:
                        file_content = file.read().decode(errors='replace')

                fileid = os.path.relpath(filename, root_dir)
                project_context[fileid] = file_content
        else:
            pass
            # print(f"File not found: {filename}")

    return project_context


def find_files_within_distance_k(current_file_path, filelist, k):
    list_of_modules = []
    module_weight = []
    for filepath in filelist:
        if filepath != current_file_path:
            dist = file_distance(filepath, current_file_path)
            if dist == -1:
                continue
            elif dist <= k:
                list_of_modules.append(filepath)
                module_weight.append(dist)

    # sorting in ascending order
    list_of_modules = [x for _, x in sorted(zip(module_weight, list_of_modules))]
    return list_of_modules


def get_cfc(example, args, semantic_ranker, repositories):
    project_context = repositories[example["metadata"]["repository"]]
    status = None
    current_filepath = example["metadata"]["file"]
    if len(project_context) == 0:
        example["crossfile_context"] = ""
        status = "project_not_found"
    else:
        current_filecontent = None
        for filepath, filecontent in project_context.items():
            if filepath == current_filepath:
                current_filecontent = filecontent
                break

        if current_filecontent is None:
            example["crossfile_context"] = {}
            print(current_filepath)
            status = "file_not_found_in_project"

        else:
            pyfiles = find_files_within_distance_k(
                example["metadata"]["file"],
                list(project_context.keys()),
                k=args.crossfile_distance
            )
            pyfiles = pyfiles[:args.maximum_cross_files]

            code_chunks = []
            code_chunk_ids = []
            for pyfile in pyfiles:
                lines = project_context[pyfile].split("\n")
                lines = [l for l in lines if l.strip()]  # removing empty lines
                c_id = 0
                for i in range(0, len(lines), SLIDING_WINDOW_SIZE):
                    c = "\n".join(lines[i:i + CHUNK_SIZE])
                    tokenized_c = tokenize_nltk(c)
                    if len(tokenized_c) > 0:
                        code_chunks.append(c)
                        code_chunk_ids.append(f"{pyfile}|{c_id}")
                        c_id += 1

            if len(code_chunks) == 0:
                example["crossfile_context"] = {}
                status = "no_crossfile_context"

            else:
                cfc, cfc_text, meta_data = get_crossfile_context_from_chunks(
                    args=args,
                    prompt=example["prompt"],
                    code_chunks=code_chunks,
                    code_chunk_ids=code_chunk_ids,
                    groundtruth=example["groundtruth"],
                    semantic_ranker=semantic_ranker
                )
                example["crossfile_context"] = {}
                example["crossfile_context"]["text"] = cfc_text
                example["crossfile_context"]["list"] = cfc

    return example, status


def attach_data(args, srcfile):
    empty_cfc = 0
    error_freq = {
        "project_not_found": 0,
        "file_not_found_in_project": 0,
        "no_crossfile_context": 0
    }
    output_examples = []

    examples = []
    repositories = dict()
    with open(srcfile) as f:
        for line in f:
            ex = json.loads(line)
            repo_name = ex["metadata"]["repository"]
            if repo_name not in repositories:
                repositories[repo_name] = read_project_files(repo_name, args.language)
            examples.append(ex)

    semantic_ranker = None
    if args.ranking_fn == "cosine_sim":
        semantic_ranker = SemanticReranking(
            args.ranker,
            max_sequence_length=256
        )

    pool = mp.Pool(args.num_processes)
    worker = partial(get_cfc, args=args, semantic_ranker=semantic_ranker, repositories=repositories)

    with tqdm(total=len(examples)) as pbar:
        for (d, stat) in pool.imap_unordered(worker, examples):
            if stat in error_freq:
                error_freq[stat] += 1
            if len(d["crossfile_context"]) == 0:
                empty_cfc += 1
                if not args.skip_if_no_cfc:
                    output_examples.append(d)
            else:
                output_examples.append(d)
            pbar.update()

    print("Total examples with empty CFC: ", empty_cfc)
    print(error_freq)
    return output_examples


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--rerank",
        type=str2bool,
        default=True,
        help="rerank the functions"
    )
    parser.add_argument(
        "--ranker",
        type=str,
        default="sparse",
        choices=["sparse", "unixcoder"],
        help="ranking function"
    )
    parser.add_argument(
        "--ranking_fn",
        type=str,
        default="bm25",
        choices=["tfidf", "bm25", "jaccard_sim", "cosine_sim"],
        help="ranking function"
    )
    parser.add_argument(
        "--query_type",
        type=str,
        default="last_n_lines",
        choices=["last_n_lines", "groundtruth"],
        help="how to form query from prompt"
    )
    parser.add_argument(
        "--crossfile_distance",
        type=int,
        default=100,
        help="max distance to search for crossfile"
    )
    parser.add_argument(
        "--maximum_chunk_to_rerank",
        type=int,
        default=1000,
        help="max chunks to consider to rank via BM25"
    )
    parser.add_argument(
        "--maximum_cross_files",
        type=int,
        default=1000,
        help="max chunks to consider to rank via BM25"
    )
    parser.add_argument(
        "--maximum_cross_file_chunk",
        type=int,
        default=50,
        help="max chunks to return as cfc"
    )
    parser.add_argument(
        "--use_next_chunk_as_cfc",
        type=str2bool,
        default=True,
        help="use next code chunk as context"
    )
    parser.add_argument(
        "--skip_if_no_cfc",
        type=str2bool,
        default=True,
        help="skip adding examples if there is no crossfile context"
    )
    parser.add_argument(
        "--output_file_suffix",
        type=str,
        default=None,
        help="add a suffix string to the output file"
    )
    parser.add_argument(
        "--language",
        type=str,
        required=True,
        choices=["java", "python", "typescript", "csharp"],
        help="language name"
    )
    args = parser.parse_args()

    args.output_file_suffix = "" if args.output_file_suffix is None else f"_{args.output_file_suffix}"
    if args.use_next_chunk_as_cfc:
        assert args.rerank
        assert args.query_type != "groundtruth"

    tgtfile_suffix = ""
    if args.rerank:
        tgtfile_suffix += f"_{args.ranking_fn}"

    args.num_processes = 60
    if args.ranking_fn == "cosine_sim":
        num_gpus = 8
        args.num_processes = num_gpus
        mp.set_start_method('spawn')

    input_file = input_files[args.language]
    output_path = os.path.dirname(input_file)
    output_filename = os.path.splitext(os.path.basename(input_file))[0]
    output_filename = output_filename + args.output_file_suffix + tgtfile_suffix + ".jsonl"
    output_file = os.path.join(output_path, output_filename)
    output_examples = attach_data(args, input_file)
    with open(output_file, "w") as fw:
        for ex in output_examples:
            fw.write(json.dumps(ex))
            fw.write("\n")


================================================
FILE: prompt_builder/rerank_utils.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from rank_bm25 import BM25Okapi
from typing import List
from multiprocessing import Pool, cpu_count
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from utils import tokenize_nltk
from transformers import AutoModel, AutoTokenizer, AutoConfig


def jaccard_similarity(tokenized_query, tokenized_doc, containment=False):
    set1 = set(tokenized_query)
    set2 = set(tokenized_doc)
    intersection = len(set1.intersection(set2))
    union = len(set1) if containment else len(set1.union(set2))
    return float(intersection) / union


def tokenize_corpus(corpus, tokenizer_fn):
    pool = Pool(cpu_count())
    tokenized_corpus = pool.map(tokenizer_fn, corpus)
    return tokenized_corpus


def tokenize_query_and_docs(query, docs):
    tokenized_query = tokenize_nltk(query)
    tokenized_docs = [tokenize_nltk(d) for d in docs]
    return tokenized_query, tokenized_docs


def lexical_ranking(
        query,
        docs,
        ranking_fn,
        doc_ids=None,
        score_threshold=None,
):
    if ranking_fn == "bm25":
        tokenized_query, tokenized_docs = tokenize_query_and_docs(query, docs)
        bm25 = BM25Okapi(tokenized_docs)
        scores = bm25.get_scores(tokenized_query)
    elif ranking_fn == "tfidf":
        tfidf_vectorizer = TfidfVectorizer(tokenizer=tokenize_nltk)
        X = tfidf_vectorizer.fit_transform(docs).toarray()  # (n_fn, n_features)
        y = tfidf_vectorizer.transform([query]).toarray()  # (1, n_features)
        scores = cosine_similarity(X, y).tolist()  # (n_fn, 1)
    elif ranking_fn == "jaccard_sim":
        tokenized_query, tokenized_docs = tokenize_query_and_docs(query, docs)
        scores = [jaccard_similarity(tokenized_query, d, containment=False) for d in tokenized_docs]
    else:
        raise NotImplementedError

    if score_threshold:
        skip_ids = [idx for idx, s in enumerate(scores) if s < score_threshold]
        scores = [s for idx, s in enumerate(scores) if idx not in skip_ids]
        docs = [d for idx, d in enumerate(docs) if idx not in skip_ids]
        if doc_ids is not None:
            doc_ids = [doc_id for idx, doc_id in enumerate(doc_ids) if idx not in skip_ids]

    if len(docs) == 0:
        return docs, doc_ids, scores

    if doc_ids is not None:
        doc_ids = [x for _, x in sorted(zip(scores, doc_ids), reverse=True)]
    docs_scores = [(x, s) for s, x in sorted(zip(scores, docs), reverse=True)]
    docs = [item[0] for item in docs_scores]
    scores = [item[1] for item in docs_scores]

    return docs, doc_ids, scores


class SemanticReranking:

    def __init__(self, model_type="unixcoder", **kwargs):
        self.model_type = model_type
        if model_type == "unixcoder":
            self.tokenizer = AutoTokenizer.from_pretrained('microsoft/unixcoder-base')
            self.model = AutoModel.from_pretrained('microsoft/unixcoder-base')
        else:
            raise NotImplementedError

        # maximum sequence length for query and documents
        self.max_sequence_length = kwargs.get("max_sequence_length", 256)

    def text_to_tensor(
            self,
            text: str,
            pad_to_max: bool = True,
    ):
        text = text.strip()

        # tokenizer automatic padding is explicitly disabled since its inconsistent behavior
        token_ids = self.tokenizer.encode(
            text,
            add_special_tokens=False,
            max_length=self.max_sequence_length,
            pad_to_max_length=False,
            truncation=True
        )

        if pad_to_max and len(token_ids) < self.max_sequence_length:
            token_ids = token_ids + [self.tokenizer.pad_token_id] * (self.max_sequence_length - len(token_ids))
        if len(token_ids) > self.max_sequence_length:
            token_ids = token_ids[0:self.max_sequence_length]

        return torch.tensor(token_ids)

    def get_pad_id(self):
        return self.tokenizer.pad_token_id

    def get_attn_mask(self, tokens_tensor):
        return tokens_tensor != self.get_pad_id()

    def get_representations(self, list_input_ids, gpu_id):
        device = torch.device('cuda', gpu_id)
        self.model = self.model.to(device=device, dtype=torch.float16)
        self.model.eval()

        batch_size = 64
        sequence_outputs = []
        pooled_outputs = []

        for idx in range(0, len(list_input_ids), batch_size):
            start, end = idx, min(idx + batch_size, len(list_input_ids))
            input_ids = torch.stack(list_input_ids[start:end], dim=0).to(device=device)
            attention_mask = self.get_attn_mask(input_ids)

            if self.model_type in CODE_SAGE_MODELS.keys():
                output = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    output_hidden_states=True
                )
                token_embeddings = output.hidden_states[-1]  # bsz x seq_len x hid_dim
            else:
                output = self.model(input_ids, attention_mask)
                token_embeddings = output.last_hidden_state  # bsz x seq_len x hid_dim

            mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            sum_embeddings = torch.sum(token_embeddings * mask_expanded, 1)
            sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
            sequence_embeddings = sum_embeddings / sum_mask  # bsz x hid_dim

            sequence_outputs.append(token_embeddings)
            pooled_outputs.append(sequence_embeddings)

        sequence_output = torch.cat(sequence_outputs)
        pooled_output = torch.cat(pooled_outputs)

        return sequence_output, pooled_output

    def rerank(self, query: str, docs: List[str], doc_ids: List[str] = None, gpu_id=0, score_threshold=None):
        with torch.no_grad():
            batch_queries = [self.text_to_tensor(query)]
            batch_candidates = [self.text_to_tensor(d) for d in docs]

            _, query_rep = self.get_representations(batch_queries, gpu_id)  # 1 x hidden_size
            _, candi_rep = self.get_representations(batch_candidates, gpu_id)  # num_cand x hidden_size
            scores = torch.nn.functional.cosine_similarity(query_rep, candi_rep).tolist()  # num_cand

        if score_threshold:
            skip_ids = [idx for idx, s in enumerate(scores) if s < score_threshold]
            scores = [s for idx, s in enumerate(scores) if idx not in skip_ids]
            docs = [d for idx, d in enumerate(docs) if idx not in skip_ids]
            if doc_ids is not None:
                doc_ids = [doc_id for idx, doc_id in enumerate(doc_ids) if idx not in skip_ids]

        if len(docs) == 0:
            return docs, doc_ids, scores

        if doc_ids is not None:
            doc_ids = [x for _, x in sorted(zip(scores, doc_ids), reverse=True)]
        docs_scores = [(x, s) for s, x in sorted(zip(scores, docs), reverse=True)]
        docs = [item[0] for item in docs_scores]
        scores = [item[1] for item in docs_scores]

        return docs, doc_ids, scores


================================================
FILE: prompt_builder/run.sh
================================================
#!/usr/bin/env bash

export PYTHONIOENCODING=utf-8

function generate_data() {
    lang=$1
    ranker=$2
    ranking_fn=$3

    echo "$lang, $ranker, $ranking_fn"

    output_file_suffix=""
    if [[ $ranker != "sparse" ]]; then
        output_file_suffix="_${ranker}"
    fi

    # for RG-1
    python augment_with_cfc.py \
        --language $lang \
        --rerank True \
        --ranker $ranker \
        --ranking_fn $ranking_fn \
        --query_type last_n_lines \
        --crossfile_distance 100 \
        --maximum_chunk_to_rerank 1000 \
        --maximum_cross_files 1000 \
        --maximum_cross_file_chunk 5 \
        --use_next_chunk_as_cfc True \
        --skip_if_no_cfc False \
        --output_file_suffix "rg1${output_file_suffix}"

    # for oracle experiment
    python augment_with_cfc.py \
        --language $lang \
        --rerank True \
        --ranker $ranker \
        --ranking_fn $ranking_fn \
        --query_type groundtruth \
        --crossfile_distance 100 \
        --maximum_chunk_to_rerank 1000 \
        --maximum_cross_files 1000 \
        --maximum_cross_file_chunk 5 \
        --use_next_chunk_as_cfc False \
        --skip_if_no_cfc False \
        --output_file_suffix "oracle${output_file_suffix}"
}

generate_data python sparse bm25
generate_data java sparse bm25
generate_data typescript sparse bm25
generate_data csharp sparse bm25

generate_data python sparse jaccard_sim
generate_data java sparse jaccard_sim
generate_data typescript sparse jaccard_sim
generate_data csharp sparse jaccard_sim

generate_data python unixcoder cosine_sim
generate_data java unixcoder cosine_sim
generate_data typescript unixcoder cosine_sim
generate_data csharp unixcoder cosine_sim


================================================
FILE: prompt_builder/utils.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import os
from typing import List
from nltk.tokenize import word_tokenize


def tokenize_nltk(text):
    words = word_tokenize(text)
    output_list = []
    for w in words:
        w_list = re.findall(r'\w+', w)
        output_list.extend(w_list)
    return output_list


def file_distance(src_file, dest_file):
    distance = -1
    try:
        commonpath = os.path.commonpath([src_file, dest_file])
        rel_file1_path = os.path.relpath(src_file, commonpath)
        rel_file2_path = os.path.relpath(dest_file, commonpath)
        distance = rel_file1_path.count(os.sep) + rel_file2_path.count(os.sep)
    except Exception as e:
        # print(e, src_file, dest_file)
        pass

    return distance


def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


================================================
FILE: requirements.txt
================================================
torch
transformers
datasets
tree-sitter
timeout-decorator
bitsandbytes
accelerate
scikit-learn
rank-bm25
fuzzywuzzy
nltk
sacrebleu
deepspeed
tiktoken
vllm>=0.3.3


================================================
FILE: scripts/build_treesitter.sh
================================================
mkdir ts_package;
cd ts_package;
# Download the tree-sitter package
git clone https://github.com/tree-sitter/tree-sitter-python.git;
git clone https://github.com/tree-sitter/tree-sitter-java.git;
git clone https://github.com/tree-sitter/tree-sitter-c-sharp.git;
git clone https://github.com/tree-sitter/tree-sitter-typescript.git;
cd ..;
# Build tree-sitter
python scripts/build_ts_lib.py


================================================
FILE: scripts/build_ts_lib.py
================================================
#!/usr/bin/env python
# coding=utf-8
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

from tree_sitter import Language

def build_language_lib():
    for lang in ["java", "python", "typescript", "csharp"]:
        ts_lang = "c-sharp" if lang == "csharp" else lang
        if lang == "typescript":
            git_dir = f"ts_package/tree-sitter-{ts_lang}/{lang}"
        else:
            git_dir = f"ts_package/tree-sitter-{ts_lang}"
        Language.build_library(f'build/{lang}-lang-parser.so', [git_dir])


if __name__ == "__main__":
    build_language_lib()


================================================
FILE: scripts/custom_generate.py
================================================
# Modifications Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

##############################################################################
# Modified from https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/utils.py
##############################################################################

import copy
import inspect
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
from torch import nn

from transformers.deepspeed import is_deepspeed_zero3_enabled
from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from transformers.models.auto import (
    MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
    MODEL_FOR_CAUSAL_LM_MAPPING,
    MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
    MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
    MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from transformers.utils import ModelOutput, logging
from transformers.generation.beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from transformers.generation.beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import (
    EncoderNoRepeatNGramLogitsProcessor,
    EncoderRepetitionPenaltyLogitsProcessor,
    EpsilonLogitsWarper,
    EtaLogitsWarper,
    ExponentialDecayLengthPenalty,
    ForcedBOSTokenLogitsProcessor,
    ForcedEOSTokenLogitsProcessor,
    ForceTokensLogitsProcessor,
    HammingDiversityLogitsProcessor,
    InfNanRemoveLogitsProcessor,
    LogitNormalization,
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    MinNewTokensLengthLogitsProcessor,
    NoBadWordsLogitsProcessor,
    NoRepeatNGramLogitsProcessor,
    PrefixConstrainedLogitsProcessor,
    RepetitionPenaltyLogitsProcessor,
    SuppressTokensAtBeginLogitsProcessor,
    SuppressTokensLogitsProcessor,
    TemperatureLogitsWarper,
    TopKLogitsWarper,
    TopPLogitsWarper,
    TypicalLogitsWarper,
)
from transformers.generation.stopping_criteria import (
    MaxLengthCriteria,
    MaxTimeCriteria,
    StoppingCriteria,
    StoppingCriteriaList,
    validate_stopping_criteria,
)
from transformers.generation.utils import GenerateOutput, SampleOutput, SampleDecoderOnlyOutput

if TYPE_CHECKING:
    from transformers.modeling_utils import PreTrainedModel
    from transformers.generation.streamers import BaseStreamer

logger = logging.get_logger(__name__)


@torch.no_grad()
def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        generation_config: Optional[GenerationConfig] = None,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
        synced_gpus: Optional[bool] = None,
        assistant_model: Optional["PreTrainedModel"] = None,
        streamer: Optional["BaseStreamer"] = None,
        **kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
    r"""

    Generates sequences of token ids for models with a language modeling head.

    <Tip warning={true}>

    Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
    model's default generation configuration. You can override any `generation_config` by passing the corresponding
    parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.

    For an overview of generation strategies and code examples, check out the [following
    guide](../generation_strategies).

    </Tip>

    Parameters:
        inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
            The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
            method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
            should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
            `input_ids`, `input_values`, `input_features`, or `pixel_values`.
        generation_config (`~generation.GenerationConfig`, *optional*):
            The generation configuration to be used as base parametrization for the generation call. `**kwargs`
            passed to generate matching the attributes of `generation_config` will override them. If
            `generation_config` is not provided, the default will be used, which had the following loading
            priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
            configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
            default values, whose documentation should be checked to parameterize generation.
        logits_processor (`LogitsProcessorList`, *optional*):
            Custom logits processors that complement the default logits processors built from arguments and
            generation config. If a logit processor is passed that is already created with the arguments or a
            generation config an error is thrown. This feature is intended for advanced users.
        stopping_criteria (`StoppingCriteriaList`, *optional*):
            Custom stopping criteria that complement the default stopping criteria built from arguments and a
            generation config. If a stopping criteria is passed that is already created with the arguments or a
            generation config an error is thrown. This feature is intended for advanced users.
        prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
            If provided, this function constraints the beam search to allowed tokens only at each step. If not
            provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
            `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
            on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
            for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
            Retrieval](https://arxiv.org/abs/2010.00904).
        synced_gpus (`bool`, *optional*):
            Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
            `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
            generating before other GPUs. Otherwise it'll be set to `False`.
        assistant_model (`PreTrainedModel`, *optional*):
            An assistant model that can be used to accelerate generation. The assistant model must have the exact
            same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
            is much faster than running generation with the model you're calling generate from. As such, the
            assistant model should be much smaller.
        streamer (`BaseStreamer`, *optional*):
            Streamer object that will be used to stream the generated sequences. Generated tokens are passed
            through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
        kwargs:
            Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
            forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
            specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.

    Return:
        [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
        or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.

            If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
            [`~utils.ModelOutput`] types are:

                - [`~generation.GreedySearchDecoderOnlyOutput`],
                - [`~generation.SampleDecoderOnlyOutput`],
                - [`~generation.BeamSearchDecoderOnlyOutput`],
                - [`~generation.BeamSampleDecoderOnlyOutput`]

            If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
            [`~utils.ModelOutput`] types are:

                - [`~generation.GreedySearchEncoderDecoderOutput`],
                - [`~generation.SampleEncoderDecoderOutput`],
                - [`~generation.BeamSearchEncoderDecoderOutput`],
                - [`~generation.BeamSampleEncoderDecoderOutput`]
    """

    if synced_gpus is None:
        if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1:
            synced_gpus = True
        else:
            synced_gpus = False

    # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
    self._validate_model_class()

    # priority: `generation_config` argument > `model.generation_config` (the default generation config)
    if generation_config is None:
        # legacy: users may modify the model configuration to control generation -- update the generation config
        # model attribute accordingly, if it was created from the model config
        if self.generation_config._from_model_config:
            new_generation_config = GenerationConfig.from_model_config(self.config)
            if new_generation_config != self.generation_config:
                warnings.warn(
                    "You have modified the pretrained model configuration to control generation. This is a"
                    " deprecated strategy to control generation and will be removed soon, in a future version."
                    " Please use a generation configuration file (see"
                    " https://huggingface.co/docs/transformers/main_classes/text_generation)"
                )
                self.generation_config = new_generation_config
        generation_config = self.generation_config

    generation_config = copy.deepcopy(generation_config)
    model_kwargs = generation_config.update(**kwargs)  # All unused kwargs must be model kwargs
    generation_config.validate()
    self._validate_model_kwargs(model_kwargs.copy())

    # 2. Set generation parameters if not already defined
    logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
    stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

    if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
        if model_kwargs.get("attention_mask", None) is None:
            logger.warning(
                "The attention mask and the pad token id were not set. As a consequence, you may observe "
                "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
            )
        eos_token_id = generation_config.eos_token_id
        if isinstance(eos_token_id, list):
            eos_token_id = eos_token_id[0]
        logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
        generation_config.pad_token_id = eos_token_id

    # 3. Define model inputs
    # inputs_tensor has to be defined
    # model_input_name is defined if model-specific keyword input is passed
    # otherwise model_input_name is None
    # all model-specific keyword inputs are removed from `model_kwargs`
    inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
        inputs, generation_config.bos_token_id, model_kwargs
    )
    batch_size = inputs_tensor.shape[0]

    # 4. Define other model kwargs
    model_kwargs["output_attentions"] = generation_config.output_attentions
    model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
    model_kwargs["use_cache"] = generation_config.use_cache

    accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
    requires_attention_mask = "encoder_outputs" not in model_kwargs

    if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
        model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
            inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
        )

    # decoder-only models should use left-padding for generation
    if not self.config.is_encoder_decoder:
        if (
                generation_config.pad_token_id is not None
                and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
        ):
            logger.warning(
                "A decoder-only architecture is being used, but right-padding was detected! For correct "
                "generation results, please set `padding_side='left'` when initializing the tokenizer."
            )

    if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
        # if model is encoder decoder encoder_outputs are created
        # and added to `model_kwargs`
        model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
            inputs_tensor, model_kwargs, model_input_name
        )

    # 5. Prepare `input_ids` which will be used for auto-regressive generation
    if self.config.is_encoder_decoder:
        input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
            batch_size=batch_size,
            model_input_name=model_input_name,
            model_kwargs=model_kwargs,
            decoder_start_token_id=generation_config.decoder_start_token_id,
            bos_token_id=generation_config.bos_token_id,
            device=inputs_tensor.device,
        )
    else:
        input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")

    if streamer is not None:
        streamer.put(input_ids.cpu())

    # 6. Prepare `max_length` depending on other stopping criteria.
    input_ids_seq_length = input_ids.shape[-1]
    has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
    if has_default_max_length and generation_config.max_new_tokens is None:
        warnings.warn(
            f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
            "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
            " recommend using `max_new_tokens` to control the maximum length of the generation.",
            UserWarning,
        )
    elif generation_config.max_new_tokens is not None:
        if not has_default_max_length:
            logger.warning(
                f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
                f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
                "Please refer to the documentation for more information. "
                "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
            )
        generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length

    if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
        raise ValueError(
            f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
            f" the maximum length ({generation_config.max_length})"
        )
    if input_ids_seq_length >= generation_config.max_length:
        input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
        logger.warning(
            f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
            f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
            " increasing `max_new_tokens`."
        )

    # 7. determine generation mode
    is_constraint_gen_mode = (
            generation_config.constraints is not None or generation_config.force_words_ids is not None
    )

    is_contrastive_search_gen_mode = (
            (generation_config.num_beams == 1)
            and generation_config.top_k is not None
            and generation_config.top_k > 1
            and generation_config.do_sample is False
            and generation_config.penalty_alpha is not None
            and generation_config.penalty_alpha > 0
    )

    is_greedy_gen_mode = (
            (generation_config.num_beams == 1)
            and (generation_config.num_beam_groups == 1)
            and generation_config.do_sample is False
            and not is_constraint_gen_mode
            and not is_contrastive_search_gen_mode
    )
    is_sample_gen_mode = (
            (generation_config.num_beams == 1)
            and (generation_config.num_beam_groups == 1)
            and generation_config.do_sample is True
            and not is_constraint_gen_mode
            and not is_contrastive_search_gen_mode
    )
    is_beam_gen_mode = (
            (generation_config.num_beams > 1)
            and (generation_config.num_beam_groups == 1)
            and generation_config.do_sample is False
            and not is_constraint_gen_mode
            and not is_contrastive_search_gen_mode
    )
    is_beam_sample_gen_mode = (
            (generation_config.num_beams > 1)
            and (generation_config.num_beam_groups == 1)
            and generation_config.do_sample is True
            and not is_constraint_gen_mode
            and not is_contrastive_search_gen_mode
    )
    is_group_beam_gen_mode = (
            (generation_config.num_beams > 1)
            and (generation_config.num_beam_groups > 1)
            and not is_constraint_gen_mode
            and not is_contrastive_search_gen_mode
    )
    is_assisted_gen_mode = False
    if assistant_model is not None:
        if not (is_greedy_gen_mode or is_sample_gen_mode):
            raise ValueError(
                "You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
                "is only supported with Greedy Search and Sample."
            )
        is_assisted_gen_mode = True

    if generation_config.num_beam_groups > generation_config.num_beams:
        raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
    if is_group_beam_gen_mode and generation_config.do_sample is True:
        raise ValueError(
            "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
        )

    if streamer is not None and (generation_config.num_beams > 1):
        raise ValueError(
            "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
        )

    if self.device.type != input_ids.device.type:
        warnings.warn(
            "You are calling .generate() with the `input_ids` being on a device type different"
            f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
            f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
            " Please make sure that you have put `input_ids` to the"
            f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
            " running `.generate()`.",
            UserWarning,
        )

    # 8. prepare distribution pre_processing samplers
    logits_processor = self._get_logits_processor(
        generation_config=generation_config,
        input_ids_seq_length=input_ids_seq_length,
        encoder_input_ids=inputs_tensor,
        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
        logits_processor=logits_processor,
    )

    # 9. prepare stopping criteria
    stopping_criteria = self._get_stopping_criteria(
        generation_config=generation_config, stopping_criteria=stopping_criteria
    )
    # 10. go into different generation modes
    if is_assisted_gen_mode:
        if generation_config.num_return_sequences > 1:
            raise ValueError(
                "num_return_sequences has to be 1 when doing assisted generate, "
                f"but is {generation_config.num_return_sequences}."
            )
        if batch_size > 1:
            raise ValueError("assisted generate is only supported for batch_size = 1")
        if not model_kwargs["use_cache"]:
            raise ValueError("assisted generate requires `use_cache=True`")

        # 11. If the assistant model is an encoder-decoder, prepare its encoder outputs
        if assistant_model.config.is_encoder_decoder:
            assistant_model_kwargs = copy.deepcopy(model_kwargs)
            inputs_tensor, model_input_name, assistant_model_kwargs = assistant_model._prepare_model_inputs(
                inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_model_kwargs
            )
            assistant_model_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation(
                inputs_tensor, assistant_model_kwargs, model_input_name
            )
            model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs["encoder_outputs"]

        # 12. run assisted generate
        return self.assisted_decoding(
            input_ids,
            assistant_model=assistant_model,
            do_sample=generation_config.do_sample,
            logits_processor=logits_processor,
            logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None,
            stopping_criteria=stopping_criteria,
            pad_token_id=generation_config.pad_token_id,
            eos_token_id=generation_config.eos_token_id,
            output_scores=generation_config.output_scores,
            return_dict_in_generate=generation_config.return_dict_in_generate,
            synced_gpus=synced_gpus,
            streamer=streamer,
            **model_kwargs,
        )
    if is_greedy_gen_mode:
        if generation_config.num_return_sequences > 1:
            raise ValueError(
                "num_return_sequences has to be 1 when doing greedy search, "
                f"but is {generation_config.num_return_sequences}."
            )

        # 11. run greedy search
        return self.greedy_search(
            input_ids,
            logits_processor=logits_processor,
            stopping_criteria=stopping_criteria,
            pad_token_id=generation_config.pad_token_id,
            eos_token_id=generation_config.eos_token_id,
            output_scores=generation_config.output_scores,
            return_dict_in_generate=generation_config.return_dict_in_generate,
            synced_gpus=synced_gpus,
            streamer=streamer,
            **model_kwargs,
        )

    elif is_contrastive_search_gen_mode:
        if generation_config.num_return_sequences > 1:
            raise ValueError(
                "num_return_sequences has to be 1 when doing contrastive search, "
                f"but is {generation_config.num_return_sequences}."
            )
        if not model_kwargs["use_cache"]:
            raise ValueError("Contrastive search requires `use_cache=True`")

        return self.contrastive_search(
            input_ids,
            top_k=generation_config.top_k,
            penalty_alpha=generation_config.penalty_alpha,
            logits_processor=logits_processor,
            stopping_criteria=stopping_criteria,
            pad_token_id=generation_config.pad_token_id,
            eos_token_id=generation_config.eos_token_id,
            output_scores=generation_config.output_scores,
            return_dict_in_generate=generation_config.return_dict_in_generate,
            synced_gpus=synced_gpus,
            streamer=streamer,
            **model_kwargs,
        )

    elif is_sample_gen_mode:
        # 11. prepare logits warper
        logits_warper = self._get_logits_warper(generation_config)

        # 12. expand input_ids with `num_return_sequences` additional sequences per batch
        input_ids, model_kwargs = self._expand_inputs_for_generation(
            input_ids=input_ids,
            expand_size=generation_config.num_return_sequences,
            is_encoder_decoder=self.config.is_encoder_decoder,
            **model_kwargs,
        )

        # 13. run sample
        return sample(
            self,
            input_ids,
            logits_processor=logits_processor,
            logits_warper=logits_warper,
            stopping_criteria=stopping_criteria,
            pad_token_id=generation_config.pad_token_id,
            eos_token_id=generation_config.eos_token_id,
            output_scores=generation_config.output_scores,
            return_dict_in_generate=generation_config.return_dict_in_generate,
            synced_gpus=synced_gpus,
            streamer=streamer,
            **model_kwargs,
        )

    elif is_beam_gen_mode:
        if generation_config.num_return_sequences > generation_config.num_beams:
            raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")

        if stopping_criteria.max_length is None:
            raise ValueError("`max_length` needs to be a stopping_criteria for now.")

        # 11. prepare beam search scorer
        beam_scorer = BeamSearchScorer(
            batch_size=batch_size,
            num_beams=generation_config.num_beams,
            device=inputs_tensor.device,
            length_penalty=generation_config.length_penalty,
            do_early_stopping=generation_config.early_stopping,
            num_beam_hyps_to_keep=generation_config.num_return_sequences,
            max_length=generation_config.max_length,
        )
        # 12. interleave input_ids with `num_beams` additional sequences per batch
        input_ids, model_kwargs = self._expand_inputs_for_generation(
            input_ids=input_ids,
            expand_size=generation_config.num_beams,
            is_encoder_decoder=self.config.is_encoder_decoder,
            **model_kwargs,
        )
        # 13. run beam search
        return self.beam_search(
            input_ids,
            beam_scorer,
            logits_processor=logits_processor,
            stopping_criteria=stopping_criteria,
            pad_token_id=generation_config.pad_token_id,
            eos_token_id=generation_config.eos_token_id,
            output_scores=generation_config.output_scores,
            return_dict_in_generate=generation_config.return_dict_in_generate,
            synced_gpus=synced_gpus,
            **model_kwargs,
        )

    elif is_beam_sample_gen_mode:
        # 11. prepare logits warper
        logits_warper = self._get_logits_warper(generation_config)

        if stopping_criteria.max_length is None:
            raise ValueError("`max_length` needs to be a stopping_criteria for now.")
        # 12. prepare beam search scorer
        beam_scorer = BeamSearchScorer(
            batch_size=batch_size * generation_config.num_return_sequences,
            num_beams=generation_config.num_beams,
            device=inputs_tensor.device,
            length_penalty=generation_config.length_penalty,
            do_early_stopping=generation_config.early_stopping,
            max_length=generation_config.max_length,
        )

        # 13. interleave input_ids with `num_beams` additional sequences per batch
        input_ids, model_kwargs = self._expand_inputs_for_generation(
            input_ids=input_ids,
            expand_size=generation_config.num_beams * generation_config.num_return_sequences,
            is_encoder_decoder=self.config.is_encoder_decoder,
            **model_kwargs,
        )

        # 14. run beam sample
        return self.beam_sample(
            input_ids,
            beam_scorer,
            logits_processor=logits_processor,
            logits_warper=logits_warper,
            stopping_criteria=stopping_criteria,
            pad_token_id=generation_config.pad_token_id,
            eos_token_id=generation_config.eos_token_id,
            output_scores=generation_config.output_scores,
            return_dict_in_generate=generation_config.return_dict_in_generate,
            synced_gpus=synced_gpus,
            **model_kwargs,
        )

    elif is_group_beam_gen_mode:
        if generation_config.num_return_sequences > generation_config.num_beams:
            raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")

        if generation_config.num_beams % generation_config.num_beam_groups != 0:
            raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")

        if stopping_criteria.max_length is None:
            raise ValueError("`max_length` needs to be a stopping_criteria for now.")

        has_default_typical_p = kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
        if not has_default_typical_p:
            raise ValueError("Decoder argument `typical_p` is not supported with beam groups.")

        # 11. prepare beam search scorer
        beam_scorer = BeamSearchScorer(
            batch_size=batch_size,
            num_beams=generation_config.num_beams,
            device=inputs_tensor.device,
            length_penalty=generation_config.length_penalty,
            do_early_stopping=generation_config.early_stopping,
            num_beam_hyps_to_keep=generation_config.num_return_sequences,
            num_beam_groups=generation_config.num_beam_groups,
            max_length=generation_config.max_length,
        )
        # 12. interleave input_ids with `num_beams` additional sequences per batch
        input_ids, model_kwargs = self._expand_inputs_for_generation(
            input_ids=input_ids,
            expand_size=generation_config.num_beams,
            is_encoder_decoder=self.config.is_encoder_decoder,
            **model_kwargs,
        )
        # 13. run beam search
        return self.group_beam_search(
            input_ids,
            beam_scorer,
            logits_processor=logits_processor,
            stopping_criteria=stopping_criteria,
            pad_token_id=generation_config.pad_token_id,
            eos_token_id=generation_config.eos_token_id,
            output_scores=generation_config.output_scores,
            return_dict_in_generate=generation_config.return_dict_in_generate,
            synced_gpus=synced_gpus,
            **model_kwargs,
        )

    elif is_constraint_gen_mode:
        if generation_config.num_return_sequences > generation_config.num_beams:
            raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")

        if stopping_criteria.max_length is None:
            raise ValueError("`max_length` needs to be a stopping_criteria for now.")

        if generation_config.num_beams <= 1:
            raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.")

        if generation_config.do_sample:
            raise ValueError("`do_sample` needs to be false for constrained generation.")

        if generation_config.num_beam_groups is not None and generation_config.num_beam_groups > 1:
            raise ValueError("`num_beam_groups` not supported yet for constrained generation.")

        final_constraints = []
        if generation_config.constraints is not None:
            final_constraints = generation_config.constraints

        if generation_config.force_words_ids is not None:

            def typeerror():
                raise ValueError(
                    "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
                    f"of positive integers, but is {generation_config.force_words_ids}."
                )

            if (
                    not isinstance(generation_config.force_words_ids, list)
                    or len(generation_config.force_words_ids) == 0
            ):
                typeerror()

            for word_ids in generation_config.force_words_ids:
                if isinstance(word_ids[0], list):
                    if not isinstance(word_ids, list) or len(word_ids) == 0:
                        typeerror()
                    if any(not isinstance(token_ids, list) for token_ids in word_ids):
                        typeerror()
                    if any(
                            any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
                            for token_ids in word_ids
                    ):
                        typeerror()

                    constraint = DisjunctiveConstraint(word_ids)
                else:
                    if not isinstance(word_ids, list) or len(word_ids) == 0:
                        typeerror()
                    if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
                        typeerror()

                    constraint = PhrasalConstraint(word_ids)
                final_constraints.append(constraint)

        # 11. prepare beam search scorer
        constrained_beam_scorer = ConstrainedBeamSearchScorer(
            constraints=final_constraints,
            batch_size=batch_size,
            num_beams=generation_config.num_beams,
            device=inputs_tensor.device,
            length_penalty=generation_config.length_penalty,
            do_early_stopping=generation_config.early_stopping,
            num_beam_hyps_to_keep=generation_config.num_return_sequences,
            max_length=generation_config.max_length,
        )
        # 12. interleave input_ids with `num_beams` additional sequences per batch
        input_ids, model_kwargs = self._expand_inputs_for_generation(
            input_ids=input_ids,
            expand_size=generation_config.num_beams,
            is_encoder_decoder=self.config.is_encoder_decoder,
            **model_kwargs,
        )
        # 13. run beam search
        return self.constrained_beam_search(
            input_ids,
            constrained_beam_scorer=constrained_beam_scorer,
            logits_processor=logits_processor,
            stopping_criteria=stopping_criteria,
            pad_token_id=generation_config.pad_token_id,
            eos_token_id=generation_config.eos_token_id,
            output_scores=generation_config.output_scores,
            return_dict_in_generate=generation_config.return_dict_in_generate,
            synced_gpus=synced_gpus,
            **model_kwargs,
        )


def sample(
        self,
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        logits_warper: Optional[LogitsProcessorList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: bool = False,
        streamer: Optional["BaseStreamer"] = None,
        **model_kwargs,
) -> Union[SampleOutput, torch.LongTensor]:
    r"""
    Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
    can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.

    <Tip warning={true}>

    In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead.
    For an overview of generation strategies and code examples, check the [following
    guide](../generation_strategies).

    </Tip>

    Parameters:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            The sequence used as a prompt for the generation.
        logits_processor (`LogitsProcessorList`, *optional*):
            An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
            used to modify the prediction scores of the language modeling head applied at each generation step.
        stopping_criteria (`StoppingCriteriaList`, *optional*):
            An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
            used to tell if the generation loop should stop.
        logits_warper (`LogitsProcessorList`, *optional*):
            An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
            to warp the prediction score distribution of the language modeling head applied before multinomial
            sampling at each generation step.
        max_length (`int`, *optional*, defaults to 20):
            **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
            tokens. The maximum length of the sequence to be generated.
        pad_token_id (`int`, *optional*):
            The id of the *padding* token.
        eos_token_id (`Union[int, List[int]]`, *optional*):
            The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
        output_attentions (`bool`, *optional*, defaults to `False`):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under
            returned tensors for more details.
        output_hidden_states (`bool`, *optional*, defaults to `False`):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
            for more details.
        output_scores (`bool`, *optional*, defaults to `False`):
            Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
        return_dict_in_generate (`bool`, *optional*, defaults to `False`):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        synced_gpus (`bool`, *optional*, defaults to `False`):
            Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
        streamer (`BaseStreamer`, *optional*):
            Streamer object that will be used to stream the generated sequences. Generated tokens are passed
            through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
        model_kwargs:
            Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
            an encoder-decoder model the kwargs should include `encoder_outputs`.

    Return:
        [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`:
        A `torch.LongTensor` containing the generated tokens (default behaviour) or a
        [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
        `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if
        `model.config.is_encoder_decoder=True`.

    Examples:

    ```python
    >>> from transformers import (
    ...     AutoTokenizer,
    ...     AutoModelForCausalLM,
    ...     LogitsProcessorList,
    ...     MinLengthLogitsProcessor,
    ...     TopKLogitsWarper,
    ...     TemperatureLogitsWarper,
    ...     StoppingCriteriaList,
    ...     MaxLengthCriteria,
    ... )
    >>> import torch

    >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
    >>> model = AutoModelForCausalLM.from_pretrained("gpt2")

    >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
    >>> model.config.pad_token_id = model.config.eos_token_id
    >>> model.generation_config.pad_token_id = model.config.eos_token_id

    >>> input_prompt = "Today is a beautiful day, and"
    >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids

    >>> # instantiate logits processors
    >>> logits_processor = LogitsProcessorList(
    ...     [
    ...         MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id),
    ...     ]
    ... )
    >>> # instantiate logits processors
    >>> logits_warper = LogitsProcessorList(
    ...     [
    ...         TopKLogitsWarper(50),
    ...         TemperatureLogitsWarper(0.7),
    ...     ]
    ... )

    >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])

    >>> torch.manual_seed(0)  # doctest: +IGNORE_RESULT
    >>> outputs = model.sample(
    ...     input_ids,
    ...     logits_processor=logits_processor,
    ...     logits_warper=logits_warper,
    ...     stopping_criteria=stopping_criteria,
    ... )

    >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
    ['Today is a beautiful day, and we must do everything possible to make it a day of celebration.']
    ```"""
    # init values
    logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
    stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
    if max_length is not None:
        warnings.warn(
            "`max_length` is deprecated in this function, use"
            " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
            UserWarning,
        )
        stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
    logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
    pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
    eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
    if isinstance(eos_token_id, int):
        eos_token_id = [eos_token_id]
    eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
    output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
    output_attentions = (
        output_attentions if output_attentions is not None else self.generation_config.output_attentions
    )
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
    )
    return_dict_in_generate = (
        return_dict_in_generate
        if return_dict_in_generate is not None
        else self.generation_config.return_dict_in_generate
    )

    # init attention / hidden states / scores tuples
    scores = () if (return_dict_in_generate and output_scores) else None
    decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
    cross_attentions = () if (return_dict_in_generate and output_attentions) else None
    decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

    # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
    if return_dict_in_generate and self.config.is_encoder_decoder:
        encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
        encoder_hidden_states = (
            model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
        )

    # keep track of which sequences are already finished
    unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)

    this_peer_finished = False  # used by synced_gpus only
    # auto-regressive generation
    while True:
        if synced_gpus:
            # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
            # The following logic allows an early break if all peers finished generating their sequence
            this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
            # send 0.0 if we finished, 1.0 otherwise
            dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
            # did all peers finish? the reduced sum will be 0.0 then
            if this_peer_finished_flag.item() == 0.0:
                break

        # prepare model inputs
        model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

        # forward pass to get next token
        outputs = self(
            **model_inputs,
            return_dict=True,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        if synced_gpus and this_peer_finished:
            continue  # don't waste resources running the code we don't need

        next_token_logits = outputs.logits[:, -1, :]

        # pre-process distribution
        next_token_scores = logits_processor(input_ids, next_token_logits)
        next_token_scores = logits_warper(input_ids, next_token_scores)

        # Store scores, attentions and hidden_states when required
        if return_dict_in_generate:
            if output_scores:
                # scores += (next_token_scores,)
                scores += (next_token_logits,)
            if output_attentions:
                decoder_attentions += (
                    (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                )
                if self.config.is_encoder_decoder:
                    cross_attentions += (outputs.cross_attentions,)

            if output_hidden_states:
                decoder_hidden_states += (
                    (outputs.decoder_hidden_states,)
                    if self.config.is_encoder_decoder
                    else (outputs.hidden_states,)
                )

        # sample
        probs = nn.functional.softmax(next_token_scores, dim=-1)
        if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
            probs = torch.nan_to_num(probs)
        next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

        # finished sentences should have their next token be a padding token
        if eos_token_id is not None:
            if pad_token_id is None:
                raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
            next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

        # update generated ids, model inputs, and length for next step
        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
        if streamer is not None:
            streamer.put(next_tokens.cpu())
        model_kwargs = self._update_model_kwargs_for_generation(
            outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
        )

        # if eos_token was found in one sentence, set sentence to finished
        if eos_token_id_tensor is not None:
            unfinished_sequences = unfinished_sequences.mul(
                next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
            )

            # stop when each sentence is finished
            if unfinished_sequences.max() == 0:
                this_peer_finished = True

        # stop if we exceed the maximum length
        if stopping_criteria(input_ids, scores):
            this_peer_finished = True

        if this_peer_finished and not synced_gpus:
            break

    if streamer is not None:
        streamer.end()

    if return_dict_in_generate:
        if self.config.is_encoder_decoder:
            return SampleEncoderDecoderOutput(
                sequences=input_ids,
                scores=scores,
                encoder_attentions=encoder_attentions,
                encoder_hidden_states=encoder_hidden_states,
                decoder_attentions=decoder_attentions,
                cross_attentions=cross_attentions,
                decoder_hidden_states=decoder_hidden_states,
            )
        else:
            return SampleDecoderOnlyOutput(
                sequences=input_ids,
                scores=scores,
                attentions=decoder_attentions,
                hidden_states=decoder_hidden_states,
            )
    else:
        return input_ids


================================================
FILE: scripts/eval.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
import json
import logging
import os

import numpy as np
import torch
from accelerate import Accelerator
from accelerate.utils import set_seed
from datasets import load_dataset
from torch.utils.data import DataLoader, SequentialSampler
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM
)

import custom_generate
from eval_metric import compute_metric_stmt
from eval_utils import compute_mean_logp

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)

logger = logging.getLogger(__name__)

COMMENT_SYMBOL = {
    "python": "#",
    "java": "//",
    "csharp": "//",
    "typescript": "//"
}


def custom_data_collator(features):
    first = features[0]
    batch = {}
    for k, v in first.items():
        if v is not None and not isinstance(v, str):
            if isinstance(v, torch.Tensor):
                batch[k] = torch.stack([f[k] for f in features])
            elif isinstance(v, np.ndarray):
                batch[k] = torch.tensor(np.stack([f[k] for f in features]))
            else:
                batch[k] = torch.tensor([f[k] for f in features])
        if v is not None and isinstance(v, str):
            batch[k] = [f[k] for f in features]

    return batch


def build_datasets(args, tokenizer):
    # Initialize the model and tokenizer
    # when generating, we will use the logits of right-most token to predict the next token
    # so the padding should be on the left
    tokenizer.padding_side = "left"
    tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else tokenizer.bos_token

    # load the files into Dataset
    raw_datasets = load_dataset("json", data_files=args.prompt_file, cache_dir=args.cache_dir)
    raw_datasets = raw_datasets["train"]
    raw_datasets = raw_datasets.map(lambda example, idx: {'index': idx, **example}, with_indices=True)
    index2taskid = {idx: md["task_id"] for idx, md in zip(raw_datasets["index"], raw_datasets["metadata"])}
    column_names = raw_datasets.column_names

    # Prompt composition
    def prepare_features(examples):
        tokenizer.truncation_side = "left"
        tokenized_inputs = tokenizer(
            examples["prompt"],
            padding="max_length",
            truncation=True,
            max_length=args.max_seq_length - args.gen_length
        )

        features = {k: t for k, t in tokenized_inputs.items()}
        features["index"] = examples["index"]
        return features

    def prepare_features_cfc(examples):
        max_prompt_length = args.max_seq_length - args.gen_length
        use_key = "list"

        crossfile_context = []
        if use_key == "text":
            crossfile_context = [ex["text"] for ex in examples["crossfile_context"]]
        else:
            ls_sym = COMMENT_SYMBOL[args.language]
            num_chunk_inc_prompt = []
            augmented_prompt = 0
            for cfc_chunks in examples["crossfile_context"]:
                cfc_chunks = cfc_chunks["list"]  # a list of dict
                cfc_text = ""
                if cfc_chunks:
                    # at least 1 relevant cfc_chunk found
                    init_cfc_text = f"{ls_sym} Here are some relevant code fragments from other files of the repo:\n\n"
                    cfc_length = len(tokenizer.tokenize(init_cfc_text))
                    num_chunk_inc = 0
                    for cfc_idx, cfc_chunk in enumerate(cfc_chunks):
                        if cfc_chunk["score"] > args.min_cfc_score:
                            add_text = f"{ls_sym} the below code fragment is found in {cfc_chunk['filename']}" + "\n"
                            cfc_lines = cfc_chunk["retrieved_chunk"].split('\n')
                            add_text += "\n".join([f"{ls_sym} {cl}" for cl in cfc_lines if cl]) + "\n\n"
                            # check if adding chunk exceeds max length budget for CFC
                            add_text_len = len(tokenizer.tokenize(add_text))
                            if cfc_length + add_text_len <= args.cfc_seq_length:
                                cfc_text += add_text
                                cfc_length += add_text_len
                                num_chunk_inc += 1
                            else:
                                break
                    num_chunk_inc_prompt.append(num_chunk_inc)
                    if num_chunk_inc > 0:
                        cfc_text = init_cfc_text + cfc_text
                        augmented_prompt += 1
                crossfile_context.append(cfc_text)

            logger.info(
                f"{augmented_prompt} out of {len(examples['crossfile_context'])} prompts are augmented with cross-file context.")

        tokenizer.truncation_side = "right"
        crossfile_features = tokenizer(
            crossfile_context,
            truncation=True,
            max_length=args.cfc_seq_length
        )

        features = {"input_ids": [], "attention_mask": []}
        tokenizer.truncation_side = "left"
        for idx, prompt in enumerate(examples["prompt"]):
            allowed_prompt_length = max_prompt_length - len(crossfile_features["input_ids"][idx])
            prompt_feats = tokenizer(
                [prompt],
                truncation=True,
                max_length=allowed_prompt_length
            )
            for k, v in prompt_feats.items():
                features[k].append(crossfile_features[k][idx] + prompt_feats[k][0])

        # pad to max_seq_length
        tokenizer.padding_side = "left"
        features = tokenizer.pad(features, padding="max_length", max_length=args.max_seq_length - args.gen_length)
        features["index"] = examples["index"]
        return features

    if args.model_type in ["codelm", "seq2seqlm"]:
        tokenized_datasets = raw_datasets.map(
            prepare_features,
            batched=True,
            num_proc=args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not args.overwrite_cache,
            desc="Running tokenizer on dataset",
        )
    elif args.model_type == "codelm_cfc":
        tokenized_datasets = raw_datasets.map(
            prepare_features_cfc,
            batched=True,
            num_proc=args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not args.overwrite_cache,
            desc="Running tokenizer on dataset",
        )
    else:
        raise NotImplementedError("prepare feature functions not implemented for new model type")

    return tokenized_datasets, index2taskid


def model_inference(tokenized_datasets, index2taskid, tokenizer):
    if args.dtype == 'fp16':
        dtype = torch.float16
    elif args.dtype == 'fp32':
        dtype = torch.float32
    elif args.dtype == 'bf16':
        dtype = torch.bfloat16
    elif args.dtype == 'int8':
        dtype = torch.int8
    else:
        assert False, f'{args.dtype=} not implemented'

    if args.model_type in ["codelm", "codelm_cfc"]:
        model = AutoModelForCausalLM.from_pretrained(
            args.model_name_or_path,
            torch_dtype=dtype,
            trust_remote_code=True,
            revision="main"
        )
    else:
        raise ValueError("Unknown model type")

    total_samples_cnt = len(tokenized_datasets)
    logger.info(f"total samples: {total_samples_cnt}")

    data_sampler = SequentialSampler(tokenized_datasets)
    dataloader = DataLoader(
        tokenized_datasets,
        sampler=data_sampler,
        collate_fn=custom_data_collator,
        batch_size=args.batch_size
    )

    model = accelerator.prepare_model(model)
    dataloader = accelerator.prepare_data_loader(dataloader)

    if not os.path.isdir(args.output_dir):
        os.mkdir(args.output_dir)

    tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else tokenizer.bos_token
    prompt_length = args.max_seq_length - args.gen_length

    @torch.no_grad()
    def generate_completions(batch):
        output_dict = custom_generate.generate(
            accelerator.unwrap_model(model),
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            max_length=args.max_seq_length,
            temperature=args.temperature,
            top_k=args.top_k,
            top_p=args.top_p,
            do_sample=args.do_sample,
            num_beams=args.num_beams,
            num_return_sequences=1,
            pad_token_id=tokenizer.pad_token_id,
            return_dict_in_generate=True,
            output_scores=True
        )
        batch_task_id = batch["index"]
        batch_pred = accelerator.pad_across_processes(
            output_dict.sequences, dim=1, pad_index=tokenizer.pad_token_id
        )
        scores = torch.stack(output_dict.scores, dim=1)
        batch_scores = accelerator.pad_across_processes(
            scores, dim=1, pad_index=tokenizer.pad_token_id
        )
        # batch_scores.shape = (batch_size x num_gpus x num_return_sequences, max_length)
        batch_task_id, batch_pred, batch_scores = accelerator.gather((batch_task_id, batch_pred, batch_scores))

        batch_pred = batch_pred[:, prompt_length:]
        generated_texts = tokenizer.batch_decode(batch_pred, skip_special_tokens=True)

        mean_logp = compute_mean_logp(batch_scores, batch_pred, tokenizer.pad_token_id)
        return batch_task_id.tolist(), generated_texts, mean_logp

    all_preds = []
    all_task_ids = []
    with torch.no_grad():
        for idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
            completions = None
            completion_scores = None
            for seq_idx in range(args.num_return_sequences):
                batch_task_id, generated_texts, mean_logp = generate_completions(batch)
                if seq_idx == 0:
                    all_task_ids.extend(batch_task_id)
                    batch_size = len(batch_task_id)
                    completions = [[] for _ in range(batch_size)]
                    completion_scores = [[] for _ in range(batch_size)]

                for j in range(batch_size):
                    completions[j].append(generated_texts[j])
                    completion_scores[j].append(mean_logp[j])

            if args.num_return_sequences == 1:
                all_preds.extend([c[0] for c in completions])
            else:
                for c, cs in zip(completions, completion_scores):
                    max_score = max(cs)
                    max_index = cs.index(max_score)
                    all_preds.append(c[max_index])

    with open(f"{args.output_dir}/prediction.jsonl", "w", encoding="utf-8") as f_pred:
        id_processed = set()
        for idx, p in zip(all_task_ids, all_preds):
            if index2taskid[idx] not in id_processed:
                f_pred.write(json.dumps({"task_id": index2taskid[idx], "pred": p}) + "\n")
                id_processed.add(index2taskid[idx])


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # model inference args
    parser.add_argument("--language", type=str, required=True, help="language name")
    parser.add_argument("--model_name_or_path", default=None, type=str, help="Pre-trained Model Path")
    parser.add_argument(
        "--model_type",
        type=str,
        default="codelm",
        choices=["codelm", "codelm_cfc"],
        help="Model type to be loaded"
    )
    parser.add_argument("--prompt_file", type=str, default=None, help="file with a list of prompts")
    parser.add_argument("--gen_length", type=int, default=50, help="max length of generated token sequence")
    parser.add_argument("--max_seq_length", type=int, default=2048, help="max length of prompt")
    parser.add_argument(
        "--cfc_seq_length",
        type=int,
        default=512,
        help="For model_type=codelm_cfc: Text sequence length corresponding to the retrieved nodes"
    )
    parser.add_argument(
        "--min_cfc_score",
        type=float,
        default=float('-inf'),
        help="For model_type=codelm_cfc: min score of a chunk to be considered as CFC chunk"
    )
    parser.add_argument("--batch_size", type=int, default=32, help="batch size for code completion")
    parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
    parser.add_argument("--cache_dir", type=str, default=None)
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.2,
        help="temperature of 1.0 has no effect, lower tend toward greedy sampling"
    )
    parser.add_argument("--output_dir", type=str, default="output_dir", help="output directory to save predictions")
    parser.add_argument("--top_k", type=int, default=0)
    parser.add_argument("--top_p", type=float, default=0.95)
    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
    parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
    parser.add_argument("--repetition_penalty", type=float, default=1.0, help="The parameter for repetition penalty.")
    parser.add_argument(
        "--preprocessing_num_workers",
        type=int,
        default=1,
        help="The number of processes to use for the preprocessing."
    )
    parser.add_argument(
        "--overwrite_cache",
        type=bool,
        default=False,
        help="Overwrite the cached training and evaluation sets"
    )
    parser.add_argument("--dtype", type=str, default='bf16')
    parser.add_argument("--do_sample", action="store_true", help="whether we do sampling or greedy/beam-search")
    parser.add_argument("--num_beams", type=int, default=1, help="num of beam for beam-search")
    # compute metric args
    parser.add_argument(
        "--ts_lib",
        type=str,
        default="build/python-lang-parser.so",
        help="tree-sitter lib for tokenize code"
    )
    # only compute metric
    parser.add_argument("--only_compute_metric", action="store_true", help="only compute metric")
    args = parser.parse_args()
    set_seed(args.seed, device_specific=False)

    if args.num_return_sequences > 1:
        assert args.do_sample, "sampling must be set to True when num_return_sequences > 1"

    accelerator = Accelerator()
    if not args.only_compute_metric:
        tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
        tokenized_datasets, index2taskid = build_datasets(args, tokenizer)
        model_inference(tokenized_datasets, index2taskid, tokenizer)

    # check if the process is the main process
    if accelerator.is_main_process:
        compute_metric_stmt(args)


================================================
FILE: scripts/eval_metric.py
================================================
import json
from functools import partial

import torch.multiprocessing as mp
from tqdm import tqdm
from tree_sitter import Language, Parser

from eval_utils import (
    postprocess_code_lines,
    extract_identifiers,
    cal_edit_sim,
    remove_comments
)
import os

parser = None


def compute_id_match(pred_ids, target_ids):
    pred_ids = list(set(pred_ids))
    target_ids = list(set(target_ids))
    tp = 0
    fp = 0
    fn = 0
    for pid in pred_ids:
        if pid in target_ids:
            tp += 1
        else:
            fp += 1
    for tid in target_ids:
        if tid not in pred_ids:
            fn += 1
    return tp, fp, fn


def compute_edit_sim(samples):
    refs, hyps = [], []
    for s in samples:
        refs.append(s["target"])
        hyps.append(s["pred"])
    return cal_edit_sim(refs, hyps)


def process_examples(lang, args):
    sample, ex = args
    global parser

    prediction = postprocess_code_lines(ex["prompt"], sample["pred"], parser, lang)
    prediction = remove_comments(prediction)
    target = ex["groundtruth"]
    target = remove_comments(target)

    pred_lines = [l.strip() for l in prediction.split("\n") if l.strip()]
    gt_lines = [l.strip() for l in target.split("\n") if l.strip()]
    em_label = int(pred_lines == gt_lines)

    pred_ids = extract_identifiers(prediction, lang)
    target_ids = extract_identifiers(target, lang)

    trunc_s = {
        "task_id": sample["task_id"],
        "pred": prediction,
        "target": target,
        "pred_ids": pred_ids,
        "target_ids": target_ids
    }
    return trunc_s, em_label


def compute_metric_stmt(args):
    with open(os.path.join(args.output_dir, "prediction.jsonl"), "r") as f_pred:
        samples = []
        for l in f_pred.readlines():
            samples.append(json.loads(l))

    examples = {}
    with open(args.prompt_file, "r") as f_in:
        for l in f_in.readlines():
            ex = json.loads(l)
            examples[ex["metadata"]["task_id"]] = {
                "prompt": ex["prompt"],
                "groundtruth": ex["groundtruth"]
            }

    assert len(samples) == len(examples), f"{len(samples)} != {len(examples)}"

    global parser
    ts_lang = "c_sharp" if args.language == "csharp" else args.language
    language = Language(args.ts_lib, ts_lang)
    parser = Parser()
    parser.set_language(language)

    truncated_samples = []
    em_labels = []

    print("post-processing samples ...")
    pool = mp.Pool(mp.cpu_count() - 1)
    worker = partial(process_examples, args.language)

    with tqdm(total=len(samples)) as pbar:
        for output in pool.imap_unordered(worker, zip(samples, [examples[s["task_id"]] for s in samples])):
            trunc_s, em_label = output
            em_labels.append(em_label)
            truncated_samples.append(trunc_s)
            pbar.update()

    exact_match = 0
    with open(os.path.join(args.output_dir, "prediction_truncated.jsonl"), 'w', encoding="utf-8") as pt, \
            open(f"{args.output_dir}/exact_match_idx.jsonl", 'w') as em:
        for trunc_s, em_label in zip(truncated_samples, em_labels):
            pt.write(json.dumps(trunc_s) + "\n")
            if em_label == 1:
                em.write(f'{trunc_s["task_id"]}\n')
                exact_match += 1

    ### Score calculation

    id_em = []
    edit_similarities = []
    detailed_results = []

    for idx, trunc_s in enumerate(truncated_samples):
        identifier_em = int(trunc_s["pred_ids"] == trunc_s["target_ids"])
        es = cal_edit_sim([trunc_s["target"]], [trunc_s["pred"]])
        id_tp, id_fp, id_fn = compute_id_match(trunc_s["pred_ids"], trunc_s["target_ids"])
        id_em.append(identifier_em)
        edit_similarities.append(es)

        detailed_results.append({
            "task_id": trunc_s["task_id"],
            "em": em_labels[idx],
            "es": es,
            "id_em": identifier_em,
            "id_precision": id_tp / (id_tp + id_fp) if (id_tp + id_fp) != 0 else 0,
            "id_recall": id_tp / (id_tp + id_fn) if (id_tp + id_fn) != 0 else 0,
            "id_f1": 2 * id_tp / (2 * id_tp + id_fp + id_fn) if (2 * id_tp + id_fp + id_fn) != 0 else 0,
        })

    em_ratio = round(exact_match / len(samples) * 100, 2)
    edit_sim = round(sum(edit_similarities) / len(edit_similarities), 2)

    id_em_ratio = round(
        sum(detailed_results[idx]['id_em'] for idx in range(len(detailed_results))) / len(detailed_results) * 100, 2)
    id_precision = round(sum(detailed_results[idx]['id_precision'] for idx in range(len(detailed_results))) / len(
        detailed_results) * 100, 2)
    id_recall = round(
        sum(detailed_results[idx]['id_recall'] for idx in range(len(detailed_results))) / len(detailed_results) * 100,
        2)
    id_f1 = round(
        sum(detailed_results[idx]['id_f1'] for idx in range(len(detailed_results))) / len(detailed_results) * 100, 2)

    print(
        f"Code Matching: "
        f"EM {em_ratio:.2f}, "
        f"ES {edit_sim:.2f}"
    )

    print(
        f"ID matching: "
        f"EM {id_em_ratio}, "
        #f"Precision {id_precision}, "
        #f"Recall {id_recall}, "
        f"F1 {id_f1}"
    )

    with open(os.path.join(args.output_dir, "detailed_results.json"), 'w') as f:
        for dr in detailed_results:
            f.write(json.dumps(dr) + "\n")

    # write the results to a file
    print(f'writing results to {os.path.join(args.output_dir, "results.json")}')
    with open(os.path.join(args.output_dir, "results.json"), 'w') as f:
        res = {
            "em": em_ratio,
            "es": edit_sim,
            "id_em": id_em_ratio,
            "id_precision": id_precision,
            "id_recall": id_recall,
            "id_f1": id_f1,
            "total": len(truncated_samples)
        }
        f.write(json.dumps(res, indent=2))


================================================
FILE: scripts/eval_utils.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import ast
import re
from functools import lru_cache
from typing import List

import timeout_decorator
import torch
from fuzzywuzzy import fuzz
from nltk.tokenize import RegexpTokenizer
from sacrebleu.tokenizers.tokenizer_intl import TokenizerV14International

from keywords.keywordlist import get_language_keywords

IDENTIFIER_REGEX = re.compile('[_a-zA-Z][_a-zA-Z0-9]*')
REGEX_TEXT = ("(?<=[a-z0-9])(?=[A-Z])|"
              "(?<=[A-Z0-9])(?=[A-Z][a-z])|"
              "(?<=[0-9])(?=[a-zA-Z])|"
              "(?<=[A-Za-z])(?=[0-9])|"
              "(?<=[@$.'\"])(?=[a-zA-Z0-9])|"
              "(?<=[a-zA-Z0-9])(?=[@$.'\"])|"
              "_|\\s+")
string_pattern = r'"([^"\\]*(\\.[^"\\]*)*)"|\'([^\'\\]*(\\.[^\'\\]*)*)\''

SPLIT_REGEX = re.compile(REGEX_TEXT)

str_tokenizer = TokenizerV14International()
code_tokenizer = RegexpTokenizer(r'\w+')


def cal_edit_sim(references, hypotheses):
    total = len(references)
    edit_sim = 0.0
    for pred, gt in zip(hypotheses, references):
        pred = pred.strip()
        gt = gt.strip()
        edit_sim += fuzz.ratio(pred, gt)
    return edit_sim / total


@lru_cache(maxsize=5000)
def split_identifier_into_parts(identifier: str) -> List[str]:
    """
    Split a single identifier into parts on snake_case and camelCase
    """
    identifier_parts = list(s for s in SPLIT_REGEX.split(identifier) if len(s) > 0)

    if len(identifier_parts) == 0:
        return [identifier]
    if "_" in identifier:  # We consider "_" as part of identifier and add it back in between each semantic part
        # if snake_case, we only split identifiers based on "_", ignore the mixed camelCase or other special symbols
        # this helps us avoid splitting identifiers like "get_2d_array" into ["get", "2", "d", "array"]
        # also avoid many other corner cases
        identifier_parts = identifier.split("_")
        tmp = [identifier_parts[0]]
        for i in identifier_parts[1:]:
            tmp.append("_")
            tmp.append(i)
        identifier_parts = tmp

    return identifier_parts


def is_identifier(token, lang=None):
    return True if IDENTIFIER_REGEX.match(token) \
                   and (lang is None or token not in get_language_keywords(lang)) \
        else False


def extract_identifiers(source_code, lang):
    # the main idea is to remove String from a source code
    # then, tokenize the code to get all words and match with identifier regular expression
    # check if it is a language specific keyword, it not, then it is an identifier
    source_code_without_strings = re.sub(string_pattern, '', source_code)
    _ids = [t for t in code_tokenizer.tokenize(source_code_without_strings) if is_identifier(t, lang)]
    return _ids


def tokenize_string(input_str):
    return str_tokenizer(input_str)


def get_bracket_lang_statement(completion):
    end_idx = None
    for i in range(len(completion)):
        if completion[i] in [";", "}", "{"]:
            end_idx = i
            break
    return completion[:end_idx + 1] if end_idx else completion


@timeout_decorator.timeout(5)
def get_ast(parser, code):
    assert isinstance(code, str) or isinstance(code, bytes)
    if isinstance(code, str):
        code = bytes(code, "utf8")
    try:
        tree = parser.parse(code)
        return tree
    except Exception as e:
        return None


def remove_comments(code):
    code = re.sub(r'#.*', '', code)
    code = re.sub(r'//.*', '', code)
    return code


def is_parse_valid(parser, code):
    def syntax_error(node):
        if node.type == "ERROR":
            return True
        try:
            for child in node.children:
                if syntax_error(child):
                    return True
        except RecursionError as err:
            return True

        return False

    tree = get_ast(parser, code)
    if tree is not None:
        return not syntax_error(tree.root_node)
    return False


def is_code_parseable(code):
    try:
        ast.parse(code)
        return True
    except SyntaxError:
        return False


def get_python_one_statement(prompt, completion, parser):
    for i in range(len(completion)):
        code = prompt + completion[:i + 1]
        if not is_parse_valid(parser, code):
            continue
        if completion[i + 1] == "\n":
            return completion[:i + 1].rstrip()

    return completion


def postprocess_code_lines(prompt, completion, parser, lang):
    try:
        if lang in ["java", "csharp", "typescript"]:
            return get_bracket_lang_statement(completion)
        elif lang == "python":
            return get_python_one_statement(prompt, completion, parser)
    except Exception as e:
        return completion


def compute_mean_logp(scores, sequences, pad_token_id):
    assert scores.shape[0] == sequences.shape[0]
    assert scores.shape[1] == sequences.shape[1]
    with torch.no_grad():
        logp_vocab = torch.nn.functional.log_softmax(scores, dim=-1)
        indices = torch.unsqueeze(sequences, dim=-1)
        logp = torch.gather(logp_vocab, dim=-1, index=indices).squeeze(-1)
        sum_logp = torch.cumsum(logp, dim=1)  # batch_size, seq_len
        denom = torch.arange(1, sum_logp.shape[1] + 1).reshape(1, -1).to(device=sum_logp.device)  # 1, seq_len
        mean_logp = (sum_logp / denom).tolist()  # batch_size, seq_len
        sequence_lengths = (sequences != pad_token_id).sum(1).tolist()  # batch_size
        mean_logp = [mean_logp[idx][l - 1] for idx, l in enumerate(sequence_lengths)]
    return mean_logp


================================================
FILE: scripts/keywords/__init__.py
================================================


================================================
FILE: scripts/keywords/csharp.txt
================================================
abstract
as
base
bool
break
byte
case
catch
char
checked
class
const
continue
decimal
default
delegate
do
double
else
enum
event
explicit
extern
finally
fixed
float
for
foreach
goto
if
implicit
in
int
interface
internal
is
lock
long
namespace
new
null
object
operator
out
override
params
private
protected
public
readonly
ref
return
sbyte
sealed
short
sizeof
stackalloc
static
string
struct
switch
this
throw
try
typeof
uint
ulong
unchecked
unsafe
ushort
using
using
static
virtual
void
volatile
while

================================================
FILE: scripts/keywords/java.txt
================================================
abstract
assert
boolean
break
byte
case
catch
char
class
continue
default
do
double
else
enum
extends
final
finally
float
for
if
implements
import
instanceof
int
interface
long
native
new
package
private
protected
public
return
short
static
strictfp
super
switch
synchronized
this
throw
throws
transient
try
void
volatile
while
var
const
goto


================================================
FILE: scripts/keywords/javascript.txt
================================================
break
case
catch
class
const
continue
debugger
default
delete
do
else
export
extends
finally
for
function
if
import
in
instanceof
new
return
super
switch
this
throw
try
typeof
var
void
while
with
yield
enum
implements
interface
let
package
private
protected
public
static

================================================
FILE: scripts/keywords/keywordlist.py
================================================
# Original Copyright 2021 Microsoft under MIT License.
# From https://github.com/microsoft/dpu-utils/blob/master/python/dpu_utils/codeutils/keywords/keywordlist.py

import os
import keyword
from functools import lru_cache
from typing import FrozenSet

__all__ = ['get_language_keywords']

_LANGUAGE_TO_FILENAME = {
    'c': 'c.txt',
    'cpp': 'cpp.txt',
    'c++': 'cpp.txt',
    'csharp': 'csharp.txt',
    'c_sharp': 'csharp.txt',
    'c#': 'csharp.txt',
    'go': 'go.txt',
    'java': 'java.txt',
    'javascript': 'javascript.txt',
    'js': 'javascript.txt',
    'php': 'php.txt',
    'ruby': 'ruby.txt',
    'typescript': 'typescript.txt',
    'ts': 'typescript.txt',
}


@lru_cache()
def get_language_keywords(language: str) -> FrozenSet[str]:
    """
    Returns the keywords of a programming language.

    There are some inconsistencies across languages wrt to
    what is considered a keyword. For example, the true/false
    literals are considered keywords in many languages. However,
    we exclude them here for consistency. We also exclude special
    functions-like keywords, such as `die()` in PHP.
    """
    language = language.lower()
    if language == 'python':
        return frozenset(k for k in keyword.kwlist if k != 'True' and k != 'False')
    elif language in _LANGUAGE_TO_FILENAME:
        name = _LANGUAGE_TO_FILENAME[language]
        with open(os.path.join(os.path.dirname(__file__), name)) as f:
            return frozenset(l.strip() for l in f if len(l.strip()) > 0)
    else:
        raise Exception('Language keywords `%s` not supported yet. Consider contributing it to dpu-utils.' % language)


================================================
FILE: scripts/keywords/typescript.txt
================================================
break
case
catch
class
const
continue
debugger
default
delete
do
else
export
extends
finally
for
function
if
import
in
instanceof
new
return
super
switch
this
throw
try
typeof
var
void
while
with
yield
enum
implements
interface
let
package
private
protected
public
static

================================================
FILE: scripts/openai_inference.py
================================================
"""
Script to query an OpenAI API to generate code.
Set environment variable OPENAI_KEY with your API key
before running this script.
"""

import argparse
import json
import os
import time
from typing import Dict, List, Tuple

import numpy as np
import openai
import tiktoken
from openai import OpenAI
from openai.types.chat import ChatCompletion
from tqdm import tqdm

SLEEP_SECOND = 2.8  # minimum time to sleep with API errors
MAX_SLEEP_SECOND = 120  # maximum time sleep time to wait with exp backoff
BUFFER = 100  # estimated tokens used by OpenAI + some more buffer
SYS_PROMPT = 'You are Codex, a code completion language model. Continue the code presented to you.'

openai_api_key = os.environ.get("OPENAI_API_KEY")
assert openai_api_key is not None, "Please set openai_api_key with your API key"
client = OpenAI()


def query(
        args,
        prompt: str,
) -> ChatCompletion:
    """
    This function queries an OpenAI API to generate code based on the given prompt.

    Args:
    prompt: str, the prompt to generate code from
    temperature: float, the value used to module the next token probabilities
    max_tokens: int, the maximum number of tokens to generate
    top_p: float, the cumulative probability for top-p filtering

    Returns:
    OpenAI Completion object, the response from the OpenAI Codex API
    """
    return client.chat.completions.create(model=args.model,
                                          messages=[
                                              {"role": "system", "content": SYS_PROMPT},
                                              {"role": "user", "content": prompt}
                                          ],
                                          temperature=args.temperature,
                                          max_tokens=args.generation_max_tokens,
                                          top_p=args.top_p,
                                          )


def query_with_retry(
        args,
        prompt: str,

) -> ChatCompletion | None:
    """
    This function queries an OpenAI API to generate code based on the given prompt.

    Args:
    prompt: str, the prompt to generate code from
    sleep_second: int, the number of seconds to sleep when the rate limit error is raised
    temperature: float, the value used to module the next token probabilities
    max_tokens: int, the maximum number of tokens to generate
    top_p: float, the cumulative probability for top-p filtering

    Returns:
    OpenAI Completion object, the response from the OpenAI Codex API if succeeds
    else return None

    Reference:
    https://github.com/Leolty/repobench/blob/c24b7a80465957e75107eafd23c66d369fa9e755/model/codex.py
    """

    error_sleep_second = SLEEP_SECOND

    def _upd_error_sleep_time(error_sleep_second):
        # double the sleep time if it is less than MAX_SLEEP_SECOND seconds
        if error_sleep_second < MAX_SLEEP_SECOND:
            error_sleep_second *= 2
        # if the sleep time is greater than MAX_SLEEP_SECOND seconds,
        # then sleep MAX_SLEEP_SECOND seconds
        else:
            error_sleep_second = MAX_SLEEP_SECOND
        return error_sleep_second

    while True:
        try:
            response = query(args, prompt)
            time.sleep(SLEEP_SECOND + np.random.rand())
            return response
        except openai.RateLimitError as e:
            print(f'RateLimitError: {e}')
            print(f'Retrying after {error_sleep_second} seconds')
            time.sleep(error_sleep_second)
            error_sleep_second = _upd_error_sleep_time(error_sleep_second)
        except openai.OpenAIError as e:
            print(f'OpenAIError: {e}')
            print(f'Retrying after {error_sleep_second} seconds')
            time.sleep(error_sleep_second)
            error_sleep_second = _upd_error_sleep_time(error_sleep_second)


def truncate(prompt: str, max_num_tokens: int, tokenizer, side: str) -> str:
    """Truncate prompt from side given the token budget"""

    # use tiktokenizer to analyze num of tokens
    tokens = tokenizer.encode(prompt, disallowed_special=())
    num_tokens = len(tokens)

    if num_tokens > max_num_tokens:
        if side == 'left':
            prompt_tokens = tokens[num_tokens - max_num_tokens:]
        elif side == 'right':
            prompt_tokens = tokens[:max_num_tokens]
        else:
            assert False, 'Invalid side'
        # decode and encode again as a sanity check
        prompt = tokenizer.decode(prompt_tokens)
        new_len = len(tokenizer.encode(prompt, disallowed_special=()))
        assert new_len <= max_num_tokens
    return prompt


def prepare_prompt(
        prompt: str,
        cross_file_context: str,
        cross_file_budget: int,
        prompt_budget: int,
        tokenizer
) -> str:
    """Create an augmented prompt according to budget specs"""

    # left truncate original prompt
    prompt = truncate(prompt, prompt_budget, tokenizer, 'left')

    if cross_file_context is not None:
        # right truncate cross file context string
        cross_file_context = truncate(cross_file_context, cross_file_budget, tokenizer, 'right')
    else:
        cross_file_context = ''

    # return <CFC>\n<PROMPT>
    return cross_file_context + '\n' + prompt


def get_openai_response(
        sample: Dict,
        tokenizer,
        args
) -> Tuple[str, Dict]:
    """Get OpenAI response for a single sample. Returns the prompt used to
    infer and the response of the API."""
    if args.use_crossfile_context:
        prompt = prepare_prompt(
            sample['prompt'], sample['crossfile_context']['text'],
            args.crossfile_max_tokens,
            args.model_max_tokens - args.generation_max_tokens - args.crossfile_max_tokens - BUFFER,
            tokenizer
        )
    else:
        prompt = prepare_prompt(
            sample['prompt'], None,
            0,
            args.model_max_tokens - args.generation_max_tokens - BUFFER,
            tokenizer
        )

    response = query_with_retry(args, prompt)
    return prompt, response


def get_openai_responses(
        args, data, out_path
) -> List[str]:
    """Get OpenAI responses to all samples in data, store in out_path,
    and return list of task ids that were skipped due to some errors"""
    tokenizer = tiktoken.encoding_for_model(args.model)
    skipped = []
    with open(out_path, 'w') as f:
        for d in tqdm(data):
            try:
                prompt, response = get_openai_response(
                    d, tokenizer, args
                )
            except Exception as e:
                print('Unknown error', e)
                raise

            if response is not None:
                d['pred_raw'] = response.choices[0].message.content  # key compatible with eval script
                d['pred'] = '\n'.join(d['pred_raw'].split('\n')[1:]).strip('`') if d['pred_raw'].startswith('```') else d['pred_raw'] # newer chatgpt may ourput ```[lang_tag]``` at beginning 
                # d['api_response'] = str(response)
                d['prompt_used'] = prompt  # records the augmented prompt
                d['task_id'] = d['metadata']['task_id']  # adding for compatibility with eval script
                print(json.dumps(d), file=f, flush=True)
            else:
                skipped.append(d['metadata']['task_id'])
                print(f'Skipped {d["metadata"]["task_id"]}')

    return skipped


def main():
    # get config for current run
    parser = argparse.ArgumentParser()
    parser.add_argument('--temperature', type=float, default=0.2)
    parser.add_argument('--top_p', type=float, default=0.95)
    parser.add_argument(
        '--task', type=str, required=True,
    )
    parser.add_argument(
        '--language', type=str, required=True,
        choices=['csharp', 'python', 'java', 'typescript']
    )
    parser.add_argument(
        '--data_root_dir', type=str, default='data/',
        help='path to directory where data is organized in lang/task.jsonl format'
    )
    parser.add_argument(
        '--output_dir', type=str, required=True,
        help='path to directory where to store outputs'
    )
    parser.add_argument(
        '--model', type=str, required=True,
        help='openAI-supported model'
    )
    parser.add_argument(
        '--model_max_tokens', type=int, default=16384,
        help='maximum number of tokens of the model'
    )
    parser.add_argument(
        '--crossfile_max_tokens', type=int, default=12800,
        help='maximum number of tokens for cross file context'
    )
    parser.add_argument(
        '--use_crossfile_context', action='store_true',
        help='whether use cross file context'
    )
    parser.add_argument(
        '--generation_max_tokens', type=int, default=50,
        help='maximum number of tokens to generate'
    )
    args = parser.parse_args()
    print(json.dumps(vars(args), indent=4))

    # setup paths
    if not os.path.isdir(args.output_dir):
        print(f'==== Output dir does not exist. Creating: {args.output_dir} ====')
        os.makedirs(args.output_dir)
    data_path = os.path.join(args.data_root_dir, args.language, args.task + '.jsonl')
    data = [json.loads(l) for l in open(data_path, 'r').readlines()]

    out_path = os.path.join(args.output_dir, 'prediction.jsonl')
    # start OpenAI inference
    skipped_tasks = get_openai_responses(
        args, data, out_path
    )

    # save list of skipped tasks
    with open(out_path.replace('.jsonl', '_skipped_tasks.json'), 'w') as f:
        f.write(json.dumps(skipped_tasks))


if __name__ == '__main__':
    main()


================================================
FILE: scripts/vllm_inference.py
================================================
"""
Script to run vllm-based inference. See README for an example.
"""

import argparse
import json
import os
from typing import List

from tqdm import tqdm
from transformers import AutoTokenizer
from transformers.utils import logging
from vllm import LLM, SamplingParams

logging.set_verbosity_info()
logger = logging.get_logger(__name__)
# add a small buffer to take care of non-lossless tokenizers
BUFFER = 100


def truncate(prompt: str, max_num_tokens: int, side: str, tokenizer) -> str:
    """Truncate prompt from side given the token budget"""

    tokens = tokenizer.tokenize(prompt)
    num_tokens = len(tokens)

    if num_tokens > max_num_tokens:
        if side == 'left':
            prompt_tokens = tokens[num_tokens - max_num_tokens:]
        elif side == 'right':
            prompt_tokens = tokens[:max_num_tokens]
        prompt = tokenizer.convert_tokens_to_string(prompt_tokens)
        new_len = len(tokenizer.tokenize(prompt))
        if new_len > max_num_tokens:
            logger.warning(
                f'Number of tokens after truncation is greater than max tokens allowed: {new_len=} {num_tokens=}')
    return prompt


def prepare_prompt(
        prompt: str,
        cross_file_context: str,
        cross_file_budget: int,
        prompt_budget: int,
        tokenizer
) -> str:
    """Create an augmented prompt according to budget specs"""

    # print(f'{cross_file_budget=} {prompt_budget=}')
    # left truncate original prompt
    prompt = truncate(prompt, prompt_budget, 'left', tokenizer)

    if cross_file_context is not None:
        # right truncate cross file context string
        cross_file_context = truncate(cross_file_context, cross_file_budget, 'right', tokenizer)
    else:
        cross_file_context = ''

    return cross_file_context + '\n' + prompt


def cceval_generate(
        args,
        data,
        tokenizer,
        sampling_params,
        llm
) -> List[str]:
    prompts = []
    for d in data:
        if args.use_crossfile_context:
            prompt = prepare_prompt(
                d['prompt'], d['crossfile_context']['text'],
                args.crossfile_max_tokens,
                args.model_max_tokens - args.generation_max_tokens - args.crossfile_max_tokens - BUFFER,
                tokenizer
            )
        else:
            prompt = prepare_prompt(
                d['prompt'], None,
                0,
                args.model_max_tokens - args.generation_max_tokens - BUFFER,
                tokenizer
            )
        prompts.append(prompt)

    outputs = llm.generate(prompts, sampling_params)

    out_path = os.path.join(args.output_dir, 'prediction.jsonl')
    with open(out_path, 'w') as f:
        for d, response in tqdm(zip(data, outputs)):
            d['pred'] = response.outputs[0].text
            d['task_id'] = d['metadata']['task_id']
            print(json.dumps(d), file=f, flush=True)

    return


def main():
    # set the OpenAI key
    # openai.api_key = os.environ.get('OPENAI_KEY', None)
    # if openai.api_key is None:
    #    raise ValueError('OPENAI_KEY environment variable not set')

    # get config for current run
    parser = argparse.ArgumentParser()
    parser.add_argument('--temperature', type=float, default=0.2)

    parser.add_argument('--top_p', type=float, default=0.95)
    parser.add_argument(
        '--task', type=str, required=True,
    )
    parser.add_argument(
        '--language', type=str, required=True,
        choices=['csharp', 'python', 'java', 'typescript']
    )
    parser.add_argument(
        '--data_root_dir', type=str, default='data/',
        help='path to directory where data is organized in lang/task.jsonl format'
    )
    parser.add_argument(
        '--output_dir', type=str, required=True,
        help='path to directory where to store outputs'
    )
    parser.add_argument(
        '--model', type=str, required=True,
        help='vLLM-supported model'
    )
    parser.add_argument(
        '--tp_size', type=int, default=1,
        help='tensor parallel size'
    )
    parser.add_argument(
        '--model_max_tokens', type=int, default=16384,
        help='maximum number of tokens of the model'
    )
    parser.add_argument(
        '--crossfile_max_tokens', type=int, default=12800,
        help='maximum number of tokens for cross file context'
    )
    parser.add_argument(
        '--use_crossfile_context', action='store_true',
        help='whether use cross file context'
    )
    parser.add_argument(
        '--generation_max_tokens', type=int, default=50,
        help='maximum number of tokens to generate'
    )

    args = parser.parse_args()
    print(json.dumps(vars(args), indent=4))

    # load model
    llm = LLM(model=args.model, tensor_parallel_size=args.tp_size, max_model_len=args.model_max_tokens)
    tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
    sampling_params = SamplingParams(temperature=args.temperature, top_p=args.top_p, max_tokens=args.generation_max_tokens)

    # setup paths
    if not os.path.isdir(args.output_dir):
        print(f'==== Output dir does not exist. Creating: {args.output_dir} ====')
        os.makedirs(args.output_dir)
    data_path = os.path.join(args.data_root_dir, args.language, args.task + '.jsonl')
    data = [json.loads(l) for l in open(data_path, 'r').readlines()]

    # generation
    cceval_generate(args, data, tokenizer, sampling_params, llm)


if __name__ == '__main__':
    main()
Download .txt
gitextract_65e_y3xi/

├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── NOTICE
├── README.md
├── THIRD_PARTY_LICENSES
├── cceval_config.yaml
├── data/
│   └── crosscodeeval_data.tar.xz
├── prompt_builder/
│   ├── README.md
│   ├── augment_with_cfc.py
│   ├── rerank_utils.py
│   ├── run.sh
│   └── utils.py
├── requirements.txt
└── scripts/
    ├── build_treesitter.sh
    ├── build_ts_lib.py
    ├── custom_generate.py
    ├── eval.py
    ├── eval_metric.py
    ├── eval_utils.py
    ├── keywords/
    │   ├── __init__.py
    │   ├── csharp.txt
    │   ├── java.txt
    │   ├── javascript.txt
    │   ├── keywordlist.py
    │   └── typescript.txt
    ├── openai_inference.py
    └── vllm_inference.py
Download .txt
SYMBOL INDEX (54 symbols across 11 files)

FILE: prompt_builder/augment_with_cfc.py
  function get_crossfile_context_from_chunks (line 41) | def get_crossfile_context_from_chunks(
  function read_project_files (line 148) | def read_project_files(repo_name, lang):
  function find_files_within_distance_k (line 185) | def find_files_within_distance_k(current_file_path, filelist, k):
  function get_cfc (line 202) | def get_cfc(example, args, semantic_ranker, repositories):
  function attach_data (line 263) | def attach_data(args, srcfile):

FILE: prompt_builder/rerank_utils.py
  function jaccard_similarity (line 24) | def jaccard_similarity(tokenized_query, tokenized_doc, containment=False):
  function tokenize_corpus (line 32) | def tokenize_corpus(corpus, tokenizer_fn):
  function tokenize_query_and_docs (line 38) | def tokenize_query_and_docs(query, docs):
  function lexical_ranking (line 44) | def lexical_ranking(
  class SemanticReranking (line 85) | class SemanticReranking:
    method __init__ (line 87) | def __init__(self, model_type="unixcoder", **kwargs):
    method text_to_tensor (line 98) | def text_to_tensor(
    method get_pad_id (line 121) | def get_pad_id(self):
    method get_attn_mask (line 124) | def get_attn_mask(self, tokens_tensor):
    method get_representations (line 127) | def get_representations(self, list_input_ids, gpu_id):
    method rerank (line 165) | def rerank(self, query: str, docs: List[str], doc_ids: List[str] = Non...

FILE: prompt_builder/utils.py
  function tokenize_nltk (line 20) | def tokenize_nltk(text):
  function file_distance (line 29) | def file_distance(src_file, dest_file):
  function str2bool (line 43) | def str2bool(v):

FILE: scripts/build_ts_lib.py
  function build_language_lib (line 7) | def build_language_lib():

FILE: scripts/custom_generate.py
  function generate (line 86) | def generate(
  function sample (line 732) | def sample(

FILE: scripts/eval.py
  function custom_data_collator (line 52) | def custom_data_collator(features):
  function build_datasets (line 69) | def build_datasets(args, tokenizer):
  function model_inference (line 187) | def model_inference(tokenized_datasets, index2taskid, tokenizer):

FILE: scripts/eval_metric.py
  function compute_id_match (line 19) | def compute_id_match(pred_ids, target_ids):
  function compute_edit_sim (line 36) | def compute_edit_sim(samples):
  function process_examples (line 44) | def process_examples(lang, args):
  function compute_metric_stmt (line 70) | def compute_metric_stmt(args):

FILE: scripts/eval_utils.py
  function cal_edit_sim (line 44) | def cal_edit_sim(references, hypotheses):
  function split_identifier_into_parts (line 55) | def split_identifier_into_parts(identifier: str) -> List[str]:
  function is_identifier (line 77) | def is_identifier(token, lang=None):
  function extract_identifiers (line 83) | def extract_identifiers(source_code, lang):
  function tokenize_string (line 92) | def tokenize_string(input_str):
  function get_bracket_lang_statement (line 96) | def get_bracket_lang_statement(completion):
  function get_ast (line 106) | def get_ast(parser, code):
  function remove_comments (line 117) | def remove_comments(code):
  function is_parse_valid (line 123) | def is_parse_valid(parser, code):
  function is_code_parseable (line 142) | def is_code_parseable(code):
  function get_python_one_statement (line 150) | def get_python_one_statement(prompt, completion, parser):
  function postprocess_code_lines (line 161) | def postprocess_code_lines(prompt, completion, parser, lang):
  function compute_mean_logp (line 171) | def compute_mean_logp(scores, sequences, pad_token_id):

FILE: scripts/keywords/keywordlist.py
  function get_language_keywords (line 30) | def get_language_keywords(language: str) -> FrozenSet[str]:

FILE: scripts/openai_inference.py
  function query (line 30) | def query(
  function query_with_retry (line 57) | def query_with_retry(
  function truncate (line 109) | def truncate(prompt: str, max_num_tokens: int, tokenizer, side: str) -> ...
  function prepare_prompt (line 130) | def prepare_prompt(
  function get_openai_response (line 152) | def get_openai_response(
  function get_openai_responses (line 178) | def get_openai_responses(
  function main (line 209) | def main():

FILE: scripts/vllm_inference.py
  function truncate (line 21) | def truncate(prompt: str, max_num_tokens: int, side: str, tokenizer) -> ...
  function prepare_prompt (line 40) | def prepare_prompt(
  function cceval_generate (line 62) | def cceval_generate(
  function main (line 99) | def main():
Condensed preview — 29 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (151K chars).
[
  {
    "path": ".gitignore",
    "chars": 2073,
    "preview": "*.idea/\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution /"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "chars": 309,
    "preview": "## Code of Conduct\nThis project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-condu"
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 3160,
    "preview": "# Contributing Guidelines\n\nThank you for your interest in contributing to our project. Whether it's a bug report, new fe"
  },
  {
    "path": "LICENSE",
    "chars": 10142,
    "preview": "\n                                 Apache License\n                           Version 2.0, January 2004\n                  "
  },
  {
    "path": "NOTICE",
    "chars": 67,
    "preview": "Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n"
  },
  {
    "path": "README.md",
    "chars": 4906,
    "preview": "# CrossCodeEval: A Diverse and Multilingual Benchmark for Cross-File Code Completion\n\nThis repository contains the data "
  },
  {
    "path": "THIRD_PARTY_LICENSES",
    "chars": 1391,
    "preview": "The CrossCodeEval repository includes the following third-party software/licensing:\n\nThe keywordlist.py was from https:/"
  },
  {
    "path": "cceval_config.yaml",
    "chars": 457,
    "preview": "compute_environment: LOCAL_MACHINE\ndeepspeed_config:\n  gradient_accumulation_steps: 1\n  offload_optimizer_device: none\n "
  },
  {
    "path": "prompt_builder/README.md",
    "chars": 311,
    "preview": "## Retrieval Augmented Prompting\n\nWe can generate the retrieval augmented prompt following the below 3 steps.\n\n1. Please"
  },
  {
    "path": "prompt_builder/augment_with_cfc.py",
    "chars": 14357,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All rights reserved.\n# Licensed under the Apache License, Version 2.0 (t"
  },
  {
    "path": "prompt_builder/rerank_utils.py",
    "chars": 7725,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All rights reserved.\n# Licensed under the Apache License, Version 2.0 (t"
  },
  {
    "path": "prompt_builder/run.sh",
    "chars": 1719,
    "preview": "#!/usr/bin/env bash\n\nexport PYTHONIOENCODING=utf-8\n\nfunction generate_data() {\n    lang=$1\n    ranker=$2\n    ranking_fn="
  },
  {
    "path": "prompt_builder/utils.py",
    "chars": 1623,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All rights reserved.\n# Licensed under the Apache License, Version 2.0 (t"
  },
  {
    "path": "requirements.txt",
    "chars": 162,
    "preview": "torch\ntransformers\ndatasets\ntree-sitter\ntimeout-decorator\nbitsandbytes\naccelerate\nscikit-learn\nrank-bm25\nfuzzywuzzy\nnltk"
  },
  {
    "path": "scripts/build_treesitter.sh",
    "chars": 389,
    "preview": "mkdir ts_package;\ncd ts_package;\n# Download the tree-sitter package\ngit clone https://github.com/tree-sitter/tree-sitter"
  },
  {
    "path": "scripts/build_ts_lib.py",
    "chars": 583,
    "preview": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n\nfrom tree_sit"
  },
  {
    "path": "scripts/custom_generate.py",
    "chars": 49224,
    "preview": "# Modifications Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.\n# Copyright The HuggingFace Team"
  },
  {
    "path": "scripts/eval.py",
    "chars": 15489,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All rights reserved.\n# Licensed under the Apache License, Version 2.0 (t"
  },
  {
    "path": "scripts/eval_metric.py",
    "chars": 5842,
    "preview": "import json\nfrom functools import partial\n\nimport torch.multiprocessing as mp\nfrom tqdm import tqdm\nfrom tree_sitter imp"
  },
  {
    "path": "scripts/eval_utils.py",
    "chars": 6126,
    "preview": "# Copyright Amazon.com, Inc. or its affiliates. All rights reserved.\n# Licensed under the Apache License, Version 2.0 (t"
  },
  {
    "path": "scripts/keywords/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "scripts/keywords/csharp.txt",
    "chars": 501,
    "preview": "abstract\nas\nbase\nbool\nbreak\nbyte\ncase\ncatch\nchar\nchecked\nclass\nconst\ncontinue\ndecimal\ndefault\ndelegate\ndo\ndouble\nelse\nen"
  },
  {
    "path": "scripts/keywords/java.txt",
    "chars": 343,
    "preview": "abstract\nassert\nboolean\nbreak\nbyte\ncase\ncatch\nchar\nclass\ncontinue\ndefault\ndo\ndouble\nelse\nenum\nextends\nfinal\nfinally\nfloa"
  },
  {
    "path": "scripts/keywords/javascript.txt",
    "chars": 271,
    "preview": "break\ncase\ncatch\nclass\nconst\ncontinue\ndebugger\ndefault\ndelete\ndo\nelse\nexport\nextends\nfinally\nfor\nfunction\nif\nimport\nin\ni"
  },
  {
    "path": "scripts/keywords/keywordlist.py",
    "chars": 1636,
    "preview": "# Original Copyright 2021 Microsoft under MIT License.\n# From https://github.com/microsoft/dpu-utils/blob/master/python/"
  },
  {
    "path": "scripts/keywords/typescript.txt",
    "chars": 271,
    "preview": "break\ncase\ncatch\nclass\nconst\ncontinue\ndebugger\ndefault\ndelete\ndo\nelse\nexport\nextends\nfinally\nfor\nfunction\nif\nimport\nin\ni"
  },
  {
    "path": "scripts/openai_inference.py",
    "chars": 9607,
    "preview": "\"\"\"\nScript to query an OpenAI API to generate code.\nSet environment variable OPENAI_KEY with your API key\nbefore running"
  },
  {
    "path": "scripts/vllm_inference.py",
    "chars": 5482,
    "preview": "\"\"\"\nScript to run vllm-based inference. See README for an example.\n\"\"\"\n\nimport argparse\nimport json\nimport os\nfrom typin"
  }
]

// ... and 1 more files (download for full content)

About this extraction

This page contains the full source code of the amazon-science/cceval GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 29 files (40.2 MB), approximately 32.5k tokens, and a symbol index with 54 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!