Showing preview only (820K chars total). Download the full file or copy to clipboard to get everything.
Repository: FederatedAI/FATE-LLM
Branch: main
Commit: 0c63377e468f
Files: 172
Total size: 765.2 KB
Directory structure:
gitextract_qkv2xwam/
├── LICENSE
├── README.md
├── RELEASE.md
├── doc/
│ ├── fate_llm_evaluate.md
│ ├── standalone_deploy.md
│ └── tutorial/
│ ├── fdkt/
│ │ ├── README.md
│ │ └── fdkt.ipynb
│ ├── fedcot/
│ │ ├── README.md
│ │ ├── encoder_decoder_tutorial.ipynb
│ │ └── fedcot_tutorial.ipynb
│ ├── fedkseed/
│ │ ├── README.md
│ │ └── fedkseed-example.ipynb
│ ├── fedmkt/
│ │ ├── README.md
│ │ └── fedmkt.ipynb
│ ├── inferdpt/
│ │ └── inferdpt_tutorial.ipynb
│ ├── offsite_tuning/
│ │ ├── Offsite_tuning_tutorial.ipynb
│ │ └── README.md
│ └── pellm/
│ ├── ChatGLM3-6B_ds.ipynb
│ └── builtin_pellm_models.md
├── examples/
│ ├── fedmkt/
│ │ ├── __init__.py
│ │ ├── fedmkt.py
│ │ ├── fedmkt_config.yaml
│ │ └── test_fedmkt_llmsuit.yaml
│ ├── offsite_tuning/
│ │ ├── __init__.py
│ │ ├── offsite_tuning.py
│ │ ├── offsite_tuning_config.yaml
│ │ └── test_offsite_tuning_llmsuite.yaml
│ └── pellm/
│ ├── __init__.py
│ ├── bloom_lora_config.yaml
│ ├── test_bloom_lora.py
│ └── test_pellm_llmsuite.yaml
└── python/
├── MANIFEST.in
├── fate_llm/
│ ├── __init__.py
│ ├── algo/
│ │ ├── __init__.py
│ │ ├── dp/
│ │ │ ├── __init__.py
│ │ │ ├── dp_trainer.py
│ │ │ └── opacus_compatibility/
│ │ │ ├── __init__.py
│ │ │ ├── grad_sample/
│ │ │ │ ├── __init__.py
│ │ │ │ └── embedding.py
│ │ │ ├── optimizers/
│ │ │ │ ├── __init__.py
│ │ │ │ └── optimizer.py
│ │ │ └── transformers_compate.py
│ │ ├── fdkt/
│ │ │ ├── __init__.py
│ │ │ ├── cluster/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── cluster.py
│ │ │ │ └── cluster_method.py
│ │ │ ├── fdkt_data_aug.py
│ │ │ ├── inference_inst.py
│ │ │ └── utils/
│ │ │ ├── __init__.py
│ │ │ ├── dp_loss.py
│ │ │ ├── invalid_data_filter.py
│ │ │ └── text_generate.py
│ │ ├── fedavg/
│ │ │ ├── __init__.py
│ │ │ └── fedavg.py
│ │ ├── fedcollm/
│ │ │ ├── __init__.py
│ │ │ ├── fedcollm.py
│ │ │ ├── fedcollm_trainer.py
│ │ │ └── fedcollm_training_args.py
│ │ ├── fedcot/
│ │ │ ├── __init__.py
│ │ │ ├── encoder_decoder/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── init/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── default_init.py
│ │ │ │ └── slm_encoder_decoder.py
│ │ │ ├── fedcot_trainer.py
│ │ │ └── slm_encoder_decoder_trainer.py
│ │ ├── fedkseed/
│ │ │ ├── __init__.py
│ │ │ ├── args.py
│ │ │ ├── fedkseed.py
│ │ │ ├── optimizer.py
│ │ │ ├── pytorch_utils.py
│ │ │ ├── trainer.py
│ │ │ └── zo_utils.py
│ │ ├── fedmkt/
│ │ │ ├── __init__.py
│ │ │ ├── fedmkt.py
│ │ │ ├── fedmkt_data_collator.py
│ │ │ ├── fedmkt_trainer.py
│ │ │ ├── token_alignment/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── spectal_token_mapping.py
│ │ │ │ ├── token_align.py
│ │ │ │ └── vocab_mapping.py
│ │ │ └── utils/
│ │ │ ├── __init__.py
│ │ │ ├── dataset_sync_util.py
│ │ │ ├── generate_logit_utils.py
│ │ │ ├── tokenizer_tool.py
│ │ │ └── vars_define.py
│ │ ├── inferdpt/
│ │ │ ├── __init__.py
│ │ │ ├── _encode_decode.py
│ │ │ ├── inferdpt.py
│ │ │ ├── init/
│ │ │ │ ├── _init.py
│ │ │ │ └── default_init.py
│ │ │ └── utils.py
│ │ ├── offsite_tuning/
│ │ │ ├── __init__.py
│ │ │ └── offsite_tuning.py
│ │ └── ppc-gpt/
│ │ └── __init__.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── data_collator/
│ │ │ ├── __init__.py
│ │ │ ├── cust_data_collator.py
│ │ │ └── fedcot_collator.py
│ │ └── tokenizers/
│ │ ├── __init__.py
│ │ └── cust_tokenizer.py
│ ├── dataset/
│ │ ├── __init__.py
│ │ ├── data_config/
│ │ │ ├── __init__.py
│ │ │ ├── default_ag_news.yaml
│ │ │ └── default_yelp_review.yaml
│ │ ├── fedcot_dataset.py
│ │ ├── flex_dataset.py
│ │ ├── hf_dataset.py
│ │ ├── input_output_dataset.py
│ │ ├── prompt_dataset.py
│ │ ├── qa_dataset.py
│ │ └── seq_cls_dataset.py
│ ├── evaluate/
│ │ ├── __init__.py
│ │ ├── scripts/
│ │ │ ├── __init__.py
│ │ │ ├── _options.py
│ │ │ ├── config_cli.py
│ │ │ ├── data_cli.py
│ │ │ ├── eval_cli.py
│ │ │ └── fate_llm_cli.py
│ │ ├── tasks/
│ │ │ ├── __init__.py
│ │ │ ├── advertise_gen/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── advertise_utils.py
│ │ │ │ └── default_advertise_gen.yaml
│ │ │ └── dolly_15k/
│ │ │ ├── __init__.py
│ │ │ ├── default_dolly_15k.yaml
│ │ │ └── dolly_utils.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── _io.py
│ │ ├── _parser.py
│ │ ├── config.py
│ │ ├── data_tools.py
│ │ ├── llm_evaluator.py
│ │ └── model_tools.py
│ ├── inference/
│ │ ├── __init__.py
│ │ ├── api.py
│ │ ├── hf_qw.py
│ │ ├── inference_base.py
│ │ └── vllm.py
│ ├── model_zoo/
│ │ ├── __init__.py
│ │ ├── embedding_transformer/
│ │ │ ├── __init__.py
│ │ │ └── st_model.py
│ │ ├── hf_model.py
│ │ ├── offsite_tuning/
│ │ │ ├── __init__.py
│ │ │ ├── bloom.py
│ │ │ ├── gpt2.py
│ │ │ ├── llama.py
│ │ │ └── offsite_tuning_model.py
│ │ └── pellm/
│ │ ├── __init__.py
│ │ ├── albert.py
│ │ ├── bart.py
│ │ ├── bert.py
│ │ ├── bloom.py
│ │ ├── chatglm.py
│ │ ├── deberta.py
│ │ ├── distilbert.py
│ │ ├── gpt2.py
│ │ ├── llama.py
│ │ ├── opt.py
│ │ ├── parameter_efficient_llm.py
│ │ ├── qwen.py
│ │ └── roberta.py
│ ├── runner/
│ │ ├── __init__.py
│ │ ├── fdkt_runner.py
│ │ ├── fedcot_runner.py
│ │ ├── fedkseed_runner.py
│ │ ├── fedmkt_runner.py
│ │ ├── homo_seq2seq_runner.py
│ │ ├── inferdpt_runner.py
│ │ └── offsite_tuning_runner.py
│ └── trainer/
│ ├── __init__.py
│ └── seq2seq_trainer.py
├── requirements.txt
└── setup.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# FATE-LLM
FATE-LLM is a framework to support federated learning for large language models(LLMs) and small language models(SLMs).
<div align="center">
<img src="./doc/images/fate-llm-show.png" height="300">
</div>
## Design Principle
- Federated learning for large language models(LLMs) and small language models(SLMs).
- Promote training efficiency of federated LLMs using Parameter-Efficient methods.
- Protect the IP of LLMs using FedIPR.
- Protect data privacy during training and inference through privacy preserving mechanisms.
<div align="center">
<img src="./doc/images/fate-llm-plan.png">
</div>
### Standalone deployment
* To deploy FATE-LLM v2.2.0 or higher version, three ways are provided, please refer [deploy tutorial](./doc/standalone_deploy.md) for more details:
* deploy with FATE only from pypi then using Launcher to run tasks
* deploy with FATE、FATE-Flow、FATE-Client from pypi, user can run tasks with Pipeline
* To deploy lower versions: please refer to [FATE-Standalone deployment](https://github.com/FederatedAI/FATE#standalone-deployment).
* To deploy FATE-LLM v2.0.* - FATE-LLM v2.1.*, deploy FATE-Standalone with version >= 2.1, then make a new directory `{fate_install}/fate_llm` and clone the code into it, install the python requirements, and add `{fate_install}/fate_llm/python` to `PYTHONPATH`
* To deploy FATE-LLM v1.x, deploy FATE-Standalone with 1.11.3 <= version < 2.0, then copy directory `python/fate_llm` to `{fate_install}/fate/python/fate_llm`
### Cluster deployment
Use [FATE-LLM deployment packages](https://github.com/FederatedAI/FATE/wiki/Download#llm%E9%83%A8%E7%BD%B2%E5%8C%85) to deploy, refer to [FATE-Cluster deployment](https://github.com/FederatedAI/FATE#cluster-deployment) for more deployment details.
## Quick Start
- [Federated ChatGLM3-6B Training](doc/tutorial/pellm/ChatGLM3-6B_ds.ipynb)
- [Builtin Models In PELLM](doc/tutorial/pellm/builtin_pellm_models.md)
- [FedMKT: Federated Mutual Knowledge Transfer for Large and Small
Language Models](./doc/tutorial/fedmkt/)
- [FedCoT: Federated Chain-of-Thought Distillation for Large Language Models](./doc/tutorial/fedcot)
- [PPC-GPT: Federated Task-Specific Compression of Large Language
Models via Pruning and Chain-of-Thought Distillation](https://aclanthology.org/2025.emnlp-main.747.pdf)
- [FDKT: Federated Domain-Specific Knowledge Transfer on Large Language Models Using Synthetic Data](./doc/tutorial/fdkt)
- [Offsite Tuning: Transfer Learning without Full Model](./doc/tutorial/offsite_tuning/Offsite_tuning_tutorial.ipynb)
- [FedKSeed: Federated Full-Parameter Tuning of Billion-Sized Language Models
with Communication Cost under 18 Kilobytes](./doc/tutorial/fedkseed/)
- [InferDPT: Privacy-preserving Inference for Black-box Large Language Models](./doc/tutorial/inferdpt/inferdpt_tutorial.ipynb)
## FATE-LLM Evaluate
- [Python SDK & CLI Usage Guide](./doc/fate_llm_evaluate.md)
## Citation
If you publish work that uses FATE-LLM, please cite FATE-LLM as follows:
```
@article{fan2023fate,
title={Fate-llm: A industrial grade federated learning framework for large language models},
author={Fan, Tao and Kang, Yan and Ma, Guoqiang and Chen, Weijing and Wei, Wenbin and Fan, Lixin and Yang, Qiang},
journal={Symposium on Advances and Open Problems in Large Language Models (LLM@IJCAI'23)},
year={2023}
}
```
================================================
FILE: RELEASE.md
================================================
## Release 2.2.0
### Major Features and Improvements
* Integrate the FedCoT (Federated Chain-of-Thought) algorithm, a novel framework that enhances local small language models (SLMs) using differentially private protected Chain of Thoughts (Cot) generated by remote LLMs:
* Implement InferDPT for privacy-preserving Cot generation.
* Support an encoder-decoder mechanism for privacy-preserving Cot generation.
* Add prefix trainers for step-by-step distillation and text encoder-decoder training.
* Integrate the FDKT algorithm, a framework that enables domain-specific knowledge transfer from LLMs to SLMs while preserving SLM data privacy
* Deployment Optimization: support installation of FATE-LLM by PyPi
## Release 2.1.0
### Major Features and Improvements
* New FedMKT Federated Tuning Algorithms: Federated Mutual Knowledge Transfer for Large and Small Language Models
* Support three distinct scenarios: Heterogeneous, Homogeneous and One-to-One
* Support LLM to SLM one-way knowledge transfer
* Introduce the InferDPT algorithm, which leverages differential privacy (DP) to facilitate privacy-preserving inference for large language models.
* Introduce FATE-LLM Evaluate: evaluate FATE-LLM models in few lines with Python SDK or simple CLI commands(`fate_llm evaluate`), built-in cases included
## Release 2.0.0
### Major Features and Improvements
* Adapt to fate-v2.0 framework:
* Migrate parameter-efficient fine-tuning training methods and models.
* Migrate Standard Offsite-Tuning and Extended Offsite-Tuning(Federated Offsite-Tuning+)
* Newly trainer,dataset, data_processing function design
* New FedKSeed Federated Tuning Algorithm: train large language models in a federated learning setting with extremely low communication cost
## Release 1.3.0
### Major Features and Improvements
* FTL-LLM(Fedrated Learning + Transfer Learning + LLM)
* Standard Offsite-Tuning and Extended Offsite-Tuning(Federated Offsite-Tuning+)now supported
* Framework available for Emulator and Adapter development
* New Offsite-Tuning Trainer introduced
* Includes built-in models such as GPT-2 family, Llama7b, and Bloom family
* FedIPR
* Introduced WatermarkDataset as the foundational dataset class for backdoor-based watermarks
* Added SignConv and SignLayerNorm blocks for feature-based watermark models
* New FedIPR Trainer available
* Built-in models with feature-based watermarks include Alexnet, Resnet18, DistilBert, and GPT2
* More models support parameter-efficient fine-tuning: ChatGLM2-6B and Bloom-7B1
## Release 1.2.0
### Major Features and Improvements
* Support Federated Training of LLaMA-7B with parameter-efficient fine-tuning.
## Release 1.1.0
### Major Features and Improvements
* Support Federated Training of ChatGLM-6B with parameter-efficient fine-tuning adapters: like Lora and P-Tuning V2 etc.
* Integration of `peft`, which support many parameter-efficient adapters.
================================================
FILE: doc/fate_llm_evaluate.md
================================================
## FATE-LLM Python SDK
FATE-LLM Python SDK provides simple API for evaluating large language models.
Built on [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/), our evaluation tool may be used on pre-trained models from Huggingface, local-built models, as well as FATE-LLM models.
[Built-in datasets](#built-in-tasks) currently include Dolly-15k and Advertise Generation.
Below shows how to evaluate given llm model in few lines. For quick single-model evaluation, below steps should suffice, however, if comparative evaluation among multiple models is desired, CLI is recommended.
```python
from lm_eval.models.huggingface import HFLM
from fate_llm.evaluate.utils import llm_evaluator
# download data for built-in tasks if running fate-llm evaluation for the first time
# alternatively, use CLI `fate-llm data download` to download data
llm_evaluator.download_task("dolly-15k")
# set paths of built-in tasks
llm_evaluator.init_tasks()
# load model
bloom_lm = HFLM(pretrained='bloom-560')
# if loading local model, specify peft storage location
# gpt2_lm = HFLM(pretrained='bloom-560m', peft_path_format="path/to/peft")
# run evaluation
llm_evaluator.evaluate(model=bloom_lm, tasks="dolly-15k", show_result=True)
```
When network allows, or if already cached, tasks from lm-evaluation may be provided for evaluation in similar style.
```python
from lm_eval.models.huggingface import HFLM
from fate_llm.evaluate.utils import llm_evaluator
# load model
bloom_lm = HFLM(pretrained='bloom-560')
# if loading local model, specify peft storage location
# bloom_lm = HFLM(pretrained='bloom-560m', peft_path_format="path/to/peft")
# run evaluation
llm_evaluator.evaluate(model=gpt2_lm, tasks="ceval", show_result=True)
```
## FATE-LLM Command Line Interface
FATE LLM provides built-in tasks for comparing evaluation results of different llm models.
Alternatively, user may provide arbitrary tasks for evaluation.
### install
```bash
cd {path_to_fate_llm}/python
pip install -e .
```
### command options
```bash
fate_llm --help
```
#### evaluate:
1. in:
```bash
fate_llm evaluate -i <path1 to *.yaml>
```
will run llm at
*path1*
2. eval-config:
```bash
fate_llm evaluate -i <path1 to *.yaml> -c <path2>
```
will run llm testsuites in *path1* with evaluation configuration set to *path2*
3. result-output:
```bash
fate_llm evaluate -i <path1 contains *.yaml> -o <path2>
```
will run llm testsuites in *path1* with evaluation result output stored in *path2*
### config
```bash
fate_llm config --help
```
1. new:
```bash
fate_llm config new
```
will create a new evaluation configuration file in current directory
2. show:
```bash
fate_llm config show
```
will show current evaluation configuration
3. edit:
```bash
fate_llm config edit
```
will edit evaluation configuration
### data
```bash
fate_llm data --help
```
1. download:
```bash
fate_llm data download -t <task1> -t <task2> ...
```
will download corresponding data for given tasks
### FATE-LLM Eval job configuration
Configuration of jobs should be specified in a yaml file.
A FATE-LLM testsuite includes the following elements:
- job group: each group includes arbitrary number of jobs with paths
to corresponding script and configuration
- job: name of evaluation job to be run, must be unique within each group
list
- pretrained: path to pretrained model, should be either mmodel name from Hugginface or relative path to
testsuite
- peft: path to peft file, should be relative to testsuite,
optional
- tasks: list of tasks to be evaluated, optional for jobs skipping evaluation
- include_path: should be specified if tasks are user-defined
- eval_conf: path to evaluation configuration file, should be
relative to testsuite; if not provided, will use default conf
```yaml
bloom_lora:
pretrained: "bloom-560m"
peft_path_format: "{{fate_base}}/fate_flow/model/{{job_id}}/guest/{{party_id}}/{{model_task_name}}/0/output/output_model/model_directory"
tasks:
- "dolly-15k"
```
- llm suite
```yaml
bloom_suite:
bloom_zero_shot:
pretrained: "bloom-560m"
tasks:
- "dolly-15k"
```
## Built-in Tasks
Currently, we include the following tasks in FATE-LLM Evaluate:
| Task Name | Alias | Task Type | Metric | source |
|:---------:|:-------------:|:----------:|:-------:|:-------------------------------------------------------------------------:|
| Dolly-15k | dolly-15k | generation | rouge-L | [link](https://huggingface.co/datasets/databricks/databricks-dolly-15k) |
| ADGEN | advertise-gen | generation | rouge-L | [link](https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/README_en.md#instructions) |
Use corresponding alias to reference tasks in the system.
================================================
FILE: doc/standalone_deploy.md
================================================
# FATE-LLM Single-Node Deployment Guide
## 1. Introduction
**Server Configuration:**
- **Quantity:** 1
- **Configuration:** 8 cores / 16GB memory / 500GB hard disk / GPU Machine
- **Operating System:** CentOS Linux release 7
- **User:** User: app owner:apps
The single-node version provides 3 deployment methods, which can be selected based on your needs:
- Install FATE-LLM from PyPI With FATE
- Install FATE-LLM from PyPI with FATE, FATE-Flow, FATE-Client
## 2. Install FATE-LLM from PyPI With FATE
In this way, user can run tasks with Launcher, a convenient way for fast experimental using.
### 2.1 Installing Python Environment
- Prepare and install [conda](https://docs.conda.io/projects/miniconda/en/latest/) environment.
- Create a virtual environment:
```shell
# FATE-LLM requires Python >= 3.10
conda create -n fate_env python=3.10
conda activate fate_env
```
### 2.2 Installing FATE-LLM
This section introduces how to install FATE-LLM from pypi with FATE, execute the following command to install FATE-LLM.
```shell
pip install fate_llm[fate]==2.2.0
```
### 2.3 Usage
After installing successfully, please refer to [tutorials](../README.md#quick-start) to run tasks, tasks describe in the tutorials running will Launcher are all supported.
## 3. Install FATE-LLM from PyPI with FATE, FATE-Flow, FATE-Client
In this way, user can run tasks with Pipeline or Launcher.
### 3.1 Installing Python Environment
Please refer to section-2.1
### 3.2 Installing FATE-LLM with FATE, FATE-Flow, FATE-Client
```shell
pip install fate_client[fate,fate_flow,fate_client]==2.2.0
```
### 3.3 Service Initialization
```shell
mkdir fate_workspace
fate_flow init --ip 127.0.0.1 --port 9380 --home $(pwd)/fate_workspace
pipeline init --ip 127.0.0.1 --port 9380
```
- `ip`: The IP address where the service runs.
- `port`: The HTTP port the service runs on.
- `home`: The data storage directory, including data, models, logs, job configurations, and SQLite databases.
### 3.4 Start Fate-Flow Service
```shell
fate_flow start
fate_flow status # make sure fate_flow service is started
```
FATE-Flow also provides other instructions like stop and restart, use only if users want to stop/restart fate_flow services.
```shell
# Warning: normal installing process does not need to execute stop/restart instructions.
fate_flow stop
fate_flow restart
```
### 3.5 Usage
Please refer to [tutorials](../README.md#quick-start) for more usage guides, tasks describe in the tutorials running will Pipeline or Launcher are all supported.
================================================
FILE: doc/tutorial/fdkt/README.md
================================================
# FATE-LLM: FDKT
The algorithm is based on paper [Federated Domain-Specific Knowledge Transfer on Large Language Models Using Synthetic Data](https://arxiv.org/pdf/2405.14212),
a novel framework that enables domain-specific knowledge transfer from LLMs to SLMs while preserving SLM data privacy.
## Citation
If you publish work that uses FDKT, please cite FDKT as follows:
```
@article{li2024federated,
title={Federated Domain-Specific Knowledge Transfer on Large Language Models Using Synthetic Data},
author={Li, Haoran and Zhao, Xinyuan and Guo, Dadi and Gu, Hanlin and Zeng, Ziqian and Han, Yuxing and Song, Yangqiu and Fan, Lixin and Yang, Qiang},
journal={arXiv preprint arXiv:2405.14212},
year={2024}
}
```
================================================
FILE: doc/tutorial/fdkt/fdkt.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Synthesize Data With FDKT"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this tutoria, we will demonstrate how to Synthesize data using the FATE-LLM framework. In FATE-LLM, we introduce the \"FDKT\" module, specifically designed for domain-specific knowledge transfer on large language models using synthetic data. FDKT Algorithm is based on paper [Federated Domain-Specific Knowledge Transfer on\n",
"Large Language Models Using Synthetic Data](https://arxiv.org/pdf/2405.14212), We integrate its code into the FATE-LLM framework. \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dataset: Yelp\n",
"We processed and sample data of 'Health' subdomain from [Yelp dataset](https://arxiv.org/abs/1509.01626) , the dataset can be downloaded from [here](https://www.yelp.com/dataset). \n",
"Once the dataset has been downloaded, execute the following command to untar the downloaded dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"```shell\n",
"tar -xvf yelp_dataset.tar\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following code will sample 5000 datalines of 'Health' subdomain, and train data will generated under the folder './processed_data/Health/train.json'"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"import sys\n",
"import random\n",
"from pathlib import Path\n",
"random.seed(42)\n",
"\n",
"\n",
"base_dir = \"./\"\n",
"business_data_path = os.path.join(base_dir, 'yelp_academic_dataset_business.json')\n",
"review_data_path = os.path.join(base_dir, 'yelp_academic_dataset_review.json')\n",
"\n",
"business_data_file = open(business_data_path, 'r')\n",
"review_data_file = open(review_data_path, 'r')\n",
"\n",
"categories_list = ['Restaurants', 'Shopping', 'Arts', 'Health']\n",
"business_dic = {}\n",
"data_dict = {}\n",
"for category in categories_list:\n",
" business_dic[category] = set()\n",
" data_dict[category] = []\n",
"\n",
"\n",
"def get_categories(categories):\n",
" return_list = []\n",
" for category in categories_list:\n",
" if category in categories:\n",
" return_list.append(category)\n",
" return return_list\n",
"\n",
"\n",
"for line in business_data_file.readlines():\n",
" dic = json.loads(line)\n",
" if 'categories' in dic.keys() and dic['categories'] is not None:\n",
" category = get_categories(dic['categories'])\n",
" if len(category) == 1:\n",
" business_dic[category[0]].add(dic['business_id'])\n",
"\n",
"# for category in categories_list:\n",
"for line in review_data_file.readlines():\n",
" dic = json.loads(line)\n",
" if 'business_id' in dic.keys() and dic['business_id'] is not None:\n",
" for category in categories_list:\n",
" if dic['business_id'] in business_dic[category]:\n",
" if dic['text'] is not None and dic['stars'] is not None:\n",
" data_dict[category].append({'text': dic['text'], 'stars': dic['stars']})\n",
" break\n",
"\n",
"train_data_path = os.path.join('processed_data', \"Health\", 'train.json')\n",
"os.makedirs(Path(train_data_path).parent, exist_ok=True)\n",
"train_data_file = open(train_data_path, 'w')\n",
"data_list = data_dict[\"Health\"]\n",
"\n",
"sample_data_dict = dict()\n",
"\n",
"for data in data_list:\n",
" star = int(data[\"stars\"])\n",
" if star not in sample_data_dict:\n",
" sample_data_dict[star] = []\n",
"\n",
" sample_data_dict[star].append(data)\n",
"\n",
"data_list = []\n",
"star_keys = list(sample_data_dict.keys())\n",
"for star in star_keys:\n",
" sample_data = sample_data_dict[star][:1000]\n",
" random.shuffle(sample_data)\n",
" data_list.extend(sample_data)\n",
"\n",
"random.shuffle(data_list)\n",
"json.dump(data_list, train_data_file, indent=4)\n",
"train_data_file.close()\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Models Use\n",
"Please download the following models, these models are used for data augmentation process.\n",
"\n",
"LLM: [Qwen1.5-7B-Chat](https://huggingface.co/Qwen/Qwen1.5-7B-Chat) \n",
"SLM: [gpt2-xl](https://huggingface.co/openai-community/gpt2-xl)\n",
"\n",
"MeanWhile, 'all-mpnet-base-v2' is used to generate embedding vectors in LLM side.\n",
"\n",
"Embedding Model: [all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Running FDKT Data Synthetic Process With Launcher (Experimential Using)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SLM Setting\n",
"\n",
"In this section, we will introduce some key configurations in SLM side."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 1. loading model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import transformers\n",
"from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n",
"\n",
"\n",
"slm_pretrained_path = \"gpt2-xl\" # modity this to local directory\n",
"slm = transformers.AutoModelForCausalLM.from_pretrained(slm_pretrained_path, torch_dtype=torch.bfloat16)\n",
"tokenizer = get_tokenizer(slm_pretrained_path)\n",
"tokenizer.pad_token_id = tokenizer.eos_token_id\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 2. Initialize SLM Training Arugments"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fate_llm.algo.fdkt.fdkt_data_aug import FDKTTrainingArguments\n",
"\n",
"\n",
"training_args = FDKTTrainingArguments(\n",
" use_cpu=False, # use gpu to do dp(differential privacy) training process\n",
" device_id=0, # the device number of gpu\n",
" num_train_epochs=1, # dp training epochs\n",
" per_device_train_batch_size=2, # batch size of dp training\n",
" slm_generation_batch_size=32, # batch_size to generate data in slm side\n",
" seq_num_for_single_category=300, # data num for each category(label)\n",
" slm_generation_config=dict(\n",
" max_new_tokens=256,\n",
" temperature=1.0,\n",
" top_k=50,\n",
" top_p=0.9,\n",
" repetition_penalty=1.0,\n",
" pad_token_id=tokenizer.eos_token_id\n",
" ),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 3. Initlaize DataSet Instance\n",
"\n",
"We provide default templates for dataset \"Yelp\" and \"AGNews\", user can refer [here](https://github.com/FederatedAI/FATE-LLM/tree/dev-2.2.0/python/fate_llm/dataset/data_config) for more details. If you want to use your own dataset, please provide fields label_key/text_key/augment_format/filter_format/tokenize_format/sub_domain/label_list/few_shot_format/text_with_label_format like the two default templates and passing it as and argument."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fate_llm.dataset.flex_dataset import FlexDataset\n",
"\n",
"\n",
"ds = FlexDataset(\n",
" tokenizer_name_or_path=slm_pretrained_path,\n",
" load_from=\"json\",\n",
" data_part=\"train\",\n",
" dataset_name=\"yelp_review\", # use default template\n",
" # config=dict/template_path # if dataset_name not equals to \"yelp_review\" or \"ag_news\"\n",
" need_preprocess=True,\n",
" select_num=2000, # use data_num=2000 to train, default is None, None means using all data\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### LLM Setting\n",
"\n",
"In this section, we will introduce some key configurations in LLM side."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 1. Deploy VLLM Server And Use OpenAI API Protocol To SpeedUp LLM Inference"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"please copy the following code to local file create_and_start_vllm.sh, then run the bash code by executing \"bash create_and_start_vllm.sh\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# create_and_start_vllm.sh\n",
"# create vllm enviroment\n",
"\n",
"python -m venv vllm_venv\n",
"source vllm_venv/bin/activate\n",
"pip install vllm==0.4.3\n",
"pip install numpy==1.26.4 # numpy >= 2.0.0 will raise error, so reinstall numpy<2.0.0\n",
"\n",
"# please modify Qwen1.5-7B-Chat to local llm model saving path\n",
"export CUDA_VISIBLE_DEVICES=1,2\n",
"nohup python -m vllm.entrypoints.openai.api_server --host 127.0.0.1 --port 9999 --model Qwen1.5-7B-Chat --dtype=half --enforce-eager --api-key demo --device cuda -tp 2 &"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 2. Initialize LLM Training Arugments"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fate_llm.algo.fdkt.fdkt_data_aug import FDKTTrainingArguments\n",
"\n",
"\n",
"training_args = FDKTTrainingArguments(\n",
" sample_num_per_cluster=4, # use this to estimate the number of clusters, n_clusters=(len(dataset) + sample_num_per_cluster - 1) // sample_num_per_cluster\n",
" filter_prompt_max_length=2**16,\n",
" filter_generation_config=dict(\n",
" max_tokens=512,\n",
" ),\n",
" aug_generation_config=dict(\n",
" max_tokens=4096,\n",
" temperature=0.8,\n",
" top_p=0.9,\n",
" ),\n",
" aug_prompt_num=20000, # prompts use for data augmentation\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 3. Initialize Embedding Generated Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fate_llm.model_zoo.embedding_transformer.st_model import SentenceTransformerModel\n",
"\n",
"\n",
"embedding_lm = SentenceTransformerModel(model_name_or_path=\"all-mpnet-base-v2\").load() # modified model_name_or_path to local model saved path"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 4. Initalize OpenAI Api For Inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fate_llm.algo.fdkt.inference_inst import api_init\n",
"\n",
"\n",
"inference_inst = api_init(\n",
" api_url=\"http://127.0.0.1:9999/v1/\",\n",
" model_name=\"Qwen1.5-7B-Chat\", # modified model_name to local Meta-Llama-3-8B-Instruct saved path\n",
" api_key=\"demo\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Complete Code \n",
"\n",
"Please paste the code in \"run_fdkt_by_launcher.py\" and execute it with the following command. Once the process is finished, augmentation data will be saved in the current directory, whose filename is aug_data_result.json"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"python run_fdkt_by_launcher.py --parties guest:9999 arbiter:10000 --log_level INFO"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"\n",
"import torch\n",
"from fate.arch import Context\n",
"from fate.arch.launchers.multiprocess_launcher import launch\n",
"\n",
"# please replace the following four variables to local paths\n",
"llm_pretrained_path = \"Qwen1.5-7B-Chat\"\n",
"embedding_model_path = \"all-mpnet-base-v2\"\n",
"slm_pretrained_path = \"gpt2-xl\"\n",
"slm_data_path = \"./processed_data/Health/train.json\"\n",
"\n",
"\n",
"def get_optimizer(model, optimizer=\"adam\", lr=1e-4):\n",
" if optimizer == \"adam\":\n",
" optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)\n",
" elif optimizer == \"adamw\":\n",
" optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr)\n",
" else:\n",
" raise NotImplementedError(\"Given optimizer type is not supported\")\n",
" return optimizer\n",
"\n",
"\n",
"def train_slm(ctx):\n",
" import transformers\n",
" from fate_llm.algo.fdkt.fdkt_data_aug import (\n",
" FDKTSLM,\n",
" FDKTTrainingArguments\n",
" )\n",
" from fate_llm.dataset.flex_dataset import FlexDataset\n",
" from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n",
" from transformers.data import DataCollatorForSeq2Seq\n",
"\n",
" slm = transformers.AutoModelForCausalLM.from_pretrained(slm_pretrained_path, torch_dtype=torch.bfloat16)\n",
" tokenizer = get_tokenizer(slm_pretrained_path)\n",
" tokenizer.pad_token_id = tokenizer.eos_token_id\n",
" training_args = FDKTTrainingArguments(\n",
" use_cpu=False,\n",
" device_id=0,\n",
" num_train_epochs=1,\n",
" per_device_train_batch_size=2,\n",
" slm_generation_batch_size=32,\n",
" seq_num_for_single_category=2000,\n",
" slm_generation_config=dict(\n",
" max_new_tokens=256,\n",
" do_sample=True,\n",
" temperature=1.0,\n",
" top_k=50,\n",
" top_p=0.9,\n",
" repetition_penalty=1.0,\n",
" pad_token_id=tokenizer.eos_token_id\n",
" ),\n",
" # inference_method=\"vllm\",\n",
" )\n",
"\n",
" ds = FlexDataset(\n",
" tokenizer_name_or_path=slm_pretrained_path,\n",
" load_from=\"json\",\n",
" data_part=\"train\",\n",
" dataset_name=\"yelp_review\",\n",
" need_preprocess=True,\n",
" select_num=2000, # use 2000 data to train, default is None, using all data\n",
" )\n",
" ds.load(slm_data_path)\n",
"\n",
" fdkt_runner = FDKTSLM(\n",
" ctx=ctx,\n",
" model=slm,\n",
" training_args=training_args,\n",
" tokenizer=tokenizer,\n",
" train_set=ds,\n",
" optimizer=get_optimizer(slm),\n",
" data_collator=DataCollatorForSeq2Seq(tokenizer, label_pad_token_id=tokenizer.pad_token_id)\n",
" )\n",
"\n",
" aug_data = fdkt_runner.aug_data()\n",
" with open(\"./aug_data_result.json\", \"w\") as fout:\n",
" fout.write(json.dumps(aug_data, indent=4))\n",
"\n",
"\n",
"def train_llm(ctx):\n",
" from fate_llm.algo.fdkt.fdkt_data_aug import (\n",
" FDKTLLM,\n",
" FDKTTrainingArguments\n",
" )\n",
" from fate_llm.model_zoo.embedding_transformer.st_model import SentenceTransformerModel\n",
" from fate_llm.dataset.flex_dataset import FlexDataset\n",
" from fate_llm.algo.fdkt.inference_inst import api_init, vllm_init\n",
"\n",
" embedding_lm = SentenceTransformerModel(model_name_or_path=embedding_model_path).load()\n",
" training_args = FDKTTrainingArguments(\n",
" sample_num_per_cluster=4,\n",
" filter_prompt_max_length=2**14,\n",
" filter_generation_config=dict(\n",
" max_tokens=4096,\n",
" ),\n",
" use_cpu=False,\n",
" aug_generation_config=dict(\n",
" max_tokens=4096,\n",
" temperature=0.8,\n",
" top_p=0.9,\n",
" ),\n",
" aug_prompt_num=20000,\n",
" )\n",
"\n",
" ds = FlexDataset(\n",
" tokenizer_name_or_path=llm_pretrained_path,\n",
" load_from=\"json\",\n",
" data_part=\"train\",\n",
" dataset_name=\"yelp_review\",\n",
" need_preprocess=True,\n",
" few_shot_num_per_label=1,\n",
" )\n",
"\n",
" inference_inst = api_init(\n",
" api_url=\"http://127.0.0.1:9999/v1/\",\n",
" model_name=llm_pretrained_path,\n",
" api_key=\"demo\"\n",
" )\n",
"\n",
" fdkt_runner = FDKTLLM(\n",
" ctx=ctx,\n",
" embedding_model=embedding_lm,\n",
" training_args=training_args,\n",
" dataset=ds,\n",
" inference_inst=inference_inst,\n",
" )\n",
"\n",
" fdkt_runner.aug_data()\n",
"\n",
"\n",
"def run(ctx: Context):\n",
" if ctx.is_on_arbiter:\n",
" train_llm(ctx)\n",
" else:\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
" train_slm(ctx)\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" launch(run)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Running FDKT with Pipeline (Industrial Using)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Please make sure that FATE and FATE-Flow has been deployed, paste the following code to test_fdkt_by_pipeline.py, the execute \"python test_fdkt_by_pipeline.py\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_fdkt_runner\n",
"from fate_client.pipeline.components.fate.nn.algo_params import FDKTTrainingArguments\n",
"from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMDataFuncLoader\n",
"from fate_client.pipeline import FateFlowPipeline\n",
"from fate_client.pipeline.components.fate.reader import Reader\n",
"from fate_client.pipeline.components.fate.nn.torch import nn, optim\n",
"\n",
"\n",
"guest = '9999'# replace this party id to actual guest party id in your enviroment\n",
"arbiter = '9999'# replace this party id to actual arbiter party id in your enviroment\n",
"\n",
"# please replace the following four variables to local paths\n",
"llm_pretrained_path = \"Qwen1.5-7B-Chat\"\n",
"embedding_model_path = \"all-mpnet-base-v2/\"\n",
"slm_pretrained_path = \"gpt2-xl\"\n",
"slm_data_path = \"./processed_data/Health/train.json\" # should be absolute path\n",
"\n",
"\n",
"def get_llm_conf():\n",
" embedding_model = LLMModelLoader(\n",
" \"embedding_transformer.st_model\",\n",
" \"SentenceTransformerModel\",\n",
" model_name_or_path=embedding_model_path\n",
" )\n",
"\n",
" dataset = LLMDatasetLoader(\n",
" \"flex_dataset\",\n",
" \"FlexDataset\",\n",
" tokenizer_name_or_path=llm_pretrained_path,\n",
" need_preprocess=True,\n",
" dataset_name=\"yelp_review\",\n",
" data_part=\"train\",\n",
" load_from=\"json\",\n",
" few_shot_num_per_label=1,\n",
" )\n",
"\n",
" training_args = FDKTTrainingArguments(\n",
" sample_num_per_cluster=4,\n",
" filter_prompt_max_length=2 ** 14,\n",
" filter_generation_config=dict(\n",
" max_tokens=4096,\n",
" ),\n",
" use_cpu=False,\n",
" aug_generation_config=dict(\n",
" max_tokens=4096,\n",
" temperature=0.8,\n",
" top_p=0.9,\n",
" ),\n",
" aug_prompt_num=20000,\n",
" )\n",
"\n",
" inference_inst_conf = dict(\n",
" module_name=\"fate_llm.algo.fdkt.inference_inst\",\n",
" item_name=\"api_init\",\n",
" kwargs=dict(\n",
" api_url=\"http://127.0.0.1:9999/v1/\",\n",
" model_name=llm_pretrained_path,\n",
" api_key=\"demo\"\n",
" )\n",
" )\n",
"\n",
" return get_config_of_fdkt_runner(\n",
" training_args=training_args,\n",
" embedding_model=embedding_model,\n",
" dataset=dataset,\n",
" inference_inst_conf=inference_inst_conf,\n",
" )\n",
"\n",
"\n",
"def get_slm_conf():\n",
" slm_model = LLMModelLoader(\n",
" \"hf_model\",\n",
" \"HFAutoModelForCausalLM\",\n",
" pretrained_model_name_or_path=slm_pretrained_path,\n",
" torch_dtype=\"bfloat16\",\n",
" )\n",
"\n",
" tokenizer = LLMDataFuncLoader(\n",
" \"tokenizers.cust_tokenizer\",\n",
" \"get_tokenizer\",\n",
" tokenizer_name_or_path=slm_pretrained_path,\n",
" pad_token_id=50256\n",
" )\n",
"\n",
" training_args = FDKTTrainingArguments(\n",
" use_cpu=False,\n",
" device_id=1,\n",
" num_train_epochs=1,\n",
" per_device_train_batch_size=2,\n",
" slm_generation_batch_size=32,\n",
" seq_num_for_single_category=2000,\n",
" slm_generation_config=dict(\n",
" max_new_tokens=256,\n",
" do_sample=True,\n",
" temperature=1.0,\n",
" top_k=50,\n",
" top_p=0.9,\n",
" repetition_penalty=1.0,\n",
" pad_token_id=50256\n",
" ),\n",
" )\n",
"\n",
" dataset = LLMDatasetLoader(\n",
" \"flex_dataset\",\n",
" \"FlexDataset\",\n",
" tokenizer_name_or_path=slm_pretrained_path,\n",
" need_preprocess=True,\n",
" dataset_name=\"yelp_review\",\n",
" data_part=\"train\",\n",
" load_from=\"json\",\n",
" select_num=2000,\n",
" few_shot_num_per_label=1,\n",
" )\n",
"\n",
" optimizer = optim.Adam(lr=0.01)\n",
"\n",
" return get_config_of_fdkt_runner(\n",
" model=slm_model,\n",
" tokenizer=tokenizer,\n",
" training_args=training_args,\n",
" dataset=dataset,\n",
" optimizer=optimizer,\n",
" data_collator=LLMDataFuncLoader(\n",
" \"data_collator.cust_data_collator\",\n",
" \"get_seq2seq_data_collator\",\n",
" label_pad_token_id=50256,\n",
" tokenizer_name_or_path=slm_pretrained_path,\n",
" pad_token_id=50256,\n",
" ),\n",
" )\n",
"\n",
"\n",
"pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter)\n",
"pipeline.bind_local_path(path=slm_data_path, namespace=\"experiment\", name=\"slm_train\")\n",
"\n",
"\n",
"reader_0 = Reader(\"reader_0\", runtime_parties=dict(guest=guest))\n",
"reader_0.guest.task_parameters(\n",
" namespace=\"experiment\",\n",
" name=\"slm_train\"\n",
")\n",
"\n",
"\n",
"homo_nn_0 = HomoNN(\n",
" 'homo_nn_0',\n",
" train_data=reader_0.outputs[\"output_data\"],\n",
" runner_module=\"fdkt_runner\",\n",
" runner_class=\"FDKTRunner\",\n",
")\n",
"\n",
"homo_nn_0.arbiter.task_parameters(\n",
" runner_conf=get_llm_conf()\n",
")\n",
"\n",
"homo_nn_0.guest.task_parameters(\n",
" runner_conf=get_slm_conf()\n",
")\n",
"\n",
"pipeline.add_tasks([reader_0, homo_nn_0])\n",
"pipeline.conf.set(\"task\", dict(engine_run={\"cores\": 1}))\n",
"\n",
"pipeline.compile()\n",
"pipeline.fit()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.15"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
================================================
FILE: doc/tutorial/fedcot/README.md
================================================
# FATE-LLM: FedCoT
The algorithm is based on paper ["FedCoT: Federated Chain-of-Thought Distillation for Large Language Models"](https://aclanthology.org/anthology-files/anthology-files/pdf/findings/2025.findings-emnlp.454.pdf), We integrate its code into the FATE-LLM framework.
## Citation
If you publish work that uses FedMKT, please cite FedCoT as follows:
```
@inproceedings{fan2025fedcot,
title={FedCoT: Federated Chain-of-Thought Distillation for Large Language Models},
author={Fan, Tao and Chen, Weijing and Kang, Yan and Ma, Guoqiang and Gu, Hanlin and Song, Yuanfeng and Fan, Lixin and Yang, Qiang},
booktitle={Findings of the Association for Computational Linguistics: EMNLP 2025},
pages={8546--8557},
year={2025}
}
```
================================================
FILE: doc/tutorial/fedcot/encoder_decoder_tutorial.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "a163d9c2-f9d6-4c61-a8e8-76a3f66c38ae",
"metadata": {},
"source": [
"# FedCoT - Train a SLM Encoder Decoder"
]
},
{
"cell_type": "markdown",
"id": "f2b56772-26d5-44fe-9c51-7bc662478b98",
"metadata": {},
"source": [
"FedCoT is an innovative framework designed to distill knowledge from large language models (LLMs) to small language models (SLMs) while ensuring data privacy. This method involves a strategy that trains a small language model (SLM) to learn from perturbed and recovered texts. The SLM can then encode raw text, produce results similar to differential privacy mechanisms, and return higher quality recovered text.\n",
"\n",
"In this tutorial, we will introduce how to train an SLM using the built-in trainer."
]
},
{
"cell_type": "markdown",
"id": "62c6d18a-cc91-4cf5-9cfd-0f97095f7041",
"metadata": {},
"source": [
"## Prepare Data\n",
"\n",
"Several steps need to be done to prepare data for training a SLM encoder-decoder model:\n",
"- Sample data from original dataset(For example 50%)\n",
"- Organize raw text and get a direct rationale reply from a remote LLM\n",
"- Perturb doc using InferDPTKit to get perturbed docs\n",
"- Get perturbed replies from a remote LLM\n",
"- Organize training data\n",
"\n",
"### Sample data\n",
"Here we will use the arc-easy data as an example, and take first 50% of the original dataset\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "40cc1bb8-a17c-4abc-9279-0849e98ca116",
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset, load_from_disk\n",
"ds = load_dataset('arc_easy')['train']\n",
"ds = [ds[i] for i in range(len(ds)//2)]"
]
},
{
"cell_type": "markdown",
"id": "0caff897-5b2b-4409-8601-10f973133b10",
"metadata": {},
"source": [
"### Get Direct Replies from A Remote LLM\n",
"\n",
"We use the inference class to create an API for remote LLMs, or you can implement this part on your own."
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "cf128b46-dea2-4eb4-bf31-568e56b9b78e",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
}
],
"source": [
"from fate_llm.inference.api import APICompletionInference\n",
"from jinja2 import Template\n",
"from transformers import AutoTokenizer\n",
"\n",
"# We are using a Qwen 14B model as the remote model\n",
"# You can change the setting\n",
"api = APICompletionInference(\n",
" api_url='http://172.21.140.2:8081/v1',\n",
" api_key='EMPTY',\n",
" model_name='/data/cephfs/llm/models/Qwen1.5-14B-Chat'\n",
")\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained('/data/cephfs/llm/models/Qwen1.5-0.5B-Chat/')\n",
"\n",
"arc_e_template_r = \"\"\"Select Answer from Choices and explain it in \"Rationale\" with few words. Please refer to the example to write the rationale.\n",
"Use <end> to finish your rationle.\n",
"\n",
"Example(s):\n",
"Question:Which factor will most likely cause a person to develop a fever?\n",
"Choices:['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach']\n",
"Rationale:A bacterial infection in the bloodstream triggers the immune system to respond, therefore often causing a fever as the body tries to fight off the bacteria. Therefore, the answer is 'a bacterial population in the bloodstream'\n",
"\n",
"Please explain:\n",
"Question:{{question}}\n",
"Choices:{{choices.text}}\n",
"Rationale:\n",
"\"\"\"\n",
"\n",
"template = Template(arc_e_template_r)\n",
"docs_to_infer = [tokenizer.apply_chat_template([{'role':'system', 'content': 'you are a helpful assistant'}, {'role':'user', 'content': template.render(i)}], add_generation_prompt=True, tokenize=False) for i in ds]\n",
"results = api.inference(docs_to_infer, {\n",
" 'stop': ['<|im_end|>', '<end>', '<end>\\n', '<end>\\n\\n', '.\\n\\n\\n\\n\\n', '<|end_of_text|>', '>\\n\\n\\n'],\n",
" 'temperature': 0.01,\n",
" 'max_tokens': 256\n",
"})\n",
"\n",
"for i, r in zip(ds, results):\n",
" i['rationale'] = r"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "212822ab-9f64-49a2-bb95-ef8ee2de8e49",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"A fever is a response to an infection, typically caused by bacteria or viruses. So, the answer is 'a bacterial population in the bloodstream' because it indicates an immune response to a foreign invader. 'Several viral particles on the skin' could also lead to a fever if they enter the body, but bloodstream presence is more direct. The other choices are unrelated to fever development.\n"
]
}
],
"source": [
"print(results[0])"
]
},
{
"cell_type": "markdown",
"id": "0f6a0039-1530-4b87-a098-fd2eb01805c2",
"metadata": {},
"source": [
"### Perturb Docs & Replies\n",
"\n",
"You can refer to the InferDPT tutorial for guidance on using the InferDPTKit to generate perturbed documents: [InferDPT Document](./)\n",
"We can produce perturbed doc using InferDPTKit:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "39249747-bfaa-43bf-8b66-896568941ab8",
"metadata": {},
"outputs": [],
"source": [
"from fate_llm.algo.inferdpt.utils import InferDPTKit\n",
"path_to_kit = '/data/projects/inferdpt/test_fate_llm/'\n",
"kit = InferDPTKit.load_from_path(path_to_kit)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "39b9cefa-dfdb-4bac-b313-4ca3bc118aee",
"metadata": {},
"outputs": [],
"source": [
"import copy\n",
"tmp_ds = copy.deepcopy(ds)\n",
"\n",
"q_doc = [kit.perturb(i, epsilon=1.0) for i in [Template(\"\"\"{{question}}\"\"\").render(i) for i in tmp_ds]]\n",
"c_doc = [kit.perturb(i, epsilon=1.0) for i in [Template(\"\"\"{{choices.text}}\"\"\").render(i) for i in tmp_ds]]\n",
"for i,q,c in zip(tmp_ds,q_doc,c_doc):\n",
" i['question'] = q\n",
" i['choices']['text'] = c"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "61b30886-746c-43c5-889a-a6583dc939d0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'id': 'Mercury_7179953',\n",
" 'question': 'stuff two alpha Rogers are today chap in Department?',\n",
" 'choices': {'text': \"['muscular and skeletal', 'digestive and muscular', 'skeletal and pasteiratory', 'respiratory and exhibive']\",\n",
" 'label': ['A', 'B', 'C', 'D']},\n",
" 'answerKey': 'A',\n",
" 'rationale': {...}}"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tmp_ds[6]"
]
},
{
"cell_type": "markdown",
"id": "fed90297-9957-4f8b-a53c-37a03d516c78",
"metadata": {},
"source": [
"And then send formatted docs to remote LLM for perturbed responses:"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "5b8bd833-fb0f-418b-bd9b-6452e8ae4d6c",
"metadata": {},
"outputs": [],
"source": [
"template = Template(arc_e_template_r)\n",
"docs_to_infer = [tokenizer.apply_chat_template([{'role':'system', 'content': 'you are a helpful assistant'}, {'role':'user', 'content': template.render(i)}], add_generation_prompt=True, tokenize=False) for i in tmp_ds]\n",
"p_results = api.inference(docs_to_infer, {\n",
" 'stop': ['<|im_end|>', '<end>', '<end>\\n', '<end>\\n\\n', '.\\n\\n\\n\\n\\n', '<|end_of_text|>', '>\\n\\n\\n'],\n",
" 'temperature': 0.01,\n",
" 'max_tokens': 256\n",
"})"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "187361fa-8b73-4a01-9039-f52ec98a5791",
"metadata": {},
"outputs": [],
"source": [
"for i, r in zip(ds, p_results):\n",
" i['p_rationale'] = r\n",
"\n",
"for i,q,c in zip(ds, q_doc, c_doc):\n",
" i['p_question'] = q\n",
" i['p_choice'] = c"
]
},
{
"cell_type": "markdown",
"id": "927b2265-4e87-4275-98dc-7f33d405e19a",
"metadata": {},
"source": [
"### Organize Training Data\n",
"\n",
"As described in the original paper, we need to train the encoder and decoder in one model.\n",
"We can organize the training data using templates below:"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "9292ad25-12c7-418a-9e77-b433b95f57ac",
"metadata": {},
"outputs": [],
"source": [
"train_data = []\n",
"\n",
"encoder_prompt = Template(\"\"\"Disrupt the main words in the original text so that it becomes difficult to recognize, but at the same time, try to maintain the original meaning as much as possible. Use <end> to end your reply.\n",
"Origin Doc: \n",
"Question:{{question}}\n",
"Choices:{{choices.text}}\n",
"\n",
"Perturbed Doc:\n",
"\"\"\")\n",
"\n",
"encoder_out = Template(\"\"\"\n",
"Question:{{p_question}}\n",
"Choices:{{p_choice}}<end>\n",
"\"\"\")\n",
"\n",
"decoder_in = Template(\"\"\"This is a perturbed question and its corresponding answer(rationale). And following is the original question. Try to recover the correct rationale from docs provided.\n",
"\n",
"Perturbed doc and rationale:\n",
"Question:{{p_question}}\n",
"Choices:{{p_choice}}\n",
"Rationale:{{p_rationale}}\n",
"\n",
"Original Doc:\n",
"Question:{{question}}\n",
"Choices:{{choices.text}}\n",
"\n",
"Recover Rationale:\n",
"\"\"\")\n",
"\n",
"decoder_out = Template(\"\"\"{{rationale}}<end>\"\"\")\n",
"\n",
"\n",
"for i in ds:\n",
" a = {}\n",
" a['encoder_in'] = encoder_prompt.render(i)\n",
" a['encoder_out'] = encoder_out.render(i)\n",
" a['decoder_in'] = decoder_in.render(i)\n",
" a['decoder_out'] = decoder_out.render(i)\n",
" train_data.append(a)\n",
"\n",
"import torch\n",
"torch.save(train_data, './slm_ed_train_data.pkl')"
]
},
{
"cell_type": "markdown",
"id": "dd73db44-4e73-4c1e-8f27-755522587636",
"metadata": {},
"source": [
"## Train Script\n",
"\n",
"The key step: preparing data is now done. Then we can train a SLM model using the train data. You can use following dataset&trainer class to train an encoder-decoder slm model. Here we use Qwen-0.5B as the example."
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "eb01c591-3c04-4317-8bb0-f55846fb1b66",
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoModelForCausalLM, AutoTokenizer"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "f0da4e10-af80-4216-8ff8-5816dabc8526",
"metadata": {},
"outputs": [],
"source": [
"model = AutoModelForCausalLM.from_pretrained('/data/cephfs/llm/models/Qwen1.5-0.5B/').half().cuda()"
]
},
{
"cell_type": "code",
"execution_count": 75,
"id": "634fc973-29c8-499e-a99e-d50b7ee54124",
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import Dataset\n",
"\n",
"class EDDataset(Dataset):\n",
"\n",
" def __init__(self, tokenizer, train_data, max_input_length=64, max_target_length=64):\n",
" self.tokenizer = tokenizer\n",
" self.dataset = train_data\n",
" self.max_input_length = max_input_length\n",
" self.max_target_length = max_target_length\n",
" self.max_seq_length = max_input_length + max_target_length + 1\n",
"\n",
" def get_str_item(self, i) -> dict:\n",
"\n",
" data_item = self.dataset[i]\n",
" ret_dict = {\n",
" 'encoder':{\n",
" 'input': data_item['encoder_in'],\n",
" 'output': data_item['encoder_out']\n",
" },\n",
" 'decoder':{\n",
" 'input': data_item['decoder_in'],\n",
" 'output': data_item['decoder_out']\n",
" }\n",
" }\n",
" return ret_dict\n",
"\n",
" def _process_item(self, data_item):\n",
"\n",
" a_ids = self.tokenizer.encode(text=data_item['input'], add_special_tokens=True, truncation=True,\n",
" max_length=self.max_input_length)\n",
" b_ids = self.tokenizer.encode(text=data_item['output'], add_special_tokens=False, truncation=True,\n",
" max_length=self.max_target_length)\n",
" context_length = len(a_ids)\n",
" input_ids = a_ids + b_ids + [self.tokenizer.eos_token_id]\n",
" labels = [self.tokenizer.pad_token_id] * context_length + b_ids + [self.tokenizer.eos_token_id]\n",
" pad_len = self.max_seq_length - len(input_ids)\n",
" input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len\n",
" labels = labels + [self.tokenizer.pad_token_id] * pad_len\n",
" labels = [(l if l != self.tokenizer.pad_token_id else -100) for l in labels]\n",
"\n",
" assert len(input_ids) == len(labels), f\"length mismatch: {len(input_ids)} vs {len(labels)}\"\n",
"\n",
" return {\n",
" \"input_ids\": input_ids,\n",
" \"labels\": labels\n",
" }\n",
"\n",
" def get_tokenized_item(self, i) -> dict: \n",
"\n",
" str_item = self.get_str_item(i)\n",
" ret_dict = {\n",
" 'encoder': self._process_item(str_item['encoder']),\n",
" 'docoder': self._process_item(str_item['decoder'])\n",
" }\n",
" return ret_dict\n",
"\n",
" def __getitem__(self, i) -> dict:\n",
" item = self.get_tokenized_item(i)\n",
" return item"
]
},
{
"cell_type": "code",
"execution_count": 76,
"id": "5f914b1f-cf14-4bdc-acc9-ae1b73cf857c",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
}
],
"source": [
"train_ds = EDDataset(AutoTokenizer.from_pretrained('/data/cephfs/llm/models/Qwen1.5-0.5B/'), train_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "817084b2-2439-45d8-aa1b-da0b1a8a2846",
"metadata": {},
"outputs": [],
"source": [
"print(train_ds.get_str_item(0))\n",
"print(train_ds[0])"
]
},
{
"cell_type": "code",
"execution_count": 82,
"id": "303bcb23-d54b-4375-bad2-bf5450c14f28",
"metadata": {},
"outputs": [],
"source": [
"from fate_llm.algo.fedcot.slm_encoder_decoder_trainer import EncoderDecoderPrefixTrainer, EDPrefixDataCollator"
]
},
{
"cell_type": "markdown",
"id": "aa5a0b4f-cd03-4867-8753-fc5bcb036c69",
"metadata": {},
"source": [
"After completing the setup, you can utilize the EncoderDecoderPrefixTrainer, EDPrefixDataCollator, and the training dataset to train an SLM encoder-decoder model following the Huggingface approach! "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: doc/tutorial/fedcot/fedcot_tutorial.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "9234355d-389f-484f-9fc2-7b17563b3390",
"metadata": {},
"source": [
"# FedCoT Tutorial\n",
"\n",
"## Introduction to FedCoT\n",
"\n",
"FedCoT (Federated Chain-of-Thought) is a novel framework designed to distill knowledge from large language models (LLMs) to small language models (SLMs) while ensuring data privacy. The framework addresses two major challenges faced by LLM deployment in real-world applications: the privacy of domain-specific knowledge and resource constraints.\n",
"\n",
"FedCoT adopts a server-client architecture where the client sends perturbed prompts to the server-side LLM for inference, generating perturbed rationales. The client then decodes these rationales and uses them to enrich the training of its task-specific SLM, ultimately enhancing its performance.\n",
"\n",
"FedCoT introduces two privacy protection strategies: \n",
"- **the Exponential Mechanism Strategy**\n",
"- **the Encoder-Decoder Strategy**\n",
" \n",
"The Exponential Mechanism Strategy utilizes a DP(differential privacy) based exponential mechanism to obfuscate user prompts, while the Encoder-Decoder Strategy employs a specialized Encoder-Decoder SLM to encode and decode perturbed prompts and rationales. These strategies effectively balance user privacy and the usability of rationales, allowing for secure and enhanced training of the client's SLM without compromising on privacy concerns.\n",
"\n",
"Through experiments on various text generation tasks, FedCoT demonstrates its effectiveness in training task-specific SLMs with enhanced performance, significantly improving the SLM's capabilities while prioritizing data privacy protection. For more details, please refer to the paper: [FedCoT: Federated Chain-of-Thought Distillation for Large Language Models](https://arxiv.org/pdf/2406.12403).\n",
"\n",
"**Before reading this tutorial, we strongly recommend that you first read [the InferDPT](./) tutorial.**\n",
"\n",
"## Use the Infer Client & Server\n",
"\n",
"In this section, we are going to introduce the inference part, which is the key part of FedCoT that generates useful rationales with privacy-preserving. You can use InferDPT(which utilize the Exponential Mechanism Strategy) or specifically trained SLM as the text encoder & decoder. In this section, we retrieve a sample from the arc-easy dataset as an example:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "c443c920-31ff-446a-801f-d7a02409a8c0",
"metadata": {},
"outputs": [],
"source": [
"test_example = {'id': 'Mercury_7220990',\n",
"'question': 'Which factor will most likely cause a person to develop a fever?',\n",
"'choices': {'text': ['a leg muscle relaxing after exercise',\n",
"'a bacterial population in the bloodstream',\n",
"'several viral particles on the skin',\n",
"'carbohydrates being digested in the stomach'],\n",
"'label': ['A', 'B', 'C', 'D']},\n",
"'answerKey': 'B'}"
]
},
{
"cell_type": "markdown",
"id": "46646b18-46bb-476d-8b1d-1ef661446929",
"metadata": {},
"source": [
"### Fate Context\n",
"\n",
"We need to create fate context to enable the communication between client and server. Then, we can initialize infer client(who will encodes the raw prompt and decodes the perturbed response) and server(who deploys the LLM) to enable secure inference."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "0cc8e8f8-88d7-45ab-a988-5ead06356418",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"arbiter = (\"arbiter\", 10000)\n",
"guest = (\"guest\", 10000)\n",
"host = (\"host\", 9999)\n",
"name = \"fed1\"\n",
"\n",
"def create_ctx(local):\n",
" from fate.arch import Context\n",
" from fate.arch.computing.backends.standalone import CSession\n",
" from fate.arch.federation.backends.standalone import StandaloneFederation\n",
" import logging\n",
"\n",
" logger = logging.getLogger()\n",
" logger.setLevel(logging.INFO)\n",
"\n",
" console_handler = logging.StreamHandler()\n",
" console_handler.setLevel(logging.INFO)\n",
"\n",
" formatter = logging.Formatter(\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\")\n",
" console_handler.setFormatter(formatter)\n",
"\n",
" logger.addHandler(console_handler)\n",
" computing = CSession(data_dir=\"./session_dir\")\n",
" return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))"
]
},
{
"cell_type": "markdown",
"id": "c75dbcda-1a40-421d-ab1b-92eca5600866",
"metadata": {},
"source": [
"### The DP based Strategy(InferDPT)\n",
"\n",
"As outlined in the [InferDPT tutorial](./), you can initialize the InferDPT client and server to facilitate secure and private inference. Prior to executing the InferDPT component, it is recommended to generate the InferDPT kit by following the step-by-step instructions provided in the tutorial.\n",
"\n",
"#### Client-Side Code\n",
"\n",
"On the client side, we load the pre-computed inferdpt-kit and deploy a local SLM as the decoding model."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ff0f317f-414f-4b9f-84e6-b992b31350cb",
"metadata": {},
"outputs": [],
"source": [
"from fate_llm.inference.api import APICompletionInference\n",
"from fate_llm.algo.inferdpt import inferdpt\n",
"from fate_llm.algo.inferdpt.utils import InferDPTKit\n",
"import sys\n",
"\n",
"arbiter = (\"arbiter\", 10000)\n",
"guest = (\"guest\", 10000)\n",
"host = (\"host\", 9999)\n",
"name = \"fed1\"\n",
"\n",
"def create_ctx(local):\n",
" from fate.arch import Context\n",
" from fate.arch.computing.backends.standalone import CSession\n",
" from fate.arch.federation.backends.standalone import StandaloneFederation\n",
" import logging\n",
"\n",
" logger = logging.getLogger()\n",
" logger.setLevel(logging.INFO)\n",
"\n",
" console_handler = logging.StreamHandler()\n",
" console_handler.setLevel(logging.INFO)\n",
"\n",
" formatter = logging.Formatter(\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\")\n",
" console_handler.setFormatter(formatter)\n",
"\n",
" logger.addHandler(console_handler)\n",
" computing = CSession(data_dir=\"./session_dir\")\n",
" return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))\n",
"\n",
"ctx = create_ctx(guest)\n",
"save_kit_path = 'your path'\n",
"kit = InferDPTKit.load_from_path(save_kit_path)\n",
"# local deployed small model as decoding model\n",
"inference = APICompletionInference(api_url=\"http://127.0.0.1:8887/v1\", model_name='./Qwen1.5-0.5B', api_key='EMPTY')\n",
"\n",
"test_example = {'id': 'Mercury_7220990',\n",
"'question': 'Which factor will most likely cause a person to develop a fever?',\n",
"'choices': {'text': ['a leg muscle relaxing after exercise',\n",
"'a bacterial population in the bloodstream',\n",
"'several viral particles on the skin',\n",
"'carbohydrates being digested in the stomach'],\n",
"'label': ['A', 'B', 'C', 'D']},\n",
"'answerKey': 'B'}\n",
"\n",
"\n",
"doc_template = \"\"\"{{question}} \n",
"Choices:{{choices.text}}\n",
"\"\"\"\n",
"\n",
"instruction_template=\"\"\"\n",
"<s>[INST]\n",
"Select Answer from Choices and explain it in \"Rationale\" with few words. Please refer to the example to write the rationale.\n",
"Use <end> to finish your rationle.\"\n",
"\n",
"Example(s):\n",
"Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\n",
"Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\n",
"Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction. Therefore, the answer is 'dry palms'.<end>\n",
"\n",
"Please explain:\n",
"Question:{{perturbed_doc}}\n",
"Rationale:\n",
"[/INST]\n",
"\"\"\"\n",
"\n",
"decode_template = \"\"\"Select Answer from Choices and explain it in \"Rationale\" with few words. Please refer to the example to write the rationale.\n",
"Use <end> to finish your rationle.\"\n",
"\n",
"Example(s):\n",
"Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\n",
"Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\n",
"Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction. Therefore, the answer is 'dry palms'.<end>\n",
"\n",
"Question:{{perturbed_doc}}\n",
"Rationale:{{perturbed_response | replace('\\n', '')}}<end>\n",
"\n",
"Please explain:\n",
"Question:{{question}} \n",
"Choices:{{choices.text}}\n",
"Rationale:\n",
"\"\"\"\n",
"\n",
"inferdpt_client = inferdpt.InferDPTClient(ctx, kit, inference, epsilon=3.0)\n",
"result = inferdpt_client.inference([test_example], doc_template, instruction_template, decode_template, \\\n",
" remote_inference_kwargs={\n",
" 'stop': ['<\\s>'],\n",
" 'temperature': 0.01,\n",
" 'max_tokens': 256\n",
" },\n",
" local_inference_kwargs={\n",
" 'stop': ['<|im_end|>', '<end>', '<end>\\n', '<end>\\n\\n', '.\\n\\n\\n\\n\\n', '<|end_of_text|>', '>\\n\\n\\n'],\n",
" 'temperature': 0.01,\n",
" 'max_tokens': 256\n",
" })\n",
"print('result is {}'.format(result[0]['inferdpt_result']))"
]
},
{
"cell_type": "markdown",
"id": "96fbcb01-6907-432f-8393-ae1746559c3a",
"metadata": {},
"source": [
"#### Server Side Code"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "960a476c-50a5-40fb-847d-02101cea27ae",
"metadata": {},
"outputs": [],
"source": [
"from fate_llm.algo.inferdpt.inferdpt import InferDPTServer\n",
"import sys\n",
"from fate_llm.inference.api import APICompletionInference\n",
"\n",
"arbiter = (\"arbiter\", 10000)\n",
"guest = (\"guest\", 10000)\n",
"host = (\"host\", 9999)\n",
"name = \"fed1\"\n",
"\n",
"def create_ctx(local):\n",
" from fate.arch import Context\n",
" from fate.arch.computing.backends.standalone import CSession\n",
" from fate.arch.federation.backends.standalone import StandaloneFederation\n",
" import logging\n",
"\n",
" logger = logging.getLogger()\n",
" logger.setLevel(logging.INFO)\n",
"\n",
" console_handler = logging.StreamHandler()\n",
" console_handler.setLevel(logging.INFO)\n",
"\n",
" formatter = logging.Formatter(\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\")\n",
" console_handler.setFormatter(formatter)\n",
"\n",
" logger.addHandler(console_handler)\n",
" computing = CSession(data_dir=\"./session_dir\")\n",
" return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))\n",
"\n",
"ctx = create_ctx(arbiter)\n",
"# Api to a LLM\n",
"inference_server = APICompletionInference(api_url=\"http://127.0.0.1:8888/v1\", model_name='./Mistral-7B-Instruct-v0.2', api_key='EMPTY')\n",
"inferdpt_server = InferDPTServer(ctx, inference_server)\n",
"inferdpt_server.inference()"
]
},
{
"cell_type": "markdown",
"id": "16f908a7-9187-461a-93db-9945456d502d",
"metadata": {},
"source": [
"Start two terminal and launch client&server scripts simultaneously. On the client side we can get the answer:\n",
"\n",
"```\n",
"The given question asks which factor will most likely cause a person to develop a fever. The factors mentioned are a leg muscle relaxing after exercise, a bacterial population in the bloodstream, several viral particles on the skin, and carbohydrates being digested in the stomach. The question is asking which factor is most likely to cause a person to develop a fever. The factors are all related to the body's internal environment, but the most likely factor is a bacterial population in the bloodstream. This is because bacteria can cause a fever, and the body's immune system responds to the infection by producing antibodies that can fight off the bacteria. Therefore, the answer is 'a bacterial population in the bloodstream'\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "fb36a485-2fa8-4629-a2cf-2d53fdbbcc5f",
"metadata": {},
"source": [
"### The Encoder-Decoder Model Strategy\n",
"\n",
"Similar to the InferDPT, we can initialize SLMEncoderDecoderClient and SLMEncoderDecoderServer to enable secure inference.\n",
"The client will encode the raw prompt using local slm model and then decoded it with the same model\n",
"\n",
"#### Client Side Code"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "cd174244-8640-4cb2-8609-ac6468f5a6f5",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from fate_llm.inference.api import APICompletionInference\n",
"from fate_llm.algo.fedcot.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderClient\n",
"\n",
"arbiter = (\"arbiter\", 10000)\n",
"guest = (\"guest\", 10000)\n",
"host = (\"host\", 9999)\n",
"name = \"fed1\"\n",
"\n",
"def create_ctx(local):\n",
" from fate.arch import Context\n",
" from fate.arch.computing.backends.standalone import CSession\n",
" from fate.arch.federation.backends.standalone import StandaloneFederation\n",
" import logging\n",
"\n",
" logger = logging.getLogger()\n",
" logger.setLevel(logging.INFO)\n",
"\n",
" console_handler = logging.StreamHandler()\n",
" console_handler.setLevel(logging.INFO)\n",
"\n",
" formatter = logging.Formatter(\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\")\n",
" console_handler.setFormatter(formatter)\n",
"\n",
" logger.addHandler(console_handler)\n",
" computing = CSession(data_dir=\"./session_dir\")\n",
" return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))\n",
"\n",
"\n",
"test_example = {'id': 'Mercury_7220990',\n",
"'question': 'Which factor will most likely cause a person to develop a fever?',\n",
"'choices': {'text': ['a leg muscle relaxing after exercise',\n",
"'a bacterial population in the bloodstream',\n",
"'several viral particles on the skin',\n",
"'carbohydrates being digested in the stomach'],\n",
"'label': ['A', 'B', 'C', 'D']},\n",
"'answerKey': 'B'\n",
"}\n",
"\n",
"\n",
"encode_prompt = \"\"\"Disrupt the main words in the original text so that it becomes difficult to recognize, but at the same time, try to maintain the original meaning as much as possible. Use <end> to end your reply.\n",
"Origin Doc:Question:{{question}}\n",
"Choices:{{choices.text}}\n",
"Perturb Doc: \n",
"\"\"\"\n",
"\n",
"decode_prompt = \"\"\"This is a perturbed question and its corresponding answer(rationale). And following is the original question. Try to recover the correct rationale from docs provided.\n",
"\n",
"Perturbed doc and rationale:\n",
"{{perturbed_doc}}\n",
"Rationale:{{perturbed_response}}\n",
"\n",
"Original Doc:\n",
"Question:{{question}}\n",
"Choices:{{choices.text}}\n",
"\n",
"Recover Rationale:\n",
"\"\"\"\n",
"\n",
"instruction_template = \"\"\"<|im_start|>system\n",
"You are a helpful assistant<|im_end|>\n",
"<|im_start|>user\n",
"Select Answer from Choices and explain it in \"Rationale\" with few words. Please refer to the example to write the rationale.\n",
"Use <end> to finish your rationle.\n",
"\n",
"Example(s):\n",
"Question:Which factor will most likely cause a person to develop a fever?\n",
"Choices:['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach']\n",
"Rationale:A bacterial infection in the bloodstream triggers the immune system to respond, therefore often causing a fever as the body tries to fight off the bacteria. Therefore, the answer is 'a bacterial population in the bloodstream'\n",
"\n",
"Please explain:\n",
"{{perturbed_doc}}\n",
"Rationale:\n",
"<|im_end|>\n",
"<|im_start|>assistant\n",
"\"\"\"\n",
"\n",
"ctx = create_ctx(guest)\n",
"model_name = 'Deploy your encoder decoder model'\n",
"# api_url to your locally deployed encoder decoder\n",
"api = APICompletionInference(api_url='http://127.0.0.1:8887/v1', api_key='EMPTY', model_name=model_name)\n",
"client = SLMEncoderDecoderClient(ctx, api)\n",
"result = client.inference([test_example], encode_prompt, instruction_template, decode_prompt, \\\n",
" remote_inference_kwargs={\n",
" 'stop': ['<\\s>'],\n",
" 'temperature': 0.01,\n",
" 'max_tokens': 256\n",
" },\n",
" local_inference_kwargs={\n",
" 'stop': ['<|im_end|>', '<end>', '<end>\\n', '<end>\\n\\n', '.\\n\\n\\n\\n\\n', '<|end_of_text|>', '>\\n\\n\\n'],\n",
" 'temperature': 0.01,\n",
" 'max_tokens': 256\n",
" })\n",
"print('result is {}'.format(result[0]['inferdpt_result']))"
]
},
{
"cell_type": "markdown",
"id": "1a865536-7814-40a2-a814-d00e46f2787f",
"metadata": {},
"source": [
"#### Server Side Code"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "cced44b0-0dcb-4427-8efe-a04135b246ac",
"metadata": {},
"outputs": [],
"source": [
"from fate_llm.inference.api import APICompletionInference\n",
"from fate_llm.algo.fedcot.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderServer\n",
"\n",
"arbiter = (\"arbiter\", 10000)\n",
"guest = (\"guest\", 10000)\n",
"host = (\"host\", 9999)\n",
"name = \"fed1\"\n",
"\n",
"def create_ctx(local):\n",
" from fate.arch import Context\n",
" from fate.arch.computing.backends.standalone import CSession\n",
" from fate.arch.federation.backends.standalone import StandaloneFederation\n",
" import logging\n",
"\n",
" logger = logging.getLogger()\n",
" logger.setLevel(logging.INFO)\n",
"\n",
" console_handler = logging.StreamHandler()\n",
" console_handler.setLevel(logging.INFO)\n",
"\n",
" formatter = logging.Formatter(\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\")\n",
" console_handler.setFormatter(formatter)\n",
"\n",
" logger.addHandler(console_handler)\n",
" computing = CSession(data_dir=\"./session_dir\")\n",
" return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))\n",
"\n",
"ctx = create_ctx(arbiter)\n",
"# api url&name are depolyed LLM\n",
"model_name = '/data/cephfs/llm/models/Qwen1.5-14B-Chat/'\n",
"api = APICompletionInference(api_url='http://127.0.0.1:8888/v1', api_key='EMPTY', model_name=model_name)\n",
"server = SLMEncoderDecoderServer(ctx, api)\n",
"server.inference()"
]
},
{
"cell_type": "markdown",
"id": "c38ed7a6-2eb2-4f46-b59c-eaafcc9a5b7a",
"metadata": {},
"source": [
"Start two terminal and launch client&server scripts simultaneously. On the client side we can get the answer:\n",
"\n",
"```\n",
"A fever is typically caused by a bacterial population in the bloodstream, as it is a response to an infection. So the answer is 'a bacterial population in the bloodstream'.\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "41fbbefd-e931-4e95-9d28-9675ff7865a3",
"metadata": {},
"source": [
"## Prefix Dataset & FedCoT Trainer\n",
"\n",
"Now that we can carry out privacy-preserving inference and acquire rationales, the next step is to train a new task-specific model, enhanced by the rationales generated by the LLMs.\n",
"\n",
"In this section, we will introduce the PrefixDataset and FedCoTTrainer, which facilitate training tasks with the added benefit of supplementary rationales. The PrefixDataset allows you to assign various text prefixes, guiding the model to produce different text targets. With FedCoTTrainer, the model is trained to generate both text labels and text rationales at each update step, ultimately leading to superior performance compared to training on the raw dataset alone.\n",
"\n",
"### Prepare dataset\n",
"In this tutorial, we will use the arc-easy dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e25377d0-1a7e-4e8c-aa9f-3bcb03ae0c45",
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"dataset = load_dataset(\"arc_easy\")\n",
"dataset.save_to_disk('path_to_save/arce')"
]
},
{
"cell_type": "markdown",
"id": "9166110f-bf67-4bf1-9da8-04c16bd79423",
"metadata": {},
"source": [
"Let’s proceed with testing the PrefixDataset. We can utilize Jinja2 templates to structure the text and append prefixes or suffixes to our training data.\n",
"\n",
"Please note that at this stage, the dataset does not contain rationales. In the 'rationale_output_template', the key used for the inference results is ‘infer_result’. We can perform secure inference using the FedCoTTrainer and then integrate the rationale results, keyed as ‘infer_result’, into the PrefixDataset."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "fdbd93d6-45f3-404f-813e-9ca1fd6def04",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
}
],
"source": [
"from fate_llm.dataset.fedcot_dataset import PrefixDataset\n",
"\n",
"pds = PrefixDataset(\n",
" tokenizer_path='/data/cephfs/llm/models/Qwen1.5-0.5B/',\n",
" predict_input_template=\"\"\"Predict:\n",
"Question:{{question}}\n",
"Choices:{{choices.text}}\n",
"Answer:\n",
" \"\"\",\n",
" predict_output_template=\"\"\"{{choices.text[choices.label.index(answerKey)]}}<end>\"\"\",\n",
" rationale_input_template=\"\"\"Explain:\n",
"Question:{{question}}\n",
"Choices:{{choices.text}}\n",
"Rationale:\n",
" \"\"\",\n",
" rationale_output_template=\"\"\"{{infer_result}}<end>\"\"\",\n",
" max_input_length=128,\n",
" max_target_length=128,\n",
" split_key='train'\n",
" )\n",
"\n",
"\n",
"pds.load('path_to_save/arce')"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "100eeb69-8bd2-4e66-b1cc-667f95e47f23",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'id': 'Mercury_7220990',\n",
" 'question': 'Which factor will most likely cause a person to develop a fever?',\n",
" 'choices': {'text': ['a leg muscle relaxing after exercise',\n",
" 'a bacterial population in the bloodstream',\n",
" 'several viral particles on the skin',\n",
" 'carbohydrates being digested in the stomach'],\n",
" 'label': ['A', 'B', 'C', 'D']},\n",
" 'answerKey': 'B'}"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pds.dataset[0] # the structure is the same as hf dataset"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "6f0356ef-f94b-41db-ab66-b1d0eb862eca",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'predict': {'input': \"Predict:\\nQuestion:Which factor will most likely cause a person to develop a fever?\\nChoices:['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach']\\nAnswer:\\n \",\n",
" 'output': 'a bacterial population in the bloodstream<end>'},\n",
" 'rationale': {'input': \"Explain:\\nQuestion:Which factor will most likely cause a person to develop a fever?\\nChoices:['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach']\\nRationale:\\n \",\n",
" 'output': '<end>\\n '}}"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pds.get_str_item(0) # we can see that the output of rationale term is empty"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "6a227af7-f24a-46bd-9af7-78584a381b33",
"metadata": {},
"outputs": [],
"source": [
"print(pds[0]) # show tokenized, for the sake of breif we dont show it in this tutorial doc"
]
},
{
"cell_type": "markdown",
"id": "e0382a33-7a45-43a3-8ed3-58ed1d1b07d8",
"metadata": {},
"source": [
"### The FedCoTTrainer\n",
"\n",
"Here we introduce the FedCoTTrainer which is develop based on Huggingface trainer and supports collaboratively training a task with raw labels and additional rationales. Here show how the compute loss function is realized:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b40b7d99-9ef8-43f9-8e28-db96d96af62a",
"metadata": {},
"outputs": [],
"source": [
"def compute_loss(self, model, inputs, return_outputs=False):\n",
"\n",
" label_outputs = model(**inputs['predict'])\n",
" cot_outputs = model(**inputs['rationale'])\n",
" loss = self.alpha * cot_outputs.loss + (1. - self.alpha) * label_outputs.loss\n",
" return (loss, {'rationale_loss': cot_outputs, 'predict_loss': label_outputs}) if return_outputs else loss"
]
},
{
"cell_type": "markdown",
"id": "ff1cee5d-68e1-4caf-96b9-132b27b46dca",
"metadata": {},
"source": [
"You have the option to choose from three distinct modes: ‘infer_only’, ‘train_only’, and ‘infer_and_train’, to meet your specific requirements.\n",
"- infer_only: Only generate the rationales and they will be saved to the output_dir\n",
"- train_only: Local training only\n",
"- infer_and_train: Generate rationales, and then load them into PrefixDataset and start training\n",
" \n",
"In this instance, we will opt for the ‘infer_and_train’ mode to initially generate rationales with the assistance of the remote LLM. To activate the inference process, it is necessary to initialize the infer client and server for both the client-side and server-side trainers, as demonstrated in the preceding sections.\n",
"\n",
"Below is an FedCoT example. We ran this example on a machine equipped with 4 V100-32G GPUs. We launch the client script using deepspeed. LLM is depolyed on another machine."
]
},
{
"cell_type": "markdown",
"id": "c559341a-d133-4a24-8f1a-35cd6d2a26d3",
"metadata": {},
"source": [
"## FedCoT Example\n",
"\n",
"### Client Script(deepspeed_run.py)\n",
"\n",
"This script show how to setup a fedcot task on the client side."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e4710fda-904a-4e90-bc65-beec7594703f",
"metadata": {},
"outputs": [],
"source": [
"import logging\n",
"import os\n",
"import sys\n",
"from transformers import (\n",
" AutoTokenizer,\n",
" HfArgumentParser,\n",
" Seq2SeqTrainingArguments,\n",
")\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"from typing import List\n",
"from fate_llm.algo.inferdpt.utils import InferDPTKit\n",
"from fate_llm.dataset.fedcot_dataset import PrefixDataset\n",
"from fate_llm.algo.fedcot.fedcot_trainer import FedCoTTrainerClient\n",
"from fate_llm.data.data_collator.fedcot_collator import PrefixDataCollator\n",
"from fate_llm.algo.inferdpt import inferdpt\n",
"\n",
"\n",
"arbiter = (\"arbiter\", 10000)\n",
"guest = (\"guest\", 10000)\n",
"host = (\"host\", 9999)\n",
"name = \"fed1\"\n",
"\n",
"\n",
"def create_ctx(local):\n",
" from fate.arch import Context\n",
" from fate.arch.computing.backends.standalone import CSession\n",
" from fate.arch.federation.backends.standalone import StandaloneFederation\n",
" import logging\n",
"\n",
" logger = logging.getLogger()\n",
" logger.setLevel(logging.INFO)\n",
"\n",
" console_handler = logging.StreamHandler()\n",
" console_handler.setLevel(logging.INFO)\n",
"\n",
" formatter = logging.Formatter(\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\")\n",
" console_handler.setFormatter(formatter)\n",
"\n",
" logger.addHandler(console_handler)\n",
" computing = CSession(data_dir=\"./session_dir\")\n",
" return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))\n",
"\n",
"\n",
"logger = logging.getLogger(__name__)\n",
"\n",
"\n",
"doc_template = \"\"\"{{question}} \n",
"Choices:{{choices.text}}\n",
"\"\"\"\n",
"\n",
"instruction_template=\"\"\"\n",
"<s>[INST]\n",
"Select Answer from Choices and explain it in \"Rationale\" with few words. Please refer to the example to write the rationale.\n",
"Use <end> to finish your rationle.\"\n",
"\n",
"Example(s):\n",
"Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\n",
"Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\n",
"Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction. Therefore, the answer is 'dry palms'.<end>\n",
"\n",
"Please explain:\n",
"Question:{{perturbed_doc}}\n",
"Rationale:\n",
"[/INST]\n",
"\"\"\"\n",
"\n",
"decode_template = \"\"\"Select Answer from Choices and explain it in \"Rationale\" with few words. Please refer to the example to write the rationale.\n",
"Use <end> to finish your rationle.\"\n",
"\n",
"Example(s):\n",
"Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\n",
"Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\n",
"Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction. Therefore, the answer is 'dry palms'.<end>\n",
"\n",
"Question:{{perturbed_doc}}\n",
"Rationale:{{perturbed_response | replace('\\n', '')}}<end>\n",
"\n",
"Please explain:\n",
"Question:{{question}} \n",
"Choices:{{choices.text}}\n",
"Rationale:\n",
"\"\"\"\n",
" \n",
"\n",
"if __name__ == \"__main__\":\n",
" \n",
" parser = HfArgumentParser(Seq2SeqTrainingArguments)\n",
" if len(sys.argv) == 2 and sys.argv[1].endswith(\".json\"):\n",
" training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]\n",
" else:\n",
" training_args = parser.parse_args_into_dataclasses()[0]\n",
"\n",
" model_path = '/data/cephfs/llm/models/Qwen1.5-0.5B/'\n",
" pds = PrefixDataset(\n",
" tokenizer_path=model_path,\n",
" predict_input_template=\"\"\"Predict:\n",
"Question:{{question}}\n",
"Choices:{{choices.text}}\n",
"Answer:\n",
" \"\"\",\n",
" predict_output_template=\"\"\"{{choices.text[choices.label.index(answerKey)]}}<end>\"\"\",\n",
" rationale_input_template=\"\"\"Explain:\n",
"Question:{{question}}\n",
"Choices:{{choices.text}}\n",
"Rationale:\n",
" \"\"\",\n",
" rationale_output_template=\"\"\"{{infer_result}}<end>\n",
" \"\"\",\n",
" max_input_length=128,\n",
" max_target_length=128,\n",
" split_key='train'\n",
" )\n",
" pds.load('/data/cephfs/llm/datasets/arce/')\n",
" \n",
" model = AutoModelForCausalLM.from_pretrained(model_path).half().cuda()\n",
" tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
" model.gradient_checkpointing_enable()\n",
" model.enable_input_require_grads()\n",
"\n",
" ctx = create_ctx(guest)\n",
" if training_args.local_rank == 0:\n",
" # only rank 0 need to load infer instance\n",
" save_kit_path = 'your path'\n",
" kit = InferDPTKit.load_from_path(save_kit_path)\n",
" # local deployed small model as decoding model\n",
" from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n",
" inference = APICompletionInference(api_url=\"http://xxxx/v1\", model_name='./Qwen1.5-0.5B', api_key='EMPTY')\n",
" client = inferdpt.InferDPTClient(ctx, kit, inference, epsilon=3.0)\n",
" else:\n",
" client = None\n",
" \n",
" trainer = FedCoTTrainerClient(\n",
" ctx=ctx,\n",
" model=model,\n",
" training_args=training_args,\n",
" tokenizer=tokenizer, \n",
" train_set=pds,\n",
" data_collator=PrefixDataCollator(tokenizer),\n",
" mode='infer_and_train',\n",
" infer_client=client,\n",
" encode_template=doc_template,\n",
" decode_template=decode_template,\n",
" instruction_template=instruction_template,\n",
" remote_inference_kwargs={\n",
" 'stop': ['<\\s>'],\n",
" 'temperature': 0.01,\n",
" 'max_tokens': 256\n",
" },\n",
" local_inference_kwargs={\n",
" 'stop': ['<|im_end|>', '<end>', '<end>\\n', '<end>\\n\\n', '.\\n\\n\\n\\n\\n', '<|end_of_text|>', '>\\n\\n\\n'],\n",
" 'temperature': 0.01,\n",
" 'max_tokens': 256\n",
" }\n",
" )\n",
"\n",
" trainer.train()\n",
"\n",
" if training_args.local_rank == 0:\n",
" model.save_pretrained(training_args.output_dir)\n",
" tokenizer.save_pretrained(training_args.output_dir)"
]
},
{
"cell_type": "markdown",
"id": "962dd399-1dec-4164-bd86-15aa8550c50b",
"metadata": {},
"source": [
"### Server Script(server.py)\n",
"\n",
"This script show how to setup a fedcot task on the server side."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "91b42972-5308-4ccf-a768-f7dfa087313e",
"metadata": {},
"outputs": [],
"source": [
"from fate_llm.algo.inferdpt.inferdpt import InferDPTServer\n",
"from fate_llm.algo.fedcot.fedcot_trainer import FedCoTTraineServer\n",
"import sys\n",
"\n",
"\n",
"arbiter = (\"arbiter\", 10000)\n",
"guest = (\"guest\", 10000)\n",
"host = (\"host\", 9999)\n",
"name = \"fed1\"\n",
"\n",
"\n",
"def create_ctx(local):\n",
" from fate.arch import Context\n",
" from fate.arch.computing.backends.standalone import CSession\n",
" from fate.arch.federation.backends.standalone import StandaloneFederation\n",
" import logging\n",
"\n",
" logger = logging.getLogger()\n",
" logger.setLevel(logging.INFO)\n",
"\n",
" console_handler = logging.StreamHandler()\n",
" console_handler.setLevel(logging.INFO)\n",
"\n",
" formatter = logging.Formatter(\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\")\n",
" console_handler.setFormatter(formatter)\n",
"\n",
" logger.addHandler(console_handler)\n",
" computing = CSession(data_dir=\"./session_dir\")\n",
" return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))\n",
"\n",
"\n",
"from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n",
"api = APICompletionInference(api_url='http://xxxx:8080/v1', api_key='EMPTY', model_name='/data/cephfs/llm/models/Qwen1.5-14B-Chat')\n",
"\n",
"ctx = create_ctx(arbiter)\n",
"server_api = InferDPTServer(ctx, api)\n",
"server = FedCoTTraineServer(ctx, server_api)\n",
"server.train()"
]
},
{
"cell_type": "markdown",
"id": "125dd68e-c7d4-41aa-9972-4881b1330fb6",
"metadata": {},
"source": [
"### Start script\n",
"\n",
"You can launch client side training with following script:\n",
"\n",
"```\n",
"deepspeed --num_nodes 1 --num_gpus 4 deepspeed_run.py \\\n",
" --output_dir \"./\" \\\n",
" --per_device_train_batch_size \"1\" \\\n",
" --gradient_accumulation_steps \"8\" \\\n",
" --max_steps \"750\" \\\n",
" --fp16 \\\n",
" --logging_steps 10 \\\n",
" --save_only_model \\\n",
" --deepspeed \"./ds_config.json\" \n",
"```"
]
},
{
"cell_type": "markdown",
"id": "0b506c1c-51f4-448d-9b0b-adf1a71cc7cf",
"metadata": {},
"source": [
"and the ds_config.json is\n",
"```\n",
"{ \n",
" \"train_micro_batch_size_per_gpu\": 1,\n",
" \"gradient_accumulation_steps\": 8,\n",
" \"optimizer\": {\n",
" \"type\": \"AdamW\",\n",
" \"params\": {\n",
" \"lr\": 5e-5\n",
" }\n",
" },\n",
" \"fp16\": {\n",
" \"enabled\": true\n",
" },\n",
" \"zero_optimization\": {\n",
" \"stage\": 0\n",
" }\n",
"}\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "613fbfb6-ac9e-485b-8587-ffef1e2361c1",
"metadata": {},
"source": [
"And server side:"
]
},
{
"cell_type": "markdown",
"id": "5b50adf0-8f9c-40e5-9a7d-40a70e30a420",
"metadata": {},
"source": [
"```python server.py```"
]
},
{
"cell_type": "markdown",
"id": "28a5de71-25fd-4042-a6b7-0ec2c505eaee",
"metadata": {},
"source": [
"## FedCoT Pipeline Example\n",
"\n",
"You have the capability to submit a FedCoT task within the FATE pipeline. By appropriately configuring the necessary settings, you can execute FedCoT in a production environment."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "52f1e19b-da8e-4977-adb1-42fb84dee407",
"metadata": {},
"outputs": [],
"source": [
"from fate_client.pipeline.components.fate.nn.loader import Loader\n",
"import argparse\n",
"from fate_client.pipeline.utils import test_utils\n",
"from fate_client.pipeline.components.fate.reader import Reader\n",
"from fate_client.pipeline import FateFlowPipeline\n",
"\n",
"\n",
"def main(config=\"../../config.yaml\", namespace=\"\"):\n",
" # obtain config\n",
" if isinstance(config, str):\n",
" config = test_utils.load_job_config(config)\n",
" parties = config.parties\n",
" guest = '9999'\n",
" host = parties.host[0]\n",
" arbiter = '10000'\n",
"\n",
" pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter)\n",
"\n",
" reader_0 = Reader(\"reader_0\", runtime_parties=dict(guest=guest))\n",
" reader_0.guest.task_parameters(\n",
" namespace=\"experiment\",\n",
" name=\"arc_e_example\"\n",
" )\n",
"\n",
" model_conf = Loader(module_name='fate_llm.model_zoo.hf_model', item_name='HFAutoModelForCausalLM', \n",
" pretrained_model_name_or_path='/data/cephfs/llm/models/Qwen1.5-0.5B/').to_dict()\n",
" data_collator_conf = Loader(module_name='fate_llm.data.data_collator.fedcot_collator', item_name='get_prefix_data_collator', tokenizer_name_or_path='/data/cephfs/llm/models/Qwen1.5-0.5B/').to_dict()\n",
"\n",
" infer_init_conf_client = {\n",
" 'module_name': 'fate_llm.algo.inferdpt.init.default_init',\n",
" 'item_name': 'InferDPTAPIClientInit'\n",
" }\n",
"\n",
" infer_init_conf_server = {\n",
" 'module_name': 'fate_llm.algo.inferdpt.init.default_init',\n",
" 'item_name': 'InferDPTAPIServerInit'\n",
" }\n",
"\n",
" dataset_conf = {\n",
" 'module_name': 'fate_llm.dataset.fedcot_dataset',\n",
" 'item_name': 'PrefixDataset',\n",
" 'kwargs':dict(\n",
" tokenizer_path='/data/cephfs/llm/models/Qwen1.5-0.5B/',\n",
" predict_input_template=\"\"\"Predict:\n",
" Question:{{question}}\n",
" Choices:{{choices.text}}\n",
" \"\"\",\n",
" predict_output_template=\"\"\"{{choices.text[choices.label.index(answerKey)]}}<end>\"\"\",\n",
" rationale_input_template=\"\"\"Explain:\n",
" Question:{{question}}\n",
" Choices:{{choices.text}}\n",
" \"\"\",\n",
" rationale_output_template=\"\"\"{{infer_result}}<end>\n",
" \"\"\",\n",
" max_input_length=128,\n",
" max_target_length=128,\n",
" split_key='train'\n",
" )\n",
" }\n",
"\n",
" encoder_prompt = \"\"\"{{question}}\n",
"Choices:{{choices.text}}\n",
"\"\"\"\n",
"\n",
" decoder_prompt = \"\"\"Select Answer from Choices and explain it in \"Rationale\" with few words. Please refer to the example to write the rationale.Use <end> to finish your rationle.\n",
"\n",
"Example(s):\n",
"Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\n",
"Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\n",
"Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction. Therefore, the answer is 'dry palms'.<end>\n",
"\n",
"Question:{{perturbed_doc}}\n",
"Rationale:{{perturbed_response | replace('\\n', '')}}<end>\n",
"\n",
"Please explain:\n",
"Question:{{question}} \n",
"Choices:{{choices.text}}\n",
" \"\"\"\n",
"\n",
" instruction_prompt = \"\"\"<|im_start|>system\n",
"You are a helpful assistant<|im_end|>\n",
"<|im_start|>user\n",
"Select Answer from Choices and explain it in \"Rationale\" with few words. Please refer to the example to write the rationale.\n",
"Use <end> to finish your rationle.\n",
"\n",
"Example(s):\n",
"Question:Which factor will most likely cause a person to develop a fever?\n",
"Choices:['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach']\n",
"Rationale:A bacterial infection in the bloodstream triggers the immune system to respond, therefore often causing a fever as the body tries to fight off the bacteria. Therefore, the answer is 'a bacterial population in the bloodstream'\n",
"\n",
"Please explain:\n",
"Question:{{perturbed_doc}}\n",
"Rationale:\n",
"<|im_end|>\n",
"<|im_start|>assistant\n",
" \"\"\"\n",
"\n",
" remote_inference_kwargs={\n",
" 'stop': ['<|im_end|>', '<end>', '<end>\\n', '<end>\\n\\n', '.\\n\\n\\n\\n\\n', '<|end_of_text|>', '>\\n\\n\\n'],\n",
" 'temperature': 0.01,\n",
" 'max_tokens': 256\n",
" }\n",
"\n",
" local_inference_kwargs={\n",
" 'stop': ['<|im_end|>', '<end>', '<end>\\n', '<end>\\n\\n', '.\\n\\n\\n\\n\\n', '<|end_of_text|>', '>\\n\\n\\n'],\n",
" 'temperature': 0.01,\n",
" 'max_tokens': 256\n",
" }\n",
"\n",
" ds_config = { \n",
" \"train_micro_batch_size_per_gpu\": 1,\n",
" \"gradient_accumulation_steps\": 8,\n",
" \"optimizer\": {\n",
" \"type\": \"AdamW\",\n",
" \"params\": {\n",
" \"lr\": 5e-5\n",
" }\n",
" },\n",
" \"fp16\": {\n",
" \"enabled\": True\n",
" },\n",
" \"zero_optimization\": {\n",
" \"stage\": 0\n",
" }\n",
" }\n",
"\n",
" training_args_dict = dict(\n",
" per_device_train_batch_size=1, \n",
" gradient_accumulation_steps=8,\n",
" logging_steps=10,\n",
" max_steps=30,\n",
" fp16=True,\n",
" log_level='debug'\n",
" )\n",
"\n",
" mode = 'infer_and_train'\n",
"\n",
" client_conf = dict(\n",
" model_conf=model_conf,\n",
" dataset_conf=dataset_conf,\n",
" training_args_conf=training_args_dict,\n",
" data_collator_conf=data_collator_conf,\n",
" mode=mode,\n",
" infer_inst_init_conf=infer_init_conf_client,\n",
" encode_template=encoder_prompt,\n",
" instruction_template=instruction_prompt,\n",
" decode_template=decoder_prompt,\n",
" remote_inference_kwargs=remote_inference_kwargs,\n",
" local_inference_kwargs=local_inference_kwargs,\n",
" perturb_doc_key='perturbed_doc',\n",
" perturbed_response_key='perturbed_response',\n",
" result_key='infer_result'\n",
" )\n",
"\n",
" server_conf = dict(\n",
" infer_inst_init_conf=infer_init_conf_server,\n",
" mode=mode\n",
" )\n",
"\n",
" homo_nn_0 = HomoNN(\n",
" 'nn_0',\n",
" train_data=reader_0.outputs[\"output_data\"],\n",
" runner_module=\"fedcot_runner\",\n",
" runner_class=\"FedCoTRunner\"\n",
" )\n",
"\n",
" homo_nn_0.guest.task_parameters(runner_conf=client_conf)\n",
" homo_nn_0.arbiter.task_parameters(runner_conf=server_conf)\n",
"\n",
" homo_nn_0.guest.conf.set(\"launcher_name\", \"deepspeed\")\n",
"\n",
" pipeline.add_tasks([reader_0, homo_nn_0])\n",
" pipeline.conf.set(\"task\", dict(engine_run={\"cores\": 4}))\n",
" pipeline.compile()\n",
" pipeline.fit()\n",
"\n",
"if __name__ == \"__main__\":\n",
" parser = argparse.ArgumentParser(\"PIPELINE DEMO\")\n",
" parser.add_argument(\"--config\", type=str, default=\"../config.yaml\",\n",
" help=\"config file\")\n",
" parser.add_argument(\"--namespace\", type=str, default=\"\",\n",
" help=\"namespace for data stored in FATE\")\n",
" args = parser.parse_args()\n",
" main(config=args.config, namespace=args.namespace)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: doc/tutorial/fedkseed/README.md
================================================
## FedKSeed
The Algorithm is based on the paper: [Federated Full-Parameter Tuning of Billion-Sized Language Models
with Communication Cost under 18 Kilobytes](https://arxiv.org/pdf/2312.06353.pdf) and the code is adaptor
from the https://github.com/alibaba/FederatedScope/tree/FedKSeed.
We refactor the code to make it more compatible with (transformers/PyTorch) framework
and integrate it into the FATE-LLM framework.
The main works include:
1. An KSeedZerothOrderOptimizer class that can be used to optimize model along given direction that generated with random seed.
2. An KSeedZOExtendedTrainer subclass of Trainer from transformers that can be used to train large language models with KSeedZerothOrderOptimizer.
3. Trainers for federated learning with large language models.
================================================
FILE: doc/tutorial/fedkseed/fedkseed-example.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Federated Tuning with FedKSeed methods in FATE-LLM"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this tutorial, we will demonstrate how to efficiently train federated large language models using the FATE-LLM framework. In FATE-LLM, we introduce the \"FedKSeed\" module, specifically designed for federated learning with large language models. The Idea of FedKSeed is to use Zeroth-Order-Optimizer to optimize model along given direction that generated with random seed. This method can be used to train large language models in a federated learning setting with extremely low communication cost.\n",
"\n",
"The Algorithm is based on the paper: [Federated Full-Parameter Tuning of Billion-Sized Language Models\n",
"with Communication Cost under 18 Kilobytes](https://arxiv.org/pdf/2312.06353.pdf) and the code is modified from the https://github.com/alibaba/FederatedScope/tree/FedKSeed. We refactor the code to make it more compatible with (transformers/PyTorch) framework and integrate it into the FATE-LLM framework.\n",
"\n",
"The main works include:\n",
"1. An KSeedZerothOrderOptimizer class that can be used to optimize model along given direction that generated with random seed.\n",
"2. An KSeedZOExtendedTrainer subclass of Trainer from transformers that can be used to train large language models with KSeedZerothOrderOptimizer.\n",
"3. Trainers for federated learning with large language models.\n",
"\n",
"In this tutorial, we will demonstrate how to use the FedKSeed method to train a large language model in a federated learning setting. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model: datajuicer/LLaMA-1B-dj-refine-150B\n",
"\n",
"This is the introduction from the Huggingface model hub: [datajuicer/LLaMA-1B-dj-refine-150B](https://huggingface.co/datajuicer/LLaMA-1B-dj-refine-150B)\n",
"\n",
"> The model architecture is LLaMA-1.3B and we adopt the OpenLLaMA implementation. The model is pre-trained on 150B tokens of Data-Juicer's refined RedPajama and Pile. It achieves an average score of 34.21 over 16 HELM tasks, beating Falcon-1.3B (trained on 350B tokens from RefinedWeb), Pythia-1.4B (trained on 300B tokens from original Pile) and Open-LLaMA-1.3B (trained on 150B tokens from original RedPajama and Pile).\n",
"\n",
"> For more details, please refer to our [paper](https://arxiv.org/abs/2309.02033).\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2024-02-29T09:27:23.512735Z",
"start_time": "2024-02-29T09:27:23.508790Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"# model_name_or_path = \"datajuicer/LLaMA-1B-dj-refine-150B\"\n",
"model_name_or_path = \"gpt2\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dataset: databricks/databricks-dolly-15k\n",
"\n",
"This is the introduction from the Huggingface dataset hub: [databricks/databricks-dolly-15k](https://huggingface.co/dataset/databricks/databricks-dolly-15k)\n",
"\n",
"> databricks-dolly-15k is a corpus of more than 15,000 records generated by thousands of Databricks employees to enable large language models to exhibit the magical interactivity of ChatGPT. Databricks employees were invited to create prompt / response pairs in each of eight different instruction categories, including the seven outlined in the InstructGPT paper, as well as an open-ended free-form category. The contributors were instructed to avoid using information from any source on the web with the exception of Wikipedia (for particular subsets of instruction categories), and explicitly instructed to avoid using generative AI in formulating instructions or responses. Examples of each behavior were provided to motivate the types of questions and instructions appropriate to each category\n",
"\n",
"To use this dataset, you first need to download it from the Huggingface dataset hub:\n",
"\n",
"```bash\n",
"mkdir -p ../../../examples/data/dolly && cd ../../../examples/data/dolly && wget wget https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl\\?download\\=true -O databricks-dolly-15k.jsonl\n",
"```\n",
"\n",
"### Check Dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2024-02-29T09:27:26.987779Z",
"start_time": "2024-02-29T09:27:24.706218Z"
}
},
"outputs": [],
"source": [
"from fate_llm.dataset.hf_dataset import Dolly15K\n",
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name_or_path)\n",
"special_tokens = tokenizer.special_tokens_map\n",
"if \"pad_token\" not in tokenizer.special_tokens_map:\n",
" special_tokens[\"pad_token\"] = special_tokens[\"eos_token\"]\n",
"\n",
"tokenizer.pad_token = tokenizer.eos_token\n",
"ds = Dolly15K(split=\"train\", tokenizer_params={\"pretrained_model_name_or_path\": model_name_or_path, **special_tokens},\n",
" tokenizer_apply_params=dict(truncation=True, max_length=tokenizer.model_max_length, padding=\"max_length\", return_tensors=\"pt\"))\n",
"ds = ds.load('../../../examples/data/dolly')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2024-02-29T09:27:27.875025Z",
"start_time": "2024-02-29T09:27:27.867839Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['instruction', 'context', 'response', 'category', 'text', 'input_ids', 'attention_mask'],\n",
" num_rows: 15011\n",
"})"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ds"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For more details of FATE-LLM dataset setting, we recommend that you read through these tutorials first: [NN Dataset Customization](https://github.com/FederatedAI/FATE/blob/master/doc/tutorial/pipeline/nn_tutorial/Homo-NN-Customize-your-Dataset.ipynb), [Some Built-In Dataset](https://github.com/FederatedAI/FATE/blob/master/doc/tutorial/pipeline/nn_tutorial/Introduce-Built-In-Dataset.ipynb),"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Check local training\n",
"\n",
"Before submitting a federated learning task, we will demonstrate how to perform local testing to ensure the proper functionality of your custom dataset, model. "
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"ExecuteTime": {
"end_time": "2024-02-29T09:38:33.175079Z",
"start_time": "2024-02-29T09:38:33.168844Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"from transformers import AutoModelForCausalLM, TrainingArguments, DataCollatorForLanguageModeling\n",
"from fate_llm.algo.fedkseed.trainer import KSeedZOExtendedTrainer, KSeedTrainingArguments\n",
"from fate_llm.algo.fedkseed.zo_utils import build_seed_candidates, get_even_seed_probabilities\n",
"\n",
"def test_training(zo_mode=True):\n",
" tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **special_tokens)\n",
" data_collector = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)\n",
" model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_name_or_path)\n",
"\n",
" training_args = TrainingArguments(output_dir='./',\n",
" dataloader_num_workers=1,\n",
" dataloader_prefetch_factor=1,\n",
" remove_unused_columns=True,\n",
" learning_rate=1e-5,\n",
" per_device_train_batch_size=1,\n",
" num_train_epochs=0.01,\n",
" )\n",
" kseed_args = KSeedTrainingArguments(zo_optim=zo_mode)\n",
" trainer = KSeedZOExtendedTrainer(model=model, train_dataset=ds, training_args=training_args, kseed_args=kseed_args,\n",
" tokenizer=tokenizer, data_collator=data_collector)\n",
" if zo_mode:\n",
" seed_candidates = build_seed_candidates(k=kseed_args.k)\n",
" seed_probabilities = get_even_seed_probabilities(k=kseed_args.k)\n",
" trainer.configure_seed_candidates(seed_candidates, seed_probabilities)\n",
" return trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"ExecuteTime": {
"end_time": "2024-02-29T09:39:37.602070Z",
"start_time": "2024-02-29T09:38:34.024223Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='151' max='151' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [151/151 00:59, Epoch 0/1]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=151, training_loss=1.2660519429390005, metrics={'train_runtime': 61.8249, 'train_samples_per_second': 2.428, 'train_steps_per_second': 2.442, 'total_flos': 78910193664000.0, 'train_loss': 1.2660519429390005, 'epoch': 0.01})"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_training(zo_mode=True)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"ExecuteTime": {
"end_time": "2024-02-29T09:41:28.949449Z",
"start_time": "2024-02-29T09:39:54.802705Z"
},
"collapsed": false
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='151' max='151' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [151/151 01:29, Epoch 0/1]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=151, training_loss=0.6093456950408733, metrics={'train_runtime': 92.6158, 'train_samples_per_second': 1.621, 'train_steps_per_second': 1.63, 'total_flos': 78910193664000.0, 'train_loss': 0.6093456950408733, 'epoch': 0.01})"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_training(zo_mode=False)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"You can see that Zeroth-Order-Optimizer has much worse performance than AdamW, that's the price we need to pay for the low communication cost. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Submit Federated Task\n",
"Once you have successfully completed local testing, We can submit a task to FATE. Please notice that this tutorial is ran on a standalone version. **Please notice that in this tutorial we are using a standalone version, if you are using a cluster version, you need to bind the data with the corresponding name&namespace on each machine.**\n",
"\n",
"In this example we load pretrained weights for gpt2 model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"from fate_client.pipeline.components.fate.reader import Reader\n",
"from fate_client.pipeline import FateFlowPipeline\n",
"from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_seq2seq_runner\n",
"from fate_client.pipeline.components.fate.nn.algo_params import TrainingArguments, FedAVGArguments\n",
"from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMDataFuncLoader\n",
"\n",
"guest = '10000'\n",
"host = '10000'\n",
"arbiter = '10000'\n",
"\n",
"epochs = 0.01\n",
"batch_size = 1\n",
"lr = 1e-5\n",
"\n",
"pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter)\n",
"pipeline.bind_local_path(path=\"/data/projects/fate/examples/data/dolly\", namespace=\"experiment\",\n",
" name=\"dolly\")\n",
"time.sleep(5)\n",
"\n",
"reader_0 = Reader(\"reader_0\", runtime_parties=dict(guest=guest, host=host))\n",
"reader_0.guest.task_parameters(\n",
" namespace=\"experiment\",\n",
" name=\"dolly\"\n",
")\n",
"reader_0.hosts[0].task_parameters(\n",
" namespace=\"experiment\",\n",
" name=\"dolly\"\n",
")\n",
"\n",
"tokenizer_params = dict(\n",
" pretrained_model_name_or_path=\"gpt2\",\n",
" trust_remote_code=True,\n",
")\n",
"conf = get_config_of_seq2seq_runner(\n",
" algo='fedkseed',\n",
" model=LLMModelLoader(\n",
" \"hf_model\",\n",
" \"HFAutoModelForCausalLM\",\n",
" # pretrained_model_name_or_path=\"datajuicer/LLaMA-1B-dj-refine-150B\",\n",
" pretrained_model_name_or_path=\"gpt2\",\n",
" trust_remote_code=True\n",
" ),\n",
" dataset=LLMDatasetLoader(\n",
" \"hf_dataset\",\n",
" \"Dolly15K\",\n",
" split=\"train\",\n",
" tokenizer_params=tokenizer_params,\n",
" tokenizer_apply_params=dict(\n",
" truncation=True,\n",
" max_length=1024,\n",
" )),\n",
" data_collator=LLMDataFuncLoader(\n",
" \"cust_func.cust_data_collator\",\n",
" \"get_seq2seq_tokenizer\",\n",
" tokenizer_params=tokenizer_params,\n",
" ),\n",
" training_args=TrainingArguments(\n",
" num_train_epochs=0.01,\n",
" per_device_train_batch_size=batch_size,\n",
" remove_unused_columns=True,\n",
" learning_rate=lr,\n",
" fp16=False,\n",
" use_cpu=False,\n",
" disable_tqdm=False,\n",
" ),\n",
" fed_args=FedAVGArguments(),\n",
" task_type='causal_lm',\n",
" save_trainable_weights_only=True,\n",
")\n",
"\n",
"conf[\"fed_args_conf\"] = {}\n",
"\n",
"homo_nn_0 = HomoNN(\n",
" 'nn_0',\n",
" runner_conf=conf,\n",
" train_data=reader_0.outputs[\"output_data\"],\n",
" runner_module=\"fedkseed_runner\",\n",
" runner_class=\"FedKSeedRunner\",\n",
")\n",
"\n",
"pipeline.add_tasks([reader_0, homo_nn_0])\n",
"pipeline.conf.set(\"task\", dict(engine_run={\"cores\": 1}))\n",
"\n",
"pipeline.compile()\n",
"pipeline.fit()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can use this script to submit the model, but submitting the model will take a long time to train and generate a long log, so we won't do it here."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: doc/tutorial/fedmkt/README.md
================================================
# FATE-LLM: FedMKT
The algorithm is based on paper ["FedMKT: Federated Mutual Knowledge Transfer for Large and SmallLanguage Models"](https://aclanthology.org/2025.coling-main.17.pdf), We integrate its code into the FATE-LLM framework.
## Citation
If you publish work that uses FedMKT, please cite FedMKT as follows:
```
@inproceedings{fan2025fedmkt,
title={Fedmkt: Federated mutual knowledge transfer for large and small language models},
author={Fan, Tao and Ma, Guoqiang and Kang, Yan and Gu, Hanlin and Song, Yuanfeng and Fan, Lixin and Chen, Kai and Yang, Qiang},
booktitle={Proceedings of the 31st International Conference on Computational Linguistics},
pages={243--255},
year={2025}
}
```
================================================
FILE: doc/tutorial/fedmkt/fedmkt.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Federated Tuning With FedMKT methods in FATE-LLM"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this tutorial, we will demonstrate how to efficiently train federated large language models using the FATE-LLM framework. In FATE-LLM, we introduce the \"FedMKT\" module, specifically designed for federated learning with large language models. FedMKT introduces a novel\n",
"federated mutual knowledge transfer framework that enables effective knowledge transfer between an LLM deployed on the server and SLMs residing on clients.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The Algorithm is based on paper [\"FedMKT: Federated Mutual Knowledge Transfer for Large and SmallLanguage Models\"](https://arxiv.org/pdf/2406.02224), We integrate its code into the FATE-LLM framework. \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Experiments\n",
"\n",
"Chapter List: \n",
"* settings\n",
" 1. DataSet: ARC-Challenge\n",
" 2. Models Use in \"FEDMKT\" Paper\n",
" 3. Prepare Optimal Vocabulary Mapping Tables\n",
" 4. Training LLMs with Lora\n",
"* experiment examples:\n",
" 1. Running FEDMKT With Launcher (Experimential Using): 4-SLMs\n",
" 2. Running FEDMKT With Launcher (Experimential Using): 1-SLM (One To One)\n",
" 3. Running FEDMKT With Launcher (Experimential Using): 1-SLM And SLM Trains Only (LLM2SLM)\n",
" 4. Running FEDMKT With Launcher (Experimential Using): 4-SLMs Homogeneous SFT\n",
" 5. Running FEDMKT with Pipeline (Industrial Using)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dataset: ARC-Challenge\n",
"\n",
"ARC-Challenge is a dataset of 7,787 genuine grade-school level, multiple-choice science questions, assembled to encourage research in advanced question-answering. \n",
"\n",
"You can refer to following link for more details about [ARC-Challange](https://huggingface.co/datasets/allenai/ai2_arc)\n",
"\n",
"In this section, we will download ARC-Challenge dataset from huggingface and splits it into five parts, part \"common\" for public dataset and other parts for slms(opt2, gpt2, llama, opt)'s training. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import datasets\n",
"\n",
"\n",
"data = datasets.load_dataset(\"ai2_arc\", \"ARC-Challenge\", download_mode=\"force_redownload\", ignore_verifications=True)\n",
"train_data = data.pop(\"train\")\n",
"\n",
"seed=123\n",
"n = train_data.shape[0]\n",
"client_num = 4\n",
"process_data_output_dir = \"\" # processed data saved directory should be specified, it will be used in later.\n",
"\n",
"client_data_num = n // (client_num + 1)\n",
"\n",
"for i in range(client_num):\n",
" splits = train_data.train_test_split(train_size=client_data_num, shuffle=True, seed=seed)\n",
" client_name = f\"client_{i}\"\n",
" data[client_name] = splits[\"train\"]\n",
" train_data = splits[\"test\"]\n",
"\n",
"if train_data.shape[0] == client_data_num:\n",
" data[\"common\"] = train_data\n",
"else:\n",
" data[\"common\"] = train_data.train_test_split(\n",
" train_size=client_data_num, shuffle=True, seed=args.seed\n",
" )[\"train\"]\n",
"\n",
"data.save_to_disk(process_data_output_dir)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Models Use In \"FEDMKT\" Paper\n",
"\n",
"LLM: [Llama-2-7B](https://huggingface.co/meta-llama/Llama-2-7b-hf) \n",
"SLM-0: [opt-1.3b](https://huggingface.co/facebook/opt-1.3b) \n",
"SLM-1: [gpt2-xlarge](https://huggingface.co/openai-community/gpt2-xl) \n",
"SLM-2: [Llama-1.3b](https://huggingface.co/princeton-nlp/Sheared-LLaMA-1.3B) \n",
"SLM-3: [bloom-1.1B](https://huggingface.co/bigscience/bloom-1b1)\n",
"\n",
"Users should download the models from huggingface before the following steps and saved them in local directories, as models are too big, redownload them cost too much times.\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# replaoce the names of models to local save directories\n",
"llm_pretrained_path = \"llama-2-7b-hf\"\n",
"slm_0_pretrained_path = \"opt-1.3b\"\n",
"slm_1_pretrained_path = \"gpt2-xl\"\n",
"slm_2_pretrained_path = \"Sheared-LLaMA-1.3B\"\n",
"slm_3_pretrained_path = \"bloom-1b1\"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prepare Optimal Vocabulary Mapping Tables\n",
"\n",
"To use \"FEDMKT\" for federated knowledge transfer, we need to build pptimal vocabulary mapping tables first.\n",
"In paper of \"FEDMKT\", it has One LLM and four SLMs, so we need to build eight pptimal vocabulary mapping tables. For each paired of (LLM, SLM), two tables should be built as co-training are needed.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fate_llm.algo.fedmkt.token_alignment.vocab_mapping import get_vocab_mappings\n",
"\n",
"\n",
"llm_slm_pairs = [\n",
" (llm_pretrained_path, slm_0_pretrained_path),\n",
" (llm_pretrained_path, slm_1_pretrained_path),\n",
" (llm_pretrained_path, slm_2_pretrained_path),\n",
" (llm_pretrained_path, slm_3_pretrained_path)\n",
"]\n",
"\n",
"vocab_mapping_directory = \"\" # replace this to actually paths\n",
"\n",
"slm_to_llm_vocab_mapping_paths = [\"opt_to_llama.json\", \"gpt2_to_llama.json\", \"llama_small_to_llama.json\", \"bloom_to_llama.json\"]\n",
"llm_to_slm_vocab_mapping_paths = [\"llama_to_opt.json\", \"llama_to_gpt2.json\", \"llama_to_llama_small\", \"llama_to_bloom.json\"]\n",
"\n",
"for idx in range(4):\n",
" slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + slm_to_llm_vocab_mapping_paths[idx]\n",
" llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + llm_to_slm_vocab_mapping_paths[idx]\n",
"\n",
"for idx, (llm_pretrained, slm_pretrained) in enumerate(llm_slm_pairs):\n",
" slm_to_llm_vocab_mapping_path = slm_to_llm_vocab_mapping_paths[idx]\n",
" llm_to_slm_vocab_mapping_path = llm_to_slm_vocab_mapping_paths[idx]\n",
" _ = get_vocab_mappings(slm_pretrained, llm_pretrained, slm_to_llm_vocab_mapping_paths[idx], num_processors=16)\n",
" _ = get_vocab_mappings(llm_pretrained, slm_pretrained, llm_to_slm_vocab_mapping_paths[idx], num_processors=16)\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training LLMs with Lora\n",
"\n",
"In this section, We will introduce the lora configs use in five models listed in paper: one LLM (Llama-2-7B), four SLMs(opt-1.3B, gpt2-xlarge, Llama-1.3B, bloom-1.1B)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"LLM models with peft is located on fate_llm/model_zoo, we will give a guide to use them. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Init LLm Llama-2-7B's Lora Config"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\n",
" target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']\n",
")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Init SLMs Lora Config"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"slm_pretrained_paths = [slm_0_pretrained_path, slm_1_pretrained_path, slm_2_pretrained_path, slm_3_pretrained_path]\n",
"slm_lora_target_modules = [\n",
" [\"q_proj\", \"v_proj\"],\n",
" [\"c_attn\"],\n",
" ['q_proj', 'k_proj', 'v_proj', 'o_proj'],\n",
" [\"query_key_value\"]\n",
"]\n",
"\n",
"def get_slm_conf(slm_idx):\n",
" slm_pretrained_path = slm_pretrained_paths[slm_idx]\n",
" lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\n",
" target_modules=slm_lora_target_modules[slm_idx]\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Running FEDMKT With Launcher (Experimential Using): 4-SLMs\n",
"\n",
"Using launcher to startup is mainly for experimential. Before running this section, make sure that [FATE-LLM Standalone](https://github.com/FederatedAI/FATE-LLM?tab=readme-ov-file#standalone-deployment) has been deployed."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Global Settings"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"process_data_output_dir = \"\"\n",
"llm_pretrained_path = \"Llama-2-7b-hf\"\n",
"slm_0_pretrained_path = \"opt-1.3b\"\n",
"slm_1_pretrained_path = \"gpt2-xl\"\n",
"slm_2_pretrained_path = \"Sheared-LLaMa-1.3B\"\n",
"slm_3_pretrained_path = \"bloom-1b1\"\n",
"llm_slm_pairs = [\n",
" (llm_pretrained_path, slm_0_pretrained_path),\n",
" (llm_pretrained_path, slm_1_pretrained_path),\n",
" (llm_pretrained_path, slm_2_pretrained_path),\n",
" (llm_pretrained_path, slm_3_pretrained_path)\n",
"]\n",
"\n",
"vocab_mapping_directory = \"\"\n",
"\n",
"slm_to_llm_vocab_mapping_paths = [\"opt_to_llama.json\", \"gpt2_to_llama.json\", \"llama_small_to_llama.json\", \"bloom_to_llama.json\"]\n",
"llm_to_slm_vocab_mapping_paths = [\"llama_to_opt.json\", \"llama_to_gpt2.json\", \"llama_to_llama_small\", \"llama_to_bloom.json\"]\n",
"\n",
"for idx in range(4):\n",
" slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + slm_to_llm_vocab_mapping_paths[idx]\n",
" llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + llm_to_slm_vocab_mapping_paths[idx]\n",
"\n",
"#### all variables has been defined above\n",
"\n",
"slm_pretrained_paths = [slm_0_pretrained_path, slm_1_pretrained_path, slm_2_pretrained_path, slm_3_pretrained_path]\n",
"slm_lora_target_modules = [\n",
" [\"q_proj\", \"v_proj\"],\n",
" [\"c_attn\"],\n",
" ['q_proj', 'k_proj', 'v_proj', 'o_proj'],\n",
" [\"query_key_value\"]\n",
"]\n",
"\n",
"global_epochs = 1\n",
"batch_size=4\n",
"llm_lr = 3e-5\n",
"slm_lrs = [3e-5, 3e-4, 3e-5, 3e-5, 3e-5]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Init FEDMKTLLM Runner"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"In this Section, we will introduce how to initialize \"FEDMKTLLM\" object."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Step1: Initialize LLM With LoraConfig"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from peft import LoraConfig, TaskType\n",
"from fate_llm.model_zoo.pellm.llama import LLaMa\n",
"from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTLLM\n",
"from fate.ml.nn.homo.fedavg import FedAVGArguments\n",
"from fate_llm.dataset.qa_dataset import QaDataset\n",
"from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n",
"from transformers import AutoConfig\n",
"\n",
"lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\n",
" target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']\n",
")\n",
"\n",
"model = LLaMa(\n",
" pretrained_path=llm_pretrained_path,\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" torch_dtype=\"bfloat16\" \n",
")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Step2: Specify Public Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pub_data = QaDataset(tokenizer_name_or_path=llm_pretrained_path,\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=\"common\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
"pub_data.load(process_data_output_dir)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Step3: Initialize FEDMKT Training Args"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"training_args = FedMKTTrainingArguments(\n",
" global_epochs=global_epochs,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=batch_size,\n",
" learning_rate=llm_lr,\n",
" output_dir=\"./\",\n",
" dataloader_num_workers=4,\n",
" remove_unused_columns=False,\n",
" warmup_ratio=0.008,\n",
" lr_scheduler_type=\"cosine\",\n",
" optim=\"adamw_torch\",\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.95,\n",
" weight_decay=0.1,\n",
" max_grad_norm=1.0,\n",
" use_cpu=False,\n",
" vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size, # pay attention to this, \n",
" # vocab_size must be specified to avoid dimension mismatch \n",
" # of tokenizer's vocab_size\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Step4: Initialize Other Variables"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fed_args = FedAVGArguments(\n",
" aggregate_strategy='epoch',\n",
" aggregate_freq=1\n",
")\n",
"\n",
"slm_to_llm_vocab_mapping = []\n",
"for path in slm_to_llm_vocab_mapping_paths:\n",
" with open(path, \"r\") as fin:\n",
" vocab_mapping = json.loads(fin.read())\n",
" slm_to_llm_vocab_mapping.append(vocab_mapping)\n",
"\n",
"slm_tokenizers = [get_tokenizer(slm_pretrained_path) for slm_pretrained_path in slm_pretrained_paths]\n",
"tokenizer = get_tokenizer(llm_pretrained_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Step5: New FEDMKTLLM Object"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer = FedMKTLLM(\n",
" ctx=ctx,\n",
" model=model,\n",
" training_args=training_args,\n",
" fed_args=fed_args,\n",
" train_set=pub_data,\n",
" tokenizer=tokenizer,\n",
" slm_tokenizers=slm_tokenizers,\n",
" slm_to_llm_vocab_mappings=slm_to_llm_vocab_mapping,\n",
" save_trainable_weights_only=True, # save lora weights only\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Step6: Training And Save Results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer.train()\n",
"trainer.save_model(output_dir=\"fill the path to save llm finetuning result\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Init FEDMKTSLM Runner\n",
"\n",
"FEDMKTSLM Runner is a slightly different of FEDMKTLLM Runner, we only introduce different variables"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Import SLMs you need to run, here we choose four Slms Using In Original Paper."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import transformers\n",
"from peft import LoraConfig, TaskType \n",
"from fate_llm.model_zoo.pellm.llama import LLaMa\n",
"from fate_llm.model_zoo.pellm.gpt2 import GPT2CLM\n",
"from fate_llm.model_zoo.pellm.opt import OPT\n",
"from fate_llm.model_zoo.pellm.bloom import Bloom\n",
"from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTSLM\n",
"from fate_llm.dataset.qa_dataset import QaDataset\n",
"from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n",
"from transformers import AutoConfig\n",
"\n",
"slm_idx = 0\n",
"\n",
"slm_model_class = [\n",
" OPT,\n",
" GPT2CLM,\n",
" LLaMa,\n",
" Bloom\n",
"]\n",
" \n",
"lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\n",
" target_modules=slm_lora_target_modules[slm_idx]\n",
")\n",
"\n",
"model = slm_model_class[slm_idx](\n",
" pretrained_path=slm_pretrained_paths[slm_idx],\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" torch_dtype=\"bfloat16\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Specify Private Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"priv_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=f\"client_{slm_idx}\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
"priv_data.load(process_data_output_dir)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Other Variables "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = get_tokenizer(slm_pretrained_paths[slm_idx])\n",
"\n",
"import json\n",
"with open(llm_to_slm_vocab_mapping_paths[slm_idx], \"r\") as fin:\n",
" vocab_mapping = json.loads(fin.read())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### New FEDMKTSLM Object"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer = FedMKTSLM(\n",
" ctx=ctx,\n",
" model=model,\n",
" training_args=training_args,\n",
" fed_args=fed_args,\n",
" pub_train_set=pub_data,\n",
" priv_train_set=priv_data,\n",
" tokenizer=tokenizer,\n",
" save_trainable_weights_only=True, # save lora weights only\n",
" llm_tokenizer=get_tokenizer(llm_pretrained_path), # different with LLM setting\n",
" llm_to_slm_vocab_mapping=vocab_mapping, # different with LLM setting\n",
" data_collator=transformers.DataCollatorForSeq2Seq(tokenizer) # use to train private dataset\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Complete Code To DO SFT With 4 SLMs\n",
"\n",
"Please paste the code in \"fedmkt_4_slms.py\" and execute it with the following command"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"python fedmkt_4_slms.py --parties guest:9999 host:9999 host:10000 host:10001 arbiter:9999 --log_level INFO\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# fedmkt_4_slms.py\n",
"\n",
"import os\n",
"\n",
"from fate.arch import Context\n",
"from fate.arch.launchers.multiprocess_launcher import launch\n",
"import json\n",
"\n",
"process_data_output_dir = \"\"\n",
"llm_pretrained_path = \"Llama-2-7b-hf\"\n",
"slm_0_pretrained_path = \"opt-1.3b\"\n",
"slm_1_pretrained_path = \"gpt2-xl\"\n",
"slm_2_pretrained_path = \"Sheared-LLaMa-1.3B\"\n",
"slm_3_pretrained_path = \"bloom-1b1\"\n",
"llm_slm_pairs = [\n",
" (llm_pretrained_path, slm_0_pretrained_path),\n",
" (llm_pretrained_path, slm_1_pretrained_path),\n",
" (llm_pretrained_path, slm_2_pretrained_path),\n",
" (llm_pretrained_path, slm_3_pretrained_path)\n",
"]\n",
"\n",
"vocab_mapping_directory = \"\"\n",
"\n",
"slm_to_llm_vocab_mapping_paths = [\"opt_to_llama.json\", \"gpt2_to_llama.json\", \"llama_small_to_llama.json\", \"bloom_to_llama.json\"]\n",
"llm_to_slm_vocab_mapping_paths = [\"llama_to_opt.json\", \"llama_to_gpt2.json\", \"llama_to_llama_small\", \"llama_to_bloom.json\"]\n",
"\n",
"for idx in range(4):\n",
" slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + slm_to_llm_vocab_mapping_paths[idx]\n",
" llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + llm_to_slm_vocab_mapping_paths[idx]\n",
"\n",
"slm_pretrained_paths = [slm_0_pretrained_path, slm_1_pretrained_path, slm_2_pretrained_path, slm_3_pretrained_path]\n",
"slm_lora_target_modules = [\n",
" [\"q_proj\", \"v_proj\"],\n",
" [\"c_attn\"],\n",
" ['q_proj', 'k_proj', 'v_proj', 'o_proj'],\n",
" [\"query_key_value\"]\n",
"]\n",
"\n",
"global_epochs = 5\n",
"batch_size=4\n",
"llm_lr = 3e-5\n",
"slm_lrs = [3e-5, 3e-4, 3e-5, 3e-5, 3e-5]\n",
"\n",
"llm_model_saved_directory = \"./models/fedmkt_4_slms_llm_model\"\n",
"slm_models_saved_directory = [\n",
" \"./models/fedmkt_4_slms_slm_0\", \n",
" \"./models/fedmkt_4_slms_slm_1\", \n",
" \"./models/fedmkt_4_slms_slm_2\", \n",
" \"./models/fedmkt_4_slms_slm_3\"\n",
"]\n",
"\n",
"\n",
"def train_llm(ctx):\n",
" from peft import LoraConfig, TaskType\n",
" from fate_llm.model_zoo.pellm.llama import LLaMa\n",
" from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTLLM\n",
" from fate_llm.dataset.qa_dataset import QaDataset\n",
" from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n",
" from transformers import AutoConfig\n",
"\n",
" lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\n",
" target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']\n",
" )\n",
"\n",
" model = LLaMa(\n",
" pretrained_path=llm_pretrained_path,\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" torch_dtype=\"bfloat16\"\n",
" )\n",
"\n",
" pub_data = QaDataset(tokenizer_name_or_path=llm_pretrained_path,\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=\"common\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" pub_data.load(process_data_output_dir)\n",
"\n",
" training_args = FedMKTTrainingArguments(\n",
" global_epochs=global_epochs,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=batch_size,\n",
" learning_rate=llm_lr,\n",
" output_dir=\"./\",\n",
" dataloader_num_workers=4,\n",
" remove_unused_columns=False,\n",
" warmup_ratio=0.008,\n",
" lr_scheduler_type=\"cosine\",\n",
" optim=\"adamw_torch\",\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.95,\n",
" weight_decay=0.1,\n",
" max_grad_norm=1.0,\n",
" use_cpu=False,\n",
" vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size,\n",
" )\n",
"\n",
" slm_to_llm_vocab_mapping = []\n",
" for path in slm_to_llm_vocab_mapping_paths:\n",
" with open(path, \"r\") as fin:\n",
" vocab_mapping = json.loads(fin.read())\n",
" slm_to_llm_vocab_mapping.append(vocab_mapping)\n",
"\n",
" slm_tokenizers = [get_tokenizer(slm_pretrained_path) for slm_pretrained_path in slm_pretrained_paths]\n",
"\n",
" tokenizer = get_tokenizer(llm_pretrained_path)\n",
" trainer = FedMKTLLM(\n",
" ctx=ctx,\n",
" model=model,\n",
" training_args=training_args,\n",
" train_set=pub_data,\n",
" tokenizer=tokenizer,\n",
" slm_tokenizers=slm_tokenizers,\n",
" slm_to_llm_vocab_mappings=slm_to_llm_vocab_mapping,\n",
" save_trainable_weights_only=True,\n",
" )\n",
"\n",
" trainer.train()\n",
" trainer.save_model(llm_model_saved_directory)\n",
"\n",
"\n",
"def train_slm(ctx, slm_idx):\n",
" import transformers\n",
" from peft import LoraConfig, TaskType\n",
" from fate_llm.model_zoo.pellm.llama import LLaMa\n",
" from fate_llm.model_zoo.pellm.gpt2 import GPT2CLM\n",
" from fate_llm.model_zoo.pellm.opt import OPT\n",
" from fate_llm.model_zoo.pellm.bloom import Bloom\n",
" from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTSLM\n",
" from fate_llm.dataset.qa_dataset import QaDataset\n",
" from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n",
" from transformers import AutoConfig\n",
"\n",
" slm_model_class = [\n",
" OPT,\n",
" GPT2CLM,\n",
" LLaMa,\n",
" Bloom\n",
" ]\n",
"\n",
" lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\n",
" target_modules=slm_lora_target_modules[slm_idx]\n",
" )\n",
"\n",
" model = slm_model_class[slm_idx](\n",
" pretrained_path=slm_pretrained_paths[slm_idx],\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" torch_dtype=\"bfloat16\"\n",
" )\n",
"\n",
" priv_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=f\"client_{slm_idx}\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" priv_data.load(process_data_output_dir)\n",
"\n",
" pub_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=\"common\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" pub_data.load(process_data_output_dir)\n",
"\n",
" training_args = FedMKTTrainingArguments(\n",
" global_epochs=global_epochs,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=batch_size,\n",
" learning_rate=slm_lrs[slm_idx],\n",
" output_dir=\"./\",\n",
" dataloader_num_workers=4,\n",
" remove_unused_columns=False,\n",
" warmup_ratio=0.008,\n",
" lr_scheduler_type=\"cosine\",\n",
" optim=\"adamw_torch\",\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.95,\n",
" weight_decay=0.1,\n",
" max_grad_norm=1.0,\n",
" use_cpu=False,\n",
" vocab_size=AutoConfig.from_pretrained(slm_pretrained_paths[slm_idx]).vocab_size,\n",
" )\n",
"\n",
" tokenizer = get_tokenizer(slm_pretrained_paths[slm_idx])\n",
"\n",
" import json\n",
" with open(llm_to_slm_vocab_mapping_paths[slm_idx], \"r\") as fin:\n",
" vocab_mapping = json.loads(fin.read())\n",
"\n",
" trainer = FedMKTSLM(\n",
" ctx=ctx,\n",
" model=model,\n",
" training_args=training_args,\n",
" pub_train_set=pub_data,\n",
" priv_train_set=priv_data,\n",
" tokenizer=tokenizer,\n",
" save_trainable_weights_only=True,\n",
" llm_tokenizer=get_tokenizer(llm_pretrained_path),\n",
" llm_to_slm_vocab_mapping=vocab_mapping,\n",
" data_collator=transformers.DataCollatorForSeq2Seq(tokenizer)\n",
" )\n",
"\n",
" trainer.train()\n",
" trainer.save_model(slm_models_saved_directory[slm_idx])\n",
"\n",
"\n",
"def run(ctx: Context):\n",
" if ctx.is_on_arbiter:\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
" train_llm(ctx)\n",
" elif ctx.is_on_guest:\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
" train_slm(ctx, slm_idx=0)\n",
" else:\n",
" if ctx.local.party[1] == \"9999\":\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n",
" slm_idx = 1\n",
" elif ctx.local.party[1] == \"10000\":\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n",
" slm_idx = 2\n",
" elif ctx.local.party[1] == \"10001\":\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"4\"\n",
" slm_idx = 3\n",
" else:\n",
" raise ValueError(f\"party_id={ctx.local.party[1]} is illegal\")\n",
"\n",
" train_slm(ctx, slm_idx=slm_idx)\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" launch(run)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Running FEDMKT With Launcher (Experimential Using): 1-SLM (One To One)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Actually, a slightly modifications from 4-SLMs running code are enough to do sft with single clients, it will be listed in below sections, we take SLM-0(OPT) as an example"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Only Use Single Optimal Vocabulary Mapping Tables"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"slm_idx = 0\n",
"slm_to_llm_vocab_mapping = []\n",
"with open(slm_to_llm_vocab_mapping_paths[slm_idx], \"r\") as fin:\n",
" vocab_mapping = json.loads(fin.read())\n",
" slm_to_llm_vocab_mapping.append(vocab_mapping)\n",
"\n",
"slm_tokenizers = [get_tokenizer(slm_pretrained_paths[slm_idx])]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Complete Code To DO SFT With 1 SLM\n",
"\n",
"Please paste the code in \"fedmkt_1_slm.py\" and execute it with the following command"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"python fedmkt_1_slm.py --parties guest:9999 arbiter:9999 --log_level INFO\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# fedmkt_1_slm.py\n",
"\n",
"import os\n",
"\n",
"from fate.arch import Context\n",
"from fate.arch.launchers.multiprocess_launcher import launch\n",
"import json\n",
"\n",
"process_data_output_dir = \"\"\n",
"llm_pretrained_path = \"Llama-2-7b-hf\"\n",
"slm_0_pretrained_path = \"opt-1.3b\"\n",
"slm_1_pretrained_path = \"gpt2-xl\"\n",
"slm_2_pretrained_path = \"Sheared-LLaMa-1.3B\"\n",
"slm_3_pretrained_path = \"bloom-1b1\"\n",
"llm_slm_pairs = [\n",
" (llm_pretrained_path, slm_0_pretrained_path),\n",
" (llm_pretrained_path, slm_1_pretrained_path),\n",
" (llm_pretrained_path, slm_2_pretrained_path),\n",
" (llm_pretrained_path, slm_3_pretrained_path)\n",
"]\n",
"\n",
"vocab_mapping_directory = \"\"\n",
"\n",
"slm_to_llm_vocab_mapping_paths = [\"opt_to_llama.json\", \"gpt2_to_llama.json\", \"llama_small_to_llama.json\", \"bloom_to_llama.json\"]\n",
"llm_to_slm_vocab_mapping_paths = [\"llama_to_opt.json\", \"llama_to_gpt2.json\", \"llama_to_llama_small\", \"llama_to_bloom.json\"]\n",
"\n",
"for idx in range(4):\n",
" slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + slm_to_llm_vocab_mapping_paths[idx]\n",
" llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + llm_to_slm_vocab_mapping_paths[idx]\n",
"\n",
"slm_pretrained_paths = [slm_0_pretrained_path, slm_1_pretrained_path, slm_2_pretrained_path, slm_3_pretrained_path]\n",
"slm_lora_target_modules = [\n",
" [\"q_proj\", \"v_proj\"],\n",
" [\"c_attn\"],\n",
" ['q_proj', 'k_proj', 'v_proj', 'o_proj'],\n",
" [\"query_key_value\"]\n",
"]\n",
"\n",
"global_epochs = 5\n",
"batch_size = 4\n",
"llm_lr = 3e-5\n",
"slm_lrs = [3e-5]\n",
"\n",
"llm_model_saved_directory = \"./models/fedmkt_single_slm_llm\"\n",
"slm_models_saved_directory = [\n",
" \"./models/fedmkt_single_slm_opt\",\n",
"]\n",
"\n",
"\n",
"def train_llm(ctx, slm_idx):\n",
" from peft import LoraConfig, TaskType\n",
" from fate_llm.model_zoo.pellm.llama import LLaMa\n",
" from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTLLM\n",
" from fate_llm.dataset.qa_dataset import QaDataset\n",
" from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n",
" from transformers import AutoConfig\n",
"\n",
" lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\n",
" target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']\n",
" )\n",
"\n",
" model = LLaMa(\n",
" pretrained_path=llm_pretrained_path,\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" torch_dtype=\"bfloat16\"\n",
" )\n",
"\n",
" pub_data = QaDataset(tokenizer_name_or_path=llm_pretrained_path,\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=\"common\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" pub_data.load(process_data_output_dir)\n",
"\n",
" training_args = FedMKTTrainingArguments(\n",
" global_epochs=global_epochs,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=batch_size,\n",
" learning_rate=llm_lr,\n",
" output_dir=\"./\",\n",
" dataloader_num_workers=4,\n",
" remove_unused_columns=False,\n",
" warmup_ratio=0.008,\n",
" lr_scheduler_type=\"cosine\",\n",
" optim=\"adamw_torch\",\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.95,\n",
" weight_decay=0.1,\n",
" max_grad_norm=1.0,\n",
" use_cpu=False,\n",
" vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size,\n",
" )\n",
"\n",
" slm_to_llm_vocab_mapping = []\n",
" with open(slm_to_llm_vocab_mapping_paths[slm_idx], \"r\") as fin:\n",
" vocab_mapping = json.loads(fin.read())\n",
" slm_to_llm_vocab_mapping.append(vocab_mapping)\n",
"\n",
" slm_tokenizers = [get_tokenizer(slm_pretrained_paths[slm_idx])]\n",
"\n",
" tokenizer = get_tokenizer(llm_pretrained_path)\n",
" trainer = FedMKTLLM(\n",
" ctx=ctx,\n",
" model=model,\n",
" training_args=training_args,\n",
" train_set=pub_data,\n",
" tokenizer=tokenizer,\n",
" slm_tokenizers=slm_tokenizers,\n",
" slm_to_llm_vocab_mappings=slm_to_llm_vocab_mapping,\n",
" save_trainable_weights_only=True,\n",
" )\n",
"\n",
" trainer.train()\n",
" trainer.save_model(llm_model_saved_directory)\n",
"\n",
"\n",
"def train_slm(ctx, slm_idx):\n",
" import transformers\n",
" from peft import LoraConfig, TaskType\n",
" from fate_llm.model_zoo.pellm.llama import LLaMa\n",
" from fate_llm.model_zoo.pellm.gpt2 import GPT2CLM\n",
" from fate_llm.model_zoo.pellm.opt import OPT\n",
" from fate_llm.model_zoo.pellm.bloom import Bloom\n",
" from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTSLM\n",
" from fate_llm.dataset.qa_dataset import QaDataset\n",
" from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n",
" from transformers import AutoConfig\n",
"\n",
" slm_model_class = [\n",
" OPT,\n",
" GPT2CLM,\n",
" LLaMa,\n",
" Bloom\n",
" ]\n",
"\n",
" lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\n",
" target_modules=slm_lora_target_modules[slm_idx]\n",
" )\n",
"\n",
" model = slm_model_class[slm_idx](\n",
" pretrained_path=slm_pretrained_paths[slm_idx],\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" torch_dtype=\"bfloat16\"\n",
" )\n",
"\n",
" priv_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=f\"client_{slm_idx}\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" priv_data.load(process_data_output_dir)\n",
"\n",
" pub_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=\"common\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" pub_data.load(process_data_output_dir)\n",
"\n",
" training_args = FedMKTTrainingArguments(\n",
" global_epochs=global_epochs,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=batch_size,\n",
" learning_rate=slm_lrs[slm_idx],\n",
" output_dir=\"./\",\n",
" dataloader_num_workers=4,\n",
" remove_unused_columns=False,\n",
" warmup_ratio=0.008,\n",
" lr_scheduler_type=\"cosine\",\n",
" optim=\"adamw_torch\",\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.95,\n",
" weight_decay=0.1,\n",
" max_grad_norm=1.0,\n",
" use_cpu=False,\n",
" vocab_size=AutoConfig.from_pretrained(slm_pretrained_paths[slm_idx]).vocab_size,\n",
" )\n",
"\n",
" tokenizer = get_tokenizer(slm_pretrained_paths[slm_idx])\n",
"\n",
" import json\n",
" with open(llm_to_slm_vocab_mapping_paths[slm_idx], \"r\") as fin:\n",
" vocab_mapping = json.loads(fin.read())\n",
"\n",
" trainer = FedMKTSLM(\n",
" ctx=ctx,\n",
" model=model,\n",
" training_args=training_args,\n",
" pub_train_set=pub_data,\n",
" priv_train_set=priv_data,\n",
" tokenizer=tokenizer,\n",
" save_trainable_weights_only=True,\n",
" llm_tokenizer=get_tokenizer(llm_pretrained_path),\n",
" llm_to_slm_vocab_mapping=vocab_mapping,\n",
" data_collator=transformers.DataCollatorForSeq2Seq(tokenizer)\n",
" )\n",
"\n",
" trainer.train()\n",
" trainer.save_model(slm_models_saved_directory[slm_idx])\n",
"\n",
"\n",
"def run(ctx: Context):\n",
" if ctx.is_on_arbiter:\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
" train_llm(ctx, slm_idx=0)\n",
" else:\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
" train_slm(ctx, slm_idx=0)\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" launch(run)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Running FEDMKT With Launcher (Experimential Using): 1-SLM And SLM Trains Only (LLM2SLM)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this section, we introduce how to do SFT using FEDMKT algorithm, with only single SLM are trained, but without LLM training, means that SLM distill knowlege from LLM only, not co-training."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Difference With Section \"Running FEDMKT With Launcher (Experimential Using): 1-SLMs\"\n",
"\n",
"Add llm_training=False to fedmkt_training_args to both LLM and LLM is enough!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Complete Code To DO SFT With 1 SLM And SLM Trains Only\n",
"\n",
"Please paste the code in \"fedmkt_llm_to_slm.py\" and execute it with the following command"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"python fedmkt_llm_to_slm.py --parties guest:9999 arbiter:9999 --log_level INFO\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# fedmkt_llm_to_slm.py\n",
"\n",
"import os\n",
"\n",
"from fate.arch import Context\n",
"from fate.arch.launchers.multiprocess_launcher import launch\n",
"import json\n",
"\n",
"process_data_output_dir = \"\"\n",
"llm_pretrained_path = \"Llama-2-7b-hf\"\n",
"slm_0_pretrained_path = \"opt-1.3b\"\n",
"slm_1_pretrained_path = \"gpt2-xl\"\n",
"slm_2_pretrained_path = \"Sheared-LLaMa-1.3B\"\n",
"slm_3_pretrained_path = \"bloom-1b1\"\n",
"llm_slm_pairs = [\n",
" (llm_pretrained_path, slm_0_pretrained_path),\n",
" (llm_pretrained_path, slm_1_pretrained_path),\n",
" (llm_pretrained_path, slm_2_pretrained_path),\n",
" (llm_pretrained_path, slm_3_pretrained_path)\n",
"]\n",
"\n",
"vocab_mapping_directory = \"\"\n",
"\n",
"slm_to_llm_vocab_mapping_paths = [\"opt_to_llama.json\", \"gpt2_to_llama.json\", \"llama_small_to_llama.json\", \"bloom_to_llama.json\"]\n",
"llm_to_slm_vocab_mapping_paths = [\"llama_to_opt.json\", \"llama_to_gpt2.json\", \"llama_to_llama_small\", \"llama_to_bloom.json\"]\n",
"\n",
"for idx in range(4):\n",
" slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + slm_to_llm_vocab_mapping_paths[idx]\n",
" llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + llm_to_slm_vocab_mapping_paths[idx]\n",
"\n",
"slm_pretrained_paths = [slm_0_pretrained_path, slm_1_pretrained_path, slm_2_pretrained_path, slm_3_pretrained_path]\n",
"slm_lora_target_modules = [\n",
" [\"q_proj\", \"v_proj\"],\n",
" [\"c_attn\"],\n",
" ['q_proj', 'k_proj', 'v_proj', 'o_proj'],\n",
" [\"query_key_value\"]\n",
"]\n",
"\n",
"global_epochs = 5\n",
"batch_size = 4\n",
"llm_lr = 3e-5\n",
"slm_lrs = [3e-5]\n",
"\n",
"slm_models_saved_directory = [\n",
" \"./models/fedmkt_llm_to_slm_opt\",\n",
"]\n",
"\n",
"\n",
"def train_llm(ctx, slm_idx):\n",
" from peft import LoraConfig, TaskType\n",
" from fate_llm.model_zoo.pellm.llama import LLaMa\n",
" from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTLLM\n",
" from fate_llm.dataset.qa_dataset import QaDataset\n",
" from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n",
" from transformers import AutoConfig\n",
"\n",
" lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\n",
" target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']\n",
" )\n",
"\n",
" model = LLaMa(\n",
" pretrained_path=llm_pretrained_path,\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" torch_dtype=\"bfloat16\"\n",
" )\n",
"\n",
" pub_data = QaDataset(tokenizer_name_or_path=llm_pretrained_path,\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=\"common\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" pub_data.load(process_data_output_dir)\n",
"\n",
" training_args = FedMKTTrainingArguments(\n",
" global_epochs=global_epochs,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=batch_size,\n",
" learning_rate=llm_lr,\n",
" output_dir=\"./\",\n",
" dataloader_num_workers=4,\n",
" remove_unused_columns=False,\n",
" warmup_ratio=0.008,\n",
" lr_scheduler_type=\"cosine\",\n",
" optim=\"adamw_torch\",\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.95,\n",
" weight_decay=0.1,\n",
" max_grad_norm=1.0,\n",
" use_cpu=False,\n",
" vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size,\n",
" llm_training=False\n",
" )\n",
"\n",
" slm_to_llm_vocab_mapping = []\n",
" with open(slm_to_llm_vocab_mapping_paths[slm_idx], \"r\") as fin:\n",
" vocab_mapping = json.loads(fin.read())\n",
" slm_to_llm_vocab_mapping.append(vocab_mapping)\n",
"\n",
" slm_tokenizers = [get_tokenizer(slm_pretrained_paths[slm_idx])]\n",
"\n",
" tokenizer = get_tokenizer(llm_pretrained_path)\n",
" trainer = FedMKTLLM(\n",
" ctx=ctx,\n",
" model=model,\n",
" training_args=training_args,\n",
" train_set=pub_data,\n",
" tokenizer=tokenizer,\n",
" slm_tokenizers=slm_tokenizers,\n",
" slm_to_llm_vocab_mappings=slm_to_llm_vocab_mapping,\n",
" save_trainable_weights_only=True,\n",
" )\n",
"\n",
" trainer.train()\n",
"\n",
"\n",
"def train_slm(ctx, slm_idx):\n",
" import transformers\n",
" from peft import LoraConfig, TaskType\n",
" from fate_llm.model_zoo.pellm.llama import LLaMa\n",
" from fate_llm.model_zoo.pellm.gpt2 import GPT2CLM\n",
" from fate_llm.model_zoo.pellm.opt import OPT\n",
" from fate_llm.model_zoo.pellm.bloom import Bloom\n",
" from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTSLM\n",
" from fate_llm.dataset.qa_dataset import QaDataset\n",
" from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n",
" from transformers import AutoConfig\n",
"\n",
" slm_model_class = [\n",
" OPT,\n",
" GPT2CLM,\n",
" LLaMa,\n",
" Bloom\n",
" ]\n",
"\n",
" lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1,\n",
" target_modules=slm_lora_target_modules[slm_idx]\n",
" )\n",
"\n",
" model = slm_model_class[slm_idx](\n",
" pretrained_path=slm_pretrained_paths[slm_idx],\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" torch_dtype=\"bfloat16\"\n",
" )\n",
"\n",
" priv_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=f\"client_{slm_idx}\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" priv_data.load(process_data_output_dir)\n",
"\n",
" pub_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=\"common\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" pub_data.load(process_data_output_dir)\n",
"\n",
" training_args = FedMKTTrainingArguments(\n",
" global_epochs=global_epochs,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=batch_size,\n",
" learning_rate=slm_lrs[slm_idx],\n",
" output_dir=\"./\",\n",
" dataloader_num_workers=4,\n",
" remove_unused_columns=False,\n",
" warmup_ratio=0.008,\n",
" lr_scheduler_type=\"cosine\",\n",
" optim=\"adamw_torch\",\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.95,\n",
" weight_decay=0.1,\n",
" max_grad_norm=1.0,\n",
" use_cpu=False,\n",
" vocab_size=AutoConfig.from_pretrained(slm_pretrained_paths[slm_idx]).vocab_size,\n",
" llm_training=False\n",
" )\n",
"\n",
" tokenizer = get_tokenizer(slm_pretrained_paths[slm_idx])\n",
"\n",
" import json\n",
" with open(llm_to_slm_vocab_mapping_paths[slm_idx], \"r\") as fin:\n",
" vocab_mapping = json.loads(fin.read())\n",
"\n",
" trainer = FedMKTSLM(\n",
" ctx=ctx,\n",
" model=model,\n",
" training_args=training_args,\n",
" pub_train_set=pub_data,\n",
" priv_train_set=priv_data,\n",
" tokenizer=tokenizer,\n",
" save_trainable_weights_only=True,\n",
" llm_tokenizer=get_tokenizer(llm_pretrained_path),\n",
" llm_to_slm_vocab_mapping=vocab_mapping,\n",
" data_collator=transformers.DataCollatorForSeq2Seq(tokenizer)\n",
" )\n",
"\n",
" trainer.train()\n",
" trainer.save_model(slm_models_saved_directory[slm_idx])\n",
"\n",
"\n",
"def run(ctx: Context):\n",
" if ctx.is_on_arbiter:\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
" train_llm(ctx, slm_idx=0)\n",
" else:\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
" train_slm(ctx, slm_idx=0)\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" launch(run)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Running FEDMKT With Launcher (Experimential Using): 4-SLMs Homogeneous SFT\n",
"\n",
"To run homogeneous experiments, two steps are needed.\n",
"1. add post_fedavg=True to fedmkt_training_args to both LLM and LLM is enough!\n",
"2. add fed_args to FEDMKTLLM/FEDMKTSLM"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# initialze fed args\n",
"from fate.ml.nn.homo.fedavg import FedAVGArguments\n",
"\n",
"fed_args = FedAVGArguments(\n",
" aggregate_strategy='epoch',\n",
" aggregate_freq=1\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Complete Code To DO SFT With 4-SLMs Homogeneous SFT\n",
"\n",
"Please paste the code in \"fedmkt_4_slms_homo.py\" and execute it with the following command"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"python fedmkt_4_slms_homo.py --parties guest:9999 host:9999 host:10000 host:10001 arbiter:9999 --log_level INFO\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# fedmkt_4_slms_homo.py\n",
"\n",
"import os\n",
"\n",
"from fate.arch import Context\n",
"from fate.arch.launchers.multiprocess_launcher import launch\n",
"import json\n",
"\n",
"process_data_output_dir = \"\"\n",
"llm_pretrained_path = \"Llama-2-7b-hf\"\n",
"slm_0_pretrained_path = \"opt-1.3b\"\n",
"slm_1_pretrained_path = \"opt-1.3b\"\n",
"slm_2_pretrained_path = \"opt-1.3b\"\n",
"slm_3_pretrained_path = \"opt-1.3b\"\n",
"llm_slm_pairs = [\n",
" (llm_pretrained_path, slm_0_pretrained_path),\n",
" (llm_pretrained_path, slm_1_pretrained_path),\n",
" (llm_pretrained_path, slm_2_pretrained_path),\n",
" (llm_pretrained_path, slm_3_pretrained_path)\n",
"]\n",
"\n",
"vocab_mapping_directory = \"\"\n",
"\n",
"slm_to_llm_vocab_mapping_paths = [\"opt_to_llama.json\"] * 4\n",
"llm_to_slm_vocab_mapping_paths = [\"llama_to_opt.json\"] * 4\n",
"\n",
"for idx in range(4):\n",
" slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + slm_to_llm_vocab_mapping_paths[idx]\n",
" llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + \"/\" + llm_to_slm_vocab_mapping_paths[idx]\n",
"\n",
"slm_pretrained_paths = [slm_0_pretrained_path] * 4\n",
"slm_lora_target_modules = [[\"q_proj\", \"v_proj\"]] * 4\n",
"\n",
"global_epochs = 5\n",
"batch_size = 4\n",
"llm_lr = 3e-5\n",
"slm_lrs = [3e-5, 3e-5, 3e-5, 3e-5, 3e-5]\n",
"\n",
"llm_model_saved_directory = \"./models/fedmkt_homo_4_slms_llm_model\"\n",
"slm_models_saved_directory = [\n",
" \"./models/fedmkt_homo_4_slms_slm_0\",\n",
"]\n",
"\n",
"\n",
"def train_llm(ctx):\n",
" from peft import LoraConfig, TaskType\n",
" from fate_llm.model_zoo.pellm.llama import LLaMa\n",
" from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTLLM\n",
" from fate.ml.nn.homo.fedavg import FedAVGArguments\n",
" from fate_llm.dataset.qa_dataset import QaDataset\n",
" from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n",
" from transformers import AutoConfig\n",
"\n",
" lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1,\n",
" target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']\n",
" )\n",
"\n",
" model = LLaMa(\n",
" pretrained_path=llm_pretrained_path,\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" torch_dtype=\"bfloat16\"\n",
" )\n",
"\n",
" pub_data = QaDataset(tokenizer_name_or_path=llm_pretrained_path,\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=\"common\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" pub_data.load(process_data_output_dir)\n",
"\n",
" training_args = FedMKTTrainingArguments(\n",
" global_epochs=global_epochs,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=batch_size,\n",
" learning_rate=llm_lr,\n",
" output_dir=\"./\",\n",
" dataloader_num_workers=4,\n",
" remove_unused_columns=False,\n",
" warmup_ratio=0.008,\n",
" lr_scheduler_type=\"cosine\",\n",
" optim=\"adamw_torch\",\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.95,\n",
" weight_decay=0.1,\n",
" max_grad_norm=1.0,\n",
" use_cpu=False,\n",
" vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size,\n",
" post_fedavg=True, # difference\n",
" )\n",
"\n",
" # difference\n",
" fed_args = FedAVGArguments(\n",
" aggregate_strategy='epoch',\n",
" aggregate_freq=1\n",
" )\n",
"\n",
" slm_to_llm_vocab_mapping = []\n",
" for path in slm_to_llm_vocab_mapping_paths:\n",
" with open(path, \"r\") as fin:\n",
" vocab_mapping = json.loads(fin.read())\n",
" slm_to_llm_vocab_mapping.append(vocab_mapping)\n",
"\n",
" slm_tokenizers = [get_tokenizer(slm_pretrained_path) for slm_pretrained_path in slm_pretrained_paths]\n",
"\n",
" tokenizer = get_tokenizer(llm_pretrained_path)\n",
" trainer = FedMKTLLM(\n",
" ctx=ctx,\n",
" model=model,\n",
" training_args=training_args,\n",
" fed_args=fed_args, # difference\n",
" train_set=pub_data,\n",
" tokenizer=tokenizer,\n",
" slm_tokenizers=slm_tokenizers,\n",
" slm_to_llm_vocab_mappings=slm_to_llm_vocab_mapping,\n",
" save_trainable_weights_only=True,\n",
" )\n",
"\n",
" trainer.train()\n",
" trainer.save_model(llm_model_saved_directory)\n",
"\n",
"\n",
"def train_slm(ctx, slm_idx):\n",
" import transformers\n",
" from peft import LoraConfig, TaskType\n",
" from fate_llm.model_zoo.pellm.opt import OPT\n",
" from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTSLM\n",
" from fate.ml.nn.homo.fedavg import FedAVGArguments\n",
" from fate_llm.dataset.qa_dataset import QaDataset\n",
" from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\n",
" from transformers import AutoConfig\n",
"\n",
" slm_model_class = [OPT] * 4\n",
"\n",
" lora_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1,\n",
" target_modules=slm_lora_target_modules[slm_idx]\n",
" )\n",
"\n",
" model = slm_model_class[slm_idx](\n",
" pretrained_path=slm_pretrained_paths[slm_idx],\n",
" peft_type=\"LoraConfig\",\n",
" peft_config=lora_config.to_dict(),\n",
" torch_dtype=\"bfloat16\"\n",
" )\n",
"\n",
" priv_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=f\"client_{slm_idx}\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" priv_data.load(process_data_output_dir)\n",
"\n",
" pub_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\n",
" dataset_name=\"arc_challenge\",\n",
" data_part=\"common\",\n",
" seq_max_len=512,\n",
" need_preprocess=True)\n",
" pub_data.load(process_data_output_dir)\n",
"\n",
" training_args = FedMKTTrainingArguments(\n",
" global_epochs=global_epochs,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=batch_size,\n",
" learning_rate=slm_lrs[slm_idx],\n",
" output_dir=\"./\",\n",
" dataloader_num_workers=4,\n",
" remove_unused_columns=False,\n",
" warmup_ratio=0.008,\n",
" lr_scheduler_type=\"cosine\",\n",
" optim=\"adamw_torch\",\n",
" adam_beta1=0.9,\n",
" adam_beta2=0.95,\n",
" weight_decay=0.1,\n",
" max_grad_norm=1.0,\n",
" use_cpu=False,\n",
" vocab_size=AutoConfig.from_pretrained(slm_pretrained_paths[slm_idx]).vocab_size,\n",
" post_fedavg=True, # difference\n",
" )\n",
"\n",
" # difference\n",
" fed_args = FedAVGArguments(\n",
" aggregate_strategy='epoch',\n",
" aggregate_freq=1\n",
" )\n",
"\n",
" tokenizer = get_tokenizer(slm_pretrained_paths[slm_idx])\n",
"\n",
" import json\n",
" with open(llm_to_slm_vocab_mapping_paths[slm_idx], \"r\") as fin:\n",
" vocab_mapping = json.loads(fin.read())\n",
"\n",
" trainer = FedMKTSLM(\n",
" ctx=ctx,\n",
" model=model,\n",
" training_args=training_args, \n",
" fed_args=fed_args, # difference\n",
" pub_train_set=pub_data,\n",
" priv_train_set=priv_data,\n",
" tokenizer=tokenizer,\n",
" save_trainable_weights_only=True,\n",
" llm_tokenizer=get_tokenizer(llm_pretrained_path),\n",
" llm_to_slm_vocab_mapping=vocab_mapping,\n",
" data_collator=transformers.DataCollatorForSeq2Seq(tokenizer)\n",
" )\n",
"\n",
" trainer.train()\n",
" if slm_idx == 0:\n",
" trainer.save_model(slm_models_saved_directory[slm_idx])\n",
"\n",
"\n",
"def run(ctx: Context):\n",
" if ctx.is_on_arbiter:\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
" train_llm(ctx)\n",
" elif ctx.is_on_guest:\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
" train_slm(ctx, slm_idx=0)\n",
" else:\n",
" if ctx.local.party[1] == \"9999\":\n",
" os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n",
" slm_idx = 1\n",
" elif ctx.local.party[1]
gitextract_qkv2xwam/
├── LICENSE
├── README.md
├── RELEASE.md
├── doc/
│ ├── fate_llm_evaluate.md
│ ├── standalone_deploy.md
│ └── tutorial/
│ ├── fdkt/
│ │ ├── README.md
│ │ └── fdkt.ipynb
│ ├── fedcot/
│ │ ├── README.md
│ │ ├── encoder_decoder_tutorial.ipynb
│ │ └── fedcot_tutorial.ipynb
│ ├── fedkseed/
│ │ ├── README.md
│ │ └── fedkseed-example.ipynb
│ ├── fedmkt/
│ │ ├── README.md
│ │ └── fedmkt.ipynb
│ ├── inferdpt/
│ │ └── inferdpt_tutorial.ipynb
│ ├── offsite_tuning/
│ │ ├── Offsite_tuning_tutorial.ipynb
│ │ └── README.md
│ └── pellm/
│ ├── ChatGLM3-6B_ds.ipynb
│ └── builtin_pellm_models.md
├── examples/
│ ├── fedmkt/
│ │ ├── __init__.py
│ │ ├── fedmkt.py
│ │ ├── fedmkt_config.yaml
│ │ └── test_fedmkt_llmsuit.yaml
│ ├── offsite_tuning/
│ │ ├── __init__.py
│ │ ├── offsite_tuning.py
│ │ ├── offsite_tuning_config.yaml
│ │ └── test_offsite_tuning_llmsuite.yaml
│ └── pellm/
│ ├── __init__.py
│ ├── bloom_lora_config.yaml
│ ├── test_bloom_lora.py
│ └── test_pellm_llmsuite.yaml
└── python/
├── MANIFEST.in
├── fate_llm/
│ ├── __init__.py
│ ├── algo/
│ │ ├── __init__.py
│ │ ├── dp/
│ │ │ ├── __init__.py
│ │ │ ├── dp_trainer.py
│ │ │ └── opacus_compatibility/
│ │ │ ├── __init__.py
│ │ │ ├── grad_sample/
│ │ │ │ ├── __init__.py
│ │ │ │ └── embedding.py
│ │ │ ├── optimizers/
│ │ │ │ ├── __init__.py
│ │ │ │ └── optimizer.py
│ │ │ └── transformers_compate.py
│ │ ├── fdkt/
│ │ │ ├── __init__.py
│ │ │ ├── cluster/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── cluster.py
│ │ │ │ └── cluster_method.py
│ │ │ ├── fdkt_data_aug.py
│ │ │ ├── inference_inst.py
│ │ │ └── utils/
│ │ │ ├── __init__.py
│ │ │ ├── dp_loss.py
│ │ │ ├── invalid_data_filter.py
│ │ │ └── text_generate.py
│ │ ├── fedavg/
│ │ │ ├── __init__.py
│ │ │ └── fedavg.py
│ │ ├── fedcollm/
│ │ │ ├── __init__.py
│ │ │ ├── fedcollm.py
│ │ │ ├── fedcollm_trainer.py
│ │ │ └── fedcollm_training_args.py
│ │ ├── fedcot/
│ │ │ ├── __init__.py
│ │ │ ├── encoder_decoder/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── init/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── default_init.py
│ │ │ │ └── slm_encoder_decoder.py
│ │ │ ├── fedcot_trainer.py
│ │ │ └── slm_encoder_decoder_trainer.py
│ │ ├── fedkseed/
│ │ │ ├── __init__.py
│ │ │ ├── args.py
│ │ │ ├── fedkseed.py
│ │ │ ├── optimizer.py
│ │ │ ├── pytorch_utils.py
│ │ │ ├── trainer.py
│ │ │ └── zo_utils.py
│ │ ├── fedmkt/
│ │ │ ├── __init__.py
│ │ │ ├── fedmkt.py
│ │ │ ├── fedmkt_data_collator.py
│ │ │ ├── fedmkt_trainer.py
│ │ │ ├── token_alignment/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── spectal_token_mapping.py
│ │ │ │ ├── token_align.py
│ │ │ │ └── vocab_mapping.py
│ │ │ └── utils/
│ │ │ ├── __init__.py
│ │ │ ├── dataset_sync_util.py
│ │ │ ├── generate_logit_utils.py
│ │ │ ├── tokenizer_tool.py
│ │ │ └── vars_define.py
│ │ ├── inferdpt/
│ │ │ ├── __init__.py
│ │ │ ├── _encode_decode.py
│ │ │ ├── inferdpt.py
│ │ │ ├── init/
│ │ │ │ ├── _init.py
│ │ │ │ └── default_init.py
│ │ │ └── utils.py
│ │ ├── offsite_tuning/
│ │ │ ├── __init__.py
│ │ │ └── offsite_tuning.py
│ │ └── ppc-gpt/
│ │ └── __init__.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── data_collator/
│ │ │ ├── __init__.py
│ │ │ ├── cust_data_collator.py
│ │ │ └── fedcot_collator.py
│ │ └── tokenizers/
│ │ ├── __init__.py
│ │ └── cust_tokenizer.py
│ ├── dataset/
│ │ ├── __init__.py
│ │ ├── data_config/
│ │ │ ├── __init__.py
│ │ │ ├── default_ag_news.yaml
│ │ │ └── default_yelp_review.yaml
│ │ ├── fedcot_dataset.py
│ │ ├── flex_dataset.py
│ │ ├── hf_dataset.py
│ │ ├── input_output_dataset.py
│ │ ├── prompt_dataset.py
│ │ ├── qa_dataset.py
│ │ └── seq_cls_dataset.py
│ ├── evaluate/
│ │ ├── __init__.py
│ │ ├── scripts/
│ │ │ ├── __init__.py
│ │ │ ├── _options.py
│ │ │ ├── config_cli.py
│ │ │ ├── data_cli.py
│ │ │ ├── eval_cli.py
│ │ │ └── fate_llm_cli.py
│ │ ├── tasks/
│ │ │ ├── __init__.py
│ │ │ ├── advertise_gen/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── advertise_utils.py
│ │ │ │ └── default_advertise_gen.yaml
│ │ │ └── dolly_15k/
│ │ │ ├── __init__.py
│ │ │ ├── default_dolly_15k.yaml
│ │ │ └── dolly_utils.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── _io.py
│ │ ├── _parser.py
│ │ ├── config.py
│ │ ├── data_tools.py
│ │ ├── llm_evaluator.py
│ │ └── model_tools.py
│ ├── inference/
│ │ ├── __init__.py
│ │ ├── api.py
│ │ ├── hf_qw.py
│ │ ├── inference_base.py
│ │ └── vllm.py
│ ├── model_zoo/
│ │ ├── __init__.py
│ │ ├── embedding_transformer/
│ │ │ ├── __init__.py
│ │ │ └── st_model.py
│ │ ├── hf_model.py
│ │ ├── offsite_tuning/
│ │ │ ├── __init__.py
│ │ │ ├── bloom.py
│ │ │ ├── gpt2.py
│ │ │ ├── llama.py
│ │ │ └── offsite_tuning_model.py
│ │ └── pellm/
│ │ ├── __init__.py
│ │ ├── albert.py
│ │ ├── bart.py
│ │ ├── bert.py
│ │ ├── bloom.py
│ │ ├── chatglm.py
│ │ ├── deberta.py
│ │ ├── distilbert.py
│ │ ├── gpt2.py
│ │ ├── llama.py
│ │ ├── opt.py
│ │ ├── parameter_efficient_llm.py
│ │ ├── qwen.py
│ │ └── roberta.py
│ ├── runner/
│ │ ├── __init__.py
│ │ ├── fdkt_runner.py
│ │ ├── fedcot_runner.py
│ │ ├── fedkseed_runner.py
│ │ ├── fedmkt_runner.py
│ │ ├── homo_seq2seq_runner.py
│ │ ├── inferdpt_runner.py
│ │ └── offsite_tuning_runner.py
│ └── trainer/
│ ├── __init__.py
│ └── seq2seq_trainer.py
├── requirements.txt
└── setup.py
SYMBOL INDEX (637 symbols across 98 files)
FILE: examples/fedmkt/fedmkt.py
function main (line 12) | def main(config="./config.yaml", param: Union[Dict, str] = None):
FILE: examples/offsite_tuning/offsite_tuning.py
function load_params (line 11) | def load_params(file_path):
function setup_pipeline (line 17) | def setup_pipeline(params):
function main (line 112) | def main(config_file, param_file):
FILE: examples/pellm/test_bloom_lora.py
function main (line 14) | def main(config="../../config.yaml", param: Union[Dict, str] = None, nam...
FILE: python/fate_llm/algo/dp/dp_trainer.py
class DPTrainingArguments (line 32) | class DPTrainingArguments(Seq2SeqTrainingArguments):
class DPTrainer (line 39) | class DPTrainer(object):
method __init__ (line 40) | def __init__(
method _init_dp_model (line 78) | def _init_dp_model(self):
method train (line 97) | def train(self):
method _train_an_epoch (line 103) | def _train_an_epoch(self):
method _prepare_batch_input (line 141) | def _prepare_batch_input(self, input_ids) -> dict:
method freeze_model_embedding (line 148) | def freeze_model_embedding(self):
method save_model (line 151) | def save_model(
FILE: python/fate_llm/algo/dp/opacus_compatibility/__init__.py
function add_layer_compatibility (line 20) | def add_layer_compatibility(opacus):
function add_optimizer_compatibility (line 30) | def add_optimizer_compatibility(optimizer):
FILE: python/fate_llm/algo/dp/opacus_compatibility/grad_sample/embedding.py
function compute_embedding_grad_sample (line 24) | def compute_embedding_grad_sample(
FILE: python/fate_llm/algo/dp/opacus_compatibility/optimizers/optimizer.py
function add_noise (line 27) | def add_noise(self):
function add_noise_wrapper (line 47) | def add_noise_wrapper(optimizer):
FILE: python/fate_llm/algo/dp/opacus_compatibility/transformers_compate.py
function get_model_class (line 22) | def get_model_class(model):
function prepare_position_ids (line 31) | def prepare_position_ids(model, input_ids):
function _get_position_ids_for_gpt2 (line 38) | def _get_position_ids_for_gpt2(input_ids):
FILE: python/fate_llm/algo/fdkt/cluster/cluster.py
class SentenceCluster (line 20) | class SentenceCluster(object):
method __init__ (line 21) | def __init__(self, model, cluster_method="kmeans", n_clusters=8, **oth...
method get_embeddings (line 27) | def get_embeddings(self, sentences: List[str]):
method cluster (line 30) | def cluster(self, sentences):
FILE: python/fate_llm/algo/fdkt/cluster/cluster_method.py
class KMeansRunner (line 19) | class KMeansRunner(object):
method __init__ (line 20) | def __init__(self, n_clusters, **other_cluster_args):
method fit (line 24) | def fit(self, x):
function get_cluster_runner (line 31) | def get_cluster_runner(method, n_clusters, **other_cluster_args):
FILE: python/fate_llm/algo/fdkt/fdkt_data_aug.py
class FDKTTrainingArguments (line 38) | class FDKTTrainingArguments(Seq2SeqTrainingArguments):
method to_dict (line 76) | def to_dict(self):
class FDKTSLM (line 91) | class FDKTSLM(object):
method __init__ (line 92) | def __init__(
method aug_data (line 117) | def aug_data(self):
method dp_train (line 151) | def dp_train(self):
method _create_inference_inst (line 182) | def _create_inference_inst(self):
method _destroy_inference_inst (line 199) | def _destroy_inference_inst(self):
method sync_synthetic_dataset (line 205) | def sync_synthetic_dataset(self, data):
method sync_aug_data (line 208) | def sync_aug_data(self):
method save_model (line 211) | def save_model(
class FDKTLLM (line 223) | class FDKTLLM(object):
method __init__ (line 224) | def __init__(
method sync_synthetic_data (line 248) | def sync_synthetic_data(self):
method sync_aug_data (line 251) | def sync_aug_data(self, aug_data):
method aug_data (line 254) | def aug_data(self):
method _aug (line 273) | def _aug(self, aug_prompts):
method filter_data (line 289) | def filter_data(self, slm_data):
method cluster_data (line 314) | def cluster_data(self, slm_data):
FILE: python/fate_llm/algo/fdkt/inference_inst.py
function api_init (line 16) | def api_init(api_url: str, model_name: str, api_key: str = 'EMPTY', api_...
function vllm_init (line 26) | def vllm_init(model_path: str, num_gpu=1, dtype='float16', gpu_memory_ut...
FILE: python/fate_llm/algo/fdkt/utils/dp_loss.py
class SequenceCrossEntropyLoss (line 25) | class SequenceCrossEntropyLoss(nn.Module):
method __init__ (line 26) | def __init__(self, model_type, label_smoothing=-1, reduce=None):
method forward (line 32) | def forward(self, logits, targets, mask):
function sequence_cross_entropy_with_logits (line 36) | def sequence_cross_entropy_with_logits(logits, targets, mask, label_smoo...
FILE: python/fate_llm/algo/fdkt/utils/invalid_data_filter.py
function filter_invalid_data (line 20) | def filter_invalid_data(data_dict):
FILE: python/fate_llm/algo/fdkt/utils/text_generate.py
function slm_text_generate (line 20) | def slm_text_generate(
function general_text_generate (line 70) | def general_text_generate(
FILE: python/fate_llm/algo/fedavg/fedavg.py
class Seq2SeqFedAVGClient (line 39) | class Seq2SeqFedAVGClient(HomoSeq2SeqTrainerClient):
method __init__ (line 41) | def __init__(
method init_aggregator (line 83) | def init_aggregator(self, ctx: Context, fed_args: FedArguments):
method on_federation (line 91) | def on_federation(
FILE: python/fate_llm/algo/fedcollm/fedcollm.py
class FedCoLLMBase (line 38) | class FedCoLLMBase(object):
method update_model (line 40) | def update_model(model, updated_params):
class SLM (line 45) | class SLM(FedCoLLMBase):
method __init__ (line 46) | def __init__(
method train (line 85) | def train(self):
method _sync_slm_updated_params (line 111) | def _sync_slm_updated_params(self, iter_ctx):
method _get_slm_training_args (line 115) | def _get_slm_training_args(self):
method _init_aggregator (line 118) | def _init_aggregator(self, ctx: Context, fed_args: FedArguments):
class LLM (line 128) | class LLM(FedCoLLMBase):
method __init__ (line 129) | def __init__(
method _init_aggregator (line 178) | def _init_aggregator(self, ctx: Context):
method _get_logits (line 181) | def _get_logits(self, model):
method on_epoch_begin (line 198) | def on_epoch_begin(self, iter_ctx):
method _sync_slm_updated_params (line 203) | def _sync_slm_updated_params(self, iter_ctx):
method _train_slm (line 209) | def _train_slm(self, iter_ctx, llm_pub_logits, epoch_idx):
method _train_llm (line 242) | def _train_llm(self, slm_pub_logits, epoch_idx):
method train (line 273) | def train(self):
FILE: python/fate_llm/algo/fedcollm/fedcollm_trainer.py
function computing_kd_loss (line 34) | def computing_kd_loss(src_logits, dst_logits, loss_mask):
function recovery_logits (line 45) | def recovery_logits(
class FedCoLLMTrainer (line 73) | class FedCoLLMTrainer(Seq2SeqTrainer):
method __init__ (line 82) | def __init__(self, **kwargs):
method compute_loss (line 98) | def compute_loss(self, model, inputs, return_outputs=False):
FILE: python/fate_llm/algo/fedcollm/fedcollm_training_args.py
class FedCoLLMTrainingArguments (line 21) | class FedCoLLMTrainingArguments(Seq2SeqTrainingArguments):
method to_dict (line 44) | def to_dict(self):
method _pop_extra (line 58) | def _pop_extra(self):
method to_slm_seq_training_args (line 65) | def to_slm_seq_training_args(self):
method to_fedco_slm_training_args (line 71) | def to_fedco_slm_training_args(self):
method to_fedco_llm_training_args (line 77) | def to_fedco_llm_training_args(self):
FILE: python/fate_llm/algo/fedcot/encoder_decoder/init/default_init.py
class FedCoTEDAPIClientInit (line 22) | class FedCoTEDAPIClientInit(InferInit):
method __init__ (line 28) | def __init__(self, ctx):
method get_inst (line 32) | def get_inst(self):
class FedCoTEDAPIServerInit (line 38) | class FedCoTEDAPIServerInit(InferInit):
method __init__ (line 44) | def __init__(self, ctx):
method get_inst (line 48) | def get_inst(self):
FILE: python/fate_llm/algo/fedcot/encoder_decoder/slm_encoder_decoder.py
class SLMEncoderDecoderClient (line 33) | class SLMEncoderDecoderClient(InferDPTClient):
method __init__ (line 35) | def __init__(self, ctx: Context, local_inference_inst: Inference) -> N...
method encode (line 41) | def encode(self, docs: List[Dict[str, str]], format_template: str = No...
method decode (line 57) | def decode(self, p_docs: List[Dict[str, str]], instruction_template: s...
method inference (line 62) | def inference(self, docs: Union[List[Dict[str, str]], HuggingfaceDatas...
class SLMEncoderDecoderServer (line 78) | class SLMEncoderDecoderServer(InferDPTServer):
FILE: python/fate_llm/algo/fedcot/fedcot_trainer.py
function save_to (line 45) | def save_to(obj, filepath, filename='tmp.pkl'):
function load (line 55) | def load(filepath, filename='tmp.pkl'):
class DSSTrainerClient (line 71) | class DSSTrainerClient(Seq2SeqTrainer):
method __init__ (line 73) | def __init__(self,
method compute_loss (line 103) | def compute_loss(self, model, inputs, return_outputs=False):
class FedCoTTrainerClient (line 111) | class FedCoTTrainerClient(DSSTrainerClient):
method __init__ (line 113) | def __init__(self,
method infer (line 182) | def infer(self) -> List[str]:
method train (line 207) | def train(self):
method get_infer_result (line 220) | def get_infer_result(self):
class FedCoTTraineServer (line 224) | class FedCoTTraineServer(object):
method __init__ (line 226) | def __init__(self, ctx: Context, infer_server: Union[SLMEncoderDecoder...
method train (line 231) | def train(self):
FILE: python/fate_llm/algo/fedcot/slm_encoder_decoder_trainer.py
class EDPrefixDataCollator (line 7) | class EDPrefixDataCollator(DataCollatorForSeq2Seq):
method __call__ (line 8) | def __call__(self, features, return_tensors=None):
class EncoderDecoderPrefixTrainer (line 19) | class EncoderDecoderPrefixTrainer(Seq2SeqTrainer):
method __init__ (line 21) | def __init__(self, alpha=0.5, *args, **kwargs):
method compute_loss (line 25) | def compute_loss(self, model, inputs, return_outputs=False):
FILE: python/fate_llm/algo/fedkseed/args.py
class KSeedTrainingArguments (line 5) | class KSeedTrainingArguments:
FILE: python/fate_llm/algo/fedkseed/fedkseed.py
class Trainer (line 17) | class Trainer:
method __init__ (line 18) | def __init__(
method get_clients (line 30) | def get_clients(ctx: Context):
method load_model (line 38) | def load_model(self):
method train (line 41) | def train(self):
method should_stop (line 87) | def should_stop(self):
method evaluate (line 90) | def evaluate(self):
class ClientTrainer (line 94) | class ClientTrainer:
method __init__ (line 95) | def __init__(self, ctx: Context, model, fedkseed_args, training_args, ...
method train (line 108) | def train(self):
method train_once (line 130) | def train_once(self, seed_candidates, seed_probabilities, direction_de...
class FedKSeedTrainingArguments (line 162) | class FedKSeedTrainingArguments(KSeedTrainingArguments):
FILE: python/fate_llm/algo/fedkseed/optimizer.py
class RandomWalkOptimizer (line 55) | class RandomWalkOptimizer(Optimizer):
method __init__ (line 62) | def __init__(self, params, lr, weight_decay, grad_clip, defaults=None):
method from_model (line 74) | def from_model(cls, model, lr, weight_decay, grad_clip, **kwargs):
method directional_derivative_step (line 81) | def directional_derivative_step(
method step (line 95) | def step(self, closure: Optional[Callable[[], float]] = None) -> Optio...
class ZerothOrderOptimizer (line 102) | class ZerothOrderOptimizer(RandomWalkOptimizer):
method __init__ (line 103) | def __init__(self, params, lr, eps, weight_decay, grad_clip):
method zeroth_order_step (line 108) | def zeroth_order_step(
method random_perturb_parameters (line 152) | def random_perturb_parameters(self, directional_derivative_seed: int, ...
class KSeedZerothOrderOptimizer (line 176) | class KSeedZerothOrderOptimizer(ZerothOrderOptimizer):
method __init__ (line 177) | def __init__(
method sample (line 193) | def sample(self) -> int:
method step (line 201) | def step(self, closure: Callable[[], torch.FloatTensor] = None) -> tor...
method kseed_zeroth_order_step (line 210) | def kseed_zeroth_order_step(self, closure: Callable[[], torch.FloatTen...
FILE: python/fate_llm/algo/fedkseed/pytorch_utils.py
function get_decay_parameter_names (line 7) | def get_decay_parameter_names(model) -> List[str]:
function get_optimizer_parameters_grouped_with_decay (line 34) | def get_optimizer_parameters_grouped_with_decay(model, weight_decay: flo...
FILE: python/fate_llm/algo/fedkseed/trainer.py
class KSeedZOExtendedTrainer (line 20) | class KSeedZOExtendedTrainer(Trainer):
method __init__ (line 21) | def __init__(
method configure_seed_candidates (line 55) | def configure_seed_candidates(self, seed_candidates: torch.LongTensor,...
method get_directional_derivative_history (line 59) | def get_directional_derivative_history(self):
method k_seed_zo_mode (line 71) | def k_seed_zo_mode(args):
method training_step (line 74) | def training_step(self, model: nn.Module, inputs: Dict[str, Union[torc...
method create_optimizer_and_scheduler (line 101) | def create_optimizer_and_scheduler(self, num_training_steps: int):
FILE: python/fate_llm/algo/fedkseed/zo_utils.py
function probability_from_amps (line 6) | def probability_from_amps(amps: List[List[float]], clip):
function directional_derivative_step (line 23) | def directional_derivative_step(
function build_seed_candidates (line 57) | def build_seed_candidates(k, low=0, high=2**32):
function get_even_seed_probabilities (line 64) | def get_even_seed_probabilities(k):
FILE: python/fate_llm/algo/fedmkt/fedmkt.py
class FedMKTTrainingArguments (line 46) | class FedMKTTrainingArguments(Seq2SeqTrainingArguments):
method to_dict (line 88) | def to_dict(self):
method to_dict_without_extra_args (line 102) | def to_dict_without_extra_args(self):
method to_dict_with_client_priv_training_args (line 128) | def to_dict_with_client_priv_training_args(self):
method to_dict_with_client_kd_args (line 135) | def to_dict_with_client_kd_args(self):
method to_dict_with_server_kd_args (line 142) | def to_dict_with_server_kd_args(self):
class FedMKTBase (line 149) | class FedMKTBase(object):
method __init__ (line 150) | def __init__(self, *args, **kwargs):
method save_model (line 154) | def save_model(
class FedMKTSLM (line 175) | class FedMKTSLM(FedMKTBase):
method __init__ (line 176) | def __init__(
method train (line 229) | def train(self):
method _init_trainer_for_distill (line 314) | def _init_trainer_for_distill(self, train_set):
method _get_priv_data_training_args (line 339) | def _get_priv_data_training_args(self):
method _get_pub_data_kd_training_args (line 345) | def _get_pub_data_kd_training_args(self):
method _init_aggregator (line 351) | def _init_aggregator(self, ctx: Context, fed_args: FedArguments):
class FedMKTLLM (line 364) | class FedMKTLLM(FedMKTBase):
method __init__ (line 365) | def __init__(
method _init_aggregator (line 409) | def _init_aggregator(self, ctx: Context):
method generate_pub_data_logits (line 414) | def generate_pub_data_logits(self, first_epoch=False):
method on_epoch_begin (line 430) | def on_epoch_begin(self, iter_ctx, epoch_idx, previous_pub_dataset):
method on_epoch_end (line 484) | def on_epoch_end(self, iter_ctx, epoch_idx):
method _get_pub_data_kd_training_args (line 510) | def _get_pub_data_kd_training_args(self):
method train (line 516) | def train(self):
FILE: python/fate_llm/algo/fedmkt/fedmkt_data_collator.py
class DataCollatorForFedMKT (line 40) | class DataCollatorForFedMKT(DataCollatorForSeq2Seq):
method __init__ (line 54) | def __init__(self, *args, **kwargs):
method __call__ (line 66) | def __call__(self, features, return_tensors=None):
FILE: python/fate_llm/algo/fedmkt/fedmkt_trainer.py
class FedMKTTrainer (line 36) | class FedMKTTrainer(Seq2SeqTrainer):
method __init__ (line 45) | def __init__(self, *args, **kwargs):
method compute_loss (line 56) | def compute_loss(self, model, inputs, return_outputs=False):
FILE: python/fate_llm/algo/fedmkt/token_alignment/token_align.py
function dtw (line 40) | def dtw(series_1, series_2, norm_func=np.linalg.norm):
function greedy_dynamic_matching (line 84) | def greedy_dynamic_matching(base_model_tokens, blending_model_tokens, ba...
function align_blending_model_logits_with_base_model_logits (line 179) | def align_blending_model_logits_with_base_model_logits(base_examples,
function transform_step_logits (line 231) | def transform_step_logits(base_model_tokenizer: transformers.tokenizatio...
function token_align (line 353) | def token_align(
FILE: python/fate_llm/algo/fedmkt/token_alignment/vocab_mapping.py
function find_best_mapping (line 32) | def find_best_mapping(x, base_tokens, blending_model_special_token, base...
function get_vocab_mappings (line 47) | def get_vocab_mappings(model_name_or_path, candidate_model_name_or_path,...
FILE: python/fate_llm/algo/fedmkt/utils/dataset_sync_util.py
function sync_dataset (line 29) | def sync_dataset(dataset, local_rank, world_size, device):
FILE: python/fate_llm/algo/fedmkt/utils/generate_logit_utils.py
class Metric (line 26) | class Metric(object):
method cal_metric (line 28) | def cal_metric(cls, logits, input_ids, attention_mask, labels, trainin...
method cal_ce (line 35) | def cal_ce(cls, logits, input_ids, attention_mask, labels, training_ar...
class LogitsSelection (line 44) | class LogitsSelection(object):
method select_logits (line 46) | def select_logits(cls, logits, training_args):
method select_highest (line 53) | def select_highest(cls, logits, top_k_logits_keep):
function generate_pub_data_logits (line 60) | def generate_pub_data_logits(inputs, model, training_args, data_collator):
FILE: python/fate_llm/algo/fedmkt/utils/tokenizer_tool.py
function get_vocab_size (line 19) | def get_vocab_size(tokenizer_name_or_path):
FILE: python/fate_llm/algo/inferdpt/_encode_decode.py
class EncoderDecoder (line 25) | class EncoderDecoder(object):
method __init__ (line 27) | def __init__(self, ctx: Context) -> None:
method encode (line 30) | def encode(self, docs: List[Dict[str, str]], format_template: str):
method decode (line 33) | def decode(self, docs: List[Dict[str, str]], format_template: str ):
method inference (line 36) | def inference(self, docs: List[Dict[str, str]], inference_kwargs: dict...
FILE: python/fate_llm/algo/inferdpt/inferdpt.py
class InferDPTClient (line 34) | class InferDPTClient(EncoderDecoder):
method __init__ (line 36) | def __init__(self, ctx: Context, inferdpt_pertub_kit: InferDPTKit, loc...
method encode (line 44) | def encode(self, docs: List[Dict[str, str]], format_template: str = No...
method _remote_inference (line 64) | def _remote_inference(self, docs: List[Dict[str, str]],
method decode (line 97) | def decode(self, p_docs: List[Dict[str, str]], instruction_template: s...
method inference (line 115) | def inference(self, docs: Union[List[Dict[str, str]], HuggingfaceDatas...
class InferDPTServer (line 149) | class InferDPTServer(object):
method __init__ (line 151) | def __init__(self, ctx: Context, inference_inst: Inference) -> None:
method inference (line 157) | def inference(self, verbose=False):
method predict (line 170) | def predict(self):
FILE: python/fate_llm/algo/inferdpt/init/_init.py
class InferInit (line 21) | class InferInit(object):
method __init__ (line 23) | def __init__(self, ctx: Context):
method get_inst (line 26) | def get_inst(self):
FILE: python/fate_llm/algo/inferdpt/init/default_init.py
class InferDPTAPIClientInit (line 23) | class InferDPTAPIClientInit(InferInit):
method __init__ (line 31) | def __init__(self, ctx):
method get_inst (line 35) | def get_inst(self)-> InferDPTClient:
class InferDPTAPIServerInit (line 42) | class InferDPTAPIServerInit(InferInit):
method __init__ (line 48) | def __init__(self, ctx):
method get_inst (line 52) | def get_inst(self)-> InferDPTServer:
FILE: python/fate_llm/algo/inferdpt/utils.py
class NumpyEncoder (line 33) | class NumpyEncoder(json.JSONEncoder):
method default (line 36) | def default(self, obj):
function save_jsonl (line 46) | def save_jsonl(filename, data):
function create_sensitivity_of_embeddings (line 53) | def create_sensitivity_of_embeddings(all_embedding_matrix):
function create_sorted_embedding_matrix (line 64) | def create_sorted_embedding_matrix(token_list, similarity_matrix):
function cosine_similarity_vectors (line 74) | def cosine_similarity_vectors(A, B):
class InferDPTKit (line 82) | class InferDPTKit(object):
method __init__ (line 84) | def __init__(self, token_to_vector_dict, sorted_similarities, delta_f,...
method save_to_path (line 92) | def save_to_path(self, path):
method make_inferdpt_kit_param (line 111) | def make_inferdpt_kit_param(embedding_matrix: np.ndarray, token_list: ...
method load_from_path (line 132) | def load_from_path(path):
method perturb (line 144) | def perturb(self, doc: str, epsilon: float) -> str:
function cosine_similarity_vectors (line 184) | def cosine_similarity_vectors(A, B):
function add_laplace_noise_to_vector (line 192) | def add_laplace_noise_to_vector(vector, epsilon, delta_f_new):
function perturb_sentence (line 204) | def perturb_sentence(sent,
FILE: python/fate_llm/algo/offsite_tuning/offsite_tuning.py
class OffsiteTuningTrainerClient (line 22) | class OffsiteTuningTrainerClient(Seq2SeqFedAVGClient):
method __init__ (line 24) | def __init__(
method _share_model (line 70) | def _share_model(self, model, args: Seq2SeqTrainingArguments, sync_tra...
method on_train_begin (line 82) | def on_train_begin(self, ctx: Context, aggregator: Aggregator, fed_arg...
method on_federation (line 102) | def on_federation(
method on_train_end (line 120) | def on_train_end(self, ctx: Context, aggregator: Aggregator, fed_args:...
method init_aggregator (line 132) | def init_aggregator(self, ctx: Context, fed_args: FedArguments):
class OffsiteTuningTrainerServer (line 139) | class OffsiteTuningTrainerServer(Seq2SeqFedAVGServer):
method __init__ (line 141) | def __init__(self, ctx: Context, model: OffsiteTuningBaseModel, aggreg...
method on_train_begin (line 147) | def on_train_begin(self, ctx: Context, aggregator: Aggregator):
method on_train_end (line 154) | def on_train_end(self, ctx: Context, aggregator: Aggregator):
method on_federation (line 159) | def on_federation(self, ctx: Context, aggregator, agg_iter_idx: int):
method init_aggregator (line 165) | def init_aggregator(self, ctx):
method train (line 171) | def train(self):
method save_model (line 182) | def save_model(
FILE: python/fate_llm/data/data_collator/cust_data_collator.py
function get_data_collator (line 20) | def get_data_collator(data_collator_name,
function get_seq2seq_data_collator (line 45) | def get_seq2seq_data_collator(tokenizer_name_or_path, **kwargs):
FILE: python/fate_llm/data/data_collator/fedcot_collator.py
class PrefixDataCollator (line 5) | class PrefixDataCollator(DataCollatorForSeq2Seq):
method __call__ (line 6) | def __call__(self, features, return_tensors=None):
function get_prefix_data_collator (line 17) | def get_prefix_data_collator(tokenizer_name_or_path):
FILE: python/fate_llm/data/tokenizers/cust_tokenizer.py
function get_tokenizer (line 19) | def get_tokenizer(
FILE: python/fate_llm/dataset/fedcot_dataset.py
class PrefixDataset (line 12) | class PrefixDataset(InputOutputDataset):
method __init__ (line 14) | def __init__(self,
method load_rationale (line 30) | def load_rationale(self, result_list, key='rationale'):
method get_str_item (line 34) | def get_str_item(self, i) -> dict:
method get_tokenized_item (line 53) | def get_tokenized_item(self, i) -> dict:
FILE: python/fate_llm/dataset/flex_dataset.py
function get_jinjax_placeholders (line 38) | def get_jinjax_placeholders(jinjax_text, placeholder_count=2):
function regex_replace (line 45) | def regex_replace(string, pattern, repl, count: int = 0):
function apply_template (line 62) | def apply_template(template, data):
function tokenize_flex_dataset (line 77) | def tokenize_flex_dataset(raw_datasets, tokenizer, sub_domain, tokenize_...
class FlexDataset (line 116) | class FlexDataset(Dataset):
method __init__ (line 117) | def __init__(self,
method parse_config (line 161) | def parse_config(self, config=None):
method get_generate_prompt (line 176) | def get_generate_prompt(self, tokenize=True, return_tensors="pt"):
method construct_prompt_list (line 187) | def construct_prompt_list(samples_dict, num_shot_per_label, prompt_num...
method group_text_label_list (line 221) | def group_text_label_list(text_list, label_list):
method prepare_few_shot (line 225) | def prepare_few_shot(self, text_list, label_list, aug_prompt_num):
method prepare_augment (line 240) | def prepare_augment(self, text_list, label_list, aug_prompt_num):
method abstract_from_augmented (line 250) | def abstract_from_augmented(self, sample_list):
method prepare_query_to_filter_clustered (line 275) | def prepare_query_to_filter_clustered(self, clustered_sentences_list, ...
method parse_clustered_response (line 288) | def parse_clustered_response(self, clustered_sentence, clustered_label...
method group_data_list (line 315) | def group_data_list(data_list, text_key, label_key):
method load (line 321) | def load(self, path):
method apply_chat_template (line 341) | def apply_chat_template(self, query):
method get_raw_dataset (line 358) | def get_raw_dataset(self):
method __len__ (line 361) | def __len__(self):
method get_item (line 364) | def get_item(self, i):
method get_item_dict (line 367) | def get_item_dict(self, i):
method __getitem__ (line 371) | def __getitem__(self, i) -> dict:
FILE: python/fate_llm/dataset/hf_dataset.py
class HuggingfaceDataset (line 28) | class HuggingfaceDataset(Dataset):
method __init__ (line 33) | def __init__(
method load (line 92) | def load(self, file_path):
method __getitem__ (line 113) | def __getitem__(self, idx):
method __len__ (line 118) | def __len__(self):
class Dolly15K (line 124) | class Dolly15K(HuggingfaceDataset):
method __init__ (line 174) | def __init__(self, *args, **kwargs):
method load (line 178) | def load(self, file_path):
method _post_process (line 182) | def _post_process(self, dataset):
FILE: python/fate_llm/dataset/input_output_dataset.py
class InputOutputDataset (line 12) | class InputOutputDataset(Dataset):
method __init__ (line 14) | def __init__(self,
method load (line 37) | def load(self, path):
method get_raw_dataset (line 66) | def get_raw_dataset(self):
method __len__ (line 69) | def __len__(self):
method get_str_item (line 72) | def get_str_item(self, i) -> dict:
method _process_item (line 82) | def _process_item(self, data_item):
method get_tokenized_item (line 105) | def get_tokenized_item(self, i) -> dict:
method __getitem__ (line 111) | def __getitem__(self, i) -> dict:
FILE: python/fate_llm/dataset/prompt_dataset.py
class PromptDataset (line 28) | class PromptDataset(Dataset):
method __init__ (line 29) | def __init__(self,
method load (line 75) | def load(self, file_path):
method _process_data (line 121) | def _process_data(examples, tokenizer, prompt_template, prompt_column,
method _pad_to_max_length (line 172) | def _pad_to_max_length(examples, tokenizer, max_length):
method get_vocab_size (line 196) | def get_vocab_size(self):
method __getitem__ (line 199) | def __getitem__(self, item):
method __len__ (line 202) | def __len__(self):
method __repr__ (line 205) | def __repr__(self):
FILE: python/fate_llm/dataset/qa_dataset.py
class PIQA (line 25) | class PIQA:
method __init__ (line 26) | def __init__(self):
method get_context (line 29) | def get_context(self, examples):
method get_target (line 33) | def get_target(self, examples):
class SciQ (line 42) | class SciQ:
method __init__ (line 43) | def __init__(self):
method get_context (line 46) | def get_context(self, examples):
method get_target (line 51) | def get_target(self, examples):
class OpenBookQA (line 55) | class OpenBookQA:
method get_context (line 56) | def get_context(self, examples):
method get_target (line 59) | def get_target(self, examples):
class ARC (line 69) | class ARC:
method __init__ (line 70) | def __init__(self):
method get_context (line 73) | def get_context(self, examples):
method get_target (line 77) | def get_target(self, examples):
class WIC (line 88) | class WIC:
method __init__ (line 89) | def __init__(self):
method get_context (line 93) | def get_context(self, examples):
method get_target (line 107) | def get_target(self, examples):
class BoolQ (line 116) | class BoolQ:
method __init__ (line 117) | def __init__(self):
method get_context (line 120) | def get_context(self, examples):
method get_target (line 127) | def get_target(self, examples):
class CommonsenseQA (line 130) | class CommonsenseQA:
method get_context (line 131) | def get_context(self, examples):
method get_target (line 134) | def get_target(self, examples):
class RTE (line 144) | class RTE:
method __init__ (line 145) | def __init__(self):
method get_context (line 148) | def get_context(self, examples):
method get_target (line 159) | def get_target(self, examples):
function tokenize_qa_dataset (line 177) | def tokenize_qa_dataset(dataset_name, tokenizer, save_path=None, seq_max...
class QaDataset (line 268) | class QaDataset(Dataset):
method __init__ (line 270) | def __init__(self,
method load (line 296) | def load(self, path):
method set_return_with_idx (line 317) | def set_return_with_idx(self):
method reset_return_with_idx (line 320) | def reset_return_with_idx(self):
method __len__ (line 323) | def __len__(self):
method __getitem__ (line 326) | def __getitem__(self, idx):
FILE: python/fate_llm/dataset/seq_cls_dataset.py
class SeqCLSDataset (line 27) | class SeqCLSDataset(Dataset):
method __init__ (line 45) | def __init__(
method load (line 74) | def load(self, file_path):
method get_classes (line 97) | def get_classes(self):
method get_vocab_size (line 100) | def get_vocab_size(self):
method get_sample_ids (line 103) | def get_sample_ids(self):
method __getitem__ (line 106) | def __getitem__(self, item):
method __len__ (line 118) | def __len__(self):
method __repr__ (line 121) | def __repr__(self):
FILE: python/fate_llm/evaluate/scripts/_options.py
function parse_custom_type (line 9) | def parse_custom_type(value):
class LlmSharedOptions (line 19) | class LlmSharedOptions(object):
method __init__ (line 31) | def __init__(self):
method __getitem__ (line 34) | def __getitem__(self, item):
method get (line 37) | def get(self, k, default=None):
method update (line 43) | def update(self, **kwargs):
method post_process (line 48) | def post_process(self):
method get_shared_options (line 61) | def get_shared_options(cls, hidden=False):
FILE: python/fate_llm/evaluate/scripts/config_cli.py
function eval_config_group (line 26) | def eval_config_group():
function _new (line 34) | def _new():
function _edit (line 45) | def _edit(ctx, **kwargs):
function _show (line 56) | def _show():
FILE: python/fate_llm/evaluate/scripts/data_cli.py
function download_data (line 34) | def download_data(ctx, tasks, **kwargs):
FILE: python/fate_llm/evaluate/scripts/eval_cli.py
function run_evaluate (line 41) | def run_evaluate(ctx, include, eval_config, result_output, **kwargs):
function run_job_eval (line 70) | def run_job_eval(job, eval_conf):
function run_suite_eval (line 102) | def run_suite_eval(suite, eval_conf, output_path=None):
FILE: python/fate_llm/evaluate/scripts/fate_llm_cli.py
class FATELlmCLI (line 35) | class FATELlmCLI(click.MultiCommand):
method list_commands (line 37) | def list_commands(self, ctx):
method get_command (line 40) | def get_command(self, ctx, name):
function fate_llm_cli (line 51) | def fate_llm_cli(ctx, **kwargs):
FILE: python/fate_llm/evaluate/tasks/__init__.py
function local_fn_constructor (line 21) | def local_fn_constructor(loader, node):
function local_fn_representer (line 25) | def local_fn_representer(dumper, data):
function dump_yaml (line 29) | def dump_yaml(dict, path):
class Task (line 34) | class Task:
method task_name (line 42) | def task_name(self):
method task_template (line 46) | def task_template(self):
method task_scr_dir (line 53) | def task_scr_dir(self):
method task_conf_path (line 57) | def task_conf_path(self):
method task_source_url (line 61) | def task_source_url(self):
method download_from_source (line 64) | def download_from_source(self):
class Dolly (line 68) | class Dolly(Task):
method download_from_source (line 73) | def download_from_source(self):
class AdvertiseGen (line 85) | class AdvertiseGen(Task):
method download_from_source (line 92) | def download_from_source(self):
FILE: python/fate_llm/evaluate/tasks/advertise_gen/advertise_utils.py
function rouge_l (line 8) | def rouge_l(predictions, references, use_stemmer=False):
FILE: python/fate_llm/evaluate/tasks/dolly_15k/dolly_utils.py
function rouge_l (line 7) | def rouge_l(predictions, references, use_stemmer=False):
function doc_to_text (line 17) | def doc_to_text(doc):
FILE: python/fate_llm/evaluate/utils/_io.py
class echo (line 21) | class echo(object):
method set_file (line 25) | def set_file(cls, file):
method echo (line 29) | def echo(cls, message, **kwargs):
method sep_line (line 34) | def sep_line(cls):
method file (line 38) | def file(cls, message, **kwargs):
method stdout (line 42) | def stdout(cls, message, **kwargs):
method stdout_newline (line 46) | def stdout_newline(cls):
method welcome (line 50) | def welcome(cls):
method flush (line 55) | def flush(cls):
function set_logger (line 60) | def set_logger(name):
FILE: python/fate_llm/evaluate/utils/_parser.py
class LlmJob (line 23) | class LlmJob(object):
method __init__ (line 24) | def __init__(self, job_name: str, script_path: Path=None, conf_path: P...
class LlmPair (line 43) | class LlmPair(object):
method __init__ (line 44) | def __init__(
class LlmSuite (line 51) | class LlmSuite(object):
method __init__ (line 52) | def __init__(
method load (line 61) | def load(path: Path):
method update_status (line 131) | def update_status(
method get_final_status (line 139) | def get_final_status(self):
FILE: python/fate_llm/evaluate/utils/config.py
function create_eval_config (line 44) | def create_eval_config(path: Path, override=False):
function default_eval_config (line 52) | def default_eval_config():
class Config (line 58) | class Config(object):
method __init__ (line 59) | def __init__(self, config):
method update_conf (line 62) | def update_conf(self, **kwargs):
method load (line 67) | def load(path: typing.Union[str, Path], **kwargs):
method load_from_file (line 79) | def load_from_file(path: typing.Union[str, Path]):
function parse_config (line 104) | def parse_config(config):
function _set_namespace (line 112) | def _set_namespace(namespace):
FILE: python/fate_llm/evaluate/utils/data_tools.py
function download_data (line 18) | def download_data(data_dir, data_url, is_tar=True):
FILE: python/fate_llm/evaluate/utils/llm_evaluator.py
function evaluate (line 32) | def evaluate(tasks, model="hf", model_args=None, include_path=None, task...
function aggregate_table (line 83) | def aggregate_table(results):
function get_task_template (line 151) | def get_task_template(task):
function export_config (line 159) | def export_config(config, task, export_dir=None, export_sub_dir=None):
function copy_directory_to_dst (line 176) | def copy_directory_to_dst(src_dir, dst_dir, target_conf_file, new_conf: ...
function contains_subdirectory (line 199) | def contains_subdirectory(path, subdirectories):
function delete_config (line 211) | def delete_config(target_dir, force=False):
function set_environ_fate_llm_base (line 221) | def set_environ_fate_llm_base(path):
function set_environ_fate_llm_task_base (line 226) | def set_environ_fate_llm_task_base(path):
function init_tasks (line 231) | def init_tasks(root_path=None):
function download_task (line 274) | def download_task(tasks=None):
FILE: python/fate_llm/evaluate/utils/model_tools.py
function load_model_from_path (line 22) | def load_model_from_path(model_path, peft_path=None, peft_config=None, m...
function load_model (line 44) | def load_model(model_path, peft_path=None, model_args=None):
function load_by_loader (line 49) | def load_by_loader(loader_name=None, loader_conf_path=None, peft_path=No...
FILE: python/fate_llm/inference/api.py
class APICompletionInference (line 23) | class APICompletionInference(Inference):
method __init__ (line 25) | def __init__(self, api_url: str, model_name: str, api_key: str = 'EMPT...
method inference (line 34) | def inference(self, docs: List[str], inference_kwargs: dict = {}) -> L...
FILE: python/fate_llm/inference/hf_qw.py
class QwenHFCompletionInference (line 23) | class QwenHFCompletionInference(Inference):
method __init__ (line 25) | def __init__(self, model, tokenizer):
method inference (line 29) | def inference(self, docs: List[str], inference_kwargs: dict = {}) -> L...
FILE: python/fate_llm/inference/inference_base.py
class Inference (line 20) | class Inference(object):
method __init__ (line 22) | def __init__(self):
method inference (line 25) | def inference(self, docs: List[str], inference_kwargs: dict = {}) -> L...
FILE: python/fate_llm/inference/vllm.py
class VLLMInference (line 27) | class VLLMInference(Inference):
method __init__ (line 29) | def __init__(self, model_path, num_gpu=1, dtype='float16', gpu_memory_...
method inference (line 34) | def inference(self, docs: List[str], inference_kwargs: dict = {}) -> L...
FILE: python/fate_llm/model_zoo/embedding_transformer/st_model.py
class SentenceTransformerModel (line 20) | class SentenceTransformerModel(object):
method __init__ (line 21) | def __init__(
method load (line 53) | def load(self):
FILE: python/fate_llm/model_zoo/hf_model.py
class HFAutoModelForCausalLM (line 20) | class HFAutoModelForCausalLM:
method __init__ (line 22) | def __init__(self, pretrained_model_name_or_path, *model_args, **kwarg...
method load (line 30) | def load(self):
FILE: python/fate_llm/model_zoo/offsite_tuning/bloom.py
class BloomMainModel (line 23) | class BloomMainModel(OffsiteTuningMainModel):
method __init__ (line 25) | def __init__(
method get_base_model (line 38) | def get_base_model(self):
method get_model_transformer_blocks (line 41) | def get_model_transformer_blocks(self, model: BloomForCausalLM):
method get_additional_param_state_dict (line 44) | def get_additional_param_state_dict(self):
method load_additional_param_state_dict (line 60) | def load_additional_param_state_dict(self, submodel_weights: dict):
method forward (line 80) | def forward(
class BloomSubModel (line 110) | class BloomSubModel(OffsiteTuningSubModel):
method __init__ (line 112) | def __init__(
method get_base_model (line 132) | def get_base_model(self):
method get_model_transformer_blocks (line 140) | def get_model_transformer_blocks(self, model: BloomForCausalLM):
method get_additional_param_state_dict (line 143) | def get_additional_param_state_dict(self):
method load_additional_param_state_dict (line 159) | def load_additional_param_state_dict(self, submodel_weights: dict):
method forward (line 179) | def forward(
method parameters (line 208) | def parameters(self, recurse=True):
FILE: python/fate_llm/model_zoo/offsite_tuning/gpt2.py
class GPT2LMHeadMainModel (line 22) | class GPT2LMHeadMainModel(OffsiteTuningMainModel):
method __init__ (line 24) | def __init__(
method get_base_model (line 37) | def get_base_model(self):
method get_model_transformer_blocks (line 40) | def get_model_transformer_blocks(self, model: GPT2LMHeadModel):
method forward (line 43) | def forward(self,
method get_additional_param_state_dict (line 75) | def get_additional_param_state_dict(self):
method load_additional_param_state_dict (line 94) | def load_additional_param_state_dict(self, submodel_weights: dict):
class GPT2LMHeadSubModel (line 119) | class GPT2LMHeadSubModel(OffsiteTuningSubModel):
method __init__ (line 121) | def __init__(
method get_base_model (line 141) | def get_base_model(self):
method get_model_transformer_blocks (line 149) | def get_model_transformer_blocks(self, model: GPT2LMHeadModel):
method get_additional_param_state_dict (line 152) | def get_additional_param_state_dict(self):
method load_additional_param_state_dict (line 171) | def load_additional_param_state_dict(self, submodel_weights: dict):
method forward (line 195) | def forward(self,
method parameters (line 227) | def parameters(self, recurse=True):
FILE: python/fate_llm/model_zoo/offsite_tuning/llama.py
class LlamaMainModel (line 20) | class LlamaMainModel(OffsiteTuningMainModel):
method __init__ (line 22) | def __init__(
method get_base_model (line 35) | def get_base_model(self):
method get_model_transformer_blocks (line 38) | def get_model_transformer_blocks(self, model: LlamaForCausalLM):
method get_additional_param_state_dict (line 41) | def get_additional_param_state_dict(self):
method load_additional_param_state_dict (line 56) | def load_additional_param_state_dict(self, submodel_weights: dict):
method forward (line 74) | def forward(self, **kwargs):
class LlamaSubModel (line 78) | class LlamaSubModel(OffsiteTuningSubModel):
method __init__ (line 80) | def __init__(
method get_base_model (line 100) | def get_base_model(self):
method get_model_transformer_blocks (line 108) | def get_model_transformer_blocks(self, model: LlamaForCausalLM):
method get_additional_param_state_dict (line 111) | def get_additional_param_state_dict(self):
method load_additional_param_state_dict (line 126) | def load_additional_param_state_dict(self, submodel_weights: dict):
method forward (line 144) | def forward(self, **kwargs):
method parameters (line 147) | def parameters(self, recurse=True):
FILE: python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py
function get_dropout_emulator_and_adapters (line 25) | def get_dropout_emulator_and_adapters(
function split_numpy_array (line 61) | def split_numpy_array(embedding_matrix, n, suffix):
function recover_numpy_array (line 74) | def recover_numpy_array(slices_dict, suffix=""):
class OffsiteTuningBaseModel (line 81) | class OffsiteTuningBaseModel(t.nn.Module):
method __init__ (line 83) | def __init__(self, emulator_layer_num: int, adapter_top_layer_num: int...
method initialize_model (line 97) | def initialize_model(self):
method post_initialization (line 103) | def post_initialization(self):
method get_adapter_top (line 106) | def get_adapter_top(self):
method get_adapter_bottom (line 109) | def get_adapter_bottom(self):
method get_emulator (line 112) | def get_emulator(self):
method get_additional_param_state_dict (line 115) | def get_additional_param_state_dict(self):
method load_additional_param_state_dict (line 119) | def load_additional_param_state_dict(self, submodel_weights: dict):
method _get_numpy_arr (line 123) | def _get_numpy_arr(self, v):
method load_numpy_state_dict (line 133) | def load_numpy_state_dict(self, module_dict, state_dict):
method get_numpy_state_dict (line 144) | def get_numpy_state_dict(self, module_dict):
method get_submodel_weights (line 153) | def get_submodel_weights(self, with_emulator=True) -> dict:
method load_submodel_weights (line 177) | def load_submodel_weights(self, submodel_weights: dict, with_emulator=...
method forward (line 201) | def forward(self, **kwargs):
method get_base_model (line 204) | def get_base_model(self):
method get_model_transformer_blocks (line 207) | def get_model_transformer_blocks(self, model: t.nn.Module):
class OffsiteTuningMainModel (line 211) | class OffsiteTuningMainModel(OffsiteTuningBaseModel):
method post_initialization (line 213) | def post_initialization(self):
class OffsiteTuningSubModel (line 217) | class OffsiteTuningSubModel(OffsiteTuningBaseModel):
method post_initialization (line 219) | def post_initialization(self):
FILE: python/fate_llm/model_zoo/pellm/albert.py
class Albert (line 21) | class Albert(PELLM):
method __init__ (line 26) | def __init__(self, config: dict = None,
method check_config (line 44) | def check_config(self, pretain_path):
FILE: python/fate_llm/model_zoo/pellm/bart.py
class Bart (line 21) | class Bart(PELLM):
method __init__ (line 25) | def __init__(self, config: dict = None,
method check_config (line 42) | def check_config(self, pretrain_path):
FILE: python/fate_llm/model_zoo/pellm/bert.py
class Bert (line 21) | class Bert(PELLM):
method __init__ (line 25) | def __init__(self, config: dict = None,
method check_config (line 42) | def check_config(self, pretrain_path):
FILE: python/fate_llm/model_zoo/pellm/bloom.py
class Bloom (line 21) | class Bloom(PELLM):
method __init__ (line 26) | def __init__(self, config: dict = None,
FILE: python/fate_llm/model_zoo/pellm/chatglm.py
class ChatGLM (line 20) | class ChatGLM(PELLM):
method __init__ (line 21) | def __init__(self,
method init_config (line 38) | def init_config(self):
method add_peft (line 44) | def add_peft(self):
FILE: python/fate_llm/model_zoo/pellm/deberta.py
class Deberta (line 21) | class Deberta(PELLM):
method __init__ (line 26) | def __init__(self, config: dict = None,
method check_config (line 43) | def check_config(self, pretrain_path):
FILE: python/fate_llm/model_zoo/pellm/distilbert.py
class DistilBert (line 21) | class DistilBert(PELLM):
method __init__ (line 25) | def __init__(self, config: dict = None,
method check_config (line 42) | def check_config(self, pretrain_path):
FILE: python/fate_llm/model_zoo/pellm/gpt2.py
class GPT2 (line 21) | class GPT2(PELLM):
method __init__ (line 25) | def __init__(self,
method check_config (line 43) | def check_config(self, pretrain_path):
class GPT2CLM (line 50) | class GPT2CLM(GPT2):
FILE: python/fate_llm/model_zoo/pellm/llama.py
class LLaMa (line 22) | class LLaMa(PELLM):
method __init__ (line 25) | def __init__(self,
method init_base_lm (line 36) | def init_base_lm(self, **kwargs):
method check_config (line 48) | def check_config(self, pretrain_path):
FILE: python/fate_llm/model_zoo/pellm/opt.py
class OPT (line 21) | class OPT(PELLM):
method __init__ (line 26) | def __init__(self, config: dict = None,
FILE: python/fate_llm/model_zoo/pellm/parameter_efficient_llm.py
class PELLM (line 36) | class PELLM(torch.nn.Module):
method __init__ (line 41) | def __init__(self,
method _init_pelm (line 64) | def _init_pelm(self, **kwargs):
method init_lm_with_peft (line 68) | def init_lm_with_peft(self, **kwargs):
method init_config (line 73) | def init_config(self, **kwargs):
method init_base_lm (line 86) | def init_base_lm(self, **kwargs):
method add_peft (line 102) | def add_peft(self):
method model_summary (line 116) | def model_summary(self):
method forward (line 121) | def forward(self, *args, **kwargs):
method save_trainable (line 129) | def save_trainable(self, output_path):
class AutoPELLM (line 133) | class AutoPELLM(PELLM):
method __init__ (line 135) | def __init__(self, **kwargs) -> None:
FILE: python/fate_llm/model_zoo/pellm/qwen.py
class Qwen (line 21) | class Qwen(PELLM):
method __init__ (line 26) | def __init__(self, config: dict = None,
FILE: python/fate_llm/model_zoo/pellm/roberta.py
class Roberta (line 21) | class Roberta(PELLM):
method __init__ (line 25) | def __init__(self, config: dict = None,
method check_config (line 42) | def check_config(self, pretrain_path):
FILE: python/fate_llm/runner/fdkt_runner.py
class FDKTRunner (line 37) | class FDKTRunner(DefaultRunner):
method __init__ (line 38) | def __init__(
method common_setup (line 73) | def common_setup(self, saved_model=None, output_dir=None):
method llm_setup (line 112) | def llm_setup(self, train_set=None, validate_set=None, output_dir=None...
method slm_setup (line 140) | def slm_setup(self, train_set=None, validate_set=None, output_dir=None...
method train (line 169) | def train(
method predict (line 195) | def predict(self, *args, **kwargs):
FILE: python/fate_llm/runner/fedcot_runner.py
function _check_instances (line 43) | def _check_instances(
class FedCoTRunner (line 68) | class FedCoTRunner(NNRunner):
method __init__ (line 69) | def __init__(
method _get_infer_inst (line 111) | def _get_infer_inst(self, init_conf):
method _prepare_data (line 121) | def _prepare_data(self, data, data_name):
method client_setup (line 153) | def client_setup(self, train_set=None, validate_set=None, output_dir=N...
method server_setup (line 225) | def server_setup(self, stage="train"):
method train (line 232) | def train(
method predict (line 267) | def predict(self, test_data: Union[str], saved_model_path: str = None)...
FILE: python/fate_llm/runner/fedkseed_runner.py
class FedKSeedRunner (line 38) | class FedKSeedRunner(DefaultRunner):
method __init__ (line 39) | def __init__(
method client_setup (line 77) | def client_setup(self, train_set=None, validate_set=None, output_dir=N...
method server_setup (line 113) | def server_setup(self, stage="train"):
function maybe_loader_load_from_conf (line 127) | def maybe_loader_load_from_conf(conf):
FILE: python/fate_llm/runner/fedmkt_runner.py
class FedMKTRunner (line 35) | class FedMKTRunner(DefaultRunner):
method __init__ (line 37) | def __init__(
method common_setup (line 83) | def common_setup(self, saved_model=None, output_dir=None):
method llm_setup (line 134) | def llm_setup(self, train_set=None, validate_set=None, output_dir=None...
method slm_setup (line 169) | def slm_setup(self, train_set=None, validate_set=None, output_dir=None...
method train (line 206) | def train(
method predict (line 230) | def predict(self, *args, **kwargs):
FILE: python/fate_llm/runner/homo_seq2seq_runner.py
function _check_instances (line 44) | def _check_instances(
class Seq2SeqRunner (line 81) | class Seq2SeqRunner(DefaultRunner):
method __init__ (line 82) | def __init__(
method client_setup (line 120) | def client_setup(self, train_set=None, validate_set=None, output_dir=N...
method server_setup (line 192) | def server_setup(self, stage="train"):
method predict (line 204) | def predict(self, test_data: Union[str, DataFrame], saved_model_path: ...
FILE: python/fate_llm/runner/inferdpt_runner.py
class InferDPTRunner (line 45) | class InferDPTRunner(NNRunner):
method __init__ (line 47) | def __init__(
method _get_inst (line 71) | def _get_inst(self):
method client_setup (line 79) | def client_setup(self):
method server_setup (line 84) | def server_setup(self):
method _prepare_data (line 89) | def _prepare_data(self, data, data_name):
method train (line 115) | def train(
method predict (line 145) | def predict(
FILE: python/fate_llm/runner/offsite_tuning_runner.py
class OTRunner (line 41) | class OTRunner(Seq2SeqRunner):
method __init__ (line 43) | def __init__(
method setup (line 65) | def setup(self, train_set=None, validate_set=None, output_dir=None, sa...
method server_setup (line 144) | def server_setup(self, stage="train"):
method train (line 157) | def train(
FILE: python/fate_llm/trainer/seq2seq_trainer.py
class _S2STrainingArguments (line 43) | class _S2STrainingArguments(_hf_Seq2SeqTrainingArguments):
method __post_init__ (line 60) | def __post_init__(self):
class Seq2SeqTrainingArguments (line 75) | class Seq2SeqTrainingArguments(_S2STrainingArguments):
method to_dict (line 78) | def to_dict(self):
class HomoSeq2SeqTrainerClient (line 88) | class HomoSeq2SeqTrainerClient(Seq2SeqTrainer, HomoTrainerMixin):
method __init__ (line 90) | def __init__(
method _save (line 150) | def _save(
Condensed preview — 172 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (843K chars).
[
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 3361,
"preview": "# FATE-LLM\nFATE-LLM is a framework to support federated learning for large language models(LLMs) and small language mode"
},
{
"path": "RELEASE.md",
"chars": 2933,
"preview": "## Release 2.2.0\n### Major Features and Improvements\n* Integrate the FedCoT (Federated Chain-of-Thought) algorithm, a no"
},
{
"path": "doc/fate_llm_evaluate.md",
"chars": 5254,
"preview": "## FATE-LLM Python SDK\n\nFATE-LLM Python SDK provides simple API for evaluating large language models.\nBuilt on [lm-evalu"
},
{
"path": "doc/standalone_deploy.md",
"chars": 2533,
"preview": "# FATE-LLM Single-Node Deployment Guide\n\n## 1. Introduction\n\n**Server Configuration:**\n\n- **Quantity:** 1\n- **Configurat"
},
{
"path": "doc/tutorial/fdkt/README.md",
"chars": 724,
"preview": "# FATE-LLM: FDKT\nThe algorithm is based on paper [Federated Domain-Specific Knowledge Transfer on Large Language Models "
},
{
"path": "doc/tutorial/fdkt/fdkt.ipynb",
"chars": 24991,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"# Synthesize Data With FDKT\"\n ]\n "
},
{
"path": "doc/tutorial/fedcot/README.md",
"chars": 746,
"preview": "# FATE-LLM: FedCoT\n\nThe algorithm is based on paper [\"FedCoT: Federated Chain-of-Thought Distillation for Large Language"
},
{
"path": "doc/tutorial/fedcot/encoder_decoder_tutorial.ipynb",
"chars": 16168,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"id\": \"a163d9c2-f9d6-4c61-a8e8-76a3f66c38ae\",\n \"metadata\": {},\n \"so"
},
{
"path": "doc/tutorial/fedcot/fedcot_tutorial.ipynb",
"chars": 49368,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"id\": \"9234355d-389f-484f-9fc2-7b17563b3390\",\n \"metadata\": {},\n \"so"
},
{
"path": "doc/tutorial/fedkseed/README.md",
"chars": 783,
"preview": "## FedKSeed\n\nThe Algorithm is based on the paper: [Federated Full-Parameter Tuning of Billion-Sized Language Models\nwith"
},
{
"path": "doc/tutorial/fedkseed/fedkseed-example.ipynb",
"chars": 16856,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"# Federated Tuning with FedKSeed m"
},
{
"path": "doc/tutorial/fedmkt/README.md",
"chars": 710,
"preview": "# FATE-LLM: FedMKT\n\nThe algorithm is based on paper [\"FedMKT: Federated Mutual Knowledge Transfer for Large and SmallLan"
},
{
"path": "doc/tutorial/fedmkt/fedmkt.ipynb",
"chars": 75222,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"# Federated Tuning With FedMKT meth"
},
{
"path": "doc/tutorial/inferdpt/inferdpt_tutorial.ipynb",
"chars": 40220,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"id\": \"341aeb6e-9e25-4a0e-9664-a32ab11293fa\",\n \"metadata\": {},\n \"so"
},
{
"path": "doc/tutorial/offsite_tuning/Offsite_tuning_tutorial.ipynb",
"chars": 42728,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"id\": \"c2345e19-83eb-4196-9606-74658c8fbdc5\",\n \"metadata\": {},\n \"so"
},
{
"path": "doc/tutorial/offsite_tuning/README.md",
"chars": 1921,
"preview": "\n# Offsite-Tuning\n\n## Standard Offsite-tuning\n\nOffsite-Tuning is designed for the efficient adaptation of large foundati"
},
{
"path": "doc/tutorial/pellm/ChatGLM3-6B_ds.ipynb",
"chars": 16761,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"### Federated ChatGLM3 Tuning with "
},
{
"path": "doc/tutorial/pellm/builtin_pellm_models.md",
"chars": 2031,
"preview": "## Builtin PELLM Models\nFATE-LLM provide some builtin pellm models, users can use them simply to efficiently train their"
},
{
"path": "examples/fedmkt/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "examples/fedmkt/fedmkt.py",
"chars": 10912,
"preview": "from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_fedmkt_runner\nfrom fate_client.pipeline.c"
},
{
"path": "examples/fedmkt/fedmkt_config.yaml",
"chars": 2144,
"preview": "# fedmkt_config.yaml\n\n# Configuration for Lora\nlora_config:\n llm:\n r: 8\n lora_alpha: 16\n lora_dropout: 0.05\n "
},
{
"path": "examples/fedmkt/test_fedmkt_llmsuit.yaml",
"chars": 296,
"preview": "data:\n - file: \n table_name: arc_challenge\n namespace: experiment\n role: guest_0\n - file: \n table_name: ar"
},
{
"path": "examples/offsite_tuning/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "examples/offsite_tuning/offsite_tuning.py",
"chars": 4914,
"preview": "import argparse\nimport yaml\nfrom fate_client.pipeline.components.fate.reader import Reader\nfrom fate_client.pipeline imp"
},
{
"path": "examples/offsite_tuning/offsite_tuning_config.yaml",
"chars": 1411,
"preview": "# params.yaml\n\npaths:\n pretrained_model_path: 'gpt2'\n\npipeline:\n guest: '9999'\n arbiter: '9999'\n namespace: 'experim"
},
{
"path": "examples/offsite_tuning/test_offsite_tuning_llmsuite.yaml",
"chars": 290,
"preview": "data:\n - file: \n table_name: sciq\n namespace: experiment\n role: guest_0\n - file: \n table_name: sciq\n na"
},
{
"path": "examples/pellm/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "examples/pellm/bloom_lora_config.yaml",
"chars": 1178,
"preview": "data:\n guest:\n namespace: experiment\n name: ad\n host:\n namespace: experiment\n name: ad\nepoch: 1\nbatch_size"
},
{
"path": "examples/pellm/test_bloom_lora.py",
"chars": 3975,
"preview": "import time\nfrom fate_client.pipeline.components.fate.reader import Reader\nfrom fate_client.pipeline import FateFlowPipe"
},
{
"path": "examples/pellm/test_pellm_llmsuite.yaml",
"chars": 626,
"preview": "data:\n - file: examples/data/AdvertiseGen/train.json\n table_name: ad\n namespace: experiment\n role: guest_0\n -"
},
{
"path": "python/MANIFEST.in",
"chars": 89,
"preview": "include fate_llm/dataset/data_config/*yaml\ninclude python/fate_llm/evaluate/tasks/*/*yaml"
},
{
"path": "python/fate_llm/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "python/fate_llm/algo/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "python/fate_llm/algo/dp/__init__.py",
"chars": 742,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/dp/dp_trainer.py",
"chars": 5727,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/dp/opacus_compatibility/__init__.py",
"chars": 1144,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/dp/opacus_compatibility/grad_sample/__init__.py",
"chars": 616,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/dp/opacus_compatibility/grad_sample/embedding.py",
"chars": 2142,
"preview": "#\n# Copyright (c) Meta Platforms, Inc. and affiliates.\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Li"
},
{
"path": "python/fate_llm/algo/dp/opacus_compatibility/optimizers/__init__.py",
"chars": 616,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/dp/opacus_compatibility/optimizers/optimizer.py",
"chars": 1617,
"preview": "#\n# Copyright (c) Meta Platforms, Inc. and affiliates.\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Li"
},
{
"path": "python/fate_llm/algo/dp/opacus_compatibility/transformers_compate.py",
"chars": 1563,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fdkt/__init__.py",
"chars": 772,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fdkt/cluster/__init__.py",
"chars": 616,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fdkt/cluster/cluster.py",
"chars": 1458,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fdkt/cluster/cluster_method.py",
"chars": 1228,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fdkt/fdkt_data_aug.py",
"chars": 12020,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fdkt/inference_inst.py",
"chars": 1226,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fdkt/utils/__init__.py",
"chars": 616,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fdkt/utils/dp_loss.py",
"chars": 2554,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fdkt/utils/invalid_data_filter.py",
"chars": 1174,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fdkt/utils/text_generate.py",
"chars": 3872,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fedavg/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "python/fate_llm/algo/fedavg/fedavg.py",
"chars": 3773,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fedcollm/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "python/fate_llm/algo/fedcollm/fedcollm.py",
"chars": 12090,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fedcollm/fedcollm_trainer.py",
"chars": 4477,
"preview": "#\n# NOTE: The implementations of FedMKTTrainer is modified from FuseAI/FuseLLM\n# Copyright FuseAI\n#\n#\n# Copyright 2019 "
},
{
"path": "python/fate_llm/algo/fedcollm/fedcollm_training_args.py",
"chars": 2840,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fedcot/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "python/fate_llm/algo/fedcot/encoder_decoder/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "python/fate_llm/algo/fedcot/encoder_decoder/init/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "python/fate_llm/algo/fedcot/encoder_decoder/init/default_init.py",
"chars": 1659,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fedcot/encoder_decoder/slm_encoder_decoder.py",
"chars": 3512,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fedcot/fedcot_trainer.py",
"chars": 9507,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fedcot/slm_encoder_decoder_trainer.py",
"chars": 1033,
"preview": "from fate_llm.trainer.seq2seq_trainer import Seq2SeqTrainer\nfrom transformers import DataCollatorForSeq2Seq\nfrom transfo"
},
{
"path": "python/fate_llm/algo/fedkseed/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "python/fate_llm/algo/fedkseed/args.py",
"chars": 1077,
"preview": "from dataclasses import dataclass, field\n\n\n@dataclass\nclass KSeedTrainingArguments:\n \"\"\"\n TrainingArguments is the"
},
{
"path": "python/fate_llm/algo/fedkseed/fedkseed.py",
"chars": 6903,
"preview": "import copy\nimport logging\nfrom dataclasses import dataclass, field\nfrom typing import List, Mapping\n\nimport torch\nfrom "
},
{
"path": "python/fate_llm/algo/fedkseed/optimizer.py",
"chars": 10205,
"preview": "\"\"\"\nThe implementations of ZerothOrderOptimizer and KSeedZerothOrderOptimizer is\nadapted from https://github.com/princet"
},
{
"path": "python/fate_llm/algo/fedkseed/pytorch_utils.py",
"chars": 2074,
"preview": "from typing import List\n\nfrom transformers.pytorch_utils import ALL_LAYERNORM_LAYERS\nfrom transformers.trainer_pt_utils "
},
{
"path": "python/fate_llm/algo/fedkseed/trainer.py",
"chars": 5922,
"preview": "import logging\nfrom typing import Dict, Union, Any, Tuple\nfrom typing import Optional, List, Callable\n\nimport torch\nfrom"
},
{
"path": "python/fate_llm/algo/fedkseed/zo_utils.py",
"chars": 2367,
"preview": "from typing import List\n\nimport torch\n\n\ndef probability_from_amps(amps: List[List[float]], clip):\n \"\"\"\n Get the pr"
},
{
"path": "python/fate_llm/algo/fedmkt/__init__.py",
"chars": 798,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fedmkt/fedmkt.py",
"chars": 22753,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fedmkt/fedmkt_data_collator.py",
"chars": 5314,
"preview": "#\n# NOTE: The implementations of DataCollatorForFedMKT is modified from FuseAI/FuseLLM\n# Copyright FuseAI/FuseLLM\n#\n#\n# "
},
{
"path": "python/fate_llm/algo/fedmkt/fedmkt_trainer.py",
"chars": 6522,
"preview": "#\n# NOTE: The implementations of FedMKTTrainer is modified from FuseAI/FuseLLM\n# Copyright FuseAI\n#\n#\n# Copyright 2019 "
},
{
"path": "python/fate_llm/algo/fedmkt/token_alignment/__init__.py",
"chars": 616,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fedmkt/token_alignment/spectal_token_mapping.py",
"chars": 915,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fedmkt/token_alignment/token_align.py",
"chars": 19335,
"preview": "#\n# NOTE: The dtw function is copied from FuseAI/FuseLLM\n# and the align_blending_model_logits_with_base_model_log"
},
{
"path": "python/fate_llm/algo/fedmkt/token_alignment/vocab_mapping.py",
"chars": 2772,
"preview": "#\n# NOTE: The find_best_mapping function is copied from FuseAI/FuseLLM\n# Copyright FuseAI/FuseLLM\n#\n#\n# Copyright 2019 "
},
{
"path": "python/fate_llm/algo/fedmkt/utils/__init__.py",
"chars": 616,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fedmkt/utils/dataset_sync_util.py",
"chars": 2786,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fedmkt/utils/generate_logit_utils.py",
"chars": 4124,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fedmkt/utils/tokenizer_tool.py",
"chars": 807,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/fedmkt/utils/vars_define.py",
"chars": 998,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/inferdpt/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "python/fate_llm/algo/inferdpt/_encode_decode.py",
"chars": 1132,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/inferdpt/inferdpt.py",
"chars": 6778,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/inferdpt/init/_init.py",
"chars": 800,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/inferdpt/init/default_init.py",
"chars": 1949,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/inferdpt/utils.py",
"chars": 9469,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/algo/offsite_tuning/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "python/fate_llm/algo/offsite_tuning/offsite_tuning.py",
"chars": 7798,
"preview": "from fate.ml.aggregator.base import Aggregator\nfrom fate_llm.algo.fedavg.fedavg import Seq2SeqFedAVGClient, Seq2SeqFedAV"
},
{
"path": "python/fate_llm/algo/ppc-gpt/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "python/fate_llm/data/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "python/fate_llm/data/data_collator/__init__.py",
"chars": 616,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/data/data_collator/cust_data_collator.py",
"chars": 2124,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/data/data_collator/fedcot_collator.py",
"chars": 705,
"preview": "from transformers import DataCollatorForSeq2Seq \nfrom transformers import AutoTokenizer\nimport pandas as pd\n\nclass Prefi"
},
{
"path": "python/fate_llm/data/tokenizers/__init__.py",
"chars": 616,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/data/tokenizers/cust_tokenizer.py",
"chars": 1840,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/dataset/__init__.py",
"chars": 616,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/dataset/data_config/__init__.py",
"chars": 283,
"preview": "import os\n# absolute path to current directory\nparent_dir = os.path.dirname(os.path.realpath(__file__))\n\nDATA_CONFIG_TEM"
},
{
"path": "python/fate_llm/dataset/data_config/default_ag_news.yaml",
"chars": 1669,
"preview": "dataset_kwargs:\n data_files: ag_news_review/AGnews/train.json\ndataset_path: json\ndoc_to_target: '{{label}}'\nmetric_list"
},
{
"path": "python/fate_llm/dataset/data_config/default_yelp_review.yaml",
"chars": 2039,
"preview": "dataset_kwargs:\n data_files: yelp_review/Health/train.json\ndataset_path: json\ndoc_to_target: '{{label}}'\nmetric_list:\n-"
},
{
"path": "python/fate_llm/dataset/fedcot_dataset.py",
"chars": 2094,
"preview": "from fate_llm.dataset.input_output_dataset import InputOutputDataset\nfrom transformers.trainer_pt_utils import LabelSmoo"
},
{
"path": "python/fate_llm/dataset/flex_dataset.py",
"chars": 15141,
"preview": "#\n# Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/dataset/hf_dataset.py",
"chars": 8200,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/dataset/input_output_dataset.py",
"chars": 4203,
"preview": "from fate.ml.nn.dataset.base import Dataset\nfrom transformers.trainer_pt_utils import LabelSmoother\nfrom typing import L"
},
{
"path": "python/fate_llm/dataset/prompt_dataset.py",
"chars": 8007,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/dataset/qa_dataset.py",
"chars": 11338,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/dataset/seq_cls_dataset.py",
"chars": 4128,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/evaluate/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "python/fate_llm/evaluate/scripts/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "python/fate_llm/evaluate/scripts/_options.py",
"chars": 2214,
"preview": "import time\n\nimport click\n\nfrom ..utils.config import parse_config, default_eval_config\nfrom ..utils.config import _set_"
},
{
"path": "python/fate_llm/evaluate/scripts/config_cli.py",
"chars": 1704,
"preview": "#\n# Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/evaluate/scripts/data_cli.py",
"chars": 1527,
"preview": "#\n# Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/evaluate/scripts/eval_cli.py",
"chars": 4871,
"preview": "#\n# Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/evaluate/scripts/fate_llm_cli.py",
"chars": 1691,
"preview": "#\n# Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/evaluate/tasks/__init__.py",
"chars": 3267,
"preview": "#\n# Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/evaluate/tasks/advertise_gen/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "python/fate_llm/evaluate/tasks/advertise_gen/advertise_utils.py",
"chars": 519,
"preview": "# adopted from https://github.com/huggingface/datasets/blob/main/metrics/rouge/rouge.py\n\n\nfrom rouge_score import rouge_"
},
{
"path": "python/fate_llm/evaluate/tasks/advertise_gen/default_advertise_gen.yaml",
"chars": 333,
"preview": "dataset_kwargs:\n data_files:\n train: train.json\n validation: dev.json\ndataset_path: json\ndoc_to_target: '{{summar"
},
{
"path": "python/fate_llm/evaluate/tasks/dolly_15k/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "python/fate_llm/evaluate/tasks/dolly_15k/default_dolly_15k.yaml",
"chars": 323,
"preview": "dataset_kwargs:\n data_files: databricks-dolly-15k.jsonl\ndataset_path: json\ndoc_to_target: '{{response}}'\ndoc_to_text: !"
},
{
"path": "python/fate_llm/evaluate/tasks/dolly_15k/dolly_utils.py",
"chars": 1648,
"preview": "# adopted from https://github.com/huggingface/datasets/blob/main/metrics/rouge/rouge.py\n\n\nfrom rouge_score import rouge_"
},
{
"path": "python/fate_llm/evaluate/utils/__init__.py",
"chars": 46,
"preview": "from ._parser import LlmJob, LlmPair, LlmSuite"
},
{
"path": "python/fate_llm/evaluate/utils/_io.py",
"chars": 1656,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/evaluate/utils/_parser.py",
"chars": 5960,
"preview": "#\n# Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/evaluate/utils/config.py",
"chars": 3471,
"preview": "#\n# Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/evaluate/utils/data_tools.py",
"chars": 1822,
"preview": "#\n# Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/evaluate/utils/llm_evaluator.py",
"chars": 10350,
"preview": "#\n# Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/evaluate/utils/model_tools.py",
"chars": 2003,
"preview": "#\n# Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/inference/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "python/fate_llm/inference/api.py",
"chars": 1443,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/inference/hf_qw.py",
"chars": 1520,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/inference/inference_base.py",
"chars": 826,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/inference/vllm.py",
"chars": 1724,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/model_zoo/__init__.py",
"chars": 616,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/model_zoo/embedding_transformer/__init__.py",
"chars": 616,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/model_zoo/embedding_transformer/st_model.py",
"chars": 2789,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/model_zoo/hf_model.py",
"chars": 1317,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/model_zoo/offsite_tuning/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "python/fate_llm/model_zoo/offsite_tuning/bloom.py",
"chars": 8300,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/model_zoo/offsite_tuning/gpt2.py",
"chars": 9526,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/model_zoo/offsite_tuning/llama.py",
"chars": 5965,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py",
"chars": 8095,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/model_zoo/pellm/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "python/fate_llm/model_zoo/pellm/albert.py",
"chars": 1809,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/model_zoo/pellm/bart.py",
"chars": 1748,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/model_zoo/pellm/bert.py",
"chars": 1748,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/model_zoo/pellm/bloom.py",
"chars": 1364,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/model_zoo/pellm/chatglm.py",
"chars": 1778,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/model_zoo/pellm/deberta.py",
"chars": 1773,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/model_zoo/pellm/distilbert.py",
"chars": 1796,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/model_zoo/pellm/gpt2.py",
"chars": 1850,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/model_zoo/pellm/llama.py",
"chars": 2159,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/model_zoo/pellm/opt.py",
"chars": 1352,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/model_zoo/pellm/parameter_efficient_llm.py",
"chars": 4854,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/model_zoo/pellm/qwen.py",
"chars": 1363,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/model_zoo/pellm/roberta.py",
"chars": 1772,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/runner/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "python/fate_llm/runner/fdkt_runner.py",
"chars": 7225,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/runner/fedcot_runner.py",
"chars": 11314,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/runner/fedkseed_runner.py",
"chars": 5099,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/runner/fedmkt_runner.py",
"chars": 8941,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/runner/homo_seq2seq_runner.py",
"chars": 9760,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/runner/inferdpt_runner.py",
"chars": 6431,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/runner/offsite_tuning_runner.py",
"chars": 6848,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/fate_llm/trainer/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "python/fate_llm/trainer/seq2seq_trainer.py",
"chars": 6428,
"preview": "#\n# Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "python/requirements.txt",
"chars": 227,
"preview": "accelerate==0.27.2\ndeepspeed==0.13.3\npeft==0.8.2\nsentencepiece==0.2.0\nlm_eval==0.4.2\nrouge-score==0.1.2\ndatasets==2.18.0"
},
{
"path": "python/setup.py",
"chars": 2199,
"preview": "# -*- coding: utf-8 -*-\n# \n# Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n# Licensed under the Apache Licen"
}
]
About this extraction
This page contains the full source code of the FederatedAI/FATE-LLM GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 172 files (765.2 KB), approximately 201.7k tokens, and a symbol index with 637 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.