If you prefer non-vLLM script :: click to expand ::
First, configure `accelerate` via `accelerate config` if you haven't. A reference configuration is available at `cceval_config.yaml`
The following command demonstrates how to run greedy eval using codegen-350M on python with cross-file context.
```bash
export model_type=codelm_cfc # or codelm for no cross-file context eval
export model_name=Salesforce/codegen-350M-mono
export language=python
export ts_lib=./build/${language}-lang-parser.so
export dtype=bf16 # or fp16
export prompt_file=./data/crosscodeeval_data/${language}/line_completion_rg1_unixcoder_cosine_sim.jsonl # or other options in the dir, which corresponds to different retrieval methods and/or retrieval settings
export max_seq_length=2048
export cfc_seq_length=512
export batch_size=16 # reduce for larger models
export output_dir=./tmp/crosscodeeval_testrun/
accelerate launch eval.py \
--model_type $model_type \
--model_name_or_path $model_name \
--cfc_seq_length $cfc_seq_length \
--prompt_file $prompt_file \
--gen_length 50 \
--max_seq_length $max_seq_length \
--batch_size $batch_size \
--output_dir $output_dir \
--dtype $dtype \
--num_return_sequences 1 \
--overwrite_cache True \
--ts_lib $ts_lib \
--language $language
```
You may run sampling via the following (additional) args:
```bash
--do_sample \
--top_p 0.95 \
--temperature 0.2 \
--num_return_sequences 5 \
```
#### OpenAI models
OpenAI models are accessible through an API. You may use the following script:
```bash
export model=gpt-3.5-turbo-0125
export language=python
export task=line_completion_rg1_unixcoder_cosine_sim
export output_dir=./tmp/crosscodeeval_openai_testrun/
python scripts/openai_inference.py \
--task $task \
--language $language \
--model $model \
--output_dir $output_dir \
--use_crossfile_context
```
### Metrics Calculation
After obtaining the generation, we can calculate the final metrics
```bash
export language=python
export ts_lib=./build/${language}-lang-parser.so;
export task=line_completion_oracle_unixcoder_cosine_sim
export prompt_file=./data/${language}/${task}.jsonl
export output_dir=./tmp/crosscodeeval_testrun/;
python scripts/eval.py \
--prompt_file $prompt_file \
--output_dir $output_dir \
--ts_lib $ts_lib \
--language $language \
--only_compute_metric
```
## Citation
```
@inproceedings{ding2023crosscodeeval,
title={CrossCodeEval: A Diverse and Multilingual Benchmark for Cross-File Code Completion},
author={Yangruibo Ding and Zijian Wang and Wasi Uddin Ahmad and Hantian Ding and Ming Tan and Nihal Jain and Murali Krishna Ramanathan and Ramesh Nallapati and Parminder Bhatia and Dan Roth and Bing Xiang},
year={2023},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
url={https://arxiv.org/pdf/2310.11248.pdf}
}
```
## Questions
Please feel free to email us. You may also submit an issue in this repo.
## Security
See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
## License
This project is licensed under the Apache-2.0 License.
================================================
FILE: THIRD_PARTY_LICENSES
================================================
The CrossCodeEval repository includes the following third-party software/licensing:
The keywordlist.py was from https://github.com/microsoft/dpu-utils/blob/master/python/dpu_utils/codeutils/keywords/keywordlist.py with license
MIT License
Copyright (c) Microsoft Corporation. All rights reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
================================================
FILE: cceval_config.yaml
================================================
compute_environment: LOCAL_MACHINE
deepspeed_config:
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero3_save_16bit_model: false
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
================================================
FILE: data/crosscodeeval_data.tar.xz
================================================
[File too large to display: 40.1 MB]
================================================
FILE: prompt_builder/README.md
================================================
## Retrieval Augmented Prompting
We can generate the retrieval augmented prompt following the below 3 steps.
1. Please email us if to get the raw software repositories.
2. Set the `repository_root` variable with the root directory which contains the raw software repositories.
3. Run the `run.sh` bash script.
================================================
FILE: prompt_builder/augment_with_cfc.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import json
import time
import glob
import argparse
import multiprocessing as mp
from tqdm import tqdm
from functools import partial
from rerank_utils import lexical_ranking, SemanticReranking
from utils import str2bool, file_distance, tokenize_nltk
CHUNK_SIZE = 10
SLIDING_WINDOW_SIZE = 10 # non-overlapping chunks if SLIDING_WINDOW_SIZE=CHUNK_SIZE
QUERY_LENGTH = 10 # last N lines from prompt will be query
repository_root = "/PATH/TO/REPOS" # get the data from authors
input_files = {
"python": "../data/crosscodeeval_data/python/line_completion.jsonl",
"java": "../data/crosscodeeval_data/java/line_completion.jsonl",
"typescript": "../data/crosscodeeval_data/typescript/line_completion.jsonl",
"csharp": "../data/crosscodeeval_data/csharp/line_completion.jsonl"
}
file_ext = {"python": "py", "java": "java", "typescript": "ts", "csharp": "cs"}
def get_crossfile_context_from_chunks(
args,
prompt,
code_chunks,
code_chunk_ids,
groundtruth,
semantic_ranker
):
assert len(code_chunks) != 0
candidate_code_chunks = code_chunks[:args.maximum_chunk_to_rerank]
candidate_code_chunk_ids = code_chunk_ids[:args.maximum_chunk_to_rerank]
ranking_scores = None
meta_data = {}
if args.rerank:
if args.query_type == "groundtruth":
# oracle experiment
prompt_lines = [pl for pl in prompt.split("\n") if pl.strip()]
groundtruth_lines = [gt for gt in groundtruth.split("\n") if gt.strip()]
code_lines = prompt_lines + groundtruth_lines
query = "\n".join(code_lines[-QUERY_LENGTH:])
elif args.query_type == "last_n_lines":
prompt_lines = [pl for pl in prompt.split("\n") if pl.strip()]
query = "\n".join(prompt_lines[-QUERY_LENGTH:])
else:
raise NotImplementedError
meta_data["query"] = query
start = time.time()
if args.ranking_fn == "cosine_sim":
gpu_id = int(mp.current_process().name.split('-')[-1]) - 1
candidate_code_chunks, candidate_code_chunk_ids, ranking_scores = semantic_ranker.rerank(
query,
candidate_code_chunks,
candidate_code_chunk_ids,
gpu_id,
score_threshold=None
)
else:
candidate_code_chunks, candidate_code_chunk_ids, ranking_scores = lexical_ranking(
query,
candidate_code_chunks,
args.ranking_fn,
candidate_code_chunk_ids,
score_threshold=None
)
meta_data["latency"] = time.time() - start
meta_data["num_candidates"] = len(candidate_code_chunks)
top_k = min(args.maximum_cross_file_chunk, len(candidate_code_chunk_ids))
if top_k == 0:
return [], meta_data
selected_chunks = []
selected_chunks_filename = []
selected_chunks_scores = []
if args.use_next_chunk_as_cfc:
# prepare an id2idx map
assert len(candidate_code_chunks) == len(candidate_code_chunk_ids)
id2idx = dict()
for j, cci in enumerate(code_chunk_ids):
id2idx[cci] = j
total_added = 0
for cidx, _id in enumerate(candidate_code_chunk_ids):
fname, c_id = _id.rsplit("|", 1)
next_id = f"{fname}|{int(c_id) + 1}"
if next_id not in id2idx:
to_add = code_chunks[id2idx[_id]]
else:
to_add = code_chunks[id2idx[next_id]]
if to_add not in selected_chunks:
selected_chunks.append(to_add)
selected_chunks_filename.append(fname)
if args.rerank:
selected_chunks_scores.append(ranking_scores[cidx])
total_added += 1
if total_added == top_k:
break
else:
selected_chunks = candidate_code_chunks[:top_k]
selected_chunks_filename = [_id.rsplit("|", 1)[0] for _id in candidate_code_chunk_ids[:top_k]]
if args.rerank:
selected_chunks_scores = ranking_scores[:top_k]
cross_file_context = []
for idx in range(len(selected_chunks)):
cross_file_context.append({
"retrieved_chunk": selected_chunks[idx],
"filename": selected_chunks_filename[idx],
"score": selected_chunks_scores[idx] if args.rerank else None
})
line_start_sym = "#" if args.language == "python" else "//"
cfc_text = f"{line_start_sym} Here are some relevant code fragments from other files of the repo:\n\n"
for sc, scf in zip(selected_chunks, selected_chunks_filename):
cfc_text += f"{line_start_sym} the below code fragment can be found in:\n{line_start_sym} {scf}" + "\n"
cfc_text += "\n".join([f"{line_start_sym} {cl}" for cl in sc.strip('\n').splitlines()]) + "\n\n"
return cross_file_context, cfc_text, meta_data
def read_project_files(repo_name, lang):
# root_dir needs a trailing slash (i.e. /root/dir/)
project_context = {}
root_dir = os.path.join(repository_root, lang, repo_name)
if not os.path.isdir(root_dir):
print(f"Repository not found: {root_dir}")
return project_context
if lang == "typescript":
src_files = []
src_files += glob.glob(os.path.join(root_dir, f'src/**/*.ts'), recursive=True)
src_files += glob.glob(os.path.join(root_dir, f'src/**/*.tsx'), recursive=True)
else:
src_files = glob.glob(os.path.join(root_dir, f'**/*.{file_ext[lang]}'), recursive=True)
if len(src_files) == 0:
return project_context
for filename in src_files:
if os.path.exists(filename): # weird but some files cannot be opened to read
if os.path.isfile(filename):
try:
with open(filename, "r") as file:
file_content = file.read()
except:
with open(filename, "rb") as file:
file_content = file.read().decode(errors='replace')
fileid = os.path.relpath(filename, root_dir)
project_context[fileid] = file_content
else:
pass
# print(f"File not found: {filename}")
return project_context
def find_files_within_distance_k(current_file_path, filelist, k):
list_of_modules = []
module_weight = []
for filepath in filelist:
if filepath != current_file_path:
dist = file_distance(filepath, current_file_path)
if dist == -1:
continue
elif dist <= k:
list_of_modules.append(filepath)
module_weight.append(dist)
# sorting in ascending order
list_of_modules = [x for _, x in sorted(zip(module_weight, list_of_modules))]
return list_of_modules
def get_cfc(example, args, semantic_ranker, repositories):
project_context = repositories[example["metadata"]["repository"]]
status = None
current_filepath = example["metadata"]["file"]
if len(project_context) == 0:
example["crossfile_context"] = ""
status = "project_not_found"
else:
current_filecontent = None
for filepath, filecontent in project_context.items():
if filepath == current_filepath:
current_filecontent = filecontent
break
if current_filecontent is None:
example["crossfile_context"] = {}
print(current_filepath)
status = "file_not_found_in_project"
else:
pyfiles = find_files_within_distance_k(
example["metadata"]["file"],
list(project_context.keys()),
k=args.crossfile_distance
)
pyfiles = pyfiles[:args.maximum_cross_files]
code_chunks = []
code_chunk_ids = []
for pyfile in pyfiles:
lines = project_context[pyfile].split("\n")
lines = [l for l in lines if l.strip()] # removing empty lines
c_id = 0
for i in range(0, len(lines), SLIDING_WINDOW_SIZE):
c = "\n".join(lines[i:i + CHUNK_SIZE])
tokenized_c = tokenize_nltk(c)
if len(tokenized_c) > 0:
code_chunks.append(c)
code_chunk_ids.append(f"{pyfile}|{c_id}")
c_id += 1
if len(code_chunks) == 0:
example["crossfile_context"] = {}
status = "no_crossfile_context"
else:
cfc, cfc_text, meta_data = get_crossfile_context_from_chunks(
args=args,
prompt=example["prompt"],
code_chunks=code_chunks,
code_chunk_ids=code_chunk_ids,
groundtruth=example["groundtruth"],
semantic_ranker=semantic_ranker
)
example["crossfile_context"] = {}
example["crossfile_context"]["text"] = cfc_text
example["crossfile_context"]["list"] = cfc
return example, status
def attach_data(args, srcfile):
empty_cfc = 0
error_freq = {
"project_not_found": 0,
"file_not_found_in_project": 0,
"no_crossfile_context": 0
}
output_examples = []
examples = []
repositories = dict()
with open(srcfile) as f:
for line in f:
ex = json.loads(line)
repo_name = ex["metadata"]["repository"]
if repo_name not in repositories:
repositories[repo_name] = read_project_files(repo_name, args.language)
examples.append(ex)
semantic_ranker = None
if args.ranking_fn == "cosine_sim":
semantic_ranker = SemanticReranking(
args.ranker,
max_sequence_length=256
)
pool = mp.Pool(args.num_processes)
worker = partial(get_cfc, args=args, semantic_ranker=semantic_ranker, repositories=repositories)
with tqdm(total=len(examples)) as pbar:
for (d, stat) in pool.imap_unordered(worker, examples):
if stat in error_freq:
error_freq[stat] += 1
if len(d["crossfile_context"]) == 0:
empty_cfc += 1
if not args.skip_if_no_cfc:
output_examples.append(d)
else:
output_examples.append(d)
pbar.update()
print("Total examples with empty CFC: ", empty_cfc)
print(error_freq)
return output_examples
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--rerank",
type=str2bool,
default=True,
help="rerank the functions"
)
parser.add_argument(
"--ranker",
type=str,
default="sparse",
choices=["sparse", "unixcoder"],
help="ranking function"
)
parser.add_argument(
"--ranking_fn",
type=str,
default="bm25",
choices=["tfidf", "bm25", "jaccard_sim", "cosine_sim"],
help="ranking function"
)
parser.add_argument(
"--query_type",
type=str,
default="last_n_lines",
choices=["last_n_lines", "groundtruth"],
help="how to form query from prompt"
)
parser.add_argument(
"--crossfile_distance",
type=int,
default=100,
help="max distance to search for crossfile"
)
parser.add_argument(
"--maximum_chunk_to_rerank",
type=int,
default=1000,
help="max chunks to consider to rank via BM25"
)
parser.add_argument(
"--maximum_cross_files",
type=int,
default=1000,
help="max chunks to consider to rank via BM25"
)
parser.add_argument(
"--maximum_cross_file_chunk",
type=int,
default=50,
help="max chunks to return as cfc"
)
parser.add_argument(
"--use_next_chunk_as_cfc",
type=str2bool,
default=True,
help="use next code chunk as context"
)
parser.add_argument(
"--skip_if_no_cfc",
type=str2bool,
default=True,
help="skip adding examples if there is no crossfile context"
)
parser.add_argument(
"--output_file_suffix",
type=str,
default=None,
help="add a suffix string to the output file"
)
parser.add_argument(
"--language",
type=str,
required=True,
choices=["java", "python", "typescript", "csharp"],
help="language name"
)
args = parser.parse_args()
args.output_file_suffix = "" if args.output_file_suffix is None else f"_{args.output_file_suffix}"
if args.use_next_chunk_as_cfc:
assert args.rerank
assert args.query_type != "groundtruth"
tgtfile_suffix = ""
if args.rerank:
tgtfile_suffix += f"_{args.ranking_fn}"
args.num_processes = 60
if args.ranking_fn == "cosine_sim":
num_gpus = 8
args.num_processes = num_gpus
mp.set_start_method('spawn')
input_file = input_files[args.language]
output_path = os.path.dirname(input_file)
output_filename = os.path.splitext(os.path.basename(input_file))[0]
output_filename = output_filename + args.output_file_suffix + tgtfile_suffix + ".jsonl"
output_file = os.path.join(output_path, output_filename)
output_examples = attach_data(args, input_file)
with open(output_file, "w") as fw:
for ex in output_examples:
fw.write(json.dumps(ex))
fw.write("\n")
================================================
FILE: prompt_builder/rerank_utils.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from rank_bm25 import BM25Okapi
from typing import List
from multiprocessing import Pool, cpu_count
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from utils import tokenize_nltk
from transformers import AutoModel, AutoTokenizer, AutoConfig
def jaccard_similarity(tokenized_query, tokenized_doc, containment=False):
set1 = set(tokenized_query)
set2 = set(tokenized_doc)
intersection = len(set1.intersection(set2))
union = len(set1) if containment else len(set1.union(set2))
return float(intersection) / union
def tokenize_corpus(corpus, tokenizer_fn):
pool = Pool(cpu_count())
tokenized_corpus = pool.map(tokenizer_fn, corpus)
return tokenized_corpus
def tokenize_query_and_docs(query, docs):
tokenized_query = tokenize_nltk(query)
tokenized_docs = [tokenize_nltk(d) for d in docs]
return tokenized_query, tokenized_docs
def lexical_ranking(
query,
docs,
ranking_fn,
doc_ids=None,
score_threshold=None,
):
if ranking_fn == "bm25":
tokenized_query, tokenized_docs = tokenize_query_and_docs(query, docs)
bm25 = BM25Okapi(tokenized_docs)
scores = bm25.get_scores(tokenized_query)
elif ranking_fn == "tfidf":
tfidf_vectorizer = TfidfVectorizer(tokenizer=tokenize_nltk)
X = tfidf_vectorizer.fit_transform(docs).toarray() # (n_fn, n_features)
y = tfidf_vectorizer.transform([query]).toarray() # (1, n_features)
scores = cosine_similarity(X, y).tolist() # (n_fn, 1)
elif ranking_fn == "jaccard_sim":
tokenized_query, tokenized_docs = tokenize_query_and_docs(query, docs)
scores = [jaccard_similarity(tokenized_query, d, containment=False) for d in tokenized_docs]
else:
raise NotImplementedError
if score_threshold:
skip_ids = [idx for idx, s in enumerate(scores) if s < score_threshold]
scores = [s for idx, s in enumerate(scores) if idx not in skip_ids]
docs = [d for idx, d in enumerate(docs) if idx not in skip_ids]
if doc_ids is not None:
doc_ids = [doc_id for idx, doc_id in enumerate(doc_ids) if idx not in skip_ids]
if len(docs) == 0:
return docs, doc_ids, scores
if doc_ids is not None:
doc_ids = [x for _, x in sorted(zip(scores, doc_ids), reverse=True)]
docs_scores = [(x, s) for s, x in sorted(zip(scores, docs), reverse=True)]
docs = [item[0] for item in docs_scores]
scores = [item[1] for item in docs_scores]
return docs, doc_ids, scores
class SemanticReranking:
def __init__(self, model_type="unixcoder", **kwargs):
self.model_type = model_type
if model_type == "unixcoder":
self.tokenizer = AutoTokenizer.from_pretrained('microsoft/unixcoder-base')
self.model = AutoModel.from_pretrained('microsoft/unixcoder-base')
else:
raise NotImplementedError
# maximum sequence length for query and documents
self.max_sequence_length = kwargs.get("max_sequence_length", 256)
def text_to_tensor(
self,
text: str,
pad_to_max: bool = True,
):
text = text.strip()
# tokenizer automatic padding is explicitly disabled since its inconsistent behavior
token_ids = self.tokenizer.encode(
text,
add_special_tokens=False,
max_length=self.max_sequence_length,
pad_to_max_length=False,
truncation=True
)
if pad_to_max and len(token_ids) < self.max_sequence_length:
token_ids = token_ids + [self.tokenizer.pad_token_id] * (self.max_sequence_length - len(token_ids))
if len(token_ids) > self.max_sequence_length:
token_ids = token_ids[0:self.max_sequence_length]
return torch.tensor(token_ids)
def get_pad_id(self):
return self.tokenizer.pad_token_id
def get_attn_mask(self, tokens_tensor):
return tokens_tensor != self.get_pad_id()
def get_representations(self, list_input_ids, gpu_id):
device = torch.device('cuda', gpu_id)
self.model = self.model.to(device=device, dtype=torch.float16)
self.model.eval()
batch_size = 64
sequence_outputs = []
pooled_outputs = []
for idx in range(0, len(list_input_ids), batch_size):
start, end = idx, min(idx + batch_size, len(list_input_ids))
input_ids = torch.stack(list_input_ids[start:end], dim=0).to(device=device)
attention_mask = self.get_attn_mask(input_ids)
if self.model_type in CODE_SAGE_MODELS.keys():
output = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True
)
token_embeddings = output.hidden_states[-1] # bsz x seq_len x hid_dim
else:
output = self.model(input_ids, attention_mask)
token_embeddings = output.last_hidden_state # bsz x seq_len x hid_dim
mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * mask_expanded, 1)
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
sequence_embeddings = sum_embeddings / sum_mask # bsz x hid_dim
sequence_outputs.append(token_embeddings)
pooled_outputs.append(sequence_embeddings)
sequence_output = torch.cat(sequence_outputs)
pooled_output = torch.cat(pooled_outputs)
return sequence_output, pooled_output
def rerank(self, query: str, docs: List[str], doc_ids: List[str] = None, gpu_id=0, score_threshold=None):
with torch.no_grad():
batch_queries = [self.text_to_tensor(query)]
batch_candidates = [self.text_to_tensor(d) for d in docs]
_, query_rep = self.get_representations(batch_queries, gpu_id) # 1 x hidden_size
_, candi_rep = self.get_representations(batch_candidates, gpu_id) # num_cand x hidden_size
scores = torch.nn.functional.cosine_similarity(query_rep, candi_rep).tolist() # num_cand
if score_threshold:
skip_ids = [idx for idx, s in enumerate(scores) if s < score_threshold]
scores = [s for idx, s in enumerate(scores) if idx not in skip_ids]
docs = [d for idx, d in enumerate(docs) if idx not in skip_ids]
if doc_ids is not None:
doc_ids = [doc_id for idx, doc_id in enumerate(doc_ids) if idx not in skip_ids]
if len(docs) == 0:
return docs, doc_ids, scores
if doc_ids is not None:
doc_ids = [x for _, x in sorted(zip(scores, doc_ids), reverse=True)]
docs_scores = [(x, s) for s, x in sorted(zip(scores, docs), reverse=True)]
docs = [item[0] for item in docs_scores]
scores = [item[1] for item in docs_scores]
return docs, doc_ids, scores
================================================
FILE: prompt_builder/run.sh
================================================
#!/usr/bin/env bash
export PYTHONIOENCODING=utf-8
function generate_data() {
lang=$1
ranker=$2
ranking_fn=$3
echo "$lang, $ranker, $ranking_fn"
output_file_suffix=""
if [[ $ranker != "sparse" ]]; then
output_file_suffix="_${ranker}"
fi
# for RG-1
python augment_with_cfc.py \
--language $lang \
--rerank True \
--ranker $ranker \
--ranking_fn $ranking_fn \
--query_type last_n_lines \
--crossfile_distance 100 \
--maximum_chunk_to_rerank 1000 \
--maximum_cross_files 1000 \
--maximum_cross_file_chunk 5 \
--use_next_chunk_as_cfc True \
--skip_if_no_cfc False \
--output_file_suffix "rg1${output_file_suffix}"
# for oracle experiment
python augment_with_cfc.py \
--language $lang \
--rerank True \
--ranker $ranker \
--ranking_fn $ranking_fn \
--query_type groundtruth \
--crossfile_distance 100 \
--maximum_chunk_to_rerank 1000 \
--maximum_cross_files 1000 \
--maximum_cross_file_chunk 5 \
--use_next_chunk_as_cfc False \
--skip_if_no_cfc False \
--output_file_suffix "oracle${output_file_suffix}"
}
generate_data python sparse bm25
generate_data java sparse bm25
generate_data typescript sparse bm25
generate_data csharp sparse bm25
generate_data python sparse jaccard_sim
generate_data java sparse jaccard_sim
generate_data typescript sparse jaccard_sim
generate_data csharp sparse jaccard_sim
generate_data python unixcoder cosine_sim
generate_data java unixcoder cosine_sim
generate_data typescript unixcoder cosine_sim
generate_data csharp unixcoder cosine_sim
================================================
FILE: prompt_builder/utils.py
================================================
# Copyright Amazon.com, Inc. or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import os
from typing import List
from nltk.tokenize import word_tokenize
def tokenize_nltk(text):
words = word_tokenize(text)
output_list = []
for w in words:
w_list = re.findall(r'\w+', w)
output_list.extend(w_list)
return output_list
def file_distance(src_file, dest_file):
distance = -1
try:
commonpath = os.path.commonpath([src_file, dest_file])
rel_file1_path = os.path.relpath(src_file, commonpath)
rel_file2_path = os.path.relpath(dest_file, commonpath)
distance = rel_file1_path.count(os.sep) + rel_file2_path.count(os.sep)
except Exception as e:
# print(e, src_file, dest_file)
pass
return distance
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
================================================
FILE: requirements.txt
================================================
torch
transformers
datasets
tree-sitter
timeout-decorator
bitsandbytes
accelerate
scikit-learn
rank-bm25
fuzzywuzzy
nltk
sacrebleu
deepspeed
tiktoken
vllm>=0.3.3
================================================
FILE: scripts/build_treesitter.sh
================================================
mkdir ts_package;
cd ts_package;
# Download the tree-sitter package
git clone https://github.com/tree-sitter/tree-sitter-python.git;
git clone https://github.com/tree-sitter/tree-sitter-java.git;
git clone https://github.com/tree-sitter/tree-sitter-c-sharp.git;
git clone https://github.com/tree-sitter/tree-sitter-typescript.git;
cd ..;
# Build tree-sitter
python scripts/build_ts_lib.py
================================================
FILE: scripts/build_ts_lib.py
================================================
#!/usr/bin/env python
# coding=utf-8
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
from tree_sitter import Language
def build_language_lib():
for lang in ["java", "python", "typescript", "csharp"]:
ts_lang = "c-sharp" if lang == "csharp" else lang
if lang == "typescript":
git_dir = f"ts_package/tree-sitter-{ts_lang}/{lang}"
else:
git_dir = f"ts_package/tree-sitter-{ts_lang}"
Language.build_library(f'build/{lang}-lang-parser.so', [git_dir])
if __name__ == "__main__":
build_language_lib()
================================================
FILE: scripts/custom_generate.py
================================================
# Modifications Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
# Modified from https://github.com/huggingface/transformers/blob/v4.29.2/src/transformers/generation/utils.py
##############################################################################
import copy
import inspect
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch import nn
from transformers.deepspeed import is_deepspeed_zero3_enabled
from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from transformers.models.auto import (
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from transformers.utils import ModelOutput, logging
from transformers.generation.beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from transformers.generation.beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor,
EpsilonLogitsWarper,
EtaLogitsWarper,
ExponentialDecayLengthPenalty,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
ForceTokensLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitNormalization,
LogitsProcessorList,
MinLengthLogitsProcessor,
MinNewTokensLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
SuppressTokensAtBeginLogitsProcessor,
SuppressTokensLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
)
from transformers.generation.stopping_criteria import (
MaxLengthCriteria,
MaxTimeCriteria,
StoppingCriteria,
StoppingCriteriaList,
validate_stopping_criteria,
)
from transformers.generation.utils import GenerateOutput, SampleOutput, SampleDecoderOnlyOutput
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.streamers import BaseStreamer
logger = logging.get_logger(__name__)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
r"""
Generates sequences of token ids for models with a language modeling head.