\\n\\n', '.\\n\\n\\n\\n\\n', '<|end_of_text|>', '>\\n\\n\\n'],\n",
" 'temperature': 0.01,\n",
" 'max_tokens': 256\n",
" }\n",
"\n",
" ds_config = { \n",
" \"train_micro_batch_size_per_gpu\": 1,\n",
" \"gradient_accumulation_steps\": 8,\n",
" \"optimizer\": {\n",
" \"type\": \"AdamW\",\n",
" \"params\": {\n",
" \"lr\": 5e-5\n",
" }\n",
" },\n",
" \"fp16\": {\n",
" \"enabled\": True\n",
" },\n",
" \"zero_optimization\": {\n",
" \"stage\": 0\n",
" }\n",
" }\n",
"\n",
" training_args_dict = dict(\n",
" per_device_train_batch_size=1, \n",
" gradient_accumulation_steps=8,\n",
" logging_steps=10,\n",
" max_steps=30,\n",
" fp16=True,\n",
" log_level='debug'\n",
" )\n",
"\n",
" mode = 'infer_and_train'\n",
"\n",
" client_conf = dict(\n",
" model_conf=model_conf,\n",
" dataset_conf=dataset_conf,\n",
" training_args_conf=training_args_dict,\n",
" data_collator_conf=data_collator_conf,\n",
" mode=mode,\n",
" infer_inst_init_conf=infer_init_conf_client,\n",
" encode_template=encoder_prompt,\n",
" instruction_template=instruction_prompt,\n",
" decode_template=decoder_prompt,\n",
" remote_inference_kwargs=remote_inference_kwargs,\n",
" local_inference_kwargs=local_inference_kwargs,\n",
" perturb_doc_key='perturbed_doc',\n",
" perturbed_response_key='perturbed_response',\n",
" result_key='infer_result'\n",
" )\n",
"\n",
" server_conf = dict(\n",
" infer_inst_init_conf=infer_init_conf_server,\n",
" mode=mode\n",
" )\n",
"\n",
" homo_nn_0 = HomoNN(\n",
" 'nn_0',\n",
" train_data=reader_0.outputs[\"output_data\"],\n",
" runner_module=\"fedcot_runner\",\n",
" runner_class=\"FedCoTRunner\"\n",
" )\n",
"\n",
" homo_nn_0.guest.task_parameters(runner_conf=client_conf)\n",
" homo_nn_0.arbiter.task_parameters(runner_conf=server_conf)\n",
"\n",
" homo_nn_0.guest.conf.set(\"launcher_name\", \"deepspeed\")\n",
"\n",
" pipeline.add_tasks([reader_0, homo_nn_0])\n",
" pipeline.conf.set(\"task\", dict(engine_run={\"cores\": 4}))\n",
" pipeline.compile()\n",
" pipeline.fit()\n",
"\n",
"if __name__ == \"__main__\":\n",
" parser = argparse.ArgumentParser(\"PIPELINE DEMO\")\n",
" parser.add_argument(\"--config\", type=str, default=\"../config.yaml\",\n",
" help=\"config file\")\n",
" parser.add_argument(\"--namespace\", type=str, default=\"\",\n",
" help=\"namespace for data stored in FATE\")\n",
" args = parser.parse_args()\n",
" main(config=args.config, namespace=args.namespace)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: doc/tutorial/fedkseed/README.md
================================================
## FedKSeed
The Algorithm is based on the paper: [Federated Full-Parameter Tuning of Billion-Sized Language Models
with Communication Cost under 18 Kilobytes](https://arxiv.org/pdf/2312.06353.pdf) and the code is adaptor
from the https://github.com/alibaba/FederatedScope/tree/FedKSeed.
We refactor the code to make it more compatible with (transformers/PyTorch) framework
and integrate it into the FATE-LLM framework.
The main works include:
1. An KSeedZerothOrderOptimizer class that can be used to optimize model along given direction that generated with random seed.
2. An KSeedZOExtendedTrainer subclass of Trainer from transformers that can be used to train large language models with KSeedZerothOrderOptimizer.
3. Trainers for federated learning with large language models.
================================================
FILE: doc/tutorial/fedkseed/fedkseed-example.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Federated Tuning with FedKSeed methods in FATE-LLM"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this tutorial, we will demonstrate how to efficiently train federated large language models using the FATE-LLM framework. In FATE-LLM, we introduce the \"FedKSeed\" module, specifically designed for federated learning with large language models. The Idea of FedKSeed is to use Zeroth-Order-Optimizer to optimize model along given direction that generated with random seed. This method can be used to train large language models in a federated learning setting with extremely low communication cost.\n",
"\n",
"The Algorithm is based on the paper: [Federated Full-Parameter Tuning of Billion-Sized Language Models\n",
"with Communication Cost under 18 Kilobytes](https://arxiv.org/pdf/2312.06353.pdf) and the code is modified from the https://github.com/alibaba/FederatedScope/tree/FedKSeed. We refactor the code to make it more compatible with (transformers/PyTorch) framework and integrate it into the FATE-LLM framework.\n",
"\n",
"The main works include:\n",
"1. An KSeedZerothOrderOptimizer class that can be used to optimize model along given direction that generated with random seed.\n",
"2. An KSeedZOExtendedTrainer subclass of Trainer from transformers that can be used to train large language models with KSeedZerothOrderOptimizer.\n",
"3. Trainers for federated learning with large language models.\n",
"\n",
"In this tutorial, we will demonstrate how to use the FedKSeed method to train a large language model in a federated learning setting. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model: datajuicer/LLaMA-1B-dj-refine-150B\n",
"\n",
"This is the introduction from the Huggingface model hub: [datajuicer/LLaMA-1B-dj-refine-150B](https://huggingface.co/datajuicer/LLaMA-1B-dj-refine-150B)\n",
"\n",
"> The model architecture is LLaMA-1.3B and we adopt the OpenLLaMA implementation. The model is pre-trained on 150B tokens of Data-Juicer's refined RedPajama and Pile. It achieves an average score of 34.21 over 16 HELM tasks, beating Falcon-1.3B (trained on 350B tokens from RefinedWeb), Pythia-1.4B (trained on 300B tokens from original Pile) and Open-LLaMA-1.3B (trained on 150B tokens from original RedPajama and Pile).\n",
"\n",
"> For more details, please refer to our [paper](https://arxiv.org/abs/2309.02033).\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2024-02-29T09:27:23.512735Z",
"start_time": "2024-02-29T09:27:23.508790Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"# model_name_or_path = \"datajuicer/LLaMA-1B-dj-refine-150B\"\n",
"model_name_or_path = \"gpt2\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dataset: databricks/databricks-dolly-15k\n",
"\n",
"This is the introduction from the Huggingface dataset hub: [databricks/databricks-dolly-15k](https://huggingface.co/dataset/databricks/databricks-dolly-15k)\n",
"\n",
"> databricks-dolly-15k is a corpus of more than 15,000 records generated by thousands of Databricks employees to enable large language models to exhibit the magical interactivity of ChatGPT. Databricks employees were invited to create prompt / response pairs in each of eight different instruction categories, including the seven outlined in the InstructGPT paper, as well as an open-ended free-form category. The contributors were instructed to avoid using information from any source on the web with the exception of Wikipedia (for particular subsets of instruction categories), and explicitly instructed to avoid using generative AI in formulating instructions or responses. Examples of each behavior were provided to motivate the types of questions and instructions appropriate to each category\n",
"\n",
"To use this dataset, you first need to download it from the Huggingface dataset hub:\n",
"\n",
"```bash\n",
"mkdir -p ../../../examples/data/dolly && cd ../../../examples/data/dolly && wget wget https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl\\?download\\=true -O databricks-dolly-15k.jsonl\n",
"```\n",
"\n",
"### Check Dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2024-02-29T09:27:26.987779Z",
"start_time": "2024-02-29T09:27:24.706218Z"
}
},
"outputs": [],
"source": [
"from fate_llm.dataset.hf_dataset import Dolly15K\n",
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name_or_path)\n",
"special_tokens = tokenizer.special_tokens_map\n",
"if \"pad_token\" not in tokenizer.special_tokens_map:\n",
" special_tokens[\"pad_token\"] = special_tokens[\"eos_token\"]\n",
"\n",
"tokenizer.pad_token = tokenizer.eos_token\n",
"ds = Dolly15K(split=\"train\", tokenizer_params={\"pretrained_model_name_or_path\": model_name_or_path, **special_tokens},\n",
" tokenizer_apply_params=dict(truncation=True, max_length=tokenizer.model_max_length, padding=\"max_length\", return_tensors=\"pt\"))\n",
"ds = ds.load('../../../examples/data/dolly')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2024-02-29T09:27:27.875025Z",
"start_time": "2024-02-29T09:27:27.867839Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['instruction', 'context', 'response', 'category', 'text', 'input_ids', 'attention_mask'],\n",
" num_rows: 15011\n",
"})"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For more details of FATE-LLM dataset setting, we recommend that you read through these tutorials first: [NN Dataset Customization](https://github.com/FederatedAI/FATE/blob/master/doc/tutorial/pipeline/nn_tutorial/Homo-NN-Customize-your-Dataset.ipynb), [Some Built-In Dataset](https://github.com/FederatedAI/FATE/blob/master/doc/tutorial/pipeline/nn_tutorial/Introduce-Built-In-Dataset.ipynb),"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Check local training\n",
"\n",
"Before submitting a federated learning task, we will demonstrate how to perform local testing to ensure the proper functionality of your custom dataset, model. "
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"ExecuteTime": {
"end_time": "2024-02-29T09:38:33.175079Z",
"start_time": "2024-02-29T09:38:33.168844Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"from transformers import AutoModelForCausalLM, TrainingArguments, DataCollatorForLanguageModeling\n",
"from fate_llm.algo.fedkseed.trainer import KSeedZOExtendedTrainer, KSeedTrainingArguments\n",
"from fate_llm.algo.fedkseed.zo_utils import build_seed_candidates, get_even_seed_probabilities\n",
"\n",
"def test_training(zo_mode=True):\n",
" tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **special_tokens)\n",
" data_collector = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)\n",
" model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_name_or_path)\n",
"\n",
" training_args = TrainingArguments(output_dir='./',\n",
" dataloader_num_workers=1,\n",
" dataloader_prefetch_factor=1,\n",
" remove_unused_columns=True,\n",
" learning_rate=1e-5,\n",
" per_device_train_batch_size=1,\n",
" num_train_epochs=0.01,\n",
" )\n",
" kseed_args = KSeedTrainingArguments(zo_optim=zo_mode)\n",
" trainer = KSeedZOExtendedTrainer(model=model, train_dataset=ds, training_args=training_args, kseed_args=kseed_args,\n",
" tokenizer=tokenizer, data_collator=data_collector)\n",
" if zo_mode:\n",
" seed_candidates = build_seed_candidates(k=kseed_args.k)\n",
" seed_probabilities = get_even_seed_probabilities(k=kseed_args.k)\n",
" trainer.configure_seed_candidates(seed_candidates, seed_probabilities)\n",
" return trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"ExecuteTime": {
"end_time": "2024-02-29T09:39:37.602070Z",
"start_time": "2024-02-29T09:38:34.024223Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
"
\n",
" [151/151 00:59, Epoch 0/1]\n",
"
\n",
" \n",
" \n",
" \n",
" | Step | \n",
" Training Loss | \n",
"
\n",
" \n",
" \n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=151, training_loss=1.2660519429390005, metrics={'train_runtime': 61.8249, 'train_samples_per_second': 2.428, 'train_steps_per_second': 2.442, 'total_flos': 78910193664000.0, 'train_loss': 1.2660519429390005, 'epoch': 0.01})"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_training(zo_mode=True)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"ExecuteTime": {
"end_time": "2024-02-29T09:41:28.949449Z",
"start_time": "2024-02-29T09:39:54.802705Z"
},
"collapsed": false
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
"
\n",
" [151/151 01:29, Epoch 0/1]\n",
"
\n",
" \n",
" \n",
" \n",
" | Step | \n",
" Training Loss | \n",
"
\n",
" \n",
" \n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=151, training_loss=0.6093456950408733, metrics={'train_runtime': 92.6158, 'train_samples_per_second': 1.621, 'train_steps_per_second': 1.63, 'total_flos': 78910193664000.0, 'train_loss': 0.6093456950408733, 'epoch': 0.01})"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_training(zo_mode=False)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"You can see that Zeroth-Order-Optimizer has much worse performance than AdamW, that's the price we need to pay for the low communication cost. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Submit Federated Task\n",
"Once you have successfully completed local testing, We can submit a task to FATE. Please notice that this tutorial is ran on a standalone version. **Please notice that in this tutorial we are using a standalone version, if you are using a cluster version, you need to bind the data with the corresponding name&namespace on each machine.**\n",
"\n",
"In this example we load pretrained weights for gpt2 model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"from fate_client.pipeline.components.fate.reader import Reader\n",
"from fate_client.pipeline import FateFlowPipeline\n",
"from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_seq2seq_runner\n",
"from fate_client.pipeline.components.fate.nn.algo_params import TrainingArguments, FedAVGArguments\n",
"from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMDataFuncLoader\n",
"\n",
"guest = '10000'\n",
"host = '10000'\n",
"arbiter = '10000'\n",
"\n",
"epochs = 0.01\n",
"batch_size = 1\n",
"lr = 1e-5\n",
"\n",
"pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter)\n",
"pipeline.bind_local_path(path=\"/data/projects/fate/examples/data/dolly\", namespace=\"experiment\",\n",
" name=\"dolly\")\n",
"time.sleep(5)\n",
"\n",
"reader_0 = Reader(\"reader_0\", runtime_parties=dict(guest=guest, host=host))\n",
"reader_0.guest.task_parameters(\n",
" namespace=\"experiment\",\n",
" name=\"dolly\"\n",
")\n",
"reader_0.hosts[0].task_parameters(\n",
" namespace=\"experiment\",\n",
" name=\"dolly\"\n",
")\n",
"\n",
"tokenizer_params = dict(\n",
" pretrained_model_name_or_path=\"gpt2\",\n",
" trust_remote_code=True,\n",
")\n",
"conf = get_config_of_seq2seq_runner(\n",
" algo='fedkseed',\n",
" model=LLMModelLoader(\n",
" \"hf_model\",\n",
" \"HFAutoModelForCausalLM\",\n",
" # pretrained_model_name_or_path=\"datajuicer/LLaMA-1B-dj-refine-150B\",\n",
" pretrained_model_name_or_path=\"gpt2\",\n",
" trust_remote_code=True\n",
" ),\n",
" dataset=LLMDatasetLoader(\n",
" \"hf_dataset\",\n",
" \"Dolly15K\",\n",
" split=\"train\",\n",
" tokenizer_params=tokenizer_params,\n",
" tokenizer_apply_params=dict(\n",
" truncation=True,\n",
" max_length=1024,\n",
" )),\n",
" data_collator=LLMDataFuncLoader(\n",
" \"cust_func.cust_data_collator\",\n",
" \"get_seq2seq_tokenizer\",\n",
" tokenizer_params=tokenizer_params,\n",
" ),\n",
" training_args=TrainingArguments(\n",
" num_train_epochs=0.01,\n",
" per_device_train_batch_size=batch_size,\n",
" remove_unused_columns=True,\n",
" learning_rate=lr,\n",
" fp16=False,\n",
" use_cpu=False,\n",
" disable_tqdm=False,\n",
" ),\n",
" fed_args=FedAVGArguments(),\n",
" task_type='causal_lm',\n",
" save_trainable_weights_only=True,\n",
")\n",
"\n",
"conf[\"fed_args_conf\"] = {}\n",
"\n",
"homo_nn_0 = HomoNN(\n",
" 'nn_0',\n",
" runner_conf=conf,\n",
" train_data=reader_0.outputs[\"output_data\"],\n",
" runner_module=\"fedkseed_runner\",\n",
" runner_class=\"FedKSeedRunner\",\n",
")\n",
"\n",
"pipeline.add_tasks([reader_0, homo_nn_0])\n",
"pipeline.conf.set(\"task\", dict(engine_run={\"cores\": 1}))\n",
"\n",
"pipeline.compile()\n",
"pipeline.fit()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can use this script to submit the model, but submitting the model will take a long time to train and generate a long log, so we won't do it here."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: doc/tutorial/fedmkt/README.md
================================================
# FATE-LLM: FedMKT
The algorithm is based on paper ["FedMKT: Federated Mutual Knowledge Transfer for Large and SmallLanguage Models"](https://aclanthology.org/2025.coling-main.17.pdf), We integrate its code into the FATE-LLM framework.
## Citation
If you publish work that uses FedMKT, please cite FedMKT as follows:
```
@inproceedings{fan2025fedmkt,
title={Fedmkt: Federated mutual knowledge transfer for large and small language models},
author={Fan, Tao and Ma, Guoqiang and Kang, Yan and Gu, Hanlin and Song, Yuanfeng and Fan, Lixin and Chen, Kai and Yang, Qiang},
booktitle={Proceedings of the 31st International Conference on Computational Linguistics},
pages={243--255},
year={2025}
}
```
================================================
FILE: doc/tutorial/fedmkt/fedmkt.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Federated Tuning With FedMKT methods in FATE-LLM"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this tutorial, we will demonstrate how to efficiently train federated large language models using the FATE-LLM framework. In FATE-LLM, we introduce the \"FedMKT\" module, specifically designed for federated learning with large language models. FedMKT introduces a novel\n",
"federated mutual knowledge transfer framework that enables effective knowledge transfer between an LLM deployed on the server and SLMs residing on clients.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The Algorithm is based on paper [\"FedMKT: Federated Mutual Knowledge Transfer for Large and SmallLanguage Models\"](https://arxiv.org/pdf/2406.02224), We integrate its code into the FATE-LLM framework. \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Experiments\n",
"\n",
"Chapter List: \n",
"* settings\n",
" 1. DataSet: ARC-Challenge\n",
" 2. Models Use in \"FEDMKT\" Paper\n",
" 3. Prepare Optimal Vocabulary Mapping Tables\n",
" 4. Training LLMs with Lora\n",
"* experiment examples:\n",
" 1. Running FEDMKT With Launcher (Experimential Using): 4-SLMs\n",
" 2. Running FEDMKT With Launcher (Experimential Using): 1-SLM (One To One)\n",
" 3. Running FEDMKT With Launcher (Experimential Using): 1-SLM And SLM Trains Only (LLM2SLM)\n",
" 4. Running FEDMKT With Launcher (Experimential Using): 4-SLMs Homogeneous SFT\n",
" 5. Running FEDMKT with Pipeline (Industrial Using)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dataset: ARC-Challenge\n",
"\n",
"ARC-Challenge is a dataset of 7,787 genuine grade-school level, multiple-choice science questions, assembled to encourage research in advanced question-answering. \n",
"\n",
"You can refer to following link for more details about [ARC-Challange](https://huggingface.co/datasets/allenai/ai2_arc)\n",
"\n",
"In this section, we will download ARC-Challenge dataset from huggingface and splits it into five parts, part \"common\" for public dataset and other parts for slms(opt2, gpt2, llama, opt)'s training. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import datasets\n",
"\n",
"\n",
"data = datasets.load_dataset(\"ai2_arc\", \"ARC-Challenge\", download_mode=\"force_redownload\", ignore_verifications=True)\n",
"train_data = data.pop(\"train\")\n",
"\n",
"seed=123\n",
"n = train_data.shape[0]\n",
"client_num = 4\n",
"process_data_output_dir = \"\" # processed data saved directory should be specified, it will be used in later.\n",
"\n",
"client_data_num = n // (client_num + 1)\n",
"\n",
"for i in range(client_num):\n",
" splits = train_data.train_test_split(train_size=client_data_num, shuffle=True, seed=seed)\n",
" client_name = f\"client_{i}\"\n",
" data[client_name] = splits[\"train\"]\n",
" train_data = splits[\"test\"]\n",
"\n",
"if train_data.shape[0] == client_data_num:\n",
" data[\"common\"] = train_data\n",
"else:\n",
" data[\"common\"] = train_data.train_test_split(\n",
" train_size=client_data_num, shuffle=True, seed=args.seed\n",
" )[\"train\"]\n",
"\n",
"data.save_to_disk(process_data_output_dir)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Models Use In \"FEDMKT\" Paper\n",
"\n",
"LLM: [Llama-2-7B](https://huggingface.co/meta-llama/Llama-2-7b-hf) \n",
"SLM-0: [opt-1.3b](https://huggingface.co/facebook/opt-1.3b) \n",
"SLM-1: [gpt2-xlarge](https://huggingface.co/openai-community/gpt2-xl) \n",
"SLM-2: [Llama-1.3b](https://huggingface.co/princeton-nlp/Sheared-LLaMA-1.3B) \n",
"SLM-3: [bloom-1.1B](https://huggingface.co/bigscience/bloom-1b1)\n",
"\n",
"Users should download the models from huggingface before the following steps and saved them in local directories, as models are too big, redownload them cost too much times.\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# replaoce the names of models to local save directories\n",
"llm_pretrained_path = \"llama-2-7b-hf\"\n",
"slm_0_pretrained_path = \"opt-1.3b\"\n",
"slm_1_pretrained_path = \"gpt2-xl\"\n",
"slm_2_pretrained_path = \"Sheared-LLaMA-1.3B\"\n",
"slm_3_pretrained_path = \"bloom-1b1\"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prepare Optimal Vocabulary Mapping Tables\n",
"\n",
"To use \"FEDMKT\" for federated knowledge transfer, we need to build pptimal vocabulary mapping tables first.\n",
"In paper of \"FEDMKT\", it has One LLM and four SLMs, so we need to build eight pptimal vocabulary mapping tables. For each paired of (LLM, SLM), two tables should be built as co-training are needed.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fate_llm.algo.fedmkt.token_alignment.vocab_mapping import get_vocab_mappings\n",
"\n",
"\n",
"llm_slm_pairs = [\n",
" (llm_pretrained_path, slm_0_pretrained_path),\n",
" (llm_pretrained_path, slm_1_pretrained_path),\n",
" (llm_pretrained_path, slm_2_pretrained_path),\n",
" (llm_pretrained_path, slm_3_pretrained_path)\n",
"]\n",
"\n",
"vocab_mapping_directory = \"\" # replace this to actually paths\n",
"\n",
"slm_to_llm_vocab_mapping_paths = [\"opt_to_llama.json\", \"gpt2_to_llama.json\", \"llama_small_to_llama.json\", \"bloom_to_llama.json\"]\n",
"llm_to_slm_vocab_mapping_paths = [\"llama_to_opt.json\", \"llama_to_gpt2.json\", \"llama_to_llama_small\", \"llama_to_bloom.json\"]\n",
"\n",
"for idx in range(4):\n",
" slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + slm_to_llm_vocab_mapping_paths[idx]\n",
" llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + llm_to_slm_vocab_mapping_paths[idx]\n",
"\n",
"for idx, (llm_pretrained, slm_pretrained) in enumerate(llm_slm_pairs):\n",
" slm_to_llm_vocab_mapping_path = slm_to_llm_vocab_mapping_paths[idx]\n",
" llm_to_slm_vocab_mapping_path = llm_to_slm_vocab_mapping_paths[idx]\n",
" _ = get_vocab_mappings(slm_pretrained, llm_pretrained, slm_to_llm_vocab_mapping_paths[idx], num_processors=16)\n",
" _ = get_vocab_mappings(llm_pretrained, slm_pretrained, llm_to_slm_vocab_mapping_paths[idx], num_processors=16)\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training LLMs with Lora\n",
"\n",
"In this section, We will introduce the lora configs use in five models listed in paper: one LLM (Llama-2-7B), four SLMs(opt-1.3B, gpt2-xlarge, Llama-1.3B, bloom-1.1B)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"LLM models with peft is located on fate_llm/model_zoo, we will give a guide to use them. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Init LLm Llama-2-7B's Lora Config"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\n",
" target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']\n",
")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Init SLMs Lora Config"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"slm_pretrained_paths = [slm_0_pretrained_path, slm_1_pretrained_path, slm_2_pretrained_path, slm_3_pretrained_path]\n",
"slm_lora_target_modules = [\n",
" [\"q_proj\", \"v_proj\"],\n",
" [\"c_attn\"],\n",
" ['q_proj', 'k_proj', 'v_proj', 'o_proj'],\n",
" [\"query_key_value\"]\n",
"]\n",
"\n",
"def get_slm_conf(slm_idx):\n",
" slm_pretrained_path = slm_pretrained_paths[slm_idx]\n",
" lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\n",
" target_modules=slm_lora_target_modules[slm_idx]\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Running FEDMKT With Launcher (Experimential Using): 4-SLMs\n",
"\n",
"Using launcher to startup is mainly for experimential. Before running this section, make sure that [FATE-LLM Standalone](https://github.com/FederatedAI/FATE-LLM?tab=readme-ov-file#standalone-deployment) has been deployed."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Global Settings"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"process_data_output_dir = \"\"\n",
"llm_pretrained_path = \"Llama-2-7b-hf\"\n",
"slm_0_pretrained_path = \"opt-1.3b\"\n",
"slm_1_pretrained_path = \"gpt2-xl\"\n",
"slm_2_pretrained_path = \"Sheared-LLaMa-1.3B\"\n",
"slm_3_pretrained_path = \"bloom-1b1\"\n",
"llm_slm_pairs = [\n",
" (llm_pretrained_path, slm_0_pretrained_path),\n",
" (llm_pretrained_path, slm_1_pretrained_path),\n",
" (llm_pretrained_path, slm_2_pretrained_path),\n",
" (llm_pretrained_path, slm_3_pretrained_path)\n",
"]\n",
"\n",
"vocab_mapping_directory = \"\"\n",
"\n",
"slm_to_llm_vocab_mapping_paths = [\"opt_to_llama.json\", \"gpt2_to_llama.json\", \"llama_small_to_llama.json\", \"bloom_to_llama.json\"]\n",
"llm_to_slm_vocab_mapping_paths = [\"llama_to_opt.json\", \"llama_to_gpt2.json\", \"llama_to_llama_small\", \"llama_to_bloom.json\"]\n",
"\n",
"for idx in range(4):\n",
" slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + slm_to_llm_vocab_mapping_paths[idx]\n",
" llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + llm_to_slm_vocab_mapping_paths[idx]\n",
"\n",
"#### all variables has been defined above\n",
"\n",
"slm_pretrained_paths = [slm_0_pretrained_path, slm_1_pretrained_path, slm_2_pretrained_path, slm_3_pretrained_path]\n",
"slm_lora_target_modules = [\n",
" [\"q_proj\", \"v_proj\"],\n",
" [\"c_attn\"],\n",
" ['q_proj', 'k_proj', 'v_proj', 'o_proj'],\n",
" [\"query_key_value\"]\n",
"]\n",
"\n",
"global_epochs = 1\n",
"batch_size=4\n",
"llm_lr = 3e-5\n",
"slm_lrs = [3e-5, 3e-4, 3e-5, 3e-5, 3e-5]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Init FEDMKTLLM Runner"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"In this Section, we will introduce how to initialize \"FEDMKTLLM\" object."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Step1: Initialize LLM With LoraConfig"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from peft import LoraConfig, TaskType\n",
"from fate_llm.model_zoo.pellm.llama import LLaMa\n",
"from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTLLM\n",
"from fate.ml.nn.homo.fedavg import FedAVGArguments\n",
"from fate_llm.dataset.qa_dataset import QaDataset\n",
"from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n",
"from transformers import AutoConfig\n",
"\n",
"lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\n",
" target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']\n",
")\n",
"\n",
"model = LLaMa(\n",
" pretrained_path=llm_pretrained_path,\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" torch_dtype=\"bfloat16\" \n",
")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Step2: Specify Public Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pub_data = QaDataset(tokenizer_name_or_path=llm_pretrained_path,\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=\"common\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
"pub_data.load(process_data_output_dir)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Step3: Initialize FEDMKT Training Args"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"training_args = FedMKTTrainingArguments(\n",
" global_epochs=global_epochs,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=batch_size,\n",
" learning_rate=llm_lr,\n",
" output_dir=\"./\",\n",
" dataloader_num_workers=4,\n",
" remove_unused_columns=False,\n",
" warmup_ratio=0.008,\n",
" lr_scheduler_type=\"cosine\",\n",
" optim=\"adamw_torch\",\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.95,\n",
" weight_decay=0.1,\n",
" max_grad_norm=1.0,\n",
" use_cpu=False,\n",
" vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size, # pay attention to this, \n",
" # vocab_size must be specified to avoid dimension mismatch \n",
" # of tokenizer's vocab_size\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Step4: Initialize Other Variables"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fed_args = FedAVGArguments(\n",
" aggregate_strategy='epoch',\n",
" aggregate_freq=1\n",
")\n",
"\n",
"slm_to_llm_vocab_mapping = []\n",
"for path in slm_to_llm_vocab_mapping_paths:\n",
" with open(path, \"r\") as fin:\n",
" vocab_mapping = json.loads(fin.read())\n",
" slm_to_llm_vocab_mapping.append(vocab_mapping)\n",
"\n",
"slm_tokenizers = [get_tokenizer(slm_pretrained_path) for slm_pretrained_path in slm_pretrained_paths]\n",
"tokenizer = get_tokenizer(llm_pretrained_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Step5: New FEDMKTLLM Object"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer = FedMKTLLM(\n",
" ctx=ctx,\n",
" model=model,\n",
" training_args=training_args,\n",
" fed_args=fed_args,\n",
" train_set=pub_data,\n",
" tokenizer=tokenizer,\n",
" slm_tokenizers=slm_tokenizers,\n",
" slm_to_llm_vocab_mappings=slm_to_llm_vocab_mapping,\n",
" save_trainable_weights_only=True, # save lora weights only\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Step6: Training And Save Results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer.train()\n",
"trainer.save_model(output_dir=\"fill the path to save llm finetuning result\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Init FEDMKTSLM Runner\n",
"\n",
"FEDMKTSLM Runner is a slightly different of FEDMKTLLM Runner, we only introduce different variables"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Import SLMs you need to run, here we choose four Slms Using In Original Paper."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import transformers\n",
"from peft import LoraConfig, TaskType \n",
"from fate_llm.model_zoo.pellm.llama import LLaMa\n",
"from fate_llm.model_zoo.pellm.gpt2 import GPT2CLM\n",
"from fate_llm.model_zoo.pellm.opt import OPT\n",
"from fate_llm.model_zoo.pellm.bloom import Bloom\n",
"from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTSLM\n",
"from fate_llm.dataset.qa_dataset import QaDataset\n",
"from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n",
"from transformers import AutoConfig\n",
"\n",
"slm_idx = 0\n",
"\n",
"slm_model_class = [\n",
" OPT,\n",
" GPT2CLM,\n",
" LLaMa,\n",
" Bloom\n",
"]\n",
" \n",
"lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\n",
" target_modules=slm_lora_target_modules[slm_idx]\n",
")\n",
"\n",
"model = slm_model_class[slm_idx](\n",
" pretrained_path=slm_pretrained_paths[slm_idx],\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" torch_dtype=\"bfloat16\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Specify Private Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"priv_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=f\"client_{slm_idx}\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
"priv_data.load(process_data_output_dir)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Other Variables "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = get_tokenizer(slm_pretrained_paths[slm_idx])\n",
"\n",
"import json\n",
"with open(llm_to_slm_vocab_mapping_paths[slm_idx], \"r\") as fin:\n",
" vocab_mapping = json.loads(fin.read())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### New FEDMKTSLM Object"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer = FedMKTSLM(\n",
" ctx=ctx,\n",
" model=model,\n",
" training_args=training_args,\n",
" fed_args=fed_args,\n",
" pub_train_set=pub_data,\n",
" priv_train_set=priv_data,\n",
" tokenizer=tokenizer,\n",
" save_trainable_weights_only=True, # save lora weights only\n",
" llm_tokenizer=get_tokenizer(llm_pretrained_path), # different with LLM setting\n",
" llm_to_slm_vocab_mapping=vocab_mapping, # different with LLM setting\n",
" data_collator=transformers.DataCollatorForSeq2Seq(tokenizer) # use to train private dataset\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Complete Code To DO SFT With 4 SLMs\n",
"\n",
"Please paste the code in \"fedmkt_4_slms.py\" and execute it with the following command"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"python fedmkt_4_slms.py --parties guest:9999 host:9999 host:10000 host:10001 arbiter:9999 --log_level INFO\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# fedmkt_4_slms.py\n",
"\n",
"import os\n",
"\n",
"from fate.arch import Context\n",
"from fate.arch.launchers.multiprocess_launcher import launch\n",
"import json\n",
"\n",
"process_data_output_dir = \"\"\n",
"llm_pretrained_path = \"Llama-2-7b-hf\"\n",
"slm_0_pretrained_path = \"opt-1.3b\"\n",
"slm_1_pretrained_path = \"gpt2-xl\"\n",
"slm_2_pretrained_path = \"Sheared-LLaMa-1.3B\"\n",
"slm_3_pretrained_path = \"bloom-1b1\"\n",
"llm_slm_pairs = [\n",
" (llm_pretrained_path, slm_0_pretrained_path),\n",
" (llm_pretrained_path, slm_1_pretrained_path),\n",
" (llm_pretrained_path, slm_2_pretrained_path),\n",
" (llm_pretrained_path, slm_3_pretrained_path)\n",
"]\n",
"\n",
"vocab_mapping_directory = \"\"\n",
"\n",
"slm_to_llm_vocab_mapping_paths = [\"opt_to_llama.json\", \"gpt2_to_llama.json\", \"llama_small_to_llama.json\", \"bloom_to_llama.json\"]\n",
"llm_to_slm_vocab_mapping_paths = [\"llama_to_opt.json\", \"llama_to_gpt2.json\", \"llama_to_llama_small\", \"llama_to_bloom.json\"]\n",
"\n",
"for idx in range(4):\n",
" slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + slm_to_llm_vocab_mapping_paths[idx]\n",
" llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + llm_to_slm_vocab_mapping_paths[idx]\n",
"\n",
"slm_pretrained_paths = [slm_0_pretrained_path, slm_1_pretrained_path, slm_2_pretrained_path, slm_3_pretrained_path]\n",
"slm_lora_target_modules = [\n",
" [\"q_proj\", \"v_proj\"],\n",
" [\"c_attn\"],\n",
" ['q_proj', 'k_proj', 'v_proj', 'o_proj'],\n",
" [\"query_key_value\"]\n",
"]\n",
"\n",
"global_epochs = 5\n",
"batch_size=4\n",
"llm_lr = 3e-5\n",
"slm_lrs = [3e-5, 3e-4, 3e-5, 3e-5, 3e-5]\n",
"\n",
"llm_model_saved_directory = \"./models/fedmkt_4_slms_llm_model\"\n",
"slm_models_saved_directory = [\n",
" \"./models/fedmkt_4_slms_slm_0\", \n",
" \"./models/fedmkt_4_slms_slm_1\", \n",
" \"./models/fedmkt_4_slms_slm_2\", \n",
" \"./models/fedmkt_4_slms_slm_3\"\n",
"]\n",
"\n",
"\n",
"def train_llm(ctx):\n",
" from peft import LoraConfig, TaskType\n",
" from fate_llm.model_zoo.pellm.llama import LLaMa\n",
" from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTLLM\n",
" from fate_llm.dataset.qa_dataset import QaDataset\n",
" from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n",
" from transformers import AutoConfig\n",
"\n",
" lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\n",
" target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']\n",
" )\n",
"\n",
" model = LLaMa(\n",
" pretrained_path=llm_pretrained_path,\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" torch_dtype=\"bfloat16\"\n",
" )\n",
"\n",
" pub_data = QaDataset(tokenizer_name_or_path=llm_pretrained_path,\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=\"common\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" pub_data.load(process_data_output_dir)\n",
"\n",
" training_args = FedMKTTrainingArguments(\n",
" global_epochs=global_epochs,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=batch_size,\n",
" learning_rate=llm_lr,\n",
" output_dir=\"./\",\n",
" dataloader_num_workers=4,\n",
" remove_unused_columns=False,\n",
" warmup_ratio=0.008,\n",
" lr_scheduler_type=\"cosine\",\n",
" optim=\"adamw_torch\",\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.95,\n",
" weight_decay=0.1,\n",
" max_grad_norm=1.0,\n",
" use_cpu=False,\n",
" vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size,\n",
" )\n",
"\n",
" slm_to_llm_vocab_mapping = []\n",
" for path in slm_to_llm_vocab_mapping_paths:\n",
" with open(path, \"r\") as fin:\n",
" vocab_mapping = json.loads(fin.read())\n",
" slm_to_llm_vocab_mapping.append(vocab_mapping)\n",
"\n",
" slm_tokenizers = [get_tokenizer(slm_pretrained_path) for slm_pretrained_path in slm_pretrained_paths]\n",
"\n",
" tokenizer = get_tokenizer(llm_pretrained_path)\n",
" trainer = FedMKTLLM(\n",
" ctx=ctx,\n",
" model=model,\n",
" training_args=training_args,\n",
" train_set=pub_data,\n",
" tokenizer=tokenizer,\n",
" slm_tokenizers=slm_tokenizers,\n",
" slm_to_llm_vocab_mappings=slm_to_llm_vocab_mapping,\n",
" save_trainable_weights_only=True,\n",
" )\n",
"\n",
" trainer.train()\n",
" trainer.save_model(llm_model_saved_directory)\n",
"\n",
"\n",
"def train_slm(ctx, slm_idx):\n",
" import transformers\n",
" from peft import LoraConfig, TaskType\n",
" from fate_llm.model_zoo.pellm.llama import LLaMa\n",
" from fate_llm.model_zoo.pellm.gpt2 import GPT2CLM\n",
" from fate_llm.model_zoo.pellm.opt import OPT\n",
" from fate_llm.model_zoo.pellm.bloom import Bloom\n",
" from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTSLM\n",
" from fate_llm.dataset.qa_dataset import QaDataset\n",
" from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n",
" from transformers import AutoConfig\n",
"\n",
" slm_model_class = [\n",
" OPT,\n",
" GPT2CLM,\n",
" LLaMa,\n",
" Bloom\n",
" ]\n",
"\n",
" lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\n",
" target_modules=slm_lora_target_modules[slm_idx]\n",
" )\n",
"\n",
" model = slm_model_class[slm_idx](\n",
" pretrained_path=slm_pretrained_paths[slm_idx],\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" torch_dtype=\"bfloat16\"\n",
" )\n",
"\n",
" priv_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=f\"client_{slm_idx}\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" priv_data.load(process_data_output_dir)\n",
"\n",
" pub_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=\"common\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" pub_data.load(process_data_output_dir)\n",
"\n",
" training_args = FedMKTTrainingArguments(\n",
" global_epochs=global_epochs,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=batch_size,\n",
" learning_rate=slm_lrs[slm_idx],\n",
" output_dir=\"./\",\n",
" dataloader_num_workers=4,\n",
" remove_unused_columns=False,\n",
" warmup_ratio=0.008,\n",
" lr_scheduler_type=\"cosine\",\n",
" optim=\"adamw_torch\",\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.95,\n",
" weight_decay=0.1,\n",
" max_grad_norm=1.0,\n",
" use_cpu=False,\n",
" vocab_size=AutoConfig.from_pretrained(slm_pretrained_paths[slm_idx]).vocab_size,\n",
" )\n",
"\n",
" tokenizer = get_tokenizer(slm_pretrained_paths[slm_idx])\n",
"\n",
" import json\n",
" with open(llm_to_slm_vocab_mapping_paths[slm_idx], \"r\") as fin:\n",
" vocab_mapping = json.loads(fin.read())\n",
"\n",
" trainer = FedMKTSLM(\n",
" ctx=ctx,\n",
" model=model,\n",
" training_args=training_args,\n",
" pub_train_set=pub_data,\n",
" priv_train_set=priv_data,\n",
" tokenizer=tokenizer,\n",
" save_trainable_weights_only=True,\n",
" llm_tokenizer=get_tokenizer(llm_pretrained_path),\n",
" llm_to_slm_vocab_mapping=vocab_mapping,\n",
" data_collator=transformers.DataCollatorForSeq2Seq(tokenizer)\n",
" )\n",
"\n",
" trainer.train()\n",
" trainer.save_model(slm_models_saved_directory[slm_idx])\n",
"\n",
"\n",
"def run(ctx: Context):\n",
" if ctx.is_on_arbiter:\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
" train_llm(ctx)\n",
" elif ctx.is_on_guest:\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
" train_slm(ctx, slm_idx=0)\n",
" else:\n",
" if ctx.local.party[1] == \"9999\":\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n",
" slm_idx = 1\n",
" elif ctx.local.party[1] == \"10000\":\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n",
" slm_idx = 2\n",
" elif ctx.local.party[1] == \"10001\":\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"4\"\n",
" slm_idx = 3\n",
" else:\n",
" raise ValueError(f\"party_id={ctx.local.party[1]} is illegal\")\n",
"\n",
" train_slm(ctx, slm_idx=slm_idx)\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" launch(run)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Running FEDMKT With Launcher (Experimential Using): 1-SLM (One To One)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Actually, a slightly modifications from 4-SLMs running code are enough to do sft with single clients, it will be listed in below sections, we take SLM-0(OPT) as an example"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Only Use Single Optimal Vocabulary Mapping Tables"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"slm_idx = 0\n",
"slm_to_llm_vocab_mapping = []\n",
"with open(slm_to_llm_vocab_mapping_paths[slm_idx], \"r\") as fin:\n",
" vocab_mapping = json.loads(fin.read())\n",
" slm_to_llm_vocab_mapping.append(vocab_mapping)\n",
"\n",
"slm_tokenizers = [get_tokenizer(slm_pretrained_paths[slm_idx])]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Complete Code To DO SFT With 1 SLM\n",
"\n",
"Please paste the code in \"fedmkt_1_slm.py\" and execute it with the following command"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"python fedmkt_1_slm.py --parties guest:9999 arbiter:9999 --log_level INFO\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# fedmkt_1_slm.py\n",
"\n",
"import os\n",
"\n",
"from fate.arch import Context\n",
"from fate.arch.launchers.multiprocess_launcher import launch\n",
"import json\n",
"\n",
"process_data_output_dir = \"\"\n",
"llm_pretrained_path = \"Llama-2-7b-hf\"\n",
"slm_0_pretrained_path = \"opt-1.3b\"\n",
"slm_1_pretrained_path = \"gpt2-xl\"\n",
"slm_2_pretrained_path = \"Sheared-LLaMa-1.3B\"\n",
"slm_3_pretrained_path = \"bloom-1b1\"\n",
"llm_slm_pairs = [\n",
" (llm_pretrained_path, slm_0_pretrained_path),\n",
" (llm_pretrained_path, slm_1_pretrained_path),\n",
" (llm_pretrained_path, slm_2_pretrained_path),\n",
" (llm_pretrained_path, slm_3_pretrained_path)\n",
"]\n",
"\n",
"vocab_mapping_directory = \"\"\n",
"\n",
"slm_to_llm_vocab_mapping_paths = [\"opt_to_llama.json\", \"gpt2_to_llama.json\", \"llama_small_to_llama.json\", \"bloom_to_llama.json\"]\n",
"llm_to_slm_vocab_mapping_paths = [\"llama_to_opt.json\", \"llama_to_gpt2.json\", \"llama_to_llama_small\", \"llama_to_bloom.json\"]\n",
"\n",
"for idx in range(4):\n",
" slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + slm_to_llm_vocab_mapping_paths[idx]\n",
" llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + llm_to_slm_vocab_mapping_paths[idx]\n",
"\n",
"slm_pretrained_paths = [slm_0_pretrained_path, slm_1_pretrained_path, slm_2_pretrained_path, slm_3_pretrained_path]\n",
"slm_lora_target_modules = [\n",
" [\"q_proj\", \"v_proj\"],\n",
" [\"c_attn\"],\n",
" ['q_proj', 'k_proj', 'v_proj', 'o_proj'],\n",
" [\"query_key_value\"]\n",
"]\n",
"\n",
"global_epochs = 5\n",
"batch_size = 4\n",
"llm_lr = 3e-5\n",
"slm_lrs = [3e-5]\n",
"\n",
"llm_model_saved_directory = \"./models/fedmkt_single_slm_llm\"\n",
"slm_models_saved_directory = [\n",
" \"./models/fedmkt_single_slm_opt\",\n",
"]\n",
"\n",
"\n",
"def train_llm(ctx, slm_idx):\n",
" from peft import LoraConfig, TaskType\n",
" from fate_llm.model_zoo.pellm.llama import LLaMa\n",
" from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTLLM\n",
" from fate_llm.dataset.qa_dataset import QaDataset\n",
" from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n",
" from transformers import AutoConfig\n",
"\n",
" lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\n",
" target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']\n",
" )\n",
"\n",
" model = LLaMa(\n",
" pretrained_path=llm_pretrained_path,\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" torch_dtype=\"bfloat16\"\n",
" )\n",
"\n",
" pub_data = QaDataset(tokenizer_name_or_path=llm_pretrained_path,\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=\"common\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" pub_data.load(process_data_output_dir)\n",
"\n",
" training_args = FedMKTTrainingArguments(\n",
" global_epochs=global_epochs,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=batch_size,\n",
" learning_rate=llm_lr,\n",
" output_dir=\"./\",\n",
" dataloader_num_workers=4,\n",
" remove_unused_columns=False,\n",
" warmup_ratio=0.008,\n",
" lr_scheduler_type=\"cosine\",\n",
" optim=\"adamw_torch\",\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.95,\n",
" weight_decay=0.1,\n",
" max_grad_norm=1.0,\n",
" use_cpu=False,\n",
" vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size,\n",
" )\n",
"\n",
" slm_to_llm_vocab_mapping = []\n",
" with open(slm_to_llm_vocab_mapping_paths[slm_idx], \"r\") as fin:\n",
" vocab_mapping = json.loads(fin.read())\n",
" slm_to_llm_vocab_mapping.append(vocab_mapping)\n",
"\n",
" slm_tokenizers = [get_tokenizer(slm_pretrained_paths[slm_idx])]\n",
"\n",
" tokenizer = get_tokenizer(llm_pretrained_path)\n",
" trainer = FedMKTLLM(\n",
" ctx=ctx,\n",
" model=model,\n",
" training_args=training_args,\n",
" train_set=pub_data,\n",
" tokenizer=tokenizer,\n",
" slm_tokenizers=slm_tokenizers,\n",
" slm_to_llm_vocab_mappings=slm_to_llm_vocab_mapping,\n",
" save_trainable_weights_only=True,\n",
" )\n",
"\n",
" trainer.train()\n",
" trainer.save_model(llm_model_saved_directory)\n",
"\n",
"\n",
"def train_slm(ctx, slm_idx):\n",
" import transformers\n",
" from peft import LoraConfig, TaskType\n",
" from fate_llm.model_zoo.pellm.llama import LLaMa\n",
" from fate_llm.model_zoo.pellm.gpt2 import GPT2CLM\n",
" from fate_llm.model_zoo.pellm.opt import OPT\n",
" from fate_llm.model_zoo.pellm.bloom import Bloom\n",
" from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTSLM\n",
" from fate_llm.dataset.qa_dataset import QaDataset\n",
" from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n",
" from transformers import AutoConfig\n",
"\n",
" slm_model_class = [\n",
" OPT,\n",
" GPT2CLM,\n",
" LLaMa,\n",
" Bloom\n",
" ]\n",
"\n",
" lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\n",
" target_modules=slm_lora_target_modules[slm_idx]\n",
" )\n",
"\n",
" model = slm_model_class[slm_idx](\n",
" pretrained_path=slm_pretrained_paths[slm_idx],\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" torch_dtype=\"bfloat16\"\n",
" )\n",
"\n",
" priv_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=f\"client_{slm_idx}\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" priv_data.load(process_data_output_dir)\n",
"\n",
" pub_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=\"common\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" pub_data.load(process_data_output_dir)\n",
"\n",
" training_args = FedMKTTrainingArguments(\n",
" global_epochs=global_epochs,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=batch_size,\n",
" learning_rate=slm_lrs[slm_idx],\n",
" output_dir=\"./\",\n",
" dataloader_num_workers=4,\n",
" remove_unused_columns=False,\n",
" warmup_ratio=0.008,\n",
" lr_scheduler_type=\"cosine\",\n",
" optim=\"adamw_torch\",\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.95,\n",
" weight_decay=0.1,\n",
" max_grad_norm=1.0,\n",
" use_cpu=False,\n",
" vocab_size=AutoConfig.from_pretrained(slm_pretrained_paths[slm_idx]).vocab_size,\n",
" )\n",
"\n",
" tokenizer = get_tokenizer(slm_pretrained_paths[slm_idx])\n",
"\n",
" import json\n",
" with open(llm_to_slm_vocab_mapping_paths[slm_idx], \"r\") as fin:\n",
" vocab_mapping = json.loads(fin.read())\n",
"\n",
" trainer = FedMKTSLM(\n",
" ctx=ctx,\n",
" model=model,\n",
" training_args=training_args,\n",
" pub_train_set=pub_data,\n",
" priv_train_set=priv_data,\n",
" tokenizer=tokenizer,\n",
" save_trainable_weights_only=True,\n",
" llm_tokenizer=get_tokenizer(llm_pretrained_path),\n",
" llm_to_slm_vocab_mapping=vocab_mapping,\n",
" data_collator=transformers.DataCollatorForSeq2Seq(tokenizer)\n",
" )\n",
"\n",
" trainer.train()\n",
" trainer.save_model(slm_models_saved_directory[slm_idx])\n",
"\n",
"\n",
"def run(ctx: Context):\n",
" if ctx.is_on_arbiter:\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
" train_llm(ctx, slm_idx=0)\n",
" else:\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
" train_slm(ctx, slm_idx=0)\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" launch(run)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Running FEDMKT With Launcher (Experimential Using): 1-SLM And SLM Trains Only (LLM2SLM)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this section, we introduce how to do SFT using FEDMKT algorithm, with only single SLM are trained, but without LLM training, means that SLM distill knowlege from LLM only, not co-training."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Difference With Section \"Running FEDMKT With Launcher (Experimential Using): 1-SLMs\"\n",
"\n",
"Add llm_training=False to fedmkt_training_args to both LLM and LLM is enough!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Complete Code To DO SFT With 1 SLM And SLM Trains Only\n",
"\n",
"Please paste the code in \"fedmkt_llm_to_slm.py\" and execute it with the following command"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"python fedmkt_llm_to_slm.py --parties guest:9999 arbiter:9999 --log_level INFO\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# fedmkt_llm_to_slm.py\n",
"\n",
"import os\n",
"\n",
"from fate.arch import Context\n",
"from fate.arch.launchers.multiprocess_launcher import launch\n",
"import json\n",
"\n",
"process_data_output_dir = \"\"\n",
"llm_pretrained_path = \"Llama-2-7b-hf\"\n",
"slm_0_pretrained_path = \"opt-1.3b\"\n",
"slm_1_pretrained_path = \"gpt2-xl\"\n",
"slm_2_pretrained_path = \"Sheared-LLaMa-1.3B\"\n",
"slm_3_pretrained_path = \"bloom-1b1\"\n",
"llm_slm_pairs = [\n",
" (llm_pretrained_path, slm_0_pretrained_path),\n",
" (llm_pretrained_path, slm_1_pretrained_path),\n",
" (llm_pretrained_path, slm_2_pretrained_path),\n",
" (llm_pretrained_path, slm_3_pretrained_path)\n",
"]\n",
"\n",
"vocab_mapping_directory = \"\"\n",
"\n",
"slm_to_llm_vocab_mapping_paths = [\"opt_to_llama.json\", \"gpt2_to_llama.json\", \"llama_small_to_llama.json\", \"bloom_to_llama.json\"]\n",
"llm_to_slm_vocab_mapping_paths = [\"llama_to_opt.json\", \"llama_to_gpt2.json\", \"llama_to_llama_small\", \"llama_to_bloom.json\"]\n",
"\n",
"for idx in range(4):\n",
" slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + slm_to_llm_vocab_mapping_paths[idx]\n",
" llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + llm_to_slm_vocab_mapping_paths[idx]\n",
"\n",
"slm_pretrained_paths = [slm_0_pretrained_path, slm_1_pretrained_path, slm_2_pretrained_path, slm_3_pretrained_path]\n",
"slm_lora_target_modules = [\n",
" [\"q_proj\", \"v_proj\"],\n",
" [\"c_attn\"],\n",
" ['q_proj', 'k_proj', 'v_proj', 'o_proj'],\n",
" [\"query_key_value\"]\n",
"]\n",
"\n",
"global_epochs = 5\n",
"batch_size = 4\n",
"llm_lr = 3e-5\n",
"slm_lrs = [3e-5]\n",
"\n",
"slm_models_saved_directory = [\n",
" \"./models/fedmkt_llm_to_slm_opt\",\n",
"]\n",
"\n",
"\n",
"def train_llm(ctx, slm_idx):\n",
" from peft import LoraConfig, TaskType\n",
" from fate_llm.model_zoo.pellm.llama import LLaMa\n",
" from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTLLM\n",
" from fate_llm.dataset.qa_dataset import QaDataset\n",
" from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n",
" from transformers import AutoConfig\n",
"\n",
" lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\n",
" target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']\n",
" )\n",
"\n",
" model = LLaMa(\n",
" pretrained_path=llm_pretrained_path,\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" torch_dtype=\"bfloat16\"\n",
" )\n",
"\n",
" pub_data = QaDataset(tokenizer_name_or_path=llm_pretrained_path,\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=\"common\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" pub_data.load(process_data_output_dir)\n",
"\n",
" training_args = FedMKTTrainingArguments(\n",
" global_epochs=global_epochs,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=batch_size,\n",
" learning_rate=llm_lr,\n",
" output_dir=\"./\",\n",
" dataloader_num_workers=4,\n",
" remove_unused_columns=False,\n",
" warmup_ratio=0.008,\n",
" lr_scheduler_type=\"cosine\",\n",
" optim=\"adamw_torch\",\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.95,\n",
" weight_decay=0.1,\n",
" max_grad_norm=1.0,\n",
" use_cpu=False,\n",
" vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size,\n",
" llm_training=False\n",
" )\n",
"\n",
" slm_to_llm_vocab_mapping = []\n",
" with open(slm_to_llm_vocab_mapping_paths[slm_idx], \"r\") as fin:\n",
" vocab_mapping = json.loads(fin.read())\n",
" slm_to_llm_vocab_mapping.append(vocab_mapping)\n",
"\n",
" slm_tokenizers = [get_tokenizer(slm_pretrained_paths[slm_idx])]\n",
"\n",
" tokenizer = get_tokenizer(llm_pretrained_path)\n",
" trainer = FedMKTLLM(\n",
" ctx=ctx,\n",
" model=model,\n",
" training_args=training_args,\n",
" train_set=pub_data,\n",
" tokenizer=tokenizer,\n",
" slm_tokenizers=slm_tokenizers,\n",
" slm_to_llm_vocab_mappings=slm_to_llm_vocab_mapping,\n",
" save_trainable_weights_only=True,\n",
" )\n",
"\n",
" trainer.train()\n",
"\n",
"\n",
"def train_slm(ctx, slm_idx):\n",
" import transformers\n",
" from peft import LoraConfig, TaskType\n",
" from fate_llm.model_zoo.pellm.llama import LLaMa\n",
" from fate_llm.model_zoo.pellm.gpt2 import GPT2CLM\n",
" from fate_llm.model_zoo.pellm.opt import OPT\n",
" from fate_llm.model_zoo.pellm.bloom import Bloom\n",
" from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTSLM\n",
" from fate_llm.dataset.qa_dataset import QaDataset\n",
" from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n",
" from transformers import AutoConfig\n",
"\n",
" slm_model_class = [\n",
" OPT,\n",
" GPT2CLM,\n",
" LLaMa,\n",
" Bloom\n",
" ]\n",
"\n",
" lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1,\n",
" target_modules=slm_lora_target_modules[slm_idx]\n",
" )\n",
"\n",
" model = slm_model_class[slm_idx](\n",
" pretrained_path=slm_pretrained_paths[slm_idx],\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" torch_dtype=\"bfloat16\"\n",
" )\n",
"\n",
" priv_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=f\"client_{slm_idx}\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" priv_data.load(process_data_output_dir)\n",
"\n",
" pub_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=\"common\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" pub_data.load(process_data_output_dir)\n",
"\n",
" training_args = FedMKTTrainingArguments(\n",
" global_epochs=global_epochs,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=batch_size,\n",
" learning_rate=slm_lrs[slm_idx],\n",
" output_dir=\"./\",\n",
" dataloader_num_workers=4,\n",
" remove_unused_columns=False,\n",
" warmup_ratio=0.008,\n",
" lr_scheduler_type=\"cosine\",\n",
" optim=\"adamw_torch\",\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.95,\n",
" weight_decay=0.1,\n",
" max_grad_norm=1.0,\n",
" use_cpu=False,\n",
" vocab_size=AutoConfig.from_pretrained(slm_pretrained_paths[slm_idx]).vocab_size,\n",
" llm_training=False\n",
" )\n",
"\n",
" tokenizer = get_tokenizer(slm_pretrained_paths[slm_idx])\n",
"\n",
" import json\n",
" with open(llm_to_slm_vocab_mapping_paths[slm_idx], \"r\") as fin:\n",
" vocab_mapping = json.loads(fin.read())\n",
"\n",
" trainer = FedMKTSLM(\n",
" ctx=ctx,\n",
" model=model,\n",
" training_args=training_args,\n",
" pub_train_set=pub_data,\n",
" priv_train_set=priv_data,\n",
" tokenizer=tokenizer,\n",
" save_trainable_weights_only=True,\n",
" llm_tokenizer=get_tokenizer(llm_pretrained_path),\n",
" llm_to_slm_vocab_mapping=vocab_mapping,\n",
" data_collator=transformers.DataCollatorForSeq2Seq(tokenizer)\n",
" )\n",
"\n",
" trainer.train()\n",
" trainer.save_model(slm_models_saved_directory[slm_idx])\n",
"\n",
"\n",
"def run(ctx: Context):\n",
" if ctx.is_on_arbiter:\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
" train_llm(ctx, slm_idx=0)\n",
" else:\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
" train_slm(ctx, slm_idx=0)\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" launch(run)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Running FEDMKT With Launcher (Experimential Using): 4-SLMs Homogeneous SFT\n",
"\n",
"To run homogeneous experiments, two steps are needed.\n",
"1. add post_fedavg=True to fedmkt_training_args to both LLM and LLM is enough!\n",
"2. add fed_args to FEDMKTLLM/FEDMKTSLM"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# initialze fed args\n",
"from fate.ml.nn.homo.fedavg import FedAVGArguments\n",
"\n",
"fed_args = FedAVGArguments(\n",
" aggregate_strategy='epoch',\n",
" aggregate_freq=1\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Complete Code To DO SFT With 4-SLMs Homogeneous SFT\n",
"\n",
"Please paste the code in \"fedmkt_4_slms_homo.py\" and execute it with the following command"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"python fedmkt_4_slms_homo.py --parties guest:9999 host:9999 host:10000 host:10001 arbiter:9999 --log_level INFO\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# fedmkt_4_slms_homo.py\n",
"\n",
"import os\n",
"\n",
"from fate.arch import Context\n",
"from fate.arch.launchers.multiprocess_launcher import launch\n",
"import json\n",
"\n",
"process_data_output_dir = \"\"\n",
"llm_pretrained_path = \"Llama-2-7b-hf\"\n",
"slm_0_pretrained_path = \"opt-1.3b\"\n",
"slm_1_pretrained_path = \"opt-1.3b\"\n",
"slm_2_pretrained_path = \"opt-1.3b\"\n",
"slm_3_pretrained_path = \"opt-1.3b\"\n",
"llm_slm_pairs = [\n",
" (llm_pretrained_path, slm_0_pretrained_path),\n",
" (llm_pretrained_path, slm_1_pretrained_path),\n",
" (llm_pretrained_path, slm_2_pretrained_path),\n",
" (llm_pretrained_path, slm_3_pretrained_path)\n",
"]\n",
"\n",
"vocab_mapping_directory = \"\"\n",
"\n",
"slm_to_llm_vocab_mapping_paths = [\"opt_to_llama.json\"] * 4\n",
"llm_to_slm_vocab_mapping_paths = [\"llama_to_opt.json\"] * 4\n",
"\n",
"for idx in range(4):\n",
" slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + slm_to_llm_vocab_mapping_paths[idx]\n",
" llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + llm_to_slm_vocab_mapping_paths[idx]\n",
"\n",
"slm_pretrained_paths = [slm_0_pretrained_path] * 4\n",
"slm_lora_target_modules = [[\"q_proj\", \"v_proj\"]] * 4\n",
"\n",
"global_epochs = 5\n",
"batch_size = 4\n",
"llm_lr = 3e-5\n",
"slm_lrs = [3e-5, 3e-5, 3e-5, 3e-5, 3e-5]\n",
"\n",
"llm_model_saved_directory = \"./models/fedmkt_homo_4_slms_llm_model\"\n",
"slm_models_saved_directory = [\n",
" \"./models/fedmkt_homo_4_slms_slm_0\",\n",
"]\n",
"\n",
"\n",
"def train_llm(ctx):\n",
" from peft import LoraConfig, TaskType\n",
" from fate_llm.model_zoo.pellm.llama import LLaMa\n",
" from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTLLM\n",
" from fate.ml.nn.homo.fedavg import FedAVGArguments\n",
" from fate_llm.dataset.qa_dataset import QaDataset\n",
" from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n",
" from transformers import AutoConfig\n",
"\n",
" lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1,\n",
" target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']\n",
" )\n",
"\n",
" model = LLaMa(\n",
" pretrained_path=llm_pretrained_path,\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" torch_dtype=\"bfloat16\"\n",
" )\n",
"\n",
" pub_data = QaDataset(tokenizer_name_or_path=llm_pretrained_path,\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=\"common\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" pub_data.load(process_data_output_dir)\n",
"\n",
" training_args = FedMKTTrainingArguments(\n",
" global_epochs=global_epochs,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=batch_size,\n",
" learning_rate=llm_lr,\n",
" output_dir=\"./\",\n",
" dataloader_num_workers=4,\n",
" remove_unused_columns=False,\n",
" warmup_ratio=0.008,\n",
" lr_scheduler_type=\"cosine\",\n",
" optim=\"adamw_torch\",\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.95,\n",
" weight_decay=0.1,\n",
" max_grad_norm=1.0,\n",
" use_cpu=False,\n",
" vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size,\n",
" post_fedavg=True, # difference\n",
" )\n",
"\n",
" # difference\n",
" fed_args = FedAVGArguments(\n",
" aggregate_strategy='epoch',\n",
" aggregate_freq=1\n",
" )\n",
"\n",
" slm_to_llm_vocab_mapping = []\n",
" for path in slm_to_llm_vocab_mapping_paths:\n",
" with open(path, \"r\") as fin:\n",
" vocab_mapping = json.loads(fin.read())\n",
" slm_to_llm_vocab_mapping.append(vocab_mapping)\n",
"\n",
" slm_tokenizers = [get_tokenizer(slm_pretrained_path) for slm_pretrained_path in slm_pretrained_paths]\n",
"\n",
" tokenizer = get_tokenizer(llm_pretrained_path)\n",
" trainer = FedMKTLLM(\n",
" ctx=ctx,\n",
" model=model,\n",
" training_args=training_args,\n",
" fed_args=fed_args, # difference\n",
" train_set=pub_data,\n",
" tokenizer=tokenizer,\n",
" slm_tokenizers=slm_tokenizers,\n",
" slm_to_llm_vocab_mappings=slm_to_llm_vocab_mapping,\n",
" save_trainable_weights_only=True,\n",
" )\n",
"\n",
" trainer.train()\n",
" trainer.save_model(llm_model_saved_directory)\n",
"\n",
"\n",
"def train_slm(ctx, slm_idx):\n",
" import transformers\n",
" from peft import LoraConfig, TaskType\n",
" from fate_llm.model_zoo.pellm.opt import OPT\n",
" from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTSLM\n",
" from fate.ml.nn.homo.fedavg import FedAVGArguments\n",
" from fate_llm.dataset.qa_dataset import QaDataset\n",
" from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n",
" from transformers import AutoConfig\n",
"\n",
" slm_model_class = [OPT] * 4\n",
"\n",
" lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1,\n",
" target_modules=slm_lora_target_modules[slm_idx]\n",
" )\n",
"\n",
" model = slm_model_class[slm_idx](\n",
" pretrained_path=slm_pretrained_paths[slm_idx],\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" torch_dtype=\"bfloat16\"\n",
" )\n",
"\n",
" priv_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=f\"client_{slm_idx}\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" priv_data.load(process_data_output_dir)\n",
"\n",
" pub_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=\"common\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" pub_data.load(process_data_output_dir)\n",
"\n",
" training_args = FedMKTTrainingArguments(\n",
" global_epochs=global_epochs,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=batch_size,\n",
" learning_rate=slm_lrs[slm_idx],\n",
" output_dir=\"./\",\n",
" dataloader_num_workers=4,\n",
" remove_unused_columns=False,\n",
" warmup_ratio=0.008,\n",
" lr_scheduler_type=\"cosine\",\n",
" optim=\"adamw_torch\",\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.95,\n",
" weight_decay=0.1,\n",
" max_grad_norm=1.0,\n",
" use_cpu=False,\n",
" vocab_size=AutoConfig.from_pretrained(slm_pretrained_paths[slm_idx]).vocab_size,\n",
" post_fedavg=True, # difference\n",
" )\n",
"\n",
" # difference\n",
" fed_args = FedAVGArguments(\n",
" aggregate_strategy='epoch',\n",
" aggregate_freq=1\n",
" )\n",
"\n",
" tokenizer = get_tokenizer(slm_pretrained_paths[slm_idx])\n",
"\n",
" import json\n",
" with open(llm_to_slm_vocab_mapping_paths[slm_idx], \"r\") as fin:\n",
" vocab_mapping = json.loads(fin.read())\n",
"\n",
" trainer = FedMKTSLM(\n",
" ctx=ctx,\n",
" model=model,\n",
" training_args=training_args, \n",
" fed_args=fed_args, # difference\n",
" pub_train_set=pub_data,\n",
" priv_train_set=priv_data,\n",
" tokenizer=tokenizer,\n",
" save_trainable_weights_only=True,\n",
" llm_tokenizer=get_tokenizer(llm_pretrained_path),\n",
" llm_to_slm_vocab_mapping=vocab_mapping,\n",
" data_collator=transformers.DataCollatorForSeq2Seq(tokenizer)\n",
" )\n",
"\n",
" trainer.train()\n",
" if slm_idx == 0:\n",
" trainer.save_model(slm_models_saved_directory[slm_idx])\n",
"\n",
"\n",
"def run(ctx: Context):\n",
" if ctx.is_on_arbiter:\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
" train_llm(ctx)\n",
" elif ctx.is_on_guest:\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
" train_slm(ctx, slm_idx=0)\n",
" else:\n",
" if ctx.local.party[1] == \"9999\":\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n",
" slm_idx = 1\n",
" elif ctx.local.party[1] == \"10000\":\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n",
" slm_idx = 2\n",
" elif ctx.local.party[1] == \"10001\":\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"4\"\n",
" slm_idx = 3\n",
" else:\n",
" raise ValueError(f\"party_id={ctx.local.party[1]} is illegal\")\n",
"\n",
" train_slm(ctx, slm_idx=slm_idx)\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" launch(run)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Running FEDMKT with Pipeline (Industrial Using)\n",
"\n",
"Please make sure that [FATE-LLM Cluster](https://github.com/FederatedAI/FATE/wiki/Download#llm%E9%83%A8%E7%BD%B2%E5%8C%85) has been deployed, ensure that multiple machines has been deployed in FATE-LLM Cluster mode, past the following code to test_fedmkt_4_slms.py, the execute \"python test_fedmkt_4_slms.py\""
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_fedmkt_runner\n",
"from fate_client.pipeline.components.fate.nn.algo_params import FedMKTTrainingArguments, FedAVGArguments\n",
"from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMDataFuncLoader\n",
"from peft import LoraConfig, TaskType\n",
"from fate_client.pipeline import FateFlowPipeline\n",
"from fate_client.pipeline.components.fate.reader import Reader\n",
"from transformers import AutoConfig\n",
"\n",
"guest = '9999' # replace this party id to actual guest party id in your enviroment\n",
"host = ['9999', '10000', '10001'] # replace host party ids in your enviroment\n",
"arbiter = '9999' # replace this party id to actual arbiter party id in your enviroment\n",
"\n",
"\n",
"process_data_output_dir = \"\" # replace this to actual process_data_output_dir\n",
"# replaoce the names of models to local save directories\n",
"llm_pretrained_path = \"llama-2-7b-hf\"\n",
"slm_0_pretrained_path = \"opt-1.3b\"\n",
"slm_1_pretrained_path = \"gpt2-xl\"\n",
"slm_2_pretrained_path = \"Sheared-LLaMA-1.3B\"\n",
"slm_3_pretrained_path = \"bloom-1b1\"\n",
"llm_slm_pairs = [\n",
" (llm_pretrained_path, slm_0_pretrained_path),\n",
" (llm_pretrained_path, slm_1_pretrained_path),\n",
" (llm_pretrained_path, slm_2_pretrained_path),\n",
" (llm_pretrained_path, slm_3_pretrained_path)\n",
"]\n",
"\n",
"vocab_mapping_directory = \"\" # reploace this to actual voacb_mapping_directory\n",
"\n",
"slm_to_llm_vocab_mapping_paths = [\"opt_to_llama.json\", \"gpt2_to_llama.json\", \"llama_small_to_llama.json\", \"bloom_to_llama.json\"]\n",
"llm_to_slm_vocab_mapping_paths = [\"llama_to_opt.json\", \"llama_to_gpt2.json\", \"llama_to_llama_small\", \"llama_to_bloom.json\"]\n",
"\n",
"for idx in range(4):\n",
" slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + slm_to_llm_vocab_mapping_paths[idx]\n",
" llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + llm_to_slm_vocab_mapping_paths[idx]\n",
"\n",
"slm_pretrained_paths = [slm_0_pretrained_path, slm_1_pretrained_path, slm_2_pretrained_path, slm_3_pretrained_path]\n",
"slm_lora_target_modules = [\n",
" [\"q_proj\", \"v_proj\"],\n",
" [\"c_attn\"],\n",
" ['q_proj', 'k_proj', 'v_proj', 'o_proj'],\n",
" [\"query_key_value\"]\n",
"]\n",
"slm_models = [\n",
" (\"pellm.opt\", \"OPT\"),\n",
" (\"pellm.gpt2\", \"GPT2CLM\"),\n",
" (\"pellm.llama\", \"LLaMa\"),\n",
" (\"pellm.bloom\", \"Bloom\")\n",
"]\n",
"\n",
"\n",
"def get_llm_conf():\n",
" lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\n",
" target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']\n",
" )\n",
" lora_config.target_modules = list(lora_config.target_modules)\n",
"\n",
" llm_model = LLMModelLoader(\n",
" \"pellm.llama\",\n",
" \"LLaMa\",\n",
" pretrained_path=llm_pretrained_path,\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" torch_dtype=\"bfloat16\"\n",
" )\n",
"\n",
" pub_dataset = LLMDatasetLoader(\n",
" \"qa_dataset\",\n",
" \"QaDataset\",\n",
" tokenizer_name_or_path=llm_pretrained_path,\n",
" need_preprocess=True,\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=\"common\",\n",
" seq_max_len=512\n",
" )\n",
"\n",
" training_args = FedMKTTrainingArguments(\n",
" global_epochs=5,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=4,\n",
" learning_rate=3e-5,\n",
" output_dir=\"./\",\n",
" dataloader_num_workers=4,\n",
" remove_unused_columns=False,\n",
" warmup_ratio=0.008,\n",
" lr_scheduler_type=\"cosine\",\n",
" optim=\"adamw_torch\",\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.95,\n",
" weight_decay=0.1,\n",
" max_grad_norm=1.0,\n",
" use_cpu=False,\n",
" vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size,\n",
" )\n",
"\n",
" fed_args = FedAVGArguments(\n",
" aggregate_strategy='epoch',\n",
" aggregate_freq=1\n",
" )\n",
"\n",
" tokenizer = LLMDataFuncLoader(\n",
" \"tokenizers.cust_tokenizer\",\n",
" \"get_tokenizer\",\n",
" tokenizer_name_or_path=llm_pretrained_path\n",
" )\n",
"\n",
" slm_tokenizers = list()\n",
" for slm_pretrained_path in slm_pretrained_paths:\n",
" slm_tokenizers.append(\n",
" LLMDataFuncLoader(\"tokenizers.cust_tokenizer\", \"get_tokenizer\", tokenizer_name_or_path=slm_pretrained_path)\n",
" )\n",
"\n",
" return get_config_of_fedmkt_runner(\n",
" model=llm_model,\n",
" training_args=training_args,\n",
" fed_args=fed_args,\n",
" pub_dataset=pub_dataset,\n",
" tokenizer=tokenizer,\n",
" slm_tokenizers=slm_tokenizers,\n",
" slm_to_llm_vocab_mapping_paths=slm_to_llm_vocab_mapping_paths,\n",
" pub_dataset_path=process_data_output_dir,\n",
" save_trainable_weights_only=True,\n",
" )\n",
"\n",
"\n",
"def get_slm_conf(slm_idx):\n",
" slm_pretrained_path = slm_pretrained_paths[slm_idx]\n",
" lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1,\n",
" target_modules=slm_lora_target_modules[slm_idx]\n",
" )\n",
" lora_config.target_modules = list(lora_config.target_modules)\n",
" llm_to_slm_vocab_mapping = llm_to_slm_vocab_mapping_paths[slm_idx]\n",
"\n",
" slm_model = LLMModelLoader(\n",
" slm_models[slm_idx][0],\n",
" slm_models[slm_idx][1],\n",
" pretrained_path=slm_pretrained_path,\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" )\n",
" vocab_size = AutoConfig.from_pretrained(slm_pretrained_path).vocab_size\n",
"\n",
" pub_dataset = LLMDatasetLoader(\n",
" \"qa_dataset\",\n",
" \"QaDataset\",\n",
" tokenizer_name_or_path=slm_pretrained_path,\n",
" need_preprocess=True,\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=\"common\",\n",
" seq_max_len=512\n",
" )\n",
"\n",
" priv_dataset = LLMDatasetLoader(\n",
" \"qa_dataset\",\n",
" \"QaDataset\",\n",
" tokenizer_name_or_path=slm_pretrained_path,\n",
" need_preprocess=True,\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=\"client_0\",\n",
" seq_max_len=512\n",
" )\n",
"\n",
" training_args = FedMKTTrainingArguments(\n",
" global_epochs=5,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=4,\n",
" learning_rate=3e-5 if slm_idx != 1 else 3e-4\n",
" output_dir=\"./\",\n",
" dataloader_num_workers=4,\n",
" remove_unused_columns=False,\n",
" warmup_ratio=0.008,\n",
" lr_scheduler_type=\"cosine\",\n",
" optim=\"adamw_torch\",\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.95,\n",
" weight_decay=0.1,\n",
" max_grad_norm=1.0,\n",
" use_cpu=False,\n",
" vocab_size=vocab_size,\n",
" # post_fedavg=True,\n",
" # llm_training=False,\n",
" )\n",
"\n",
" fed_args = FedAVGArguments(\n",
" aggregate_strategy='epoch',\n",
" aggregate_freq=1\n",
" )\n",
"\n",
" tokenizer = LLMDataFuncLoader(\n",
" \"tokenizers.cust_tokenizer\",\n",
" \"get_tokenizer\",\n",
" tokenizer_name_or_path=slm_pretrained_path\n",
" )\n",
"\n",
" llm_tokenizer = LLMDataFuncLoader(\n",
" \"tokenizers.cust_tokenizer\", \"get_tokenizer\", tokenizer_name_or_path=llm_pretrained_path\n",
" )\n",
"\n",
" data_collator = LLMDataFuncLoader(module_name='data_collator.cust_data_collator',\n",
" item_name='get_seq2seq_data_collator', tokenizer_name_or_path=slm_pretrained_path)\n",
"\n",
" return get_config_of_fedmkt_runner(\n",
" model=slm_model,\n",
" training_args=training_args,\n",
" fed_args=fed_args,\n",
" pub_dataset=pub_dataset,\n",
" priv_dataset=priv_dataset,\n",
" tokenizer=tokenizer,\n",
" llm_tokenizer=llm_tokenizer,\n",
" llm_to_slm_vocab_mapping_path=llm_to_slm_vocab_mapping,\n",
" pub_dataset_path=process_data_output_dir,\n",
" save_trainable_weights_only=True,\n",
" data_collator=data_collator\n",
" )\n",
"\n",
"\n",
"pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter, host=host)\n",
"pipeline.bind_local_path(path=process_data_output_dir, namespace=\"experiment\", name=\"arc_challenge\")\n",
"\n",
"\n",
"reader_0 = Reader(\"reader_0\", runtime_parties=dict(guest=guest, host=host))\n",
"reader_0.guest.task_parameters(\n",
" namespace=\"experiment\",\n",
" name=\"arc_challenge\"\n",
")\n",
"reader_0.hosts[[0, 1, 2]].task_parameters(\n",
" namespace=\"experiment\",\n",
" name=\"arc_challenge\"\n",
")\n",
"\n",
"\n",
"homo_nn_0 = HomoNN(\n",
" 'nn_0',\n",
" train_data=reader_0.outputs[\"output_data\"],\n",
" runner_module=\"fedmkt_runner\",\n",
" runner_class=\"FedMKTRunner\",\n",
")\n",
"\n",
"homo_nn_0.arbiter.task_parameters(\n",
" runner_conf=get_llm_conf()\n",
")\n",
"\n",
"homo_nn_0.guest.task_parameters(\n",
" runner_conf=get_slm_conf(slm_idx=0)\n",
")\n",
"\n",
"for idx in range(3):\n",
" homo_nn_0.hosts[idx].task_parameters(\n",
" runner_conf=get_slm_conf(slm_idx=idx + 1)\n",
" )\n",
"\n",
"homo_nn_0.guest.conf.set(\"launcher_name\", \"deepspeed\") # tell schedule engine to run task with deepspeed\n",
"homo_nn_0.hosts[[0, 1, 2]].conf.set(\"launcher_name\", \"deepspeed\") # tell schedule engine to run task with deepspeed\n",
"homo_nn_0.arbiter.conf.set(\"launcher_name\", \"deepspeed\") # tell schedule engine to run task with deepspeed\n",
"\n",
"pipeline.add_tasks([reader_0, homo_nn_0])\n",
"pipeline.conf.set(\"task\", dict(engine_run={\"cores\": 1})) # the number of gpus of each party\n",
"\n",
"pipeline.compile()\n",
"pipeline.fit()\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.15"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
================================================
FILE: doc/tutorial/inferdpt/inferdpt_tutorial.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "341aeb6e-9e25-4a0e-9664-a32ab11293fa",
"metadata": {},
"source": [
"# Inferdpt Tutorial"
]
},
{
"cell_type": "markdown",
"id": "0b40afd5-77b9-45c6-a761-81b9a6bddc05",
"metadata": {},
"source": [
"## Introduction of Inferdpt\n",
"\n",
"Inferdpt is an advanced algorithm framework designed for efficient and privacy-preserving text generation using large language models (LLMs). The framework addresses privacy concerns related to data leakage and unauthorized information collection in LLMs. Inferdpt implements Differential Privacy mechanisms to protect sensitive information during the inference process with black-box LLMs.\n",
"\n",
"Inferdpt comprises two key modules: the \"perturbation module\" and the \"extraction module\". The perturbation module utilizes a differentially private(DP) mechanism to generate a perturbed prompt from the raw document, facilitating privacy-preserving inference with black-box LLMs. The extraction module, inspired by knowledge distillation and retrieval-augmented generation, processes the perturbed text to produce coherent and consistent output. This ensures that the text generation quality of InferDPT is comparable to that of non-private LLMs, maintaining high utility while providing strong privacy guarantees.\n",
"\n",
"To further enhance privacy protection, Inferdpt integrates a novel mechanism called RANTEXT. RANTEXT introduces the concept of random adjacency list for token-level perturbation, addressing the vulnerability of existing differentially private mechanisms to embedding inversion attacks.\n",
"\n",
"For more details of Inferdpt, please refer to the [original paper](https://arxiv.org/pdf/2310.12214.pdf)."
]
},
{
"cell_type": "markdown",
"id": "ac982b2d-4a71-45a5-a2b1-90259711f36b",
"metadata": {},
"source": [
"## Use InferDPT"
]
},
{
"cell_type": "markdown",
"id": "042049c5-80ce-4786-9896-88baddd59f4e",
"metadata": {},
"source": [
"In this section, we will guide you through the process of:\n",
"- Setting up the inferdpt toolkit with an existing language model.\n",
"- Creating a model inference tool using the built-in class.\n",
"- Executing a step-by-step walkthrough of an inference instance: Employing inferdpt to generate rationale responses for question-answering tasks."
]
},
{
"cell_type": "markdown",
"id": "e1938eef-106d-4cc0-a9b7-6ad8d9d281f5",
"metadata": {},
"source": [
"### Create Inferdpt Kit"
]
},
{
"cell_type": "markdown",
"id": "565aa2ed-5919-4aa0-9499-23b730434c62",
"metadata": {},
"source": [
"In alignment with the original paper, the implementation of differential privacy in inferdpt involves the random substitution of tokens in the original text with semantically similar words. To facilitate this process, it is necessary to precalculate the similarities between a subset of tokens from the vocabulary of the remote large language model. In this tutorial, we will utilize the Mistral-7B model as our remote large language model and the Qwen1.5-0.5B model as the local decoding model. For the sake of computational efficiency, we will select a subset of 11,400 tokens from the Mistral-7B vocabulary to perform the similarity calculations and use the built-in function to finally get the inferdpt-kit.\n",
"\n",
"Firstly we load the mistral model to get the embedding set:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f01a229a-52e1-4a97-af06-a2ab122b7083",
"metadata": {},
"outputs": [],
"source": [
"# load embeddings from mistral model\n",
"import numpy as np\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"model_path = '/data/cephfs/llm/models/Mistral-7B-Instruct-v0.2/'\n",
"tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
"model = AutoModelForCausalLM.from_pretrained(model_path)\n",
"embeddings = tokenizer.get_vocab() # get embeddings matrix"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "3f7ec40b-1a58-4608-b2c1-3299979e699a",
"metadata": {},
"outputs": [],
"source": [
"# Get the embedding layer weights\n",
"dtype = np.float32\n",
"embedding_weights = model.get_input_embeddings().weight\n",
"# Convert the embedding layer weights to numpy\n",
"embedding_weights_np = embedding_weights.detach().numpy().astype(dtype)"
]
},
{
"cell_type": "markdown",
"id": "07261aee-b676-4a42-9098-2923fa67519c",
"metadata": {},
"source": [
"Then we select english tokens from the vocabulary. Then we can get an embedding matrix and a corresponding token list."
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "dbb231f9-f0ca-4add-bb45-f4fb59429abb",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32000/32000 [00:00<00:00, 663000.04it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"11400\n"
]
}
],
"source": [
"import tqdm\n",
"import re\n",
"\n",
"def contains_english_chars(string):\n",
" pattern = r'[a-zA-Z]'\n",
" match = re.search(pattern, string)\n",
" return bool(match)\n",
"\n",
"def contains_non_english_chars(string):\n",
" pattern = r'[^a-zA-Z]'\n",
" match = re.search(pattern, string)\n",
" return bool(match)\n",
"\n",
"def filter_tokens(token2index):\n",
" filtered_index2token = {}\n",
" for key, idx in tqdm.tqdm(token2index.items()):\n",
" if key.startswith('<'):\n",
" continue\n",
" if not key.startswith('▁'):\n",
" continue\n",
" val_ = key.replace(\"▁\", \"\")\n",
" if val_ == val_.upper():\n",
" continue\n",
" if contains_non_english_chars(val_):\n",
" continue\n",
" if 3 < len(val_) < 16 and contains_english_chars(val_):\n",
" filtered_index2token[idx] = key\n",
"\n",
" return filtered_index2token\n",
"\n",
"filtered_index2token = filter_tokens(embeddings)\n",
"used_num_tokens = len(filtered_index2token)\n",
"print(used_num_tokens)\n",
"for idx, token in filtered_index2token.items():\n",
" token_2_embedding[token] = embedding_weights_np[idx].tolist()\n",
"token_list = list(token_2_embedding.keys())\n",
"embedding_matrix = np.array(list(token_2_embedding.values()), dtype=dtype)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "5922a177-d752-485d-98ab-9fd6688198f8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"we got the embedding matrix:\n",
"[[-6.1035156e-04 -4.5471191e-03 -5.2795410e-03 ... -1.3656616e-03\n",
" 4.2419434e-03 -8.1634521e-04]\n",
" [ 4.8522949e-03 5.9814453e-03 1.1596680e-03 ... -2.6702881e-03\n",
" -1.7471313e-03 9.9182129e-04]\n",
" [-2.7465820e-03 4.3029785e-03 3.3874512e-03 ... -2.6092529e-03\n",
" -1.2397766e-05 -3.4027100e-03]\n",
" ...\n",
" [-6.1340332e-03 -5.3405762e-03 -1.0910034e-03 ... -9.3841553e-04\n",
" -7.4005127e-04 -7.3852539e-03]\n",
" [-4.5166016e-03 8.2015991e-04 4.8217773e-03 ... -1.1978149e-03\n",
" -1.0528564e-03 -2.1362305e-03]\n",
" [ 1.2054443e-03 1.9836426e-03 -2.8419495e-04 ... -1.5792847e-03\n",
" -2.8381348e-03 -7.1716309e-04]]\n"
]
}
],
"source": [
"print('we got the embedding matrix:')\n",
"print(embedding_matrix)"
]
},
{
"cell_type": "markdown",
"id": "20890d89-998f-4f38-968a-2a6a0648b050",
"metadata": {},
"source": [
"We can easily prepare the pre-computed data we needed for inferdpt by using the built-in function of the InferDPTKit class:"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "c90c7099-d20d-4009-bb7a-aeb3b46210b2",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"11400it [00:37, 300.99it/s]\n",
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4096/4096 [00:03<00:00, 1147.93it/s]\n"
]
}
],
"source": [
"from fate_llm.algo.inferdpt.utils import InferDPTKit\n",
"param = InferDPTKit.make_inferdpt_kit_param(embedding_matrix, token_list)"
]
},
{
"cell_type": "markdown",
"id": "0fe3a722-5cdf-4393-90f3-e5d7b82051cf",
"metadata": {},
"source": [
"Great, the computation is complete! Now, let’s proceed to perturb a sentence using inferdpt with ε (epsilon) set to 3.0. We will also save the perturbed sentence to a designated folder for future reference."
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "4a6acd81-bc7c-49d3-86f4-ad0b5c329e61",
"metadata": {},
"outputs": [],
"source": [
"inferdpt_kit = InferDPTKit(*param, tokenizer)"
]
},
{
"cell_type": "code",
"execution_count": 96,
"id": "0077696c-0c10-4500-8835-6e72a084bc42",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'into the tree to the woods'"
]
},
"execution_count": 96,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inferdpt_kit.perturb('From the river to the ocean', epsilon=3.0)"
]
},
{
"cell_type": "code",
"execution_count": 97,
"id": "8df2f57a-e202-4d4f-a175-21990223dc3d",
"metadata": {},
"outputs": [],
"source": [
"save_kit_path = 'your path'\n",
"inferdpt_kit.save_to_path(save_kit_path)"
]
},
{
"cell_type": "markdown",
"id": "ced7f1bf-aa49-4806-92e2-712493bb4b10",
"metadata": {},
"source": [
"### Go through Inferdpt Step by Step\n",
"\n",
"Next, we will guide you through the process of using inferdpt step by step. We will simulate the interaction between the client and server locally. Before we begin, let’s discuss model inference. Within fate-llm's inferdpt module, we provide three types of model inference classes: vllm, vllm server, and Huggingface native. You can explore these classes in the [code files](../../../python/fate_llm/algo/inferdpt/inference/) or develop your own inference tool based on your specific needs. We highly recommend using vllm server. In this case, we will use the following two commands to launch two large model services, corresponding to the server’s LLM and the local decoding small model.\n",
"\n",
"For this example, we have executed the process on a machine equipped with four V100-32G GPUs. We advise you to adjust the model path and GPU settings as necessary to accommodate the specifications of your own machine.\n",
"\n",
"Start vllm server using commands below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1b6c3c7-6ddd-4386-8700-c95f74a2bae0",
"metadata": {},
"outputs": [],
"source": [
"! python -m vllm.entrypoints.openai.api_server --host 127.0.0.1 --port 8888 --model ./Mistral-7B-Instruct-v0.2 --dtype=half --enforce-eager --tensor-parallel-size 4 --gpu-memory-utilization 0.6"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "48374cb5-3a5b-456c-9d53-219c2468da63",
"metadata": {},
"outputs": [],
"source": [
"! python -m vllm.entrypoints.openai.api_server --host 127.0.0.1 --port 8887 --model ./Qwen1.5-0.5B --dtype=half --enforce-eager --tensor-parallel-size 4 --gpu-memory-utilization 0.2"
]
},
{
"cell_type": "markdown",
"id": "375e3f0d-36c7-4ab3-8e65-cccac23e93c6",
"metadata": {},
"source": [
"Next, we will initialize the inference instance, which are the parameters for both the inferdpt client and server. This includes specifying the IP address, port, and the model name of the service that has been started."
]
},
{
"cell_type": "code",
"execution_count": 130,
"id": "cd099ef4-569d-45b6-9765-502b688c3fb4",
"metadata": {},
"outputs": [],
"source": [
"from fate_llm.inference.api import APICompletionInference\n",
"# for client\n",
"inference_client = APICompletionInference(api_url=\"http://127.0.0.1:8887/v1\", model_name='./Qwen1.5-0.5B', api_key='EMPTY')\n",
"# for server\n",
"inference_server = APICompletionInference(api_url=\"http://127.0.0.1:8888/v1\", model_name='./Mistral-7B-Instruct-v0.2', api_key='EMPTY')"
]
},
{
"cell_type": "code",
"execution_count": 135,
"id": "8c430c14-2180-4f02-8f06-3f41bae1a710",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" I am a new user of this forum. I am a 20 year\n"
]
}
],
"source": [
"ret = inference_client.inference(['Hello how are you?'], inference_kwargs={\n",
" 'stop': ['<|im_end|>', '\\n'],\n",
" 'temperature': 0.01,\n",
" 'max_tokens': 16\n",
"})\n",
"print(ret[0])"
]
},
{
"cell_type": "code",
"execution_count": 138,
"id": "6341eb48-e30f-46d4-aeaa-8c6fd27259b9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" I am an artificial intelligence designed to assist with information and answer questions to the best of my ability. I don't have the ability to have a personal identity or emotions. I'm here to help you with any inquiries you may have. How can I assist you today?\n"
]
}
],
"source": [
"ret = inference_server.inference(['[INST]Who are u?[/INST]'], inference_kwargs={\n",
" 'stop': [''],\n",
" 'temperature': 0.01,\n",
" 'max_tokens': 128\n",
"})\n",
"print(ret[0])"
]
},
{
"cell_type": "markdown",
"id": "90f5ce55-de3b-481a-a9cc-cb4c24edb7c2",
"metadata": {},
"source": [
"In this tutorial, we will use a question-answering (QA) task as our illustrative example. To do so, we will extract a sample from the ARC-E dataset for demonstration purposes, here is the example:"
]
},
{
"cell_type": "code",
"execution_count": 100,
"id": "f912f986-ae86-4d57-9ebc-534d6404173c",
"metadata": {},
"outputs": [],
"source": [
"test_example = {'id': 'Mercury_7220990',\n",
"'question': 'Which factor will most likely cause a person to develop a fever?',\n",
"'choices': {'text': ['a leg muscle relaxing after exercise',\n",
"'a bacterial population in the bloodstream',\n",
"'several viral particles on the skin',\n",
"'carbohydrates being digested in the stomach'],\n",
"'label': ['A', 'B', 'C', 'D']},\n",
"'answerKey': 'B'}"
]
},
{
"cell_type": "markdown",
"id": "a98c74a8-7760-438b-b4f4-33178fed8761",
"metadata": {},
"source": [
"Before initiating the inference, it's crucial to understand the sequence of steps involved. We will leverage the Jinja2 template engine to structure our documentation as follows:\n",
"\n",
"1. **Document Template Organization**: The initial step is to organize the document dictionary using the DOC TEMPLATE. This template will provide the structure for the input document.\n",
"\n",
"2. **Differential Privacy Perturbation**: Apply Differential Privacy (DP) to perturb the structured document string. This will result in a perturbed document. The perturbed document is then added to the original document under the key 'perturbed_doc'. Note that you can modify this key according to your parameter settings.\n",
"\n",
"3. **Instruction Addition**: Use the INSTRUCTION TEMPLATE to add instructions (or few-shot examples) to the perturbed document. This modified document is then sent to the server side for processing. The server's response is captured, and this perturbed response is appended to the original document under the key 'perturbed_response'. As before, this key can be adjusted as needed.\n",
"\n",
"4. **Decode Template Formatting**: Finally, employ the decode template to format the decode prompt. The resulting inference is then added to the original dictionary under the key 'inferdpt_result'. This key, like the others, can be customized to fit your specific parameters.\n",
"\n",
"By following these steps, the inferdpt framework enables a structured and privacy-preserving inference process, leading to a final output that incorporates the perturbed data and the model's response.\n",
"For more details, you can refer to the source codes:\n"
]
},
{
"cell_type": "markdown",
"id": "09d7377a-22f4-4d04-b886-88faa1384d7f",
"metadata": {},
"source": [
"The templates for this example are defined on the client side. Below is the Jinja template we use:"
]
},
{
"cell_type": "code",
"execution_count": 141,
"id": "eff74a65-f765-483f-a685-418376414ff0",
"metadata": {},
"outputs": [],
"source": [
"doc_template = \"\"\"{{question}} \n",
"Choices:{{choices.text}}\n",
"\"\"\"\n",
"\n",
"instruction_template=\"\"\"\n",
"[INST]\n",
"Select Answer from Choices and explain it in \"Rationale\" with few words. Please refer to the example to write the rationale.\n",
"Use to finish your rationle.\"\n",
"\n",
"Example(s):\n",
"Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\n",
"Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\n",
"Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction. Therefore, the answer is 'dry palms'.\n",
"\n",
"Please explain:\n",
"Question:{{perturbed_doc}}\n",
"Rationale:\n",
"[/INST]\n",
"\"\"\"\n",
"\n",
"decode_template = \"\"\"Select Answer from Choices and explain it in \"Rationale\" with few words. Please refer to the example to write the rationale.\n",
"Use to finish your rationle.\"\n",
"\n",
"Example(s):\n",
"Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\n",
"Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\n",
"Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction. Therefore, the answer is 'dry palms'.\n",
"\n",
"Question:{{perturbed_doc}}\n",
"Rationale:{{perturbed_response | replace('\\n', '')}}\n",
"\n",
"Please explain:\n",
"Question:{{question}} \n",
"Choices:{{choices.text}}\n",
"Rationale:\n",
"\"\"\""
]
},
{
"cell_type": "markdown",
"id": "b3c02898-91df-48b6-a0e2-9af5bd5538d8",
"metadata": {},
"source": [
"Please be aware that we have included a one-shot example in the prompt to ensure that the Large Language Model (LLM) responds as anticipated.\n",
"\n",
"Now we create two script: \n",
"- inferdpt_client.py\n",
"- inferdpt_server.py\n",
"\n",
"And run codes provided below:"
]
},
{
"cell_type": "markdown",
"id": "f0f110a3-c601-4c7b-8e89-8684d2ae266d",
"metadata": {},
"source": [
"#### Client Side: inferdpt_client.py"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "006612b8-da8d-402c-9b6d-b6786325fa7c",
"metadata": {},
"outputs": [],
"source": [
"from fate_llm.inference.api import APICompletionInference\n",
"from fate_llm.algo.inferdpt import inferdpt\n",
"from fate_llm.algo.inferdpt.utils import InferDPTKit\n",
"import sys\n",
"\n",
"\n",
"arbiter = (\"arbiter\", 10000)\n",
"guest = (\"guest\", 10000)\n",
"host = (\"host\", 9999)\n",
"name = \"fed1\"\n",
"\n",
"\n",
"def create_ctx(local):\n",
" from fate.arch import Context\n",
" from fate.arch.computing.backends.standalone import CSession\n",
" from fate.arch.federation.backends.standalone import StandaloneFederation\n",
" import logging\n",
"\n",
" logger = logging.getLogger()\n",
" logger.setLevel(logging.INFO)\n",
"\n",
" console_handler = logging.StreamHandler()\n",
" console_handler.setLevel(logging.INFO)\n",
"\n",
" formatter = logging.Formatter(\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\")\n",
" console_handler.setFormatter(formatter)\n",
"\n",
" logger.addHandler(console_handler)\n",
" computing = CSession(data_dir=\"./session_dir\")\n",
" return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))\n",
"\n",
"\n",
"ctx = create_ctx(guest)\n",
"save_kit_path = 'your path'\n",
"kit = InferDPTKit.load_from_path(save_kit_path)\n",
"inference = APICompletionInference(api_url=\"http://127.0.0.1:8887/v1\", model_name='./Qwen1.5-0.5B', api_key='EMPTY')\n",
"\n",
"test_example = {'id': 'Mercury_7220990',\n",
"'question': 'Which factor will most likely cause a person to develop a fever?',\n",
"'choices': {'text': ['a leg muscle relaxing after exercise',\n",
"'a bacterial population in the bloodstream',\n",
"'several viral particles on the skin',\n",
"'carbohydrates being digested in the stomach'],\n",
"'label': ['A', 'B', 'C', 'D']},\n",
"'answerKey': 'B'}\n",
"\n",
"\n",
"doc_template = \"\"\"{{question}} \n",
"Choices:{{choices.text}}\n",
"\"\"\"\n",
"\n",
"instruction_template=\"\"\"\n",
"[INST]\n",
"Select Answer from Choices and explain it in \"Rationale\" with few words. Please refer to the example to write the rationale.\n",
"Use to finish your rationle.\"\n",
"\n",
"Example(s):\n",
"Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\n",
"Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\n",
"Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction. Therefore, the answer is 'dry palms'.\n",
"\n",
"Please explain:\n",
"Question:{{perturbed_doc}}\n",
"Rationale:\n",
"[/INST]\n",
"\"\"\"\n",
"\n",
"decode_template = \"\"\"Select Answer from Choices and explain it in \"Rationale\" with few words. Please refer to the example to write the rationale.\n",
"Use to finish your rationle.\"\n",
"\n",
"Example(s):\n",
"Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\n",
"Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\n",
"Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction. Therefore, the answer is 'dry palms'.\n",
"\n",
"Question:{{perturbed_doc}}\n",
"Rationale:{{perturbed_response | replace('\\n', '')}}\n",
"\n",
"Please explain:\n",
"Question:{{question}} \n",
"Choices:{{choices.text}}\n",
"Rationale:\n",
"\"\"\"\n",
"\n",
"inferdpt_client = inferdpt.InferDPTClient(ctx, kit, inference, epsilon=3.0)\n",
"result = inferdpt_client.inference([test_example], doc_template, instruction_template, decode_template, \\\n",
" remote_inference_kwargs={\n",
" 'stop': ['<\\s>'],\n",
" 'temperature': 0.01,\n",
" 'max_tokens': 256\n",
" },\n",
" local_inference_kwargs={\n",
" 'stop': ['<|im_end|>', '', '\\n', '\\n\\n', '.\\n\\n\\n\\n\\n', '<|end_of_text|>', '>\\n\\n\\n'],\n",
" 'temperature': 0.01,\n",
" 'max_tokens': 256\n",
" })\n",
"print('result is {}'.format(result[0]['inferdpt_result']))"
]
},
{
"cell_type": "markdown",
"id": "e6ed3c0e-0b1f-4087-b155-def3ee957618",
"metadata": {},
"source": [
"#### Server Side: inferdpt_server.py"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "96e3e9fa-9554-4bcf-b8bf-358c469014bf",
"metadata": {},
"outputs": [],
"source": [
"from fate_llm.algo.inferdpt.inferdpt import InferDPTServer\n",
"import sys\n",
"from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n",
"\n",
"\n",
"arbiter = (\"arbiter\", 10000)\n",
"guest = (\"guest\", 10000)\n",
"host = (\"host\", 9999)\n",
"name = \"fed1\"\n",
"\n",
"\n",
"def create_ctx(local):\n",
" from fate.arch import Context\n",
" from fate.arch.computing.backends.standalone import CSession\n",
" from fate.arch.federation.backends.standalone import StandaloneFederation\n",
" import logging\n",
"\n",
" logger = logging.getLogger()\n",
" logger.setLevel(logging.INFO)\n",
"\n",
" console_handler = logging.StreamHandler()\n",
" console_handler.setLevel(logging.INFO)\n",
"\n",
" formatter = logging.Formatter(\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\")\n",
" console_handler.setFormatter(formatter)\n",
"\n",
" logger.addHandler(console_handler)\n",
" computing = CSession(data_dir=\"./session_dir\")\n",
" return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))\n",
"\n",
"\n",
"ctx = create_ctx(arbiter)\n",
"inference_server = APICompletionInference(api_url=\"http://127.0.0.1:8888/v1\", model_name='./Mistral-7B-Instruct-v0.2', api_key='EMPTY')\n",
"inferdpt_server = InferDPTServer(ctx, inference)\n",
"inferdpt_server.inference()"
]
},
{
"cell_type": "markdown",
"id": "bfef704b-179e-44cd-84dc-a40b036e7f28",
"metadata": {},
"source": [
"Start two terminal and launch client&server scripts simultaneously.\n",
"On the client side we can get the answer:\n",
"\n",
"```\n",
"The given question asks which factor will most likely cause a person to develop a fever. The factors mentioned are a leg muscle relaxing after exercise, a bacterial population in the bloodstream, several viral particles on the skin, and carbohydrates being digested in the stomach. The question is asking which factor is most likely to cause a person to develop a fever. The factors are all related to the body's internal environment, but the most likely factor is a bacterial population in the bloodstream. This is because bacteria can cause a fever, and the body's immune system responds to the infection by producing antibodies that can fight off the bacteria. Therefore, the answer is 'a bacterial population in the bloodstream'\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "adf80b4e-4727-4ee4-b0b1-4839bd516f4f",
"metadata": {},
"source": [
"## Use Inferdpt in FATE Pipeline"
]
},
{
"cell_type": "markdown",
"id": "b9b560e3-8db4-4828-a4fc-494320a9a3e5",
"metadata": {},
"source": [
"We can leverage the FATE pipeline to submit inference tasks for industrial applications. When operating in pipeline mode, to safeguard against privacy breaches such as API key or server path leakage, it is crucial to create initialization scripts for establishing inferdpt client instances. Alternatively, you can modify the provided scripts within the fate_llm/algo/inferdpt/init folder.\n",
"\n",
"Below, we provide an overview of the default_init.py script, which serves as an example of how to create an [initialization class](../../../python/fate_llm/algo/inferdpt/init/default_init.py). By customizing the static variables within this class, you can configure the client and server to interact with the Large Language Model (LLM) interfaces as intended."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eab49960-0541-4059-b84d-bee4bb690974",
"metadata": {},
"outputs": [],
"source": [
"from fate_llm.algo.inferdpt.init._init import InferClientInit\n",
"from fate_llm.inference.api import APICompletionInference\n",
"from fate_llm.algo.inferdpt import inferdpt\n",
"from fate_llm.algo.inferdpt.utils import InferDPTKit\n",
"from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer\n",
"\n",
"\n",
"class InferDPTAPIClientInit(InferClientInit):\n",
"\n",
" api_url = ''\n",
" api_model_name = ''\n",
" api_key = 'EMPTY'\n",
" inferdpt_kit_path = ''\n",
" eps = 3.0\n",
"\n",
" def __init__(self, ctx):\n",
" super().__init__(ctx)\n",
" self.ctx = ctx\n",
"\n",
" def get_inst(self)-> InferDPTClient:\n",
" inference = APICompletionInference(api_url=self.api_url, model_name=self.api_model_name, api_key=self.api_key)\n",
" kit = InferDPTKit.load_from_path(self.inferdpt_kit_path)\n",
" inferdpt_client = inferdpt.InferDPTClient(self.ctx, kit, inference, epsilon=self.eps)\n",
" return inferdpt_client\n",
"\n",
"\n",
"class InferDPTAPIServerInit(InferClientInit):\n",
"\n",
" api_url = ''\n",
" api_model_name = ''\n",
" api_key = 'EMPTY'\n",
"\n",
" def __init__(self, ctx):\n",
" super().__init__(ctx)\n",
" self.ctx = ctx\n",
"\n",
" def get_inst(self)-> InferDPTServer:\n",
" inference = APICompletionInference(api_url=self.api_url, model_name=self.api_model_name, api_key=self.api_key)\n",
" inferdpt_server = inferdpt.InferDPTServer(self.ctx,inference_inst=inference)\n",
" return inferdpt_server\n",
" "
]
},
{
"cell_type": "markdown",
"id": "0a5c9d6b-94b9-4ae3-80f7-20d1a698764c",
"metadata": {},
"source": [
"In the pipeline example, we use arc_easy dataset and our built-in huggingface dataset. Only HuggingfaceDataset is supported in the pipeline mode:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "15276057-fdda-4cc6-8678-eb1f485e4c58",
"metadata": {},
"outputs": [],
"source": [
"from fate_llm.dataset.hf_dataset import HuggingfaceDataset"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "31cce967-3f5f-4261-ae17-9089368b82f9",
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"dataset = load_dataset('arc_easy')\n",
"dataset.save_to_disk('your_path/arc_easy')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "af9adcb4-766d-45f6-a13c-5c127df61e5b",
"metadata": {},
"outputs": [],
"source": [
"ds = HuggingfaceDataset(load_from_disk= True, data_split_key='train')\n",
"ds.load('your_path/arc_easy')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "899c410f-fe68-4f7e-936e-8b11720ff148",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'id': 'Mercury_7220990', 'question': 'Which factor will most likely cause a person to develop a fever?', 'choices': {'text': ['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach'], 'label': ['A', 'B', 'C', 'D']}, 'answerKey': 'B'}\n"
]
}
],
"source": [
"print(ds[0])"
]
},
{
"cell_type": "markdown",
"id": "5f69f8cf-f40d-418a-be2a-753d67537442",
"metadata": {},
"source": [
"After that, we can associate the dataset path with a name and namespace. By specifying the dataset configuration, the HuggingfaceDataset will be initialized and the dataset will be loaded from the specified path. \n",
"```\n",
"flow table bind --namespace experiment --name arc_e --path 'your_path/arc_easy'\n",
"```\n",
"Once these initialization scripts are in place, you can submit a pipeline task by specifying the initialization class in the configuration files. For more information, refer to the script provided below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4da1aea7-0ba2-4ebb-918f-cfcf24d4498b",
"metadata": {},
"outputs": [],
"source": [
"import argparse\n",
"from fate_client.pipeline.utils import test_utils\n",
"from fate_client.pipeline.components.fate.reader import Reader\n",
"from fate_client.pipeline import FateFlowPipeline\n",
"\n",
"\n",
"def main(config=\"../../config.yaml\", namespace=\"\"):\n",
" # obtain config\n",
" if isinstance(config, str):\n",
" config = test_utils.load_job_config(config)\n",
" parties = config.parties\n",
" guest = parties.guest[0]\n",
" arbiter = parties.arbiter[0]\n",
"\n",
" pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter)\n",
"\n",
" reader_0 = Reader(\"reader_0\", runtime_parties=dict(guest=guest))\n",
" reader_0.guest.task_parameters(\n",
" namespace=f\"experiment{namespace}\",\n",
" name=\"arc_e\"\n",
" )\n",
"\n",
" inferdpt_init_conf_client = {\n",
" 'module_name': 'fate_llm.algo.inferdpt.init.default_init',\n",
" 'item_name': 'InferDPTAPIClientInit'\n",
" }\n",
"\n",
" dataset_conf = {\n",
" 'module_name': 'fate_llm.dataset.hf_dataset',\n",
" 'item_name': 'HuggingfaceDataset',\n",
" 'kwargs':{\n",
" 'load_from_disk': True,\n",
" 'data_split_key': 'train'\n",
" }\n",
" }\n",
"\n",
" doc_template = \"\"\"{{question}} \n",
" Choices:{{choices.text}}\n",
" \"\"\"\n",
"\n",
" instruction_template=\"\"\"\n",
" <|im_start|>system\n",
" You are a helpful assistant.<|im_end|>\n",
" <|im_start|>user\n",
" Select Answer from Choices and explain it in \"Rationale\" with few words. Please refer to the example to write the rationale.\n",
" Use to finish your rationle.\"\n",
"\n",
" Example(s):\n",
" Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\n",
" Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\n",
" Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction. Therefore, the answer is 'dry palms'.\n",
"\n",
" Please explain:\n",
" Question:{{perturbed_doc}}\n",
" Rationale:\n",
" <|im_end|>\n",
" <|im_start|>assistant\n",
" \"\"\"\n",
"\n",
" decode_template = \"\"\"Select Answer from Choices and explain it in \"Rationale\" with few words. Please refer to the example to write the rationale.\n",
" Use to finish your rationle.\"\n",
"\n",
" Example(s):\n",
" Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\n",
" Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\n",
" Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction. Therefore, the answer is 'dry palms'.\n",
"\n",
" Question:{{perturbed_doc}}\n",
" Rationale:{{perturbed_response | replace('\\n', '')}}\n",
"\n",
" Please explain:\n",
" Question:{{question}} \n",
" Choices:{{choices.text}}\n",
" Rationale:\n",
" \"\"\"\n",
"\n",
" remote_inference_kwargs={\n",
" 'stop': [['<\\s>']],\n",
" 'temperature': 0.01,\n",
" 'max_tokens': 256\n",
" }\n",
"\n",
" local_inference_kwargs={\n",
" 'stop': ['<|im_end|>', '', '\\n', '\\n\\n', '.\\n\\n\\n\\n\\n', '<|end_of_text|>', '>\\n\\n\\n'],\n",
" 'temperature': 0.01,\n",
" 'max_tokens': 256\n",
" }\n",
"\n",
" inferdpt_client_conf = {\n",
" 'inferdpt_init_conf': inferdpt_init_conf_client,\n",
" 'dataset_conf': dataset_conf,\n",
" 'doc_template': doc_template,\n",
" 'instruction_template': instruction_template,\n",
" 'decode_template': decode_template,\n",
" 'dataset_conf': dataset_conf,\n",
" 'remote_inference_kwargs': remote_inference_kwargs,\n",
" 'local_inference_kwargs': local_inference_kwargs\n",
" }\n",
"\n",
" inferdpt_init_conf_server = {\n",
" 'module_name': 'fate_llm.algo.inferdpt.init.default_init',\n",
" 'item_name': 'InferDPTAPIServerInit'\n",
" }\n",
"\n",
" inferdpt_server_conf = {\n",
" 'inferdpt_init_conf': inferdpt_init_conf_server\n",
" }\n",
"\n",
" homo_nn_0 = HomoNN(\n",
" 'nn_0',\n",
" runner_module='inferdpt_runner',\n",
" runner_class='InferDPTRunner',\n",
" train_data=reader_0.outputs[\"output_data\"]\n",
" )\n",
"\n",
" homo_nn_0.guest.task_parameters(runner_conf=inferdpt_client_conf)\n",
" homo_nn_0.arbiter.task_parameters(runner_conf=inferdpt_server_conf)\n",
" pipeline.add_tasks([reader_0, homo_nn_0])\n",
" pipeline.compile()\n",
" pipeline.fit()\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" parser = argparse.ArgumentParser(\"PIPELINE DEMO\")\n",
" parser.add_argument(\"--config\", type=str, default=\"../config.yaml\",\n",
" help=\"config file\")\n",
" parser.add_argument(\"--namespace\", type=str, default=\"\",\n",
" help=\"namespace for data stored in FATE\")\n",
" args = parser.parse_args()\n",
" main(config=args.config, namespace=args.namespace)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: doc/tutorial/offsite_tuning/Offsite_tuning_tutorial.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "c2345e19-83eb-4196-9606-74658c8fbdc5",
"metadata": {},
"source": [
"# Offsite-tuning Tutorial"
]
},
{
"cell_type": "markdown",
"id": "9f1d728c-09e1-418e-8d80-53dd0ec467b1",
"metadata": {},
"source": [
"In this tutorial, we'll focus on how to leverage Offsite-Tuning framework in FATE-LLM-2.0 to fine-tune your LLM. You'll learn how to:\n",
"\n",
"1. Define models, including main models(which are at server side and will offer adapters and emulators) and submodel(which are at client side and will load adapters and emulators for local fine-tuning) compatible with Offsite-Tuning framework.\n",
"2. Get hands-on experience with the Offsite-Tuning trainer.\n",
"3. Define configurations for advanced setup(Using Deepspeed, offsite-tuning + federation) through FATE-pipeline."
]
},
{
"cell_type": "markdown",
"id": "31432345-5cce-4efa-9a9b-844f997f14ad",
"metadata": {},
"source": [
"## Introduction of Offsite-tuning\n",
"\n",
"Offsite-Tuning is a novel approach designed for the efficient and privacy-preserving adaptation of large foundational models for specific downstream tasks. The framework allows data owners to fine-tune models locally without uploading sensitive data to the LLM owner's servers. Specifically, the LLM owner sends a lightweight \"Adapter\" and a lossy compressed \"Emulator\" to the data owner. Using these smaller components, the data owner can then fine-tune the model solely on their private data. The Adapter, once fine-tuned, is returned to the model owner and integrated back into the large model to enhance its performance on the specific dataset.\n",
"\n",
"Offsite-Tuning addresses the challenge of unequal distribution of computational power and data. It allows thLLMel owner to enhance the model's capabilities without direct access to private data, while also enabling data owners who may not have the resources to train a full-scale model to fine-tune a portion of it using less computational power. This mutually beneficial arrangement accommodates both parties involve.\n",
"\n",
"Beyond the standard two-party setup involving the model owner and the data ownin FATE-LLM, er, Offsite-Tunframework ing is also extendable to scenarios with multiple data owners. FATE supports multi-party Offsite-Tuning, allowing multiple data owners to fine-tune and aggregate their Adapters locally, further enhancing the flexibility and applicability of this framewrFor more details of Offsite-tuning, please refer to the [original paper](https://arxiv.org/pdf/2302.04870.pdf).\n"
]
},
{
"cell_type": "markdown",
"id": "2e7ac467-e5df-4bf3-8571-0a477ab4612d",
"metadata": {},
"source": [
"## Preliminary\n",
"\n",
"We strongly recommend you finish reading our NN tutorial to get familiar with Model and Dataset customizations: [NN Tutorials](https://github.com/FederatedAI/FATE/blob/master/doc/2.0/fate/components/pipeline_nn_cutomization_tutorial.md)\n",
"\n",
"In this tutorial, we assume that you have deploy the codes of FATE(including fateflow & fate-client) & FATE-LLM-2.0. You can add python path so that you can run codes in the notebook."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f33516e8-0d28-4c97-bc38-ba28d60acf37",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"your_path_to_fate_python = 'xxx/fate/fate/python'\n",
"sys.path.append(your_path_to_fate_python)"
]
},
{
"cell_type": "markdown",
"id": "2f2fc794",
"metadata": {},
"source": [
"If you install FATE & FATE-LLM-2.0 via pip, you can directly use the following codes."
]
},
{
"cell_type": "markdown",
"id": "7309281b-5956-4158-9256-d6db230e086d",
"metadata": {},
"source": [
"## Define Main Model and Sub Model\n",
"\n",
"Main models are at server side and will provides weights of adapters and emulators to client sides, while Sub Models are at client side and will load adapters and emulators for local fine-tuning. In this chapter we will take a standard GPT2 as the example and show you how to quickly develop main model class and sub model class for offsite-tuning.\n",
"\n",
"### Base Classes and Interfaces\n",
"\n",
"The base classes for the Main and Sub Models are OffsiteTuningMainModel and OffsiteTuningSubModel, respectively. To build your own models upon these base classes, you need to:\n",
"\n",
"1. Implement three key interfaces: get_base_model, get_model_transformer_blocks, and forward. The get_base_model interface should return the full Main or Sub Model. Meanwhile, the get_model_transformer_blocks function should return a ModuleList of all transformer blocks present in your language model, enabling the extraction of emulators and adapters from these blocks. Finally, you're required to implement the forward process for model inference.\n",
"\n",
"2. Supply the parameters emulator_layer_num, adapter_top_layer_num, and adapter_bottom_layer_num to the parent class. This allows the framework to automatically generate the top and bottom adapters as well as the dropout emulator for you. Specifically, the top adapters are taken from the top of the transformer blocks, while the bottom adapters are taken from the bottom. The emulator uses a dropout emulator consistent with the paper's specifications. Once the adapter layers are removed, the emulator is formed by selecting transformer blocks at fixed intervals and finally stack them to make a dropout emulator.\n",
"\n",
"Our framework will automatically detect the emulator and adapters of a main model, and send them to clients. Clients' models them load the weights of emulators and adapters to get trainable models.\n",
"\n",
"### Example\n",
"\n",
"Let us take a look of our built-in GPT-2 model. It will be easy for you to build main models and sub models based on the framework. Please notice that the GPT2LMHeadSubModel's base model is intialized from a GPTConfig, that is to say, it's weights are random and need to load pretrained weights from server."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8611c115-0321-458f-b190-49dcb127a653",
"metadata": {},
"outputs": [],
"source": [
"from fate_llm.model_zoo.offsite_tuning.offsite_tuning_model import OffsiteTuningSubModel, OffsiteTuningMainModel\n",
"from transformers import GPT2LMHeadModel, GPT2Config\n",
"from torch import nn\n",
"import torch as t\n",
"\n",
"\n",
"class GPT2LMHeadMainModel(OffsiteTuningMainModel):\n",
"\n",
" def __init__(\n",
" self,\n",
" model_name_or_path,\n",
" emulator_layer_num: int,\n",
" adapter_top_layer_num: int = 2,\n",
" adapter_bottom_layer_num: int = 2):\n",
"\n",
" self.model_name_or_path = model_name_or_path\n",
" super().__init__(\n",
" emulator_layer_num,\n",
" adapter_top_layer_num,\n",
" adapter_bottom_layer_num)\n",
"\n",
" def get_base_model(self):\n",
" return GPT2LMHeadModel.from_pretrained(self.model_name_or_path)\n",
"\n",
" def get_model_transformer_blocks(self, model: GPT2LMHeadModel):\n",
" return model.transformer.h\n",
"\n",
" def forward(self, x):\n",
" return self.model(**x)\n",
"\n",
"class GPT2LMHeadSubModel(OffsiteTuningSubModel):\n",
"\n",
" def __init__(\n",
" self,\n",
" model_name_or_path,\n",
" emulator_layer_num: int,\n",
" adapter_top_layer_num: int = 2,\n",
" adapter_bottom_layer_num: int = 2,\n",
" fp16_mix_precision=False,\n",
" partial_weight_decay=None):\n",
"\n",
" self.model_name_or_path = model_name_or_path\n",
" self.emulator_layer_num = emulator_layer_num\n",
" self.adapter_top_layer_num = adapter_top_layer_num\n",
" self.adapter_bottom_layer_num = adapter_bottom_layer_num\n",
" super().__init__(\n",
" emulator_layer_num,\n",
" adapter_top_layer_num,\n",
" adapter_bottom_layer_num,\n",
" fp16_mix_precision)\n",
" self.partial_weight_decay = partial_weight_decay\n",
"\n",
" def get_base_model(self):\n",
" total_layer_num = self.emulator_layer_num + \\\n",
" self.adapter_top_layer_num + self.adapter_bottom_layer_num\n",
" config = GPT2Config.from_pretrained(self.model_name_or_path)\n",
" config.num_hidden_layers = total_layer_num\n",
" # initialize a model without pretrained weights\n",
" return GPT2LMHeadModel(config)\n",
"\n",
" def get_model_transformer_blocks(self, model: GPT2LMHeadModel):\n",
" return model.transformer.h\n",
" \n",
" def forward(self, x):\n",
" return self.model(**x)\n"
]
},
{
"cell_type": "markdown",
"id": "abd1f63f-afa7-4f09-a67e-63812ddcd801",
"metadata": {},
"source": [
"We can define a server side model and a client side model that can work together in the offsite-tuning:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "04870e76-11cc-4d79-a09e-b6fd16ed2f23",
"metadata": {},
"outputs": [],
"source": [
"model_main = GPT2LMHeadMainModel('gpt2', 4, 2, 2)\n",
"model_sub = GPT2LMHeadSubModel('gpt2', 4, 2, 2)"
]
},
{
"cell_type": "markdown",
"id": "19d34937-b4ae-436e-b4ea-1620fb80bed4",
"metadata": {},
"source": [
"### Share additional parameters with clients\n",
"\n",
"Additionally, beyond the weights of emulators and adapters, you may also want to share other model parameters, such as embedding weights, with your client partners. To achieve this, you'll need to implement two more interfaces: get_additional_param_state_dict and load_additional_param_state_dict for both the Main and Sub Models."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "189fce0e-8e4d-4368-8e14-907b30ce0a49",
"metadata": {},
"outputs": [],
"source": [
"def get_additional_param_state_dict(self):\n",
" # get parameter of additional parameter\n",
" model = self.model\n",
" param_dict = {\n",
" 'wte': model.transformer.wte,\n",
" 'wpe': model.transformer.wpe,\n",
" 'last_ln_f': model.transformer.ln_f\n",
" }\n",
"\n",
" addition_weights = self.get_numpy_state_dict(param_dict)\n",
"\n",
" wte = addition_weights.pop('wte')\n",
" wte_dict = split_numpy_array(wte, 10, 'wte')\n",
" wpe = addition_weights.pop('wpe')\n",
" wpe_dict = split_numpy_array(wpe, 10, 'wpe')\n",
" addition_weights.update(wte_dict)\n",
" addition_weights.update(wpe_dict)\n",
" return addition_weights\n",
"\n",
"def load_additional_param_state_dict(self, submodel_weights: dict):\n",
" # load additional weights:\n",
" model = self.model\n",
" param_dict = {\n",
" 'wte': model.transformer.wte,\n",
" 'wpe': model.transformer.wpe,\n",
" 'last_ln_f': model.transformer.ln_f\n",
" }\n",
"\n",
" new_submodel_weight = {}\n",
" new_submodel_weight['last_ln_f'] = submodel_weights['last_ln_f']\n",
" wte_dict, wpe_dict = {}, {}\n",
" for k, v in submodel_weights.items():\n",
" if 'wte' in k:\n",
" wte_dict[k] = v\n",
" if 'wpe' in k:\n",
" wpe_dict[k] = v\n",
" wte = recover_numpy_array(wte_dict, 'wte')\n",
" wpe = recover_numpy_array(wpe_dict, 'wpe')\n",
" new_submodel_weight['wte'] = wte\n",
" new_submodel_weight['wpe'] = wpe\n",
"\n",
" self.load_numpy_state_dict(param_dict, new_submodel_weight)"
]
},
{
"cell_type": "markdown",
"id": "59d9aa6a-80e9-4130-8af1-c7d2bd0fbba3",
"metadata": {},
"source": [
"From these codes we can see that we use 'split_numpy_array, recover_numpy_array' to cut embedding weights into pieces and recover them."
]
},
{
"cell_type": "markdown",
"id": "dda6f5e3-d05a-4cdf-afd4-affbc162fce4",
"metadata": {},
"source": [
"## Submit a Offsite-tuning Task - A QA Task Sample with GPT2\n",
"\n",
"Now we are going to show you how to run a 2 party(server & client) offsite-tuning task using the GPT-2 model defined above. Before we submit the task we need to prepare the QA dataset.\n",
"\n",
"### Prepare QA Dataset - Sciq\n",
"\n",
"In this example, we use sciq dataset. You can use tools provided in our qa_dataset.py to tokenize the sciq dataset and save the tokenized result. **Remember to modify the save_path to your own path.** For the sake of simplicity, in this tutorial, for every party we only use this dataset to train the model."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "84f6947e-f0a3-4a42-9549-a9776a15b66d",
"metadata": {},
"outputs": [],
"source": [
"from fate_llm.dataset.qa_dataset import tokenize_qa_dataset\n",
"from transformers import AutoTokenizer\n",
"tokenizer_name_or_path = 'gpt2'\n",
"tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)\n",
"\n",
"if 'llama' in tokenizer_name_or_path:\n",
" tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, unk_token=\"\", bos_token=\"\", eos_token=\"\", add_eos_token=True) \n",
" tokenizer.pad_token = tokenizer.eos_token\n",
"else:\n",
" tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)\n",
"if 'gpt2' in tokenizer_name_or_path:\n",
" tokenizer.pad_token = tokenizer.eos_token\n",
"\n",
"import os\n",
"# bind data path to name & namespace\n",
"save_path = 'xxxx/sciq'\n",
"rs = tokenize_qa_dataset('sciq', tokenizer, save_path, seq_max_len=600) # we save the cache dataset to the fate root folder"
]
},
{
"cell_type": "markdown",
"id": "adabe89a-37be-4c64-bd83-4f8c8b80096f",
"metadata": {},
"source": [
"We can use our built-in QA dataset to load tokenized dataset, to see if everything is working correctly."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "6500c2ba-bc39-4db4-b2ea-947fb09c334e",
"metadata": {},
"outputs": [],
"source": [
"from fate_llm.dataset.qa_dataset import QaDataset\n",
"\n",
"ds = QaDataset(tokenizer_name_or_path=tokenizer_name_or_path)\n",
"ds.load(save_path)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "d6f62b60-eed0-4bd0-874e-ae3feeebb120",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"11679\n",
"600\n"
]
}
],
"source": [
"print(len(ds)) # train set length\n",
"print(ds[0]['input_ids'].__len__()) # first sample length"
]
},
{
"cell_type": "markdown",
"id": "0609c63d-35a4-43bc-bd4b-f1c61adea587",
"metadata": {},
"source": [
"## Submit a Task\n",
"\n",
"Now the model and the dataset is prepared! We can submit a training task. In the FATE-2.0, you can define your pipeline in a much easier manner.\n",
"\n",
"After we submit the task below, the following process will occur: The server and client each initialize their respective models. The server extracts shared parameters and sends them to the client. The client then loads these parameters and conducts training on a miniaturized GPT-2 model composed of an emulator and adapter on SciqP \n",
"\n",
"If you are not familiar with trainer configuration, please refer to [NN Tutorials](https://github.com/FederatedAI/FATE/blob/master/doc/2.0/fate/components/pipeline_nn_cutomization_tutorial.md).\n",
"\n",
" Upon completion of the training, the client sends the adapter parameters back to the server. Since we are directly using Hugging Face's LMHeadGPT2, there's no need to supply a loss function. Simply inputting the preprocessed data and labels into the model will calculate the correct loss and proceed with gradient descent\n",
"\n",
"One thing to pay special attention to is that Offsite-Tuning differs from FedAvg within FATE. In Offsite-Tuning, the server (the arbiter role) needs to initialize the model. Therefore, please refer to the example below and set the runner conf separately for the client and the server.\n",
"\n",
"To make this a quick demo, we only select 100 samples from the origin qa datset, see 'select_num=100' in the LLMDatasetLoader."
]
},
{
"cell_type": "markdown",
"id": "261dfb43",
"metadata": {},
"source": [
"### Bind Dataset Path with Name & Namespace\n",
"\n",
"Plase execute the following code to bind the dataset path with name & namespace. Remember to modify the path to your own dataset save path."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8dc1e82b",
"metadata": {},
"outputs": [],
"source": [
"! flow table bind --namespace experiment --name sciq --path YOUR_SAVE_PATH"
]
},
{
"cell_type": "markdown",
"id": "0e8c5ff4",
"metadata": {},
"source": [
"### Pipeline codes"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "c9113d10-c3e7-4875-9502-ce46aa0b86b1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import time\n",
"from fate_client.pipeline.components.fate.reader import Reader\n",
"from fate_client.pipeline import FateFlowPipeline\n",
"from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_conf_of_ot_runner\n",
"from fate_client.pipeline.components.fate.nn.algo_params import Seq2SeqTrainingArguments, FedAVGArguments\n",
"from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMDataFuncLoader\n",
"from fate_client.pipeline.components.fate.nn.torch.base import Sequential\n",
"from fate_client.pipeline.components.fate.nn.torch import nn\n",
"\n",
"\n",
"guest = '9999'\n",
"host = '9999'\n",
"arbiter = '9999'\n",
"\n",
"pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter)\n",
"pipeline.set_site_party_id('9999')\n",
"reader_0 = Reader(\"reader_0\", runtime_parties=dict(guest=guest))\n",
"reader_0.guest.task_parameters(\n",
" namespace=\"experiment\",\n",
" name=\"sciq\"\n",
")\n",
"\n",
"client_model = LLMModelLoader(\n",
" module_name='offsite_tuning.gpt2', item_name='GPT2LMHeadSubModel',\n",
" model_name_or_path='gpt2',\n",
" emulator_layer_num=4,\n",
" adapter_top_layer_num=1,\n",
" adapter_bottom_layer_num=1\n",
")\n",
"\n",
"server_model = LLMModelLoader(\n",
" module_name='offsite_tuning.gpt2', item_name='GPT2LMHeadMainModel',\n",
" model_name_or_path='gpt2',\n",
" emulator_layer_num=4,\n",
" adapter_top_layer_num=1,\n",
" adapter_bottom_layer_num=1 \n",
")\n",
"\n",
"train_args = Seq2SeqTrainingArguments(\n",
" per_device_train_batch_size=1,\n",
" learning_rate=5e-5,\n",
" disable_tqdm=False,\n",
" num_train_epochs=1,\n",
" logging_steps=10,\n",
" logging_strategy='steps',\n",
" use_cpu=False\n",
")\n",
"\n",
"dataset = LLMDatasetLoader(\n",
" module_name='qa_dataset', item_name='QaDataset',\n",
" tokenizer_name_or_path='gpt2',\n",
" select_num=100\n",
")\n",
"\n",
"data_collator = LLMDataFuncLoader(module_name='data_collator.cust_data_collator', item_name='get_seq2seq_data_collator', tokenizer_name_or_path='gpt2')\n",
"\n",
"client_conf = get_conf_of_ot_runner(\n",
" model=client_model,\n",
" dataset=dataset,\n",
" data_collator=data_collator,\n",
" training_args=train_args,\n",
" fed_args=FedAVGArguments(),\n",
" aggregate_model=False\n",
")\n",
"\n",
"server_conf = get_conf_of_ot_runner(\n",
" model=server_model,\n",
" dataset=dataset,\n",
" data_collator=data_collator,\n",
" training_args=train_args,\n",
" fed_args=FedAVGArguments(),\n",
" aggregate_model=False\n",
")\n",
"\n",
"homo_nn_0 = HomoNN(\n",
" 'nn_0',\n",
" train_data=reader_0.outputs[\"output_data\"],\n",
" runner_module=\"offsite_tuning_runner\",\n",
" runner_class=\"OTRunner\"\n",
")\n",
"\n",
"homo_nn_0.guest.task_parameters(runner_conf=client_conf)\n",
"homo_nn_0.arbiter.task_parameters(runner_conf=server_conf)\n",
"pipeline.add_tasks([reader_0, homo_nn_0])\n",
"pipeline.compile()"
]
},
{
"cell_type": "markdown",
"id": "e97c2823",
"metadata": {},
"source": [
"You can try to initialize your models, datasets to check if they can be loaded correctly."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "872817e5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GPT2LMHeadSubModel(\n",
" (model): GPT2LMHeadModel(\n",
" (transformer): GPT2Model(\n",
" (wte): Embedding(50257, 768)\n",
" (wpe): Embedding(1024, 768)\n",
" (drop): Dropout(p=0.1, inplace=False)\n",
" (h): ModuleList(\n",
" (0): GPT2Block(\n",
" (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (attn): GPT2Attention(\n",
" (c_attn): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (attn_dropout): Dropout(p=0.1, inplace=False)\n",
" (resid_dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): GPT2MLP(\n",
" (c_fc): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (act): NewGELUActivation()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (1): GPT2Block(\n",
" (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (attn): GPT2Attention(\n",
" (c_attn): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (attn_dropout): Dropout(p=0.1, inplace=False)\n",
" (resid_dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): GPT2MLP(\n",
" (c_fc): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (act): NewGELUActivation()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (2): GPT2Block(\n",
" (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (attn): GPT2Attention(\n",
" (c_attn): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (attn_dropout): Dropout(p=0.1, inplace=False)\n",
" (resid_dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): GPT2MLP(\n",
" (c_fc): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (act): NewGELUActivation()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (3): GPT2Block(\n",
" (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (attn): GPT2Attention(\n",
" (c_attn): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (attn_dropout): Dropout(p=0.1, inplace=False)\n",
" (resid_dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): GPT2MLP(\n",
" (c_fc): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (act): NewGELUActivation()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (4): GPT2Block(\n",
" (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (attn): GPT2Attention(\n",
" (c_attn): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (attn_dropout): Dropout(p=0.1, inplace=False)\n",
" (resid_dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): GPT2MLP(\n",
" (c_fc): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (act): NewGELUActivation()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (5): GPT2Block(\n",
" (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (attn): GPT2Attention(\n",
" (c_attn): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (attn_dropout): Dropout(p=0.1, inplace=False)\n",
" (resid_dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): GPT2MLP(\n",
" (c_fc): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (act): NewGELUActivation()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
" )\n",
" (emulator): ModuleList(\n",
" (0): GPT2Block(\n",
" (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (attn): GPT2Attention(\n",
" (c_attn): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (attn_dropout): Dropout(p=0.1, inplace=False)\n",
" (resid_dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): GPT2MLP(\n",
" (c_fc): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (act): NewGELUActivation()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (1): GPT2Block(\n",
" (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (attn): GPT2Attention(\n",
" (c_attn): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (attn_dropout): Dropout(p=0.1, inplace=False)\n",
" (resid_dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): GPT2MLP(\n",
" (c_fc): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (act): NewGELUActivation()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (2): GPT2Block(\n",
" (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (attn): GPT2Attention(\n",
" (c_attn): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (attn_dropout): Dropout(p=0.1, inplace=False)\n",
" (resid_dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): GPT2MLP(\n",
" (c_fc): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (act): NewGELUActivation()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (3): GPT2Block(\n",
" (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (attn): GPT2Attention(\n",
" (c_attn): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (attn_dropout): Dropout(p=0.1, inplace=False)\n",
" (resid_dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): GPT2MLP(\n",
" (c_fc): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (act): NewGELUActivation()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (adapter_bottom): ModuleList(\n",
" (0): GPT2Block(\n",
" (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (attn): GPT2Attention(\n",
" (c_attn): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (attn_dropout): Dropout(p=0.1, inplace=False)\n",
" (resid_dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): GPT2MLP(\n",
" (c_fc): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (act): NewGELUActivation()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (adapter_top): ModuleList(\n",
" (0): GPT2Block(\n",
" (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (attn): GPT2Attention(\n",
" (c_attn): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (attn_dropout): Dropout(p=0.1, inplace=False)\n",
" (resid_dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): GPT2MLP(\n",
" (c_fc): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (act): NewGELUActivation()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
")\n",
"**********\n",
"\n",
"**********\n",
"DataCollatorForSeq2Seq(tokenizer=GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True), added_tokens_decoder={\n",
"\t50256: AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n",
"}, model=None, padding=True, max_length=None, pad_to_multiple_of=None, label_pad_token_id=-100, return_tensors='pt')\n"
]
}
],
"source": [
"print(client_model())\n",
"print('*' * 10)\n",
"print(dataset())\n",
"print('*' * 10)\n",
"print(data_collator())"
]
},
{
"cell_type": "markdown",
"id": "898c3491",
"metadata": {},
"source": [
"Seems that everything is ready! Now we can submit the task. Submit the code below to submit your task."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "74497742-4030-4a7a-a13e-2c020da47cd1",
"metadata": {},
"outputs": [],
"source": [
"pipeline.fit()"
]
},
{
"cell_type": "markdown",
"id": "b33b2e2b-3b53-4881-8db6-a67e1293e88b",
"metadata": {},
"source": [
"## Add Deepspeed Setting\n",
"\n",
"By simply adding a ds_config, we can run our task with a deepspeed backend. If you have deployed eggroll envoironment, you can submmit the task with deepspeed to eggroll accelerate your training."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6e8f063b-263c-4ba5-b2ba-98a86ce38b94",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import time\n",
"from fate_client.pipeline.components.fate.reader import Reader\n",
"from fate_client.pipeline import FateFlowPipeline\n",
"from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_conf_of_ot_runner\n",
"from fate_client.pipeline.components.fate.nn.algo_params import Seq2SeqTrainingArguments, FedAVGArguments\n",
"from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMDataFuncLoader\n",
"from peft import LoraConfig, TaskType\n",
"from transformers.modeling_utils import unwrap_model\n",
"\n",
"\n",
"guest = '10000'\n",
"host = '10000'\n",
"arbiter = '10000'\n",
"\n",
"# pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter)\n",
"pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter)\n",
"\n",
"reader_0 = Reader(\"reader_0\", runtime_parties=dict(guest=guest))\n",
"reader_0.guest.task_parameters(\n",
" namespace=\"experiment\",\n",
" name=\"sciq\"\n",
")\n",
"\n",
"client_model = LLMModelLoader(\n",
" module_name='offsite_tuning.gpt2', item_name='GPT2LMHeadSubModel',\n",
" model_name_or_path='gpt2',\n",
" emulator_layer_num=18,\n",
" adapter_top_layer_num=2,\n",
" adapter_bottom_layer_num=2\n",
")\n",
"\n",
"server_model = LLMModelLoader(\n",
" module_name='offsite_tuning.gpt2', item_name='GPT2LMHeadMainModel',\n",
" model_name_or_path='gpt2',\n",
" emulator_layer_num=18,\n",
" adapter_top_layer_num=2,\n",
" adapter_bottom_layer_num=2 \n",
")\n",
"\n",
"dataset = LLMDatasetLoader(\n",
" module_name='qa_dataset', item_name='QaDataset',\n",
" tokenizer_name_or_path='gpt2',\n",
" select_num=100\n",
")\n",
"\n",
"data_collator = LLMDataFuncLoader(module_name='data_collator.cust_data_collator', item_name='get_seq2seq_data_collator', tokenizer_name_or_path='gpt2')\n",
"\n",
"batch_size = 1\n",
"lr = 5e-5\n",
"ds_config = {\n",
" \"train_micro_batch_size_per_gpu\": batch_size,\n",
" \"optimizer\": {\n",
" \"type\": \"Adam\",\n",
" \"params\": {\n",
" \"lr\": lr,\n",
" \"torch_adam\": True,\n",
" \"adam_w_mode\": False\n",
" }\n",
" },\n",
" \"fp16\": {\n",
" \"enabled\": True\n",
" },\n",
" \"gradient_accumulation_steps\": 1,\n",
" \"zero_optimization\": {\n",
" \"stage\": 2,\n",
" \"allgather_partitions\": True,\n",
" \"allgather_bucket_size\": 1e8,\n",
" \"overlap_comm\": True,\n",
" \"reduce_scatter\": True,\n",
" \"reduce_bucket_size\": 1e8,\n",
" \"contiguous_gradients\": True,\n",
" \"offload_optimizer\": {\n",
" \"device\": \"cpu\"\n",
" },\n",
" \"offload_param\": {\n",
" \"device\": \"cpu\"\n",
" }\n",
" }\n",
"}\n",
"\n",
"train_args = Seq2SeqTrainingArguments(\n",
" per_device_train_batch_size=1,\n",
" learning_rate=5e-5,\n",
" disable_tqdm=False,\n",
" num_train_epochs=1,\n",
" logging_steps=10,\n",
" logging_strategy='steps',\n",
" dataloader_num_workers=4,\n",
" use_cpu=False,\n",
" deepspeed=ds_config, # Add deepspeed config here\n",
" remove_unused_columns=False,\n",
" fp16=True\n",
")\n",
"\n",
"client_conf = get_conf_of_ot_runner(\n",
" model=client_model,\n",
" dataset=dataset,\n",
" data_collator=data_collator,\n",
" training_args=train_args,\n",
" fed_args=FedAVGArguments(),\n",
" aggregate_model=False,\n",
")\n",
"\n",
"server_conf = get_conf_of_ot_runner(\n",
" model=server_model,\n",
" dataset=dataset,\n",
" data_collator=data_collator,\n",
" training_args=train_args,\n",
" fed_args=FedAVGArguments(),\n",
" aggregate_model=False\n",
")\n",
"\n",
"\n",
"homo_nn_0 = HomoNN(\n",
" 'nn_0',\n",
" train_data=reader_0.outputs[\"output_data\"],\n",
" runner_module=\"offsite_tuning_runner\",\n",
" runner_class=\"OTRunner\"\n",
")\n",
"\n",
"homo_nn_0.guest.task_parameters(runner_conf=client_conf)\n",
"homo_nn_0.arbiter.task_parameters(runner_conf=server_conf)\n",
"\n",
"# if you have deployed eggroll, you can add this line to submit your job to eggroll\n",
"homo_nn_0.guest.conf.set(\"launcher_name\", \"deepspeed\")\n",
"\n",
"pipeline.add_tasks([reader_0, homo_nn_0])\n",
"pipeline.conf.set(\"task\", dict(engine_run={\"cores\": 4}))\n",
"pipeline.compile()\n",
"pipeline.fit()\n"
]
},
{
"cell_type": "markdown",
"id": "97249681-c3a3-43bd-8167-7ae3f4e1616b",
"metadata": {},
"source": [
"## Offsite-tuning + Multi Client Federation\n",
"\n",
"\n",
"The Offsite-Tuning + FedAVG federation is configured based on the standard Offsite-Tuning. In this situation, you need to add data input & configurations for all clients. And do remember to add 'aggregate_model=True' for client & server conf so that model federation will be conducted during the training."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fdbdc60c-a948-4be3-bba6-519d8640b0a9",
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"from fate_client.pipeline.components.fate.reader import Reader\n",
"from fate_client.pipeline import FateFlowPipeline\n",
"from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_conf_of_ot_runner\n",
"from fate_client.pipeline.components.fate.nn.algo_params import Seq2SeqTrainingArguments, FedAVGArguments\n",
"from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMCustFuncLoader\n",
"from peft import LoraConfig, TaskType\n",
"\n",
"\n",
"guest = '10000'\n",
"host = '10000'\n",
"arbiter = '10000'\n",
"\n",
"pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter)\n",
"\n",
"reader_0 = Reader(\"reader_0\", runtime_parties=dict(guest=guest, host=host))\n",
"reader_0.guest.task_parameters(\n",
" namespace=\"experiment\",\n",
" name=\"sciq\"\n",
")\n",
"reader_0.hosts[0].task_parameters(\n",
" namespace=\"experiment\",\n",
" name=\"sciq\"\n",
")\n",
"\n",
"client_model = LLMModelLoader(\n",
" module_name='offsite_tuning.gpt2', item_name='GPT2LMHeadSubModel',\n",
" model_name_or_path='gpt2',\n",
" emulator_layer_num=4,\n",
" adapter_top_layer_num=1,\n",
" adapter_bottom_layer_num=1\n",
")\n",
"\n",
"server_model = LLMModelLoader(\n",
" module_name='offsite_tuning.gpt2', item_name='GPT2LMHeadMainModel',\n",
" model_name_or_path='gpt2',\n",
" emulator_layer_num=4,\n",
" adapter_top_layer_num=1,\n",
" adapter_bottom_layer_num=1 \n",
")\n",
"\n",
"dataset = LLMDatasetLoader(\n",
" module_name='qa_dataset', item_name='QaDataset',\n",
" tokenizer_name_or_path='gpt2',\n",
" select_num=100\n",
")\n",
"\n",
"data_collator = LLMCustFuncLoader(module_name='cust_data_collator', item_name='get_seq2seq_tokenizer', model_path='gpt2')\n",
"\n",
"train_args = Seq2SeqTrainingArguments(\n",
" per_device_train_batch_size=1,\n",
" learning_rate=5e-5,\n",
" disable_tqdm=False,\n",
" num_train_epochs=1,\n",
" logging_steps=10,\n",
" logging_strategy='steps',\n",
" dataloader_num_workers=4\n",
")\n",
"\n",
"client_conf = get_conf_of_ot_runner(\n",
" model=client_model,\n",
" dataset=dataset,\n",
" data_collator=data_collator,\n",
" training_args=train_args,\n",
" fed_args=FedAVGArguments(),\n",
" aggregate_model=True\n",
")\n",
"\n",
"server_conf = get_conf_of_ot_runner(\n",
" model=server_model,\n",
" dataset=dataset,\n",
" data_collator=data_collator,\n",
" training_args=train_args,\n",
" fed_args=FedAVGArguments(),\n",
" aggregate_model=True\n",
")\n",
"\n",
"homo_nn_0 = HomoNN(\n",
" 'nn_0',\n",
" train_data=reader_0.outputs[\"output_data\"],\n",
" runner_module=\"offsite_tuning_runner\",\n",
" runner_class=\"OTRunner\"\n",
")\n",
"\n",
"homo_nn_0.guest.task_parameters(runner_conf=client_conf)\n",
"homo_nn_0.hosts[0].task_parameters(runner_conf=client_conf)\n",
"homo_nn_0.arbiter.task_parameters(runner_conf=server_conf)\n",
"\n",
"pipeline.add_tasks([reader_0, homo_nn_0])\n",
"\n",
"pipeline.compile()\n",
"pipeline.fit()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: doc/tutorial/offsite_tuning/README.md
================================================
# Offsite-Tuning
## Standard Offsite-tuning
Offsite-Tuning is designed for the efficient adaptation of large foundational models for specific downstream tasks.
Through Offsite-Tuning, the model owner can enhance the capabilities of large models using data providers without having to disclose the full model weights and directly access the data providers' sensitive information. Specifically, the LLM owner sends a lightweight "Adapter" and a lossy compressed "Emulator" to the data owner. Using these smaller components, the data owner can then fine-tune the model solely on their private data. The Adapter, once fine-tuned, is returned to the model owner and integrated back into the large model to enhance its performance on the specific dataset.
In FATE-LLM 1.3, we provide these built-in models:
- GPT2 series models (e.g., GPT2, GPT2-XL, etc.)
- Bloom series models (such as Bloom7B)
- Llama-1 series models (e.g., Llama7B)
FATE-LLM v1.3 builds on v1.2 and offers the ability to easily configure multi-machine and multi-card acceleration. It also has specialized optimizations for the network transmission of adapters and emulators.
[Read the full paper](https://arxiv.org/abs/2302.04870)
## Offsite-tuning with Federated Learning
In addition to supporting standard two-party (model owner and data provider) offsite-tuning, FATE also supports offsite-tuning with multiple data providers simultaneously. Adapters can be fine-tuned locally and then aggregated with those from other data providers. Ultimately, large models can be enhanced through the secure aggregation of adapters from multiple parties. This approach can be used to address issues related to the uneven distribution of computational power and data.
As shown in the diagram below:
================================================
FILE: doc/tutorial/pellm/ChatGLM3-6B_ds.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Federated ChatGLM3 Tuning with Parameter Efficient methods in FATE-LLM"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this tutorial, we will demonstrate how to efficiently train federated ChatGLM3-6B with deepspeed using the FATE-LLM framework. In FATE-LLM, we introduce the \"pellm\"(Parameter Efficient Large Language Model) module, specifically designed for federated learning with large language models. We enable the implementation of parameter-efficient methods in federated learning, reducing communication overhead while maintaining model performance. In this tutorial we particularlly focus on ChatGLM3-6B, and we will also emphasize the use of the Adapter mechanism for fine-tuning ChatGLM3-6B, which enables us to effectively reduce communication volume and improve overall efficiency.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## FATE-LLM: ChatGLM3-6B\n",
"\n",
"### ChatGLM-6B\n",
"ChatGLM3-6B is a large transformer-based language model with 5.977 billion parameters, it is an open bilingual language model based on General Language Model. You can download the pretrained model from [here](https://github.com/THUDM/ChatGLM3), or let the program automatically download it when you use it later.\n",
"\n",
"### Current Features\n",
"\n",
"In current version, FATE-LLM: ChatGLM-6B supports the following features:\n",
"\n",
"

\n",
"
"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Experiment Setting"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Before running experiment, please make sure that [FATE-LLM Cluster](https://github.com/FederatedAI/FATE/wiki/Download#llm%E9%83%A8%E7%BD%B2%E5%8C%85) has been deployed. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dataset: Advertising Text Generation\n",
"\n",
"This is an advertising test generateion dataset, you can download dataset from the following links and place it in the examples/data folder. \n",
"- [data link 1](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view)\n",
"- [data link 2](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) \n",
"\n",
"You can refer to following link for more details about [data](https://aclanthology.org/D19-1321.pdf)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"df = pd.read_json('${fate_install}/examples/data/AdvertiseGen/train.json', lines=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ChatGLM3-6B with Adapter\n",
"\n",
"In this section, we will guide you through the process of finetuning ChatGLM-6B with adapters using the FATE-LLM framework. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ChatGLM model is located on fate_llm/model_zoo/chatglm.py, can be use directly"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"albert.py bloom.py distilbert.py parameter_efficient_llm.py\n",
"bart.py chatglm.py gpt2.py\t qwen.py\n",
"bert.py deberta.py llama.py roberta.py\n"
]
}
],
"source": [
"! ls ../../../../fate_llm/python/fate_llm/model_zoo/pellm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Adapters"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can directly use adapters from the peft. See details for adapters on this page [Adapter Methods](https://huggingface.co/docs/peft/index) for more details. By specifying the adapter name and the adapter\n",
"config dict we can insert adapters into our language models:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from peft import LoraConfig, TaskType\n",
"\n",
"lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1,\n",
" target_modules=['query_key_value'],\n",
")\n",
"lora_config.target_modules = list(lora_config.target_modules) # this line is needed to ensure lora_config is jsonable"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Init ChatGLM3 Model "
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader\n",
"\n",
"pretrained_model_path = \"fill with pretrained model download path please\"\n",
"\n",
"model = LLMModelLoader(\n",
" \"pellm.chatglm\",\n",
" \"ChatGLM\",\n",
" pretrained_path=pretrained_model_path,\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" trust_remote_code=True\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**During the training process, all weights of the pretrained language model will be frozen, and weights of adapters are traininable. Thus, FATE-LLM only train in the local training and aggregate adapters' weights in the fedederation process**\n",
"\n",
"Now available adapters are [Adapters Overview](https://huggingface.co/docs/peft/index) for details.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Specify Dataset And DataCollator To Process Data"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"from fate_client.pipeline.components.fate.nn.loader import LLMDatasetLoader, LLMDataFuncLoader\n",
"\n",
"tokenizer_params = dict(\n",
" tokenizer_name_or_path=pretrained_model_path,\n",
" trust_remote_code=True,\n",
")\n",
"\n",
"dataset = LLMDatasetLoader(\n",
" \"prompt_dataset\",\n",
" \"PromptDataset\",\n",
" **tokenizer_params,\n",
")\n",
"\n",
"data_collator = LLMDataFuncLoader(\n",
" \"data_collator.cust_data_collator\",\n",
" \"get_seq2seq_data_collator\",\n",
" **tokenizer_params,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Init DeepSpeed Config"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"ds_config = {\n",
" \"train_micro_batch_size_per_gpu\": 1,\n",
" \"optimizer\": {\n",
" \"type\": \"Adam\",\n",
" \"params\": {\n",
" \"lr\": 5e-4\n",
" }\n",
" },\n",
" \"fp16\": {\n",
" \"enabled\": True\n",
" },\n",
" \"gradient_accumulation_steps\": 1,\n",
" \"zero_optimization\": {\n",
" \"stage\": 2,\n",
" \"allgather_partitions\": True,\n",
" \"allgather_bucket_size\": 1e8,\n",
" \"overlap_comm\": True,\n",
" \"reduce_scatter\": True,\n",
" \"reduce_bucket_size\": 1e8,\n",
" \"contiguous_gradients\": True,\n",
" \"offload_optimizer\": {\n",
" \"device\": \"cpu\"\n",
" },\n",
" \"offload_param\": {\n",
" \"device\": \"cpu\"\n",
" }\n",
" }\n",
"}\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Submit Federated Task\n",
"To run federated task, please make sure to ues fate>=2.1.0 and deploy it with gpu machines. To running this code, make sure training data path is already binded. The following code shoud be copy to a script and run in a command line like \"python federated_chatglm.py\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can use this script to submit the model, but submitting the model will take a long time to train and generate a long log, so we won't do it here."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"from fate_client.pipeline.components.fate.reader import Reader\n",
"from fate_client.pipeline import FateFlowPipeline\n",
"from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_seq2seq_runner\n",
"from fate_client.pipeline.components.fate.nn.algo_params import Seq2SeqTrainingArguments, FedAVGArguments\n",
"from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMDataFuncLoader\n",
"from peft import LoraConfig, TaskType\n",
"\n",
"\n",
"guest = '10000'\n",
"host = '10000'\n",
"arbiter = '10000'\n",
"\n",
"epochs = 1\n",
"batch_size = 1\n",
"lr = 5e-4\n",
"\n",
"ds_config = {\n",
" \"train_micro_batch_size_per_gpu\": batch_size,\n",
" \"optimizer\": {\n",
" \"type\": \"Adam\",\n",
" \"params\": {\n",
" \"lr\": lr,\n",
" \"torch_adam\": True,\n",
" \"adam_w_mode\": False\n",
" }\n",
" },\n",
" \"fp16\": {\n",
" \"enabled\": True\n",
" },\n",
" \"gradient_accumulation_steps\": 1,\n",
" \"zero_optimization\": {\n",
" \"stage\": 2,\n",
" \"allgather_partitions\": True,\n",
" \"allgather_bucket_size\": 1e8,\n",
" \"overlap_comm\": True,\n",
" \"reduce_scatter\": True,\n",
" \"reduce_bucket_size\": 1e8,\n",
" \"contiguous_gradients\": True,\n",
" \"offload_optimizer\": {\n",
" \"device\": \"cpu\"\n",
" },\n",
" \"offload_param\": {\n",
" \"device\": \"cpu\"\n",
" }\n",
" }\n",
"}\n",
"\n",
"pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter)\n",
"# pipeline.bind_local_path(path=\"\", namespace=\"experiment\", name=\"ad\")\n",
"time.sleep(5)\n",
"\n",
"\n",
"reader_0 = Reader(\"reader_0\", runtime_parties=dict(guest=guest, host=host))\n",
"reader_0.guest.task_parameters(\n",
" namespace=\"experiment\",\n",
" name=\"ad\"\n",
")\n",
"reader_0.hosts[0].task_parameters(\n",
" namespace=\"experiment\",\n",
" name=\"ad\"\n",
")\n",
"\n",
"# define lora config\n",
"lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1,\n",
" target_modules=['query_key_value'],\n",
")\n",
"lora_config.target_modules = list(lora_config.target_modules)\n",
"\n",
"pretrained_model_path = \"/data/cephfs/llm/models/chatglm3-6b\"\n",
"\n",
"model = LLMModelLoader(\n",
" \"pellm.chatglm\",\n",
" \"ChatGLM\",\n",
" pretrained_path=pretrained_model_path,\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" trust_remote_code=True\n",
")\n",
"\n",
"\n",
"tokenizer_params = dict(\n",
" tokenizer_name_or_path=pretrained_model_path,\n",
" trust_remote_code=True,\n",
")\n",
"\n",
"dataset = LLMDatasetLoader(\n",
" \"prompt_dataset\",\n",
" \"PromptDataset\",\n",
" **tokenizer_params,\n",
")\n",
"\n",
"data_collator = LLMDataFuncLoader(\n",
" \"data_collator.cust_data_collator\",\n",
" \"get_seq2seq_data_collator\",\n",
" **tokenizer_params,\n",
")\n",
"\n",
"conf = get_config_of_seq2seq_runner(\n",
" algo='fedavg',\n",
" model=model,\n",
" dataset=dataset,\n",
" data_collator=data_collator,\n",
" training_args=Seq2SeqTrainingArguments(\n",
" num_train_epochs=epochs,\n",
" per_device_train_batch_size=batch_size,\n",
" remove_unused_columns=False, \n",
" predict_with_generate=False,\n",
" deepspeed=ds_config,\n",
" learning_rate=lr,\n",
" use_cpu=False, # this must be set as we will gpu\n",
" fp16=True,\n",
" ),\n",
" fed_args=FedAVGArguments(),\n",
" task_type='causal_lm',\n",
" save_trainable_weights_only=True # only save trainable weights\n",
")\n",
"\n",
"homo_nn_0 = HomoNN(\n",
" 'nn_0',\n",
" runner_conf=conf,\n",
" train_data=reader_0.outputs[\"output_data\"],\n",
" runner_module=\"homo_seq2seq_runner\",\n",
" runner_class=\"Seq2SeqRunner\",\n",
")\n",
"\n",
"homo_nn_0.guest.conf.set(\"launcher_name\", \"deepspeed\") # tell schedule engine to run task with deepspeed\n",
"homo_nn_0.hosts[0].conf.set(\"launcher_name\", \"deepspeed\") # tell schedule engine to run task with deepspeed\n",
"\n",
"pipeline.add_tasks([reader_0, homo_nn_0])\n",
"pipeline.conf.set(\"task\", dict(engine_run={\"cores\": 1})) # the number of gpus of each party\n",
"\n",
"pipeline.compile()\n",
"pipeline.fit()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training With P-Tuning V2 Adapter"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To use another adapter lke P-Tuning V2, slightly changes is needed!"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"model = LLMModelLoader(\n",
" \"pellm.chatglm\",\n",
" \"ChatGLM\",\n",
" pretrained_path=pretrained_model_path,\n",
" pre_seq_len=128,\n",
" trust_remote_code=True\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Inference"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Models trained with FATE-LLM can be find under the directory `${fate_install}/fateflow/model/$job_id/${role}/${party_id}/$cpn_name/0/output/output_model/model_directory/adapter_model.bin}`,\n",
"The following code is an example to load trained lora adapter weights:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import sys\n",
"import torch\n",
"from peft import PeftModel, PeftConfig, LoraConfig, TaskType, get_peft_model\n",
"from transformers import AutoModel, AutoTokenizer\n",
"\n",
"\n",
"def load_model(pretrained_model_path):\n",
" _tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path, trust_remote_code=True)\n",
" _model = AutoModel.from_pretrained(pretrained_model_path, trust_remote_code=True)\n",
"\n",
" _model = _model.half()\n",
" _model = _model.eval()\n",
"\n",
" return _model, _tokenizer\n",
"\n",
"\n",
"def load_data(data_path):\n",
" with open(data_path, \"r\") as fin:\n",
" for _l in fin:\n",
" yield json.loads(_l.strip())\n",
"\n",
"\n",
"chatglm_model_path = \"\"\n",
"model, tokenizer = load_model(chatglm_model_path)\n",
"\n",
"test_data_path = \"{fate_install}/examples/data/AdvertiseGen/dev.json\"\n",
"dataset = load_data(test_data_path)\n",
"\n",
"peft_path = \"${fate_install}/fateflow/model/$job_id/${role}/${party_id}/$cpn_name/0/output/output_model/model_directory/adapter_model.bin}\"\n",
"\n",
"model = PeftModel.from_pretrained(model, peft_path)\n",
"model = model.half()\n",
"model.eval()\n",
"\n",
"for p in model.parameters():\n",
" if p.requires_grad:\n",
" print(p)\n",
"\n",
"model.cuda(\"cuda:0\")\n",
"\n",
"content = list(dataset)[0][\"content\"]\n",
"print(model.chat(tokenizer, content, do_sample=False))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.15"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
================================================
FILE: doc/tutorial/pellm/builtin_pellm_models.md
================================================
## Builtin PELLM Models
FATE-LLM provide some builtin pellm models, users can use them simply to efficiently train their language models.
To use these models, please read the using tutorial of [ChatGLM-6B Training Guide](./ChatGLM3-6B_ds.ipynb).
After reading the training tutorial above, it's easy to use other models listing in the following tabular by changing `module_name`, `class_name`, `dataset` list below.
| Model | ModuleName | ClassName | DataSetName |
| -------------- | ----------------- | --------------| --------------- |
| Qwen2 | pellm.qwen | Qwen | prompt_dataset |
| Bloom-7B1 | pellm.bloom | Bloom | prompt_dataset |
| OPT-6.7B | pellm.opt | OPT | prompt_dataset |
| LLaMA-2-7B | pellm.llama | LLaMa | prompt_dataset |
| LLaMA-7B | pellm.llama | LLaMa | prompt_dataset |
| ChatGLM3-6B | pellm.chatglm | ChatGLM | prompt_dataset |
| GPT-2 | pellm.gpt2 | GPT2CLM | prompt_dataset |
| GPT-2 | pellm.gpt2 | GPT2 | seq_cls_dataset |
| ALBERT | pellm.albert | Albert | seq_cls_dataset |
| BART | pellm.bart | Bart | seq_cls_dataset |
| BERT | pellm.bert | Bert | seq_cls_dataset |
| DeBERTa | pellm.deberta | Deberta | seq_cls_dataset |
| DistilBERT | pellm.distilbert | DistilBert | seq_cls_dataset |
| RoBERTa | pellm.roberta | Roberta | seq_cls_dataset |
================================================
FILE: examples/fedmkt/__init__.py
================================================
================================================
FILE: examples/fedmkt/fedmkt.py
================================================
from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_fedmkt_runner
from fate_client.pipeline.components.fate.nn.algo_params import FedMKTTrainingArguments, FedAVGArguments
from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMDataFuncLoader
from peft import LoraConfig, TaskType
from fate_client.pipeline import FateFlowPipeline
from fate_client.pipeline.components.fate.reader import Reader
from transformers import AutoConfig
import argparse
import yaml
from typing import Union, Dict
def main(config="./config.yaml", param: Union[Dict, str] = None):
if isinstance(config, str):
with open(config, 'r') as f:
config = yaml.safe_load(f)
if isinstance(param, str):
param = yaml.safe_load(param)
guest = config['parties']['guest'][0] # replace with actual guest party ID
host = config['parties']['host'][0] # replace with actual host party ID
arbiter = config['parties']['arbiter'][0] # replace with actual arbiter party ID
process_data_output_dir = config['paths']['process_data_output_dir']
llm_pretrained_path = config['paths']['llm_pretrained_path']
slm_pretrained_paths = config['paths']['slm_pretrained_paths']
vocab_mapping_directory = config['paths']['vocab_mapping_directory']
slm_to_llm_vocab_mapping_paths = [
vocab_mapping_directory + "/" + path for path in config['paths']['slm_to_llm_vocab_mapping_paths']
]
llm_to_slm_vocab_mapping_paths = [
vocab_mapping_directory + "/" + path for path in config['paths']['llm_to_slm_vocab_mapping_paths']
]
slm_models = config['models']['slm_models']
slm_lora_target_modules = config['lora_config']['slm_lora_target_modules']
def get_llm_conf():
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=param['lora_config']['llm']['r'],
lora_alpha=param['lora_config']['llm']['lora_alpha'],
lora_dropout=param['lora_config']['llm']['lora_dropout'],
target_modules=param['lora_config']['llm']['target_modules']
)
lora_config.target_modules = list(lora_config.target_modules)
llm_model = LLMModelLoader(
"pellm.llama",
"LLaMa",
pretrained_path=llm_pretrained_path,
peft_type="LoraConfig",
peft_config=lora_config.to_dict(),
torch_dtype="bfloat16"
)
pub_dataset = LLMDatasetLoader(
"qa_dataset",
"QaDataset",
tokenizer_name_or_path=llm_pretrained_path,
need_preprocess=True,
dataset_name="arc_challenge",
data_part="common",
seq_max_len=512
)
training_args = FedMKTTrainingArguments(
global_epochs=param['training']['llm']['global_epochs'],
per_device_train_batch_size=param['training']['llm']['per_device_train_batch_size'],
gradient_accumulation_steps=param['training']['llm']['gradient_accumulation_steps'],
learning_rate=param['training']['llm']['learning_rate'],
output_dir=param['training']['llm']['output_dir'],
dataloader_num_workers=param['training']['llm']['dataloader_num_workers'],
remove_unused_columns=param['training']['llm']['remove_unused_columns'],
warmup_ratio=param['training']['llm']['warmup_ratio'],
lr_scheduler_type=param['training']['llm']['lr_scheduler_type'],
optim=param['training']['llm']['optim'],
adam_beta1=param['training']['llm']['adam_beta1'],
adam_beta2=param['training']['llm']['adam_beta2'],
weight_decay=param['training']['llm']['weight_decay'],
max_grad_norm=param['training']['llm']['max_grad_norm'],
use_cpu=param['training']['llm']['use_cpu'],
vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size,
)
fed_args = FedAVGArguments(
aggregate_strategy='epoch',
aggregate_freq=1
)
tokenizer = LLMDataFuncLoader(
"tokenizers.cust_tokenizer",
"get_tokenizer",
tokenizer_name_or_path=llm_pretrained_path
)
slm_tokenizers = [
LLMDataFuncLoader("tokenizers.cust_tokenizer", "get_tokenizer", tokenizer_name_or_path=path)
for path in slm_pretrained_paths
]
return get_config_of_fedmkt_runner(
model=llm_model,
training_args=training_args,
fed_args=fed_args,
pub_dataset=pub_dataset,
tokenizer=tokenizer,
slm_tokenizers=slm_tokenizers,
slm_to_llm_vocab_mapping_paths=slm_to_llm_vocab_mapping_paths,
pub_dataset_path=process_data_output_dir,
save_trainable_weights_only=True,
)
def get_slm_conf(slm_idx):
slm_pretrained_path = slm_pretrained_paths[slm_idx]
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=param['lora_config']['slm'][slm_idx]['r'],
lora_alpha=param['lora_config']['slm'][slm_idx]['lora_alpha'],
lora_dropout=param['lora_config']['slm'][slm_idx]['lora_dropout'],
target_modules=param['lora_config']['slm'][slm_idx]['target_modules']
)
lora_config.target_modules = list(lora_config.target_modules)
llm_to_slm_vocab_mapping = llm_to_slm_vocab_mapping_paths[slm_idx]
slm_model = LLMModelLoader(
slm_models[slm_idx][0],
slm_models[slm_idx][1],
pretrained_path=slm_pretrained_path,
peft_type="LoraConfig",
peft_config=lora_config.to_dict(),
)
vocab_size = AutoConfig.from_pretrained(slm_pretrained_path).vocab_size
pub_dataset = LLMDatasetLoader(
"qa_dataset",
"QaDataset",
tokenizer_name_or_path=slm_pretrained_path,
need_preprocess=True,
dataset_name="arc_challenge",
data_part="common",
seq_max_len=512
)
priv_dataset = LLMDatasetLoader(
"qa_dataset",
"QaDataset",
tokenizer_name_or_path=slm_pretrained_path,
need_preprocess=True,
dataset_name="arc_challenge",
data_part="client_0",
seq_max_len=512
)
training_args = FedMKTTrainingArguments(
global_epochs=param['training']['slm']['global_epochs'],
per_device_train_batch_size=param['training']['slm']['per_device_train_batch_size'],
gradient_accumulation_steps=param['training']['slm']['gradient_accumulation_steps'],
learning_rate=param['training']['slm']['learning_rate'] if slm_idx != 1 else 3e-4,
output_dir=param['training']['slm']['output_dir'],
dataloader_num_workers=param['training']['slm']['dataloader_num_workers'],
remove_unused_columns=param['training']['slm']['remove_unused_columns'],
warmup_ratio=param['training']['slm']['warmup_ratio'],
lr_scheduler_type=param['training']['slm']['lr_scheduler_type'],
optim=param['training']['slm']['optim'],
adam_beta1=param['training']['slm']['adam_beta1'],
adam_beta2=param['training']['slm']['adam_beta2'],
weight_decay=param['training']['slm']['weight_decay'],
max_grad_norm=param['training']['slm']['max_grad_norm'],
use_cpu=param['training']['slm']['use_cpu'],
vocab_size=vocab_size,
)
fed_args = FedAVGArguments(
aggregate_strategy='epoch',
aggregate_freq=1
)
tokenizer = LLMDataFuncLoader(
"tokenizers.cust_tokenizer",
"get_tokenizer",
tokenizer_name_or_path=slm_pretrained_path
)
llm_tokenizer = LLMDataFuncLoader(
"tokenizers.cust_tokenizer",
"get_tokenizer",
tokenizer_name_or_path=llm_pretrained_path
)
data_collator = LLMDataFuncLoader(
module_name='data_collator.cust_data_collator',
item_name='get_seq2seq_data_collator',
tokenizer_name_or_path=slm_pretrained_path
)
return get_config_of_fedmkt_runner(
model=slm_model,
training_args=training_args,
fed_args=fed_args,
pub_dataset=pub_dataset,
priv_dataset=priv_dataset,
tokenizer=tokenizer,
llm_tokenizer=llm_tokenizer,
llm_to_slm_vocab_mapping_path=llm_to_slm_vocab_mapping,
pub_dataset_path=process_data_output_dir,
save_trainable_weights_only=True,
data_collator=data_collator
)
pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter, host=host)
pipeline.bind_local_path(path=process_data_output_dir, namespace="experiment", name="arc_challenge")
reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host))
reader_0.guest.task_parameters(
namespace=config['data']['guest']['namespace'],
name=config['data']['guest']['name']
)
reader_0.hosts[[0, 1, 2]].task_parameters(
namespace=config['data']['host']['namespace'],
name=config['data']['host']['name']
)
homo_nn_0 = HomoNN(
'nn_0',
train_data=reader_0.outputs["output_data"],
runner_module="fedmkt_runner",
runner_class="FedMKTRunner",
)
homo_nn_0.arbiter.task_parameters(
runner_conf=get_llm_conf()
)
homo_nn_0.guest.task_parameters(
runner_conf=get_slm_conf(slm_idx=0)
)
for idx in range(1):
homo_nn_0.hosts[idx].task_parameters(
runner_conf=get_slm_conf(slm_idx=idx + 1)
)
homo_nn_0.guest.conf.set("launcher_name", "deepspeed") # tell scheduler engine to run task with deepspeed
homo_nn_0.hosts[0].conf.set("launcher_name", "deepspeed") # tell scheduler engine to run task with deepspeed
homo_nn_0.arbiter.conf.set("launcher_name", "deepspeed") # tell scheduler engine to run task with deepspeed
pipeline.add_tasks([reader_0, homo_nn_0])
pipeline.conf.set("task", dict(engine_run={"cores": 1})) # the number of gpus of each party
pipeline.compile()
pipeline.fit()
if __name__ == "__main__":
parser = argparse.ArgumentParser("LLMSUITE PIPELINE JOB")
parser.add_argument("-c", "--config", type=str, help="config file", default="./config.yaml")
parser.add_argument("-p", "--param", type=str, help="config file for params", default="./fedmkt_config.yaml")
args = parser.parse_args()
main(args.config, args.param)
================================================
FILE: examples/fedmkt/fedmkt_config.yaml
================================================
# fedmkt_config.yaml
# Configuration for Lora
lora_config:
llm:
r: 8
lora_alpha: 16
lora_dropout: 0.05
target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
slm:
- # Configuration for the first SLM model
r: 8
lora_alpha: 32
lora_dropout: 0.1
target_modules:
- q_proj
- v_proj
- # Configuration for the second SLM model
r: 8
lora_alpha: 32
lora_dropout: 0.1
target_modules:
- c_attn
# Training configuration
training:
llm:
global_epochs: 5
per_device_train_batch_size: 1
gradient_accumulation_steps: 4
learning_rate: 3e-5
output_dir: "./"
dataloader_num_workers: 4
remove_unused_columns: false
warmup_ratio: 0.008
lr_scheduler_type: "cosine"
optim: "adamw_torch"
adam_beta1: 0.9
adam_beta2: 0.95
weight_decay: 0.1
max_grad_norm: 1.0
use_cpu: false
slm:
global_epochs: 5
per_device_train_batch_size: 1
gradient_accumulation_steps: 4
learning_rate: 3e-5 # Adjust learning rate for SLM models
output_dir: "./"
dataloader_num_workers: 4
remove_unused_columns: false
warmup_ratio: 0.008
lr_scheduler_type: "cosine"
optim: "adamw_torch"
adam_beta1: 0.9
adam_beta2: 0.95
weight_decay: 0.1
max_grad_norm: 1.0
use_cpu: false
# Paths configuration
paths:
process_data_output_dir: ""
llm_pretrained_path: "Llama-2-7b-hf"
slm_pretrained_paths:
- "opt-1.3b"
- "gpt2"
vocab_mapping_directory: ""
slm_to_llm_vocab_mapping_paths:
- "opt_to_llama.json"
- "gpt2_to_llama.json"
- "llama_small_to_llama.json"
llm_to_slm_vocab_mapping_paths:
- "llama_to_opt.json"
- "llama_to_gpt2.json"
- "llama_to_llama_small"
# Models configuration
models:
slm_models:
- ["pellm.opt", "OPT"]
- ["pellm.gpt2", "GPT2CLM"]
# Data configuration
data:
guest:
namespace: "experiment"
name: "arc_challenge"
host:
namespace: "experiment"
name: "arc_challenge"
# Example: Additional custom configuration
custom_config:
some_param: "value"
another_param: 123
================================================
FILE: examples/fedmkt/test_fedmkt_llmsuit.yaml
================================================
data:
- file:
table_name: arc_challenge
namespace: experiment
role: guest_0
- file:
table_name: arc_challenge
namespace: experiment
role: host_0
bloom_lora_vs_zero_shot:
gpt2_fedmkt:
pretrained: "gpt2"
script: "./fedmkt.py"
conf: "./fedmkt_config.yaml"
================================================
FILE: examples/offsite_tuning/__init__.py
================================================
================================================
FILE: examples/offsite_tuning/offsite_tuning.py
================================================
import argparse
import yaml
from fate_client.pipeline.components.fate.reader import Reader
from fate_client.pipeline import FateFlowPipeline
from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_conf_of_ot_runner
from fate_client.pipeline.components.fate.nn.algo_params import Seq2SeqTrainingArguments, FedAVGArguments
from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMDataFuncLoader
from fate_client.pipeline.components.fate.nn.torch.base import Sequential
from fate_client.pipeline.components.fate.nn.torch import nn
def load_params(file_path):
"""Load and parse the YAML params file."""
with open(file_path, 'r') as f:
params = yaml.safe_load(f)
return params
def setup_pipeline(params):
"""Set up the pipeline using the provided parameters."""
guest = params['pipeline']['guest']
arbiter = params['pipeline']['arbiter']
pretrained_model_path = params['paths']['pretrained_model_path']
pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter)
reader = Reader("reader_0", runtime_parties=dict(guest=guest))
reader.guest.task_parameters(
namespace=params['pipeline']['namespace'],
name=params['pipeline']['name']
)
client_model = LLMModelLoader(
module_name=params['models']['client']['module_name'],
item_name=params['models']['client']['item_name'],
model_name_or_path=pretrained_model_path,
emulator_layer_num=params['models']['client']['emulator_layer_num'],
adapter_top_layer_num=params['models']['client']['adapter_top_layer_num'],
adapter_bottom_layer_num=params['models']['client']['adapter_bottom_layer_num']
)
server_model = LLMModelLoader(
module_name=params['models']['server']['module_name'],
item_name=params['models']['server']['item_name'],
model_name_or_path=pretrained_model_path,
emulator_layer_num=params['models']['server']['emulator_layer_num'],
adapter_top_layer_num=params['models']['server']['adapter_top_layer_num'],
adapter_bottom_layer_num=params['models']['server']['adapter_bottom_layer_num']
)
dataset = LLMDatasetLoader(
module_name=params['dataset']['module_name'],
item_name=params['dataset']['item_name'],
tokenizer_name_or_path=params['dataset']['tokenizer_name_or_path'],
select_num=params['dataset']['select_num']
)
data_collator = LLMDataFuncLoader(
module_name=params['data_collator']['module_name'],
item_name=params['data_collator']['item_name'],
tokenizer_name_or_path=params['data_collator']['tokenizer_name_or_path']
)
train_args = Seq2SeqTrainingArguments(
per_device_train_batch_size=params['training']['batch_size'],
learning_rate=params['training']['learning_rate'],
disable_tqdm=False,
num_train_epochs=params['training']['num_train_epochs'],
logging_steps=params['training']['logging_steps'],
logging_strategy='steps',
dataloader_num_workers=4,
use_cpu=False,
deepspeed=params['training']['deepspeed'], # Add DeepSpeed config here
remove_unused_columns=False,
fp16=True
)
client_conf = get_conf_of_ot_runner(
model=client_model,
dataset=dataset,
data_collator=data_collator,
training_args=train_args,
fed_args=FedAVGArguments(),
aggregate_model=False,
)
server_conf = get_conf_of_ot_runner(
model=server_model,
dataset=dataset,
data_collator=data_collator,
training_args=train_args,
fed_args=FedAVGArguments(),
aggregate_model=False
)
homo_nn = HomoNN(
'nn_0',
train_data=reader.outputs["output_data"],
runner_module="offsite_tuning_runner",
runner_class="OTRunner"
)
homo_nn.guest.task_parameters(runner_conf=client_conf)
homo_nn.arbiter.task_parameters(runner_conf=server_conf)
# If using Eggroll, you can add this line to submit your job
homo_nn.guest.conf.set("launcher_name", "deepspeed")
pipeline.add_tasks([reader, homo_nn])
pipeline.conf.set("task", dict(engine_run=params['pipeline']['engine_run']))
pipeline.compile()
pipeline.fit()
def main(config_file, param_file):
params = load_params(param_file)
setup_pipeline(params)
if __name__ == "__main__":
parser = argparse.ArgumentParser("LLMSUITE Offsite-tuning JOB")
parser.add_argument("-c", "--config", type=str,
help="Path to config file", default="./config.yaml")
parser.add_argument("-p", "--param", type=str,
help="Path to parameter file", default="./test_offsite_tuning_llmsuite.yaml")
args = parser.parse_args()
main(args.config, args.param)
================================================
FILE: examples/offsite_tuning/offsite_tuning_config.yaml
================================================
# params.yaml
paths:
pretrained_model_path: 'gpt2'
pipeline:
guest: '9999'
arbiter: '9999'
namespace: 'experiment'
name: 'sciq'
engine_run:
cores: 1
training:
batch_size: 1
learning_rate: 5e-5
num_train_epochs: 1
logging_steps: 10
deepspeed:
train_micro_batch_size_per_gpu: 1
optimizer:
type: "Adam"
params:
lr: 5e-5
torch_adam: true
adam_w_mode: false
fp16:
enabled: true
gradient_accumulation_steps: 1
zero_optimization:
stage: 2
allgather_partitions: true
allgather_bucket_size: 1e8
overlap_comm: true
reduce_scatter: true
reduce_bucket_size: 1e8
contiguous_gradients: true
offload_optimizer:
device: "cpu"
offload_param:
device: "cpu"
models:
client:
module_name: 'offsite_tuning.gpt2'
item_name: 'GPT2LMHeadSubModel'
emulator_layer_num: 11
adapter_top_layer_num: 2
adapter_bottom_layer_num: 2
server:
module_name: 'offsite_tuning.gpt2'
item_name: 'GPT2LMHeadMainModel'
emulator_layer_num: 11
adapter_top_layer_num: 2
adapter_bottom_layer_num: 2
dataset:
module_name: 'qa_dataset'
item_name: 'QaDataset'
tokenizer_name_or_path: 'gpt2'
select_num: 100
data_collator:
module_name: 'data_collator.cust_data_collator'
item_name: 'get_seq2seq_data_collator'
tokenizer_name_or_path: 'gpt2'
================================================
FILE: examples/offsite_tuning/test_offsite_tuning_llmsuite.yaml
================================================
data:
- file:
table_name: sciq
namespace: experiment
role: guest_0
- file:
table_name: sciq
namespace: experiment
role: host_0
bloom_lora_vs_zero_shot:
gpt2_ot:
pretrained: "gpt2"
script: "./offsite_tuning.py"
conf: "./offsite_tuning_config.yaml"
================================================
FILE: examples/pellm/__init__.py
================================================
================================================
FILE: examples/pellm/bloom_lora_config.yaml
================================================
data:
guest:
namespace: experiment
name: ad
host:
namespace: experiment
name: ad
epoch: 1
batch_size: 4
lr: 5e-4
pretrained_model_path: bloom-560m
peft_config:
alpha_pattern: {}
auto_mapping: null
base_model_name_or_path: null
bias: none
fan_in_fan_out: false
inference_mode: false
init_lora_weights: true
layers_pattern: null
layers_to_transform: null
loftq_config: { }
lora_alpha: 32
lora_dropout: 0.1
megatron_config: null
megatron_core: megatron.core
modules_to_save: null
peft_type: LORA
r: 8
rank_pattern: { }
revision: null
target_modules:
- query_key_value
task_type: CAUSAL_LM
use_rslora: false
ds_config:
fp16:
enabled: true
gradient_accumulation_steps: 1
optimizer:
params:
adam_w_mode: false
lr: 5e-4
torch_adam: true
type: Adam
train_micro_batch_size_per_gpu: 4
zero_optimization:
allgather_bucket_size: 100000000.0
allgather_partitions: true
contiguous_gradients: true
offload_optimizer:
device: cpu
offload_param:
device: cpu
overlap_comm: true
reduce_bucket_size: 100000000.0
reduce_scatter: true
stage: 2
================================================
FILE: examples/pellm/test_bloom_lora.py
================================================
import time
from fate_client.pipeline.components.fate.reader import Reader
from fate_client.pipeline import FateFlowPipeline
from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_seq2seq_runner
from fate_client.pipeline.components.fate.nn.algo_params import Seq2SeqTrainingArguments, FedAVGArguments
from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMDataFuncLoader
from peft import LoraConfig, TaskType
from fate_client.pipeline.utils import test_utils
import argparse
import yaml
from typing import Union, Dict
def main(config="../../config.yaml", param: Union[Dict, str] = None, namespace=""):
if isinstance(config, str):
config = test_utils.load_job_config(config)
if isinstance(param, str):
param = yaml.safe_load(param)
parties = config.parties
guest = parties.guest[0]
host = parties.host[0]
arbiter = parties.arbiter[0]
pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter)
reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host))
reader_0.guest.task_parameters(
namespace=param["data"]["guest"]["namespace"],
name=param["data"]["guest"]["name"]
)
reader_0.hosts[0].task_parameters(
namespace=param["data"]["host"]["namespace"],
name=param["data"]["host"]["name"]
)
lora_config = LoraConfig(**param["peft_config"])
lora_config.target_modules = list(lora_config.target_modules)
pretrained_model_path = param["pretrained_model_path"]
model = LLMModelLoader(
"pellm.bloom",
"Bloom",
pretrained_path=pretrained_model_path,
peft_type="LoraConfig",
peft_config=lora_config.to_dict(),
trust_remote_code=True
)
tokenizer_params = dict(
tokenizer_name_or_path=pretrained_model_path,
trust_remote_code=True,
)
dataset = LLMDatasetLoader(
"prompt_dataset",
"PromptDataset",
**tokenizer_params,
)
data_collator = LLMDataFuncLoader(
"data_collator.cust_data_collator",
"get_seq2seq_data_collator",
**tokenizer_params,
)
conf = get_config_of_seq2seq_runner(
algo='fedavg',
model=model,
dataset=dataset,
data_collator=data_collator,
training_args=Seq2SeqTrainingArguments(
num_train_epochs=param["epoch"],
per_device_train_batch_size=param["batch_size"],
remove_unused_columns=False,
predict_with_generate=False,
deepspeed=param["ds_config"],
learning_rate=param["lr"],
use_cpu=False, # this must be set as we will gpu
fp16=True,
),
fed_args=FedAVGArguments(),
task_type='causal_lm',
save_trainable_weights_only=True # only save trainable weights
)
homo_nn_0 = HomoNN(
'nn_0',
runner_conf=conf,
train_data=reader_0.outputs["output_data"],
runner_module="homo_seq2seq_runner",
runner_class="Seq2SeqRunner",
)
homo_nn_0.guest.conf.set("launcher_name", "deepspeed") # tell schedule engine to run task with deepspeed
homo_nn_0.hosts[0].conf.set("launcher_name", "deepspeed") # tell schedule engine to run task with deepspeed
pipeline.add_tasks([reader_0, homo_nn_0])
pipeline.conf.set("task", dict(engine_run={"cores": 1})) # the number of gpus of each party
pipeline.compile()
pipeline.fit()
return pretrained_model_path
if __name__ == "__main__":
parser = argparse.ArgumentParser("LLMSUITE PIPELINE JOB")
parser.add_argument("-c", "--config", type=str,
help="config file", default="../../config.yaml")
parser.add_argument("-p", "--param", type=str,
help="config file for params", default="./bloom_lora_config.yaml")
args = parser.parse_args()
main(args.config, args.param)
================================================
FILE: examples/pellm/test_pellm_llmsuite.yaml
================================================
data:
- file: examples/data/AdvertiseGen/train.json
table_name: ad
namespace: experiment
role: guest_0
- file: examples/data/AdvertiseGen/train.json
table_name: ad
namespace: experiment
role: host_0
bloom_lora_vs_zero_shot:
bloom_lora:
pretrained: "bloom-560m"
script: "./test_bloom_lora.py"
conf: "./bloom_lora_config.yaml"
peft_path_format: "{{fate_base}}/fate_flow/model/{{job_id}}/guest/{{party_id}}/{{model_task_name}}/0/output/output_model/model_directory"
tasks:
- "advertise-gen"
bloom_zero_shot:
pretrained: "bloom-560m"
tasks:
- "advertise-gen"
================================================
FILE: python/MANIFEST.in
================================================
include fate_llm/dataset/data_config/*yaml
include python/fate_llm/evaluate/tasks/*/*yaml
================================================
FILE: python/fate_llm/__init__.py
================================================
================================================
FILE: python/fate_llm/algo/__init__.py
================================================
================================================
FILE: python/fate_llm/algo/dp/__init__.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .opacus_compatibility.transformers_compate import get_model_class
from .dp_trainer import DPTrainer, DPTrainingArguments
================================================
FILE: python/fate_llm/algo/dp/dp_trainer.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import opacus
import os
import torch
from dataclasses import dataclass, field
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
from torch.utils.data import DataLoader
from tqdm import tqdm
from typing import Optional, Callable
from .opacus_compatibility import add_layer_compatibility, add_optimizer_compatibility
from .opacus_compatibility.transformers_compate import prepare_position_ids
logger = logging.getLogger(__name__)
@dataclass
class DPTrainingArguments(Seq2SeqTrainingArguments):
target_epsilon: float = field(default=3)
target_delta: float = field(default=1e-5)
freeze_embedding: bool = field(default=True)
device_id: int = field(default=0)
class DPTrainer(object):
def __init__(
self,
model: torch.nn.Module,
training_args: DPTrainingArguments,
train_set,
loss_fn,
optimizer: torch.optim.Optimizer = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
data_collator: Callable = None,
use_tqdm: bool = False,
):
self.module = model
self.training_args = training_args
self.ori_optimizer = optimizer
self.lr_scheduler = scheduler
self.train_set = train_set
self.data_collator = data_collator
self.loss_fn = loss_fn
self.use_tqdm = use_tqdm
self.data_loader = DataLoader(
dataset=self.train_set,
shuffle=True,
batch_size=self.training_args.per_device_train_batch_size,
collate_fn=self.data_collator
)
if not self.training_args.use_cpu:
self.module.cuda(self.training_args.device_id)
if self.training_args.freeze_embedding:
self.freeze_model_embedding()
self.dp_model = None
self.dp_optimizer = None
self.privacy_engine = None
self._init_dp_model()
def _init_dp_model(self):
self.module.train()
# add compatibility for layer hooks
add_layer_compatibility(opacus)
self.privacy_engine = opacus.PrivacyEngine(accountant="rdp")
self.dp_model, self.dp_optimizer, _ = self.privacy_engine.make_private_with_epsilon(
module=self.module,
optimizer=self.ori_optimizer,
data_loader=self.data_loader,
target_delta=self.training_args.target_delta,
target_epsilon=self.training_args.target_epsilon,
max_grad_norm=self.training_args.max_grad_norm,
epochs=int(self.training_args.num_train_epochs),
)
add_optimizer_compatibility(self.dp_optimizer)
def train(self):
logger.info(f"begin dp training, total epochs={self.training_args.num_train_epochs}")
for epoch in range(int(self.training_args.num_train_epochs)):
logger.info(f"dp training on epoch={epoch}")
self._train_an_epoch()
def _train_an_epoch(self):
if self.use_tqdm:
data_loader = tqdm(self.data_loader)
else:
data_loader = self.data_loader
for batch_idx, batch_data in enumerate(tqdm(data_loader)):
input_ids = batch_data["input_ids"]
labels = batch_data["labels"]
if "attention_mask" not in batch_data:
attention_mask = torch.ones(input_ids.shape)
else:
attention_mask = batch_data["attention_mask"]
if not self.training_args.use_cpu:
input_ids = input_ids.to(self.module.device)
labels = labels.to(self.module.device)
attention_mask = attention_mask.to(self.module.device)
inputs = self._prepare_batch_input(input_ids)
logits = self.dp_model(**inputs).logits
loss = self.loss_fn(logits, labels, attention_mask)
loss = loss.mean()
loss.backward()
if (batch_idx + 1) % self.training_args.gradient_accumulation_steps == 0 or \
batch_idx + 1 == len(self.data_loader):
self.dp_optimizer.step()
if self.lr_scheduler is not None:
self.lr_scheduler.step()
self.dp_optimizer.zero_grad()
else:
self.dp_optimizer.step()
self.dp_optimizer.zero_grad()
def _prepare_batch_input(self, input_ids) -> dict:
position_ids = prepare_position_ids(self.module, input_ids)
if not self.training_args.use_cpu:
position_ids = position_ids.to(self.module.device)
return dict(input_ids=input_ids, position_ids=position_ids)
def freeze_model_embedding(self):
self.module.get_input_embeddings().requires_grad_(False)
def save_model(
self,
output_dir="./"
):
if hasattr(self.module, "save_pretrained"):
self.module.save_pretrained(output_dir)
else:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
torch.save(self.module.state_dict(), output_dir + '/pytorch_model.bin')
================================================
FILE: python/fate_llm/algo/dp/opacus_compatibility/__init__.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .grad_sample.embedding import compute_embedding_grad_sample
from .optimizers.optimizer import add_noise_wrapper
def add_layer_compatibility(opacus):
replace_method = []
for k, v in opacus.GradSampleModule.GRAD_SAMPLERS.items():
if v.__name__ == "compute_embedding_grad_sample":
replace_method.append(k)
for k in replace_method:
opacus.GradSampleModule.GRAD_SAMPLERS[k] = compute_embedding_grad_sample
def add_optimizer_compatibility(optimizer):
add_noise_wrapper(optimizer)
================================================
FILE: python/fate_llm/algo/dp/opacus_compatibility/grad_sample/__init__.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
================================================
FILE: python/fate_llm/algo/dp/opacus_compatibility/grad_sample/embedding.py
================================================
#
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import torch.nn as nn
from typing import Dict
# the function is modified from https://github.com/pytorch/opacus/blob/main/opacus/grad_sample/embedding.py#L25,
# avoid dtype error when backprops's dtype isn't torch.float32
def compute_embedding_grad_sample(
layer: nn.Embedding, activations: torch.Tensor, backprops: torch.Tensor
) -> Dict[nn.Parameter, torch.Tensor]:
"""
Computes per sample gradients for ``nn.Embedding`` layer.
Args:
layer: Layer
activations: Activations
backprops: Backpropagations
"""
activations = activations[0]
ret = {}
if layer.weight.requires_grad:
saved = torch.backends.cudnn.deterministic
torch.backends.cudnn.deterministic = True
batch_size = activations.shape[0]
if batch_size == 0:
ret[layer.weight] = torch.zeros_like(layer.weight).unsqueeze(0)
return ret
index = (
activations.unsqueeze(-1)
.expand(*activations.shape, layer.embedding_dim)
.reshape(batch_size, -1, layer.embedding_dim)
)
grad_sample = torch.zeros(
batch_size, *layer.weight.shape, device=layer.weight.device, dtype=backprops.dtype
)
grad_sample.scatter_add_(
1, index, backprops.reshape(batch_size, -1, layer.embedding_dim)
)
torch.backends.cudnn.deterministic = saved
ret[layer.weight] = grad_sample
return ret
================================================
FILE: python/fate_llm/algo/dp/opacus_compatibility/optimizers/__init__.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
================================================
FILE: python/fate_llm/algo/dp/opacus_compatibility/optimizers/optimizer.py
================================================
#
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import types
from opacus.optimizers.optimizer import (
_check_processed_flag,
_generate_noise,
_mark_as_processed
)
# modified from https://github.com/pytorch/opacus/blob/main/opacus/optimizers/optimizer.py#L424
# avoid dtype error when summed_grad's dtype isn't torch.float32
def add_noise(self):
"""
Adds noise to clipped gradients. Stores clipped and noised result in ``p.grad``
"""
for p in self.params:
_check_processed_flag(p.summed_grad)
noise = _generate_noise(
std=self.noise_multiplier * self.max_grad_norm,
reference=p.summed_grad,
generator=self.generator,
secure_mode=self.secure_mode,
)
noise = noise.to(p.summed_grad.dtype)
p.grad = (p.summed_grad + noise).view_as(p)
_mark_as_processed(p.summed_grad)
def add_noise_wrapper(optimizer):
optimizer.add_noise = types.MethodType(add_noise, optimizer)
================================================
FILE: python/fate_llm/algo/dp/opacus_compatibility/transformers_compate.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import transformers
from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM
from transformers.modeling_utils import unwrap_model
def get_model_class(model):
if isinstance(model, PELLM):
model = model._pe_lm
model = unwrap_model(model)
return model.__class__
def prepare_position_ids(model, input_ids):
if get_model_class(model) == transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel:
return _get_position_ids_for_gpt2(input_ids)
else:
raise ValueError(f"Can not prepare position_ids for model_type={model.__class__}")
def _get_position_ids_for_gpt2(input_ids):
past_length = 0
position_ids = torch.arange(past_length, input_ids.shape[-1] + past_length, dtype=torch.long,
device=input_ids.device)
position_ids = position_ids.unsqueeze(0)
position_ids = position_ids.repeat(input_ids.shape[0], 1)
return position_ids
================================================
FILE: python/fate_llm/algo/fdkt/__init__.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .fdkt_data_aug import (
FDKTSLM,
FDKTLLM,
FDKTTrainingArguments
)
__all__ = [
"FDKTSLM",
"FDKTLLM",
"FDKTTrainingArguments"
]
================================================
FILE: python/fate_llm/algo/fdkt/cluster/__init__.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
================================================
FILE: python/fate_llm/algo/fdkt/cluster/cluster.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List
from .cluster_method import get_cluster_runner
class SentenceCluster(object):
def __init__(self, model, cluster_method="kmeans", n_clusters=8, **other_cluster_args):
self.model = model
self.cluster_method = cluster_method
self.n_clusters = n_clusters
self.other_cluster_args = other_cluster_args
def get_embeddings(self, sentences: List[str]):
return self.model.encode(sentences)
def cluster(self, sentences):
embeddings = self.get_embeddings(sentences)
cluster_runner = get_cluster_runner(method=self.cluster_method,
n_clusters=self.n_clusters,
**self.other_cluster_args)
cluster_rets = cluster_runner.fit(embeddings)
return cluster_rets
================================================
FILE: python/fate_llm/algo/fdkt/cluster/cluster_method.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from sklearn.cluster import KMeans
class KMeansRunner(object):
def __init__(self, n_clusters, **other_cluster_args):
self.n_clusters = n_clusters
self.other_cluster_args = other_cluster_args
def fit(self, x):
model = KMeans(n_clusters=self.n_clusters, **self.other_cluster_args)
model.fit(x)
return model.labels_
def get_cluster_runner(method, n_clusters, **other_cluster_args):
if method.lower() == "kmeans":
return KMeansRunner(n_clusters, **other_cluster_args)
else:
raise ValueError(f"cluster method={method} is not implemented")
================================================
FILE: python/fate_llm/algo/fdkt/fdkt_data_aug.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os.path
import shutil
import torch
import logging
from dataclasses import dataclass, field
from ...trainer.seq2seq_trainer import Seq2SeqTrainingArguments
from typing import Optional, Callable
from fate.arch import Context
from transformers import PreTrainedTokenizer
from .utils.invalid_data_filter import filter_invalid_data
from .utils.text_generate import slm_text_generate, general_text_generate
from .cluster.cluster import SentenceCluster
from fate_llm.inference.inference_base import Inference
logger = logging.getLogger(__name__)
SLM_SYNTHETIC_DATA = "slm_synthetic_data"
LLM_AUG_DATA = "llm_aug_data"
@dataclass
class FDKTTrainingArguments(Seq2SeqTrainingArguments):
"""
slm parameters
"""
dp_training: bool = field(default=True)
target_epsilon: float = field(default=3)
target_delta: float = field(default=1e-5)
freeze_embedding: bool = field(default=True)
device_id: int = field(default=0)
slm_generation_config: dict = field(default=None)
slm_generation_batch_size: dict = field(default=None)
inference_method: str = field(default="native")
inference_inst_init_conf: dict = field(default=None)
"""
slm generation config
"""
seq_num_for_single_category: int = field(default=None)
"""
dp loss params
"""
label_smoothing_factor = 0.02
loss_reduce = True
"""
llm parameters
"""
sample_num_per_cluster: int = field(default=None)
filter_data_batch_size: int = field(default=2)
filter_prompt_max_length: int = field(default=2048)
filter_generation_config: dict = field(default=None)
aug_generation_config: dict = field(default=None)
aug_prompt_num: int = field(default=None)
aug_data_batch_size: int = field(default=2)
aug_prompt_max_length: int = field(default=2048)
def to_dict(self):
from dataclasses import fields
from enum import Enum
d = {field.name: getattr(self, field.name) for field in fields(self) if field.init}
for k, v in d.items():
if isinstance(v, Enum):
d[k] = v.value
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
d[k] = [x.value for x in v]
if k.endswith("_token"):
d[k] = f"<{k.upper()}>"
return d
class FDKTSLM(object):
def __init__(
self,
ctx: Context,
model: torch.nn.Module,
training_args: FDKTTrainingArguments,
train_set,
optimizer: torch.optim.Optimizer = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
data_collator: Callable = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
):
super(FDKTSLM, self).__init__()
self.ctx = ctx
self.training_args = training_args
self.train_set = train_set
self.model = model
self.tokenizer = tokenizer
self.optimizer = optimizer
self.scheduler = scheduler
self.data_collator = data_collator
if not self.training_args.use_cpu:
self.model.cuda(self.training_args.device_id)
def aug_data(self):
logging.info("Start aug data process")
logging.debug(f"dp_training={self.training_args.dp_training}")
if self.training_args.dp_training:
logging.info("Start dp training")
self.dp_train()
logging.info("End dp training")
inference_inst = self._create_inference_inst()
prefix_prompt_dict = self.train_set.get_generate_prompt(
tokenize=True if inference_inst is None else False)
generated_texts = slm_text_generate(
inference_inst,
self.model,
self.tokenizer,
prompt_dict=prefix_prompt_dict,
seq_num_for_single_category=self.training_args.seq_num_for_single_category,
batch_size=self.training_args.slm_generation_batch_size,
use_cpu=self.training_args.use_cpu,
generation_config=self.training_args.slm_generation_config
)
self._destroy_inference_inst()
if not self.training_args.use_cpu:
self.model.cpu()
torch.cuda.empty_cache()
generated_texts = filter_invalid_data(generated_texts)
self.sync_synthetic_dataset(generated_texts)
return self.sync_aug_data()
def dp_train(self):
from ..dp import DPTrainer, DPTrainingArguments, get_model_class
from .utils.dp_loss import SequenceCrossEntropyLoss
dp_training_args = DPTrainingArguments(
target_delta=self.training_args.target_delta,
target_epsilon=self.training_args.target_epsilon,
freeze_embedding=self.training_args.freeze_embedding,
device_id=self.training_args.device_id,
num_train_epochs=self.training_args.num_train_epochs,
per_device_train_batch_size=self.training_args.per_device_train_batch_size,
output_dir="/" if self.training_args.output_dir is None else self.training_args.output_dir
)
loss_fn = SequenceCrossEntropyLoss(
get_model_class(self.model).__name__,
label_smoothing=self.training_args.label_smoothing_factor,
reduce=self.training_args.loss_reduce
)
dp_trainer = DPTrainer(
model=self.model,
training_args=dp_training_args,
train_set=self.train_set,
optimizer=self.optimizer,
scheduler=self.scheduler,
data_collator=self.data_collator,
loss_fn=loss_fn
)
dp_trainer.train()
def _create_inference_inst(self):
if self.training_args.inference_method == "native":
return None
elif self.training_args.inference_method == "vllm":
from .inference_inst import vllm_init
self.model.cpu()
model_temp_path = self.training_args.output_dir + "./model_for_inference"
self.tokenizer.save_pretrained(model_temp_path)
self.model.save_pretrained(model_temp_path)
return vllm_init(model_temp_path) if self.training_args.inference_inst_init_conf is None \
else vllm_init(model_temp_path, **self.training_args.inference_inst_init_conf)
else:
raise ValueError(f"not supported inference_method={self.training_args.inference_method}")
def _destroy_inference_inst(self):
if self.training_args.inference_method == "vllm":
shutil.rmtree(self.training_args.output_dir + "./model_for_inference")
elif not self.training_args.use_cpu:
self.model.cpu()
def sync_synthetic_dataset(self, data):
self.ctx.arbiter.put(SLM_SYNTHETIC_DATA, data)
def sync_aug_data(self):
return self.ctx.arbiter.get(LLM_AUG_DATA)
def save_model(
self,
output_dir="./"
):
if hasattr(self.model, "save_pretrained"):
self.model.save_pretrained(output_dir)
else:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
torch.save(self.model.state_dict(), output_dir + '/pytorch_model.bin')
class FDKTLLM(object):
def __init__(
self,
ctx: Context,
embedding_model: torch.nn.Module,
training_args: FDKTTrainingArguments,
dataset,
model: Optional[torch.nn.Module] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
inference_inst: Optional[Inference] = None,
):
super(FDKTLLM, self).__init__()
self.ctx = ctx
self.inference_inst = inference_inst
self.embedding_model = embedding_model
self.dataset = dataset
self.training_args = training_args
self.model = model
self.tokenizer = tokenizer
if self.inference_inst is None and (self.model is None or self.tokenizer is None):
raise ValueError("Inference_inst and Model are both empty, should provided one")
if self.model is not None and self.training_args.device_id is not None and not self.training_args.use_cpu:
self.model.cuda(self.training_args.device_id)
def sync_synthetic_data(self):
return self.ctx.guest.get(SLM_SYNTHETIC_DATA)
def sync_aug_data(self, aug_data):
self.ctx.guest.put(LLM_AUG_DATA, aug_data)
def aug_data(self):
logging.info("sync slm synthetic_data")
slm_data = self.sync_synthetic_data()
logging.info("filter slm synthetic data")
filter_data = self.filter_data(slm_data)
logging.info("prepare prompts for aug")
aug_prompts = self.dataset.prepare_augment(
filter_data["inputs"],
filter_data["labels"],
aug_prompt_num=self.training_args.aug_prompt_num
)
logging.info("aug_data")
aug_data = self._aug(aug_prompts)
aug_data = filter_invalid_data(aug_data)
self.sync_aug_data(aug_data)
def _aug(self, aug_prompts):
aug_responses = general_text_generate(
inference_inst=self.inference_inst,
model=self.model,
tokenizer=self.tokenizer,
generation_config=self.training_args.aug_generation_config,
prompts=aug_prompts,
batch_size=self.training_args.aug_data_batch_size,
use_cpu=self.training_args.use_cpu,
prompt_max_length=self.training_args.aug_prompt_max_length
)
aug_data = self.dataset.abstract_from_augmented(aug_responses)
return aug_data
def filter_data(self, slm_data):
clustered_sentences, clustered_labels = self.cluster_data(slm_data)
filter_prompts = self.dataset.prepare_query_to_filter_clustered(clustered_sentences, clustered_labels)
filter_responses = general_text_generate(
inference_inst=self.inference_inst,
model=self.model,
tokenizer=self.tokenizer,
generation_config=self.training_args.filter_generation_config,
prompts=filter_prompts,
batch_size=self.training_args.filter_data_batch_size,
use_cpu=self.training_args.use_cpu,
prompt_max_length=self.training_args.filter_prompt_max_length
)
filtered_sentences, filtered_labels = self.dataset.parse_clustered_response(
clustered_sentence=clustered_sentences,
clustered_labels=clustered_labels,
response_list=filter_responses
)
return dict(
inputs=filtered_sentences,
labels=filtered_labels
)
def cluster_data(self, slm_data):
sentences = slm_data["inputs"]
labels = slm_data["labels"]
n_clusters = (len(sentences) + self.training_args.sample_num_per_cluster - 1) // self.training_args.sample_num_per_cluster
cluster_ret = SentenceCluster(model=self.embedding_model, n_clusters=n_clusters).cluster(sentences)
clustered_sentences = [[] for _ in range(n_clusters)]
clustered_labels = [[] for _ in range(n_clusters)]
for sentence_id, cluster_id in enumerate(cluster_ret):
clustered_sentences[cluster_id].append(sentences[sentence_id])
clustered_labels[cluster_id].append(labels[sentence_id])
return clustered_sentences, clustered_labels
================================================
FILE: python/fate_llm/algo/fdkt/inference_inst.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
def api_init(api_url: str, model_name: str, api_key: str = 'EMPTY', api_timeout=3600):
from fate_llm.inference.api import APICompletionInference
return APICompletionInference(
api_url=api_url,
model_name=model_name,
api_key=api_key,
api_timeout=api_timeout
)
def vllm_init(model_path: str, num_gpu=1, dtype='float16', gpu_memory_utilization=0.9):
from fate_llm.inference.vllm import VLLMInference
return VLLMInference(
model_path=model_path,
num_gpu=num_gpu,
dtype=dtype,
gpu_memory_utilization=gpu_memory_utilization
)
================================================
FILE: python/fate_llm/algo/fdkt/utils/__init__.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
================================================
FILE: python/fate_llm/algo/fdkt/utils/dp_loss.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
NUMERICAL_STABILITY_CONSTANT = 1e-13
class SequenceCrossEntropyLoss(nn.Module):
def __init__(self, model_type, label_smoothing=-1, reduce=None):
super().__init__()
self.model_type = model_type
self.label_smoothing = label_smoothing
self.reduce = reduce
def forward(self, logits, targets, mask):
return sequence_cross_entropy_with_logits(logits, targets, mask, self.label_smoothing, self.reduce, self.model_type)
def sequence_cross_entropy_with_logits(logits, targets, mask, label_smoothing, reduce, model_type):
if model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
logits = logits[:, :-1].contiguous()
targets = targets[:, 1:]
mask = torch.ones_like(targets).float()
logits_flat = logits.view(-1, logits.size(-1))
log_probs_flat = F.log_softmax(logits_flat, dim=-1)
targets_flat = targets.reshape(-1, 1).long()
if label_smoothing > 0.0:
num_classes = logits.size(-1)
smoothing_value = label_smoothing / float(num_classes)
one_hot_targets = torch.zeros_like(log_probs_flat).scatter_(-1, targets_flat, 1.0 - label_smoothing)
smoothed_targets = one_hot_targets + smoothing_value
negative_log_likelihood_flat = -log_probs_flat * smoothed_targets
negative_log_likelihood_flat = negative_log_likelihood_flat.sum(-1, keepdim=True)
else:
negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat)
negative_log_likelihood = negative_log_likelihood_flat.view(-1, logits.shape[1])
loss = negative_log_likelihood * mask
if reduce:
loss = loss.sum(1) / (mask.sum(1) + NUMERICAL_STABILITY_CONSTANT)
if reduce is "batch":
loss = loss.mean()
return loss
================================================
FILE: python/fate_llm/algo/fdkt/utils/invalid_data_filter.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
INVALID_CHARACTERS = "".join([' ', '-', '.', '_', '~', '/', '\\', '*', '|', '#'])
LEAST_WORDS = 10
def filter_invalid_data(data_dict):
sample_num = len(data_dict["inputs"])
new_data_dict = dict(
inputs=list(),
labels=list()
)
for idx in range(sample_num):
text = data_dict["inputs"][idx].strip(INVALID_CHARACTERS)
if len(text.split()) < LEAST_WORDS:
continue
new_data_dict["inputs"].append(text)
new_data_dict["labels"].append(data_dict["labels"][idx])
return new_data_dict
================================================
FILE: python/fate_llm/algo/fdkt/utils/text_generate.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from tqdm import tqdm
from typing import Any, Dict, List
def slm_text_generate(
inference_inst,
model,
tokenizer,
prompt_dict,
seq_num_for_single_category,
batch_size,
use_cpu,
generation_config
):
generated_ret = dict(
inputs=list(),
labels=list(),
)
if inference_inst is not None:
for label, prompt in prompt_dict.items():
generated_sequences = inference_inst.inference([prompt] * seq_num_for_single_category, generation_config)
for g in generated_sequences:
generated_ret["inputs"].append(g)
generated_ret["labels"].append(label)
else:
model.eval()
for label, prompt_ids in prompt_dict.items():
prompt_length = len(prompt_ids)
batch_num = (seq_num_for_single_category + batch_size - 1) // batch_size
for batch_idx in tqdm(range(batch_num)):
if batch_idx + 1 == batch_num:
cur_batch_size = seq_num_for_single_category - batch_idx * batch_size
else:
cur_batch_size = batch_size
input_ids = prompt_ids.repeat(cur_batch_size, 1)
if not use_cpu:
input_ids = input_ids.to(model.device)
output_sequences = model.generate(
input_ids=input_ids,
**generation_config
)
output_sequences = output_sequences[:, prompt_length:]
generated_sequences = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
for g in generated_sequences:
generated_ret["inputs"].append(g)
generated_ret["labels"].append(label)
return generated_ret
def general_text_generate(
inference_inst,
model,
tokenizer,
generation_config: Dict[Any, Any],
prompts: List[str],
batch_size,
use_cpu: bool,
prompt_max_length
):
if inference_inst is not None:
if prompt_max_length is not None:
prompts = [prompt[:prompt_max_length] for prompt in prompts]
generate_texts = inference_inst.inference(prompts, generation_config)
else:
model.eval()
generate_texts = []
batch_num = (len(prompts) + batch_size - 1) // batch_size
for batch_idx in range(batch_num):
batch_data = prompts[batch_idx * batch_size: (batch_idx + 1) * batch_size]
inputs = tokenizer(batch_data, return_tensors="pt", padding="longest", truncation=True,
max_length=prompt_max_length)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
if not use_cpu:
input_ids = input_ids.to(model.device)
attention_mask = attention_mask.to(model.device)
output = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
**generation_config
)
batch_responses = tokenizer.batch_decode(output[:, input_ids.shape[1]:], skip_special_tokens=True)
generate_texts.extend(batch_responses)
return generate_texts
================================================
FILE: python/fate_llm/algo/fedavg/__init__.py
================================================
================================================
FILE: python/fate_llm/algo/fedavg/fedavg.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from fate.ml.nn.homo.fedavg import FedAVGServer, FedAVGArguments, FedArguments
from fate.arch import Context
from fate_llm.trainer.seq2seq_trainer import HomoSeq2SeqTrainerClient, Seq2SeqTrainingArguments
from fate.ml.aggregator import AggregatorClientWrapper
import logging
from typing import List, Optional, Tuple, Callable, Dict
from fate.arch import Context
from torch.optim import Optimizer
from torch.utils.data import Dataset
from torch.optim.lr_scheduler import _LRScheduler
from transformers.trainer_callback import TrainerCallback
from torch import nn
from torch.utils.data import DataLoader
from transformers import TrainerState, TrainerControl, PreTrainedTokenizer, EvalPrediction
logger = logging.getLogger(__name__)
Seq2SeqFedAVGServer = FedAVGServer
class Seq2SeqFedAVGClient(HomoSeq2SeqTrainerClient):
def __init__(
self,
ctx: Context,
model: nn.Module,
training_args: Seq2SeqTrainingArguments,
fed_args: FedArguments,
train_set: Dataset,
val_set: Dataset = None,
optimizer: torch.optim.Optimizer = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
data_collator: Callable = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
callbacks: Optional[List[TrainerCallback]] = [],
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
local_mode: bool = False,
save_trainable_weights_only: bool = False,
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
):
# in case you forget to set evaluation_strategy
if val_set is not None and training_args.evaluation_strategy == "no":
training_args.evaluation_strategy = "epoch"
HomoSeq2SeqTrainerClient.__init__(
self,
ctx,
model,
training_args,
fed_args,
train_set,
val_set,
optimizer,
data_collator,
scheduler,
tokenizer,
callbacks,
compute_metrics,
local_mode,
save_trainable_weights_only,
preprocess_logits_for_metrics
)
def init_aggregator(self, ctx: Context, fed_args: FedArguments):
aggregate_type = "weighted_mean"
aggregator_name = "fedavg"
aggregator = fed_args.aggregator
return AggregatorClientWrapper(
ctx, aggregate_type, aggregator_name, aggregator, sample_num=len(self.train_dataset), args=self._args
)
def on_federation(
self,
ctx: Context,
aggregator: AggregatorClientWrapper,
fed_args: FedArguments,
args: Seq2SeqTrainingArguments,
model: Optional[nn.Module] = None,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[_LRScheduler] = None,
dataloader: Optional[Tuple[DataLoader]] = None,
control: Optional[TrainerControl] = None,
state: Optional[TrainerState] = None,
**kwargs,
):
aggregator.model_aggregation(ctx, model)
================================================
FILE: python/fate_llm/algo/fedcollm/__init__.py
================================================
================================================
FILE: python/fate_llm/algo/fedcollm/fedcollm.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import logging
from fate_llm.algo.fedcollm.fedcollm_trainer import FedCoLLMTrainer
from typing import Dict, Optional, List, Callable, Union
from fate.arch import Context
from fate.ml.nn.trainer.trainer_base import FedArguments
from torch.utils.data import Dataset
from transformers.trainer_callback import TrainerCallback
from transformers import PreTrainedTokenizer
from transformers import Seq2SeqTrainer
from transformers.trainer_utils import EvalPrediction
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_utils import unwrap_model
from fate_llm.algo.fedmkt.utils.generate_logit_utils import generate_pub_data_logits
from fate.ml.aggregator import AggregatorClientWrapper, AggregatorServerWrapper
from fate_llm.algo.fedcollm.fedcollm_training_args import FedCoLLMTrainingArguments
from types import SimpleNamespace
logger = logging.getLogger(__name__)
class FedCoLLMBase(object):
@staticmethod
def update_model(model, updated_params):
for updated_p, p in zip(updated_params, [p for p in model.parameters() if p.requires_grad]):
p.data.copy_(t.Tensor(updated_p))
class SLM(FedCoLLMBase):
def __init__(
self,
ctx: Context,
model: torch.nn.Module,
training_args: FedCoLLMTrainingArguments,
fed_args: FedArguments = None,
train_set=None,
val_set: Dataset = None,
optimizer: torch.optim.Optimizer = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
data_collator: Callable = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = [],
save_trainable_weights_only: bool = False,
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
):
super(SLM, self).__init__()
self.ctx = ctx
self.training_args = training_args
self.fed_args = fed_args
self.model = model
self.tokenizer = tokenizer
self.model_init = model_init
self.callbacks = callbacks
self.compute_metrics = compute_metrics
self.save_trainable_weights_only = save_trainable_weights_only
self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
self.data_collator = data_collator
self.optimizer = optimizer
self.scheduler = scheduler
self.train_set = train_set
self.val_set = val_set
self.aggregator = self._init_aggregator(ctx, fed_args)
def train(self):
global_epochs = self.training_args.global_epochs
for i, iter_ctx in self.ctx.on_iterations.ctxs_range(global_epochs):
logger.info(f"begin {i}-th global kd process")
training_args = self._get_slm_training_args()
trainer = Seq2SeqTrainer(
model=self.model,
tokenizer=self.tokenizer,
data_collator=self.data_collator,
train_dataset=self.train_set,
args=training_args,
model_init=self.model_init if not i else None,
compute_metrics=self.compute_metrics,
callbacks=self.callbacks,
optimizers=(self.optimizer, self.scheduler),
preprocess_logits_for_metrics=self.preprocess_logits_for_metrics
)
logger.info(f"begin {i}-th private data training process")
trainer.train()
self.model = unwrap_model(trainer.model)
self.aggregator.model_aggregation(iter_ctx, self.model)
def _sync_slm_updated_params(self, iter_ctx):
updated_params = iter_ctx.arbiter.get("slm_updated_params")
self.update_model(self.model, updated_params)
def _get_slm_training_args(self):
return self.training_args.to_slm_seq_training_args()
def _init_aggregator(self, ctx: Context, fed_args: FedArguments):
aggregate_type = "weighted_mean"
aggregator_name = "fedavg"
aggregator = fed_args.aggregator
return AggregatorClientWrapper(
ctx, aggregate_type, aggregator_name, aggregator,
sample_num=len(self.train_set), args=self.training_args
)
class LLM(FedCoLLMBase):
def __init__(
self,
ctx: Context,
llm_model: torch.nn.Module,
slm_model: torch.nn.Module,
training_args: FedCoLLMTrainingArguments,
fed_args: FedArguments = None,
train_set=None,
val_set: Dataset = None,
llm_optimizer: torch.optim.Optimizer = None,
slm_optimizer: torch.optim.Optimizer = None,
llm_lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
slm_lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
data_collator: Callable = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
llm_model_init: Optional[Callable[[], PreTrainedModel]] = None,
slm_model_init: Optional[Callable[[], PreTrainedModel]] = None,
llm_compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
slm_compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
llm_callbacks: Optional[List[TrainerCallback]] = [],
slm_callbacks: Optional[List[TrainerCallback]] = [],
save_trainable_weights_only: bool = False,
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
):
super(LLM, self).__init__()
self.ctx = ctx
self.llm_model = llm_model
self.slm_model = slm_model
self.training_args = training_args
self.fed_args = fed_args
self.train_set = train_set
self.val_set = val_set
self.llm_optimizer = llm_optimizer
self.slm_optimizer = slm_optimizer
self.llm_lr_scheduler = llm_lr_scheduler
self.slm_lr_scheduler = slm_lr_scheduler
self.data_collator = data_collator
self.tokenizer = tokenizer
self.llm_model_init = llm_model_init
self.slm_model_init = slm_model_init
self.llm_compute_metrics = llm_compute_metrics
self.slm_compute_metrics = slm_compute_metrics
self.llm_callbacks = llm_callbacks
self.slm_callbacks = slm_callbacks
self.save_trainable_weights_only = save_trainable_weights_only
self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
self.aggregator = self._init_aggregator(ctx)
def _init_aggregator(self, ctx: Context):
return AggregatorServerWrapper(ctx)
def _get_logits(self, model):
if self.training_args.device.type == "cuda":
model.cuda(self.training_args.device.type)
fn_kwargs = {"model": model,
"training_args": self.training_args,
"data_collator": self.data_collator}
return self.train_set.map(
generate_pub_data_logits,
batched=True,
batch_size=self.training_args.per_device_train_batch_size,
num_proc=None,
load_from_cache_file=True,
fn_kwargs=fn_kwargs
)
def on_epoch_begin(self, iter_ctx):
self.aggregator.model_aggregation(iter_ctx)
updated_slm_params = iter_ctx()
self.update_model(self.slm_model, updated_slm_params)
def _sync_slm_updated_params(self, iter_ctx):
updated_params = [p for p in self.slm_model.parameters() if p.requires_grad]
iter_ctx.guest.put("slm_updated_params", updated_params)
if any(p.role == 'host' for p in self.ctx.parties):
iter_ctx.hosts.put("slm_updated_params", updated_params)
def _train_slm(self, iter_ctx, llm_pub_logits, epoch_idx):
top_k_args = SimpleNamespace(
top_k_logits_keep=self.training_args.top_k_logits_keep,
top_k_strategy=self.training_args.top_k_strategy
)
self.train_set.set_return_with_idx()
trainer = FedCoLLMTrainer(
model=self.slm_model,
tokenizer=self.tokenizer,
data_collator=self.data_collator,
train_dataset=self.train_set,
args=self.training_args.to_slm_seq_training_args(),
model_init=self.slm_model_init if not epoch_idx else None,
compute_metrics=self.slm_compute_metrics,
callbacks=self.slm_callbacks,
optimizers=(self.slm_optimizer, self.slm_lr_scheduler),
preprocess_logits_for_metrics=self.preprocess_logits_for_metrics,
top_k_args=top_k_args,
distill_lambda=self.training_args.distill_lambda,
distill_temperature=self.training_args.distill_temperature,
max_length=max(len(d["input_ids"]) for d in self.train_set),
vocab_size=self.training_args.vocab_size,
dtype=next(self.slm_model.parameters()).dtype,
other_logits=llm_pub_logits
)
trainer.train()
self.slm_model = unwrap_model(trainer.model)
self.train_set.reset_return_with_idx()
self._sync_slm_updated_params(iter_ctx)
def _train_llm(self, slm_pub_logits, epoch_idx):
top_k_args = SimpleNamespace(
top_k_logits_keep=self.training_args.top_k_logits_keep,
top_k_strategy=self.training_args.top_k_strategy
)
self.train_set.set_return_with_idx()
trainer = FedCoLLMTrainer(
model=self.llm_model,
tokenizer=self.tokenizer,
data_collator=self.data_collator,
train_dataset=self.train_set,
args=self.training_args.to_llm_seq_training_args(),
model_init=self.llm_model_init if not epoch_idx else None,
compute_metrics=self.llm_compute_metrics,
callbacks=self.llm_callbacks,
optimizers=(self.llm_optimizer, self.llm_lr_scheduler),
preprocess_logits_for_metrics=self.preprocess_logits_for_metrics,
top_k_args=top_k_args,
distill_lambda=self.training_args.distill_lambda,
distill_temperature=self.training_args.distill_temperature,
max_length=max(len(d["input_ids"]) for d in self.train_set),
vocab_size=self.training_args.vocab_size,
dtype=next(self.slm_model.parameters()).dtype,
other_logits=slm_pub_logits
)
trainer.train()
self.llm_model = unwrap_model(trainer.model)
self.train_set.reset_return_with_idx()
def train(self):
global_epochs = self.training_args.global_epochs
for i, iter_ctx in self.ctx.on_iterations.ctxs_range(global_epochs):
logger.info(f"begin {i}-th global kd process")
self.on_epoch_begin(iter_ctx)
logger.info(f"get pub data logits for llm of global epoch={i}")
llm_pub_data_logits = self._get_logits(self.llm_model)
logger.info(f"train slm of global epoch={i}")
self._train_slm(iter_ctx, llm_pub_data_logits, i)
logger.info(f"get pub data logits for trained slm of global epoch={i}")
slm_pub_data_logits = self._get_logits(self.slm_model)
logger.info(f"train llm of global epoch={i}")
self._train_llm(slm_pub_data_logits, i)
================================================
FILE: python/fate_llm/algo/fedcollm/fedcollm_trainer.py
================================================
#
# NOTE: The implementations of FedMKTTrainer is modified from FuseAI/FuseLLM
# Copyright FuseAI
#
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import torch
from torch.nn.functional import kl_div, log_softmax, softmax
from transformers import Seq2SeqTrainer
from fate_llm.algo.fedmkt.utils.generate_logit_utils import LogitsSelection
from fate_llm.algo.fedmkt.utils.vars_define import (
PER_STEP_LOGITS,
PER_STEP_INDICES,
)
from types import SimpleNamespace
logger = logging.getLogger(__name__)
def computing_kd_loss(src_logits, dst_logits, loss_mask):
src_logits = src_logits[loss_mask]
dst_logits = dst_logits[loss_mask]
return kl_div(
log_softmax(src_logits, dim=-1, dtype=torch.float32),
dst_logits,
log_target=False,
reduction="none").sum(dim=-1)
def recovery_logits(
top_k_logits,
top_k_indices,
batch_size,
max_length,
vocab_size,
dtype,
device,
pad_id,
distill_temperature
):
logits = torch.zeros(batch_size, max_length, vocab_size).to(dtype).to(device)
for i in range(batch_size):
base_seq_len = len(top_k_logits[i])
for j in range(max_length):
if j < base_seq_len:
base_logits = torch.tensor(top_k_logits[i][j], dtype=dtype)
base_prob = softmax(base_logits / distill_temperature, -1)
base_indices = torch.tensor(top_k_indices[i][j])
base_prob = base_prob.to(device)
base_indices = base_indices.cuda(device)
logits[i][j] = logits[i][j].scatter_(-1, base_indices, base_prob)
else: # padding position
logits[i][j][pad_id] = 1.0
return logits
class FedCoLLMTrainer(Seq2SeqTrainer):
distill_lambda: float = 1.0
distill_temperature: float = 1.0
other_logits = None
dtype: torch.dtype = torch.bfloat16
vocab_size: int = None
max_length: int = None
top_k_args: SimpleNamespace = None
def __init__(self, **kwargs):
distill_lambda = kwargs.pop("distill_lambda", 1.0)
distill_temperature = kwargs.pop("distill_temperature", 1.0)
other_logits = kwargs.pop("other_logits")
vocab_size = kwargs.pop("vocab_size")
max_length = kwargs.pop("max_length")
top_k_args = kwargs.pop("top_k_args")
super(FedCoLLMTrainer, self).__init__(**kwargs)
self.distill_lambda = distill_lambda
self.distill_temperature = distill_temperature
self.other_logits = other_logits
self.pad_id = self.tokenizer.pad_token_id
self.vocab_size = vocab_size
self.max_length = max_length
self.top_k_args = top_k_args
def compute_loss(self, model, inputs, return_outputs=False):
lm_outputs = model(**inputs['inputs'])
lm_loss = lm_outputs.loss
logits = lm_outputs.logits
other_logits = self.other_logits[inputs["indexes"]]
batch_size = logits.shape[0]
top_k_logits, top_k_indices = LogitsSelection.select_logits(logits, self.top_k_args)
dst_logits = recovery_logits(
other_logits[PER_STEP_INDICES],
other_logits[PER_STEP_INDICES],
batch_size,
self.max_length,
self.vocab_size,
self.dtype,
logits.device,
self.pad_id,
self.distill_temperature
)
src_logits = recovery_logits(
top_k_logits,
top_k_indices,
batch_size,
self.max_length,
self.vocab_size,
self.dtype,
logits.device,
self.pad_id,
self.distill_temperature
)
loss_mask = (inputs["inputs"]["labels"] != -100)
kl_loss = computing_kd_loss(src_logits, dst_logits, loss_mask=loss_mask).sum()
return lm_loss + self.distill_lambda * kl_loss
================================================
FILE: python/fate_llm/algo/fedcollm/fedcollm_training_args.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from dataclasses import dataclass, field
from ...trainer.seq2seq_trainer import Seq2SeqTrainingArguments
@dataclass
class FedCoLLMTrainingArguments(Seq2SeqTrainingArguments):
"""
top-k logits select params
"""
top_k_logits_keep: int = field(default=128)
top_k_strategy: str = field(default="highest")
vocab_size: int = field(default=None)
"""
distillation params
"""
distill_lambda: float = field(default=1.0)
distill_temperature: float = field(default=1.0)
server_public_data_local_epoch: int = field(default=1)
client_public_data_local_epoch: int = field(default=1)
client_priv_data_local_epoch: int = field(default=1)
global_epochs: int = field(default=1)
extra_args = ["top_k_logits_keep", "top_k_strategy", "vocab_size",
"distill_lambda", "distill_temperature", "server_public_data_local_epoch",
"client_public_data_local_epoch", "client_priv_data_local_epoch",
"global_epochs"]
def to_dict(self):
from dataclasses import fields
from enum import Enum
d = {field.name: getattr(self, field.name) for field in fields(self) if field.init}
for k, v in d.items():
if isinstance(v, Enum):
d[k] = v.value
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
d[k] = [x.value for x in v]
if k.endswith("_token"):
d[k] = f"<{k.upper()}>"
return d
def _pop_extra(self):
args = self.to_dict()
for arg in self.extra_args:
args.pop(arg)
return args
def to_slm_seq_training_args(self):
args = self._pop_extra()
args["num_train_epochs"] = self.client_priv_data_local_epoch
return Seq2SeqTrainingArguments(**args)
def to_fedco_slm_training_args(self):
args = self._pop_extra()
args["num_train_epochs"] = self.client_pub_data_local_epoch
return Seq2SeqTrainingArguments(**args)
def to_fedco_llm_training_args(self):
args = self._pop_extra()
args["num_train_epochs"] = self.server_pub_data_local_epoch
return Seq2SeqTrainingArguments(**args)
================================================
FILE: python/fate_llm/algo/fedcot/__init__.py
================================================
================================================
FILE: python/fate_llm/algo/fedcot/encoder_decoder/__init__.py
================================================
================================================
FILE: python/fate_llm/algo/fedcot/encoder_decoder/init/__init__.py
================================================
================================================
FILE: python/fate_llm/algo/fedcot/encoder_decoder/init/default_init.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from fate_llm.algo.inferdpt.init._init import InferInit
from fate_llm.inference.api import APICompletionInference
from fate_llm.algo.fedcot.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderClient, SLMEncoderDecoderServer
class FedCoTEDAPIClientInit(InferInit):
api_url = ''
api_model_name = ''
api_key = 'EMPTY'
def __init__(self, ctx):
super().__init__(ctx)
self.ctx = ctx
def get_inst(self):
inference = APICompletionInference(api_url=self.api_url, model_name=self.api_model_name, api_key=self.api_key)
client = SLMEncoderDecoderClient(self.ctx, inference)
return client
class FedCoTEDAPIServerInit(InferInit):
api_url = ''
api_model_name = ''
api_key = 'EMPTY'
def __init__(self, ctx):
super().__init__(ctx)
self.ctx = ctx
def get_inst(self):
inference = APICompletionInference(api_url=self.api_url, model_name=self.api_model_name, api_key=self.api_key)
return SLMEncoderDecoderServer(self.ctx, inference)
================================================
FILE: python/fate_llm/algo/fedcot/encoder_decoder/slm_encoder_decoder.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
from jinja2 import Template
from tqdm import tqdm
from fate.arch import Context
from typing import List, Dict, Union
from fate.ml.nn.dataset.base import Dataset
from fate_llm.algo.inferdpt.utils import InferDPTKit
from openai import OpenAI
import logging
from fate_llm.inference.inference_base import Inference
from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer
from fate_llm.dataset.hf_dataset import HuggingfaceDataset
logger = logging.getLogger(__name__)
class SLMEncoderDecoderClient(InferDPTClient):
def __init__(self, ctx: Context, local_inference_inst: Inference) -> None:
self.ctx = ctx
self.comm_idx = 0
self.local_inference_inst = local_inference_inst
self.local_inference_kwargs = {}
def encode(self, docs: List[Dict[str, str]], format_template: str = None, verbose=False, perturb_doc_key: str ='perturbed_doc') -> List[Dict[str, str]]:
template = Template(format_template)
copy_docs = copy.deepcopy(docs)
doc_to_infer = []
for doc in tqdm(copy_docs):
rendered_doc = template.render(**doc)
doc_to_infer.append(rendered_doc)
# perturb using local model inference
self.doc_to_infer = doc_to_infer
infer_result = self.local_inference_inst.inference(doc_to_infer, self.local_inference_kwargs)
for doc, pr in zip(copy_docs, infer_result):
doc[perturb_doc_key] = pr
self.doc_with_p = copy_docs
return copy_docs
def decode(self, p_docs: List[Dict[str, str]], instruction_template: str = None, decode_template: str = None, verbose=False,
perturbed_response_key: str = 'perturbed_response', result_key: str = 'result',
remote_inference_kwargs: dict = {}, local_inference_kwargs: dict = {}):
return super().decode(p_docs, instruction_template, decode_template, verbose, perturbed_response_key, result_key, remote_inference_kwargs, local_inference_kwargs)
def inference(self, docs: Union[List[Dict[str, str]], HuggingfaceDataset],
encode_template: str,
instruction_template: str,
decode_template: str,
verbose: bool = False,
remote_inference_kwargs: dict = {},
local_inference_kwargs: dict = {},
perturb_doc_key: str = 'perturbed_doc',
perturbed_response_key: str = 'perturbed_response',
result_key: str = 'result',
) -> List[Dict[str, str]]:
self.local_inference_kwargs = local_inference_kwargs
return super().inference(docs, encode_template, instruction_template, decode_template, verbose, remote_inference_kwargs, \
local_inference_kwargs, perturb_doc_key, perturbed_response_key, result_key)
class SLMEncoderDecoderServer(InferDPTServer):
pass
================================================
FILE: python/fate_llm/algo/fedcot/fedcot_trainer.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import pickle
import time
from torch import nn
from typing import List, Optional, Callable, Literal, Union
from fate.arch import Context
from torch.utils.data import DataLoader, Dataset
from transformers.trainer_callback import TrainerCallback
from transformers import PreTrainedTokenizer
import logging
import torch
import torch.distributed as dist
from fate_llm.dataset.fedcot_dataset import PrefixDataset
from transformers.modeling_utils import unwrap_model
from transformers import PreTrainedTokenizer, PreTrainedModel
from typing import Dict, Any
from transformers import Seq2SeqTrainingArguments
from transformers.trainer_utils import EvalPrediction
from fate_llm.trainer.seq2seq_trainer import Seq2SeqTrainer, Seq2SeqTrainingArguments
from fate_llm.inference.inference_base import Inference
from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer
from fate_llm.algo.fedcot.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderClient, SLMEncoderDecoderServer
logger = logging.getLogger(__name__)
_MODE = ['train_only', 'infer_only', 'infer_and_train']
# share obj between ranks in an easy way
def save_to(obj, filepath, filename='tmp.pkl'):
if not os.path.exists(filepath):
os.mkdir(filepath)
path = filepath + filename
with open(path, 'wb') as f:
pickle.dump(obj, f)
dist.barrier()
os.remove(path)
def load(filepath, filename='tmp.pkl'):
path = filepath + filename
while not os.path.exists(path):
time.sleep(0.1)
while True:
try:
with open(path, 'rb') as f:
d = pickle.load(f)
break
except (EOFError, pickle.UnpicklingError):
time.sleep(0.1)
dist.barrier()
return d
class DSSTrainerClient(Seq2SeqTrainer):
def __init__(self,
model: nn.Module,
training_args: Seq2SeqTrainingArguments,
train_set: Dataset,
val_set: Dataset = None,
alpha: float = 0.5,
optimizer: torch.optim.Optimizer = None,
data_collator: Callable = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
callbacks: Optional[List[TrainerCallback]] = [],
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None
) -> None:
self.alpha = alpha
Seq2SeqTrainer.__init__(
self,
model=model,
args=training_args,
train_dataset=train_set,
eval_dataset=val_set,
data_collator=data_collator,
optimizers=(optimizer, scheduler),
tokenizer=tokenizer,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
compute_metrics=compute_metrics,
callbacks=callbacks,
)
def compute_loss(self, model, inputs, return_outputs=False):
label_outputs = model(**inputs['predict'])
cot_outputs = model(**inputs['rationale'])
loss = self.alpha * cot_outputs.loss + (1. - self.alpha) * label_outputs.loss
return (loss, {'rationale_loss': cot_outputs, 'predict_loss': label_outputs}) if return_outputs else loss
class FedCoTTrainerClient(DSSTrainerClient):
def __init__(self,
ctx: Context,
training_args: Seq2SeqTrainingArguments,
train_set: PrefixDataset,
val_set: Dataset = None,
model: nn.Module = None,
optimizer: torch.optim.Optimizer = None,
data_collator: Callable = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
callbacks: Optional[List[TrainerCallback]] = [],
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
alpha: float = 0.5,
mode: Literal['train_only', 'infer_only', 'infer_and_train'] = 'infer_and_train',
infer_client: Union[SLMEncoderDecoderClient, InferDPTClient] = None,
encode_template: str = None,
instruction_template: str = None,
decode_template: str = None,
result_key: str = 'infer_result',
verbose: bool = False,
remote_inference_kwargs: dict = {},
local_inference_kwargs: dict = {},
tmp_data_share_path: str = None
) -> None:
self.mode = mode
self.infer_client = infer_client
self.infer_result = None
self.infer_predict_kwargs = {
'encode_template': encode_template,
'instruction_template': instruction_template,
'decode_template': decode_template,
'result_key': result_key,
'verbose': verbose,
'remote_inference_kwargs': remote_inference_kwargs,
'local_inference_kwargs': local_inference_kwargs
}
self.infer_result = None
self.tmp_data_share_path = tmp_data_share_path
assert mode in _MODE, "mode should be one of {}".format(_MODE)
if training_args.local_rank == 0:
if mode == 'infer_only' or mode == 'infer_and_train':
if self.infer_client is None:
raise ValueError('You must provide an inference instance for remote inference')
if mode != 'infer_only':
training_args.remove_unused_columns = False # this parameter is neccessary
DSSTrainerClient.__init__(
self,
model=model,
training_args=training_args,
train_set=train_set,
val_set=val_set,
data_collator=data_collator,
optimizer=optimizer,
scheduler=scheduler,
tokenizer=tokenizer,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
compute_metrics=compute_metrics,
callbacks=callbacks,
alpha=alpha
)
else:
# skip trainer initialzation becuase training is not needed
self.args = training_args
self.train_dataset = train_set
def infer(self) -> List[str]:
if self.args.local_rank == 0: # other rank will skip federation step
assert isinstance(self.train_dataset, PrefixDataset), "train_set should be an instance of PrefixDataset"
dict_dataset = self.train_dataset.get_raw_dataset()
infer_result = self.infer_client.inference(dict_dataset, **self.infer_predict_kwargs)
self.infer_result = infer_result
rationale_list = [i[self.infer_predict_kwargs['result_key']] for i in self.infer_result]
self.train_dataset.load_rationale(rationale_list, key=self.infer_predict_kwargs['result_key'])
logger.info('infer done')
if self.mode == 'infer_and_train':
if self.args.world_size > 1: # sync dataset with other ranks
tmp_path = self.tmp_data_share_path if self.tmp_data_share_path is not None else self.args.output_dir
logger.info('scattering obj, save to temp path {}'.format(tmp_path))
save_to(rationale_list, tmp_path)
if self.args.local_rank > 0:
if self.mode == 'infer_and_train':
# wait until infer is done
tmp_path = self.tmp_data_share_path if self.tmp_data_share_path is not None else self.args.output_dir
logger.info('waiting for obj, load frm temp path {}'.format(tmp_path))
rationale_list = load(tmp_path)
self.train_dataset.load_rationale(rationale_list)
logger.info('Rationale loaded')
def train(self):
if self.mode == 'train_only':
logger.info("Train only mode")
super().train()
elif self.mode == 'infer_only':
logger.info("infer only mode, skip training")
self.infer()
elif self.mode == 'infer_and_train':
logger.info("infer and train mode")
self.infer()
super().train()
def get_infer_result(self):
return self.infer_result
class FedCoTTraineServer(object):
def __init__(self, ctx: Context, infer_server: Union[SLMEncoderDecoderServer, InferDPTServer]):
super().__init__()
self.ctx = ctx
self.infer_server = infer_server
def train(self):
logger.info('Server side start inference')
self.infer_server.inference()
logger.info('Server inference done')
if __name__ == '__main__':
pass
================================================
FILE: python/fate_llm/algo/fedcot/slm_encoder_decoder_trainer.py
================================================
from fate_llm.trainer.seq2seq_trainer import Seq2SeqTrainer
from transformers import DataCollatorForSeq2Seq
from transformers import AutoTokenizer
import pandas as pd
class EDPrefixDataCollator(DataCollatorForSeq2Seq):
def __call__(self, features, return_tensors=None):
features_df = pd.DataFrame(features)
a = super().__call__(list(features_df['encoder']), return_tensors)
b = super().__call__(list(features_df['decoder']), return_tensors)
return {
'encoder': a,
'decoder': b
}
class EncoderDecoderPrefixTrainer(Seq2SeqTrainer):
def __init__(self, alpha=0.5, *args, **kwargs):
super().__init__(*args, **kwargs)
self.alpha = alpha
def compute_loss(self, model, inputs, return_outputs=False):
out_a = model(**inputs['encoder'])
out_b = model(**inputs['decoder'])
loss = self.alpha * out_a.loss + (1. - self.alpha) * out_b.loss
return (loss, {'out_a': out_a, 'out_b': out_b}) if return_outputs else loss
================================================
FILE: python/fate_llm/algo/fedkseed/__init__.py
================================================
================================================
FILE: python/fate_llm/algo/fedkseed/args.py
================================================
from dataclasses import dataclass, field
@dataclass
class KSeedTrainingArguments:
"""
TrainingArguments is the subset of the arguments we use in our example scripts, they are the arguments that
Parameters:
optim: optional, default is KSeedZO
The optimizer to use.
eps: optional, default is 0.0005
Epsilon value for KSeedZerothOrderOptimizer.
grad_clip: optional, default is -100.0
Gradient clip value for KSeedZerothOrderOptimizer.
"""
zo_optim: bool = field(
default=True,
metadata={"help": "Whether to use KSeedZerothOrderOptimizer. This suppress `optim` argument when True."},
)
k: int = field(
default=4096,
metadata={"help": "The number of seed candidates to use. This suppress `seed_candidates` argument when > 1."},
)
eps: float = field(default=0.0005, metadata={"help": "Epsilon value for KSeedZerothOrderOptimizer."})
grad_clip: float = field(default=-100.0, metadata={"help": "Gradient clip value for KSeedZerothOrderOptimizer."})
================================================
FILE: python/fate_llm/algo/fedkseed/fedkseed.py
================================================
import copy
import logging
from dataclasses import dataclass, field
from typing import List, Mapping
import torch
from fate.arch.context import Context
from fate_llm.algo.fedkseed.pytorch_utils import get_optimizer_parameters_grouped_with_decay
from fate_llm.algo.fedkseed.trainer import KSeedZOExtendedTrainer
from fate_llm.algo.fedkseed.zo_utils import probability_from_amps, directional_derivative_step, get_even_seed_probabilities
from fate_llm.algo.fedkseed.args import KSeedTrainingArguments
logger = logging.getLogger(__name__)
class Trainer:
def __init__(
self, ctx: Context, seed_candidates: torch.LongTensor, args, fedkseed_args,
):
self.ctx = ctx
self.args = args
self.fedkseed_args = fedkseed_args
self.seed_candidates = seed_candidates
self.k = len(seed_candidates)
self.model = None
@staticmethod
def get_clients(ctx: Context):
clients = [ctx.guest]
try:
clients.extend(ctx.hosts)
except:
pass
return clients
def load_model(self):
raise NotImplementedError
def train(self):
direction_derivative_history = {seed.item(): [self.fedkseed_args.grad_initial] for seed in self.seed_candidates}
direction_derivative_sum = None
seed_probabilities = None
for aggregation_iter, sub_ctx in self.ctx.ctxs_range(self.fedkseed_args.num_aggregations):
# step1: re-calculate sample probabilities for each seed
if seed_probabilities is None:
seed_probabilities = get_even_seed_probabilities(self.k)
else:
seed_probabilities = probability_from_amps(
[direction_derivative_history[seed.item()] for seed in self.seed_candidates],
self.fedkseed_args.bias_loss_clip,
)
# step2(rpc): remote call to the clients to get the directional derivative history
# proposal
for client in self.get_clients(sub_ctx):
client.put(
"train_once",
(
False,
{
"seed_candidates": self.seed_candidates,
"seed_probabilities": seed_probabilities,
"direction_derivative_sum": direction_derivative_sum,
},
),
)
if direction_derivative_sum is None:
direction_derivative_sum = {seed.item(): 0.0 for seed in self.seed_candidates}
# wait for reply and update the directional derivative history
for client in self.get_clients(sub_ctx):
client_directional_derivative_history = client.get("direction_derivative_history")
for seed, history in client_directional_derivative_history.items():
# torch.LongTensor -> int
seed = int(seed)
if seed not in direction_derivative_history:
direction_derivative_history[seed] = []
direction_derivative_history[seed].extend(history)
direction_derivative_sum[seed] += sum(history)
# step3: evaluate to get stopping condition if necessary
if self.should_stop():
break
def should_stop(self):
return False
def evaluate(self):
pass
class ClientTrainer:
def __init__(self, ctx: Context, model, fedkseed_args, training_args, train_dataset, eval_dataset, data_collator,
tokenizer):
self.ctx = ctx
self.fedkseed_args = fedkseed_args
self.training_args = training_args
self.data_collator = data_collator
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.tokenizer = tokenizer
self.weight_decay = training_args.weight_decay
self.model_0 = model
def train(self):
for i, sub_ctx in self.ctx.ctxs_range(self.fedkseed_args.num_aggregations):
# step1: wait for the server to send the seed candidates and probabilities or exit signal
logger.info(f"training loop started: {i}")
should_exit, kwargs = sub_ctx.arbiter.get("train_once")
seed_candidates = kwargs["seed_candidates"]
seed_probabilities = kwargs["seed_probabilities"]
direction_derivative_sum = kwargs["direction_derivative_sum"]
logger.info(
f"should_exit: {should_exit}, seed_candidates: {seed_candidates}, seed_probabilities: {seed_probabilities}"
)
if should_exit:
break
# step2: start the training loop
direction_derivative_history = self.train_once(
seed_candidates, seed_probabilities, direction_derivative_sum
)
# step3: send the directional derivative history to the server
sub_ctx.arbiter.put("direction_derivative_history", direction_derivative_history)
def train_once(self, seed_candidates, seed_probabilities, direction_derivative_sum) -> Mapping[int, List[float]]:
# build model
model = copy.deepcopy(self.model_0)
model.to(self.training_args.device)
if direction_derivative_sum is not None:
param_groups = get_optimizer_parameters_grouped_with_decay(model, self.weight_decay)
for seed, grad in direction_derivative_sum.items():
if grad != 0.0:
directional_derivative_step(
param_groups, seed, grad, lr=self.training_args.learning_rate,
weight_decay=self.training_args.weight_decay
)
# train
trainer = KSeedZOExtendedTrainer(
model=model,
training_args=self.training_args,
kseed_args=self.fedkseed_args,
tokenizer=self.tokenizer,
data_collator=self.data_collator,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
)
trainer.configure_seed_candidates(seed_candidates, seed_probabilities)
trainer.train()
if self.eval_dataset is not None:
logger.info(f"evaluate: {trainer.evaluate()}")
# get directional derivative history
return trainer.get_directional_derivative_history()
@dataclass
class FedKSeedTrainingArguments(KSeedTrainingArguments):
num_aggregations: int = field(default=10, metadata={"help": "The number of aggregations to perform."})
bias_loss_clip: float = field(default=1000.0, metadata={"help": "The bias loss clip value."})
grad_initial: float = field(
default=0.0, metadata={"help": "The initial value for the directional derivative history."}
)
================================================
FILE: python/fate_llm/algo/fedkseed/optimizer.py
================================================
"""
The implementations of ZerothOrderOptimizer and KSeedZerothOrderOptimizer is
adapted from https://github.com/princeton-nlp/MeZO (MIT License) and
https://github.com/alibaba/FederatedScope/tree/FedKSeed (Apache License 2.0)
Copyright (c) 2021 Princeton Natural Language Processing
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
---
#
# Copyright 2023 The FederatedScope Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import math
from typing import Mapping, Optional, Callable, Tuple, List
import torch
from torch.optim import Optimizer
from fate_llm.algo.fedkseed.pytorch_utils import get_optimizer_parameters_grouped_with_decay
from fate_llm.algo.fedkseed.zo_utils import directional_derivative_step
class RandomWalkOptimizer(Optimizer):
"""
Random Walk Optimizer
This optimizer performs a `random` walk update for the parameters of the model.
"""
def __init__(self, params, lr, weight_decay, grad_clip, defaults=None):
self.lr = lr
self.weight_decay = weight_decay
self.grad_clip = grad_clip
if defaults is None:
defaults = dict(lr=lr, weight_decay=weight_decay)
else:
defaults = dict(defaults)
defaults.update(lr=lr, weight_decay=weight_decay)
super(RandomWalkOptimizer, self).__init__(params, defaults)
@classmethod
def from_model(cls, model, lr, weight_decay, grad_clip, **kwargs):
optimizer_grouped_parameters = get_optimizer_parameters_grouped_with_decay(model, weight_decay)
kwargs["lr"] = lr
kwargs["weight_decay"] = weight_decay
kwargs["grad_clip"] = grad_clip
return cls(optimizer_grouped_parameters, **kwargs)
def directional_derivative_step(
self, directional_derivative_seed: int, directional_derivative_value: torch.FloatTensor
) -> torch.FloatTensor:
"""
perform a step update for the parameters of the model
along the random direction z with the learning rate lr and the step size grad_projected_value
"""
if self.grad_clip > 0.0:
if abs(directional_derivative_value) > self.grad_clip:
return torch.FloatTensor([torch.nan])
directional_derivative_step(self.param_groups, directional_derivative_seed, directional_derivative_value)
return directional_derivative_value
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
raise NotImplementedError(
"use random_step instead of step for RandomWalkOptimizer \
since we need pass the `seed` and `grad_projected_value`"
)
class ZerothOrderOptimizer(RandomWalkOptimizer):
def __init__(self, params, lr, eps, weight_decay, grad_clip):
self.eps = eps
defaults = dict(eps=eps)
super(ZerothOrderOptimizer, self).__init__(params, lr, weight_decay, grad_clip, defaults)
def zeroth_order_step(
self, directional_derivative_seed: int, closure: Callable[[], torch.FloatTensor]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""
perform a step update for the parameters of the model along the
random direction z generated by the `directional_derivative_seed`
with the learning rate lr
and the step size of calculated namely `directional_derivative_value`
Input:
- directional_derivative_seed: the seed for generating the random direction z
- closure (callable, optional): A closure that reevaluates the model and returns the loss.
Output:
- directional_derivative_value: the gradient projected value
- loss_right: the loss of the model with the perturbed parameters x + eps * z
- loss_left: the loss of the model with the perturbed parameters x - eps * z
"""
# x -> x + eps * z
self.random_perturb_parameters(directional_derivative_seed, scaling_factor=1.0)
loss_right = closure()
# x + eps * z -> x - eps * z
self.random_perturb_parameters(directional_derivative_seed, scaling_factor=-2.0)
loss_left = closure()
# x - eps * z -> x
self.random_perturb_parameters(directional_derivative_seed, scaling_factor=1.0)
if torch.isnan(loss_right):
return loss_right, loss_right, loss_left
if torch.isnan(loss_left):
return loss_left, loss_right, loss_left
# ∇f(x) · z = D_z f(x) ≈ (f(x + eps * z) - f(x - eps * z)) / (2 * eps)
directional_derivative_value = (loss_right - loss_left) / (2 * self.eps)
# perform update for the random direction z * grad_projected_value
directional_derivative_value = self.directional_derivative_step(
directional_derivative_seed, directional_derivative_value
)
return directional_derivative_value, loss_right, loss_left
def random_perturb_parameters(self, directional_derivative_seed: int, scaling_factor: float):
"""
Perturb the parameters with random direction z generated by the directional_derivative_seed
for each parameter theta, the update is theta = theta + scaling_factor * z * eps
Input:
- seed: the seed for generating the random direction z
- scaling_factor: the scaling factor for the random direction z
Output:
- None
"""
torch.manual_seed(directional_derivative_seed)
for param_group in self.param_groups:
eps = param_group["eps"]
for param in param_group["params"]:
if param.requires_grad:
z = torch.normal(
mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype
)
param.data = param.data + scaling_factor * eps * z
class KSeedZerothOrderOptimizer(ZerothOrderOptimizer):
def __init__(
self,
params,
seed_candidates: torch.LongTensor,
seed_probabilities: torch.FloatTensor,
lr,
eps,
weight_decay,
grad_clip,
):
self.seed_candidate = seed_candidates
self.seed_probabilities = seed_probabilities
self.directional_derivative_history: Mapping[int, List[float]] = {seed.item(): [] for seed in seed_candidates}
self.sample_random_generator = torch.Generator()
super(KSeedZerothOrderOptimizer, self).__init__(params, lr, eps, weight_decay, grad_clip)
def sample(self) -> int:
sampled = torch.multinomial(
input=self.seed_probabilities,
num_samples=1,
generator=self.sample_random_generator,
)[0].item()
return self.seed_candidate[sampled].item()
def step(self, closure: Callable[[], torch.FloatTensor] = None) -> torch.FloatTensor:
if closure is None:
# closure is required for the zeroth_order_step, but we
# don't raise an error here to maintain compatibility with
# the third-party tools that use the `step` method without
# providing the closure in training loop, e.g., HuggingFace Transformers
return torch.FloatTensor([torch.nan])
return self.kseed_zeroth_order_step(closure)
def kseed_zeroth_order_step(self, closure: Callable[[], torch.FloatTensor]) -> torch.FloatTensor:
"""
Performs a single optimization step.
1. Sample a random seed for sampling z
2. Perturb the parameters with the random direction(-z * eps, z * eps) for evaluating the model on the batch, and compute the loss(loss1, loss2)
3. Compute the directional derivative value: grad_projected_value = (loss_right - loss_left) / (2 * eps)
4. Perform the directional derivative step update for the parameters of the model along the random direction z with the learning rate lr and the step size grad_projected_value
Input:
- closure (callable, optional): A closure that reevaluates the model and returns the loss.
"""
if closure is None:
raise ValueError("closure must not be None")
# sample the random seed for sampling z for perturbing parameters.
seed = self.sample()
directional_derivative_value, loss_right, loss_left = self.zeroth_order_step(seed, closure)
if math.isnan(directional_derivative_value):
return directional_derivative_value
# record the directional_derivative_value for the seed
self.directional_derivative_history[seed].append(directional_derivative_value.item())
return loss_right # TODO: return loss_left or loss_right or average of both?
================================================
FILE: python/fate_llm/algo/fedkseed/pytorch_utils.py
================================================
from typing import List
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
def get_decay_parameter_names(model) -> List[str]:
"""
Get all parameter names that weight decay will be applied to
Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still
apply to those modules since this function only filter out instance of nn.LayerNorm
NOTE: This function is copied from transformers
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
return decay_parameters
def get_optimizer_parameters_grouped_with_decay(model, weight_decay: float) -> List[dict]:
"""
Get the parameters grouped by whether they should have weight decay applied
"""
decay_parameters = get_decay_parameter_names(model)
params_no_decay = []
params_decay = []
for n, p in model.named_parameters():
if p.requires_grad:
if n in decay_parameters:
params_decay.append(p)
else:
params_no_decay.append(p)
grouped_parameters_with_decay = [
{"params": params_no_decay, "weight_decay": 0.0},
{"params": params_decay, "weight_decay": weight_decay},
]
return grouped_parameters_with_decay
================================================
FILE: python/fate_llm/algo/fedkseed/trainer.py
================================================
import logging
from typing import Dict, Union, Any, Tuple
from typing import Optional, List, Callable
import torch
from torch import nn
from torch.utils.data import Dataset
from transformers import PreTrainedModel, PreTrainedTokenizerBase, EvalPrediction, DataCollator
from transformers import Trainer, TrainingArguments
from transformers.optimization import get_scheduler, SchedulerType
from transformers.trainer_callback import TrainerCallback
from fate_llm.algo.fedkseed.args import KSeedTrainingArguments
from fate_llm.algo.fedkseed.optimizer import KSeedZerothOrderOptimizer
from fate_llm.algo.fedkseed.pytorch_utils import get_optimizer_parameters_grouped_with_decay
logger = logging.getLogger(__name__)
class KSeedZOExtendedTrainer(Trainer):
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
training_args: TrainingArguments = None,
kseed_args: "KSeedTrainingArguments" = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
):
super().__init__(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
model_init=model_init,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
self.kseed_args = kseed_args
self._kseed_optimizer = None
self._seed_candidates = None
self._seed_probabilities = None
def configure_seed_candidates(self, seed_candidates: torch.LongTensor, seed_probabilities: torch.FloatTensor):
self._seed_candidates = seed_candidates
self._seed_probabilities = seed_probabilities
def get_directional_derivative_history(self):
"""
hook to get the directional derivative history
"""
if KSeedZOExtendedTrainer.k_seed_zo_mode(self.kseed_args):
if self._kseed_optimizer is None:
raise ValueError("KSeedZerothOrderOptimizer is not configured")
return self._kseed_optimizer.directional_derivative_history
else:
raise ValueError("KSeedZerothOrderOptimizer is not configured")
@staticmethod
def k_seed_zo_mode(args):
return hasattr(args, "zo_optim") and args.zo_optim
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
"""
hook to do the step with KSeedZerothOrderOptimizer
"""
if KSeedZOExtendedTrainer.k_seed_zo_mode(self.kseed_args):
if self._kseed_optimizer is None:
raise ValueError("KSeedZerothOrderOptimizer is not configured")
model.eval()
inputs = self._prepare_inputs(inputs)
with self.compute_loss_context_manager():
# zeroth order optimization needs forward pass twice in an optimization step,
# so we need to wrap the forward pass in a closure
def closure() -> torch.FloatTensor:
with torch.no_grad():
return self.compute_loss(model, inputs, return_outputs=False).detach()
# we don't use step() method of KSeedZerothOrderOptimizer here
# because `Trainer` wraps the optimizer that is subclass of `torch.optim.Optimizer` and
# returns nothing from the step method
with torch.no_grad():
loss = self._kseed_optimizer.kseed_zeroth_order_step(closure=closure)
return loss.detach()
else:
return super().training_step(model, inputs)
def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
hook to add KSeedZerothOrderOptimizer
"""
if KSeedZOExtendedTrainer.k_seed_zo_mode(self.kseed_args):
if self._seed_candidates is None or self._seed_probabilities is None:
raise ValueError("Seed candidates and probabilities are not configured.")
optimizer_grouped_parameters = get_optimizer_parameters_grouped_with_decay(
self.model, self.args.weight_decay
)
self.optimizer = KSeedZerothOrderOptimizer(
optimizer_grouped_parameters,
seed_candidates=self._seed_candidates,
seed_probabilities=self._seed_probabilities,
lr=self.args.learning_rate,
eps=self.kseed_args.eps,
weight_decay=self.args.weight_decay,
grad_clip=self.kseed_args.grad_clip,
)
# we need to keep the reference to the original optimizer to use it in training_step
self._kseed_optimizer = self.optimizer
# if we use learning rate scheduler, we may need to preserve all updates instead of the aggregated one
self.lr_scheduler = get_scheduler(
name=SchedulerType.CONSTANT,
optimizer=self.optimizer,
num_warmup_steps=self.args.warmup_steps,
num_training_steps=num_training_steps,
)
else:
super().create_optimizer_and_scheduler(num_training_steps)
================================================
FILE: python/fate_llm/algo/fedkseed/zo_utils.py
================================================
from typing import List
import torch
def probability_from_amps(amps: List[List[float]], clip):
"""
Get the probability distribution from the amplitude history
formula: amp_i = clamp(amp_i, -clip, clip).abs().mean()
amp_i = (amp_i - min(amp)) / (max(amp) - min(amp))
prob_i = softmax(amp)_i
:param amps: list of amplitude history
:param clip: the clipping value
:return:
"""
amps = [torch.Tensor(amp) for amp in amps]
amp = torch.stack([amp.clamp_(-clip, clip).abs_().mean() for amp in amps])
return (amp - amp.min()).div_(amp.max() - amp.min() + 1e-10).softmax(0)
def directional_derivative_step(
param_groups: List[dict],
directional_derivative_seed: int,
directional_derivative_value: torch.FloatTensor,
lr: float = None,
weight_decay: float = None,
) -> torch.FloatTensor:
"""
perform a step update for the parameters of the model
along the random direction z with the learning rate lr and the step size grad_projected_value
Input:
- param_groups (List[dict]): list of parameter groups
- directional_derivative_seed (int): seed for the random direction
- directional_derivative_value (torch.FloatTensor): the step size
- lr (float, optional): learning rate
- weight_decay (float, optional): weight decay
"""
torch.manual_seed(directional_derivative_seed)
for param_group in param_groups:
weight_decay = param_group["weight_decay"] if weight_decay is None else weight_decay
lr = param_group["lr"] if lr is None else lr
for param in param_group["params"]:
z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
if weight_decay is not None:
param.data = param.data - lr * (directional_derivative_value * z + weight_decay * param.data)
else:
param.data = param.data - lr * (directional_derivative_value * z)
return directional_derivative_value
def build_seed_candidates(k, low=0, high=2**32):
"""
Build seed candidates for the random walk optimizer
"""
return torch.randint(low, high, size=(k,), dtype=torch.long)
def get_even_seed_probabilities(k):
"""
Get the even seed probabilities, i.e., 1/k for each seed
"""
return torch.ones(k) / k
================================================
FILE: python/fate_llm/algo/fedmkt/__init__.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from fate_llm.algo.fedmkt.fedmkt import (
FedMKTTrainingArguments,
FedMKTSLM,
FedMKTLLM
)
__all__ = [
"FedMKTSLM",
"FedMKTLLM",
"FedMKTTrainingArguments"
]
================================================
FILE: python/fate_llm/algo/fedmkt/fedmkt.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import logging
import datasets
from dataclasses import dataclass, field
import transformers
from ...trainer.seq2seq_trainer import Seq2SeqTrainingArguments
from typing import Dict, Optional, List, Callable, Union
from fate.arch import Context
from fate.ml.nn.trainer.trainer_base import FedArguments
from torch.utils.data import Dataset
from transformers.trainer_callback import TrainerCallback
from transformers import PreTrainedTokenizer
from transformers import Seq2SeqTrainer
from transformers.trainer_utils import EvalPrediction
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_utils import unwrap_model
from fate_llm.algo.fedmkt.token_alignment.token_align import token_align
from fate_llm.algo.fedmkt.utils.generate_logit_utils import generate_pub_data_logits
from fate.ml.aggregator import AggregatorClientWrapper, AggregatorServerWrapper
from fate_llm.algo.fedmkt.fedmkt_trainer import FedMKTTrainer
from fate_llm.algo.fedmkt.fedmkt_data_collator import DataCollatorForFedMKT
from fate_llm.algo.fedmkt.utils.dataset_sync_util import sync_dataset
logger = logging.getLogger(__name__)
@dataclass
class FedMKTTrainingArguments(Seq2SeqTrainingArguments):
"""
selection metric type
"""
metric_type: str = field(default="ce")
"""
top-k logits select params
"""
top_k_logits_keep: int = field(default=128)
top_k_strategy: str = field(default="highest")
"""
distillation params
"""
distill_loss_type: str = field(default="ce")
kd_alpha: float = field(default=0.9)
distill_temperature: float = field(default=1.0)
server_public_data_local_epoch: int = field(default=1)
client_public_data_local_epoch: int = field(default=1)
client_priv_data_local_epoch: int = field(default=1)
distill_strategy: str = field(default="greater")
global_epochs: int = field(default=1)
"""
token-alignment params
"""
skip_align: bool = field(default=False)
token_align_strategy: str = field(default="dtw")
vocab_mapping_paths: Union[str, List[str]] = field(default=None)
vocab_size: int = field(default=None)
"""
homo training params
"""
post_fedavg: bool = field(default=False)
"""
slm training only
"""
llm_training: bool = field(default=True)
def to_dict(self):
from dataclasses import fields
from enum import Enum
d = {field.name: getattr(self, field.name) for field in fields(self) if field.init}
for k, v in d.items():
if isinstance(v, Enum):
d[k] = v.value
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
d[k] = [x.value for x in v]
if k.endswith("_token"):
d[k] = f"<{k.upper()}>"
return d
def to_dict_without_extra_args(self):
args_dict = self.to_dict()
args_dict.pop("metric_type")
args_dict.pop("top_k_logits_keep")
args_dict.pop("top_k_strategy")
args_dict.pop("distill_loss_type")
args_dict.pop("kd_alpha")
args_dict.pop("distill_temperature")
args_dict.pop("distill_strategy")
args_dict.pop("server_public_data_local_epoch")
args_dict.pop("client_public_data_local_epoch")
args_dict.pop("client_priv_data_local_epoch")
args_dict.pop("global_epochs")
args_dict.pop("skip_align", False)
args_dict.pop("token_align_strategy")
args_dict.pop("vocab_mapping_paths", None)
args_dict.pop("vocab_size", None)
args_dict.pop("post_fedavg")
args_dict.pop("llm_training", True)
return args_dict
def to_dict_with_client_priv_training_args(self):
args_dict = self.to_dict_without_extra_args()
args_dict["num_train_epochs"] = self.client_priv_data_local_epoch
return args_dict
def to_dict_with_client_kd_args(self):
args_dict = self.to_dict_without_extra_args()
args_dict["num_train_epochs"] = self.client_public_data_local_epoch
return args_dict
def to_dict_with_server_kd_args(self):
args_dict = self.to_dict_without_extra_args()
args_dict["num_train_epochs"] = self.server_public_data_local_epoch
return args_dict
class FedMKTBase(object):
def __init__(self, *args, **kwargs):
self.model = None
self.save_trainable_weights_only = None
def save_model(
self,
output_dir: Optional[str] = None,
state_dict=None
):
if not self.save_trainable_weights_only:
torch.save(self.model.state_dict(), output_dir + '/pytorch_model.bin')
else:
model = unwrap_model(self.model)
if hasattr(model, "save_trainable"):
model.save_trainable(output_dir)
else:
state_dict = {
k: p.to("cpu") for k,
p in model.named_parameters() if p.requires_grad
}
torch.save(state_dict, output_dir + '/pytorch_model.bin')
class FedMKTSLM(FedMKTBase):
def __init__(
self,
ctx: Context,
model: torch.nn.Module,
training_args: FedMKTTrainingArguments,
fed_args: FedArguments = None,
priv_train_set=None,
pub_train_set=None,
val_set: Dataset = None,
priv_optimizer: torch.optim.Optimizer = None,
pub_optimizer: torch.optim.Optimizer = None,
priv_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
pub_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
data_collator: Callable = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = [],
save_trainable_weights_only: bool = False,
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
llm_tokenizer=None,
llm_to_slm_vocab_mapping=None,
):
super(FedMKTSLM, self).__init__()
self.ctx = ctx
self.training_args = training_args
self.fed_args = fed_args
self.model = model
self.tokenizer = tokenizer
self.model_init = model_init
self.callbacks = callbacks
self.compute_metrics = compute_metrics
self.save_trainable_weights_only = save_trainable_weights_only
self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
self.priv_data_collator = data_collator
self.priv_optimizer = priv_optimizer
self.pub_optimizer = pub_optimizer
self.priv_scheduler = priv_scheduler
self.pub_scheduler = pub_scheduler
self.priv_train_set = priv_train_set
self.pub_train_set = pub_train_set
self.llm_tokenizer = llm_tokenizer
self.llm_to_slm_vocab_mapping = llm_to_slm_vocab_mapping
self.val_set = val_set
self.aggregator = self._init_aggregator(ctx, fed_args)
if not isinstance(self.pub_train_set, datasets.Dataset):
self.pub_train_set = datasets.Dataset.from_list(list(self.pub_train_set))
def train(self):
global_epochs = self.training_args.global_epochs
llm_pub_logits = None
for i, iter_ctx in self.ctx.on_iterations.ctxs_range(global_epochs):
logger.info(f"begin {i}-th global kd process")
priv_data_training_args = self._get_priv_data_training_args()
priv_trainer = Seq2SeqTrainer(
model=self.model,
tokenizer=self.tokenizer,
data_collator=self.priv_data_collator,
train_dataset=self.priv_train_set,
args=priv_data_training_args,
model_init=self.model_init if not i else None,
compute_metrics=self.compute_metrics,
callbacks=self.callbacks,
optimizers=(self.priv_optimizer, self.priv_scheduler),
preprocess_logits_for_metrics=self.preprocess_logits_for_metrics
)
logger.info(f"begin {i}-th private data training process")
priv_trainer.train()
self.model = unwrap_model(priv_trainer.model)
logger.info(f"begin {i}-th public logits generation process")
if self.training_args.world_size <= 1 or self.training_args.local_rank == 0:
slm_pub_logits = self.pub_train_set.map(
generate_pub_data_logits,
batched=True,
batch_size=self.training_args.per_device_train_batch_size,
num_proc=None,
load_from_cache_file=True,
fn_kwargs={"model": self.model,
"training_args": self.training_args,
"data_collator": transformers.DataCollatorForSeq2Seq(self.tokenizer)}
)
if self.training_args.world_size > 1:
logger.info("sync slm_pub_logits")
sync_dataset(
slm_pub_logits, self.training_args.local_rank, self.training_args.world_size, self.training_args.device
)
if self.training_args.llm_training:
logger.debug(f"send {i}-th public logits to llm")
iter_ctx.arbiter.put("slm_pub_logits", slm_pub_logits.to_dict())
if self.training_args.llm_training or not i:
llm_pub_logits = datasets.Dataset.from_dict(iter_ctx.arbiter.get("llm_pub_logits"))
if self.training_args.world_size > 1:
logger.info("sync llm_pub_logits")
sync_dataset(llm_pub_logits, self.training_args.local_rank,
self.training_args.world_size, self.training_args.device)
else:
slm_pub_logits = sync_dataset(
None, self.training_args.local_rank, self.training_args.world_size, self.training_args.device
)
if self.training_args.llm_training or not i:
llm_pub_logits = sync_dataset(None, self.training_args.local_rank,
self.training_args.world_size, self.training_args.device)
logger.info(f"begin {i}-th token alignment process")
aligned_dataset = token_align(
base_model_logits_datasets=slm_pub_logits,
blending_model_logits_dataset=llm_pub_logits,
base_tokenizer=self.tokenizer,
blending_tokenizer=self.llm_tokenizer,
blending_to_base_mapping=self.llm_to_slm_vocab_mapping,
blending_model_index=0,
skip_align=self.training_args.skip_align,
align_strategy=self.training_args.token_align_strategy
)
logger.info(f"begin {i}-th public logits kd process")
fedmkt_trainer = self._init_trainer_for_distill(aligned_dataset)
fedmkt_trainer.train()
self.model = unwrap_model(fedmkt_trainer.model)
if self.training_args.post_fedavg and (i + 1) % self.fed_args.aggregate_freq == 0:
self.aggregator.model_aggregation(iter_ctx, self.model)
def _init_trainer_for_distill(self, train_set):
public_data_training_args = self._get_pub_data_kd_training_args()
fedmkt_trainer = FedMKTTrainer(
model=self.model,
tokenizer=self.tokenizer,
args=public_data_training_args,
train_dataset=train_set,
eval_dataset=self.val_set,
data_collator=DataCollatorForFedMKT(
self.tokenizer,
padding="max_length",
max_length=max(len(d["input_ids"]) for d in train_set),
blending_num=1,
vocab_size=self.training_args.vocab_size,
dtype=next(self.model.parameters()).dtype,
distill_temperature=self.training_args.distill_temperature
),
blending_num=1,
lm_loss_weight=self.training_args.kd_alpha,
distill_loss_type=self.training_args.distill_loss_type,
distill_strategy=self.training_args.distill_strategy
)
return fedmkt_trainer
def _get_priv_data_training_args(self):
pre_args = self.training_args.to_dict_with_client_priv_training_args()
post_args = Seq2SeqTrainingArguments(**pre_args)
return post_args
def _get_pub_data_kd_training_args(self):
pre_args = self.training_args.to_dict_with_client_kd_args()
post_args = Seq2SeqTrainingArguments(**pre_args)
return post_args
def _init_aggregator(self, ctx: Context, fed_args: FedArguments):
if not self.training_args.post_fedavg:
return None
aggregate_type = "weighted_mean"
aggregator_name = "fedavg"
aggregator = fed_args.aggregator
return AggregatorClientWrapper(
ctx, aggregate_type, aggregator_name, aggregator,
sample_num=len(self.pub_train_set), args=self.training_args
)
class FedMKTLLM(FedMKTBase):
def __init__(
self,
ctx: Context,
model: torch.nn.Module,
training_args: FedMKTTrainingArguments,
fed_args: FedArguments = None,
train_set=None,
val_set: Dataset = None,
optimizer: torch.optim.Optimizer = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
data_collator: Callable = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = [],
save_trainable_weights_only: bool = False,
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
slm_tokenizers: List = None,
slm_to_llm_vocab_mappings: List[Dict] = None,
):
super(FedMKTLLM, self).__init__()
self.ctx = ctx
self.model = model
self.training_args = training_args
self.fed_args = fed_args
self.train_set = train_set
self.val_set = val_set
self.optimizer = optimizer
self.lr_scheduler = scheduler
self.data_collator = data_collator
self.tokenizer = tokenizer
self.model_init = model_init
self.compute_metrics = compute_metrics
self.callbacks = callbacks
self.save_trainable_weights_only = save_trainable_weights_only
self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
self.slm_tokenizers = slm_tokenizers
self.slm_to_llm_vocab_mappings = slm_to_llm_vocab_mappings
self.aggregator = self._init_aggregator(ctx)
if not isinstance(self.train_set, datasets.Dataset):
self.train_set = datasets.Dataset.from_list(list(self.train_set))
def _init_aggregator(self, ctx: Context):
if not self.training_args.post_fedavg:
return None
return AggregatorServerWrapper(ctx)
def generate_pub_data_logits(self, first_epoch=False):
fn_kwargs = {"model": self.model,
"training_args": self.training_args,
"data_collator": transformers.DataCollatorForSeq2Seq(self.tokenizer)}
if first_epoch and self.training_args.device.type == "cuda":
self.model.cuda(self.training_args.device)
return self.train_set.map(
generate_pub_data_logits,
batched=True,
batch_size=self.training_args.per_device_train_batch_size,
num_proc=None,
load_from_cache_file=True,
fn_kwargs=fn_kwargs
)
def on_epoch_begin(self, iter_ctx, epoch_idx, previous_pub_dataset):
logger.info(f"on {epoch_idx}-epoch begin")
if not self.training_args.llm_training:
return
if previous_pub_dataset is None:
if self.training_args.world_size <= 1 or self.training_args.local_rank == 0:
llm_pub_logits = self.generate_pub_data_logits(first_epoch=True if not epoch_idx else False)
if self.training_args.world_size > 1:
sync_dataset(llm_pub_logits, self.training_args.local_rank,
self.training_args.world_size, self.training_args.device)
else:
llm_pub_logits = sync_dataset(None, self.training_args.local_rank,
self.training_args.world_size, self.training_args.device)
else:
llm_pub_logits = previous_pub_dataset
slm_pub_logits_list = list()
if self.training_args.world_size <= 1 or self.training_args.local_rank == 0:
slm_pub_logits_list.append(datasets.Dataset.from_dict(iter_ctx.guest.get('slm_pub_logits')))
if any(p.role == 'host' for p in self.ctx.parties):
slm_pub_logits_list.extend(
datasets.Dataset.from_dict(client_logits) for client_logits in iter_ctx.hosts.get("slm_pub_logits")
)
if self.training_args.world_size > 1:
logger.info("sync dataset to other rank")
for slm_pub_logits in slm_pub_logits_list:
sync_dataset(slm_pub_logits, self.training_args.local_rank,
self.training_args.world_size, self.training_args.device)
logger.info("end to sync")
else:
logger.info("sync dataset from rank 0")
for _ in range(len(self.slm_tokenizers)):
slm_pub_logits_list.append(
sync_dataset(None, self.training_args.local_rank,
self.training_args.world_size, self.training_args.device)
)
logger.info("end to sync dataset from rank 0")
aligned_dataset = llm_pub_logits
for idx, slm_pub_logits in enumerate(slm_pub_logits_list):
aligned_dataset = token_align(
base_model_logits_datasets=aligned_dataset,
blending_model_logits_dataset=slm_pub_logits,
base_tokenizer=self.tokenizer,
blending_tokenizer=self.slm_tokenizers[idx],
blending_to_base_mapping=self.slm_to_llm_vocab_mappings[idx],
blending_model_index=idx,
skip_align=self.training_args.skip_align,
align_strategy=self.training_args.token_align_strategy
)
return aligned_dataset
def on_epoch_end(self, iter_ctx, epoch_idx):
logger.info(f"on {epoch_idx}-epoch end")
if not self.training_args.llm_training and epoch_idx > 1:
return
llm_pub_logits = self.generate_pub_data_logits(first_epoch=True if not self.training_args.llm_training else False)
if self.training_args.world_size <= 1 or self.training_args.local_rank == 0:
iter_ctx.guest.put("llm_pub_logits", llm_pub_logits.to_dict())
if len(self.slm_tokenizers) > 1:
iter_ctx.hosts.put("llm_pub_logits", llm_pub_logits.to_dict())
if self.training_args.post_fedavg and (epoch_idx + 1) % self.fed_args.aggregate_freq == 0:
self.aggregator.model_aggregation(iter_ctx)
if self.training_args.world_size > 1:
sync_dataset(
llm_pub_logits, self.training_args.local_rank, self.training_args.world_size, self.training_args.device
)
else:
llm_pub_logits = sync_dataset(
None, self.training_args.local_rank, self.training_args.world_size, self.training_args.device
)
return llm_pub_logits
def _get_pub_data_kd_training_args(self):
pre_args = self.training_args.to_dict_with_server_kd_args()
post_args = Seq2SeqTrainingArguments(**pre_args)
return post_args
def train(self):
global_epochs = self.training_args.global_epochs
previous_pub_logits = None
for i, iter_ctx in self.ctx.on_iterations.ctxs_range(global_epochs):
logger.info(f"begin {i}-th global kd process")
aligend_train_set = self.on_epoch_begin(iter_ctx, i, previous_pub_logits)
if self.training_args.llm_training:
public_data_training_args = self._get_pub_data_kd_training_args()
fedmkt_trainer = FedMKTTrainer(
model=self.model,
tokenizer=self.tokenizer,
args=public_data_training_args,
train_dataset=aligend_train_set,
eval_dataset=self.val_set,
data_collator=DataCollatorForFedMKT(
self.tokenizer,
padding="max_length",
max_length=max(len(d["input_ids"]) for d in aligend_train_set),
blending_num=len(self.slm_tokenizers),
vocab_size=self.training_args.vocab_size,
dtype=next(self.model.parameters()).dtype,
distill_temperature=self.training_args.distill_temperature
),
blending_num=len(self.slm_tokenizers),
lm_loss_weight=self.training_args.kd_alpha,
distill_loss_type=self.training_args.distill_loss_type,
distill_strategy=self.training_args.distill_strategy
)
fedmkt_trainer.train()
self.model = unwrap_model(fedmkt_trainer.model)
previous_pub_logits = self.on_epoch_end(iter_ctx, i)
================================================
FILE: python/fate_llm/algo/fedmkt/fedmkt_data_collator.py
================================================
#
# NOTE: The implementations of DataCollatorForFedMKT is modified from FuseAI/FuseLLM
# Copyright FuseAI/FuseLLM
#
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from torch.nn.functional import softmax
from transformers import DataCollatorForSeq2Seq
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from typing import Optional, Any, Union
import logging
from fate_llm.algo.fedmkt.utils.vars_define import (
ALIGNED_OTHER_LOGITS,
ALIGNED_OTHER_INDICES,
PER_STEP_LOGITS,
PER_STEP_INDICES,
SELF_TARGET_DIST,
OTHER_TARGET_DIST
)
logger = logging.getLogger(__name__)
class DataCollatorForFedMKT(DataCollatorForSeq2Seq):
"""modified from https://github.com/fanqiwan/FuseAI/blob/main/FuseLLM/src/utils/data_collator.py#L135"""
tokenizer: PreTrainedTokenizerBase
model: Optional[Any] = None
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
label_pad_token_id: int = -100
return_tensors: str = "pt"
blending_num: int = 1
distill_temperature: float = 1.0
vocab_size: int = None
dtype: torch.dtype = torch.bfloat16
def __init__(self, *args, **kwargs):
blending_num = kwargs.pop("blending_num", 4)
vocab_size = kwargs.pop("vocab_size", None)
dtype = kwargs.pop("dtype", torch.dtype)
distill_temperature = kwargs.pop("distill_temperature", 1.0)
super(DataCollatorForFedMKT, self).__init__(*args, **kwargs)
self.blending_num = blending_num
self.vocab_size = vocab_size if vocab_size is not None else len(self.tokenizer.get_vocab())
self.pad_id = self.tokenizer.pad_token_id
self.dtype = dtype
self.distill_temperature = distill_temperature
def __call__(self, features, return_tensors=None):
extra_features = dict()
feature_keys = list(features[0].keys())
for f_key in feature_keys:
if f_key not in ["input_ids", "attention_mask", "labels"]:
extra_features[f_key] = []
for feature in features:
extra_features[f_key].append(feature.pop(f_key))
features = super().__call__(features=features, return_tensors=return_tensors)
features.update(extra_features)
batch_size = features["input_ids"].size(0)
base_target_dist = torch.zeros(batch_size, self.max_length, self.vocab_size).to(self.dtype)
aligned_target_dists = [torch.zeros(batch_size, self.max_length, self.vocab_size).to(self.dtype)
for _ in range(self.blending_num)]
for i in range(batch_size):
base_seq_len = len(features[PER_STEP_LOGITS][i])
for j in range(self.max_length):
if j < base_seq_len:
base_logits = torch.tensor(features[PER_STEP_LOGITS][i][j], dtype=self.dtype)
base_prob = softmax(base_logits / self.distill_temperature, -1)
base_indices = torch.tensor(features[PER_STEP_INDICES][i][j])
base_target_dist[i][j] = base_target_dist[i][j].scatter_(-1, base_indices, base_prob)
for k in range(self.blending_num):
per_step_aligned_indices_key = f"{ALIGNED_OTHER_INDICES}_{k}"
per_step_aligned_logits_key = f"{ALIGNED_OTHER_LOGITS}_{k}"
if len(features[per_step_aligned_indices_key][i][j]) > 0:
aligned_logits = torch.tensor(features[per_step_aligned_logits_key][i][j], dtype=self.dtype)
aligned_prob = softmax(aligned_logits / self.distill_temperature, -1)
aligned_indices = torch.tensor(features[per_step_aligned_indices_key][i][j])
aligned_target_dists[k][i][j] = aligned_target_dists[k][i][j].scatter_(-1, aligned_indices, aligned_prob)
else:
aligned_target_dists[k][i][j] = base_target_dist[i][j]
else: # padding position
base_target_dist[i][j][self.pad_id] = 1.0
for k in range(self.blending_num):
aligned_target_dists[k][i][j][self.pad_id] = 1.0
features.pop(PER_STEP_LOGITS)
features.pop(PER_STEP_INDICES)
for i in range(self.blending_num):
features.pop(f"{ALIGNED_OTHER_LOGITS}_{i}")
features.pop(f"{ALIGNED_OTHER_INDICES}_{i}")
features[f"{OTHER_TARGET_DIST}_{i}"] = aligned_target_dists[i]
features[SELF_TARGET_DIST] = base_target_dist
return features
================================================
FILE: python/fate_llm/algo/fedmkt/fedmkt_trainer.py
================================================
#
# NOTE: The implementations of FedMKTTrainer is modified from FuseAI/FuseLLM
# Copyright FuseAI
#
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import torch
from torch.nn.functional import kl_div, log_softmax, cross_entropy
from transformers import Seq2SeqTrainer
from transformers.modeling_utils import unwrap_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from fate_llm.algo.fedmkt.utils.vars_define import (
SELF_TARGET_DIST,
OTHER_TARGET_DIST,
ALIGNED_OTHER_METRIC,
METRIC,
)
logger = logging.getLogger(__name__)
class FedMKTTrainer(Seq2SeqTrainer):
"""
modified from https://github.com/fanqiwan/FuseAI/blob/main/FuseLLM/src/utils/trainer.py#L22
"""
blending_num: int = 2
distill_loss_type: str = "ce"
lm_loss_weight: float = 0.9
distill_strategy = "greater"
def __init__(self, *args, **kwargs):
blending_num = kwargs.pop("blending_num", 1)
distill_loss_type = kwargs.pop("distill_loss_type", "ce")
lm_loss_weight = kwargs.pop("lm_loss_weight", 0.9)
distill_strategy = kwargs.pop("distill_strategy", "greater")
super(FedMKTTrainer, self).__init__(*args, **kwargs)
self.blending_num = blending_num
self.distill_loss_type = distill_loss_type
self.lm_loss_weight = lm_loss_weight
self.distill_strategy = distill_strategy
def compute_loss(self, model, inputs, return_outputs=False):
if self.label_smoother is not None and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
base_target_dist = inputs.pop(SELF_TARGET_DIST)
base_metric = inputs.pop(METRIC)
aligned_target_dists = []
aligned_metrics = []
for i in range(self.blending_num):
aligned_target_dists.append(inputs.pop(f"{OTHER_TARGET_DIST}_{i}"))
aligned_metrics.append(inputs.pop(f"{ALIGNED_OTHER_METRIC}_{i}"))
outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
if labels is not None:
if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
loss = self.label_smoother(outputs, labels)
else:
if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
batch_size, seq_len, vocab_size = outputs["logits"].size(0), outputs["logits"].size(1), outputs["logits"].size(2)
aligned_rewards = []
for i in range(self.blending_num):
aligned_rewards.append((1 / torch.exp(torch.tensor(aligned_metrics[i], dtype=torch.bfloat16))).to(loss.device))
base_reward = (1 / torch.exp(torch.tensor(base_metric, dtype=torch.bfloat16))).to(loss.device)
if self.distill_strategy == "greater":
base_reward_expanded = base_reward.unsqueeze(-1).unsqueeze(-1).expand_as(base_target_dist)
aligned_rewards_expanded = [
aligned_rewards[i].unsqueeze(-1).unsqueeze(-1).expand_as(aligned_target_dists[i])
for i in range(self.blending_num)
]
target_dist_list = []
reward_list = []
if base_target_dist is not None:
target_dist_list.append(base_target_dist)
reward_list.append(base_reward_expanded)
target_dist_list.extend(aligned_target_dists)
reward_list.extend(aligned_rewards_expanded)
stacked_dists = torch.stack(target_dist_list, dim=-1)
stacked_rewards = torch.stack(reward_list, dim=-1)
max_reward_indices = torch.argmax(stacked_rewards, dim=-1, keepdim=True)
target_dist = torch.gather(stacked_dists, -1, max_reward_indices).squeeze(-1)
elif self.distill_strategy == "weighted_mean":
weights = torch.stack(
[base_reward] + aligned_rewards, dim=1
)
normalized_weights = torch.softmax(weights, dim=1)
weight_labels = normalized_weights[:, 0].unsqueeze(1).unsqueeze(2) * base_target_dist
for i in range(self.blending_num):
weight_labels += normalized_weights[:, i + 1].unsqueeze(1).unsqueeze(2) * aligned_target_dists[i]
target_dist = (
weight_labels
)
else:
raise ValueError(f"distill_strategy={self.distill_strategy}")
if self.distill_loss_type == "ce":
loss_lm = cross_entropy(
input=outputs["logits"].view(-1, vocab_size),
target=target_dist.view(-1, vocab_size),
reduction="none",
).view(batch_size, -1)
elif self.distill_loss_type == "kl":
loss_lm = kl_div(
input=log_softmax(outputs["logits"], dim=-1),
target=target_dist,
log_target=False,
reduction="none").sum(dim=-1)
else:
raise ValueError(f"Not implement distill_loss_type={self.distill_loss_type}")
loss_lm = (loss_lm * inputs["attention_mask"]).sum() / inputs["attention_mask"].sum()
loss = self.lm_loss_weight * loss + (1.0 - self.lm_loss_weight) * loss_lm
return (loss, outputs) if return_outputs else loss
================================================
FILE: python/fate_llm/algo/fedmkt/token_alignment/__init__.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
================================================
FILE: python/fate_llm/algo/fedmkt/token_alignment/spectal_token_mapping.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import transformers
TOKENIZER_TO_SPECIAL_TOKEN = {
transformers.LlamaTokenizer: '▁',
transformers.LlamaTokenizerFast: '▁',
transformers.GPTNeoXTokenizerFast: 'Ġ',
transformers.GPT2TokenizerFast: 'Ġ',
transformers.GPT2Tokenizer: 'Ġ',
transformers.BloomTokenizerFast: 'Ġ',
}
================================================
FILE: python/fate_llm/algo/fedmkt/token_alignment/token_align.py
================================================
#
# NOTE: The dtw function is copied from FuseAI/FuseLLM
# and the align_blending_model_logits_with_base_model_logits function is modified from FuseAI/FuseLLM
# Copyright FuseAI
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import transformers
import editdistance
import numpy as np
from typing import Dict, List
from fate_llm.algo.fedmkt.token_alignment.spectal_token_mapping import TOKENIZER_TO_SPECIAL_TOKEN
from fate_llm.algo.fedmkt.utils.vars_define import (
PER_STEP_LOGITS,
PER_STEP_INDICES,
ALIGNED_OTHER_LOGITS,
ALIGNED_OTHER_INDICES,
ALIGNED_OTHER_METRIC,
METRIC
)
logger = logging.getLogger(__name__)
def dtw(series_1, series_2, norm_func=np.linalg.norm):
"""code refer to: https://github.com/fanqiwan/FuseAI/blob/main/FuseLLM/src/utils/others.py#L318"""
matrix = np.zeros((len(series_1) + 1, len(series_2) + 1))
matrix[0, :] = np.inf
matrix[:, 0] = np.inf
matrix[0, 0] = 0
for i, vec1 in enumerate(series_1):
for j, vec2 in enumerate(series_2):
cost = norm_func(vec1, vec2)
matrix[i + 1, j + 1] = cost + min(matrix[i, j + 1], matrix[i + 1, j], matrix[i, j])
matrix = matrix[1:, 1:]
i = matrix.shape[0] - 1
j = matrix.shape[1] - 1
matches = []
mappings_series_1 = [list() for v in range(matrix.shape[0])]
mappings_series_2 = [list() for v in range(matrix.shape[1])]
while i > 0 or j > 0:
matches.append((i, j))
mappings_series_1[i].append(j)
mappings_series_2[j].append(i)
option_diag = matrix[i - 1, j - 1] if i > 0 and j > 0 else np.inf
option_up = matrix[i - 1, j] if i > 0 else np.inf
option_left = matrix[i, j - 1] if j > 0 else np.inf
move = np.argmin([option_diag, option_up, option_left])
if move == 0:
i -= 1
j -= 1
elif move == 1:
i -= 1
else:
j -= 1
matches.append((0, 0))
mappings_series_1[0].append(0)
mappings_series_2[0].append(0)
matches.reverse()
for mp in mappings_series_1:
mp.reverse()
for mp in mappings_series_2:
mp.reverse()
return matches, matrix[-1, -1], mappings_series_1, mappings_series_2, matrix
def greedy_dynamic_matching(base_model_tokens, blending_model_tokens, base_model_sp_t, blending_model_sp_t):
l1 = len(base_model_tokens)
l2 = len(blending_model_tokens)
base_model_tokens = [token.replace(base_model_sp_t, "") for token in base_model_tokens]
blending_model_tokens = [token.replace(blending_model_sp_t, "") for token in blending_model_tokens]
dp = np.full((l1 + 1, l2 + 1), -1000000000, dtype="int32")
matched_left = np.full((l1, l2), -1, dtype="int32")
matched_right = np.full((l1, l2), -1, dtype="int32")
trans_left = np.full((l1 + 1, l2 + 1), -1, dtype="int32")
trans_right = np.full((l1 + 1, l2 + 1), -1, dtype="int32")
# this can be optimizer use suffix data structure, but naive implemented for fast trial , it will be optimize later.
for i in range(l1):
for j in range(l2):
if base_model_tokens[i] == blending_model_tokens[j]:
matched_left[i][j] = 1
matched_right[i][j] = 1
continue
i2, j2 = i, j
t1 = ""
t2 = ""
sq_l1, sq_l2 = 0, 0
while i2 >= 0 and j2 >= 0:
if len(t1) > len(t2):
t2 = blending_model_tokens[j2] + t2
sq_l2 += 1
j2 -= 1
elif len(t1) < len(t2):
t1 = base_model_tokens[i2] + t1
sq_l1 += 1
i2 -= 1
else:
if sq_l1 == 0:
sq_l1 += 1
sq_l2 += 1
t1 += base_model_tokens[i2]
t2 += blending_model_tokens[j2]
i2 -= 1
j2 -= 1
continue
if t1 == t2:
matched_left[i][j] = sq_l1
matched_right[i][j] = sq_l2
break
"""
always shortest matching
"""
for i in range(0, l1 + 1):
dp[i][0] = 0
for j in range(0, l2 + 1):
dp[0][j] = 1
for i in range(0, l1):
for j in range(0, l2):
if matched_left[i][j] == -1:
dp[i + 1][j + 1] = max(dp[i + 1][j], dp[i][j + 1])
if dp[i + 1][j + 1] == dp[i + 1][j]:
trans_right[i + 1][j + 1] = j
else:
trans_left[i + 1][j + 1] = i
else:
l_len = matched_left[i][j]
r_len = matched_right[i][j]
dp[i + 1][j + 1] = max(max(dp[i + 1][j], dp[i][j + 1]), dp[i + 1 - l_len][j + 1 - r_len] + l_len)
if dp[i + 1][j + 1] == dp[i + 1 - l_len][j + 1 - r_len] + l_len:
trans_left[i + 1][j + 1] = i + 1 - l_len
trans_right[i + 1][j + 1] = j + 1 - r_len
assert l_len > 0 and r_len > 0
elif dp[i + 1][j + 1] == dp[i + 1][j]:
trans_right[i + 1][j + 1] = j
else:
trans_left[i + 1][j + 1] = i
i, j = l1, l2
matches = []
while i > 0 and j > 0:
if trans_left[i][j] != -1 and trans_right[i][j] != -1:
l = trans_left[i][j]
r = trans_right[i][j]
matches.append([(l, i - 1), (r, j - 1)])
i, j = l, r
elif trans_left[i][j] < 0:
j -= 1
else:
i -= 1
matches.reverse()
return matches
def align_blending_model_logits_with_base_model_logits(base_examples,
indices,
blending_examples,
blending_to_base_mapping,
base_tokenizer,
blending_tokenizer,
blending_model_index,
skip_align=False,
align_strategy="greedy_dp"):
"""modified from https://github.com/fanqiwan/FuseAI/blob/main/FuseLLM/src/utils/token_alignment.py#L101"""
base_features = [{key: base_examples[key][i] for key in base_examples} for i in
range(len(base_examples[next(iter(base_examples))]))]
blending_features = [blending_examples[idx] for idx in indices]
aligned_per_step_logits_list, aligned_per_step_indices_list = [], []
per_step_logits_list, per_step_indices_list = [], []
metric_ce_aligned = []
for base_feature, blending_feature in zip(base_features, blending_features):
base_feature[PER_STEP_LOGITS] = base_feature[PER_STEP_LOGITS][:len(base_feature['input_ids'])]
base_feature[PER_STEP_INDICES] = base_feature[PER_STEP_INDICES][:len(base_feature['input_ids'])]
blending_feature[PER_STEP_LOGITS] = blending_feature[PER_STEP_LOGITS][:len(blending_feature['input_ids'])]
blending_feature[PER_STEP_INDICES] = blending_feature[PER_STEP_INDICES][:len(blending_feature['input_ids'])]
if skip_align is True:
aligned_blending_model_per_step_logits = blending_feature[PER_STEP_LOGITS]
aligned_blending_model_per_step_indices = blending_feature[PER_STEP_INDICES]
else:
aligned_blending_model_per_step_logits, aligned_blending_model_per_step_indices = transform_step_logits(
base_model_tokenizer=base_tokenizer,
blending_model_tokenizer=blending_tokenizer,
base_model_vocab=base_tokenizer.get_vocab(),
base_model_input_ids=base_feature['input_ids'],
blending_model_input_ids=blending_feature['input_ids'],
blending_model_per_step_logits=blending_feature[PER_STEP_LOGITS],
blending_model_per_step_indices=blending_feature[PER_STEP_INDICES],
blending_to_base_mapping=blending_to_base_mapping,
align_strategy=align_strategy
)
aligned_per_step_logits_list.append(aligned_blending_model_per_step_logits)
aligned_per_step_indices_list.append(aligned_blending_model_per_step_indices)
per_step_logits_list.append(base_feature[PER_STEP_LOGITS])
per_step_indices_list.append(base_feature[PER_STEP_INDICES])
metric_ce_aligned.append(blending_feature[METRIC])
base_examples[PER_STEP_LOGITS] = per_step_logits_list
base_examples[PER_STEP_INDICES] = per_step_indices_list
base_examples[f"{ALIGNED_OTHER_LOGITS}_{blending_model_index}"] = aligned_per_step_logits_list
base_examples[f"{ALIGNED_OTHER_INDICES}_{blending_model_index}"] = aligned_per_step_indices_list
base_examples[f"{ALIGNED_OTHER_METRIC}_{blending_model_index}"] = metric_ce_aligned
return base_examples
def transform_step_logits(base_model_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,
blending_model_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,
base_model_vocab: Dict[str, int],
base_model_input_ids: List[int],
blending_model_input_ids: List[int],
blending_model_per_step_logits: List[List[float]],
blending_model_per_step_indices: List[List[int]],
blending_to_base_mapping: Dict[str, str] = None,
align_strategy: str = "dtw"
):
"""modified from https://github.com/fanqiwan/FuseAI/blob/main/FuseLLM/src/utils/others.py#L364"""
"""Align blending model per step logits & indices with base model."""
base_model_tokens = base_model_tokenizer.convert_ids_to_tokens(base_model_input_ids)
blending_model_tokens = blending_model_tokenizer.convert_ids_to_tokens(blending_model_input_ids)
base_model_special_token = TOKENIZER_TO_SPECIAL_TOKEN[base_model_tokenizer.__class__]
blending_model_special_token = TOKENIZER_TO_SPECIAL_TOKEN[blending_model_tokenizer.__class__]
aligned_blending_model_per_step_logits, aligned_blending_model_per_step_indices = [], []
if align_strategy == "dtw":
def dist_fn(a, b):
"""Calculate editdistance between two tokens, a is from blending model, b is from base model."""
return editdistance.eval(a.replace(blending_model_special_token, ''),
b.replace(base_model_special_token, ''))
_, _, _, base_to_blending, _ = dtw(blending_model_tokens, base_model_tokens, norm_func=dist_fn)
for i, blending_idx in enumerate(base_to_blending):
aligned_blending_model_per_step_logit = []
aligned_blending_model_per_step_index = []
if len(blending_idx) == 1: # one base token map to one blending token
j = blending_idx[0]
base_token = base_model_tokens[i]
blending_token = blending_model_tokens[j].replace(blending_model_special_token,
base_model_special_token)
if (
blending_model_tokenizer.__class__ == transformers.GPTNeoXTokenizerFast
or blending_model_tokenizer.__class__ == transformers.GPT2TokenizerFast) and i == 0 and base_token.startswith(
base_model_special_token) and not blending_token.startswith(base_model_special_token):
blending_token = base_model_special_token + blending_token # special case for mpt
if (base_token == blending_token) or (
blending_token in blending_to_base_mapping and base_token == blending_to_base_mapping[
blending_token]): # find the aligned mapping, use the corresponding logits
# the logits and indices at this step
for blending_logit, blending_index in zip(blending_model_per_step_logits[j],
blending_model_per_step_indices[j]):
# the token corresponds to the logit and indices
blending_t = blending_model_tokenizer.convert_ids_to_tokens([blending_index])[0].replace(
blending_model_special_token, base_model_special_token)
blending_t = blending_to_base_mapping[blending_t]
if blending_t in base_model_vocab:
aligned_index = base_model_vocab[blending_t] # the index of the token in base model vocab
if aligned_index not in aligned_blending_model_per_step_index:
aligned_blending_model_per_step_index.append(aligned_index)
aligned_blending_model_per_step_logit.append(blending_logit)
else:
logger.warning(f"blending_t: {blending_t} not in base_model_vocab!")
else: # find error aligned mapping, use the one-hot logits
aligned_blending_model_per_step_index.append(base_model_vocab[base_token])
aligned_blending_model_per_step_logit.append(1.0)
else: # one base token map to multiple blending token, in this case only fit base token. use the one-hot logits
base_token = base_model_tokens[i]
aligned_blending_model_per_step_index.append(base_model_vocab[base_token])
aligned_blending_model_per_step_logit.append(1.0)
aligned_blending_model_per_step_indices.append(aligned_blending_model_per_step_index)
aligned_blending_model_per_step_logits.append(aligned_blending_model_per_step_logit)
elif align_strategy == "greedy_dp":
matches = greedy_dynamic_matching(base_model_tokens, blending_model_tokens, base_model_special_token, blending_model_special_token)
fusion_logits = [[] for _ in range(len(matches))]
fusion_indices = [[] for _ in range(len(matches))]
match_pos = [-1] * len(base_model_tokens)
used = [False] * len(matches)
for idx, ((start_pos_1, end_pos_1), (start_pos_2, end_pos_2)) in enumerate(matches):
fusion_dict = dict()
fusion_counter_dict = dict()
for blending_pos in range(start_pos_2, end_pos_2 + 1):
for blending_logit, blending_index in zip(blending_model_per_step_logits[blending_pos],
blending_model_per_step_indices[blending_pos]):
if blending_index not in fusion_dict:
fusion_dict[blending_index] = 0
fusion_counter_dict[blending_index] = 0
fusion_dict[blending_index] += blending_logit
fusion_counter_dict[blending_index] += 1
for j in range(start_pos_1, end_pos_1 + 1):
match_pos[j] = idx
for token_index, token_logit in fusion_dict.items():
fusion_logits[idx].append(token_logit / fusion_counter_dict[token_index])
fusion_indices[idx].append(token_index)
for i in range(len(base_model_tokens)):
aligned_blending_model_per_step_logit = []
aligned_blending_model_per_step_index = []
if match_pos[i] == -1 or used[match_pos[i]]:
base_token = base_model_tokens[i]
aligned_blending_model_per_step_index.append(base_model_vocab[base_token])
aligned_blending_model_per_step_logit.append(1.0)
else:
pos = match_pos[i]
used[pos] = True
for blending_logit, blending_index in zip(fusion_logits[pos],
fusion_indices[pos]):
# the token corresponds to the logit and indices
blending_t = blending_model_tokenizer.convert_ids_to_tokens([blending_index])[0].replace(
blending_model_special_token, base_model_special_token)
blending_t = blending_to_base_mapping[blending_t]
if blending_t in base_model_vocab:
aligned_index = base_model_vocab[blending_t] # the index of the token in base model vocab
if aligned_index not in aligned_blending_model_per_step_index:
aligned_blending_model_per_step_index.append(aligned_index)
aligned_blending_model_per_step_logit.append(blending_logit)
else:
logger.warning(f"blending_t: {blending_t} not in base_model_vocab!")
aligned_blending_model_per_step_indices.append(aligned_blending_model_per_step_index)
aligned_blending_model_per_step_logits.append(aligned_blending_model_per_step_logit)
else:
raise ValueError(f"{align_strategy} not implemented yet.")
return aligned_blending_model_per_step_logits, aligned_blending_model_per_step_indices
def token_align(
base_model_logits_datasets,
blending_model_logits_dataset,
base_tokenizer,
blending_tokenizer,
blending_to_base_mapping,
blending_model_index,
batch_size=4,
preprocessing_num_workers=4,
skip_align=False,
align_strategy="dtw",
):
assert len(base_model_logits_datasets) == len(blending_model_logits_dataset)
base_model_blending_model_logits_datasets = base_model_logits_datasets.map(
align_blending_model_logits_with_base_model_logits,
batched=True,
batch_size=batch_size,
with_indices=True,
num_proc=preprocessing_num_workers,
load_from_cache_file=True,
fn_kwargs={"blending_examples": blending_model_logits_dataset,
"blending_to_base_mapping": blending_to_base_mapping,
"base_tokenizer": base_tokenizer,
"blending_tokenizer": blending_tokenizer,
"blending_model_index": blending_model_index,
"skip_align": skip_align,
"align_strategy": align_strategy},
keep_in_memory=True,
desc="Align blending model's logits with base model's logits.",
)
return base_model_blending_model_logits_datasets
================================================
FILE: python/fate_llm/algo/fedmkt/token_alignment/vocab_mapping.py
================================================
#
# NOTE: The find_best_mapping function is copied from FuseAI/FuseLLM
# Copyright FuseAI/FuseLLM
#
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import editdistance
import tqdm
import multiprocessing
import logging
from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer
from fate_llm.algo.fedmkt.token_alignment.spectal_token_mapping import TOKENIZER_TO_SPECIAL_TOKEN
logger = logging.getLogger(__name__)
def find_best_mapping(x, base_tokens, blending_model_special_token, base_model_special_token, best_one=True):
"""code refer to https://github.com/fanqiwan/FuseAI/blob/main/FuseLLM/src/utils/vocab_mapping.py#L82"""
tmp_x = x.replace(blending_model_special_token, base_model_special_token)
if tmp_x in base_tokens:
return tmp_x, tmp_x
else:
if best_one:
return tmp_x, min([(y, editdistance.eval(tmp_x, y)) for y in base_tokens], key=lambda d: d[1])[0]
else:
token_and_distance = [(y, editdistance.eval(tmp_x, y)) for y in base_tokens]
min_distance = min(item[1] for item in token_and_distance)
shortest_distance_tokens = [item[0] for item in token_and_distance if item[1] == min_distance]
return tmp_x, shortest_distance_tokens
def get_vocab_mappings(model_name_or_path, candidate_model_name_or_path, vocab_mapping_save_path, num_processors=8):
ori_tokenizer = get_tokenizer(model_name_or_path)
candidate_tokenizer = get_tokenizer(candidate_model_name_or_path)
ori_special_tok = TOKENIZER_TO_SPECIAL_TOKEN[ori_tokenizer.__class__]
candidate_special_tok = TOKENIZER_TO_SPECIAL_TOKEN[candidate_tokenizer.__class__]
candidate_tokens = list(candidate_tokenizer.get_vocab().keys())
with multiprocessing.Pool(num_processors) as process_pool:
func_args = [(tok, candidate_tokens, ori_special_tok, candidate_special_tok) for tok in ori_tokenizer.get_vocab()]
vocab_mappings = dict(tqdm.tqdm(process_pool.starmap(find_best_mapping, func_args)),
total=len(ori_tokenizer.get_vocab()))
with open(vocab_mapping_save_path, "w") as fout:
json.dump(vocab_mappings, fout)
return vocab_mappings
================================================
FILE: python/fate_llm/algo/fedmkt/utils/__init__.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
================================================
FILE: python/fate_llm/algo/fedmkt/utils/dataset_sync_util.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import datasets
import torch
import torch.distributed as dist
from fate_llm.algo.fedmkt.utils.vars_define import (
METRIC,
PER_STEP_LOGITS,
PER_STEP_INDICES,
)
logger = logging.getLogger(__name__)
def sync_dataset(dataset, local_rank, world_size, device):
integer_keys_2d = ["input_ids", "attention_mask", "labels"]
integer_keys_3d = [PER_STEP_INDICES]
float_keys_3d = [PER_STEP_LOGITS]
float_keys_1d = [METRIC]
if local_rank == 0:
for key in integer_keys_2d + integer_keys_3d + float_keys_3d + float_keys_1d:
if key in integer_keys_2d or key in integer_keys_3d:
dtype = torch.int32
else:
dtype = torch.float64
values = dataset[key]
v_tensor = torch.tensor(values, dtype=dtype).cuda(device)
shape_tensor = torch.tensor(v_tensor.shape, dtype=torch.int32).cuda(device)
shape_tensors = [shape_tensor for _ in range(world_size)]
dist.scatter(shape_tensor, shape_tensors, async_op=False)
v_tensors = [v_tensor for _ in range(world_size)]
dist.scatter(v_tensor, v_tensors, async_op=False)
return dataset
else:
data_dict = dict()
for key in integer_keys_2d + integer_keys_3d + float_keys_3d + float_keys_1d:
if key in integer_keys_2d or key in integer_keys_3d:
dtype = torch.int32
else:
dtype = torch.float64
if key in integer_keys_2d:
shape_tensor = torch.tensor([0, 0], dtype=torch.int32).cuda(device)
elif key in float_keys_3d or key in integer_keys_3d:
shape_tensor = torch.tensor([0, 0, 0], dtype=torch.int32).cuda(device)
else:
shape_tensor = torch.tensor([0], dtype=torch.int32).cuda(device)
dist.scatter(shape_tensor, src=0, async_op=False)
v_tensor = torch.zeros(shape_tensor.tolist(), dtype=dtype).cuda(device)
dist.scatter(v_tensor, src=0, async_op=False)
data_dict[key] = v_tensor.tolist()
return datasets.Dataset.from_dict(data_dict)
================================================
FILE: python/fate_llm/algo/fedmkt/utils/generate_logit_utils.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import torch.nn.functional as F
import gc
from fate_llm.algo.fedmkt.utils.vars_define import (
PER_STEP_LOGITS,
PER_STEP_INDICES,
METRIC
)
class Metric(object):
@classmethod
def cal_metric(cls, logits, input_ids, attention_mask, labels, training_args):
if training_args.metric_type == "ce":
return cls.cal_ce(logits, input_ids, attention_mask, labels, training_args)
else:
raise NotImplemented(f"metric={training_args.metric_type} is not implemented yet")
@classmethod
def cal_ce(cls, logits, input_ids, attention_mask, labels, training_args):
metric = F.cross_entropy(logits[..., :-1, :].contiguous().view(-1, logits.size(-1)),
labels[..., 1:].contiguous().view(-1), reduction="none").view(logits.size(0), -1)
metric = (metric * attention_mask[..., 1:]).sum(dim=-1) / attention_mask[..., 1:].sum(dim=-1)
return metric
class LogitsSelection(object):
@classmethod
def select_logits(cls, logits, training_args):
if training_args.top_k_strategy == "highest":
return cls.select_highest(logits, training_args.top_k_logits_keep)
else:
raise NotImplemented(f"logits selection strategy={training_args.top_k_strategy} is not implemented")
@classmethod
def select_highest(cls, logits, top_k_logits_keep):
top_k_logits, top_k_indices = torch.topk(logits.cuda(), k=top_k_logits_keep)
logits.cpu()
return top_k_logits, top_k_indices
def generate_pub_data_logits(inputs, model, training_args, data_collator):
input_keys = ["attention_mask", "input_ids", "labels"]
inputs_per_batched = [dict() for _ in range(len(inputs[input_keys[1]]))]
for key in input_keys:
if key not in inputs:
continue
for idx, _in in enumerate(inputs[key]):
inputs_per_batched[idx][key] = _in
if "attention_mask" not in inputs:
for idx in range(len(inputs_per_batched)):
inputs_per_batched[idx]["attention_mask"] = [1] * len(inputs_per_batched[idx]["input_ids"])
inputs_per_batched = data_collator(inputs_per_batched)
input_ids = inputs_per_batched["input_ids"]
attention_mask = inputs_per_batched["attention_mask"]
labels = inputs_per_batched["labels"]
device = next(model.parameters()).device
if device.type == "cuda":
input_ids = input_ids.cuda(device)
attention_mask = attention_mask.cuda(device)
labels = labels.cuda(device)
model.eval()
with torch.no_grad():
logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
metric = Metric.cal_metric(logits, input_ids, attention_mask, labels, training_args)
input_ids.cpu()
del input_ids
attention_mask.cpu()
del attention_mask
labels.cpu()
del labels
logits.cpu()
metric.cpu()
if training_args.top_k_logits_keep is None:
raise ValueError("Please specify top_k_logits_keep, fulling save will leak to memory exceeds")
selected_logits, selected_indices = LogitsSelection.select_logits(logits=logits, training_args=training_args)
selected_logits.cpu()
selected_indices.cpu()
inputs[PER_STEP_LOGITS] = selected_logits
inputs[PER_STEP_INDICES] = selected_indices
inputs[METRIC] = metric
del logits
gc.collect()
model.train()
return inputs
================================================
FILE: python/fate_llm/algo/fedmkt/utils/tokenizer_tool.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from transformers import AutoConfig
def get_vocab_size(tokenizer_name_or_path):
if tokenizer_name_or_path is not None:
return AutoConfig.from_pretrained(tokenizer_name_or_path)
================================================
FILE: python/fate_llm/algo/fedmkt/utils/vars_define.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
PER_STEP_LOGITS = "per_step_logits"
PER_STEP_INDICES = "per_step_indices"
METRIC = "metric"
ALIGNED_OTHER_LOGITS = "aligned_other_logits"
ALIGNED_OTHER_INDICES = "aligned_other_indices"
ALIGNED_OTHER_METRIC = "aligned_other_metrice"
SELF_TARGET_DIST = "llm_target_distribution"
OTHER_TARGET_DIST = "slm_target_distribution"
INPUT_KEYS = {"input_ids", "attention_mask", "labels"}
================================================
FILE: python/fate_llm/algo/inferdpt/__init__.py
================================================
================================================
FILE: python/fate_llm/algo/inferdpt/_encode_decode.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from fate.arch import Context
from typing import List, Dict
import logging
logger = logging.getLogger(__name__)
class EncoderDecoder(object):
def __init__(self, ctx: Context) -> None:
self.ctx = ctx
def encode(self, docs: List[Dict[str, str]], format_template: str):
pass
def decode(self, docs: List[Dict[str, str]], format_template: str ):
pass
def inference(self, docs: List[Dict[str, str]], inference_kwargs: dict = {}, format_template: str = None):
pass
================================================
FILE: python/fate_llm/algo/inferdpt/inferdpt.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
from jinja2 import Template
from tqdm import tqdm
from fate.arch import Context
from typing import List, Dict, Union
from fate.ml.nn.dataset.base import Dataset
from fate_llm.algo.inferdpt.utils import InferDPTKit
from openai import OpenAI
import logging
from fate_llm.inference.inference_base import Inference
from fate_llm.algo.inferdpt._encode_decode import EncoderDecoder
from fate_llm.dataset.hf_dataset import HuggingfaceDataset
logger = logging.getLogger(__name__)
class InferDPTClient(EncoderDecoder):
def __init__(self, ctx: Context, inferdpt_pertub_kit: InferDPTKit, local_inference_inst: Inference, epsilon: float = 3.0,) -> None:
self.ctx = ctx
self.kit = inferdpt_pertub_kit
assert epsilon > 0, 'epsilon must be a float > 0'
self.ep = epsilon
self.comm_idx = 0
self.local_inference_inst = local_inference_inst
def encode(self, docs: List[Dict[str, str]], format_template: str = None, verbose=False, perturb_doc_key: str ='perturbed_doc') -> List[Dict[str, str]]:
copy_docs = copy.deepcopy(docs)
if format_template is not None:
template = Template(format_template)
else:
template = None
for doc in tqdm(copy_docs):
if template is None:
rendered_doc = str(doc)
else:
rendered_doc = template.render(**doc)
if verbose:
logger.debug('doc to perturb {}'.format(rendered_doc))
p_doc = self.kit.perturb(rendered_doc, self.ep)
doc[perturb_doc_key] = p_doc
return copy_docs
def _remote_inference(self, docs: List[Dict[str, str]],
inference_kwargs: dict = {},
format_template: str = None,
perturbed_response_key: str = 'perturbed_response',
verbose=False
) -> List[Dict[str, str]]:
copy_docs = copy.deepcopy(docs)
if format_template is not None:
template = Template(format_template)
else:
template = None
infer_docs = []
for doc in tqdm(copy_docs):
if template is None:
rendered_doc = str(doc)
else:
rendered_doc = template.render(**doc)
if verbose:
logger.debug('inference doc {}'.format(rendered_doc))
infer_docs.append(rendered_doc)
doc['perturbed_doc_with_instrcution'] = rendered_doc
self.ctx.arbiter.put('client_data_{}'.format(self.comm_idx), (infer_docs, inference_kwargs))
perturb_resp = self.ctx.arbiter.get('pdoc_{}'.format(self.comm_idx))
self.comm_idx += 1
for pr, doc in zip(perturb_resp, copy_docs):
doc[perturbed_response_key] = pr
return copy_docs
def decode(self, p_docs: List[Dict[str, str]], instruction_template: str = None, decode_template: str = None, verbose=False,
perturbed_response_key: str = 'perturbed_response', result_key: str = 'inferdpt_result',
remote_inference_kwargs: dict = {}, local_inference_kwargs: dict = {}):
# inference using remote large models
docs_with_infer_result = self._remote_inference(p_docs, format_template=instruction_template, verbose=verbose, inference_kwargs=remote_inference_kwargs, perturbed_response_key=perturbed_response_key)
if decode_template is not None:
dt = Template(decode_template)
doc_to_decode = [dt.render(**i) for i in docs_with_infer_result]
else:
doc_to_decode = [str(i) for i in docs_with_infer_result]
# local model decode
final_result = self.local_inference_inst.inference(doc_to_decode, local_inference_kwargs)
for final_r, d in zip(final_result, docs_with_infer_result):
d[result_key] = final_r
return docs_with_infer_result
def inference(self, docs: Union[List[Dict[str, str]], HuggingfaceDataset],
encode_template: str,
instruction_template: str,
decode_template: str,
verbose: bool = False,
remote_inference_kwargs: dict = {},
local_inference_kwargs: dict = {},
perturb_doc_key: str = 'perturbed_doc',
perturbed_response_key: str = 'perturbed_response',
result_key: str = 'inferdpt_result',
) -> List[Dict[str, str]]:
assert (isinstance(docs, list) and isinstance(docs[0], dict)) or isinstance(docs, HuggingfaceDataset), 'Input doc must be a list of dict or HuggingfaceDataset'
# perturb doc
if isinstance(docs, HuggingfaceDataset):
docs = [docs[i] for i in range(len(docs))]
docs_with_p = self.encode(docs, format_template=encode_template, verbose=verbose, perturb_doc_key=perturb_doc_key)
logger.info('encode done')
# inference using perturbed doc
final_result = self.decode(
docs_with_p,
instruction_template,
decode_template,
verbose,
perturbed_response_key,
result_key,
remote_inference_kwargs,
local_inference_kwargs,
)
logger.info('decode done')
return final_result
class InferDPTServer(object):
def __init__(self, ctx: Context, inference_inst: Inference) -> None:
self.ctx = ctx
self.inference_inst = inference_inst
self.comm_idx = 0
def inference(self, verbose=False):
client_data = self.ctx.guest.get('client_data_{}'.format(self.comm_idx))
perturbed_docs, inference_kwargs = client_data
if verbose:
logger.info('got data {}'.format(client_data))
logger.info('start inference')
rs_doc = self.inference_inst.inference(perturbed_docs, inference_kwargs)
self.ctx.guest.put('pdoc_{}'.format(self.comm_idx), rs_doc)
self.comm_idx += 1
def predict(self):
self.inference()
================================================
FILE: python/fate_llm/algo/inferdpt/init/_init.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from fate.arch import Context
from typing import Union
class InferInit(object):
def __init__(self, ctx: Context):
self.ctx = ctx
def get_inst(self):
pass
================================================
FILE: python/fate_llm/algo/inferdpt/init/default_init.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from fate_llm.algo.inferdpt.init._init import InferInit
from fate_llm.inference.api import APICompletionInference
from fate_llm.algo.inferdpt import inferdpt
from fate_llm.algo.inferdpt.utils import InferDPTKit
from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer
class InferDPTAPIClientInit(InferInit):
api_url = ''
api_model_name = ''
api_key = 'EMPTY'
inferdpt_kit_path = ''
eps = 3.0
def __init__(self, ctx):
super().__init__(ctx)
self.ctx = ctx
def get_inst(self)-> InferDPTClient:
inference = APICompletionInference(api_url=self.api_url, model_name=self.api_model_name, api_key=self.api_key)
kit = InferDPTKit.load_from_path(self.inferdpt_kit_path)
inferdpt_client = inferdpt.InferDPTClient(self.ctx, kit, inference, epsilon=self.eps)
return inferdpt_client
class InferDPTAPIServerInit(InferInit):
api_url = ''
api_model_name = ''
api_key = 'EMPTY'
def __init__(self, ctx):
super().__init__(ctx)
self.ctx = ctx
def get_inst(self)-> InferDPTServer:
inference = APICompletionInference(api_url=self.api_url, model_name=self.api_model_name, api_key=self.api_key)
inferdpt_server = inferdpt.InferDPTServer(self.ctx,inference_inst=inference)
return inferdpt_server
================================================
FILE: python/fate_llm/algo/inferdpt/utils.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Parts of the codes are modified from https://github.com/mengtong0110/InferDPT
"""
from decimal import getcontext
from transformers import AutoTokenizer
import numpy as np
import json
import tqdm
from typing import List
getcontext().prec = 100
class NumpyEncoder(json.JSONEncoder):
""" Special json encoder for numpy types """
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
def save_jsonl(filename, data):
with open(filename, 'w') as file:
for item in data:
json.dump(item, file)
file.write('\n')
def create_sensitivity_of_embeddings(all_embedding_matrix):
n_dimensions = all_embedding_matrix.shape[1]
delta_f_new = np.zeros(n_dimensions)
for dim in tqdm.trange(n_dimensions):
dim_data = all_embedding_matrix[:, dim]
sorted_dim_data = np.sort(dim_data)
differences = sorted_dim_data[-1] - sorted_dim_data[0]
delta_f_new[dim] = differences
return delta_f_new
def create_sorted_embedding_matrix(token_list, similarity_matrix):
token_2_sorted_distances = dict()
token_array = np.array(token_list)
for idx, token in tqdm.tqdm(enumerate(token_list)):
similarity_array = similarity_matrix[idx]
sorted_indices = np.argsort(similarity_array)[::-1]
token_2_sorted_distances[token] = [token_array[sorted_indices].tolist(), similarity_array[sorted_indices].tolist()]
return token_2_sorted_distances
def cosine_similarity_vectors(A, B):
dot_product = np.dot(A, B)
norm_a = np.linalg.norm(A)
norm_b = np.linalg.norm(B)
similarity = dot_product / (norm_a * norm_b)
return similarity
class InferDPTKit(object):
def __init__(self, token_to_vector_dict, sorted_similarities, delta_f, tokenizer) -> None:
self.token_to_vector_dict = token_to_vector_dict
self.sorted_similarities = sorted_similarities
self.delta_f = delta_f
self.tokenizer = tokenizer
assert len(token_to_vector_dict) == len(sorted_similarities)
def save_to_path(self, path):
# make folder
import os
if not os.path.exists(path+'/inferdpt_kit'):
os.makedirs(path+'/inferdpt_kit')
with open(path+'/inferdpt_kit/token_2_vector.json', 'w', encoding='utf8') as f:
json.dump(self.token_to_vector_dict, f, ensure_ascii=False, cls=NumpyEncoder)
with open(path+'/inferdpt_kit/sorted_similarities.json', 'w') as f:
json.dump(self.sorted_similarities, f, cls=NumpyEncoder)
with open(path+'/inferdpt_kit/delta_f.json', 'w') as f:
json.dump(self.delta_f, f, cls=NumpyEncoder)
self.tokenizer.save_pretrained(path+'/inferdpt_kit/tokenizer/')
@staticmethod
def make_inferdpt_kit_param(embedding_matrix: np.ndarray, token_list: List[str]):
def cosine_simi(embedding_matrix1, embedding_matrix2):
dot_product = np.dot(embedding_matrix1, embedding_matrix2.T)
norm_matrix1 = np.linalg.norm(embedding_matrix1, axis=1)
norm_matrix2 = np.linalg.norm(embedding_matrix2, axis=1)
similarity = dot_product / (np.outer(norm_matrix1, norm_matrix2))
return similarity
assert len(embedding_matrix) == len(token_list)
similarity_matrix = cosine_simi(embedding_matrix, embedding_matrix)
token_sorted_distance_dict = create_sorted_embedding_matrix(token_list, similarity_matrix)
delta_f_new = create_sensitivity_of_embeddings(embedding_matrix)
token_2_embedding = {}
for token, embedding in zip(token_list, embedding_matrix):
token_2_embedding[token] = embedding
return token_2_embedding, token_sorted_distance_dict, delta_f_new
@staticmethod
def load_from_path(path):
with open(path+'/inferdpt_kit/token_2_vector.json', 'r', encoding='utf8') as f:
token_to_vector_dict = json.load(f)
with open(path+'/inferdpt_kit/sorted_similarities.json', 'r') as f:
sorted_similarities = json.load(f)
with open(path+'/inferdpt_kit/delta_f.json', 'r') as f:
delta_f = np.array(json.load(f))
tokenizer = AutoTokenizer.from_pretrained(path+'/inferdpt_kit/tokenizer/')
inferdpt_kit = InferDPTKit(token_to_vector_dict, sorted_similarities, delta_f, tokenizer)
return inferdpt_kit
def perturb(self, doc: str, epsilon: float) -> str:
# epsilon > 0
assert epsilon > 0, "epsilon should be greater than 0"
tokenizer = self.tokenizer
tokens = tokenizer.tokenize(doc)
new_tokens = []
Delta_u = 1.0
exp_factor = epsilon / (2 * Delta_u)
for origin_token in tokens:
if origin_token[0] == ' ':
origin_token = origin_token[1:]
origin_embed = self.token_to_vector_dict.get(origin_token, None)
if origin_embed is None:
new_tokens.append(origin_token)
continue
noise_embed = add_laplace_noise_to_vector(origin_embed, epsilon, self.delta_f)
similarity = cosine_similarity_vectors(origin_embed, noise_embed)
sorted_distances_for_token = self.sorted_similarities.get(origin_token, None)
if sorted_distances_for_token is None:
continue
token_only = sorted_distances_for_token[0]
similarity_only = sorted_distances_for_token[1]
arr = np.flip(similarity_only)
index = np.searchsorted(arr, similarity)
index = len(arr) - index
close_tokens = token_only[:index]
close_similarities = similarity_only[:index]
if len(close_tokens) == 0:
continue
unnormalized_probabilities = np.exp(exp_factor * np.array(close_similarities))
total_unnormalized_prob = np.sum(unnormalized_probabilities)
probabilities = unnormalized_probabilities / total_unnormalized_prob
selected_token = np.random.choice(close_tokens, p=probabilities)
new_tokens.append(selected_token)
token_ids = tokenizer.convert_tokens_to_ids(new_tokens)
sentence = tokenizer.decode(token_ids)
return sentence
def cosine_similarity_vectors(A, B):
dot_product = np.dot(A, B)
norm_a = np.linalg.norm(A)
norm_b = np.linalg.norm(B)
similarity = dot_product / (norm_a * norm_b)
return similarity
def add_laplace_noise_to_vector(vector, epsilon, delta_f_new):
vector = np.asarray(vector, dtype=np.longdouble)
if epsilon == 0:
beta_values = delta_f_new * 0
else:
beta_values = delta_f_new / (0.5 * epsilon)
noise = np.random.laplace(loc=0, scale=beta_values, size=len(beta_values))
noisy_vector = vector + noise
return noisy_vector
def perturb_sentence(sent,
epsilon,
tokenizer,
token_to_vector_dict,
sorted_distance_data,
delta_f_new):
tokens = tokenizer.tokenize(sent)
new_tokens = []
Delta_u = 1.0
exp_factor = epsilon / (2 * Delta_u)
for origin_token in tokens:
if origin_token[0] == ' ':
origin_token = origin_token[1:]
origin_embed = token_to_vector_dict.get(origin_token, None)
if origin_embed is None:
new_tokens.append(origin_token)
continue
noise_embed = add_laplace_noise_to_vector(origin_embed, epsilon, delta_f_new)
similarity = cosine_similarity_vectors(origin_embed, noise_embed)
sorted_distances_for_token = sorted_distance_data.get(origin_token, None)
if sorted_distances_for_token is None:
continue
token_only = sorted_distances_for_token[0]
similarity_only = sorted_distances_for_token[1]
arr = np.flip(similarity_only)
index = np.searchsorted(arr, similarity)
index = len(arr) - index
close_tokens = token_only[:index]
close_similarities = similarity_only[:index]
if len(close_tokens) == 0:
continue
unnormalized_probabilities = np.exp(exp_factor * np.array(close_similarities))
total_unnormalized_prob = np.sum(unnormalized_probabilities)
probabilities = unnormalized_probabilities / total_unnormalized_prob
selected_token = np.random.choice(close_tokens, p=probabilities)
new_tokens.append(selected_token)
token_ids = tokenizer.convert_tokens_to_ids(new_tokens)
sentence = tokenizer.decode(token_ids)
return sentence
================================================
FILE: python/fate_llm/algo/offsite_tuning/__init__.py
================================================
================================================
FILE: python/fate_llm/algo/offsite_tuning/offsite_tuning.py
================================================
from fate.ml.aggregator.base import Aggregator
from fate_llm.algo.fedavg.fedavg import Seq2SeqFedAVGClient, Seq2SeqFedAVGServer, Seq2SeqTrainingArguments
from fate.ml.nn.trainer.trainer_base import FedArguments, TrainingArguments
from typing import List, Optional, Callable, Tuple
from fate.arch import Context
from torch.optim import Optimizer
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import _LRScheduler
from transformers.trainer_callback import TrainerCallback
from torch.nn import Module
from transformers import TrainerState, TrainerControl, PreTrainedTokenizer
from fate_llm.model_zoo.offsite_tuning.offsite_tuning_model import OffsiteTuningBaseModel
import logging
import torch
import torch.distributed as dist
from transformers.modeling_utils import unwrap_model
logger = logging.getLogger(__name__)
class OffsiteTuningTrainerClient(Seq2SeqFedAVGClient):
def __init__(
self,
ctx: Context,
model: OffsiteTuningBaseModel,
training_args: Seq2SeqTrainingArguments,
fed_args: FedArguments,
train_set: Dataset,
val_set: Dataset = None,
optimizer: Optimizer = None,
scheduler: _LRScheduler = None,
data_collator: Callable = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
callbacks: List[TrainerCallback] = [],
compute_metrics: Callable = None,
aggregate_model: bool = False,
save_trainable_weights_only: bool = False,
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
):
assert isinstance(model, OffsiteTuningBaseModel), "model must be the subclass of OffsiteTuningBaseModel"
if aggregate_model == False and fed_args is None:
fed_args = FedArguments()
elif fed_args is None:
raise ValueError("fed_args must be provided when aggregate_model is True")
local_mode = True if not aggregate_model else False
super().__init__(
ctx,
model,
training_args,
fed_args,
train_set,
val_set,
optimizer,
scheduler,
data_collator,
tokenizer,
callbacks,
compute_metrics,
local_mode,
save_trainable_weights_only,
preprocess_logits_for_metrics
)
self._aggregate_model = aggregate_model
def _share_model(self, model, args: Seq2SeqTrainingArguments, sync_trainable_only=True):
if args.local_rank == 0:
for p in model.parameters():
if (not sync_trainable_only) or (sync_trainable_only and p.requires_grad):
scatter_list = [p.data for _ in range(args.world_size)]
dist.scatter(p.data, scatter_list, async_op=False)
else:
for p in model.parameters():
if (not sync_trainable_only) or (sync_trainable_only and p.requires_grad):
dist.scatter(p.data, src=0, async_op=False)
def on_train_begin(self, ctx: Context, aggregator: Aggregator, fed_args: FedArguments,
args: TrainingArguments, model: Module = None, optimizer: Optimizer = None, scheduler: _LRScheduler = None,
dataloader: Tuple[DataLoader]= None, control: TrainerControl= None,
state: TrainerState = None, **kwargs):
if args.local_rank == 0: # master
logger.info('receving weights from server')
parameters_to_get = ctx.arbiter.get('sub_model_para')
model = unwrap_model(model)
model.load_submodel_weights(parameters_to_get)
logger.info('received submodel weigths from the server')
if args.world_size > 1:
self._share_model(model, args)
logger.info('sharing model parameters done')
else:
if args.world_size > 1:
model = unwrap_model(model)
self._share_model(model, args)
logger.info('sharing model parameters done')
def on_federation(
self,
ctx: Context,
aggregator,
fed_args: FedArguments,
args: TrainingArguments,
model: Optional[OffsiteTuningBaseModel] = None,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[_LRScheduler] = None,
dataloader: Optional[Tuple[DataLoader]] = None,
control: Optional[TrainerControl] = None,
state: Optional[TrainerState] = None,
**kwargs,
):
if self._aggregate_model:
aggregator.model_aggregation(ctx, model)
def on_train_end(self, ctx: Context, aggregator: Aggregator, fed_args: FedArguments,
args: TrainingArguments, model: OffsiteTuningBaseModel = None, optimizer: Optimizer = None, scheduler: _LRScheduler = None,
dataloader: Tuple[DataLoader]= None, control: TrainerControl= None,
state: TrainerState = None, **kwargs):
if args.local_rank == 0:
if args.world_size > 1:
model = unwrap_model(model)
return_weights = model.get_submodel_weights(with_emulator=False)
ctx.arbiter.put('trained_sub_model_para', return_weights)
logger.info('weights sent back to the server')
def init_aggregator(self, ctx: Context, fed_args: FedArguments):
if self._aggregate_model:
return super().init_aggregator(ctx, fed_args)
else:
return None
class OffsiteTuningTrainerServer(Seq2SeqFedAVGServer):
def __init__(self, ctx: Context, model: OffsiteTuningBaseModel, aggregate_model=False) -> None:
self._aggregate_model = aggregate_model
super().__init__(ctx, local_mode=False)
assert isinstance(model, OffsiteTuningBaseModel), "model must be the subclass of OffsiteTuningBaseModel"
self.model = model
def on_train_begin(self, ctx: Context, aggregator: Aggregator):
logger.info('sending weights to clients')
parameters_to_send = self.model.get_submodel_weights()
ctx.guest.put('sub_model_para', parameters_to_send)
if any(p.role=='host' for p in ctx.parties):
ctx.hosts.put('sub_model_para', parameters_to_send)
def on_train_end(self, ctx: Context, aggregator: Aggregator):
parameters_to_get = ctx.guest.get('trained_sub_model_para')
self.model.load_submodel_weights(parameters_to_get, with_emulator=False)
logger.info('received trained submodel weigths from the client')
def on_federation(self, ctx: Context, aggregator, agg_iter_idx: int):
if self._aggregate_model:
aggregator.model_aggregation(ctx)
else:
logger.info('skip aggregation')
def init_aggregator(self, ctx):
if self._aggregate_model:
return super().init_aggregator(ctx)
else:
return None
def train(self):
if self._aggregate_model:
super().train()
else:
# do nothing but send the submodel weights to the client
# and then aggregate the weights from the client
self.on_init_end(self.ctx, aggregator=self.aggregator)
self.on_train_begin(self.ctx, aggregator=self.aggregator)
self.on_train_end(self.ctx, aggregator=self.aggregator)
def save_model(
self,
output_dir: Optional[str] = None,
state_dict=None
):
import torch
import os
if not os.path.exists(output_dir):
os.makedirs(output_dir)
torch.save(self.model.state_dict(), output_dir + '/pytorch_model.bin')
================================================
FILE: python/fate_llm/algo/ppc-gpt/__init__.py
================================================
================================================
FILE: python/fate_llm/data/__init__.py
================================================
================================================
FILE: python/fate_llm/data/data_collator/__init__.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
================================================
FILE: python/fate_llm/data/data_collator/cust_data_collator.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from transformers.data import data_collator
from ..tokenizers.cust_tokenizer import get_tokenizer
def get_data_collator(data_collator_name,
tokenizer_name_or_path=None,
pad_token=None,
bos_token=None,
eos_token=None,
pad_token_id=None,
bos_token_id=None,
eos_token_id=None,
trust_remote_code=False, **kwargs):
if not hasattr(data_collator, data_collator_name):
support_collator_list = list(filter(lambda module_name: "collator" in module_name.lower(), dir(data_collator)))
return ValueError(f"data_collator's name={data_collator_name} does not in support list={support_collator_list}")
tokenizer = get_tokenizer(tokenizer_name_or_path=tokenizer_name_or_path,
pad_token=pad_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
trust_remote_code=trust_remote_code)
return getattr(data_collator, data_collator_name)(tokenizer, **kwargs)
def get_seq2seq_data_collator(tokenizer_name_or_path, **kwargs):
return get_data_collator("DataCollatorForSeq2Seq", tokenizer_name_or_path=tokenizer_name_or_path, **kwargs)
================================================
FILE: python/fate_llm/data/data_collator/fedcot_collator.py
================================================
from transformers import DataCollatorForSeq2Seq
from transformers import AutoTokenizer
import pandas as pd
class PrefixDataCollator(DataCollatorForSeq2Seq):
def __call__(self, features, return_tensors=None):
features_df = pd.DataFrame(features)
cot = super().__call__(list(features_df['predict']), return_tensors)
label = super().__call__(list(features_df['rationale']), return_tensors)
return {
'predict': cot,
'rationale': label
}
def get_prefix_data_collator(tokenizer_name_or_path):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
data_collator = PrefixDataCollator(tokenizer)
return data_collator
================================================
FILE: python/fate_llm/data/tokenizers/__init__.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
================================================
FILE: python/fate_llm/data/tokenizers/cust_tokenizer.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from transformers import AutoTokenizer
def get_tokenizer(
tokenizer_name_or_path,
trust_remote_code=False,
padding_side=None,
pad_token=None,
bos_token=None,
eos_token=None,
pad_token_id=None,
bos_token_id=None,
eos_token_id=None,
add_eos_token=True,
):
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path,
trust_remote_code=trust_remote_code,
add_eos_token=add_eos_token
)
if padding_side is not None:
tokenizer.padding_side = padding_side
if pad_token is not None:
tokenizer.add_special_tokens({'pad_token': pad_token})
if bos_token is not None:
tokenizer.add_special_tokens({'bos_token': bos_token})
if eos_token is not None:
tokenizer.add_special_tokens({"eos_token": eos_token})
if pad_token_id is not None:
tokenizer.pad_token_id = pad_token_id
if bos_token_id is not None:
tokenizer.bos_token_id = bos_token_id
if eos_token_id is not None:
tokenizer.eos_token_id = eos_token_id
if "llama" in tokenizer_name_or_path.lower() or "gpt2" in tokenizer_name_or_path.lower():
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
================================================
FILE: python/fate_llm/dataset/__init__.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
================================================
FILE: python/fate_llm/dataset/data_config/__init__.py
================================================
import os
# absolute path to current directory
parent_dir = os.path.dirname(os.path.realpath(__file__))
DATA_CONFIG_TEMPLATE = {"ag_news": os.path.join(parent_dir, "default_ag_news.yaml"),
"yelp_review": os.path.join(parent_dir, "default_yelp_review.yaml"),}
================================================
FILE: python/fate_llm/dataset/data_config/default_ag_news.yaml
================================================
dataset_kwargs:
data_files: ag_news_review/AGnews/train.json
dataset_path: json
doc_to_target: '{{label}}'
metric_list:
- aggregation: mean
higher_is_better: true
metric: accuracy
output_type: generate_until
task: ag-news
validation_split: train
label_key: label
text_key: text
sub_domain: AGnews
few_shot_num_per_label: 2
tokenize_format: "Product type: {{sub_domain}} | Text Category: {{label}}"
few_shot_format: "- : {{label}}.\n- : {{text}}\n\n"
augment_format: "The news' topics belong to the following 4 categories: 0.world 1.sports 2.business 3.science and technology. Please generate news according to the following format, bearing in mind that the generated results should not resemble the examples, but should align with the specified category: \n"
text_with_label_format: "******\n {{i}}.\nNews: {{text}}\nCategory: {{label}}.\n"
filter_format: "I will give you some news samples with their categories, The news' topics belong to the following 4 categories: 0.world 1.sports 2.business 3.science and technology. the samples are delimited by '******':\n {text_with_label} Please filter out texts that are ambiguous, do not belong to news or do not meet the categories, and leave news texts that meet the categories.\n You should also filter out news text that are too similar to other samples and keep the most representative ones. Your answer should begin with 'The eligible samples:\n\n' and the indexes of the texts you choose, use spaces to separate the indexes and do not provide duplicate indices or indices that exceed the maximum index of samples."
label_list:
- 'world'
- 'sports'
- 'business'
- 'science and technology'
================================================
FILE: python/fate_llm/dataset/data_config/default_yelp_review.yaml
================================================
dataset_kwargs:
data_files: yelp_review/Health/train.json
dataset_path: json
doc_to_target: '{{label}}'
metric_list:
- aggregation: mean
higher_is_better: true
metric: accuracy
output_type: generate_until
task: yelp-review
label_key: stars
text_key: text
validation_split: train
sub_domain: Health
few_shot_num_per_label: 2
tokenize_format: "Product type: {{sub_domain}} | Review Score: {{label}}"
text_with_label_format: "******\n {{i}}.\nReview: {{text}}\nRating stars: {{label}}.\n"
few_shot_format: "******\n- : {{label}} stars.\n- : {{text}}\n\n"
augment_format: "The reviews are rated from 1 to 5 stars, with 1 being the worst, 3 being neutral and 5 being the best. Please generate more similar samples for each rating star about the Health domain as shown in the following format, bearing in mind that the generated results should not copy or resemble the examples, and should align with the {{sub_domain}} domain and the rating stars.\nThe examples are delimited by '******'."
filter_format: "I will give you some customer review text samples with their rating stars, these samples are indexed starting from 0, the samples are delimited by '******':\n {{text_with_label}}. These reviews gradually shift from negative to positive from 1 star to 5 stars. 1 star represents the worst, 2 stars are better than 1 star, but still indicate a negative review. 3 stars represent a neutral review. 4 stars indicate a positive review, but less positive than 5 stars. 5 stars represent perfection.\n Please filter out text that does not belong to customer reviews or does not meet the rating stars, and leave review texts that meet the labels.\n You should also filter out text that are too similar to other samples and keep the most representative ones. Your answer should begin with 'The eligible samples:\n\n' and the indexes of the texts you choose, use spaces to separate the indexes and do not provide duplicate indices or indices that exceed the maximum index of samples."
label_list:
- 1
- 2
- 3
- 4
- 5
================================================
FILE: python/fate_llm/dataset/fedcot_dataset.py
================================================
from fate_llm.dataset.input_output_dataset import InputOutputDataset
from transformers.trainer_pt_utils import LabelSmoother
from typing import List, Dict, Union, Literal
import logging
from jinja2 import Template
from transformers import AutoTokenizer
logger = logging.getLogger(__name__)
class PrefixDataset(InputOutputDataset):
def __init__(self,
tokenizer_path,
predict_input_template: str,
predict_output_template: str,
rationale_input_template: str,
rationale_output_template: str,
max_input_length: int = 256,
max_target_length: int = 256,
load_from: Literal['jsonl', 'hf_load_from_disk', 'hf_load_dataset'] = 'hf_load_from_disk',
split_key: str = None
):
super().__init__(tokenizer_path, predict_input_template, predict_output_template, max_input_length, max_target_length, load_from, split_key)
self.r_input_template = Template(rationale_input_template)
self.r_output_template = Template(rationale_output_template)
def load_rationale(self, result_list, key='rationale'):
for d, r in zip(self.dataset, result_list):
d[key] = r
def get_str_item(self, i) -> dict:
data_item = self.dataset[i]
p_in = self.input_template.render(data_item)
p_out = self.output_template.render(data_item)
r_in = self.r_input_template.render(data_item)
r_out = self.r_output_template.render(data_item)
ret_dict = {
'predict':{
'input': p_in,
'output': p_out
},
'rationale':{
'input': r_in,
'output': r_out
}
}
return ret_dict
def get_tokenized_item(self, i) -> dict:
str_item = self.get_str_item(i)
ret_dict = {
'predict': self._process_item(str_item['predict']),
'rationale': self._process_item(str_item['rationale'])
}
return ret_dict
================================================
FILE: python/fate_llm/dataset/flex_dataset.py
================================================
#
# Copyright 2024 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import pickle
import re
from datasets import load_dataset
from fastchat.model import get_conversation_template
from jinja2 import Template
from ruamel import yaml
from transformers import AutoTokenizer
from typing import Union, Literal
from fate.ml.nn.dataset.base import Dataset
from fate_llm.dataset.data_config import DATA_CONFIG_TEMPLATE
logger = logging.getLogger(__name__)
"""
Implementation of FDKT augmentation process, adopted from https://arxiv.org/abs/2405.14212
"""
def get_jinjax_placeholders(jinjax_text, placeholder_count=2):
pattern = r"<([^>]+)>"
matches = re.findall(pattern, jinjax_text)
return matches[:placeholder_count]
def regex_replace(string, pattern, repl, count: int = 0):
"""
adopted from lm-evaluation-harness/lm-eval/utils.py for offline use
Parameters
----------
string
pattern
repl
count
Returns
-------
"""
return re.sub(pattern, repl, string, count=count)
def apply_template(template, data):
"""
adopted from lm-evaluation-harness/lm-eval/utils.py for offline use
Parameters
----------
template
data
Returns
-------
"""
return Template(template).render(data)
def tokenize_flex_dataset(raw_datasets, tokenizer, sub_domain, tokenize_format, text_key, label_key, data_part="train",
save_path=None, max_prompt_len=256):
tokenizer.pad_token = tokenizer.eos_token
column_names = raw_datasets[data_part].column_names
def tokenize_function(examples):
texts = tokenizer(examples[text_key])
label_processed = [apply_template(tokenize_format,{"sub_domain": sub_domain,"label": label})
for label in examples[label_key]]
labels = tokenizer(label_processed)
input_ids = [i2 + i1 for i1, i2 in zip(texts['input_ids'], labels['input_ids'])]
attention_mask = [i2 + i1 for i1, i2 in zip(texts['attention_mask'], labels['attention_mask'])]
"""
cut off max prompt length
"""
input_ids = [t[: max_prompt_len] for t in input_ids]
attention_mask = [t[: max_prompt_len] for t in attention_mask]
out = {"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": input_ids}
return out
tokenized_datasets = raw_datasets.map(
tokenize_function,
batched=True,
num_proc=4,
remove_columns=column_names,
desc="Running tokenizer on dataset",
)
if save_path is not None:
tokenized_datasets.save_to_disk(save_path)
return tokenized_datasets
class FlexDataset(Dataset):
def __init__(self,
tokenizer_name_or_path,
dataset_name: str,
load_from: Literal['json'] = 'json',
data_part: str = None,
config: Union[dict, str] = None,
need_preprocess: bool = True,
random_state: int = None,
max_prompt_len: int = 256,
select_num: int = None,
few_shot_num_per_label: int = None
):
super().__init__()
self.tokenizer = None
self.tokenizer_name_or_path = tokenizer_name_or_path
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name_or_path, trust_remote_code=True)
self.dataset_name = dataset_name
if self.dataset_name and config is None:
config = DATA_CONFIG_TEMPLATE.get(self.dataset_name, "")
self.load_from = load_from
self.data_part = data_part
self.random_state = random_state
self.need_preprocess = need_preprocess
self.max_prompt_len = max_prompt_len
self.select_num = select_num
self.dataset = None
self.ds = None
self.label_key = None
self.text_key = None
self.augment_format = None
self.filter_format = None
self.few_shot_format = None
self.tokenize_format = None
self.sub_domain = None
self.label_list = None
self.text_with_label_format = None
self.few_shot_num_per_label = few_shot_num_per_label
self.config = config
if isinstance(config, str):
with open(config, 'r') as f:
self.config = yaml.safe_load(f)
self.parse_config()
def parse_config(self, config=None):
if config is None:
config = self.config
self.label_key = config.get("label_key", None)
self.text_key = config.get("text_key", None)
self.augment_format = config.get("augment_format", None)
self.filter_format = config.get("filter_format", None)
self.tokenize_format = config.get("tokenize_format", None)
self.sub_domain = config.get("sub_domain", None)
self.label_list = config.get("label_list", None)
self.few_shot_format = config.get("few_shot_format", None)
self.text_with_label_format = config.get("text_with_label_format", None)
if self.few_shot_num_per_label is None:
self.few_shot_num_per_label = config.get("few_shot_num_per_label", 2)
def get_generate_prompt(self, tokenize=True, return_tensors="pt"):
prompt_list = [apply_template(self.tokenize_format,
{"sub_domain": self.sub_domain,
"label": label}) for label in self.label_list]
if tokenize:
tokenized_prompts = self.tokenizer(prompt_list, return_tensors=return_tensors)
prompt_list = tokenized_prompts['input_ids']
return {label: prompt for label, prompt in zip(self.label_list, prompt_list)}
@staticmethod
def construct_prompt_list(samples_dict, num_shot_per_label, prompt_num, format_template, random_state=None):
from sklearn.utils import resample
from collections import deque
label_samples = {label: deque(resample(samples,
replace=False,
n_samples=len(samples))) for label, samples in samples_dict.items()}
def get_samples_for_label(label):
samples = []
while len(samples) < num_shot_per_label:
remaining_needed = num_shot_per_label - len(samples)
if len(label_samples[label]) < remaining_needed:
batch_samples = list(label_samples[label])
samples.extend(batch_samples)
# reset to allow repetition
label_samples[label] = deque(resample(samples_dict[label],
replace=False,
n_samples=len(samples_dict[label])))
else:
batch_samples = [label_samples[label].popleft() for _ in range(remaining_needed)]
samples.extend(batch_samples)
return samples
result = []
for _ in range(prompt_num):
prompt = ''
for label in samples_dict.keys():
samples = get_samples_for_label(label)
for text in samples:
prompt += apply_template(format_template, {"text": text, "label": label})
result.append(prompt)
return result
@staticmethod
def group_text_label_list(text_list, label_list):
group_data = [{"text": text, "label": label} for text, label in zip(text_list, label_list)]
return group_data
def prepare_few_shot(self, text_list, label_list, aug_prompt_num):
from collections import defaultdict
data_dict = defaultdict(list)
for text, label in zip(text_list, label_list):
# in case extra labels are present, ignore
if label in self.label_list:
data_dict[label].append(text)
few_shot_list = FlexDataset.construct_prompt_list(samples_dict=data_dict,
num_shot_per_label=self.few_shot_num_per_label,
prompt_num=aug_prompt_num,
format_template=self.few_shot_format,
random_state=self.random_state)
return few_shot_list
def prepare_augment(self, text_list, label_list, aug_prompt_num):
few_shot_samples = self.prepare_few_shot(text_list, label_list, aug_prompt_num)
result = []
instruction = apply_template(self.augment_format, {"sub_domain": self.sub_domain})
for i, sample in enumerate(few_shot_samples):
query = instruction + '\n' + sample
formatted_query = self.apply_chat_template(query)
result.append(formatted_query)
return result
def abstract_from_augmented(self, sample_list):
label_key, text_key = get_jinjax_placeholders(self.few_shot_format, 2)
res = {'inputs': [], 'labels': []}
for sample in sample_list:
data_list = sample.split('\n\n-')
for entry in data_list:
temp = entry.split(f"<{text_key}>:")
# print(f"temp: {temp}")
if len(temp) == 2 and f"<{label_key}>" in temp[0]:
label_str, input_str = temp
label = label_str.split(f"<{label_key}>:")[1].strip()
if isinstance(self.label_list[0], int) and label[0].isdigit():
label = int(label[0])
elif isinstance(self.label_list[0], float) and re.match(r'^\d+\.\d*?$', label):
label = float(label[0])
# abstracted label value does not match the original label type
elif isinstance(self.label_list[0], int) or isinstance(self.label_list[0], float):
continue
text = input_str.replace('', '').rstrip('*')
text = text.strip()
res['inputs'].append(text)
res['labels'].append(label)
# print(f"res: {res}")
return res
def prepare_query_to_filter_clustered(self, clustered_sentences_list, clustered_labels_list):
prompt_list = []
for clustered_sentences, clustered_labels in zip(clustered_sentences_list, clustered_labels_list):
text_with_label = ''
for i in range(len(clustered_sentences)):
formatted_entry = apply_template(self.text_with_label_format, {"i": i,
"text": clustered_sentences[i],
"label": clustered_labels[i]})
text_with_label += formatted_entry
cluster_query = apply_template(self.filter_format, {"text_with_label": text_with_label})
prompt_list.append(self.apply_chat_template(cluster_query))
return prompt_list
def parse_clustered_response(self, clustered_sentence, clustered_labels, response_list):
"""
Parse the response from the clustering model and filter the data per cluster.
:param clustered_sentence: nested list of clustered sentences
:param clustered_labels: nested list of clustered labels
:param response_list: list of responses from the clustering model
"""
def parse_response(response):
pattern = r'The eligible samples:\s*((?:\b\d+\b[\s.,]*)+)'
matches = re.search(pattern, response, re.MULTILINE)
if matches:
digits = [int(i) for i in re.findall(r'\b\d+\b', matches.group())]
else:
digits = []
return list(set(digits))
filtered_text_list = []
filtered_label_list = []
for i in range(len(clustered_sentence)):
parsed_response = parse_response(response_list[i])
for idx in parsed_response:
if idx < len(clustered_sentence[i]):
filtered_label_list.append(clustered_labels[i][idx])
filtered_text_list.append(clustered_sentence[i][idx])
return filtered_text_list, filtered_label_list
@staticmethod
def group_data_list(data_list, text_key, label_key):
inputs = [entry[text_key] for entry in data_list]
labels = [entry[label_key] for entry in data_list]
data_dict = {text_key: inputs, label_key: labels}
return data_dict
def load(self, path):
local_data = load_dataset('json', data_files={self.data_part: path})
self.dataset = local_data
if not self.need_preprocess:
self.ds = local_data
else:
tokenized_ds = tokenize_flex_dataset(
raw_datasets=local_data,
tokenizer=self.tokenizer,
sub_domain=self.sub_domain,
tokenize_format=self.tokenize_format,
text_key=self.text_key,
label_key=self.label_key,
max_prompt_len=self.max_prompt_len
)
self.ds = tokenized_ds[self.data_part]
if self.select_num is not None:
self.ds = self.ds.select(range(self.select_num))
def apply_chat_template(self, query):
tokenizer = self.tokenizer
if "llama-3" in self.tokenizer_name_or_path.lower():
msg = [
{"role": "system", "content": "You are a helpful assistant. "},
{"role": "user", "content": query}
]
prompt = tokenizer.apply_chat_template(msg, add_generation_prompt=True, tokenize=False)
else:
conv = get_conversation_template(self.tokenizer_name_or_path)
conv.append_message(conv.roles[0], query)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
return prompt
def get_raw_dataset(self):
return self.dataset
def __len__(self):
return len(self.ds)
def get_item(self, i):
return self.dataset[self.data_part][i]
def get_item_dict(self, i):
return {"text": self.dataset[self.data_part][self.text_key][i],
"label": self.dataset[self.data_part][self.label_key][i]}
def __getitem__(self, i) -> dict:
return self.ds[i]
================================================
FILE: python/fate_llm/dataset/hf_dataset.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from typing import Optional, Union, Sequence, Mapping, Dict
from datasets import load_dataset, Features, Split, DownloadConfig, DownloadMode, VerificationMode, Version, load_from_disk
from transformers import AutoTokenizer
from fate.ml.nn.dataset.base import Dataset
# avoid tokenizer parallelism
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class HuggingfaceDataset(Dataset):
"""
A dataset class for huggingface datasets
"""
def __init__(
self,
name: Optional[str] = None,
data_dir: Optional[str] = None,
data_files: Optional[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]] = None,
split: Optional[Union[str, Split]] = None,
cache_dir: Optional[str] = None,
features: Optional[Features] = None,
download_config: Optional[DownloadConfig] = None,
download_mode: Optional[Union[DownloadMode, str]] = None,
verification_mode: Optional[Union[VerificationMode, str]] = None,
ignore_verifications="deprecated",
keep_in_memory: Optional[bool] = None,
save_infos: bool = False,
revision: Optional[Union[str, Version]] = None,
token: Optional[Union[bool, str]] = None,
use_auth_token="deprecated",
task="deprecated",
streaming: bool = False,
num_proc: Optional[int] = None,
storage_options: Optional[Dict] = None,
trust_remote_code: bool = None,
tokenizer_params: Optional[Dict] = None,
tokenizer_apply_params: Optional[Dict] = None,
load_from_disk: Optional[bool] = False,
inplace_load: Optional[bool] = True,
data_split_key: Optional[str] = None,
**config_kwargs,
):
self.name = name
self.data_dir = data_dir
self.data_files = data_files
self.split = split
self.cache_dir = cache_dir
self.features = features
self.download_config = download_config
self.download_mode = download_mode
self.verification_mode = verification_mode
self.ignore_verifications = ignore_verifications
self.keep_in_memory = keep_in_memory
self.save_infos = save_infos
self.revision = revision
self.token = token
self.use_auth_token = use_auth_token
self.task = task
self.streaming = streaming
self.num_proc = num_proc
self.storage_options = storage_options
self.trust_remote_code = trust_remote_code
self.tokenizer_params = tokenizer_params
self.tokenizer_apply_params = tokenizer_apply_params
self.config_kwargs = config_kwargs
self.load_from_disk = load_from_disk
self.inplace_load = inplace_load
self.data_split_key = data_split_key
self.ds = None
super(HuggingfaceDataset, self).__init__()
def load(self, file_path):
if not self.load_from_disk:
ds = load_dataset(path=file_path, name=self.name, data_dir=self.data_dir, data_files=self.data_files,
split=self.split, cache_dir=self.cache_dir, features=self.features,
download_config=self.download_config, download_mode=self.download_mode,
verification_mode=self.verification_mode, ignore_verifications=self.ignore_verifications,
keep_in_memory=self.keep_in_memory, save_infos=self.save_infos, revision=self.revision,
token=self.token, use_auth_token=self.use_auth_token, task=self.task,
streaming=self.streaming, num_proc=self.num_proc, storage_options=self.storage_options,
trust_remote_code=self.trust_remote_code, **self.config_kwargs)
else:
ds = load_from_disk(file_path)
if self.data_split_key is not None:
ds = ds[self.data_split_key]
if self.inplace_load:
self.ds = ds
else:
return ds
def __getitem__(self, idx):
if self.ds is None:
raise ValueError('Dataset is not loaded')
return self.ds[idx]
def __len__(self):
if self.ds is None:
raise ValueError('Dataset is not loaded')
return len(self.ds)
class Dolly15K(HuggingfaceDataset):
INSTRUCTION_KEY = "### Instruction:"
INPUT_KEY = "Input:"
RESPONSE_KEY = "### Response:"
END_KEY = "### End"
RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n"
DEFAULT_SEED = 42
INTRO_BLURB = (
"Below is an instruction that describes a task. Write a response that appropriately completes the request."
)
PROMPT_NO_INPUT_FORMAT = """{intro}
{instruction_key}
{instruction}
{response_key}
{response}
{end_key}""".format(
intro=INTRO_BLURB,
instruction_key=INSTRUCTION_KEY,
instruction="{instruction}",
response_key=RESPONSE_KEY,
response="{response}",
end_key=END_KEY,
)
# This is a training prompt that contains an input string that serves as context for the instruction. For example,
# the input might be a passage from Wikipedia and the intruction is to extract some information from it.
PROMPT_WITH_INPUT_FORMAT = """{intro}
{instruction_key}
{instruction}
{input_key}
{input}
{response_key}
{response}
{end_key}""".format(
intro=INTRO_BLURB,
instruction_key=INSTRUCTION_KEY,
instruction="{instruction}",
input_key=INPUT_KEY,
input="{input}",
response_key=RESPONSE_KEY,
response="{response}",
end_key=END_KEY,
)
def __init__(self, *args, **kwargs):
super(Dolly15K, self).__init__(*args, **kwargs)
self.inplace_load = False
def load(self, file_path):
dataset = super().load(file_path)
return self._post_process(dataset)
def _post_process(self, dataset):
def _add_text(rec):
instruction = rec["instruction"]
response = rec["response"]
context = rec.get("context")
if not instruction:
raise ValueError(f"Expected an instruction in: {rec}")
if not response:
raise ValueError(f"Expected a response in: {rec}")
# For some instructions there is an input that goes along with the instruction, providing context for the
# instruction. For example, the input might be a passage from Wikipedia and the instruction says to extract
# some piece of information from it. The response is that information to extract. In other cases there is
# no input. For example, the instruction might be open QA such as asking what year some historic figure was
# born.
if context:
rec["text"] = self.PROMPT_WITH_INPUT_FORMAT.format(instruction=instruction, response=response,
input=context)
else:
rec["text"] = self.PROMPT_NO_INPUT_FORMAT.format(instruction=instruction, response=response)
return rec
dataset = dataset.map(_add_text)
tokenizer = AutoTokenizer.from_pretrained(**self.tokenizer_params)
def tokenize_function(examples):
return tokenizer(examples["text"], **self.tokenizer_apply_params)
dataset = dataset.map(tokenize_function, batched=True)
return dataset
================================================
FILE: python/fate_llm/dataset/input_output_dataset.py
================================================
from fate.ml.nn.dataset.base import Dataset
from transformers.trainer_pt_utils import LabelSmoother
from typing import List, Dict, Union, Literal
import logging
from jinja2 import Template
from transformers import AutoTokenizer
logger = logging.getLogger(__name__)
class InputOutputDataset(Dataset):
def __init__(self,
tokenizer_path,
input_template: str,
output_template: str,
max_input_length: int = 256,
max_target_length: int = 256,
load_from: Literal['jsonl', 'hf_load_from_disk', 'hf_load_dataset'] = 'hf_load_from_disk',
split_key: str = None
):
super().__init__()
self.tokenizer = None
self.tokenizer_path = tokenizer_path
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path, trust_remote_code=True)
self.max_input_length = max_input_length
self.max_target_length = max_target_length
self.dataset = None
self.load_from = load_from
self.input_template = Template(input_template)
self.output_template = Template(output_template)
self.split_key = split_key
self.max_seq_length = max_input_length + max_target_length + 1
def load(self, path):
if self.load_from == 'hf_load_from_disk':
import datasets
self.dataset = datasets.load_from_disk(path)
if self.split_key is not None:
self.dataset = self.dataset[self.split_key]
self.dataset = [i for i in self.dataset]
elif self.load_from == 'jsonl':
import json
with open(path, 'r') as f:
json_lines = f.read().split('\n')
self.dataset = []
for i in json_lines:
try:
self.dataset.append(json.loads(i))
except:
print('skip line')
elif self.load_from == 'hf_load_dataset':
from datasets import load_dataset
self.dataset = load_dataset(path)
if self.split_key is not None:
self.dataset = self.dataset[self.split_key]
self.dataset = [i for i in self.dataset]
else:
raise ValueError('unknown load format')
if not isinstance(self.dataset, list) or not isinstance(self.dataset[0], dict):
logger.warn('loaded dataset is expected to be a list of dict')
def get_raw_dataset(self):
return self.dataset
def __len__(self):
return len(self.dataset)
def get_str_item(self, i) -> dict:
data_item = self.dataset[i]
in_ = self.input_template.render(**data_item)
out_ = self.output_template.render(**data_item)
return {
'input': in_,
'output': out_
}
def _process_item(self, data_item):
a_ids = self.tokenizer.encode(text=data_item['input'], add_special_tokens=True, truncation=True,
max_length=self.max_input_length)
b_ids = self.tokenizer.encode(text=data_item['output'], add_special_tokens=False, truncation=True,
max_length=self.max_target_length)
context_length = len(a_ids)
input_ids = a_ids + b_ids + [self.tokenizer.eos_token_id]
labels = [self.tokenizer.pad_token_id] * context_length + b_ids + [self.tokenizer.eos_token_id]
pad_len = self.max_seq_length - len(input_ids)
input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len
labels = labels + [self.tokenizer.pad_token_id] * pad_len
labels = [(l if l != self.tokenizer.pad_token_id else -100) for l in labels]
assert len(input_ids) == len(labels), f"length mismatch: {len(input_ids)} vs {len(labels)}"
return {
"input_ids": input_ids,
"labels": labels
}
def get_tokenized_item(self, i) -> dict:
str_item = self.get_str_item(i)
ret_dict = self._process_item(str_item)
return ret_dict
def __getitem__(self, i) -> dict:
item = self.get_tokenized_item(i)
return item
================================================
FILE: python/fate_llm/dataset/prompt_dataset.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
import json
import datasets
import torch
from fate.ml.nn.dataset.base import Dataset
from ..data.tokenizers.cust_tokenizer import get_tokenizer
PROMPT_TEMPLATE = "{prompt}"
class PromptDataset(Dataset):
def __init__(self,
text_max_length=512,
tokenizer_name_or_path=None,
trust_remote_code=False,
padding=False,
padding_side='left',
pad_token=None,
pad_token_id=None,
bos_token_id=None,
eos_token_id=None,
add_eos_token=True,
prompt_template=None,
add_special_tokens=False,
prompt_column="content",
response_column="summary",
max_prompt_length=256,
file_type="jsonl",
num_proc=4,
):
super(PromptDataset, self).__init__()
self.tokenizer = None
self.tokenizer_name_or_path = tokenizer_name_or_path
self.padding = padding
self.add_special_tokens = add_special_tokens
self.max_prompt_length = max_prompt_length
self.text_max_length = text_max_length
self.tokenizer = get_tokenizer(
tokenizer_name_or_path=tokenizer_name_or_path,
trust_remote_code=trust_remote_code,
pad_token=pad_token,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
padding_side=padding_side,
add_eos_token=add_eos_token,
)
self.prompt_template = prompt_template if prompt_template else PROMPT_TEMPLATE
self.prompt_column = prompt_column
self.response_column = response_column
self.file_type = file_type
self.num_proc = num_proc
self._data = None
def load(self, file_path):
if "jsonl" in self.file_type:
prompts = []
responses = []
with open(file_path, "r") as fin:
for line in fin:
line = json.loads(line)
prompts.append(line[self.prompt_column])
responses.append(line[self.response_column])
ds = datasets.Dataset.from_dict({self.prompt_column: prompts, self.response_column: responses})
else:
ds = datasets.load_from_disk(file_path)
self._data = ds.map(
self._process_data,
fn_kwargs={"tokenizer": self.tokenizer,
"prompt_template": self.prompt_template,
"prompt_column": self.prompt_column,
"response_column": self.response_column,
"max_prompt_length": self.max_prompt_length,
"max_length": self.text_max_length
},
batched=True,
remove_columns=ds.column_names,
num_proc=self.num_proc,
)
max_length = None
for d in self._data:
if max_length is None:
max_length = len(d["input_ids"])
else:
max_length = max(max_length, len(d["input_ids"]))
self._data = self._data.map(
self._pad_to_max_length,
batched=True,
fn_kwargs={
"tokenizer": self.tokenizer,
"max_length": max_length
},
num_proc=self.num_proc
)
@staticmethod
def _process_data(examples, tokenizer, prompt_template, prompt_column,
response_column, max_prompt_length, max_length):
prompts = examples[prompt_column]
responses = examples[response_column]
processed_data = dict()
input_ids_list = []
labels_list = []
attention_mask_list = []
for _prompt, _response in zip(prompts, responses):
if isinstance(_response, list):
_response = _response[0]
_prompt = prompt_template.format_map(dict(prompt=_prompt))
prompt_encoded = tokenizer(_prompt)
if len(prompt_encoded['input_ids']) > 0 and prompt_encoded['input_ids'][-1] in tokenizer.all_special_ids:
prompt_encoded['input_ids'] = prompt_encoded['input_ids'][:-1]
prompt_encoded['attention_mask'] = prompt_encoded['attention_mask'][:-1]
target_encoded = tokenizer(_response)
if len(target_encoded['input_ids']) > 0 and target_encoded['input_ids'][-1] in tokenizer.all_special_ids:
target_encoded['input_ids'] = target_encoded['input_ids'][:-1]
target_encoded['attention_mask'] = target_encoded['attention_mask'][:-1]
prompt_ids = prompt_encoded["input_ids"][: max_prompt_length]
prompt_attention_mask = prompt_encoded["attention_mask"][:max_prompt_length]
target_ids = target_encoded["input_ids"][: max_length - len(prompt_ids) - 1]
target_attention_mask = target_encoded["attention_mask"][: max_length - len(prompt_ids) - 1]
if tokenizer.bos_token_id is not None:
seq_length = len(prompt_ids) + 1
input_ids = prompt_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
labels = [-100] * seq_length + input_ids[seq_length:]
attention_mask = prompt_attention_mask + [1] + target_attention_mask + [1]
else:
seq_length = len(prompt_ids)
input_ids = prompt_ids + target_ids + [tokenizer.eos_token_id]
labels = [-100] * seq_length + input_ids[seq_length:]
attention_mask = prompt_attention_mask + target_attention_mask + [1]
input_ids_list.append(input_ids)
labels_list.append(labels)
attention_mask_list.append(attention_mask)
processed_data["labels"] = labels_list
processed_data["input_ids"] = input_ids_list
processed_data["attention_mask"] = attention_mask_list
return processed_data
@staticmethod
def _pad_to_max_length(examples, tokenizer, max_length):
padded_input_ids = []
padded_labels = []
padded_attention_mask = []
labels_list = examples["labels"]
input_ids_list = examples["input_ids"]
attention_mask_list = examples["attention_mask"]
for input_ids, attention_mask, labels in zip(input_ids_list, attention_mask_list, labels_list):
l = len(input_ids)
input_ids = torch.LongTensor(input_ids + [tokenizer.pad_token_id] * (max_length - l))
labels = torch.LongTensor(labels + [-100] * (max_length - l))
attention_mask = torch.LongTensor(attention_mask + [0] * (max_length - l))
padded_input_ids.append(input_ids)
padded_labels.append(labels)
padded_attention_mask.append(attention_mask)
return dict(
input_ids=padded_input_ids,
attention_mask=padded_attention_mask,
labels=padded_labels
)
def get_vocab_size(self):
return self.tokenizer.vocab_size
def __getitem__(self, item):
return self._data[item]
def __len__(self):
return len(self._data)
def __repr__(self):
return self.tokenizer.__repr__()
================================================
FILE: python/fate_llm/dataset/qa_dataset.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datasets import load_from_disk, load_dataset
from transformers import AutoTokenizer
from fate.ml.nn.dataset.base import Dataset
"""
These Data pre-processing templates are from https://github.com/mit-han-lab/offsite-tuning
"""
class PIQA:
def __init__(self):
self._template = "Question: {}\nAnswer:"
def get_context(self, examples):
ctx = examples['goal']
return [self._template.format(c) for c in ctx]
def get_target(self, examples):
if -1 in examples["label"]: # test set
return [""] * len(examples["label"])
else:
gt_tuples = [("sol{}".format(label + 1), idx)
for idx, label in enumerate(examples['label'])]
return [examples[k][i] for k, i in gt_tuples]
class SciQ:
def __init__(self):
self._template = "{}\nQuestion: {}\nAnswer:"
def get_context(self, examples):
sources = examples['support']
queries = examples['question']
return [self._template.format(s, q) for s, q in zip(sources, queries)]
def get_target(self, examples):
return examples['correct_answer']
class OpenBookQA:
def get_context(self, examples):
return examples['question_stem']
def get_target(self, examples):
choices = examples['choices']
answers = examples['answerKey']
targets = []
for choice, answer in zip(choices, answers):
answer = ord(answer.strip()) - ord('A')
targets.append(choice['text'][answer])
return targets
class ARC:
def __init__(self):
self._template = "Question: {}\nAnswer:"
def get_context(self, examples):
ctx = examples['question']
return [self._template.format(c) for c in ctx]
def get_target(self, examples):
choices = examples['choices']
answers = examples['answerKey']
num_to_letter = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"}
for idx, answer in enumerate(answers):
answer = num_to_letter.get(answer, answer)
answer = ord(answer) - ord("A")
answers[idx] = choices[idx]["text"][answer]
return answers
class WIC:
def __init__(self):
self._template = "Sentence 1: {}\nSentence 2: {}\nQuestion: Is the word '{}' used in the same way in the" \
" two sentences above?\nAnswer:"
def get_context(self, examples):
sentences_1 = examples["sentence1"]
sentences_2 = examples["sentence2"]
starts_1 = examples["start1"]
ends_1 = examples["end1"]
contexts = []
for s1, s2, st, ed in zip(sentences_1, sentences_2, starts_1, ends_1):
contexts.append(
self._template.format(s1, s2, s1[st: ed])
)
return contexts
def get_target(self, examples):
labels = examples["label"]
targets = []
for label in labels:
targets.append(" {}".format({0: "no", 1: "yes"}[label]))
return targets
class BoolQ:
def __init__(self):
self._template = "{}\nQuestion: {}?\nAnswer:"
def get_context(self, examples):
passages = examples["passage"]
questions = examples["question"]
return [self._template.format(passage, question)
for passage, question in zip(passages, questions)
]
def get_target(self, examples):
return [" " + "yes" if label else "no" for label in examples["answer"]]
class CommonsenseQA:
def get_context(self, examples):
return examples["question"]
def get_target(self, examples):
choices = examples['choices']
answers = examples['answerKey']
targets = []
for choice, answer in zip(choices, answers):
answer = ord(answer.strip()) - ord('A')
targets.append(choice['text'][answer])
return targets
class RTE:
def __init__(self):
self._template = "{}\nQuestion: {} True or False?\nAnswer:"
def get_context(self, examples):
sentences_1 = examples["premise"]
sentences_2 = examples["hypothesis"]
contexts = []
for sentence_1, sentence_2 in zip(sentences_1, sentences_2):
contexts.append(
self._template.format(sentence_1, sentence_2)
)
return contexts
def get_target(self, examples):
labels = examples["label"]
return [" {}".format({0: "True", 1: "False"}[label]) for label in labels]
task_dict = {
"piqa": PIQA(),
"sciq": SciQ(),
"openbookqa": OpenBookQA(),
"arc_easy": ARC(),
"arc_challenge": ARC(),
"wic": WIC(),
"boolq": BoolQ(),
"commonsenseqa": CommonsenseQA(),
"rte": RTE()
}
def tokenize_qa_dataset(dataset_name, tokenizer, save_path=None, seq_max_len=1000, data_part="train", dataset=None):
max_len = seq_max_len
assert dataset_name in task_dict.keys(), f"dataset name must be one of {list(task_dict.keys())}"
if dataset is None:
raw_datasets = load_dataset(dataset_name)
else:
raw_datasets = dataset
task = task_dict[dataset_name]
column_names = raw_datasets[data_part].column_names
def tokenize_function(examples):
context = task.get_context(examples)
target = task.get_target(examples)
context = tokenizer(context)
target = tokenizer(target)
# if context is ending with special token, remove it
if len(context['input_ids'][0]) > 0 and context['input_ids'][0][-1] in tokenizer.all_special_ids:
context['input_ids'] = [i[:-1] for i in context['input_ids']]
context['attention_mask'] = [a[:-1]
for a in context['attention_mask']]
# if target is starting with special token, remove it
if len(target['input_ids'][0]) > 0 and target['input_ids'][0][0] in tokenizer.all_special_ids:
target['input_ids'] = [i[1:] for i in target['input_ids']]
target['attention_mask'] = [a[1:]
for a in target['attention_mask']]
out = {}
out['input_ids'] = [i1 + i2 for i1,
i2 in zip(context['input_ids'], target['input_ids'])]
out['attention_mask'] = [a1 + a2 for a1,
a2 in zip(context['attention_mask'], target['attention_mask'])]
# set -100 for context tokens
out["labels"] = [
[-100] * len(i1) + i2 for i1, i2 in zip(context['input_ids'], target['input_ids'])]
return out
tokenized_datasets = raw_datasets.map(
tokenize_function,
batched=True,
num_proc=4,
remove_columns=column_names,
load_from_cache_file=True,
desc="Running tokenizer on dataset",
)
# pad all instances in lm_datasets to the max length of the dataset
max_length = -1
for v in tokenized_datasets.values():
for x in v:
max_length = max(max_length, len(x['input_ids']))
# pad to the multiple of 8
max_length = (max_length // 8 + 1) * 8
block_size = max_len
max_length = min(max_length, block_size)
def pad_function(examples):
examples["input_ids"] = [i + [tokenizer.pad_token_id] *
(max_length - len(i)) for i in examples["input_ids"]]
examples["attention_mask"] = [[1] * len(i) + [0] *
(max_length - len(i)) for i in examples["attention_mask"]]
examples["labels"] = [i + [-100] *
(max_length - len(i)) for i in examples["labels"]]
# truncate to max_length
examples["input_ids"] = [i[:max_length] for i in examples["input_ids"]]
examples["attention_mask"] = [a[:max_length]
for a in examples["attention_mask"]]
examples["labels"] = [l[:max_length] for l in examples["labels"]]
return examples
tokenized_datasets = tokenized_datasets.map(
pad_function,
batched=True,
num_proc=4,
load_from_cache_file=True,
desc=f"Padding dataset to max length {max_length}",
)
if save_path is not None:
tokenized_datasets.save_to_disk(save_path)
return tokenized_datasets
class QaDataset(Dataset):
def __init__(self,
tokenizer_name_or_path,
select_num=None,
start_idx=None,
need_preprocess=False,
dataset_name=None,
data_part="train",
seq_max_len=1000
):
self.select_num = select_num
self.start_idx = start_idx
self.ds = None
self.need_preprocess = need_preprocess
self.dataset_name = dataset_name
self.data_part = data_part
self.seq_max_len = seq_max_len
self.return_with_idx = False
if 'llama' in tokenizer_name_or_path.lower():
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, unk_token="", bos_token="",
eos_token="", add_eos_token=True)
self.tokenizer.pad_token = self.tokenizer.eos_token
else:
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
if 'gpt2' in tokenizer_name_or_path.lower():
self.tokenizer.pad_token = self.tokenizer.eos_token
def load(self, path):
local_data = load_from_disk(path)
if not self.need_preprocess:
self.ds = local_data[self.data_part]
else:
tokenized_ds = tokenize_qa_dataset(
dataset_name=self.dataset_name,
tokenizer=self.tokenizer,
seq_max_len=self.seq_max_len,
data_part=self.data_part,
dataset=local_data
)
self.ds = tokenized_ds[self.data_part]
if self.select_num is not None:
if self.start_idx is not None:
self.ds = self.ds.select(range(self.start_idx, min(len(self.ds), self.start_idx + self.select_num)))
else:
self.ds = self.ds.select(range(self.select_num))
def set_return_with_idx(self):
self.return_with_idx = True
def reset_return_with_idx(self):
self.return_with_idx = False
def __len__(self):
return len(self.ds)
def __getitem__(self, idx):
if self.return_with_idx:
return {
"idx": idx,
"inputs": self.ds[idx]
}
else:
return self.ds[idx]
================================================
FILE: python/fate_llm/dataset/seq_cls_dataset.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from fate.ml.nn.dataset.base import Dataset
import pandas as pd
import torch as t
from transformers import AutoTokenizer
import os
import numpy as np
# avoid tokenizer parallelism
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class SeqCLSDataset(Dataset):
"""
A Dataset for some basic NLP Tasks, this dataset will automatically transform raw text into word indices
using AutoTokenizer from transformers library,
Parameters
----------
truncation bool, truncate word sequence to 'text_max_length'
text_max_length int, max length of word sequences
tokenizer_name_or_path str, name of bert tokenizer(see transformers official for details) or path to local
transformer tokenizer folder
return_label bool, return label or not, this option is for host dataset, when running hetero-NN
padding bool, whether to pad the word sequence to 'text_max_length'
padding_side str, 'left' or 'right', where to pad the word sequence
pad_token str, pad token, use this str as pad token, if None, use tokenizer.pad_token
return_input_ids bool, whether to return input_ids or not, if False, return word_idx['input_ids']
"""
def __init__(
self,
truncation=True,
text_max_length=128,
tokenizer_name_or_path="bert-base-uncased",
return_label=True,
padding=True,
padding_side="right",
pad_token=None,
return_input_ids=True):
super(SeqCLSDataset, self).__init__()
self.text = None
self.word_idx = None
self.label = None
self.tokenizer = None
self.sample_ids = None
self.padding = padding
self.truncation = truncation
self.max_length = text_max_length
self.with_label = return_label
self.tokenizer_name_or_path = tokenizer_name_or_path
self.tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer_name_or_path)
self.tokenizer.padding_side = padding_side
self.return_input_ids = return_input_ids
if pad_token is not None:
self.tokenizer.add_special_tokens({'pad_token': pad_token})
def load(self, file_path):
tokenizer = self.tokenizer
self.text = pd.read_csv(file_path)
text_list = list(self.text.text)
self.word_idx = tokenizer(
text_list,
padding=self.padding,
return_tensors='pt',
truncation=self.truncation,
max_length=self.max_length)
if self.return_input_ids:
self.word_idx = self.word_idx['input_ids']
if self.with_label:
self.label = t.Tensor(self.text.label).detach().numpy()
self.label = self.label.reshape((len(self.text), -1))
if 'id' in self.text:
self.sample_ids = self.text['id'].values.tolist()
def get_classes(self):
return np.unique(self.label).tolist()
def get_vocab_size(self):
return self.tokenizer.vocab_size
def get_sample_ids(self):
return self.sample_ids
def __getitem__(self, item):
if self.return_input_ids:
ret = self.word_idx[item]
else:
ret = {k: v[item] for k, v in self.word_idx.items()}
if self.with_label:
return ret, self.label[item]
return ret
def __len__(self):
return len(self.text)
def __repr__(self):
return self.tokenizer.__repr__()
================================================
FILE: python/fate_llm/evaluate/__init__.py
================================================
================================================
FILE: python/fate_llm/evaluate/scripts/__init__.py
================================================
================================================
FILE: python/fate_llm/evaluate/scripts/_options.py
================================================
import time
import click
from ..utils.config import parse_config, default_eval_config
from ..utils.config import _set_namespace
def parse_custom_type(value):
parts = value.split('=')
if len(parts) == 2 and parts[1].isdigit():
return parts[0], int(parts[1])
elif len(parts) == 2 and isinstance(parts[1], str):
return parts[0], parts[1]
else:
raise click.BadParameter('Invalid input format. Use "str=int" or "str=str".')
class LlmSharedOptions(object):
_options = {
"eval_config": (('-c', '--eval_config'),
dict(type=click.Path(exists=True), help=f"Manual specify config file", default=None),
default_eval_config().__str__()),
"yes": (('-y', '--yes',), dict(type=bool, is_flag=True, help="Skip double check", default=None),
False),
"namespace": (('-n', '--namespace'),
dict(type=str, help=f"Manual specify fate llm namespace", default=None),
time.strftime('%Y%m%d%H%M%S'))
}
def __init__(self):
self._options_kwargs = {}
def __getitem__(self, item):
return self._options_kwargs[item]
def get(self, k, default=None):
v = self._options_kwargs.get(k, default)
if v is None and k in self._options:
v = self._options[k][2]
return v
def update(self, **kwargs):
for k, v in kwargs.items():
if v is not None:
self._options_kwargs[k] = v
def post_process(self):
# add defaults here
for k, v in self._options.items():
if self._options_kwargs.get(k, None) is None:
self._options_kwargs[k] = v[2]
# update config
config = parse_config(self._options_kwargs['eval_config'])
self._options_kwargs['eval_config'] = config
_set_namespace(self._options_kwargs['namespace'])
@classmethod
def get_shared_options(cls, hidden=False):
def shared_options(f):
for name, option in cls._options.items():
f = click.option(*option[0], **dict(option[1], hidden=hidden))(f)
return f
return shared_options
================================================
FILE: python/fate_llm/evaluate/scripts/config_cli.py
================================================
#
# Copyright 2024 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import click
import yaml
from pathlib import Path
from ..utils.config import create_eval_config, default_eval_config
from ._options import LlmSharedOptions
from ..utils._io import echo
@click.group("eval_config", help="fate_llm evaluate config")
def eval_config_group():
"""
eval_config fate_llm
"""
pass
@eval_config_group.command(name="new")
def _new():
"""
create new fate_llm eval config from template
"""
create_eval_config(Path("llm_eval_config.yaml"))
click.echo(f"create eval_config file: llm_eval_config.yaml")
@eval_config_group.command(name="edit")
@LlmSharedOptions.get_shared_options(hidden=True)
@click.pass_context
def _edit(ctx, **kwargs):
"""
edit fate_llm eval_config file
"""
ctx.obj.update(**kwargs)
eval_config = ctx.obj.get("eval_config")
print(f"eval_config: {eval_config}")
click.edit(filename=eval_config)
@eval_config_group.command(name="show")
def _show():
"""
show fate_test default eval_config path
"""
click.echo(f"default eval_config path is {default_eval_config()}")
================================================
FILE: python/fate_llm/evaluate/scripts/data_cli.py
================================================
#
# Copyright 2024 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import copy
import click
import yaml
import warnings
from typing import Union
from ._options import LlmSharedOptions
from ..utils.llm_evaluator import download_task
from ..utils._io import echo
@click.command('download_data')
@click.option('-t', '--tasks', required=False, type=str, multiple=True, default=None,
help='tasks whose data will be downloaded')
# @click.argument('other_args', nargs=-1)
@LlmSharedOptions.get_shared_options(hidden=True)
@click.pass_context
def download_data(ctx, tasks, **kwargs):
"""
Evaluate a pretrained model with specified parameters.
"""
ctx.obj.update(**kwargs)
ctx.obj.post_process()
if tasks is None or len(tasks) == 0:
tasks = None
echo.echo(f"No task is given, will download data for all built-in tasks.", fg='red')
else:
echo.echo(f"given tasks: {tasks}", fg='red')
download_task(tasks)
================================================
FILE: python/fate_llm/evaluate/scripts/eval_cli.py
================================================
#
# Copyright 2024 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import copy
import click
import yaml
import warnings
from typing import Union
from ._options import LlmSharedOptions
from ..utils.config import default_eval_config
from ..utils.llm_evaluator import evaluate, init_tasks, aggregate_table
from ..utils.model_tools import load_by_loader
from ..utils._io import echo
from ..utils._parser import LlmSuite
@click.command('evaluate')
@click.option('-i', '--include', required=True, type=click.Path(exists=True),
help='Path to model and metrics conf')
@click.option('-c', '--eval-config', type=click.Path(exists=True), help='Path to FATE Llm evaluation config. '
'If not provided, use default config.')
@click.option('-o', '--result-output', type=click.Path(),
help='Path to save evaluation results.')
# @click.argument('other_args', nargs=-1)
@LlmSharedOptions.get_shared_options(hidden=True)
@click.pass_context
def run_evaluate(ctx, include, eval_config, result_output, **kwargs):
"""
Evaluate a pretrained model with specified parameters.
"""
ctx.obj.update(**kwargs)
ctx.obj.post_process()
# namespace = ctx.obj["namespace"]
yes = ctx.obj["yes"]
echo.echo(f"include: {include}", fg='red')
try:
# include = os.path.abspath(include)
suite = LlmSuite.load(include)
except Exception as e:
raise ValueError(f"Invalid include path: {include}, please check. {e}")
if not eval_config:
eval_config = default_eval_config()
if not os.path.exists(eval_config):
eval_config = None
if not yes and not click.confirm("running?"):
return
# init tasks
init_tasks()
# run_suite_eval(suite, eval_config_dict, result_output)
run_suite_eval(suite, eval_config, result_output)
def run_job_eval(job, eval_conf):
job_eval_conf = {}
if isinstance(eval_conf, dict):
job_eval_conf.update(eval_conf)
elif eval_conf is not None and os.path.exists(eval_conf):
with open(eval_conf, 'r') as f:
job_eval_conf.update(yaml.safe_load(f))
# echo.echo(f"Evaluating job: {job.job_name} with tasks: {job.tasks}")
if job.eval_conf_path:
# job-level eval conf takes priority
with open(job.eval_conf_path, 'r') as f:
job_eval_conf.update(yaml.safe_load(f))
# get loader
if job.loader:
if job.peft_path:
model = load_by_loader(loader_name=job.loader,
loader_conf_path=loader_conf_path,
peft_path=job.peft_path)
else:
model = load_by_loader(loader_name=job.loader,
loader_conf_path=loader_conf_path)
result = evaluate(model=model, tasks=job.tasks, include_path=job.include_path, **job_eval_conf)
else:
# feed in pretrained & peft path
job_eval_conf["model_args"]["pretrained"] = job.pretrained_model_path
if job.peft_path:
job_eval_conf["model_args"]["peft"] = job.peft_path
result = evaluate(tasks=job.tasks, include_path=job.include_path, **job_eval_conf)
return result
def run_suite_eval(suite, eval_conf, output_path=None):
suite_results = dict()
for pair in suite.pairs:
job_results = dict()
for job in pair.jobs:
if not job.evaluate_only:
# give warning that job will be skipped
warnings.warn(f"Job {job.job_name} will be skipped since no pretrained model is provided")
continue
echo.echo(f"Evaluating job: {job.job_name} with tasks: {job.tasks}")
result = run_job_eval(job, eval_conf)
job_results[job.job_name] = result
suite_results[pair.pair_name] = job_results
suite_writers = aggregate_table(suite_results)
for pair_name, pair_writer in suite_writers.items():
echo.sep_line()
echo.echo(f"Pair: {pair_name}")
echo.sep_line()
echo.echo(pair_writer.dumps())
echo.stdout_newline()
if output_path:
with open(output_path, 'w') as f:
for pair_name, pair_writer in suite_writers.items():
pair_writer.dumps(f)
================================================
FILE: python/fate_llm/evaluate/scripts/fate_llm_cli.py
================================================
#
# Copyright 2024 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import click
import yaml
from typing import Union
from .eval_cli import run_evaluate
from .config_cli import eval_config_group
from .data_cli import download_data
from ._options import LlmSharedOptions
commands = {
"evaluate": run_evaluate,
"config": eval_config_group,
"download": download_data
}
class FATELlmCLI(click.MultiCommand):
def list_commands(self, ctx):
return list(commands)
def get_command(self, ctx, name):
if name not in commands and name in commands_alias:
name = commands_alias[name]
if name not in commands:
ctx.fail("No such command '{}'.".format(name))
return commands[name]
@click.command(cls=FATELlmCLI, help="A collection of tools to run FATE Llm Evaluation.",
context_settings=dict(help_option_names=["-h", "--help"]))
@LlmSharedOptions.get_shared_options()
@click.pass_context
def fate_llm_cli(ctx, **kwargs):
ctx.ensure_object(LlmSharedOptions)
ctx.obj.update(**kwargs)
if __name__ == '__main__':
fate_llm_cli(obj=LlmSharedOptions())
================================================
FILE: python/fate_llm/evaluate/tasks/__init__.py
================================================
#
# Copyright 2024 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import yaml
import os
def local_fn_constructor(loader, node):
return node
def local_fn_representer(dumper, data):
return data
def dump_yaml(dict, path):
yaml.add_representer(yaml.ScalarNode, local_fn_representer)
with open(path, 'w') as f:
yaml.dump(dict, f)
class Task:
_task_name = ""
_task_dir = ""
_task_conf_file = ""
_task_source_url = ""
script_dir = os.path.dirname(__file__)
@property
def task_name(self):
return self._task_name
@property
def task_template(self):
yaml.add_constructor("!function", local_fn_constructor)
with open(os.path.abspath(os.path.join(self.script_dir, self._task_dir, self._task_conf_file)), "rb") as f:
task_template = yaml.full_load(f)
return task_template
@property
def task_scr_dir(self):
return os.path.abspath(os.path.join(self.script_dir, self._task_dir))
@property
def task_conf_path(self):
return os.path.abspath(os.path.join(self.script_dir, self._task_dir, self._task_conf_file))
@property
def task_source_url(self):
return self._task_source_url
def download_from_source(self):
raise NotImplementedError(f"Should not be called here.")
class Dolly(Task):
_task_name = "dolly-15k"
_task_dir = "dolly_15k"
_task_conf_file = "default_dolly_15k.yaml"
def download_from_source(self):
try:
from datasets import load_dataset
data = load_dataset("databricks/databricks-dolly-15k", split="train")
filename = os.path.join(self.task_scr_dir, "databricks-dolly-15k.jsonl")
data.to_json(filename)
return True
except Exception as e:
print(f"Failed to download data from source: {e}")
return False
class AdvertiseGen(Task):
_task_name = "advertise-gen"
_task_dir = "advertise_gen"
_task_conf_file = "default_advertise_gen.yaml"
_task_source_url = ["https://cloud.tsinghua.edu.cn/seafhttp/files/3781289a-5a60-44b1-b5f1-a04364e3eb9d/AdvertiseGen.tar.gz",
"https://docs.google.com/uc?export=download&id=13_vf0xRTQsyneRKdD1bZIr93vBGOczrk"]
def download_from_source(self):
from ..utils.data_tools import download_data
result = download_data(self.task_scr_dir, self.task_source_url[0])
if not result:
print(f"retry with address: {self.task_source_url[1]}")
return download_data(self.task_scr_dir, self.task_source_url[1])
return result
build_in_tasks = {"dolly-15k": Dolly(),
"advertise-gen": AdvertiseGen()}
================================================
FILE: python/fate_llm/evaluate/tasks/advertise_gen/__init__.py
================================================
================================================
FILE: python/fate_llm/evaluate/tasks/advertise_gen/advertise_utils.py
================================================
# adopted from https://github.com/huggingface/datasets/blob/main/metrics/rouge/rouge.py
from rouge_score import rouge_scorer
# from multiprocessing import Pool
def rouge_l(predictions, references, use_stemmer=False):
scorer = rouge_scorer.RougeScorer(rouge_types=['rougeL'], use_stemmer=use_stemmer)
scores = []
for ref, pred in zip(references, predictions):
score = scorer.score(ref, pred)
scores.append(score)
rouge_l_score = scores[0]['rougeL'].fmeasure
return rouge_l_score
================================================
FILE: python/fate_llm/evaluate/tasks/advertise_gen/default_advertise_gen.yaml
================================================
dataset_kwargs:
data_files:
train: train.json
validation: dev.json
dataset_path: json
doc_to_target: '{{summary}}'
doc_to_text: '{{content}}'
metric_list:
- aggregation: mean
higher_is_better: true
metric: !function 'advertise_utils.rouge_l'
output_type: generate_until
task: advertise-gen
validation_split: validation
================================================
FILE: python/fate_llm/evaluate/tasks/dolly_15k/__init__.py
================================================
================================================
FILE: python/fate_llm/evaluate/tasks/dolly_15k/default_dolly_15k.yaml
================================================
dataset_kwargs:
data_files: databricks-dolly-15k.jsonl
dataset_path: json
doc_to_target: '{{response}}'
doc_to_text: !function 'dolly_utils.doc_to_text'
metric_list:
- aggregation: mean
higher_is_better: true
metric: !function 'dolly_utils.rouge_l'
output_type: generate_until
task: dolly-15k
validation_split: train
================================================
FILE: python/fate_llm/evaluate/tasks/dolly_15k/dolly_utils.py
================================================
# adopted from https://github.com/huggingface/datasets/blob/main/metrics/rouge/rouge.py
from rouge_score import rouge_scorer
def rouge_l(predictions, references, use_stemmer=False):
scorer = rouge_scorer.RougeScorer(rouge_types=['rougeL'], use_stemmer=use_stemmer)
scores = []
for ref, pred in zip(references, predictions):
score = scorer.score(ref, pred)
scores.append(score)
rouge_l_score = scores[0]['rougeL'].fmeasure
return rouge_l_score
def doc_to_text(doc):
if doc["context"]:
return f"context: {doc['context']}\ninstruction: {doc['instruction']}\nresponse:"
else:
return f"instruction: {doc['instruction']}\nresponse:"
"""
def train_load_evalaute_lm():
pipeline.fit(train_data)
lm = OTModelLoader().load(path, **args)
from fate_llm.evaluator import evaluator
# general case
evaluator.evaluate(lm, task="dolly_15k", **args)
# user modified conf
config = evaluator.get_task_template(task="dolly_15k") # return dict copy of yaml file
config['dataset_kwargs'] = {"dataset_kwargs":
{"data_files":
{"test": './dolly_15k_test.csv',
"dev": './dolly_15k_dev.csv'}}}
# may provide arbitrary export path, must be of dir, create temp dir under the given path: {$export_path}/temp_dir
new_task_dir = evaluator.export_config(config, task="dolly_15k", export_path=None)
result = evaluator.evalute(lm, task="dolly_15k", include_path=new_task_dir, **args)
print(result) # dict
evaluator.delete_config(new_task_dir)
"""
================================================
FILE: python/fate_llm/evaluate/utils/__init__.py
================================================
from ._parser import LlmJob, LlmPair, LlmSuite
================================================
FILE: python/fate_llm/evaluate/utils/_io.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import click
import loguru
# noinspection PyPep8Naming
class echo(object):
_file = None
@classmethod
def set_file(cls, file):
cls._file = file
@classmethod
def echo(cls, message, **kwargs):
click.secho(message, **kwargs)
click.secho(message, file=cls._file, **kwargs)
@classmethod
def sep_line(cls):
click.secho("-------------------------------------------------")
@classmethod
def file(cls, message, **kwargs):
click.secho(message, file=cls._file, **kwargs)
@classmethod
def stdout(cls, message, **kwargs):
click.secho(message, **kwargs)
@classmethod
def stdout_newline(cls):
click.secho("")
@classmethod
def welcome(cls):
cls.echo("Welcome to FATE Llm Evaluator")
@classmethod
def flush(cls):
import sys
sys.stdout.flush()
def set_logger(name):
loguru.logger.remove()
loguru.logger.add(name, level='ERROR', delay=True)
return loguru.logger
LOGGER = loguru.logger
================================================
FILE: python/fate_llm/evaluate/utils/_parser.py
================================================
#
# Copyright 2024 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import yaml
import typing
from pathlib import Path
class LlmJob(object):
def __init__(self, job_name: str, script_path: Path=None, conf_path: Path=None, model_task_name: str=None,
pretrained_model_path: Path=None, peft_path: Path=None,
eval_conf_path: Path=None, loader: str=None, loader_conf_path: Path=None,
tasks: typing.List[str]=None, include_path: Path=None, peft_path_format: str=None):
self.job_name = job_name
self.script_path = script_path
self.conf_path = conf_path
self.model_task_name = model_task_name
self.pretrained_model_path = pretrained_model_path
self.peft_path = peft_path
self.loader = loader
self.loader_conf_path = loader_conf_path
self.eval_conf_path = eval_conf_path
self.tasks = tasks
self.include_path = include_path
self.evaluate_only = self.script_path is None
self.peft_path_format = peft_path_format
class LlmPair(object):
def __init__(
self, pair_name: str, jobs: typing.List[LlmJob]
):
self.pair_name = pair_name
self.jobs = jobs
class LlmSuite(object):
def __init__(
self, pairs: typing.List[LlmPair], path: Path, dataset=None
):
self.pairs = pairs
self.path = path
self.dataset = dataset
self._final_status = {}
@staticmethod
def load(path: Path):
if isinstance(path, str):
path = Path(path)
with path.open("r") as f:
testsuite_config = yaml.safe_load(f)
pairs = []
for pair_name, pair_configs in testsuite_config.items():
if pair_name == "data":
continue
jobs = []
for job_name, job_configs in pair_configs.items():
# with train
script_path = job_configs.get("script", None)
if script_path and not os.path.isabs(script_path):
script_path = path.parent.joinpath(script_path).resolve()
conf_path = job_configs.get("conf", None)
if conf_path and not os.path.isabs(conf_path):
conf_path = path.parent.joinpath(conf_path).resolve()
model_task_name = job_configs.get("model_task_name", None)
# evaluate only
pretrained_model_path = job_configs.get("pretrained", None)
if pretrained_model_path and not os.path.isabs(pretrained_model_path):
# make path absolute, else keep original pretrained model name
if "yaml" in pretrained_model_path or "/" in pretrained_model_path:
pretrained_model_path = path.parent.joinpath(pretrained_model_path).resolve()
peft_path = job_configs.get("peft", None)
if peft_path and not os.path.isabs(peft_path):
peft_path = path.parent.joinpath(peft_path).resolve()
eval_conf_path = job_configs.get("eval_conf", None)
if eval_conf_path and not os.path.isabs(eval_conf_path):
eval_conf_path = path.parent.joinpath(eval_conf_path).resolve()
loader = job_configs.get("loader", None)
if job_configs.get("loader_conf"):
loader_conf_path = path.parent.joinpath(job_configs["loader_conf"]).resolve()
else:
loader_conf_path = ""
tasks = job_configs.get("tasks", [])
include_path = job_configs.get("include_path", "")
if include_path and not os.path.isabs(include_path):
include_path = path.parent.joinpath(job_configs["include_path"]).resolve()
peft_path_format = job_configs.get("peft_path_format", "{{fate_base}}/fate_flow/model/{{job_id}}/"
"guest/{{party_id}}/{{model_task_name}}/0/"
"output/output_model/model_directory")
jobs.append(
LlmJob(
job_name=job_name, script_path=script_path, conf_path=conf_path,
model_task_name=model_task_name,
pretrained_model_path=pretrained_model_path, peft_path=peft_path, eval_conf_path=eval_conf_path,
loader=loader, loader_conf_path=loader_conf_path, tasks=tasks, include_path=include_path,
peft_path_format=peft_path_format
)
)
pairs.append(
LlmPair(
pair_name=pair_name, jobs=jobs
)
)
suite = LlmSuite(pairs=pairs, path=path)
return suite
def update_status(
self, pair_name, job_name, job_id=None, status=None, exception_id=None, time_elapsed=None, event=None
):
for k, v in locals().items():
if k != "job_name" and k != "pair_name" and v is not None:
if self._final_status.get(f"{pair_name}-{job_name}"):
setattr(self._final_status[f"{pair_name}-{job_name}"], k, v)
def get_final_status(self):
return self._final_status
================================================
FILE: python/fate_llm/evaluate/utils/config.py
================================================
#
# Copyright 2024 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import click
import yaml
import typing
from pathlib import Path
from ._io import set_logger, echo
DEFAULT_FATE_LLM_BASE_PATH = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
FATE_LLM_BASE_PATH = os.getenv("FATE_LLM_BASE_PATH") or DEFAULT_FATE_LLM_BASE_PATH
# DEFAULT_TASK_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../tasks"))
DEFAULT_FATE_LLM_TASK_PATH = os.path.abspath(os.path.join(FATE_LLM_BASE_PATH, "tasks"))
FATE_LLM_TASK_PATH = os.getenv("FATE_LLM_TASK_PATH") or DEFAULT_FATE_LLM_TASK_PATH
_default_eval_config = Path(FATE_LLM_BASE_PATH).resolve() / 'llm_eval_config.yaml'
template = """# args for evaluate
batch_size: 10
model_args:
device: cuda
dtype: auto
trust_remote_code: true
num_fewshot: 0
"""
def create_eval_config(path: Path, override=False):
if path.exists() and not override:
raise FileExistsError(f"{path} exists")
with path.open("w") as f:
f.write(template)
def default_eval_config():
if not _default_eval_config.exists():
create_eval_config(_default_eval_config)
return _default_eval_config
class Config(object):
def __init__(self, config):
self.update_conf(**config)
def update_conf(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
@staticmethod
def load(path: typing.Union[str, Path], **kwargs):
if isinstance(path, str):
path = Path(path)
config = {}
if path is not None:
with path.open("r") as f:
config.update(yaml.safe_load(f))
config.update(kwargs)
return Config(config)
@staticmethod
def load_from_file(path: typing.Union[str, Path]):
"""
Loads conf content from yaml file. Used to read in parameter configuration
Parameters
----------
path: str, path to conf file, should be absolute path
Returns
-------
dict, parameter configuration in dictionary format
"""
if isinstance(path, str):
path = Path(path)
config = {}
if path is not None:
file_type = path.suffix
with path.open("r") as f:
if file_type == ".yaml":
config.update(yaml.safe_load(f))
else:
raise ValueError(f"Cannot load conf from file type {file_type}")
return config
def parse_config(config):
try:
config_inst = Config.load(config)
except Exception as e:
raise RuntimeError(f"error parse config from {config}") from e
return config_inst
def _set_namespace(namespace):
Path(f"logs/{namespace}").mkdir(exist_ok=True, parents=True)
set_logger(f"logs/{namespace}/exception.log")
echo.set_file(click.open_file(f'logs/{namespace}/stdout', "a"))
================================================
FILE: python/fate_llm/evaluate/utils/data_tools.py
================================================
#
# Copyright 2024 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
def download_data(data_dir, data_url, is_tar=True):
import os
import requests
import tarfile
import io
# Create data directory
if not os.path.exists(data_dir):
os.makedirs(data_dir)
# Download data
try:
response = requests.get(data_url)
if response.status_code == 200:
if is_tar:
# extract tar file and write to data_dir
with tarfile.open(fileobj=io.BytesIO(response.content), mode='r:gz') as tar:
for member in tar.getmembers():
# check if member is a file
if member.isreg():
member.name = os.path.join(data_dir, os.path.basename(member.name))
tar.extract(member)
else:
# write to data_dir
with open(os.path.join(data_dir, os.path.basename(data_url)), 'wb') as f:
f.write(response.content)
return True
else:
print(f"Error downloading file: {response.status_code}")
return False
except Exception as e:
print(f"Error downloading file: {e}")
return False
================================================
FILE: python/fate_llm/evaluate/utils/llm_evaluator.py
================================================
#
# Copyright 2024 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# this file is used to evaluate the model on fate-llm built-in tasks and user-given tasks
import os
import tempfile
import yaml
import shutil
import warnings
from pytablewriter import MarkdownTableWriter
import lm_eval
from lm_eval.utils import load_yaml_config
from ..tasks import build_in_tasks, dump_yaml
from .config import FATE_LLM_BASE_PATH, FATE_LLM_TASK_PATH
def evaluate(tasks, model="hf", model_args=None, include_path=None, task_manager=None, show_result=False, **kwargs):
"""
Evaluate the model on given tasks. Simplified uses for built-in tasks.
Parameters
----------
tasks: str or List[str], task name(s)
model: str or model object, model to be evaluated,
select from lm_eval supported types: {"hf-auto", "hf", "huggingface", "vllm"}
model_args: model args, str or dict
include_path: task path for tasks not in built-in tasks
task_manager: lm_eval.TakManger object
kwargs
Returns
-------
"""
if task_manager:
if not isinstance(task_manager, lm_eval.tasks.TaskManager):
raise ValueError(f"'task_manager' must be of TaskManager type.")
elif include_path:
task_manager = lm_eval.tasks.TaskManager(include_path=str(include_path))
else:
task_manager = lm_eval.tasks.TaskManager(include_path=str(FATE_LLM_TASK_PATH))
task_names = []
if isinstance(tasks, str):
task_names.append(tasks)
elif isinstance(tasks, list):
for task in tasks:
if isinstance(task, str):
task_names.append(task)
else:
raise ValueError(f"tasks: {task} of type {type(task)} not valid, please check.")
else:
raise ValueError(f"tasks: {tasks} of type {type(tasks)} not valid, please check.")
results = lm_eval.simple_evaluate(
model=model,
model_args=model_args,
tasks=task_names,
task_manager=task_manager,
**kwargs
)
if show_result:
result_table = lm_eval.utils.make_table(results)
print(result_table)
return results
def aggregate_table(results):
"""
adapted from lm_eval.utils.make_table:
https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.2/lm_eval/utils.py Aggregate results from different models with same tasks
Parameters
----------
results: dict, results from different models
Returns
-------
"""
suite_writers = dict()
for pair_name, pair_results in results.items():
# job_count = len(pair_results)
all_jobs = list(pair_results.keys())
md_writer = MarkdownTableWriter()
values = []
task_results = dict()
# print(f"pair results: {pair_results}")
for job_name, result_dict in pair_results.items():
if "results" in result_dict and result_dict["results"]:
column = "results"
else:
column = "groups"
for k, dic in result_dict[column].items():
if "alias" in dic:
# task alias
k = dic.pop("alias")
for (mf), v in dic.items():
m, _, f = mf.partition(",")
if m.endswith("_stderr"):
continue
if m + "_stderr" + "," + f in dic:
se = dic[m + "_stderr" + "," + f]
if se != "N/A":
se = "%.4f" % se
v = "%.4f ± %s" % (v, se)
else:
v = "%.4f" % v
task_results.setdefault(k, {}).setdefault(job_name, {})[m] = v
# job names as columns
# print(f"task results: {task_results}")
for task_name, task_result in task_results.items():
metrics = {inner_key for inner_dict in task_result.values() for inner_key, value in inner_dict.items()}
for metric in metrics:
row = [f"{task_name}({metric})"]
for job_name in all_jobs:
if job_name in task_result:
row.append(task_result[job_name].get(metric, "N/A"))
else:
row.append("N/A")
values.append(row)
all_headers = ["Task"] + list(pair_results.keys())
md_writer.headers = all_headers
md_writer.value_matrix = values
suite_writers[pair_name] = md_writer
return suite_writers
def get_task_template(task):
if not isinstance(task, str) or task not in build_in_tasks:
raise ValueError(f"{task} not found in build in task, please check input.")
result = build_in_tasks.get(task).task_template
return result
def export_config(config, task, export_dir=None, export_sub_dir=None):
scr_dir = build_in_tasks.get(task).task_scr_dir
if export_dir is None:
export_dir = os.path.dirname(scr_dir)
if export_sub_dir is None:
temp_dir = tempfile.mkdtemp()
# make sure the relative path in new file will work
full_export_dir = os.path.join(export_dir, os.path.basename(temp_dir))
os.rename(temp_dir, full_export_dir)
else:
full_export_dir = os.path.join(export_dir, export_sub_dir)
copy_directory_to_dst(scr_dir, full_export_dir, build_in_tasks.get(task).task_conf_path, config)
return full_export_dir
def copy_directory_to_dst(src_dir, dst_dir, target_conf_file, new_conf: dict):
"""parent_dir = os.path.dirname(src_dir)
temp_dir = tempfile.mkdtemp()
# make sure the relative path in new file will work
temp_dir_in_parent = os.path.join(parent_dir, os.path.basename(temp_dir))
os.rename(temp_dir, temp_dir_in_parent)"""
for item in os.listdir(src_dir):
src_item = os.path.join(src_dir, item)
dst_item = os.path.join(dst_dir, item)
if os.path.isdir(src_item):
shutil.copytree(src_item, dst_item)
else:
if item == target_conf_file:
# write new conf file
dump_yaml(new_conf, dst_item)
else:
shutil.copy2(src_item, dst_item)
# shutil.copy2(src_item, dst_item)
def contains_subdirectory(path, subdirectories):
base_name = os.path.basename(path)
if base_name in subdirectories:
return True
for root, dirs, files in os.walk(path):
for d in dirs:
if d in subdirectories:
return True
return False
def delete_config(target_dir, force=False):
if not force:
# check if target dir in any of the build in tasks, only rm dir for build in tasks if force=True
all_build_in_dir = {task.task_scr_dir for task in build_in_tasks.values()}
if contains_subdirectory(target_dir, all_build_in_dir):
warnings.warn(f"Built-in task(s) found in given target directory, please check input or set `force`=True.")
return
shutil.rmtree(target_dir)
def set_environ_fate_llm_base(path):
if path:
os.environ["FATE_LLM_BASE_PATH"] = path
def set_environ_fate_llm_task_base(path):
if path:
os.environ["FATE_LLM_TASK_PATH"] = path
def init_tasks(root_path=None):
"""
Parameters
----------
root_path: str, default None, root path for all local datasets in built-in tasks, {$root_path}/{$data_files};
if not provided, current file path will be used to generate root
Returns
-------
"""
for task in build_in_tasks.values():
conf_path = task.task_conf_path
parent_path = os.path.dirname(conf_path)
task_template = task.task_template
data_args = task_template.get("dataset_kwargs")
if data_args:
data_files = data_args.get("data_files")
if isinstance(data_files, str):
if data_files.endswith("jsonl") or data_files.endswith("json"):
if root_path:
parent_dir = os.path.basename(parent_path)
new_conf_path = os.path.join(root_path, parent_dir, os.path.basename(conf_path))
else:
new_conf_path = os.path.join(parent_path, data_files)
task_template["dataset_kwargs"]["data_files"] = new_conf_path
elif isinstance(data_files, dict):
for k, v in data_files.items():
if root_path:
parent_dir = os.path.basename(parent_path)
new_conf_path = os.path.join(root_path, parent_dir, os.path.basename(conf_path))
else:
new_conf_path = os.path.join(parent_path, v)
task_template["dataset_kwargs"]["data_files"][k] = new_conf_path
try:
dump_yaml(task_template, conf_path)
except FileNotFoundError:
raise ValueError(f"Cannot find task config {conf_path}, please check.")
except Exception:
raise ValueError(f"Initialization failed.")
def download_task(tasks=None):
if tasks is None:
tasks = list(build_in_tasks.keys())
i = 1
if isinstance(tasks, str):
tasks = [tasks]
n = len(tasks)
for task in tasks:
task_obj = build_in_tasks.get(task)
if task_obj is None:
print(f"Task {task} not found in built-in tasks, please check.")
continue
result = task_obj.download_from_source()
if result:
print(f"Finish downloading {i}/{n} th task data: {task}, saved to {task_obj.task_scr_dir}.\n")
else:
print(f"Failed to download {i}/{n} th task data to {task_obj.task_scr_dir}.\n")
i += 1
================================================
FILE: python/fate_llm/evaluate/utils/model_tools.py
================================================
#
# Copyright 2024 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from transformers import AutoModel, AutoTokenizer
from lm_eval.models.huggingface import HFLM
def load_model_from_path(model_path, peft_path=None, peft_config=None, model_args=None):
model_args = model_args or {}
if peft_path is None:
if os.path.isfile(model_path):
return HFLM(pretrained=model_path, **model_args)
else:
raise ValueError(f"given model path is not valid, please check: {model_path}")
else:
import torch
from peft import PeftModel, PeftConfig, LoraConfig, TaskType, get_peft_model
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
model.half()
model.eval()
peft_config = peft_config or {}
peft_config=LoraConfig(**peft_config)
model = get_peft_model(model, peft_config)
model.load_state_dict(torch.load(peft_path), strict=False)
model.model.half()
HFLM(pretrained=model, tokenizer=tokenizer, **model_args)
def load_model(model_path, peft_path=None, model_args=None):
model_args = model_args or {}
return HFLM(pretrained=model_path, peft_path=peft_path, **model_args)
def load_by_loader(loader_name=None, loader_conf_path=None, peft_path=None):
#@todo: find loader fn & return loaded model
pass
================================================
FILE: python/fate_llm/inference/__init__.py
================================================
================================================
FILE: python/fate_llm/inference/api.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from fate_llm.inference.inference_base import Inference
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import GenerationConfig
from typing import List
class APICompletionInference(Inference):
def __init__(self, api_url: str, model_name: str, api_key: str = 'EMPTY', api_timeout=3600):
from openai import OpenAI
self.model_name = model_name
self.client = OpenAI(
api_key=api_key,
base_url=api_url,
timeout=api_timeout
)
def inference(self, docs: List[str], inference_kwargs: dict = {}) -> List[str]:
completion = self.client.completions.create(model=self.model_name, prompt=docs, **inference_kwargs)
rs_doc = [completion.choices[i].text for i in range(len(completion.choices))]
return rs_doc
================================================
FILE: python/fate_llm/inference/hf_qw.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from fate_llm.inference.inference_base import Inference
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List
import tqdm
class QwenHFCompletionInference(Inference):
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def inference(self, docs: List[str], inference_kwargs: dict = {}) -> List[str]:
self.model = self.model.eval()
rs_list = []
for d in tqdm.tqdm(docs):
inputs = self.tokenizer(d, return_tensors='pt')
inputs = inputs.to(self.model.device)
inputs.update(inference_kwargs)
pred = self.model.generate(**inputs)
response = self.tokenizer.decode(pred.cpu()[0][len(inputs['input_ids'][0]):], skip_special_tokens=True)
rs_list.append(response)
self.model = self.model.train()
return rs_list
================================================
FILE: python/fate_llm/inference/inference_base.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List
class Inference(object):
def __init__(self):
pass
def inference(self, docs: List[str], inference_kwargs: dict = {}) -> List[str]:
raise NotImplementedError()
================================================
FILE: python/fate_llm/inference/vllm.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from fate_llm.inference.inference_base import Inference
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import GenerationConfig
import logging
from typing import List
logger = logging.getLogger(__name__)
class VLLMInference(Inference):
def __init__(self, model_path, num_gpu=1, dtype='float16', gpu_memory_utilization=0.9):
from vllm import LLM
self.llm = LLM(model=model_path, trust_remote_code=True, dtype=dtype, tensor_parallel_size=num_gpu, gpu_memory_utilization=gpu_memory_utilization)
logger.info('vllm model init done, model path is {}'.format(model_path))
def inference(self, docs: List[str], inference_kwargs: dict = {}) -> List[str]:
from vllm import SamplingParams
param = SamplingParams(**inference_kwargs)
outputs = self.llm.generate(
prompts=docs,
sampling_params=param)
rs = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
rs.append(generated_text)
return rs
================================================
FILE: python/fate_llm/model_zoo/__init__.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
================================================
FILE: python/fate_llm/model_zoo/embedding_transformer/__init__.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
================================================
FILE: python/fate_llm/model_zoo/embedding_transformer/st_model.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from sentence_transformers import SentenceTransformer
from typing import Any, Optional, Dict, Union
class SentenceTransformerModel(object):
def __init__(
self,
model_name_or_path: Optional[str] = None,
device: Optional[str] = None,
prompts: Optional[Dict[str, str]] = None,
default_prompt_name: Optional[str] = None,
cache_folder: Optional[str] = None,
trust_remote_code: bool = False,
revision: Optional[str] = None,
local_files_only: bool = False,
token: Optional[Union[bool, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
truncate_dim: Optional[int] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
self.model_name_or_path = model_name_or_path
self.device = device
self.prompts = prompts
self.default_prompt_name = default_prompt_name
self.cache_folder = cache_folder
self.trust_remote_code = trust_remote_code
self.revision = revision
self.local_files_only = local_files_only
self.token = token
self.use_auth_token = use_auth_token
self.truncate_dim = truncate_dim
self.model_kwargs = model_kwargs
self.tokenizer_kwargs = tokenizer_kwargs
self.config_kwargs = config_kwargs
def load(self):
model = SentenceTransformer(
model_name_or_path=self.model_name_or_path,
device=self.device,
prompts=self.prompts,
default_prompt_name=self.default_prompt_name,
cache_folder=self.cache_folder,
trust_remote_code=self.trust_remote_code,
revision=self.revision,
local_files_only=self.local_files_only,
token=self.token,
use_auth_token=self.use_auth_token,
truncate_dim=self.truncate_dim,
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs
)
return model
================================================
FILE: python/fate_llm/model_zoo/hf_model.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from transformers import AutoModelForCausalLM
class HFAutoModelForCausalLM:
def __init__(self, pretrained_model_name_or_path, *model_args, **kwargs) -> None:
self.pretrained_model_name_or_path = pretrained_model_name_or_path
self.model_args = model_args
self.kwargs = kwargs
if "torch_dtype" in self.kwargs and self.kwargs["torch_dtype"] != "auto":
dtype = self.kwargs.pop("torch_dtype")
self.kwargs["torch_dtype"] = getattr(torch, dtype)
def load(self):
model = AutoModelForCausalLM.from_pretrained(
self.pretrained_model_name_or_path, *self.model_args, **self.kwargs
)
return model
================================================
FILE: python/fate_llm/model_zoo/offsite_tuning/__init__.py
================================================
================================================
FILE: python/fate_llm/model_zoo/offsite_tuning/bloom.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from fate_llm.model_zoo.offsite_tuning.offsite_tuning_model import OffsiteTuningSubModel, OffsiteTuningMainModel, get_dropout_emulator_and_adapters, split_numpy_array, recover_numpy_array
from transformers.models.bloom.modeling_bloom import BloomForCausalLM, BloomModel, BloomConfig
from torch import nn
import torch
from typing import Optional, Tuple
class BloomMainModel(OffsiteTuningMainModel):
def __init__(
self,
model_name_or_path,
emulator_layer_num: int,
adapter_top_layer_num: int = 2,
adapter_bottom_layer_num: int = 2):
self.model_name_or_path = model_name_or_path
super().__init__(
emulator_layer_num,
adapter_top_layer_num,
adapter_bottom_layer_num)
def get_base_model(self):
return BloomForCausalLM.from_pretrained(self.model_name_or_path)
def get_model_transformer_blocks(self, model: BloomForCausalLM):
return model.transformer.h
def get_additional_param_state_dict(self):
# get parameter of additional parameter
model = self.model
param_dict = {
'wte': model.transformer.word_embeddings,
'word_ln': model.transformer.word_embeddings_layernorm,
'last_ln_f': model.transformer.ln_f
}
addition_weights = self.get_numpy_state_dict(param_dict)
wte = addition_weights.pop('wte')
wte_dict = split_numpy_array(wte, 25, 'wte')
addition_weights.update(wte_dict)
return addition_weights
def load_additional_param_state_dict(self, submodel_weights: dict):
# load additional weights:
model = self.model
param_dict = {
'wte': model.transformer.word_embeddings,
'word_ln': model.transformer.word_embeddings_layernorm,
'last_ln_f': model.transformer.ln_f
}
new_submodel_weight = {}
new_submodel_weight['last_ln_f'] = submodel_weights['last_ln_f']
new_submodel_weight['word_ln'] = submodel_weights['word_ln']
wte_dict = {}
for k, v in submodel_weights.items():
if 'wte' in k:
wte_dict[k] = v
wte = recover_numpy_array(wte_dict, 'wte')
new_submodel_weight['wte'] = wte
self.load_numpy_state_dict(param_dict, new_submodel_weight)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**deprecated_arguments,
):
return self.model(
input_ids,
past_key_values,
attention_mask,
head_mask,
inputs_embeds,
labels,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
**deprecated_arguments,
)
class BloomSubModel(OffsiteTuningSubModel):
def __init__(
self,
model_name_or_path,
emulator_layer_num: int,
adapter_top_layer_num: int = 2,
adapter_bottom_layer_num: int = 2,
fp16_mix_precision=False,
partial_weight_decay=None):
self.model_name_or_path = model_name_or_path
self.emulator_layer_num = emulator_layer_num
self.adapter_top_layer_num = adapter_top_layer_num
self.adapter_bottom_layer_num = adapter_bottom_layer_num
super().__init__(
emulator_layer_num,
adapter_top_layer_num,
adapter_bottom_layer_num,
fp16_mix_precision)
self.partial_weight_decay = partial_weight_decay
def get_base_model(self):
total_layer_num = self.emulator_layer_num + \
self.adapter_top_layer_num + self.adapter_bottom_layer_num
config = BloomConfig.from_pretrained(self.model_name_or_path)
config.num_hidden_layers = total_layer_num
# initialize a model without pretrained weights
return BloomForCausalLM(config)
def get_model_transformer_blocks(self, model: BloomForCausalLM):
return model.transformer.h
def get_additional_param_state_dict(self):
# get parameter of additional parameter
model = self.model
param_dict = {
'wte': model.transformer.word_embeddings,
'word_ln': model.transformer.word_embeddings_layernorm,
'last_ln_f': model.transformer.ln_f
}
addition_weights = self.get_numpy_state_dict(param_dict)
wte = addition_weights.pop('wte')
wte_dict = split_numpy_array(wte, 25, 'wte')
addition_weights.update(wte_dict)
return addition_weights
def load_additional_param_state_dict(self, submodel_weights: dict):
# load additional weights:
model = self.model
param_dict = {
'wte': model.transformer.word_embeddings,
'word_ln': model.transformer.word_embeddings_layernorm,
'last_ln_f': model.transformer.ln_f
}
new_submodel_weight = {}
new_submodel_weight['last_ln_f'] = submodel_weights['last_ln_f']
new_submodel_weight['word_ln'] = submodel_weights['word_ln']
wte_dict = {}
for k, v in submodel_weights.items():
if 'wte' in k:
wte_dict[k] = v
wte = recover_numpy_array(wte_dict, 'wte')
new_submodel_weight['wte'] = wte
self.load_numpy_state_dict(param_dict, new_submodel_weight)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**deprecated_arguments,
):
return self.model(
input_ids,
past_key_values,
attention_mask,
head_mask,
inputs_embeds,
labels,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
**deprecated_arguments,
)
def parameters(self, recurse=True):
if self.partial_weight_decay is None:
return super().parameters(recurse)
elif isinstance(self.partial_weight_decay, float):
no_decay = ["bias", "layer_norm.weight"]
return [
{
"params": [
p for n, p in self.named_parameters() if not any(
nd in n for nd in no_decay)], "weight_decay": self.partial_weight_decay}, {
"params": [
p for n, p in self.named_parameters() if any(
nd in n for nd in no_decay)], "weight_decay": 0.0}]
else:
raise ValueError(
f"partial_weight_decay should be None or float, but got {self.partial_weight_decay}")
================================================
FILE: python/fate_llm/model_zoo/offsite_tuning/gpt2.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from fate_llm.model_zoo.offsite_tuning.offsite_tuning_model import OffsiteTuningSubModel, OffsiteTuningMainModel, get_dropout_emulator_and_adapters, split_numpy_array, recover_numpy_array
from transformers import GPT2LMHeadModel, GPT2Config
import torch
from typing import Optional, Tuple
class GPT2LMHeadMainModel(OffsiteTuningMainModel):
def __init__(
self,
model_name_or_path,
emulator_layer_num: int,
adapter_top_layer_num: int = 2,
adapter_bottom_layer_num: int = 2):
self.model_name_or_path = model_name_or_path
super().__init__(
emulator_layer_num,
adapter_top_layer_num,
adapter_bottom_layer_num)
def get_base_model(self):
return GPT2LMHeadModel.from_pretrained(self.model_name_or_path)
def get_model_transformer_blocks(self, model: GPT2LMHeadModel):
return model.transformer.h
def forward(self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,):
return self.model(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)
def get_additional_param_state_dict(self):
# get parameter of additional parameter
model = self.model
param_dict = {
'wte': model.transformer.wte,
'wpe': model.transformer.wpe,
'last_ln_f': model.transformer.ln_f
}
addition_weights = self.get_numpy_state_dict(param_dict)
wte = addition_weights.pop('wte')
wte_dict = split_numpy_array(wte, 10, 'wte')
wpe = addition_weights.pop('wpe')
wpe_dict = split_numpy_array(wpe, 10, 'wpe')
addition_weights.update(wte_dict)
addition_weights.update(wpe_dict)
return addition_weights
def load_additional_param_state_dict(self, submodel_weights: dict):
# load additional weights:
model = self.model
param_dict = {
'wte': model.transformer.wte,
'wpe': model.transformer.wpe,
'last_ln_f': model.transformer.ln_f
}
new_submodel_weight = {}
new_submodel_weight['last_ln_f'] = submodel_weights['last_ln_f']
wte_dict, wpe_dict = {}, {}
for k, v in submodel_weights.items():
if 'wte' in k:
wte_dict[k] = v
if 'wpe' in k:
wpe_dict[k] = v
wte = recover_numpy_array(wte_dict, 'wte')
wpe = recover_numpy_array(wpe_dict, 'wpe')
new_submodel_weight['wte'] = wte
new_submodel_weight['wpe'] = wpe
self.load_numpy_state_dict(param_dict, new_submodel_weight)
class GPT2LMHeadSubModel(OffsiteTuningSubModel):
def __init__(
self,
model_name_or_path,
emulator_layer_num: int,
adapter_top_layer_num: int = 2,
adapter_bottom_layer_num: int = 2,
fp16_mix_precision=False,
partial_weight_decay=None):
self.model_name_or_path = model_name_or_path
self.emulator_layer_num = emulator_layer_num
self.adapter_top_layer_num = adapter_top_layer_num
self.adapter_bottom_layer_num = adapter_bottom_layer_num
super().__init__(
emulator_layer_num,
adapter_top_layer_num,
adapter_bottom_layer_num,
fp16_mix_precision)
self.partial_weight_decay = partial_weight_decay
def get_base_model(self):
total_layer_num = self.emulator_layer_num + \
self.adapter_top_layer_num + self.adapter_bottom_layer_num
config = GPT2Config.from_pretrained(self.model_name_or_path)
config.num_hidden_layers = total_layer_num
# initialize a model without pretrained weights
return GPT2LMHeadModel(config)
def get_model_transformer_blocks(self, model: GPT2LMHeadModel):
return model.transformer.h
def get_additional_param_state_dict(self):
# get parameter of additional parameter
model = self.model
param_dict = {
'wte': model.transformer.wte,
'wpe': model.transformer.wpe,
'last_ln_f': model.transformer.ln_f
}
addition_weights = self.get_numpy_state_dict(param_dict)
wte = addition_weights.pop('wte')
wte_dict = split_numpy_array(wte, 10, 'wte')
wpe = addition_weights.pop('wpe')
wpe_dict = split_numpy_array(wpe, 10, 'wpe')
addition_weights.update(wte_dict)
addition_weights.update(wpe_dict)
return addition_weights
def load_additional_param_state_dict(self, submodel_weights: dict):
# load additional weights:
model = self.model
param_dict = {
'wte': model.transformer.wte,
'wpe': model.transformer.wpe,
'last_ln_f': model.transformer.ln_f
}
new_submodel_weight = {}
new_submodel_weight['last_ln_f'] = submodel_weights['last_ln_f']
wte_dict, wpe_dict = {}, {}
for k, v in submodel_weights.items():
if 'wte' in k:
wte_dict[k] = v
if 'wpe' in k:
wpe_dict[k] = v
wte = recover_numpy_array(wte_dict, 'wte')
wpe = recover_numpy_array(wpe_dict, 'wpe')
new_submodel_weight['wte'] = wte
new_submodel_weight['wpe'] = wpe
self.load_numpy_state_dict(param_dict, new_submodel_weight)
def forward(self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,):
return self.model(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)
def parameters(self, recurse=True):
if self.partial_weight_decay is None:
return super().parameters(recurse)
elif isinstance(self.partial_weight_decay, float):
no_decay = ["bias", "layer_norm.weight"]
return [
{
"params": [
p for n, p in self.named_parameters() if not any(
nd in n for nd in no_decay)], "weight_decay": self.partial_weight_decay}, {
"params": [
p for n, p in self.named_parameters() if any(
nd in n for nd in no_decay)], "weight_decay": 0.0}]
else:
raise ValueError(
f"partial_weight_decay should be None or float, but got {self.partial_weight_decay}")
================================================
FILE: python/fate_llm/model_zoo/offsite_tuning/llama.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from fate_llm.model_zoo.offsite_tuning.offsite_tuning_model import OffsiteTuningSubModel, OffsiteTuningMainModel, get_dropout_emulator_and_adapters, split_numpy_array, recover_numpy_array
from transformers import LlamaConfig, LlamaForCausalLM
class LlamaMainModel(OffsiteTuningMainModel):
def __init__(
self,
model_name_or_path,
emulator_layer_num: int,
adapter_top_layer_num: int = 2,
adapter_bottom_layer_num: int = 2):
self.model_name_or_path = model_name_or_path
super().__init__(
emulator_layer_num,
adapter_top_layer_num,
adapter_bottom_layer_num)
def get_base_model(self):
return LlamaForCausalLM.from_pretrained(self.model_name_or_path)
def get_model_transformer_blocks(self, model: LlamaForCausalLM):
return model.model.layers
def get_additional_param_state_dict(self):
# get parameter of additional parameter
model = self.model
param_dict = {
'wte': model.model.embed_tokens,
'last_ln_f': model.model.norm
}
addition_weights = self.get_numpy_state_dict(param_dict)
wte = addition_weights.pop('wte')
wte_dict = split_numpy_array(wte, 25, 'wte')
addition_weights.update(wte_dict)
return addition_weights
def load_additional_param_state_dict(self, submodel_weights: dict):
# load additional weights:
model = self.model
param_dict = {
'wte': model.model.embed_tokens,
'last_ln_f': model.model.norm
}
new_submodel_weight = {}
new_submodel_weight['last_ln_f'] = submodel_weights['last_ln_f']
wte_dict = {}
for k, v in submodel_weights.items():
if 'wte' in k:
wte_dict[k] = v
wte = recover_numpy_array(wte_dict, 'wte')
new_submodel_weight['wte'] = wte
self.load_numpy_state_dict(param_dict, new_submodel_weight)
def forward(self, **kwargs):
return self.model(**kwargs)
class LlamaSubModel(OffsiteTuningSubModel):
def __init__(
self,
model_name_or_path,
emulator_layer_num: int,
adapter_top_layer_num: int = 2,
adapter_bottom_layer_num: int = 2,
fp16_mix_precision=False,
partial_weight_decay=None):
self.model_name_or_path = model_name_or_path
self.emulator_layer_num = emulator_layer_num
self.adapter_top_layer_num = adapter_top_layer_num
self.adapter_bottom_layer_num = adapter_bottom_layer_num
super().__init__(
emulator_layer_num,
adapter_top_layer_num,
adapter_bottom_layer_num,
fp16_mix_precision)
self.partial_weight_decay = partial_weight_decay
def get_base_model(self):
total_layer_num = self.emulator_layer_num + \
self.adapter_top_layer_num + self.adapter_bottom_layer_num
config = LlamaConfig.from_pretrained(self.model_name_or_path)
config.num_hidden_layers = total_layer_num
# initialize a model without pretrained weights
return LlamaForCausalLM(config)
def get_model_transformer_blocks(self, model: LlamaForCausalLM):
return model.model.layers
def get_additional_param_state_dict(self):
# get parameter of additional parameter
model = self.model
param_dict = {
'wte': model.model.embed_tokens,
'last_ln_f': model.model.norm
}
addition_weights = self.get_numpy_state_dict(param_dict)
wte = addition_weights.pop('wte')
wte_dict = split_numpy_array(wte, 25, 'wte')
addition_weights.update(wte_dict)
return addition_weights
def load_additional_param_state_dict(self, submodel_weights: dict):
# load additional weights:
model = self.model
param_dict = {
'wte': model.model.embed_tokens,
'last_ln_f': model.model.norm
}
new_submodel_weight = {}
new_submodel_weight['last_ln_f'] = submodel_weights['last_ln_f']
wte_dict = {}
for k, v in submodel_weights.items():
if 'wte' in k:
wte_dict[k] = v
wte = recover_numpy_array(wte_dict, 'wte')
new_submodel_weight['wte'] = wte
self.load_numpy_state_dict(param_dict, new_submodel_weight)
def forward(self, **kwargs):
return self.model(**kwargs)
def parameters(self, recurse=True):
if self.partial_weight_decay is None:
return super().parameters(recurse)
elif isinstance(self.partial_weight_decay, float):
no_decay = ["bias", "layer_norm.weight"]
return [
{
"params": [
p for n, p in self.named_parameters() if not any(
nd in n for nd in no_decay)], "weight_decay": self.partial_weight_decay}, {
"params": [
p for n, p in self.named_parameters() if any(
nd in n for nd in no_decay)], "weight_decay": 0.0}]
else:
raise ValueError(
f"partial_weight_decay should be None or float, but got {self.partial_weight_decay}")
================================================
FILE: python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch as t
from torch import nn
from transformers import AutoModel
import numpy as np
import logging
logger = logging.getLogger(__name__)
def get_dropout_emulator_and_adapters(
transformer_layers: nn.ModuleList,
emulator_layer_num: int,
adapter_top_layer_num: int,
adapter_bottom_layer_num: int):
assert adapter_bottom_layer_num > 0 and adapter_top_layer_num > 0, "adapter layer num must be greater than 0"
assert emulator_layer_num < len(
transformer_layers), "emulator layer num must be less than the number of transformer layers"
assert adapter_bottom_layer_num + adapter_top_layer_num < len(
transformer_layers), "adapter layer num must be less than the number of transformer layers"
assert emulator_layer_num < len(
transformer_layers) and emulator_layer_num > 0, "emulator layer num must be less than the number of transformer layers"
bottom_idx = adapter_bottom_layer_num
top_idx = len(transformer_layers) - adapter_top_layer_num
bottom_layers = transformer_layers[:bottom_idx]
top_layers = transformer_layers[top_idx:]
kept_layers = transformer_layers[bottom_idx:top_idx]
emulator = nn.ModuleList()
stride = (len(kept_layers) - 1) / (emulator_layer_num - 1)
layer_idx = []
for i in range(emulator_layer_num):
idx = int(round(i * stride))
layer_idx.append(idx)
emulator.append(kept_layers[idx])
logger.info(
'take layer {} of the original model as the emulator'.format(
t.Tensor(layer_idx) +
bottom_idx))
return nn.ModuleList(emulator), nn.ModuleList(
bottom_layers), nn.ModuleList(top_layers)
def split_numpy_array(embedding_matrix, n, suffix):
# Calculate the indices where the splits should occur
embedding_matrix = embedding_matrix['weight']
indices = np.linspace(0, embedding_matrix.shape[0], n+1, dtype=int)
# Split the embedding matrix at the calculated indices
slices = [embedding_matrix[indices[i]:indices[i+1]] for i in range(n)]
# Create a dictionary with the slices
result_dict = {suffix+str(i): slice for i, slice in enumerate(slices)}
return result_dict
def recover_numpy_array(slices_dict, suffix=""):
# Get the slices from the dictionary and concatenate them
slices = [slices_dict[suffix + str(i)] for i in range(len(slices_dict))]
complete_array = np.concatenate(slices, axis=0)
return {'weight': complete_array}
class OffsiteTuningBaseModel(t.nn.Module):
def __init__(self, emulator_layer_num: int, adapter_top_layer_num: int = 2,
adapter_bottom_layer_num: int = 2, fp16_mix_precision=False):
super().__init__()
self.fp16_mix_precision = fp16_mix_precision
self.model = self.get_base_model()
self.initialize_model()
self.emulator, self.adapter_bottom, self.adapter_top = get_dropout_emulator_and_adapters(
transformer_layers=self.get_model_transformer_blocks(self.model),
emulator_layer_num=emulator_layer_num,
adapter_top_layer_num=adapter_top_layer_num,
adapter_bottom_layer_num=adapter_bottom_layer_num
)
self.post_initialization()
def initialize_model(self):
if self.fp16_mix_precision:
self.model.half()
for param in self.model.parameters():
param.requires_grad = False
def post_initialization(self):
pass
def get_adapter_top(self):
return self.adapter_top
def get_adapter_bottom(self):
return self.adapter_bottom
def get_emulator(self):
return self.emulator
def get_additional_param_state_dict(self):
# get parameter of additional parameter
return {}
def load_additional_param_state_dict(self, submodel_weights: dict):
# load additional weights:
pass
def _get_numpy_arr(self, v):
if v.dtype == t.bfloat16:
# float 32
v = v.detach().cpu().float().numpy()
else:
v = v.detach().cpu().numpy()
return v
def load_numpy_state_dict(self, module_dict, state_dict):
param_dict = module_dict
for k, v in param_dict.items():
if k not in state_dict:
continue
addition_weights = {
k: t.tensor(v) for k,
v in state_dict[k].items()}
v.load_state_dict(addition_weights)
def get_numpy_state_dict(self, module_dict):
weight_dict = {}
for k, v in module_dict.items():
weight_dict[k] = {
k: self._get_numpy_arr(v) for k,
v in v.state_dict().items()}
return weight_dict
def get_submodel_weights(self, with_emulator=True) -> dict:
if with_emulator:
submodel_weights = {
"emulator": {
k: self._get_numpy_arr(v) for k,
v in self.get_emulator().state_dict().items()},
"adapter_top": {
k: self._get_numpy_arr(v) for k,
v in self.get_adapter_top().state_dict().items()},
"adapter_bottom": {
k: self._get_numpy_arr(v) for k,
v in self.get_adapter_bottom().state_dict().items()}}
else:
submodel_weights = {
"adapter_top": {
k: self._get_numpy_arr(v) for k,
v in self.get_adapter_top().state_dict().items()},
"adapter_bottom": {
k: self._get_numpy_arr(v) for k,
v in self.get_adapter_bottom().state_dict().items()}}
addition_weights = self.get_additional_param_state_dict()
submodel_weights.update(addition_weights)
return submodel_weights
def load_submodel_weights(self, submodel_weights: dict, with_emulator=True):
if with_emulator:
emulator_weights = {
k: t.tensor(v) for k,
v in submodel_weights['emulator'].items()}
emulator = self.get_emulator()
emulator.load_state_dict(emulator_weights)
adapter_top_weights = {
k: t.tensor(v) for k,
v in submodel_weights['adapter_top'].items()}
adapter_bottom_weights = {
k: t.tensor(v) for k,
v in submodel_weights['adapter_bottom'].items()}
adapter_top = self.get_adapter_top()
adapter_bottom = self.get_adapter_bottom()
adapter_top.load_state_dict(adapter_top_weights)
adapter_bottom.load_state_dict(adapter_bottom_weights)
self.load_additional_param_state_dict(submodel_weights)
def forward(self, **kwargs):
raise NotImplementedError()
def get_base_model(self):
raise NotImplementedError()
def get_model_transformer_blocks(self, model: t.nn.Module):
raise NotImplementedError()
class OffsiteTuningMainModel(OffsiteTuningBaseModel):
def post_initialization(self):
pass
class OffsiteTuningSubModel(OffsiteTuningBaseModel):
def post_initialization(self):
# mix precision model training
for param in self.adapter_top.parameters():
param.data = param.data.float()
param.requires_grad = True
for param in self.adapter_bottom.parameters():
param.data = param.data.float()
param.requires_grad = True
================================================
FILE: python/fate_llm/model_zoo/pellm/__init__.py
================================================
================================================
FILE: python/fate_llm/model_zoo/pellm/albert.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from transformers import AlbertConfig, AutoConfig
from transformers import AlbertForSequenceClassification
from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM
class Albert(PELLM):
config_class = AlbertConfig
model_loader = AlbertForSequenceClassification
def __init__(self, config: dict = None,
pretrained_path: str = None,
peft_type: str = None,
peft_config: dict = None,
**kwargs
) -> None:
if pretrained_path is not None:
self.check_config(pretain_path=pretrained_path)
if config is None and pretrained_path is None:
config = AlbertConfig().to_dict() # use default model setting
super().__init__(
config=config,
pretrained_path=pretrained_path,
peft_type=peft_type,
peft_config=peft_config,
**kwargs)
def check_config(self, pretain_path):
config = AutoConfig.from_pretrained(pretain_path)
assert isinstance(
config, AlbertConfig), 'The config of pretrained model must be AlbertConfig, but got {}'.format(
type(config))
================================================
FILE: python/fate_llm/model_zoo/pellm/bart.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from transformers import BartConfig, AutoConfig
from transformers import BartForSequenceClassification
from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM
class Bart(PELLM):
config_class = BartConfig
model_loader = BartForSequenceClassification
def __init__(self, config: dict = None,
pretrained_path: str = None,
peft_type: str = None,
peft_config: dict = None,
**kwargs) -> None:
if pretrained_path is not None:
self.check_config(pretrain_path=pretrained_path)
if config is None and pretrained_path is None:
config = BartConfig().to_dict()
super().__init__(
config=config,
pretrained_path=pretrained_path,
peft_type=peft_type,
peft_config=peft_config,
**kwargs)
def check_config(self, pretrain_path):
config = AutoConfig.from_pretrained(pretrain_path)
assert isinstance(
config, BartConfig), 'The config of pretrained model must be BartConfig, but got {}'.format(
type(config))
================================================
FILE: python/fate_llm/model_zoo/pellm/bert.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from transformers import BertConfig, AutoConfig
from transformers import BertForSequenceClassification
from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM
class Bert(PELLM):
config_class = BertConfig
model_loader = BertForSequenceClassification
def __init__(self, config: dict = None,
pretrained_path: str = None,
peft_type: str = None,
peft_config: dict = None,
**kwargs) -> None:
if pretrained_path is not None:
self.check_config(pretrain_path=pretrained_path)
if config is None and pretrained_path is None:
config = BertConfig().to_dict()
super().__init__(
config=config,
pretrained_path=pretrained_path,
peft_type=peft_type,
peft_config=peft_config,
**kwargs)
def check_config(self, pretrain_path):
config = AutoConfig.from_pretrained(pretrain_path)
assert isinstance(
config, BertConfig), 'The config of pretrained model must be BertConfig, but got {}'.format(
type(config))
================================================
FILE: python/fate_llm/model_zoo/pellm/bloom.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from transformers import BloomConfig
from transformers import BloomForCausalLM
from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM
class Bloom(PELLM):
config_class = BloomConfig
model_loader = BloomForCausalLM
def __init__(self, config: dict = None,
pretrained_path: str = None,
peft_type: str = None,
peft_config: dict = None,
**kwargs
) -> None:
if config is None and pretrained_path is None:
config = BloomConfig().to_dict() # use default model setting
super().__init__(config=config, pretrained_path=pretrained_path,
peft_type=peft_type, peft_config=peft_config, **kwargs)
================================================
FILE: python/fate_llm/model_zoo/pellm/chatglm.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM
from transformers import AutoConfig
class ChatGLM(PELLM):
def __init__(self,
pretrained_path: str = None,
peft_type: str = None,
peft_config: dict = None,
pre_seq_len: int = None,
prefix_projection: bool = False,
**kwargs) -> None:
self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection
super().__init__(pretrained_path=pretrained_path,
peft_type=peft_type,
peft_config=peft_config,
**kwargs
)
def init_config(self):
self.config = AutoConfig.from_pretrained(
self.config_path, trust_remote_code=True)
self.config.pre_seq_len = self.pre_seq_len
self.config.prefix_projection = self.prefix_projection
def add_peft(self):
if self.pre_seq_len:
self._pe_lm.half()
self._pe_lm.transformer.prefix_encoder.float()
else:
super(ChatGLM, self).add_peft()
================================================
FILE: python/fate_llm/model_zoo/pellm/deberta.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from transformers import DebertaConfig, AutoConfig
from transformers import DebertaForSequenceClassification
from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM
class Deberta(PELLM):
config_class = DebertaConfig
model_loader = DebertaForSequenceClassification
def __init__(self, config: dict = None,
pretrained_path: str = None,
peft_type: str = None,
peft_config: dict = None,
**kwargs) -> None:
if pretrained_path is not None:
self.check_config(pretrain_path=pretrained_path)
if config is None and pretrained_path is None:
config = DebertaConfig().to_dict()
super().__init__(
config=config,
pretrained_path=pretrained_path,
peft_type=peft_type,
peft_config=peft_config,
**kwargs)
def check_config(self, pretrain_path):
config = AutoConfig.from_pretrained(pretrain_path)
assert isinstance(
config, DebertaConfig), 'The config of pretrained model must be DebertaConfig, but got {}'.format(
type(config))
================================================
FILE: python/fate_llm/model_zoo/pellm/distilbert.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from transformers import DistilBertConfig, AutoConfig
from transformers import DistilBertForSequenceClassification
from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM
class DistilBert(PELLM):
config_class = DistilBertConfig
model_loader = DistilBertForSequenceClassification
def __init__(self, config: dict = None,
pretrained_path: str = None,
peft_type: str = None,
peft_config: dict = None,
**kwargs) -> None:
if pretrained_path is not None:
self.check_config(pretrain_path=pretrained_path)
if config is None and pretrained_path is None:
config = DistilBertConfig().to_dict()
super().__init__(
config=config,
pretrained_path=pretrained_path,
peft_type=peft_type,
peft_config=peft_config,
**kwargs)
def check_config(self, pretrain_path):
config = AutoConfig.from_pretrained(pretrain_path)
assert isinstance(
config, DistilBertConfig), 'The config of pretrained model must be DistilBertConfig, but got {}'.format(
type(config))
================================================
FILE: python/fate_llm/model_zoo/pellm/gpt2.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from transformers import GPT2Config, AutoConfig
from transformers import GPT2ForSequenceClassification, AutoModelForCausalLM
from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM
class GPT2(PELLM):
config_class = GPT2Config
model_loader = GPT2ForSequenceClassification
def __init__(self,
config: dict = None,
pretrained_path: str = None,
peft_type: str = None,
peft_config: dict = None,
**kwargs) -> None:
if pretrained_path is not None:
self.check_config(pretrain_path=pretrained_path)
if config is None and pretrained_path is None:
config = GPT2Config().to_dict()
super().__init__(
config=config,
pretrained_path=pretrained_path,
peft_type=peft_type,
peft_config=peft_config,
**kwargs)
def check_config(self, pretrain_path):
config = AutoConfig.from_pretrained(pretrain_path)
assert isinstance(
config, GPT2Config), 'The config of pretrained model must be GPT2Config, but got {}'.format(
type(config))
class GPT2CLM(GPT2):
model_loader = AutoModelForCausalLM
================================================
FILE: python/fate_llm/model_zoo/pellm/llama.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM
from transformers import AutoConfig
from transformers import LlamaConfig
from transformers import LlamaForCausalLM
class LLaMa(PELLM):
config_class = LlamaConfig
def __init__(self,
pretrained_path: str = None,
peft_type: str = None,
peft_config: dict = None,
**kwargs) -> None:
super().__init__(pretrained_path=pretrained_path,
peft_type=peft_type,
peft_config=peft_config,
**kwargs)
def init_base_lm(self, **kwargs):
if self.config is not None:
self._pe_lm = LlamaForCausalLM.from_pretrained(self.config_path,
config=self.config,
torch_dtype=self.torch_dtype,
**kwargs)
elif self.config_path is not None:
self._pe_lm = LlamaForCausalLM.from_pretrained(self.config_path, torch_dtype=self.torch_dtype, **kwargs)
else:
raise ValueError(
'config_path to pretrained model folder cannot be None')
def check_config(self, pretrain_path):
config = AutoConfig.from_pretrained(pretrain_path)
assert isinstance(
config, LlamaConfig), 'The config of pretrained model must be LlamaConfig, but got {}'.format(
type(config))
================================================
FILE: python/fate_llm/model_zoo/pellm/opt.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from transformers import OPTConfig
from transformers import OPTForCausalLM
from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM
class OPT(PELLM):
config_class = OPTConfig
model_loader = OPTForCausalLM
def __init__(self, config: dict = None,
pretrained_path: str = None,
peft_type: str = None,
peft_config: dict = None,
**kwargs
) -> None:
if config is None and pretrained_path is None:
config = OPTConfig().to_dict() # use default model setting
super().__init__(config=config, pretrained_path=pretrained_path,
peft_type=peft_type, peft_config=peft_config, **kwargs)
================================================
FILE: python/fate_llm/model_zoo/pellm/parameter_efficient_llm.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import peft
import torch
from collections.abc import Mapping
from peft import PeftModel, TaskType
from transformers import AutoConfig
from transformers import AutoModel
from transformers.configuration_utils import PretrainedConfig
import logging
logger = logging.getLogger(__name__)
AVAILABLE_PEFT_CONFIG = list(
filter(
lambda peft_type: peft_type.endswith("Config"), dir(peft)
)
)
class PELLM(torch.nn.Module):
config_class: PretrainedConfig = None
model_loader = None
def __init__(self,
config: dict = None,
pretrained_path: str = None,
peft_type: str = None,
peft_config=None,
torch_dtype: str = None,
trust_remote_code: bool = False,
**kwargs
) -> None:
super().__init__()
self._pe_lm: PeftModel = None
self.config = config
self.config_path = pretrained_path
self.peft_type = peft_type
self.peft_config = peft_config
self.torch_dtype = None if not torch_dtype else getattr(torch, torch_dtype)
self.trust_remote_code = trust_remote_code
assert self.config_path is not None or self.config is not None, \
"At least one of config_path and config must be set."
self._init_pelm(**kwargs)
def _init_pelm(self, **kwargs):
self.init_lm_with_peft(**kwargs)
self.model_summary()
def init_lm_with_peft(self, **kwargs):
self.init_config(**kwargs)
self.init_base_lm()
self.add_peft()
def init_config(self, **kwargs):
if self.config_path is not None:
self.config = AutoConfig.from_pretrained(self.config_path, trust_remote_code=self.trust_remote_code)
elif self.config is not None and self.config_class is not None:
self.config = self.config_class().from_dict(self.config)
else:
raise ValueError(
'config_path to pretrained model folder and model config dict cannot be None at the same time, '
'you need to specify one of them')
if kwargs:
self.config.update(kwargs)
def init_base_lm(self, **kwargs):
model_loader = self.model_loader if self.model_loader is not None else AutoModel
if self.config is not None:
self._pe_lm = model_loader.from_pretrained(
self.config_path, config=self.config,
torch_dtype=self.torch_dtype, **kwargs,
trust_remote_code=self.trust_remote_code
)
elif self.config_path is not None:
self._pe_lm = model_loader.from_pretrained(
self.config_path, torch_dtype=self.torch_dtype,
trust_remote_code=self.trust_remote_code, **kwargs)
else:
raise ValueError(
'config_path to pretrained model folder cannot be None')
def add_peft(self):
assert self.peft_type in AVAILABLE_PEFT_CONFIG, 'peft name {} not in available config {}'.format(
self.peft_type, AVAILABLE_PEFT_CONFIG)
if self.peft_config is None:
peft_config = getattr(peft, self.peft_type)()
elif isinstance(self.peft_config, dict):
peft_config = getattr(peft, self.peft_type)(**self.peft_config)
else:
raise ValueError(f"Can not parse peft_config of {type(self.peft_config)}")
self._pe_lm = peft.get_peft_model(self._pe_lm, peft_config)
self.peft_config = peft_config
def model_summary(self):
if hasattr(self._pe_lm, "print_trainable_parameters"):
summary = self._pe_lm.print_trainable_parameters()
logger.debug(f'PELLM model summary: \n{summary}')
def forward(self, *args, **kwargs):
forward_ret = self._pe_lm.forward(*args, **kwargs)
if self.peft_config is None or self.peft_config.task_type != TaskType.SEQ_CLS:
return forward_ret
else:
return forward_ret.logits
def save_trainable(self, output_path):
self._pe_lm.save_pretrained(output_path)
class AutoPELLM(PELLM):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
================================================
FILE: python/fate_llm/model_zoo/pellm/qwen.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from transformers import Qwen2Config
from transformers import Qwen2ForCausalLM
from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM
class Qwen(PELLM):
config_class = Qwen2Config
model_loader = Qwen2ForCausalLM
def __init__(self, config: dict = None,
pretrained_path: str = None,
peft_type: str = None,
peft_config: dict = None,
**kwargs
) -> None:
if config is None and pretrained_path is None:
config = Qwen2Config().to_dict() # use default model setting
super().__init__(config=config, pretrained_path=pretrained_path,
peft_type=peft_type, peft_config=peft_config, **kwargs)
================================================
FILE: python/fate_llm/model_zoo/pellm/roberta.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from transformers import RobertaConfig, AutoConfig
from transformers import RobertaForSequenceClassification
from fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM
class Roberta(PELLM):
config_class = RobertaConfig
model_loader = RobertaForSequenceClassification
def __init__(self, config: dict = None,
pretrained_path: str = None,
peft_type: str = None,
peft_config: dict = None,
**kwargs) -> None:
if pretrained_path is not None:
self.check_config(pretrain_path=pretrained_path)
if config is None and pretrained_path is None:
config = RobertaConfig().to_dict()
super().__init__(
config=config,
pretrained_path=pretrained_path,
peft_type=peft_type,
peft_config=peft_config,
**kwargs)
def check_config(self, pretrain_path):
config = AutoConfig.from_pretrained(pretrain_path)
assert isinstance(
config, RobertaConfig), 'The config of pretrained model must be RobertaConfig, but got {}'.format(
type(config))
================================================
FILE: python/fate_llm/runner/__init__.py
================================================
================================================
FILE: python/fate_llm/runner/fdkt_runner.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import torch
from fate.components.components.nn.nn_runner import (
load_model_dict_from_path,
dir_warning,
loader_load_from_conf,
run_dataset_func,
)
from typing import Dict
from fate.components.components.nn.loader import Loader
from typing import Union, Optional, Literal
from transformers.trainer_utils import get_last_checkpoint
from fate.arch.dataframe import DataFrame
from fate.components.components.nn.runner.homo_default_runner import DefaultRunner
from fate_llm.algo.fdkt import FDKTTrainingArguments, FDKTSLM, FDKTLLM
logger = logging.getLogger(__name__)
AUG_DATA_SAVED_PATH_SUFFIX = "aug_data.pkl"
DP_MODEL_SAVED_PATH_SUFFIX = "dp_model"
class FDKTRunner(DefaultRunner):
def __init__(
self,
algo: str = "fdkt",
inference_inst_conf: Optional[Dict] = None,
model_conf: Optional[Dict] = None,
embedding_model_conf: Optional[Dict] = None,
optimizer_conf: Optional[Dict] = None,
training_args_conf: Optional[Dict] = None,
dataset_conf: Optional[Dict] = None,
data_collator_conf: Optional[Dict] = None,
tokenizer_conf: Optional[Dict] = None,
task_type: Literal["causal_lm", "others"] = "causal_lm",
save_dp_model: bool = False,
) -> None:
super(FDKTRunner, self).__init__()
self.algo = algo
self.inference_inst_conf = inference_inst_conf
self.model_conf = model_conf
self.embedding_model_conf = embedding_model_conf
self.optimizer_conf = optimizer_conf
self.training_args_conf = training_args_conf
self.dataset_conf = dataset_conf
self.data_collator_conf = data_collator_conf
self.tokenizer_conf = tokenizer_conf
self.task_type = task_type
self.save_dp_model = save_dp_model
self.training_args = None
# check param
if self.algo.lower() != "fdkt":
raise ValueError(f"algo should be fdkt")
if self.task_type not in ["causal_lm"]:
raise ValueError("task_type should be causal_lm")
def common_setup(self, saved_model=None, output_dir=None):
ctx = self.get_context()
if output_dir is None:
output_dir = "./"
if self.model_conf is not None:
model = loader_load_from_conf(self.model_conf)
else:
model = None
resume_path = None
if saved_model is not None:
model_dict = load_model_dict_from_path(saved_model)
model.load_state_dict(model_dict)
logger.info(f"loading model dict from {saved_model} to model done")
if get_last_checkpoint(saved_model) is not None:
resume_path = saved_model
logger.info(f"checkpoint detected, resume_path set to {resume_path}")
# load tokenizer if import conf provided
if self.tokenizer_conf is not None:
tokenizer = loader_load_from_conf(self.tokenizer_conf)
else:
tokenizer = None
# args
dir_warning(self.training_args_conf)
training_args = FDKTTrainingArguments(**self.training_args_conf)
# reset to default, saving to arbitrary path is not allowed in
# DefaultRunner
training_args.output_dir = output_dir
training_args.resume_from_checkpoint = resume_path # resume path
self.training_args = training_args
dataset = loader_load_from_conf(self.dataset_conf)
return ctx, model, tokenizer, training_args, dataset
def llm_setup(self, train_set=None, validate_set=None, output_dir=None, saved_model=None):
ctx, model, tokenizer, training_args, dataset = self.common_setup(
output_dir=output_dir, saved_model=saved_model)
if model is not None:
model = model.load()
inference_inst = None
if self.inference_inst_conf is not None:
inference_inst = loader_load_from_conf(self.inference_inst_conf)
embedding_model = loader_load_from_conf(self.embedding_model_conf)
if embedding_model is None:
raise ValueError(f"model is None, cannot load model from conf {self.model_conf}")
embedding_model = embedding_model.load()
trainer = FDKTLLM(
ctx=ctx,
inference_inst=inference_inst,
model=model,
embedding_model=embedding_model,
training_args=training_args,
tokenizer=tokenizer,
dataset=dataset,
)
return trainer
def slm_setup(self, train_set=None, validate_set=None, output_dir=None, saved_model=None):
ctx, model, tokenizer, training_args, dataset = self.common_setup(
output_dir=output_dir, saved_model=saved_model)
model = model.load()
dataset.load(train_set)
if self.data_collator_conf is not None:
data_collator = loader_load_from_conf(self.data_collator_conf)
else:
data_collator = None
optimizer_loader = Loader.from_dict(self.optimizer_conf)
optimizer_ = optimizer_loader.load_item()
optimizer_params = optimizer_loader.kwargs
optimizer = optimizer_(model.parameters(), **optimizer_params)
trainer = FDKTSLM(
ctx=ctx,
model=model,
training_args=training_args,
tokenizer=tokenizer,
train_set=dataset,
data_collator=data_collator,
optimizer=optimizer,
)
return trainer
def train(
self,
train_data: Optional[Union[str, DataFrame]] = None,
validate_data: Optional[Union[str, DataFrame]] = None,
output_dir: str = None,
saved_model_path: str = None,
):
if self.is_client():
trainer = self.slm_setup(train_set=train_data, validate_set=validate_data, output_dir=output_dir, saved_model=saved_model_path)
aug_data = trainer.aug_data()
data_saved_path = output_dir + '/' + AUG_DATA_SAVED_PATH_SUFFIX
logger.info('result save to path {}'.format(data_saved_path))
torch.save(aug_data, data_saved_path)
if self.save_dp_model:
model_save_dir = output_dir + "/" + DP_MODEL_SAVED_PATH_SUFFIX
trainer.save_model(model_save_dir)
else:
trainer = self.llm_setup(
train_set=train_data, validate_set=validate_data, output_dir=output_dir, saved_model=saved_model_path
)
trainer.aug_data()
def predict(self, *args, **kwargs):
pass
================================================
FILE: python/fate_llm/runner/fedcot_runner.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from fate.components.components.nn.nn_runner import (
NNRunner,
load_model_dict_from_path,
dir_warning,
loader_load_from_conf,
)
from fate_llm.model_zoo.hf_model import HFAutoModelForCausalLM
from fate.components.components.nn.loader import Loader
from fate.arch.dataframe import DataFrame
from fate.ml.nn.dataset.base import Dataset
from typing import Dict
from fate_llm.algo.fedcot.fedcot_trainer import FedCoTTrainerClient, FedCoTTraineServer
from fate_llm.algo.fedcot.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderClient, SLMEncoderDecoderServer
from fate_llm.algo.inferdpt.init._init import InferInit
import torch.nn as nn
import torch.optim as optim
from fate_llm.trainer.seq2seq_trainer import Seq2SeqTrainingArguments
from typing import Union, Type, Callable, Optional
from transformers.trainer_utils import get_last_checkpoint
from typing import Literal
import logging
logger = logging.getLogger(__name__)
def _check_instances(
model: nn.Module = None,
optimizer: optim.Optimizer = None,
train_args: Seq2SeqTrainingArguments = None,
data_collator: Callable = None,
) -> None:
if model is not None and not issubclass(type(model), nn.Module):
raise TypeError(f"SetupReturn Error: model must be a subclass of torch.nn.Module but got {type(model)}")
if optimizer is not None and not issubclass(type(optimizer), optim.Optimizer):
raise TypeError(
f"SetupReturn Error: optimizer must be a subclass of torch.optim.Optimizer but got {type(optimizer)}"
)
if train_args is not None and not isinstance(train_args, Seq2SeqTrainingArguments):
raise TypeError(
f"SetupReturn Error: train_args must be an instance of Seq2SeqTrainingArguments "
f"but got {type(train_args)}"
)
if data_collator is not None and not callable(data_collator):
raise TypeError(f"SetupReturn Error: data_collator must be callable but got {type(data_collator)}")
class FedCoTRunner(NNRunner):
def __init__(
self,
mode: Literal['train_only', 'infer_only', 'infer_and_train'],
model_conf: Optional[Dict] = None,
dataset_conf: Optional[Dict] = None,
optimizer_conf: Optional[Dict] = None,
training_args_conf: Optional[Dict] = None,
data_collator_conf: Optional[Dict] = None,
tokenizer_conf: Optional[Dict] = None,
infer_inst_init_conf: Dict = None,
encode_template: str = None,
instruction_template: str = None,
decode_template: str = None,
remote_inference_kwargs: Dict = {},
local_inference_kwargs: Dict = {},
perturb_doc_key: str = 'perturbed_doc',
perturbed_response_key: str = 'perturbed_response',
result_key: str = 'infer_result',
) -> None:
super(NNRunner, self).__init__()
self.model_conf = model_conf
self.dataset_conf = dataset_conf
self.optimizer_conf = optimizer_conf
self.training_args_conf = training_args_conf
self.data_collator_conf = data_collator_conf
self.mode = mode
self.tokenizer_conf = tokenizer_conf
self.infer_inst_init_conf = infer_inst_init_conf
self.encode_template = encode_template
self.instruction_template = instruction_template
self.decode_template = decode_template
self.remote_inference_kwargs = remote_inference_kwargs
self.local_inference_kwargs = local_inference_kwargs
self.perturb_doc_key = perturb_doc_key
self.perturbed_response_key = perturbed_response_key
self.result_key = result_key
self._temp_data_path = ''
# setup var
self.trainer = None
self.training_args = None
def _get_infer_inst(self, init_conf):
if init_conf is None:
return None
loader = Loader.from_dict(init_conf)
init_inst = loader.load_item()(self.get_context())
assert isinstance(init_inst, InferInit), 'Need a InferInit class for initialization, but got {}'.format(type(init_inst))
infer_inst = init_inst.get_inst()
logger.info('inferdpt inst loaded')
return infer_inst
def _prepare_data(self, data, data_name):
if data is None:
return None
if isinstance(data, DataFrame) and self.dataset_conf is None:
raise RuntimeError('DataFrame format dataset is not supported, please use bind path to load your dataset')
else:
dataset = loader_load_from_conf(self.dataset_conf)
if hasattr(dataset, "load"):
logger.info("load path is {}".format(data))
import os
if os.path.exists(data) and os.path.isdir(data):
self._temp_data_path = data
load_output = dataset.load(data)
if load_output is not None:
dataset = load_output
return dataset
else:
raise RuntimeError('You must offer an existing folder path as data input, but got {}'.format(data))
else:
raise ValueError(
f"The dataset {dataset} lacks a load() method, which is required for data parsing in the DefaultRunner. \
Please implement this method in your dataset class. You can refer to the base class 'Dataset' in 'fate.ml.nn.dataset.base' \
for the necessary interfaces to implement."
)
if dataset is not None and not issubclass(type(dataset), Dataset):
raise TypeError(
f"SetupReturn Error: {data_name}_set must be a subclass of fate built-in Dataset but got {type(dataset)}, \n"
f"You can get the class via: from fate.ml.nn.dataset.table import Dataset"
)
return dataset
def client_setup(self, train_set=None, validate_set=None, output_dir=None, saved_model=None, stage="train"):
ctx = self.get_context()
model = loader_load_from_conf(self.model_conf)
if isinstance(model, HFAutoModelForCausalLM):
model = model.load()
if model is None:
raise ValueError(f"model is None, cannot load model from conf {self.model_conf}")
if output_dir is None:
output_dir = "./"
resume_path = None
if saved_model is not None:
model_dict = load_model_dict_from_path(saved_model)
model.load_state_dict(model_dict)
logger.info(f"loading model dict from {saved_model} to model done")
if get_last_checkpoint(saved_model) is not None:
resume_path = saved_model
logger.info(f"checkpoint detected, resume_path set to {resume_path}")
# load optimizer
if self.optimizer_conf:
optimizer_loader = Loader.from_dict(self.optimizer_conf)
optimizer_ = optimizer_loader.load_item()
optimizer_params = optimizer_loader.kwargs
optimizer = optimizer_(model.parameters(), **optimizer_params)
else:
optimizer = None
# load collator func
data_collator = loader_load_from_conf(self.data_collator_conf)
# load tokenizer if import conf provided
tokenizer = loader_load_from_conf(self.tokenizer_conf)
# args
dir_warning(self.training_args_conf)
training_args = Seq2SeqTrainingArguments(**self.training_args_conf)
# reset to default, saving to arbitrary path is not allowed in
# DefaultRunner
training_args.output_dir = output_dir
training_args.resume_from_checkpoint = resume_path # resume path
self.training_args = training_args
if self.training_args.world_size > 0 and self.training_args.local_rank == 0:
infer_client = self._get_infer_inst(self.infer_inst_init_conf)
else:
infer_client = None # only rank 0 need to load the client
# prepare trainer
trainer = FedCoTTrainerClient(
ctx=ctx,
training_args=training_args,
train_set=train_set,
val_set=validate_set,
model=model,
tokenizer=tokenizer,
mode=self.mode,
encode_template=self.encode_template,
decode_template=self.decode_template,
instruction_template=self.instruction_template,
local_inference_kwargs=self.local_inference_kwargs,
remote_inference_kwargs=self.remote_inference_kwargs,
data_collator=data_collator,
optimizer=optimizer,
infer_client=infer_client,
tmp_data_share_path=self._temp_data_path
)
return trainer
def server_setup(self, stage="train"):
trainer = FedCoTTraineServer(
ctx=self.get_context(),
infer_server=self._get_infer_inst(self.infer_inst_init_conf)
)
return trainer
def train(
self,
train_data: Optional[Union[str]] = None,
validate_data: Optional[Union[str]] = None,
output_dir: str = None,
saved_model_path: str = None,
):
if self.is_client():
train_set = self._prepare_data(train_data, "train_data")
validate_set = self._prepare_data(validate_data, "val_data")
trainer = self.client_setup(
train_set=train_set, validate_set=validate_set, output_dir=output_dir, saved_model=saved_model_path
)
self.trainer = trainer
trainer.train()
if self.mode == 'infer_only':
# save result dataset to the output dir
saving_path = output_dir + '/' + 'inference_result.pkl'
torch.save(train_set.dataset, saving_path)
logger.info('inference result saved to {}'.format(saving_path))
else:
if output_dir is not None:
if self.training_args.deepspeed and self.training_args.local_rank != 0:
pass
else:
trainer.save_model(output_dir)
elif self.is_server():
if self.mode == 'train_only':
return
else:
trainer = self.server_setup()
trainer.train()
def predict(self, test_data: Union[str], saved_model_path: str = None) -> None:
logger.warning('The prediction mode is not supported by this algorithm in the current version. Please perform inference using locally saved models.')
return
================================================
FILE: python/fate_llm/runner/fedkseed_runner.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict
from typing import Literal
from typing import Optional
import transformers
from fate.components.components.nn.nn_runner import (
NNRunner,
dir_warning,
loader_load_from_conf,
)
from fate.components.components.nn.runner.homo_default_runner import DefaultRunner
from fate_llm.algo.fedkseed.fedkseed import Trainer, FedKSeedTrainingArguments, ClientTrainer
from fate_llm.algo.fedkseed.zo_utils import build_seed_candidates
from fate_llm.trainer.seq2seq_trainer import Seq2SeqTrainingArguments
logger = logging.getLogger(__name__)
SUPPORTED_ALGO = ["fedkseed"]
class FedKSeedRunner(DefaultRunner):
def __init__(
self,
algo: str = "fedkseed",
model_conf: Optional[Dict] = None,
dataset_conf: Optional[Dict] = None,
optimizer_conf: Optional[Dict] = None,
training_args_conf: Optional[Dict] = None,
fed_args_conf: Optional[Dict] = None,
data_collator_conf: Optional[Dict] = None,
tokenizer_conf: Optional[Dict] = None,
task_type: Literal["causal_lm", "other"] = "causal_lm",
local_mode: bool = False,
save_trainable_weights_only: bool = False,
) -> None:
super(NNRunner, self).__init__()
self.algo = algo
self.model_conf = model_conf
self.dataset_conf = dataset_conf
self.optimizer_conf = optimizer_conf
self.training_args_conf = training_args_conf
self.fed_args_conf = fed_args_conf
self.data_collator_conf = data_collator_conf
self.local_mode = local_mode
self.tokenizer_conf = tokenizer_conf
self.task_type = task_type
self.save_trainable_weights_only = save_trainable_weights_only
# check param
if self.algo not in SUPPORTED_ALGO:
raise ValueError(f"algo should be one of {SUPPORTED_ALGO}")
if self.task_type not in ["causal_lm", "others"]:
raise ValueError("task_type should be one of [binary, multi, regression, others]")
assert isinstance(self.local_mode, bool), "local should be bool"
# setup var
self.trainer = None
self.training_args = None
def client_setup(self, train_set=None, validate_set=None, output_dir=None, saved_model=None, stage="train"):
if self.algo != "fedkseed":
raise ValueError(f"algo {self.algo} not supported")
ctx = self.get_context()
model = maybe_loader_load_from_conf(self.model_conf)
if model is None:
raise ValueError(f"model is None, cannot load model from conf {self.model_conf}")
if output_dir is None:
output_dir = "./"
tokenizer = transformers.AutoTokenizer.from_pretrained(**self.data_collator_conf["kwargs"]["tokenizer_params"])
data_collator = transformers.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
dir_warning(self.training_args_conf)
training_args = Seq2SeqTrainingArguments(**self.training_args_conf)
self.training_args = training_args
training_args.output_dir = output_dir
fedkseed_args = FedKSeedTrainingArguments(**self.fed_args_conf)
logger.debug(f"training_args: {training_args}")
logger.debug(f"fedkseed_args: {fedkseed_args}")
trainer = ClientTrainer(
ctx=ctx,
model=model,
training_args=training_args,
fedkseed_args=fedkseed_args,
data_collator=data_collator,
tokenizer=tokenizer,
train_dataset=train_set,
eval_dataset=validate_set,
)
return trainer
def server_setup(self, stage="train"):
if self.algo != "fedkseed":
raise ValueError(f"algo {self.algo} not supported")
ctx = self.get_context()
fedkseed_args = FedKSeedTrainingArguments(**self.fed_args_conf)
training_args = Seq2SeqTrainingArguments(**self.training_args_conf)
seed_candidates = build_seed_candidates(fedkseed_args.k, low=0, high=2 ** 32)
trainer = Trainer(ctx=ctx, seed_candidates=seed_candidates, args=training_args, fedkseed_args=fedkseed_args)
return trainer
def maybe_loader_load_from_conf(conf):
from fate_llm.model_zoo.hf_model import HFAutoModelForCausalLM
model = loader_load_from_conf(conf)
if isinstance(model, HFAutoModelForCausalLM):
model = model.load()
return model
================================================
FILE: python/fate_llm/runner/fedmkt_runner.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from fate.components.components.nn.nn_runner import (
load_model_dict_from_path,
dir_warning,
loader_load_from_conf,
run_dataset_func,
)
from typing import Dict
from fate.components.components.nn.loader import Loader
from fate.ml.nn.homo.fedavg import FedAVGArguments
from typing import Union, Optional, Literal, List
from transformers.trainer_utils import get_last_checkpoint
import logging
from fate.arch.dataframe import DataFrame
from fate.components.components.nn.runner.homo_default_runner import DefaultRunner
from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTSLM, FedMKTLLM
logger = logging.getLogger(__name__)
class FedMKTRunner(DefaultRunner):
def __init__(
self,
algo: str = "fedmkt",
model_conf: Optional[Dict] = None,
optimizer_conf: Optional[Dict] = None,
training_args_conf: Optional[Dict] = None,
fed_args_conf: Optional[Dict] = None,
pub_dataset_conf: Optional[Dict] = None,
priv_dataset_conf: Optional[Dict] = None,
data_collator_conf: Optional[Dict] = None,
tokenizer_conf: Optional[Dict] = None,
llm_tokenizer_conf: Optional[Dict] = None,
slm_tokenizers_conf: List[Optional[Dict]] = None,
llm_to_slm_vocab_mapping_path: str = None,
slm_to_llm_vocab_mapping_paths: List[str] = None,
task_type: Literal["causal_lm", "others"] = "causal_lm",
save_trainable_weights_only: bool = False,
pub_dataset_path: str = None,
) -> None:
super(FedMKTRunner, self).__init__()
self.algo = algo
self.model_conf = model_conf
self.optimizer_conf = optimizer_conf
self.training_args_conf = training_args_conf
self.fed_args_conf = fed_args_conf
self.pub_dataset_conf = pub_dataset_conf
self.priv_dataset_conf = priv_dataset_conf
self.data_collator_conf = data_collator_conf
self.tokenizer_conf = tokenizer_conf
self.llm_tokenizer_conf = llm_tokenizer_conf
self.slm_tokenizers_conf = slm_tokenizers_conf
self.llm_to_slm_vocab_mapping_path = llm_to_slm_vocab_mapping_path
self.slm_to_llm_vocab_mapping_paths = slm_to_llm_vocab_mapping_paths
self.task_type = task_type
self.pub_dataset_path = pub_dataset_path
self.save_trainable_weights_only = save_trainable_weights_only
self.training_args = None
# check param
if self.algo.lower() != "fedmkt":
raise ValueError(f"algo should be fedmkt")
if self.task_type not in ["causal_lm"]:
raise ValueError("task_type should be causal_lm")
def common_setup(self, saved_model=None, output_dir=None):
ctx = self.get_context()
if output_dir is None:
output_dir = "./"
model = loader_load_from_conf(self.model_conf)
if model is None:
raise ValueError(f"model is None, cannot load model from conf {self.model_conf}")
resume_path = None
if saved_model is not None:
model_dict = load_model_dict_from_path(saved_model)
model.load_state_dict(model_dict)
logger.info(f"loading model dict from {saved_model} to model done")
if get_last_checkpoint(saved_model) is not None:
resume_path = saved_model
logger.info(f"checkpoint detected, resume_path set to {resume_path}")
# load optimizer
if self.optimizer_conf:
optimizer_loader = Loader.from_dict(self.optimizer_conf)
optimizer_ = optimizer_loader.load_item()
optimizer_params = optimizer_loader.kwargs
optimizer = optimizer_(model.parameters(), **optimizer_params)
else:
optimizer = None
# load tokenizer if import conf provided
tokenizer = loader_load_from_conf(self.tokenizer_conf)
# args
dir_warning(self.training_args_conf)
training_args = FedMKTTrainingArguments(**self.training_args_conf)
# reset to default, saving to arbitrary path is not allowed in
# DefaultRunner
training_args.output_dir = output_dir
training_args.resume_from_checkpoint = resume_path # resume path
self.training_args = training_args
if self.fed_args_conf is not None:
fed_args = FedAVGArguments(**self.fed_args_conf)
else:
fed_args = None
pub_dataset = loader_load_from_conf(self.pub_dataset_conf)
pub_dataset.load(self.pub_dataset_path)
return ctx, model, optimizer, tokenizer, training_args, fed_args, pub_dataset
def llm_setup(self, train_set=None, validate_set=None, output_dir=None, saved_model=None):
ctx, model, optimizer, tokenizer, training_args, fed_args, pub_dataset = self.common_setup(
output_dir=output_dir, saved_model=saved_model)
if validate_set is not None:
validate_dataset = loader_load_from_conf(self.pub_dataset_conf)
validate_dataset.load(validate_set)
else:
validate_dataset = None
slm_tokenizers = None
if self.slm_tokenizers_conf:
slm_tokenizers = [loader_load_from_conf(tokenizer_conf) for tokenizer_conf in self.slm_tokenizers_conf]
slm_to_llm_vocab_mappings = []
for vocab_mapping_path in self.slm_to_llm_vocab_mapping_paths:
with open(vocab_mapping_path, "r") as fin:
vocab_mapping = json.loads(fin.read())
slm_to_llm_vocab_mappings.append(vocab_mapping)
trainer = FedMKTLLM(
ctx=ctx,
model=model,
training_args=training_args,
fed_args=fed_args,
train_set=pub_dataset,
val_set=validate_dataset,
tokenizer=tokenizer,
slm_tokenizers=slm_tokenizers,
slm_to_llm_vocab_mappings=slm_to_llm_vocab_mappings,
save_trainable_weights_only=self.save_trainable_weights_only,
)
return trainer
def slm_setup(self, train_set=None, validate_set=None, output_dir=None, saved_model=None):
ctx, model, optimizer, tokenizer, training_args, fed_args, pub_dataset = self.common_setup(
output_dir=output_dir, saved_model=saved_model)
priv_dataset = loader_load_from_conf(self.priv_dataset_conf)
priv_dataset.load(train_set)
if validate_set is not None:
validate_dataset = loader_load_from_conf(self.priv_dataset_conf)
validate_dataset.load(validate_set)
else:
validate_dataset = None
llm_tokenizer = loader_load_from_conf(self.llm_tokenizer_conf)
with open(self.llm_to_slm_vocab_mapping_path, "r") as fin:
vocab_mapping = json.loads(fin.read())
priv_data_collator = loader_load_from_conf(self.data_collator_conf)
trainer = FedMKTSLM(
ctx=ctx,
model=model,
training_args=training_args,
fed_args=fed_args,
pub_train_set=pub_dataset,
priv_train_set=priv_dataset,
val_set=validate_dataset,
tokenizer=tokenizer,
save_trainable_weights_only=self.save_trainable_weights_only,
llm_tokenizer=llm_tokenizer,
llm_to_slm_vocab_mapping=vocab_mapping,
data_collator=priv_data_collator
)
return trainer
def train(
self,
train_data: Optional[Union[str, DataFrame]] = None,
validate_data: Optional[Union[str, DataFrame]] = None,
output_dir: str = None,
saved_model_path: str = None,
):
if self.is_client():
trainer = self.slm_setup(train_set=train_data, validate_set=validate_data, output_dir=output_dir, saved_model=saved_model_path)
trainer.train()
else:
trainer = self.llm_setup(
train_set=train_data, validate_set=validate_data, output_dir=output_dir, saved_model=saved_model_path
)
trainer.train()
self.trainer = trainer
if self.training_args.deepspeed and self.training_args.local_rank != 0:
pass
else:
trainer.save_model(output_dir)
def predict(self, *args, **kwargs):
pass
================================================
FILE: python/fate_llm/runner/homo_seq2seq_runner.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fate.components.components.nn.nn_runner import (
NNRunner,
load_model_dict_from_path,
dir_warning,
loader_load_from_conf,
run_dataset_func,
)
from fate.components.components.nn.runner.homo_default_runner import DefaultRunner
from fate.ml.nn.homo.fedavg import FedAVGArguments
from fate_llm.algo.fedavg.fedavg import Seq2SeqFedAVGClient, Seq2SeqFedAVGServer
from typing import Dict
from fate.components.components.nn.loader import Loader
import torch.nn as nn
import torch.optim as optim
from fate.ml.nn.trainer.trainer_base import FedArguments, HomoTrainerServer
from fate_llm.trainer.seq2seq_trainer import Seq2SeqTrainingArguments, HomoSeq2SeqTrainerClient
from typing import Union, Type, Callable, Optional
from transformers.trainer_utils import get_last_checkpoint
from typing import Literal
import logging
from fate.arch.dataframe import DataFrame
logger = logging.getLogger(__name__)
SUPPORTED_ALGO = ["fedavg", "ot"]
def _check_instances(
trainer: Union[Type[HomoSeq2SeqTrainerClient], Type[HomoTrainerServer]] = None,
fed_args: FedArguments = None,
model: nn.Module = None,
optimizer: optim.Optimizer = None,
train_args: Seq2SeqTrainingArguments = None,
data_collator: Callable = None,
) -> None:
if trainer is not None and not (
issubclass(type(trainer), HomoSeq2SeqTrainerClient) or issubclass(type(trainer), HomoTrainerServer)
):
raise TypeError(
f"SetupReturn Error: trainer must be a subclass of either "
f"HomoSeq2SeqTrainerClient or HomoSeq2SeqTrainerClient but got {type(trainer)}"
)
if fed_args is not None and not isinstance(fed_args, FedArguments):
raise TypeError(f"SetupReturn Error: fed_args must be an instance of FedArguments but got {type(fed_args)}")
if model is not None and not issubclass(type(model), nn.Module):
raise TypeError(f"SetupReturn Error: model must be a subclass of torch.nn.Module but got {type(model)}")
if optimizer is not None and not issubclass(type(optimizer), optim.Optimizer):
raise TypeError(
f"SetupReturn Error: optimizer must be a subclass of torch.optim.Optimizer but got {type(optimizer)}"
)
if train_args is not None and not isinstance(train_args, Seq2SeqTrainingArguments):
raise TypeError(
f"SetupReturn Error: train_args must be an instance of Seq2SeqTrainingArguments "
f"but got {type(train_args)}"
)
if data_collator is not None and not callable(data_collator):
raise TypeError(f"SetupReturn Error: data_collator must be callable but got {type(data_collator)}")
class Seq2SeqRunner(DefaultRunner):
def __init__(
self,
algo: str = "fedavg",
model_conf: Optional[Dict] = None,
dataset_conf: Optional[Dict] = None,
optimizer_conf: Optional[Dict] = None,
training_args_conf: Optional[Dict] = None,
fed_args_conf: Optional[Dict] = None,
data_collator_conf: Optional[Dict] = None,
tokenizer_conf: Optional[Dict] = None,
task_type: Literal["causal_lm", "other"] = "causal_lm",
local_mode: bool = False,
save_trainable_weights_only: bool = False,
) -> None:
super(NNRunner, self).__init__()
self.algo = algo
self.model_conf = model_conf
self.dataset_conf = dataset_conf
self.optimizer_conf = optimizer_conf
self.training_args_conf = training_args_conf
self.fed_args_conf = fed_args_conf
self.data_collator_conf = data_collator_conf
self.local_mode = local_mode
self.tokenizer_conf = tokenizer_conf
self.task_type = task_type
self.save_trainable_weights_only = save_trainable_weights_only
# check param
if self.algo not in SUPPORTED_ALGO:
raise ValueError(f"algo should be one of {SUPPORTED_ALGO}")
if self.task_type not in ["causal_lm", "others"]:
raise ValueError("task_type should be one of [binary, multi, regression, others]")
assert isinstance(self.local_mode, bool), "local should be bool"
# setup var
self.trainer = None
self.training_args = None
def client_setup(self, train_set=None, validate_set=None, output_dir=None, saved_model=None, stage="train"):
if stage == "predict":
self.local_mode = True
if self.algo == "fedavg":
client_class: Seq2SeqFedAVGClient = Seq2SeqFedAVGClient
else:
raise ValueError(f"algo {self.algo} not supported")
ctx = self.get_context()
model = loader_load_from_conf(self.model_conf)
if model is None:
raise ValueError(f"model is None, cannot load model from conf {self.model_conf}")
if output_dir is None:
output_dir = "./"
resume_path = None
if saved_model is not None:
model_dict = load_model_dict_from_path(saved_model)
model.load_state_dict(model_dict)
logger.info(f"loading model dict from {saved_model} to model done")
if get_last_checkpoint(saved_model) is not None:
resume_path = saved_model
logger.info(f"checkpoint detected, resume_path set to {resume_path}")
# load optimizer
if self.optimizer_conf:
optimizer_loader = Loader.from_dict(self.optimizer_conf)
optimizer_ = optimizer_loader.load_item()
optimizer_params = optimizer_loader.kwargs
optimizer = optimizer_(model.parameters(), **optimizer_params)
else:
optimizer = None
# load collator func
data_collator = loader_load_from_conf(self.data_collator_conf)
# load tokenizer if import conf provided
tokenizer = loader_load_from_conf(self.tokenizer_conf)
# args
dir_warning(self.training_args_conf)
training_args = Seq2SeqTrainingArguments(**self.training_args_conf)
self.training_args = training_args
# reset to default, saving to arbitrary path is not allowed in
# DefaultRunner
training_args.output_dir = output_dir
training_args.resume_from_checkpoint = resume_path # resume path
fed_args = FedAVGArguments(**self.fed_args_conf)
# prepare trainer
trainer = client_class(
ctx=ctx,
model=model,
optimizer=optimizer,
training_args=training_args,
fed_args=fed_args,
data_collator=data_collator,
tokenizer=tokenizer,
train_set=train_set,
val_set=validate_set,
local_mode=self.local_mode,
save_trainable_weights_only=self.save_trainable_weights_only,
)
_check_instances(
trainer=trainer,
model=model,
optimizer=optimizer,
train_args=training_args,
fed_args=fed_args,
data_collator=data_collator,
)
return trainer
def server_setup(self, stage="train"):
if stage == "predict":
self.local_mode = True
if self.algo == "fedavg":
server_class: Seq2SeqFedAVGServer = Seq2SeqFedAVGServer
else:
raise ValueError(f"algo {self.algo} not supported")
ctx = self.get_context()
trainer = server_class(ctx=ctx, local_mode=self.local_mode)
_check_instances(trainer)
return trainer
def predict(self, test_data: Union[str, DataFrame], saved_model_path: str = None) -> Union[DataFrame, None]:
if self.is_client():
test_set = self._prepare_data(test_data, "test_data")
if self.trainer is not None:
trainer = self.trainer
logger.info("trainer found, skip setting up")
else:
trainer = self.client_setup(saved_model=saved_model_path, stage="predict")
classes = run_dataset_func(test_set, "get_classes")
match_ids = run_dataset_func(test_set, "get_match_ids")
sample_ids = run_dataset_func(test_set, "get_sample_ids")
match_id_name = run_dataset_func(test_set, "get_match_id_name")
sample_id_name = run_dataset_func(test_set, "get_sample_id_name")
if not self.training_args.predict_with_generate:
return
pred_rs = trainer.predict(test_set)
if self.training_args and self.training_args.deepspeed and self.training_args.local_rank != 0:
return
rs_df = self.get_nn_output_dataframe(
self.get_context(),
pred_rs.predictions,
pred_rs.label_ids if hasattr(pred_rs, "label_ids") else None,
match_ids,
sample_ids,
match_id_name=match_id_name,
sample_id_name=sample_id_name,
dataframe_format="dist_df",
task_type=self.task_type,
classes=classes,
)
return rs_df
else:
# server not predict
return
================================================
FILE: python/fate_llm/runner/inferdpt_runner.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fate.components.components.nn.nn_runner import (
NNRunner,
load_model_dict_from_path,
dir_warning,
loader_load_from_conf,
run_dataset_func,
)
import os
from datetime import datetime
from fate.components.components.nn.nn_runner import NNRunner
from typing import Dict
from fate.components.components.nn.loader import Loader
import torch
import torch.nn as nn
import torch.optim as optim
from typing import Union, Type, Callable, Optional
from typing import Literal
import logging
from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer
from fate_llm.algo.inferdpt.init._init import InferInit
from fate.components.components.nn.loader import Loader
from fate_llm.dataset.hf_dataset import HuggingfaceDataset, Dataset
from fate.arch.dataframe import DataFrame
logger = logging.getLogger(__name__)
class InferDPTRunner(NNRunner):
def __init__(
self,
inferdpt_init_conf: Dict,
encode_template: str = None,
instruction_template: str = None,
decode_template: str = None,
dataset_conf: Optional[Dict] = None,
remote_inference_kwargs: Dict = {},
local_inference_kwargs: Dict = {},
perturb_doc_key: str = 'perturbed_doc',
perturbed_response_key: str = 'perturbed_response',
result_key: str = 'inferdpt_result',
) -> None:
self.inferdpt_init_conf = inferdpt_init_conf
self.encode_template = encode_template
self.instruction_template = instruction_template
self.decode_template = decode_template
self.dataset_conf = dataset_conf
self.remote_inference_kwargs = remote_inference_kwargs
self.local_inference_kwargs = local_inference_kwargs
self.perturb_doc_key = perturb_doc_key
self.perturbed_response_key = perturbed_response_key
self.result_key = result_key
def _get_inst(self):
loader = Loader.from_dict(self.inferdpt_init_conf)
init_inst = loader.load_item()(self.get_context())
assert isinstance(init_inst, InferInit), 'Need a InferDPTInit class for initialization, but got {}'.format(type(init_inst))
inferdpt_inst = init_inst.get_inst()
logger.info('inferdpt inst loaded')
return inferdpt_inst
def client_setup(self):
client_inst = self._get_inst()
assert isinstance(client_inst, InferDPTClient), 'Client need to get an InferDPTClient class to run the algo'
return client_inst
def server_setup(self):
server_inst = self._get_inst()
assert isinstance(server_inst, InferDPTServer), 'Server need to get an InferDPTServer class to run the algo'
return server_inst
def _prepare_data(self, data, data_name):
if data is None:
return None
if isinstance(data, DataFrame) and self.dataset_conf is None:
raise ValueError('DataFrame format dataset is not supported, please use bind path to load your dataset')
else:
dataset = loader_load_from_conf(self.dataset_conf)
if hasattr(dataset, "load"):
logger.info("load path is {}".format(data))
load_output = dataset.load(data)
if load_output is not None:
dataset = load_output
return dataset
else:
raise ValueError(
f"The dataset {dataset} lacks a load() method, which is required for data parsing in the DefaultRunner. \
Please implement this method in your dataset class. You can refer to the base class 'Dataset' in 'fate.ml.nn.dataset.base' \
for the necessary interfaces to implement."
)
if dataset is not None and not issubclass(type(dataset), Dataset):
raise TypeError(
f"SetupReturn Error: {data_name}_set must be a subclass of fate built-in Dataset but got {type(dataset)}, \n"
f"You can get the class via: from fate.ml.nn.dataset.table import Dataset"
)
return dataset
def train(
self,
train_data: Optional[Union[str]] = None,
validate_data: Optional[Union[str]] = None,
output_dir: str = None,
saved_model_path: str = None,
) -> None:
if self.is_client():
dataset_0 = self._prepare_data(train_data, "train_data")
logger.info('dataset loaded')
if dataset_0 is None:
raise ValueError('You must provide dataset for inference')
assert isinstance(dataset_0, HuggingfaceDataset), 'Currently only support HuggingfaceDataset for inference, but got {}'.format(type(dataset_0))
logger.info('initializing inst')
client_inst = self.client_setup()
pred_rs = client_inst.inference(
dataset_0, self.encode_template, self.instruction_template, self.decode_template, \
remote_inference_kwargs=self.remote_inference_kwargs,
local_inference_kwargs=self.local_inference_kwargs
)
logger.info('predict done')
saving_path = output_dir + '/' + 'inference_result.pkl'
logger.info('result save to path {}'.format(saving_path))
torch.save(pred_rs, saving_path)
elif self.is_server():
server_inst = self.server_setup()
server_inst.inference()
else:
raise ValueError('Unknown role')
def predict(
self, test_data: Optional[Union[str]] = None, output_dir: str = None, saved_model_path: str = None
):
logger.warning('Predicting mode is not supported in this algorithms in current version, please use the train mode to run inferdpt inference.')
return
================================================
FILE: python/fate_llm/runner/offsite_tuning_runner.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fate.components.components.nn.nn_runner import (
load_model_dict_from_path,
dir_warning,
loader_load_from_conf,
)
from fate.ml.nn.homo.fedavg import FedAVGArguments
from fate_llm.algo.fedavg.fedavg import Seq2SeqFedAVGServer
from typing import Dict
from fate.components.components.nn.loader import Loader
from fate_llm.trainer.seq2seq_trainer import Seq2SeqTrainingArguments
from typing import Union, Optional
from transformers.trainer_utils import get_last_checkpoint
from typing import Literal
import logging
from fate.arch.dataframe import DataFrame
from fate_llm.runner.homo_seq2seq_runner import Seq2SeqRunner, _check_instances
from fate_llm.algo.offsite_tuning.offsite_tuning import OffsiteTuningTrainerClient, OffsiteTuningTrainerServer
logger = logging.getLogger(__name__)
SUPPORTED_ALGO = ["fedavg"]
class OTRunner(Seq2SeqRunner):
def __init__(
self,
model_conf: Optional[Dict] = None,
dataset_conf: Optional[Dict] = None,
optimizer_conf: Optional[Dict] = None,
training_args_conf: Optional[Dict] = None,
fed_args_conf: Optional[Dict] = None,
data_collator_conf: Optional[Dict] = None,
tokenizer_conf: Optional[Dict] = None,
task_type: Literal["causal_lm", "other"] = "causal_lm",
save_trainable_weights_only: bool = False,
aggregate_model: bool = False,
algo: str = 'ot'
) -> None:
super(OTRunner, self).__init__(
algo, model_conf, dataset_conf, optimizer_conf, training_args_conf, fed_args_conf,
data_collator_conf, tokenizer_conf, task_type, local_mode=False
)
self.aggregate_model = aggregate_model
self.save_trainable_weights_only = save_trainable_weights_only
def setup(self, train_set=None, validate_set=None, output_dir=None, saved_model=None, stage="train"):
if stage == "predict":
self.local_mode = True
ctx = self.get_context()
model = loader_load_from_conf(self.model_conf)
if model is None:
raise ValueError(f"model is None, cannot load model from conf {self.model_conf}")
if output_dir is None:
output_dir = "./"
resume_path = None
if saved_model is not None:
model_dict = load_model_dict_from_path(saved_model)
model.load_state_dict(model_dict)
logger.info(f"loading model dict from {saved_model} to model done")
if get_last_checkpoint(saved_model) is not None:
resume_path = saved_model
logger.info(f"checkpoint detected, resume_path set to {resume_path}")
# load optimizer
if self.optimizer_conf:
optimizer_loader = Loader.from_dict(self.optimizer_conf)
optimizer_ = optimizer_loader.load_item()
optimizer_params = optimizer_loader.kwargs
optimizer = optimizer_(model.parameters(), **optimizer_params)
else:
optimizer = None
# load collator func
data_collator = loader_load_from_conf(self.data_collator_conf)
# load tokenizer if import conf provided
tokenizer = loader_load_from_conf(self.tokenizer_conf)
# args
dir_warning(self.training_args_conf)
training_args = Seq2SeqTrainingArguments(**self.training_args_conf)
self.training_args = training_args
# reset to default, saving to arbitrary path is not allowed in
# DefaultRunner
training_args.output_dir = output_dir
training_args.resume_from_checkpoint = resume_path # resume path
fed_args = FedAVGArguments(**self.fed_args_conf)
# prepare trainer
if self.is_client():
trainer = OffsiteTuningTrainerClient(
ctx=ctx,
model=model,
optimizer=optimizer,
training_args=training_args,
fed_args=fed_args,
data_collator=data_collator,
tokenizer=tokenizer,
train_set=train_set,
val_set=validate_set,
save_trainable_weights_only=self.save_trainable_weights_only,
aggregate_model=self.aggregate_model
)
elif self.is_server():
trainer = OffsiteTuningTrainerServer(
ctx=ctx,
model=model,
aggregate_model=self.aggregate_model
)
_check_instances(
trainer=trainer,
model=model,
optimizer=optimizer,
train_args=training_args,
fed_args=fed_args,
data_collator=data_collator,
)
return trainer
def server_setup(self, stage="train"):
if stage == "predict":
self.local_mode = True
if self.algo == "fedavg":
server_class: Seq2SeqFedAVGServer = Seq2SeqFedAVGServer
else:
raise ValueError(f"algo {self.algo} not supported")
ctx = self.get_context()
trainer = server_class(ctx=ctx, local_mode=self.local_mode)
_check_instances(trainer)
return trainer
def train(
self,
train_data: Optional[Union[str, DataFrame]] = None,
validate_data: Optional[Union[str, DataFrame]] = None,
output_dir: str = None,
saved_model_path: str = None,
):
if self.is_client():
train_set = self._prepare_data(train_data, "train_data")
validate_set = self._prepare_data(validate_data, "val_data")
trainer = self.setup(
train_set=train_set, validate_set=validate_set, output_dir=output_dir, saved_model=saved_model_path
)
self.trainer = trainer
trainer.train()
elif self.is_server():
trainer = self.setup(
train_set=None, validate_set=None, output_dir=output_dir, saved_model=saved_model_path
)
trainer.train()
if output_dir is not None:
if self.training_args.deepspeed and self.training_args.local_rank != 0:
pass
else:
trainer.save_model(output_dir)
================================================
FILE: python/fate_llm/trainer/__init__.py
================================================
================================================
FILE: python/fate_llm/trainer/seq2seq_trainer.py
================================================
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from transformers import Seq2SeqTrainingArguments as _hf_Seq2SeqTrainingArguments, Seq2SeqTrainer
from dataclasses import dataclass, field
from typing import Optional
from fate.ml.nn.trainer.trainer_base import HomoTrainerMixin, FedArguments, get_ith_checkpoint
import os
import torch
import copy
from torch import nn
from typing import Any, Dict, List, Callable
from enum import Enum
from fate.arch import Context
from torch.optim import Optimizer
from torch.utils.data import DataLoader, Dataset
from transformers import PreTrainedTokenizer
from transformers import Trainer, EvalPrediction
from transformers.trainer_utils import has_length
from torch.utils.data import _utils
from transformers.trainer_callback import TrainerCallback
from typing import Optional
from dataclasses import dataclass, field
from transformers.modeling_utils import unwrap_model
TRAINABLE_WEIGHTS_NAME = "adapter_model.bin"
@dataclass
class _S2STrainingArguments(_hf_Seq2SeqTrainingArguments):
# in fate-2.0, we will control the output dir when using pipeline
output_dir: str = field(default="./")
disable_tqdm: bool = field(default=True)
save_strategy: str = field(default="no")
logging_strategy: str = field(default="epoch")
logging_steps: int = field(default=1)
evaluation_strategy: str = field(default="no")
logging_dir: str = field(default=None)
checkpoint_idx: int = field(default=None)
# by default, we use constant learning rate, the same as FATE-1.X
lr_scheduler_type: str = field(default="constant")
log_level: str = field(default="info")
deepspeed: Optional[str] = field(default=None)
save_safetensors: bool = field(default=False)
use_cpu: bool = field(default=False)
def __post_init__(self):
self.push_to_hub = False
self.hub_model_id = None
self.hub_strategy = "every_save"
self.hub_token = None
self.hub_private_repo = False
self.push_to_hub_model_id = None
self.push_to_hub_organization = None
self.push_to_hub_token = None
super().__post_init__()
DEFAULT_ARGS = _S2STrainingArguments().to_dict()
@dataclass
class Seq2SeqTrainingArguments(_S2STrainingArguments):
# To simplify the to dict result(to_dict only return non-default args)
def to_dict(self):
# Call the superclass's to_dict method
all_args = super().to_dict()
# Get a dict with default values for all fields
default_args = copy.deepcopy(DEFAULT_ARGS)
# Filter out args that are equal to their default values
set_args = {name: value for name, value in all_args.items() if value != default_args.get(name)}
return set_args
class HomoSeq2SeqTrainerClient(Seq2SeqTrainer, HomoTrainerMixin):
def __init__(
self,
ctx: Context,
model: nn.Module,
training_args: Seq2SeqTrainingArguments,
fed_args: FedArguments,
train_set: Dataset,
val_set: Dataset = None,
optimizer: torch.optim.Optimizer = None,
data_collator: Callable = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
callbacks: Optional[List[TrainerCallback]] = [],
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
local_mode: bool = False,
save_trainable_weights_only: bool = False,
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
):
# in case you forget to set evaluation_strategy
if val_set is not None and training_args.evaluation_strategy == "no":
training_args.evaluation_strategy = "epoch"
HomoTrainerMixin.__init__(
self,
ctx=ctx,
model=model,
optimizer=optimizer,
training_args=training_args,
fed_args=fed_args,
train_set=train_set,
val_set=val_set,
scheduler=scheduler,
callbacks=callbacks,
compute_metrics=compute_metrics,
local_mode=local_mode,
save_trainable_weights_only=save_trainable_weights_only,
)
# concat checkpoint path if checkpoint idx is set
if self._args.checkpoint_idx is not None:
checkpoint_path = self._args.resume_from_checkpoint
if checkpoint_path is not None and os.path.exists(checkpoint_path):
checkpoint_folder = get_ith_checkpoint(checkpoint_path, self._args.checkpoint_idx)
self._args.resume_from_checkpoint = os.path.join(checkpoint_path, checkpoint_folder)
Trainer.__init__(
self,
model=model,
args=self._args,
train_dataset=train_set,
eval_dataset=val_set,
data_collator=data_collator,
optimizers=(optimizer, scheduler),
tokenizer=tokenizer,
compute_metrics=self._compute_metrics_warp_func,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
self._add_fate_callback(self.callback_handler)
def _save(
self,
output_dir: Optional[str] = None,
state_dict=None
):
if not self._save_trainable_weights_only:
return super()._save(output_dir, state_dict)
else:
model = unwrap_model(self.model)
if hasattr(model, "save_trainable"):
model.save_trainable(output_dir)
else:
state_dict = {
k: p.to("cpu") for k,
p in model.named_parameters() if p.requires_grad
}
torch.save(state_dict, os.path.join(output_dir, TRAINABLE_WEIGHTS_NAME))
================================================
FILE: python/requirements.txt
================================================
accelerate==0.27.2
deepspeed==0.13.3
peft==0.8.2
sentencepiece==0.2.0
lm_eval==0.4.2
rouge-score==0.1.2
datasets==2.18.0
editdistance
torch==2.3.1
transformers==4.37.2
opacus==1.4.1
fastchat
Jinja2
sentence-transformers
openai
================================================
FILE: python/setup.py
================================================
# -*- coding: utf-8 -*-
#
# Copyright 2024 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from setuptools import find_packages, setup
# Define the packages and modules
packages = find_packages(".")
package_data = {"": ["*"]}
# Define dependencies
install_requires = [
"accelerate==0.27.2",
"deepspeed==0.13.3",
"peft==0.8.2",
"sentencepiece==0.2.0",
"lm_eval==0.4.2",
"rouge-score==0.1.2",
"datasets==2.18.0",
"editdistance",
"torch==2.3.1",
"transformers==4.37.2",
"opacus==1.4.1",
"fastchat",
"Jinja2",
"sentence-transformers",
"openai"
]
# Define the entry points for command-line tools
entry_points = {
"console_scripts": [
"fate_llm = fate_llm.evaluate.scripts.fate_llm_cli:fate_llm_cli"
]
}
extras_require = {
"fate": ["pyfate==2.2.0"],
"fate_flow": ["fate_flow==2.2.0"],
"fate_client": ["fate_client==2.2.0"]
}
# Configure and call the setup function
setup_kwargs = {
"name": "fate_llm",
"version": "2.2.0",
"description": "Federated Learning for Large Language Models",
"long_description": "Federated Learning for Large Language Models (FATE-LLM) provides a framework to train and evaluate large language models in a federated manner.",
"long_description_content_type": "text/markdown",
"author": "FederatedAI",
"author_email": "contact@FedAI.org",
"url": "https://fate.fedai.org/",
"packages": packages,
"install_requires": install_requires,
"entry_points": entry_points,
"extras_require": extras_require,
"python_requires": ">=3.8",
"include_package_data": True
}
setup(**setup_kwargs)