[
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# FATE-LLM\nFATE-LLM is a framework to support federated learning for large language models(LLMs) and small language models(SLMs).\n<div align=\"center\">\n  <img src=\"./doc/images/fate-llm-show.png\" height=\"300\">\n</div>\n\n## Design Principle\n- Federated learning for large language models(LLMs) and small language models(SLMs).\n- Promote training efficiency of federated LLMs using Parameter-Efficient methods.\n- Protect the IP of LLMs using FedIPR.\n- Protect data privacy during training and inference through privacy preserving mechanisms.\n<div align=\"center\">\n  <img src=\"./doc/images/fate-llm-plan.png\">\n</div>\n\n### Standalone deployment\n* To deploy FATE-LLM v2.2.0 or higher version, three ways are provided, please refer [deploy tutorial](./doc/standalone_deploy.md) for more details:\n  * deploy with FATE only from pypi then using Launcher to run tasks\n  * deploy with FATE、FATE-Flow、FATE-Client from pypi, user can run tasks with Pipeline  \n* To deploy lower versions: please refer to [FATE-Standalone deployment](https://github.com/FederatedAI/FATE#standalone-deployment).   \n  * To deploy FATE-LLM v2.0.* - FATE-LLM v2.1.*, deploy FATE-Standalone with version >= 2.1, then make a new directory `{fate_install}/fate_llm` and clone the code into it, install the python requirements, and add `{fate_install}/fate_llm/python` to `PYTHONPATH` \n  * To deploy FATE-LLM v1.x, deploy FATE-Standalone with 1.11.3 <= version < 2.0, then copy directory `python/fate_llm` to `{fate_install}/fate/python/fate_llm` \n\n### Cluster deployment\nUse [FATE-LLM deployment packages](https://github.com/FederatedAI/FATE/wiki/Download#llm%E9%83%A8%E7%BD%B2%E5%8C%85) to deploy,  refer to [FATE-Cluster deployment](https://github.com/FederatedAI/FATE#cluster-deployment) for more deployment details.\n\n## Quick Start\n\n- [Federated ChatGLM3-6B Training](doc/tutorial/pellm/ChatGLM3-6B_ds.ipynb)\n- [Builtin Models In PELLM](doc/tutorial/pellm/builtin_pellm_models.md)\n- [FedMKT: Federated Mutual Knowledge Transfer for Large and Small\nLanguage Models](./doc/tutorial/fedmkt/)\n- [FedCoT: Federated Chain-of-Thought Distillation for Large Language Models](./doc/tutorial/fedcot)\n- [PPC-GPT: Federated Task-Specific Compression of Large Language\nModels via Pruning and Chain-of-Thought Distillation](https://aclanthology.org/2025.emnlp-main.747.pdf)\n- [FDKT: Federated Domain-Specific Knowledge Transfer on Large Language Models Using Synthetic Data](./doc/tutorial/fdkt)\n- [Offsite Tuning: Transfer Learning without Full Model](./doc/tutorial/offsite_tuning/Offsite_tuning_tutorial.ipynb)\n- [FedKSeed: Federated Full-Parameter Tuning of Billion-Sized Language Models\nwith Communication Cost under 18 Kilobytes](./doc/tutorial/fedkseed/)\n- [InferDPT: Privacy-preserving Inference for Black-box Large Language Models](./doc/tutorial/inferdpt/inferdpt_tutorial.ipynb)\n\n## FATE-LLM Evaluate\n\n- [Python SDK & CLI Usage Guide](./doc/fate_llm_evaluate.md)\n\n## Citation\n\nIf you publish work that uses FATE-LLM, please cite FATE-LLM as follows:\n```\n@article{fan2023fate,\n  title={Fate-llm: A industrial grade federated learning framework for large language models},\n  author={Fan, Tao and Kang, Yan and Ma, Guoqiang and Chen, Weijing and Wei, Wenbin and Fan, Lixin and Yang, Qiang},\n  journal={Symposium on Advances and Open Problems in Large Language Models (LLM@IJCAI'23)},\n  year={2023}\n}\n```\n"
  },
  {
    "path": "RELEASE.md",
    "content": "## Release 2.2.0\n### Major Features and Improvements\n* Integrate the FedCoT (Federated Chain-of-Thought) algorithm, a novel framework that enhances local small language models (SLMs) using differentially private protected Chain of Thoughts (Cot) generated by remote LLMs:\n  * Implement InferDPT for privacy-preserving Cot generation.\n  * Support an encoder-decoder mechanism for privacy-preserving Cot generation.\n  * Add prefix trainers for step-by-step distillation and text encoder-decoder training.\n* Integrate the FDKT algorithm, a framework that enables domain-specific knowledge transfer from LLMs to SLMs while preserving SLM data privacy\n* Deployment Optimization: support installation of FATE-LLM by PyPi\n\n\n## Release 2.1.0\n### Major Features and Improvements\n* New FedMKT Federated Tuning Algorithms: Federated Mutual Knowledge Transfer for Large and Small Language Models\n  * Support three distinct scenarios: Heterogeneous, Homogeneous and One-to-One\n  * Support LLM to SLM one-way knowledge transfer\n* Introduce the InferDPT algorithm, which leverages differential privacy (DP) to facilitate privacy-preserving inference for large language models.\n* Introduce FATE-LLM Evaluate: evaluate FATE-LLM models in few lines with Python SDK or simple CLI commands(`fate_llm evaluate`), built-in cases included\n\n\n## Release 2.0.0\n### Major Features and Improvements\n* Adapt to fate-v2.0 framework:\n  * Migrate parameter-efficient fine-tuning training methods and models. \n  * Migrate Standard Offsite-Tuning and Extended Offsite-Tuning（Federated Offsite-Tuning+)\n  * Newly trainer，dataset, data_processing function design\n* New FedKSeed Federated Tuning Algorithm: train large language models in a federated learning setting with extremely low communication cost\n\n## Release 1.3.0\n### Major Features and Improvements\n* FTL-LLM（Fedrated Learning + Transfer Learning + LLM）\n  * Standard Offsite-Tuning and Extended Offsite-Tuning（Federated Offsite-Tuning+）now supported\n  * Framework available for Emulator and Adapter development\n  * New Offsite-Tuning Trainer introduced\n  * Includes built-in models such as GPT-2 family, Llama7b, and Bloom family\n* FedIPR\n  * Introduced WatermarkDataset as the foundational dataset class for backdoor-based watermarks\n  * Added SignConv and SignLayerNorm blocks for feature-based watermark models\n  * New FedIPR Trainer available\n  * Built-in models with feature-based watermarks include Alexnet, Resnet18, DistilBert, and GPT2\n* More models support parameter-efficient fine-tuning: ChatGLM2-6B and Bloom-7B1\n\n\n## Release 1.2.0\n### Major Features and Improvements\n* Support Federated Training of LLaMA-7B with parameter-efficient fine-tuning.\n\n\n## Release 1.1.0\n### Major Features and Improvements\n* Support Federated Training of ChatGLM-6B with parameter-efficient fine-tuning adapters: like Lora and P-Tuning V2 etc.\n* Integration of `peft`, which support many parameter-efficient adapters.\n"
  },
  {
    "path": "doc/fate_llm_evaluate.md",
    "content": "## FATE-LLM Python SDK\n\nFATE-LLM Python SDK provides simple API for evaluating large language models.\nBuilt on [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/), our evaluation tool may be used on pre-trained models from Huggingface, local-built models, as well as FATE-LLM models. \n[Built-in datasets](#built-in-tasks) currently include Dolly-15k and Advertise Generation.\nBelow shows how to evaluate given llm model in few lines. For quick single-model evaluation, below steps should suffice, however, if comparative evaluation among multiple models is desired, CLI is recommended.\n\n```python\n    from lm_eval.models.huggingface import HFLM\n    from fate_llm.evaluate.utils import llm_evaluator\n\n    # download data for built-in tasks if running fate-llm evaluation for the first time \n    # alternatively, use CLI `fate-llm data download` to download data\n    llm_evaluator.download_task(\"dolly-15k\")\n    # set paths of built-in tasks\n    llm_evaluator.init_tasks()\n    # load model\n    bloom_lm = HFLM(pretrained='bloom-560')\n    # if loading local model, specify peft storage location\n    # gpt2_lm = HFLM(pretrained='bloom-560m', peft_path_format=\"path/to/peft\")\n    # run evaluation\n    llm_evaluator.evaluate(model=bloom_lm, tasks=\"dolly-15k\", show_result=True)\n```\n\nWhen network allows, or if already cached, tasks from lm-evaluation may be provided for evaluation in similar style.\n\n```python\n    from lm_eval.models.huggingface import HFLM\n    from fate_llm.evaluate.utils import llm_evaluator\n    # load model\n    bloom_lm = HFLM(pretrained='bloom-560')\n    # if loading local model, specify peft storage location\n    # bloom_lm = HFLM(pretrained='bloom-560m', peft_path_format=\"path/to/peft\")\n    # run evaluation\n    llm_evaluator.evaluate(model=gpt2_lm, tasks=\"ceval\", show_result=True)\n```\n\n## FATE-LLM Command Line Interface\n\nFATE LLM provides built-in tasks for comparing evaluation results of different llm models. \nAlternatively, user may provide arbitrary tasks for evaluation.\n\n### install\n\n```bash\ncd {path_to_fate_llm}/python\npip install -e .\n```\n\n### command options\n\n```bash\nfate_llm --help\n```\n\n#### evaluate:\n\n\n1. in:\n\n   ```bash\n   fate_llm evaluate -i <path1 to *.yaml>\n   ```\n\n   will run llm at\n   *path1*\n\n2. eval-config:\n\n    ```bash\n    fate_llm evaluate -i <path1 to *.yaml> -c <path2>\n    ```\n  \n\n   will run llm testsuites in *path1* with evaluation configuration set to *path2*\n\n3. result-output:\n\n    ```bash\n    fate_llm evaluate -i <path1 contains *.yaml> -o <path2>\n    ```\n\n    will run llm testsuites in *path1* with evaluation result output stored in *path2*\n\n### config\n\n```bash\nfate_llm config --help\n```\n\n1. new:\n    ```bash\n    fate_llm config new\n    ```\n\n    will create a new evaluation configuration file in current directory\n\n2. show:\n\n    ```bash\n    fate_llm config show\n    ```\n\n    will show current evaluation configuration \n\n3. edit:\n\n    ```bash\n    fate_llm config edit \n    ```\n\n    will edit evaluation configuration\n\n### data\n    \n    ```bash\n    fate_llm data --help\n    ```\n1. download:\n\n    ```bash\n    fate_llm data download -t <task1> -t <task2> ...\n    ```\n\n    will download corresponding data for given tasks \n\n\n### FATE-LLM Eval job configuration\n\nConfiguration of jobs should be specified in a yaml file. \n\nA FATE-LLM testsuite includes the following elements:\n\n- job group: each group includes arbitrary number of jobs with paths\n  to corresponding script and configuration\n\n    - job: name of evaluation job to be run, must be unique within each group\n      list\n        - pretrained: path to pretrained model, should be either mmodel name from Hugginface or relative path to\n          testsuite\n        - peft: path to peft file, should be relative to testsuite, \n          optional\n        - tasks: list of tasks to be evaluated, optional for jobs skipping evaluation\n        - include_path: should be specified if tasks are user-defined\n        - eval_conf: path to evaluation configuration file, should be\n          relative to testsuite; if not provided, will use default conf\n\n      ```yaml\n          bloom_lora:\n            pretrained: \"bloom-560m\"\n            peft_path_format: \"{{fate_base}}/fate_flow/model/{{job_id}}/guest/{{party_id}}/{{model_task_name}}/0/output/output_model/model_directory\"\n            tasks:\n              - \"dolly-15k\"\n\n      ```\n\n- llm suite\n\n  ```yaml\n     bloom_suite:\n      bloom_zero_shot:\n        pretrained: \"bloom-560m\"\n        tasks:\n          - \"dolly-15k\"\n  ```\n  \n## Built-in Tasks\n\nCurrently, we include the following tasks in FATE-LLM Evaluate:\n\n| Task Name |     Alias     | Task Type  | Metric  |                                  source                                   |\n|:---------:|:-------------:|:----------:|:-------:|:-------------------------------------------------------------------------:|\n| Dolly-15k |   dolly-15k   | generation | rouge-L |  [link](https://huggingface.co/datasets/databricks/databricks-dolly-15k)  |\n|   ADGEN   | advertise-gen | generation | rouge-L |                                 [link](https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/README_en.md#instructions)                                  |\n\nUse corresponding alias to reference tasks in the system.\n"
  },
  {
    "path": "doc/standalone_deploy.md",
    "content": "# FATE-LLM Single-Node Deployment Guide\n\n## 1. Introduction\n\n**Server Configuration:**\n\n- **Quantity:** 1\n- **Configuration:** 8 cores / 16GB memory / 500GB hard disk / GPU Machine\n- **Operating System:** CentOS Linux release 7\n- **User:** User: app owner:apps\n\nThe single-node version provides 3 deployment methods, which can be selected based on your needs:\n- Install FATE-LLM from PyPI With FATE\n- Install FATE-LLM from PyPI with FATE, FATE-Flow, FATE-Client\n\n## 2. Install FATE-LLM from PyPI With FATE\nIn this way, user can run tasks with Launcher, a convenient way for fast experimental using.\n\n### 2.1 Installing Python Environment\n- Prepare and install [conda](https://docs.conda.io/projects/miniconda/en/latest/) environment.\n- Create a virtual environment:\n\n```shell\n# FATE-LLM requires Python >= 3.10\nconda create -n fate_env python=3.10\nconda activate fate_env\n```\n\n### 2.2 Installing FATE-LLM\nThis section introduces how to install FATE-LLM from pypi with FATE, execute the following command to install FATE-LLM. \n\n```shell\npip install fate_llm[fate]==2.2.0\n```\n\n### 2.3 Usage\nAfter installing successfully, please refer to [tutorials](../README.md#quick-start) to run tasks, tasks describe in the tutorials running will Launcher are all supported.\n\n\n## 3. Install FATE-LLM from PyPI with FATE, FATE-Flow, FATE-Client\nIn this way, user can run tasks with Pipeline or Launcher. \n\n### 3.1 Installing Python Environment\nPlease refer to section-2.1\n\n### 3.2 Installing FATE-LLM with FATE, FATE-Flow, FATE-Client\n\n```shell\npip install fate_client[fate,fate_flow,fate_client]==2.2.0\n```\n\n### 3.3 Service Initialization\n\n```shell\nmkdir fate_workspace\nfate_flow init --ip 127.0.0.1 --port 9380 --home $(pwd)/fate_workspace\npipeline init --ip 127.0.0.1 --port 9380\n```\n- `ip`: The IP address where the service runs.\n- `port`: The HTTP port the service runs on.\n- `home`: The data storage directory, including data, models, logs, job configurations, and SQLite databases.\n\n### 3.4 Start Fate-Flow Service\n\n```shell\nfate_flow start\nfate_flow status # make sure fate_flow service is started\n```\n\nFATE-Flow also provides other instructions like stop and restart, use only if users want to stop/restart fate_flow services.\n```shell\n# Warning: normal installing process does not need to execute stop/restart instructions.\nfate_flow stop\nfate_flow restart\n```\n\n### 3.5 Usage\nPlease refer to [tutorials](../README.md#quick-start) for more usage guides, tasks describe in the tutorials running will Pipeline or Launcher are all supported.\n"
  },
  {
    "path": "doc/tutorial/fdkt/README.md",
    "content": "# FATE-LLM: FDKT\nThe algorithm is based on paper [Federated Domain-Specific Knowledge Transfer on Large Language Models Using Synthetic Data](https://arxiv.org/pdf/2405.14212), \na novel framework that enables domain-specific knowledge transfer from LLMs to SLMs while preserving SLM data privacy.\n\n## Citation\nIf you publish work that uses FDKT, please cite FDKT as follows:\n```\n@article{li2024federated,\n  title={Federated Domain-Specific Knowledge Transfer on Large Language Models Using Synthetic Data},\n  author={Li, Haoran and Zhao, Xinyuan and Guo, Dadi and Gu, Hanlin and Zeng, Ziqian and Han, Yuxing and Song, Yangqiu and Fan, Lixin and Yang, Qiang},\n  journal={arXiv preprint arXiv:2405.14212},\n  year={2024}\n}\n```\n"
  },
  {
    "path": "doc/tutorial/fdkt/fdkt.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Synthesize Data With FDKT\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"In this tutoria, we will demonstrate how to Synthesize data using the FATE-LLM framework. In FATE-LLM, we introduce the \\\"FDKT\\\" module,  specifically designed for domain-specific knowledge transfer on large language models using synthetic data. FDKT Algorithm is based on paper [Federated Domain-Specific Knowledge Transfer on\\n\",\n    \"Large Language Models Using Synthetic Data](https://arxiv.org/pdf/2405.14212), We integrate its code into the FATE-LLM framework.  \\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Dataset: Yelp\\n\",\n    \"We processed and sample data of 'Health' subdomain from [Yelp dataset](https://arxiv.org/abs/1509.01626) , the dataset can be downloaded from [here](https://www.yelp.com/dataset). \\n\",\n    \"Once the dataset has been downloaded, execute the following command to untar the downloaded dataset.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"```shell\\n\",\n    \"tar -xvf yelp_dataset.tar\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"The following code will sample 5000 datalines of 'Health' subdomain, and train data will generated under the folder './processed_data/Health/train.json'\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"import json\\n\",\n    \"import sys\\n\",\n    \"import random\\n\",\n    \"from pathlib import Path\\n\",\n    \"random.seed(42)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"base_dir = \\\"./\\\"\\n\",\n    \"business_data_path = os.path.join(base_dir, 'yelp_academic_dataset_business.json')\\n\",\n    \"review_data_path = os.path.join(base_dir, 'yelp_academic_dataset_review.json')\\n\",\n    \"\\n\",\n    \"business_data_file = open(business_data_path, 'r')\\n\",\n    \"review_data_file = open(review_data_path, 'r')\\n\",\n    \"\\n\",\n    \"categories_list = ['Restaurants', 'Shopping', 'Arts', 'Health']\\n\",\n    \"business_dic = {}\\n\",\n    \"data_dict = {}\\n\",\n    \"for category in categories_list:\\n\",\n    \"    business_dic[category] = set()\\n\",\n    \"    data_dict[category] = []\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def get_categories(categories):\\n\",\n    \"    return_list = []\\n\",\n    \"    for category in categories_list:\\n\",\n    \"        if category in categories:\\n\",\n    \"            return_list.append(category)\\n\",\n    \"    return return_list\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"for line in business_data_file.readlines():\\n\",\n    \"    dic = json.loads(line)\\n\",\n    \"    if 'categories' in dic.keys() and dic['categories'] is not None:\\n\",\n    \"        category = get_categories(dic['categories'])\\n\",\n    \"        if len(category) == 1:\\n\",\n    \"            business_dic[category[0]].add(dic['business_id'])\\n\",\n    \"\\n\",\n    \"# for category in categories_list:\\n\",\n    \"for line in review_data_file.readlines():\\n\",\n    \"    dic = json.loads(line)\\n\",\n    \"    if 'business_id' in dic.keys() and dic['business_id'] is not None:\\n\",\n    \"        for category in categories_list:\\n\",\n    \"            if dic['business_id'] in business_dic[category]:\\n\",\n    \"                if dic['text'] is not None and dic['stars'] is not None:\\n\",\n    \"                    data_dict[category].append({'text': dic['text'], 'stars': dic['stars']})\\n\",\n    \"                break\\n\",\n    \"\\n\",\n    \"train_data_path = os.path.join('processed_data', \\\"Health\\\", 'train.json')\\n\",\n    \"os.makedirs(Path(train_data_path).parent, exist_ok=True)\\n\",\n    \"train_data_file = open(train_data_path, 'w')\\n\",\n    \"data_list = data_dict[\\\"Health\\\"]\\n\",\n    \"\\n\",\n    \"sample_data_dict = dict()\\n\",\n    \"\\n\",\n    \"for data in data_list:\\n\",\n    \"    star = int(data[\\\"stars\\\"])\\n\",\n    \"    if star not in sample_data_dict:\\n\",\n    \"        sample_data_dict[star] = []\\n\",\n    \"\\n\",\n    \"    sample_data_dict[star].append(data)\\n\",\n    \"\\n\",\n    \"data_list = []\\n\",\n    \"star_keys = list(sample_data_dict.keys())\\n\",\n    \"for star in star_keys:\\n\",\n    \"    sample_data = sample_data_dict[star][:1000]\\n\",\n    \"    random.shuffle(sample_data)\\n\",\n    \"    data_list.extend(sample_data)\\n\",\n    \"\\n\",\n    \"random.shuffle(data_list)\\n\",\n    \"json.dump(data_list, train_data_file, indent=4)\\n\",\n    \"train_data_file.close()\\n\",\n    \"\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Models Use\\n\",\n    \"Please download the following models, these models are used for data augmentation process.\\n\",\n    \"\\n\",\n    \"LLM: [Qwen1.5-7B-Chat](https://huggingface.co/Qwen/Qwen1.5-7B-Chat)  \\n\",\n    \"SLM: [gpt2-xl](https://huggingface.co/openai-community/gpt2-xl)\\n\",\n    \"\\n\",\n    \"MeanWhile, 'all-mpnet-base-v2' is used to generate embedding vectors in LLM side.\\n\",\n    \"\\n\",\n    \"Embedding Model:  [all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Running FDKT Data Synthetic Process With Launcher (Experimential Using)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### SLM Setting\\n\",\n    \"\\n\",\n    \"In this section, we will introduce some key configurations in SLM side.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### 1. loading model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\\n\",\n    \"import transformers\\n\",\n    \"from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"slm_pretrained_path = \\\"gpt2-xl\\\" # modity this to local directory\\n\",\n    \"slm = transformers.AutoModelForCausalLM.from_pretrained(slm_pretrained_path, torch_dtype=torch.bfloat16)\\n\",\n    \"tokenizer = get_tokenizer(slm_pretrained_path)\\n\",\n    \"tokenizer.pad_token_id = tokenizer.eos_token_id\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### 2. Initialize SLM Training Arugments\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_llm.algo.fdkt.fdkt_data_aug import FDKTTrainingArguments\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"training_args = FDKTTrainingArguments(\\n\",\n    \"    use_cpu=False, # use gpu to do dp(differential privacy) training process\\n\",\n    \"    device_id=0, # the device number of gpu\\n\",\n    \"    num_train_epochs=1, # dp training epochs\\n\",\n    \"    per_device_train_batch_size=2, # batch size of dp training\\n\",\n    \"    slm_generation_batch_size=32, # batch_size to generate data in slm side\\n\",\n    \"    seq_num_for_single_category=300, # data num for each category(label)\\n\",\n    \"    slm_generation_config=dict(\\n\",\n    \"        max_new_tokens=256,\\n\",\n    \"        temperature=1.0,\\n\",\n    \"        top_k=50,\\n\",\n    \"        top_p=0.9,\\n\",\n    \"        repetition_penalty=1.0,\\n\",\n    \"        pad_token_id=tokenizer.eos_token_id\\n\",\n    \"    ),\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### 3. Initlaize DataSet Instance\\n\",\n    \"\\n\",\n    \"We provide default templates for dataset \\\"Yelp\\\" and \\\"AGNews\\\", user can refer [here](https://github.com/FederatedAI/FATE-LLM/tree/dev-2.2.0/python/fate_llm/dataset/data_config) for more details. If you want to use your own dataset, please provide fields label_key/text_key/augment_format/filter_format/tokenize_format/sub_domain/label_list/few_shot_format/text_with_label_format like the two default templates and passing it as and argument.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_llm.dataset.flex_dataset import FlexDataset\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"ds = FlexDataset(\\n\",\n    \"    tokenizer_name_or_path=slm_pretrained_path,\\n\",\n    \"    load_from=\\\"json\\\",\\n\",\n    \"    data_part=\\\"train\\\",\\n\",\n    \"    dataset_name=\\\"yelp_review\\\", # use default template\\n\",\n    \"    # config=dict/template_path # if dataset_name not equals to \\\"yelp_review\\\" or \\\"ag_news\\\"\\n\",\n    \"    need_preprocess=True,\\n\",\n    \"    select_num=2000, # use data_num=2000 to train, default is None, None means using all data\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### LLM Setting\\n\",\n    \"\\n\",\n    \"In this section, we will introduce some key configurations in LLM side.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### 1. Deploy VLLM Server And Use OpenAI API Protocol To SpeedUp LLM Inference\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"please copy the following code to local file create_and_start_vllm.sh, then run the bash code by executing \\\"bash create_and_start_vllm.sh\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# create_and_start_vllm.sh\\n\",\n    \"# create vllm enviroment\\n\",\n    \"\\n\",\n    \"python -m venv vllm_venv\\n\",\n    \"source vllm_venv/bin/activate\\n\",\n    \"pip install vllm==0.4.3\\n\",\n    \"pip install numpy==1.26.4 # numpy >= 2.0.0 will raise error, so reinstall numpy<2.0.0\\n\",\n    \"\\n\",\n    \"# please modify Qwen1.5-7B-Chat to local llm model saving path\\n\",\n    \"export CUDA_VISIBLE_DEVICES=1,2\\n\",\n    \"nohup python -m vllm.entrypoints.openai.api_server --host 127.0.0.1 --port 9999 --model Qwen1.5-7B-Chat --dtype=half --enforce-eager --api-key demo --device cuda -tp 2 &\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### 2. Initialize LLM Training Arugments\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_llm.algo.fdkt.fdkt_data_aug import FDKTTrainingArguments\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"training_args = FDKTTrainingArguments(\\n\",\n    \"    sample_num_per_cluster=4, # use this to estimate the number of clusters, n_clusters=(len(dataset) + sample_num_per_cluster - 1) // sample_num_per_cluster\\n\",\n    \"    filter_prompt_max_length=2**16,\\n\",\n    \"    filter_generation_config=dict(\\n\",\n    \"        max_tokens=512,\\n\",\n    \"    ),\\n\",\n    \"    aug_generation_config=dict(\\n\",\n    \"        max_tokens=4096,\\n\",\n    \"        temperature=0.8,\\n\",\n    \"        top_p=0.9,\\n\",\n    \"    ),\\n\",\n    \"    aug_prompt_num=20000, # prompts use for data augmentation\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### 3. Initialize Embedding Generated Model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_llm.model_zoo.embedding_transformer.st_model import SentenceTransformerModel\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"embedding_lm = SentenceTransformerModel(model_name_or_path=\\\"all-mpnet-base-v2\\\").load() # modified model_name_or_path to local model saved path\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### 4. Initalize OpenAI Api For Inference\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_llm.algo.fdkt.inference_inst import api_init\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"inference_inst = api_init(\\n\",\n    \"    api_url=\\\"http://127.0.0.1:9999/v1/\\\",\\n\",\n    \"    model_name=\\\"Qwen1.5-7B-Chat\\\", # modified model_name to local Meta-Llama-3-8B-Instruct saved path\\n\",\n    \"    api_key=\\\"demo\\\"\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Complete Code \\n\",\n    \"\\n\",\n    \"Please paste the code in \\\"run_fdkt_by_launcher.py\\\" and execute it with the following command. Once the process is finished, augmentation data will be saved in the current directory, whose filename is aug_data_result.json\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"python run_fdkt_by_launcher.py --parties guest:9999 arbiter:10000 --log_level INFO\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import json\\n\",\n    \"import os\\n\",\n    \"\\n\",\n    \"import torch\\n\",\n    \"from fate.arch import Context\\n\",\n    \"from fate.arch.launchers.multiprocess_launcher import launch\\n\",\n    \"\\n\",\n    \"# please replace the following four variables to local paths\\n\",\n    \"llm_pretrained_path = \\\"Qwen1.5-7B-Chat\\\"\\n\",\n    \"embedding_model_path = \\\"all-mpnet-base-v2\\\"\\n\",\n    \"slm_pretrained_path = \\\"gpt2-xl\\\"\\n\",\n    \"slm_data_path = \\\"./processed_data/Health/train.json\\\"\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def get_optimizer(model, optimizer=\\\"adam\\\", lr=1e-4):\\n\",\n    \"    if optimizer == \\\"adam\\\":\\n\",\n    \"        optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)\\n\",\n    \"    elif optimizer == \\\"adamw\\\":\\n\",\n    \"        optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr)\\n\",\n    \"    else:\\n\",\n    \"        raise NotImplementedError(\\\"Given optimizer type is not supported\\\")\\n\",\n    \"    return optimizer\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def train_slm(ctx):\\n\",\n    \"    import transformers\\n\",\n    \"    from fate_llm.algo.fdkt.fdkt_data_aug import (\\n\",\n    \"        FDKTSLM,\\n\",\n    \"        FDKTTrainingArguments\\n\",\n    \"    )\\n\",\n    \"    from fate_llm.dataset.flex_dataset import FlexDataset\\n\",\n    \"    from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\\n\",\n    \"    from transformers.data import DataCollatorForSeq2Seq\\n\",\n    \"\\n\",\n    \"    slm = transformers.AutoModelForCausalLM.from_pretrained(slm_pretrained_path, torch_dtype=torch.bfloat16)\\n\",\n    \"    tokenizer = get_tokenizer(slm_pretrained_path)\\n\",\n    \"    tokenizer.pad_token_id = tokenizer.eos_token_id\\n\",\n    \"    training_args = FDKTTrainingArguments(\\n\",\n    \"        use_cpu=False,\\n\",\n    \"        device_id=0,\\n\",\n    \"        num_train_epochs=1,\\n\",\n    \"        per_device_train_batch_size=2,\\n\",\n    \"        slm_generation_batch_size=32,\\n\",\n    \"        seq_num_for_single_category=2000,\\n\",\n    \"        slm_generation_config=dict(\\n\",\n    \"            max_new_tokens=256,\\n\",\n    \"            do_sample=True,\\n\",\n    \"            temperature=1.0,\\n\",\n    \"            top_k=50,\\n\",\n    \"            top_p=0.9,\\n\",\n    \"            repetition_penalty=1.0,\\n\",\n    \"            pad_token_id=tokenizer.eos_token_id\\n\",\n    \"        ),\\n\",\n    \"        # inference_method=\\\"vllm\\\",\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    ds = FlexDataset(\\n\",\n    \"        tokenizer_name_or_path=slm_pretrained_path,\\n\",\n    \"        load_from=\\\"json\\\",\\n\",\n    \"        data_part=\\\"train\\\",\\n\",\n    \"        dataset_name=\\\"yelp_review\\\",\\n\",\n    \"        need_preprocess=True,\\n\",\n    \"        select_num=2000,  # use 2000 data to train, default is None, using all data\\n\",\n    \"    )\\n\",\n    \"    ds.load(slm_data_path)\\n\",\n    \"\\n\",\n    \"    fdkt_runner = FDKTSLM(\\n\",\n    \"        ctx=ctx,\\n\",\n    \"        model=slm,\\n\",\n    \"        training_args=training_args,\\n\",\n    \"        tokenizer=tokenizer,\\n\",\n    \"        train_set=ds,\\n\",\n    \"        optimizer=get_optimizer(slm),\\n\",\n    \"        data_collator=DataCollatorForSeq2Seq(tokenizer, label_pad_token_id=tokenizer.pad_token_id)\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    aug_data = fdkt_runner.aug_data()\\n\",\n    \"    with open(\\\"./aug_data_result.json\\\", \\\"w\\\") as fout:\\n\",\n    \"        fout.write(json.dumps(aug_data, indent=4))\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def train_llm(ctx):\\n\",\n    \"    from fate_llm.algo.fdkt.fdkt_data_aug import (\\n\",\n    \"        FDKTLLM,\\n\",\n    \"        FDKTTrainingArguments\\n\",\n    \"    )\\n\",\n    \"    from fate_llm.model_zoo.embedding_transformer.st_model import SentenceTransformerModel\\n\",\n    \"    from fate_llm.dataset.flex_dataset import FlexDataset\\n\",\n    \"    from fate_llm.algo.fdkt.inference_inst import api_init, vllm_init\\n\",\n    \"\\n\",\n    \"    embedding_lm = SentenceTransformerModel(model_name_or_path=embedding_model_path).load()\\n\",\n    \"    training_args = FDKTTrainingArguments(\\n\",\n    \"        sample_num_per_cluster=4,\\n\",\n    \"        filter_prompt_max_length=2**14,\\n\",\n    \"        filter_generation_config=dict(\\n\",\n    \"            max_tokens=4096,\\n\",\n    \"        ),\\n\",\n    \"        use_cpu=False,\\n\",\n    \"        aug_generation_config=dict(\\n\",\n    \"            max_tokens=4096,\\n\",\n    \"            temperature=0.8,\\n\",\n    \"            top_p=0.9,\\n\",\n    \"        ),\\n\",\n    \"        aug_prompt_num=20000,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    ds = FlexDataset(\\n\",\n    \"        tokenizer_name_or_path=llm_pretrained_path,\\n\",\n    \"        load_from=\\\"json\\\",\\n\",\n    \"        data_part=\\\"train\\\",\\n\",\n    \"        dataset_name=\\\"yelp_review\\\",\\n\",\n    \"        need_preprocess=True,\\n\",\n    \"        few_shot_num_per_label=1,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    inference_inst = api_init(\\n\",\n    \"        api_url=\\\"http://127.0.0.1:9999/v1/\\\",\\n\",\n    \"        model_name=llm_pretrained_path,\\n\",\n    \"        api_key=\\\"demo\\\"\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    fdkt_runner = FDKTLLM(\\n\",\n    \"        ctx=ctx,\\n\",\n    \"        embedding_model=embedding_lm,\\n\",\n    \"        training_args=training_args,\\n\",\n    \"        dataset=ds,\\n\",\n    \"        inference_inst=inference_inst,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    fdkt_runner.aug_data()\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def run(ctx: Context):\\n\",\n    \"    if ctx.is_on_arbiter:\\n\",\n    \"        train_llm(ctx)\\n\",\n    \"    else:\\n\",\n    \"        os.environ[\\\"CUDA_VISIBLE_DEVICES\\\"] = \\\"0\\\"\\n\",\n    \"        train_slm(ctx)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"if __name__ == \\\"__main__\\\":\\n\",\n    \"    launch(run)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Running FDKT with Pipeline (Industrial Using)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Please make sure that FATE and FATE-Flow has been deployed, paste the following code to test_fdkt_by_pipeline.py, the execute \\\"python test_fdkt_by_pipeline.py\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_fdkt_runner\\n\",\n    \"from fate_client.pipeline.components.fate.nn.algo_params import FDKTTrainingArguments\\n\",\n    \"from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMDataFuncLoader\\n\",\n    \"from fate_client.pipeline import FateFlowPipeline\\n\",\n    \"from fate_client.pipeline.components.fate.reader import Reader\\n\",\n    \"from fate_client.pipeline.components.fate.nn.torch import nn, optim\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"guest = '9999'# replace this party id to actual guest party id in your enviroment\\n\",\n    \"arbiter = '9999'# replace this party id to actual arbiter party id in your enviroment\\n\",\n    \"\\n\",\n    \"# please replace the following four variables to local paths\\n\",\n    \"llm_pretrained_path = \\\"Qwen1.5-7B-Chat\\\"\\n\",\n    \"embedding_model_path = \\\"all-mpnet-base-v2/\\\"\\n\",\n    \"slm_pretrained_path = \\\"gpt2-xl\\\"\\n\",\n    \"slm_data_path = \\\"./processed_data/Health/train.json\\\" # should be absolute path\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def get_llm_conf():\\n\",\n    \"    embedding_model = LLMModelLoader(\\n\",\n    \"        \\\"embedding_transformer.st_model\\\",\\n\",\n    \"        \\\"SentenceTransformerModel\\\",\\n\",\n    \"        model_name_or_path=embedding_model_path\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    dataset = LLMDatasetLoader(\\n\",\n    \"        \\\"flex_dataset\\\",\\n\",\n    \"        \\\"FlexDataset\\\",\\n\",\n    \"        tokenizer_name_or_path=llm_pretrained_path,\\n\",\n    \"        need_preprocess=True,\\n\",\n    \"        dataset_name=\\\"yelp_review\\\",\\n\",\n    \"        data_part=\\\"train\\\",\\n\",\n    \"        load_from=\\\"json\\\",\\n\",\n    \"        few_shot_num_per_label=1,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    training_args = FDKTTrainingArguments(\\n\",\n    \"        sample_num_per_cluster=4,\\n\",\n    \"        filter_prompt_max_length=2 ** 14,\\n\",\n    \"        filter_generation_config=dict(\\n\",\n    \"            max_tokens=4096,\\n\",\n    \"        ),\\n\",\n    \"        use_cpu=False,\\n\",\n    \"        aug_generation_config=dict(\\n\",\n    \"            max_tokens=4096,\\n\",\n    \"            temperature=0.8,\\n\",\n    \"            top_p=0.9,\\n\",\n    \"        ),\\n\",\n    \"        aug_prompt_num=20000,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    inference_inst_conf = dict(\\n\",\n    \"        module_name=\\\"fate_llm.algo.fdkt.inference_inst\\\",\\n\",\n    \"        item_name=\\\"api_init\\\",\\n\",\n    \"        kwargs=dict(\\n\",\n    \"            api_url=\\\"http://127.0.0.1:9999/v1/\\\",\\n\",\n    \"            model_name=llm_pretrained_path,\\n\",\n    \"            api_key=\\\"demo\\\"\\n\",\n    \"        )\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    return get_config_of_fdkt_runner(\\n\",\n    \"        training_args=training_args,\\n\",\n    \"        embedding_model=embedding_model,\\n\",\n    \"        dataset=dataset,\\n\",\n    \"        inference_inst_conf=inference_inst_conf,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def get_slm_conf():\\n\",\n    \"    slm_model = LLMModelLoader(\\n\",\n    \"        \\\"hf_model\\\",\\n\",\n    \"        \\\"HFAutoModelForCausalLM\\\",\\n\",\n    \"        pretrained_model_name_or_path=slm_pretrained_path,\\n\",\n    \"        torch_dtype=\\\"bfloat16\\\",\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    tokenizer = LLMDataFuncLoader(\\n\",\n    \"        \\\"tokenizers.cust_tokenizer\\\",\\n\",\n    \"        \\\"get_tokenizer\\\",\\n\",\n    \"        tokenizer_name_or_path=slm_pretrained_path,\\n\",\n    \"        pad_token_id=50256\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    training_args = FDKTTrainingArguments(\\n\",\n    \"        use_cpu=False,\\n\",\n    \"        device_id=1,\\n\",\n    \"        num_train_epochs=1,\\n\",\n    \"        per_device_train_batch_size=2,\\n\",\n    \"        slm_generation_batch_size=32,\\n\",\n    \"        seq_num_for_single_category=2000,\\n\",\n    \"        slm_generation_config=dict(\\n\",\n    \"            max_new_tokens=256,\\n\",\n    \"            do_sample=True,\\n\",\n    \"            temperature=1.0,\\n\",\n    \"            top_k=50,\\n\",\n    \"            top_p=0.9,\\n\",\n    \"            repetition_penalty=1.0,\\n\",\n    \"            pad_token_id=50256\\n\",\n    \"        ),\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    dataset = LLMDatasetLoader(\\n\",\n    \"        \\\"flex_dataset\\\",\\n\",\n    \"        \\\"FlexDataset\\\",\\n\",\n    \"        tokenizer_name_or_path=slm_pretrained_path,\\n\",\n    \"        need_preprocess=True,\\n\",\n    \"        dataset_name=\\\"yelp_review\\\",\\n\",\n    \"        data_part=\\\"train\\\",\\n\",\n    \"        load_from=\\\"json\\\",\\n\",\n    \"        select_num=2000,\\n\",\n    \"        few_shot_num_per_label=1,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    optimizer = optim.Adam(lr=0.01)\\n\",\n    \"\\n\",\n    \"    return get_config_of_fdkt_runner(\\n\",\n    \"        model=slm_model,\\n\",\n    \"        tokenizer=tokenizer,\\n\",\n    \"        training_args=training_args,\\n\",\n    \"        dataset=dataset,\\n\",\n    \"        optimizer=optimizer,\\n\",\n    \"        data_collator=LLMDataFuncLoader(\\n\",\n    \"            \\\"data_collator.cust_data_collator\\\",\\n\",\n    \"            \\\"get_seq2seq_data_collator\\\",\\n\",\n    \"            label_pad_token_id=50256,\\n\",\n    \"            tokenizer_name_or_path=slm_pretrained_path,\\n\",\n    \"            pad_token_id=50256,\\n\",\n    \"        ),\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter)\\n\",\n    \"pipeline.bind_local_path(path=slm_data_path, namespace=\\\"experiment\\\", name=\\\"slm_train\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"reader_0 = Reader(\\\"reader_0\\\", runtime_parties=dict(guest=guest))\\n\",\n    \"reader_0.guest.task_parameters(\\n\",\n    \"    namespace=\\\"experiment\\\",\\n\",\n    \"    name=\\\"slm_train\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"homo_nn_0 = HomoNN(\\n\",\n    \"    'homo_nn_0',\\n\",\n    \"    train_data=reader_0.outputs[\\\"output_data\\\"],\\n\",\n    \"    runner_module=\\\"fdkt_runner\\\",\\n\",\n    \"    runner_class=\\\"FDKTRunner\\\",\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"homo_nn_0.arbiter.task_parameters(\\n\",\n    \"    runner_conf=get_llm_conf()\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"homo_nn_0.guest.task_parameters(\\n\",\n    \"    runner_conf=get_slm_conf()\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"pipeline.add_tasks([reader_0, homo_nn_0])\\n\",\n    \"pipeline.conf.set(\\\"task\\\", dict(engine_run={\\\"cores\\\": 1}))\\n\",\n    \"\\n\",\n    \"pipeline.compile()\\n\",\n    \"pipeline.fit()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.15\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "doc/tutorial/fedcot/README.md",
    "content": "# FATE-LLM: FedCoT\n\nThe algorithm is based on paper [\"FedCoT: Federated Chain-of-Thought Distillation for Large Language Models\"](https://aclanthology.org/anthology-files/anthology-files/pdf/findings/2025.findings-emnlp.454.pdf), We integrate its code into the FATE-LLM framework.  \n\n## Citation\nIf you publish work that uses FedMKT, please cite FedCoT as follows:\n```\n@inproceedings{fan2025fedcot,\n  title={FedCoT: Federated Chain-of-Thought Distillation for Large Language Models},\n  author={Fan, Tao and Chen, Weijing and Kang, Yan and Ma, Guoqiang and Gu, Hanlin and Song, Yuanfeng and Fan, Lixin and Yang, Qiang},\n  booktitle={Findings of the Association for Computational Linguistics: EMNLP 2025},\n  pages={8546--8557},\n  year={2025}\n}\n```\n"
  },
  {
    "path": "doc/tutorial/fedcot/encoder_decoder_tutorial.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"a163d9c2-f9d6-4c61-a8e8-76a3f66c38ae\",\n   \"metadata\": {},\n   \"source\": [\n    \"# FedCoT - Train a SLM Encoder Decoder\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f2b56772-26d5-44fe-9c51-7bc662478b98\",\n   \"metadata\": {},\n   \"source\": [\n    \"FedCoT is an innovative framework designed to distill knowledge from large language models (LLMs) to small language models (SLMs) while ensuring data privacy. This method involves a strategy that trains a small language model (SLM) to learn from perturbed and recovered texts. The SLM can then encode raw text, produce results similar to differential privacy mechanisms, and return higher quality recovered text.\\n\",\n    \"\\n\",\n    \"In this tutorial, we will introduce how to train an SLM using the built-in trainer.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"62c6d18a-cc91-4cf5-9cfd-0f97095f7041\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Prepare Data\\n\",\n    \"\\n\",\n    \"Several steps need to be done to prepare data for training a SLM encoder-decoder model:\\n\",\n    \"- Sample data from original dataset(For example 50%)\\n\",\n    \"- Organize raw text and get a direct rationale reply from a remote LLM\\n\",\n    \"- Perturb doc using InferDPTKit to get perturbed docs\\n\",\n    \"- Get perturbed replies from a remote LLM\\n\",\n    \"- Organize training data\\n\",\n    \"\\n\",\n    \"### Sample data\\n\",\n    \"Here we will use the arc-easy data as an example, and take first 50% of the original dataset\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"id\": \"40cc1bb8-a17c-4abc-9279-0849e98ca116\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from datasets import load_dataset, load_from_disk\\n\",\n    \"ds = load_dataset('arc_easy')['train']\\n\",\n    \"ds = [ds[i] for i in range(len(ds)//2)]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"0caff897-5b2b-4409-8601-10f973133b10\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Get Direct Replies from A Remote LLM\\n\",\n    \"\\n\",\n    \"We use the inference class to create an API for remote LLMs, or you can implement this part on your own.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 42,\n   \"id\": \"cf128b46-dea2-4eb4-bf31-568e56b9b78e\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from fate_llm.inference.api import APICompletionInference\\n\",\n    \"from jinja2 import Template\\n\",\n    \"from transformers import AutoTokenizer\\n\",\n    \"\\n\",\n    \"# We are using a Qwen 14B model as the remote model\\n\",\n    \"# You can change the setting\\n\",\n    \"api = APICompletionInference(\\n\",\n    \"    api_url='http://172.21.140.2:8081/v1',\\n\",\n    \"    api_key='EMPTY',\\n\",\n    \"    model_name='/data/cephfs/llm/models/Qwen1.5-14B-Chat'\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained('/data/cephfs/llm/models/Qwen1.5-0.5B-Chat/')\\n\",\n    \"\\n\",\n    \"arc_e_template_r = \\\"\\\"\\\"Select Answer from Choices and explain it in \\\"Rationale\\\" with few words. Please refer to the example to write the rationale.\\n\",\n    \"Use <end> to finish your rationle.\\n\",\n    \"\\n\",\n    \"Example(s):\\n\",\n    \"Question:Which factor will most likely cause a person to develop a fever?\\n\",\n    \"Choices:['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach']\\n\",\n    \"Rationale:A bacterial infection in the bloodstream triggers the immune system to respond, therefore often causing a fever as the body tries to fight off the bacteria. Therefore, the answer is 'a bacterial population in the bloodstream'\\n\",\n    \"\\n\",\n    \"Please explain:\\n\",\n    \"Question:{{question}}\\n\",\n    \"Choices:{{choices.text}}\\n\",\n    \"Rationale:\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"template = Template(arc_e_template_r)\\n\",\n    \"docs_to_infer = [tokenizer.apply_chat_template([{'role':'system', 'content': 'you are a helpful assistant'}, {'role':'user', 'content': template.render(i)}], add_generation_prompt=True, tokenize=False) for i in ds]\\n\",\n    \"results = api.inference(docs_to_infer, {\\n\",\n    \"    'stop': ['<|im_end|>', '<end>', '<end>\\\\n', '<end>\\\\n\\\\n', '.\\\\n\\\\n\\\\n\\\\n\\\\n', '<|end_of_text|>', '>\\\\n\\\\n\\\\n'],\\n\",\n    \"    'temperature': 0.01,\\n\",\n    \"    'max_tokens': 256\\n\",\n    \"})\\n\",\n    \"\\n\",\n    \"for i, r in zip(ds, results):\\n\",\n    \"    i['rationale'] = r\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"id\": \"212822ab-9f64-49a2-bb95-ef8ee2de8e49\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"A fever is a response to an infection, typically caused by bacteria or viruses. So, the answer is 'a bacterial population in the bloodstream' because it indicates an immune response to a foreign invader. 'Several viral particles on the skin' could also lead to a fever if they enter the body, but bloodstream presence is more direct. The other choices are unrelated to fever development.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(results[0])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"0f6a0039-1530-4b87-a098-fd2eb01805c2\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Perturb Docs & Replies\\n\",\n    \"\\n\",\n    \"You can refer to the InferDPT tutorial for guidance on using the InferDPTKit to generate perturbed documents: [InferDPT Document](./)\\n\",\n    \"We can produce perturbed doc using InferDPTKit:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"id\": \"39249747-bfaa-43bf-8b66-896568941ab8\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_llm.algo.inferdpt.utils import InferDPTKit\\n\",\n    \"path_to_kit = '/data/projects/inferdpt/test_fate_llm/'\\n\",\n    \"kit = InferDPTKit.load_from_path(path_to_kit)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"id\": \"39b9cefa-dfdb-4bac-b313-4ca3bc118aee\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import copy\\n\",\n    \"tmp_ds = copy.deepcopy(ds)\\n\",\n    \"\\n\",\n    \"q_doc = [kit.perturb(i, epsilon=1.0) for i in [Template(\\\"\\\"\\\"{{question}}\\\"\\\"\\\").render(i) for i in tmp_ds]]\\n\",\n    \"c_doc = [kit.perturb(i, epsilon=1.0) for i in [Template(\\\"\\\"\\\"{{choices.text}}\\\"\\\"\\\").render(i) for i in tmp_ds]]\\n\",\n    \"for i,q,c in zip(tmp_ds,q_doc,c_doc):\\n\",\n    \"    i['question'] = q\\n\",\n    \"    i['choices']['text'] = c\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"id\": \"61b30886-746c-43c5-889a-a6583dc939d0\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'id': 'Mercury_7179953',\\n\",\n       \" 'question': 'stuff two alpha Rogers are today chap in Department?',\\n\",\n       \" 'choices': {'text': \\\"['muscular and skeletal', 'digestive and muscular', 'skeletal and pasteiratory', 'respiratory and exhibive']\\\",\\n\",\n       \"  'label': ['A', 'B', 'C', 'D']},\\n\",\n       \" 'answerKey': 'A',\\n\",\n       \" 'rationale': {...}}\"\n      ]\n     },\n     \"execution_count\": 23,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"tmp_ds[6]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"fed90297-9957-4f8b-a53c-37a03d516c78\",\n   \"metadata\": {},\n   \"source\": [\n    \"And then send formatted docs to remote LLM for perturbed responses:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 33,\n   \"id\": \"5b8bd833-fb0f-418b-bd9b-6452e8ae4d6c\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"template = Template(arc_e_template_r)\\n\",\n    \"docs_to_infer = [tokenizer.apply_chat_template([{'role':'system', 'content': 'you are a helpful assistant'}, {'role':'user', 'content': template.render(i)}], add_generation_prompt=True, tokenize=False) for i in tmp_ds]\\n\",\n    \"p_results = api.inference(docs_to_infer, {\\n\",\n    \"    'stop': ['<|im_end|>', '<end>', '<end>\\\\n', '<end>\\\\n\\\\n', '.\\\\n\\\\n\\\\n\\\\n\\\\n', '<|end_of_text|>', '>\\\\n\\\\n\\\\n'],\\n\",\n    \"    'temperature': 0.01,\\n\",\n    \"    'max_tokens': 256\\n\",\n    \"})\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 37,\n   \"id\": \"187361fa-8b73-4a01-9039-f52ec98a5791\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"for i, r in zip(ds, p_results):\\n\",\n    \"    i['p_rationale'] = r\\n\",\n    \"\\n\",\n    \"for i,q,c in zip(ds, q_doc, c_doc):\\n\",\n    \"    i['p_question'] = q\\n\",\n    \"    i['p_choice'] = c\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"927b2265-4e87-4275-98dc-7f33d405e19a\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Organize Training Data\\n\",\n    \"\\n\",\n    \"As described in the original paper, we need to train the encoder and decoder in one model.\\n\",\n    \"We can organize the training data using templates below:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 47,\n   \"id\": \"9292ad25-12c7-418a-9e77-b433b95f57ac\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"train_data = []\\n\",\n    \"\\n\",\n    \"encoder_prompt = Template(\\\"\\\"\\\"Disrupt the main words in the original text so that it becomes difficult to recognize, but at the same time, try to maintain the original meaning as much as possible. Use <end> to end your reply.\\n\",\n    \"Origin Doc: \\n\",\n    \"Question:{{question}}\\n\",\n    \"Choices:{{choices.text}}\\n\",\n    \"\\n\",\n    \"Perturbed Doc:\\n\",\n    \"\\\"\\\"\\\")\\n\",\n    \"\\n\",\n    \"encoder_out = Template(\\\"\\\"\\\"\\n\",\n    \"Question:{{p_question}}\\n\",\n    \"Choices:{{p_choice}}<end>\\n\",\n    \"\\\"\\\"\\\")\\n\",\n    \"\\n\",\n    \"decoder_in = Template(\\\"\\\"\\\"This is a perturbed question and its corresponding answer(rationale). And following is the original question. Try to recover the correct rationale from docs provided.\\n\",\n    \"\\n\",\n    \"Perturbed doc and rationale:\\n\",\n    \"Question:{{p_question}}\\n\",\n    \"Choices:{{p_choice}}\\n\",\n    \"Rationale:{{p_rationale}}\\n\",\n    \"\\n\",\n    \"Original Doc:\\n\",\n    \"Question:{{question}}\\n\",\n    \"Choices:{{choices.text}}\\n\",\n    \"\\n\",\n    \"Recover Rationale:\\n\",\n    \"\\\"\\\"\\\")\\n\",\n    \"\\n\",\n    \"decoder_out = Template(\\\"\\\"\\\"{{rationale}}<end>\\\"\\\"\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"for i in ds:\\n\",\n    \"    a = {}\\n\",\n    \"    a['encoder_in'] = encoder_prompt.render(i)\\n\",\n    \"    a['encoder_out'] = encoder_out.render(i)\\n\",\n    \"    a['decoder_in'] = decoder_in.render(i)\\n\",\n    \"    a['decoder_out'] = decoder_out.render(i)\\n\",\n    \"    train_data.append(a)\\n\",\n    \"\\n\",\n    \"import torch\\n\",\n    \"torch.save(train_data, './slm_ed_train_data.pkl')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"dd73db44-4e73-4c1e-8f27-755522587636\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Train Script\\n\",\n    \"\\n\",\n    \"The key step: preparing data is now done. Then we can train a SLM model using the train data. You can use following dataset&trainer class to train an encoder-decoder slm model. Here we use Qwen-0.5B as the example.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 51,\n   \"id\": \"eb01c591-3c04-4317-8bb0-f55846fb1b66\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import AutoModelForCausalLM, AutoTokenizer\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 52,\n   \"id\": \"f0da4e10-af80-4216-8ff8-5816dabc8526\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model = AutoModelForCausalLM.from_pretrained('/data/cephfs/llm/models/Qwen1.5-0.5B/').half().cuda()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 75,\n   \"id\": \"634fc973-29c8-499e-a99e-d50b7ee54124\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from torch.utils.data import Dataset\\n\",\n    \"\\n\",\n    \"class EDDataset(Dataset):\\n\",\n    \"\\n\",\n    \"    def __init__(self, tokenizer, train_data, max_input_length=64, max_target_length=64):\\n\",\n    \"        self.tokenizer = tokenizer\\n\",\n    \"        self.dataset = train_data\\n\",\n    \"        self.max_input_length = max_input_length\\n\",\n    \"        self.max_target_length = max_target_length\\n\",\n    \"        self.max_seq_length = max_input_length + max_target_length + 1\\n\",\n    \"\\n\",\n    \"    def get_str_item(self, i) -> dict:\\n\",\n    \"\\n\",\n    \"        data_item = self.dataset[i]\\n\",\n    \"        ret_dict = {\\n\",\n    \"            'encoder':{\\n\",\n    \"                'input': data_item['encoder_in'],\\n\",\n    \"                'output': data_item['encoder_out']\\n\",\n    \"            },\\n\",\n    \"            'decoder':{\\n\",\n    \"                'input': data_item['decoder_in'],\\n\",\n    \"                'output': data_item['decoder_out']\\n\",\n    \"            }\\n\",\n    \"        }\\n\",\n    \"        return ret_dict\\n\",\n    \"\\n\",\n    \"    def _process_item(self, data_item):\\n\",\n    \"\\n\",\n    \"        a_ids = self.tokenizer.encode(text=data_item['input'], add_special_tokens=True, truncation=True,\\n\",\n    \"                                      max_length=self.max_input_length)\\n\",\n    \"        b_ids = self.tokenizer.encode(text=data_item['output'], add_special_tokens=False, truncation=True,\\n\",\n    \"                                      max_length=self.max_target_length)\\n\",\n    \"        context_length = len(a_ids)\\n\",\n    \"        input_ids = a_ids + b_ids + [self.tokenizer.eos_token_id]\\n\",\n    \"        labels = [self.tokenizer.pad_token_id] * context_length + b_ids + [self.tokenizer.eos_token_id]\\n\",\n    \"        pad_len = self.max_seq_length - len(input_ids)\\n\",\n    \"        input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len\\n\",\n    \"        labels = labels + [self.tokenizer.pad_token_id] * pad_len\\n\",\n    \"        labels = [(l if l != self.tokenizer.pad_token_id else -100) for l in labels]\\n\",\n    \"\\n\",\n    \"        assert len(input_ids) == len(labels), f\\\"length mismatch: {len(input_ids)} vs {len(labels)}\\\"\\n\",\n    \"\\n\",\n    \"        return {\\n\",\n    \"            \\\"input_ids\\\": input_ids,\\n\",\n    \"            \\\"labels\\\": labels\\n\",\n    \"        }\\n\",\n    \"\\n\",\n    \"    def get_tokenized_item(self, i) -> dict:   \\n\",\n    \"\\n\",\n    \"        str_item = self.get_str_item(i)\\n\",\n    \"        ret_dict = {\\n\",\n    \"            'encoder': self._process_item(str_item['encoder']),\\n\",\n    \"            'docoder': self._process_item(str_item['decoder'])\\n\",\n    \"        }\\n\",\n    \"        return ret_dict\\n\",\n    \"\\n\",\n    \"    def __getitem__(self, i) -> dict:\\n\",\n    \"        item = self.get_tokenized_item(i)\\n\",\n    \"        return item\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 76,\n   \"id\": \"5f914b1f-cf14-4bdc-acc9-ae1b73cf857c\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"train_ds = EDDataset(AutoTokenizer.from_pretrained('/data/cephfs/llm/models/Qwen1.5-0.5B/'), train_data)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"817084b2-2439-45d8-aa1b-da0b1a8a2846\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"print(train_ds.get_str_item(0))\\n\",\n    \"print(train_ds[0])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 82,\n   \"id\": \"303bcb23-d54b-4375-bad2-bf5450c14f28\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_llm.algo.fedcot.slm_encoder_decoder_trainer import EncoderDecoderPrefixTrainer, EDPrefixDataCollator\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"aa5a0b4f-cd03-4867-8753-fc5bcb036c69\",\n   \"metadata\": {},\n   \"source\": [\n    \"After completing the setup, you can utilize the EncoderDecoderPrefixTrainer, EDPrefixDataCollator, and the training dataset to train an SLM encoder-decoder model following the Huggingface approach! \"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.10.14\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}"
  },
  {
    "path": "doc/tutorial/fedcot/fedcot_tutorial.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"9234355d-389f-484f-9fc2-7b17563b3390\",\n   \"metadata\": {},\n   \"source\": [\n    \"# FedCoT Tutorial\\n\",\n    \"\\n\",\n    \"## Introduction to FedCoT\\n\",\n    \"\\n\",\n    \"FedCoT (Federated Chain-of-Thought) is a novel framework designed to distill knowledge from large language models (LLMs) to small language models (SLMs) while ensuring data privacy. The framework addresses two major challenges faced by LLM deployment in real-world applications: the privacy of domain-specific knowledge and resource constraints.\\n\",\n    \"\\n\",\n    \"FedCoT adopts a server-client architecture where the client sends perturbed prompts to the server-side LLM for inference, generating perturbed rationales. The client then decodes these rationales and uses them to enrich the training of its task-specific SLM, ultimately enhancing its performance.\\n\",\n    \"\\n\",\n    \"FedCoT introduces two privacy protection strategies: \\n\",\n    \"- **the Exponential Mechanism Strategy**\\n\",\n    \"- **the Encoder-Decoder Strategy**\\n\",\n    \"  \\n\",\n    \"The Exponential Mechanism Strategy utilizes a DP(differential privacy) based exponential mechanism to obfuscate user prompts, while the Encoder-Decoder Strategy employs a specialized Encoder-Decoder SLM to encode and decode perturbed prompts and rationales. These strategies effectively balance user privacy and the usability of rationales, allowing for secure and enhanced training of the client's SLM without compromising on privacy concerns.\\n\",\n    \"\\n\",\n    \"Through experiments on various text generation tasks, FedCoT demonstrates its effectiveness in training task-specific SLMs with enhanced performance, significantly improving the SLM's capabilities while prioritizing data privacy protection. For more details, please refer to the paper: [FedCoT: Federated Chain-of-Thought Distillation for Large Language Models](https://arxiv.org/pdf/2406.12403).\\n\",\n    \"\\n\",\n    \"**Before reading this tutorial, we strongly recommend that you first read [the InferDPT](./) tutorial.**\\n\",\n    \"\\n\",\n    \"## Use the Infer Client & Server\\n\",\n    \"\\n\",\n    \"In this section, we are going to introduce the inference part, which is the key part of FedCoT that generates useful rationales with privacy-preserving. You can use InferDPT(which utilize the Exponential Mechanism Strategy) or specifically trained SLM as the text encoder & decoder. In this section, we retrieve a sample from the arc-easy dataset as an example:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"id\": \"c443c920-31ff-446a-801f-d7a02409a8c0\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"test_example = {'id': 'Mercury_7220990',\\n\",\n    \"'question': 'Which factor will most likely cause a person to develop a fever?',\\n\",\n    \"'choices': {'text': ['a leg muscle relaxing after exercise',\\n\",\n    \"'a bacterial population in the bloodstream',\\n\",\n    \"'several viral particles on the skin',\\n\",\n    \"'carbohydrates being digested in the stomach'],\\n\",\n    \"'label': ['A', 'B', 'C', 'D']},\\n\",\n    \"'answerKey': 'B'}\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"46646b18-46bb-476d-8b1d-1ef661446929\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Fate Context\\n\",\n    \"\\n\",\n    \"We need to create fate context to enable the communication between client and server. Then, we can initialize infer client(who will encodes the raw prompt and decodes the perturbed response) and server(who deploys the LLM) to enable secure inference.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"id\": \"0cc8e8f8-88d7-45ab-a988-5ead06356418\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"arbiter = (\\\"arbiter\\\", 10000)\\n\",\n    \"guest = (\\\"guest\\\", 10000)\\n\",\n    \"host = (\\\"host\\\", 9999)\\n\",\n    \"name = \\\"fed1\\\"\\n\",\n    \"\\n\",\n    \"def create_ctx(local):\\n\",\n    \"    from fate.arch import Context\\n\",\n    \"    from fate.arch.computing.backends.standalone import CSession\\n\",\n    \"    from fate.arch.federation.backends.standalone import StandaloneFederation\\n\",\n    \"    import logging\\n\",\n    \"\\n\",\n    \"    logger = logging.getLogger()\\n\",\n    \"    logger.setLevel(logging.INFO)\\n\",\n    \"\\n\",\n    \"    console_handler = logging.StreamHandler()\\n\",\n    \"    console_handler.setLevel(logging.INFO)\\n\",\n    \"\\n\",\n    \"    formatter = logging.Formatter(\\\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\\\")\\n\",\n    \"    console_handler.setFormatter(formatter)\\n\",\n    \"\\n\",\n    \"    logger.addHandler(console_handler)\\n\",\n    \"    computing = CSession(data_dir=\\\"./session_dir\\\")\\n\",\n    \"    return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"c75dbcda-1a40-421d-ab1b-92eca5600866\",\n   \"metadata\": {},\n   \"source\": [\n    \"### The DP based Strategy(InferDPT)\\n\",\n    \"\\n\",\n    \"As outlined in the [InferDPT tutorial](./), you can initialize the InferDPT client and server to facilitate secure and private inference. Prior to executing the InferDPT component, it is recommended to generate the InferDPT kit by following the step-by-step instructions provided in the tutorial.\\n\",\n    \"\\n\",\n    \"#### Client-Side Code\\n\",\n    \"\\n\",\n    \"On the client side, we load the pre-computed inferdpt-kit and deploy a local SLM as the decoding model.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"ff0f317f-414f-4b9f-84e6-b992b31350cb\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_llm.inference.api import APICompletionInference\\n\",\n    \"from fate_llm.algo.inferdpt import inferdpt\\n\",\n    \"from fate_llm.algo.inferdpt.utils import InferDPTKit\\n\",\n    \"import sys\\n\",\n    \"\\n\",\n    \"arbiter = (\\\"arbiter\\\", 10000)\\n\",\n    \"guest = (\\\"guest\\\", 10000)\\n\",\n    \"host = (\\\"host\\\", 9999)\\n\",\n    \"name = \\\"fed1\\\"\\n\",\n    \"\\n\",\n    \"def create_ctx(local):\\n\",\n    \"    from fate.arch import Context\\n\",\n    \"    from fate.arch.computing.backends.standalone import CSession\\n\",\n    \"    from fate.arch.federation.backends.standalone import StandaloneFederation\\n\",\n    \"    import logging\\n\",\n    \"\\n\",\n    \"    logger = logging.getLogger()\\n\",\n    \"    logger.setLevel(logging.INFO)\\n\",\n    \"\\n\",\n    \"    console_handler = logging.StreamHandler()\\n\",\n    \"    console_handler.setLevel(logging.INFO)\\n\",\n    \"\\n\",\n    \"    formatter = logging.Formatter(\\\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\\\")\\n\",\n    \"    console_handler.setFormatter(formatter)\\n\",\n    \"\\n\",\n    \"    logger.addHandler(console_handler)\\n\",\n    \"    computing = CSession(data_dir=\\\"./session_dir\\\")\\n\",\n    \"    return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))\\n\",\n    \"\\n\",\n    \"ctx = create_ctx(guest)\\n\",\n    \"save_kit_path = 'your path'\\n\",\n    \"kit = InferDPTKit.load_from_path(save_kit_path)\\n\",\n    \"# local deployed small model as decoding model\\n\",\n    \"inference = APICompletionInference(api_url=\\\"http://127.0.0.1:8887/v1\\\", model_name='./Qwen1.5-0.5B', api_key='EMPTY')\\n\",\n    \"\\n\",\n    \"test_example = {'id': 'Mercury_7220990',\\n\",\n    \"'question': 'Which factor will most likely cause a person to develop a fever?',\\n\",\n    \"'choices': {'text': ['a leg muscle relaxing after exercise',\\n\",\n    \"'a bacterial population in the bloodstream',\\n\",\n    \"'several viral particles on the skin',\\n\",\n    \"'carbohydrates being digested in the stomach'],\\n\",\n    \"'label': ['A', 'B', 'C', 'D']},\\n\",\n    \"'answerKey': 'B'}\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"doc_template = \\\"\\\"\\\"{{question}} \\n\",\n    \"Choices:{{choices.text}}\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"instruction_template=\\\"\\\"\\\"\\n\",\n    \"<s>[INST]\\n\",\n    \"Select Answer from Choices and explain it in \\\"Rationale\\\" with few words. Please refer to the example to write the rationale.\\n\",\n    \"Use <end> to finish your rationle.\\\"\\n\",\n    \"\\n\",\n    \"Example(s):\\n\",\n    \"Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\\n\",\n    \"Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\\n\",\n    \"Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction.  Therefore, the answer is 'dry palms'.<end>\\n\",\n    \"\\n\",\n    \"Please explain:\\n\",\n    \"Question:{{perturbed_doc}}\\n\",\n    \"Rationale:\\n\",\n    \"[/INST]\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"decode_template = \\\"\\\"\\\"Select Answer from Choices and explain it in \\\"Rationale\\\" with few words. Please refer to the example to write the rationale.\\n\",\n    \"Use <end> to finish your rationle.\\\"\\n\",\n    \"\\n\",\n    \"Example(s):\\n\",\n    \"Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\\n\",\n    \"Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\\n\",\n    \"Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction.  Therefore, the answer is 'dry palms'.<end>\\n\",\n    \"\\n\",\n    \"Question:{{perturbed_doc}}\\n\",\n    \"Rationale:{{perturbed_response | replace('\\\\n', '')}}<end>\\n\",\n    \"\\n\",\n    \"Please explain:\\n\",\n    \"Question:{{question}} \\n\",\n    \"Choices:{{choices.text}}\\n\",\n    \"Rationale:\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"inferdpt_client = inferdpt.InferDPTClient(ctx, kit, inference, epsilon=3.0)\\n\",\n    \"result = inferdpt_client.inference([test_example], doc_template, instruction_template, decode_template, \\\\\\n\",\n    \"                                 remote_inference_kwargs={\\n\",\n    \"                                    'stop': ['<\\\\s>'],\\n\",\n    \"                                    'temperature': 0.01,\\n\",\n    \"                                    'max_tokens': 256\\n\",\n    \"                                 },\\n\",\n    \"                                 local_inference_kwargs={\\n\",\n    \"                                    'stop': ['<|im_end|>', '<end>', '<end>\\\\n', '<end>\\\\n\\\\n', '.\\\\n\\\\n\\\\n\\\\n\\\\n', '<|end_of_text|>', '>\\\\n\\\\n\\\\n'],\\n\",\n    \"                                    'temperature': 0.01,\\n\",\n    \"                                    'max_tokens': 256\\n\",\n    \"                                 })\\n\",\n    \"print('result is {}'.format(result[0]['inferdpt_result']))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"96fbcb01-6907-432f-8393-ae1746559c3a\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Server Side Code\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"id\": \"960a476c-50a5-40fb-847d-02101cea27ae\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_llm.algo.inferdpt.inferdpt import InferDPTServer\\n\",\n    \"import sys\\n\",\n    \"from fate_llm.inference.api import APICompletionInference\\n\",\n    \"\\n\",\n    \"arbiter = (\\\"arbiter\\\", 10000)\\n\",\n    \"guest = (\\\"guest\\\", 10000)\\n\",\n    \"host = (\\\"host\\\", 9999)\\n\",\n    \"name = \\\"fed1\\\"\\n\",\n    \"\\n\",\n    \"def create_ctx(local):\\n\",\n    \"    from fate.arch import Context\\n\",\n    \"    from fate.arch.computing.backends.standalone import CSession\\n\",\n    \"    from fate.arch.federation.backends.standalone import StandaloneFederation\\n\",\n    \"    import logging\\n\",\n    \"\\n\",\n    \"    logger = logging.getLogger()\\n\",\n    \"    logger.setLevel(logging.INFO)\\n\",\n    \"\\n\",\n    \"    console_handler = logging.StreamHandler()\\n\",\n    \"    console_handler.setLevel(logging.INFO)\\n\",\n    \"\\n\",\n    \"    formatter = logging.Formatter(\\\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\\\")\\n\",\n    \"    console_handler.setFormatter(formatter)\\n\",\n    \"\\n\",\n    \"    logger.addHandler(console_handler)\\n\",\n    \"    computing = CSession(data_dir=\\\"./session_dir\\\")\\n\",\n    \"    return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))\\n\",\n    \"\\n\",\n    \"ctx = create_ctx(arbiter)\\n\",\n    \"# Api to a LLM\\n\",\n    \"inference_server = APICompletionInference(api_url=\\\"http://127.0.0.1:8888/v1\\\", model_name='./Mistral-7B-Instruct-v0.2', api_key='EMPTY')\\n\",\n    \"inferdpt_server = InferDPTServer(ctx, inference_server)\\n\",\n    \"inferdpt_server.inference()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"16f908a7-9187-461a-93db-9945456d502d\",\n   \"metadata\": {},\n   \"source\": [\n    \"Start two terminal and launch client&server scripts simultaneously. On the client side we can get the answer:\\n\",\n    \"\\n\",\n    \"```\\n\",\n    \"The given question asks which factor will most likely cause a person to develop a fever. The factors mentioned are a leg muscle relaxing after exercise, a bacterial population in the bloodstream, several viral particles on the skin, and carbohydrates being digested in the stomach. The question is asking which factor is most likely to cause a person to develop a fever. The factors are all related to the body's internal environment, but the most likely factor is a bacterial population in the bloodstream. This is because bacteria can cause a fever, and the body's immune system responds to the infection by producing antibodies that can fight off the bacteria. Therefore, the answer is 'a bacterial population in the bloodstream'\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"fb36a485-2fa8-4629-a2cf-2d53fdbbcc5f\",\n   \"metadata\": {},\n   \"source\": [\n    \"### The Encoder-Decoder Model Strategy\\n\",\n    \"\\n\",\n    \"Similar to the InferDPT, we can initialize SLMEncoderDecoderClient and SLMEncoderDecoderServer to enable secure inference.\\n\",\n    \"The client will encode the raw prompt using local slm model and then decoded it with the same model\\n\",\n    \"\\n\",\n    \"#### Client Side Code\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"cd174244-8640-4cb2-8609-ac6468f5a6f5\",\n   \"metadata\": {\n    \"tags\": []\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_llm.inference.api import APICompletionInference\\n\",\n    \"from fate_llm.algo.fedcot.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderClient\\n\",\n    \"\\n\",\n    \"arbiter = (\\\"arbiter\\\", 10000)\\n\",\n    \"guest = (\\\"guest\\\", 10000)\\n\",\n    \"host = (\\\"host\\\", 9999)\\n\",\n    \"name = \\\"fed1\\\"\\n\",\n    \"\\n\",\n    \"def create_ctx(local):\\n\",\n    \"    from fate.arch import Context\\n\",\n    \"    from fate.arch.computing.backends.standalone import CSession\\n\",\n    \"    from fate.arch.federation.backends.standalone import StandaloneFederation\\n\",\n    \"    import logging\\n\",\n    \"\\n\",\n    \"    logger = logging.getLogger()\\n\",\n    \"    logger.setLevel(logging.INFO)\\n\",\n    \"\\n\",\n    \"    console_handler = logging.StreamHandler()\\n\",\n    \"    console_handler.setLevel(logging.INFO)\\n\",\n    \"\\n\",\n    \"    formatter = logging.Formatter(\\\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\\\")\\n\",\n    \"    console_handler.setFormatter(formatter)\\n\",\n    \"\\n\",\n    \"    logger.addHandler(console_handler)\\n\",\n    \"    computing = CSession(data_dir=\\\"./session_dir\\\")\\n\",\n    \"    return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"test_example = {'id': 'Mercury_7220990',\\n\",\n    \"'question': 'Which factor will most likely cause a person to develop a fever?',\\n\",\n    \"'choices': {'text': ['a leg muscle relaxing after exercise',\\n\",\n    \"'a bacterial population in the bloodstream',\\n\",\n    \"'several viral particles on the skin',\\n\",\n    \"'carbohydrates being digested in the stomach'],\\n\",\n    \"'label': ['A', 'B', 'C', 'D']},\\n\",\n    \"'answerKey': 'B'\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"encode_prompt = \\\"\\\"\\\"Disrupt the main words in the original text so that it becomes difficult to recognize, but at the same time, try to maintain the original meaning as much as possible. Use <end> to end your reply.\\n\",\n    \"Origin Doc:Question:{{question}}\\n\",\n    \"Choices:{{choices.text}}\\n\",\n    \"Perturb Doc: \\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"decode_prompt = \\\"\\\"\\\"This is a perturbed question and its corresponding answer(rationale). And following is the original question. Try to recover the correct rationale from docs provided.\\n\",\n    \"\\n\",\n    \"Perturbed doc and rationale:\\n\",\n    \"{{perturbed_doc}}\\n\",\n    \"Rationale:{{perturbed_response}}\\n\",\n    \"\\n\",\n    \"Original Doc:\\n\",\n    \"Question:{{question}}\\n\",\n    \"Choices:{{choices.text}}\\n\",\n    \"\\n\",\n    \"Recover Rationale:\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"instruction_template = \\\"\\\"\\\"<|im_start|>system\\n\",\n    \"You are a helpful assistant<|im_end|>\\n\",\n    \"<|im_start|>user\\n\",\n    \"Select Answer from Choices and explain it in \\\"Rationale\\\" with few words. Please refer to the example to write the rationale.\\n\",\n    \"Use <end> to finish your rationle.\\n\",\n    \"\\n\",\n    \"Example(s):\\n\",\n    \"Question:Which factor will most likely cause a person to develop a fever?\\n\",\n    \"Choices:['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach']\\n\",\n    \"Rationale:A bacterial infection in the bloodstream triggers the immune system to respond, therefore often causing a fever as the body tries to fight off the bacteria. Therefore, the answer is 'a bacterial population in the bloodstream'\\n\",\n    \"\\n\",\n    \"Please explain:\\n\",\n    \"{{perturbed_doc}}\\n\",\n    \"Rationale:\\n\",\n    \"<|im_end|>\\n\",\n    \"<|im_start|>assistant\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"ctx = create_ctx(guest)\\n\",\n    \"model_name = 'Deploy your encoder decoder model'\\n\",\n    \"# api_url to your locally deployed encoder decoder\\n\",\n    \"api = APICompletionInference(api_url='http://127.0.0.1:8887/v1', api_key='EMPTY', model_name=model_name)\\n\",\n    \"client = SLMEncoderDecoderClient(ctx, api)\\n\",\n    \"result = client.inference([test_example], encode_prompt, instruction_template, decode_prompt, \\\\\\n\",\n    \"                                 remote_inference_kwargs={\\n\",\n    \"                                    'stop': ['<\\\\s>'],\\n\",\n    \"                                    'temperature': 0.01,\\n\",\n    \"                                    'max_tokens': 256\\n\",\n    \"                                 },\\n\",\n    \"                                 local_inference_kwargs={\\n\",\n    \"                                    'stop': ['<|im_end|>', '<end>', '<end>\\\\n', '<end>\\\\n\\\\n', '.\\\\n\\\\n\\\\n\\\\n\\\\n', '<|end_of_text|>', '>\\\\n\\\\n\\\\n'],\\n\",\n    \"                                    'temperature': 0.01,\\n\",\n    \"                                    'max_tokens': 256\\n\",\n    \"                                 })\\n\",\n    \"print('result is {}'.format(result[0]['inferdpt_result']))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"1a865536-7814-40a2-a814-d00e46f2787f\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Server Side Code\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"id\": \"cced44b0-0dcb-4427-8efe-a04135b246ac\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_llm.inference.api import APICompletionInference\\n\",\n    \"from fate_llm.algo.fedcot.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderServer\\n\",\n    \"\\n\",\n    \"arbiter = (\\\"arbiter\\\", 10000)\\n\",\n    \"guest = (\\\"guest\\\", 10000)\\n\",\n    \"host = (\\\"host\\\", 9999)\\n\",\n    \"name = \\\"fed1\\\"\\n\",\n    \"\\n\",\n    \"def create_ctx(local):\\n\",\n    \"    from fate.arch import Context\\n\",\n    \"    from fate.arch.computing.backends.standalone import CSession\\n\",\n    \"    from fate.arch.federation.backends.standalone import StandaloneFederation\\n\",\n    \"    import logging\\n\",\n    \"\\n\",\n    \"    logger = logging.getLogger()\\n\",\n    \"    logger.setLevel(logging.INFO)\\n\",\n    \"\\n\",\n    \"    console_handler = logging.StreamHandler()\\n\",\n    \"    console_handler.setLevel(logging.INFO)\\n\",\n    \"\\n\",\n    \"    formatter = logging.Formatter(\\\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\\\")\\n\",\n    \"    console_handler.setFormatter(formatter)\\n\",\n    \"\\n\",\n    \"    logger.addHandler(console_handler)\\n\",\n    \"    computing = CSession(data_dir=\\\"./session_dir\\\")\\n\",\n    \"    return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))\\n\",\n    \"\\n\",\n    \"ctx = create_ctx(arbiter)\\n\",\n    \"# api url&name are depolyed LLM\\n\",\n    \"model_name = '/data/cephfs/llm/models/Qwen1.5-14B-Chat/'\\n\",\n    \"api = APICompletionInference(api_url='http://127.0.0.1:8888/v1', api_key='EMPTY', model_name=model_name)\\n\",\n    \"server = SLMEncoderDecoderServer(ctx, api)\\n\",\n    \"server.inference()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"c38ed7a6-2eb2-4f46-b59c-eaafcc9a5b7a\",\n   \"metadata\": {},\n   \"source\": [\n    \"Start two terminal and launch client&server scripts simultaneously. On the client side we can get the answer:\\n\",\n    \"\\n\",\n    \"```\\n\",\n    \"A fever is typically caused by a bacterial population in the bloodstream, as it is a response to an infection. So the answer is 'a bacterial population in the bloodstream'.\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"41fbbefd-e931-4e95-9d28-9675ff7865a3\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Prefix Dataset & FedCoT Trainer\\n\",\n    \"\\n\",\n    \"Now that we can carry out privacy-preserving inference and acquire rationales, the next step is to train a new task-specific model, enhanced by the rationales generated by the LLMs.\\n\",\n    \"\\n\",\n    \"In this section, we will introduce the PrefixDataset and FedCoTTrainer, which facilitate training tasks with the added benefit of supplementary rationales. The PrefixDataset allows you to assign various text prefixes, guiding the model to produce different text targets. With FedCoTTrainer, the model is trained to generate both text labels and text rationales at each update step, ultimately leading to superior performance compared to training on the raw dataset alone.\\n\",\n    \"\\n\",\n    \"### Prepare dataset\\n\",\n    \"In this tutorial, we will use the arc-easy dataset.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"e25377d0-1a7e-4e8c-aa9f-3bcb03ae0c45\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from datasets import load_dataset\\n\",\n    \"dataset = load_dataset(\\\"arc_easy\\\")\\n\",\n    \"dataset.save_to_disk('path_to_save/arce')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"9166110f-bf67-4bf1-9da8-04c16bd79423\",\n   \"metadata\": {},\n   \"source\": [\n    \"Let’s proceed with testing the PrefixDataset. We can utilize Jinja2 templates to structure the text and append prefixes or suffixes to our training data.\\n\",\n    \"\\n\",\n    \"Please note that at this stage, the dataset does not contain rationales. In the 'rationale_output_template', the key used for the inference results is ‘infer_result’. We can perform secure inference using the FedCoTTrainer and then integrate the rationale results, keyed as ‘infer_result’, into the PrefixDataset.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"id\": \"fdbd93d6-45f3-404f-813e-9ca1fd6def04\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from fate_llm.dataset.fedcot_dataset import PrefixDataset\\n\",\n    \"\\n\",\n    \"pds = PrefixDataset(\\n\",\n    \"        tokenizer_path='/data/cephfs/llm/models/Qwen1.5-0.5B/',\\n\",\n    \"        predict_input_template=\\\"\\\"\\\"Predict:\\n\",\n    \"Question:{{question}}\\n\",\n    \"Choices:{{choices.text}}\\n\",\n    \"Answer:\\n\",\n    \"    \\\"\\\"\\\",\\n\",\n    \"        predict_output_template=\\\"\\\"\\\"{{choices.text[choices.label.index(answerKey)]}}<end>\\\"\\\"\\\",\\n\",\n    \"        rationale_input_template=\\\"\\\"\\\"Explain:\\n\",\n    \"Question:{{question}}\\n\",\n    \"Choices:{{choices.text}}\\n\",\n    \"Rationale:\\n\",\n    \"    \\\"\\\"\\\",\\n\",\n    \"        rationale_output_template=\\\"\\\"\\\"{{infer_result}}<end>\\\"\\\"\\\",\\n\",\n    \"        max_input_length=128,\\n\",\n    \"        max_target_length=128,\\n\",\n    \"        split_key='train'\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"pds.load('path_to_save/arce')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 27,\n   \"id\": \"100eeb69-8bd2-4e66-b1cc-667f95e47f23\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'id': 'Mercury_7220990',\\n\",\n       \" 'question': 'Which factor will most likely cause a person to develop a fever?',\\n\",\n       \" 'choices': {'text': ['a leg muscle relaxing after exercise',\\n\",\n       \"   'a bacterial population in the bloodstream',\\n\",\n       \"   'several viral particles on the skin',\\n\",\n       \"   'carbohydrates being digested in the stomach'],\\n\",\n       \"  'label': ['A', 'B', 'C', 'D']},\\n\",\n       \" 'answerKey': 'B'}\"\n      ]\n     },\n     \"execution_count\": 27,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"pds.dataset[0] # the structure is the same as hf dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 21,\n   \"id\": \"6f0356ef-f94b-41db-ab66-b1d0eb862eca\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'predict': {'input': \\\"Predict:\\\\nQuestion:Which factor will most likely cause a person to develop a fever?\\\\nChoices:['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach']\\\\nAnswer:\\\\n    \\\",\\n\",\n       \"  'output': 'a bacterial population in the bloodstream<end>'},\\n\",\n       \" 'rationale': {'input': \\\"Explain:\\\\nQuestion:Which factor will most likely cause a person to develop a fever?\\\\nChoices:['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach']\\\\nRationale:\\\\n    \\\",\\n\",\n       \"  'output': '<end>\\\\n    '}}\"\n      ]\n     },\n     \"execution_count\": 21,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"pds.get_str_item(0)  # we can see that the output of rationale term is empty\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 25,\n   \"id\": \"6a227af7-f24a-46bd-9af7-78584a381b33\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"print(pds[0]) # show tokenized, for the sake of breif we dont show it in this tutorial doc\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e0382a33-7a45-43a3-8ed3-58ed1d1b07d8\",\n   \"metadata\": {},\n   \"source\": [\n    \"### The FedCoTTrainer\\n\",\n    \"\\n\",\n    \"Here we introduce the FedCoTTrainer which is develop based on Huggingface trainer and supports collaboratively training a task with raw labels and additional rationales. Here show how the compute loss function is realized:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"b40b7d99-9ef8-43f9-8e28-db96d96af62a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def compute_loss(self, model, inputs, return_outputs=False):\\n\",\n    \"\\n\",\n    \"    label_outputs = model(**inputs['predict'])\\n\",\n    \"    cot_outputs = model(**inputs['rationale'])\\n\",\n    \"    loss = self.alpha * cot_outputs.loss + (1. - self.alpha) * label_outputs.loss\\n\",\n    \"    return (loss, {'rationale_loss': cot_outputs, 'predict_loss': label_outputs}) if return_outputs else loss\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ff1cee5d-68e1-4caf-96b9-132b27b46dca\",\n   \"metadata\": {},\n   \"source\": [\n    \"You have the option to choose from three distinct modes: ‘infer_only’, ‘train_only’, and ‘infer_and_train’, to meet your specific requirements.\\n\",\n    \"- infer_only: Only generate the rationales and they will be saved to the output_dir\\n\",\n    \"- train_only: Local training only\\n\",\n    \"- infer_and_train: Generate rationales, and then load them into PrefixDataset and start training\\n\",\n    \"  \\n\",\n    \"In this instance, we will opt for the ‘infer_and_train’ mode to initially generate rationales with the assistance of the remote LLM. To activate the inference process, it is necessary to initialize the infer client and server for both the client-side and server-side trainers, as demonstrated in the preceding sections.\\n\",\n    \"\\n\",\n    \"Below is an FedCoT example. We ran this example on a machine equipped with 4 V100-32G GPUs. We launch the client script using deepspeed. LLM is depolyed on another machine.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"c559341a-d133-4a24-8f1a-35cd6d2a26d3\",\n   \"metadata\": {},\n   \"source\": [\n    \"## FedCoT Example\\n\",\n    \"\\n\",\n    \"### Client Script(deepspeed_run.py)\\n\",\n    \"\\n\",\n    \"This script show how to setup a fedcot task on the client side.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"e4710fda-904a-4e90-bc65-beec7594703f\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import logging\\n\",\n    \"import os\\n\",\n    \"import sys\\n\",\n    \"from transformers import (\\n\",\n    \"    AutoTokenizer,\\n\",\n    \"    HfArgumentParser,\\n\",\n    \"    Seq2SeqTrainingArguments,\\n\",\n    \")\\n\",\n    \"from transformers import AutoModelForCausalLM, AutoTokenizer\\n\",\n    \"from typing import List\\n\",\n    \"from fate_llm.algo.inferdpt.utils import InferDPTKit\\n\",\n    \"from fate_llm.dataset.fedcot_dataset import PrefixDataset\\n\",\n    \"from fate_llm.algo.fedcot.fedcot_trainer import FedCoTTrainerClient\\n\",\n    \"from fate_llm.data.data_collator.fedcot_collator import PrefixDataCollator\\n\",\n    \"from fate_llm.algo.inferdpt import inferdpt\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"arbiter = (\\\"arbiter\\\", 10000)\\n\",\n    \"guest = (\\\"guest\\\", 10000)\\n\",\n    \"host = (\\\"host\\\", 9999)\\n\",\n    \"name = \\\"fed1\\\"\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def create_ctx(local):\\n\",\n    \"    from fate.arch import Context\\n\",\n    \"    from fate.arch.computing.backends.standalone import CSession\\n\",\n    \"    from fate.arch.federation.backends.standalone import StandaloneFederation\\n\",\n    \"    import logging\\n\",\n    \"\\n\",\n    \"    logger = logging.getLogger()\\n\",\n    \"    logger.setLevel(logging.INFO)\\n\",\n    \"\\n\",\n    \"    console_handler = logging.StreamHandler()\\n\",\n    \"    console_handler.setLevel(logging.INFO)\\n\",\n    \"\\n\",\n    \"    formatter = logging.Formatter(\\\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\\\")\\n\",\n    \"    console_handler.setFormatter(formatter)\\n\",\n    \"\\n\",\n    \"    logger.addHandler(console_handler)\\n\",\n    \"    computing = CSession(data_dir=\\\"./session_dir\\\")\\n\",\n    \"    return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"logger = logging.getLogger(__name__)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"doc_template = \\\"\\\"\\\"{{question}} \\n\",\n    \"Choices:{{choices.text}}\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"instruction_template=\\\"\\\"\\\"\\n\",\n    \"<s>[INST]\\n\",\n    \"Select Answer from Choices and explain it in \\\"Rationale\\\" with few words. Please refer to the example to write the rationale.\\n\",\n    \"Use <end> to finish your rationle.\\\"\\n\",\n    \"\\n\",\n    \"Example(s):\\n\",\n    \"Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\\n\",\n    \"Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\\n\",\n    \"Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction.  Therefore, the answer is 'dry palms'.<end>\\n\",\n    \"\\n\",\n    \"Please explain:\\n\",\n    \"Question:{{perturbed_doc}}\\n\",\n    \"Rationale:\\n\",\n    \"[/INST]\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"decode_template = \\\"\\\"\\\"Select Answer from Choices and explain it in \\\"Rationale\\\" with few words. Please refer to the example to write the rationale.\\n\",\n    \"Use <end> to finish your rationle.\\\"\\n\",\n    \"\\n\",\n    \"Example(s):\\n\",\n    \"Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\\n\",\n    \"Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\\n\",\n    \"Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction.  Therefore, the answer is 'dry palms'.<end>\\n\",\n    \"\\n\",\n    \"Question:{{perturbed_doc}}\\n\",\n    \"Rationale:{{perturbed_response | replace('\\\\n', '')}}<end>\\n\",\n    \"\\n\",\n    \"Please explain:\\n\",\n    \"Question:{{question}} \\n\",\n    \"Choices:{{choices.text}}\\n\",\n    \"Rationale:\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"    \\n\",\n    \"\\n\",\n    \"if __name__ == \\\"__main__\\\":\\n\",\n    \"    \\n\",\n    \"    parser = HfArgumentParser(Seq2SeqTrainingArguments)\\n\",\n    \"    if len(sys.argv) == 2 and sys.argv[1].endswith(\\\".json\\\"):\\n\",\n    \"        training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]\\n\",\n    \"    else:\\n\",\n    \"        training_args = parser.parse_args_into_dataclasses()[0]\\n\",\n    \"\\n\",\n    \"    model_path = '/data/cephfs/llm/models/Qwen1.5-0.5B/'\\n\",\n    \"    pds = PrefixDataset(\\n\",\n    \"        tokenizer_path=model_path,\\n\",\n    \"        predict_input_template=\\\"\\\"\\\"Predict:\\n\",\n    \"Question:{{question}}\\n\",\n    \"Choices:{{choices.text}}\\n\",\n    \"Answer:\\n\",\n    \"    \\\"\\\"\\\",\\n\",\n    \"        predict_output_template=\\\"\\\"\\\"{{choices.text[choices.label.index(answerKey)]}}<end>\\\"\\\"\\\",\\n\",\n    \"        rationale_input_template=\\\"\\\"\\\"Explain:\\n\",\n    \"Question:{{question}}\\n\",\n    \"Choices:{{choices.text}}\\n\",\n    \"Rationale:\\n\",\n    \"    \\\"\\\"\\\",\\n\",\n    \"        rationale_output_template=\\\"\\\"\\\"{{infer_result}}<end>\\n\",\n    \"    \\\"\\\"\\\",\\n\",\n    \"        max_input_length=128,\\n\",\n    \"        max_target_length=128,\\n\",\n    \"        split_key='train'\\n\",\n    \"    )\\n\",\n    \"    pds.load('/data/cephfs/llm/datasets/arce/')\\n\",\n    \"    \\n\",\n    \"    model = AutoModelForCausalLM.from_pretrained(model_path).half().cuda()\\n\",\n    \"    tokenizer = AutoTokenizer.from_pretrained(model_path)\\n\",\n    \"    model.gradient_checkpointing_enable()\\n\",\n    \"    model.enable_input_require_grads()\\n\",\n    \"\\n\",\n    \"    ctx = create_ctx(guest)\\n\",\n    \"    if training_args.local_rank == 0:\\n\",\n    \"        # only rank 0 need to load infer instance\\n\",\n    \"        save_kit_path = 'your path'\\n\",\n    \"        kit = InferDPTKit.load_from_path(save_kit_path)\\n\",\n    \"        # local deployed small model as decoding model\\n\",\n    \"        from fate_llm.algo.inferdpt.inference.api import APICompletionInference\\n\",\n    \"        inference = APICompletionInference(api_url=\\\"http://xxxx/v1\\\", model_name='./Qwen1.5-0.5B', api_key='EMPTY')\\n\",\n    \"        client = inferdpt.InferDPTClient(ctx, kit, inference, epsilon=3.0)\\n\",\n    \"    else:\\n\",\n    \"        client = None\\n\",\n    \"    \\n\",\n    \"    trainer = FedCoTTrainerClient(\\n\",\n    \"        ctx=ctx,\\n\",\n    \"        model=model,\\n\",\n    \"        training_args=training_args,\\n\",\n    \"        tokenizer=tokenizer,    \\n\",\n    \"        train_set=pds,\\n\",\n    \"        data_collator=PrefixDataCollator(tokenizer),\\n\",\n    \"        mode='infer_and_train',\\n\",\n    \"        infer_client=client,\\n\",\n    \"        encode_template=doc_template,\\n\",\n    \"        decode_template=decode_template,\\n\",\n    \"        instruction_template=instruction_template,\\n\",\n    \"        remote_inference_kwargs={\\n\",\n    \"            'stop': ['<\\\\s>'],\\n\",\n    \"            'temperature': 0.01,\\n\",\n    \"            'max_tokens': 256\\n\",\n    \"         },\\n\",\n    \"         local_inference_kwargs={\\n\",\n    \"            'stop': ['<|im_end|>', '<end>', '<end>\\\\n', '<end>\\\\n\\\\n', '.\\\\n\\\\n\\\\n\\\\n\\\\n', '<|end_of_text|>', '>\\\\n\\\\n\\\\n'],\\n\",\n    \"            'temperature': 0.01,\\n\",\n    \"            'max_tokens': 256\\n\",\n    \"         }\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    trainer.train()\\n\",\n    \"\\n\",\n    \"    if training_args.local_rank == 0:\\n\",\n    \"        model.save_pretrained(training_args.output_dir)\\n\",\n    \"        tokenizer.save_pretrained(training_args.output_dir)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"962dd399-1dec-4164-bd86-15aa8550c50b\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Server Script(server.py)\\n\",\n    \"\\n\",\n    \"This script show how to setup a fedcot task on the server side.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"91b42972-5308-4ccf-a768-f7dfa087313e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_llm.algo.inferdpt.inferdpt import InferDPTServer\\n\",\n    \"from fate_llm.algo.fedcot.fedcot_trainer import FedCoTTraineServer\\n\",\n    \"import sys\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"arbiter = (\\\"arbiter\\\", 10000)\\n\",\n    \"guest = (\\\"guest\\\", 10000)\\n\",\n    \"host = (\\\"host\\\", 9999)\\n\",\n    \"name = \\\"fed1\\\"\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def create_ctx(local):\\n\",\n    \"    from fate.arch import Context\\n\",\n    \"    from fate.arch.computing.backends.standalone import CSession\\n\",\n    \"    from fate.arch.federation.backends.standalone import StandaloneFederation\\n\",\n    \"    import logging\\n\",\n    \"\\n\",\n    \"    logger = logging.getLogger()\\n\",\n    \"    logger.setLevel(logging.INFO)\\n\",\n    \"\\n\",\n    \"    console_handler = logging.StreamHandler()\\n\",\n    \"    console_handler.setLevel(logging.INFO)\\n\",\n    \"\\n\",\n    \"    formatter = logging.Formatter(\\\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\\\")\\n\",\n    \"    console_handler.setFormatter(formatter)\\n\",\n    \"\\n\",\n    \"    logger.addHandler(console_handler)\\n\",\n    \"    computing = CSession(data_dir=\\\"./session_dir\\\")\\n\",\n    \"    return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"from fate_llm.algo.inferdpt.inference.api import APICompletionInference\\n\",\n    \"api = APICompletionInference(api_url='http://xxxx:8080/v1', api_key='EMPTY', model_name='/data/cephfs/llm/models/Qwen1.5-14B-Chat')\\n\",\n    \"\\n\",\n    \"ctx = create_ctx(arbiter)\\n\",\n    \"server_api = InferDPTServer(ctx, api)\\n\",\n    \"server = FedCoTTraineServer(ctx, server_api)\\n\",\n    \"server.train()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"125dd68e-c7d4-41aa-9972-4881b1330fb6\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Start script\\n\",\n    \"\\n\",\n    \"You can launch client side training with following script:\\n\",\n    \"\\n\",\n    \"```\\n\",\n    \"deepspeed --num_nodes 1 --num_gpus 4 deepspeed_run.py \\\\\\n\",\n    \"    --output_dir \\\"./\\\" \\\\\\n\",\n    \"    --per_device_train_batch_size \\\"1\\\" \\\\\\n\",\n    \"    --gradient_accumulation_steps \\\"8\\\" \\\\\\n\",\n    \"    --max_steps \\\"750\\\" \\\\\\n\",\n    \"    --fp16 \\\\\\n\",\n    \"    --logging_steps 10 \\\\\\n\",\n    \"    --save_only_model \\\\\\n\",\n    \"    --deepspeed \\\"./ds_config.json\\\" \\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"0b506c1c-51f4-448d-9b0b-adf1a71cc7cf\",\n   \"metadata\": {},\n   \"source\": [\n    \"and the ds_config.json is\\n\",\n    \"```\\n\",\n    \"{   \\n\",\n    \"    \\\"train_micro_batch_size_per_gpu\\\": 1,\\n\",\n    \"    \\\"gradient_accumulation_steps\\\": 8,\\n\",\n    \"    \\\"optimizer\\\": {\\n\",\n    \"        \\\"type\\\": \\\"AdamW\\\",\\n\",\n    \"        \\\"params\\\": {\\n\",\n    \"             \\\"lr\\\": 5e-5\\n\",\n    \"        }\\n\",\n    \"    },\\n\",\n    \"    \\\"fp16\\\": {\\n\",\n    \"        \\\"enabled\\\": true\\n\",\n    \"    },\\n\",\n    \"    \\\"zero_optimization\\\": {\\n\",\n    \"        \\\"stage\\\": 0\\n\",\n    \"    }\\n\",\n    \"}\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"613fbfb6-ac9e-485b-8587-ffef1e2361c1\",\n   \"metadata\": {},\n   \"source\": [\n    \"And server side:\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"5b50adf0-8f9c-40e5-9a7d-40a70e30a420\",\n   \"metadata\": {},\n   \"source\": [\n    \"```python server.py```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"28a5de71-25fd-4042-a6b7-0ec2c505eaee\",\n   \"metadata\": {},\n   \"source\": [\n    \"## FedCoT Pipeline Example\\n\",\n    \"\\n\",\n    \"You have the capability to submit a FedCoT task within the FATE pipeline. By appropriately configuring the necessary settings, you can execute FedCoT in a production environment.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"52f1e19b-da8e-4977-adb1-42fb84dee407\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_client.pipeline.components.fate.nn.loader import Loader\\n\",\n    \"import argparse\\n\",\n    \"from fate_client.pipeline.utils import test_utils\\n\",\n    \"from fate_client.pipeline.components.fate.reader import Reader\\n\",\n    \"from fate_client.pipeline import FateFlowPipeline\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def main(config=\\\"../../config.yaml\\\", namespace=\\\"\\\"):\\n\",\n    \"    # obtain config\\n\",\n    \"    if isinstance(config, str):\\n\",\n    \"        config = test_utils.load_job_config(config)\\n\",\n    \"    parties = config.parties\\n\",\n    \"    guest = '9999'\\n\",\n    \"    host = parties.host[0]\\n\",\n    \"    arbiter = '10000'\\n\",\n    \"\\n\",\n    \"    pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter)\\n\",\n    \"\\n\",\n    \"    reader_0 = Reader(\\\"reader_0\\\", runtime_parties=dict(guest=guest))\\n\",\n    \"    reader_0.guest.task_parameters(\\n\",\n    \"        namespace=\\\"experiment\\\",\\n\",\n    \"        name=\\\"arc_e_example\\\"\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    model_conf = Loader(module_name='fate_llm.model_zoo.hf_model', item_name='HFAutoModelForCausalLM', \\n\",\n    \"                        pretrained_model_name_or_path='/data/cephfs/llm/models/Qwen1.5-0.5B/').to_dict()\\n\",\n    \"    data_collator_conf = Loader(module_name='fate_llm.data.data_collator.fedcot_collator', item_name='get_prefix_data_collator', tokenizer_name_or_path='/data/cephfs/llm/models/Qwen1.5-0.5B/').to_dict()\\n\",\n    \"\\n\",\n    \"    infer_init_conf_client = {\\n\",\n    \"        'module_name': 'fate_llm.algo.inferdpt.init.default_init',\\n\",\n    \"        'item_name': 'InferDPTAPIClientInit'\\n\",\n    \"    }\\n\",\n    \"\\n\",\n    \"    infer_init_conf_server = {\\n\",\n    \"        'module_name': 'fate_llm.algo.inferdpt.init.default_init',\\n\",\n    \"        'item_name': 'InferDPTAPIServerInit'\\n\",\n    \"    }\\n\",\n    \"\\n\",\n    \"    dataset_conf = {\\n\",\n    \"        'module_name': 'fate_llm.dataset.fedcot_dataset',\\n\",\n    \"        'item_name': 'PrefixDataset',\\n\",\n    \"        'kwargs':dict(\\n\",\n    \"            tokenizer_path='/data/cephfs/llm/models/Qwen1.5-0.5B/',\\n\",\n    \"            predict_input_template=\\\"\\\"\\\"Predict:\\n\",\n    \"    Question:{{question}}\\n\",\n    \"    Choices:{{choices.text}}\\n\",\n    \"    \\\"\\\"\\\",\\n\",\n    \"            predict_output_template=\\\"\\\"\\\"{{choices.text[choices.label.index(answerKey)]}}<end>\\\"\\\"\\\",\\n\",\n    \"            rationale_input_template=\\\"\\\"\\\"Explain:\\n\",\n    \"    Question:{{question}}\\n\",\n    \"    Choices:{{choices.text}}\\n\",\n    \"    \\\"\\\"\\\",\\n\",\n    \"            rationale_output_template=\\\"\\\"\\\"{{infer_result}}<end>\\n\",\n    \"        \\\"\\\"\\\",\\n\",\n    \"            max_input_length=128,\\n\",\n    \"            max_target_length=128,\\n\",\n    \"            split_key='train'\\n\",\n    \"        )\\n\",\n    \"    }\\n\",\n    \"\\n\",\n    \"    encoder_prompt = \\\"\\\"\\\"{{question}}\\n\",\n    \"Choices:{{choices.text}}\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"    decoder_prompt = \\\"\\\"\\\"Select Answer from Choices and explain it in \\\"Rationale\\\" with few words. Please refer to the example to write the rationale.Use <end> to finish your rationle.\\n\",\n    \"\\n\",\n    \"Example(s):\\n\",\n    \"Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\\n\",\n    \"Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\\n\",\n    \"Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction.  Therefore, the answer is 'dry palms'.<end>\\n\",\n    \"\\n\",\n    \"Question:{{perturbed_doc}}\\n\",\n    \"Rationale:{{perturbed_response | replace('\\\\n', '')}}<end>\\n\",\n    \"\\n\",\n    \"Please explain:\\n\",\n    \"Question:{{question}} \\n\",\n    \"Choices:{{choices.text}}\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"    instruction_prompt = \\\"\\\"\\\"<|im_start|>system\\n\",\n    \"You are a helpful assistant<|im_end|>\\n\",\n    \"<|im_start|>user\\n\",\n    \"Select Answer from Choices and explain it in \\\"Rationale\\\" with few words. Please refer to the example to write the rationale.\\n\",\n    \"Use <end> to finish your rationle.\\n\",\n    \"\\n\",\n    \"Example(s):\\n\",\n    \"Question:Which factor will most likely cause a person to develop a fever?\\n\",\n    \"Choices:['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach']\\n\",\n    \"Rationale:A bacterial infection in the bloodstream triggers the immune system to respond, therefore often causing a fever as the body tries to fight off the bacteria. Therefore, the answer is 'a bacterial population in the bloodstream'\\n\",\n    \"\\n\",\n    \"Please explain:\\n\",\n    \"Question:{{perturbed_doc}}\\n\",\n    \"Rationale:\\n\",\n    \"<|im_end|>\\n\",\n    \"<|im_start|>assistant\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"    remote_inference_kwargs={\\n\",\n    \"        'stop': ['<|im_end|>', '<end>', '<end>\\\\n', '<end>\\\\n\\\\n', '.\\\\n\\\\n\\\\n\\\\n\\\\n', '<|end_of_text|>', '>\\\\n\\\\n\\\\n'],\\n\",\n    \"        'temperature': 0.01,\\n\",\n    \"        'max_tokens': 256\\n\",\n    \"    }\\n\",\n    \"\\n\",\n    \"    local_inference_kwargs={\\n\",\n    \"        'stop': ['<|im_end|>', '<end>', '<end>\\\\n', '<end>\\\\n\\\\n', '.\\\\n\\\\n\\\\n\\\\n\\\\n', '<|end_of_text|>', '>\\\\n\\\\n\\\\n'],\\n\",\n    \"        'temperature': 0.01,\\n\",\n    \"        'max_tokens': 256\\n\",\n    \"    }\\n\",\n    \"\\n\",\n    \"    ds_config = {   \\n\",\n    \"        \\\"train_micro_batch_size_per_gpu\\\": 1,\\n\",\n    \"        \\\"gradient_accumulation_steps\\\": 8,\\n\",\n    \"        \\\"optimizer\\\": {\\n\",\n    \"            \\\"type\\\": \\\"AdamW\\\",\\n\",\n    \"            \\\"params\\\": {\\n\",\n    \"                \\\"lr\\\": 5e-5\\n\",\n    \"            }\\n\",\n    \"        },\\n\",\n    \"        \\\"fp16\\\": {\\n\",\n    \"            \\\"enabled\\\": True\\n\",\n    \"        },\\n\",\n    \"        \\\"zero_optimization\\\": {\\n\",\n    \"            \\\"stage\\\": 0\\n\",\n    \"        }\\n\",\n    \"    }\\n\",\n    \"\\n\",\n    \"    training_args_dict = dict(\\n\",\n    \"        per_device_train_batch_size=1, \\n\",\n    \"        gradient_accumulation_steps=8,\\n\",\n    \"        logging_steps=10,\\n\",\n    \"        max_steps=30,\\n\",\n    \"        fp16=True,\\n\",\n    \"        log_level='debug'\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    mode = 'infer_and_train'\\n\",\n    \"\\n\",\n    \"    client_conf = dict(\\n\",\n    \"        model_conf=model_conf,\\n\",\n    \"        dataset_conf=dataset_conf,\\n\",\n    \"        training_args_conf=training_args_dict,\\n\",\n    \"        data_collator_conf=data_collator_conf,\\n\",\n    \"        mode=mode,\\n\",\n    \"        infer_inst_init_conf=infer_init_conf_client,\\n\",\n    \"        encode_template=encoder_prompt,\\n\",\n    \"        instruction_template=instruction_prompt,\\n\",\n    \"        decode_template=decoder_prompt,\\n\",\n    \"        remote_inference_kwargs=remote_inference_kwargs,\\n\",\n    \"        local_inference_kwargs=local_inference_kwargs,\\n\",\n    \"        perturb_doc_key='perturbed_doc',\\n\",\n    \"        perturbed_response_key='perturbed_response',\\n\",\n    \"        result_key='infer_result'\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    server_conf = dict(\\n\",\n    \"        infer_inst_init_conf=infer_init_conf_server,\\n\",\n    \"        mode=mode\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    homo_nn_0 = HomoNN(\\n\",\n    \"        'nn_0',\\n\",\n    \"        train_data=reader_0.outputs[\\\"output_data\\\"],\\n\",\n    \"        runner_module=\\\"fedcot_runner\\\",\\n\",\n    \"        runner_class=\\\"FedCoTRunner\\\"\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    homo_nn_0.guest.task_parameters(runner_conf=client_conf)\\n\",\n    \"    homo_nn_0.arbiter.task_parameters(runner_conf=server_conf)\\n\",\n    \"\\n\",\n    \"    homo_nn_0.guest.conf.set(\\\"launcher_name\\\", \\\"deepspeed\\\")\\n\",\n    \"\\n\",\n    \"    pipeline.add_tasks([reader_0, homo_nn_0])\\n\",\n    \"    pipeline.conf.set(\\\"task\\\", dict(engine_run={\\\"cores\\\": 4}))\\n\",\n    \"    pipeline.compile()\\n\",\n    \"    pipeline.fit()\\n\",\n    \"\\n\",\n    \"if __name__ == \\\"__main__\\\":\\n\",\n    \"    parser = argparse.ArgumentParser(\\\"PIPELINE DEMO\\\")\\n\",\n    \"    parser.add_argument(\\\"--config\\\", type=str, default=\\\"../config.yaml\\\",\\n\",\n    \"                        help=\\\"config file\\\")\\n\",\n    \"    parser.add_argument(\\\"--namespace\\\", type=str, default=\\\"\\\",\\n\",\n    \"                        help=\\\"namespace for data stored in FATE\\\")\\n\",\n    \"    args = parser.parse_args()\\n\",\n    \"    main(config=args.config, namespace=args.namespace)\\n\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.13\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}"
  },
  {
    "path": "doc/tutorial/fedkseed/README.md",
    "content": "## FedKSeed\n\nThe Algorithm is based on the paper: [Federated Full-Parameter Tuning of Billion-Sized Language Models\nwith Communication Cost under 18 Kilobytes](https://arxiv.org/pdf/2312.06353.pdf) and the code is adaptor\nfrom the https://github.com/alibaba/FederatedScope/tree/FedKSeed.\nWe refactor the code to make it more compatible with (transformers/PyTorch) framework \nand integrate it into the FATE-LLM framework.\n\nThe main works include:\n1. An KSeedZerothOrderOptimizer class that can be used to optimize model along given direction that generated with random seed.\n2. An KSeedZOExtendedTrainer subclass of Trainer from transformers that can be used to train large language models with KSeedZerothOrderOptimizer.\n3. Trainers for federated learning with large language models."
  },
  {
    "path": "doc/tutorial/fedkseed/fedkseed-example.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#  Federated Tuning with FedKSeed methods in FATE-LLM\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"In this tutorial, we will demonstrate how to efficiently train federated large language models using the FATE-LLM framework. In FATE-LLM, we introduce the \\\"FedKSeed\\\" module, specifically designed for federated learning with large language models. The Idea of FedKSeed is to use Zeroth-Order-Optimizer to optimize model along given direction that generated with random seed. This method can be used to train large language models in a federated learning setting with extremely low communication cost.\\n\",\n    \"\\n\",\n    \"The Algorithm is based on the paper: [Federated Full-Parameter Tuning of Billion-Sized Language Models\\n\",\n    \"with Communication Cost under 18 Kilobytes](https://arxiv.org/pdf/2312.06353.pdf) and the code is modified from the https://github.com/alibaba/FederatedScope/tree/FedKSeed. We refactor the code to make it more compatible with (transformers/PyTorch) framework and integrate it into the FATE-LLM framework.\\n\",\n    \"\\n\",\n    \"The main works include:\\n\",\n    \"1. An KSeedZerothOrderOptimizer class that can be used to optimize model along given direction that generated with random seed.\\n\",\n    \"2. An KSeedZOExtendedTrainer subclass of Trainer from transformers that can be used to train large language models with KSeedZerothOrderOptimizer.\\n\",\n    \"3. Trainers for federated learning with large language models.\\n\",\n    \"\\n\",\n    \"In this tutorial, we will demonstrate how to use the FedKSeed method to train a large language model in a federated learning setting. \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Model: datajuicer/LLaMA-1B-dj-refine-150B\\n\",\n    \"\\n\",\n    \"This is the introduction from the Huggingface model hub: [datajuicer/LLaMA-1B-dj-refine-150B](https://huggingface.co/datajuicer/LLaMA-1B-dj-refine-150B)\\n\",\n    \"\\n\",\n    \"> The model architecture is LLaMA-1.3B and we adopt the OpenLLaMA implementation. The model is pre-trained on 150B tokens of Data-Juicer's refined RedPajama and Pile. It achieves an average score of 34.21 over 16 HELM tasks, beating Falcon-1.3B (trained on 350B tokens from RefinedWeb), Pythia-1.4B (trained on 300B tokens from original Pile) and Open-LLaMA-1.3B (trained on 150B tokens from original RedPajama and Pile).\\n\",\n    \"\\n\",\n    \"> For more details, please refer to our [paper](https://arxiv.org/abs/2309.02033).\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {\n    \"ExecuteTime\": {\n     \"end_time\": \"2024-02-29T09:27:23.512735Z\",\n     \"start_time\": \"2024-02-29T09:27:23.508790Z\"\n    },\n    \"collapsed\": false\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# model_name_or_path = \\\"datajuicer/LLaMA-1B-dj-refine-150B\\\"\\n\",\n    \"model_name_or_path = \\\"gpt2\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Dataset: databricks/databricks-dolly-15k\\n\",\n    \"\\n\",\n    \"This is the introduction from the Huggingface dataset hub: [databricks/databricks-dolly-15k](https://huggingface.co/dataset/databricks/databricks-dolly-15k)\\n\",\n    \"\\n\",\n    \"> databricks-dolly-15k is a corpus of more than 15,000 records generated by thousands of Databricks employees to enable large language models to exhibit the magical interactivity of ChatGPT. Databricks employees were invited to create prompt / response pairs in each of eight different instruction categories, including the seven outlined in the InstructGPT paper, as well as an open-ended free-form category. The contributors were instructed to avoid using information from any source on the web with the exception of Wikipedia (for particular subsets of instruction categories), and explicitly instructed to avoid using generative AI in formulating instructions or responses. Examples of each behavior were provided to motivate the types of questions and instructions appropriate to each category\\n\",\n    \"\\n\",\n    \"To use this dataset, you first need to download it from the Huggingface dataset hub:\\n\",\n    \"\\n\",\n    \"```bash\\n\",\n    \"mkdir -p ../../../examples/data/dolly && cd ../../../examples/data/dolly && wget  wget https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl\\\\?download\\\\=true -O databricks-dolly-15k.jsonl\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"### Check Dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {\n    \"ExecuteTime\": {\n     \"end_time\": \"2024-02-29T09:27:26.987779Z\",\n     \"start_time\": \"2024-02-29T09:27:24.706218Z\"\n    }\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_llm.dataset.hf_dataset import Dolly15K\\n\",\n    \"from transformers import AutoTokenizer\\n\",\n    \"\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name_or_path)\\n\",\n    \"special_tokens = tokenizer.special_tokens_map\\n\",\n    \"if \\\"pad_token\\\" not in tokenizer.special_tokens_map:\\n\",\n    \"    special_tokens[\\\"pad_token\\\"] = special_tokens[\\\"eos_token\\\"]\\n\",\n    \"\\n\",\n    \"tokenizer.pad_token = tokenizer.eos_token\\n\",\n    \"ds = Dolly15K(split=\\\"train\\\", tokenizer_params={\\\"pretrained_model_name_or_path\\\": model_name_or_path, **special_tokens},\\n\",\n    \"              tokenizer_apply_params=dict(truncation=True, max_length=tokenizer.model_max_length, padding=\\\"max_length\\\", return_tensors=\\\"pt\\\"))\\n\",\n    \"ds = ds.load('../../../examples/data/dolly')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {\n    \"ExecuteTime\": {\n     \"end_time\": \"2024-02-29T09:27:27.875025Z\",\n     \"start_time\": \"2024-02-29T09:27:27.867839Z\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"Dataset({\\n\",\n       \"    features: ['instruction', 'context', 'response', 'category', 'text', 'input_ids', 'attention_mask'],\\n\",\n       \"    num_rows: 15011\\n\",\n       \"})\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"ds\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"For more details of FATE-LLM dataset setting, we recommend that you read through these tutorials first: [NN Dataset Customization](https://github.com/FederatedAI/FATE/blob/master/doc/tutorial/pipeline/nn_tutorial/Homo-NN-Customize-your-Dataset.ipynb), [Some Built-In Dataset](https://github.com/FederatedAI/FATE/blob/master/doc/tutorial/pipeline/nn_tutorial/Introduce-Built-In-Dataset.ipynb),\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Check local training\\n\",\n    \"\\n\",\n    \"Before submitting a federated learning task, we will demonstrate how to perform local testing to ensure the proper functionality of your custom dataset, model. \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {\n    \"ExecuteTime\": {\n     \"end_time\": \"2024-02-29T09:38:33.175079Z\",\n     \"start_time\": \"2024-02-29T09:38:33.168844Z\"\n    },\n    \"collapsed\": false\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import AutoModelForCausalLM, TrainingArguments, DataCollatorForLanguageModeling\\n\",\n    \"from fate_llm.algo.fedkseed.trainer import KSeedZOExtendedTrainer, KSeedTrainingArguments\\n\",\n    \"from fate_llm.algo.fedkseed.zo_utils import build_seed_candidates, get_even_seed_probabilities\\n\",\n    \"\\n\",\n    \"def test_training(zo_mode=True):\\n\",\n    \"    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **special_tokens)\\n\",\n    \"    data_collector = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)\\n\",\n    \"    model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_name_or_path)\\n\",\n    \"\\n\",\n    \"    training_args = TrainingArguments(output_dir='./',\\n\",\n    \"                                      dataloader_num_workers=1,\\n\",\n    \"                                      dataloader_prefetch_factor=1,\\n\",\n    \"                                      remove_unused_columns=True,\\n\",\n    \"                                      learning_rate=1e-5,\\n\",\n    \"                                      per_device_train_batch_size=1,\\n\",\n    \"                                      num_train_epochs=0.01,\\n\",\n    \"                                      )\\n\",\n    \"    kseed_args = KSeedTrainingArguments(zo_optim=zo_mode)\\n\",\n    \"    trainer = KSeedZOExtendedTrainer(model=model, train_dataset=ds, training_args=training_args, kseed_args=kseed_args,\\n\",\n    \"                                     tokenizer=tokenizer, data_collator=data_collector)\\n\",\n    \"    if zo_mode:\\n\",\n    \"        seed_candidates = build_seed_candidates(k=kseed_args.k)\\n\",\n    \"        seed_probabilities = get_even_seed_probabilities(k=kseed_args.k)\\n\",\n    \"        trainer.configure_seed_candidates(seed_candidates, seed_probabilities)\\n\",\n    \"    return trainer.train()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {\n    \"ExecuteTime\": {\n     \"end_time\": \"2024-02-29T09:39:37.602070Z\",\n     \"start_time\": \"2024-02-29T09:38:34.024223Z\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/html\": [\n       \"\\n\",\n       \"    <div>\\n\",\n       \"      \\n\",\n       \"      <progress value='151' max='151' style='width:300px; height:20px; vertical-align: middle;'></progress>\\n\",\n       \"      [151/151 00:59, Epoch 0/1]\\n\",\n       \"    </div>\\n\",\n       \"    <table border=\\\"1\\\" class=\\\"dataframe\\\">\\n\",\n       \"  <thead>\\n\",\n       \" <tr style=\\\"text-align: left;\\\">\\n\",\n       \"      <th>Step</th>\\n\",\n       \"      <th>Training Loss</th>\\n\",\n       \"    </tr>\\n\",\n       \"  </thead>\\n\",\n       \"  <tbody>\\n\",\n       \"  </tbody>\\n\",\n       \"</table><p>\"\n      ],\n      \"text/plain\": [\n       \"<IPython.core.display.HTML object>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"TrainOutput(global_step=151, training_loss=1.2660519429390005, metrics={'train_runtime': 61.8249, 'train_samples_per_second': 2.428, 'train_steps_per_second': 2.442, 'total_flos': 78910193664000.0, 'train_loss': 1.2660519429390005, 'epoch': 0.01})\"\n      ]\n     },\n     \"execution_count\": 17,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"test_training(zo_mode=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {\n    \"ExecuteTime\": {\n     \"end_time\": \"2024-02-29T09:41:28.949449Z\",\n     \"start_time\": \"2024-02-29T09:39:54.802705Z\"\n    },\n    \"collapsed\": false\n   },\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/html\": [\n       \"\\n\",\n       \"    <div>\\n\",\n       \"      \\n\",\n       \"      <progress value='151' max='151' style='width:300px; height:20px; vertical-align: middle;'></progress>\\n\",\n       \"      [151/151 01:29, Epoch 0/1]\\n\",\n       \"    </div>\\n\",\n       \"    <table border=\\\"1\\\" class=\\\"dataframe\\\">\\n\",\n       \"  <thead>\\n\",\n       \" <tr style=\\\"text-align: left;\\\">\\n\",\n       \"      <th>Step</th>\\n\",\n       \"      <th>Training Loss</th>\\n\",\n       \"    </tr>\\n\",\n       \"  </thead>\\n\",\n       \"  <tbody>\\n\",\n       \"  </tbody>\\n\",\n       \"</table><p>\"\n      ],\n      \"text/plain\": [\n       \"<IPython.core.display.HTML object>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"TrainOutput(global_step=151, training_loss=0.6093456950408733, metrics={'train_runtime': 92.6158, 'train_samples_per_second': 1.621, 'train_steps_per_second': 1.63, 'total_flos': 78910193664000.0, 'train_loss': 0.6093456950408733, 'epoch': 0.01})\"\n      ]\n     },\n     \"execution_count\": 18,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"test_training(zo_mode=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"collapsed\": false\n   },\n   \"source\": [\n    \"You can see that Zeroth-Order-Optimizer has much worse performance than AdamW, that's the price we need to pay for the low communication cost. \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Submit Federated Task\\n\",\n    \"Once you have successfully completed local testing, We can submit a task to FATE. Please notice that this tutorial is ran on a standalone version. **Please notice that in this tutorial we are using a standalone version, if you are using a cluster version, you need to bind the data with the corresponding name&namespace on each machine.**\\n\",\n    \"\\n\",\n    \"In this example we load pretrained weights for gpt2 model.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import time\\n\",\n    \"from fate_client.pipeline.components.fate.reader import Reader\\n\",\n    \"from fate_client.pipeline import FateFlowPipeline\\n\",\n    \"from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_seq2seq_runner\\n\",\n    \"from fate_client.pipeline.components.fate.nn.algo_params import TrainingArguments, FedAVGArguments\\n\",\n    \"from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMDataFuncLoader\\n\",\n    \"\\n\",\n    \"guest = '10000'\\n\",\n    \"host = '10000'\\n\",\n    \"arbiter = '10000'\\n\",\n    \"\\n\",\n    \"epochs = 0.01\\n\",\n    \"batch_size = 1\\n\",\n    \"lr = 1e-5\\n\",\n    \"\\n\",\n    \"pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter)\\n\",\n    \"pipeline.bind_local_path(path=\\\"/data/projects/fate/examples/data/dolly\\\", namespace=\\\"experiment\\\",\\n\",\n    \"                         name=\\\"dolly\\\")\\n\",\n    \"time.sleep(5)\\n\",\n    \"\\n\",\n    \"reader_0 = Reader(\\\"reader_0\\\", runtime_parties=dict(guest=guest, host=host))\\n\",\n    \"reader_0.guest.task_parameters(\\n\",\n    \"    namespace=\\\"experiment\\\",\\n\",\n    \"    name=\\\"dolly\\\"\\n\",\n    \")\\n\",\n    \"reader_0.hosts[0].task_parameters(\\n\",\n    \"    namespace=\\\"experiment\\\",\\n\",\n    \"    name=\\\"dolly\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"tokenizer_params = dict(\\n\",\n    \"    pretrained_model_name_or_path=\\\"gpt2\\\",\\n\",\n    \"    trust_remote_code=True,\\n\",\n    \")\\n\",\n    \"conf = get_config_of_seq2seq_runner(\\n\",\n    \"    algo='fedkseed',\\n\",\n    \"    model=LLMModelLoader(\\n\",\n    \"        \\\"hf_model\\\",\\n\",\n    \"        \\\"HFAutoModelForCausalLM\\\",\\n\",\n    \"        # pretrained_model_name_or_path=\\\"datajuicer/LLaMA-1B-dj-refine-150B\\\",\\n\",\n    \"        pretrained_model_name_or_path=\\\"gpt2\\\",\\n\",\n    \"        trust_remote_code=True\\n\",\n    \"    ),\\n\",\n    \"    dataset=LLMDatasetLoader(\\n\",\n    \"        \\\"hf_dataset\\\",\\n\",\n    \"        \\\"Dolly15K\\\",\\n\",\n    \"        split=\\\"train\\\",\\n\",\n    \"        tokenizer_params=tokenizer_params,\\n\",\n    \"        tokenizer_apply_params=dict(\\n\",\n    \"            truncation=True,\\n\",\n    \"            max_length=1024,\\n\",\n    \"        )),\\n\",\n    \"    data_collator=LLMDataFuncLoader(\\n\",\n    \"        \\\"cust_func.cust_data_collator\\\",\\n\",\n    \"        \\\"get_seq2seq_tokenizer\\\",\\n\",\n    \"        tokenizer_params=tokenizer_params,\\n\",\n    \"    ),\\n\",\n    \"    training_args=TrainingArguments(\\n\",\n    \"        num_train_epochs=0.01,\\n\",\n    \"        per_device_train_batch_size=batch_size,\\n\",\n    \"        remove_unused_columns=True,\\n\",\n    \"        learning_rate=lr,\\n\",\n    \"        fp16=False,\\n\",\n    \"        use_cpu=False,\\n\",\n    \"        disable_tqdm=False,\\n\",\n    \"    ),\\n\",\n    \"    fed_args=FedAVGArguments(),\\n\",\n    \"    task_type='causal_lm',\\n\",\n    \"    save_trainable_weights_only=True,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"conf[\\\"fed_args_conf\\\"] = {}\\n\",\n    \"\\n\",\n    \"homo_nn_0 = HomoNN(\\n\",\n    \"    'nn_0',\\n\",\n    \"    runner_conf=conf,\\n\",\n    \"    train_data=reader_0.outputs[\\\"output_data\\\"],\\n\",\n    \"    runner_module=\\\"fedkseed_runner\\\",\\n\",\n    \"    runner_class=\\\"FedKSeedRunner\\\",\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"pipeline.add_tasks([reader_0, homo_nn_0])\\n\",\n    \"pipeline.conf.set(\\\"task\\\", dict(engine_run={\\\"cores\\\": 1}))\\n\",\n    \"\\n\",\n    \"pipeline.compile()\\n\",\n    \"pipeline.fit()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"You can use this script to submit the model, but submitting the model will take a long time to train and generate a long log, so we won't do it here.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.9.0\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "doc/tutorial/fedmkt/README.md",
    "content": "# FATE-LLM: FedMKT\n\nThe algorithm is based on paper [\"FedMKT: Federated Mutual Knowledge Transfer for Large and SmallLanguage Models\"](https://aclanthology.org/2025.coling-main.17.pdf), We integrate its code into the FATE-LLM framework.  \n\n## Citation\nIf you publish work that uses FedMKT, please cite FedMKT as follows:\n```\n@inproceedings{fan2025fedmkt,\n  title={Fedmkt: Federated mutual knowledge transfer for large and small language models},\n  author={Fan, Tao and Ma, Guoqiang and Kang, Yan and Gu, Hanlin and Song, Yuanfeng and Fan, Lixin and Chen, Kai and Yang, Qiang},\n  booktitle={Proceedings of the 31st International Conference on Computational Linguistics},\n  pages={243--255},\n  year={2025}\n}\n```\n"
  },
  {
    "path": "doc/tutorial/fedmkt/fedmkt.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Federated Tuning With FedMKT methods in FATE-LLM\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"In this tutorial, we will demonstrate how to efficiently train federated large language models using the FATE-LLM framework. In FATE-LLM, we introduce the \\\"FedMKT\\\" module, specifically designed for federated learning with large language models. FedMKT introduces a novel\\n\",\n    \"federated mutual knowledge transfer framework that enables effective knowledge transfer between an LLM deployed on the server and SLMs residing on clients.\\n\",\n    \"\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"The Algorithm is based on paper [\\\"FedMKT: Federated Mutual Knowledge Transfer for Large and SmallLanguage Models\\\"](https://arxiv.org/pdf/2406.02224), We integrate its code into the FATE-LLM framework.  \\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Experiments\\n\",\n    \"\\n\",\n    \"Chapter List: \\n\",\n    \"* settings\\n\",\n    \"  1. DataSet: ARC-Challenge\\n\",\n    \"  2. Models Use in \\\"FEDMKT\\\" Paper\\n\",\n    \"  3. Prepare Optimal Vocabulary Mapping Tables\\n\",\n    \"  4. Training LLMs with Lora\\n\",\n    \"* experiment examples:\\n\",\n    \"  1. Running FEDMKT With Launcher (Experimential Using): 4-SLMs\\n\",\n    \"  2. Running FEDMKT With Launcher (Experimential Using): 1-SLM (One To One)\\n\",\n    \"  3. Running FEDMKT With Launcher (Experimential Using): 1-SLM And SLM Trains Only (LLM2SLM)\\n\",\n    \"  4. Running FEDMKT With Launcher (Experimential Using): 4-SLMs Homogeneous SFT\\n\",\n    \"  5. Running FEDMKT with Pipeline (Industrial Using)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Dataset: ARC-Challenge\\n\",\n    \"\\n\",\n    \"ARC-Challenge is a dataset of 7,787 genuine grade-school level, multiple-choice science questions, assembled to encourage research in advanced question-answering. \\n\",\n    \"\\n\",\n    \"You can refer to following link for more details about [ARC-Challange](https://huggingface.co/datasets/allenai/ai2_arc)\\n\",\n    \"\\n\",\n    \"In this section, we will download ARC-Challenge dataset from huggingface and splits it into five parts, part \\\"common\\\" for public dataset and other parts for slms(opt2, gpt2, llama, opt)'s  training. \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import datasets\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"data = datasets.load_dataset(\\\"ai2_arc\\\", \\\"ARC-Challenge\\\", download_mode=\\\"force_redownload\\\", ignore_verifications=True)\\n\",\n    \"train_data = data.pop(\\\"train\\\")\\n\",\n    \"\\n\",\n    \"seed=123\\n\",\n    \"n = train_data.shape[0]\\n\",\n    \"client_num = 4\\n\",\n    \"process_data_output_dir = \\\"\\\" # processed data saved directory should be specified, it will be used in later.\\n\",\n    \"\\n\",\n    \"client_data_num = n // (client_num + 1)\\n\",\n    \"\\n\",\n    \"for i in range(client_num):\\n\",\n    \"    splits = train_data.train_test_split(train_size=client_data_num, shuffle=True, seed=seed)\\n\",\n    \"    client_name = f\\\"client_{i}\\\"\\n\",\n    \"    data[client_name] = splits[\\\"train\\\"]\\n\",\n    \"    train_data = splits[\\\"test\\\"]\\n\",\n    \"\\n\",\n    \"if train_data.shape[0] == client_data_num:\\n\",\n    \"    data[\\\"common\\\"] = train_data\\n\",\n    \"else:\\n\",\n    \"    data[\\\"common\\\"] = train_data.train_test_split(\\n\",\n    \"        train_size=client_data_num, shuffle=True, seed=args.seed\\n\",\n    \"    )[\\\"train\\\"]\\n\",\n    \"\\n\",\n    \"data.save_to_disk(process_data_output_dir)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Models Use In \\\"FEDMKT\\\" Paper\\n\",\n    \"\\n\",\n    \"LLM: [Llama-2-7B](https://huggingface.co/meta-llama/Llama-2-7b-hf)  \\n\",\n    \"SLM-0: [opt-1.3b](https://huggingface.co/facebook/opt-1.3b)  \\n\",\n    \"SLM-1: [gpt2-xlarge](https://huggingface.co/openai-community/gpt2-xl)  \\n\",\n    \"SLM-2: [Llama-1.3b](https://huggingface.co/princeton-nlp/Sheared-LLaMA-1.3B)  \\n\",\n    \"SLM-3: [bloom-1.1B](https://huggingface.co/bigscience/bloom-1b1)\\n\",\n    \"\\n\",\n    \"Users should download the models from huggingface before the following steps and saved them in local directories, as models are too big, redownload them cost too much times.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# replaoce the names of models to local save directories\\n\",\n    \"llm_pretrained_path = \\\"llama-2-7b-hf\\\"\\n\",\n    \"slm_0_pretrained_path = \\\"opt-1.3b\\\"\\n\",\n    \"slm_1_pretrained_path = \\\"gpt2-xl\\\"\\n\",\n    \"slm_2_pretrained_path = \\\"Sheared-LLaMA-1.3B\\\"\\n\",\n    \"slm_3_pretrained_path = \\\"bloom-1b1\\\"\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Prepare Optimal Vocabulary Mapping Tables\\n\",\n    \"\\n\",\n    \"To use \\\"FEDMKT\\\" for federated knowledge transfer, we need to build pptimal vocabulary mapping tables first.\\n\",\n    \"In paper of \\\"FEDMKT\\\", it has One LLM and four SLMs, so we need to build eight pptimal vocabulary mapping tables. For each paired of (LLM, SLM), two tables should be built as co-training are needed.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_llm.algo.fedmkt.token_alignment.vocab_mapping import get_vocab_mappings\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"llm_slm_pairs = [\\n\",\n    \"    (llm_pretrained_path, slm_0_pretrained_path),\\n\",\n    \"    (llm_pretrained_path, slm_1_pretrained_path),\\n\",\n    \"    (llm_pretrained_path, slm_2_pretrained_path),\\n\",\n    \"    (llm_pretrained_path, slm_3_pretrained_path)\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"vocab_mapping_directory = \\\"\\\" # replace this to actually paths\\n\",\n    \"\\n\",\n    \"slm_to_llm_vocab_mapping_paths = [\\\"opt_to_llama.json\\\", \\\"gpt2_to_llama.json\\\", \\\"llama_small_to_llama.json\\\", \\\"bloom_to_llama.json\\\"]\\n\",\n    \"llm_to_slm_vocab_mapping_paths = [\\\"llama_to_opt.json\\\", \\\"llama_to_gpt2.json\\\", \\\"llama_to_llama_small\\\", \\\"llama_to_bloom.json\\\"]\\n\",\n    \"\\n\",\n    \"for idx in range(4):\\n\",\n    \"    slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + \\\"/\\\" + slm_to_llm_vocab_mapping_paths[idx]\\n\",\n    \"    llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + \\\"/\\\" + llm_to_slm_vocab_mapping_paths[idx]\\n\",\n    \"\\n\",\n    \"for idx, (llm_pretrained, slm_pretrained) in enumerate(llm_slm_pairs):\\n\",\n    \"    slm_to_llm_vocab_mapping_path = slm_to_llm_vocab_mapping_paths[idx]\\n\",\n    \"    llm_to_slm_vocab_mapping_path = llm_to_slm_vocab_mapping_paths[idx]\\n\",\n    \"    _ = get_vocab_mappings(slm_pretrained, llm_pretrained, slm_to_llm_vocab_mapping_paths[idx], num_processors=16)\\n\",\n    \"    _ = get_vocab_mappings(llm_pretrained, slm_pretrained, llm_to_slm_vocab_mapping_paths[idx], num_processors=16)\\n\",\n    \"    \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Training LLMs with Lora\\n\",\n    \"\\n\",\n    \"In this section, We will introduce the lora configs use in five models listed in paper: one LLM (Llama-2-7B), four SLMs(opt-1.3B, gpt2-xlarge, Llama-1.3B, bloom-1.1B)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"LLM models with peft is located on fate_llm/model_zoo, we will give a guide to use them. \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Init LLm Llama-2-7B's Lora Config\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"lora_config = LoraConfig(\\n\",\n    \"    task_type=TaskType.CAUSAL_LM,\\n\",\n    \"    inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\\n\",\n    \"    target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']\\n\",\n    \")\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Init SLMs Lora Config\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"slm_pretrained_paths = [slm_0_pretrained_path, slm_1_pretrained_path, slm_2_pretrained_path, slm_3_pretrained_path]\\n\",\n    \"slm_lora_target_modules = [\\n\",\n    \"    [\\\"q_proj\\\", \\\"v_proj\\\"],\\n\",\n    \"    [\\\"c_attn\\\"],\\n\",\n    \"    ['q_proj', 'k_proj', 'v_proj', 'o_proj'],\\n\",\n    \"    [\\\"query_key_value\\\"]\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"def get_slm_conf(slm_idx):\\n\",\n    \"    slm_pretrained_path = slm_pretrained_paths[slm_idx]\\n\",\n    \"    lora_config = LoraConfig(\\n\",\n    \"        task_type=TaskType.CAUSAL_LM,\\n\",\n    \"        inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\\n\",\n    \"        target_modules=slm_lora_target_modules[slm_idx]\\n\",\n    \"    )\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Running FEDMKT With Launcher (Experimential Using): 4-SLMs\\n\",\n    \"\\n\",\n    \"Using launcher to startup is mainly for experimential. Before running this section, make sure that [FATE-LLM Standalone](https://github.com/FederatedAI/FATE-LLM?tab=readme-ov-file#standalone-deployment) has been deployed.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Global Settings\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"process_data_output_dir = \\\"\\\"\\n\",\n    \"llm_pretrained_path = \\\"Llama-2-7b-hf\\\"\\n\",\n    \"slm_0_pretrained_path = \\\"opt-1.3b\\\"\\n\",\n    \"slm_1_pretrained_path = \\\"gpt2-xl\\\"\\n\",\n    \"slm_2_pretrained_path = \\\"Sheared-LLaMa-1.3B\\\"\\n\",\n    \"slm_3_pretrained_path = \\\"bloom-1b1\\\"\\n\",\n    \"llm_slm_pairs = [\\n\",\n    \"    (llm_pretrained_path, slm_0_pretrained_path),\\n\",\n    \"    (llm_pretrained_path, slm_1_pretrained_path),\\n\",\n    \"    (llm_pretrained_path, slm_2_pretrained_path),\\n\",\n    \"    (llm_pretrained_path, slm_3_pretrained_path)\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"vocab_mapping_directory = \\\"\\\"\\n\",\n    \"\\n\",\n    \"slm_to_llm_vocab_mapping_paths = [\\\"opt_to_llama.json\\\", \\\"gpt2_to_llama.json\\\", \\\"llama_small_to_llama.json\\\", \\\"bloom_to_llama.json\\\"]\\n\",\n    \"llm_to_slm_vocab_mapping_paths = [\\\"llama_to_opt.json\\\", \\\"llama_to_gpt2.json\\\", \\\"llama_to_llama_small\\\", \\\"llama_to_bloom.json\\\"]\\n\",\n    \"\\n\",\n    \"for idx in range(4):\\n\",\n    \"    slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + \\\"/\\\" + slm_to_llm_vocab_mapping_paths[idx]\\n\",\n    \"    llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + \\\"/\\\" + llm_to_slm_vocab_mapping_paths[idx]\\n\",\n    \"\\n\",\n    \"#### all variables has been defined above\\n\",\n    \"\\n\",\n    \"slm_pretrained_paths = [slm_0_pretrained_path, slm_1_pretrained_path, slm_2_pretrained_path, slm_3_pretrained_path]\\n\",\n    \"slm_lora_target_modules = [\\n\",\n    \"    [\\\"q_proj\\\", \\\"v_proj\\\"],\\n\",\n    \"    [\\\"c_attn\\\"],\\n\",\n    \"    ['q_proj', 'k_proj', 'v_proj', 'o_proj'],\\n\",\n    \"    [\\\"query_key_value\\\"]\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"global_epochs = 1\\n\",\n    \"batch_size=4\\n\",\n    \"llm_lr = 3e-5\\n\",\n    \"slm_lrs = [3e-5, 3e-4, 3e-5, 3e-5, 3e-5]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Init FEDMKTLLM Runner\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"In this Section, we will introduce how to initialize \\\"FEDMKTLLM\\\" object.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"##### Step1: Initialize LLM With LoraConfig\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from peft import LoraConfig, TaskType\\n\",\n    \"from fate_llm.model_zoo.pellm.llama import LLaMa\\n\",\n    \"from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTLLM\\n\",\n    \"from fate.ml.nn.homo.fedavg import FedAVGArguments\\n\",\n    \"from fate_llm.dataset.qa_dataset import QaDataset\\n\",\n    \"from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\\n\",\n    \"from transformers import AutoConfig\\n\",\n    \"\\n\",\n    \"lora_config = LoraConfig(\\n\",\n    \"    task_type=TaskType.CAUSAL_LM,\\n\",\n    \"    inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\\n\",\n    \"    target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"model = LLaMa(\\n\",\n    \"    pretrained_path=llm_pretrained_path,\\n\",\n    \"    peft_type=\\\"LoraConfig\\\",\\n\",\n    \"    peft_config=lora_config.to_dict(),\\n\",\n    \"    torch_dtype=\\\"bfloat16\\\"    \\n\",\n    \")\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"##### Step2: Specify Public Dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"pub_data = QaDataset(tokenizer_name_or_path=llm_pretrained_path,\\n\",\n    \"                     dataset_name=\\\"arc_challenge\\\",\\n\",\n    \"                     data_part=\\\"common\\\",\\n\",\n    \"                     seq_max_len=512,\\n\",\n    \"                     need_preprocess=True)\\n\",\n    \"pub_data.load(process_data_output_dir)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"##### Step3: Initialize FEDMKT Training Args\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"training_args = FedMKTTrainingArguments(\\n\",\n    \"    global_epochs=global_epochs,\\n\",\n    \"    per_device_train_batch_size=1,\\n\",\n    \"    gradient_accumulation_steps=batch_size,\\n\",\n    \"    learning_rate=llm_lr,\\n\",\n    \"    output_dir=\\\"./\\\",\\n\",\n    \"    dataloader_num_workers=4,\\n\",\n    \"    remove_unused_columns=False,\\n\",\n    \"    warmup_ratio=0.008,\\n\",\n    \"    lr_scheduler_type=\\\"cosine\\\",\\n\",\n    \"    optim=\\\"adamw_torch\\\",\\n\",\n    \"    adam_beta1=0.9,\\n\",\n    \"    adam_beta2=0.95,\\n\",\n    \"    weight_decay=0.1,\\n\",\n    \"    max_grad_norm=1.0,\\n\",\n    \"    use_cpu=False,\\n\",\n    \"    vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size, # pay attention to this, \\n\",\n    \"                                                                           # vocab_size must be specified to avoid dimension mismatch \\n\",\n    \"                                                                           # of tokenizer's vocab_size\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"##### Step4: Initialize Other Variables\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"fed_args = FedAVGArguments(\\n\",\n    \"    aggregate_strategy='epoch',\\n\",\n    \"    aggregate_freq=1\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"slm_to_llm_vocab_mapping = []\\n\",\n    \"for path in slm_to_llm_vocab_mapping_paths:\\n\",\n    \"    with open(path, \\\"r\\\") as fin:\\n\",\n    \"        vocab_mapping = json.loads(fin.read())\\n\",\n    \"        slm_to_llm_vocab_mapping.append(vocab_mapping)\\n\",\n    \"\\n\",\n    \"slm_tokenizers = [get_tokenizer(slm_pretrained_path) for slm_pretrained_path in slm_pretrained_paths]\\n\",\n    \"tokenizer = get_tokenizer(llm_pretrained_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"##### Step5: New FEDMKTLLM Object\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"trainer = FedMKTLLM(\\n\",\n    \"    ctx=ctx,\\n\",\n    \"    model=model,\\n\",\n    \"    training_args=training_args,\\n\",\n    \"    fed_args=fed_args,\\n\",\n    \"    train_set=pub_data,\\n\",\n    \"    tokenizer=tokenizer,\\n\",\n    \"    slm_tokenizers=slm_tokenizers,\\n\",\n    \"    slm_to_llm_vocab_mappings=slm_to_llm_vocab_mapping,\\n\",\n    \"    save_trainable_weights_only=True, # save lora weights only\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"##### Step6: Training And Save Results\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"trainer.train()\\n\",\n    \"trainer.save_model(output_dir=\\\"fill the path to save llm finetuning result\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Init FEDMKTSLM Runner\\n\",\n    \"\\n\",\n    \"FEDMKTSLM Runner is a slightly different of FEDMKTLLM Runner, we only introduce different variables\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"##### Import SLMs you need to run, here we choose four Slms Using In Original Paper.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import transformers\\n\",\n    \"from peft import LoraConfig, TaskType    \\n\",\n    \"from fate_llm.model_zoo.pellm.llama import LLaMa\\n\",\n    \"from fate_llm.model_zoo.pellm.gpt2 import GPT2CLM\\n\",\n    \"from fate_llm.model_zoo.pellm.opt import OPT\\n\",\n    \"from fate_llm.model_zoo.pellm.bloom import Bloom\\n\",\n    \"from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTSLM\\n\",\n    \"from fate_llm.dataset.qa_dataset import QaDataset\\n\",\n    \"from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\\n\",\n    \"from transformers import AutoConfig\\n\",\n    \"\\n\",\n    \"slm_idx = 0\\n\",\n    \"\\n\",\n    \"slm_model_class = [\\n\",\n    \"    OPT,\\n\",\n    \"    GPT2CLM,\\n\",\n    \"    LLaMa,\\n\",\n    \"    Bloom\\n\",\n    \"]\\n\",\n    \"    \\n\",\n    \"lora_config = LoraConfig(\\n\",\n    \"    task_type=TaskType.CAUSAL_LM,\\n\",\n    \"    inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\\n\",\n    \"    target_modules=slm_lora_target_modules[slm_idx]\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"model = slm_model_class[slm_idx](\\n\",\n    \"    pretrained_path=slm_pretrained_paths[slm_idx],\\n\",\n    \"    peft_type=\\\"LoraConfig\\\",\\n\",\n    \"    peft_config=lora_config.to_dict(),\\n\",\n    \"    torch_dtype=\\\"bfloat16\\\"\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"##### Specify Private Dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"priv_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\\n\",\n    \"                      dataset_name=\\\"arc_challenge\\\",\\n\",\n    \"                      data_part=f\\\"client_{slm_idx}\\\",\\n\",\n    \"                      seq_max_len=512,\\n\",\n    \"                      need_preprocess=True)\\n\",\n    \"priv_data.load(process_data_output_dir)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"##### Other Variables \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"tokenizer = get_tokenizer(slm_pretrained_paths[slm_idx])\\n\",\n    \"\\n\",\n    \"import json\\n\",\n    \"with open(llm_to_slm_vocab_mapping_paths[slm_idx], \\\"r\\\") as fin:\\n\",\n    \"    vocab_mapping = json.loads(fin.read())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"##### New FEDMKTSLM Object\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"trainer = FedMKTSLM(\\n\",\n    \"    ctx=ctx,\\n\",\n    \"    model=model,\\n\",\n    \"    training_args=training_args,\\n\",\n    \"    fed_args=fed_args,\\n\",\n    \"    pub_train_set=pub_data,\\n\",\n    \"    priv_train_set=priv_data,\\n\",\n    \"    tokenizer=tokenizer,\\n\",\n    \"    save_trainable_weights_only=True, # save lora weights only\\n\",\n    \"    llm_tokenizer=get_tokenizer(llm_pretrained_path), # different with LLM setting\\n\",\n    \"    llm_to_slm_vocab_mapping=vocab_mapping, # different with LLM setting\\n\",\n    \"    data_collator=transformers.DataCollatorForSeq2Seq(tokenizer) # use to train private dataset\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Complete Code To DO SFT With 4 SLMs\\n\",\n    \"\\n\",\n    \"Please paste the code in \\\"fedmkt_4_slms.py\\\" and execute it with the following command\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"```python\\n\",\n    \"python fedmkt_4_slms.py --parties guest:9999 host:9999 host:10000 host:10001 arbiter:9999 --log_level INFO\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# fedmkt_4_slms.py\\n\",\n    \"\\n\",\n    \"import os\\n\",\n    \"\\n\",\n    \"from fate.arch import Context\\n\",\n    \"from fate.arch.launchers.multiprocess_launcher import launch\\n\",\n    \"import json\\n\",\n    \"\\n\",\n    \"process_data_output_dir = \\\"\\\"\\n\",\n    \"llm_pretrained_path = \\\"Llama-2-7b-hf\\\"\\n\",\n    \"slm_0_pretrained_path = \\\"opt-1.3b\\\"\\n\",\n    \"slm_1_pretrained_path = \\\"gpt2-xl\\\"\\n\",\n    \"slm_2_pretrained_path = \\\"Sheared-LLaMa-1.3B\\\"\\n\",\n    \"slm_3_pretrained_path = \\\"bloom-1b1\\\"\\n\",\n    \"llm_slm_pairs = [\\n\",\n    \"    (llm_pretrained_path, slm_0_pretrained_path),\\n\",\n    \"    (llm_pretrained_path, slm_1_pretrained_path),\\n\",\n    \"    (llm_pretrained_path, slm_2_pretrained_path),\\n\",\n    \"    (llm_pretrained_path, slm_3_pretrained_path)\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"vocab_mapping_directory = \\\"\\\"\\n\",\n    \"\\n\",\n    \"slm_to_llm_vocab_mapping_paths = [\\\"opt_to_llama.json\\\", \\\"gpt2_to_llama.json\\\", \\\"llama_small_to_llama.json\\\", \\\"bloom_to_llama.json\\\"]\\n\",\n    \"llm_to_slm_vocab_mapping_paths = [\\\"llama_to_opt.json\\\", \\\"llama_to_gpt2.json\\\", \\\"llama_to_llama_small\\\", \\\"llama_to_bloom.json\\\"]\\n\",\n    \"\\n\",\n    \"for idx in range(4):\\n\",\n    \"    slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + \\\"/\\\" + slm_to_llm_vocab_mapping_paths[idx]\\n\",\n    \"    llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + \\\"/\\\" + llm_to_slm_vocab_mapping_paths[idx]\\n\",\n    \"\\n\",\n    \"slm_pretrained_paths = [slm_0_pretrained_path, slm_1_pretrained_path, slm_2_pretrained_path, slm_3_pretrained_path]\\n\",\n    \"slm_lora_target_modules = [\\n\",\n    \"    [\\\"q_proj\\\", \\\"v_proj\\\"],\\n\",\n    \"    [\\\"c_attn\\\"],\\n\",\n    \"    ['q_proj', 'k_proj', 'v_proj', 'o_proj'],\\n\",\n    \"    [\\\"query_key_value\\\"]\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"global_epochs = 5\\n\",\n    \"batch_size=4\\n\",\n    \"llm_lr = 3e-5\\n\",\n    \"slm_lrs = [3e-5, 3e-4, 3e-5, 3e-5, 3e-5]\\n\",\n    \"\\n\",\n    \"llm_model_saved_directory = \\\"./models/fedmkt_4_slms_llm_model\\\"\\n\",\n    \"slm_models_saved_directory = [\\n\",\n    \"    \\\"./models/fedmkt_4_slms_slm_0\\\", \\n\",\n    \"    \\\"./models/fedmkt_4_slms_slm_1\\\", \\n\",\n    \"    \\\"./models/fedmkt_4_slms_slm_2\\\", \\n\",\n    \"    \\\"./models/fedmkt_4_slms_slm_3\\\"\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def train_llm(ctx):\\n\",\n    \"    from peft import LoraConfig, TaskType\\n\",\n    \"    from fate_llm.model_zoo.pellm.llama import LLaMa\\n\",\n    \"    from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTLLM\\n\",\n    \"    from fate_llm.dataset.qa_dataset import QaDataset\\n\",\n    \"    from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\\n\",\n    \"    from transformers import AutoConfig\\n\",\n    \"\\n\",\n    \"    lora_config = LoraConfig(\\n\",\n    \"        task_type=TaskType.CAUSAL_LM,\\n\",\n    \"        inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\\n\",\n    \"        target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    model = LLaMa(\\n\",\n    \"        pretrained_path=llm_pretrained_path,\\n\",\n    \"        peft_type=\\\"LoraConfig\\\",\\n\",\n    \"        peft_config=lora_config.to_dict(),\\n\",\n    \"        torch_dtype=\\\"bfloat16\\\"\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    pub_data = QaDataset(tokenizer_name_or_path=llm_pretrained_path,\\n\",\n    \"                         dataset_name=\\\"arc_challenge\\\",\\n\",\n    \"                         data_part=\\\"common\\\",\\n\",\n    \"                         seq_max_len=512,\\n\",\n    \"                         need_preprocess=True)\\n\",\n    \"    pub_data.load(process_data_output_dir)\\n\",\n    \"\\n\",\n    \"    training_args = FedMKTTrainingArguments(\\n\",\n    \"        global_epochs=global_epochs,\\n\",\n    \"        per_device_train_batch_size=1,\\n\",\n    \"        gradient_accumulation_steps=batch_size,\\n\",\n    \"        learning_rate=llm_lr,\\n\",\n    \"        output_dir=\\\"./\\\",\\n\",\n    \"        dataloader_num_workers=4,\\n\",\n    \"        remove_unused_columns=False,\\n\",\n    \"        warmup_ratio=0.008,\\n\",\n    \"        lr_scheduler_type=\\\"cosine\\\",\\n\",\n    \"        optim=\\\"adamw_torch\\\",\\n\",\n    \"        adam_beta1=0.9,\\n\",\n    \"        adam_beta2=0.95,\\n\",\n    \"        weight_decay=0.1,\\n\",\n    \"        max_grad_norm=1.0,\\n\",\n    \"        use_cpu=False,\\n\",\n    \"        vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    slm_to_llm_vocab_mapping = []\\n\",\n    \"    for path in slm_to_llm_vocab_mapping_paths:\\n\",\n    \"        with open(path, \\\"r\\\") as fin:\\n\",\n    \"            vocab_mapping = json.loads(fin.read())\\n\",\n    \"            slm_to_llm_vocab_mapping.append(vocab_mapping)\\n\",\n    \"\\n\",\n    \"    slm_tokenizers = [get_tokenizer(slm_pretrained_path) for slm_pretrained_path in slm_pretrained_paths]\\n\",\n    \"\\n\",\n    \"    tokenizer = get_tokenizer(llm_pretrained_path)\\n\",\n    \"    trainer = FedMKTLLM(\\n\",\n    \"        ctx=ctx,\\n\",\n    \"        model=model,\\n\",\n    \"        training_args=training_args,\\n\",\n    \"        train_set=pub_data,\\n\",\n    \"        tokenizer=tokenizer,\\n\",\n    \"        slm_tokenizers=slm_tokenizers,\\n\",\n    \"        slm_to_llm_vocab_mappings=slm_to_llm_vocab_mapping,\\n\",\n    \"        save_trainable_weights_only=True,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    trainer.train()\\n\",\n    \"    trainer.save_model(llm_model_saved_directory)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def train_slm(ctx, slm_idx):\\n\",\n    \"    import transformers\\n\",\n    \"    from peft import LoraConfig, TaskType\\n\",\n    \"    from fate_llm.model_zoo.pellm.llama import LLaMa\\n\",\n    \"    from fate_llm.model_zoo.pellm.gpt2 import GPT2CLM\\n\",\n    \"    from fate_llm.model_zoo.pellm.opt import OPT\\n\",\n    \"    from fate_llm.model_zoo.pellm.bloom import Bloom\\n\",\n    \"    from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTSLM\\n\",\n    \"    from fate_llm.dataset.qa_dataset import QaDataset\\n\",\n    \"    from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\\n\",\n    \"    from transformers import AutoConfig\\n\",\n    \"\\n\",\n    \"    slm_model_class = [\\n\",\n    \"        OPT,\\n\",\n    \"        GPT2CLM,\\n\",\n    \"        LLaMa,\\n\",\n    \"        Bloom\\n\",\n    \"    ]\\n\",\n    \"\\n\",\n    \"    lora_config = LoraConfig(\\n\",\n    \"        task_type=TaskType.CAUSAL_LM,\\n\",\n    \"        inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\\n\",\n    \"        target_modules=slm_lora_target_modules[slm_idx]\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    model = slm_model_class[slm_idx](\\n\",\n    \"        pretrained_path=slm_pretrained_paths[slm_idx],\\n\",\n    \"        peft_type=\\\"LoraConfig\\\",\\n\",\n    \"        peft_config=lora_config.to_dict(),\\n\",\n    \"        torch_dtype=\\\"bfloat16\\\"\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    priv_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\\n\",\n    \"                          dataset_name=\\\"arc_challenge\\\",\\n\",\n    \"                          data_part=f\\\"client_{slm_idx}\\\",\\n\",\n    \"                          seq_max_len=512,\\n\",\n    \"                          need_preprocess=True)\\n\",\n    \"    priv_data.load(process_data_output_dir)\\n\",\n    \"\\n\",\n    \"    pub_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\\n\",\n    \"                         dataset_name=\\\"arc_challenge\\\",\\n\",\n    \"                         data_part=\\\"common\\\",\\n\",\n    \"                         seq_max_len=512,\\n\",\n    \"                         need_preprocess=True)\\n\",\n    \"    pub_data.load(process_data_output_dir)\\n\",\n    \"\\n\",\n    \"    training_args = FedMKTTrainingArguments(\\n\",\n    \"        global_epochs=global_epochs,\\n\",\n    \"        per_device_train_batch_size=1,\\n\",\n    \"        gradient_accumulation_steps=batch_size,\\n\",\n    \"        learning_rate=slm_lrs[slm_idx],\\n\",\n    \"        output_dir=\\\"./\\\",\\n\",\n    \"        dataloader_num_workers=4,\\n\",\n    \"        remove_unused_columns=False,\\n\",\n    \"        warmup_ratio=0.008,\\n\",\n    \"        lr_scheduler_type=\\\"cosine\\\",\\n\",\n    \"        optim=\\\"adamw_torch\\\",\\n\",\n    \"        adam_beta1=0.9,\\n\",\n    \"        adam_beta2=0.95,\\n\",\n    \"        weight_decay=0.1,\\n\",\n    \"        max_grad_norm=1.0,\\n\",\n    \"        use_cpu=False,\\n\",\n    \"        vocab_size=AutoConfig.from_pretrained(slm_pretrained_paths[slm_idx]).vocab_size,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    tokenizer = get_tokenizer(slm_pretrained_paths[slm_idx])\\n\",\n    \"\\n\",\n    \"    import json\\n\",\n    \"    with open(llm_to_slm_vocab_mapping_paths[slm_idx], \\\"r\\\") as fin:\\n\",\n    \"        vocab_mapping = json.loads(fin.read())\\n\",\n    \"\\n\",\n    \"    trainer = FedMKTSLM(\\n\",\n    \"        ctx=ctx,\\n\",\n    \"        model=model,\\n\",\n    \"        training_args=training_args,\\n\",\n    \"        pub_train_set=pub_data,\\n\",\n    \"        priv_train_set=priv_data,\\n\",\n    \"        tokenizer=tokenizer,\\n\",\n    \"        save_trainable_weights_only=True,\\n\",\n    \"        llm_tokenizer=get_tokenizer(llm_pretrained_path),\\n\",\n    \"        llm_to_slm_vocab_mapping=vocab_mapping,\\n\",\n    \"        data_collator=transformers.DataCollatorForSeq2Seq(tokenizer)\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    trainer.train()\\n\",\n    \"    trainer.save_model(slm_models_saved_directory[slm_idx])\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def run(ctx: Context):\\n\",\n    \"    if ctx.is_on_arbiter:\\n\",\n    \"        os.environ[\\\"CUDA_VISIBLE_DEVICES\\\"] = \\\"0\\\"\\n\",\n    \"        train_llm(ctx)\\n\",\n    \"    elif ctx.is_on_guest:\\n\",\n    \"        os.environ[\\\"CUDA_VISIBLE_DEVICES\\\"] = \\\"1\\\"\\n\",\n    \"        train_slm(ctx, slm_idx=0)\\n\",\n    \"    else:\\n\",\n    \"        if ctx.local.party[1] == \\\"9999\\\":\\n\",\n    \"            os.environ[\\\"CUDA_VISIBLE_DEVICES\\\"] = \\\"2\\\"\\n\",\n    \"            slm_idx = 1\\n\",\n    \"        elif ctx.local.party[1] == \\\"10000\\\":\\n\",\n    \"            os.environ[\\\"CUDA_VISIBLE_DEVICES\\\"] = \\\"3\\\"\\n\",\n    \"            slm_idx = 2\\n\",\n    \"        elif ctx.local.party[1] == \\\"10001\\\":\\n\",\n    \"            os.environ[\\\"CUDA_VISIBLE_DEVICES\\\"] = \\\"4\\\"\\n\",\n    \"            slm_idx = 3\\n\",\n    \"        else:\\n\",\n    \"            raise ValueError(f\\\"party_id={ctx.local.party[1]} is illegal\\\")\\n\",\n    \"\\n\",\n    \"        train_slm(ctx, slm_idx=slm_idx)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"if __name__ == \\\"__main__\\\":\\n\",\n    \"    launch(run)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Running FEDMKT With Launcher (Experimential Using): 1-SLM (One To One)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Actually, a slightly modifications from 4-SLMs running code are enough to do sft with single clients, it will be listed in below sections, we take SLM-0(OPT) as an example\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Only Use Single Optimal Vocabulary Mapping Tables\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"slm_idx = 0\\n\",\n    \"slm_to_llm_vocab_mapping = []\\n\",\n    \"with open(slm_to_llm_vocab_mapping_paths[slm_idx], \\\"r\\\") as fin:\\n\",\n    \"    vocab_mapping = json.loads(fin.read())\\n\",\n    \"    slm_to_llm_vocab_mapping.append(vocab_mapping)\\n\",\n    \"\\n\",\n    \"slm_tokenizers = [get_tokenizer(slm_pretrained_paths[slm_idx])]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Complete Code To DO SFT With 1 SLM\\n\",\n    \"\\n\",\n    \"Please paste the code in \\\"fedmkt_1_slm.py\\\" and execute it with the following command\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"```python\\n\",\n    \"python fedmkt_1_slm.py --parties guest:9999 arbiter:9999 --log_level INFO\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# fedmkt_1_slm.py\\n\",\n    \"\\n\",\n    \"import os\\n\",\n    \"\\n\",\n    \"from fate.arch import Context\\n\",\n    \"from fate.arch.launchers.multiprocess_launcher import launch\\n\",\n    \"import json\\n\",\n    \"\\n\",\n    \"process_data_output_dir = \\\"\\\"\\n\",\n    \"llm_pretrained_path = \\\"Llama-2-7b-hf\\\"\\n\",\n    \"slm_0_pretrained_path = \\\"opt-1.3b\\\"\\n\",\n    \"slm_1_pretrained_path = \\\"gpt2-xl\\\"\\n\",\n    \"slm_2_pretrained_path = \\\"Sheared-LLaMa-1.3B\\\"\\n\",\n    \"slm_3_pretrained_path = \\\"bloom-1b1\\\"\\n\",\n    \"llm_slm_pairs = [\\n\",\n    \"    (llm_pretrained_path, slm_0_pretrained_path),\\n\",\n    \"    (llm_pretrained_path, slm_1_pretrained_path),\\n\",\n    \"    (llm_pretrained_path, slm_2_pretrained_path),\\n\",\n    \"    (llm_pretrained_path, slm_3_pretrained_path)\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"vocab_mapping_directory = \\\"\\\"\\n\",\n    \"\\n\",\n    \"slm_to_llm_vocab_mapping_paths = [\\\"opt_to_llama.json\\\", \\\"gpt2_to_llama.json\\\", \\\"llama_small_to_llama.json\\\", \\\"bloom_to_llama.json\\\"]\\n\",\n    \"llm_to_slm_vocab_mapping_paths = [\\\"llama_to_opt.json\\\", \\\"llama_to_gpt2.json\\\", \\\"llama_to_llama_small\\\", \\\"llama_to_bloom.json\\\"]\\n\",\n    \"\\n\",\n    \"for idx in range(4):\\n\",\n    \"    slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + \\\"/\\\" + slm_to_llm_vocab_mapping_paths[idx]\\n\",\n    \"    llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + \\\"/\\\" + llm_to_slm_vocab_mapping_paths[idx]\\n\",\n    \"\\n\",\n    \"slm_pretrained_paths = [slm_0_pretrained_path, slm_1_pretrained_path, slm_2_pretrained_path, slm_3_pretrained_path]\\n\",\n    \"slm_lora_target_modules = [\\n\",\n    \"    [\\\"q_proj\\\", \\\"v_proj\\\"],\\n\",\n    \"    [\\\"c_attn\\\"],\\n\",\n    \"    ['q_proj', 'k_proj', 'v_proj', 'o_proj'],\\n\",\n    \"    [\\\"query_key_value\\\"]\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"global_epochs = 5\\n\",\n    \"batch_size = 4\\n\",\n    \"llm_lr = 3e-5\\n\",\n    \"slm_lrs = [3e-5]\\n\",\n    \"\\n\",\n    \"llm_model_saved_directory = \\\"./models/fedmkt_single_slm_llm\\\"\\n\",\n    \"slm_models_saved_directory = [\\n\",\n    \"    \\\"./models/fedmkt_single_slm_opt\\\",\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def train_llm(ctx, slm_idx):\\n\",\n    \"    from peft import LoraConfig, TaskType\\n\",\n    \"    from fate_llm.model_zoo.pellm.llama import LLaMa\\n\",\n    \"    from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTLLM\\n\",\n    \"    from fate_llm.dataset.qa_dataset import QaDataset\\n\",\n    \"    from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\\n\",\n    \"    from transformers import AutoConfig\\n\",\n    \"\\n\",\n    \"    lora_config = LoraConfig(\\n\",\n    \"        task_type=TaskType.CAUSAL_LM,\\n\",\n    \"        inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\\n\",\n    \"        target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    model = LLaMa(\\n\",\n    \"        pretrained_path=llm_pretrained_path,\\n\",\n    \"        peft_type=\\\"LoraConfig\\\",\\n\",\n    \"        peft_config=lora_config.to_dict(),\\n\",\n    \"        torch_dtype=\\\"bfloat16\\\"\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    pub_data = QaDataset(tokenizer_name_or_path=llm_pretrained_path,\\n\",\n    \"                         dataset_name=\\\"arc_challenge\\\",\\n\",\n    \"                         data_part=\\\"common\\\",\\n\",\n    \"                         seq_max_len=512,\\n\",\n    \"                         need_preprocess=True)\\n\",\n    \"    pub_data.load(process_data_output_dir)\\n\",\n    \"\\n\",\n    \"    training_args = FedMKTTrainingArguments(\\n\",\n    \"        global_epochs=global_epochs,\\n\",\n    \"        per_device_train_batch_size=1,\\n\",\n    \"        gradient_accumulation_steps=batch_size,\\n\",\n    \"        learning_rate=llm_lr,\\n\",\n    \"        output_dir=\\\"./\\\",\\n\",\n    \"        dataloader_num_workers=4,\\n\",\n    \"        remove_unused_columns=False,\\n\",\n    \"        warmup_ratio=0.008,\\n\",\n    \"        lr_scheduler_type=\\\"cosine\\\",\\n\",\n    \"        optim=\\\"adamw_torch\\\",\\n\",\n    \"        adam_beta1=0.9,\\n\",\n    \"        adam_beta2=0.95,\\n\",\n    \"        weight_decay=0.1,\\n\",\n    \"        max_grad_norm=1.0,\\n\",\n    \"        use_cpu=False,\\n\",\n    \"        vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    slm_to_llm_vocab_mapping = []\\n\",\n    \"    with open(slm_to_llm_vocab_mapping_paths[slm_idx], \\\"r\\\") as fin:\\n\",\n    \"        vocab_mapping = json.loads(fin.read())\\n\",\n    \"        slm_to_llm_vocab_mapping.append(vocab_mapping)\\n\",\n    \"\\n\",\n    \"    slm_tokenizers = [get_tokenizer(slm_pretrained_paths[slm_idx])]\\n\",\n    \"\\n\",\n    \"    tokenizer = get_tokenizer(llm_pretrained_path)\\n\",\n    \"    trainer = FedMKTLLM(\\n\",\n    \"        ctx=ctx,\\n\",\n    \"        model=model,\\n\",\n    \"        training_args=training_args,\\n\",\n    \"        train_set=pub_data,\\n\",\n    \"        tokenizer=tokenizer,\\n\",\n    \"        slm_tokenizers=slm_tokenizers,\\n\",\n    \"        slm_to_llm_vocab_mappings=slm_to_llm_vocab_mapping,\\n\",\n    \"        save_trainable_weights_only=True,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    trainer.train()\\n\",\n    \"    trainer.save_model(llm_model_saved_directory)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def train_slm(ctx, slm_idx):\\n\",\n    \"    import transformers\\n\",\n    \"    from peft import LoraConfig, TaskType\\n\",\n    \"    from fate_llm.model_zoo.pellm.llama import LLaMa\\n\",\n    \"    from fate_llm.model_zoo.pellm.gpt2 import GPT2CLM\\n\",\n    \"    from fate_llm.model_zoo.pellm.opt import OPT\\n\",\n    \"    from fate_llm.model_zoo.pellm.bloom import Bloom\\n\",\n    \"    from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTSLM\\n\",\n    \"    from fate_llm.dataset.qa_dataset import QaDataset\\n\",\n    \"    from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\\n\",\n    \"    from transformers import AutoConfig\\n\",\n    \"\\n\",\n    \"    slm_model_class = [\\n\",\n    \"        OPT,\\n\",\n    \"        GPT2CLM,\\n\",\n    \"        LLaMa,\\n\",\n    \"        Bloom\\n\",\n    \"    ]\\n\",\n    \"\\n\",\n    \"    lora_config = LoraConfig(\\n\",\n    \"        task_type=TaskType.CAUSAL_LM,\\n\",\n    \"        inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\\n\",\n    \"        target_modules=slm_lora_target_modules[slm_idx]\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    model = slm_model_class[slm_idx](\\n\",\n    \"        pretrained_path=slm_pretrained_paths[slm_idx],\\n\",\n    \"        peft_type=\\\"LoraConfig\\\",\\n\",\n    \"        peft_config=lora_config.to_dict(),\\n\",\n    \"        torch_dtype=\\\"bfloat16\\\"\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    priv_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\\n\",\n    \"                          dataset_name=\\\"arc_challenge\\\",\\n\",\n    \"                          data_part=f\\\"client_{slm_idx}\\\",\\n\",\n    \"                          seq_max_len=512,\\n\",\n    \"                          need_preprocess=True)\\n\",\n    \"    priv_data.load(process_data_output_dir)\\n\",\n    \"\\n\",\n    \"    pub_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\\n\",\n    \"                         dataset_name=\\\"arc_challenge\\\",\\n\",\n    \"                         data_part=\\\"common\\\",\\n\",\n    \"                         seq_max_len=512,\\n\",\n    \"                         need_preprocess=True)\\n\",\n    \"    pub_data.load(process_data_output_dir)\\n\",\n    \"\\n\",\n    \"    training_args = FedMKTTrainingArguments(\\n\",\n    \"        global_epochs=global_epochs,\\n\",\n    \"        per_device_train_batch_size=1,\\n\",\n    \"        gradient_accumulation_steps=batch_size,\\n\",\n    \"        learning_rate=slm_lrs[slm_idx],\\n\",\n    \"        output_dir=\\\"./\\\",\\n\",\n    \"        dataloader_num_workers=4,\\n\",\n    \"        remove_unused_columns=False,\\n\",\n    \"        warmup_ratio=0.008,\\n\",\n    \"        lr_scheduler_type=\\\"cosine\\\",\\n\",\n    \"        optim=\\\"adamw_torch\\\",\\n\",\n    \"        adam_beta1=0.9,\\n\",\n    \"        adam_beta2=0.95,\\n\",\n    \"        weight_decay=0.1,\\n\",\n    \"        max_grad_norm=1.0,\\n\",\n    \"        use_cpu=False,\\n\",\n    \"        vocab_size=AutoConfig.from_pretrained(slm_pretrained_paths[slm_idx]).vocab_size,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    tokenizer = get_tokenizer(slm_pretrained_paths[slm_idx])\\n\",\n    \"\\n\",\n    \"    import json\\n\",\n    \"    with open(llm_to_slm_vocab_mapping_paths[slm_idx], \\\"r\\\") as fin:\\n\",\n    \"        vocab_mapping = json.loads(fin.read())\\n\",\n    \"\\n\",\n    \"    trainer = FedMKTSLM(\\n\",\n    \"        ctx=ctx,\\n\",\n    \"        model=model,\\n\",\n    \"        training_args=training_args,\\n\",\n    \"        pub_train_set=pub_data,\\n\",\n    \"        priv_train_set=priv_data,\\n\",\n    \"        tokenizer=tokenizer,\\n\",\n    \"        save_trainable_weights_only=True,\\n\",\n    \"        llm_tokenizer=get_tokenizer(llm_pretrained_path),\\n\",\n    \"        llm_to_slm_vocab_mapping=vocab_mapping,\\n\",\n    \"        data_collator=transformers.DataCollatorForSeq2Seq(tokenizer)\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    trainer.train()\\n\",\n    \"    trainer.save_model(slm_models_saved_directory[slm_idx])\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def run(ctx: Context):\\n\",\n    \"    if ctx.is_on_arbiter:\\n\",\n    \"        os.environ[\\\"CUDA_VISIBLE_DEVICES\\\"] = \\\"0\\\"\\n\",\n    \"        train_llm(ctx, slm_idx=0)\\n\",\n    \"    else:\\n\",\n    \"        os.environ[\\\"CUDA_VISIBLE_DEVICES\\\"] = \\\"1\\\"\\n\",\n    \"        train_slm(ctx, slm_idx=0)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"if __name__ == \\\"__main__\\\":\\n\",\n    \"    launch(run)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Running FEDMKT With Launcher (Experimential Using): 1-SLM And SLM Trains Only (LLM2SLM)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"In this section, we introduce how to do SFT using FEDMKT algorithm, with only single SLM are trained, but without LLM training, means that SLM distill knowlege from LLM only, not co-training.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Difference With Section \\\"Running FEDMKT With Launcher (Experimential Using): 1-SLMs\\\"\\n\",\n    \"\\n\",\n    \"Add llm_training=False to fedmkt_training_args to both LLM and LLM is enough!\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Complete Code To DO SFT With 1 SLM And SLM Trains Only\\n\",\n    \"\\n\",\n    \"Please paste the code in \\\"fedmkt_llm_to_slm.py\\\" and execute it with the following command\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"```python\\n\",\n    \"python fedmkt_llm_to_slm.py --parties guest:9999 arbiter:9999 --log_level INFO\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# fedmkt_llm_to_slm.py\\n\",\n    \"\\n\",\n    \"import os\\n\",\n    \"\\n\",\n    \"from fate.arch import Context\\n\",\n    \"from fate.arch.launchers.multiprocess_launcher import launch\\n\",\n    \"import json\\n\",\n    \"\\n\",\n    \"process_data_output_dir = \\\"\\\"\\n\",\n    \"llm_pretrained_path = \\\"Llama-2-7b-hf\\\"\\n\",\n    \"slm_0_pretrained_path = \\\"opt-1.3b\\\"\\n\",\n    \"slm_1_pretrained_path = \\\"gpt2-xl\\\"\\n\",\n    \"slm_2_pretrained_path = \\\"Sheared-LLaMa-1.3B\\\"\\n\",\n    \"slm_3_pretrained_path = \\\"bloom-1b1\\\"\\n\",\n    \"llm_slm_pairs = [\\n\",\n    \"    (llm_pretrained_path, slm_0_pretrained_path),\\n\",\n    \"    (llm_pretrained_path, slm_1_pretrained_path),\\n\",\n    \"    (llm_pretrained_path, slm_2_pretrained_path),\\n\",\n    \"    (llm_pretrained_path, slm_3_pretrained_path)\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"vocab_mapping_directory = \\\"\\\"\\n\",\n    \"\\n\",\n    \"slm_to_llm_vocab_mapping_paths = [\\\"opt_to_llama.json\\\", \\\"gpt2_to_llama.json\\\", \\\"llama_small_to_llama.json\\\", \\\"bloom_to_llama.json\\\"]\\n\",\n    \"llm_to_slm_vocab_mapping_paths = [\\\"llama_to_opt.json\\\", \\\"llama_to_gpt2.json\\\", \\\"llama_to_llama_small\\\", \\\"llama_to_bloom.json\\\"]\\n\",\n    \"\\n\",\n    \"for idx in range(4):\\n\",\n    \"    slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + \\\"/\\\" + slm_to_llm_vocab_mapping_paths[idx]\\n\",\n    \"    llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + \\\"/\\\" + llm_to_slm_vocab_mapping_paths[idx]\\n\",\n    \"\\n\",\n    \"slm_pretrained_paths = [slm_0_pretrained_path, slm_1_pretrained_path, slm_2_pretrained_path, slm_3_pretrained_path]\\n\",\n    \"slm_lora_target_modules = [\\n\",\n    \"    [\\\"q_proj\\\", \\\"v_proj\\\"],\\n\",\n    \"    [\\\"c_attn\\\"],\\n\",\n    \"    ['q_proj', 'k_proj', 'v_proj', 'o_proj'],\\n\",\n    \"    [\\\"query_key_value\\\"]\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"global_epochs = 5\\n\",\n    \"batch_size = 4\\n\",\n    \"llm_lr = 3e-5\\n\",\n    \"slm_lrs = [3e-5]\\n\",\n    \"\\n\",\n    \"slm_models_saved_directory = [\\n\",\n    \"    \\\"./models/fedmkt_llm_to_slm_opt\\\",\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def train_llm(ctx, slm_idx):\\n\",\n    \"    from peft import LoraConfig, TaskType\\n\",\n    \"    from fate_llm.model_zoo.pellm.llama import LLaMa\\n\",\n    \"    from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTLLM\\n\",\n    \"    from fate_llm.dataset.qa_dataset import QaDataset\\n\",\n    \"    from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\\n\",\n    \"    from transformers import AutoConfig\\n\",\n    \"\\n\",\n    \"    lora_config = LoraConfig(\\n\",\n    \"        task_type=TaskType.CAUSAL_LM,\\n\",\n    \"        inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\\n\",\n    \"        target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    model = LLaMa(\\n\",\n    \"        pretrained_path=llm_pretrained_path,\\n\",\n    \"        peft_type=\\\"LoraConfig\\\",\\n\",\n    \"        peft_config=lora_config.to_dict(),\\n\",\n    \"        torch_dtype=\\\"bfloat16\\\"\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    pub_data = QaDataset(tokenizer_name_or_path=llm_pretrained_path,\\n\",\n    \"                         dataset_name=\\\"arc_challenge\\\",\\n\",\n    \"                         data_part=\\\"common\\\",\\n\",\n    \"                         seq_max_len=512,\\n\",\n    \"                         need_preprocess=True)\\n\",\n    \"    pub_data.load(process_data_output_dir)\\n\",\n    \"\\n\",\n    \"    training_args = FedMKTTrainingArguments(\\n\",\n    \"        global_epochs=global_epochs,\\n\",\n    \"        per_device_train_batch_size=1,\\n\",\n    \"        gradient_accumulation_steps=batch_size,\\n\",\n    \"        learning_rate=llm_lr,\\n\",\n    \"        output_dir=\\\"./\\\",\\n\",\n    \"        dataloader_num_workers=4,\\n\",\n    \"        remove_unused_columns=False,\\n\",\n    \"        warmup_ratio=0.008,\\n\",\n    \"        lr_scheduler_type=\\\"cosine\\\",\\n\",\n    \"        optim=\\\"adamw_torch\\\",\\n\",\n    \"        adam_beta1=0.9,\\n\",\n    \"        adam_beta2=0.95,\\n\",\n    \"        weight_decay=0.1,\\n\",\n    \"        max_grad_norm=1.0,\\n\",\n    \"        use_cpu=False,\\n\",\n    \"        vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size,\\n\",\n    \"        llm_training=False\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    slm_to_llm_vocab_mapping = []\\n\",\n    \"    with open(slm_to_llm_vocab_mapping_paths[slm_idx], \\\"r\\\") as fin:\\n\",\n    \"        vocab_mapping = json.loads(fin.read())\\n\",\n    \"        slm_to_llm_vocab_mapping.append(vocab_mapping)\\n\",\n    \"\\n\",\n    \"    slm_tokenizers = [get_tokenizer(slm_pretrained_paths[slm_idx])]\\n\",\n    \"\\n\",\n    \"    tokenizer = get_tokenizer(llm_pretrained_path)\\n\",\n    \"    trainer = FedMKTLLM(\\n\",\n    \"        ctx=ctx,\\n\",\n    \"        model=model,\\n\",\n    \"        training_args=training_args,\\n\",\n    \"        train_set=pub_data,\\n\",\n    \"        tokenizer=tokenizer,\\n\",\n    \"        slm_tokenizers=slm_tokenizers,\\n\",\n    \"        slm_to_llm_vocab_mappings=slm_to_llm_vocab_mapping,\\n\",\n    \"        save_trainable_weights_only=True,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    trainer.train()\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def train_slm(ctx, slm_idx):\\n\",\n    \"    import transformers\\n\",\n    \"    from peft import LoraConfig, TaskType\\n\",\n    \"    from fate_llm.model_zoo.pellm.llama import LLaMa\\n\",\n    \"    from fate_llm.model_zoo.pellm.gpt2 import GPT2CLM\\n\",\n    \"    from fate_llm.model_zoo.pellm.opt import OPT\\n\",\n    \"    from fate_llm.model_zoo.pellm.bloom import Bloom\\n\",\n    \"    from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTSLM\\n\",\n    \"    from fate_llm.dataset.qa_dataset import QaDataset\\n\",\n    \"    from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\\n\",\n    \"    from transformers import AutoConfig\\n\",\n    \"\\n\",\n    \"    slm_model_class = [\\n\",\n    \"        OPT,\\n\",\n    \"        GPT2CLM,\\n\",\n    \"        LLaMa,\\n\",\n    \"        Bloom\\n\",\n    \"    ]\\n\",\n    \"\\n\",\n    \"    lora_config = LoraConfig(\\n\",\n    \"        task_type=TaskType.CAUSAL_LM,\\n\",\n    \"        inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1,\\n\",\n    \"        target_modules=slm_lora_target_modules[slm_idx]\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    model = slm_model_class[slm_idx](\\n\",\n    \"        pretrained_path=slm_pretrained_paths[slm_idx],\\n\",\n    \"        peft_type=\\\"LoraConfig\\\",\\n\",\n    \"        peft_config=lora_config.to_dict(),\\n\",\n    \"        torch_dtype=\\\"bfloat16\\\"\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    priv_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\\n\",\n    \"                          dataset_name=\\\"arc_challenge\\\",\\n\",\n    \"                          data_part=f\\\"client_{slm_idx}\\\",\\n\",\n    \"                          seq_max_len=512,\\n\",\n    \"                          need_preprocess=True)\\n\",\n    \"    priv_data.load(process_data_output_dir)\\n\",\n    \"\\n\",\n    \"    pub_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\\n\",\n    \"                         dataset_name=\\\"arc_challenge\\\",\\n\",\n    \"                         data_part=\\\"common\\\",\\n\",\n    \"                         seq_max_len=512,\\n\",\n    \"                         need_preprocess=True)\\n\",\n    \"    pub_data.load(process_data_output_dir)\\n\",\n    \"\\n\",\n    \"    training_args = FedMKTTrainingArguments(\\n\",\n    \"        global_epochs=global_epochs,\\n\",\n    \"        per_device_train_batch_size=1,\\n\",\n    \"        gradient_accumulation_steps=batch_size,\\n\",\n    \"        learning_rate=slm_lrs[slm_idx],\\n\",\n    \"        output_dir=\\\"./\\\",\\n\",\n    \"        dataloader_num_workers=4,\\n\",\n    \"        remove_unused_columns=False,\\n\",\n    \"        warmup_ratio=0.008,\\n\",\n    \"        lr_scheduler_type=\\\"cosine\\\",\\n\",\n    \"        optim=\\\"adamw_torch\\\",\\n\",\n    \"        adam_beta1=0.9,\\n\",\n    \"        adam_beta2=0.95,\\n\",\n    \"        weight_decay=0.1,\\n\",\n    \"        max_grad_norm=1.0,\\n\",\n    \"        use_cpu=False,\\n\",\n    \"        vocab_size=AutoConfig.from_pretrained(slm_pretrained_paths[slm_idx]).vocab_size,\\n\",\n    \"        llm_training=False\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    tokenizer = get_tokenizer(slm_pretrained_paths[slm_idx])\\n\",\n    \"\\n\",\n    \"    import json\\n\",\n    \"    with open(llm_to_slm_vocab_mapping_paths[slm_idx], \\\"r\\\") as fin:\\n\",\n    \"        vocab_mapping = json.loads(fin.read())\\n\",\n    \"\\n\",\n    \"    trainer = FedMKTSLM(\\n\",\n    \"        ctx=ctx,\\n\",\n    \"        model=model,\\n\",\n    \"        training_args=training_args,\\n\",\n    \"        pub_train_set=pub_data,\\n\",\n    \"        priv_train_set=priv_data,\\n\",\n    \"        tokenizer=tokenizer,\\n\",\n    \"        save_trainable_weights_only=True,\\n\",\n    \"        llm_tokenizer=get_tokenizer(llm_pretrained_path),\\n\",\n    \"        llm_to_slm_vocab_mapping=vocab_mapping,\\n\",\n    \"        data_collator=transformers.DataCollatorForSeq2Seq(tokenizer)\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    trainer.train()\\n\",\n    \"    trainer.save_model(slm_models_saved_directory[slm_idx])\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def run(ctx: Context):\\n\",\n    \"    if ctx.is_on_arbiter:\\n\",\n    \"        os.environ[\\\"CUDA_VISIBLE_DEVICES\\\"] = \\\"0\\\"\\n\",\n    \"        train_llm(ctx, slm_idx=0)\\n\",\n    \"    else:\\n\",\n    \"        os.environ[\\\"CUDA_VISIBLE_DEVICES\\\"] = \\\"1\\\"\\n\",\n    \"        train_slm(ctx, slm_idx=0)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"if __name__ == \\\"__main__\\\":\\n\",\n    \"    launch(run)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Running FEDMKT With Launcher (Experimential Using): 4-SLMs Homogeneous SFT\\n\",\n    \"\\n\",\n    \"To run homogeneous experiments, two steps are needed.\\n\",\n    \"1. add post_fedavg=True to fedmkt_training_args to both LLM and LLM is enough!\\n\",\n    \"2. add fed_args to FEDMKTLLM/FEDMKTSLM\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# initialze fed args\\n\",\n    \"from fate.ml.nn.homo.fedavg import FedAVGArguments\\n\",\n    \"\\n\",\n    \"fed_args = FedAVGArguments(\\n\",\n    \"    aggregate_strategy='epoch',\\n\",\n    \"    aggregate_freq=1\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Complete Code To DO SFT With 4-SLMs Homogeneous SFT\\n\",\n    \"\\n\",\n    \"Please paste the code in \\\"fedmkt_4_slms_homo.py\\\" and execute it with the following command\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"```python\\n\",\n    \"python fedmkt_4_slms_homo.py --parties guest:9999 host:9999 host:10000 host:10001 arbiter:9999 --log_level INFO\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# fedmkt_4_slms_homo.py\\n\",\n    \"\\n\",\n    \"import os\\n\",\n    \"\\n\",\n    \"from fate.arch import Context\\n\",\n    \"from fate.arch.launchers.multiprocess_launcher import launch\\n\",\n    \"import json\\n\",\n    \"\\n\",\n    \"process_data_output_dir = \\\"\\\"\\n\",\n    \"llm_pretrained_path = \\\"Llama-2-7b-hf\\\"\\n\",\n    \"slm_0_pretrained_path = \\\"opt-1.3b\\\"\\n\",\n    \"slm_1_pretrained_path = \\\"opt-1.3b\\\"\\n\",\n    \"slm_2_pretrained_path = \\\"opt-1.3b\\\"\\n\",\n    \"slm_3_pretrained_path = \\\"opt-1.3b\\\"\\n\",\n    \"llm_slm_pairs = [\\n\",\n    \"    (llm_pretrained_path, slm_0_pretrained_path),\\n\",\n    \"    (llm_pretrained_path, slm_1_pretrained_path),\\n\",\n    \"    (llm_pretrained_path, slm_2_pretrained_path),\\n\",\n    \"    (llm_pretrained_path, slm_3_pretrained_path)\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"vocab_mapping_directory = \\\"\\\"\\n\",\n    \"\\n\",\n    \"slm_to_llm_vocab_mapping_paths = [\\\"opt_to_llama.json\\\"] * 4\\n\",\n    \"llm_to_slm_vocab_mapping_paths = [\\\"llama_to_opt.json\\\"] * 4\\n\",\n    \"\\n\",\n    \"for idx in range(4):\\n\",\n    \"    slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + \\\"/\\\" + slm_to_llm_vocab_mapping_paths[idx]\\n\",\n    \"    llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + \\\"/\\\" + llm_to_slm_vocab_mapping_paths[idx]\\n\",\n    \"\\n\",\n    \"slm_pretrained_paths = [slm_0_pretrained_path] * 4\\n\",\n    \"slm_lora_target_modules = [[\\\"q_proj\\\", \\\"v_proj\\\"]] * 4\\n\",\n    \"\\n\",\n    \"global_epochs = 5\\n\",\n    \"batch_size = 4\\n\",\n    \"llm_lr = 3e-5\\n\",\n    \"slm_lrs = [3e-5, 3e-5, 3e-5, 3e-5, 3e-5]\\n\",\n    \"\\n\",\n    \"llm_model_saved_directory = \\\"./models/fedmkt_homo_4_slms_llm_model\\\"\\n\",\n    \"slm_models_saved_directory = [\\n\",\n    \"    \\\"./models/fedmkt_homo_4_slms_slm_0\\\",\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def train_llm(ctx):\\n\",\n    \"    from peft import LoraConfig, TaskType\\n\",\n    \"    from fate_llm.model_zoo.pellm.llama import LLaMa\\n\",\n    \"    from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTLLM\\n\",\n    \"    from fate.ml.nn.homo.fedavg import FedAVGArguments\\n\",\n    \"    from fate_llm.dataset.qa_dataset import QaDataset\\n\",\n    \"    from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\\n\",\n    \"    from transformers import AutoConfig\\n\",\n    \"\\n\",\n    \"    lora_config = LoraConfig(\\n\",\n    \"        task_type=TaskType.CAUSAL_LM,\\n\",\n    \"        inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1,\\n\",\n    \"        target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    model = LLaMa(\\n\",\n    \"        pretrained_path=llm_pretrained_path,\\n\",\n    \"        peft_type=\\\"LoraConfig\\\",\\n\",\n    \"        peft_config=lora_config.to_dict(),\\n\",\n    \"        torch_dtype=\\\"bfloat16\\\"\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    pub_data = QaDataset(tokenizer_name_or_path=llm_pretrained_path,\\n\",\n    \"                         dataset_name=\\\"arc_challenge\\\",\\n\",\n    \"                         data_part=\\\"common\\\",\\n\",\n    \"                         seq_max_len=512,\\n\",\n    \"                         need_preprocess=True)\\n\",\n    \"    pub_data.load(process_data_output_dir)\\n\",\n    \"\\n\",\n    \"    training_args = FedMKTTrainingArguments(\\n\",\n    \"        global_epochs=global_epochs,\\n\",\n    \"        per_device_train_batch_size=1,\\n\",\n    \"        gradient_accumulation_steps=batch_size,\\n\",\n    \"        learning_rate=llm_lr,\\n\",\n    \"        output_dir=\\\"./\\\",\\n\",\n    \"        dataloader_num_workers=4,\\n\",\n    \"        remove_unused_columns=False,\\n\",\n    \"        warmup_ratio=0.008,\\n\",\n    \"        lr_scheduler_type=\\\"cosine\\\",\\n\",\n    \"        optim=\\\"adamw_torch\\\",\\n\",\n    \"        adam_beta1=0.9,\\n\",\n    \"        adam_beta2=0.95,\\n\",\n    \"        weight_decay=0.1,\\n\",\n    \"        max_grad_norm=1.0,\\n\",\n    \"        use_cpu=False,\\n\",\n    \"        vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size,\\n\",\n    \"        post_fedavg=True, # difference\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    # difference\\n\",\n    \"    fed_args = FedAVGArguments(\\n\",\n    \"        aggregate_strategy='epoch',\\n\",\n    \"        aggregate_freq=1\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    slm_to_llm_vocab_mapping = []\\n\",\n    \"    for path in slm_to_llm_vocab_mapping_paths:\\n\",\n    \"        with open(path, \\\"r\\\") as fin:\\n\",\n    \"            vocab_mapping = json.loads(fin.read())\\n\",\n    \"            slm_to_llm_vocab_mapping.append(vocab_mapping)\\n\",\n    \"\\n\",\n    \"    slm_tokenizers = [get_tokenizer(slm_pretrained_path) for slm_pretrained_path in slm_pretrained_paths]\\n\",\n    \"\\n\",\n    \"    tokenizer = get_tokenizer(llm_pretrained_path)\\n\",\n    \"    trainer = FedMKTLLM(\\n\",\n    \"        ctx=ctx,\\n\",\n    \"        model=model,\\n\",\n    \"        training_args=training_args,\\n\",\n    \"        fed_args=fed_args, # difference\\n\",\n    \"        train_set=pub_data,\\n\",\n    \"        tokenizer=tokenizer,\\n\",\n    \"        slm_tokenizers=slm_tokenizers,\\n\",\n    \"        slm_to_llm_vocab_mappings=slm_to_llm_vocab_mapping,\\n\",\n    \"        save_trainable_weights_only=True,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    trainer.train()\\n\",\n    \"    trainer.save_model(llm_model_saved_directory)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def train_slm(ctx, slm_idx):\\n\",\n    \"    import transformers\\n\",\n    \"    from peft import LoraConfig, TaskType\\n\",\n    \"    from fate_llm.model_zoo.pellm.opt import OPT\\n\",\n    \"    from fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTSLM\\n\",\n    \"    from fate.ml.nn.homo.fedavg import FedAVGArguments\\n\",\n    \"    from fate_llm.dataset.qa_dataset import QaDataset\\n\",\n    \"    from fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\\n\",\n    \"    from transformers import AutoConfig\\n\",\n    \"\\n\",\n    \"    slm_model_class = [OPT] * 4\\n\",\n    \"\\n\",\n    \"    lora_config = LoraConfig(\\n\",\n    \"        task_type=TaskType.CAUSAL_LM,\\n\",\n    \"        inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1,\\n\",\n    \"        target_modules=slm_lora_target_modules[slm_idx]\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    model = slm_model_class[slm_idx](\\n\",\n    \"        pretrained_path=slm_pretrained_paths[slm_idx],\\n\",\n    \"        peft_type=\\\"LoraConfig\\\",\\n\",\n    \"        peft_config=lora_config.to_dict(),\\n\",\n    \"        torch_dtype=\\\"bfloat16\\\"\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    priv_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\\n\",\n    \"                          dataset_name=\\\"arc_challenge\\\",\\n\",\n    \"                          data_part=f\\\"client_{slm_idx}\\\",\\n\",\n    \"                          seq_max_len=512,\\n\",\n    \"                          need_preprocess=True)\\n\",\n    \"    priv_data.load(process_data_output_dir)\\n\",\n    \"\\n\",\n    \"    pub_data = QaDataset(tokenizer_name_or_path=slm_pretrained_paths[slm_idx],\\n\",\n    \"                         dataset_name=\\\"arc_challenge\\\",\\n\",\n    \"                         data_part=\\\"common\\\",\\n\",\n    \"                         seq_max_len=512,\\n\",\n    \"                         need_preprocess=True)\\n\",\n    \"    pub_data.load(process_data_output_dir)\\n\",\n    \"\\n\",\n    \"    training_args = FedMKTTrainingArguments(\\n\",\n    \"        global_epochs=global_epochs,\\n\",\n    \"        per_device_train_batch_size=1,\\n\",\n    \"        gradient_accumulation_steps=batch_size,\\n\",\n    \"        learning_rate=slm_lrs[slm_idx],\\n\",\n    \"        output_dir=\\\"./\\\",\\n\",\n    \"        dataloader_num_workers=4,\\n\",\n    \"        remove_unused_columns=False,\\n\",\n    \"        warmup_ratio=0.008,\\n\",\n    \"        lr_scheduler_type=\\\"cosine\\\",\\n\",\n    \"        optim=\\\"adamw_torch\\\",\\n\",\n    \"        adam_beta1=0.9,\\n\",\n    \"        adam_beta2=0.95,\\n\",\n    \"        weight_decay=0.1,\\n\",\n    \"        max_grad_norm=1.0,\\n\",\n    \"        use_cpu=False,\\n\",\n    \"        vocab_size=AutoConfig.from_pretrained(slm_pretrained_paths[slm_idx]).vocab_size,\\n\",\n    \"        post_fedavg=True, # difference\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    # difference\\n\",\n    \"    fed_args = FedAVGArguments(\\n\",\n    \"        aggregate_strategy='epoch',\\n\",\n    \"        aggregate_freq=1\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    tokenizer = get_tokenizer(slm_pretrained_paths[slm_idx])\\n\",\n    \"\\n\",\n    \"    import json\\n\",\n    \"    with open(llm_to_slm_vocab_mapping_paths[slm_idx], \\\"r\\\") as fin:\\n\",\n    \"        vocab_mapping = json.loads(fin.read())\\n\",\n    \"\\n\",\n    \"    trainer = FedMKTSLM(\\n\",\n    \"        ctx=ctx,\\n\",\n    \"        model=model,\\n\",\n    \"        training_args=training_args, \\n\",\n    \"        fed_args=fed_args, # difference\\n\",\n    \"        pub_train_set=pub_data,\\n\",\n    \"        priv_train_set=priv_data,\\n\",\n    \"        tokenizer=tokenizer,\\n\",\n    \"        save_trainable_weights_only=True,\\n\",\n    \"        llm_tokenizer=get_tokenizer(llm_pretrained_path),\\n\",\n    \"        llm_to_slm_vocab_mapping=vocab_mapping,\\n\",\n    \"        data_collator=transformers.DataCollatorForSeq2Seq(tokenizer)\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    trainer.train()\\n\",\n    \"    if slm_idx == 0:\\n\",\n    \"        trainer.save_model(slm_models_saved_directory[slm_idx])\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def run(ctx: Context):\\n\",\n    \"    if ctx.is_on_arbiter:\\n\",\n    \"        os.environ[\\\"CUDA_VISIBLE_DEVICES\\\"] = \\\"0\\\"\\n\",\n    \"        train_llm(ctx)\\n\",\n    \"    elif ctx.is_on_guest:\\n\",\n    \"        os.environ[\\\"CUDA_VISIBLE_DEVICES\\\"] = \\\"1\\\"\\n\",\n    \"        train_slm(ctx, slm_idx=0)\\n\",\n    \"    else:\\n\",\n    \"        if ctx.local.party[1] == \\\"9999\\\":\\n\",\n    \"            os.environ[\\\"CUDA_VISIBLE_DEVICES\\\"] = \\\"2\\\"\\n\",\n    \"            slm_idx = 1\\n\",\n    \"        elif ctx.local.party[1] == \\\"10000\\\":\\n\",\n    \"            os.environ[\\\"CUDA_VISIBLE_DEVICES\\\"] = \\\"3\\\"\\n\",\n    \"            slm_idx = 2\\n\",\n    \"        elif ctx.local.party[1] == \\\"10001\\\":\\n\",\n    \"            os.environ[\\\"CUDA_VISIBLE_DEVICES\\\"] = \\\"4\\\"\\n\",\n    \"            slm_idx = 3\\n\",\n    \"        else:\\n\",\n    \"            raise ValueError(f\\\"party_id={ctx.local.party[1]} is illegal\\\")\\n\",\n    \"\\n\",\n    \"        train_slm(ctx, slm_idx=slm_idx)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"if __name__ == \\\"__main__\\\":\\n\",\n    \"    launch(run)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Running FEDMKT with Pipeline (Industrial Using)\\n\",\n    \"\\n\",\n    \"Please make sure that [FATE-LLM Cluster](https://github.com/FederatedAI/FATE/wiki/Download#llm%E9%83%A8%E7%BD%B2%E5%8C%85) has been deployed, ensure that multiple machines has been deployed in FATE-LLM Cluster mode, past the following code to test_fedmkt_4_slms.py, the execute \\\"python test_fedmkt_4_slms.py\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_fedmkt_runner\\n\",\n    \"from fate_client.pipeline.components.fate.nn.algo_params import FedMKTTrainingArguments, FedAVGArguments\\n\",\n    \"from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMDataFuncLoader\\n\",\n    \"from peft import LoraConfig, TaskType\\n\",\n    \"from fate_client.pipeline import FateFlowPipeline\\n\",\n    \"from fate_client.pipeline.components.fate.reader import Reader\\n\",\n    \"from transformers import AutoConfig\\n\",\n    \"\\n\",\n    \"guest = '9999' # replace this party id to actual guest party id in your enviroment\\n\",\n    \"host = ['9999', '10000', '10001'] # replace host party ids in your enviroment\\n\",\n    \"arbiter = '9999' # replace this party id to actual arbiter party id in your enviroment\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"process_data_output_dir = \\\"\\\" # replace this to actual process_data_output_dir\\n\",\n    \"# replaoce the names of models to local save directories\\n\",\n    \"llm_pretrained_path = \\\"llama-2-7b-hf\\\"\\n\",\n    \"slm_0_pretrained_path = \\\"opt-1.3b\\\"\\n\",\n    \"slm_1_pretrained_path = \\\"gpt2-xl\\\"\\n\",\n    \"slm_2_pretrained_path = \\\"Sheared-LLaMA-1.3B\\\"\\n\",\n    \"slm_3_pretrained_path = \\\"bloom-1b1\\\"\\n\",\n    \"llm_slm_pairs = [\\n\",\n    \"    (llm_pretrained_path, slm_0_pretrained_path),\\n\",\n    \"    (llm_pretrained_path, slm_1_pretrained_path),\\n\",\n    \"    (llm_pretrained_path, slm_2_pretrained_path),\\n\",\n    \"    (llm_pretrained_path, slm_3_pretrained_path)\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"vocab_mapping_directory = \\\"\\\" # reploace this to actual voacb_mapping_directory\\n\",\n    \"\\n\",\n    \"slm_to_llm_vocab_mapping_paths = [\\\"opt_to_llama.json\\\", \\\"gpt2_to_llama.json\\\", \\\"llama_small_to_llama.json\\\", \\\"bloom_to_llama.json\\\"]\\n\",\n    \"llm_to_slm_vocab_mapping_paths = [\\\"llama_to_opt.json\\\", \\\"llama_to_gpt2.json\\\", \\\"llama_to_llama_small\\\", \\\"llama_to_bloom.json\\\"]\\n\",\n    \"\\n\",\n    \"for idx in range(4):\\n\",\n    \"    slm_to_llm_vocab_mapping_paths[idx] = vocab_mapping_directory + \\\"/\\\" + slm_to_llm_vocab_mapping_paths[idx]\\n\",\n    \"    llm_to_slm_vocab_mapping_paths[idx] = vocab_mapping_directory + \\\"/\\\" + llm_to_slm_vocab_mapping_paths[idx]\\n\",\n    \"\\n\",\n    \"slm_pretrained_paths = [slm_0_pretrained_path, slm_1_pretrained_path, slm_2_pretrained_path, slm_3_pretrained_path]\\n\",\n    \"slm_lora_target_modules = [\\n\",\n    \"    [\\\"q_proj\\\", \\\"v_proj\\\"],\\n\",\n    \"    [\\\"c_attn\\\"],\\n\",\n    \"    ['q_proj', 'k_proj', 'v_proj', 'o_proj'],\\n\",\n    \"    [\\\"query_key_value\\\"]\\n\",\n    \"]\\n\",\n    \"slm_models = [\\n\",\n    \"    (\\\"pellm.opt\\\", \\\"OPT\\\"),\\n\",\n    \"    (\\\"pellm.gpt2\\\", \\\"GPT2CLM\\\"),\\n\",\n    \"    (\\\"pellm.llama\\\", \\\"LLaMa\\\"),\\n\",\n    \"    (\\\"pellm.bloom\\\", \\\"Bloom\\\")\\n\",\n    \"]\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def get_llm_conf():\\n\",\n    \"    lora_config = LoraConfig(\\n\",\n    \"        task_type=TaskType.CAUSAL_LM,\\n\",\n    \"        inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.05,\\n\",\n    \"        target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']\\n\",\n    \"    )\\n\",\n    \"    lora_config.target_modules = list(lora_config.target_modules)\\n\",\n    \"\\n\",\n    \"    llm_model = LLMModelLoader(\\n\",\n    \"        \\\"pellm.llama\\\",\\n\",\n    \"        \\\"LLaMa\\\",\\n\",\n    \"        pretrained_path=llm_pretrained_path,\\n\",\n    \"        peft_type=\\\"LoraConfig\\\",\\n\",\n    \"        peft_config=lora_config.to_dict(),\\n\",\n    \"        torch_dtype=\\\"bfloat16\\\"\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    pub_dataset = LLMDatasetLoader(\\n\",\n    \"        \\\"qa_dataset\\\",\\n\",\n    \"        \\\"QaDataset\\\",\\n\",\n    \"        tokenizer_name_or_path=llm_pretrained_path,\\n\",\n    \"        need_preprocess=True,\\n\",\n    \"        dataset_name=\\\"arc_challenge\\\",\\n\",\n    \"        data_part=\\\"common\\\",\\n\",\n    \"        seq_max_len=512\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    training_args = FedMKTTrainingArguments(\\n\",\n    \"        global_epochs=5,\\n\",\n    \"        per_device_train_batch_size=1,\\n\",\n    \"        gradient_accumulation_steps=4,\\n\",\n    \"        learning_rate=3e-5,\\n\",\n    \"        output_dir=\\\"./\\\",\\n\",\n    \"        dataloader_num_workers=4,\\n\",\n    \"        remove_unused_columns=False,\\n\",\n    \"        warmup_ratio=0.008,\\n\",\n    \"        lr_scheduler_type=\\\"cosine\\\",\\n\",\n    \"        optim=\\\"adamw_torch\\\",\\n\",\n    \"        adam_beta1=0.9,\\n\",\n    \"        adam_beta2=0.95,\\n\",\n    \"        weight_decay=0.1,\\n\",\n    \"        max_grad_norm=1.0,\\n\",\n    \"        use_cpu=False,\\n\",\n    \"        vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    fed_args = FedAVGArguments(\\n\",\n    \"        aggregate_strategy='epoch',\\n\",\n    \"        aggregate_freq=1\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    tokenizer = LLMDataFuncLoader(\\n\",\n    \"        \\\"tokenizers.cust_tokenizer\\\",\\n\",\n    \"        \\\"get_tokenizer\\\",\\n\",\n    \"        tokenizer_name_or_path=llm_pretrained_path\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    slm_tokenizers = list()\\n\",\n    \"    for slm_pretrained_path in slm_pretrained_paths:\\n\",\n    \"        slm_tokenizers.append(\\n\",\n    \"            LLMDataFuncLoader(\\\"tokenizers.cust_tokenizer\\\", \\\"get_tokenizer\\\", tokenizer_name_or_path=slm_pretrained_path)\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"    return get_config_of_fedmkt_runner(\\n\",\n    \"        model=llm_model,\\n\",\n    \"        training_args=training_args,\\n\",\n    \"        fed_args=fed_args,\\n\",\n    \"        pub_dataset=pub_dataset,\\n\",\n    \"        tokenizer=tokenizer,\\n\",\n    \"        slm_tokenizers=slm_tokenizers,\\n\",\n    \"        slm_to_llm_vocab_mapping_paths=slm_to_llm_vocab_mapping_paths,\\n\",\n    \"        pub_dataset_path=process_data_output_dir,\\n\",\n    \"        save_trainable_weights_only=True,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def get_slm_conf(slm_idx):\\n\",\n    \"    slm_pretrained_path = slm_pretrained_paths[slm_idx]\\n\",\n    \"    lora_config = LoraConfig(\\n\",\n    \"        task_type=TaskType.CAUSAL_LM,\\n\",\n    \"        inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1,\\n\",\n    \"        target_modules=slm_lora_target_modules[slm_idx]\\n\",\n    \"    )\\n\",\n    \"    lora_config.target_modules = list(lora_config.target_modules)\\n\",\n    \"    llm_to_slm_vocab_mapping = llm_to_slm_vocab_mapping_paths[slm_idx]\\n\",\n    \"\\n\",\n    \"    slm_model = LLMModelLoader(\\n\",\n    \"        slm_models[slm_idx][0],\\n\",\n    \"        slm_models[slm_idx][1],\\n\",\n    \"        pretrained_path=slm_pretrained_path,\\n\",\n    \"        peft_type=\\\"LoraConfig\\\",\\n\",\n    \"        peft_config=lora_config.to_dict(),\\n\",\n    \"    )\\n\",\n    \"    vocab_size = AutoConfig.from_pretrained(slm_pretrained_path).vocab_size\\n\",\n    \"\\n\",\n    \"    pub_dataset = LLMDatasetLoader(\\n\",\n    \"        \\\"qa_dataset\\\",\\n\",\n    \"        \\\"QaDataset\\\",\\n\",\n    \"        tokenizer_name_or_path=slm_pretrained_path,\\n\",\n    \"        need_preprocess=True,\\n\",\n    \"        dataset_name=\\\"arc_challenge\\\",\\n\",\n    \"        data_part=\\\"common\\\",\\n\",\n    \"        seq_max_len=512\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    priv_dataset = LLMDatasetLoader(\\n\",\n    \"        \\\"qa_dataset\\\",\\n\",\n    \"        \\\"QaDataset\\\",\\n\",\n    \"        tokenizer_name_or_path=slm_pretrained_path,\\n\",\n    \"        need_preprocess=True,\\n\",\n    \"        dataset_name=\\\"arc_challenge\\\",\\n\",\n    \"        data_part=\\\"client_0\\\",\\n\",\n    \"        seq_max_len=512\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    training_args = FedMKTTrainingArguments(\\n\",\n    \"        global_epochs=5,\\n\",\n    \"        per_device_train_batch_size=1,\\n\",\n    \"        gradient_accumulation_steps=4,\\n\",\n    \"        learning_rate=3e-5 if slm_idx != 1 else 3e-4\\n\",\n    \"        output_dir=\\\"./\\\",\\n\",\n    \"        dataloader_num_workers=4,\\n\",\n    \"        remove_unused_columns=False,\\n\",\n    \"        warmup_ratio=0.008,\\n\",\n    \"        lr_scheduler_type=\\\"cosine\\\",\\n\",\n    \"        optim=\\\"adamw_torch\\\",\\n\",\n    \"        adam_beta1=0.9,\\n\",\n    \"        adam_beta2=0.95,\\n\",\n    \"        weight_decay=0.1,\\n\",\n    \"        max_grad_norm=1.0,\\n\",\n    \"        use_cpu=False,\\n\",\n    \"        vocab_size=vocab_size,\\n\",\n    \"        # post_fedavg=True,\\n\",\n    \"        # llm_training=False,\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    fed_args = FedAVGArguments(\\n\",\n    \"        aggregate_strategy='epoch',\\n\",\n    \"        aggregate_freq=1\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    tokenizer = LLMDataFuncLoader(\\n\",\n    \"        \\\"tokenizers.cust_tokenizer\\\",\\n\",\n    \"        \\\"get_tokenizer\\\",\\n\",\n    \"        tokenizer_name_or_path=slm_pretrained_path\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    llm_tokenizer = LLMDataFuncLoader(\\n\",\n    \"        \\\"tokenizers.cust_tokenizer\\\", \\\"get_tokenizer\\\", tokenizer_name_or_path=llm_pretrained_path\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    data_collator = LLMDataFuncLoader(module_name='data_collator.cust_data_collator',\\n\",\n    \"                                      item_name='get_seq2seq_data_collator', tokenizer_name_or_path=slm_pretrained_path)\\n\",\n    \"\\n\",\n    \"    return get_config_of_fedmkt_runner(\\n\",\n    \"        model=slm_model,\\n\",\n    \"        training_args=training_args,\\n\",\n    \"        fed_args=fed_args,\\n\",\n    \"        pub_dataset=pub_dataset,\\n\",\n    \"        priv_dataset=priv_dataset,\\n\",\n    \"        tokenizer=tokenizer,\\n\",\n    \"        llm_tokenizer=llm_tokenizer,\\n\",\n    \"        llm_to_slm_vocab_mapping_path=llm_to_slm_vocab_mapping,\\n\",\n    \"        pub_dataset_path=process_data_output_dir,\\n\",\n    \"        save_trainable_weights_only=True,\\n\",\n    \"        data_collator=data_collator\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter, host=host)\\n\",\n    \"pipeline.bind_local_path(path=process_data_output_dir, namespace=\\\"experiment\\\", name=\\\"arc_challenge\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"reader_0 = Reader(\\\"reader_0\\\", runtime_parties=dict(guest=guest, host=host))\\n\",\n    \"reader_0.guest.task_parameters(\\n\",\n    \"    namespace=\\\"experiment\\\",\\n\",\n    \"    name=\\\"arc_challenge\\\"\\n\",\n    \")\\n\",\n    \"reader_0.hosts[[0, 1, 2]].task_parameters(\\n\",\n    \"    namespace=\\\"experiment\\\",\\n\",\n    \"    name=\\\"arc_challenge\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"homo_nn_0 = HomoNN(\\n\",\n    \"    'nn_0',\\n\",\n    \"    train_data=reader_0.outputs[\\\"output_data\\\"],\\n\",\n    \"    runner_module=\\\"fedmkt_runner\\\",\\n\",\n    \"    runner_class=\\\"FedMKTRunner\\\",\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"homo_nn_0.arbiter.task_parameters(\\n\",\n    \"    runner_conf=get_llm_conf()\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"homo_nn_0.guest.task_parameters(\\n\",\n    \"    runner_conf=get_slm_conf(slm_idx=0)\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"for idx in range(3):\\n\",\n    \"    homo_nn_0.hosts[idx].task_parameters(\\n\",\n    \"        runner_conf=get_slm_conf(slm_idx=idx + 1)\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"homo_nn_0.guest.conf.set(\\\"launcher_name\\\", \\\"deepspeed\\\") # tell schedule engine to run task with deepspeed\\n\",\n    \"homo_nn_0.hosts[[0, 1, 2]].conf.set(\\\"launcher_name\\\", \\\"deepspeed\\\") # tell schedule engine to run task with deepspeed\\n\",\n    \"homo_nn_0.arbiter.conf.set(\\\"launcher_name\\\", \\\"deepspeed\\\") # tell schedule engine to run task with deepspeed\\n\",\n    \"\\n\",\n    \"pipeline.add_tasks([reader_0, homo_nn_0])\\n\",\n    \"pipeline.conf.set(\\\"task\\\", dict(engine_run={\\\"cores\\\": 1})) # the number of gpus of each party\\n\",\n    \"\\n\",\n    \"pipeline.compile()\\n\",\n    \"pipeline.fit()\\n\",\n    \"\\n\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.15\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "doc/tutorial/inferdpt/inferdpt_tutorial.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"341aeb6e-9e25-4a0e-9664-a32ab11293fa\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Inferdpt Tutorial\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"0b40afd5-77b9-45c6-a761-81b9a6bddc05\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Introduction of Inferdpt\\n\",\n    \"\\n\",\n    \"Inferdpt is an advanced algorithm framework designed for efficient and privacy-preserving text generation using large language models (LLMs). The framework addresses privacy concerns related to data leakage and unauthorized information collection in LLMs. Inferdpt implements Differential Privacy mechanisms to protect sensitive information during the inference process with black-box LLMs.\\n\",\n    \"\\n\",\n    \"Inferdpt comprises two key modules: the \\\"perturbation module\\\" and the \\\"extraction module\\\". The perturbation module utilizes a differentially private(DP) mechanism to generate a perturbed prompt from the raw document, facilitating privacy-preserving inference with black-box LLMs. The extraction module, inspired by knowledge distillation and retrieval-augmented generation, processes the perturbed text to produce coherent and consistent output. This ensures that the text generation quality of InferDPT is comparable to that of non-private LLMs, maintaining high utility while providing strong privacy guarantees.\\n\",\n    \"\\n\",\n    \"To further enhance privacy protection, Inferdpt integrates a novel mechanism called RANTEXT. RANTEXT introduces the concept of random adjacency list for token-level perturbation, addressing the vulnerability of existing differentially private mechanisms to embedding inversion attacks.\\n\",\n    \"\\n\",\n    \"For more details of Inferdpt, please refer to the [original paper](https://arxiv.org/pdf/2310.12214.pdf).\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ac982b2d-4a71-45a5-a2b1-90259711f36b\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Use InferDPT\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"042049c5-80ce-4786-9896-88baddd59f4e\",\n   \"metadata\": {},\n   \"source\": [\n    \"In this section, we will guide you through the process of:\\n\",\n    \"- Setting up the inferdpt toolkit with an existing language model.\\n\",\n    \"- Creating a model inference tool using the built-in class.\\n\",\n    \"- Executing a step-by-step walkthrough of an inference instance: Employing inferdpt to generate rationale responses for question-answering tasks.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e1938eef-106d-4cc0-a9b7-6ad8d9d281f5\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Create Inferdpt Kit\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"565aa2ed-5919-4aa0-9499-23b730434c62\",\n   \"metadata\": {},\n   \"source\": [\n    \"In alignment with the original paper, the implementation of differential privacy in inferdpt involves the random substitution of tokens in the original text with semantically similar words. To facilitate this process, it is necessary to precalculate the similarities between a subset of tokens from the vocabulary of the remote large language model. In this tutorial, we will utilize the Mistral-7B model as our remote large language model and the Qwen1.5-0.5B model as the local decoding model. For the sake of computational efficiency, we will select a subset of 11,400 tokens from the Mistral-7B vocabulary to perform the similarity calculations and use the built-in function to finally get the inferdpt-kit.\\n\",\n    \"\\n\",\n    \"Firstly we load the mistral model to get the embedding set:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"f01a229a-52e1-4a97-af06-a2ab122b7083\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# load embeddings from mistral model\\n\",\n    \"import numpy as np\\n\",\n    \"from transformers import AutoModelForCausalLM, AutoTokenizer\\n\",\n    \"model_path = '/data/cephfs/llm/models/Mistral-7B-Instruct-v0.2/'\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(model_path)\\n\",\n    \"model = AutoModelForCausalLM.from_pretrained(model_path)\\n\",\n    \"embeddings = tokenizer.get_vocab() # get embeddings matrix\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"id\": \"3f7ec40b-1a58-4608-b2c1-3299979e699a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Get the embedding layer weights\\n\",\n    \"dtype = np.float32\\n\",\n    \"embedding_weights = model.get_input_embeddings().weight\\n\",\n    \"# Convert the embedding layer weights to numpy\\n\",\n    \"embedding_weights_np = embedding_weights.detach().numpy().astype(dtype)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"07261aee-b676-4a42-9098-2923fa67519c\",\n   \"metadata\": {},\n   \"source\": [\n    \"Then we select english tokens from the vocabulary. Then we can get an embedding matrix and a corresponding token list.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 26,\n   \"id\": \"dbb231f9-f0ca-4add-bb45-f4fb59429abb\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32000/32000 [00:00<00:00, 663000.04it/s]\\n\"\n     ]\n    },\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"11400\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import tqdm\\n\",\n    \"import re\\n\",\n    \"\\n\",\n    \"def contains_english_chars(string):\\n\",\n    \"    pattern = r'[a-zA-Z]'\\n\",\n    \"    match = re.search(pattern, string)\\n\",\n    \"    return bool(match)\\n\",\n    \"\\n\",\n    \"def contains_non_english_chars(string):\\n\",\n    \"    pattern = r'[^a-zA-Z]'\\n\",\n    \"    match = re.search(pattern, string)\\n\",\n    \"    return bool(match)\\n\",\n    \"\\n\",\n    \"def filter_tokens(token2index):\\n\",\n    \"    filtered_index2token = {}\\n\",\n    \"    for key, idx in tqdm.tqdm(token2index.items()):\\n\",\n    \"        if key.startswith('<'):\\n\",\n    \"            continue\\n\",\n    \"        if not key.startswith('▁'):\\n\",\n    \"            continue\\n\",\n    \"        val_ = key.replace(\\\"▁\\\", \\\"\\\")\\n\",\n    \"        if val_ == val_.upper():\\n\",\n    \"            continue\\n\",\n    \"        if contains_non_english_chars(val_):\\n\",\n    \"            continue\\n\",\n    \"        if 3 < len(val_) < 16 and contains_english_chars(val_):\\n\",\n    \"            filtered_index2token[idx] = key\\n\",\n    \"\\n\",\n    \"    return filtered_index2token\\n\",\n    \"\\n\",\n    \"filtered_index2token = filter_tokens(embeddings)\\n\",\n    \"used_num_tokens = len(filtered_index2token)\\n\",\n    \"print(used_num_tokens)\\n\",\n    \"for idx, token in filtered_index2token.items():\\n\",\n    \"    token_2_embedding[token] = embedding_weights_np[idx].tolist()\\n\",\n    \"token_list = list(token_2_embedding.keys())\\n\",\n    \"embedding_matrix = np.array(list(token_2_embedding.values()), dtype=dtype)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 27,\n   \"id\": \"5922a177-d752-485d-98ab-9fd6688198f8\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"we got the embedding matrix:\\n\",\n      \"[[-6.1035156e-04 -4.5471191e-03 -5.2795410e-03 ... -1.3656616e-03\\n\",\n      \"   4.2419434e-03 -8.1634521e-04]\\n\",\n      \" [ 4.8522949e-03  5.9814453e-03  1.1596680e-03 ... -2.6702881e-03\\n\",\n      \"  -1.7471313e-03  9.9182129e-04]\\n\",\n      \" [-2.7465820e-03  4.3029785e-03  3.3874512e-03 ... -2.6092529e-03\\n\",\n      \"  -1.2397766e-05 -3.4027100e-03]\\n\",\n      \" ...\\n\",\n      \" [-6.1340332e-03 -5.3405762e-03 -1.0910034e-03 ... -9.3841553e-04\\n\",\n      \"  -7.4005127e-04 -7.3852539e-03]\\n\",\n      \" [-4.5166016e-03  8.2015991e-04  4.8217773e-03 ... -1.1978149e-03\\n\",\n      \"  -1.0528564e-03 -2.1362305e-03]\\n\",\n      \" [ 1.2054443e-03  1.9836426e-03 -2.8419495e-04 ... -1.5792847e-03\\n\",\n      \"  -2.8381348e-03 -7.1716309e-04]]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print('we got the embedding matrix:')\\n\",\n    \"print(embedding_matrix)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"20890d89-998f-4f38-968a-2a6a0648b050\",\n   \"metadata\": {},\n   \"source\": [\n    \"We can easily prepare the pre-computed data we needed for inferdpt by using the built-in function of the InferDPTKit class:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 28,\n   \"id\": \"c90c7099-d20d-4009-bb7a-aeb3b46210b2\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"11400it [00:37, 300.99it/s]\\n\",\n      \"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4096/4096 [00:03<00:00, 1147.93it/s]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from fate_llm.algo.inferdpt.utils import InferDPTKit\\n\",\n    \"param = InferDPTKit.make_inferdpt_kit_param(embedding_matrix, token_list)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"0fe3a722-5cdf-4393-90f3-e5d7b82051cf\",\n   \"metadata\": {},\n   \"source\": [\n    \"Great, the computation is complete! Now, let’s proceed to perturb a sentence using inferdpt with ε (epsilon) set to 3.0. We will also save the perturbed sentence to a designated folder for future reference.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 33,\n   \"id\": \"4a6acd81-bc7c-49d3-86f4-ad0b5c329e61\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"inferdpt_kit = InferDPTKit(*param, tokenizer)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 96,\n   \"id\": \"0077696c-0c10-4500-8835-6e72a084bc42\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"'into the tree to the woods'\"\n      ]\n     },\n     \"execution_count\": 96,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"inferdpt_kit.perturb('From the river to the ocean', epsilon=3.0)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 97,\n   \"id\": \"8df2f57a-e202-4d4f-a175-21990223dc3d\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"save_kit_path = 'your path'\\n\",\n    \"inferdpt_kit.save_to_path(save_kit_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ced7f1bf-aa49-4806-92e2-712493bb4b10\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Go through Inferdpt Step by Step\\n\",\n    \"\\n\",\n    \"Next, we will guide you through the process of using inferdpt step by step. We will simulate the interaction between the client and server locally. Before we begin, let’s discuss model inference. Within fate-llm's inferdpt module, we provide three types of model inference classes: vllm, vllm server, and Huggingface native. You can explore these classes in the [code files](../../../python/fate_llm/algo/inferdpt/inference/) or develop your own inference tool based on your specific needs. We highly recommend using vllm server. In this case, we will use the following two commands to launch two large model services, corresponding to the server’s LLM and the local decoding small model.\\n\",\n    \"\\n\",\n    \"For this example, we have executed the process on a machine equipped with four V100-32G GPUs. We advise you to adjust the model path and GPU settings as necessary to accommodate the specifications of your own machine.\\n\",\n    \"\\n\",\n    \"Start vllm server using commands below:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"b1b6c3c7-6ddd-4386-8700-c95f74a2bae0\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"! python -m vllm.entrypoints.openai.api_server --host 127.0.0.1 --port 8888 --model ./Mistral-7B-Instruct-v0.2  --dtype=half --enforce-eager --tensor-parallel-size 4 --gpu-memory-utilization 0.6\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"48374cb5-3a5b-456c-9d53-219c2468da63\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"! python -m vllm.entrypoints.openai.api_server --host 127.0.0.1 --port 8887 --model ./Qwen1.5-0.5B  --dtype=half --enforce-eager --tensor-parallel-size 4 --gpu-memory-utilization 0.2\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"375e3f0d-36c7-4ab3-8e65-cccac23e93c6\",\n   \"metadata\": {},\n   \"source\": [\n    \"Next, we will initialize the inference instance, which are the parameters for both the inferdpt client and server. This includes specifying the IP address, port, and the model name of the service that has been started.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 130,\n   \"id\": \"cd099ef4-569d-45b6-9765-502b688c3fb4\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_llm.inference.api import APICompletionInference\\n\",\n    \"# for client\\n\",\n    \"inference_client = APICompletionInference(api_url=\\\"http://127.0.0.1:8887/v1\\\", model_name='./Qwen1.5-0.5B', api_key='EMPTY')\\n\",\n    \"# for server\\n\",\n    \"inference_server = APICompletionInference(api_url=\\\"http://127.0.0.1:8888/v1\\\", model_name='./Mistral-7B-Instruct-v0.2', api_key='EMPTY')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 135,\n   \"id\": \"8c430c14-2180-4f02-8f06-3f41bae1a710\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \" I am a new user of this forum. I am a 20 year\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"ret = inference_client.inference(['Hello how are you?'], inference_kwargs={\\n\",\n    \"    'stop': ['<|im_end|>', '\\\\n'],\\n\",\n    \"    'temperature': 0.01,\\n\",\n    \"    'max_tokens': 16\\n\",\n    \"})\\n\",\n    \"print(ret[0])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 138,\n   \"id\": \"6341eb48-e30f-46d4-aeaa-8c6fd27259b9\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \" I am an artificial intelligence designed to assist with information and answer questions to the best of my ability. I don't have the ability to have a personal identity or emotions. I'm here to help you with any inquiries you may have. How can I assist you today?\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"ret = inference_server.inference(['<s>[INST]Who are u?[/INST]'], inference_kwargs={\\n\",\n    \"    'stop': ['</s>'],\\n\",\n    \"    'temperature': 0.01,\\n\",\n    \"    'max_tokens': 128\\n\",\n    \"})\\n\",\n    \"print(ret[0])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"90f5ce55-de3b-481a-a9cc-cb4c24edb7c2\",\n   \"metadata\": {},\n   \"source\": [\n    \"In this tutorial, we will use a question-answering (QA) task as our illustrative example. To do so, we will extract a sample from the ARC-E dataset for demonstration purposes, here is the example:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 100,\n   \"id\": \"f912f986-ae86-4d57-9ebc-534d6404173c\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"test_example = {'id': 'Mercury_7220990',\\n\",\n    \"'question': 'Which factor will most likely cause a person to develop a fever?',\\n\",\n    \"'choices': {'text': ['a leg muscle relaxing after exercise',\\n\",\n    \"'a bacterial population in the bloodstream',\\n\",\n    \"'several viral particles on the skin',\\n\",\n    \"'carbohydrates being digested in the stomach'],\\n\",\n    \"'label': ['A', 'B', 'C', 'D']},\\n\",\n    \"'answerKey': 'B'}\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"a98c74a8-7760-438b-b4f4-33178fed8761\",\n   \"metadata\": {},\n   \"source\": [\n    \"Before initiating the inference, it's crucial to understand the sequence of steps involved. We will leverage the Jinja2 template engine to structure our documentation as follows:\\n\",\n    \"\\n\",\n    \"1. **Document Template Organization**: The initial step is to organize the document dictionary using the DOC TEMPLATE. This template will provide the structure for the input document.\\n\",\n    \"\\n\",\n    \"2. **Differential Privacy Perturbation**: Apply Differential Privacy (DP) to perturb the structured document string. This will result in a perturbed document. The perturbed document is then added to the original document under the key 'perturbed_doc'. Note that you can modify this key according to your parameter settings.\\n\",\n    \"\\n\",\n    \"3. **Instruction Addition**: Use the INSTRUCTION TEMPLATE to add instructions (or few-shot examples) to the perturbed document. This modified document is then sent to the server side for processing. The server's response is captured, and this perturbed response is appended to the original document under the key 'perturbed_response'. As before, this key can be adjusted as needed.\\n\",\n    \"\\n\",\n    \"4. **Decode Template Formatting**: Finally, employ the decode template to format the decode prompt. The resulting inference is then added to the original dictionary under the key 'inferdpt_result'. This key, like the others, can be customized to fit your specific parameters.\\n\",\n    \"\\n\",\n    \"By following these steps, the inferdpt framework enables a structured and privacy-preserving inference process, leading to a final output that incorporates the perturbed data and the model's response.\\n\",\n    \"For more details, you can refer to the source codes:\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"09d7377a-22f4-4d04-b886-88faa1384d7f\",\n   \"metadata\": {},\n   \"source\": [\n    \"The templates for this example are defined on the client side. Below is the Jinja template we use:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 141,\n   \"id\": \"eff74a65-f765-483f-a685-418376414ff0\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"doc_template = \\\"\\\"\\\"{{question}} \\n\",\n    \"Choices:{{choices.text}}\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"instruction_template=\\\"\\\"\\\"\\n\",\n    \"<s>[INST]\\n\",\n    \"Select Answer from Choices and explain it in \\\"Rationale\\\" with few words. Please refer to the example to write the rationale.\\n\",\n    \"Use <end> to finish your rationle.\\\"\\n\",\n    \"\\n\",\n    \"Example(s):\\n\",\n    \"Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\\n\",\n    \"Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\\n\",\n    \"Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction.  Therefore, the answer is 'dry palms'.<end>\\n\",\n    \"\\n\",\n    \"Please explain:\\n\",\n    \"Question:{{perturbed_doc}}\\n\",\n    \"Rationale:\\n\",\n    \"[/INST]\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"decode_template = \\\"\\\"\\\"Select Answer from Choices and explain it in \\\"Rationale\\\" with few words. Please refer to the example to write the rationale.\\n\",\n    \"Use <end> to finish your rationle.\\\"\\n\",\n    \"\\n\",\n    \"Example(s):\\n\",\n    \"Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\\n\",\n    \"Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\\n\",\n    \"Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction.  Therefore, the answer is 'dry palms'.<end>\\n\",\n    \"\\n\",\n    \"Question:{{perturbed_doc}}\\n\",\n    \"Rationale:{{perturbed_response | replace('\\\\n', '')}}<end>\\n\",\n    \"\\n\",\n    \"Please explain:\\n\",\n    \"Question:{{question}} \\n\",\n    \"Choices:{{choices.text}}\\n\",\n    \"Rationale:\\n\",\n    \"\\\"\\\"\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"b3c02898-91df-48b6-a0e2-9af5bd5538d8\",\n   \"metadata\": {},\n   \"source\": [\n    \"Please be aware that we have included a one-shot example in the prompt to ensure that the Large Language Model (LLM) responds as anticipated.\\n\",\n    \"\\n\",\n    \"Now we create two script: \\n\",\n    \"- inferdpt_client.py\\n\",\n    \"- inferdpt_server.py\\n\",\n    \"\\n\",\n    \"And run codes provided below:\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f0f110a3-c601-4c7b-8e89-8684d2ae266d\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Client Side: inferdpt_client.py\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"006612b8-da8d-402c-9b6d-b6786325fa7c\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_llm.inference.api import APICompletionInference\\n\",\n    \"from fate_llm.algo.inferdpt import inferdpt\\n\",\n    \"from fate_llm.algo.inferdpt.utils import InferDPTKit\\n\",\n    \"import sys\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"arbiter = (\\\"arbiter\\\", 10000)\\n\",\n    \"guest = (\\\"guest\\\", 10000)\\n\",\n    \"host = (\\\"host\\\", 9999)\\n\",\n    \"name = \\\"fed1\\\"\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def create_ctx(local):\\n\",\n    \"    from fate.arch import Context\\n\",\n    \"    from fate.arch.computing.backends.standalone import CSession\\n\",\n    \"    from fate.arch.federation.backends.standalone import StandaloneFederation\\n\",\n    \"    import logging\\n\",\n    \"\\n\",\n    \"    logger = logging.getLogger()\\n\",\n    \"    logger.setLevel(logging.INFO)\\n\",\n    \"\\n\",\n    \"    console_handler = logging.StreamHandler()\\n\",\n    \"    console_handler.setLevel(logging.INFO)\\n\",\n    \"\\n\",\n    \"    formatter = logging.Formatter(\\\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\\\")\\n\",\n    \"    console_handler.setFormatter(formatter)\\n\",\n    \"\\n\",\n    \"    logger.addHandler(console_handler)\\n\",\n    \"    computing = CSession(data_dir=\\\"./session_dir\\\")\\n\",\n    \"    return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"ctx = create_ctx(guest)\\n\",\n    \"save_kit_path = 'your path'\\n\",\n    \"kit = InferDPTKit.load_from_path(save_kit_path)\\n\",\n    \"inference = APICompletionInference(api_url=\\\"http://127.0.0.1:8887/v1\\\", model_name='./Qwen1.5-0.5B', api_key='EMPTY')\\n\",\n    \"\\n\",\n    \"test_example = {'id': 'Mercury_7220990',\\n\",\n    \"'question': 'Which factor will most likely cause a person to develop a fever?',\\n\",\n    \"'choices': {'text': ['a leg muscle relaxing after exercise',\\n\",\n    \"'a bacterial population in the bloodstream',\\n\",\n    \"'several viral particles on the skin',\\n\",\n    \"'carbohydrates being digested in the stomach'],\\n\",\n    \"'label': ['A', 'B', 'C', 'D']},\\n\",\n    \"'answerKey': 'B'}\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"doc_template = \\\"\\\"\\\"{{question}} \\n\",\n    \"Choices:{{choices.text}}\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"instruction_template=\\\"\\\"\\\"\\n\",\n    \"<s>[INST]\\n\",\n    \"Select Answer from Choices and explain it in \\\"Rationale\\\" with few words. Please refer to the example to write the rationale.\\n\",\n    \"Use <end> to finish your rationle.\\\"\\n\",\n    \"\\n\",\n    \"Example(s):\\n\",\n    \"Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\\n\",\n    \"Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\\n\",\n    \"Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction.  Therefore, the answer is 'dry palms'.<end>\\n\",\n    \"\\n\",\n    \"Please explain:\\n\",\n    \"Question:{{perturbed_doc}}\\n\",\n    \"Rationale:\\n\",\n    \"[/INST]\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"decode_template = \\\"\\\"\\\"Select Answer from Choices and explain it in \\\"Rationale\\\" with few words. Please refer to the example to write the rationale.\\n\",\n    \"Use <end> to finish your rationle.\\\"\\n\",\n    \"\\n\",\n    \"Example(s):\\n\",\n    \"Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\\n\",\n    \"Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\\n\",\n    \"Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction.  Therefore, the answer is 'dry palms'.<end>\\n\",\n    \"\\n\",\n    \"Question:{{perturbed_doc}}\\n\",\n    \"Rationale:{{perturbed_response | replace('\\\\n', '')}}<end>\\n\",\n    \"\\n\",\n    \"Please explain:\\n\",\n    \"Question:{{question}} \\n\",\n    \"Choices:{{choices.text}}\\n\",\n    \"Rationale:\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"inferdpt_client = inferdpt.InferDPTClient(ctx, kit, inference, epsilon=3.0)\\n\",\n    \"result = inferdpt_client.inference([test_example], doc_template, instruction_template, decode_template, \\\\\\n\",\n    \"                                 remote_inference_kwargs={\\n\",\n    \"                                    'stop': ['<\\\\s>'],\\n\",\n    \"                                    'temperature': 0.01,\\n\",\n    \"                                    'max_tokens': 256\\n\",\n    \"                                 },\\n\",\n    \"                                 local_inference_kwargs={\\n\",\n    \"                                    'stop': ['<|im_end|>', '<end>', '<end>\\\\n', '<end>\\\\n\\\\n', '.\\\\n\\\\n\\\\n\\\\n\\\\n', '<|end_of_text|>', '>\\\\n\\\\n\\\\n'],\\n\",\n    \"                                    'temperature': 0.01,\\n\",\n    \"                                    'max_tokens': 256\\n\",\n    \"                                 })\\n\",\n    \"print('result is {}'.format(result[0]['inferdpt_result']))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e6ed3c0e-0b1f-4087-b155-def3ee957618\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Server Side: inferdpt_server.py\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"96e3e9fa-9554-4bcf-b8bf-358c469014bf\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_llm.algo.inferdpt.inferdpt import InferDPTServer\\n\",\n    \"import sys\\n\",\n    \"from fate_llm.algo.inferdpt.inference.api import APICompletionInference\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"arbiter = (\\\"arbiter\\\", 10000)\\n\",\n    \"guest = (\\\"guest\\\", 10000)\\n\",\n    \"host = (\\\"host\\\", 9999)\\n\",\n    \"name = \\\"fed1\\\"\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def create_ctx(local):\\n\",\n    \"    from fate.arch import Context\\n\",\n    \"    from fate.arch.computing.backends.standalone import CSession\\n\",\n    \"    from fate.arch.federation.backends.standalone import StandaloneFederation\\n\",\n    \"    import logging\\n\",\n    \"\\n\",\n    \"    logger = logging.getLogger()\\n\",\n    \"    logger.setLevel(logging.INFO)\\n\",\n    \"\\n\",\n    \"    console_handler = logging.StreamHandler()\\n\",\n    \"    console_handler.setLevel(logging.INFO)\\n\",\n    \"\\n\",\n    \"    formatter = logging.Formatter(\\\"%(asctime)s - %(name)s - %(levelname)s - %(message)s\\\")\\n\",\n    \"    console_handler.setFormatter(formatter)\\n\",\n    \"\\n\",\n    \"    logger.addHandler(console_handler)\\n\",\n    \"    computing = CSession(data_dir=\\\"./session_dir\\\")\\n\",\n    \"    return Context(computing=computing, federation=StandaloneFederation(computing, name, local, [guest, host, arbiter]))\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"ctx = create_ctx(arbiter)\\n\",\n    \"inference_server = APICompletionInference(api_url=\\\"http://127.0.0.1:8888/v1\\\", model_name='./Mistral-7B-Instruct-v0.2', api_key='EMPTY')\\n\",\n    \"inferdpt_server = InferDPTServer(ctx, inference)\\n\",\n    \"inferdpt_server.inference()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"bfef704b-179e-44cd-84dc-a40b036e7f28\",\n   \"metadata\": {},\n   \"source\": [\n    \"Start two terminal and launch client&server scripts simultaneously.\\n\",\n    \"On the client side we can get the answer:\\n\",\n    \"\\n\",\n    \"```\\n\",\n    \"The given question asks which factor will most likely cause a person to develop a fever. The factors mentioned are a leg muscle relaxing after exercise, a bacterial population in the bloodstream, several viral particles on the skin, and carbohydrates being digested in the stomach. The question is asking which factor is most likely to cause a person to develop a fever. The factors are all related to the body's internal environment, but the most likely factor is a bacterial population in the bloodstream. This is because bacteria can cause a fever, and the body's immune system responds to the infection by producing antibodies that can fight off the bacteria. Therefore, the answer is 'a bacterial population in the bloodstream'\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"adf80b4e-4727-4ee4-b0b1-4839bd516f4f\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Use Inferdpt in FATE Pipeline\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"b9b560e3-8db4-4828-a4fc-494320a9a3e5\",\n   \"metadata\": {},\n   \"source\": [\n    \"We can leverage the FATE pipeline to submit inference tasks for industrial applications. When operating in pipeline mode, to safeguard against privacy breaches such as API key or server path leakage, it is crucial to create initialization scripts for establishing inferdpt client instances. Alternatively, you can modify the provided scripts within the fate_llm/algo/inferdpt/init folder.\\n\",\n    \"\\n\",\n    \"Below, we provide an overview of the default_init.py script, which serves as an example of how to create an [initialization class](../../../python/fate_llm/algo/inferdpt/init/default_init.py). By customizing the static variables within this class, you can configure the client and server to interact with the Large Language Model (LLM) interfaces as intended.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"eab49960-0541-4059-b84d-bee4bb690974\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_llm.algo.inferdpt.init._init import InferClientInit\\n\",\n    \"from fate_llm.inference.api import APICompletionInference\\n\",\n    \"from fate_llm.algo.inferdpt import inferdpt\\n\",\n    \"from fate_llm.algo.inferdpt.utils import InferDPTKit\\n\",\n    \"from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"class InferDPTAPIClientInit(InferClientInit):\\n\",\n    \"\\n\",\n    \"    api_url = ''\\n\",\n    \"    api_model_name = ''\\n\",\n    \"    api_key = 'EMPTY'\\n\",\n    \"    inferdpt_kit_path = ''\\n\",\n    \"    eps = 3.0\\n\",\n    \"\\n\",\n    \"    def __init__(self, ctx):\\n\",\n    \"        super().__init__(ctx)\\n\",\n    \"        self.ctx = ctx\\n\",\n    \"\\n\",\n    \"    def get_inst(self)-> InferDPTClient:\\n\",\n    \"        inference = APICompletionInference(api_url=self.api_url, model_name=self.api_model_name, api_key=self.api_key)\\n\",\n    \"        kit = InferDPTKit.load_from_path(self.inferdpt_kit_path)\\n\",\n    \"        inferdpt_client = inferdpt.InferDPTClient(self.ctx, kit, inference, epsilon=self.eps)\\n\",\n    \"        return inferdpt_client\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"class InferDPTAPIServerInit(InferClientInit):\\n\",\n    \"\\n\",\n    \"    api_url = ''\\n\",\n    \"    api_model_name = ''\\n\",\n    \"    api_key = 'EMPTY'\\n\",\n    \"\\n\",\n    \"    def __init__(self, ctx):\\n\",\n    \"        super().__init__(ctx)\\n\",\n    \"        self.ctx = ctx\\n\",\n    \"\\n\",\n    \"    def get_inst(self)-> InferDPTServer:\\n\",\n    \"        inference = APICompletionInference(api_url=self.api_url, model_name=self.api_model_name, api_key=self.api_key)\\n\",\n    \"        inferdpt_server = inferdpt.InferDPTServer(self.ctx,inference_inst=inference)\\n\",\n    \"        return inferdpt_server\\n\",\n    \"        \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"0a5c9d6b-94b9-4ae3-80f7-20d1a698764c\",\n   \"metadata\": {},\n   \"source\": [\n    \"In the pipeline example, we use arc_easy dataset and our built-in huggingface dataset. Only HuggingfaceDataset is supported in the pipeline mode:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"15276057-fdda-4cc6-8678-eb1f485e4c58\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_llm.dataset.hf_dataset import HuggingfaceDataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"id\": \"31cce967-3f5f-4261-ae17-9089368b82f9\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from datasets import load_dataset\\n\",\n    \"dataset = load_dataset('arc_easy')\\n\",\n    \"dataset.save_to_disk('your_path/arc_easy')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"id\": \"af9adcb4-766d-45f6-a13c-5c127df61e5b\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"ds = HuggingfaceDataset(load_from_disk= True, data_split_key='train')\\n\",\n    \"ds.load('your_path/arc_easy')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"id\": \"899c410f-fe68-4f7e-936e-8b11720ff148\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"{'id': 'Mercury_7220990', 'question': 'Which factor will most likely cause a person to develop a fever?', 'choices': {'text': ['a leg muscle relaxing after exercise', 'a bacterial population in the bloodstream', 'several viral particles on the skin', 'carbohydrates being digested in the stomach'], 'label': ['A', 'B', 'C', 'D']}, 'answerKey': 'B'}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(ds[0])\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"5f69f8cf-f40d-418a-be2a-753d67537442\",\n   \"metadata\": {},\n   \"source\": [\n    \"After that, we can associate the dataset path with a name and namespace. By specifying the dataset configuration, the HuggingfaceDataset will be initialized and the dataset will be loaded from the specified path. \\n\",\n    \"```\\n\",\n    \"flow table bind --namespace experiment --name arc_e --path 'your_path/arc_easy'\\n\",\n    \"```\\n\",\n    \"Once these initialization scripts are in place, you can submit a pipeline task by specifying the initialization class in the configuration files. For more information, refer to the script provided below:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"4da1aea7-0ba2-4ebb-918f-cfcf24d4498b\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import argparse\\n\",\n    \"from fate_client.pipeline.utils import test_utils\\n\",\n    \"from fate_client.pipeline.components.fate.reader import Reader\\n\",\n    \"from fate_client.pipeline import FateFlowPipeline\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def main(config=\\\"../../config.yaml\\\", namespace=\\\"\\\"):\\n\",\n    \"    # obtain config\\n\",\n    \"    if isinstance(config, str):\\n\",\n    \"        config = test_utils.load_job_config(config)\\n\",\n    \"    parties = config.parties\\n\",\n    \"    guest = parties.guest[0]\\n\",\n    \"    arbiter = parties.arbiter[0]\\n\",\n    \"\\n\",\n    \"    pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter)\\n\",\n    \"\\n\",\n    \"    reader_0 = Reader(\\\"reader_0\\\", runtime_parties=dict(guest=guest))\\n\",\n    \"    reader_0.guest.task_parameters(\\n\",\n    \"        namespace=f\\\"experiment{namespace}\\\",\\n\",\n    \"        name=\\\"arc_e\\\"\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    inferdpt_init_conf_client = {\\n\",\n    \"        'module_name': 'fate_llm.algo.inferdpt.init.default_init',\\n\",\n    \"        'item_name': 'InferDPTAPIClientInit'\\n\",\n    \"    }\\n\",\n    \"\\n\",\n    \"    dataset_conf = {\\n\",\n    \"        'module_name': 'fate_llm.dataset.hf_dataset',\\n\",\n    \"        'item_name': 'HuggingfaceDataset',\\n\",\n    \"        'kwargs':{\\n\",\n    \"            'load_from_disk': True,\\n\",\n    \"            'data_split_key': 'train'\\n\",\n    \"        }\\n\",\n    \"    }\\n\",\n    \"\\n\",\n    \"    doc_template = \\\"\\\"\\\"{{question}} \\n\",\n    \"    Choices:{{choices.text}}\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"    instruction_template=\\\"\\\"\\\"\\n\",\n    \"    <|im_start|>system\\n\",\n    \"    You are a helpful assistant.<|im_end|>\\n\",\n    \"    <|im_start|>user\\n\",\n    \"    Select Answer from Choices and explain it in \\\"Rationale\\\" with few words. Please refer to the example to write the rationale.\\n\",\n    \"    Use <end> to finish your rationle.\\\"\\n\",\n    \"\\n\",\n    \"    Example(s):\\n\",\n    \"    Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\\n\",\n    \"    Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\\n\",\n    \"    Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction.  Therefore, the answer is 'dry palms'.<end>\\n\",\n    \"\\n\",\n    \"    Please explain:\\n\",\n    \"    Question:{{perturbed_doc}}\\n\",\n    \"    Rationale:\\n\",\n    \"    <|im_end|>\\n\",\n    \"    <|im_start|>assistant\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"    decode_template = \\\"\\\"\\\"Select Answer from Choices and explain it in \\\"Rationale\\\" with few words. Please refer to the example to write the rationale.\\n\",\n    \"    Use <end> to finish your rationle.\\\"\\n\",\n    \"\\n\",\n    \"    Example(s):\\n\",\n    \"    Question:George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most heat?\\n\",\n    \"    Choices:['dry palms', 'wet palms', 'palms covered with oil', 'palms covered with lotion']\\n\",\n    \"    Rationale:Friction between two surfaces generates heat due to the conversion of kinetic energy into thermal energy. Dry palms produce the most heat when rubbed together as they create higher friction compared to wet or lubricated palms, which reduce friction.  Therefore, the answer is 'dry palms'.<end>\\n\",\n    \"\\n\",\n    \"    Question:{{perturbed_doc}}\\n\",\n    \"    Rationale:{{perturbed_response | replace('\\\\n', '')}}<end>\\n\",\n    \"\\n\",\n    \"    Please explain:\\n\",\n    \"    Question:{{question}} \\n\",\n    \"    Choices:{{choices.text}}\\n\",\n    \"    Rationale:\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"    remote_inference_kwargs={\\n\",\n    \"        'stop': [['<\\\\s>']],\\n\",\n    \"        'temperature': 0.01,\\n\",\n    \"        'max_tokens': 256\\n\",\n    \"    }\\n\",\n    \"\\n\",\n    \"    local_inference_kwargs={\\n\",\n    \"        'stop': ['<|im_end|>', '<end>', '<end>\\\\n', '<end>\\\\n\\\\n', '.\\\\n\\\\n\\\\n\\\\n\\\\n', '<|end_of_text|>', '>\\\\n\\\\n\\\\n'],\\n\",\n    \"        'temperature': 0.01,\\n\",\n    \"        'max_tokens': 256\\n\",\n    \"    }\\n\",\n    \"\\n\",\n    \"    inferdpt_client_conf = {\\n\",\n    \"        'inferdpt_init_conf': inferdpt_init_conf_client,\\n\",\n    \"        'dataset_conf': dataset_conf,\\n\",\n    \"        'doc_template': doc_template,\\n\",\n    \"        'instruction_template': instruction_template,\\n\",\n    \"        'decode_template': decode_template,\\n\",\n    \"        'dataset_conf': dataset_conf,\\n\",\n    \"        'remote_inference_kwargs': remote_inference_kwargs,\\n\",\n    \"        'local_inference_kwargs': local_inference_kwargs\\n\",\n    \"    }\\n\",\n    \"\\n\",\n    \"    inferdpt_init_conf_server = {\\n\",\n    \"        'module_name': 'fate_llm.algo.inferdpt.init.default_init',\\n\",\n    \"        'item_name': 'InferDPTAPIServerInit'\\n\",\n    \"    }\\n\",\n    \"\\n\",\n    \"    inferdpt_server_conf = {\\n\",\n    \"        'inferdpt_init_conf': inferdpt_init_conf_server\\n\",\n    \"    }\\n\",\n    \"\\n\",\n    \"    homo_nn_0 = HomoNN(\\n\",\n    \"        'nn_0',\\n\",\n    \"        runner_module='inferdpt_runner',\\n\",\n    \"        runner_class='InferDPTRunner',\\n\",\n    \"        train_data=reader_0.outputs[\\\"output_data\\\"]\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    homo_nn_0.guest.task_parameters(runner_conf=inferdpt_client_conf)\\n\",\n    \"    homo_nn_0.arbiter.task_parameters(runner_conf=inferdpt_server_conf)\\n\",\n    \"    pipeline.add_tasks([reader_0, homo_nn_0])\\n\",\n    \"    pipeline.compile()\\n\",\n    \"    pipeline.fit()\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"if __name__ == \\\"__main__\\\":\\n\",\n    \"    parser = argparse.ArgumentParser(\\\"PIPELINE DEMO\\\")\\n\",\n    \"    parser.add_argument(\\\"--config\\\", type=str, default=\\\"../config.yaml\\\",\\n\",\n    \"                        help=\\\"config file\\\")\\n\",\n    \"    parser.add_argument(\\\"--namespace\\\", type=str, default=\\\"\\\",\\n\",\n    \"                        help=\\\"namespace for data stored in FATE\\\")\\n\",\n    \"    args = parser.parse_args()\\n\",\n    \"    main(config=args.config, namespace=args.namespace)\\n\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.16\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "doc/tutorial/offsite_tuning/Offsite_tuning_tutorial.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"c2345e19-83eb-4196-9606-74658c8fbdc5\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Offsite-tuning Tutorial\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"9f1d728c-09e1-418e-8d80-53dd0ec467b1\",\n   \"metadata\": {},\n   \"source\": [\n    \"In this tutorial, we'll focus on how to leverage Offsite-Tuning framework in FATE-LLM-2.0 to fine-tune your LLM. You'll learn how to:\\n\",\n    \"\\n\",\n    \"1. Define models, including main models(which are at server side and will offer adapters and emulators) and submodel(which are at client side and will load adapters and emulators for local fine-tuning) compatible with Offsite-Tuning framework.\\n\",\n    \"2. Get hands-on experience with the Offsite-Tuning trainer.\\n\",\n    \"3. Define configurations for advanced setup(Using Deepspeed, offsite-tuning + federation) through FATE-pipeline.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"31432345-5cce-4efa-9a9b-844f997f14ad\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Introduction of Offsite-tuning\\n\",\n    \"\\n\",\n    \"Offsite-Tuning is a novel approach designed for the efficient and privacy-preserving adaptation of large foundational models for specific downstream tasks. The framework allows data owners to fine-tune models locally without uploading sensitive data to the LLM owner's servers. Specifically, the LLM owner sends a lightweight \\\"Adapter\\\" and a lossy compressed \\\"Emulator\\\" to the data owner. Using these smaller components, the data owner can then fine-tune the model solely on their private data. The Adapter, once fine-tuned, is returned to the model owner and integrated back into the large model to enhance its performance on the specific dataset.\\n\",\n    \"\\n\",\n    \"Offsite-Tuning addresses the challenge of unequal distribution of computational power and data. It allows thLLMel owner to enhance the model's capabilities without direct access to private data, while also enabling data owners who may not have the resources to train a full-scale model to fine-tune a portion of it using less computational power. This mutually beneficial arrangement accommodates both parties involve.\\n\",\n    \"\\n\",\n    \"Beyond the standard two-party setup involving the model owner and the data ownin FATE-LLM, er, Offsite-Tunframework ing is also extendable to scenarios with multiple data owners. FATE supports multi-party Offsite-Tuning, allowing multiple data owners to fine-tune and aggregate their Adapters locally, further enhancing the flexibility and applicability of this framewrFor more details of Offsite-tuning, please refer to the [original paper](https://arxiv.org/pdf/2302.04870.pdf).\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"2e7ac467-e5df-4bf3-8571-0a477ab4612d\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Preliminary\\n\",\n    \"\\n\",\n    \"We strongly recommend you finish reading our NN tutorial to get familiar with Model and Dataset customizations: [NN Tutorials](https://github.com/FederatedAI/FATE/blob/master/doc/2.0/fate/components/pipeline_nn_cutomization_tutorial.md)\\n\",\n    \"\\n\",\n    \"In this tutorial, we assume that you have deploy the codes of FATE(including fateflow & fate-client) & FATE-LLM-2.0. You can add python path so that you can run codes in the notebook.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"f33516e8-0d28-4c97-bc38-ba28d60acf37\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import sys\\n\",\n    \"your_path_to_fate_python = 'xxx/fate/fate/python'\\n\",\n    \"sys.path.append(your_path_to_fate_python)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"2f2fc794\",\n   \"metadata\": {},\n   \"source\": [\n    \"If you install FATE & FATE-LLM-2.0 via pip, you can directly use the following codes.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"7309281b-5956-4158-9256-d6db230e086d\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Define Main Model and Sub Model\\n\",\n    \"\\n\",\n    \"Main models are at server side and will provides weights of adapters and emulators to client sides, while Sub Models are at client side and will load adapters and emulators for local fine-tuning. In this chapter we will take a standard GPT2 as the example and show you how to quickly develop main model class and sub model class for offsite-tuning.\\n\",\n    \"\\n\",\n    \"### Base Classes and Interfaces\\n\",\n    \"\\n\",\n    \"The base classes for the Main and Sub Models are OffsiteTuningMainModel and OffsiteTuningSubModel, respectively. To build your own models upon these base classes, you need to:\\n\",\n    \"\\n\",\n    \"1. Implement three key interfaces: get_base_model, get_model_transformer_blocks, and forward. The get_base_model interface should return the full Main or Sub Model. Meanwhile, the get_model_transformer_blocks function should return a ModuleList of all transformer blocks present in your language model, enabling the extraction of emulators and adapters from these blocks. Finally, you're required to implement the forward process for model inference.\\n\",\n    \"\\n\",\n    \"2. Supply the parameters emulator_layer_num, adapter_top_layer_num, and adapter_bottom_layer_num to the parent class. This allows the framework to automatically generate the top and bottom adapters as well as the dropout emulator for you. Specifically, the top adapters are taken from the top of the transformer blocks, while the bottom adapters are taken from the bottom. The emulator uses a dropout emulator consistent with the paper's specifications. Once the adapter layers are removed, the emulator is formed by selecting transformer blocks at fixed intervals and finally stack them to make a dropout emulator.\\n\",\n    \"\\n\",\n    \"Our framework will automatically detect the emulator and adapters of a main model, and send them to clients. Clients' models them load the weights of emulators and adapters to get trainable models.\\n\",\n    \"\\n\",\n    \"### Example\\n\",\n    \"\\n\",\n    \"Let us take a look of our built-in GPT-2 model. It will be easy for you to build main models and sub models based on the framework. Please notice that the GPT2LMHeadSubModel's base model is intialized from a GPTConfig, that is to say, it's weights are random and need to load pretrained weights from server.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"8611c115-0321-458f-b190-49dcb127a653\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_llm.model_zoo.offsite_tuning.offsite_tuning_model import OffsiteTuningSubModel, OffsiteTuningMainModel\\n\",\n    \"from transformers import GPT2LMHeadModel, GPT2Config\\n\",\n    \"from torch import nn\\n\",\n    \"import torch as t\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"class GPT2LMHeadMainModel(OffsiteTuningMainModel):\\n\",\n    \"\\n\",\n    \"    def __init__(\\n\",\n    \"            self,\\n\",\n    \"            model_name_or_path,\\n\",\n    \"            emulator_layer_num: int,\\n\",\n    \"            adapter_top_layer_num: int = 2,\\n\",\n    \"            adapter_bottom_layer_num: int = 2):\\n\",\n    \"\\n\",\n    \"        self.model_name_or_path = model_name_or_path\\n\",\n    \"        super().__init__(\\n\",\n    \"            emulator_layer_num,\\n\",\n    \"            adapter_top_layer_num,\\n\",\n    \"            adapter_bottom_layer_num)\\n\",\n    \"\\n\",\n    \"    def get_base_model(self):\\n\",\n    \"        return GPT2LMHeadModel.from_pretrained(self.model_name_or_path)\\n\",\n    \"\\n\",\n    \"    def get_model_transformer_blocks(self, model: GPT2LMHeadModel):\\n\",\n    \"        return model.transformer.h\\n\",\n    \"\\n\",\n    \"    def forward(self, x):\\n\",\n    \"        return self.model(**x)\\n\",\n    \"\\n\",\n    \"class GPT2LMHeadSubModel(OffsiteTuningSubModel):\\n\",\n    \"\\n\",\n    \"    def __init__(\\n\",\n    \"            self,\\n\",\n    \"            model_name_or_path,\\n\",\n    \"            emulator_layer_num: int,\\n\",\n    \"            adapter_top_layer_num: int = 2,\\n\",\n    \"            adapter_bottom_layer_num: int = 2,\\n\",\n    \"            fp16_mix_precision=False,\\n\",\n    \"            partial_weight_decay=None):\\n\",\n    \"\\n\",\n    \"        self.model_name_or_path = model_name_or_path\\n\",\n    \"        self.emulator_layer_num = emulator_layer_num\\n\",\n    \"        self.adapter_top_layer_num = adapter_top_layer_num\\n\",\n    \"        self.adapter_bottom_layer_num = adapter_bottom_layer_num\\n\",\n    \"        super().__init__(\\n\",\n    \"            emulator_layer_num,\\n\",\n    \"            adapter_top_layer_num,\\n\",\n    \"            adapter_bottom_layer_num,\\n\",\n    \"            fp16_mix_precision)\\n\",\n    \"        self.partial_weight_decay = partial_weight_decay\\n\",\n    \"\\n\",\n    \"    def get_base_model(self):\\n\",\n    \"        total_layer_num = self.emulator_layer_num + \\\\\\n\",\n    \"            self.adapter_top_layer_num + self.adapter_bottom_layer_num\\n\",\n    \"        config = GPT2Config.from_pretrained(self.model_name_or_path)\\n\",\n    \"        config.num_hidden_layers = total_layer_num\\n\",\n    \"        # initialize a model without pretrained weights\\n\",\n    \"        return GPT2LMHeadModel(config)\\n\",\n    \"\\n\",\n    \"    def get_model_transformer_blocks(self, model: GPT2LMHeadModel):\\n\",\n    \"        return model.transformer.h\\n\",\n    \"        \\n\",\n    \"    def forward(self, x):\\n\",\n    \"        return self.model(**x)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"abd1f63f-afa7-4f09-a67e-63812ddcd801\",\n   \"metadata\": {},\n   \"source\": [\n    \"We can define a server side model and a client side model that can work together in the offsite-tuning:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"04870e76-11cc-4d79-a09e-b6fd16ed2f23\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model_main = GPT2LMHeadMainModel('gpt2', 4, 2, 2)\\n\",\n    \"model_sub = GPT2LMHeadSubModel('gpt2', 4, 2, 2)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"19d34937-b4ae-436e-b4ea-1620fb80bed4\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Share additional parameters with clients\\n\",\n    \"\\n\",\n    \"Additionally, beyond the weights of emulators and adapters, you may also want to share other model parameters, such as embedding weights, with your client partners. To achieve this, you'll need to implement two more interfaces: get_additional_param_state_dict and load_additional_param_state_dict for both the Main and Sub Models.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"189fce0e-8e4d-4368-8e14-907b30ce0a49\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def get_additional_param_state_dict(self):\\n\",\n    \"    # get parameter of additional parameter\\n\",\n    \"    model = self.model\\n\",\n    \"    param_dict = {\\n\",\n    \"        'wte': model.transformer.wte,\\n\",\n    \"        'wpe': model.transformer.wpe,\\n\",\n    \"        'last_ln_f': model.transformer.ln_f\\n\",\n    \"    }\\n\",\n    \"\\n\",\n    \"    addition_weights = self.get_numpy_state_dict(param_dict)\\n\",\n    \"\\n\",\n    \"    wte = addition_weights.pop('wte')\\n\",\n    \"    wte_dict = split_numpy_array(wte, 10, 'wte')\\n\",\n    \"    wpe = addition_weights.pop('wpe')\\n\",\n    \"    wpe_dict = split_numpy_array(wpe, 10, 'wpe')\\n\",\n    \"    addition_weights.update(wte_dict)\\n\",\n    \"    addition_weights.update(wpe_dict)\\n\",\n    \"    return addition_weights\\n\",\n    \"\\n\",\n    \"def load_additional_param_state_dict(self, submodel_weights: dict):\\n\",\n    \"    # load additional weights:\\n\",\n    \"    model = self.model\\n\",\n    \"    param_dict = {\\n\",\n    \"        'wte': model.transformer.wte,\\n\",\n    \"        'wpe': model.transformer.wpe,\\n\",\n    \"        'last_ln_f': model.transformer.ln_f\\n\",\n    \"    }\\n\",\n    \"\\n\",\n    \"    new_submodel_weight = {}\\n\",\n    \"    new_submodel_weight['last_ln_f'] = submodel_weights['last_ln_f']\\n\",\n    \"    wte_dict, wpe_dict = {}, {}\\n\",\n    \"    for k, v in submodel_weights.items():\\n\",\n    \"        if 'wte' in k:\\n\",\n    \"            wte_dict[k] = v\\n\",\n    \"        if 'wpe' in k:\\n\",\n    \"            wpe_dict[k] = v\\n\",\n    \"    wte = recover_numpy_array(wte_dict, 'wte')\\n\",\n    \"    wpe = recover_numpy_array(wpe_dict, 'wpe')\\n\",\n    \"    new_submodel_weight['wte'] = wte\\n\",\n    \"    new_submodel_weight['wpe'] = wpe\\n\",\n    \"\\n\",\n    \"    self.load_numpy_state_dict(param_dict, new_submodel_weight)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"59d9aa6a-80e9-4130-8af1-c7d2bd0fbba3\",\n   \"metadata\": {},\n   \"source\": [\n    \"From these codes we can see that we use 'split_numpy_array, recover_numpy_array' to cut embedding weights into pieces and recover them.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"dda6f5e3-d05a-4cdf-afd4-affbc162fce4\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Submit a Offsite-tuning Task - A QA Task Sample with GPT2\\n\",\n    \"\\n\",\n    \"Now we are going to show you how to run a 2 party(server & client) offsite-tuning task using the GPT-2 model defined above. Before we submit the task we need to prepare the QA dataset.\\n\",\n    \"\\n\",\n    \"### Prepare QA Dataset - Sciq\\n\",\n    \"\\n\",\n    \"In this example, we use sciq dataset. You can use tools provided in our qa_dataset.py to tokenize the sciq dataset and save the tokenized result. **Remember to modify the save_path to your own path.** For the sake of simplicity, in this tutorial, for every party we only use this dataset to train the model.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"84f6947e-f0a3-4a42-9549-a9776a15b66d\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_llm.dataset.qa_dataset import tokenize_qa_dataset\\n\",\n    \"from transformers import AutoTokenizer\\n\",\n    \"tokenizer_name_or_path = 'gpt2'\\n\",\n    \"tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)\\n\",\n    \"\\n\",\n    \"if 'llama' in tokenizer_name_or_path:\\n\",\n    \"    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, unk_token=\\\"<unk>\\\",  bos_token=\\\"<s>\\\", eos_token=\\\"</s>\\\", add_eos_token=True)   \\n\",\n    \"    tokenizer.pad_token = tokenizer.eos_token\\n\",\n    \"else:\\n\",\n    \"    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)\\n\",\n    \"if 'gpt2' in tokenizer_name_or_path:\\n\",\n    \"    tokenizer.pad_token = tokenizer.eos_token\\n\",\n    \"\\n\",\n    \"import os\\n\",\n    \"# bind data path to name & namespace\\n\",\n    \"save_path = 'xxxx/sciq'\\n\",\n    \"rs = tokenize_qa_dataset('sciq', tokenizer, save_path, seq_max_len=600)  # we save the cache dataset to the fate root folder\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"adabe89a-37be-4c64-bd83-4f8c8b80096f\",\n   \"metadata\": {},\n   \"source\": [\n    \"We can use our built-in QA dataset to load tokenized dataset, to see if everything is working correctly.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"id\": \"6500c2ba-bc39-4db4-b2ea-947fb09c334e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_llm.dataset.qa_dataset import QaDataset\\n\",\n    \"\\n\",\n    \"ds = QaDataset(tokenizer_name_or_path=tokenizer_name_or_path)\\n\",\n    \"ds.load(save_path)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"id\": \"d6f62b60-eed0-4bd0-874e-ae3feeebb120\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"11679\\n\",\n      \"600\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(len(ds))  # train set length\\n\",\n    \"print(ds[0]['input_ids'].__len__()) # first sample length\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"0609c63d-35a4-43bc-bd4b-f1c61adea587\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Submit a Task\\n\",\n    \"\\n\",\n    \"Now the model and the dataset is prepared! We can submit a training task. In the FATE-2.0, you can define your pipeline in a much easier manner.\\n\",\n    \"\\n\",\n    \"After we submit the task below, the following process will occur: The server and client each initialize their respective models. The server extracts shared parameters and sends them to the client. The client then loads these parameters and conducts training on a miniaturized GPT-2 model composed of an emulator and adapter on SciqP \\n\",\n    \"\\n\",\n    \"If you are not familiar with trainer configuration, please refer to [NN Tutorials](https://github.com/FederatedAI/FATE/blob/master/doc/2.0/fate/components/pipeline_nn_cutomization_tutorial.md).\\n\",\n    \"\\n\",\n    \" Upon completion of the training, the client sends the adapter parameters back to the server. Since we are directly using Hugging Face's LMHeadGPT2, there's no need to supply a loss function. Simply inputting the preprocessed data and labels into the model will calculate the correct loss and proceed with gradient descent\\n\",\n    \"\\n\",\n    \"One thing to pay special attention to is that Offsite-Tuning differs from FedAvg within FATE. In Offsite-Tuning, the server (the arbiter role) needs to initialize the model. Therefore, please refer to the example below and set the runner conf separately for the client and the server.\\n\",\n    \"\\n\",\n    \"To make this a quick demo, we only select 100 samples from the origin qa datset, see 'select_num=100' in the LLMDatasetLoader.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"261dfb43\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Bind Dataset Path with Name & Namespace\\n\",\n    \"\\n\",\n    \"Plase execute the following code to bind the dataset path with name & namespace. Remember to modify the path to your own dataset save path.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"8dc1e82b\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"! flow table bind --namespace experiment --name sciq --path YOUR_SAVE_PATH\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"0e8c5ff4\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Pipeline codes\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"id\": \"c9113d10-c3e7-4875-9502-ce46aa0b86b1\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"<fate_client.pipeline.pipeline.FateFlowPipeline at 0x7fc69aa33a00>\"\n      ]\n     },\n     \"execution_count\": 16,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"import time\\n\",\n    \"from fate_client.pipeline.components.fate.reader import Reader\\n\",\n    \"from fate_client.pipeline import FateFlowPipeline\\n\",\n    \"from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_conf_of_ot_runner\\n\",\n    \"from fate_client.pipeline.components.fate.nn.algo_params import Seq2SeqTrainingArguments, FedAVGArguments\\n\",\n    \"from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMDataFuncLoader\\n\",\n    \"from fate_client.pipeline.components.fate.nn.torch.base import Sequential\\n\",\n    \"from fate_client.pipeline.components.fate.nn.torch import nn\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"guest = '9999'\\n\",\n    \"host = '9999'\\n\",\n    \"arbiter = '9999'\\n\",\n    \"\\n\",\n    \"pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter)\\n\",\n    \"pipeline.set_site_party_id('9999')\\n\",\n    \"reader_0 = Reader(\\\"reader_0\\\", runtime_parties=dict(guest=guest))\\n\",\n    \"reader_0.guest.task_parameters(\\n\",\n    \"    namespace=\\\"experiment\\\",\\n\",\n    \"    name=\\\"sciq\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"client_model = LLMModelLoader(\\n\",\n    \"    module_name='offsite_tuning.gpt2', item_name='GPT2LMHeadSubModel',\\n\",\n    \"    model_name_or_path='gpt2',\\n\",\n    \"    emulator_layer_num=4,\\n\",\n    \"    adapter_top_layer_num=1,\\n\",\n    \"    adapter_bottom_layer_num=1\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"server_model = LLMModelLoader(\\n\",\n    \"    module_name='offsite_tuning.gpt2', item_name='GPT2LMHeadMainModel',\\n\",\n    \"    model_name_or_path='gpt2',\\n\",\n    \"    emulator_layer_num=4,\\n\",\n    \"    adapter_top_layer_num=1,\\n\",\n    \"    adapter_bottom_layer_num=1  \\n\",\n    \")\\n\",\n    \"\\n\",\n    \"train_args = Seq2SeqTrainingArguments(\\n\",\n    \"    per_device_train_batch_size=1,\\n\",\n    \"    learning_rate=5e-5,\\n\",\n    \"    disable_tqdm=False,\\n\",\n    \"    num_train_epochs=1,\\n\",\n    \"    logging_steps=10,\\n\",\n    \"    logging_strategy='steps',\\n\",\n    \"    use_cpu=False\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"dataset = LLMDatasetLoader(\\n\",\n    \"    module_name='qa_dataset', item_name='QaDataset',\\n\",\n    \"    tokenizer_name_or_path='gpt2',\\n\",\n    \"    select_num=100\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"data_collator = LLMDataFuncLoader(module_name='data_collator.cust_data_collator', item_name='get_seq2seq_data_collator', tokenizer_name_or_path='gpt2')\\n\",\n    \"\\n\",\n    \"client_conf = get_conf_of_ot_runner(\\n\",\n    \"    model=client_model,\\n\",\n    \"    dataset=dataset,\\n\",\n    \"    data_collator=data_collator,\\n\",\n    \"    training_args=train_args,\\n\",\n    \"    fed_args=FedAVGArguments(),\\n\",\n    \"    aggregate_model=False\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"server_conf = get_conf_of_ot_runner(\\n\",\n    \"    model=server_model,\\n\",\n    \"    dataset=dataset,\\n\",\n    \"    data_collator=data_collator,\\n\",\n    \"    training_args=train_args,\\n\",\n    \"    fed_args=FedAVGArguments(),\\n\",\n    \"    aggregate_model=False\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"homo_nn_0 = HomoNN(\\n\",\n    \"    'nn_0',\\n\",\n    \"    train_data=reader_0.outputs[\\\"output_data\\\"],\\n\",\n    \"    runner_module=\\\"offsite_tuning_runner\\\",\\n\",\n    \"    runner_class=\\\"OTRunner\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"homo_nn_0.guest.task_parameters(runner_conf=client_conf)\\n\",\n    \"homo_nn_0.arbiter.task_parameters(runner_conf=server_conf)\\n\",\n    \"pipeline.add_tasks([reader_0, homo_nn_0])\\n\",\n    \"pipeline.compile()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e97c2823\",\n   \"metadata\": {},\n   \"source\": [\n    \"You can try to initialize your models, datasets to check if they can be loaded correctly.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"id\": \"872817e5\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"GPT2LMHeadSubModel(\\n\",\n      \"  (model): GPT2LMHeadModel(\\n\",\n      \"    (transformer): GPT2Model(\\n\",\n      \"      (wte): Embedding(50257, 768)\\n\",\n      \"      (wpe): Embedding(1024, 768)\\n\",\n      \"      (drop): Dropout(p=0.1, inplace=False)\\n\",\n      \"      (h): ModuleList(\\n\",\n      \"        (0): GPT2Block(\\n\",\n      \"          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"          (attn): GPT2Attention(\\n\",\n      \"            (c_attn): Conv1D()\\n\",\n      \"            (c_proj): Conv1D()\\n\",\n      \"            (attn_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"            (resid_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"          )\\n\",\n      \"          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"          (mlp): GPT2MLP(\\n\",\n      \"            (c_fc): Conv1D()\\n\",\n      \"            (c_proj): Conv1D()\\n\",\n      \"            (act): NewGELUActivation()\\n\",\n      \"            (dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"          )\\n\",\n      \"        )\\n\",\n      \"        (1): GPT2Block(\\n\",\n      \"          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"          (attn): GPT2Attention(\\n\",\n      \"            (c_attn): Conv1D()\\n\",\n      \"            (c_proj): Conv1D()\\n\",\n      \"            (attn_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"            (resid_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"          )\\n\",\n      \"          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"          (mlp): GPT2MLP(\\n\",\n      \"            (c_fc): Conv1D()\\n\",\n      \"            (c_proj): Conv1D()\\n\",\n      \"            (act): NewGELUActivation()\\n\",\n      \"            (dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"          )\\n\",\n      \"        )\\n\",\n      \"        (2): GPT2Block(\\n\",\n      \"          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"          (attn): GPT2Attention(\\n\",\n      \"            (c_attn): Conv1D()\\n\",\n      \"            (c_proj): Conv1D()\\n\",\n      \"            (attn_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"            (resid_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"          )\\n\",\n      \"          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"          (mlp): GPT2MLP(\\n\",\n      \"            (c_fc): Conv1D()\\n\",\n      \"            (c_proj): Conv1D()\\n\",\n      \"            (act): NewGELUActivation()\\n\",\n      \"            (dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"          )\\n\",\n      \"        )\\n\",\n      \"        (3): GPT2Block(\\n\",\n      \"          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"          (attn): GPT2Attention(\\n\",\n      \"            (c_attn): Conv1D()\\n\",\n      \"            (c_proj): Conv1D()\\n\",\n      \"            (attn_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"            (resid_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"          )\\n\",\n      \"          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"          (mlp): GPT2MLP(\\n\",\n      \"            (c_fc): Conv1D()\\n\",\n      \"            (c_proj): Conv1D()\\n\",\n      \"            (act): NewGELUActivation()\\n\",\n      \"            (dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"          )\\n\",\n      \"        )\\n\",\n      \"        (4): GPT2Block(\\n\",\n      \"          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"          (attn): GPT2Attention(\\n\",\n      \"            (c_attn): Conv1D()\\n\",\n      \"            (c_proj): Conv1D()\\n\",\n      \"            (attn_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"            (resid_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"          )\\n\",\n      \"          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"          (mlp): GPT2MLP(\\n\",\n      \"            (c_fc): Conv1D()\\n\",\n      \"            (c_proj): Conv1D()\\n\",\n      \"            (act): NewGELUActivation()\\n\",\n      \"            (dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"          )\\n\",\n      \"        )\\n\",\n      \"        (5): GPT2Block(\\n\",\n      \"          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"          (attn): GPT2Attention(\\n\",\n      \"            (c_attn): Conv1D()\\n\",\n      \"            (c_proj): Conv1D()\\n\",\n      \"            (attn_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"            (resid_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"          )\\n\",\n      \"          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"          (mlp): GPT2MLP(\\n\",\n      \"            (c_fc): Conv1D()\\n\",\n      \"            (c_proj): Conv1D()\\n\",\n      \"            (act): NewGELUActivation()\\n\",\n      \"            (dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"          )\\n\",\n      \"        )\\n\",\n      \"      )\\n\",\n      \"      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"    )\\n\",\n      \"    (lm_head): Linear(in_features=768, out_features=50257, bias=False)\\n\",\n      \"  )\\n\",\n      \"  (emulator): ModuleList(\\n\",\n      \"    (0): GPT2Block(\\n\",\n      \"      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"      (attn): GPT2Attention(\\n\",\n      \"        (c_attn): Conv1D()\\n\",\n      \"        (c_proj): Conv1D()\\n\",\n      \"        (attn_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"        (resid_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"      )\\n\",\n      \"      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"      (mlp): GPT2MLP(\\n\",\n      \"        (c_fc): Conv1D()\\n\",\n      \"        (c_proj): Conv1D()\\n\",\n      \"        (act): NewGELUActivation()\\n\",\n      \"        (dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"      )\\n\",\n      \"    )\\n\",\n      \"    (1): GPT2Block(\\n\",\n      \"      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"      (attn): GPT2Attention(\\n\",\n      \"        (c_attn): Conv1D()\\n\",\n      \"        (c_proj): Conv1D()\\n\",\n      \"        (attn_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"        (resid_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"      )\\n\",\n      \"      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"      (mlp): GPT2MLP(\\n\",\n      \"        (c_fc): Conv1D()\\n\",\n      \"        (c_proj): Conv1D()\\n\",\n      \"        (act): NewGELUActivation()\\n\",\n      \"        (dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"      )\\n\",\n      \"    )\\n\",\n      \"    (2): GPT2Block(\\n\",\n      \"      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"      (attn): GPT2Attention(\\n\",\n      \"        (c_attn): Conv1D()\\n\",\n      \"        (c_proj): Conv1D()\\n\",\n      \"        (attn_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"        (resid_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"      )\\n\",\n      \"      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"      (mlp): GPT2MLP(\\n\",\n      \"        (c_fc): Conv1D()\\n\",\n      \"        (c_proj): Conv1D()\\n\",\n      \"        (act): NewGELUActivation()\\n\",\n      \"        (dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"      )\\n\",\n      \"    )\\n\",\n      \"    (3): GPT2Block(\\n\",\n      \"      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"      (attn): GPT2Attention(\\n\",\n      \"        (c_attn): Conv1D()\\n\",\n      \"        (c_proj): Conv1D()\\n\",\n      \"        (attn_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"        (resid_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"      )\\n\",\n      \"      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"      (mlp): GPT2MLP(\\n\",\n      \"        (c_fc): Conv1D()\\n\",\n      \"        (c_proj): Conv1D()\\n\",\n      \"        (act): NewGELUActivation()\\n\",\n      \"        (dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"      )\\n\",\n      \"    )\\n\",\n      \"  )\\n\",\n      \"  (adapter_bottom): ModuleList(\\n\",\n      \"    (0): GPT2Block(\\n\",\n      \"      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"      (attn): GPT2Attention(\\n\",\n      \"        (c_attn): Conv1D()\\n\",\n      \"        (c_proj): Conv1D()\\n\",\n      \"        (attn_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"        (resid_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"      )\\n\",\n      \"      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"      (mlp): GPT2MLP(\\n\",\n      \"        (c_fc): Conv1D()\\n\",\n      \"        (c_proj): Conv1D()\\n\",\n      \"        (act): NewGELUActivation()\\n\",\n      \"        (dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"      )\\n\",\n      \"    )\\n\",\n      \"  )\\n\",\n      \"  (adapter_top): ModuleList(\\n\",\n      \"    (0): GPT2Block(\\n\",\n      \"      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"      (attn): GPT2Attention(\\n\",\n      \"        (c_attn): Conv1D()\\n\",\n      \"        (c_proj): Conv1D()\\n\",\n      \"        (attn_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"        (resid_dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"      )\\n\",\n      \"      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\\n\",\n      \"      (mlp): GPT2MLP(\\n\",\n      \"        (c_fc): Conv1D()\\n\",\n      \"        (c_proj): Conv1D()\\n\",\n      \"        (act): NewGELUActivation()\\n\",\n      \"        (dropout): Dropout(p=0.1, inplace=False)\\n\",\n      \"      )\\n\",\n      \"    )\\n\",\n      \"  )\\n\",\n      \")\\n\",\n      \"**********\\n\",\n      \"<fate_llm.dataset.qa_dataset.QaDataset object at 0x7fc724fdfd00>\\n\",\n      \"**********\\n\",\n      \"DataCollatorForSeq2Seq(tokenizer=GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={\\n\",\n      \"\\t50256: AddedToken(\\\"<|endoftext|>\\\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\\n\",\n      \"}, model=None, padding=True, max_length=None, pad_to_multiple_of=None, label_pad_token_id=-100, return_tensors='pt')\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"print(client_model())\\n\",\n    \"print('*' * 10)\\n\",\n    \"print(dataset())\\n\",\n    \"print('*' * 10)\\n\",\n    \"print(data_collator())\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"898c3491\",\n   \"metadata\": {},\n   \"source\": [\n    \"Seems that everything is ready! Now we can submit the task. Submit the code below to submit your task.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"id\": \"74497742-4030-4a7a-a13e-2c020da47cd1\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"pipeline.fit()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"b33b2e2b-3b53-4881-8db6-a67e1293e88b\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Add Deepspeed Setting\\n\",\n    \"\\n\",\n    \"By simply adding a ds_config, we can run our task with a deepspeed backend. If you have deployed eggroll envoironment, you can submmit the task with deepspeed to eggroll accelerate your training.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"id\": \"6e8f063b-263c-4ba5-b2ba-98a86ce38b94\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"<pipeline.backend.pipeline.PipeLine at 0x7f8002385e50>\"\n      ]\n     },\n     \"execution_count\": 5,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"import time\\n\",\n    \"from fate_client.pipeline.components.fate.reader import Reader\\n\",\n    \"from fate_client.pipeline import FateFlowPipeline\\n\",\n    \"from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_conf_of_ot_runner\\n\",\n    \"from fate_client.pipeline.components.fate.nn.algo_params import Seq2SeqTrainingArguments, FedAVGArguments\\n\",\n    \"from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMDataFuncLoader\\n\",\n    \"from peft import LoraConfig, TaskType\\n\",\n    \"from transformers.modeling_utils import unwrap_model\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"guest = '10000'\\n\",\n    \"host = '10000'\\n\",\n    \"arbiter = '10000'\\n\",\n    \"\\n\",\n    \"# pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter)\\n\",\n    \"pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter)\\n\",\n    \"\\n\",\n    \"reader_0 = Reader(\\\"reader_0\\\", runtime_parties=dict(guest=guest))\\n\",\n    \"reader_0.guest.task_parameters(\\n\",\n    \"    namespace=\\\"experiment\\\",\\n\",\n    \"    name=\\\"sciq\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"client_model = LLMModelLoader(\\n\",\n    \"    module_name='offsite_tuning.gpt2', item_name='GPT2LMHeadSubModel',\\n\",\n    \"    model_name_or_path='gpt2',\\n\",\n    \"    emulator_layer_num=18,\\n\",\n    \"    adapter_top_layer_num=2,\\n\",\n    \"    adapter_bottom_layer_num=2\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"server_model = LLMModelLoader(\\n\",\n    \"    module_name='offsite_tuning.gpt2', item_name='GPT2LMHeadMainModel',\\n\",\n    \"    model_name_or_path='gpt2',\\n\",\n    \"    emulator_layer_num=18,\\n\",\n    \"    adapter_top_layer_num=2,\\n\",\n    \"    adapter_bottom_layer_num=2  \\n\",\n    \")\\n\",\n    \"\\n\",\n    \"dataset = LLMDatasetLoader(\\n\",\n    \"    module_name='qa_dataset', item_name='QaDataset',\\n\",\n    \"    tokenizer_name_or_path='gpt2',\\n\",\n    \"    select_num=100\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"data_collator = LLMDataFuncLoader(module_name='data_collator.cust_data_collator', item_name='get_seq2seq_data_collator', tokenizer_name_or_path='gpt2')\\n\",\n    \"\\n\",\n    \"batch_size = 1\\n\",\n    \"lr = 5e-5\\n\",\n    \"ds_config = {\\n\",\n    \"    \\\"train_micro_batch_size_per_gpu\\\": batch_size,\\n\",\n    \"    \\\"optimizer\\\": {\\n\",\n    \"        \\\"type\\\": \\\"Adam\\\",\\n\",\n    \"        \\\"params\\\": {\\n\",\n    \"            \\\"lr\\\": lr,\\n\",\n    \"            \\\"torch_adam\\\": True,\\n\",\n    \"            \\\"adam_w_mode\\\": False\\n\",\n    \"        }\\n\",\n    \"    },\\n\",\n    \"    \\\"fp16\\\": {\\n\",\n    \"        \\\"enabled\\\": True\\n\",\n    \"    },\\n\",\n    \"    \\\"gradient_accumulation_steps\\\": 1,\\n\",\n    \"    \\\"zero_optimization\\\": {\\n\",\n    \"        \\\"stage\\\": 2,\\n\",\n    \"        \\\"allgather_partitions\\\": True,\\n\",\n    \"        \\\"allgather_bucket_size\\\": 1e8,\\n\",\n    \"        \\\"overlap_comm\\\": True,\\n\",\n    \"        \\\"reduce_scatter\\\": True,\\n\",\n    \"        \\\"reduce_bucket_size\\\": 1e8,\\n\",\n    \"        \\\"contiguous_gradients\\\": True,\\n\",\n    \"        \\\"offload_optimizer\\\": {\\n\",\n    \"            \\\"device\\\": \\\"cpu\\\"\\n\",\n    \"        },\\n\",\n    \"        \\\"offload_param\\\": {\\n\",\n    \"            \\\"device\\\": \\\"cpu\\\"\\n\",\n    \"        }\\n\",\n    \"    }\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"train_args = Seq2SeqTrainingArguments(\\n\",\n    \"    per_device_train_batch_size=1,\\n\",\n    \"    learning_rate=5e-5,\\n\",\n    \"    disable_tqdm=False,\\n\",\n    \"    num_train_epochs=1,\\n\",\n    \"    logging_steps=10,\\n\",\n    \"    logging_strategy='steps',\\n\",\n    \"    dataloader_num_workers=4,\\n\",\n    \"    use_cpu=False,\\n\",\n    \"    deepspeed=ds_config,  # Add deepspeed config here\\n\",\n    \"    remove_unused_columns=False,\\n\",\n    \"    fp16=True\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"client_conf = get_conf_of_ot_runner(\\n\",\n    \"    model=client_model,\\n\",\n    \"    dataset=dataset,\\n\",\n    \"    data_collator=data_collator,\\n\",\n    \"    training_args=train_args,\\n\",\n    \"    fed_args=FedAVGArguments(),\\n\",\n    \"    aggregate_model=False,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"server_conf = get_conf_of_ot_runner(\\n\",\n    \"    model=server_model,\\n\",\n    \"    dataset=dataset,\\n\",\n    \"    data_collator=data_collator,\\n\",\n    \"    training_args=train_args,\\n\",\n    \"    fed_args=FedAVGArguments(),\\n\",\n    \"    aggregate_model=False\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"homo_nn_0 = HomoNN(\\n\",\n    \"    'nn_0',\\n\",\n    \"    train_data=reader_0.outputs[\\\"output_data\\\"],\\n\",\n    \"    runner_module=\\\"offsite_tuning_runner\\\",\\n\",\n    \"    runner_class=\\\"OTRunner\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"homo_nn_0.guest.task_parameters(runner_conf=client_conf)\\n\",\n    \"homo_nn_0.arbiter.task_parameters(runner_conf=server_conf)\\n\",\n    \"\\n\",\n    \"# if you have deployed eggroll, you can add this line to submit your job to eggroll\\n\",\n    \"homo_nn_0.guest.conf.set(\\\"launcher_name\\\", \\\"deepspeed\\\")\\n\",\n    \"\\n\",\n    \"pipeline.add_tasks([reader_0, homo_nn_0])\\n\",\n    \"pipeline.conf.set(\\\"task\\\", dict(engine_run={\\\"cores\\\": 4}))\\n\",\n    \"pipeline.compile()\\n\",\n    \"pipeline.fit()\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"97249681-c3a3-43bd-8167-7ae3f4e1616b\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Offsite-tuning + Multi Client Federation\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"The Offsite-Tuning + FedAVG federation is configured based on the standard Offsite-Tuning. In this situation, you need to add data input & configurations for all clients. And do remember to add 'aggregate_model=True' for client & server conf so that model federation will be conducted during the training.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"fdbdc60c-a948-4be3-bba6-519d8640b0a9\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import time\\n\",\n    \"from fate_client.pipeline.components.fate.reader import Reader\\n\",\n    \"from fate_client.pipeline import FateFlowPipeline\\n\",\n    \"from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_conf_of_ot_runner\\n\",\n    \"from fate_client.pipeline.components.fate.nn.algo_params import Seq2SeqTrainingArguments, FedAVGArguments\\n\",\n    \"from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMCustFuncLoader\\n\",\n    \"from peft import LoraConfig, TaskType\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"guest = '10000'\\n\",\n    \"host = '10000'\\n\",\n    \"arbiter = '10000'\\n\",\n    \"\\n\",\n    \"pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter)\\n\",\n    \"\\n\",\n    \"reader_0 = Reader(\\\"reader_0\\\", runtime_parties=dict(guest=guest, host=host))\\n\",\n    \"reader_0.guest.task_parameters(\\n\",\n    \"    namespace=\\\"experiment\\\",\\n\",\n    \"    name=\\\"sciq\\\"\\n\",\n    \")\\n\",\n    \"reader_0.hosts[0].task_parameters(\\n\",\n    \"    namespace=\\\"experiment\\\",\\n\",\n    \"    name=\\\"sciq\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"client_model = LLMModelLoader(\\n\",\n    \"    module_name='offsite_tuning.gpt2', item_name='GPT2LMHeadSubModel',\\n\",\n    \"    model_name_or_path='gpt2',\\n\",\n    \"    emulator_layer_num=4,\\n\",\n    \"    adapter_top_layer_num=1,\\n\",\n    \"    adapter_bottom_layer_num=1\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"server_model = LLMModelLoader(\\n\",\n    \"    module_name='offsite_tuning.gpt2', item_name='GPT2LMHeadMainModel',\\n\",\n    \"    model_name_or_path='gpt2',\\n\",\n    \"    emulator_layer_num=4,\\n\",\n    \"    adapter_top_layer_num=1,\\n\",\n    \"    adapter_bottom_layer_num=1  \\n\",\n    \")\\n\",\n    \"\\n\",\n    \"dataset = LLMDatasetLoader(\\n\",\n    \"    module_name='qa_dataset', item_name='QaDataset',\\n\",\n    \"    tokenizer_name_or_path='gpt2',\\n\",\n    \"    select_num=100\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"data_collator = LLMCustFuncLoader(module_name='cust_data_collator', item_name='get_seq2seq_tokenizer', model_path='gpt2')\\n\",\n    \"\\n\",\n    \"train_args = Seq2SeqTrainingArguments(\\n\",\n    \"    per_device_train_batch_size=1,\\n\",\n    \"    learning_rate=5e-5,\\n\",\n    \"    disable_tqdm=False,\\n\",\n    \"    num_train_epochs=1,\\n\",\n    \"    logging_steps=10,\\n\",\n    \"    logging_strategy='steps',\\n\",\n    \"    dataloader_num_workers=4\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"client_conf = get_conf_of_ot_runner(\\n\",\n    \"    model=client_model,\\n\",\n    \"    dataset=dataset,\\n\",\n    \"    data_collator=data_collator,\\n\",\n    \"    training_args=train_args,\\n\",\n    \"    fed_args=FedAVGArguments(),\\n\",\n    \"    aggregate_model=True\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"server_conf = get_conf_of_ot_runner(\\n\",\n    \"    model=server_model,\\n\",\n    \"    dataset=dataset,\\n\",\n    \"    data_collator=data_collator,\\n\",\n    \"    training_args=train_args,\\n\",\n    \"    fed_args=FedAVGArguments(),\\n\",\n    \"    aggregate_model=True\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"homo_nn_0 = HomoNN(\\n\",\n    \"    'nn_0',\\n\",\n    \"    train_data=reader_0.outputs[\\\"output_data\\\"],\\n\",\n    \"    runner_module=\\\"offsite_tuning_runner\\\",\\n\",\n    \"    runner_class=\\\"OTRunner\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"homo_nn_0.guest.task_parameters(runner_conf=client_conf)\\n\",\n    \"homo_nn_0.hosts[0].task_parameters(runner_conf=client_conf)\\n\",\n    \"homo_nn_0.arbiter.task_parameters(runner_conf=server_conf)\\n\",\n    \"\\n\",\n    \"pipeline.add_tasks([reader_0, homo_nn_0])\\n\",\n    \"\\n\",\n    \"pipeline.compile()\\n\",\n    \"pipeline.fit()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.16\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "doc/tutorial/offsite_tuning/README.md",
    "content": "\n# Offsite-Tuning\n\n## Standard Offsite-tuning\n\nOffsite-Tuning is designed for the efficient adaptation of large foundational models for specific downstream tasks. \nThrough Offsite-Tuning, the model owner can enhance the capabilities of large models using data providers without having to disclose the full model weights and directly access the data providers' sensitive information. Specifically, the LLM owner sends a lightweight \"Adapter\" and a lossy compressed \"Emulator\" to the data owner. Using these smaller components, the data owner can then fine-tune the model solely on their private data. The Adapter, once fine-tuned, is returned to the model owner and integrated back into the large model to enhance its performance on the specific dataset.\n\nIn FATE-LLM 1.3, we provide these built-in models:\n\n- GPT2 series models (e.g., GPT2, GPT2-XL, etc.)\n- Bloom series models (such as Bloom7B)\n- Llama-1 series models (e.g., Llama7B)\n\nFATE-LLM v1.3 builds on v1.2 and offers the ability to easily configure multi-machine and multi-card acceleration. It also has specialized optimizations for the network transmission of adapters and emulators.\n\n\n[Read the full paper](https://arxiv.org/abs/2302.04870)\n\n<div align=\"center\">\n  <img src=\"./../../images/ot1.png\" height=\"300\">\n</div>\n\n## Offsite-tuning with Federated Learning\n\nIn addition to supporting standard two-party (model owner and data provider) offsite-tuning, FATE also supports offsite-tuning with multiple data providers simultaneously. Adapters can be fine-tuned locally and then aggregated with those from other data providers. Ultimately, large models can be enhanced through the secure aggregation of adapters from multiple parties. This approach can be used to address issues related to the uneven distribution of computational power and data.\nAs shown in the diagram below:\n\n\n<div align=\"center\">\n  <img src=\"./../../images/ot2.png\" height=\"300\">\n</div>"
  },
  {
    "path": "doc/tutorial/pellm/ChatGLM3-6B_ds.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Federated ChatGLM3 Tuning with Parameter Efficient methods in FATE-LLM\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"In this tutorial, we will demonstrate how to efficiently train federated ChatGLM3-6B with deepspeed using the FATE-LLM framework. In FATE-LLM, we introduce the \\\"pellm\\\"(Parameter Efficient Large Language Model) module, specifically designed for federated learning with large language models. We enable the implementation of parameter-efficient methods in federated learning, reducing communication overhead while maintaining model performance. In this tutorial we particularlly focus on ChatGLM3-6B, and we will also emphasize the use of the Adapter mechanism for fine-tuning ChatGLM3-6B, which enables us to effectively reduce communication volume and improve overall efficiency.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## FATE-LLM: ChatGLM3-6B\\n\",\n    \"\\n\",\n    \"### ChatGLM-6B\\n\",\n    \"ChatGLM3-6B is a large transformer-based language model with 5.977 billion parameters, it is an open bilingual language model based on General Language Model. You can download the pretrained model from [here](https://github.com/THUDM/ChatGLM3), or let the program automatically download it when you use it later.\\n\",\n    \"\\n\",\n    \"### Current Features\\n\",\n    \"\\n\",\n    \"In current version, FATE-LLM: ChatGLM-6B supports the following features:\\n\",\n    \"<div align=\\\"center\\\">\\n\",\n    \"  <img src=\\\"../../images/fate-llm-chatglm-6b.png\\\">\\n\",\n    \"</div>\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Experiment Setting\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Before running experiment, please make sure that [FATE-LLM Cluster](https://github.com/FederatedAI/FATE/wiki/Download#llm%E9%83%A8%E7%BD%B2%E5%8C%85) has been deployed. \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Dataset: Advertising Text Generation\\n\",\n    \"\\n\",\n    \"This is an advertising test generateion dataset, you can download dataset from the following links and place it in the examples/data folder. \\n\",\n    \"- [data link 1](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view)\\n\",\n    \"- [data link 2](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1)  \\n\",\n    \"\\n\",\n    \"You can refer to following link for more details about [data](https://aclanthology.org/D19-1321.pdf)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import pandas as pd\\n\",\n    \"df = pd.read_json('${fate_install}/examples/data/AdvertiseGen/train.json', lines=True)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### ChatGLM3-6B with Adapter\\n\",\n    \"\\n\",\n    \"In this section, we will guide you through the process of finetuning ChatGLM-6B with adapters using the FATE-LLM framework. \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"ChatGLM model is located on fate_llm/model_zoo/chatglm.py, can be use directly\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"albert.py  bloom.py    distilbert.py  parameter_efficient_llm.py\\n\",\n      \"bart.py    chatglm.py  gpt2.py\\t      qwen.py\\n\",\n      \"bert.py    deberta.py  llama.py       roberta.py\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"! ls ../../../../fate_llm/python/fate_llm/model_zoo/pellm\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Adapters\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"We can directly use adapters from the peft. See details for adapters on this page [Adapter Methods](https://huggingface.co/docs/peft/index) for more details. By specifying the adapter name and the adapter\\n\",\n    \"config dict we can insert adapters into our language models:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from peft import LoraConfig, TaskType\\n\",\n    \"\\n\",\n    \"lora_config = LoraConfig(\\n\",\n    \"    task_type=TaskType.CAUSAL_LM,\\n\",\n    \"    inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1,\\n\",\n    \"    target_modules=['query_key_value'],\\n\",\n    \")\\n\",\n    \"lora_config.target_modules = list(lora_config.target_modules) # this line is needed to ensure lora_config is jsonable\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Init ChatGLM3 Model \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader\\n\",\n    \"\\n\",\n    \"pretrained_model_path = \\\"fill with pretrained model download path please\\\"\\n\",\n    \"\\n\",\n    \"model = LLMModelLoader(\\n\",\n    \"    \\\"pellm.chatglm\\\",\\n\",\n    \"    \\\"ChatGLM\\\",\\n\",\n    \"    pretrained_path=pretrained_model_path,\\n\",\n    \"    peft_type=\\\"LoraConfig\\\",\\n\",\n    \"    peft_config=lora_config.to_dict(),\\n\",\n    \"    trust_remote_code=True\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"**During the training process, all weights of the pretrained language model will be frozen, and weights of adapters are traininable. Thus, FATE-LLM only train in the local training and aggregate adapters' weights in the fedederation process**\\n\",\n    \"\\n\",\n    \"Now available adapters are [Adapters Overview](https://huggingface.co/docs/peft/index) for details.\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Specify Dataset And DataCollator To Process Data\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from fate_client.pipeline.components.fate.nn.loader import LLMDatasetLoader, LLMDataFuncLoader\\n\",\n    \"\\n\",\n    \"tokenizer_params = dict(\\n\",\n    \"    tokenizer_name_or_path=pretrained_model_path,\\n\",\n    \"    trust_remote_code=True,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"dataset = LLMDatasetLoader(\\n\",\n    \"    \\\"prompt_dataset\\\",\\n\",\n    \"    \\\"PromptDataset\\\",\\n\",\n    \"    **tokenizer_params,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"data_collator = LLMDataFuncLoader(\\n\",\n    \"    \\\"data_collator.cust_data_collator\\\",\\n\",\n    \"    \\\"get_seq2seq_data_collator\\\",\\n\",\n    \"    **tokenizer_params,\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Init DeepSpeed Config\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"ds_config = {\\n\",\n    \"    \\\"train_micro_batch_size_per_gpu\\\": 1,\\n\",\n    \"    \\\"optimizer\\\": {\\n\",\n    \"        \\\"type\\\": \\\"Adam\\\",\\n\",\n    \"        \\\"params\\\": {\\n\",\n    \"            \\\"lr\\\": 5e-4\\n\",\n    \"        }\\n\",\n    \"    },\\n\",\n    \"    \\\"fp16\\\": {\\n\",\n    \"        \\\"enabled\\\": True\\n\",\n    \"    },\\n\",\n    \"    \\\"gradient_accumulation_steps\\\": 1,\\n\",\n    \"    \\\"zero_optimization\\\": {\\n\",\n    \"        \\\"stage\\\": 2,\\n\",\n    \"        \\\"allgather_partitions\\\": True,\\n\",\n    \"        \\\"allgather_bucket_size\\\": 1e8,\\n\",\n    \"        \\\"overlap_comm\\\": True,\\n\",\n    \"        \\\"reduce_scatter\\\": True,\\n\",\n    \"        \\\"reduce_bucket_size\\\": 1e8,\\n\",\n    \"        \\\"contiguous_gradients\\\": True,\\n\",\n    \"        \\\"offload_optimizer\\\": {\\n\",\n    \"            \\\"device\\\": \\\"cpu\\\"\\n\",\n    \"        },\\n\",\n    \"        \\\"offload_param\\\": {\\n\",\n    \"            \\\"device\\\": \\\"cpu\\\"\\n\",\n    \"        }\\n\",\n    \"    }\\n\",\n    \"}\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Submit Federated Task\\n\",\n    \"To run federated task, please make sure to ues fate>=2.1.0 and deploy it with gpu machines. To running this code, make sure training data path is already binded. The following code shoud be copy to a script and run in a command line like \\\"python federated_chatglm.py\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"You can use this script to submit the model, but submitting the model will take a long time to train and generate a long log, so we won't do it here.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import time\\n\",\n    \"from fate_client.pipeline.components.fate.reader import Reader\\n\",\n    \"from fate_client.pipeline import FateFlowPipeline\\n\",\n    \"from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_seq2seq_runner\\n\",\n    \"from fate_client.pipeline.components.fate.nn.algo_params import Seq2SeqTrainingArguments, FedAVGArguments\\n\",\n    \"from fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMDataFuncLoader\\n\",\n    \"from peft import LoraConfig, TaskType\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"guest = '10000'\\n\",\n    \"host = '10000'\\n\",\n    \"arbiter = '10000'\\n\",\n    \"\\n\",\n    \"epochs = 1\\n\",\n    \"batch_size = 1\\n\",\n    \"lr = 5e-4\\n\",\n    \"\\n\",\n    \"ds_config = {\\n\",\n    \"    \\\"train_micro_batch_size_per_gpu\\\": batch_size,\\n\",\n    \"    \\\"optimizer\\\": {\\n\",\n    \"        \\\"type\\\": \\\"Adam\\\",\\n\",\n    \"        \\\"params\\\": {\\n\",\n    \"            \\\"lr\\\": lr,\\n\",\n    \"            \\\"torch_adam\\\": True,\\n\",\n    \"            \\\"adam_w_mode\\\": False\\n\",\n    \"        }\\n\",\n    \"    },\\n\",\n    \"    \\\"fp16\\\": {\\n\",\n    \"        \\\"enabled\\\": True\\n\",\n    \"    },\\n\",\n    \"    \\\"gradient_accumulation_steps\\\": 1,\\n\",\n    \"    \\\"zero_optimization\\\": {\\n\",\n    \"        \\\"stage\\\": 2,\\n\",\n    \"        \\\"allgather_partitions\\\": True,\\n\",\n    \"        \\\"allgather_bucket_size\\\": 1e8,\\n\",\n    \"        \\\"overlap_comm\\\": True,\\n\",\n    \"        \\\"reduce_scatter\\\": True,\\n\",\n    \"        \\\"reduce_bucket_size\\\": 1e8,\\n\",\n    \"        \\\"contiguous_gradients\\\": True,\\n\",\n    \"        \\\"offload_optimizer\\\": {\\n\",\n    \"            \\\"device\\\": \\\"cpu\\\"\\n\",\n    \"        },\\n\",\n    \"        \\\"offload_param\\\": {\\n\",\n    \"            \\\"device\\\": \\\"cpu\\\"\\n\",\n    \"        }\\n\",\n    \"    }\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter)\\n\",\n    \"# pipeline.bind_local_path(path=\\\"\\\", namespace=\\\"experiment\\\", name=\\\"ad\\\")\\n\",\n    \"time.sleep(5)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"reader_0 = Reader(\\\"reader_0\\\", runtime_parties=dict(guest=guest, host=host))\\n\",\n    \"reader_0.guest.task_parameters(\\n\",\n    \"    namespace=\\\"experiment\\\",\\n\",\n    \"    name=\\\"ad\\\"\\n\",\n    \")\\n\",\n    \"reader_0.hosts[0].task_parameters(\\n\",\n    \"    namespace=\\\"experiment\\\",\\n\",\n    \"    name=\\\"ad\\\"\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# define lora config\\n\",\n    \"lora_config = LoraConfig(\\n\",\n    \"    task_type=TaskType.CAUSAL_LM,\\n\",\n    \"    inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1,\\n\",\n    \"    target_modules=['query_key_value'],\\n\",\n    \")\\n\",\n    \"lora_config.target_modules = list(lora_config.target_modules)\\n\",\n    \"\\n\",\n    \"pretrained_model_path = \\\"/data/cephfs/llm/models/chatglm3-6b\\\"\\n\",\n    \"\\n\",\n    \"model = LLMModelLoader(\\n\",\n    \"    \\\"pellm.chatglm\\\",\\n\",\n    \"    \\\"ChatGLM\\\",\\n\",\n    \"    pretrained_path=pretrained_model_path,\\n\",\n    \"    peft_type=\\\"LoraConfig\\\",\\n\",\n    \"    peft_config=lora_config.to_dict(),\\n\",\n    \"    trust_remote_code=True\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"tokenizer_params = dict(\\n\",\n    \"    tokenizer_name_or_path=pretrained_model_path,\\n\",\n    \"    trust_remote_code=True,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"dataset = LLMDatasetLoader(\\n\",\n    \"    \\\"prompt_dataset\\\",\\n\",\n    \"    \\\"PromptDataset\\\",\\n\",\n    \"    **tokenizer_params,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"data_collator = LLMDataFuncLoader(\\n\",\n    \"    \\\"data_collator.cust_data_collator\\\",\\n\",\n    \"    \\\"get_seq2seq_data_collator\\\",\\n\",\n    \"    **tokenizer_params,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"conf = get_config_of_seq2seq_runner(\\n\",\n    \"    algo='fedavg',\\n\",\n    \"    model=model,\\n\",\n    \"    dataset=dataset,\\n\",\n    \"    data_collator=data_collator,\\n\",\n    \"    training_args=Seq2SeqTrainingArguments(\\n\",\n    \"        num_train_epochs=epochs,\\n\",\n    \"        per_device_train_batch_size=batch_size,\\n\",\n    \"        remove_unused_columns=False, \\n\",\n    \"        predict_with_generate=False,\\n\",\n    \"        deepspeed=ds_config,\\n\",\n    \"        learning_rate=lr,\\n\",\n    \"        use_cpu=False, # this must be set as we will gpu\\n\",\n    \"        fp16=True,\\n\",\n    \"    ),\\n\",\n    \"    fed_args=FedAVGArguments(),\\n\",\n    \"    task_type='causal_lm',\\n\",\n    \"    save_trainable_weights_only=True # only save trainable weights\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"homo_nn_0 = HomoNN(\\n\",\n    \"    'nn_0',\\n\",\n    \"    runner_conf=conf,\\n\",\n    \"    train_data=reader_0.outputs[\\\"output_data\\\"],\\n\",\n    \"    runner_module=\\\"homo_seq2seq_runner\\\",\\n\",\n    \"    runner_class=\\\"Seq2SeqRunner\\\",\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"homo_nn_0.guest.conf.set(\\\"launcher_name\\\", \\\"deepspeed\\\") # tell schedule engine to run task with deepspeed\\n\",\n    \"homo_nn_0.hosts[0].conf.set(\\\"launcher_name\\\", \\\"deepspeed\\\") # tell schedule engine to run task with deepspeed\\n\",\n    \"\\n\",\n    \"pipeline.add_tasks([reader_0, homo_nn_0])\\n\",\n    \"pipeline.conf.set(\\\"task\\\", dict(engine_run={\\\"cores\\\": 1})) # the number of gpus of each party\\n\",\n    \"\\n\",\n    \"pipeline.compile()\\n\",\n    \"pipeline.fit()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Training With P-Tuning V2 Adapter\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"To use another adapter lke P-Tuning V2, slightly changes is needed!\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"model = LLMModelLoader(\\n\",\n    \"    \\\"pellm.chatglm\\\",\\n\",\n    \"    \\\"ChatGLM\\\",\\n\",\n    \"    pretrained_path=pretrained_model_path,\\n\",\n    \"    pre_seq_len=128,\\n\",\n    \"    trust_remote_code=True\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Inference\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Models trained with FATE-LLM can be find under the directory `${fate_install}/fateflow/model/$job_id/${role}/${party_id}/$cpn_name/0/output/output_model/model_directory/adapter_model.bin}`,\\n\",\n    \"The following code is an example to load trained lora adapter weights:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import json\\n\",\n    \"import sys\\n\",\n    \"import torch\\n\",\n    \"from peft import PeftModel, PeftConfig, LoraConfig, TaskType, get_peft_model\\n\",\n    \"from transformers import AutoModel, AutoTokenizer\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def load_model(pretrained_model_path):\\n\",\n    \"    _tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path, trust_remote_code=True)\\n\",\n    \"    _model = AutoModel.from_pretrained(pretrained_model_path, trust_remote_code=True)\\n\",\n    \"\\n\",\n    \"    _model = _model.half()\\n\",\n    \"    _model = _model.eval()\\n\",\n    \"\\n\",\n    \"    return _model, _tokenizer\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def load_data(data_path):\\n\",\n    \"    with open(data_path, \\\"r\\\") as fin:\\n\",\n    \"        for _l in fin:\\n\",\n    \"            yield json.loads(_l.strip())\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"chatglm_model_path = \\\"\\\"\\n\",\n    \"model, tokenizer = load_model(chatglm_model_path)\\n\",\n    \"\\n\",\n    \"test_data_path = \\\"{fate_install}/examples/data/AdvertiseGen/dev.json\\\"\\n\",\n    \"dataset = load_data(test_data_path)\\n\",\n    \"\\n\",\n    \"peft_path = \\\"${fate_install}/fateflow/model/$job_id/${role}/${party_id}/$cpn_name/0/output/output_model/model_directory/adapter_model.bin}\\\"\\n\",\n    \"\\n\",\n    \"model = PeftModel.from_pretrained(model, peft_path)\\n\",\n    \"model = model.half()\\n\",\n    \"model.eval()\\n\",\n    \"\\n\",\n    \"for p in model.parameters():\\n\",\n    \"    if p.requires_grad:\\n\",\n    \"        print(p)\\n\",\n    \"\\n\",\n    \"model.cuda(\\\"cuda:0\\\")\\n\",\n    \"\\n\",\n    \"content = list(dataset)[0][\\\"content\\\"]\\n\",\n    \"print(model.chat(tokenizer, content, do_sample=False))\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.15\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "doc/tutorial/pellm/builtin_pellm_models.md",
    "content": "## Builtin PELLM Models\nFATE-LLM provide some builtin pellm models, users can use them simply to efficiently train their language models.\nTo use these models, please read the using tutorial of [ChatGLM-6B Training Guide](./ChatGLM3-6B_ds.ipynb).   \nAfter reading the training tutorial above, it's easy to use other models listing in the following tabular by changing `module_name`, `class_name`, `dataset` list below.\n  \n  \n\n| Model          | ModuleName        | ClassName     | DataSetName     | \n| -------------- | ----------------- | --------------| --------------- |                 \n| Qwen2          | pellm.qwen        | Qwen          | prompt_dataset  |                              \n| Bloom-7B1      | pellm.bloom       | Bloom         | prompt_dataset  |                              \n| OPT-6.7B       | pellm.opt         | OPT           | prompt_dataset  |                              \n| LLaMA-2-7B     | pellm.llama       | LLaMa         | prompt_dataset  |                              \n| LLaMA-7B       | pellm.llama       | LLaMa         | prompt_dataset  |                              \n| ChatGLM3-6B    | pellm.chatglm     | ChatGLM       | prompt_dataset  |                              \n| GPT-2          | pellm.gpt2        | GPT2CLM       | prompt_dataset  |                              \n| GPT-2          | pellm.gpt2        | GPT2          | seq_cls_dataset |                              \n| ALBERT         | pellm.albert      | Albert        | seq_cls_dataset |                              \n| BART           | pellm.bart        | Bart          | seq_cls_dataset |                              \n| BERT           | pellm.bert        | Bert          | seq_cls_dataset |                              \n| DeBERTa        | pellm.deberta     | Deberta       | seq_cls_dataset |                              \n| DistilBERT     | pellm.distilbert  | DistilBert    | seq_cls_dataset |                              \n| RoBERTa        | pellm.roberta     | Roberta       | seq_cls_dataset |                              \n"
  },
  {
    "path": "examples/fedmkt/__init__.py",
    "content": ""
  },
  {
    "path": "examples/fedmkt/fedmkt.py",
    "content": "from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_fedmkt_runner\nfrom fate_client.pipeline.components.fate.nn.algo_params import FedMKTTrainingArguments, FedAVGArguments\nfrom fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMDataFuncLoader\nfrom peft import LoraConfig, TaskType\nfrom fate_client.pipeline import FateFlowPipeline\nfrom fate_client.pipeline.components.fate.reader import Reader\nfrom transformers import AutoConfig\nimport argparse\nimport yaml\nfrom typing import Union, Dict\n\ndef main(config=\"./config.yaml\", param: Union[Dict, str] = None):\n    if isinstance(config, str):\n        with open(config, 'r') as f:\n            config = yaml.safe_load(f)\n    \n    if isinstance(param, str):\n        param = yaml.safe_load(param)\n\n    guest = config['parties']['guest'][0]  # replace with actual guest party ID\n    host = config['parties']['host'][0]    # replace with actual host party ID\n    arbiter = config['parties']['arbiter'][0]  # replace with actual arbiter party ID\n    \n    process_data_output_dir = config['paths']['process_data_output_dir']\n    llm_pretrained_path = config['paths']['llm_pretrained_path']\n    slm_pretrained_paths = config['paths']['slm_pretrained_paths']\n    vocab_mapping_directory = config['paths']['vocab_mapping_directory']\n\n    slm_to_llm_vocab_mapping_paths = [\n        vocab_mapping_directory + \"/\" + path for path in config['paths']['slm_to_llm_vocab_mapping_paths']\n    ]\n    llm_to_slm_vocab_mapping_paths = [\n        vocab_mapping_directory + \"/\" + path for path in config['paths']['llm_to_slm_vocab_mapping_paths']\n    ]\n    \n    slm_models = config['models']['slm_models']\n    slm_lora_target_modules = config['lora_config']['slm_lora_target_modules']\n    \n    def get_llm_conf():\n        lora_config = LoraConfig(\n            task_type=TaskType.CAUSAL_LM,\n            inference_mode=False,\n            r=param['lora_config']['llm']['r'],\n            lora_alpha=param['lora_config']['llm']['lora_alpha'],\n            lora_dropout=param['lora_config']['llm']['lora_dropout'],\n            target_modules=param['lora_config']['llm']['target_modules']\n        )\n        lora_config.target_modules = list(lora_config.target_modules)\n\n        llm_model = LLMModelLoader(\n            \"pellm.llama\",\n            \"LLaMa\",\n            pretrained_path=llm_pretrained_path,\n            peft_type=\"LoraConfig\",\n            peft_config=lora_config.to_dict(),\n            torch_dtype=\"bfloat16\"\n        )\n\n        pub_dataset = LLMDatasetLoader(\n            \"qa_dataset\",\n            \"QaDataset\",\n            tokenizer_name_or_path=llm_pretrained_path,\n            need_preprocess=True,\n            dataset_name=\"arc_challenge\",\n            data_part=\"common\",\n            seq_max_len=512\n        )\n\n        training_args = FedMKTTrainingArguments(\n            global_epochs=param['training']['llm']['global_epochs'],\n            per_device_train_batch_size=param['training']['llm']['per_device_train_batch_size'],\n            gradient_accumulation_steps=param['training']['llm']['gradient_accumulation_steps'],\n            learning_rate=param['training']['llm']['learning_rate'],\n            output_dir=param['training']['llm']['output_dir'],\n            dataloader_num_workers=param['training']['llm']['dataloader_num_workers'],\n            remove_unused_columns=param['training']['llm']['remove_unused_columns'],\n            warmup_ratio=param['training']['llm']['warmup_ratio'],\n            lr_scheduler_type=param['training']['llm']['lr_scheduler_type'],\n            optim=param['training']['llm']['optim'],\n            adam_beta1=param['training']['llm']['adam_beta1'],\n            adam_beta2=param['training']['llm']['adam_beta2'],\n            weight_decay=param['training']['llm']['weight_decay'],\n            max_grad_norm=param['training']['llm']['max_grad_norm'],\n            use_cpu=param['training']['llm']['use_cpu'],\n            vocab_size=AutoConfig.from_pretrained(llm_pretrained_path).vocab_size,\n        )\n\n        fed_args = FedAVGArguments(\n            aggregate_strategy='epoch',\n            aggregate_freq=1\n        )\n\n        tokenizer = LLMDataFuncLoader(\n            \"tokenizers.cust_tokenizer\",\n            \"get_tokenizer\",\n            tokenizer_name_or_path=llm_pretrained_path\n        )\n\n        slm_tokenizers = [\n            LLMDataFuncLoader(\"tokenizers.cust_tokenizer\", \"get_tokenizer\", tokenizer_name_or_path=path)\n            for path in slm_pretrained_paths\n        ]\n\n        return get_config_of_fedmkt_runner(\n            model=llm_model,\n            training_args=training_args,\n            fed_args=fed_args,\n            pub_dataset=pub_dataset,\n            tokenizer=tokenizer,\n            slm_tokenizers=slm_tokenizers,\n            slm_to_llm_vocab_mapping_paths=slm_to_llm_vocab_mapping_paths,\n            pub_dataset_path=process_data_output_dir,\n            save_trainable_weights_only=True,\n        )\n    \n    def get_slm_conf(slm_idx):\n        slm_pretrained_path = slm_pretrained_paths[slm_idx]\n        lora_config = LoraConfig(\n            task_type=TaskType.CAUSAL_LM,\n            inference_mode=False, \n            r=param['lora_config']['slm'][slm_idx]['r'],\n            lora_alpha=param['lora_config']['slm'][slm_idx]['lora_alpha'],\n            lora_dropout=param['lora_config']['slm'][slm_idx]['lora_dropout'],\n            target_modules=param['lora_config']['slm'][slm_idx]['target_modules']\n        )\n        lora_config.target_modules = list(lora_config.target_modules)\n        llm_to_slm_vocab_mapping = llm_to_slm_vocab_mapping_paths[slm_idx]\n\n        slm_model = LLMModelLoader(\n            slm_models[slm_idx][0],\n            slm_models[slm_idx][1],\n            pretrained_path=slm_pretrained_path,\n            peft_type=\"LoraConfig\",\n            peft_config=lora_config.to_dict(),\n        )\n        vocab_size = AutoConfig.from_pretrained(slm_pretrained_path).vocab_size\n\n        pub_dataset = LLMDatasetLoader(\n            \"qa_dataset\",\n            \"QaDataset\",\n            tokenizer_name_or_path=slm_pretrained_path,\n            need_preprocess=True,\n            dataset_name=\"arc_challenge\",\n            data_part=\"common\",\n            seq_max_len=512\n        )\n\n        priv_dataset = LLMDatasetLoader(\n            \"qa_dataset\",\n            \"QaDataset\",\n            tokenizer_name_or_path=slm_pretrained_path,\n            need_preprocess=True,\n            dataset_name=\"arc_challenge\",\n            data_part=\"client_0\",\n            seq_max_len=512\n        )\n\n        training_args = FedMKTTrainingArguments(\n            global_epochs=param['training']['slm']['global_epochs'],\n            per_device_train_batch_size=param['training']['slm']['per_device_train_batch_size'],\n            gradient_accumulation_steps=param['training']['slm']['gradient_accumulation_steps'],\n            learning_rate=param['training']['slm']['learning_rate'] if slm_idx != 1 else 3e-4,\n            output_dir=param['training']['slm']['output_dir'],\n            dataloader_num_workers=param['training']['slm']['dataloader_num_workers'],\n            remove_unused_columns=param['training']['slm']['remove_unused_columns'],\n            warmup_ratio=param['training']['slm']['warmup_ratio'],\n            lr_scheduler_type=param['training']['slm']['lr_scheduler_type'],\n            optim=param['training']['slm']['optim'],\n            adam_beta1=param['training']['slm']['adam_beta1'],\n            adam_beta2=param['training']['slm']['adam_beta2'],\n            weight_decay=param['training']['slm']['weight_decay'],\n            max_grad_norm=param['training']['slm']['max_grad_norm'],\n            use_cpu=param['training']['slm']['use_cpu'],\n            vocab_size=vocab_size,\n        )\n\n        fed_args = FedAVGArguments(\n            aggregate_strategy='epoch',\n            aggregate_freq=1\n        )\n\n        tokenizer = LLMDataFuncLoader(\n            \"tokenizers.cust_tokenizer\",\n            \"get_tokenizer\",\n            tokenizer_name_or_path=slm_pretrained_path\n        )\n\n        llm_tokenizer = LLMDataFuncLoader(\n            \"tokenizers.cust_tokenizer\", \n            \"get_tokenizer\", \n            tokenizer_name_or_path=llm_pretrained_path\n        )\n\n        data_collator = LLMDataFuncLoader(\n            module_name='data_collator.cust_data_collator',\n            item_name='get_seq2seq_data_collator', \n            tokenizer_name_or_path=slm_pretrained_path\n        )\n\n        return get_config_of_fedmkt_runner(\n            model=slm_model,\n            training_args=training_args,\n            fed_args=fed_args,\n            pub_dataset=pub_dataset,\n            priv_dataset=priv_dataset,\n            tokenizer=tokenizer,\n            llm_tokenizer=llm_tokenizer,\n            llm_to_slm_vocab_mapping_path=llm_to_slm_vocab_mapping,\n            pub_dataset_path=process_data_output_dir,\n            save_trainable_weights_only=True,\n            data_collator=data_collator\n        )\n    \n    pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter, host=host)\n    pipeline.bind_local_path(path=process_data_output_dir, namespace=\"experiment\", name=\"arc_challenge\")\n    \n    reader_0 = Reader(\"reader_0\", runtime_parties=dict(guest=guest, host=host))\n    reader_0.guest.task_parameters(\n        namespace=config['data']['guest']['namespace'],\n        name=config['data']['guest']['name']\n    )\n    reader_0.hosts[[0, 1, 2]].task_parameters(\n        namespace=config['data']['host']['namespace'],\n        name=config['data']['host']['name']\n    )\n\n    homo_nn_0 = HomoNN(\n        'nn_0',\n        train_data=reader_0.outputs[\"output_data\"],\n        runner_module=\"fedmkt_runner\",\n        runner_class=\"FedMKTRunner\",\n    )\n    \n    homo_nn_0.arbiter.task_parameters(\n        runner_conf=get_llm_conf()\n    )\n    \n    homo_nn_0.guest.task_parameters(\n        runner_conf=get_slm_conf(slm_idx=0)\n    )\n    \n    for idx in range(1):\n        homo_nn_0.hosts[idx].task_parameters(\n            runner_conf=get_slm_conf(slm_idx=idx + 1)\n        )\n    \n    homo_nn_0.guest.conf.set(\"launcher_name\", \"deepspeed\")  # tell scheduler engine to run task with deepspeed\n    homo_nn_0.hosts[0].conf.set(\"launcher_name\", \"deepspeed\")  # tell scheduler engine to run task with deepspeed\n    homo_nn_0.arbiter.conf.set(\"launcher_name\", \"deepspeed\")  # tell scheduler engine to run task with deepspeed\n    \n    pipeline.add_tasks([reader_0, homo_nn_0])\n    pipeline.conf.set(\"task\", dict(engine_run={\"cores\": 1}))  # the number of gpus of each party\n    \n    pipeline.compile()\n    pipeline.fit()\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\"LLMSUITE PIPELINE JOB\")\n    parser.add_argument(\"-c\", \"--config\", type=str, help=\"config file\", default=\"./config.yaml\")\n    parser.add_argument(\"-p\", \"--param\", type=str, help=\"config file for params\", default=\"./fedmkt_config.yaml\")\n    args = parser.parse_args()\n    main(args.config, args.param)\n"
  },
  {
    "path": "examples/fedmkt/fedmkt_config.yaml",
    "content": "# fedmkt_config.yaml\n\n# Configuration for Lora\nlora_config:\n  llm:\n    r: 8\n    lora_alpha: 16\n    lora_dropout: 0.05\n    target_modules:\n      - q_proj\n      - k_proj\n      - v_proj\n      - o_proj\n  slm:\n    - # Configuration for the first SLM model\n      r: 8\n      lora_alpha: 32\n      lora_dropout: 0.1\n      target_modules:\n        - q_proj\n        - v_proj\n    - # Configuration for the second SLM model\n      r: 8\n      lora_alpha: 32\n      lora_dropout: 0.1\n      target_modules:\n        - c_attn\n\n# Training configuration\ntraining:\n  llm:\n    global_epochs: 5\n    per_device_train_batch_size: 1\n    gradient_accumulation_steps: 4\n    learning_rate: 3e-5\n    output_dir: \"./\"\n    dataloader_num_workers: 4\n    remove_unused_columns: false\n    warmup_ratio: 0.008\n    lr_scheduler_type: \"cosine\"\n    optim: \"adamw_torch\"\n    adam_beta1: 0.9\n    adam_beta2: 0.95\n    weight_decay: 0.1\n    max_grad_norm: 1.0\n    use_cpu: false\n  slm:\n    global_epochs: 5\n    per_device_train_batch_size: 1\n    gradient_accumulation_steps: 4\n    learning_rate: 3e-5  # Adjust learning rate for SLM models\n    output_dir: \"./\"\n    dataloader_num_workers: 4\n    remove_unused_columns: false\n    warmup_ratio: 0.008\n    lr_scheduler_type: \"cosine\"\n    optim: \"adamw_torch\"\n    adam_beta1: 0.9\n    adam_beta2: 0.95\n    weight_decay: 0.1\n    max_grad_norm: 1.0\n    use_cpu: false\n\n# Paths configuration\npaths:\n  process_data_output_dir: \"\"\n  llm_pretrained_path: \"Llama-2-7b-hf\"\n  slm_pretrained_paths:\n    - \"opt-1.3b\"\n    - \"gpt2\"\n  vocab_mapping_directory: \"\"\n  slm_to_llm_vocab_mapping_paths:\n    - \"opt_to_llama.json\"\n    - \"gpt2_to_llama.json\"\n    - \"llama_small_to_llama.json\"\n  llm_to_slm_vocab_mapping_paths:\n    - \"llama_to_opt.json\"\n    - \"llama_to_gpt2.json\"\n    - \"llama_to_llama_small\"\n\n# Models configuration\nmodels:\n  slm_models:\n    - [\"pellm.opt\", \"OPT\"]\n    - [\"pellm.gpt2\", \"GPT2CLM\"]\n\n# Data configuration\ndata:\n  guest:\n    namespace: \"experiment\"\n    name: \"arc_challenge\"\n  host:\n    namespace: \"experiment\"\n    name: \"arc_challenge\"\n\n# Example: Additional custom configuration\ncustom_config:\n  some_param: \"value\"\n  another_param: 123\n"
  },
  {
    "path": "examples/fedmkt/test_fedmkt_llmsuit.yaml",
    "content": "data:\n  - file: \n    table_name: arc_challenge\n    namespace: experiment\n    role: guest_0\n  - file: \n    table_name: arc_challenge\n    namespace: experiment\n    role: host_0\nbloom_lora_vs_zero_shot:\n  gpt2_fedmkt:\n    pretrained: \"gpt2\"\n    script: \"./fedmkt.py\"\n    conf: \"./fedmkt_config.yaml\""
  },
  {
    "path": "examples/offsite_tuning/__init__.py",
    "content": ""
  },
  {
    "path": "examples/offsite_tuning/offsite_tuning.py",
    "content": "import argparse\nimport yaml\nfrom fate_client.pipeline.components.fate.reader import Reader\nfrom fate_client.pipeline import FateFlowPipeline\nfrom fate_client.pipeline.components.fate.homo_nn import HomoNN, get_conf_of_ot_runner\nfrom fate_client.pipeline.components.fate.nn.algo_params import Seq2SeqTrainingArguments, FedAVGArguments\nfrom fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMDataFuncLoader\nfrom fate_client.pipeline.components.fate.nn.torch.base import Sequential\nfrom fate_client.pipeline.components.fate.nn.torch import nn\n\ndef load_params(file_path):\n    \"\"\"Load and parse the YAML params file.\"\"\"\n    with open(file_path, 'r') as f:\n        params = yaml.safe_load(f)\n    return params\n\ndef setup_pipeline(params):\n    \"\"\"Set up the pipeline using the provided parameters.\"\"\"\n    guest = params['pipeline']['guest']\n    arbiter = params['pipeline']['arbiter']\n    pretrained_model_path = params['paths']['pretrained_model_path']\n    \n    pipeline = FateFlowPipeline().set_parties(guest=guest, arbiter=arbiter)\n    \n    reader = Reader(\"reader_0\", runtime_parties=dict(guest=guest))\n    reader.guest.task_parameters(\n        namespace=params['pipeline']['namespace'],\n        name=params['pipeline']['name']\n    )\n    \n    client_model = LLMModelLoader(\n        module_name=params['models']['client']['module_name'],\n        item_name=params['models']['client']['item_name'],\n        model_name_or_path=pretrained_model_path,\n        emulator_layer_num=params['models']['client']['emulator_layer_num'],\n        adapter_top_layer_num=params['models']['client']['adapter_top_layer_num'],\n        adapter_bottom_layer_num=params['models']['client']['adapter_bottom_layer_num']\n    )\n    \n    server_model = LLMModelLoader(\n        module_name=params['models']['server']['module_name'],\n        item_name=params['models']['server']['item_name'],\n        model_name_or_path=pretrained_model_path,\n        emulator_layer_num=params['models']['server']['emulator_layer_num'],\n        adapter_top_layer_num=params['models']['server']['adapter_top_layer_num'],\n        adapter_bottom_layer_num=params['models']['server']['adapter_bottom_layer_num']\n    )\n    \n    dataset = LLMDatasetLoader(\n        module_name=params['dataset']['module_name'],\n        item_name=params['dataset']['item_name'],\n        tokenizer_name_or_path=params['dataset']['tokenizer_name_or_path'],\n        select_num=params['dataset']['select_num']\n    )\n    \n    data_collator = LLMDataFuncLoader(\n        module_name=params['data_collator']['module_name'],\n        item_name=params['data_collator']['item_name'],\n        tokenizer_name_or_path=params['data_collator']['tokenizer_name_or_path']\n    )\n    \n    train_args = Seq2SeqTrainingArguments(\n        per_device_train_batch_size=params['training']['batch_size'],\n        learning_rate=params['training']['learning_rate'],\n        disable_tqdm=False,\n        num_train_epochs=params['training']['num_train_epochs'],\n        logging_steps=params['training']['logging_steps'],\n        logging_strategy='steps',\n        dataloader_num_workers=4,\n        use_cpu=False,\n        deepspeed=params['training']['deepspeed'],  # Add DeepSpeed config here\n        remove_unused_columns=False,\n        fp16=True\n    )\n    \n    client_conf = get_conf_of_ot_runner(\n        model=client_model,\n        dataset=dataset,\n        data_collator=data_collator,\n        training_args=train_args,\n        fed_args=FedAVGArguments(),\n        aggregate_model=False,\n    )\n    \n    server_conf = get_conf_of_ot_runner(\n        model=server_model,\n        dataset=dataset,\n        data_collator=data_collator,\n        training_args=train_args,\n        fed_args=FedAVGArguments(),\n        aggregate_model=False\n    )\n    \n    homo_nn = HomoNN(\n        'nn_0',\n        train_data=reader.outputs[\"output_data\"],\n        runner_module=\"offsite_tuning_runner\",\n        runner_class=\"OTRunner\"\n    )\n    \n    homo_nn.guest.task_parameters(runner_conf=client_conf)\n    homo_nn.arbiter.task_parameters(runner_conf=server_conf)\n    \n    # If using Eggroll, you can add this line to submit your job\n    homo_nn.guest.conf.set(\"launcher_name\", \"deepspeed\")\n    \n    pipeline.add_tasks([reader, homo_nn])\n    pipeline.conf.set(\"task\", dict(engine_run=params['pipeline']['engine_run']))\n    pipeline.compile()\n    pipeline.fit()\n\ndef main(config_file, param_file):\n    params = load_params(param_file)\n    setup_pipeline(params)\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\"LLMSUITE Offsite-tuning JOB\")\n    parser.add_argument(\"-c\", \"--config\", type=str,\n                        help=\"Path to config file\", default=\"./config.yaml\")\n    parser.add_argument(\"-p\", \"--param\", type=str,\n                        help=\"Path to parameter file\", default=\"./test_offsite_tuning_llmsuite.yaml\")\n    args = parser.parse_args()\n    main(args.config, args.param)\n"
  },
  {
    "path": "examples/offsite_tuning/offsite_tuning_config.yaml",
    "content": "# params.yaml\n\npaths:\n  pretrained_model_path: 'gpt2'\n\npipeline:\n  guest: '9999'\n  arbiter: '9999'\n  namespace: 'experiment'\n  name: 'sciq'\n  engine_run:\n    cores: 1\n\ntraining:\n  batch_size: 1\n  learning_rate: 5e-5\n  num_train_epochs: 1\n  logging_steps: 10\n  deepspeed:\n    train_micro_batch_size_per_gpu: 1\n    optimizer:\n      type: \"Adam\"\n      params:\n        lr: 5e-5\n        torch_adam: true\n        adam_w_mode: false\n    fp16:\n      enabled: true\n    gradient_accumulation_steps: 1\n    zero_optimization:\n      stage: 2\n      allgather_partitions: true\n      allgather_bucket_size: 1e8\n      overlap_comm: true\n      reduce_scatter: true\n      reduce_bucket_size: 1e8\n      contiguous_gradients: true\n      offload_optimizer:\n        device: \"cpu\"\n      offload_param:\n        device: \"cpu\"\n\nmodels:\n  client:\n    module_name: 'offsite_tuning.gpt2'\n    item_name: 'GPT2LMHeadSubModel'\n    emulator_layer_num: 11\n    adapter_top_layer_num: 2\n    adapter_bottom_layer_num: 2\n\n  server:\n    module_name: 'offsite_tuning.gpt2'\n    item_name: 'GPT2LMHeadMainModel'\n    emulator_layer_num: 11\n    adapter_top_layer_num: 2\n    adapter_bottom_layer_num: 2\n\ndataset:\n  module_name: 'qa_dataset'\n  item_name: 'QaDataset'\n  tokenizer_name_or_path: 'gpt2'\n  select_num: 100\n\ndata_collator:\n  module_name: 'data_collator.cust_data_collator'\n  item_name: 'get_seq2seq_data_collator'\n  tokenizer_name_or_path: 'gpt2'\n"
  },
  {
    "path": "examples/offsite_tuning/test_offsite_tuning_llmsuite.yaml",
    "content": "data:\n  - file: \n    table_name: sciq\n    namespace: experiment\n    role: guest_0\n  - file: \n    table_name: sciq\n    namespace: experiment\n    role: host_0\nbloom_lora_vs_zero_shot:\n  gpt2_ot:\n    pretrained: \"gpt2\"\n    script: \"./offsite_tuning.py\"\n    conf: \"./offsite_tuning_config.yaml\""
  },
  {
    "path": "examples/pellm/__init__.py",
    "content": ""
  },
  {
    "path": "examples/pellm/bloom_lora_config.yaml",
    "content": "data:\n  guest:\n    namespace: experiment\n    name: ad\n  host:\n    namespace: experiment\n    name: ad\nepoch: 1\nbatch_size: 4\nlr: 5e-4\npretrained_model_path: bloom-560m\npeft_config:\n  alpha_pattern: {}\n  auto_mapping: null\n  base_model_name_or_path: null\n  bias: none\n  fan_in_fan_out: false\n  inference_mode: false\n  init_lora_weights: true\n  layers_pattern: null\n  layers_to_transform: null\n  loftq_config: { }\n  lora_alpha: 32\n  lora_dropout: 0.1\n  megatron_config: null\n  megatron_core: megatron.core\n  modules_to_save: null\n  peft_type: LORA\n  r: 8\n  rank_pattern: { }\n  revision: null\n  target_modules:\n    - query_key_value\n  task_type: CAUSAL_LM\n  use_rslora: false\nds_config:\n  fp16:\n    enabled: true\n  gradient_accumulation_steps: 1\n  optimizer:\n    params:\n      adam_w_mode: false\n      lr: 5e-4\n      torch_adam: true\n    type: Adam\n  train_micro_batch_size_per_gpu: 4\n  zero_optimization:\n    allgather_bucket_size: 100000000.0\n    allgather_partitions: true\n    contiguous_gradients: true\n    offload_optimizer:\n      device: cpu\n    offload_param:\n      device: cpu\n    overlap_comm: true\n    reduce_bucket_size: 100000000.0\n    reduce_scatter: true\n    stage: 2\n"
  },
  {
    "path": "examples/pellm/test_bloom_lora.py",
    "content": "import time\nfrom fate_client.pipeline.components.fate.reader import Reader\nfrom fate_client.pipeline import FateFlowPipeline\nfrom fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_seq2seq_runner\nfrom fate_client.pipeline.components.fate.nn.algo_params import Seq2SeqTrainingArguments, FedAVGArguments\nfrom fate_client.pipeline.components.fate.nn.loader import LLMModelLoader, LLMDatasetLoader, LLMDataFuncLoader\nfrom peft import LoraConfig, TaskType\nfrom fate_client.pipeline.utils import test_utils\nimport argparse\nimport yaml\nfrom typing import Union, Dict\n\n\ndef main(config=\"../../config.yaml\", param: Union[Dict, str] = None, namespace=\"\"):\n    if isinstance(config, str):\n        config = test_utils.load_job_config(config)\n    if isinstance(param, str):\n        param = yaml.safe_load(param)\n    parties = config.parties\n    guest = parties.guest[0]\n    host = parties.host[0]\n    arbiter = parties.arbiter[0]\n    pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter)\n\n    reader_0 = Reader(\"reader_0\", runtime_parties=dict(guest=guest, host=host))\n    reader_0.guest.task_parameters(\n        namespace=param[\"data\"][\"guest\"][\"namespace\"],\n        name=param[\"data\"][\"guest\"][\"name\"]\n    )\n    reader_0.hosts[0].task_parameters(\n        namespace=param[\"data\"][\"host\"][\"namespace\"],\n        name=param[\"data\"][\"host\"][\"name\"]\n    )\n\n    lora_config = LoraConfig(**param[\"peft_config\"])\n    lora_config.target_modules = list(lora_config.target_modules)\n\n    pretrained_model_path = param[\"pretrained_model_path\"]\n    model = LLMModelLoader(\n        \"pellm.bloom\",\n        \"Bloom\",\n        pretrained_path=pretrained_model_path,\n        peft_type=\"LoraConfig\",\n        peft_config=lora_config.to_dict(),\n        trust_remote_code=True\n    )\n\n    tokenizer_params = dict(\n        tokenizer_name_or_path=pretrained_model_path,\n        trust_remote_code=True,\n    )\n\n    dataset = LLMDatasetLoader(\n        \"prompt_dataset\",\n        \"PromptDataset\",\n        **tokenizer_params,\n    )\n\n    data_collator = LLMDataFuncLoader(\n        \"data_collator.cust_data_collator\",\n        \"get_seq2seq_data_collator\",\n        **tokenizer_params,\n    )\n\n    conf = get_config_of_seq2seq_runner(\n        algo='fedavg',\n        model=model,\n        dataset=dataset,\n        data_collator=data_collator,\n        training_args=Seq2SeqTrainingArguments(\n            num_train_epochs=param[\"epoch\"],\n            per_device_train_batch_size=param[\"batch_size\"],\n            remove_unused_columns=False,\n            predict_with_generate=False,\n            deepspeed=param[\"ds_config\"],\n            learning_rate=param[\"lr\"],\n            use_cpu=False,  # this must be set as we will gpu\n            fp16=True,\n        ),\n        fed_args=FedAVGArguments(),\n        task_type='causal_lm',\n        save_trainable_weights_only=True  # only save trainable weights\n    )\n\n    homo_nn_0 = HomoNN(\n        'nn_0',\n        runner_conf=conf,\n        train_data=reader_0.outputs[\"output_data\"],\n        runner_module=\"homo_seq2seq_runner\",\n        runner_class=\"Seq2SeqRunner\",\n    )\n\n    homo_nn_0.guest.conf.set(\"launcher_name\", \"deepspeed\")  # tell schedule engine to run task with deepspeed\n    homo_nn_0.hosts[0].conf.set(\"launcher_name\", \"deepspeed\")  # tell schedule engine to run task with deepspeed\n\n    pipeline.add_tasks([reader_0, homo_nn_0])\n    pipeline.conf.set(\"task\", dict(engine_run={\"cores\": 1}))  # the number of gpus of each party\n\n    pipeline.compile()\n    pipeline.fit()\n\n    return pretrained_model_path\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\"LLMSUITE PIPELINE JOB\")\n    parser.add_argument(\"-c\", \"--config\", type=str,\n                        help=\"config file\", default=\"../../config.yaml\")\n    parser.add_argument(\"-p\", \"--param\", type=str,\n                        help=\"config file for params\", default=\"./bloom_lora_config.yaml\")\n    args = parser.parse_args()\n    main(args.config, args.param)\n"
  },
  {
    "path": "examples/pellm/test_pellm_llmsuite.yaml",
    "content": "data:\n  - file: examples/data/AdvertiseGen/train.json\n    table_name: ad\n    namespace: experiment\n    role: guest_0\n  - file: examples/data/AdvertiseGen/train.json\n    table_name: ad\n    namespace: experiment\n    role: host_0\nbloom_lora_vs_zero_shot:\n  bloom_lora:\n    pretrained: \"bloom-560m\"\n    script: \"./test_bloom_lora.py\"\n    conf: \"./bloom_lora_config.yaml\"\n    peft_path_format: \"{{fate_base}}/fate_flow/model/{{job_id}}/guest/{{party_id}}/{{model_task_name}}/0/output/output_model/model_directory\"\n    tasks:\n      - \"advertise-gen\"\n  bloom_zero_shot:\n    pretrained: \"bloom-560m\"\n    tasks:\n      - \"advertise-gen\""
  },
  {
    "path": "python/MANIFEST.in",
    "content": "include fate_llm/dataset/data_config/*yaml\ninclude python/fate_llm/evaluate/tasks/*/*yaml"
  },
  {
    "path": "python/fate_llm/__init__.py",
    "content": ""
  },
  {
    "path": "python/fate_llm/algo/__init__.py",
    "content": ""
  },
  {
    "path": "python/fate_llm/algo/dp/__init__.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom .opacus_compatibility.transformers_compate import get_model_class\nfrom .dp_trainer import DPTrainer, DPTrainingArguments\n"
  },
  {
    "path": "python/fate_llm/algo/dp/dp_trainer.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport logging\nimport opacus\nimport os\nimport torch\nfrom dataclasses import dataclass, field\nfrom transformers.training_args_seq2seq import Seq2SeqTrainingArguments\nfrom torch.utils.data import DataLoader\nfrom tqdm import tqdm\nfrom typing import Optional, Callable\nfrom .opacus_compatibility import add_layer_compatibility, add_optimizer_compatibility\nfrom .opacus_compatibility.transformers_compate import prepare_position_ids\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclass\nclass DPTrainingArguments(Seq2SeqTrainingArguments):\n    target_epsilon: float = field(default=3)\n    target_delta: float = field(default=1e-5)\n    freeze_embedding: bool = field(default=True)\n    device_id: int = field(default=0)\n\n\nclass DPTrainer(object):\n    def __init__(\n        self,\n        model: torch.nn.Module,\n        training_args: DPTrainingArguments,\n        train_set,\n        loss_fn,\n        optimizer: torch.optim.Optimizer = None,\n        scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,\n        data_collator: Callable = None,\n        use_tqdm: bool = False,\n    ):\n        self.module = model\n        self.training_args = training_args\n        self.ori_optimizer = optimizer\n        self.lr_scheduler = scheduler\n        self.train_set = train_set\n        self.data_collator = data_collator\n        self.loss_fn = loss_fn\n        self.use_tqdm = use_tqdm\n\n        self.data_loader = DataLoader(\n            dataset=self.train_set,\n            shuffle=True,\n            batch_size=self.training_args.per_device_train_batch_size,\n            collate_fn=self.data_collator\n        )\n\n        if not self.training_args.use_cpu:\n            self.module.cuda(self.training_args.device_id)\n\n        if self.training_args.freeze_embedding:\n            self.freeze_model_embedding()\n\n        self.dp_model = None\n        self.dp_optimizer = None\n        self.privacy_engine = None\n        self._init_dp_model()\n\n    def _init_dp_model(self):\n        self.module.train()\n\n        # add compatibility for layer hooks\n        add_layer_compatibility(opacus)\n\n        self.privacy_engine = opacus.PrivacyEngine(accountant=\"rdp\")\n        self.dp_model, self.dp_optimizer, _ = self.privacy_engine.make_private_with_epsilon(\n            module=self.module,\n            optimizer=self.ori_optimizer,\n            data_loader=self.data_loader,\n            target_delta=self.training_args.target_delta,\n            target_epsilon=self.training_args.target_epsilon,\n            max_grad_norm=self.training_args.max_grad_norm,\n            epochs=int(self.training_args.num_train_epochs),\n        )\n\n        add_optimizer_compatibility(self.dp_optimizer)\n\n    def train(self):\n        logger.info(f\"begin dp training, total epochs={self.training_args.num_train_epochs}\")\n        for epoch in range(int(self.training_args.num_train_epochs)):\n            logger.info(f\"dp training on epoch={epoch}\")\n            self._train_an_epoch()\n\n    def _train_an_epoch(self):\n        if self.use_tqdm:\n            data_loader = tqdm(self.data_loader)\n        else:\n            data_loader = self.data_loader\n\n        for batch_idx, batch_data in enumerate(tqdm(data_loader)):\n            input_ids = batch_data[\"input_ids\"]\n            labels = batch_data[\"labels\"]\n\n            if \"attention_mask\" not in batch_data:\n                attention_mask = torch.ones(input_ids.shape)\n            else:\n                attention_mask = batch_data[\"attention_mask\"]\n\n            if not self.training_args.use_cpu:\n                input_ids = input_ids.to(self.module.device)\n                labels = labels.to(self.module.device)\n                attention_mask = attention_mask.to(self.module.device)\n\n            inputs = self._prepare_batch_input(input_ids)\n            logits = self.dp_model(**inputs).logits\n\n            loss = self.loss_fn(logits, labels, attention_mask)\n\n            loss = loss.mean()\n            loss.backward()\n\n            if (batch_idx + 1) % self.training_args.gradient_accumulation_steps == 0 or \\\n                    batch_idx + 1 == len(self.data_loader):\n                self.dp_optimizer.step()\n                if self.lr_scheduler is not None:\n                    self.lr_scheduler.step()\n                self.dp_optimizer.zero_grad()\n            else:\n                self.dp_optimizer.step()\n                self.dp_optimizer.zero_grad()\n\n    def _prepare_batch_input(self, input_ids) -> dict:\n        position_ids = prepare_position_ids(self.module, input_ids)\n        if not self.training_args.use_cpu:\n            position_ids = position_ids.to(self.module.device)\n\n        return dict(input_ids=input_ids, position_ids=position_ids)\n\n    def freeze_model_embedding(self):\n        self.module.get_input_embeddings().requires_grad_(False)\n\n    def save_model(\n        self,\n        output_dir=\"./\"\n    ):\n        if hasattr(self.module, \"save_pretrained\"):\n            self.module.save_pretrained(output_dir)\n        else:\n            if not os.path.exists(output_dir):\n                os.makedirs(output_dir)\n            torch.save(self.module.state_dict(), output_dir + '/pytorch_model.bin')\n"
  },
  {
    "path": "python/fate_llm/algo/dp/opacus_compatibility/__init__.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom .grad_sample.embedding import compute_embedding_grad_sample\nfrom .optimizers.optimizer import add_noise_wrapper\n\n\ndef add_layer_compatibility(opacus):\n    replace_method = []\n    for k, v in opacus.GradSampleModule.GRAD_SAMPLERS.items():\n        if v.__name__ == \"compute_embedding_grad_sample\":\n            replace_method.append(k)\n\n    for k in replace_method:\n        opacus.GradSampleModule.GRAD_SAMPLERS[k] = compute_embedding_grad_sample\n\n\ndef add_optimizer_compatibility(optimizer):\n    add_noise_wrapper(optimizer)\n"
  },
  {
    "path": "python/fate_llm/algo/dp/opacus_compatibility/grad_sample/__init__.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n"
  },
  {
    "path": "python/fate_llm/algo/dp/opacus_compatibility/grad_sample/embedding.py",
    "content": "#\n#  Copyright (c) Meta Platforms, Inc. and affiliates.\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport torch\nimport torch.nn as nn\nfrom typing import Dict\n\n\n# the function is modified from https://github.com/pytorch/opacus/blob/main/opacus/grad_sample/embedding.py#L25,\n# avoid dtype error when backprops's dtype isn't torch.float32\ndef compute_embedding_grad_sample(\n    layer: nn.Embedding, activations: torch.Tensor, backprops: torch.Tensor\n) -> Dict[nn.Parameter, torch.Tensor]:\n    \"\"\"\n    Computes per sample gradients for ``nn.Embedding`` layer.\n\n    Args:\n        layer: Layer\n        activations: Activations\n        backprops: Backpropagations\n    \"\"\"\n    activations = activations[0]\n    ret = {}\n    if layer.weight.requires_grad:\n        saved = torch.backends.cudnn.deterministic\n        torch.backends.cudnn.deterministic = True\n\n        batch_size = activations.shape[0]\n        if batch_size == 0:\n            ret[layer.weight] = torch.zeros_like(layer.weight).unsqueeze(0)\n            return ret\n\n        index = (\n            activations.unsqueeze(-1)\n            .expand(*activations.shape, layer.embedding_dim)\n            .reshape(batch_size, -1, layer.embedding_dim)\n        )\n        grad_sample = torch.zeros(\n            batch_size, *layer.weight.shape, device=layer.weight.device, dtype=backprops.dtype\n        )\n        grad_sample.scatter_add_(\n            1, index, backprops.reshape(batch_size, -1, layer.embedding_dim)\n        )\n        torch.backends.cudnn.deterministic = saved\n        ret[layer.weight] = grad_sample\n\n    return ret\n"
  },
  {
    "path": "python/fate_llm/algo/dp/opacus_compatibility/optimizers/__init__.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n"
  },
  {
    "path": "python/fate_llm/algo/dp/opacus_compatibility/optimizers/optimizer.py",
    "content": "#\n#  Copyright (c) Meta Platforms, Inc. and affiliates.\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport types\nfrom opacus.optimizers.optimizer import (\n    _check_processed_flag,\n    _generate_noise,\n    _mark_as_processed\n)\n\n\n# modified from https://github.com/pytorch/opacus/blob/main/opacus/optimizers/optimizer.py#L424\n# avoid dtype error when summed_grad's dtype isn't torch.float32\ndef add_noise(self):\n    \"\"\"\n    Adds noise to clipped gradients. Stores clipped and noised result in ``p.grad``\n    \"\"\"\n\n    for p in self.params:\n        _check_processed_flag(p.summed_grad)\n\n        noise = _generate_noise(\n            std=self.noise_multiplier * self.max_grad_norm,\n            reference=p.summed_grad,\n            generator=self.generator,\n            secure_mode=self.secure_mode,\n        )\n        noise = noise.to(p.summed_grad.dtype)\n        p.grad = (p.summed_grad + noise).view_as(p)\n\n        _mark_as_processed(p.summed_grad)\n\n\ndef add_noise_wrapper(optimizer):\n    optimizer.add_noise = types.MethodType(add_noise, optimizer)\n"
  },
  {
    "path": "python/fate_llm/algo/dp/opacus_compatibility/transformers_compate.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport torch\nimport transformers\nfrom fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM\nfrom transformers.modeling_utils import unwrap_model\n\n\ndef get_model_class(model):\n    if isinstance(model, PELLM):\n        model = model._pe_lm\n\n    model = unwrap_model(model)\n\n    return model.__class__\n\n\ndef prepare_position_ids(model, input_ids):\n    if get_model_class(model) == transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel:\n        return _get_position_ids_for_gpt2(input_ids)\n    else:\n        raise ValueError(f\"Can not prepare position_ids for model_type={model.__class__}\")\n\n\ndef _get_position_ids_for_gpt2(input_ids):\n    past_length = 0\n    position_ids = torch.arange(past_length, input_ids.shape[-1] + past_length, dtype=torch.long,\n                                device=input_ids.device)\n    position_ids = position_ids.unsqueeze(0)\n    position_ids = position_ids.repeat(input_ids.shape[0], 1)\n\n    return position_ids\n"
  },
  {
    "path": "python/fate_llm/algo/fdkt/__init__.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom .fdkt_data_aug import (\n    FDKTSLM,\n    FDKTLLM,\n    FDKTTrainingArguments\n)\n\n__all__ = [\n    \"FDKTSLM\",\n    \"FDKTLLM\",\n    \"FDKTTrainingArguments\"\n]\n"
  },
  {
    "path": "python/fate_llm/algo/fdkt/cluster/__init__.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n"
  },
  {
    "path": "python/fate_llm/algo/fdkt/cluster/cluster.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom typing import List\nfrom .cluster_method import get_cluster_runner\n\n\nclass SentenceCluster(object):\n    def __init__(self, model, cluster_method=\"kmeans\", n_clusters=8, **other_cluster_args):\n        self.model = model\n        self.cluster_method = cluster_method\n        self.n_clusters = n_clusters\n        self.other_cluster_args = other_cluster_args\n\n    def get_embeddings(self, sentences: List[str]):\n        return self.model.encode(sentences)\n\n    def cluster(self, sentences):\n        embeddings = self.get_embeddings(sentences)\n\n        cluster_runner = get_cluster_runner(method=self.cluster_method,\n                                            n_clusters=self.n_clusters,\n                                            **self.other_cluster_args)\n\n        cluster_rets = cluster_runner.fit(embeddings)\n\n        return cluster_rets\n"
  },
  {
    "path": "python/fate_llm/algo/fdkt/cluster/cluster_method.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom sklearn.cluster import KMeans\n\n\nclass KMeansRunner(object):\n    def __init__(self, n_clusters, **other_cluster_args):\n        self.n_clusters = n_clusters\n        self.other_cluster_args = other_cluster_args\n\n    def fit(self, x):\n        model = KMeans(n_clusters=self.n_clusters, **self.other_cluster_args)\n        model.fit(x)\n\n        return model.labels_\n\n\ndef get_cluster_runner(method, n_clusters, **other_cluster_args):\n    if method.lower() == \"kmeans\":\n        return KMeansRunner(n_clusters, **other_cluster_args)\n    else:\n        raise ValueError(f\"cluster method={method} is not implemented\")\n"
  },
  {
    "path": "python/fate_llm/algo/fdkt/fdkt_data_aug.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport os.path\nimport shutil\n\nimport torch\nimport logging\nfrom dataclasses import dataclass, field\nfrom ...trainer.seq2seq_trainer import Seq2SeqTrainingArguments\nfrom typing import Optional, Callable\nfrom fate.arch import Context\nfrom transformers import PreTrainedTokenizer\nfrom .utils.invalid_data_filter import filter_invalid_data\nfrom .utils.text_generate import slm_text_generate, general_text_generate\nfrom .cluster.cluster import SentenceCluster\nfrom fate_llm.inference.inference_base import Inference\n\n\nlogger = logging.getLogger(__name__)\nSLM_SYNTHETIC_DATA = \"slm_synthetic_data\"\nLLM_AUG_DATA = \"llm_aug_data\"\n\n\n@dataclass\nclass FDKTTrainingArguments(Seq2SeqTrainingArguments):\n    \"\"\"\n    slm parameters\n    \"\"\"\n    dp_training: bool = field(default=True)\n    target_epsilon: float = field(default=3)\n    target_delta: float = field(default=1e-5)\n    freeze_embedding: bool = field(default=True)\n    device_id: int = field(default=0)\n    slm_generation_config: dict = field(default=None)\n    slm_generation_batch_size: dict = field(default=None)\n    inference_method: str = field(default=\"native\")\n    inference_inst_init_conf: dict = field(default=None)\n\n    \"\"\"\n    slm generation config\n    \"\"\"\n    seq_num_for_single_category: int = field(default=None)\n\n    \"\"\"\n    dp loss params\n    \"\"\"\n    label_smoothing_factor = 0.02\n    loss_reduce = True\n\n    \"\"\"\n    llm parameters\n    \"\"\"\n    sample_num_per_cluster: int = field(default=None)\n    filter_data_batch_size: int = field(default=2)\n    filter_prompt_max_length: int = field(default=2048)\n    filter_generation_config: dict = field(default=None)\n\n    aug_generation_config: dict = field(default=None)\n    aug_prompt_num: int = field(default=None)\n    aug_data_batch_size: int = field(default=2)\n    aug_prompt_max_length: int = field(default=2048)\n\n    def to_dict(self):\n        from dataclasses import fields\n        from enum import Enum\n        d = {field.name: getattr(self, field.name) for field in fields(self) if field.init}\n\n        for k, v in d.items():\n            if isinstance(v, Enum):\n                d[k] = v.value\n            if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):\n                d[k] = [x.value for x in v]\n            if k.endswith(\"_token\"):\n                d[k] = f\"<{k.upper()}>\"\n        return d\n\n\nclass FDKTSLM(object):\n    def __init__(\n        self,\n        ctx: Context,\n        model: torch.nn.Module,\n        training_args: FDKTTrainingArguments,\n        train_set,\n        optimizer: torch.optim.Optimizer = None,\n        scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,\n        data_collator: Callable = None,\n        tokenizer: Optional[PreTrainedTokenizer] = None,\n    ):\n        super(FDKTSLM, self).__init__()\n        self.ctx = ctx\n        self.training_args = training_args\n        self.train_set = train_set\n        self.model = model\n        self.tokenizer = tokenizer\n        self.optimizer = optimizer\n        self.scheduler = scheduler\n\n        self.data_collator = data_collator\n\n        if not self.training_args.use_cpu:\n            self.model.cuda(self.training_args.device_id)\n\n    def aug_data(self):\n        logging.info(\"Start aug data process\")\n        logging.debug(f\"dp_training={self.training_args.dp_training}\")\n        if self.training_args.dp_training:\n            logging.info(\"Start dp training\")\n            self.dp_train()\n            logging.info(\"End dp training\")\n\n        inference_inst = self._create_inference_inst()\n        prefix_prompt_dict = self.train_set.get_generate_prompt(\n            tokenize=True if inference_inst is None else False)\n\n        generated_texts = slm_text_generate(\n            inference_inst,\n            self.model,\n            self.tokenizer,\n            prompt_dict=prefix_prompt_dict,\n            seq_num_for_single_category=self.training_args.seq_num_for_single_category,\n            batch_size=self.training_args.slm_generation_batch_size,\n            use_cpu=self.training_args.use_cpu,\n            generation_config=self.training_args.slm_generation_config\n        )\n\n        self._destroy_inference_inst()\n\n        if not self.training_args.use_cpu:\n            self.model.cpu()\n            torch.cuda.empty_cache()\n\n        generated_texts = filter_invalid_data(generated_texts)\n        self.sync_synthetic_dataset(generated_texts)\n\n        return self.sync_aug_data()\n\n    def dp_train(self):\n        from ..dp import DPTrainer, DPTrainingArguments, get_model_class\n        from .utils.dp_loss import SequenceCrossEntropyLoss\n        dp_training_args = DPTrainingArguments(\n            target_delta=self.training_args.target_delta,\n            target_epsilon=self.training_args.target_epsilon,\n            freeze_embedding=self.training_args.freeze_embedding,\n            device_id=self.training_args.device_id,\n            num_train_epochs=self.training_args.num_train_epochs,\n            per_device_train_batch_size=self.training_args.per_device_train_batch_size,\n            output_dir=\"/\" if self.training_args.output_dir is None else self.training_args.output_dir\n        )\n\n        loss_fn = SequenceCrossEntropyLoss(\n            get_model_class(self.model).__name__,\n            label_smoothing=self.training_args.label_smoothing_factor,\n            reduce=self.training_args.loss_reduce\n        )\n\n        dp_trainer = DPTrainer(\n            model=self.model,\n            training_args=dp_training_args,\n            train_set=self.train_set,\n            optimizer=self.optimizer,\n            scheduler=self.scheduler,\n            data_collator=self.data_collator,\n            loss_fn=loss_fn\n        )\n\n        dp_trainer.train()\n\n    def _create_inference_inst(self):\n        if self.training_args.inference_method == \"native\":\n            return None\n        elif self.training_args.inference_method == \"vllm\":\n            from .inference_inst import vllm_init\n\n            self.model.cpu()\n            model_temp_path = self.training_args.output_dir + \"./model_for_inference\"\n            self.tokenizer.save_pretrained(model_temp_path)\n            self.model.save_pretrained(model_temp_path)\n\n            return vllm_init(model_temp_path) if self.training_args.inference_inst_init_conf is None \\\n                else vllm_init(model_temp_path, **self.training_args.inference_inst_init_conf)\n\n        else:\n            raise ValueError(f\"not supported inference_method={self.training_args.inference_method}\")\n\n    def _destroy_inference_inst(self):\n        if self.training_args.inference_method == \"vllm\":\n            shutil.rmtree(self.training_args.output_dir + \"./model_for_inference\")\n        elif not self.training_args.use_cpu:\n            self.model.cpu()\n\n    def sync_synthetic_dataset(self, data):\n        self.ctx.arbiter.put(SLM_SYNTHETIC_DATA, data)\n\n    def sync_aug_data(self):\n        return self.ctx.arbiter.get(LLM_AUG_DATA)\n\n    def save_model(\n        self,\n        output_dir=\"./\"\n    ):\n        if hasattr(self.model, \"save_pretrained\"):\n            self.model.save_pretrained(output_dir)\n        else:\n            if not os.path.exists(output_dir):\n                os.makedirs(output_dir)\n            torch.save(self.model.state_dict(), output_dir + '/pytorch_model.bin')\n\n\nclass FDKTLLM(object):\n    def __init__(\n        self,\n        ctx: Context,\n        embedding_model: torch.nn.Module,\n        training_args: FDKTTrainingArguments,\n        dataset,\n        model: Optional[torch.nn.Module] = None,\n        tokenizer: Optional[PreTrainedTokenizer] = None,\n        inference_inst: Optional[Inference] = None,\n    ):\n        super(FDKTLLM, self).__init__()\n        self.ctx = ctx\n        self.inference_inst = inference_inst\n        self.embedding_model = embedding_model\n        self.dataset = dataset\n        self.training_args = training_args\n        self.model = model\n        self.tokenizer = tokenizer\n\n        if self.inference_inst is None and (self.model is None or self.tokenizer is None):\n            raise ValueError(\"Inference_inst and Model are both empty, should provided one\")\n        if self.model is not None and self.training_args.device_id is not None and not self.training_args.use_cpu:\n            self.model.cuda(self.training_args.device_id)\n\n    def sync_synthetic_data(self):\n        return self.ctx.guest.get(SLM_SYNTHETIC_DATA)\n\n    def sync_aug_data(self, aug_data):\n        self.ctx.guest.put(LLM_AUG_DATA, aug_data)\n\n    def aug_data(self):\n        logging.info(\"sync slm synthetic_data\")\n        slm_data = self.sync_synthetic_data()\n\n        logging.info(\"filter slm synthetic data\")\n        filter_data = self.filter_data(slm_data)\n\n        logging.info(\"prepare prompts for aug\")\n        aug_prompts = self.dataset.prepare_augment(\n            filter_data[\"inputs\"],\n            filter_data[\"labels\"],\n            aug_prompt_num=self.training_args.aug_prompt_num\n        )\n\n        logging.info(\"aug_data\")\n        aug_data = self._aug(aug_prompts)\n        aug_data = filter_invalid_data(aug_data)\n        self.sync_aug_data(aug_data)\n\n    def _aug(self, aug_prompts):\n        aug_responses = general_text_generate(\n            inference_inst=self.inference_inst,\n            model=self.model,\n            tokenizer=self.tokenizer,\n            generation_config=self.training_args.aug_generation_config,\n            prompts=aug_prompts,\n            batch_size=self.training_args.aug_data_batch_size,\n            use_cpu=self.training_args.use_cpu,\n            prompt_max_length=self.training_args.aug_prompt_max_length\n        )\n\n        aug_data = self.dataset.abstract_from_augmented(aug_responses)\n\n        return aug_data\n\n    def filter_data(self, slm_data):\n        clustered_sentences, clustered_labels = self.cluster_data(slm_data)\n        filter_prompts = self.dataset.prepare_query_to_filter_clustered(clustered_sentences, clustered_labels)\n        filter_responses = general_text_generate(\n            inference_inst=self.inference_inst,\n            model=self.model,\n            tokenizer=self.tokenizer,\n            generation_config=self.training_args.filter_generation_config,\n            prompts=filter_prompts,\n            batch_size=self.training_args.filter_data_batch_size,\n            use_cpu=self.training_args.use_cpu,\n            prompt_max_length=self.training_args.filter_prompt_max_length\n        )\n\n        filtered_sentences, filtered_labels = self.dataset.parse_clustered_response(\n            clustered_sentence=clustered_sentences,\n            clustered_labels=clustered_labels,\n            response_list=filter_responses\n        )\n\n        return dict(\n            inputs=filtered_sentences,\n            labels=filtered_labels\n        )\n\n    def cluster_data(self, slm_data):\n        sentences = slm_data[\"inputs\"]\n        labels = slm_data[\"labels\"]\n\n        n_clusters = (len(sentences) + self.training_args.sample_num_per_cluster - 1) // self.training_args.sample_num_per_cluster\n\n        cluster_ret = SentenceCluster(model=self.embedding_model, n_clusters=n_clusters).cluster(sentences)\n\n        clustered_sentences = [[] for _ in range(n_clusters)]\n        clustered_labels = [[] for _ in range(n_clusters)]\n\n        for sentence_id, cluster_id in enumerate(cluster_ret):\n            clustered_sentences[cluster_id].append(sentences[sentence_id])\n            clustered_labels[cluster_id].append(labels[sentence_id])\n\n        return clustered_sentences, clustered_labels\n"
  },
  {
    "path": "python/fate_llm/algo/fdkt/inference_inst.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\ndef api_init(api_url: str, model_name: str, api_key: str = 'EMPTY', api_timeout=3600):\n    from fate_llm.inference.api import APICompletionInference\n    return APICompletionInference(\n        api_url=api_url,\n        model_name=model_name,\n        api_key=api_key,\n        api_timeout=api_timeout\n    )\n\n\ndef vllm_init(model_path: str, num_gpu=1, dtype='float16', gpu_memory_utilization=0.9):\n    from fate_llm.inference.vllm import VLLMInference\n    return VLLMInference(\n        model_path=model_path,\n        num_gpu=num_gpu,\n        dtype=dtype,\n        gpu_memory_utilization=gpu_memory_utilization\n    )\n"
  },
  {
    "path": "python/fate_llm/algo/fdkt/utils/__init__.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n"
  },
  {
    "path": "python/fate_llm/algo/fdkt/utils/dp_loss.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES\n\n\nNUMERICAL_STABILITY_CONSTANT = 1e-13\n\n\nclass SequenceCrossEntropyLoss(nn.Module):\n    def __init__(self, model_type, label_smoothing=-1, reduce=None):\n        super().__init__()\n        self.model_type = model_type\n        self.label_smoothing = label_smoothing\n        self.reduce = reduce\n\n    def forward(self, logits, targets, mask):\n        return sequence_cross_entropy_with_logits(logits, targets, mask, self.label_smoothing, self.reduce, self.model_type)\n\n\ndef sequence_cross_entropy_with_logits(logits, targets, mask, label_smoothing, reduce, model_type):\n    if model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():\n        logits = logits[:, :-1].contiguous()\n        targets = targets[:, 1:]\n        mask = torch.ones_like(targets).float()\n\n    logits_flat = logits.view(-1, logits.size(-1))\n    log_probs_flat = F.log_softmax(logits_flat, dim=-1)\n    targets_flat = targets.reshape(-1, 1).long()\n\n    if label_smoothing > 0.0:\n        num_classes = logits.size(-1)\n        smoothing_value = label_smoothing / float(num_classes)\n        one_hot_targets = torch.zeros_like(log_probs_flat).scatter_(-1, targets_flat, 1.0 - label_smoothing)\n        smoothed_targets = one_hot_targets + smoothing_value\n        negative_log_likelihood_flat = -log_probs_flat * smoothed_targets\n        negative_log_likelihood_flat = negative_log_likelihood_flat.sum(-1, keepdim=True)\n    else:\n        negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat)\n\n    negative_log_likelihood = negative_log_likelihood_flat.view(-1, logits.shape[1])\n\n    loss = negative_log_likelihood * mask\n\n    if reduce:\n        loss = loss.sum(1) / (mask.sum(1) + NUMERICAL_STABILITY_CONSTANT)\n\n        if reduce is \"batch\":\n            loss = loss.mean()\n\n    return loss\n"
  },
  {
    "path": "python/fate_llm/algo/fdkt/utils/invalid_data_filter.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nINVALID_CHARACTERS = \"\".join([' ', '-', '.', '_', '~', '/', '\\\\', '*', '|', '#'])\nLEAST_WORDS = 10\n\n\ndef filter_invalid_data(data_dict):\n    sample_num = len(data_dict[\"inputs\"])\n    new_data_dict = dict(\n        inputs=list(),\n        labels=list()\n    )\n    for idx in range(sample_num):\n        text = data_dict[\"inputs\"][idx].strip(INVALID_CHARACTERS)\n        if len(text.split()) < LEAST_WORDS:\n            continue\n\n        new_data_dict[\"inputs\"].append(text)\n        new_data_dict[\"labels\"].append(data_dict[\"labels\"][idx])\n\n    return new_data_dict\n"
  },
  {
    "path": "python/fate_llm/algo/fdkt/utils/text_generate.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom tqdm import tqdm\nfrom typing import Any, Dict, List\n\n\ndef slm_text_generate(\n    inference_inst,\n    model,\n    tokenizer,\n    prompt_dict,\n    seq_num_for_single_category,\n    batch_size,\n    use_cpu,\n    generation_config\n):\n    generated_ret = dict(\n        inputs=list(),\n        labels=list(),\n    )\n    if inference_inst is not None:\n        for label, prompt in prompt_dict.items():\n            generated_sequences = inference_inst.inference([prompt] * seq_num_for_single_category, generation_config)\n            for g in generated_sequences:\n                generated_ret[\"inputs\"].append(g)\n                generated_ret[\"labels\"].append(label)\n    else:\n        model.eval()\n        for label, prompt_ids in prompt_dict.items():\n            prompt_length = len(prompt_ids)\n            batch_num = (seq_num_for_single_category + batch_size - 1) // batch_size\n            for batch_idx in tqdm(range(batch_num)):\n                if batch_idx + 1 == batch_num:\n                    cur_batch_size = seq_num_for_single_category - batch_idx * batch_size\n                else:\n                    cur_batch_size = batch_size\n                input_ids = prompt_ids.repeat(cur_batch_size, 1)\n\n                if not use_cpu:\n                    input_ids = input_ids.to(model.device)\n\n                output_sequences = model.generate(\n                    input_ids=input_ids,\n                    **generation_config\n                )\n                output_sequences = output_sequences[:, prompt_length:]\n\n                generated_sequences = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)\n\n                for g in generated_sequences:\n                    generated_ret[\"inputs\"].append(g)\n                    generated_ret[\"labels\"].append(label)\n\n    return generated_ret\n\n\ndef general_text_generate(\n    inference_inst,\n    model,\n    tokenizer,\n    generation_config: Dict[Any, Any],\n    prompts: List[str],\n    batch_size,\n    use_cpu: bool,\n    prompt_max_length\n):\n    if inference_inst is not None:\n        if prompt_max_length is not None:\n            prompts = [prompt[:prompt_max_length] for prompt in prompts]\n        generate_texts = inference_inst.inference(prompts, generation_config)\n    else:\n        model.eval()\n        generate_texts = []\n        batch_num = (len(prompts) + batch_size - 1) // batch_size\n        for batch_idx in range(batch_num):\n            batch_data = prompts[batch_idx * batch_size: (batch_idx + 1) * batch_size]\n\n            inputs = tokenizer(batch_data, return_tensors=\"pt\", padding=\"longest\", truncation=True,\n                               max_length=prompt_max_length)\n            input_ids = inputs[\"input_ids\"]\n            attention_mask = inputs[\"attention_mask\"]\n\n            if not use_cpu:\n                input_ids = input_ids.to(model.device)\n                attention_mask = attention_mask.to(model.device)\n\n            output = model.generate(\n                input_ids=input_ids,\n                attention_mask=attention_mask,\n                **generation_config\n            )\n\n            batch_responses = tokenizer.batch_decode(output[:, input_ids.shape[1]:], skip_special_tokens=True)\n\n            generate_texts.extend(batch_responses)\n\n    return generate_texts\n"
  },
  {
    "path": "python/fate_llm/algo/fedavg/__init__.py",
    "content": ""
  },
  {
    "path": "python/fate_llm/algo/fedavg/fedavg.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\nimport torch\nfrom fate.ml.nn.homo.fedavg import FedAVGServer, FedAVGArguments, FedArguments\nfrom fate.arch import Context\nfrom fate_llm.trainer.seq2seq_trainer import HomoSeq2SeqTrainerClient, Seq2SeqTrainingArguments\nfrom fate.ml.aggregator import AggregatorClientWrapper\nimport logging\nfrom typing import List, Optional, Tuple, Callable, Dict\nfrom fate.arch import Context\nfrom torch.optim import Optimizer\nfrom torch.utils.data import Dataset\nfrom torch.optim.lr_scheduler import _LRScheduler\nfrom transformers.trainer_callback import TrainerCallback\nfrom torch import nn\nfrom torch.utils.data import DataLoader\nfrom transformers import TrainerState, TrainerControl, PreTrainedTokenizer, EvalPrediction\n\n\n\nlogger = logging.getLogger(__name__)\n\n\nSeq2SeqFedAVGServer = FedAVGServer\n\n\nclass Seq2SeqFedAVGClient(HomoSeq2SeqTrainerClient):\n\n    def __init__(\n        self,\n        ctx: Context,\n        model: nn.Module,\n        training_args: Seq2SeqTrainingArguments,\n        fed_args: FedArguments,\n        train_set: Dataset,\n        val_set: Dataset = None,\n        optimizer: torch.optim.Optimizer = None,\n        scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,\n        data_collator: Callable = None,\n        tokenizer: Optional[PreTrainedTokenizer] = None,\n        callbacks: Optional[List[TrainerCallback]] = [],\n        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,\n        local_mode: bool = False,\n        save_trainable_weights_only: bool = False,\n        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,\n    ):\n        # in case you forget to set evaluation_strategy\n        if val_set is not None and training_args.evaluation_strategy == \"no\":\n            training_args.evaluation_strategy = \"epoch\"\n\n        HomoSeq2SeqTrainerClient.__init__(\n            self,\n            ctx,\n            model,\n            training_args,\n            fed_args,\n            train_set,\n            val_set,\n            optimizer,\n            data_collator,\n            scheduler,\n            tokenizer,\n            callbacks,\n            compute_metrics,\n            local_mode,\n            save_trainable_weights_only,\n            preprocess_logits_for_metrics\n        )\n\n\n    def init_aggregator(self, ctx: Context, fed_args: FedArguments):\n        aggregate_type = \"weighted_mean\"\n        aggregator_name = \"fedavg\"\n        aggregator = fed_args.aggregator\n        return AggregatorClientWrapper(\n            ctx, aggregate_type, aggregator_name, aggregator, sample_num=len(self.train_dataset), args=self._args\n        )\n\n    def on_federation(\n        self,\n        ctx: Context,\n        aggregator: AggregatorClientWrapper,\n        fed_args: FedArguments,\n        args: Seq2SeqTrainingArguments,\n        model: Optional[nn.Module] = None,\n        optimizer: Optional[Optimizer] = None,\n        scheduler: Optional[_LRScheduler] = None,\n        dataloader: Optional[Tuple[DataLoader]] = None,\n        control: Optional[TrainerControl] = None,\n        state: Optional[TrainerState] = None,\n        **kwargs,\n    ):\n        aggregator.model_aggregation(ctx, model)\n\n"
  },
  {
    "path": "python/fate_llm/algo/fedcollm/__init__.py",
    "content": ""
  },
  {
    "path": "python/fate_llm/algo/fedcollm/fedcollm.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport torch\nimport logging\nfrom fate_llm.algo.fedcollm.fedcollm_trainer import FedCoLLMTrainer\nfrom typing import Dict, Optional, List, Callable, Union\nfrom fate.arch import Context\nfrom fate.ml.nn.trainer.trainer_base import FedArguments\nfrom torch.utils.data import Dataset\nfrom transformers.trainer_callback import TrainerCallback\nfrom transformers import PreTrainedTokenizer\nfrom transformers import Seq2SeqTrainer\nfrom transformers.trainer_utils import EvalPrediction\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.modeling_utils import unwrap_model\nfrom fate_llm.algo.fedmkt.utils.generate_logit_utils import generate_pub_data_logits\nfrom fate.ml.aggregator import AggregatorClientWrapper, AggregatorServerWrapper\nfrom fate_llm.algo.fedcollm.fedcollm_training_args import FedCoLLMTrainingArguments\nfrom types import SimpleNamespace\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass FedCoLLMBase(object):\n    @staticmethod\n    def update_model(model, updated_params):\n        for updated_p, p in zip(updated_params, [p for p in model.parameters() if p.requires_grad]):\n            p.data.copy_(t.Tensor(updated_p))\n\n\nclass SLM(FedCoLLMBase):\n    def __init__(\n        self,\n        ctx: Context,\n        model: torch.nn.Module,\n        training_args: FedCoLLMTrainingArguments,\n        fed_args: FedArguments = None,\n        train_set=None,\n        val_set: Dataset = None,\n        optimizer: torch.optim.Optimizer = None,\n        scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,\n        data_collator: Callable = None,\n        tokenizer: Optional[PreTrainedTokenizer] = None,\n        model_init: Optional[Callable[[], PreTrainedModel]] = None,\n        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,\n        callbacks: Optional[List[TrainerCallback]] = [],\n        save_trainable_weights_only: bool = False,\n        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,\n    ):\n        super(SLM, self).__init__()\n        self.ctx = ctx\n        self.training_args = training_args\n        self.fed_args = fed_args\n        self.model = model\n        self.tokenizer = tokenizer\n        self.model_init = model_init\n        self.callbacks = callbacks\n        self.compute_metrics = compute_metrics\n        self.save_trainable_weights_only = save_trainable_weights_only\n        self.preprocess_logits_for_metrics = preprocess_logits_for_metrics\n\n        self.data_collator = data_collator\n        self.optimizer = optimizer\n        self.scheduler = scheduler\n        self.train_set = train_set\n\n        self.val_set = val_set\n\n        self.aggregator = self._init_aggregator(ctx, fed_args)\n\n    def train(self):\n        global_epochs = self.training_args.global_epochs\n\n        for i, iter_ctx in self.ctx.on_iterations.ctxs_range(global_epochs):\n            logger.info(f\"begin {i}-th global kd process\")\n            training_args = self._get_slm_training_args()\n\n            trainer = Seq2SeqTrainer(\n                model=self.model,\n                tokenizer=self.tokenizer,\n                data_collator=self.data_collator,\n                train_dataset=self.train_set,\n                args=training_args,\n                model_init=self.model_init if not i else None,\n                compute_metrics=self.compute_metrics,\n                callbacks=self.callbacks,\n                optimizers=(self.optimizer, self.scheduler),\n                preprocess_logits_for_metrics=self.preprocess_logits_for_metrics\n            )\n\n            logger.info(f\"begin {i}-th private data training process\")\n            trainer.train()\n\n            self.model = unwrap_model(trainer.model)\n            self.aggregator.model_aggregation(iter_ctx, self.model)\n\n    def _sync_slm_updated_params(self, iter_ctx):\n        updated_params = iter_ctx.arbiter.get(\"slm_updated_params\")\n        self.update_model(self.model, updated_params)\n\n    def _get_slm_training_args(self):\n        return self.training_args.to_slm_seq_training_args()\n\n    def _init_aggregator(self, ctx: Context, fed_args: FedArguments):\n        aggregate_type = \"weighted_mean\"\n        aggregator_name = \"fedavg\"\n        aggregator = fed_args.aggregator\n        return AggregatorClientWrapper(\n            ctx, aggregate_type, aggregator_name, aggregator,\n            sample_num=len(self.train_set), args=self.training_args\n        )\n\n\nclass LLM(FedCoLLMBase):\n    def __init__(\n        self,\n        ctx: Context,\n        llm_model: torch.nn.Module,\n        slm_model: torch.nn.Module,\n        training_args: FedCoLLMTrainingArguments,\n        fed_args: FedArguments = None,\n        train_set=None,\n        val_set: Dataset = None,\n        llm_optimizer: torch.optim.Optimizer = None,\n        slm_optimizer: torch.optim.Optimizer = None,\n        llm_lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,\n        slm_lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,\n        data_collator: Callable = None,\n        tokenizer: Optional[PreTrainedTokenizer] = None,\n        llm_model_init: Optional[Callable[[], PreTrainedModel]] = None,\n        slm_model_init: Optional[Callable[[], PreTrainedModel]] = None,\n        llm_compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,\n        slm_compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,\n        llm_callbacks: Optional[List[TrainerCallback]] = [],\n        slm_callbacks: Optional[List[TrainerCallback]] = [],\n        save_trainable_weights_only: bool = False,\n        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,\n    ):\n        super(LLM, self).__init__()\n        self.ctx = ctx\n        self.llm_model = llm_model\n        self.slm_model = slm_model\n        self.training_args = training_args\n        self.fed_args = fed_args\n        self.train_set = train_set\n        self.val_set = val_set\n        self.llm_optimizer = llm_optimizer\n        self.slm_optimizer = slm_optimizer\n        self.llm_lr_scheduler = llm_lr_scheduler\n        self.slm_lr_scheduler = slm_lr_scheduler\n        self.data_collator = data_collator\n        self.tokenizer = tokenizer\n        self.llm_model_init = llm_model_init\n        self.slm_model_init = slm_model_init\n        self.llm_compute_metrics = llm_compute_metrics\n        self.slm_compute_metrics = slm_compute_metrics\n        self.llm_callbacks = llm_callbacks\n        self.slm_callbacks = slm_callbacks\n        self.save_trainable_weights_only = save_trainable_weights_only\n        self.preprocess_logits_for_metrics = preprocess_logits_for_metrics\n\n        self.aggregator = self._init_aggregator(ctx)\n\n    def _init_aggregator(self, ctx: Context):\n        return AggregatorServerWrapper(ctx)\n\n    def _get_logits(self, model):\n        if self.training_args.device.type == \"cuda\":\n            model.cuda(self.training_args.device.type)\n\n        fn_kwargs = {\"model\": model,\n                     \"training_args\": self.training_args,\n                     \"data_collator\": self.data_collator}\n\n        return self.train_set.map(\n            generate_pub_data_logits,\n            batched=True,\n            batch_size=self.training_args.per_device_train_batch_size,\n            num_proc=None,\n            load_from_cache_file=True,\n            fn_kwargs=fn_kwargs\n        )\n\n    def on_epoch_begin(self, iter_ctx):\n        self.aggregator.model_aggregation(iter_ctx)\n        updated_slm_params = iter_ctx()\n        self.update_model(self.slm_model, updated_slm_params)\n\n    def _sync_slm_updated_params(self, iter_ctx):\n        updated_params = [p for p in self.slm_model.parameters() if p.requires_grad]\n        iter_ctx.guest.put(\"slm_updated_params\", updated_params)\n        if any(p.role == 'host' for p in self.ctx.parties):\n            iter_ctx.hosts.put(\"slm_updated_params\", updated_params)\n\n    def _train_slm(self, iter_ctx, llm_pub_logits, epoch_idx):\n        top_k_args = SimpleNamespace(\n            top_k_logits_keep=self.training_args.top_k_logits_keep,\n            top_k_strategy=self.training_args.top_k_strategy\n        )\n\n        self.train_set.set_return_with_idx()\n        trainer = FedCoLLMTrainer(\n            model=self.slm_model,\n            tokenizer=self.tokenizer,\n            data_collator=self.data_collator,\n            train_dataset=self.train_set,\n            args=self.training_args.to_slm_seq_training_args(),\n            model_init=self.slm_model_init if not epoch_idx else None,\n            compute_metrics=self.slm_compute_metrics,\n            callbacks=self.slm_callbacks,\n            optimizers=(self.slm_optimizer, self.slm_lr_scheduler),\n            preprocess_logits_for_metrics=self.preprocess_logits_for_metrics,\n            top_k_args=top_k_args,\n            distill_lambda=self.training_args.distill_lambda,\n            distill_temperature=self.training_args.distill_temperature,\n            max_length=max(len(d[\"input_ids\"]) for d in self.train_set),\n            vocab_size=self.training_args.vocab_size,\n            dtype=next(self.slm_model.parameters()).dtype,\n            other_logits=llm_pub_logits\n        )\n\n        trainer.train()\n        self.slm_model = unwrap_model(trainer.model)\n        self.train_set.reset_return_with_idx()\n\n        self._sync_slm_updated_params(iter_ctx)\n\n    def _train_llm(self, slm_pub_logits, epoch_idx):\n        top_k_args = SimpleNamespace(\n            top_k_logits_keep=self.training_args.top_k_logits_keep,\n            top_k_strategy=self.training_args.top_k_strategy\n        )\n\n        self.train_set.set_return_with_idx()\n        trainer = FedCoLLMTrainer(\n            model=self.llm_model,\n            tokenizer=self.tokenizer,\n            data_collator=self.data_collator,\n            train_dataset=self.train_set,\n            args=self.training_args.to_llm_seq_training_args(),\n            model_init=self.llm_model_init if not epoch_idx else None,\n            compute_metrics=self.llm_compute_metrics,\n            callbacks=self.llm_callbacks,\n            optimizers=(self.llm_optimizer, self.llm_lr_scheduler),\n            preprocess_logits_for_metrics=self.preprocess_logits_for_metrics,\n            top_k_args=top_k_args,\n            distill_lambda=self.training_args.distill_lambda,\n            distill_temperature=self.training_args.distill_temperature,\n            max_length=max(len(d[\"input_ids\"]) for d in self.train_set),\n            vocab_size=self.training_args.vocab_size,\n            dtype=next(self.slm_model.parameters()).dtype,\n            other_logits=slm_pub_logits\n        )\n\n        trainer.train()\n        self.llm_model = unwrap_model(trainer.model)\n        self.train_set.reset_return_with_idx()\n\n    def train(self):\n        global_epochs = self.training_args.global_epochs\n\n        for i, iter_ctx in self.ctx.on_iterations.ctxs_range(global_epochs):\n            logger.info(f\"begin {i}-th global kd process\")\n\n            self.on_epoch_begin(iter_ctx)\n            logger.info(f\"get pub data logits for llm of global epoch={i}\")\n            llm_pub_data_logits = self._get_logits(self.llm_model)\n\n            logger.info(f\"train slm of global epoch={i}\")\n            self._train_slm(iter_ctx, llm_pub_data_logits, i)\n\n            logger.info(f\"get pub data logits for trained slm of global epoch={i}\")\n            slm_pub_data_logits = self._get_logits(self.slm_model)\n\n            logger.info(f\"train llm of global epoch={i}\")\n            self._train_llm(slm_pub_data_logits, i)\n"
  },
  {
    "path": "python/fate_llm/algo/fedcollm/fedcollm_trainer.py",
    "content": "#\n# NOTE: The implementations of FedMKTTrainer is modified from FuseAI/FuseLLM\n# Copyright FuseAI\n#\n#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport logging\nimport torch\nfrom torch.nn.functional import kl_div, log_softmax, softmax\nfrom transformers import Seq2SeqTrainer\nfrom fate_llm.algo.fedmkt.utils.generate_logit_utils import LogitsSelection\nfrom fate_llm.algo.fedmkt.utils.vars_define import (\n    PER_STEP_LOGITS,\n    PER_STEP_INDICES,\n)\nfrom types import SimpleNamespace\n\nlogger = logging.getLogger(__name__)\n\n\ndef computing_kd_loss(src_logits, dst_logits, loss_mask):\n    src_logits = src_logits[loss_mask]\n    dst_logits = dst_logits[loss_mask]\n\n    return kl_div(\n        log_softmax(src_logits, dim=-1, dtype=torch.float32),\n        dst_logits,\n        log_target=False,\n        reduction=\"none\").sum(dim=-1)\n\n\ndef recovery_logits(\n    top_k_logits,\n    top_k_indices,\n    batch_size,\n    max_length,\n    vocab_size,\n    dtype,\n    device,\n    pad_id,\n    distill_temperature\n):\n    logits = torch.zeros(batch_size, max_length, vocab_size).to(dtype).to(device)\n    for i in range(batch_size):\n        base_seq_len = len(top_k_logits[i])\n        for j in range(max_length):\n            if j < base_seq_len:\n                base_logits = torch.tensor(top_k_logits[i][j], dtype=dtype)\n                base_prob = softmax(base_logits / distill_temperature, -1)\n                base_indices = torch.tensor(top_k_indices[i][j])\n                base_prob = base_prob.to(device)\n                base_indices = base_indices.cuda(device)\n                logits[i][j] = logits[i][j].scatter_(-1, base_indices, base_prob)\n            else:  # padding position\n                logits[i][j][pad_id] = 1.0\n\n    return logits\n\n\nclass FedCoLLMTrainer(Seq2SeqTrainer):\n    distill_lambda: float = 1.0\n    distill_temperature: float = 1.0\n    other_logits = None\n    dtype: torch.dtype = torch.bfloat16\n    vocab_size: int = None\n    max_length: int = None\n    top_k_args: SimpleNamespace = None\n\n    def __init__(self, **kwargs):\n        distill_lambda = kwargs.pop(\"distill_lambda\", 1.0)\n        distill_temperature = kwargs.pop(\"distill_temperature\", 1.0)\n        other_logits = kwargs.pop(\"other_logits\")\n        vocab_size = kwargs.pop(\"vocab_size\")\n        max_length = kwargs.pop(\"max_length\")\n        top_k_args = kwargs.pop(\"top_k_args\")\n        super(FedCoLLMTrainer, self).__init__(**kwargs)\n        self.distill_lambda = distill_lambda\n        self.distill_temperature = distill_temperature\n        self.other_logits = other_logits\n        self.pad_id = self.tokenizer.pad_token_id\n        self.vocab_size = vocab_size\n        self.max_length = max_length\n        self.top_k_args = top_k_args\n\n    def compute_loss(self,  model, inputs, return_outputs=False):\n        lm_outputs = model(**inputs['inputs'])\n        lm_loss = lm_outputs.loss\n        logits = lm_outputs.logits\n        other_logits = self.other_logits[inputs[\"indexes\"]]\n\n        batch_size = logits.shape[0]\n\n        top_k_logits, top_k_indices = LogitsSelection.select_logits(logits, self.top_k_args)\n\n        dst_logits = recovery_logits(\n            other_logits[PER_STEP_INDICES],\n            other_logits[PER_STEP_INDICES],\n            batch_size,\n            self.max_length,\n            self.vocab_size,\n            self.dtype,\n            logits.device,\n            self.pad_id,\n            self.distill_temperature\n        )\n\n        src_logits = recovery_logits(\n            top_k_logits,\n            top_k_indices,\n            batch_size,\n            self.max_length,\n            self.vocab_size,\n            self.dtype,\n            logits.device,\n            self.pad_id,\n            self.distill_temperature\n        )\n\n        loss_mask = (inputs[\"inputs\"][\"labels\"] != -100)\n        kl_loss = computing_kd_loss(src_logits, dst_logits, loss_mask=loss_mask).sum()\n\n        return lm_loss + self.distill_lambda * kl_loss\n"
  },
  {
    "path": "python/fate_llm/algo/fedcollm/fedcollm_training_args.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom dataclasses import dataclass, field\nfrom ...trainer.seq2seq_trainer import Seq2SeqTrainingArguments\n\n\n@dataclass\nclass FedCoLLMTrainingArguments(Seq2SeqTrainingArguments):\n    \"\"\"\n    top-k logits select params\n    \"\"\"\n    top_k_logits_keep: int = field(default=128)\n    top_k_strategy: str = field(default=\"highest\")\n    vocab_size: int = field(default=None)\n\n    \"\"\"\n    distillation params\n    \"\"\"\n    distill_lambda: float = field(default=1.0)\n    distill_temperature: float = field(default=1.0)\n    server_public_data_local_epoch: int = field(default=1)\n    client_public_data_local_epoch: int = field(default=1)\n    client_priv_data_local_epoch: int = field(default=1)\n    global_epochs: int = field(default=1)\n\n    extra_args = [\"top_k_logits_keep\", \"top_k_strategy\", \"vocab_size\",\n                  \"distill_lambda\", \"distill_temperature\", \"server_public_data_local_epoch\",\n                  \"client_public_data_local_epoch\", \"client_priv_data_local_epoch\",\n                  \"global_epochs\"]\n\n    def to_dict(self):\n        from dataclasses import fields\n        from enum import Enum\n        d = {field.name: getattr(self, field.name) for field in fields(self) if field.init}\n\n        for k, v in d.items():\n            if isinstance(v, Enum):\n                d[k] = v.value\n            if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):\n                d[k] = [x.value for x in v]\n            if k.endswith(\"_token\"):\n                d[k] = f\"<{k.upper()}>\"\n        return d\n\n    def _pop_extra(self):\n        args = self.to_dict()\n        for arg in self.extra_args:\n            args.pop(arg)\n\n        return args\n\n    def to_slm_seq_training_args(self):\n        args = self._pop_extra()\n        args[\"num_train_epochs\"] = self.client_priv_data_local_epoch\n\n        return Seq2SeqTrainingArguments(**args)\n\n    def to_fedco_slm_training_args(self):\n        args = self._pop_extra()\n        args[\"num_train_epochs\"] = self.client_pub_data_local_epoch\n\n        return Seq2SeqTrainingArguments(**args)\n\n    def to_fedco_llm_training_args(self):\n        args = self._pop_extra()\n        args[\"num_train_epochs\"] = self.server_pub_data_local_epoch\n\n        return Seq2SeqTrainingArguments(**args)\n"
  },
  {
    "path": "python/fate_llm/algo/fedcot/__init__.py",
    "content": ""
  },
  {
    "path": "python/fate_llm/algo/fedcot/encoder_decoder/__init__.py",
    "content": ""
  },
  {
    "path": "python/fate_llm/algo/fedcot/encoder_decoder/init/__init__.py",
    "content": ""
  },
  {
    "path": "python/fate_llm/algo/fedcot/encoder_decoder/init/default_init.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n\nfrom fate_llm.algo.inferdpt.init._init import InferInit\nfrom fate_llm.inference.api import APICompletionInference\nfrom fate_llm.algo.fedcot.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderClient, SLMEncoderDecoderServer\n\n\nclass FedCoTEDAPIClientInit(InferInit):\n\n    api_url = ''\n    api_model_name = ''\n    api_key = 'EMPTY'\n\n    def __init__(self, ctx):\n        super().__init__(ctx)\n        self.ctx = ctx\n\n    def get_inst(self):\n        inference = APICompletionInference(api_url=self.api_url, model_name=self.api_model_name, api_key=self.api_key)\n        client = SLMEncoderDecoderClient(self.ctx, inference)\n        return client\n\n\nclass FedCoTEDAPIServerInit(InferInit):\n\n    api_url = ''\n    api_model_name = ''\n    api_key = 'EMPTY'\n\n    def __init__(self, ctx):\n        super().__init__(ctx)\n        self.ctx = ctx\n\n    def get_inst(self):\n        inference = APICompletionInference(api_url=self.api_url, model_name=self.api_model_name, api_key=self.api_key)\n        return SLMEncoderDecoderServer(self.ctx, inference)\n"
  },
  {
    "path": "python/fate_llm/algo/fedcot/encoder_decoder/slm_encoder_decoder.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport copy\nfrom jinja2 import Template\nfrom tqdm import tqdm\nfrom fate.arch import Context\nfrom typing import List, Dict, Union\nfrom fate.ml.nn.dataset.base import Dataset\nfrom fate_llm.algo.inferdpt.utils import InferDPTKit\nfrom openai import OpenAI\nimport logging\nfrom fate_llm.inference.inference_base import Inference\nfrom fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer\nfrom fate_llm.dataset.hf_dataset import HuggingfaceDataset\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass SLMEncoderDecoderClient(InferDPTClient):\n\n    def __init__(self, ctx: Context, local_inference_inst: Inference) -> None:\n        self.ctx = ctx\n        self.comm_idx = 0\n        self.local_inference_inst = local_inference_inst\n        self.local_inference_kwargs = {}\n\n    def encode(self, docs: List[Dict[str, str]], format_template: str = None, verbose=False, perturb_doc_key: str ='perturbed_doc') -> List[Dict[str, str]]:\n        \n        template = Template(format_template)\n        copy_docs = copy.deepcopy(docs)\n        doc_to_infer = []\n        for doc in tqdm(copy_docs):\n            rendered_doc = template.render(**doc)\n            doc_to_infer.append(rendered_doc)\n        # perturb using local model inference\n        self.doc_to_infer = doc_to_infer\n        infer_result = self.local_inference_inst.inference(doc_to_infer, self.local_inference_kwargs)\n        for doc, pr in zip(copy_docs, infer_result):\n            doc[perturb_doc_key] = pr\n        self.doc_with_p = copy_docs\n        return copy_docs\n    \n    def decode(self, p_docs: List[Dict[str, str]], instruction_template: str = None, decode_template: str = None, verbose=False, \n            perturbed_response_key: str = 'perturbed_response', result_key: str = 'result',\n            remote_inference_kwargs: dict = {}, local_inference_kwargs: dict = {}):\n        return super().decode(p_docs, instruction_template, decode_template, verbose, perturbed_response_key, result_key, remote_inference_kwargs, local_inference_kwargs)\n\n    def inference(self, docs: Union[List[Dict[str, str]], HuggingfaceDataset],\n                encode_template: str,\n                instruction_template: str,\n                decode_template: str,\n                verbose: bool = False,\n                remote_inference_kwargs: dict = {},\n                local_inference_kwargs: dict = {},\n                perturb_doc_key: str = 'perturbed_doc',\n                perturbed_response_key: str = 'perturbed_response',\n                result_key: str = 'result',\n                ) -> List[Dict[str, str]]:\n        self.local_inference_kwargs = local_inference_kwargs\n        return super().inference(docs, encode_template, instruction_template, decode_template, verbose, remote_inference_kwargs, \\\n            local_inference_kwargs, perturb_doc_key, perturbed_response_key, result_key)\n\n\nclass SLMEncoderDecoderServer(InferDPTServer):\n    pass\n"
  },
  {
    "path": "python/fate_llm/algo/fedcot/fedcot_trainer.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport os\nimport pickle\nimport time\nfrom torch import nn\nfrom typing import List, Optional, Callable, Literal, Union\nfrom fate.arch import Context\nfrom torch.utils.data import DataLoader, Dataset\nfrom transformers.trainer_callback import TrainerCallback\nfrom transformers import PreTrainedTokenizer\nimport logging\nimport torch\nimport torch.distributed as dist\nfrom fate_llm.dataset.fedcot_dataset import PrefixDataset\nfrom transformers.modeling_utils import unwrap_model\nfrom transformers import PreTrainedTokenizer, PreTrainedModel\nfrom typing import Dict, Any\nfrom transformers import Seq2SeqTrainingArguments \nfrom transformers.trainer_utils import EvalPrediction\nfrom fate_llm.trainer.seq2seq_trainer import Seq2SeqTrainer, Seq2SeqTrainingArguments\nfrom fate_llm.inference.inference_base import Inference\nfrom fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer\nfrom fate_llm.algo.fedcot.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderClient, SLMEncoderDecoderServer\n\n\nlogger = logging.getLogger(__name__)\n_MODE = ['train_only', 'infer_only', 'infer_and_train']\n\n\n# share obj between ranks in an easy way\ndef save_to(obj, filepath, filename='tmp.pkl'):\n    if not os.path.exists(filepath):\n        os.mkdir(filepath)\n    path = filepath + filename\n    with open(path, 'wb') as f:\n        pickle.dump(obj, f)\n    dist.barrier()\n    os.remove(path)\n\n\ndef load(filepath, filename='tmp.pkl'):\n    path = filepath + filename\n    while not os.path.exists(path):\n        time.sleep(0.1)  \n    while True:\n        try:\n            with open(path, 'rb') as f:\n                d = pickle.load(f)\n                break\n        except (EOFError, pickle.UnpicklingError):\n            time.sleep(0.1) \n\n    dist.barrier()\n    return d\n\n\nclass DSSTrainerClient(Seq2SeqTrainer):\n\n    def __init__(self,\n                model: nn.Module,\n                training_args: Seq2SeqTrainingArguments,\n                train_set: Dataset,\n                val_set: Dataset = None,\n                alpha: float = 0.5,\n                optimizer: torch.optim.Optimizer = None,\n                data_collator: Callable = None,\n                scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,\n                tokenizer: Optional[PreTrainedTokenizer] = None,\n                callbacks: Optional[List[TrainerCallback]] = [],\n                compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,\n                preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None\n    ) -> None:\n\n        self.alpha = alpha\n        Seq2SeqTrainer.__init__(\n            self,\n            model=model,\n            args=training_args,\n            train_dataset=train_set,\n            eval_dataset=val_set,\n            data_collator=data_collator,\n            optimizers=(optimizer, scheduler),\n            tokenizer=tokenizer,\n            preprocess_logits_for_metrics=preprocess_logits_for_metrics,\n            compute_metrics=compute_metrics,\n            callbacks=callbacks,\n        )\n\n    def compute_loss(self, model, inputs, return_outputs=False):\n\n        label_outputs = model(**inputs['predict'])\n        cot_outputs = model(**inputs['rationale'])\n        loss = self.alpha * cot_outputs.loss + (1. - self.alpha) * label_outputs.loss\n        return (loss, {'rationale_loss': cot_outputs, 'predict_loss': label_outputs}) if return_outputs else loss\n\n\nclass FedCoTTrainerClient(DSSTrainerClient):\n\n    def __init__(self,\n        ctx: Context,\n        training_args: Seq2SeqTrainingArguments,\n        train_set: PrefixDataset,\n        val_set: Dataset = None,\n        model: nn.Module = None,\n        optimizer: torch.optim.Optimizer = None,\n        data_collator: Callable = None,\n        scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,\n        tokenizer: Optional[PreTrainedTokenizer] = None,\n        callbacks: Optional[List[TrainerCallback]] = [],\n        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,\n        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,\n        alpha: float = 0.5,\n        mode: Literal['train_only', 'infer_only', 'infer_and_train'] = 'infer_and_train',\n        infer_client: Union[SLMEncoderDecoderClient, InferDPTClient] = None,\n        encode_template: str = None,\n        instruction_template: str = None,\n        decode_template: str = None,\n        result_key: str = 'infer_result',\n        verbose: bool = False,\n        remote_inference_kwargs: dict = {},\n        local_inference_kwargs: dict = {},\n        tmp_data_share_path: str = None\n    ) -> None:\n        \n        self.mode = mode\n        self.infer_client = infer_client\n        self.infer_result = None\n        self.infer_predict_kwargs = {\n            'encode_template': encode_template,\n            'instruction_template': instruction_template,\n            'decode_template': decode_template,\n            'result_key': result_key,\n            'verbose': verbose,\n            'remote_inference_kwargs': remote_inference_kwargs,\n            'local_inference_kwargs': local_inference_kwargs\n        }\n        self.infer_result = None\n        self.tmp_data_share_path = tmp_data_share_path\n\n        assert mode in _MODE, \"mode should be one of {}\".format(_MODE)\n        if training_args.local_rank == 0:\n            if mode == 'infer_only' or mode == 'infer_and_train':\n                if self.infer_client is None:\n                    raise ValueError('You must provide an inference instance for remote inference')\n\n        if mode != 'infer_only':\n            training_args.remove_unused_columns = False  # this parameter is neccessary\n            DSSTrainerClient.__init__(\n                self,\n                model=model,\n                training_args=training_args,\n                train_set=train_set,\n                val_set=val_set,\n                data_collator=data_collator,\n                optimizer=optimizer,\n                scheduler=scheduler,\n                tokenizer=tokenizer,\n                preprocess_logits_for_metrics=preprocess_logits_for_metrics,\n                compute_metrics=compute_metrics,\n                callbacks=callbacks,\n                alpha=alpha\n            )\n        else:\n            # skip trainer initialzation becuase training is not needed\n            self.args = training_args\n            self.train_dataset = train_set\n\n    def infer(self) -> List[str]:        \n\n        if self.args.local_rank == 0:  # other rank will skip federation step\n            assert isinstance(self.train_dataset, PrefixDataset), \"train_set should be an instance of PrefixDataset\"\n            dict_dataset = self.train_dataset.get_raw_dataset()\n            infer_result = self.infer_client.inference(dict_dataset, **self.infer_predict_kwargs)\n            self.infer_result = infer_result\n            rationale_list = [i[self.infer_predict_kwargs['result_key']] for i in self.infer_result]\n            self.train_dataset.load_rationale(rationale_list, key=self.infer_predict_kwargs['result_key'])\n            logger.info('infer done')\n            if self.mode == 'infer_and_train':\n                if self.args.world_size > 1:  # sync dataset with other ranks\n                    tmp_path = self.tmp_data_share_path if self.tmp_data_share_path is not None else self.args.output_dir\n                    logger.info('scattering obj, save to temp path {}'.format(tmp_path))\n                    save_to(rationale_list, tmp_path)\n\n        if self.args.local_rank > 0:\n            if self.mode == 'infer_and_train':\n                # wait until infer is done\n                tmp_path = self.tmp_data_share_path if self.tmp_data_share_path is not None else self.args.output_dir\n                logger.info('waiting for obj, load frm temp path {}'.format(tmp_path))\n                rationale_list = load(tmp_path)\n                self.train_dataset.load_rationale(rationale_list)\n                logger.info('Rationale loaded')\n\n    def train(self):\n\n        if self.mode == 'train_only':\n            logger.info(\"Train only mode\")\n            super().train()\n        elif self.mode == 'infer_only':\n            logger.info(\"infer only mode, skip training\")\n            self.infer()\n        elif self.mode == 'infer_and_train':\n            logger.info(\"infer and train mode\")\n            self.infer()\n            super().train() \n\n    def get_infer_result(self):\n        return self.infer_result\n\n\nclass FedCoTTraineServer(object):\n\n    def __init__(self, ctx: Context, infer_server: Union[SLMEncoderDecoderServer, InferDPTServer]):\n        super().__init__()\n        self.ctx = ctx\n        self.infer_server = infer_server\n\n    def train(self):\n        logger.info('Server side start inference')\n        self.infer_server.inference()\n        logger.info('Server inference done')\n\n\nif __name__ == '__main__':\n    pass\n"
  },
  {
    "path": "python/fate_llm/algo/fedcot/slm_encoder_decoder_trainer.py",
    "content": "from fate_llm.trainer.seq2seq_trainer import Seq2SeqTrainer\nfrom transformers import DataCollatorForSeq2Seq\nfrom transformers import AutoTokenizer\nimport pandas as pd\n\n\nclass EDPrefixDataCollator(DataCollatorForSeq2Seq):\n    def __call__(self, features, return_tensors=None):\n        features_df = pd.DataFrame(features)\n        a = super().__call__(list(features_df['encoder']), return_tensors)\n        b = super().__call__(list(features_df['decoder']), return_tensors)\n\n        return {\n            'encoder': a,\n            'decoder': b\n        }\n\n\nclass EncoderDecoderPrefixTrainer(Seq2SeqTrainer):\n\n    def __init__(self, alpha=0.5, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.alpha = alpha\n\n    def compute_loss(self, model, inputs, return_outputs=False):\n        out_a = model(**inputs['encoder'])\n        out_b = model(**inputs['decoder'])\n        loss = self.alpha * out_a.loss + (1. - self.alpha) * out_b.loss\n        return (loss, {'out_a': out_a, 'out_b': out_b}) if return_outputs else loss\n"
  },
  {
    "path": "python/fate_llm/algo/fedkseed/__init__.py",
    "content": ""
  },
  {
    "path": "python/fate_llm/algo/fedkseed/args.py",
    "content": "from dataclasses import dataclass, field\n\n\n@dataclass\nclass KSeedTrainingArguments:\n    \"\"\"\n    TrainingArguments is the subset of the arguments we use in our example scripts, they are the arguments that\n\n    Parameters:\n        optim: optional, default is KSeedZO\n            The optimizer to use.\n        eps: optional, default is 0.0005\n            Epsilon value for KSeedZerothOrderOptimizer.\n        grad_clip: optional, default is -100.0\n            Gradient clip value for KSeedZerothOrderOptimizer.\n    \"\"\"\n\n    zo_optim: bool = field(\n        default=True,\n        metadata={\"help\": \"Whether to use KSeedZerothOrderOptimizer. This suppress `optim` argument when True.\"},\n    )\n    k: int = field(\n        default=4096,\n        metadata={\"help\": \"The number of seed candidates to use. This suppress `seed_candidates` argument when > 1.\"},\n    )\n    eps: float = field(default=0.0005, metadata={\"help\": \"Epsilon value for KSeedZerothOrderOptimizer.\"})\n    grad_clip: float = field(default=-100.0, metadata={\"help\": \"Gradient clip value for KSeedZerothOrderOptimizer.\"})\n"
  },
  {
    "path": "python/fate_llm/algo/fedkseed/fedkseed.py",
    "content": "import copy\nimport logging\nfrom dataclasses import dataclass, field\nfrom typing import List, Mapping\n\nimport torch\nfrom fate.arch.context import Context\n\nfrom fate_llm.algo.fedkseed.pytorch_utils import get_optimizer_parameters_grouped_with_decay\nfrom fate_llm.algo.fedkseed.trainer import KSeedZOExtendedTrainer\nfrom fate_llm.algo.fedkseed.zo_utils import probability_from_amps, directional_derivative_step, get_even_seed_probabilities\nfrom fate_llm.algo.fedkseed.args import KSeedTrainingArguments\n\nlogger = logging.getLogger(__name__)\n\n\nclass Trainer:\n    def __init__(\n            self, ctx: Context, seed_candidates: torch.LongTensor, args, fedkseed_args,\n    ):\n        self.ctx = ctx\n        self.args = args\n        self.fedkseed_args = fedkseed_args\n\n        self.seed_candidates = seed_candidates\n        self.k = len(seed_candidates)\n        self.model = None\n\n    @staticmethod\n    def get_clients(ctx: Context):\n        clients = [ctx.guest]\n        try:\n            clients.extend(ctx.hosts)\n        except:\n            pass\n        return clients\n\n    def load_model(self):\n        raise NotImplementedError\n\n    def train(self):\n        direction_derivative_history = {seed.item(): [self.fedkseed_args.grad_initial] for seed in self.seed_candidates}\n        direction_derivative_sum = None\n        seed_probabilities = None\n        for aggregation_iter, sub_ctx in self.ctx.ctxs_range(self.fedkseed_args.num_aggregations):\n            # step1: re-calculate sample probabilities for each seed\n            if seed_probabilities is None:\n                seed_probabilities = get_even_seed_probabilities(self.k)\n            else:\n                seed_probabilities = probability_from_amps(\n                    [direction_derivative_history[seed.item()] for seed in self.seed_candidates],\n                    self.fedkseed_args.bias_loss_clip,\n                )\n\n            # step2(rpc): remote call to the clients to get the directional derivative history\n            # proposal\n            for client in self.get_clients(sub_ctx):\n                client.put(\n                    \"train_once\",\n                    (\n                        False,\n                        {\n                            \"seed_candidates\": self.seed_candidates,\n                            \"seed_probabilities\": seed_probabilities,\n                            \"direction_derivative_sum\": direction_derivative_sum,\n                        },\n                    ),\n                )\n\n            if direction_derivative_sum is None:\n                direction_derivative_sum = {seed.item(): 0.0 for seed in self.seed_candidates}\n            # wait for reply and update the directional derivative history\n            for client in self.get_clients(sub_ctx):\n                client_directional_derivative_history = client.get(\"direction_derivative_history\")\n                for seed, history in client_directional_derivative_history.items():\n                    # torch.LongTensor -> int\n                    seed = int(seed)\n                    if seed not in direction_derivative_history:\n                        direction_derivative_history[seed] = []\n                    direction_derivative_history[seed].extend(history)\n                    direction_derivative_sum[seed] += sum(history)\n\n            # step3: evaluate to get stopping condition if necessary\n            if self.should_stop():\n                break\n\n    def should_stop(self):\n        return False\n\n    def evaluate(self):\n        pass\n\n\nclass ClientTrainer:\n    def __init__(self, ctx: Context, model, fedkseed_args, training_args, train_dataset, eval_dataset, data_collator,\n                 tokenizer):\n        self.ctx = ctx\n        self.fedkseed_args = fedkseed_args\n        self.training_args = training_args\n        self.data_collator = data_collator\n        self.train_dataset = train_dataset\n        self.eval_dataset = eval_dataset\n        self.tokenizer = tokenizer\n\n        self.weight_decay = training_args.weight_decay\n        self.model_0 = model\n\n    def train(self):\n        for i, sub_ctx in self.ctx.ctxs_range(self.fedkseed_args.num_aggregations):\n            # step1: wait for the server to send the seed candidates and probabilities or exit signal\n            logger.info(f\"training loop started: {i}\")\n            should_exit, kwargs = sub_ctx.arbiter.get(\"train_once\")\n            seed_candidates = kwargs[\"seed_candidates\"]\n            seed_probabilities = kwargs[\"seed_probabilities\"]\n            direction_derivative_sum = kwargs[\"direction_derivative_sum\"]\n            logger.info(\n                f\"should_exit: {should_exit}, seed_candidates: {seed_candidates}, seed_probabilities: {seed_probabilities}\"\n            )\n            if should_exit:\n                break\n\n            # step2: start the training loop\n            direction_derivative_history = self.train_once(\n                seed_candidates, seed_probabilities, direction_derivative_sum\n            )\n\n            # step3: send the directional derivative history to the server\n            sub_ctx.arbiter.put(\"direction_derivative_history\", direction_derivative_history)\n\n    def train_once(self, seed_candidates, seed_probabilities, direction_derivative_sum) -> Mapping[int, List[float]]:\n        # build model\n        model = copy.deepcopy(self.model_0)\n        model.to(self.training_args.device)\n        if direction_derivative_sum is not None:\n            param_groups = get_optimizer_parameters_grouped_with_decay(model, self.weight_decay)\n            for seed, grad in direction_derivative_sum.items():\n                if grad != 0.0:\n                    directional_derivative_step(\n                        param_groups, seed, grad, lr=self.training_args.learning_rate,\n                        weight_decay=self.training_args.weight_decay\n                    )\n\n        # train\n        trainer = KSeedZOExtendedTrainer(\n            model=model,\n            training_args=self.training_args,\n            kseed_args=self.fedkseed_args,\n            tokenizer=self.tokenizer,\n            data_collator=self.data_collator,\n            train_dataset=self.train_dataset,\n            eval_dataset=self.eval_dataset,\n        )\n        trainer.configure_seed_candidates(seed_candidates, seed_probabilities)\n        trainer.train()\n        if self.eval_dataset is not None:\n            logger.info(f\"evaluate: {trainer.evaluate()}\")\n        # get directional derivative history\n        return trainer.get_directional_derivative_history()\n\n\n@dataclass\nclass FedKSeedTrainingArguments(KSeedTrainingArguments):\n    num_aggregations: int = field(default=10, metadata={\"help\": \"The number of aggregations to perform.\"})\n    bias_loss_clip: float = field(default=1000.0, metadata={\"help\": \"The bias loss clip value.\"})\n    grad_initial: float = field(\n        default=0.0, metadata={\"help\": \"The initial value for the directional derivative history.\"}\n    )\n"
  },
  {
    "path": "python/fate_llm/algo/fedkseed/optimizer.py",
    "content": "\"\"\"\nThe implementations of ZerothOrderOptimizer and KSeedZerothOrderOptimizer is\nadapted from https://github.com/princeton-nlp/MeZO (MIT License) and\nhttps://github.com/alibaba/FederatedScope/tree/FedKSeed (Apache License 2.0)\n\nCopyright (c) 2021 Princeton Natural Language Processing\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\n---\n#\n#  Copyright 2023 The FederatedScope Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n\n\"\"\"\n\n\nimport math\nfrom typing import Mapping, Optional, Callable, Tuple, List\n\nimport torch\nfrom torch.optim import Optimizer\n\nfrom fate_llm.algo.fedkseed.pytorch_utils import get_optimizer_parameters_grouped_with_decay\nfrom fate_llm.algo.fedkseed.zo_utils import directional_derivative_step\n\n\nclass RandomWalkOptimizer(Optimizer):\n    \"\"\"\n    Random Walk Optimizer\n\n    This optimizer performs a `random` walk update for the parameters of the model.\n    \"\"\"\n\n    def __init__(self, params, lr, weight_decay, grad_clip, defaults=None):\n        self.lr = lr\n        self.weight_decay = weight_decay\n        self.grad_clip = grad_clip\n        if defaults is None:\n            defaults = dict(lr=lr, weight_decay=weight_decay)\n        else:\n            defaults = dict(defaults)\n            defaults.update(lr=lr, weight_decay=weight_decay)\n        super(RandomWalkOptimizer, self).__init__(params, defaults)\n\n    @classmethod\n    def from_model(cls, model, lr, weight_decay, grad_clip, **kwargs):\n        optimizer_grouped_parameters = get_optimizer_parameters_grouped_with_decay(model, weight_decay)\n        kwargs[\"lr\"] = lr\n        kwargs[\"weight_decay\"] = weight_decay\n        kwargs[\"grad_clip\"] = grad_clip\n        return cls(optimizer_grouped_parameters, **kwargs)\n\n    def directional_derivative_step(\n        self, directional_derivative_seed: int, directional_derivative_value: torch.FloatTensor\n    ) -> torch.FloatTensor:\n        \"\"\"\n        perform a step update for the parameters of the model\n        along the random direction z with the learning rate lr and the step size grad_projected_value\n        \"\"\"\n\n        if self.grad_clip > 0.0:\n            if abs(directional_derivative_value) > self.grad_clip:\n                return torch.FloatTensor([torch.nan])\n        directional_derivative_step(self.param_groups, directional_derivative_seed, directional_derivative_value)\n        return directional_derivative_value\n\n    def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:\n        raise NotImplementedError(\n            \"use random_step instead of step for RandomWalkOptimizer \\\n            since we need pass the `seed` and `grad_projected_value`\"\n        )\n\n\nclass ZerothOrderOptimizer(RandomWalkOptimizer):\n    def __init__(self, params, lr, eps, weight_decay, grad_clip):\n        self.eps = eps\n        defaults = dict(eps=eps)\n        super(ZerothOrderOptimizer, self).__init__(params, lr, weight_decay, grad_clip, defaults)\n\n    def zeroth_order_step(\n        self, directional_derivative_seed: int, closure: Callable[[], torch.FloatTensor]\n    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:\n        \"\"\"\n        perform a step update for the parameters of the model along the\n        random direction z generated by the `directional_derivative_seed`\n        with the learning rate lr\n        and the step size of calculated namely `directional_derivative_value`\n\n        Input:\n        - directional_derivative_seed: the seed for generating the random direction z\n        - closure (callable, optional): A closure that reevaluates the model and returns the loss.\n\n        Output:\n        - directional_derivative_value: the gradient projected value\n        - loss_right: the loss of the model with the perturbed parameters x + eps * z\n        - loss_left: the loss of the model with the perturbed parameters x - eps * z\n        \"\"\"\n\n        # x -> x + eps * z\n        self.random_perturb_parameters(directional_derivative_seed, scaling_factor=1.0)\n        loss_right = closure()\n\n        # x + eps * z -> x - eps * z\n        self.random_perturb_parameters(directional_derivative_seed, scaling_factor=-2.0)\n        loss_left = closure()\n\n        # x - eps * z -> x\n        self.random_perturb_parameters(directional_derivative_seed, scaling_factor=1.0)\n\n        if torch.isnan(loss_right):\n            return loss_right, loss_right, loss_left\n        if torch.isnan(loss_left):\n            return loss_left, loss_right, loss_left\n\n        # ∇f(x) · z = D_z f(x) ≈ (f(x + eps * z) - f(x - eps * z)) / (2 * eps)\n        directional_derivative_value = (loss_right - loss_left) / (2 * self.eps)\n        # perform update for the random direction z * grad_projected_value\n        directional_derivative_value = self.directional_derivative_step(\n            directional_derivative_seed, directional_derivative_value\n        )\n\n        return directional_derivative_value, loss_right, loss_left\n\n    def random_perturb_parameters(self, directional_derivative_seed: int, scaling_factor: float):\n        \"\"\"\n        Perturb the parameters with random direction z generated by the directional_derivative_seed\n\n        for each parameter theta, the update is theta = theta + scaling_factor * z * eps\n\n        Input:\n        - seed: the seed for generating the random direction z\n        - scaling_factor: the scaling factor for the random direction z\n\n        Output:\n        - None\n        \"\"\"\n        torch.manual_seed(directional_derivative_seed)\n        for param_group in self.param_groups:\n            eps = param_group[\"eps\"]\n            for param in param_group[\"params\"]:\n                if param.requires_grad:\n                    z = torch.normal(\n                        mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype\n                    )\n                    param.data = param.data + scaling_factor * eps * z\n\n\nclass KSeedZerothOrderOptimizer(ZerothOrderOptimizer):\n    def __init__(\n        self,\n        params,\n        seed_candidates: torch.LongTensor,\n        seed_probabilities: torch.FloatTensor,\n        lr,\n        eps,\n        weight_decay,\n        grad_clip,\n    ):\n        self.seed_candidate = seed_candidates\n        self.seed_probabilities = seed_probabilities\n        self.directional_derivative_history: Mapping[int, List[float]] = {seed.item(): [] for seed in seed_candidates}\n        self.sample_random_generator = torch.Generator()\n        super(KSeedZerothOrderOptimizer, self).__init__(params, lr, eps, weight_decay, grad_clip)\n\n    def sample(self) -> int:\n        sampled = torch.multinomial(\n            input=self.seed_probabilities,\n            num_samples=1,\n            generator=self.sample_random_generator,\n        )[0].item()\n        return self.seed_candidate[sampled].item()\n\n    def step(self, closure: Callable[[], torch.FloatTensor] = None) -> torch.FloatTensor:\n        if closure is None:\n            # closure is required for the zeroth_order_step, but we\n            # don't raise an error here to maintain compatibility with\n            # the third-party tools that use the `step` method without\n            # providing the closure in training loop, e.g., HuggingFace Transformers\n            return torch.FloatTensor([torch.nan])\n        return self.kseed_zeroth_order_step(closure)\n\n    def kseed_zeroth_order_step(self, closure: Callable[[], torch.FloatTensor]) -> torch.FloatTensor:\n        \"\"\"\n        Performs a single optimization step.\n\n        1. Sample a random seed for sampling z\n        2. Perturb the parameters with the random direction(-z * eps, z * eps) for evaluating the model on the batch, and compute the loss(loss1, loss2)\n        3. Compute the directional derivative value: grad_projected_value = (loss_right - loss_left) / (2 * eps)\n        4. Perform the directional derivative step update for the parameters of the model along the random direction z with the learning rate lr and the step size grad_projected_value\n\n\n        Input:\n        - closure (callable, optional): A closure that reevaluates the model and returns the loss.\n        \"\"\"\n        if closure is None:\n            raise ValueError(\"closure must not be None\")\n\n        # sample the random seed for sampling z for perturbing parameters.\n        seed = self.sample()\n        directional_derivative_value, loss_right, loss_left = self.zeroth_order_step(seed, closure)\n        if math.isnan(directional_derivative_value):\n            return directional_derivative_value\n\n        # record the directional_derivative_value for the seed\n        self.directional_derivative_history[seed].append(directional_derivative_value.item())\n\n        return loss_right  # TODO: return loss_left or loss_right or average of both?\n"
  },
  {
    "path": "python/fate_llm/algo/fedkseed/pytorch_utils.py",
    "content": "from typing import List\n\nfrom transformers.pytorch_utils import ALL_LAYERNORM_LAYERS\nfrom transformers.trainer_pt_utils import get_parameter_names\n\n\ndef get_decay_parameter_names(model) -> List[str]:\n    \"\"\"\n    Get all parameter names that weight decay will be applied to\n\n    Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still\n    apply to those modules since this function only filter out instance of nn.LayerNorm\n\n    NOTE: This function is copied from transformers\n    # Copyright 2020-present the HuggingFace Inc. team.\n    #\n    # Licensed under the Apache License, Version 2.0 (the \"License\");\n    # you may not use this file except in compliance with the License.\n    # You may obtain a copy of the License at\n    #\n    #     http://www.apache.org/licenses/LICENSE-2.0\n    #\n    # Unless required by applicable law or agreed to in writing, software\n    # distributed under the License is distributed on an \"AS IS\" BASIS,\n    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n    # See the License for the specific language governing permissions and\n    # limitations under the License.\n    \"\"\"\n    decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)\n    decay_parameters = [name for name in decay_parameters if \"bias\" not in name]\n    return decay_parameters\n\n\ndef get_optimizer_parameters_grouped_with_decay(model, weight_decay: float) -> List[dict]:\n    \"\"\"\n    Get the parameters grouped by whether they should have weight decay applied\n    \"\"\"\n    decay_parameters = get_decay_parameter_names(model)\n    params_no_decay = []\n    params_decay = []\n    for n, p in model.named_parameters():\n        if p.requires_grad:\n            if n in decay_parameters:\n                params_decay.append(p)\n            else:\n                params_no_decay.append(p)\n    grouped_parameters_with_decay = [\n        {\"params\": params_no_decay, \"weight_decay\": 0.0},\n        {\"params\": params_decay, \"weight_decay\": weight_decay},\n    ]\n    return grouped_parameters_with_decay\n"
  },
  {
    "path": "python/fate_llm/algo/fedkseed/trainer.py",
    "content": "import logging\nfrom typing import Dict, Union, Any, Tuple\nfrom typing import Optional, List, Callable\n\nimport torch\nfrom torch import nn\nfrom torch.utils.data import Dataset\nfrom transformers import PreTrainedModel, PreTrainedTokenizerBase, EvalPrediction, DataCollator\nfrom transformers import Trainer, TrainingArguments\nfrom transformers.optimization import get_scheduler, SchedulerType\nfrom transformers.trainer_callback import TrainerCallback\n\nfrom fate_llm.algo.fedkseed.args import KSeedTrainingArguments\nfrom fate_llm.algo.fedkseed.optimizer import KSeedZerothOrderOptimizer\nfrom fate_llm.algo.fedkseed.pytorch_utils import get_optimizer_parameters_grouped_with_decay\n\nlogger = logging.getLogger(__name__)\n\n\nclass KSeedZOExtendedTrainer(Trainer):\n    def __init__(\n        self,\n        model: Union[PreTrainedModel, nn.Module] = None,\n        training_args: TrainingArguments = None,\n        kseed_args: \"KSeedTrainingArguments\" = None,\n        data_collator: Optional[DataCollator] = None,\n        train_dataset: Optional[Dataset] = None,\n        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,\n        tokenizer: Optional[PreTrainedTokenizerBase] = None,\n        model_init: Optional[Callable[[], PreTrainedModel]] = None,\n        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,\n        callbacks: Optional[List[TrainerCallback]] = None,\n        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),\n        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,\n    ):\n        super().__init__(\n            model=model,\n            args=training_args,\n            data_collator=data_collator,\n            train_dataset=train_dataset,\n            eval_dataset=eval_dataset,\n            tokenizer=tokenizer,\n            model_init=model_init,\n            compute_metrics=compute_metrics,\n            callbacks=callbacks,\n            optimizers=optimizers,\n            preprocess_logits_for_metrics=preprocess_logits_for_metrics,\n        )\n        self.kseed_args = kseed_args\n        self._kseed_optimizer = None\n\n        self._seed_candidates = None\n        self._seed_probabilities = None\n\n    def configure_seed_candidates(self, seed_candidates: torch.LongTensor, seed_probabilities: torch.FloatTensor):\n        self._seed_candidates = seed_candidates\n        self._seed_probabilities = seed_probabilities\n\n    def get_directional_derivative_history(self):\n        \"\"\"\n        hook to get the directional derivative history\n        \"\"\"\n        if KSeedZOExtendedTrainer.k_seed_zo_mode(self.kseed_args):\n            if self._kseed_optimizer is None:\n                raise ValueError(\"KSeedZerothOrderOptimizer is not configured\")\n            return self._kseed_optimizer.directional_derivative_history\n        else:\n            raise ValueError(\"KSeedZerothOrderOptimizer is not configured\")\n\n    @staticmethod\n    def k_seed_zo_mode(args):\n        return hasattr(args, \"zo_optim\") and args.zo_optim\n\n    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:\n        \"\"\"\n        hook to do the step with KSeedZerothOrderOptimizer\n        \"\"\"\n        if KSeedZOExtendedTrainer.k_seed_zo_mode(self.kseed_args):\n            if self._kseed_optimizer is None:\n                raise ValueError(\"KSeedZerothOrderOptimizer is not configured\")\n\n            model.eval()\n            inputs = self._prepare_inputs(inputs)\n\n            with self.compute_loss_context_manager():\n                # zeroth order optimization needs forward pass twice in an optimization step,\n                # so we need to wrap the forward pass in a closure\n                def closure() -> torch.FloatTensor:\n                    with torch.no_grad():\n                        return self.compute_loss(model, inputs, return_outputs=False).detach()\n\n            # we don't use step() method of KSeedZerothOrderOptimizer here\n            # because `Trainer` wraps the optimizer that is subclass of `torch.optim.Optimizer` and\n            # returns nothing from the step method\n            with torch.no_grad():\n                loss = self._kseed_optimizer.kseed_zeroth_order_step(closure=closure)\n                return loss.detach()\n        else:\n            return super().training_step(model, inputs)\n\n    def create_optimizer_and_scheduler(self, num_training_steps: int):\n        \"\"\"\n        hook to add KSeedZerothOrderOptimizer\n        \"\"\"\n        if KSeedZOExtendedTrainer.k_seed_zo_mode(self.kseed_args):\n\n            if self._seed_candidates is None or self._seed_probabilities is None:\n                raise ValueError(\"Seed candidates and probabilities are not configured.\")\n\n            optimizer_grouped_parameters = get_optimizer_parameters_grouped_with_decay(\n                self.model, self.args.weight_decay\n            )\n            self.optimizer = KSeedZerothOrderOptimizer(\n                optimizer_grouped_parameters,\n                seed_candidates=self._seed_candidates,\n                seed_probabilities=self._seed_probabilities,\n                lr=self.args.learning_rate,\n                eps=self.kseed_args.eps,\n                weight_decay=self.args.weight_decay,\n                grad_clip=self.kseed_args.grad_clip,\n            )\n            # we need to keep the reference to the original optimizer to use it in training_step\n            self._kseed_optimizer = self.optimizer\n            # if we use learning rate scheduler, we may need to preserve all updates instead of the aggregated one\n            self.lr_scheduler = get_scheduler(\n                name=SchedulerType.CONSTANT,\n                optimizer=self.optimizer,\n                num_warmup_steps=self.args.warmup_steps,\n                num_training_steps=num_training_steps,\n            )\n        else:\n            super().create_optimizer_and_scheduler(num_training_steps)\n"
  },
  {
    "path": "python/fate_llm/algo/fedkseed/zo_utils.py",
    "content": "from typing import List\n\nimport torch\n\n\ndef probability_from_amps(amps: List[List[float]], clip):\n    \"\"\"\n    Get the probability distribution from the amplitude history\n\n    formula: amp_i = clamp(amp_i, -clip, clip).abs().mean()\n             amp_i = (amp_i - min(amp)) / (max(amp) - min(amp))\n             prob_i = softmax(amp)_i\n\n    :param amps: list of amplitude history\n    :param clip: the clipping value\n    :return:\n    \"\"\"\n    amps = [torch.Tensor(amp) for amp in amps]\n    amp = torch.stack([amp.clamp_(-clip, clip).abs_().mean() for amp in amps])\n    return (amp - amp.min()).div_(amp.max() - amp.min() + 1e-10).softmax(0)\n\n\ndef directional_derivative_step(\n    param_groups: List[dict],\n    directional_derivative_seed: int,\n    directional_derivative_value: torch.FloatTensor,\n    lr: float = None,\n    weight_decay: float = None,\n) -> torch.FloatTensor:\n    \"\"\"\n    perform a step update for the parameters of the model\n    along the random direction z with the learning rate lr and the step size grad_projected_value\n\n    Input:\n    - param_groups (List[dict]): list of parameter groups\n    - directional_derivative_seed (int): seed for the random direction\n    - directional_derivative_value (torch.FloatTensor): the step size\n    - lr (float, optional): learning rate\n    - weight_decay (float, optional): weight decay\n    \"\"\"\n\n    torch.manual_seed(directional_derivative_seed)\n    for param_group in param_groups:\n        weight_decay = param_group[\"weight_decay\"] if weight_decay is None else weight_decay\n        lr = param_group[\"lr\"] if lr is None else lr\n        for param in param_group[\"params\"]:\n            z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)\n            if weight_decay is not None:\n                param.data = param.data - lr * (directional_derivative_value * z + weight_decay * param.data)\n\n            else:\n                param.data = param.data - lr * (directional_derivative_value * z)\n\n    return directional_derivative_value\n\n\ndef build_seed_candidates(k, low=0, high=2**32):\n    \"\"\"\n    Build seed candidates for the random walk optimizer\n    \"\"\"\n    return torch.randint(low, high, size=(k,), dtype=torch.long)\n\n\ndef get_even_seed_probabilities(k):\n    \"\"\"\n    Get the even seed probabilities, i.e., 1/k for each seed\n    \"\"\"\n    return torch.ones(k) / k\n"
  },
  {
    "path": "python/fate_llm/algo/fedmkt/__init__.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n\nfrom fate_llm.algo.fedmkt.fedmkt import (\n    FedMKTTrainingArguments,\n    FedMKTSLM,\n    FedMKTLLM\n)\n\n__all__ = [\n    \"FedMKTSLM\",\n    \"FedMKTLLM\",\n    \"FedMKTTrainingArguments\"\n]\n"
  },
  {
    "path": "python/fate_llm/algo/fedmkt/fedmkt.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport torch\nimport logging\nimport datasets\nfrom dataclasses import dataclass, field\n\nimport transformers\n\nfrom ...trainer.seq2seq_trainer import Seq2SeqTrainingArguments\nfrom typing import Dict, Optional, List, Callable, Union\nfrom fate.arch import Context\nfrom fate.ml.nn.trainer.trainer_base import FedArguments\nfrom torch.utils.data import Dataset\nfrom transformers.trainer_callback import TrainerCallback\nfrom transformers import PreTrainedTokenizer\nfrom transformers import Seq2SeqTrainer\nfrom transformers.trainer_utils import EvalPrediction\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.modeling_utils import unwrap_model\nfrom fate_llm.algo.fedmkt.token_alignment.token_align import token_align\nfrom fate_llm.algo.fedmkt.utils.generate_logit_utils import generate_pub_data_logits\nfrom fate.ml.aggregator import AggregatorClientWrapper, AggregatorServerWrapper\nfrom fate_llm.algo.fedmkt.fedmkt_trainer import FedMKTTrainer\nfrom fate_llm.algo.fedmkt.fedmkt_data_collator import DataCollatorForFedMKT\nfrom fate_llm.algo.fedmkt.utils.dataset_sync_util import sync_dataset\n\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclass\nclass FedMKTTrainingArguments(Seq2SeqTrainingArguments):\n    \"\"\"\n    selection metric type\n    \"\"\"\n    metric_type: str = field(default=\"ce\")\n\n    \"\"\"\n    top-k logits select params\n    \"\"\"\n    top_k_logits_keep: int = field(default=128)\n    top_k_strategy: str = field(default=\"highest\")\n\n    \"\"\"\n    distillation params\n    \"\"\"\n    distill_loss_type: str = field(default=\"ce\")\n    kd_alpha: float = field(default=0.9)\n    distill_temperature: float = field(default=1.0)\n    server_public_data_local_epoch: int = field(default=1)\n    client_public_data_local_epoch: int = field(default=1)\n    client_priv_data_local_epoch: int = field(default=1)\n    distill_strategy: str = field(default=\"greater\")\n    global_epochs: int = field(default=1)\n\n    \"\"\"\n    token-alignment params\n    \"\"\"\n    skip_align: bool = field(default=False)\n    token_align_strategy: str = field(default=\"dtw\")\n    vocab_mapping_paths: Union[str, List[str]] = field(default=None)\n    vocab_size: int = field(default=None)\n\n    \"\"\"\n    homo training params\n    \"\"\"\n    post_fedavg: bool = field(default=False)\n\n    \"\"\"\n    slm training only\n    \"\"\"\n    llm_training: bool = field(default=True)\n\n    def to_dict(self):\n        from dataclasses import fields\n        from enum import Enum\n        d = {field.name: getattr(self, field.name) for field in fields(self) if field.init}\n\n        for k, v in d.items():\n            if isinstance(v, Enum):\n                d[k] = v.value\n            if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):\n                d[k] = [x.value for x in v]\n            if k.endswith(\"_token\"):\n                d[k] = f\"<{k.upper()}>\"\n        return d\n\n    def to_dict_without_extra_args(self):\n        args_dict = self.to_dict()\n        args_dict.pop(\"metric_type\")\n        args_dict.pop(\"top_k_logits_keep\")\n        args_dict.pop(\"top_k_strategy\")\n\n        args_dict.pop(\"distill_loss_type\")\n        args_dict.pop(\"kd_alpha\")\n        args_dict.pop(\"distill_temperature\")\n        args_dict.pop(\"distill_strategy\")\n        args_dict.pop(\"server_public_data_local_epoch\")\n        args_dict.pop(\"client_public_data_local_epoch\")\n        args_dict.pop(\"client_priv_data_local_epoch\")\n        args_dict.pop(\"global_epochs\")\n\n        args_dict.pop(\"skip_align\", False)\n        args_dict.pop(\"token_align_strategy\")\n        args_dict.pop(\"vocab_mapping_paths\", None)\n        args_dict.pop(\"vocab_size\", None)\n\n        args_dict.pop(\"post_fedavg\")\n\n        args_dict.pop(\"llm_training\", True)\n\n        return args_dict\n\n    def to_dict_with_client_priv_training_args(self):\n        args_dict = self.to_dict_without_extra_args()\n\n        args_dict[\"num_train_epochs\"] = self.client_priv_data_local_epoch\n\n        return args_dict\n\n    def to_dict_with_client_kd_args(self):\n        args_dict = self.to_dict_without_extra_args()\n\n        args_dict[\"num_train_epochs\"] = self.client_public_data_local_epoch\n\n        return args_dict\n\n    def to_dict_with_server_kd_args(self):\n        args_dict = self.to_dict_without_extra_args()\n        args_dict[\"num_train_epochs\"] = self.server_public_data_local_epoch\n\n        return args_dict\n\n\nclass FedMKTBase(object):\n    def __init__(self, *args, **kwargs):\n        self.model = None\n        self.save_trainable_weights_only = None\n\n    def save_model(\n        self,\n        output_dir: Optional[str] = None,\n        state_dict=None\n    ):\n        if not self.save_trainable_weights_only:\n            torch.save(self.model.state_dict(), output_dir + '/pytorch_model.bin')\n        else:\n            model = unwrap_model(self.model)\n\n            if hasattr(model, \"save_trainable\"):\n                model.save_trainable(output_dir)\n            else:\n                state_dict = {\n                    k: p.to(\"cpu\") for k,\n                                       p in model.named_parameters() if p.requires_grad\n                }\n\n                torch.save(state_dict, output_dir + '/pytorch_model.bin')\n\n\nclass FedMKTSLM(FedMKTBase):\n    def __init__(\n        self,\n        ctx: Context,\n        model: torch.nn.Module,\n        training_args: FedMKTTrainingArguments,\n        fed_args: FedArguments = None,\n        priv_train_set=None,\n        pub_train_set=None,\n        val_set: Dataset = None,\n        priv_optimizer: torch.optim.Optimizer = None,\n        pub_optimizer: torch.optim.Optimizer = None,\n        priv_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,\n        pub_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,\n        data_collator: Callable = None,\n        tokenizer: Optional[PreTrainedTokenizer] = None,\n        model_init: Optional[Callable[[], PreTrainedModel]] = None,\n        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,\n        callbacks: Optional[List[TrainerCallback]] = [],\n        save_trainable_weights_only: bool = False,\n        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,\n        llm_tokenizer=None,\n        llm_to_slm_vocab_mapping=None,\n    ):\n        super(FedMKTSLM, self).__init__()\n        self.ctx = ctx\n        self.training_args = training_args\n        self.fed_args = fed_args\n        self.model = model\n        self.tokenizer = tokenizer\n        self.model_init = model_init\n        self.callbacks = callbacks\n        self.compute_metrics = compute_metrics\n        self.save_trainable_weights_only = save_trainable_weights_only\n        self.preprocess_logits_for_metrics = preprocess_logits_for_metrics\n\n        self.priv_data_collator = data_collator\n        self.priv_optimizer = priv_optimizer\n        self.pub_optimizer = pub_optimizer\n        self.priv_scheduler = priv_scheduler\n        self.pub_scheduler = pub_scheduler\n        self.priv_train_set = priv_train_set\n        self.pub_train_set = pub_train_set\n\n        self.llm_tokenizer = llm_tokenizer\n        self.llm_to_slm_vocab_mapping = llm_to_slm_vocab_mapping\n\n        self.val_set = val_set\n\n        self.aggregator = self._init_aggregator(ctx, fed_args)\n\n        if not isinstance(self.pub_train_set, datasets.Dataset):\n            self.pub_train_set = datasets.Dataset.from_list(list(self.pub_train_set))\n\n    def train(self):\n        global_epochs = self.training_args.global_epochs\n\n        llm_pub_logits = None\n        for i, iter_ctx in self.ctx.on_iterations.ctxs_range(global_epochs):\n            logger.info(f\"begin {i}-th global kd process\")\n            priv_data_training_args = self._get_priv_data_training_args()\n\n            priv_trainer = Seq2SeqTrainer(\n                model=self.model,\n                tokenizer=self.tokenizer,\n                data_collator=self.priv_data_collator,\n                train_dataset=self.priv_train_set,\n                args=priv_data_training_args,\n                model_init=self.model_init if not i else None,\n                compute_metrics=self.compute_metrics,\n                callbacks=self.callbacks,\n                optimizers=(self.priv_optimizer, self.priv_scheduler),\n                preprocess_logits_for_metrics=self.preprocess_logits_for_metrics\n            )\n\n            logger.info(f\"begin {i}-th private data training process\")\n            priv_trainer.train()\n\n            self.model = unwrap_model(priv_trainer.model)\n\n            logger.info(f\"begin {i}-th public logits generation process\")\n\n            if self.training_args.world_size <= 1 or self.training_args.local_rank == 0:\n                slm_pub_logits = self.pub_train_set.map(\n                    generate_pub_data_logits,\n                    batched=True,\n                    batch_size=self.training_args.per_device_train_batch_size,\n                    num_proc=None,\n                    load_from_cache_file=True,\n                    fn_kwargs={\"model\": self.model,\n                               \"training_args\": self.training_args,\n                               \"data_collator\": transformers.DataCollatorForSeq2Seq(self.tokenizer)}\n                )\n\n                if self.training_args.world_size > 1:\n                    logger.info(\"sync slm_pub_logits\")\n                    sync_dataset(\n                        slm_pub_logits, self.training_args.local_rank, self.training_args.world_size, self.training_args.device\n                    )\n\n                if self.training_args.llm_training:\n                    logger.debug(f\"send {i}-th public logits to llm\")\n                    iter_ctx.arbiter.put(\"slm_pub_logits\", slm_pub_logits.to_dict())\n\n                if self.training_args.llm_training or not i:\n                    llm_pub_logits = datasets.Dataset.from_dict(iter_ctx.arbiter.get(\"llm_pub_logits\"))\n                    if self.training_args.world_size > 1:\n                        logger.info(\"sync llm_pub_logits\")\n                        sync_dataset(llm_pub_logits, self.training_args.local_rank,\n                                     self.training_args.world_size, self.training_args.device)\n            else:\n                slm_pub_logits = sync_dataset(\n                    None, self.training_args.local_rank, self.training_args.world_size, self.training_args.device\n                )\n\n                if self.training_args.llm_training or not i:\n                    llm_pub_logits = sync_dataset(None, self.training_args.local_rank,\n                                                  self.training_args.world_size, self.training_args.device)\n\n            logger.info(f\"begin {i}-th token alignment process\")\n            aligned_dataset = token_align(\n                base_model_logits_datasets=slm_pub_logits,\n                blending_model_logits_dataset=llm_pub_logits,\n                base_tokenizer=self.tokenizer,\n                blending_tokenizer=self.llm_tokenizer,\n                blending_to_base_mapping=self.llm_to_slm_vocab_mapping,\n                blending_model_index=0,\n                skip_align=self.training_args.skip_align,\n                align_strategy=self.training_args.token_align_strategy\n            )\n\n            logger.info(f\"begin {i}-th public logits kd process\")\n            fedmkt_trainer = self._init_trainer_for_distill(aligned_dataset)\n            fedmkt_trainer.train()\n            self.model = unwrap_model(fedmkt_trainer.model)\n\n            if self.training_args.post_fedavg and (i + 1) % self.fed_args.aggregate_freq == 0:\n                self.aggregator.model_aggregation(iter_ctx, self.model)\n\n    def _init_trainer_for_distill(self, train_set):\n        public_data_training_args = self._get_pub_data_kd_training_args()\n        fedmkt_trainer = FedMKTTrainer(\n            model=self.model,\n            tokenizer=self.tokenizer,\n            args=public_data_training_args,\n            train_dataset=train_set,\n            eval_dataset=self.val_set,\n            data_collator=DataCollatorForFedMKT(\n                self.tokenizer,\n                padding=\"max_length\",\n                max_length=max(len(d[\"input_ids\"]) for d in train_set),\n                blending_num=1,\n                vocab_size=self.training_args.vocab_size,\n                dtype=next(self.model.parameters()).dtype,\n                distill_temperature=self.training_args.distill_temperature\n            ),\n            blending_num=1,\n            lm_loss_weight=self.training_args.kd_alpha,\n            distill_loss_type=self.training_args.distill_loss_type,\n            distill_strategy=self.training_args.distill_strategy\n        )\n\n        return fedmkt_trainer\n\n    def _get_priv_data_training_args(self):\n        pre_args = self.training_args.to_dict_with_client_priv_training_args()\n        post_args = Seq2SeqTrainingArguments(**pre_args)\n\n        return post_args\n\n    def _get_pub_data_kd_training_args(self):\n        pre_args = self.training_args.to_dict_with_client_kd_args()\n        post_args = Seq2SeqTrainingArguments(**pre_args)\n\n        return post_args\n\n    def _init_aggregator(self, ctx: Context, fed_args: FedArguments):\n        if not self.training_args.post_fedavg:\n            return None\n\n        aggregate_type = \"weighted_mean\"\n        aggregator_name = \"fedavg\"\n        aggregator = fed_args.aggregator\n        return AggregatorClientWrapper(\n            ctx, aggregate_type, aggregator_name, aggregator,\n            sample_num=len(self.pub_train_set), args=self.training_args\n        )\n\n\nclass FedMKTLLM(FedMKTBase):\n    def __init__(\n        self,\n        ctx: Context,\n        model: torch.nn.Module,\n        training_args: FedMKTTrainingArguments,\n        fed_args: FedArguments = None,\n        train_set=None,\n        val_set: Dataset = None,\n        optimizer: torch.optim.Optimizer = None,\n        scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,\n        data_collator: Callable = None,\n        tokenizer: Optional[PreTrainedTokenizer] = None,\n        model_init: Optional[Callable[[], PreTrainedModel]] = None,\n        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,\n        callbacks: Optional[List[TrainerCallback]] = [],\n        save_trainable_weights_only: bool = False,\n        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,\n        slm_tokenizers: List = None,\n        slm_to_llm_vocab_mappings: List[Dict] = None,\n    ):\n        super(FedMKTLLM, self).__init__()\n        self.ctx = ctx\n        self.model = model\n        self.training_args = training_args\n        self.fed_args = fed_args\n        self.train_set = train_set\n        self.val_set = val_set\n        self.optimizer = optimizer\n        self.lr_scheduler = scheduler\n        self.data_collator = data_collator\n        self.tokenizer = tokenizer\n        self.model_init = model_init\n        self.compute_metrics = compute_metrics\n        self.callbacks = callbacks\n        self.save_trainable_weights_only = save_trainable_weights_only\n        self.preprocess_logits_for_metrics = preprocess_logits_for_metrics\n        self.slm_tokenizers = slm_tokenizers\n        self.slm_to_llm_vocab_mappings = slm_to_llm_vocab_mappings\n\n        self.aggregator = self._init_aggregator(ctx)\n\n        if not isinstance(self.train_set, datasets.Dataset):\n            self.train_set = datasets.Dataset.from_list(list(self.train_set))\n\n    def _init_aggregator(self, ctx: Context):\n        if not self.training_args.post_fedavg:\n            return None\n        return AggregatorServerWrapper(ctx)\n\n    def generate_pub_data_logits(self, first_epoch=False):\n        fn_kwargs = {\"model\": self.model,\n                     \"training_args\": self.training_args,\n                     \"data_collator\": transformers.DataCollatorForSeq2Seq(self.tokenizer)}\n        if first_epoch and self.training_args.device.type == \"cuda\":\n            self.model.cuda(self.training_args.device)\n\n        return self.train_set.map(\n            generate_pub_data_logits,\n            batched=True,\n            batch_size=self.training_args.per_device_train_batch_size,\n            num_proc=None,\n            load_from_cache_file=True,\n            fn_kwargs=fn_kwargs\n        )\n\n    def on_epoch_begin(self, iter_ctx, epoch_idx, previous_pub_dataset):\n        logger.info(f\"on {epoch_idx}-epoch begin\")\n        if not self.training_args.llm_training:\n            return\n\n        if previous_pub_dataset is None:\n            if self.training_args.world_size <= 1 or self.training_args.local_rank == 0:\n                llm_pub_logits = self.generate_pub_data_logits(first_epoch=True if not epoch_idx else False)\n                if self.training_args.world_size > 1:\n                    sync_dataset(llm_pub_logits, self.training_args.local_rank,\n                                 self.training_args.world_size, self.training_args.device)\n            else:\n                llm_pub_logits = sync_dataset(None, self.training_args.local_rank,\n                                              self.training_args.world_size, self.training_args.device)\n        else:\n            llm_pub_logits = previous_pub_dataset\n\n        slm_pub_logits_list = list()\n        if self.training_args.world_size <= 1 or self.training_args.local_rank == 0:\n            slm_pub_logits_list.append(datasets.Dataset.from_dict(iter_ctx.guest.get('slm_pub_logits')))\n            if any(p.role == 'host' for p in self.ctx.parties):\n                slm_pub_logits_list.extend(\n                    datasets.Dataset.from_dict(client_logits) for client_logits in iter_ctx.hosts.get(\"slm_pub_logits\")\n                )\n            if self.training_args.world_size > 1:\n                logger.info(\"sync dataset to other rank\")\n                for slm_pub_logits in slm_pub_logits_list:\n                    sync_dataset(slm_pub_logits, self.training_args.local_rank,\n                                 self.training_args.world_size, self.training_args.device)\n                    logger.info(\"end to sync\")\n        else:\n            logger.info(\"sync dataset from rank 0\")\n            for _ in range(len(self.slm_tokenizers)):\n                slm_pub_logits_list.append(\n                    sync_dataset(None, self.training_args.local_rank,\n                                 self.training_args.world_size, self.training_args.device)\n                )\n            logger.info(\"end to sync dataset from rank 0\")\n\n        aligned_dataset = llm_pub_logits\n        for idx, slm_pub_logits in enumerate(slm_pub_logits_list):\n            aligned_dataset = token_align(\n                base_model_logits_datasets=aligned_dataset,\n                blending_model_logits_dataset=slm_pub_logits,\n                base_tokenizer=self.tokenizer,\n                blending_tokenizer=self.slm_tokenizers[idx],\n                blending_to_base_mapping=self.slm_to_llm_vocab_mappings[idx],\n                blending_model_index=idx,\n                skip_align=self.training_args.skip_align,\n                align_strategy=self.training_args.token_align_strategy\n            )\n\n        return aligned_dataset\n\n    def on_epoch_end(self, iter_ctx, epoch_idx):\n        logger.info(f\"on {epoch_idx}-epoch end\")\n        if not self.training_args.llm_training and epoch_idx > 1:\n            return\n\n        llm_pub_logits = self.generate_pub_data_logits(first_epoch=True if not self.training_args.llm_training else False)\n\n        if self.training_args.world_size <= 1 or self.training_args.local_rank == 0:\n            iter_ctx.guest.put(\"llm_pub_logits\", llm_pub_logits.to_dict())\n            if len(self.slm_tokenizers) > 1:\n                iter_ctx.hosts.put(\"llm_pub_logits\", llm_pub_logits.to_dict())\n\n            if self.training_args.post_fedavg and (epoch_idx + 1) % self.fed_args.aggregate_freq == 0:\n                self.aggregator.model_aggregation(iter_ctx)\n\n            if self.training_args.world_size > 1:\n                sync_dataset(\n                    llm_pub_logits, self.training_args.local_rank, self.training_args.world_size, self.training_args.device\n                )\n        else:\n            llm_pub_logits = sync_dataset(\n                None, self.training_args.local_rank, self.training_args.world_size, self.training_args.device\n            )\n\n        return llm_pub_logits\n\n    def _get_pub_data_kd_training_args(self):\n        pre_args = self.training_args.to_dict_with_server_kd_args()\n        post_args = Seq2SeqTrainingArguments(**pre_args)\n\n        return post_args\n\n    def train(self):\n        global_epochs = self.training_args.global_epochs\n        previous_pub_logits = None\n\n        for i, iter_ctx in self.ctx.on_iterations.ctxs_range(global_epochs):\n            logger.info(f\"begin {i}-th global kd process\")\n\n            aligend_train_set = self.on_epoch_begin(iter_ctx, i, previous_pub_logits)\n            if self.training_args.llm_training:\n\n                public_data_training_args = self._get_pub_data_kd_training_args()\n                fedmkt_trainer = FedMKTTrainer(\n                    model=self.model,\n                    tokenizer=self.tokenizer,\n                    args=public_data_training_args,\n                    train_dataset=aligend_train_set,\n                    eval_dataset=self.val_set,\n                    data_collator=DataCollatorForFedMKT(\n                        self.tokenizer,\n                        padding=\"max_length\",\n                        max_length=max(len(d[\"input_ids\"]) for d in aligend_train_set),\n                        blending_num=len(self.slm_tokenizers),\n                        vocab_size=self.training_args.vocab_size,\n                        dtype=next(self.model.parameters()).dtype,\n                        distill_temperature=self.training_args.distill_temperature\n                    ),\n                    blending_num=len(self.slm_tokenizers),\n                    lm_loss_weight=self.training_args.kd_alpha,\n                    distill_loss_type=self.training_args.distill_loss_type,\n                    distill_strategy=self.training_args.distill_strategy\n                )\n\n                fedmkt_trainer.train()\n                self.model = unwrap_model(fedmkt_trainer.model)\n\n            previous_pub_logits = self.on_epoch_end(iter_ctx, i)\n"
  },
  {
    "path": "python/fate_llm/algo/fedmkt/fedmkt_data_collator.py",
    "content": "#\n# NOTE: The implementations of DataCollatorForFedMKT is modified from FuseAI/FuseLLM\n# Copyright FuseAI/FuseLLM\n#\n#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport torch\nfrom torch.nn.functional import softmax\nfrom transformers import DataCollatorForSeq2Seq\nfrom transformers.tokenization_utils_base import PreTrainedTokenizerBase\nfrom transformers.utils import PaddingStrategy\nfrom typing import Optional, Any, Union\nimport logging\nfrom fate_llm.algo.fedmkt.utils.vars_define import (\n    ALIGNED_OTHER_LOGITS,\n    ALIGNED_OTHER_INDICES,\n    PER_STEP_LOGITS,\n    PER_STEP_INDICES,\n    SELF_TARGET_DIST,\n    OTHER_TARGET_DIST\n)\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass DataCollatorForFedMKT(DataCollatorForSeq2Seq):\n    \"\"\"modified from https://github.com/fanqiwan/FuseAI/blob/main/FuseLLM/src/utils/data_collator.py#L135\"\"\"\n    tokenizer: PreTrainedTokenizerBase\n    model: Optional[Any] = None\n    padding: Union[bool, str, PaddingStrategy] = True\n    max_length: Optional[int] = None\n    pad_to_multiple_of: Optional[int] = None\n    label_pad_token_id: int = -100\n    return_tensors: str = \"pt\"\n    blending_num: int = 1\n    distill_temperature: float = 1.0\n    vocab_size: int = None\n    dtype: torch.dtype = torch.bfloat16\n\n    def __init__(self, *args, **kwargs):\n        blending_num = kwargs.pop(\"blending_num\", 4)\n        vocab_size = kwargs.pop(\"vocab_size\", None)\n        dtype = kwargs.pop(\"dtype\", torch.dtype)\n        distill_temperature = kwargs.pop(\"distill_temperature\", 1.0)\n        super(DataCollatorForFedMKT, self).__init__(*args, **kwargs)\n        self.blending_num = blending_num\n        self.vocab_size = vocab_size if vocab_size is not None else len(self.tokenizer.get_vocab())\n        self.pad_id = self.tokenizer.pad_token_id\n        self.dtype = dtype\n        self.distill_temperature = distill_temperature\n\n    def __call__(self, features, return_tensors=None):\n        extra_features = dict()\n        feature_keys = list(features[0].keys())\n        for f_key in feature_keys:\n            if f_key not in [\"input_ids\", \"attention_mask\", \"labels\"]:\n                extra_features[f_key] = []\n                for feature in features:\n                    extra_features[f_key].append(feature.pop(f_key))\n\n        features = super().__call__(features=features, return_tensors=return_tensors)\n\n        features.update(extra_features)\n\n        batch_size = features[\"input_ids\"].size(0)\n        base_target_dist = torch.zeros(batch_size, self.max_length, self.vocab_size).to(self.dtype)\n        aligned_target_dists = [torch.zeros(batch_size, self.max_length, self.vocab_size).to(self.dtype)\n                                for _ in range(self.blending_num)]\n\n        for i in range(batch_size):\n            base_seq_len = len(features[PER_STEP_LOGITS][i])\n            for j in range(self.max_length):\n                if j < base_seq_len:\n                    base_logits = torch.tensor(features[PER_STEP_LOGITS][i][j], dtype=self.dtype)\n                    base_prob = softmax(base_logits / self.distill_temperature, -1)\n                    base_indices = torch.tensor(features[PER_STEP_INDICES][i][j])\n                    base_target_dist[i][j] = base_target_dist[i][j].scatter_(-1, base_indices, base_prob)\n\n                    for k in range(self.blending_num):\n                        per_step_aligned_indices_key = f\"{ALIGNED_OTHER_INDICES}_{k}\"\n                        per_step_aligned_logits_key = f\"{ALIGNED_OTHER_LOGITS}_{k}\"\n                        if len(features[per_step_aligned_indices_key][i][j]) > 0:\n                            aligned_logits = torch.tensor(features[per_step_aligned_logits_key][i][j], dtype=self.dtype)\n                            aligned_prob = softmax(aligned_logits / self.distill_temperature, -1)\n                            aligned_indices = torch.tensor(features[per_step_aligned_indices_key][i][j])\n                            aligned_target_dists[k][i][j] = aligned_target_dists[k][i][j].scatter_(-1, aligned_indices, aligned_prob)\n                        else:\n                            aligned_target_dists[k][i][j] = base_target_dist[i][j]\n\n                else:  # padding position\n                    base_target_dist[i][j][self.pad_id] = 1.0\n                    for k in range(self.blending_num):\n                        aligned_target_dists[k][i][j][self.pad_id] = 1.0\n\n        features.pop(PER_STEP_LOGITS)\n        features.pop(PER_STEP_INDICES)\n        for i in range(self.blending_num):\n            features.pop(f\"{ALIGNED_OTHER_LOGITS}_{i}\")\n            features.pop(f\"{ALIGNED_OTHER_INDICES}_{i}\")\n            features[f\"{OTHER_TARGET_DIST}_{i}\"] = aligned_target_dists[i]\n\n        features[SELF_TARGET_DIST] = base_target_dist\n\n        return features\n"
  },
  {
    "path": "python/fate_llm/algo/fedmkt/fedmkt_trainer.py",
    "content": "#\n# NOTE: The implementations of FedMKTTrainer is modified from FuseAI/FuseLLM\n# Copyright FuseAI\n#\n#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport logging\nimport torch\nfrom torch.nn.functional import kl_div, log_softmax, cross_entropy\nfrom transformers import Seq2SeqTrainer\nfrom transformers.modeling_utils import unwrap_model\nfrom transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES\nfrom fate_llm.algo.fedmkt.utils.vars_define import (\n    SELF_TARGET_DIST,\n    OTHER_TARGET_DIST,\n    ALIGNED_OTHER_METRIC,\n    METRIC,\n)\n\nlogger = logging.getLogger(__name__)\n\n\nclass FedMKTTrainer(Seq2SeqTrainer):\n    \"\"\"\n    modified from https://github.com/fanqiwan/FuseAI/blob/main/FuseLLM/src/utils/trainer.py#L22\n    \"\"\"\n    blending_num: int = 2\n    distill_loss_type: str = \"ce\"\n    lm_loss_weight: float = 0.9\n    distill_strategy = \"greater\"\n\n    def __init__(self, *args, **kwargs):\n        blending_num = kwargs.pop(\"blending_num\", 1)\n        distill_loss_type = kwargs.pop(\"distill_loss_type\", \"ce\")\n        lm_loss_weight = kwargs.pop(\"lm_loss_weight\", 0.9)\n        distill_strategy = kwargs.pop(\"distill_strategy\", \"greater\")\n        super(FedMKTTrainer, self).__init__(*args, **kwargs)\n        self.blending_num = blending_num\n        self.distill_loss_type = distill_loss_type\n        self.lm_loss_weight = lm_loss_weight\n        self.distill_strategy = distill_strategy\n\n    def compute_loss(self, model, inputs, return_outputs=False):\n        if self.label_smoother is not None and \"labels\" in inputs:\n            labels = inputs.pop(\"labels\")\n        else:\n            labels = None\n\n        base_target_dist = inputs.pop(SELF_TARGET_DIST)\n        base_metric = inputs.pop(METRIC)\n\n        aligned_target_dists = []\n        aligned_metrics = []\n        for i in range(self.blending_num):\n            aligned_target_dists.append(inputs.pop(f\"{OTHER_TARGET_DIST}_{i}\"))\n            aligned_metrics.append(inputs.pop(f\"{ALIGNED_OTHER_METRIC}_{i}\"))\n\n        outputs = model(**inputs)\n        # Save past state if it exists\n        # TODO: this needs to be fixed and made cleaner later.\n        if self.args.past_index >= 0:\n            self._past = outputs[self.args.past_index]\n\n        if labels is not None:\n            if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():\n                loss = self.label_smoother(outputs, labels, shift_labels=True)\n            else:\n                loss = self.label_smoother(outputs, labels)\n        else:\n            if isinstance(outputs, dict) and \"loss\" not in outputs:\n                raise ValueError(\n                    \"The model did not return a loss from the inputs, only the following keys: \"\n                    f\"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}.\"\n                )\n            # We don't use .loss here since the model may return tuples instead of ModelOutput.\n            loss = outputs[\"loss\"] if isinstance(outputs, dict) else outputs[0]\n\n        batch_size, seq_len, vocab_size = outputs[\"logits\"].size(0), outputs[\"logits\"].size(1), outputs[\"logits\"].size(2)\n\n        aligned_rewards = []\n        for i in range(self.blending_num):\n            aligned_rewards.append((1 / torch.exp(torch.tensor(aligned_metrics[i], dtype=torch.bfloat16))).to(loss.device))\n\n        base_reward = (1 / torch.exp(torch.tensor(base_metric, dtype=torch.bfloat16))).to(loss.device)\n\n        if self.distill_strategy == \"greater\":\n            base_reward_expanded = base_reward.unsqueeze(-1).unsqueeze(-1).expand_as(base_target_dist)\n            aligned_rewards_expanded = [\n                aligned_rewards[i].unsqueeze(-1).unsqueeze(-1).expand_as(aligned_target_dists[i])\n                for i in range(self.blending_num)\n            ]\n            target_dist_list = []\n            reward_list = []\n            if base_target_dist is not None:\n                target_dist_list.append(base_target_dist)\n                reward_list.append(base_reward_expanded)\n\n            target_dist_list.extend(aligned_target_dists)\n            reward_list.extend(aligned_rewards_expanded)\n\n            stacked_dists = torch.stack(target_dist_list, dim=-1)\n            stacked_rewards = torch.stack(reward_list, dim=-1)\n            max_reward_indices = torch.argmax(stacked_rewards, dim=-1, keepdim=True)\n            target_dist = torch.gather(stacked_dists, -1, max_reward_indices).squeeze(-1)\n        elif self.distill_strategy == \"weighted_mean\":\n            weights = torch.stack(\n                [base_reward] + aligned_rewards, dim=1\n            )\n            normalized_weights = torch.softmax(weights, dim=1)\n            weight_labels = normalized_weights[:, 0].unsqueeze(1).unsqueeze(2) * base_target_dist\n            for i in range(self.blending_num):\n                weight_labels += normalized_weights[:, i + 1].unsqueeze(1).unsqueeze(2) * aligned_target_dists[i]\n\n            target_dist = (\n                weight_labels\n            )\n        else:\n            raise ValueError(f\"distill_strategy={self.distill_strategy}\")\n\n        if self.distill_loss_type == \"ce\":\n            loss_lm = cross_entropy(\n                input=outputs[\"logits\"].view(-1, vocab_size),\n                target=target_dist.view(-1, vocab_size),\n                reduction=\"none\",\n            ).view(batch_size, -1)\n        elif self.distill_loss_type == \"kl\":\n            loss_lm = kl_div(\n                input=log_softmax(outputs[\"logits\"], dim=-1),\n                target=target_dist,\n                log_target=False,\n                reduction=\"none\").sum(dim=-1)\n        else:\n            raise ValueError(f\"Not implement distill_loss_type={self.distill_loss_type}\")\n\n        loss_lm = (loss_lm * inputs[\"attention_mask\"]).sum() / inputs[\"attention_mask\"].sum()\n        loss = self.lm_loss_weight * loss + (1.0 - self.lm_loss_weight) * loss_lm\n\n        return (loss, outputs) if return_outputs else loss\n"
  },
  {
    "path": "python/fate_llm/algo/fedmkt/token_alignment/__init__.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n"
  },
  {
    "path": "python/fate_llm/algo/fedmkt/token_alignment/spectal_token_mapping.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport transformers\n\n\nTOKENIZER_TO_SPECIAL_TOKEN = {\n    transformers.LlamaTokenizer: '▁',\n    transformers.LlamaTokenizerFast: '▁',\n    transformers.GPTNeoXTokenizerFast: 'Ġ',\n    transformers.GPT2TokenizerFast: 'Ġ',\n    transformers.GPT2Tokenizer: 'Ġ',\n    transformers.BloomTokenizerFast: 'Ġ',\n}\n"
  },
  {
    "path": "python/fate_llm/algo/fedmkt/token_alignment/token_align.py",
    "content": "#\n# NOTE: The dtw function is copied from FuseAI/FuseLLM\n#       and the align_blending_model_logits_with_base_model_logits function is modified from FuseAI/FuseLLM\n# Copyright FuseAI\n#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport logging\nimport transformers\nimport editdistance\nimport numpy as np\n\nfrom typing import Dict, List\nfrom fate_llm.algo.fedmkt.token_alignment.spectal_token_mapping import TOKENIZER_TO_SPECIAL_TOKEN\nfrom fate_llm.algo.fedmkt.utils.vars_define import (\n    PER_STEP_LOGITS,\n    PER_STEP_INDICES,\n    ALIGNED_OTHER_LOGITS,\n    ALIGNED_OTHER_INDICES,\n    ALIGNED_OTHER_METRIC,\n    METRIC\n)\n\n\nlogger = logging.getLogger(__name__)\n\n\ndef dtw(series_1, series_2, norm_func=np.linalg.norm):\n    \"\"\"code refer to: https://github.com/fanqiwan/FuseAI/blob/main/FuseLLM/src/utils/others.py#L318\"\"\"\n\n    matrix = np.zeros((len(series_1) + 1, len(series_2) + 1))\n    matrix[0, :] = np.inf\n    matrix[:, 0] = np.inf\n    matrix[0, 0] = 0\n    for i, vec1 in enumerate(series_1):\n        for j, vec2 in enumerate(series_2):\n            cost = norm_func(vec1, vec2)\n            matrix[i + 1, j + 1] = cost + min(matrix[i, j + 1], matrix[i + 1, j], matrix[i, j])\n    matrix = matrix[1:, 1:]\n    i = matrix.shape[0] - 1\n    j = matrix.shape[1] - 1\n    matches = []\n    mappings_series_1 = [list() for v in range(matrix.shape[0])]\n    mappings_series_2 = [list() for v in range(matrix.shape[1])]\n    while i > 0 or j > 0:\n        matches.append((i, j))\n        mappings_series_1[i].append(j)\n        mappings_series_2[j].append(i)\n        option_diag = matrix[i - 1, j - 1] if i > 0 and j > 0 else np.inf\n        option_up = matrix[i - 1, j] if i > 0 else np.inf\n        option_left = matrix[i, j - 1] if j > 0 else np.inf\n        move = np.argmin([option_diag, option_up, option_left])\n        if move == 0:\n            i -= 1\n            j -= 1\n        elif move == 1:\n            i -= 1\n        else:\n            j -= 1\n    matches.append((0, 0))\n    mappings_series_1[0].append(0)\n    mappings_series_2[0].append(0)\n    matches.reverse()\n    for mp in mappings_series_1:\n        mp.reverse()\n    for mp in mappings_series_2:\n        mp.reverse()\n\n    return matches, matrix[-1, -1], mappings_series_1, mappings_series_2, matrix\n\n\ndef greedy_dynamic_matching(base_model_tokens, blending_model_tokens, base_model_sp_t, blending_model_sp_t):\n    l1 = len(base_model_tokens)\n    l2 = len(blending_model_tokens)\n\n    base_model_tokens = [token.replace(base_model_sp_t, \"\") for token in base_model_tokens]\n    blending_model_tokens = [token.replace(blending_model_sp_t, \"\") for token in blending_model_tokens]\n\n    dp = np.full((l1 + 1, l2 + 1), -1000000000, dtype=\"int32\")\n    matched_left = np.full((l1, l2), -1, dtype=\"int32\")\n    matched_right = np.full((l1, l2), -1, dtype=\"int32\")\n    trans_left = np.full((l1 + 1, l2 + 1), -1, dtype=\"int32\")\n    trans_right = np.full((l1 + 1, l2 + 1), -1, dtype=\"int32\")\n\n    # this can be optimizer use suffix data structure, but naive implemented for fast trial , it will be optimize later.\n    for i in range(l1):\n        for j in range(l2):\n            if base_model_tokens[i] == blending_model_tokens[j]:\n                matched_left[i][j] = 1\n                matched_right[i][j] = 1\n                continue\n\n            i2, j2 = i, j\n            t1 = \"\"\n            t2 = \"\"\n            sq_l1, sq_l2 = 0, 0\n            while i2 >= 0 and j2 >= 0:\n                if len(t1) > len(t2):\n                    t2 = blending_model_tokens[j2] + t2\n                    sq_l2 += 1\n                    j2 -= 1\n                elif len(t1) < len(t2):\n                    t1 = base_model_tokens[i2] + t1\n                    sq_l1 += 1\n                    i2 -= 1\n                else:\n                    if sq_l1 == 0:\n                        sq_l1 += 1\n                        sq_l2 += 1\n                        t1 += base_model_tokens[i2]\n                        t2 += blending_model_tokens[j2]\n                        i2 -= 1\n                        j2 -= 1\n                        continue\n                    if t1 == t2:\n                        matched_left[i][j] = sq_l1\n                        matched_right[i][j] = sq_l2\n                    break\n\n    \"\"\"\n    always shortest matching\n    \"\"\"\n    for i in range(0, l1 + 1):\n        dp[i][0] = 0\n\n    for j in range(0, l2 + 1):\n        dp[0][j] = 1\n\n    for i in range(0, l1):\n        for j in range(0, l2):\n            if matched_left[i][j] == -1:\n                dp[i + 1][j + 1] = max(dp[i + 1][j], dp[i][j + 1])\n                if dp[i + 1][j + 1] == dp[i + 1][j]:\n                    trans_right[i + 1][j + 1] = j\n                else:\n                    trans_left[i + 1][j + 1] = i\n            else:\n                l_len = matched_left[i][j]\n                r_len = matched_right[i][j]\n                dp[i + 1][j + 1] = max(max(dp[i + 1][j], dp[i][j + 1]), dp[i + 1 - l_len][j + 1 - r_len] + l_len)\n                if dp[i + 1][j + 1] == dp[i + 1 - l_len][j + 1 - r_len] + l_len:\n                    trans_left[i + 1][j + 1] = i + 1 - l_len\n                    trans_right[i + 1][j + 1] = j + 1 - r_len\n                    assert l_len > 0 and r_len > 0\n                elif dp[i + 1][j + 1] == dp[i + 1][j]:\n                    trans_right[i + 1][j + 1] = j\n                else:\n                    trans_left[i + 1][j + 1] = i\n\n    i, j = l1, l2\n    matches = []\n    while i > 0 and j > 0:\n        if trans_left[i][j] != -1 and trans_right[i][j] != -1:\n            l = trans_left[i][j]\n            r = trans_right[i][j]\n            matches.append([(l, i - 1), (r, j - 1)])\n            i, j = l, r\n        elif trans_left[i][j] < 0:\n            j -= 1\n        else:\n            i -= 1\n\n    matches.reverse()\n    return matches\n\n\ndef align_blending_model_logits_with_base_model_logits(base_examples,\n                                                       indices,\n                                                       blending_examples,\n                                                       blending_to_base_mapping,\n                                                       base_tokenizer,\n                                                       blending_tokenizer,\n                                                       blending_model_index,\n                                                       skip_align=False,\n                                                       align_strategy=\"greedy_dp\"):\n    \"\"\"modified from https://github.com/fanqiwan/FuseAI/blob/main/FuseLLM/src/utils/token_alignment.py#L101\"\"\"\n    base_features = [{key: base_examples[key][i] for key in base_examples} for i in\n                     range(len(base_examples[next(iter(base_examples))]))]\n    blending_features = [blending_examples[idx] for idx in indices]\n    aligned_per_step_logits_list, aligned_per_step_indices_list = [], []\n    per_step_logits_list, per_step_indices_list = [], []\n    metric_ce_aligned = []\n    for base_feature, blending_feature in zip(base_features, blending_features):\n        base_feature[PER_STEP_LOGITS] = base_feature[PER_STEP_LOGITS][:len(base_feature['input_ids'])]\n        base_feature[PER_STEP_INDICES] = base_feature[PER_STEP_INDICES][:len(base_feature['input_ids'])]\n        blending_feature[PER_STEP_LOGITS] = blending_feature[PER_STEP_LOGITS][:len(blending_feature['input_ids'])]\n        blending_feature[PER_STEP_INDICES] = blending_feature[PER_STEP_INDICES][:len(blending_feature['input_ids'])]\n        if skip_align is True:\n            aligned_blending_model_per_step_logits = blending_feature[PER_STEP_LOGITS]\n            aligned_blending_model_per_step_indices = blending_feature[PER_STEP_INDICES]\n        else:\n            aligned_blending_model_per_step_logits, aligned_blending_model_per_step_indices = transform_step_logits(\n                base_model_tokenizer=base_tokenizer,\n                blending_model_tokenizer=blending_tokenizer,\n                base_model_vocab=base_tokenizer.get_vocab(),\n                base_model_input_ids=base_feature['input_ids'],\n                blending_model_input_ids=blending_feature['input_ids'],\n                blending_model_per_step_logits=blending_feature[PER_STEP_LOGITS],\n                blending_model_per_step_indices=blending_feature[PER_STEP_INDICES],\n                blending_to_base_mapping=blending_to_base_mapping,\n                align_strategy=align_strategy\n            )\n\n        aligned_per_step_logits_list.append(aligned_blending_model_per_step_logits)\n        aligned_per_step_indices_list.append(aligned_blending_model_per_step_indices)\n        per_step_logits_list.append(base_feature[PER_STEP_LOGITS])\n        per_step_indices_list.append(base_feature[PER_STEP_INDICES])\n        metric_ce_aligned.append(blending_feature[METRIC])\n\n    base_examples[PER_STEP_LOGITS] = per_step_logits_list\n    base_examples[PER_STEP_INDICES] = per_step_indices_list\n    base_examples[f\"{ALIGNED_OTHER_LOGITS}_{blending_model_index}\"] = aligned_per_step_logits_list\n    base_examples[f\"{ALIGNED_OTHER_INDICES}_{blending_model_index}\"] = aligned_per_step_indices_list\n    base_examples[f\"{ALIGNED_OTHER_METRIC}_{blending_model_index}\"] = metric_ce_aligned\n\n    return base_examples\n\n\ndef transform_step_logits(base_model_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,\n                          blending_model_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,\n                          base_model_vocab: Dict[str, int],\n                          base_model_input_ids: List[int],\n                          blending_model_input_ids: List[int],\n                          blending_model_per_step_logits: List[List[float]],\n                          blending_model_per_step_indices: List[List[int]],\n                          blending_to_base_mapping: Dict[str, str] = None,\n                          align_strategy: str = \"dtw\"\n                          ):\n    \"\"\"modified from https://github.com/fanqiwan/FuseAI/blob/main/FuseLLM/src/utils/others.py#L364\"\"\"\n    \"\"\"Align blending model per step logits & indices with base model.\"\"\"\n    base_model_tokens = base_model_tokenizer.convert_ids_to_tokens(base_model_input_ids)\n    blending_model_tokens = blending_model_tokenizer.convert_ids_to_tokens(blending_model_input_ids)\n    base_model_special_token = TOKENIZER_TO_SPECIAL_TOKEN[base_model_tokenizer.__class__]\n    blending_model_special_token = TOKENIZER_TO_SPECIAL_TOKEN[blending_model_tokenizer.__class__]\n\n    aligned_blending_model_per_step_logits, aligned_blending_model_per_step_indices = [], []\n    if align_strategy == \"dtw\":\n        def dist_fn(a, b):\n            \"\"\"Calculate editdistance between two tokens, a is from blending model, b is from base model.\"\"\"\n            return editdistance.eval(a.replace(blending_model_special_token, ''),\n                                     b.replace(base_model_special_token, ''))\n\n        _, _, _, base_to_blending, _ = dtw(blending_model_tokens, base_model_tokens, norm_func=dist_fn)\n        for i, blending_idx in enumerate(base_to_blending):\n            aligned_blending_model_per_step_logit = []\n            aligned_blending_model_per_step_index = []\n            if len(blending_idx) == 1:  # one base token map to one blending token\n                j = blending_idx[0]\n                base_token = base_model_tokens[i]\n                blending_token = blending_model_tokens[j].replace(blending_model_special_token,\n                                                                  base_model_special_token)\n                if (\n                    blending_model_tokenizer.__class__ == transformers.GPTNeoXTokenizerFast\n                    or blending_model_tokenizer.__class__ == transformers.GPT2TokenizerFast) and i == 0 and base_token.startswith(\n                    base_model_special_token) and not blending_token.startswith(base_model_special_token):\n                    blending_token = base_model_special_token + blending_token  # special case for mpt\n\n                if (base_token == blending_token) or (\n                        blending_token in blending_to_base_mapping and base_token == blending_to_base_mapping[\n                    blending_token]):  # find the aligned mapping, use the corresponding logits\n                    # the logits and indices at this step\n                    for blending_logit, blending_index in zip(blending_model_per_step_logits[j],\n                                                              blending_model_per_step_indices[j]):\n                        # the token corresponds to the logit and indices\n                        blending_t = blending_model_tokenizer.convert_ids_to_tokens([blending_index])[0].replace(\n                            blending_model_special_token, base_model_special_token)\n                        blending_t = blending_to_base_mapping[blending_t]\n                        if blending_t in base_model_vocab:\n                            aligned_index = base_model_vocab[blending_t]  # the index of the token in base model vocab\n                            if aligned_index not in aligned_blending_model_per_step_index:\n                                aligned_blending_model_per_step_index.append(aligned_index)\n                                aligned_blending_model_per_step_logit.append(blending_logit)\n                        else:\n                            logger.warning(f\"blending_t: {blending_t} not in base_model_vocab!\")\n                else:  # find error aligned mapping, use the one-hot logits\n                    aligned_blending_model_per_step_index.append(base_model_vocab[base_token])\n                    aligned_blending_model_per_step_logit.append(1.0)\n            else:  # one base token map to multiple blending token, in this case only fit base token. use the one-hot logits\n                base_token = base_model_tokens[i]\n                aligned_blending_model_per_step_index.append(base_model_vocab[base_token])\n                aligned_blending_model_per_step_logit.append(1.0)\n            aligned_blending_model_per_step_indices.append(aligned_blending_model_per_step_index)\n            aligned_blending_model_per_step_logits.append(aligned_blending_model_per_step_logit)\n    elif align_strategy == \"greedy_dp\":\n        matches = greedy_dynamic_matching(base_model_tokens, blending_model_tokens, base_model_special_token, blending_model_special_token)\n        fusion_logits = [[] for _ in range(len(matches))]\n        fusion_indices = [[] for _ in range(len(matches))]\n        match_pos = [-1] * len(base_model_tokens)\n        used = [False] * len(matches)\n\n        for idx, ((start_pos_1, end_pos_1), (start_pos_2, end_pos_2)) in enumerate(matches):\n            fusion_dict = dict()\n            fusion_counter_dict = dict()\n            for blending_pos in range(start_pos_2, end_pos_2 + 1):\n                for blending_logit, blending_index in zip(blending_model_per_step_logits[blending_pos],\n                                                          blending_model_per_step_indices[blending_pos]):\n                    if blending_index not in fusion_dict:\n                        fusion_dict[blending_index] = 0\n                        fusion_counter_dict[blending_index] = 0\n                    fusion_dict[blending_index] += blending_logit\n                    fusion_counter_dict[blending_index] += 1\n\n            for j in range(start_pos_1, end_pos_1 + 1):\n                match_pos[j] = idx\n\n            for token_index, token_logit in fusion_dict.items():\n                fusion_logits[idx].append(token_logit / fusion_counter_dict[token_index])\n                fusion_indices[idx].append(token_index)\n\n        for i in range(len(base_model_tokens)):\n            aligned_blending_model_per_step_logit = []\n            aligned_blending_model_per_step_index = []\n            if match_pos[i] == -1 or used[match_pos[i]]:\n                base_token = base_model_tokens[i]\n                aligned_blending_model_per_step_index.append(base_model_vocab[base_token])\n                aligned_blending_model_per_step_logit.append(1.0)\n            else:\n                pos = match_pos[i]\n                used[pos] = True\n                for blending_logit, blending_index in zip(fusion_logits[pos],\n                                                          fusion_indices[pos]):\n                    # the token corresponds to the logit and indices\n                    blending_t = blending_model_tokenizer.convert_ids_to_tokens([blending_index])[0].replace(\n                        blending_model_special_token, base_model_special_token)\n                    blending_t = blending_to_base_mapping[blending_t]\n                    if blending_t in base_model_vocab:\n                        aligned_index = base_model_vocab[blending_t]  # the index of the token in base model vocab\n                        if aligned_index not in aligned_blending_model_per_step_index:\n                            aligned_blending_model_per_step_index.append(aligned_index)\n                            aligned_blending_model_per_step_logit.append(blending_logit)\n                    else:\n                        logger.warning(f\"blending_t: {blending_t} not in base_model_vocab!\")\n            aligned_blending_model_per_step_indices.append(aligned_blending_model_per_step_index)\n            aligned_blending_model_per_step_logits.append(aligned_blending_model_per_step_logit)\n    else:\n        raise ValueError(f\"{align_strategy} not implemented yet.\")\n\n    return aligned_blending_model_per_step_logits, aligned_blending_model_per_step_indices\n\n\ndef token_align(\n    base_model_logits_datasets,\n    blending_model_logits_dataset,\n    base_tokenizer,\n    blending_tokenizer,\n    blending_to_base_mapping,\n    blending_model_index,\n    batch_size=4,\n    preprocessing_num_workers=4,\n    skip_align=False,\n    align_strategy=\"dtw\",\n):\n    assert len(base_model_logits_datasets) == len(blending_model_logits_dataset)\n    base_model_blending_model_logits_datasets = base_model_logits_datasets.map(\n        align_blending_model_logits_with_base_model_logits,\n        batched=True,\n        batch_size=batch_size,\n        with_indices=True,\n        num_proc=preprocessing_num_workers,\n        load_from_cache_file=True,\n        fn_kwargs={\"blending_examples\": blending_model_logits_dataset,\n                   \"blending_to_base_mapping\": blending_to_base_mapping,\n                   \"base_tokenizer\": base_tokenizer,\n                   \"blending_tokenizer\": blending_tokenizer,\n                   \"blending_model_index\": blending_model_index,\n                   \"skip_align\": skip_align,\n                   \"align_strategy\": align_strategy},\n        keep_in_memory=True,\n        desc=\"Align blending model's logits with base model's logits.\",\n    )\n\n    return base_model_blending_model_logits_datasets\n"
  },
  {
    "path": "python/fate_llm/algo/fedmkt/token_alignment/vocab_mapping.py",
    "content": "#\n# NOTE: The find_best_mapping function is copied from FuseAI/FuseLLM\n# Copyright FuseAI/FuseLLM\n#\n#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport json\nimport editdistance\nimport tqdm\nimport multiprocessing\nimport logging\n\nfrom fate_llm.data.tokenizers.cust_tokenizer import get_tokenizer\nfrom fate_llm.algo.fedmkt.token_alignment.spectal_token_mapping import TOKENIZER_TO_SPECIAL_TOKEN\n\nlogger = logging.getLogger(__name__)\n\n\ndef find_best_mapping(x, base_tokens, blending_model_special_token, base_model_special_token, best_one=True):\n    \"\"\"code refer to https://github.com/fanqiwan/FuseAI/blob/main/FuseLLM/src/utils/vocab_mapping.py#L82\"\"\"\n    tmp_x = x.replace(blending_model_special_token, base_model_special_token)\n    if tmp_x in base_tokens:\n        return tmp_x, tmp_x\n    else:\n        if best_one:\n            return tmp_x, min([(y, editdistance.eval(tmp_x, y)) for y in base_tokens], key=lambda d: d[1])[0]\n        else:\n            token_and_distance = [(y, editdistance.eval(tmp_x, y)) for y in base_tokens]\n            min_distance = min(item[1] for item in token_and_distance)\n            shortest_distance_tokens = [item[0] for item in token_and_distance if item[1] == min_distance]\n            return tmp_x, shortest_distance_tokens\n\n\ndef get_vocab_mappings(model_name_or_path, candidate_model_name_or_path, vocab_mapping_save_path, num_processors=8):\n    ori_tokenizer = get_tokenizer(model_name_or_path)\n    candidate_tokenizer = get_tokenizer(candidate_model_name_or_path)\n\n    ori_special_tok = TOKENIZER_TO_SPECIAL_TOKEN[ori_tokenizer.__class__]\n    candidate_special_tok = TOKENIZER_TO_SPECIAL_TOKEN[candidate_tokenizer.__class__]\n\n    candidate_tokens = list(candidate_tokenizer.get_vocab().keys())\n\n    with multiprocessing.Pool(num_processors) as process_pool:\n        func_args = [(tok, candidate_tokens, ori_special_tok, candidate_special_tok) for tok in ori_tokenizer.get_vocab()]\n\n        vocab_mappings = dict(tqdm.tqdm(process_pool.starmap(find_best_mapping, func_args)),\n                              total=len(ori_tokenizer.get_vocab()))\n\n    with open(vocab_mapping_save_path, \"w\") as fout:\n        json.dump(vocab_mappings, fout)\n\n    return vocab_mappings\n"
  },
  {
    "path": "python/fate_llm/algo/fedmkt/utils/__init__.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n"
  },
  {
    "path": "python/fate_llm/algo/fedmkt/utils/dataset_sync_util.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport logging\nimport datasets\nimport torch\nimport torch.distributed as dist\nfrom fate_llm.algo.fedmkt.utils.vars_define import (\n    METRIC,\n    PER_STEP_LOGITS,\n    PER_STEP_INDICES,\n)\n\nlogger = logging.getLogger(__name__)\n\n\ndef sync_dataset(dataset, local_rank, world_size, device):\n    integer_keys_2d = [\"input_ids\", \"attention_mask\", \"labels\"]\n    integer_keys_3d = [PER_STEP_INDICES]\n    float_keys_3d = [PER_STEP_LOGITS]\n    float_keys_1d = [METRIC]\n\n    if local_rank == 0:\n        for key in integer_keys_2d + integer_keys_3d + float_keys_3d + float_keys_1d:\n            if key in integer_keys_2d or key in integer_keys_3d:\n                dtype = torch.int32\n            else:\n                dtype = torch.float64\n\n            values = dataset[key]\n            v_tensor = torch.tensor(values, dtype=dtype).cuda(device)\n            shape_tensor = torch.tensor(v_tensor.shape, dtype=torch.int32).cuda(device)\n            shape_tensors = [shape_tensor for _ in range(world_size)]\n            dist.scatter(shape_tensor, shape_tensors, async_op=False)\n\n            v_tensors = [v_tensor for _ in range(world_size)]\n            dist.scatter(v_tensor, v_tensors, async_op=False)\n\n        return dataset\n\n    else:\n        data_dict = dict()\n        for key in integer_keys_2d + integer_keys_3d + float_keys_3d + float_keys_1d:\n            if key in integer_keys_2d or key in integer_keys_3d:\n                dtype = torch.int32\n            else:\n                dtype = torch.float64\n\n            if key in integer_keys_2d:\n                shape_tensor = torch.tensor([0, 0], dtype=torch.int32).cuda(device)\n            elif key in float_keys_3d or key in integer_keys_3d:\n                shape_tensor = torch.tensor([0, 0, 0], dtype=torch.int32).cuda(device)\n            else:\n                shape_tensor = torch.tensor([0], dtype=torch.int32).cuda(device)\n\n            dist.scatter(shape_tensor, src=0, async_op=False)\n            v_tensor = torch.zeros(shape_tensor.tolist(), dtype=dtype).cuda(device)\n            dist.scatter(v_tensor, src=0, async_op=False)\n            data_dict[key] = v_tensor.tolist()\n\n        return datasets.Dataset.from_dict(data_dict)\n"
  },
  {
    "path": "python/fate_llm/algo/fedmkt/utils/generate_logit_utils.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport torch\nimport torch.nn.functional as F\nimport gc\nfrom fate_llm.algo.fedmkt.utils.vars_define import (\n    PER_STEP_LOGITS,\n    PER_STEP_INDICES,\n    METRIC\n)\n\n\nclass Metric(object):\n    @classmethod\n    def cal_metric(cls, logits, input_ids, attention_mask, labels, training_args):\n        if training_args.metric_type == \"ce\":\n            return cls.cal_ce(logits, input_ids, attention_mask, labels, training_args)\n        else:\n            raise NotImplemented(f\"metric={training_args.metric_type} is not implemented yet\")\n\n    @classmethod\n    def cal_ce(cls, logits, input_ids, attention_mask, labels, training_args):\n        metric = F.cross_entropy(logits[..., :-1, :].contiguous().view(-1, logits.size(-1)),\n                                 labels[..., 1:].contiguous().view(-1), reduction=\"none\").view(logits.size(0), -1)\n\n        metric = (metric * attention_mask[..., 1:]).sum(dim=-1) / attention_mask[..., 1:].sum(dim=-1)\n\n        return metric\n\n\nclass LogitsSelection(object):\n    @classmethod\n    def select_logits(cls, logits, training_args):\n        if training_args.top_k_strategy == \"highest\":\n            return cls.select_highest(logits, training_args.top_k_logits_keep)\n        else:\n            raise NotImplemented(f\"logits selection strategy={training_args.top_k_strategy} is not implemented\")\n\n    @classmethod\n    def select_highest(cls, logits, top_k_logits_keep):\n        top_k_logits, top_k_indices = torch.topk(logits.cuda(), k=top_k_logits_keep)\n        logits.cpu()\n\n        return top_k_logits, top_k_indices\n\n\ndef generate_pub_data_logits(inputs, model, training_args, data_collator):\n    input_keys = [\"attention_mask\", \"input_ids\", \"labels\"]\n    inputs_per_batched = [dict() for _ in range(len(inputs[input_keys[1]]))]\n    for key in input_keys:\n        if key not in inputs:\n            continue\n\n        for idx, _in in enumerate(inputs[key]):\n            inputs_per_batched[idx][key] = _in\n\n    if \"attention_mask\" not in inputs:\n        for idx in range(len(inputs_per_batched)):\n            inputs_per_batched[idx][\"attention_mask\"] = [1] * len(inputs_per_batched[idx][\"input_ids\"])\n\n    inputs_per_batched = data_collator(inputs_per_batched)\n\n    input_ids = inputs_per_batched[\"input_ids\"]\n    attention_mask = inputs_per_batched[\"attention_mask\"]\n    labels = inputs_per_batched[\"labels\"]\n\n    device = next(model.parameters()).device\n    if device.type == \"cuda\":\n        input_ids = input_ids.cuda(device)\n        attention_mask = attention_mask.cuda(device)\n        labels = labels.cuda(device)\n\n    model.eval()\n\n    with torch.no_grad():\n        logits = model(input_ids=input_ids, attention_mask=attention_mask).logits\n\n        metric = Metric.cal_metric(logits, input_ids, attention_mask, labels, training_args)\n\n        input_ids.cpu()\n        del input_ids\n        attention_mask.cpu()\n        del attention_mask\n        labels.cpu()\n        del labels\n        logits.cpu()\n        metric.cpu()\n\n        if training_args.top_k_logits_keep is None:\n            raise ValueError(\"Please specify top_k_logits_keep, fulling save will leak to memory exceeds\")\n\n        selected_logits, selected_indices = LogitsSelection.select_logits(logits=logits, training_args=training_args)\n        selected_logits.cpu()\n        selected_indices.cpu()\n\n        inputs[PER_STEP_LOGITS] = selected_logits\n        inputs[PER_STEP_INDICES] = selected_indices\n        inputs[METRIC] = metric\n\n        del logits\n\n        gc.collect()\n\n    model.train()\n\n    return inputs\n"
  },
  {
    "path": "python/fate_llm/algo/fedmkt/utils/tokenizer_tool.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom transformers import AutoConfig\n\n\ndef get_vocab_size(tokenizer_name_or_path):\n    if tokenizer_name_or_path is not None:\n        return AutoConfig.from_pretrained(tokenizer_name_or_path)\n"
  },
  {
    "path": "python/fate_llm/algo/fedmkt/utils/vars_define.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nPER_STEP_LOGITS = \"per_step_logits\"\nPER_STEP_INDICES = \"per_step_indices\"\nMETRIC = \"metric\"\n\nALIGNED_OTHER_LOGITS = \"aligned_other_logits\"\nALIGNED_OTHER_INDICES = \"aligned_other_indices\"\nALIGNED_OTHER_METRIC = \"aligned_other_metrice\"\n\nSELF_TARGET_DIST = \"llm_target_distribution\"\nOTHER_TARGET_DIST = \"slm_target_distribution\"\n\nINPUT_KEYS = {\"input_ids\", \"attention_mask\", \"labels\"}\n"
  },
  {
    "path": "python/fate_llm/algo/inferdpt/__init__.py",
    "content": ""
  },
  {
    "path": "python/fate_llm/algo/inferdpt/_encode_decode.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n\nfrom fate.arch import Context\nfrom typing import List, Dict\nimport logging\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass EncoderDecoder(object):\n\n    def __init__(self, ctx: Context) -> None:\n        self.ctx = ctx\n\n    def encode(self,  docs: List[Dict[str, str]], format_template: str):\n        pass\n\n    def decode(self,  docs: List[Dict[str, str]], format_template: str ):\n        pass\n\n    def inference(self, docs: List[Dict[str, str]], inference_kwargs: dict = {}, format_template: str = None):\n        pass"
  },
  {
    "path": "python/fate_llm/algo/inferdpt/inferdpt.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n\nimport copy\nfrom jinja2 import Template\nfrom tqdm import tqdm\nfrom fate.arch import Context\nfrom typing import List, Dict, Union\nfrom fate.ml.nn.dataset.base import Dataset\nfrom fate_llm.algo.inferdpt.utils import InferDPTKit\nfrom openai import OpenAI\nimport logging\nfrom fate_llm.inference.inference_base import Inference\nfrom fate_llm.algo.inferdpt._encode_decode import EncoderDecoder\nfrom fate_llm.dataset.hf_dataset import HuggingfaceDataset\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass InferDPTClient(EncoderDecoder):\n\n    def __init__(self, ctx: Context, inferdpt_pertub_kit: InferDPTKit, local_inference_inst: Inference,  epsilon: float = 3.0,) -> None:\n        self.ctx = ctx\n        self.kit = inferdpt_pertub_kit\n        assert epsilon > 0, 'epsilon must be a float > 0'\n        self.ep = epsilon\n        self.comm_idx = 0\n        self.local_inference_inst = local_inference_inst\n\n    def encode(self, docs: List[Dict[str, str]], format_template: str = None, verbose=False, perturb_doc_key: str ='perturbed_doc') -> List[Dict[str, str]]:\n        \n        copy_docs = copy.deepcopy(docs)\n        if format_template is not None:\n            template = Template(format_template)\n        else:\n            template = None\n\n        for doc in tqdm(copy_docs):\n            if template is None:\n                rendered_doc = str(doc)\n            else:\n                rendered_doc = template.render(**doc)\n                if verbose:\n                    logger.debug('doc to perturb {}'.format(rendered_doc))\n            p_doc = self.kit.perturb(rendered_doc, self.ep)\n            doc[perturb_doc_key] = p_doc\n\n        return copy_docs\n        \n    def _remote_inference(self, docs: List[Dict[str, str]], \n                     inference_kwargs: dict = {},\n                     format_template: str = None, \n                     perturbed_response_key: str = 'perturbed_response',\n                     verbose=False\n                     ) -> List[Dict[str, str]]:\n\n        copy_docs = copy.deepcopy(docs)\n        if format_template is not None:\n            template = Template(format_template)\n        else:\n            template = None\n\n        infer_docs = []\n        for doc in tqdm(copy_docs):\n            if template is None:\n                rendered_doc = str(doc)\n            else:\n                rendered_doc = template.render(**doc)\n                if verbose:\n                    logger.debug('inference doc {}'.format(rendered_doc))\n\n            infer_docs.append(rendered_doc)\n            doc['perturbed_doc_with_instrcution'] = rendered_doc\n            \n        self.ctx.arbiter.put('client_data_{}'.format(self.comm_idx), (infer_docs, inference_kwargs))\n        perturb_resp = self.ctx.arbiter.get('pdoc_{}'.format(self.comm_idx))\n        self.comm_idx += 1\n        for pr, doc in zip(perturb_resp, copy_docs):\n             doc[perturbed_response_key] = pr\n\n        return copy_docs\n\n    def decode(self, p_docs: List[Dict[str, str]], instruction_template: str = None, decode_template: str = None, verbose=False, \n                     perturbed_response_key: str = 'perturbed_response', result_key: str = 'inferdpt_result',\n                     remote_inference_kwargs: dict = {}, local_inference_kwargs: dict = {}):\n\n        # inference using remote large models\n        docs_with_infer_result = self._remote_inference(p_docs, format_template=instruction_template, verbose=verbose, inference_kwargs=remote_inference_kwargs, perturbed_response_key=perturbed_response_key)\n        if decode_template is not None:\n            dt = Template(decode_template)\n            doc_to_decode = [dt.render(**i) for i in docs_with_infer_result]\n        else:\n            doc_to_decode = [str(i) for i in docs_with_infer_result]\n        # local model decode\n        final_result = self.local_inference_inst.inference(doc_to_decode, local_inference_kwargs)\n        for final_r, d in zip(final_result, docs_with_infer_result):\n            d[result_key] = final_r\n\n        return docs_with_infer_result\n\n    def inference(self, docs: Union[List[Dict[str, str]], HuggingfaceDataset],\n                encode_template: str,\n                instruction_template: str,\n                decode_template: str,\n                verbose: bool = False,\n                remote_inference_kwargs: dict = {},\n                local_inference_kwargs: dict = {},\n                perturb_doc_key: str = 'perturbed_doc',\n                perturbed_response_key: str = 'perturbed_response',\n                result_key: str = 'inferdpt_result',\n                ) -> List[Dict[str, str]]:\n        \n        assert (isinstance(docs, list) and isinstance(docs[0], dict)) or isinstance(docs, HuggingfaceDataset), 'Input doc must be a list of dict or HuggingfaceDataset'\n        # perturb doc\n        if isinstance(docs, HuggingfaceDataset):\n            docs = [docs[i] for i in range(len(docs))]\n        docs_with_p = self.encode(docs, format_template=encode_template, verbose=verbose, perturb_doc_key=perturb_doc_key)\n        logger.info('encode done')\n        # inference using perturbed doc\n        final_result = self.decode(\n            docs_with_p,\n            instruction_template,\n            decode_template,\n            verbose,\n            perturbed_response_key,\n            result_key,\n            remote_inference_kwargs,\n            local_inference_kwargs,\n        )\n        logger.info('decode done')\n        \n        return final_result\n\n\nclass InferDPTServer(object):\n\n    def __init__(self, ctx: Context, inference_inst: Inference) -> None:\n        \n        self.ctx = ctx\n        self.inference_inst = inference_inst\n        self.comm_idx = 0 \n\n    def inference(self, verbose=False):\n\n        client_data = self.ctx.guest.get('client_data_{}'.format(self.comm_idx))\n        perturbed_docs, inference_kwargs = client_data\n\n        if verbose:\n            logger.info('got data {}'.format(client_data))\n\n        logger.info('start inference')\n        rs_doc = self.inference_inst.inference(perturbed_docs, inference_kwargs)\n        self.ctx.guest.put('pdoc_{}'.format(self.comm_idx), rs_doc)\n        self.comm_idx += 1\n\n    def predict(self):\n        self.inference()\n"
  },
  {
    "path": "python/fate_llm/algo/inferdpt/init/_init.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n\nfrom fate.arch import Context\nfrom typing import Union\n\n\nclass InferInit(object):\n\n    def __init__(self, ctx: Context):\n        self.ctx = ctx\n\n    def get_inst(self):\n        pass\n\n"
  },
  {
    "path": "python/fate_llm/algo/inferdpt/init/default_init.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom fate_llm.algo.inferdpt.init._init import InferInit\nfrom fate_llm.inference.api import APICompletionInference\nfrom fate_llm.algo.inferdpt import inferdpt\nfrom fate_llm.algo.inferdpt.utils import InferDPTKit\nfrom fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer\n\n\nclass InferDPTAPIClientInit(InferInit):\n\n    api_url = ''\n    api_model_name = ''\n    api_key = 'EMPTY'\n    inferdpt_kit_path = ''\n    eps = 3.0\n\n    def __init__(self, ctx):\n        super().__init__(ctx)\n        self.ctx = ctx\n\n    def get_inst(self)-> InferDPTClient:\n        inference = APICompletionInference(api_url=self.api_url, model_name=self.api_model_name, api_key=self.api_key)\n        kit = InferDPTKit.load_from_path(self.inferdpt_kit_path)\n        inferdpt_client = inferdpt.InferDPTClient(self.ctx, kit, inference, epsilon=self.eps)\n        return inferdpt_client\n\n\nclass InferDPTAPIServerInit(InferInit):\n\n    api_url = ''\n    api_model_name = ''\n    api_key = 'EMPTY'\n\n    def __init__(self, ctx):\n        super().__init__(ctx)\n        self.ctx = ctx\n\n    def get_inst(self)-> InferDPTServer:\n        inference = APICompletionInference(api_url=self.api_url, model_name=self.api_model_name, api_key=self.api_key)\n        inferdpt_server = inferdpt.InferDPTServer(self.ctx,inference_inst=inference)\n        return inferdpt_server\n"
  },
  {
    "path": "python/fate_llm/algo/inferdpt/utils.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n\n\n\"\"\"\nParts of the codes are modified from https://github.com/mengtong0110/InferDPT\n\"\"\"\n\nfrom decimal import getcontext\nfrom transformers import AutoTokenizer\nimport numpy as np\nimport json\nimport tqdm\nfrom typing import List\n\n\ngetcontext().prec = 100\n\n\nclass NumpyEncoder(json.JSONEncoder):\n    \"\"\" Special json encoder for numpy types \"\"\"\n\n    def default(self, obj):\n        if isinstance(obj, np.integer):\n            return int(obj)\n        elif isinstance(obj, np.floating):\n            return float(obj)\n        elif isinstance(obj, np.ndarray):\n            return obj.tolist()\n        return json.JSONEncoder.default(self, obj)\n\n\ndef save_jsonl(filename, data):\n    with open(filename, 'w') as file:\n        for item in data:\n            json.dump(item, file)\n            file.write('\\n')\n\n\ndef create_sensitivity_of_embeddings(all_embedding_matrix):\n    n_dimensions = all_embedding_matrix.shape[1]\n    delta_f_new = np.zeros(n_dimensions)\n    for dim in tqdm.trange(n_dimensions):\n        dim_data = all_embedding_matrix[:, dim]\n        sorted_dim_data = np.sort(dim_data)\n        differences = sorted_dim_data[-1] - sorted_dim_data[0]\n        delta_f_new[dim] = differences\n    return delta_f_new\n\n\ndef create_sorted_embedding_matrix(token_list, similarity_matrix):\n    token_2_sorted_distances = dict()\n    token_array = np.array(token_list)\n    for idx, token in tqdm.tqdm(enumerate(token_list)):\n        similarity_array = similarity_matrix[idx]\n        sorted_indices = np.argsort(similarity_array)[::-1]\n        token_2_sorted_distances[token] = [token_array[sorted_indices].tolist(), similarity_array[sorted_indices].tolist()]\n    return token_2_sorted_distances\n\n\ndef cosine_similarity_vectors(A, B):\n    dot_product = np.dot(A, B)\n    norm_a = np.linalg.norm(A)\n    norm_b = np.linalg.norm(B)\n    similarity = dot_product / (norm_a * norm_b)\n    return similarity\n\n\nclass InferDPTKit(object):\n\n    def __init__(self, token_to_vector_dict, sorted_similarities, delta_f, tokenizer) -> None:\n        self.token_to_vector_dict = token_to_vector_dict\n        self.sorted_similarities = sorted_similarities\n        self.delta_f = delta_f\n        self.tokenizer = tokenizer\n        assert len(token_to_vector_dict) == len(sorted_similarities)\n    \n\n    def save_to_path(self, path):\n\n        # make folder\n        import os\n        if not os.path.exists(path+'/inferdpt_kit'):\n            os.makedirs(path+'/inferdpt_kit')\n        \n        with open(path+'/inferdpt_kit/token_2_vector.json', 'w', encoding='utf8') as f:\n            json.dump(self.token_to_vector_dict, f, ensure_ascii=False, cls=NumpyEncoder)\n\n        with open(path+'/inferdpt_kit/sorted_similarities.json', 'w') as f:\n            json.dump(self.sorted_similarities, f, cls=NumpyEncoder)\n\n        with open(path+'/inferdpt_kit/delta_f.json', 'w') as f:\n            json.dump(self.delta_f, f, cls=NumpyEncoder)\n\n        self.tokenizer.save_pretrained(path+'/inferdpt_kit/tokenizer/')\n\n    @staticmethod\n    def make_inferdpt_kit_param(embedding_matrix: np.ndarray, token_list: List[str]):\n        \n        def cosine_simi(embedding_matrix1, embedding_matrix2):\n            dot_product = np.dot(embedding_matrix1, embedding_matrix2.T)\n            norm_matrix1 = np.linalg.norm(embedding_matrix1, axis=1)\n            norm_matrix2 = np.linalg.norm(embedding_matrix2, axis=1)\n            similarity = dot_product / (np.outer(norm_matrix1, norm_matrix2))\n\n            return similarity\n        assert len(embedding_matrix) == len(token_list)\n        similarity_matrix = cosine_simi(embedding_matrix, embedding_matrix)\n        token_sorted_distance_dict = create_sorted_embedding_matrix(token_list, similarity_matrix)\n        delta_f_new = create_sensitivity_of_embeddings(embedding_matrix)\n\n        token_2_embedding = {}\n        for token, embedding in zip(token_list, embedding_matrix):\n            token_2_embedding[token] = embedding\n\n        return token_2_embedding, token_sorted_distance_dict, delta_f_new\n\n    @staticmethod\n    def load_from_path(path):\n        \n        with open(path+'/inferdpt_kit/token_2_vector.json', 'r', encoding='utf8') as f:\n            token_to_vector_dict = json.load(f)\n        with open(path+'/inferdpt_kit/sorted_similarities.json', 'r') as f:\n            sorted_similarities = json.load(f)\n        with open(path+'/inferdpt_kit/delta_f.json', 'r') as f:\n            delta_f = np.array(json.load(f))\n        tokenizer = AutoTokenizer.from_pretrained(path+'/inferdpt_kit/tokenizer/')\n        inferdpt_kit = InferDPTKit(token_to_vector_dict, sorted_similarities, delta_f, tokenizer)\n        return inferdpt_kit\n\n    def perturb(self, doc: str, epsilon: float) -> str:\n        \n        # epsilon > 0\n        assert epsilon > 0, \"epsilon should be greater than 0\"\n        tokenizer = self.tokenizer\n        tokens = tokenizer.tokenize(doc)\n        new_tokens = []\n        Delta_u = 1.0  \n        exp_factor = epsilon / (2 * Delta_u)\n        for origin_token in tokens:\n            if origin_token[0] == ' ':\n                origin_token = origin_token[1:]\n            origin_embed = self.token_to_vector_dict.get(origin_token, None)\n            if origin_embed is None:\n                new_tokens.append(origin_token)\n                continue\n            noise_embed = add_laplace_noise_to_vector(origin_embed, epsilon, self.delta_f)\n            similarity = cosine_similarity_vectors(origin_embed, noise_embed)\n            sorted_distances_for_token = self.sorted_similarities.get(origin_token, None)\n            if sorted_distances_for_token is None:\n                continue\n            token_only = sorted_distances_for_token[0]\n            similarity_only = sorted_distances_for_token[1]\n            arr = np.flip(similarity_only)\n            index = np.searchsorted(arr, similarity)\n            index = len(arr) - index\n            close_tokens = token_only[:index]\n            close_similarities = similarity_only[:index]\n            if len(close_tokens) == 0:\n                continue\n            unnormalized_probabilities = np.exp(exp_factor * np.array(close_similarities))\n            total_unnormalized_prob = np.sum(unnormalized_probabilities)\n            probabilities = unnormalized_probabilities / total_unnormalized_prob\n            selected_token = np.random.choice(close_tokens, p=probabilities)\n            new_tokens.append(selected_token)\n        token_ids = tokenizer.convert_tokens_to_ids(new_tokens)\n        sentence = tokenizer.decode(token_ids)\n        return sentence\n\n\ndef cosine_similarity_vectors(A, B):\n    dot_product = np.dot(A, B)\n    norm_a = np.linalg.norm(A)\n    norm_b = np.linalg.norm(B)\n    similarity = dot_product / (norm_a * norm_b)\n    return similarity\n\n\ndef add_laplace_noise_to_vector(vector, epsilon, delta_f_new):\n    vector = np.asarray(vector, dtype=np.longdouble)\n    if epsilon == 0:\n        beta_values = delta_f_new * 0\n    else:\n        beta_values = delta_f_new / (0.5 * epsilon)\n    noise = np.random.laplace(loc=0, scale=beta_values, size=len(beta_values))\n    noisy_vector = vector + noise\n\n    return noisy_vector\n\n\ndef perturb_sentence(sent,\n                     epsilon,\n                     tokenizer,\n                     token_to_vector_dict,\n                     sorted_distance_data,\n                     delta_f_new):\n    tokens = tokenizer.tokenize(sent)\n    new_tokens = []\n    Delta_u = 1.0  \n    exp_factor = epsilon / (2 * Delta_u)\n    for origin_token in tokens:\n        if origin_token[0] == ' ':\n            origin_token = origin_token[1:]\n        origin_embed = token_to_vector_dict.get(origin_token, None)\n        if origin_embed is None:\n            new_tokens.append(origin_token)\n            continue\n        noise_embed = add_laplace_noise_to_vector(origin_embed, epsilon, delta_f_new)\n        similarity = cosine_similarity_vectors(origin_embed, noise_embed)\n        sorted_distances_for_token = sorted_distance_data.get(origin_token, None)\n        if sorted_distances_for_token is None:\n            continue\n        token_only = sorted_distances_for_token[0]\n        similarity_only = sorted_distances_for_token[1]\n        arr = np.flip(similarity_only)\n        index = np.searchsorted(arr, similarity)\n        index = len(arr) - index\n        close_tokens = token_only[:index]\n        close_similarities = similarity_only[:index]\n        if len(close_tokens) == 0:\n            continue\n        unnormalized_probabilities = np.exp(exp_factor * np.array(close_similarities))\n        total_unnormalized_prob = np.sum(unnormalized_probabilities)\n        probabilities = unnormalized_probabilities / total_unnormalized_prob\n        selected_token = np.random.choice(close_tokens, p=probabilities)\n        new_tokens.append(selected_token)\n    token_ids = tokenizer.convert_tokens_to_ids(new_tokens)\n    sentence = tokenizer.decode(token_ids)\n    return sentence\n"
  },
  {
    "path": "python/fate_llm/algo/offsite_tuning/__init__.py",
    "content": ""
  },
  {
    "path": "python/fate_llm/algo/offsite_tuning/offsite_tuning.py",
    "content": "from fate.ml.aggregator.base import Aggregator\nfrom fate_llm.algo.fedavg.fedavg import Seq2SeqFedAVGClient, Seq2SeqFedAVGServer, Seq2SeqTrainingArguments\nfrom fate.ml.nn.trainer.trainer_base import FedArguments, TrainingArguments\nfrom typing import List, Optional, Callable, Tuple\nfrom fate.arch import Context\nfrom torch.optim import Optimizer\nfrom torch.utils.data import DataLoader, Dataset\nfrom torch.optim.lr_scheduler import _LRScheduler\nfrom transformers.trainer_callback import TrainerCallback\nfrom torch.nn import Module\nfrom transformers import TrainerState, TrainerControl, PreTrainedTokenizer\nfrom fate_llm.model_zoo.offsite_tuning.offsite_tuning_model import OffsiteTuningBaseModel\nimport logging\nimport torch\nimport torch.distributed as dist\nfrom transformers.modeling_utils import unwrap_model\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass OffsiteTuningTrainerClient(Seq2SeqFedAVGClient):\n    \n    def __init__(\n        self,\n        ctx: Context,\n        model: OffsiteTuningBaseModel,\n        training_args: Seq2SeqTrainingArguments,\n        fed_args: FedArguments,\n        train_set: Dataset,\n        val_set: Dataset = None,\n        optimizer: Optimizer = None,\n        scheduler: _LRScheduler = None,\n        data_collator: Callable = None,\n        tokenizer: Optional[PreTrainedTokenizer] = None,\n        callbacks: List[TrainerCallback] = [],\n        compute_metrics: Callable = None,\n        aggregate_model: bool = False,\n        save_trainable_weights_only: bool = False,\n        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,\n    ):\n        assert isinstance(model, OffsiteTuningBaseModel), \"model must be the subclass of OffsiteTuningBaseModel\"\n        if aggregate_model == False and fed_args is None:\n            fed_args = FedArguments()\n        elif fed_args is None:\n            raise ValueError(\"fed_args must be provided when aggregate_model is True\")\n\n        local_mode = True if not aggregate_model else False\n            \n        super().__init__(\n            ctx,\n            model,\n            training_args,\n            fed_args,\n            train_set,\n            val_set,\n            optimizer,\n            scheduler,\n            data_collator,\n            tokenizer,\n            callbacks,\n            compute_metrics,\n            local_mode,\n            save_trainable_weights_only,\n            preprocess_logits_for_metrics\n        )\n        self._aggregate_model = aggregate_model\n\n\n    def _share_model(self, model, args: Seq2SeqTrainingArguments, sync_trainable_only=True):\n\n        if args.local_rank == 0:\n            for p in model.parameters():\n                if (not sync_trainable_only) or (sync_trainable_only and p.requires_grad):\n                    scatter_list = [p.data for _ in range(args.world_size)]\n                    dist.scatter(p.data, scatter_list, async_op=False)\n        else:\n            for p in model.parameters():\n                if (not sync_trainable_only) or (sync_trainable_only and p.requires_grad):\n                    dist.scatter(p.data, src=0, async_op=False)\n\n    def on_train_begin(self, ctx: Context, aggregator: Aggregator, fed_args: FedArguments, \n                       args: TrainingArguments, model: Module = None, optimizer: Optimizer = None, scheduler: _LRScheduler = None, \n                       dataloader: Tuple[DataLoader]= None, control: TrainerControl= None, \n                       state: TrainerState = None, **kwargs):\n        \n        if args.local_rank == 0: # master\n            logger.info('receving weights from server')\n            parameters_to_get = ctx.arbiter.get('sub_model_para')\n            model = unwrap_model(model)\n            model.load_submodel_weights(parameters_to_get)\n            logger.info('received submodel weigths from the server')\n            if args.world_size > 1:\n                self._share_model(model, args)\n                logger.info('sharing model parameters done')\n        else:\n            if args.world_size > 1:\n                model = unwrap_model(model)\n                self._share_model(model, args)\n                logger.info('sharing model parameters done')\n\n    def on_federation(\n        self,\n        ctx: Context,\n        aggregator,\n        fed_args: FedArguments,\n        args: TrainingArguments,\n        model: Optional[OffsiteTuningBaseModel] = None,\n        optimizer: Optional[Optimizer] = None,\n        scheduler: Optional[_LRScheduler] = None,\n        dataloader: Optional[Tuple[DataLoader]] = None,\n        control: Optional[TrainerControl] = None,\n        state: Optional[TrainerState] = None,\n        **kwargs,\n    ):\n        if self._aggregate_model:\n            aggregator.model_aggregation(ctx, model)\n\n\n    def on_train_end(self, ctx: Context, aggregator: Aggregator, fed_args: FedArguments, \n                    args: TrainingArguments, model: OffsiteTuningBaseModel = None, optimizer: Optimizer = None, scheduler: _LRScheduler = None, \n                    dataloader: Tuple[DataLoader]= None, control: TrainerControl= None, \n                    state: TrainerState = None, **kwargs):\n\n        if args.local_rank == 0:\n            if args.world_size > 1:\n                model = unwrap_model(model)\n            return_weights = model.get_submodel_weights(with_emulator=False)\n            ctx.arbiter.put('trained_sub_model_para', return_weights)\n            logger.info('weights sent back to the server')\n\n    def init_aggregator(self, ctx: Context, fed_args: FedArguments):\n        if self._aggregate_model:\n            return super().init_aggregator(ctx, fed_args)\n        else:\n            return None\n\n\nclass OffsiteTuningTrainerServer(Seq2SeqFedAVGServer):\n    \n    def __init__(self, ctx: Context, model: OffsiteTuningBaseModel, aggregate_model=False) -> None:\n        self._aggregate_model = aggregate_model\n        super().__init__(ctx, local_mode=False)\n        assert isinstance(model, OffsiteTuningBaseModel), \"model must be the subclass of OffsiteTuningBaseModel\"\n        self.model = model\n\n    def on_train_begin(self, ctx: Context, aggregator: Aggregator):\n        logger.info('sending weights to clients')\n        parameters_to_send = self.model.get_submodel_weights()\n        ctx.guest.put('sub_model_para', parameters_to_send)\n        if any(p.role=='host' for p in ctx.parties):\n            ctx.hosts.put('sub_model_para', parameters_to_send)\n\n    def on_train_end(self, ctx: Context, aggregator: Aggregator):\n        parameters_to_get = ctx.guest.get('trained_sub_model_para')\n        self.model.load_submodel_weights(parameters_to_get, with_emulator=False)\n        logger.info('received trained submodel weigths from the client')\n\n    def on_federation(self, ctx: Context, aggregator, agg_iter_idx: int):\n        if self._aggregate_model:\n            aggregator.model_aggregation(ctx)\n        else:\n            logger.info('skip aggregation')\n\n    def init_aggregator(self, ctx):\n        if self._aggregate_model:\n            return super().init_aggregator(ctx)\n        else:\n            return None\n        \n    def train(self):\n\n        if self._aggregate_model:\n            super().train()\n        else:\n            # do nothing but send the submodel weights to the client\n            # and then aggregate the weights from the client\n            self.on_init_end(self.ctx, aggregator=self.aggregator)\n            self.on_train_begin(self.ctx, aggregator=self.aggregator)\n            self.on_train_end(self.ctx, aggregator=self.aggregator)\n\n    def save_model(\n        self,\n        output_dir: Optional[str] = None,\n        state_dict=None\n    ):\n        import torch\n        import os\n        if not os.path.exists(output_dir):\n            os.makedirs(output_dir)\n        torch.save(self.model.state_dict(), output_dir + '/pytorch_model.bin')\n"
  },
  {
    "path": "python/fate_llm/algo/ppc-gpt/__init__.py",
    "content": ""
  },
  {
    "path": "python/fate_llm/data/__init__.py",
    "content": ""
  },
  {
    "path": "python/fate_llm/data/data_collator/__init__.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n"
  },
  {
    "path": "python/fate_llm/data/data_collator/cust_data_collator.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom transformers.data import data_collator\nfrom ..tokenizers.cust_tokenizer import get_tokenizer\n\n\ndef get_data_collator(data_collator_name,\n                      tokenizer_name_or_path=None,\n                      pad_token=None,\n                      bos_token=None,\n                      eos_token=None,\n                      pad_token_id=None,\n                      bos_token_id=None,\n                      eos_token_id=None,\n                      trust_remote_code=False, **kwargs):\n    if not hasattr(data_collator, data_collator_name):\n        support_collator_list = list(filter(lambda module_name: \"collator\" in module_name.lower(), dir(data_collator)))\n        return ValueError(f\"data_collator's name={data_collator_name} does not in support list={support_collator_list}\")\n\n    tokenizer = get_tokenizer(tokenizer_name_or_path=tokenizer_name_or_path,\n                              pad_token=pad_token,\n                              bos_token=bos_token,\n                              eos_token=eos_token,\n                              pad_token_id=pad_token_id,\n                              bos_token_id=bos_token_id,\n                              eos_token_id=eos_token_id,\n                              trust_remote_code=trust_remote_code)\n\n    return getattr(data_collator, data_collator_name)(tokenizer, **kwargs)\n\n\ndef get_seq2seq_data_collator(tokenizer_name_or_path, **kwargs):\n    return get_data_collator(\"DataCollatorForSeq2Seq\", tokenizer_name_or_path=tokenizer_name_or_path, **kwargs)\n"
  },
  {
    "path": "python/fate_llm/data/data_collator/fedcot_collator.py",
    "content": "from transformers import DataCollatorForSeq2Seq \nfrom transformers import AutoTokenizer\nimport pandas as pd\n\nclass PrefixDataCollator(DataCollatorForSeq2Seq):\n    def __call__(self, features, return_tensors=None):\n        features_df = pd.DataFrame(features)\n        cot = super().__call__(list(features_df['predict']), return_tensors)\n        label = super().__call__(list(features_df['rationale']), return_tensors)\n\n        return {\n            'predict': cot,\n            'rationale': label\n        }\n\n\ndef get_prefix_data_collator(tokenizer_name_or_path):\n    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)\n    data_collator = PrefixDataCollator(tokenizer)\n    return data_collator\n"
  },
  {
    "path": "python/fate_llm/data/tokenizers/__init__.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n"
  },
  {
    "path": "python/fate_llm/data/tokenizers/cust_tokenizer.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom transformers import AutoTokenizer\n\n\ndef get_tokenizer(\n    tokenizer_name_or_path,\n    trust_remote_code=False,\n    padding_side=None,\n    pad_token=None,\n    bos_token=None,\n    eos_token=None,\n    pad_token_id=None,\n    bos_token_id=None,\n    eos_token_id=None,\n    add_eos_token=True,\n):\n    tokenizer = AutoTokenizer.from_pretrained(\n        tokenizer_name_or_path,\n        trust_remote_code=trust_remote_code,\n        add_eos_token=add_eos_token\n    )\n    if padding_side is not None:\n        tokenizer.padding_side = padding_side\n    if pad_token is not None:\n        tokenizer.add_special_tokens({'pad_token': pad_token})\n    if bos_token is not None:\n        tokenizer.add_special_tokens({'bos_token': bos_token})\n    if eos_token is not None:\n        tokenizer.add_special_tokens({\"eos_token\": eos_token})\n    if pad_token_id is not None:\n        tokenizer.pad_token_id = pad_token_id\n    if bos_token_id is not None:\n        tokenizer.bos_token_id = bos_token_id\n    if eos_token_id is not None:\n        tokenizer.eos_token_id = eos_token_id\n\n    if \"llama\" in tokenizer_name_or_path.lower() or \"gpt2\" in tokenizer_name_or_path.lower():\n        tokenizer.pad_token = tokenizer.eos_token\n\n    return tokenizer\n"
  },
  {
    "path": "python/fate_llm/dataset/__init__.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n"
  },
  {
    "path": "python/fate_llm/dataset/data_config/__init__.py",
    "content": "import os\n# absolute path to current directory\nparent_dir = os.path.dirname(os.path.realpath(__file__))\n\nDATA_CONFIG_TEMPLATE = {\"ag_news\": os.path.join(parent_dir, \"default_ag_news.yaml\"),\n                        \"yelp_review\": os.path.join(parent_dir, \"default_yelp_review.yaml\"),}"
  },
  {
    "path": "python/fate_llm/dataset/data_config/default_ag_news.yaml",
    "content": "dataset_kwargs:\n  data_files: ag_news_review/AGnews/train.json\ndataset_path: json\ndoc_to_target: '{{label}}'\nmetric_list:\n- aggregation: mean\n  higher_is_better: true\n  metric: accuracy\noutput_type: generate_until\ntask: ag-news\nvalidation_split: train\nlabel_key: label\ntext_key: text\nsub_domain: AGnews\nfew_shot_num_per_label: 2\ntokenize_format: \"Product type: {{sub_domain}} | Text Category: {{label}}\"\nfew_shot_format: \"- <Category>: {{label}}.\\n- <News>: {{text}}\\n\\n\"\naugment_format: \"The news' topics belong to the following 4 categories: 0.world 1.sports 2.business 3.science and technology. Please generate news according to the following format, bearing in mind that the generated results should not resemble the examples, but should align with the specified category: \\n\"\ntext_with_label_format: \"******\\n {{i}}.\\nNews: {{text}}\\nCategory: {{label}}.\\n\"\nfilter_format: \"I will give you some news samples with their categories, The news' topics belong to the following 4 categories: 0.world 1.sports 2.business 3.science and technology. the samples are delimited by '******':\\n {text_with_label} Please filter out texts that are ambiguous, do not belong to news or do not meet the categories, and leave news texts that meet the categories.\\n You should also filter out news text that are too similar to other samples and keep the most representative ones. Your answer should begin with 'The eligible samples:\\n\\n' and the indexes of the texts you choose, use spaces to separate the indexes and do not provide duplicate indices or indices that exceed the maximum index of samples.\"\nlabel_list:\n  - 'world'\n  - 'sports'\n  - 'business'\n  - 'science and technology'"
  },
  {
    "path": "python/fate_llm/dataset/data_config/default_yelp_review.yaml",
    "content": "dataset_kwargs:\n  data_files: yelp_review/Health/train.json\ndataset_path: json\ndoc_to_target: '{{label}}'\nmetric_list:\n- aggregation: mean\n  higher_is_better: true\n  metric: accuracy\noutput_type: generate_until\ntask: yelp-review\nlabel_key: stars\ntext_key: text\nvalidation_split: train\nsub_domain: Health\nfew_shot_num_per_label: 2\ntokenize_format: \"Product type: {{sub_domain}} | Review Score: {{label}}\"\ntext_with_label_format: \"******\\n {{i}}.\\nReview: {{text}}\\nRating stars: {{label}}.\\n\"\nfew_shot_format: \"******\\n- <Rating>: {{label}} stars.\\n- <Review>: {{text}}\\n\\n\"\naugment_format: \"The reviews are rated from 1 to 5 stars, with 1 being the worst, 3 being neutral and 5 being the best. Please generate more similar samples for each rating star about the Health domain as shown in the following format, bearing in mind that the generated results should not copy or resemble the examples, and should align with the {{sub_domain}} domain and the rating stars.\\nThe examples are delimited by '******'.\"\nfilter_format: \"I will give you some customer review text samples with their rating stars, these samples are indexed starting from 0, the samples are delimited by '******':\\n {{text_with_label}}. These reviews gradually shift from negative to positive from 1 star to 5 stars. 1 star represents the worst, 2 stars are better than 1 star, but still indicate a negative review. 3 stars represent a neutral review. 4 stars indicate a positive review, but less positive than 5 stars. 5 stars represent perfection.\\n Please filter out text that does not belong to customer reviews or does not meet the rating stars, and leave review texts that meet the labels.\\n You should also filter out text that are too similar to other samples and keep the most representative ones. Your answer should begin with 'The eligible samples:\\n\\n' and the indexes of the texts you choose, use spaces to separate the indexes and do not provide duplicate indices or indices that exceed the maximum index of samples.\"\nlabel_list:\n  - 1\n  - 2\n  - 3\n  - 4\n  - 5"
  },
  {
    "path": "python/fate_llm/dataset/fedcot_dataset.py",
    "content": "from fate_llm.dataset.input_output_dataset import InputOutputDataset\nfrom transformers.trainer_pt_utils import LabelSmoother\nfrom typing import List, Dict, Union, Literal\nimport logging\nfrom jinja2 import Template\nfrom transformers import AutoTokenizer\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass PrefixDataset(InputOutputDataset):\n\n    def __init__(self, \n                tokenizer_path,\n                predict_input_template: str,\n                predict_output_template: str,\n                rationale_input_template: str,\n                rationale_output_template: str,\n                max_input_length: int = 256, \n                max_target_length: int = 256,\n                load_from: Literal['jsonl', 'hf_load_from_disk', 'hf_load_dataset'] = 'hf_load_from_disk',\n                split_key: str = None\n                ):\n\n        super().__init__(tokenizer_path, predict_input_template, predict_output_template, max_input_length, max_target_length, load_from, split_key)\n        self.r_input_template = Template(rationale_input_template)\n        self.r_output_template = Template(rationale_output_template)\n\n    def load_rationale(self, result_list, key='rationale'):\n        for d, r in zip(self.dataset, result_list):\n            d[key] = r\n\n    def get_str_item(self, i) -> dict:\n\n        data_item = self.dataset[i]\n        p_in = self.input_template.render(data_item)\n        p_out = self.output_template.render(data_item)\n        r_in = self.r_input_template.render(data_item)\n        r_out = self.r_output_template.render(data_item)\n        ret_dict = {\n            'predict':{\n                'input': p_in,\n                'output': p_out\n            },\n            'rationale':{\n                'input': r_in,\n                'output': r_out\n            }\n        }\n        return ret_dict\n    \n    def get_tokenized_item(self, i) -> dict:   \n\n        str_item = self.get_str_item(i)\n        ret_dict = {\n            'predict': self._process_item(str_item['predict']),\n            'rationale': self._process_item(str_item['rationale'])\n        }\n\n        return ret_dict\n"
  },
  {
    "path": "python/fate_llm/dataset/flex_dataset.py",
    "content": "#\n#  Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n\nimport logging\nimport pickle\nimport re\nfrom datasets import load_dataset\nfrom fastchat.model import get_conversation_template\nfrom jinja2 import Template\nfrom ruamel import yaml\nfrom transformers import AutoTokenizer\nfrom typing import Union, Literal\n\nfrom fate.ml.nn.dataset.base import Dataset\nfrom fate_llm.dataset.data_config import DATA_CONFIG_TEMPLATE\n\nlogger = logging.getLogger(__name__)\n\n\n\"\"\"\nImplementation of FDKT augmentation process, adopted from https://arxiv.org/abs/2405.14212\n\"\"\"\n\n\ndef get_jinjax_placeholders(jinjax_text, placeholder_count=2):\n    pattern = r\"<([^>]+)>\"\n    matches = re.findall(pattern, jinjax_text)\n\n    return matches[:placeholder_count]\n\n\ndef regex_replace(string, pattern, repl, count: int = 0):\n    \"\"\"\n    adopted from lm-evaluation-harness/lm-eval/utils.py for offline use\n    Parameters\n    ----------\n    string\n    pattern\n    repl\n    count\n\n    Returns\n    -------\n\n    \"\"\"\n    return re.sub(pattern, repl, string, count=count)\n\n\ndef apply_template(template, data):\n    \"\"\"\n    adopted from lm-evaluation-harness/lm-eval/utils.py for offline use\n    Parameters\n    ----------\n    template\n    data\n\n    Returns\n    -------\n\n    \"\"\"\n    return Template(template).render(data)\n\n\ndef tokenize_flex_dataset(raw_datasets, tokenizer, sub_domain, tokenize_format, text_key, label_key, data_part=\"train\",\n                          save_path=None, max_prompt_len=256):\n    tokenizer.pad_token = tokenizer.eos_token\n    column_names = raw_datasets[data_part].column_names\n\n    def tokenize_function(examples):\n        texts = tokenizer(examples[text_key])\n\n        label_processed = [apply_template(tokenize_format,{\"sub_domain\": sub_domain,\"label\": label})\n                           for label in examples[label_key]]\n        labels = tokenizer(label_processed)\n        input_ids = [i2 + i1 for i1, i2 in zip(texts['input_ids'], labels['input_ids'])]\n        attention_mask = [i2 + i1 for i1, i2 in zip(texts['attention_mask'], labels['attention_mask'])]\n\n        \"\"\"\n        cut off max prompt length\n        \"\"\"\n        input_ids = [t[: max_prompt_len] for t in input_ids]\n        attention_mask = [t[: max_prompt_len] for t in attention_mask]\n\n        out = {\"input_ids\": input_ids,\n               \"attention_mask\": attention_mask,\n               \"labels\": input_ids}\n        return out\n\n    tokenized_datasets = raw_datasets.map(\n        tokenize_function,\n        batched=True,\n        num_proc=4,\n        remove_columns=column_names,\n        desc=\"Running tokenizer on dataset\",\n    )\n\n    if save_path is not None:\n        tokenized_datasets.save_to_disk(save_path)\n\n    return tokenized_datasets\n\n\nclass FlexDataset(Dataset):\n    def __init__(self,\n                 tokenizer_name_or_path,\n                 dataset_name: str,\n                 load_from: Literal['json'] = 'json',\n                 data_part: str = None,\n                 config: Union[dict, str] = None,\n                 need_preprocess: bool = True,\n                 random_state: int = None,\n                 max_prompt_len: int = 256,\n                 select_num: int = None,\n                 few_shot_num_per_label: int = None\n                 ):\n\n        super().__init__()\n        self.tokenizer = None\n        self.tokenizer_name_or_path = tokenizer_name_or_path\n        self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name_or_path, trust_remote_code=True)\n        self.dataset_name = dataset_name\n        if self.dataset_name and config is None:\n            config = DATA_CONFIG_TEMPLATE.get(self.dataset_name, \"\")\n        self.load_from = load_from\n        self.data_part = data_part\n        self.random_state = random_state\n        self.need_preprocess = need_preprocess\n        self.max_prompt_len = max_prompt_len\n        self.select_num = select_num\n        self.dataset = None\n        self.ds = None\n        self.label_key = None\n        self.text_key = None\n        self.augment_format = None\n        self.filter_format = None\n        self.few_shot_format = None\n        self.tokenize_format = None\n        self.sub_domain = None\n        self.label_list = None\n        self.text_with_label_format = None\n        self.few_shot_num_per_label = few_shot_num_per_label\n        self.config = config\n        if isinstance(config, str):\n            with open(config, 'r') as f:\n                self.config = yaml.safe_load(f)\n        self.parse_config()\n\n    def parse_config(self, config=None):\n        if config is None:\n            config = self.config\n        self.label_key = config.get(\"label_key\", None)\n        self.text_key = config.get(\"text_key\", None)\n        self.augment_format = config.get(\"augment_format\", None)\n        self.filter_format = config.get(\"filter_format\", None)\n        self.tokenize_format = config.get(\"tokenize_format\", None)\n        self.sub_domain = config.get(\"sub_domain\", None)\n        self.label_list = config.get(\"label_list\", None)\n        self.few_shot_format = config.get(\"few_shot_format\", None)\n        self.text_with_label_format = config.get(\"text_with_label_format\", None)\n        if self.few_shot_num_per_label is None:\n            self.few_shot_num_per_label = config.get(\"few_shot_num_per_label\", 2)\n\n    def get_generate_prompt(self, tokenize=True, return_tensors=\"pt\"):\n        prompt_list = [apply_template(self.tokenize_format,\n                                      {\"sub_domain\": self.sub_domain,\n                                       \"label\": label}) for label in self.label_list]\n        if tokenize:\n            tokenized_prompts = self.tokenizer(prompt_list, return_tensors=return_tensors)\n            prompt_list = tokenized_prompts['input_ids']\n\n        return {label: prompt for label, prompt in zip(self.label_list, prompt_list)}\n\n    @staticmethod\n    def construct_prompt_list(samples_dict, num_shot_per_label, prompt_num, format_template, random_state=None):\n        from sklearn.utils import resample\n        from collections import deque\n\n        label_samples = {label: deque(resample(samples,\n                                               replace=False,\n                                               n_samples=len(samples))) for label, samples in samples_dict.items()}\n        def get_samples_for_label(label):\n            samples = []\n            while len(samples) < num_shot_per_label:\n                remaining_needed = num_shot_per_label - len(samples)\n                if len(label_samples[label]) < remaining_needed:\n                    batch_samples = list(label_samples[label])\n                    samples.extend(batch_samples)\n                    # reset to allow repetition\n                    label_samples[label] = deque(resample(samples_dict[label],\n                                                          replace=False,\n                                                          n_samples=len(samples_dict[label])))\n                else:\n                    batch_samples = [label_samples[label].popleft() for _ in range(remaining_needed)]\n                    samples.extend(batch_samples)\n            return samples\n\n        result = []\n        for _ in range(prompt_num):\n            prompt = ''\n            for label in samples_dict.keys():\n                samples = get_samples_for_label(label)\n                for text in samples:\n                    prompt += apply_template(format_template, {\"text\": text, \"label\": label})\n            result.append(prompt)\n        return result\n\n    @staticmethod\n    def group_text_label_list(text_list, label_list):\n        group_data = [{\"text\": text, \"label\": label} for text, label in zip(text_list, label_list)]\n        return group_data\n\n    def prepare_few_shot(self, text_list, label_list, aug_prompt_num):\n        from collections import defaultdict\n        data_dict = defaultdict(list)\n        for text, label in zip(text_list, label_list):\n            # in case extra labels are present, ignore\n            if label in self.label_list:\n                data_dict[label].append(text)\n        few_shot_list = FlexDataset.construct_prompt_list(samples_dict=data_dict,\n                                                          num_shot_per_label=self.few_shot_num_per_label,\n                                                          prompt_num=aug_prompt_num,\n                                                          format_template=self.few_shot_format,\n                                                          random_state=self.random_state)\n\n        return few_shot_list\n\n    def prepare_augment(self, text_list, label_list, aug_prompt_num):\n        few_shot_samples = self.prepare_few_shot(text_list, label_list, aug_prompt_num)\n        result = []\n        instruction = apply_template(self.augment_format, {\"sub_domain\": self.sub_domain})\n        for i, sample in enumerate(few_shot_samples):\n            query =  instruction + '\\n' + sample\n            formatted_query = self.apply_chat_template(query)\n            result.append(formatted_query)\n        return result\n\n    def abstract_from_augmented(self, sample_list):\n        label_key, text_key = get_jinjax_placeholders(self.few_shot_format, 2)\n        res = {'inputs': [], 'labels': []}\n        for sample in sample_list:\n            data_list = sample.split('\\n\\n-')\n            for entry in data_list:\n                temp = entry.split(f\"<{text_key}>:\")\n                # print(f\"temp: {temp}\")\n                if len(temp) == 2 and f\"<{label_key}>\" in temp[0]:\n                    label_str, input_str = temp\n                    label = label_str.split(f\"<{label_key}>:\")[1].strip()\n                    if isinstance(self.label_list[0], int) and label[0].isdigit():\n                        label = int(label[0])\n                    elif isinstance(self.label_list[0], float) and re.match(r'^\\d+\\.\\d*?$', label):\n                        label = float(label[0])\n                    # abstracted label value does not match the original label type\n                    elif isinstance(self.label_list[0], int) or isinstance(self.label_list[0], float):\n                        continue\n                    text = input_str.replace('</s>', '').rstrip('*')\n                    text = text.strip()\n                    res['inputs'].append(text)\n                    res['labels'].append(label)\n        # print(f\"res: {res}\")\n        return res\n\n    def prepare_query_to_filter_clustered(self, clustered_sentences_list, clustered_labels_list):\n        prompt_list = []\n        for clustered_sentences, clustered_labels in zip(clustered_sentences_list, clustered_labels_list):\n            text_with_label = ''\n            for i in range(len(clustered_sentences)):\n                formatted_entry = apply_template(self.text_with_label_format, {\"i\": i,\n                                                                               \"text\": clustered_sentences[i],\n                                                                               \"label\": clustered_labels[i]})\n                text_with_label += formatted_entry\n            cluster_query = apply_template(self.filter_format, {\"text_with_label\": text_with_label})\n            prompt_list.append(self.apply_chat_template(cluster_query))\n        return prompt_list\n\n    def parse_clustered_response(self, clustered_sentence, clustered_labels, response_list):\n        \"\"\"\n        Parse the response from the clustering model and filter the data per cluster.\n        :param clustered_sentence: nested list of clustered sentences\n        :param clustered_labels: nested list of clustered labels\n        :param response_list: list of responses from the clustering model\n        \"\"\"\n        def parse_response(response):\n            pattern = r'The eligible samples:\\s*((?:\\b\\d+\\b[\\s.,]*)+)'\n            matches = re.search(pattern, response, re.MULTILINE)\n            if matches:\n                digits = [int(i) for i in re.findall(r'\\b\\d+\\b', matches.group())]\n            else:\n                digits = []\n            return list(set(digits))\n\n        filtered_text_list = []\n        filtered_label_list = []\n        for i in range(len(clustered_sentence)):\n            parsed_response = parse_response(response_list[i])\n            for idx in parsed_response:\n                if idx < len(clustered_sentence[i]):\n                    filtered_label_list.append(clustered_labels[i][idx])\n                    filtered_text_list.append(clustered_sentence[i][idx])\n        return filtered_text_list, filtered_label_list\n\n    @staticmethod\n    def group_data_list(data_list, text_key, label_key):\n        inputs = [entry[text_key] for entry in data_list]\n        labels = [entry[label_key] for entry in data_list]\n        data_dict = {text_key: inputs, label_key: labels}\n        return data_dict\n\n    def load(self, path):\n        local_data = load_dataset('json', data_files={self.data_part: path})\n        self.dataset = local_data\n        if not self.need_preprocess:\n            self.ds = local_data\n        else:\n            tokenized_ds = tokenize_flex_dataset(\n                raw_datasets=local_data,\n                tokenizer=self.tokenizer,\n                sub_domain=self.sub_domain,\n                tokenize_format=self.tokenize_format,\n                text_key=self.text_key,\n                label_key=self.label_key,\n                max_prompt_len=self.max_prompt_len\n            )\n            self.ds = tokenized_ds[self.data_part]\n\n        if self.select_num is not None:\n            self.ds = self.ds.select(range(self.select_num))\n\n    def apply_chat_template(self, query):\n        tokenizer = self.tokenizer\n\n        if \"llama-3\" in self.tokenizer_name_or_path.lower():\n            msg = [\n                {\"role\": \"system\", \"content\": \"You are a helpful assistant. \"},\n                {\"role\": \"user\", \"content\": query}\n            ]\n            prompt = tokenizer.apply_chat_template(msg, add_generation_prompt=True, tokenize=False)\n        else:\n            conv = get_conversation_template(self.tokenizer_name_or_path)\n            conv.append_message(conv.roles[0], query)\n            conv.append_message(conv.roles[1], None)\n            prompt = conv.get_prompt()\n\n        return prompt\n\n    def get_raw_dataset(self):\n        return self.dataset\n\n    def __len__(self):\n        return len(self.ds)\n\n    def get_item(self, i):\n        return self.dataset[self.data_part][i]\n\n    def get_item_dict(self, i):\n        return {\"text\": self.dataset[self.data_part][self.text_key][i],\n                \"label\": self.dataset[self.data_part][self.label_key][i]}\n\n    def __getitem__(self, i) -> dict:\n        return self.ds[i]\n"
  },
  {
    "path": "python/fate_llm/dataset/hf_dataset.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport os\nfrom typing import Optional, Union, Sequence, Mapping, Dict\n\nfrom datasets import load_dataset, Features, Split, DownloadConfig, DownloadMode, VerificationMode, Version, load_from_disk\nfrom transformers import AutoTokenizer\n\nfrom fate.ml.nn.dataset.base import Dataset\n\n# avoid tokenizer parallelism\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\n\nclass HuggingfaceDataset(Dataset):\n    \"\"\"\n    A dataset class for huggingface datasets\n    \"\"\"\n\n    def __init__(\n            self,\n            name: Optional[str] = None,\n            data_dir: Optional[str] = None,\n            data_files: Optional[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]] = None,\n            split: Optional[Union[str, Split]] = None,\n            cache_dir: Optional[str] = None,\n            features: Optional[Features] = None,\n            download_config: Optional[DownloadConfig] = None,\n            download_mode: Optional[Union[DownloadMode, str]] = None,\n            verification_mode: Optional[Union[VerificationMode, str]] = None,\n            ignore_verifications=\"deprecated\",\n            keep_in_memory: Optional[bool] = None,\n            save_infos: bool = False,\n            revision: Optional[Union[str, Version]] = None,\n            token: Optional[Union[bool, str]] = None,\n            use_auth_token=\"deprecated\",\n            task=\"deprecated\",\n            streaming: bool = False,\n            num_proc: Optional[int] = None,\n            storage_options: Optional[Dict] = None,\n            trust_remote_code: bool = None,\n            tokenizer_params: Optional[Dict] = None,\n            tokenizer_apply_params: Optional[Dict] = None,\n            load_from_disk: Optional[bool] = False,\n            inplace_load: Optional[bool] = True,\n            data_split_key: Optional[str] = None,\n            **config_kwargs,\n    ):\n        self.name = name\n        self.data_dir = data_dir\n        self.data_files = data_files\n        self.split = split\n        self.cache_dir = cache_dir\n        self.features = features\n        self.download_config = download_config\n        self.download_mode = download_mode\n        self.verification_mode = verification_mode\n        self.ignore_verifications = ignore_verifications\n        self.keep_in_memory = keep_in_memory\n        self.save_infos = save_infos\n        self.revision = revision\n        self.token = token\n        self.use_auth_token = use_auth_token\n        self.task = task\n        self.streaming = streaming\n        self.num_proc = num_proc\n        self.storage_options = storage_options\n        self.trust_remote_code = trust_remote_code\n        self.tokenizer_params = tokenizer_params\n        self.tokenizer_apply_params = tokenizer_apply_params\n        self.config_kwargs = config_kwargs\n        self.load_from_disk = load_from_disk\n        self.inplace_load = inplace_load\n        self.data_split_key = data_split_key\n        self.ds = None\n\n        super(HuggingfaceDataset, self).__init__()\n\n    def load(self, file_path):\n        if not self.load_from_disk:\n            ds = load_dataset(path=file_path, name=self.name, data_dir=self.data_dir, data_files=self.data_files,\n                                split=self.split, cache_dir=self.cache_dir, features=self.features,\n                                download_config=self.download_config, download_mode=self.download_mode,\n                                verification_mode=self.verification_mode, ignore_verifications=self.ignore_verifications,\n                                keep_in_memory=self.keep_in_memory, save_infos=self.save_infos, revision=self.revision,\n                                token=self.token, use_auth_token=self.use_auth_token, task=self.task,\n                                streaming=self.streaming, num_proc=self.num_proc, storage_options=self.storage_options,\n                                trust_remote_code=self.trust_remote_code, **self.config_kwargs)\n        else:\n            ds = load_from_disk(file_path)\n\n        if self.data_split_key is not None:\n            ds = ds[self.data_split_key]\n\n        if self.inplace_load:\n            self.ds = ds\n        else:\n            return ds\n\n    def __getitem__(self, idx):\n        if self.ds is None:\n            raise ValueError('Dataset is not loaded')\n        return self.ds[idx]\n\n    def __len__(self):\n        if self.ds is None:\n            raise ValueError('Dataset is not loaded')\n        return len(self.ds)\n\n\nclass Dolly15K(HuggingfaceDataset):\n    INSTRUCTION_KEY = \"### Instruction:\"\n    INPUT_KEY = \"Input:\"\n    RESPONSE_KEY = \"### Response:\"\n    END_KEY = \"### End\"\n    RESPONSE_KEY_NL = f\"{RESPONSE_KEY}\\n\"\n    DEFAULT_SEED = 42\n    INTRO_BLURB = (\n        \"Below is an instruction that describes a task. Write a response that appropriately completes the request.\"\n    )\n    PROMPT_NO_INPUT_FORMAT = \"\"\"{intro}\n{instruction_key}\n{instruction}\n\n{response_key}\n{response}\n\n{end_key}\"\"\".format(\n        intro=INTRO_BLURB,\n        instruction_key=INSTRUCTION_KEY,\n        instruction=\"{instruction}\",\n        response_key=RESPONSE_KEY,\n        response=\"{response}\",\n        end_key=END_KEY,\n    )\n\n    # This is a training prompt that contains an input string that serves as context for the instruction.  For example,\n    # the input might be a passage from Wikipedia and the intruction is to extract some information from it.\n    PROMPT_WITH_INPUT_FORMAT = \"\"\"{intro}\n\n{instruction_key}\n{instruction}\n\n{input_key}\n{input}\n\n{response_key}\n{response}\n\n{end_key}\"\"\".format(\n        intro=INTRO_BLURB,\n        instruction_key=INSTRUCTION_KEY,\n        instruction=\"{instruction}\",\n        input_key=INPUT_KEY,\n        input=\"{input}\",\n        response_key=RESPONSE_KEY,\n        response=\"{response}\",\n        end_key=END_KEY,\n    )\n\n    def __init__(self, *args, **kwargs):\n        super(Dolly15K, self).__init__(*args, **kwargs)\n        self.inplace_load = False\n\n    def load(self, file_path):\n        dataset = super().load(file_path)\n        return self._post_process(dataset)\n\n    def _post_process(self, dataset):\n\n        def _add_text(rec):\n            instruction = rec[\"instruction\"]\n            response = rec[\"response\"]\n            context = rec.get(\"context\")\n\n            if not instruction:\n                raise ValueError(f\"Expected an instruction in: {rec}\")\n\n            if not response:\n                raise ValueError(f\"Expected a response in: {rec}\")\n\n            # For some instructions there is an input that goes along with the instruction, providing context for the\n            # instruction.  For example, the input might be a passage from Wikipedia and the instruction says to extract\n            # some piece of information from it.  The response is that information to extract.  In other cases there is\n            # no input.  For example, the instruction might be open QA such as asking what year some historic figure was\n            # born.\n            if context:\n                rec[\"text\"] = self.PROMPT_WITH_INPUT_FORMAT.format(instruction=instruction, response=response,\n                                                                   input=context)\n            else:\n                rec[\"text\"] = self.PROMPT_NO_INPUT_FORMAT.format(instruction=instruction, response=response)\n            return rec\n\n        dataset = dataset.map(_add_text)\n\n        tokenizer = AutoTokenizer.from_pretrained(**self.tokenizer_params)\n\n        def tokenize_function(examples):\n            return tokenizer(examples[\"text\"], **self.tokenizer_apply_params)\n\n        dataset = dataset.map(tokenize_function, batched=True)\n        return dataset\n"
  },
  {
    "path": "python/fate_llm/dataset/input_output_dataset.py",
    "content": "from fate.ml.nn.dataset.base import Dataset\nfrom transformers.trainer_pt_utils import LabelSmoother\nfrom typing import List, Dict, Union, Literal\nimport logging\nfrom jinja2 import Template\nfrom transformers import AutoTokenizer\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass InputOutputDataset(Dataset):\n\n    def __init__(self, \n                tokenizer_path,\n                input_template: str,\n                output_template: str,\n                max_input_length: int = 256, \n                max_target_length: int = 256,\n                load_from: Literal['jsonl', 'hf_load_from_disk', 'hf_load_dataset'] = 'hf_load_from_disk',\n                split_key: str = None\n                ):\n\n        super().__init__()\n        self.tokenizer = None\n        self.tokenizer_path = tokenizer_path\n        self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path, trust_remote_code=True)\n        self.max_input_length = max_input_length\n        self.max_target_length = max_target_length\n        self.dataset = None\n        self.load_from = load_from\n        self.input_template = Template(input_template)\n        self.output_template = Template(output_template)\n        self.split_key = split_key\n        self.max_seq_length = max_input_length + max_target_length + 1\n\n    def load(self, path):\n        if self.load_from == 'hf_load_from_disk':\n            import datasets\n            self.dataset = datasets.load_from_disk(path)\n            if self.split_key is not None:\n                self.dataset = self.dataset[self.split_key]\n            self.dataset = [i for i in self.dataset]\n        elif self.load_from == 'jsonl':\n            import json\n            with open(path, 'r') as f:\n                json_lines = f.read().split('\\n')\n            self.dataset = []\n            for i in json_lines:\n                try:\n                    self.dataset.append(json.loads(i))\n                except:\n                    print('skip line')\n        elif self.load_from == 'hf_load_dataset':\n            from datasets import load_dataset\n            self.dataset = load_dataset(path)\n            if self.split_key is not None:\n                self.dataset = self.dataset[self.split_key]\n            self.dataset = [i for i in self.dataset]\n        else:\n            raise ValueError('unknown load format')\n\n        if not isinstance(self.dataset, list) or not isinstance(self.dataset[0], dict):\n            logger.warn('loaded dataset is expected to be a list of dict')\n\n    def get_raw_dataset(self):\n        return self.dataset\n\n    def __len__(self):\n        return len(self.dataset)\n\n    def get_str_item(self, i) -> dict:\n\n        data_item = self.dataset[i]\n        in_ = self.input_template.render(**data_item)\n        out_ = self.output_template.render(**data_item)\n        return {\n            'input': in_,\n            'output': out_\n        }\n\n    def _process_item(self, data_item):\n\n        a_ids = self.tokenizer.encode(text=data_item['input'], add_special_tokens=True, truncation=True,\n                                      max_length=self.max_input_length)\n        b_ids = self.tokenizer.encode(text=data_item['output'], add_special_tokens=False, truncation=True,\n                                      max_length=self.max_target_length)\n\n        context_length = len(a_ids)\n        input_ids = a_ids + b_ids + [self.tokenizer.eos_token_id]\n        labels = [self.tokenizer.pad_token_id] * context_length + b_ids + [self.tokenizer.eos_token_id]\n\n        pad_len = self.max_seq_length - len(input_ids)\n        input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len\n        labels = labels + [self.tokenizer.pad_token_id] * pad_len\n        labels = [(l if l != self.tokenizer.pad_token_id else -100) for l in labels]\n\n        assert len(input_ids) == len(labels), f\"length mismatch: {len(input_ids)} vs {len(labels)}\"\n\n        return {\n            \"input_ids\": input_ids,\n            \"labels\": labels\n        }\n\n    def get_tokenized_item(self, i) -> dict:   \n\n        str_item = self.get_str_item(i)\n        ret_dict = self._process_item(str_item)\n        return ret_dict\n\n    def __getitem__(self, i) -> dict:\n        item = self.get_tokenized_item(i)\n        return item\n"
  },
  {
    "path": "python/fate_llm/dataset/prompt_dataset.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport copy\nimport json\n\nimport datasets\nimport torch\nfrom fate.ml.nn.dataset.base import Dataset\nfrom ..data.tokenizers.cust_tokenizer import get_tokenizer\n\n\nPROMPT_TEMPLATE = \"{prompt}\"\n\n\nclass PromptDataset(Dataset):\n    def __init__(self,\n                 text_max_length=512,\n                 tokenizer_name_or_path=None,\n                 trust_remote_code=False,\n                 padding=False,\n                 padding_side='left',\n                 pad_token=None,\n                 pad_token_id=None,\n                 bos_token_id=None,\n                 eos_token_id=None,\n                 add_eos_token=True,\n                 prompt_template=None,\n                 add_special_tokens=False,\n                 prompt_column=\"content\",\n                 response_column=\"summary\",\n                 max_prompt_length=256,\n                 file_type=\"jsonl\",\n                 num_proc=4,\n                 ):\n\n        super(PromptDataset, self).__init__()\n        self.tokenizer = None\n        self.tokenizer_name_or_path = tokenizer_name_or_path\n        self.padding = padding\n        self.add_special_tokens = add_special_tokens\n        self.max_prompt_length = max_prompt_length\n        self.text_max_length = text_max_length\n\n        self.tokenizer = get_tokenizer(\n            tokenizer_name_or_path=tokenizer_name_or_path,\n            trust_remote_code=trust_remote_code,\n            pad_token=pad_token,\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            padding_side=padding_side,\n            add_eos_token=add_eos_token,\n        )\n\n        self.prompt_template = prompt_template if prompt_template else PROMPT_TEMPLATE\n        self.prompt_column = prompt_column\n        self.response_column = response_column\n        self.file_type = file_type\n        self.num_proc = num_proc\n        self._data = None\n\n    def load(self, file_path):\n        if \"jsonl\" in self.file_type:\n            prompts = []\n            responses = []\n            with open(file_path, \"r\") as fin:\n                for line in fin:\n                    line = json.loads(line)\n                    prompts.append(line[self.prompt_column])\n                    responses.append(line[self.response_column])\n\n            ds = datasets.Dataset.from_dict({self.prompt_column: prompts, self.response_column: responses})\n        else:\n            ds = datasets.load_from_disk(file_path)\n\n        self._data = ds.map(\n            self._process_data,\n            fn_kwargs={\"tokenizer\": self.tokenizer,\n                       \"prompt_template\": self.prompt_template,\n                       \"prompt_column\": self.prompt_column,\n                       \"response_column\": self.response_column,\n                       \"max_prompt_length\": self.max_prompt_length,\n                       \"max_length\": self.text_max_length\n                       },\n            batched=True,\n            remove_columns=ds.column_names,\n            num_proc=self.num_proc,\n        )\n\n        max_length = None\n        for d in self._data:\n            if max_length is None:\n                max_length = len(d[\"input_ids\"])\n            else:\n                max_length = max(max_length, len(d[\"input_ids\"]))\n\n        self._data = self._data.map(\n            self._pad_to_max_length,\n            batched=True,\n            fn_kwargs={\n                \"tokenizer\": self.tokenizer,\n                \"max_length\": max_length\n            },\n            num_proc=self.num_proc\n        )\n\n    @staticmethod\n    def _process_data(examples, tokenizer, prompt_template, prompt_column,\n                      response_column, max_prompt_length, max_length):\n        prompts = examples[prompt_column]\n        responses = examples[response_column]\n\n        processed_data = dict()\n        input_ids_list = []\n        labels_list = []\n        attention_mask_list = []\n        for _prompt, _response in zip(prompts, responses):\n            if isinstance(_response, list):\n                _response = _response[0]\n            _prompt = prompt_template.format_map(dict(prompt=_prompt))\n            prompt_encoded = tokenizer(_prompt)\n            if len(prompt_encoded['input_ids']) > 0 and prompt_encoded['input_ids'][-1] in tokenizer.all_special_ids:\n                prompt_encoded['input_ids'] = prompt_encoded['input_ids'][:-1]\n                prompt_encoded['attention_mask'] = prompt_encoded['attention_mask'][:-1]\n\n            target_encoded = tokenizer(_response)\n            if len(target_encoded['input_ids']) > 0 and target_encoded['input_ids'][-1] in tokenizer.all_special_ids:\n                target_encoded['input_ids'] = target_encoded['input_ids'][:-1]\n                target_encoded['attention_mask'] = target_encoded['attention_mask'][:-1]\n\n            prompt_ids = prompt_encoded[\"input_ids\"][: max_prompt_length]\n            prompt_attention_mask = prompt_encoded[\"attention_mask\"][:max_prompt_length]\n\n            target_ids = target_encoded[\"input_ids\"][: max_length - len(prompt_ids) - 1]\n            target_attention_mask = target_encoded[\"attention_mask\"][: max_length - len(prompt_ids) - 1]\n\n            if tokenizer.bos_token_id is not None:\n                seq_length = len(prompt_ids) + 1\n                input_ids = prompt_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]\n                labels = [-100] * seq_length + input_ids[seq_length:]\n                attention_mask = prompt_attention_mask + [1] + target_attention_mask + [1]\n            else:\n                seq_length = len(prompt_ids)\n                input_ids = prompt_ids + target_ids + [tokenizer.eos_token_id]\n                labels = [-100] * seq_length + input_ids[seq_length:]\n                attention_mask = prompt_attention_mask + target_attention_mask + [1]\n\n            input_ids_list.append(input_ids)\n            labels_list.append(labels)\n            attention_mask_list.append(attention_mask)\n\n        processed_data[\"labels\"] = labels_list\n        processed_data[\"input_ids\"] = input_ids_list\n        processed_data[\"attention_mask\"] = attention_mask_list\n\n        return processed_data\n\n    @staticmethod\n    def _pad_to_max_length(examples, tokenizer, max_length):\n        padded_input_ids = []\n        padded_labels = []\n        padded_attention_mask = []\n\n        labels_list = examples[\"labels\"]\n        input_ids_list = examples[\"input_ids\"]\n        attention_mask_list = examples[\"attention_mask\"]\n\n        for input_ids, attention_mask, labels in zip(input_ids_list, attention_mask_list, labels_list):\n            l = len(input_ids)\n            input_ids = torch.LongTensor(input_ids + [tokenizer.pad_token_id] * (max_length - l))\n            labels = torch.LongTensor(labels + [-100] * (max_length - l))\n            attention_mask = torch.LongTensor(attention_mask + [0] * (max_length - l))\n            padded_input_ids.append(input_ids)\n            padded_labels.append(labels)\n            padded_attention_mask.append(attention_mask)\n\n        return dict(\n            input_ids=padded_input_ids,\n            attention_mask=padded_attention_mask,\n            labels=padded_labels\n        )\n\n    def get_vocab_size(self):\n        return self.tokenizer.vocab_size\n\n    def __getitem__(self, item):\n        return self._data[item]\n\n    def __len__(self):\n        return len(self._data)\n\n    def __repr__(self):\n        return self.tokenizer.__repr__()\n"
  },
  {
    "path": "python/fate_llm/dataset/qa_dataset.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom datasets import load_from_disk, load_dataset\nfrom transformers import AutoTokenizer\nfrom fate.ml.nn.dataset.base import Dataset\n\n\"\"\"\nThese Data pre-processing templates are from https://github.com/mit-han-lab/offsite-tuning\n\"\"\"\n\n\nclass PIQA:\n    def __init__(self):\n        self._template = \"Question: {}\\nAnswer:\"\n\n    def get_context(self, examples):\n        ctx = examples['goal']\n        return [self._template.format(c) for c in ctx]\n\n    def get_target(self, examples):\n        if -1 in examples[\"label\"]:  # test set\n            return [\"\"] * len(examples[\"label\"])\n        else:\n            gt_tuples = [(\"sol{}\".format(label + 1), idx)\n                         for idx, label in enumerate(examples['label'])]\n            return [examples[k][i] for k, i in gt_tuples]\n\n\nclass SciQ:\n    def __init__(self):\n        self._template = \"{}\\nQuestion: {}\\nAnswer:\"\n\n    def get_context(self, examples):\n        sources = examples['support']\n        queries = examples['question']\n        return [self._template.format(s, q) for s, q in zip(sources, queries)]\n\n    def get_target(self, examples):\n        return examples['correct_answer']\n\n\nclass OpenBookQA:\n    def get_context(self, examples):\n        return examples['question_stem']\n\n    def get_target(self, examples):\n        choices = examples['choices']\n        answers = examples['answerKey']\n        targets = []\n        for choice, answer in zip(choices, answers):\n            answer = ord(answer.strip()) - ord('A')\n            targets.append(choice['text'][answer])\n        return targets\n\n\nclass ARC:\n    def __init__(self):\n        self._template = \"Question: {}\\nAnswer:\"\n\n    def get_context(self, examples):\n        ctx = examples['question']\n        return [self._template.format(c) for c in ctx]\n\n    def get_target(self, examples):\n        choices = examples['choices']\n        answers = examples['answerKey']\n        num_to_letter = {\"1\": \"A\", \"2\": \"B\", \"3\": \"C\", \"4\": \"D\", \"5\": \"E\"}\n        for idx, answer in enumerate(answers):\n            answer = num_to_letter.get(answer, answer)\n            answer = ord(answer) - ord(\"A\")\n            answers[idx] = choices[idx][\"text\"][answer]\n        return answers\n\n\nclass WIC:\n    def __init__(self):\n        self._template = \"Sentence 1: {}\\nSentence 2: {}\\nQuestion: Is the word '{}' used in the same way in the\" \\\n                         \" two sentences above?\\nAnswer:\"\n\n    def get_context(self, examples):\n        sentences_1 = examples[\"sentence1\"]\n        sentences_2 = examples[\"sentence2\"]\n        starts_1 = examples[\"start1\"]\n        ends_1 = examples[\"end1\"]\n\n        contexts = []\n        for s1, s2, st, ed in zip(sentences_1, sentences_2, starts_1, ends_1):\n            contexts.append(\n                self._template.format(s1, s2, s1[st: ed])\n            )\n\n        return contexts\n\n    def get_target(self, examples):\n        labels = examples[\"label\"]\n        targets = []\n        for label in labels:\n            targets.append(\" {}\".format({0: \"no\", 1: \"yes\"}[label]))\n\n        return targets\n\n\nclass BoolQ:\n    def __init__(self):\n        self._template = \"{}\\nQuestion: {}?\\nAnswer:\"\n\n    def get_context(self, examples):\n        passages = examples[\"passage\"]\n        questions = examples[\"question\"]\n        return [self._template.format(passage, question)\n                for passage, question in zip(passages, questions)\n                ]\n\n    def get_target(self, examples):\n        return [\" \" + \"yes\" if label else \"no\" for label in examples[\"answer\"]]\n\nclass CommonsenseQA:\n    def get_context(self, examples):\n        return examples[\"question\"]\n\n    def get_target(self, examples):\n        choices = examples['choices']\n        answers = examples['answerKey']\n        targets = []\n        for choice, answer in zip(choices, answers):\n            answer = ord(answer.strip()) - ord('A')\n            targets.append(choice['text'][answer])\n        return targets\n\n\nclass RTE:\n    def __init__(self):\n        self._template = \"{}\\nQuestion: {} True or False?\\nAnswer:\"\n\n    def get_context(self, examples):\n        sentences_1 = examples[\"premise\"]\n        sentences_2 = examples[\"hypothesis\"]\n        contexts = []\n        for sentence_1, sentence_2 in zip(sentences_1, sentences_2):\n            contexts.append(\n                self._template.format(sentence_1, sentence_2)\n            )\n\n        return contexts\n\n    def get_target(self, examples):\n        labels = examples[\"label\"]\n        return [\" {}\".format({0: \"True\", 1: \"False\"}[label]) for label in labels]\n\n\ntask_dict = {\n    \"piqa\": PIQA(),\n    \"sciq\": SciQ(),\n    \"openbookqa\": OpenBookQA(),\n    \"arc_easy\": ARC(),\n    \"arc_challenge\": ARC(),\n    \"wic\": WIC(),\n    \"boolq\": BoolQ(),\n    \"commonsenseqa\": CommonsenseQA(),\n    \"rte\": RTE()\n}\n\n\ndef tokenize_qa_dataset(dataset_name, tokenizer, save_path=None, seq_max_len=1000, data_part=\"train\", dataset=None):\n    max_len = seq_max_len\n    assert dataset_name in task_dict.keys(), f\"dataset name must be one of {list(task_dict.keys())}\"\n    if dataset is None:\n        raw_datasets = load_dataset(dataset_name)\n    else:\n        raw_datasets = dataset\n    task = task_dict[dataset_name]\n\n    column_names = raw_datasets[data_part].column_names\n\n    def tokenize_function(examples):\n        context = task.get_context(examples)\n        target = task.get_target(examples)\n\n        context = tokenizer(context)\n        target = tokenizer(target)\n\n        # if context is ending with special token, remove it\n        if len(context['input_ids'][0]) > 0 and context['input_ids'][0][-1] in tokenizer.all_special_ids:\n            context['input_ids'] = [i[:-1] for i in context['input_ids']]\n            context['attention_mask'] = [a[:-1]\n                                         for a in context['attention_mask']]\n\n        # if target is starting with special token, remove it\n        if len(target['input_ids'][0]) > 0 and target['input_ids'][0][0] in tokenizer.all_special_ids:\n            target['input_ids'] = [i[1:] for i in target['input_ids']]\n            target['attention_mask'] = [a[1:]\n                                        for a in target['attention_mask']]\n\n        out = {}\n        out['input_ids'] = [i1 + i2 for i1,\n                                        i2 in zip(context['input_ids'], target['input_ids'])]\n        out['attention_mask'] = [a1 + a2 for a1,\n                                             a2 in zip(context['attention_mask'], target['attention_mask'])]\n\n        # set -100 for context tokens\n        out[\"labels\"] = [\n            [-100] * len(i1) + i2 for i1, i2 in zip(context['input_ids'], target['input_ids'])]\n\n        return out\n\n    tokenized_datasets = raw_datasets.map(\n        tokenize_function,\n        batched=True,\n        num_proc=4,\n        remove_columns=column_names,\n        load_from_cache_file=True,\n        desc=\"Running tokenizer on dataset\",\n    )\n\n    # pad all instances in lm_datasets to the max length of the dataset\n    max_length = -1\n    for v in tokenized_datasets.values():\n        for x in v:\n            max_length = max(max_length, len(x['input_ids']))\n\n    # pad to the multiple of 8\n    max_length = (max_length // 8 + 1) * 8\n\n    block_size = max_len\n    max_length = min(max_length, block_size)\n\n    def pad_function(examples):\n        examples[\"input_ids\"] = [i + [tokenizer.pad_token_id] *\n                                 (max_length - len(i)) for i in examples[\"input_ids\"]]\n        examples[\"attention_mask\"] = [[1] * len(i) + [0] *\n                                      (max_length - len(i)) for i in examples[\"attention_mask\"]]\n        examples[\"labels\"] = [i + [-100] *\n                              (max_length - len(i)) for i in examples[\"labels\"]]\n        # truncate to max_length\n        examples[\"input_ids\"] = [i[:max_length] for i in examples[\"input_ids\"]]\n        examples[\"attention_mask\"] = [a[:max_length]\n                                      for a in examples[\"attention_mask\"]]\n        examples[\"labels\"] = [l[:max_length] for l in examples[\"labels\"]]\n        return examples\n\n    tokenized_datasets = tokenized_datasets.map(\n        pad_function,\n        batched=True,\n        num_proc=4,\n        load_from_cache_file=True,\n        desc=f\"Padding dataset to max length {max_length}\",\n    )\n\n    if save_path is not None:\n        tokenized_datasets.save_to_disk(save_path)\n\n    return tokenized_datasets\n\n\nclass QaDataset(Dataset):\n\n    def __init__(self,\n                 tokenizer_name_or_path,\n                 select_num=None,\n                 start_idx=None,\n                 need_preprocess=False,\n                 dataset_name=None,\n                 data_part=\"train\",\n                 seq_max_len=1000\n                 ):\n        self.select_num = select_num\n        self.start_idx = start_idx\n        self.ds = None\n        self.need_preprocess = need_preprocess\n        self.dataset_name = dataset_name\n        self.data_part = data_part\n        self.seq_max_len = seq_max_len\n        self.return_with_idx = False\n        if 'llama' in tokenizer_name_or_path.lower():\n            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, unk_token=\"<unk>\", bos_token=\"<s>\",\n                                                           eos_token=\"</s>\", add_eos_token=True)\n            self.tokenizer.pad_token = self.tokenizer.eos_token\n        else:\n            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)\n        if 'gpt2' in tokenizer_name_or_path.lower():\n            self.tokenizer.pad_token = self.tokenizer.eos_token\n\n    def load(self, path):\n        local_data = load_from_disk(path)\n        if not self.need_preprocess:\n            self.ds = local_data[self.data_part]\n        else:\n            tokenized_ds = tokenize_qa_dataset(\n                dataset_name=self.dataset_name,\n                tokenizer=self.tokenizer,\n                seq_max_len=self.seq_max_len,\n                data_part=self.data_part,\n                dataset=local_data\n            )\n\n            self.ds = tokenized_ds[self.data_part]\n\n        if self.select_num is not None:\n            if self.start_idx is not None:\n                self.ds = self.ds.select(range(self.start_idx, min(len(self.ds), self.start_idx + self.select_num)))\n            else:\n                self.ds = self.ds.select(range(self.select_num))\n\n    def set_return_with_idx(self):\n        self.return_with_idx = True\n\n    def reset_return_with_idx(self):\n        self.return_with_idx = False\n\n    def __len__(self):\n        return len(self.ds)\n\n    def __getitem__(self, idx):\n        if self.return_with_idx:\n            return {\n                \"idx\": idx,\n                \"inputs\": self.ds[idx]\n            }\n        else:\n            return self.ds[idx]\n"
  },
  {
    "path": "python/fate_llm/dataset/seq_cls_dataset.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom fate.ml.nn.dataset.base import Dataset\nimport pandas as pd\nimport torch as t\nfrom transformers import AutoTokenizer\nimport os\nimport numpy as np\n\n# avoid tokenizer parallelism\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\n\nclass SeqCLSDataset(Dataset):\n    \"\"\"\n    A Dataset for some basic NLP Tasks, this dataset will automatically transform raw text into word indices\n    using AutoTokenizer from transformers library,\n\n    Parameters\n    ----------\n    truncation bool, truncate word sequence to 'text_max_length'\n    text_max_length int, max length of word sequences\n    tokenizer_name_or_path str, name of bert tokenizer(see transformers official for details) or path to local\n                                transformer tokenizer folder\n    return_label bool, return label or not, this option is for host dataset, when running hetero-NN\n    padding bool, whether to pad the word sequence to 'text_max_length'\n    padding_side str, 'left' or 'right', where to pad the word sequence\n    pad_token str, pad token, use this str as pad token, if None, use tokenizer.pad_token\n    return_input_ids bool, whether to return input_ids or not, if False, return word_idx['input_ids']\n    \"\"\"\n\n    def __init__(\n            self,\n            truncation=True,\n            text_max_length=128,\n            tokenizer_name_or_path=\"bert-base-uncased\",\n            return_label=True,\n            padding=True,\n            padding_side=\"right\",\n            pad_token=None,\n            return_input_ids=True):\n\n        super(SeqCLSDataset, self).__init__()\n        self.text = None\n        self.word_idx = None\n        self.label = None\n        self.tokenizer = None\n        self.sample_ids = None\n        self.padding = padding\n        self.truncation = truncation\n        self.max_length = text_max_length\n        self.with_label = return_label\n        self.tokenizer_name_or_path = tokenizer_name_or_path\n        self.tokenizer = AutoTokenizer.from_pretrained(\n            self.tokenizer_name_or_path)\n        self.tokenizer.padding_side = padding_side\n        self.return_input_ids = return_input_ids\n        if pad_token is not None:\n            self.tokenizer.add_special_tokens({'pad_token': pad_token})\n\n    def load(self, file_path):\n\n        tokenizer = self.tokenizer\n        self.text = pd.read_csv(file_path)\n        text_list = list(self.text.text)\n\n        self.word_idx = tokenizer(\n            text_list,\n            padding=self.padding,\n            return_tensors='pt',\n            truncation=self.truncation,\n            max_length=self.max_length)\n\n        if self.return_input_ids:\n            self.word_idx = self.word_idx['input_ids']\n\n        if self.with_label:\n            self.label = t.Tensor(self.text.label).detach().numpy()\n            self.label = self.label.reshape((len(self.text), -1))\n\n        if 'id' in self.text:\n            self.sample_ids = self.text['id'].values.tolist()\n\n    def get_classes(self):\n        return np.unique(self.label).tolist()\n\n    def get_vocab_size(self):\n        return self.tokenizer.vocab_size\n\n    def get_sample_ids(self):\n        return self.sample_ids\n\n    def __getitem__(self, item):\n\n        if self.return_input_ids:\n            ret = self.word_idx[item]\n        else:\n            ret = {k: v[item] for k, v in self.word_idx.items()}\n\n        if self.with_label:\n            return ret, self.label[item]\n\n        return ret\n\n    def __len__(self):\n        return len(self.text)\n\n    def __repr__(self):\n        return self.tokenizer.__repr__()"
  },
  {
    "path": "python/fate_llm/evaluate/__init__.py",
    "content": ""
  },
  {
    "path": "python/fate_llm/evaluate/scripts/__init__.py",
    "content": ""
  },
  {
    "path": "python/fate_llm/evaluate/scripts/_options.py",
    "content": "import time\n\nimport click\n\nfrom ..utils.config import parse_config, default_eval_config\nfrom ..utils.config import _set_namespace\n\n\ndef parse_custom_type(value):\n    parts = value.split('=')\n    if len(parts) == 2 and parts[1].isdigit():\n        return parts[0], int(parts[1])\n    elif len(parts) == 2 and isinstance(parts[1], str):\n        return parts[0], parts[1]\n    else:\n        raise click.BadParameter('Invalid input format. Use \"str=int\" or \"str=str\".')\n\n\nclass LlmSharedOptions(object):\n    _options = {\n        \"eval_config\": (('-c', '--eval_config'),\n                        dict(type=click.Path(exists=True), help=f\"Manual specify config file\", default=None),\n                        default_eval_config().__str__()),\n        \"yes\": (('-y', '--yes',), dict(type=bool, is_flag=True, help=\"Skip double check\", default=None),\n                False),\n        \"namespace\": (('-n', '--namespace'),\n                      dict(type=str, help=f\"Manual specify fate llm namespace\", default=None),\n                      time.strftime('%Y%m%d%H%M%S'))\n    }\n\n    def __init__(self):\n        self._options_kwargs = {}\n\n    def __getitem__(self, item):\n        return self._options_kwargs[item]\n\n    def get(self, k, default=None):\n        v = self._options_kwargs.get(k, default)\n        if v is None and k in self._options:\n            v = self._options[k][2]\n        return v\n\n    def update(self, **kwargs):\n        for k, v in kwargs.items():\n            if v is not None:\n                self._options_kwargs[k] = v\n\n    def post_process(self):\n        # add defaults here\n        for k, v in self._options.items():\n            if self._options_kwargs.get(k, None) is None:\n                self._options_kwargs[k] = v[2]\n\n        # update config\n        config = parse_config(self._options_kwargs['eval_config'])\n        self._options_kwargs['eval_config'] = config\n\n        _set_namespace(self._options_kwargs['namespace'])\n\n    @classmethod\n    def get_shared_options(cls, hidden=False):\n        def shared_options(f):\n            for name, option in cls._options.items():\n                f = click.option(*option[0], **dict(option[1], hidden=hidden))(f)\n            return f\n\n        return shared_options\n"
  },
  {
    "path": "python/fate_llm/evaluate/scripts/config_cli.py",
    "content": "#\n#  Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n\n\nimport click\nimport yaml\nfrom pathlib import Path\nfrom ..utils.config import create_eval_config, default_eval_config\nfrom ._options import LlmSharedOptions\nfrom ..utils._io import echo\n\n@click.group(\"eval_config\", help=\"fate_llm evaluate config\")\ndef eval_config_group():\n    \"\"\"\n    eval_config fate_llm\n    \"\"\"\n    pass\n\n\n@eval_config_group.command(name=\"new\")\ndef _new():\n    \"\"\"\n    create new fate_llm eval config from template\n    \"\"\"\n    create_eval_config(Path(\"llm_eval_config.yaml\"))\n    click.echo(f\"create eval_config file: llm_eval_config.yaml\")\n\n\n@eval_config_group.command(name=\"edit\")\n@LlmSharedOptions.get_shared_options(hidden=True)\n@click.pass_context\ndef _edit(ctx, **kwargs):\n    \"\"\"\n    edit fate_llm eval_config file\n    \"\"\"\n    ctx.obj.update(**kwargs)\n    eval_config = ctx.obj.get(\"eval_config\")\n    print(f\"eval_config: {eval_config}\")\n    click.edit(filename=eval_config)\n\n\n@eval_config_group.command(name=\"show\")\ndef _show():\n    \"\"\"\n    show fate_test default eval_config path\n    \"\"\"\n    click.echo(f\"default eval_config path is {default_eval_config()}\")\n"
  },
  {
    "path": "python/fate_llm/evaluate/scripts/data_cli.py",
    "content": "#\n#  Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n\nimport os\nimport copy\nimport click\nimport yaml\nimport warnings\n\nfrom typing import Union\nfrom ._options import LlmSharedOptions\nfrom ..utils.llm_evaluator import download_task\nfrom ..utils._io import echo\n\n@click.command('download_data')\n@click.option('-t', '--tasks', required=False, type=str, multiple=True, default=None,\n              help='tasks whose data will be downloaded')\n# @click.argument('other_args', nargs=-1)\n@LlmSharedOptions.get_shared_options(hidden=True)\n@click.pass_context\ndef download_data(ctx, tasks, **kwargs):\n    \"\"\"\n    Evaluate a pretrained model with specified parameters.\n    \"\"\"\n    ctx.obj.update(**kwargs)\n    ctx.obj.post_process()\n\n    if tasks is None or len(tasks) == 0:\n        tasks = None\n        echo.echo(f\"No task is given, will download data for all built-in tasks.\", fg='red')\n    else:\n        echo.echo(f\"given tasks: {tasks}\", fg='red')\n    download_task(tasks)\n"
  },
  {
    "path": "python/fate_llm/evaluate/scripts/eval_cli.py",
    "content": "#\n#  Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n\nimport os\nimport copy\nimport click\nimport yaml\nimport warnings\n\nfrom typing import Union\nfrom ._options import LlmSharedOptions\nfrom ..utils.config import default_eval_config\nfrom ..utils.llm_evaluator import evaluate, init_tasks, aggregate_table\nfrom ..utils.model_tools import load_by_loader\nfrom ..utils._io import echo\nfrom ..utils._parser import LlmSuite\n\n@click.command('evaluate')\n@click.option('-i', '--include', required=True, type=click.Path(exists=True),\n              help='Path to model and metrics conf')\n@click.option('-c', '--eval-config', type=click.Path(exists=True), help='Path to FATE Llm evaluation config. '\n                                                                        'If not provided, use default config.')\n@click.option('-o', '--result-output', type=click.Path(),\n              help='Path to save evaluation results.')\n# @click.argument('other_args', nargs=-1)\n@LlmSharedOptions.get_shared_options(hidden=True)\n@click.pass_context\ndef run_evaluate(ctx, include, eval_config, result_output, **kwargs):\n    \"\"\"\n    Evaluate a pretrained model with specified parameters.\n    \"\"\"\n    ctx.obj.update(**kwargs)\n    ctx.obj.post_process()\n    # namespace = ctx.obj[\"namespace\"]\n    yes = ctx.obj[\"yes\"]\n\n    echo.echo(f\"include: {include}\", fg='red')\n    try:\n        # include = os.path.abspath(include)\n        suite = LlmSuite.load(include)\n    except Exception as e:\n        raise ValueError(f\"Invalid include path: {include}, please check. {e}\")\n\n    if not eval_config:\n        eval_config = default_eval_config()\n\n    if not os.path.exists(eval_config):\n        eval_config = None\n\n    if not yes and not click.confirm(\"running?\"):\n        return\n    # init tasks\n    init_tasks()\n    # run_suite_eval(suite, eval_config_dict, result_output)\n    run_suite_eval(suite, eval_config, result_output)\n\ndef run_job_eval(job, eval_conf):\n    job_eval_conf = {}\n    if isinstance(eval_conf, dict):\n        job_eval_conf.update(eval_conf)\n    elif eval_conf is not None and os.path.exists(eval_conf):\n        with open(eval_conf, 'r') as f:\n            job_eval_conf.update(yaml.safe_load(f))\n\n    # echo.echo(f\"Evaluating job: {job.job_name} with tasks: {job.tasks}\")\n    if job.eval_conf_path:\n        # job-level eval conf takes priority\n        with open(job.eval_conf_path, 'r') as f:\n            job_eval_conf.update(yaml.safe_load(f))\n    # get loader\n    if job.loader:\n        if job.peft_path:\n            model = load_by_loader(loader_name=job.loader,\n                                   loader_conf_path=loader_conf_path,\n                                   peft_path=job.peft_path)\n        else:\n            model = load_by_loader(loader_name=job.loader,\n                                   loader_conf_path=loader_conf_path)\n        result = evaluate(model=model, tasks=job.tasks, include_path=job.include_path, **job_eval_conf)\n    else:\n        # feed in pretrained & peft path\n        job_eval_conf[\"model_args\"][\"pretrained\"] = job.pretrained_model_path\n        if job.peft_path:\n            job_eval_conf[\"model_args\"][\"peft\"] = job.peft_path\n        result = evaluate(tasks=job.tasks, include_path=job.include_path, **job_eval_conf)\n    return result\n\n\ndef run_suite_eval(suite, eval_conf, output_path=None):\n    suite_results = dict()\n    for pair in suite.pairs:\n        job_results = dict()\n        for job in pair.jobs:\n            if not job.evaluate_only:\n                # give warning that job will be skipped\n                warnings.warn(f\"Job {job.job_name} will be skipped since no pretrained model is provided\")\n                continue\n            echo.echo(f\"Evaluating job: {job.job_name} with tasks: {job.tasks}\")\n            result = run_job_eval(job, eval_conf)\n            job_results[job.job_name] = result\n        suite_results[pair.pair_name] = job_results\n    suite_writers = aggregate_table(suite_results)\n    for pair_name, pair_writer in suite_writers.items():\n        echo.sep_line()\n        echo.echo(f\"Pair: {pair_name}\")\n        echo.sep_line()\n        echo.echo(pair_writer.dumps())\n        echo.stdout_newline()\n\n    if output_path:\n        with open(output_path, 'w') as f:\n            for pair_name, pair_writer in suite_writers.items():\n                pair_writer.dumps(f)\n"
  },
  {
    "path": "python/fate_llm/evaluate/scripts/fate_llm_cli.py",
    "content": "#\n#  Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n\n\nimport click\nimport yaml\n\nfrom typing import Union\nfrom .eval_cli import run_evaluate\nfrom .config_cli import eval_config_group\nfrom .data_cli import download_data\nfrom ._options import LlmSharedOptions\n\n\ncommands = {\n    \"evaluate\": run_evaluate,\n    \"config\": eval_config_group,\n    \"download\": download_data\n}\n\n\nclass FATELlmCLI(click.MultiCommand):\n\n    def list_commands(self, ctx):\n        return list(commands)\n\n    def get_command(self, ctx, name):\n        if name not in commands and name in commands_alias:\n            name = commands_alias[name]\n        if name not in commands:\n            ctx.fail(\"No such command '{}'.\".format(name))\n        return commands[name]\n\n@click.command(cls=FATELlmCLI, help=\"A collection of tools to run FATE Llm Evaluation.\",\n               context_settings=dict(help_option_names=[\"-h\", \"--help\"]))\n@LlmSharedOptions.get_shared_options()\n@click.pass_context\ndef fate_llm_cli(ctx, **kwargs):\n    ctx.ensure_object(LlmSharedOptions)\n    ctx.obj.update(**kwargs)\n\n\nif __name__ == '__main__':\n    fate_llm_cli(obj=LlmSharedOptions())"
  },
  {
    "path": "python/fate_llm/evaluate/tasks/__init__.py",
    "content": "#\n#  Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n\nimport yaml\nimport os\n\n\ndef local_fn_constructor(loader, node):\n    return node\n\n\ndef local_fn_representer(dumper, data):\n    return data\n\n\ndef dump_yaml(dict, path):\n    yaml.add_representer(yaml.ScalarNode, local_fn_representer)\n    with open(path, 'w') as f:\n        yaml.dump(dict, f)\n\nclass Task:\n    _task_name = \"\"\n    _task_dir = \"\"\n    _task_conf_file = \"\"\n    _task_source_url = \"\"\n    script_dir = os.path.dirname(__file__)\n\n    @property\n    def task_name(self):\n        return self._task_name\n\n    @property\n    def task_template(self):\n        yaml.add_constructor(\"!function\", local_fn_constructor)\n        with open(os.path.abspath(os.path.join(self.script_dir, self._task_dir, self._task_conf_file)), \"rb\") as f:\n            task_template = yaml.full_load(f)\n        return task_template\n\n    @property\n    def task_scr_dir(self):\n        return os.path.abspath(os.path.join(self.script_dir, self._task_dir))\n\n    @property\n    def task_conf_path(self):\n        return os.path.abspath(os.path.join(self.script_dir, self._task_dir, self._task_conf_file))\n\n    @property\n    def task_source_url(self):\n        return self._task_source_url\n\n    def download_from_source(self):\n        raise NotImplementedError(f\"Should not be called here.\")\n\n\nclass Dolly(Task):\n    _task_name = \"dolly-15k\"\n    _task_dir = \"dolly_15k\"\n    _task_conf_file = \"default_dolly_15k.yaml\"\n\n    def download_from_source(self):\n        try:\n            from datasets import load_dataset\n            data = load_dataset(\"databricks/databricks-dolly-15k\", split=\"train\")\n            filename = os.path.join(self.task_scr_dir, \"databricks-dolly-15k.jsonl\")\n            data.to_json(filename)\n            return True\n        except Exception as e:\n            print(f\"Failed to download data from source: {e}\")\n            return False\n\n\nclass AdvertiseGen(Task):\n    _task_name = \"advertise-gen\"\n    _task_dir = \"advertise_gen\"\n    _task_conf_file = \"default_advertise_gen.yaml\"\n    _task_source_url = [\"https://cloud.tsinghua.edu.cn/seafhttp/files/3781289a-5a60-44b1-b5f1-a04364e3eb9d/AdvertiseGen.tar.gz\",\n                        \"https://docs.google.com/uc?export=download&id=13_vf0xRTQsyneRKdD1bZIr93vBGOczrk\"]\n\n    def download_from_source(self):\n        from ..utils.data_tools import download_data\n        result = download_data(self.task_scr_dir, self.task_source_url[0])\n        if not result:\n            print(f\"retry with address: {self.task_source_url[1]}\")\n            return download_data(self.task_scr_dir, self.task_source_url[1])\n        return result\n\n\nbuild_in_tasks = {\"dolly-15k\": Dolly(),\n                  \"advertise-gen\": AdvertiseGen()}\n"
  },
  {
    "path": "python/fate_llm/evaluate/tasks/advertise_gen/__init__.py",
    "content": ""
  },
  {
    "path": "python/fate_llm/evaluate/tasks/advertise_gen/advertise_utils.py",
    "content": "# adopted from https://github.com/huggingface/datasets/blob/main/metrics/rouge/rouge.py\n\n\nfrom rouge_score import rouge_scorer\n# from multiprocessing import Pool\n\n\ndef rouge_l(predictions, references, use_stemmer=False):\n    scorer = rouge_scorer.RougeScorer(rouge_types=['rougeL'], use_stemmer=use_stemmer)\n    scores = []\n    for ref, pred in zip(references, predictions):\n        score = scorer.score(ref, pred)\n        scores.append(score)\n\n    rouge_l_score = scores[0]['rougeL'].fmeasure\n    return rouge_l_score\n"
  },
  {
    "path": "python/fate_llm/evaluate/tasks/advertise_gen/default_advertise_gen.yaml",
    "content": "dataset_kwargs:\n  data_files:\n    train: train.json\n    validation: dev.json\ndataset_path: json\ndoc_to_target: '{{summary}}'\ndoc_to_text: '{{content}}'\nmetric_list:\n- aggregation: mean\n  higher_is_better: true\n  metric: !function 'advertise_utils.rouge_l'\noutput_type: generate_until\ntask: advertise-gen\nvalidation_split: validation\n"
  },
  {
    "path": "python/fate_llm/evaluate/tasks/dolly_15k/__init__.py",
    "content": ""
  },
  {
    "path": "python/fate_llm/evaluate/tasks/dolly_15k/default_dolly_15k.yaml",
    "content": "dataset_kwargs:\n  data_files: databricks-dolly-15k.jsonl\ndataset_path: json\ndoc_to_target: '{{response}}'\ndoc_to_text: !function 'dolly_utils.doc_to_text'\nmetric_list:\n- aggregation: mean\n  higher_is_better: true\n  metric: !function 'dolly_utils.rouge_l'\noutput_type: generate_until\ntask: dolly-15k\nvalidation_split: train\n"
  },
  {
    "path": "python/fate_llm/evaluate/tasks/dolly_15k/dolly_utils.py",
    "content": "# adopted from https://github.com/huggingface/datasets/blob/main/metrics/rouge/rouge.py\n\n\nfrom rouge_score import rouge_scorer\n\n\ndef rouge_l(predictions, references, use_stemmer=False):\n    scorer = rouge_scorer.RougeScorer(rouge_types=['rougeL'], use_stemmer=use_stemmer)\n    scores = []\n    for ref, pred in zip(references, predictions):\n        score = scorer.score(ref, pred)\n        scores.append(score)\n\n    rouge_l_score = scores[0]['rougeL'].fmeasure\n    return rouge_l_score\n\ndef doc_to_text(doc):\n    if doc[\"context\"]:\n        return f\"context: {doc['context']}\\ninstruction: {doc['instruction']}\\nresponse:\"\n    else:\n        return f\"instruction: {doc['instruction']}\\nresponse:\"\n\n\"\"\"\ndef train_load_evalaute_lm():\n    pipeline.fit(train_data)\n    lm = OTModelLoader().load(path, **args)\n    from fate_llm.evaluator import evaluator\n    # general case\n    evaluator.evaluate(lm, task=\"dolly_15k\", **args)\n\n    # user modified conf\n    config = evaluator.get_task_template(task=\"dolly_15k\") # return dict copy of yaml file\n    config['dataset_kwargs'] = {\"dataset_kwargs\":\n                                    {\"data_files\":\n                                         {\"test\": './dolly_15k_test.csv',\n                                          \"dev\": './dolly_15k_dev.csv'}}}\n    # may provide arbitrary export path, must be of dir, create temp dir under the given path: {$export_path}/temp_dir\n    new_task_dir = evaluator.export_config(config, task=\"dolly_15k\", export_path=None)\n    result = evaluator.evalute(lm, task=\"dolly_15k\", include_path=new_task_dir, **args)\n    print(result) # dict\n    evaluator.delete_config(new_task_dir)\n\"\"\""
  },
  {
    "path": "python/fate_llm/evaluate/utils/__init__.py",
    "content": "from ._parser import LlmJob, LlmPair, LlmSuite"
  },
  {
    "path": "python/fate_llm/evaluate/utils/_io.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport click\nimport loguru\n\n\n# noinspection PyPep8Naming\nclass echo(object):\n    _file = None\n\n    @classmethod\n    def set_file(cls, file):\n        cls._file = file\n\n    @classmethod\n    def echo(cls, message, **kwargs):\n        click.secho(message, **kwargs)\n        click.secho(message, file=cls._file, **kwargs)\n\n    @classmethod\n    def sep_line(cls):\n        click.secho(\"-------------------------------------------------\")\n\n    @classmethod\n    def file(cls, message, **kwargs):\n        click.secho(message, file=cls._file, **kwargs)\n\n    @classmethod\n    def stdout(cls, message, **kwargs):\n        click.secho(message, **kwargs)\n\n    @classmethod\n    def stdout_newline(cls):\n        click.secho(\"\")\n\n    @classmethod\n    def welcome(cls):\n\n        cls.echo(\"Welcome to FATE Llm Evaluator\")\n\n    @classmethod\n    def flush(cls):\n        import sys\n        sys.stdout.flush()\n\n\ndef set_logger(name):\n    loguru.logger.remove()\n    loguru.logger.add(name, level='ERROR', delay=True)\n    return loguru.logger\n\n\nLOGGER = loguru.logger\n"
  },
  {
    "path": "python/fate_llm/evaluate/utils/_parser.py",
    "content": "#\n#  Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n\nimport os\nimport yaml\nimport typing\nfrom pathlib import Path\n\n\nclass LlmJob(object):\n    def __init__(self, job_name: str, script_path: Path=None, conf_path: Path=None, model_task_name: str=None,\n                 pretrained_model_path: Path=None, peft_path: Path=None,\n                 eval_conf_path: Path=None, loader: str=None, loader_conf_path: Path=None,\n                 tasks: typing.List[str]=None, include_path: Path=None, peft_path_format: str=None):\n        self.job_name = job_name\n        self.script_path = script_path\n        self.conf_path = conf_path\n        self.model_task_name = model_task_name\n        self.pretrained_model_path = pretrained_model_path\n        self.peft_path = peft_path\n        self.loader = loader\n        self.loader_conf_path = loader_conf_path\n        self.eval_conf_path = eval_conf_path\n        self.tasks = tasks\n        self.include_path = include_path\n        self.evaluate_only = self.script_path is None\n        self.peft_path_format = peft_path_format\n\n\nclass LlmPair(object):\n    def __init__(\n            self, pair_name: str, jobs: typing.List[LlmJob]\n    ):\n        self.pair_name = pair_name\n        self.jobs = jobs\n\n\nclass LlmSuite(object):\n    def __init__(\n            self, pairs: typing.List[LlmPair], path: Path, dataset=None\n    ):\n        self.pairs = pairs\n        self.path = path\n        self.dataset = dataset\n        self._final_status = {}\n\n    @staticmethod\n    def load(path: Path):\n        if isinstance(path, str):\n            path = Path(path)\n        with path.open(\"r\") as f:\n            testsuite_config = yaml.safe_load(f)\n\n        pairs = []\n        for pair_name, pair_configs in testsuite_config.items():\n            if pair_name == \"data\":\n                continue\n            jobs = []\n            for job_name, job_configs in pair_configs.items():\n                # with train\n                script_path = job_configs.get(\"script\", None)\n                if script_path and not os.path.isabs(script_path):\n                    script_path = path.parent.joinpath(script_path).resolve()\n\n                conf_path = job_configs.get(\"conf\", None)\n                if conf_path and not os.path.isabs(conf_path):\n                    conf_path = path.parent.joinpath(conf_path).resolve()\n\n                model_task_name = job_configs.get(\"model_task_name\", None)\n\n                # evaluate only\n                pretrained_model_path = job_configs.get(\"pretrained\", None)\n                if pretrained_model_path and not os.path.isabs(pretrained_model_path):\n                    # make path absolute, else keep original pretrained model name\n                    if \"yaml\" in pretrained_model_path or \"/\" in pretrained_model_path:\n                        pretrained_model_path = path.parent.joinpath(pretrained_model_path).resolve()\n\n                peft_path = job_configs.get(\"peft\", None)\n                if peft_path and not os.path.isabs(peft_path):\n                    peft_path = path.parent.joinpath(peft_path).resolve()\n\n                eval_conf_path = job_configs.get(\"eval_conf\", None)\n                if eval_conf_path and not os.path.isabs(eval_conf_path):\n                    eval_conf_path = path.parent.joinpath(eval_conf_path).resolve()\n\n                loader = job_configs.get(\"loader\", None)\n                if job_configs.get(\"loader_conf\"):\n                    loader_conf_path = path.parent.joinpath(job_configs[\"loader_conf\"]).resolve()\n                else:\n                    loader_conf_path = \"\"\n                tasks = job_configs.get(\"tasks\", [])\n                include_path = job_configs.get(\"include_path\", \"\")\n                if include_path and not os.path.isabs(include_path):\n                    include_path = path.parent.joinpath(job_configs[\"include_path\"]).resolve()\n\n                peft_path_format = job_configs.get(\"peft_path_format\", \"{{fate_base}}/fate_flow/model/{{job_id}}/\"\n                                                                       \"guest/{{party_id}}/{{model_task_name}}/0/\"\n                                                                       \"output/output_model/model_directory\")\n\n                jobs.append(\n                    LlmJob(\n                        job_name=job_name, script_path=script_path, conf_path=conf_path,\n                        model_task_name=model_task_name,\n                        pretrained_model_path=pretrained_model_path, peft_path=peft_path, eval_conf_path=eval_conf_path,\n                        loader=loader, loader_conf_path=loader_conf_path, tasks=tasks, include_path=include_path,\n                        peft_path_format=peft_path_format\n                    )\n                )\n\n            pairs.append(\n                LlmPair(\n                    pair_name=pair_name, jobs=jobs\n                )\n            )\n        suite = LlmSuite(pairs=pairs, path=path)\n        return suite\n\n    def update_status(\n            self, pair_name, job_name, job_id=None, status=None, exception_id=None, time_elapsed=None, event=None\n    ):\n        for k, v in locals().items():\n            if k != \"job_name\" and k != \"pair_name\" and v is not None:\n                if self._final_status.get(f\"{pair_name}-{job_name}\"):\n                    setattr(self._final_status[f\"{pair_name}-{job_name}\"], k, v)\n\n    def get_final_status(self):\n        return self._final_status\n"
  },
  {
    "path": "python/fate_llm/evaluate/utils/config.py",
    "content": "#\n#  Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n\nimport os\nimport click\nimport yaml\nimport typing\nfrom pathlib import Path\nfrom ._io import set_logger, echo\n\n\nDEFAULT_FATE_LLM_BASE_PATH = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))\nFATE_LLM_BASE_PATH = os.getenv(\"FATE_LLM_BASE_PATH\") or DEFAULT_FATE_LLM_BASE_PATH\n\n# DEFAULT_TASK_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), \"../tasks\"))\nDEFAULT_FATE_LLM_TASK_PATH = os.path.abspath(os.path.join(FATE_LLM_BASE_PATH, \"tasks\"))\nFATE_LLM_TASK_PATH = os.getenv(\"FATE_LLM_TASK_PATH\") or DEFAULT_FATE_LLM_TASK_PATH\n\n_default_eval_config =  Path(FATE_LLM_BASE_PATH).resolve() / 'llm_eval_config.yaml'\n\ntemplate = \"\"\"# args for evaluate\nbatch_size: 10\nmodel_args:\n    device: cuda\n    dtype: auto\n    trust_remote_code: true\nnum_fewshot: 0\n\"\"\"\n\n\ndef create_eval_config(path: Path, override=False):\n    if path.exists() and not override:\n        raise FileExistsError(f\"{path} exists\")\n\n    with path.open(\"w\") as f:\n        f.write(template)\n\n\ndef default_eval_config():\n    if not _default_eval_config.exists():\n        create_eval_config(_default_eval_config)\n    return _default_eval_config\n\n\nclass Config(object):\n    def __init__(self, config):\n        self.update_conf(**config)\n\n    def update_conf(self, **kwargs):\n        for k, v in kwargs.items():\n            setattr(self, k, v)\n\n    @staticmethod\n    def load(path: typing.Union[str, Path], **kwargs):\n        if isinstance(path, str):\n            path = Path(path)\n        config = {}\n        if path is not None:\n            with path.open(\"r\") as f:\n                config.update(yaml.safe_load(f))\n\n        config.update(kwargs)\n        return Config(config)\n\n    @staticmethod\n    def load_from_file(path: typing.Union[str, Path]):\n        \"\"\"\n        Loads conf content from yaml file. Used to read in parameter configuration\n        Parameters\n        ----------\n        path: str, path to conf file, should be absolute path\n\n        Returns\n        -------\n        dict, parameter configuration in dictionary format\n\n        \"\"\"\n        if isinstance(path, str):\n            path = Path(path)\n        config = {}\n        if path is not None:\n            file_type = path.suffix\n            with path.open(\"r\") as f:\n                if file_type == \".yaml\":\n                    config.update(yaml.safe_load(f))\n                else:\n                    raise ValueError(f\"Cannot load conf from file type {file_type}\")\n        return config\n\n\ndef parse_config(config):\n    try:\n        config_inst = Config.load(config)\n    except Exception as e:\n        raise RuntimeError(f\"error parse config from {config}\") from e\n    return config_inst\n\n\ndef _set_namespace(namespace):\n    Path(f\"logs/{namespace}\").mkdir(exist_ok=True, parents=True)\n    set_logger(f\"logs/{namespace}/exception.log\")\n    echo.set_file(click.open_file(f'logs/{namespace}/stdout', \"a\"))\n"
  },
  {
    "path": "python/fate_llm/evaluate/utils/data_tools.py",
    "content": "#\n#  Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n\n\ndef download_data(data_dir, data_url, is_tar=True):\n    import os\n    import requests\n    import tarfile\n    import io\n\n    # Create data directory\n    if not os.path.exists(data_dir):\n        os.makedirs(data_dir)\n\n    # Download data\n    try:\n        response = requests.get(data_url)\n        if response.status_code == 200:\n            if is_tar:\n                # extract tar file and write to data_dir\n                with tarfile.open(fileobj=io.BytesIO(response.content), mode='r:gz') as tar:\n                    for member in tar.getmembers():\n                        # check if member is a file\n                        if member.isreg():\n                            member.name = os.path.join(data_dir, os.path.basename(member.name))\n                            tar.extract(member)\n            else:\n                # write to data_dir\n                with open(os.path.join(data_dir, os.path.basename(data_url)), 'wb') as f:\n                    f.write(response.content)\n            return True\n        else:\n            print(f\"Error downloading file: {response.status_code}\")\n            return False\n\n    except Exception as e:\n        print(f\"Error downloading file: {e}\")\n    return False\n"
  },
  {
    "path": "python/fate_llm/evaluate/utils/llm_evaluator.py",
    "content": "#\n#  Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n\n# this file is used to evaluate the model on fate-llm built-in tasks and user-given tasks\n\nimport os\nimport tempfile\nimport yaml\nimport shutil\nimport warnings\nfrom pytablewriter import MarkdownTableWriter\n\nimport lm_eval\nfrom lm_eval.utils import load_yaml_config\nfrom ..tasks import build_in_tasks, dump_yaml\nfrom .config import FATE_LLM_BASE_PATH, FATE_LLM_TASK_PATH\n\n\ndef evaluate(tasks, model=\"hf\", model_args=None, include_path=None, task_manager=None, show_result=False, **kwargs):\n    \"\"\"\n    Evaluate the model on given tasks. Simplified uses for built-in tasks.\n    Parameters\n    ----------\n    tasks: str or List[str], task name(s)\n    model: str or model object, model to be evaluated,\n        select from lm_eval supported types: {\"hf-auto\", \"hf\", \"huggingface\", \"vllm\"}\n    model_args: model args, str or dict\n    include_path: task path for tasks not in built-in tasks\n    task_manager: lm_eval.TakManger object\n    kwargs\n\n    Returns\n    -------\n\n    \"\"\"\n    if task_manager:\n        if not isinstance(task_manager, lm_eval.tasks.TaskManager):\n            raise ValueError(f\"'task_manager' must be of TaskManager type.\")\n    elif include_path:\n        task_manager = lm_eval.tasks.TaskManager(include_path=str(include_path))\n    else:\n        task_manager = lm_eval.tasks.TaskManager(include_path=str(FATE_LLM_TASK_PATH))\n    task_names = []\n    if isinstance(tasks, str):\n        task_names.append(tasks)\n\n    elif isinstance(tasks, list):\n        for task in tasks:\n            if isinstance(task, str):\n                task_names.append(task)\n            else:\n                raise ValueError(f\"tasks: {task}  of type {type(task)} not valid, please check.\")\n\n    else:\n        raise ValueError(f\"tasks: {tasks}  of type {type(tasks)} not valid, please check.\")\n\n    results = lm_eval.simple_evaluate(\n        model=model,\n        model_args=model_args,\n        tasks=task_names,\n        task_manager=task_manager,\n        **kwargs\n    )\n    if show_result:\n        result_table = lm_eval.utils.make_table(results)\n        print(result_table)\n    return results\n\n\ndef aggregate_table(results):\n    \"\"\"\n    adapted from lm_eval.utils.make_table:\n    https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.2/lm_eval/utils.py    Aggregate results from different models with same tasks\n    Parameters\n    ----------\n    results: dict, results from different models\n\n    Returns\n    -------\n\n    \"\"\"\n\n    suite_writers = dict()\n    for pair_name, pair_results in results.items():\n        # job_count = len(pair_results)\n        all_jobs = list(pair_results.keys())\n\n        md_writer = MarkdownTableWriter()\n\n        values = []\n        task_results = dict()\n        # print(f\"pair results: {pair_results}\")\n        for job_name, result_dict in pair_results.items():\n            if \"results\" in result_dict and result_dict[\"results\"]:\n                column = \"results\"\n            else:\n                column = \"groups\"\n            for k, dic in result_dict[column].items():\n\n                if \"alias\" in dic:\n                    # task alias\n                    k = dic.pop(\"alias\")\n\n                for (mf), v in dic.items():\n                    m, _, f = mf.partition(\",\")\n                    if m.endswith(\"_stderr\"):\n                        continue\n\n                    if m + \"_stderr\" + \",\" + f in dic:\n                        se = dic[m + \"_stderr\" + \",\" + f]\n                        if se != \"N/A\":\n                            se = \"%.4f\" % se\n                        v = \"%.4f ± %s\" % (v, se)\n                    else:\n                        v = \"%.4f\" % v\n                    task_results.setdefault(k, {}).setdefault(job_name, {})[m] = v\n\n        # job names as columns\n        # print(f\"task results: {task_results}\")\n        for task_name, task_result in task_results.items():\n            metrics = {inner_key for inner_dict in task_result.values() for inner_key, value in inner_dict.items()}\n            for metric in metrics:\n                row = [f\"{task_name}({metric})\"]\n                for job_name in all_jobs:\n                    if job_name in task_result:\n                        row.append(task_result[job_name].get(metric, \"N/A\"))\n                    else:\n                        row.append(\"N/A\")\n                values.append(row)\n\n        all_headers = [\"Task\"] + list(pair_results.keys())\n        md_writer.headers = all_headers\n        md_writer.value_matrix = values\n        suite_writers[pair_name] = md_writer\n    return suite_writers\n\n\ndef get_task_template(task):\n    if not isinstance(task, str) or task not in build_in_tasks:\n        raise ValueError(f\"{task} not found in build in task, please check input.\")\n    result = build_in_tasks.get(task).task_template\n\n    return result\n\n\ndef export_config(config, task, export_dir=None, export_sub_dir=None):\n    scr_dir = build_in_tasks.get(task).task_scr_dir\n    if export_dir is None:\n        export_dir = os.path.dirname(scr_dir)\n\n    if export_sub_dir is None:\n        temp_dir = tempfile.mkdtemp()\n        # make sure the relative path in new file will work\n        full_export_dir = os.path.join(export_dir, os.path.basename(temp_dir))\n        os.rename(temp_dir, full_export_dir)\n    else:\n        full_export_dir = os.path.join(export_dir, export_sub_dir)\n    copy_directory_to_dst(scr_dir, full_export_dir, build_in_tasks.get(task).task_conf_path, config)\n\n    return full_export_dir\n\n\ndef copy_directory_to_dst(src_dir, dst_dir, target_conf_file, new_conf: dict):\n    \"\"\"parent_dir = os.path.dirname(src_dir)\n\n    temp_dir = tempfile.mkdtemp()\n    # make sure the relative path in new file will work\n    temp_dir_in_parent = os.path.join(parent_dir, os.path.basename(temp_dir))\n    os.rename(temp_dir, temp_dir_in_parent)\"\"\"\n\n    for item in os.listdir(src_dir):\n\n        src_item = os.path.join(src_dir, item)\n        dst_item = os.path.join(dst_dir, item)\n        if os.path.isdir(src_item):\n            shutil.copytree(src_item, dst_item)\n        else:\n            if item == target_conf_file:\n                # write new conf file\n                dump_yaml(new_conf, dst_item)\n            else:\n                shutil.copy2(src_item, dst_item)\n            # shutil.copy2(src_item, dst_item)\n\n\ndef contains_subdirectory(path, subdirectories):\n    base_name = os.path.basename(path)\n    if base_name in subdirectories:\n        return True\n\n    for root, dirs, files in os.walk(path):\n        for d in dirs:\n            if d in subdirectories:\n                return True\n\n    return False\n\ndef delete_config(target_dir, force=False):\n    if not force:\n        # check if target dir in any of the build in tasks, only rm dir for build in tasks if force=True\n        all_build_in_dir = {task.task_scr_dir for task in build_in_tasks.values()}\n        if contains_subdirectory(target_dir, all_build_in_dir):\n            warnings.warn(f\"Built-in task(s) found in given target directory, please check input or set `force`=True.\")\n            return\n        shutil.rmtree(target_dir)\n\n\ndef set_environ_fate_llm_base(path):\n    if path:\n        os.environ[\"FATE_LLM_BASE_PATH\"] = path\n\n\ndef set_environ_fate_llm_task_base(path):\n    if path:\n        os.environ[\"FATE_LLM_TASK_PATH\"] = path\n\n\ndef init_tasks(root_path=None):\n    \"\"\"\n\n    Parameters\n    ----------\n    root_path: str, default None, root path for all local datasets in built-in tasks, {$root_path}/{$data_files};\n    if not provided, current file path will be used to generate root\n\n    Returns\n    -------\n\n    \"\"\"\n    for task in build_in_tasks.values():\n        conf_path = task.task_conf_path\n        parent_path = os.path.dirname(conf_path)\n        task_template = task.task_template\n        data_args = task_template.get(\"dataset_kwargs\")\n        if data_args:\n            data_files = data_args.get(\"data_files\")\n            if isinstance(data_files, str):\n                if data_files.endswith(\"jsonl\") or data_files.endswith(\"json\"):\n                    if root_path:\n                        parent_dir = os.path.basename(parent_path)\n                        new_conf_path = os.path.join(root_path, parent_dir, os.path.basename(conf_path))\n                    else:\n                        new_conf_path = os.path.join(parent_path, data_files)\n                    task_template[\"dataset_kwargs\"][\"data_files\"] = new_conf_path\n            elif isinstance(data_files, dict):\n                for k, v in data_files.items():\n                    if root_path:\n                        parent_dir = os.path.basename(parent_path)\n                        new_conf_path = os.path.join(root_path, parent_dir, os.path.basename(conf_path))\n                    else:\n                        new_conf_path = os.path.join(parent_path, v)\n                    task_template[\"dataset_kwargs\"][\"data_files\"][k] = new_conf_path\n\n        try:\n            dump_yaml(task_template, conf_path)\n        except FileNotFoundError:\n            raise ValueError(f\"Cannot find task config {conf_path}, please check.\")\n        except Exception:\n            raise ValueError(f\"Initialization failed.\")\n\ndef download_task(tasks=None):\n    if tasks is None:\n        tasks = list(build_in_tasks.keys())\n    i = 1\n    if isinstance(tasks, str):\n        tasks = [tasks]\n    n = len(tasks)\n    for task in tasks:\n        task_obj = build_in_tasks.get(task)\n        if task_obj is None:\n            print(f\"Task {task} not found in built-in tasks, please check.\")\n            continue\n        result = task_obj.download_from_source()\n        if result:\n            print(f\"Finish downloading {i}/{n} th task data: {task}, saved to {task_obj.task_scr_dir}.\\n\")\n        else:\n            print(f\"Failed to download {i}/{n} th task data to {task_obj.task_scr_dir}.\\n\")\n        i += 1\n"
  },
  {
    "path": "python/fate_llm/evaluate/utils/model_tools.py",
    "content": "#\n#  Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n\nimport os\nfrom transformers import AutoModel, AutoTokenizer\nfrom lm_eval.models.huggingface import HFLM\n\n\ndef load_model_from_path(model_path, peft_path=None, peft_config=None, model_args=None):\n    model_args = model_args or {}\n    if peft_path is None:\n        if os.path.isfile(model_path):\n            return HFLM(pretrained=model_path, **model_args)\n        else:\n            raise ValueError(f\"given model path is not valid, please check: {model_path}\")\n    else:\n        import torch\n        from peft import PeftModel, PeftConfig, LoraConfig, TaskType, get_peft_model\n        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n        model = AutoModel.from_pretrained(model_path, trust_remote_code=True)\n        model.half()\n        model.eval()\n        peft_config = peft_config or {}\n        peft_config=LoraConfig(**peft_config)\n        model = get_peft_model(model, peft_config)\n        model.load_state_dict(torch.load(peft_path), strict=False)\n        model.model.half()\n        HFLM(pretrained=model, tokenizer=tokenizer, **model_args)\n\n\ndef load_model(model_path, peft_path=None, model_args=None):\n    model_args = model_args or {}\n    return HFLM(pretrained=model_path, peft_path=peft_path, **model_args)\n\n\ndef load_by_loader(loader_name=None, loader_conf_path=None, peft_path=None):\n    #@todo: find loader fn & return loaded model\n    pass"
  },
  {
    "path": "python/fate_llm/inference/__init__.py",
    "content": ""
  },
  {
    "path": "python/fate_llm/inference/api.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n\nfrom fate_llm.inference.inference_base import Inference\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\nfrom transformers import GenerationConfig\nfrom typing import List\n\n\nclass APICompletionInference(Inference):\n\n    def __init__(self, api_url: str, model_name: str, api_key: str = 'EMPTY', api_timeout=3600):\n        from openai import OpenAI\n        self.model_name = model_name\n        self.client = OpenAI(\n            api_key=api_key,\n            base_url=api_url,\n            timeout=api_timeout\n        )\n\n    def inference(self, docs: List[str], inference_kwargs: dict = {}) -> List[str]:\n        completion = self.client.completions.create(model=self.model_name, prompt=docs, **inference_kwargs)\n        rs_doc = [completion.choices[i].text for i in range(len(completion.choices))]\n        return rs_doc"
  },
  {
    "path": "python/fate_llm/inference/hf_qw.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n\nfrom fate_llm.inference.inference_base import Inference\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\nfrom typing import List\nimport tqdm\n\n\nclass QwenHFCompletionInference(Inference):\n\n    def __init__(self, model, tokenizer):\n        self.model = model\n        self.tokenizer = tokenizer\n\n    def inference(self, docs: List[str], inference_kwargs: dict = {}) -> List[str]:\n        self.model = self.model.eval()\n        rs_list = []\n        for d in tqdm.tqdm(docs):\n            inputs = self.tokenizer(d, return_tensors='pt')\n            inputs = inputs.to(self.model.device)\n            inputs.update(inference_kwargs)\n            pred = self.model.generate(**inputs)\n            response = self.tokenizer.decode(pred.cpu()[0][len(inputs['input_ids'][0]):], skip_special_tokens=True)\n            rs_list.append(response)\n        self.model = self.model.train()\n        return rs_list\n\n"
  },
  {
    "path": "python/fate_llm/inference/inference_base.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n\nfrom typing import List\n\n\nclass Inference(object):\n\n    def __init__(self):\n        pass\n\n    def inference(self, docs: List[str], inference_kwargs: dict = {}) -> List[str]:\n        raise NotImplementedError()"
  },
  {
    "path": "python/fate_llm/inference/vllm.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n\nfrom fate_llm.inference.inference_base import Inference\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\nfrom transformers import GenerationConfig\nimport logging\nfrom typing import List\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass VLLMInference(Inference):\n\n    def __init__(self, model_path, num_gpu=1, dtype='float16', gpu_memory_utilization=0.9):\n        from vllm import LLM\n        self.llm = LLM(model=model_path, trust_remote_code=True, dtype=dtype, tensor_parallel_size=num_gpu, gpu_memory_utilization=gpu_memory_utilization)\n        logger.info('vllm model init done, model path is {}'.format(model_path))\n\n    def inference(self, docs: List[str], inference_kwargs: dict = {}) -> List[str]:\n        \n        from vllm import SamplingParams\n        param = SamplingParams(**inference_kwargs)\n        outputs = self.llm.generate(\n            prompts=docs, \n            sampling_params=param)\n\n        rs = []\n        for output in outputs:\n            prompt = output.prompt\n            generated_text = output.outputs[0].text\n            rs.append(generated_text)\n\n        return rs\n\n\n"
  },
  {
    "path": "python/fate_llm/model_zoo/__init__.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n"
  },
  {
    "path": "python/fate_llm/model_zoo/embedding_transformer/__init__.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n"
  },
  {
    "path": "python/fate_llm/model_zoo/embedding_transformer/st_model.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom sentence_transformers import SentenceTransformer\nfrom typing import Any, Optional, Dict, Union\n\n\nclass SentenceTransformerModel(object):\n    def __init__(\n        self,\n        model_name_or_path: Optional[str] = None,\n        device: Optional[str] = None,\n        prompts: Optional[Dict[str, str]] = None,\n        default_prompt_name: Optional[str] = None,\n        cache_folder: Optional[str] = None,\n        trust_remote_code: bool = False,\n        revision: Optional[str] = None,\n        local_files_only: bool = False,\n        token: Optional[Union[bool, str]] = None,\n        use_auth_token: Optional[Union[bool, str]] = None,\n        truncate_dim: Optional[int] = None,\n        model_kwargs: Optional[Dict[str, Any]] = None,\n        tokenizer_kwargs: Optional[Dict[str, Any]] = None,\n        config_kwargs: Optional[Dict[str, Any]] = None,\n    ) -> None:\n        self.model_name_or_path = model_name_or_path\n        self.device = device\n        self.prompts = prompts\n        self.default_prompt_name = default_prompt_name\n        self.cache_folder = cache_folder\n        self.trust_remote_code = trust_remote_code\n        self.revision = revision\n        self.local_files_only = local_files_only\n        self.token = token\n        self.use_auth_token = use_auth_token\n        self.truncate_dim = truncate_dim\n        self.model_kwargs = model_kwargs\n        self.tokenizer_kwargs = tokenizer_kwargs\n        self.config_kwargs = config_kwargs\n\n    def load(self):\n        model = SentenceTransformer(\n            model_name_or_path=self.model_name_or_path,\n            device=self.device,\n            prompts=self.prompts,\n            default_prompt_name=self.default_prompt_name,\n            cache_folder=self.cache_folder,\n            trust_remote_code=self.trust_remote_code,\n            revision=self.revision,\n            local_files_only=self.local_files_only,\n            token=self.token,\n            use_auth_token=self.use_auth_token,\n            truncate_dim=self.truncate_dim,\n            model_kwargs=self.model_kwargs,\n            tokenizer_kwargs=self.tokenizer_kwargs,\n            config_kwargs=self.config_kwargs\n        )\n\n        return model\n"
  },
  {
    "path": "python/fate_llm/model_zoo/hf_model.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport torch\nfrom transformers import AutoModelForCausalLM\n\n\nclass HFAutoModelForCausalLM:\n\n    def __init__(self, pretrained_model_name_or_path, *model_args, **kwargs) -> None:\n        self.pretrained_model_name_or_path = pretrained_model_name_or_path\n        self.model_args = model_args\n        self.kwargs = kwargs\n        if \"torch_dtype\" in self.kwargs and self.kwargs[\"torch_dtype\"] != \"auto\":\n            dtype = self.kwargs.pop(\"torch_dtype\")\n            self.kwargs[\"torch_dtype\"] = getattr(torch, dtype)\n\n    def load(self):\n        model = AutoModelForCausalLM.from_pretrained(\n            self.pretrained_model_name_or_path, *self.model_args, **self.kwargs\n        )\n        return model\n"
  },
  {
    "path": "python/fate_llm/model_zoo/offsite_tuning/__init__.py",
    "content": ""
  },
  {
    "path": "python/fate_llm/model_zoo/offsite_tuning/bloom.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom fate_llm.model_zoo.offsite_tuning.offsite_tuning_model import OffsiteTuningSubModel, OffsiteTuningMainModel, get_dropout_emulator_and_adapters, split_numpy_array, recover_numpy_array\nfrom transformers.models.bloom.modeling_bloom import BloomForCausalLM, BloomModel, BloomConfig\nfrom torch import nn\nimport torch\nfrom typing import Optional, Tuple\n\n\nclass BloomMainModel(OffsiteTuningMainModel):\n\n    def __init__(\n            self,\n            model_name_or_path,\n            emulator_layer_num: int,\n            adapter_top_layer_num: int = 2,\n            adapter_bottom_layer_num: int = 2):\n\n        self.model_name_or_path = model_name_or_path\n        super().__init__(\n            emulator_layer_num,\n            adapter_top_layer_num,\n            adapter_bottom_layer_num)\n\n    def get_base_model(self):\n        return BloomForCausalLM.from_pretrained(self.model_name_or_path)\n\n    def get_model_transformer_blocks(self, model: BloomForCausalLM):\n        return model.transformer.h\n\n    def get_additional_param_state_dict(self):\n        # get parameter of additional parameter\n        model = self.model\n        param_dict = {\n            'wte': model.transformer.word_embeddings,\n            'word_ln': model.transformer.word_embeddings_layernorm,\n            'last_ln_f': model.transformer.ln_f\n        }\n\n        addition_weights = self.get_numpy_state_dict(param_dict)\n\n        wte = addition_weights.pop('wte')\n        wte_dict = split_numpy_array(wte, 25, 'wte')\n        addition_weights.update(wte_dict)\n        return addition_weights\n\n    def load_additional_param_state_dict(self, submodel_weights: dict):\n        # load additional weights:\n        model = self.model\n        param_dict = {\n            'wte': model.transformer.word_embeddings,\n            'word_ln': model.transformer.word_embeddings_layernorm,\n            'last_ln_f': model.transformer.ln_f\n        }\n\n        new_submodel_weight = {}\n        new_submodel_weight['last_ln_f'] = submodel_weights['last_ln_f']\n        new_submodel_weight['word_ln'] = submodel_weights['word_ln']\n        wte_dict = {}\n        for k, v in submodel_weights.items():\n            if 'wte' in k:\n                wte_dict[k] = v\n        wte = recover_numpy_array(wte_dict, 'wte')\n        new_submodel_weight['wte'] = wte\n        self.load_numpy_state_dict(param_dict, new_submodel_weight)\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        inputs_embeds: Optional[torch.Tensor] = None,\n        labels: Optional[torch.Tensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **deprecated_arguments,\n    ):\n\n        return self.model(\n            input_ids,\n            past_key_values,\n            attention_mask,\n            head_mask,\n            inputs_embeds,\n            labels,\n            use_cache,\n            output_attentions,\n            output_hidden_states,\n            return_dict,\n            **deprecated_arguments,\n        )\n\n\nclass BloomSubModel(OffsiteTuningSubModel):\n\n    def __init__(\n            self,\n            model_name_or_path,\n            emulator_layer_num: int,\n            adapter_top_layer_num: int = 2,\n            adapter_bottom_layer_num: int = 2,\n            fp16_mix_precision=False,\n            partial_weight_decay=None):\n\n        self.model_name_or_path = model_name_or_path\n        self.emulator_layer_num = emulator_layer_num\n        self.adapter_top_layer_num = adapter_top_layer_num\n        self.adapter_bottom_layer_num = adapter_bottom_layer_num\n        super().__init__(\n            emulator_layer_num,\n            adapter_top_layer_num,\n            adapter_bottom_layer_num,\n            fp16_mix_precision)\n        self.partial_weight_decay = partial_weight_decay\n\n    def get_base_model(self):\n        total_layer_num = self.emulator_layer_num + \\\n            self.adapter_top_layer_num + self.adapter_bottom_layer_num\n        config = BloomConfig.from_pretrained(self.model_name_or_path)\n        config.num_hidden_layers = total_layer_num\n        # initialize a model without pretrained weights\n        return BloomForCausalLM(config)\n\n    def get_model_transformer_blocks(self, model: BloomForCausalLM):\n        return model.transformer.h\n\n    def get_additional_param_state_dict(self):\n        # get parameter of additional parameter\n        model = self.model\n        param_dict = {\n            'wte': model.transformer.word_embeddings,\n            'word_ln': model.transformer.word_embeddings_layernorm,\n            'last_ln_f': model.transformer.ln_f\n        }\n\n        addition_weights = self.get_numpy_state_dict(param_dict)\n\n        wte = addition_weights.pop('wte')\n        wte_dict = split_numpy_array(wte, 25, 'wte')\n        addition_weights.update(wte_dict)\n        return addition_weights\n\n    def load_additional_param_state_dict(self, submodel_weights: dict):\n        # load additional weights:\n        model = self.model\n        param_dict = {\n            'wte': model.transformer.word_embeddings,\n            'word_ln': model.transformer.word_embeddings_layernorm,\n            'last_ln_f': model.transformer.ln_f\n        }\n\n        new_submodel_weight = {}\n        new_submodel_weight['last_ln_f'] = submodel_weights['last_ln_f']\n        new_submodel_weight['word_ln'] = submodel_weights['word_ln']\n        wte_dict = {}\n        for k, v in submodel_weights.items():\n            if 'wte' in k:\n                wte_dict[k] = v\n        wte = recover_numpy_array(wte_dict, 'wte')\n        new_submodel_weight['wte'] = wte\n        self.load_numpy_state_dict(param_dict, new_submodel_weight)\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        inputs_embeds: Optional[torch.Tensor] = None,\n        labels: Optional[torch.Tensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **deprecated_arguments,\n    ):\n\n        return self.model(\n            input_ids,\n            past_key_values,\n            attention_mask,\n            head_mask,\n            inputs_embeds,\n            labels,\n            use_cache,\n            output_attentions,\n            output_hidden_states,\n            return_dict,\n            **deprecated_arguments,\n        )\n\n    def parameters(self, recurse=True):\n        if self.partial_weight_decay is None:\n            return super().parameters(recurse)\n        elif isinstance(self.partial_weight_decay, float):\n            no_decay = [\"bias\", \"layer_norm.weight\"]\n            return [\n                {\n                    \"params\": [\n                        p for n, p in self.named_parameters() if not any(\n                            nd in n for nd in no_decay)], \"weight_decay\": self.partial_weight_decay}, {\n                    \"params\": [\n                        p for n, p in self.named_parameters() if any(\n                            nd in n for nd in no_decay)], \"weight_decay\": 0.0}]\n        else:\n            raise ValueError(\n                f\"partial_weight_decay should be None or float, but got {self.partial_weight_decay}\")\n\n"
  },
  {
    "path": "python/fate_llm/model_zoo/offsite_tuning/gpt2.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom fate_llm.model_zoo.offsite_tuning.offsite_tuning_model import OffsiteTuningSubModel, OffsiteTuningMainModel, get_dropout_emulator_and_adapters, split_numpy_array, recover_numpy_array\nfrom transformers import GPT2LMHeadModel, GPT2Config\nimport torch\nfrom typing import Optional, Tuple\n\n\nclass GPT2LMHeadMainModel(OffsiteTuningMainModel):\n\n    def __init__(\n            self,\n            model_name_or_path,\n            emulator_layer_num: int,\n            adapter_top_layer_num: int = 2,\n            adapter_bottom_layer_num: int = 2):\n\n        self.model_name_or_path = model_name_or_path\n        super().__init__(\n            emulator_layer_num,\n            adapter_top_layer_num,\n            adapter_bottom_layer_num)\n\n    def get_base_model(self):\n        return GPT2LMHeadModel.from_pretrained(self.model_name_or_path)\n\n    def get_model_transformer_blocks(self, model: GPT2LMHeadModel):\n        return model.transformer.h\n\n    def forward(self,\n        input_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        token_type_ids: Optional[torch.LongTensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,):\n\n        return self.model(\n            input_ids=input_ids,\n            past_key_values=past_key_values,\n            attention_mask=attention_mask,\n            token_type_ids=token_type_ids,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=encoder_attention_mask,\n            labels=labels,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict)\n\n    def get_additional_param_state_dict(self):\n        # get parameter of additional parameter\n        model = self.model\n        param_dict = {\n            'wte': model.transformer.wte,\n            'wpe': model.transformer.wpe,\n            'last_ln_f': model.transformer.ln_f\n        }\n\n        addition_weights = self.get_numpy_state_dict(param_dict)\n\n        wte = addition_weights.pop('wte')\n        wte_dict = split_numpy_array(wte, 10, 'wte')\n        wpe = addition_weights.pop('wpe')\n        wpe_dict = split_numpy_array(wpe, 10, 'wpe')\n        addition_weights.update(wte_dict)\n        addition_weights.update(wpe_dict)\n        return addition_weights\n\n    def load_additional_param_state_dict(self, submodel_weights: dict):\n        # load additional weights:\n        model = self.model\n        param_dict = {\n            'wte': model.transformer.wte,\n            'wpe': model.transformer.wpe,\n            'last_ln_f': model.transformer.ln_f\n        }\n\n        new_submodel_weight = {}\n        new_submodel_weight['last_ln_f'] = submodel_weights['last_ln_f']\n        wte_dict, wpe_dict = {}, {}\n        for k, v in submodel_weights.items():\n            if 'wte' in k:\n                wte_dict[k] = v\n            if 'wpe' in k:\n                wpe_dict[k] = v\n        wte = recover_numpy_array(wte_dict, 'wte')\n        wpe = recover_numpy_array(wpe_dict, 'wpe')\n        new_submodel_weight['wte'] = wte\n        new_submodel_weight['wpe'] = wpe\n\n        self.load_numpy_state_dict(param_dict, new_submodel_weight)\n\n\nclass GPT2LMHeadSubModel(OffsiteTuningSubModel):\n\n    def __init__(\n            self,\n            model_name_or_path,\n            emulator_layer_num: int,\n            adapter_top_layer_num: int = 2,\n            adapter_bottom_layer_num: int = 2,\n            fp16_mix_precision=False,\n            partial_weight_decay=None):\n\n        self.model_name_or_path = model_name_or_path\n        self.emulator_layer_num = emulator_layer_num\n        self.adapter_top_layer_num = adapter_top_layer_num\n        self.adapter_bottom_layer_num = adapter_bottom_layer_num\n        super().__init__(\n            emulator_layer_num,\n            adapter_top_layer_num,\n            adapter_bottom_layer_num,\n            fp16_mix_precision)\n        self.partial_weight_decay = partial_weight_decay\n\n    def get_base_model(self):\n        total_layer_num = self.emulator_layer_num + \\\n            self.adapter_top_layer_num + self.adapter_bottom_layer_num\n        config = GPT2Config.from_pretrained(self.model_name_or_path)\n        config.num_hidden_layers = total_layer_num\n        # initialize a model without pretrained weights\n        return GPT2LMHeadModel(config)\n\n    def get_model_transformer_blocks(self, model: GPT2LMHeadModel):\n        return model.transformer.h\n\n    def get_additional_param_state_dict(self):\n        # get parameter of additional parameter\n        model = self.model\n        param_dict = {\n            'wte': model.transformer.wte,\n            'wpe': model.transformer.wpe,\n            'last_ln_f': model.transformer.ln_f\n        }\n\n        addition_weights = self.get_numpy_state_dict(param_dict)\n\n        wte = addition_weights.pop('wte')\n        wte_dict = split_numpy_array(wte, 10, 'wte')\n        wpe = addition_weights.pop('wpe')\n        wpe_dict = split_numpy_array(wpe, 10, 'wpe')\n        addition_weights.update(wte_dict)\n        addition_weights.update(wpe_dict)\n        return addition_weights\n\n    def load_additional_param_state_dict(self, submodel_weights: dict):\n        # load additional weights:\n        model = self.model\n        param_dict = {\n            'wte': model.transformer.wte,\n            'wpe': model.transformer.wpe,\n            'last_ln_f': model.transformer.ln_f\n        }\n\n        new_submodel_weight = {}\n        new_submodel_weight['last_ln_f'] = submodel_weights['last_ln_f']\n        wte_dict, wpe_dict = {}, {}\n        for k, v in submodel_weights.items():\n            if 'wte' in k:\n                wte_dict[k] = v\n            if 'wpe' in k:\n                wpe_dict[k] = v\n        wte = recover_numpy_array(wte_dict, 'wte')\n        wpe = recover_numpy_array(wpe_dict, 'wpe')\n        new_submodel_weight['wte'] = wte\n        new_submodel_weight['wpe'] = wpe\n\n        self.load_numpy_state_dict(param_dict, new_submodel_weight)\n\n    def forward(self,\n        input_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        token_type_ids: Optional[torch.LongTensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,):\n\n        return self.model(\n            input_ids=input_ids,\n            past_key_values=past_key_values,\n            attention_mask=attention_mask,\n            token_type_ids=token_type_ids,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=encoder_attention_mask,\n            labels=labels,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict)\n\n    def parameters(self, recurse=True):\n        if self.partial_weight_decay is None:\n            return super().parameters(recurse)\n        elif isinstance(self.partial_weight_decay, float):\n            no_decay = [\"bias\", \"layer_norm.weight\"]\n            return [\n                {\n                    \"params\": [\n                        p for n, p in self.named_parameters() if not any(\n                            nd in n for nd in no_decay)], \"weight_decay\": self.partial_weight_decay}, {\n                    \"params\": [\n                        p for n, p in self.named_parameters() if any(\n                            nd in n for nd in no_decay)], \"weight_decay\": 0.0}]\n        else:\n            raise ValueError(\n                f\"partial_weight_decay should be None or float, but got {self.partial_weight_decay}\")"
  },
  {
    "path": "python/fate_llm/model_zoo/offsite_tuning/llama.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom fate_llm.model_zoo.offsite_tuning.offsite_tuning_model import OffsiteTuningSubModel, OffsiteTuningMainModel, get_dropout_emulator_and_adapters, split_numpy_array, recover_numpy_array\nfrom transformers import LlamaConfig, LlamaForCausalLM\n\n\nclass LlamaMainModel(OffsiteTuningMainModel):\n\n    def __init__(\n            self,\n            model_name_or_path,\n            emulator_layer_num: int,\n            adapter_top_layer_num: int = 2,\n            adapter_bottom_layer_num: int = 2):\n\n        self.model_name_or_path = model_name_or_path\n        super().__init__(\n            emulator_layer_num,\n            adapter_top_layer_num,\n            adapter_bottom_layer_num)\n\n    def get_base_model(self):\n        return LlamaForCausalLM.from_pretrained(self.model_name_or_path)\n\n    def get_model_transformer_blocks(self, model: LlamaForCausalLM):\n        return model.model.layers\n\n    def get_additional_param_state_dict(self):\n        # get parameter of additional parameter\n        model = self.model\n        param_dict = {\n            'wte': model.model.embed_tokens,\n            'last_ln_f': model.model.norm\n        }\n\n        addition_weights = self.get_numpy_state_dict(param_dict)\n\n        wte = addition_weights.pop('wte')\n        wte_dict = split_numpy_array(wte, 25, 'wte')\n        addition_weights.update(wte_dict)\n        return addition_weights\n\n    def load_additional_param_state_dict(self, submodel_weights: dict):\n        # load additional weights:\n        model = self.model\n        param_dict = {\n            'wte': model.model.embed_tokens,\n            'last_ln_f': model.model.norm\n        }\n\n        new_submodel_weight = {}\n        new_submodel_weight['last_ln_f'] = submodel_weights['last_ln_f']\n        wte_dict = {}\n        for k, v in submodel_weights.items():\n            if 'wte' in k:\n                wte_dict[k] = v\n        wte = recover_numpy_array(wte_dict, 'wte')\n        new_submodel_weight['wte'] = wte\n        self.load_numpy_state_dict(param_dict, new_submodel_weight)\n\n    def forward(self, **kwargs):\n        return self.model(**kwargs)\n\n\nclass LlamaSubModel(OffsiteTuningSubModel):\n\n    def __init__(\n            self,\n            model_name_or_path,\n            emulator_layer_num: int,\n            adapter_top_layer_num: int = 2,\n            adapter_bottom_layer_num: int = 2,\n            fp16_mix_precision=False,\n            partial_weight_decay=None):\n\n        self.model_name_or_path = model_name_or_path\n        self.emulator_layer_num = emulator_layer_num\n        self.adapter_top_layer_num = adapter_top_layer_num\n        self.adapter_bottom_layer_num = adapter_bottom_layer_num\n        super().__init__(\n            emulator_layer_num,\n            adapter_top_layer_num,\n            adapter_bottom_layer_num,\n            fp16_mix_precision)\n        self.partial_weight_decay = partial_weight_decay\n\n    def get_base_model(self):\n        total_layer_num = self.emulator_layer_num + \\\n            self.adapter_top_layer_num + self.adapter_bottom_layer_num\n        config = LlamaConfig.from_pretrained(self.model_name_or_path)\n        config.num_hidden_layers = total_layer_num\n        # initialize a model without pretrained weights\n        return LlamaForCausalLM(config)\n\n    def get_model_transformer_blocks(self, model: LlamaForCausalLM):\n        return model.model.layers\n\n    def get_additional_param_state_dict(self):\n        # get parameter of additional parameter\n        model = self.model\n        param_dict = {\n            'wte': model.model.embed_tokens,\n            'last_ln_f': model.model.norm\n        }\n\n        addition_weights = self.get_numpy_state_dict(param_dict)\n\n        wte = addition_weights.pop('wte')\n        wte_dict = split_numpy_array(wte, 25, 'wte')\n        addition_weights.update(wte_dict)\n        return addition_weights\n\n    def load_additional_param_state_dict(self, submodel_weights: dict):\n        # load additional weights:\n        model = self.model\n        param_dict = {\n            'wte': model.model.embed_tokens,\n            'last_ln_f': model.model.norm\n        }\n\n        new_submodel_weight = {}\n        new_submodel_weight['last_ln_f'] = submodel_weights['last_ln_f']\n        wte_dict = {}\n        for k, v in submodel_weights.items():\n            if 'wte' in k:\n                wte_dict[k] = v\n        wte = recover_numpy_array(wte_dict, 'wte')\n        new_submodel_weight['wte'] = wte\n        self.load_numpy_state_dict(param_dict, new_submodel_weight)\n\n    def forward(self, **kwargs):\n        return self.model(**kwargs)\n\n    def parameters(self, recurse=True):\n        if self.partial_weight_decay is None:\n            return super().parameters(recurse)\n        elif isinstance(self.partial_weight_decay, float):\n            no_decay = [\"bias\", \"layer_norm.weight\"]\n            return [\n                {\n                    \"params\": [\n                        p for n, p in self.named_parameters() if not any(\n                            nd in n for nd in no_decay)], \"weight_decay\": self.partial_weight_decay}, {\n                    \"params\": [\n                        p for n, p in self.named_parameters() if any(\n                            nd in n for nd in no_decay)], \"weight_decay\": 0.0}]\n        else:\n            raise ValueError(\n                f\"partial_weight_decay should be None or float, but got {self.partial_weight_decay}\")"
  },
  {
    "path": "python/fate_llm/model_zoo/offsite_tuning/offsite_tuning_model.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport torch as t\nfrom torch import nn\nfrom transformers import AutoModel\nimport numpy as np\nimport logging\n\nlogger = logging.getLogger(__name__)\n\n\ndef get_dropout_emulator_and_adapters(\n        transformer_layers: nn.ModuleList,\n        emulator_layer_num: int,\n        adapter_top_layer_num: int,\n        adapter_bottom_layer_num: int):\n\n    assert adapter_bottom_layer_num > 0 and adapter_top_layer_num > 0, \"adapter layer num must be greater than 0\"\n    assert emulator_layer_num < len(\n        transformer_layers), \"emulator layer num must be less than the number of transformer layers\"\n    assert adapter_bottom_layer_num + adapter_top_layer_num < len(\n        transformer_layers), \"adapter layer num must be less than the number of transformer layers\"\n    assert emulator_layer_num < len(\n        transformer_layers) and emulator_layer_num > 0, \"emulator layer num must be less than the number of transformer layers\"\n\n    bottom_idx = adapter_bottom_layer_num\n    top_idx = len(transformer_layers) - adapter_top_layer_num\n    bottom_layers = transformer_layers[:bottom_idx]\n    top_layers = transformer_layers[top_idx:]\n    kept_layers = transformer_layers[bottom_idx:top_idx]\n    emulator = nn.ModuleList()\n    stride = (len(kept_layers) - 1) / (emulator_layer_num - 1)\n\n    layer_idx = []\n    for i in range(emulator_layer_num):\n        idx = int(round(i * stride))\n        layer_idx.append(idx)\n        emulator.append(kept_layers[idx])\n    logger.info(\n        'take layer {} of the original model as the emulator'.format(\n            t.Tensor(layer_idx) +\n            bottom_idx))\n    return nn.ModuleList(emulator), nn.ModuleList(\n        bottom_layers), nn.ModuleList(top_layers)\n\n\n\ndef split_numpy_array(embedding_matrix, n, suffix):\n    # Calculate the indices where the splits should occur\n    embedding_matrix = embedding_matrix['weight']\n    indices = np.linspace(0, embedding_matrix.shape[0], n+1, dtype=int)\n\n    # Split the embedding matrix at the calculated indices\n    slices = [embedding_matrix[indices[i]:indices[i+1]] for i in range(n)]\n\n    # Create a dictionary with the slices\n    result_dict = {suffix+str(i): slice for i, slice in enumerate(slices)}\n    return result_dict\n\n\ndef recover_numpy_array(slices_dict, suffix=\"\"):\n    # Get the slices from the dictionary and concatenate them\n    slices = [slices_dict[suffix + str(i)] for i in range(len(slices_dict))]\n    complete_array = np.concatenate(slices, axis=0)\n    return {'weight': complete_array}\n\n\nclass OffsiteTuningBaseModel(t.nn.Module):\n\n    def __init__(self, emulator_layer_num: int, adapter_top_layer_num: int = 2,\n                 adapter_bottom_layer_num: int = 2, fp16_mix_precision=False):\n        super().__init__()\n        self.fp16_mix_precision = fp16_mix_precision\n        self.model = self.get_base_model()\n        self.initialize_model()\n        self.emulator, self.adapter_bottom, self.adapter_top = get_dropout_emulator_and_adapters(\n            transformer_layers=self.get_model_transformer_blocks(self.model),\n            emulator_layer_num=emulator_layer_num,\n            adapter_top_layer_num=adapter_top_layer_num,\n            adapter_bottom_layer_num=adapter_bottom_layer_num\n        )\n        self.post_initialization()\n\n    def initialize_model(self):\n        if self.fp16_mix_precision:\n            self.model.half()\n        for param in self.model.parameters():\n            param.requires_grad = False\n\n    def post_initialization(self):\n        pass\n\n    def get_adapter_top(self):\n        return self.adapter_top\n\n    def get_adapter_bottom(self):\n        return self.adapter_bottom\n\n    def get_emulator(self):\n        return self.emulator\n\n    def get_additional_param_state_dict(self):\n        # get parameter of additional parameter\n        return {}\n\n    def load_additional_param_state_dict(self, submodel_weights: dict):\n        # load additional weights:\n        pass\n\n    def _get_numpy_arr(self, v):\n        if v.dtype == t.bfloat16:\n            # float 32\n            v = v.detach().cpu().float().numpy()\n        else:\n            v = v.detach().cpu().numpy()\n\n        return v\n\n\n    def load_numpy_state_dict(self, module_dict, state_dict):\n        param_dict = module_dict\n\n        for k, v in param_dict.items():\n            if k not in state_dict:\n                continue\n            addition_weights = {\n                k: t.tensor(v) for k,\n                v in state_dict[k].items()}\n            v.load_state_dict(addition_weights)\n\n    def get_numpy_state_dict(self, module_dict):\n\n        weight_dict = {}\n        for k, v in module_dict.items():\n            weight_dict[k] = {\n                k: self._get_numpy_arr(v) for k,\n                v in v.state_dict().items()}\n        return weight_dict\n\n    def get_submodel_weights(self, with_emulator=True) -> dict:\n        if with_emulator:\n            submodel_weights = {\n                \"emulator\": {\n                    k: self._get_numpy_arr(v) for k,\n                    v in self.get_emulator().state_dict().items()},\n                \"adapter_top\": {\n                    k: self._get_numpy_arr(v) for k,\n                    v in self.get_adapter_top().state_dict().items()},\n                \"adapter_bottom\": {\n                    k: self._get_numpy_arr(v) for k,\n                    v in self.get_adapter_bottom().state_dict().items()}}\n        else:\n            submodel_weights = {\n                \"adapter_top\": {\n                    k: self._get_numpy_arr(v) for k,\n                    v in self.get_adapter_top().state_dict().items()},\n                \"adapter_bottom\": {\n                    k: self._get_numpy_arr(v) for k,\n                    v in self.get_adapter_bottom().state_dict().items()}}\n        addition_weights = self.get_additional_param_state_dict()\n        submodel_weights.update(addition_weights)\n        return submodel_weights\n\n    def load_submodel_weights(self, submodel_weights: dict, with_emulator=True):\n\n        if with_emulator:\n            emulator_weights = {\n                k: t.tensor(v) for k,\n                v in submodel_weights['emulator'].items()}\n            emulator = self.get_emulator()\n            emulator.load_state_dict(emulator_weights)\n\n\n        adapter_top_weights = {\n            k: t.tensor(v) for k,\n            v in submodel_weights['adapter_top'].items()}\n        adapter_bottom_weights = {\n            k: t.tensor(v) for k,\n            v in submodel_weights['adapter_bottom'].items()}\n        adapter_top = self.get_adapter_top()\n        adapter_bottom = self.get_adapter_bottom()\n\n        \n        adapter_top.load_state_dict(adapter_top_weights)\n        adapter_bottom.load_state_dict(adapter_bottom_weights)\n        self.load_additional_param_state_dict(submodel_weights)\n\n    def forward(self, **kwargs):\n        raise NotImplementedError()\n\n    def get_base_model(self):\n        raise NotImplementedError()\n\n    def get_model_transformer_blocks(self, model: t.nn.Module):\n        raise NotImplementedError()\n\n\nclass OffsiteTuningMainModel(OffsiteTuningBaseModel):\n\n    def post_initialization(self):\n        pass\n\n\nclass OffsiteTuningSubModel(OffsiteTuningBaseModel):\n\n    def post_initialization(self):\n        # mix precision model training\n        for param in self.adapter_top.parameters():\n            param.data = param.data.float()\n            param.requires_grad = True\n        for param in self.adapter_bottom.parameters():\n            param.data = param.data.float()\n            param.requires_grad = True"
  },
  {
    "path": "python/fate_llm/model_zoo/pellm/__init__.py",
    "content": ""
  },
  {
    "path": "python/fate_llm/model_zoo/pellm/albert.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom transformers import AlbertConfig, AutoConfig\nfrom transformers import AlbertForSequenceClassification\nfrom fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM\n\n\nclass Albert(PELLM):\n\n    config_class = AlbertConfig\n    model_loader = AlbertForSequenceClassification\n\n    def __init__(self, config: dict = None,\n                 pretrained_path: str = None,\n                 peft_type: str = None,\n                 peft_config: dict = None,\n                 **kwargs\n                 ) -> None:\n\n        if pretrained_path is not None:\n            self.check_config(pretain_path=pretrained_path)\n        if config is None and pretrained_path is None:\n            config = AlbertConfig().to_dict()  # use default model setting\n        super().__init__(\n            config=config,\n            pretrained_path=pretrained_path,\n            peft_type=peft_type,\n            peft_config=peft_config,\n            **kwargs)\n\n    def check_config(self, pretain_path):\n        config = AutoConfig.from_pretrained(pretain_path)\n        assert isinstance(\n            config, AlbertConfig), 'The config of pretrained model must be AlbertConfig, but got {}'.format(\n            type(config))\n"
  },
  {
    "path": "python/fate_llm/model_zoo/pellm/bart.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom transformers import BartConfig, AutoConfig\nfrom transformers import BartForSequenceClassification\nfrom fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM\n\n\nclass Bart(PELLM):\n    config_class = BartConfig\n    model_loader = BartForSequenceClassification\n\n    def __init__(self, config: dict = None,\n                 pretrained_path: str = None,\n                 peft_type: str = None,\n                 peft_config: dict = None,\n                 **kwargs) -> None:\n\n        if pretrained_path is not None:\n            self.check_config(pretrain_path=pretrained_path)\n        if config is None and pretrained_path is None:\n            config = BartConfig().to_dict()\n        super().__init__(\n            config=config,\n            pretrained_path=pretrained_path,\n            peft_type=peft_type,\n            peft_config=peft_config,\n            **kwargs)\n\n    def check_config(self, pretrain_path):\n        config = AutoConfig.from_pretrained(pretrain_path)\n        assert isinstance(\n            config, BartConfig), 'The config of pretrained model must be BartConfig, but got {}'.format(\n            type(config))\n"
  },
  {
    "path": "python/fate_llm/model_zoo/pellm/bert.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom transformers import BertConfig, AutoConfig\nfrom transformers import BertForSequenceClassification\nfrom fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM\n\n\nclass Bert(PELLM):\n    config_class = BertConfig\n    model_loader = BertForSequenceClassification\n\n    def __init__(self, config: dict = None,\n                 pretrained_path: str = None,\n                 peft_type: str = None,\n                 peft_config: dict = None,\n                 **kwargs) -> None:\n\n        if pretrained_path is not None:\n            self.check_config(pretrain_path=pretrained_path)\n        if config is None and pretrained_path is None:\n            config = BertConfig().to_dict()\n        super().__init__(\n            config=config,\n            pretrained_path=pretrained_path,\n            peft_type=peft_type,\n            peft_config=peft_config,\n            **kwargs)\n\n    def check_config(self, pretrain_path):\n        config = AutoConfig.from_pretrained(pretrain_path)\n        assert isinstance(\n            config, BertConfig), 'The config of pretrained model must be BertConfig, but got {}'.format(\n            type(config))\n"
  },
  {
    "path": "python/fate_llm/model_zoo/pellm/bloom.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom transformers import BloomConfig\nfrom transformers import BloomForCausalLM\nfrom fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM\n\n\nclass Bloom(PELLM):\n\n    config_class = BloomConfig\n    model_loader = BloomForCausalLM\n\n    def __init__(self, config: dict = None,\n                 pretrained_path: str = None,\n                 peft_type: str = None,\n                 peft_config: dict = None,\n                 **kwargs\n                 ) -> None:\n\n        if config is None and pretrained_path is None:\n            config = BloomConfig().to_dict()  # use default model setting\n        super().__init__(config=config, pretrained_path=pretrained_path,\n                         peft_type=peft_type, peft_config=peft_config, **kwargs)\n"
  },
  {
    "path": "python/fate_llm/model_zoo/pellm/chatglm.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM\nfrom transformers import AutoConfig\n\n\nclass ChatGLM(PELLM):\n    def __init__(self,\n                 pretrained_path: str = None,\n                 peft_type: str = None,\n                 peft_config: dict = None,\n                 pre_seq_len: int = None,\n                 prefix_projection: bool = False,\n                 **kwargs) -> None:\n\n        self.pre_seq_len = pre_seq_len\n        self.prefix_projection = prefix_projection\n\n        super().__init__(pretrained_path=pretrained_path,\n                         peft_type=peft_type,\n                         peft_config=peft_config,\n                         **kwargs\n                         )\n\n    def init_config(self):\n        self.config = AutoConfig.from_pretrained(\n            self.config_path, trust_remote_code=True)\n        self.config.pre_seq_len = self.pre_seq_len\n        self.config.prefix_projection = self.prefix_projection\n\n    def add_peft(self):\n        if self.pre_seq_len:\n            self._pe_lm.half()\n            self._pe_lm.transformer.prefix_encoder.float()\n        else:\n            super(ChatGLM, self).add_peft()\n"
  },
  {
    "path": "python/fate_llm/model_zoo/pellm/deberta.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom transformers import DebertaConfig, AutoConfig\nfrom transformers import DebertaForSequenceClassification\nfrom fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM\n\n\nclass Deberta(PELLM):\n\n    config_class = DebertaConfig\n    model_loader = DebertaForSequenceClassification\n\n    def __init__(self, config: dict = None,\n                 pretrained_path: str = None,\n                 peft_type: str = None,\n                 peft_config: dict = None,\n                 **kwargs) -> None:\n\n        if pretrained_path is not None:\n            self.check_config(pretrain_path=pretrained_path)\n        if config is None and pretrained_path is None:\n            config = DebertaConfig().to_dict()\n        super().__init__(\n            config=config,\n            pretrained_path=pretrained_path,\n            peft_type=peft_type,\n            peft_config=peft_config,\n            **kwargs)\n\n    def check_config(self, pretrain_path):\n        config = AutoConfig.from_pretrained(pretrain_path)\n        assert isinstance(\n            config, DebertaConfig), 'The config of pretrained model must be DebertaConfig, but got {}'.format(\n            type(config))\n"
  },
  {
    "path": "python/fate_llm/model_zoo/pellm/distilbert.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom transformers import DistilBertConfig, AutoConfig\nfrom transformers import DistilBertForSequenceClassification\nfrom fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM\n\n\nclass DistilBert(PELLM):\n    config_class = DistilBertConfig\n    model_loader = DistilBertForSequenceClassification\n\n    def __init__(self, config: dict = None,\n                 pretrained_path: str = None,\n                 peft_type: str = None,\n                 peft_config: dict = None,\n                 **kwargs) -> None:\n\n        if pretrained_path is not None:\n            self.check_config(pretrain_path=pretrained_path)\n        if config is None and pretrained_path is None:\n            config = DistilBertConfig().to_dict()\n        super().__init__(\n            config=config,\n            pretrained_path=pretrained_path,\n            peft_type=peft_type,\n            peft_config=peft_config,\n            **kwargs)\n\n    def check_config(self, pretrain_path):\n        config = AutoConfig.from_pretrained(pretrain_path)\n        assert isinstance(\n            config, DistilBertConfig), 'The config of pretrained model must be DistilBertConfig, but got {}'.format(\n            type(config))\n"
  },
  {
    "path": "python/fate_llm/model_zoo/pellm/gpt2.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom transformers import GPT2Config, AutoConfig\nfrom transformers import GPT2ForSequenceClassification, AutoModelForCausalLM\nfrom fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM\n\n\nclass GPT2(PELLM):\n    config_class = GPT2Config\n    model_loader = GPT2ForSequenceClassification\n\n    def __init__(self,\n                 config: dict = None,\n                 pretrained_path: str = None,\n                 peft_type: str = None,\n                 peft_config: dict = None,\n                 **kwargs) -> None:\n\n        if pretrained_path is not None:\n            self.check_config(pretrain_path=pretrained_path)\n        if config is None and pretrained_path is None:\n            config = GPT2Config().to_dict()\n        super().__init__(\n            config=config,\n            pretrained_path=pretrained_path,\n            peft_type=peft_type,\n            peft_config=peft_config,\n            **kwargs)\n\n    def check_config(self, pretrain_path):\n        config = AutoConfig.from_pretrained(pretrain_path)\n        assert isinstance(\n            config, GPT2Config), 'The config of pretrained model must be GPT2Config, but got {}'.format(\n            type(config))\n\n\nclass GPT2CLM(GPT2):\n    model_loader = AutoModelForCausalLM\n"
  },
  {
    "path": "python/fate_llm/model_zoo/pellm/llama.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM\nfrom transformers import AutoConfig\nfrom transformers import LlamaConfig\nfrom transformers import LlamaForCausalLM\n\n\nclass LLaMa(PELLM):\n    config_class = LlamaConfig\n\n    def __init__(self,\n                 pretrained_path: str = None,\n                 peft_type: str = None,\n                 peft_config: dict = None,\n                 **kwargs) -> None:\n\n        super().__init__(pretrained_path=pretrained_path,\n                         peft_type=peft_type,\n                         peft_config=peft_config,\n                         **kwargs)\n\n    def init_base_lm(self, **kwargs):\n        if self.config is not None:\n            self._pe_lm = LlamaForCausalLM.from_pretrained(self.config_path,\n                                                           config=self.config,\n                                                           torch_dtype=self.torch_dtype,\n                                                           **kwargs)\n        elif self.config_path is not None:\n            self._pe_lm = LlamaForCausalLM.from_pretrained(self.config_path, torch_dtype=self.torch_dtype, **kwargs)\n        else:\n            raise ValueError(\n                'config_path to pretrained model folder cannot be None')\n\n    def check_config(self, pretrain_path):\n        config = AutoConfig.from_pretrained(pretrain_path)\n        assert isinstance(\n            config, LlamaConfig), 'The config of pretrained model must be LlamaConfig, but got {}'.format(\n            type(config))\n"
  },
  {
    "path": "python/fate_llm/model_zoo/pellm/opt.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom transformers import OPTConfig\nfrom transformers import OPTForCausalLM\nfrom fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM\n\n\nclass OPT(PELLM):\n\n    config_class = OPTConfig\n    model_loader = OPTForCausalLM\n\n    def __init__(self, config: dict = None,\n                 pretrained_path: str = None,\n                 peft_type: str = None,\n                 peft_config: dict = None,\n                 **kwargs\n                 ) -> None:\n\n        if config is None and pretrained_path is None:\n            config = OPTConfig().to_dict()  # use default model setting\n        super().__init__(config=config, pretrained_path=pretrained_path,\n                         peft_type=peft_type, peft_config=peft_config, **kwargs)\n"
  },
  {
    "path": "python/fate_llm/model_zoo/pellm/parameter_efficient_llm.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport peft\nimport torch\nfrom collections.abc import Mapping\nfrom peft import PeftModel, TaskType\nfrom transformers import AutoConfig\nfrom transformers import AutoModel\nfrom transformers.configuration_utils import PretrainedConfig\nimport logging\n\n\nlogger = logging.getLogger(__name__)\n\n\nAVAILABLE_PEFT_CONFIG = list(\n    filter(\n        lambda peft_type: peft_type.endswith(\"Config\"), dir(peft)\n    )\n)\n\n\nclass PELLM(torch.nn.Module):\n\n    config_class: PretrainedConfig = None\n    model_loader = None\n\n    def __init__(self,\n                 config: dict = None,\n                 pretrained_path: str = None,\n                 peft_type: str = None,\n                 peft_config=None,\n                 torch_dtype: str = None,\n                 trust_remote_code: bool = False,\n                 **kwargs\n                 ) -> None:\n\n        super().__init__()\n        self._pe_lm: PeftModel = None\n        self.config = config\n        self.config_path = pretrained_path\n        self.peft_type = peft_type\n        self.peft_config = peft_config\n        self.torch_dtype = None if not torch_dtype else getattr(torch, torch_dtype)\n        self.trust_remote_code = trust_remote_code\n\n        assert self.config_path is not None or self.config is not None, \\\n            \"At least one of config_path and config must be set.\"\n        self._init_pelm(**kwargs)\n\n    def _init_pelm(self, **kwargs):\n        self.init_lm_with_peft(**kwargs)\n        self.model_summary()\n\n    def init_lm_with_peft(self, **kwargs):\n        self.init_config(**kwargs)\n        self.init_base_lm()\n        self.add_peft()\n\n    def init_config(self, **kwargs):\n        if self.config_path is not None:\n            self.config = AutoConfig.from_pretrained(self.config_path, trust_remote_code=self.trust_remote_code)\n        elif self.config is not None and self.config_class is not None:\n            self.config = self.config_class().from_dict(self.config)\n        else:\n            raise ValueError(\n                'config_path to pretrained model folder and model config dict cannot be None at the same time, '\n                'you need to specify one of them')\n\n        if kwargs:\n            self.config.update(kwargs)\n\n    def init_base_lm(self, **kwargs):\n        model_loader = self.model_loader if self.model_loader is not None else AutoModel\n        if self.config is not None:\n            self._pe_lm = model_loader.from_pretrained(\n                self.config_path, config=self.config,\n                torch_dtype=self.torch_dtype, **kwargs,\n                trust_remote_code=self.trust_remote_code\n            )\n        elif self.config_path is not None:\n            self._pe_lm = model_loader.from_pretrained(\n                self.config_path, torch_dtype=self.torch_dtype,\n                trust_remote_code=self.trust_remote_code, **kwargs)\n        else:\n            raise ValueError(\n                'config_path to pretrained model folder cannot be None')\n\n    def add_peft(self):\n        assert self.peft_type in AVAILABLE_PEFT_CONFIG, 'peft name {} not in available config {}'.format(\n            self.peft_type, AVAILABLE_PEFT_CONFIG)\n\n        if self.peft_config is None:\n            peft_config = getattr(peft, self.peft_type)()\n        elif isinstance(self.peft_config, dict):\n            peft_config = getattr(peft, self.peft_type)(**self.peft_config)\n        else:\n            raise ValueError(f\"Can not parse peft_config of {type(self.peft_config)}\")\n\n        self._pe_lm = peft.get_peft_model(self._pe_lm, peft_config)\n        self.peft_config = peft_config\n\n    def model_summary(self):\n        if hasattr(self._pe_lm, \"print_trainable_parameters\"):\n            summary = self._pe_lm.print_trainable_parameters()\n            logger.debug(f'PELLM model summary: \\n{summary}')\n\n    def forward(self, *args, **kwargs):\n        forward_ret = self._pe_lm.forward(*args, **kwargs)\n\n        if self.peft_config is None or self.peft_config.task_type != TaskType.SEQ_CLS:\n            return forward_ret\n        else:\n            return forward_ret.logits\n\n    def save_trainable(self, output_path):\n        self._pe_lm.save_pretrained(output_path)\n\n\nclass AutoPELLM(PELLM):\n\n    def __init__(self, **kwargs) -> None:\n        super().__init__(**kwargs)\n"
  },
  {
    "path": "python/fate_llm/model_zoo/pellm/qwen.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom transformers import Qwen2Config\nfrom transformers import Qwen2ForCausalLM\nfrom fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM\n\n\nclass Qwen(PELLM):\n\n    config_class = Qwen2Config\n    model_loader = Qwen2ForCausalLM\n\n    def __init__(self, config: dict = None,\n                 pretrained_path: str = None,\n                 peft_type: str = None,\n                 peft_config: dict = None,\n                 **kwargs\n                 ) -> None:\n\n        if config is None and pretrained_path is None:\n            config = Qwen2Config().to_dict()  # use default model setting\n        super().__init__(config=config, pretrained_path=pretrained_path,\n                         peft_type=peft_type, peft_config=peft_config, **kwargs)\n"
  },
  {
    "path": "python/fate_llm/model_zoo/pellm/roberta.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom transformers import RobertaConfig, AutoConfig\nfrom transformers import RobertaForSequenceClassification\nfrom fate_llm.model_zoo.pellm.parameter_efficient_llm import PELLM\n\n\nclass Roberta(PELLM):\n    config_class = RobertaConfig\n    model_loader = RobertaForSequenceClassification\n\n    def __init__(self, config: dict = None,\n                 pretrained_path: str = None,\n                 peft_type: str = None,\n                 peft_config: dict = None,\n                 **kwargs) -> None:\n\n        if pretrained_path is not None:\n            self.check_config(pretrain_path=pretrained_path)\n        if config is None and pretrained_path is None:\n            config = RobertaConfig().to_dict()\n        super().__init__(\n            config=config,\n            pretrained_path=pretrained_path,\n            peft_type=peft_type,\n            peft_config=peft_config,\n            **kwargs)\n\n    def check_config(self, pretrain_path):\n        config = AutoConfig.from_pretrained(pretrain_path)\n        assert isinstance(\n            config, RobertaConfig), 'The config of pretrained model must be RobertaConfig, but got {}'.format(\n            type(config))\n"
  },
  {
    "path": "python/fate_llm/runner/__init__.py",
    "content": ""
  },
  {
    "path": "python/fate_llm/runner/fdkt_runner.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nimport logging\nimport torch\nfrom fate.components.components.nn.nn_runner import (\n    load_model_dict_from_path,\n    dir_warning,\n    loader_load_from_conf,\n    run_dataset_func,\n)\nfrom typing import Dict\nfrom fate.components.components.nn.loader import Loader\nfrom typing import Union, Optional, Literal\nfrom transformers.trainer_utils import get_last_checkpoint\nfrom fate.arch.dataframe import DataFrame\nfrom fate.components.components.nn.runner.homo_default_runner import DefaultRunner\nfrom fate_llm.algo.fdkt import FDKTTrainingArguments, FDKTSLM, FDKTLLM\n\nlogger = logging.getLogger(__name__)\nAUG_DATA_SAVED_PATH_SUFFIX = \"aug_data.pkl\"\nDP_MODEL_SAVED_PATH_SUFFIX = \"dp_model\"\n\n\nclass FDKTRunner(DefaultRunner):\n    def __init__(\n        self,\n        algo: str = \"fdkt\",\n        inference_inst_conf: Optional[Dict] = None,\n        model_conf: Optional[Dict] = None,\n        embedding_model_conf: Optional[Dict] = None,\n        optimizer_conf: Optional[Dict] = None,\n        training_args_conf: Optional[Dict] = None,\n        dataset_conf: Optional[Dict] = None,\n        data_collator_conf: Optional[Dict] = None,\n        tokenizer_conf: Optional[Dict] = None,\n        task_type: Literal[\"causal_lm\", \"others\"] = \"causal_lm\",\n        save_dp_model: bool = False,\n    ) -> None:\n        super(FDKTRunner, self).__init__()\n        self.algo = algo\n        self.inference_inst_conf = inference_inst_conf\n        self.model_conf = model_conf\n        self.embedding_model_conf = embedding_model_conf\n        self.optimizer_conf = optimizer_conf\n        self.training_args_conf = training_args_conf\n        self.dataset_conf = dataset_conf\n        self.data_collator_conf = data_collator_conf\n        self.tokenizer_conf = tokenizer_conf\n        self.task_type = task_type\n        self.save_dp_model = save_dp_model\n\n        self.training_args = None\n\n        # check param\n        if self.algo.lower() != \"fdkt\":\n            raise ValueError(f\"algo should be fdkt\")\n        if self.task_type not in [\"causal_lm\"]:\n            raise ValueError(\"task_type should be causal_lm\")\n\n    def common_setup(self, saved_model=None, output_dir=None):\n        ctx = self.get_context()\n\n        if output_dir is None:\n            output_dir = \"./\"\n\n        if self.model_conf is not None:\n            model = loader_load_from_conf(self.model_conf)\n        else:\n            model = None\n\n        resume_path = None\n        if saved_model is not None:\n            model_dict = load_model_dict_from_path(saved_model)\n            model.load_state_dict(model_dict)\n            logger.info(f\"loading model dict from {saved_model} to model done\")\n            if get_last_checkpoint(saved_model) is not None:\n                resume_path = saved_model\n                logger.info(f\"checkpoint detected, resume_path set to {resume_path}\")\n\n        # load tokenizer if import conf provided\n        if self.tokenizer_conf is not None:\n            tokenizer = loader_load_from_conf(self.tokenizer_conf)\n        else:\n            tokenizer = None\n\n        # args\n        dir_warning(self.training_args_conf)\n        training_args = FDKTTrainingArguments(**self.training_args_conf)\n        # reset to default, saving to arbitrary path is not allowed in\n        # DefaultRunner\n        training_args.output_dir = output_dir\n        training_args.resume_from_checkpoint = resume_path  # resume path\n\n        self.training_args = training_args\n        dataset = loader_load_from_conf(self.dataset_conf)\n\n        return ctx, model, tokenizer, training_args, dataset\n\n    def llm_setup(self, train_set=None, validate_set=None, output_dir=None, saved_model=None):\n        ctx, model, tokenizer, training_args, dataset = self.common_setup(\n            output_dir=output_dir, saved_model=saved_model)\n\n        if model is not None:\n            model = model.load()\n\n        inference_inst = None\n        if self.inference_inst_conf is not None:\n            inference_inst = loader_load_from_conf(self.inference_inst_conf)\n\n        embedding_model = loader_load_from_conf(self.embedding_model_conf)\n        if embedding_model is None:\n            raise ValueError(f\"model is None, cannot load model from conf {self.model_conf}\")\n        embedding_model = embedding_model.load()\n\n        trainer = FDKTLLM(\n            ctx=ctx,\n            inference_inst=inference_inst,\n            model=model,\n            embedding_model=embedding_model,\n            training_args=training_args,\n            tokenizer=tokenizer,\n            dataset=dataset,\n        )\n\n        return trainer\n\n    def slm_setup(self, train_set=None, validate_set=None, output_dir=None, saved_model=None):\n        ctx, model, tokenizer, training_args, dataset = self.common_setup(\n            output_dir=output_dir, saved_model=saved_model)\n        model = model.load()\n\n        dataset.load(train_set)\n\n        if self.data_collator_conf is not None:\n            data_collator = loader_load_from_conf(self.data_collator_conf)\n        else:\n            data_collator = None\n\n        optimizer_loader = Loader.from_dict(self.optimizer_conf)\n        optimizer_ = optimizer_loader.load_item()\n        optimizer_params = optimizer_loader.kwargs\n        optimizer = optimizer_(model.parameters(), **optimizer_params)\n\n        trainer = FDKTSLM(\n            ctx=ctx,\n            model=model,\n            training_args=training_args,\n            tokenizer=tokenizer,\n            train_set=dataset,\n            data_collator=data_collator,\n            optimizer=optimizer,\n        )\n\n        return trainer\n\n    def train(\n        self,\n        train_data: Optional[Union[str, DataFrame]] = None,\n        validate_data: Optional[Union[str, DataFrame]] = None,\n        output_dir: str = None,\n        saved_model_path: str = None,\n    ):\n\n        if self.is_client():\n            trainer = self.slm_setup(train_set=train_data, validate_set=validate_data, output_dir=output_dir, saved_model=saved_model_path)\n            aug_data = trainer.aug_data()\n\n            data_saved_path = output_dir + '/' + AUG_DATA_SAVED_PATH_SUFFIX\n            logger.info('result save to path {}'.format(data_saved_path))\n            torch.save(aug_data, data_saved_path)\n\n            if self.save_dp_model:\n                model_save_dir = output_dir + \"/\" + DP_MODEL_SAVED_PATH_SUFFIX\n                trainer.save_model(model_save_dir)\n\n        else:\n            trainer = self.llm_setup(\n                train_set=train_data, validate_set=validate_data, output_dir=output_dir, saved_model=saved_model_path\n            )\n            trainer.aug_data()\n\n    def predict(self, *args, **kwargs):\n        pass\n"
  },
  {
    "path": "python/fate_llm/runner/fedcot_runner.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\nimport torch\nfrom fate.components.components.nn.nn_runner import (\n    NNRunner,\n    load_model_dict_from_path,\n    dir_warning,\n    loader_load_from_conf,\n)\nfrom fate_llm.model_zoo.hf_model import HFAutoModelForCausalLM\nfrom fate.components.components.nn.loader import Loader\nfrom fate.arch.dataframe import DataFrame\nfrom fate.ml.nn.dataset.base import Dataset\nfrom typing import Dict\nfrom fate_llm.algo.fedcot.fedcot_trainer import FedCoTTrainerClient, FedCoTTraineServer\nfrom fate_llm.algo.fedcot.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderClient, SLMEncoderDecoderServer\nfrom fate_llm.algo.inferdpt.init._init import InferInit\nimport torch.nn as nn\nimport torch.optim as optim\nfrom fate_llm.trainer.seq2seq_trainer import Seq2SeqTrainingArguments\nfrom typing import Union, Type, Callable, Optional\nfrom transformers.trainer_utils import get_last_checkpoint\nfrom typing import Literal\nimport logging\n\n\nlogger = logging.getLogger(__name__)\n\n\n\ndef _check_instances(\n    model: nn.Module = None,\n    optimizer: optim.Optimizer = None,\n    train_args: Seq2SeqTrainingArguments = None,\n    data_collator: Callable = None,\n) -> None:\n    \n    if model is not None and not issubclass(type(model), nn.Module):\n        raise TypeError(f\"SetupReturn Error: model must be a subclass of torch.nn.Module but got {type(model)}\")\n\n    if optimizer is not None and not issubclass(type(optimizer), optim.Optimizer):\n        raise TypeError(\n            f\"SetupReturn Error: optimizer must be a subclass of torch.optim.Optimizer but got {type(optimizer)}\"\n        )\n\n    if train_args is not None and not isinstance(train_args, Seq2SeqTrainingArguments):\n        raise TypeError(\n            f\"SetupReturn Error: train_args must be an instance of Seq2SeqTrainingArguments \"\n            f\"but got {type(train_args)}\"\n        )\n\n    if data_collator is not None and not callable(data_collator):\n        raise TypeError(f\"SetupReturn Error: data_collator must be callable but got {type(data_collator)}\")\n\n\nclass FedCoTRunner(NNRunner):\n    def __init__(\n        self,\n        mode: Literal['train_only', 'infer_only', 'infer_and_train'],\n        model_conf: Optional[Dict] = None,\n        dataset_conf: Optional[Dict] = None,\n        optimizer_conf: Optional[Dict] = None,\n        training_args_conf: Optional[Dict] = None,\n        data_collator_conf: Optional[Dict] = None,\n        tokenizer_conf: Optional[Dict] = None,\n        infer_inst_init_conf: Dict = None,\n        encode_template: str = None,\n        instruction_template: str = None,\n        decode_template: str = None,\n        remote_inference_kwargs: Dict = {},\n        local_inference_kwargs: Dict = {},\n        perturb_doc_key: str = 'perturbed_doc',\n        perturbed_response_key: str = 'perturbed_response',\n        result_key: str = 'infer_result',\n    ) -> None:\n        super(NNRunner, self).__init__()\n        self.model_conf = model_conf\n        self.dataset_conf = dataset_conf\n        self.optimizer_conf = optimizer_conf\n        self.training_args_conf = training_args_conf\n        self.data_collator_conf = data_collator_conf\n        self.mode = mode\n        self.tokenizer_conf = tokenizer_conf\n        self.infer_inst_init_conf = infer_inst_init_conf\n        self.encode_template = encode_template\n        self.instruction_template = instruction_template\n        self.decode_template = decode_template\n        self.remote_inference_kwargs = remote_inference_kwargs\n        self.local_inference_kwargs = local_inference_kwargs\n        self.perturb_doc_key = perturb_doc_key\n        self.perturbed_response_key = perturbed_response_key\n        self.result_key = result_key\n        self._temp_data_path = ''\n\n        # setup var\n        self.trainer = None\n        self.training_args = None\n\n    def _get_infer_inst(self, init_conf):\n        if init_conf is None:\n            return None\n        loader = Loader.from_dict(init_conf)\n        init_inst = loader.load_item()(self.get_context())\n        assert isinstance(init_inst, InferInit), 'Need a InferInit class for initialization, but got {}'.format(type(init_inst))\n        infer_inst = init_inst.get_inst()\n        logger.info('inferdpt inst loaded')\n        return infer_inst\n    \n    def _prepare_data(self, data, data_name):\n        if data is None:\n            return None\n        if isinstance(data, DataFrame) and self.dataset_conf is None:\n            raise RuntimeError('DataFrame format dataset is not supported, please use bind path to load your dataset')\n        else:\n            dataset = loader_load_from_conf(self.dataset_conf)\n            if hasattr(dataset, \"load\"):\n                logger.info(\"load path is {}\".format(data))\n                import os\n                if os.path.exists(data) and os.path.isdir(data):\n                    self._temp_data_path = data\n                    load_output = dataset.load(data)\n                    if load_output is not None:\n                        dataset = load_output\n                        return dataset\n                else:\n                    raise RuntimeError('You must offer an existing folder path as data input, but got {}'.format(data))\n            else:\n                raise ValueError(\n                    f\"The dataset {dataset} lacks a load() method, which is required for data parsing in the DefaultRunner. \\\n                                Please implement this method in your dataset class. You can refer to the base class 'Dataset' in 'fate.ml.nn.dataset.base' \\\n                                for the necessary interfaces to implement.\"\n                )\n        if dataset is not None and not issubclass(type(dataset), Dataset):\n            raise TypeError(\n                f\"SetupReturn Error: {data_name}_set must be a subclass of fate built-in Dataset but got {type(dataset)}, \\n\"\n                f\"You can get the class via: from fate.ml.nn.dataset.table import Dataset\"\n            )\n\n        return dataset\n    \n    def client_setup(self, train_set=None, validate_set=None, output_dir=None, saved_model=None, stage=\"train\"):\n\n        ctx = self.get_context()\n        model = loader_load_from_conf(self.model_conf)\n        if isinstance(model, HFAutoModelForCausalLM):\n            model = model.load()\n\n        if model is None:\n            raise ValueError(f\"model is None, cannot load model from conf {self.model_conf}\")\n        if output_dir is None:\n            output_dir = \"./\"\n\n        resume_path = None\n        if saved_model is not None:\n            model_dict = load_model_dict_from_path(saved_model)\n            model.load_state_dict(model_dict)\n            logger.info(f\"loading model dict from {saved_model} to model done\")\n            if get_last_checkpoint(saved_model) is not None:\n                resume_path = saved_model\n                logger.info(f\"checkpoint detected, resume_path set to {resume_path}\")\n\n        # load optimizer\n        if self.optimizer_conf:\n            optimizer_loader = Loader.from_dict(self.optimizer_conf)\n            optimizer_ = optimizer_loader.load_item()\n            optimizer_params = optimizer_loader.kwargs\n            optimizer = optimizer_(model.parameters(), **optimizer_params)\n        else:\n            optimizer = None\n\n        # load collator func\n        data_collator = loader_load_from_conf(self.data_collator_conf)\n\n        # load tokenizer if import conf provided\n        tokenizer = loader_load_from_conf(self.tokenizer_conf)\n\n        # args\n        dir_warning(self.training_args_conf)\n        training_args = Seq2SeqTrainingArguments(**self.training_args_conf)\n        # reset to default, saving to arbitrary path is not allowed in\n        # DefaultRunner\n        training_args.output_dir = output_dir\n        training_args.resume_from_checkpoint = resume_path  # resume path\n        self.training_args = training_args\n\n        if self.training_args.world_size > 0 and self.training_args.local_rank == 0:\n            infer_client = self._get_infer_inst(self.infer_inst_init_conf)\n        else:\n            infer_client = None # only rank 0 need to load the client\n        \n        # prepare trainer\n        trainer = FedCoTTrainerClient(\n            ctx=ctx,\n            training_args=training_args,\n            train_set=train_set,\n            val_set=validate_set,\n            model=model,\n            tokenizer=tokenizer,\n            mode=self.mode,\n            encode_template=self.encode_template,\n            decode_template=self.decode_template,\n            instruction_template=self.instruction_template,\n            local_inference_kwargs=self.local_inference_kwargs,\n            remote_inference_kwargs=self.remote_inference_kwargs,\n            data_collator=data_collator,\n            optimizer=optimizer,\n            infer_client=infer_client,\n            tmp_data_share_path=self._temp_data_path\n        )\n\n        return trainer\n\n    def server_setup(self, stage=\"train\"):\n        trainer = FedCoTTraineServer(\n            ctx=self.get_context(),\n            infer_server=self._get_infer_inst(self.infer_inst_init_conf)\n        )\n        return trainer\n\n    def train(\n        self,\n        train_data: Optional[Union[str]] = None,\n        validate_data: Optional[Union[str]] = None,\n        output_dir: str = None,\n        saved_model_path: str = None,\n    ):\n        if self.is_client():\n            train_set = self._prepare_data(train_data, \"train_data\")\n            validate_set = self._prepare_data(validate_data, \"val_data\")\n            trainer = self.client_setup(\n                train_set=train_set, validate_set=validate_set, output_dir=output_dir, saved_model=saved_model_path\n            )\n            self.trainer = trainer\n            trainer.train()\n\n            if self.mode == 'infer_only':\n                # save result dataset to the output dir\n                saving_path = output_dir + '/' + 'inference_result.pkl'\n                torch.save(train_set.dataset, saving_path)\n                logger.info('inference result saved to {}'.format(saving_path))\n            else:\n                if output_dir is not None:\n                    if self.training_args.deepspeed and self.training_args.local_rank != 0:\n                        pass\n                    else:\n                        trainer.save_model(output_dir)\n\n        elif self.is_server():\n            if self.mode == 'train_only':\n                return \n            else:\n                trainer = self.server_setup()\n                trainer.train()\n\n    def predict(self, test_data: Union[str], saved_model_path: str = None) -> None:\n        logger.warning('The prediction mode is not supported by this algorithm in the current version. Please perform inference using locally saved models.')\n        return \n"
  },
  {
    "path": "python/fate_llm/runner/fedkseed_runner.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n\nimport logging\nfrom typing import Dict\nfrom typing import Literal\nfrom typing import Optional\n\nimport transformers\nfrom fate.components.components.nn.nn_runner import (\n    NNRunner,\n    dir_warning,\n    loader_load_from_conf,\n)\nfrom fate.components.components.nn.runner.homo_default_runner import DefaultRunner\n\nfrom fate_llm.algo.fedkseed.fedkseed import Trainer, FedKSeedTrainingArguments, ClientTrainer\nfrom fate_llm.algo.fedkseed.zo_utils import build_seed_candidates\nfrom fate_llm.trainer.seq2seq_trainer import Seq2SeqTrainingArguments\n\nlogger = logging.getLogger(__name__)\n\nSUPPORTED_ALGO = [\"fedkseed\"]\n\n\nclass FedKSeedRunner(DefaultRunner):\n    def __init__(\n            self,\n            algo: str = \"fedkseed\",\n            model_conf: Optional[Dict] = None,\n            dataset_conf: Optional[Dict] = None,\n            optimizer_conf: Optional[Dict] = None,\n            training_args_conf: Optional[Dict] = None,\n            fed_args_conf: Optional[Dict] = None,\n            data_collator_conf: Optional[Dict] = None,\n            tokenizer_conf: Optional[Dict] = None,\n            task_type: Literal[\"causal_lm\", \"other\"] = \"causal_lm\",\n            local_mode: bool = False,\n            save_trainable_weights_only: bool = False,\n    ) -> None:\n        super(NNRunner, self).__init__()\n        self.algo = algo\n        self.model_conf = model_conf\n        self.dataset_conf = dataset_conf\n        self.optimizer_conf = optimizer_conf\n        self.training_args_conf = training_args_conf\n        self.fed_args_conf = fed_args_conf\n        self.data_collator_conf = data_collator_conf\n        self.local_mode = local_mode\n        self.tokenizer_conf = tokenizer_conf\n        self.task_type = task_type\n        self.save_trainable_weights_only = save_trainable_weights_only\n\n        # check param\n        if self.algo not in SUPPORTED_ALGO:\n            raise ValueError(f\"algo should be one of {SUPPORTED_ALGO}\")\n        if self.task_type not in [\"causal_lm\", \"others\"]:\n            raise ValueError(\"task_type should be one of [binary, multi, regression, others]\")\n        assert isinstance(self.local_mode, bool), \"local should be bool\"\n\n        # setup var\n        self.trainer = None\n        self.training_args = None\n\n    def client_setup(self, train_set=None, validate_set=None, output_dir=None, saved_model=None, stage=\"train\"):\n        if self.algo != \"fedkseed\":\n            raise ValueError(f\"algo {self.algo} not supported\")\n\n        ctx = self.get_context()\n\n        model = maybe_loader_load_from_conf(self.model_conf)\n        if model is None:\n            raise ValueError(f\"model is None, cannot load model from conf {self.model_conf}\")\n\n        if output_dir is None:\n            output_dir = \"./\"\n\n        tokenizer = transformers.AutoTokenizer.from_pretrained(**self.data_collator_conf[\"kwargs\"][\"tokenizer_params\"])\n\n        data_collator = transformers.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)\n        dir_warning(self.training_args_conf)\n\n        training_args = Seq2SeqTrainingArguments(**self.training_args_conf)\n        self.training_args = training_args\n        training_args.output_dir = output_dir\n        fedkseed_args = FedKSeedTrainingArguments(**self.fed_args_conf)\n        logger.debug(f\"training_args: {training_args}\")\n        logger.debug(f\"fedkseed_args: {fedkseed_args}\")\n        trainer = ClientTrainer(\n            ctx=ctx,\n            model=model,\n            training_args=training_args,\n            fedkseed_args=fedkseed_args,\n            data_collator=data_collator,\n            tokenizer=tokenizer,\n            train_dataset=train_set,\n            eval_dataset=validate_set,\n        )\n        return trainer\n\n    def server_setup(self, stage=\"train\"):\n\n        if self.algo != \"fedkseed\":\n            raise ValueError(f\"algo {self.algo} not supported\")\n        ctx = self.get_context()\n\n        fedkseed_args = FedKSeedTrainingArguments(**self.fed_args_conf)\n        training_args = Seq2SeqTrainingArguments(**self.training_args_conf)\n\n        seed_candidates = build_seed_candidates(fedkseed_args.k, low=0, high=2 ** 32)\n        trainer = Trainer(ctx=ctx, seed_candidates=seed_candidates, args=training_args, fedkseed_args=fedkseed_args)\n        return trainer\n\n\ndef maybe_loader_load_from_conf(conf):\n    from fate_llm.model_zoo.hf_model import HFAutoModelForCausalLM\n\n    model = loader_load_from_conf(conf)\n    if isinstance(model, HFAutoModelForCausalLM):\n        model = model.load()\n    return model\n"
  },
  {
    "path": "python/fate_llm/runner/fedmkt_runner.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\nimport json\nfrom fate.components.components.nn.nn_runner import (\n    load_model_dict_from_path,\n    dir_warning,\n    loader_load_from_conf,\n    run_dataset_func,\n)\nfrom typing import Dict\nfrom fate.components.components.nn.loader import Loader\nfrom fate.ml.nn.homo.fedavg import FedAVGArguments\nfrom typing import Union, Optional, Literal, List\nfrom transformers.trainer_utils import get_last_checkpoint\nimport logging\nfrom fate.arch.dataframe import DataFrame\nfrom fate.components.components.nn.runner.homo_default_runner import DefaultRunner\nfrom fate_llm.algo.fedmkt import FedMKTTrainingArguments, FedMKTSLM, FedMKTLLM\n\nlogger = logging.getLogger(__name__)\n\n\nclass FedMKTRunner(DefaultRunner):\n\n    def __init__(\n        self,\n        algo: str = \"fedmkt\",\n        model_conf: Optional[Dict] = None,\n        optimizer_conf: Optional[Dict] = None,\n        training_args_conf: Optional[Dict] = None,\n        fed_args_conf: Optional[Dict] = None,\n        pub_dataset_conf: Optional[Dict] = None,\n        priv_dataset_conf: Optional[Dict] = None,\n        data_collator_conf: Optional[Dict] = None,\n        tokenizer_conf: Optional[Dict] = None,\n        llm_tokenizer_conf: Optional[Dict] = None,\n        slm_tokenizers_conf: List[Optional[Dict]] = None,\n        llm_to_slm_vocab_mapping_path: str = None,\n        slm_to_llm_vocab_mapping_paths: List[str] = None,\n        task_type: Literal[\"causal_lm\", \"others\"] = \"causal_lm\",\n        save_trainable_weights_only: bool = False,\n        pub_dataset_path: str = None,\n    ) -> None:\n        super(FedMKTRunner, self).__init__()\n        self.algo = algo\n        self.model_conf = model_conf\n        self.optimizer_conf = optimizer_conf\n        self.training_args_conf = training_args_conf\n        self.fed_args_conf = fed_args_conf\n        self.pub_dataset_conf = pub_dataset_conf\n        self.priv_dataset_conf = priv_dataset_conf\n        self.data_collator_conf = data_collator_conf\n        self.tokenizer_conf = tokenizer_conf\n        self.llm_tokenizer_conf = llm_tokenizer_conf\n        self.slm_tokenizers_conf = slm_tokenizers_conf\n        self.llm_to_slm_vocab_mapping_path = llm_to_slm_vocab_mapping_path\n        self.slm_to_llm_vocab_mapping_paths = slm_to_llm_vocab_mapping_paths\n        self.task_type = task_type\n        self.pub_dataset_path = pub_dataset_path\n\n        self.save_trainable_weights_only = save_trainable_weights_only\n\n        self.training_args = None\n\n        # check param\n        if self.algo.lower() != \"fedmkt\":\n            raise ValueError(f\"algo should be fedmkt\")\n        if self.task_type not in [\"causal_lm\"]:\n            raise ValueError(\"task_type should be causal_lm\")\n\n    def common_setup(self, saved_model=None, output_dir=None):\n        ctx = self.get_context()\n\n        if output_dir is None:\n            output_dir = \"./\"\n\n        model = loader_load_from_conf(self.model_conf)\n        if model is None:\n            raise ValueError(f\"model is None, cannot load model from conf {self.model_conf}\")\n\n        resume_path = None\n        if saved_model is not None:\n            model_dict = load_model_dict_from_path(saved_model)\n            model.load_state_dict(model_dict)\n            logger.info(f\"loading model dict from {saved_model} to model done\")\n            if get_last_checkpoint(saved_model) is not None:\n                resume_path = saved_model\n                logger.info(f\"checkpoint detected, resume_path set to {resume_path}\")\n\n        # load optimizer\n        if self.optimizer_conf:\n            optimizer_loader = Loader.from_dict(self.optimizer_conf)\n            optimizer_ = optimizer_loader.load_item()\n            optimizer_params = optimizer_loader.kwargs\n            optimizer = optimizer_(model.parameters(), **optimizer_params)\n        else:\n            optimizer = None\n\n        # load tokenizer if import conf provided\n        tokenizer = loader_load_from_conf(self.tokenizer_conf)\n\n        # args\n        dir_warning(self.training_args_conf)\n        training_args = FedMKTTrainingArguments(**self.training_args_conf)\n        # reset to default, saving to arbitrary path is not allowed in\n        # DefaultRunner\n        training_args.output_dir = output_dir\n        training_args.resume_from_checkpoint = resume_path  # resume path\n\n        self.training_args = training_args\n\n        if self.fed_args_conf is not None:\n            fed_args = FedAVGArguments(**self.fed_args_conf)\n        else:\n            fed_args = None\n\n        pub_dataset = loader_load_from_conf(self.pub_dataset_conf)\n        pub_dataset.load(self.pub_dataset_path)\n\n        return ctx, model, optimizer, tokenizer, training_args, fed_args, pub_dataset\n\n    def llm_setup(self, train_set=None, validate_set=None, output_dir=None, saved_model=None):\n        ctx, model, optimizer, tokenizer, training_args, fed_args, pub_dataset = self.common_setup(\n            output_dir=output_dir, saved_model=saved_model)\n\n        if validate_set is not None:\n            validate_dataset = loader_load_from_conf(self.pub_dataset_conf)\n            validate_dataset.load(validate_set)\n        else:\n            validate_dataset = None\n\n        slm_tokenizers = None\n        if self.slm_tokenizers_conf:\n            slm_tokenizers = [loader_load_from_conf(tokenizer_conf) for tokenizer_conf in self.slm_tokenizers_conf]\n\n        slm_to_llm_vocab_mappings = []\n        for vocab_mapping_path in self.slm_to_llm_vocab_mapping_paths:\n            with open(vocab_mapping_path, \"r\") as fin:\n                vocab_mapping = json.loads(fin.read())\n                slm_to_llm_vocab_mappings.append(vocab_mapping)\n\n        trainer = FedMKTLLM(\n            ctx=ctx,\n            model=model,\n            training_args=training_args,\n            fed_args=fed_args,\n            train_set=pub_dataset,\n            val_set=validate_dataset,\n            tokenizer=tokenizer,\n            slm_tokenizers=slm_tokenizers,\n            slm_to_llm_vocab_mappings=slm_to_llm_vocab_mappings,\n            save_trainable_weights_only=self.save_trainable_weights_only,\n        )\n\n        return trainer\n\n    def slm_setup(self, train_set=None, validate_set=None, output_dir=None, saved_model=None):\n        ctx, model, optimizer, tokenizer, training_args, fed_args, pub_dataset = self.common_setup(\n            output_dir=output_dir, saved_model=saved_model)\n\n        priv_dataset = loader_load_from_conf(self.priv_dataset_conf)\n        priv_dataset.load(train_set)\n\n        if validate_set is not None:\n            validate_dataset = loader_load_from_conf(self.priv_dataset_conf)\n            validate_dataset.load(validate_set)\n        else:\n            validate_dataset = None\n\n        llm_tokenizer = loader_load_from_conf(self.llm_tokenizer_conf)\n\n        with open(self.llm_to_slm_vocab_mapping_path, \"r\") as fin:\n            vocab_mapping = json.loads(fin.read())\n\n        priv_data_collator = loader_load_from_conf(self.data_collator_conf)\n\n        trainer = FedMKTSLM(\n            ctx=ctx,\n            model=model,\n            training_args=training_args,\n            fed_args=fed_args,\n            pub_train_set=pub_dataset,\n            priv_train_set=priv_dataset,\n            val_set=validate_dataset,\n            tokenizer=tokenizer,\n            save_trainable_weights_only=self.save_trainable_weights_only,\n            llm_tokenizer=llm_tokenizer,\n            llm_to_slm_vocab_mapping=vocab_mapping,\n            data_collator=priv_data_collator\n        )\n\n        return trainer\n\n    def train(\n        self,\n        train_data: Optional[Union[str, DataFrame]] = None,\n        validate_data: Optional[Union[str, DataFrame]] = None,\n        output_dir: str = None,\n        saved_model_path: str = None,\n    ):\n\n        if self.is_client():\n            trainer = self.slm_setup(train_set=train_data, validate_set=validate_data, output_dir=output_dir, saved_model=saved_model_path)\n            trainer.train()\n        else:\n            trainer = self.llm_setup(\n                train_set=train_data, validate_set=validate_data, output_dir=output_dir, saved_model=saved_model_path\n            )\n            trainer.train()\n\n        self.trainer = trainer\n\n        if self.training_args.deepspeed and self.training_args.local_rank != 0:\n            pass\n        else:\n            trainer.save_model(output_dir)\n\n    def predict(self, *args, **kwargs):\n        pass\n"
  },
  {
    "path": "python/fate_llm/runner/homo_seq2seq_runner.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n\nfrom fate.components.components.nn.nn_runner import (\n    NNRunner,\n    load_model_dict_from_path,\n    dir_warning,\n    loader_load_from_conf,\n    run_dataset_func,\n)\nfrom fate.components.components.nn.runner.homo_default_runner import DefaultRunner\nfrom fate.ml.nn.homo.fedavg import FedAVGArguments\nfrom fate_llm.algo.fedavg.fedavg import Seq2SeqFedAVGClient, Seq2SeqFedAVGServer\nfrom typing import Dict\nfrom fate.components.components.nn.loader import Loader\nimport torch.nn as nn\nimport torch.optim as optim\nfrom fate.ml.nn.trainer.trainer_base import FedArguments, HomoTrainerServer\nfrom fate_llm.trainer.seq2seq_trainer import Seq2SeqTrainingArguments, HomoSeq2SeqTrainerClient\nfrom typing import Union, Type, Callable, Optional\nfrom transformers.trainer_utils import get_last_checkpoint\nfrom typing import Literal\nimport logging\nfrom fate.arch.dataframe import DataFrame\n\nlogger = logging.getLogger(__name__)\n\n\nSUPPORTED_ALGO = [\"fedavg\", \"ot\"]\n\n\ndef _check_instances(\n    trainer: Union[Type[HomoSeq2SeqTrainerClient], Type[HomoTrainerServer]] = None,\n    fed_args: FedArguments = None,\n    model: nn.Module = None,\n    optimizer: optim.Optimizer = None,\n    train_args: Seq2SeqTrainingArguments = None,\n    data_collator: Callable = None,\n) -> None:\n    if trainer is not None and not (\n        issubclass(type(trainer), HomoSeq2SeqTrainerClient) or issubclass(type(trainer), HomoTrainerServer)\n    ):\n        raise TypeError(\n            f\"SetupReturn Error: trainer must be a subclass of either \"\n            f\"HomoSeq2SeqTrainerClient or HomoSeq2SeqTrainerClient but got {type(trainer)}\"\n        )\n\n    if fed_args is not None and not isinstance(fed_args, FedArguments):\n        raise TypeError(f\"SetupReturn Error: fed_args must be an instance of FedArguments but got {type(fed_args)}\")\n\n    if model is not None and not issubclass(type(model), nn.Module):\n        raise TypeError(f\"SetupReturn Error: model must be a subclass of torch.nn.Module but got {type(model)}\")\n\n    if optimizer is not None and not issubclass(type(optimizer), optim.Optimizer):\n        raise TypeError(\n            f\"SetupReturn Error: optimizer must be a subclass of torch.optim.Optimizer but got {type(optimizer)}\"\n        )\n\n    if train_args is not None and not isinstance(train_args, Seq2SeqTrainingArguments):\n        raise TypeError(\n            f\"SetupReturn Error: train_args must be an instance of Seq2SeqTrainingArguments \"\n            f\"but got {type(train_args)}\"\n        )\n\n    if data_collator is not None and not callable(data_collator):\n        raise TypeError(f\"SetupReturn Error: data_collator must be callable but got {type(data_collator)}\")\n\n\nclass Seq2SeqRunner(DefaultRunner):\n    def __init__(\n        self,\n        algo: str = \"fedavg\",\n        model_conf: Optional[Dict] = None,\n        dataset_conf: Optional[Dict] = None,\n        optimizer_conf: Optional[Dict] = None,\n        training_args_conf: Optional[Dict] = None,\n        fed_args_conf: Optional[Dict] = None,\n        data_collator_conf: Optional[Dict] = None,\n        tokenizer_conf: Optional[Dict] = None,\n        task_type: Literal[\"causal_lm\", \"other\"] = \"causal_lm\",\n        local_mode: bool = False,\n        save_trainable_weights_only: bool = False,\n    ) -> None:\n        super(NNRunner, self).__init__()\n        self.algo = algo\n        self.model_conf = model_conf\n        self.dataset_conf = dataset_conf\n        self.optimizer_conf = optimizer_conf\n        self.training_args_conf = training_args_conf\n        self.fed_args_conf = fed_args_conf\n        self.data_collator_conf = data_collator_conf\n        self.local_mode = local_mode\n        self.tokenizer_conf = tokenizer_conf\n        self.task_type = task_type\n        self.save_trainable_weights_only = save_trainable_weights_only\n\n        # check param\n        if self.algo not in SUPPORTED_ALGO:\n            raise ValueError(f\"algo should be one of {SUPPORTED_ALGO}\")\n        if self.task_type not in [\"causal_lm\", \"others\"]:\n            raise ValueError(\"task_type should be one of [binary, multi, regression, others]\")\n        assert isinstance(self.local_mode, bool), \"local should be bool\"\n\n        # setup var\n        self.trainer = None\n        self.training_args = None\n\n    def client_setup(self, train_set=None, validate_set=None, output_dir=None, saved_model=None, stage=\"train\"):\n        if stage == \"predict\":\n            self.local_mode = True\n\n        if self.algo == \"fedavg\":\n            client_class: Seq2SeqFedAVGClient = Seq2SeqFedAVGClient\n        else:\n            raise ValueError(f\"algo {self.algo} not supported\")\n\n        ctx = self.get_context()\n        model = loader_load_from_conf(self.model_conf)\n        if model is None:\n            raise ValueError(f\"model is None, cannot load model from conf {self.model_conf}\")\n\n        if output_dir is None:\n            output_dir = \"./\"\n\n        resume_path = None\n        if saved_model is not None:\n            model_dict = load_model_dict_from_path(saved_model)\n            model.load_state_dict(model_dict)\n            logger.info(f\"loading model dict from {saved_model} to model done\")\n            if get_last_checkpoint(saved_model) is not None:\n                resume_path = saved_model\n                logger.info(f\"checkpoint detected, resume_path set to {resume_path}\")\n        # load optimizer\n        if self.optimizer_conf:\n            optimizer_loader = Loader.from_dict(self.optimizer_conf)\n            optimizer_ = optimizer_loader.load_item()\n            optimizer_params = optimizer_loader.kwargs\n            optimizer = optimizer_(model.parameters(), **optimizer_params)\n        else:\n            optimizer = None\n        # load collator func\n        data_collator = loader_load_from_conf(self.data_collator_conf)\n        # load tokenizer if import conf provided\n        tokenizer = loader_load_from_conf(self.tokenizer_conf)\n        # args\n        dir_warning(self.training_args_conf)\n        training_args = Seq2SeqTrainingArguments(**self.training_args_conf)\n        self.training_args = training_args\n        # reset to default, saving to arbitrary path is not allowed in\n        # DefaultRunner\n        training_args.output_dir = output_dir\n        training_args.resume_from_checkpoint = resume_path  # resume path\n        fed_args = FedAVGArguments(**self.fed_args_conf)\n\n        # prepare trainer\n        trainer = client_class(\n            ctx=ctx,\n            model=model,\n            optimizer=optimizer,\n            training_args=training_args,\n            fed_args=fed_args,\n            data_collator=data_collator,\n            tokenizer=tokenizer,\n            train_set=train_set,\n            val_set=validate_set,\n            local_mode=self.local_mode,\n            save_trainable_weights_only=self.save_trainable_weights_only,\n        )\n\n        _check_instances(\n            trainer=trainer,\n            model=model,\n            optimizer=optimizer,\n            train_args=training_args,\n            fed_args=fed_args,\n            data_collator=data_collator,\n        )\n        return trainer\n\n    def server_setup(self, stage=\"train\"):\n        if stage == \"predict\":\n            self.local_mode = True\n        if self.algo == \"fedavg\":\n            server_class: Seq2SeqFedAVGServer = Seq2SeqFedAVGServer\n        else:\n            raise ValueError(f\"algo {self.algo} not supported\")\n        ctx = self.get_context()\n        trainer = server_class(ctx=ctx, local_mode=self.local_mode)\n        _check_instances(trainer)\n        return trainer\n\n    def predict(self, test_data: Union[str, DataFrame], saved_model_path: str = None) -> Union[DataFrame, None]:\n        if self.is_client():\n            test_set = self._prepare_data(test_data, \"test_data\")\n            if self.trainer is not None:\n                trainer = self.trainer\n                logger.info(\"trainer found, skip setting up\")\n            else:\n                trainer = self.client_setup(saved_model=saved_model_path, stage=\"predict\")\n\n            classes = run_dataset_func(test_set, \"get_classes\")\n            match_ids = run_dataset_func(test_set, \"get_match_ids\")\n            sample_ids = run_dataset_func(test_set, \"get_sample_ids\")\n            match_id_name = run_dataset_func(test_set, \"get_match_id_name\")\n            sample_id_name = run_dataset_func(test_set, \"get_sample_id_name\")\n\n            if not self.training_args.predict_with_generate:\n                return\n\n            pred_rs = trainer.predict(test_set)\n\n            if self.training_args and self.training_args.deepspeed and self.training_args.local_rank != 0:\n                return\n\n            rs_df = self.get_nn_output_dataframe(\n                self.get_context(),\n                pred_rs.predictions,\n                pred_rs.label_ids if hasattr(pred_rs, \"label_ids\") else None,\n                match_ids,\n                sample_ids,\n                match_id_name=match_id_name,\n                sample_id_name=sample_id_name,\n                dataframe_format=\"dist_df\",\n                task_type=self.task_type,\n                classes=classes,\n            )\n            return rs_df\n        else:\n            # server not predict\n            return\n"
  },
  {
    "path": "python/fate_llm/runner/inferdpt_runner.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n\nfrom fate.components.components.nn.nn_runner import (\n    NNRunner,\n    load_model_dict_from_path,\n    dir_warning,\n    loader_load_from_conf,\n    run_dataset_func,\n)\nimport os\nfrom datetime import datetime\nfrom fate.components.components.nn.nn_runner import NNRunner\nfrom typing import Dict\nfrom fate.components.components.nn.loader import Loader\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom typing import Union, Type, Callable, Optional\nfrom typing import Literal\nimport logging\nfrom fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer\nfrom fate_llm.algo.inferdpt.init._init import InferInit\nfrom fate.components.components.nn.loader import Loader\nfrom fate_llm.dataset.hf_dataset import HuggingfaceDataset, Dataset\nfrom fate.arch.dataframe import DataFrame\n\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass InferDPTRunner(NNRunner):\n\n    def __init__(\n        self,\n        inferdpt_init_conf: Dict,\n        encode_template: str = None,\n        instruction_template: str = None,\n        decode_template: str = None,\n        dataset_conf: Optional[Dict] = None,\n        remote_inference_kwargs: Dict = {},\n        local_inference_kwargs: Dict = {},\n        perturb_doc_key: str = 'perturbed_doc',\n        perturbed_response_key: str = 'perturbed_response',\n        result_key: str = 'inferdpt_result',\n    ) -> None:\n        self.inferdpt_init_conf = inferdpt_init_conf\n        self.encode_template = encode_template\n        self.instruction_template = instruction_template\n        self.decode_template = decode_template\n        self.dataset_conf = dataset_conf\n        self.remote_inference_kwargs = remote_inference_kwargs\n        self.local_inference_kwargs = local_inference_kwargs\n        self.perturb_doc_key = perturb_doc_key\n        self.perturbed_response_key = perturbed_response_key\n        self.result_key = result_key\n\n    def _get_inst(self):\n        loader = Loader.from_dict(self.inferdpt_init_conf)\n        init_inst = loader.load_item()(self.get_context())\n        assert isinstance(init_inst, InferInit), 'Need a InferDPTInit class for initialization, but got {}'.format(type(init_inst))\n        inferdpt_inst = init_inst.get_inst()\n        logger.info('inferdpt inst loaded')\n        return inferdpt_inst\n    \n    def client_setup(self):\n        client_inst = self._get_inst()\n        assert isinstance(client_inst, InferDPTClient), 'Client need to get an InferDPTClient class to run the algo'\n        return client_inst\n\n    def server_setup(self):\n        server_inst = self._get_inst()\n        assert isinstance(server_inst, InferDPTServer), 'Server need to get an InferDPTServer class to run the algo'\n        return server_inst\n\n    def _prepare_data(self, data, data_name):\n        if data is None:\n            return None\n        if isinstance(data, DataFrame) and self.dataset_conf is None:\n            raise ValueError('DataFrame format dataset is not supported, please use bind path to load your dataset')\n        else:\n            dataset = loader_load_from_conf(self.dataset_conf)\n            if hasattr(dataset, \"load\"):\n                logger.info(\"load path is {}\".format(data))\n                load_output = dataset.load(data)\n                if load_output is not None:\n                    dataset = load_output\n                    return dataset\n            else:\n                raise ValueError(\n                    f\"The dataset {dataset} lacks a load() method, which is required for data parsing in the DefaultRunner. \\\n                                Please implement this method in your dataset class. You can refer to the base class 'Dataset' in 'fate.ml.nn.dataset.base' \\\n                                for the necessary interfaces to implement.\"\n                )\n        if dataset is not None and not issubclass(type(dataset), Dataset):\n            raise TypeError(\n                f\"SetupReturn Error: {data_name}_set must be a subclass of fate built-in Dataset but got {type(dataset)}, \\n\"\n                f\"You can get the class via: from fate.ml.nn.dataset.table import Dataset\"\n            )\n        return dataset\n\n    def train(\n        self,\n        train_data: Optional[Union[str]] = None,\n        validate_data: Optional[Union[str]] = None,\n        output_dir: str = None,\n        saved_model_path: str = None,\n    ) -> None:\n        if self.is_client():\n            dataset_0 = self._prepare_data(train_data, \"train_data\")\n            logger.info('dataset loaded')\n            if dataset_0 is None:\n                raise ValueError('You must provide dataset for inference')\n            assert isinstance(dataset_0, HuggingfaceDataset), 'Currently only support HuggingfaceDataset for inference, but got {}'.format(type(dataset_0))\n            logger.info('initializing inst')\n            client_inst = self.client_setup()\n            pred_rs = client_inst.inference(\n                dataset_0, self.encode_template, self.instruction_template, self.decode_template, \\\n                remote_inference_kwargs=self.remote_inference_kwargs,\n                local_inference_kwargs=self.local_inference_kwargs\n            )\n            logger.info('predict done')\n            saving_path = output_dir + '/' + 'inference_result.pkl'\n            logger.info('result save to path {}'.format(saving_path))\n            torch.save(pred_rs, saving_path)\n        elif self.is_server():\n            server_inst = self.server_setup()\n            server_inst.inference()\n        else:\n            raise ValueError('Unknown role')\n\n    def predict(\n        self, test_data: Optional[Union[str]] = None, output_dir: str = None, saved_model_path: str = None\n    ):\n        logger.warning('Predicting mode is not supported in this algorithms in current version, please use the train mode to run inferdpt inference.')\n        return \n\n\n"
  },
  {
    "path": "python/fate_llm/runner/offsite_tuning_runner.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n\nfrom fate.components.components.nn.nn_runner import (\n    load_model_dict_from_path,\n    dir_warning,\n    loader_load_from_conf,\n)\nfrom fate.ml.nn.homo.fedavg import FedAVGArguments\nfrom fate_llm.algo.fedavg.fedavg import Seq2SeqFedAVGServer\nfrom typing import Dict\nfrom fate.components.components.nn.loader import Loader\nfrom fate_llm.trainer.seq2seq_trainer import Seq2SeqTrainingArguments\nfrom typing import Union, Optional\nfrom transformers.trainer_utils import get_last_checkpoint\nfrom typing import Literal\nimport logging\nfrom fate.arch.dataframe import DataFrame\nfrom fate_llm.runner.homo_seq2seq_runner import Seq2SeqRunner, _check_instances\nfrom fate_llm.algo.offsite_tuning.offsite_tuning import OffsiteTuningTrainerClient, OffsiteTuningTrainerServer\n\n\nlogger = logging.getLogger(__name__)\n\n\nSUPPORTED_ALGO = [\"fedavg\"]\n\n\nclass OTRunner(Seq2SeqRunner):\n\n    def __init__(\n        self,\n        model_conf: Optional[Dict] = None,\n        dataset_conf: Optional[Dict] = None,\n        optimizer_conf: Optional[Dict] = None,\n        training_args_conf: Optional[Dict] = None,\n        fed_args_conf: Optional[Dict] = None,\n        data_collator_conf: Optional[Dict] = None,\n        tokenizer_conf: Optional[Dict] = None,\n        task_type: Literal[\"causal_lm\", \"other\"] = \"causal_lm\",\n        save_trainable_weights_only: bool = False,\n        aggregate_model: bool = False,\n        algo: str = 'ot'\n    ) -> None:\n        super(OTRunner, self).__init__(\n            algo, model_conf, dataset_conf, optimizer_conf, training_args_conf, fed_args_conf,\n            data_collator_conf, tokenizer_conf, task_type, local_mode=False\n        )\n\n        self.aggregate_model = aggregate_model\n        self.save_trainable_weights_only = save_trainable_weights_only\n\n    def setup(self, train_set=None, validate_set=None, output_dir=None, saved_model=None, stage=\"train\"):\n\n        if stage == \"predict\":\n            self.local_mode = True\n            \n        ctx = self.get_context()\n        model = loader_load_from_conf(self.model_conf)\n\n        if model is None:\n            raise ValueError(f\"model is None, cannot load model from conf {self.model_conf}\")\n        \n        if output_dir is None:\n            output_dir = \"./\"\n\n        resume_path = None\n        if saved_model is not None:\n            model_dict = load_model_dict_from_path(saved_model)\n            model.load_state_dict(model_dict)\n            logger.info(f\"loading model dict from {saved_model} to model done\")\n            if get_last_checkpoint(saved_model) is not None:\n                resume_path = saved_model\n                logger.info(f\"checkpoint detected, resume_path set to {resume_path}\")\n\n        # load optimizer\n        if self.optimizer_conf:\n            optimizer_loader = Loader.from_dict(self.optimizer_conf)\n            optimizer_ = optimizer_loader.load_item()\n            optimizer_params = optimizer_loader.kwargs\n            optimizer = optimizer_(model.parameters(), **optimizer_params)\n        else:\n            optimizer = None\n        # load collator func\n        data_collator = loader_load_from_conf(self.data_collator_conf)\n        # load tokenizer if import conf provided\n        tokenizer = loader_load_from_conf(self.tokenizer_conf)\n        # args\n        dir_warning(self.training_args_conf)\n        training_args = Seq2SeqTrainingArguments(**self.training_args_conf)\n        self.training_args = training_args\n        # reset to default, saving to arbitrary path is not allowed in\n        # DefaultRunner\n        training_args.output_dir = output_dir\n        training_args.resume_from_checkpoint = resume_path  # resume path\n        fed_args = FedAVGArguments(**self.fed_args_conf)\n\n        # prepare trainer\n        if self.is_client():\n            trainer = OffsiteTuningTrainerClient(\n                ctx=ctx,\n                model=model,\n                optimizer=optimizer,\n                training_args=training_args,\n                fed_args=fed_args,\n                data_collator=data_collator,\n                tokenizer=tokenizer,\n                train_set=train_set,\n                val_set=validate_set,\n                save_trainable_weights_only=self.save_trainable_weights_only,\n                aggregate_model=self.aggregate_model\n            )\n\n        elif self.is_server():\n            trainer = OffsiteTuningTrainerServer(\n                ctx=ctx,\n                model=model,\n                aggregate_model=self.aggregate_model\n            )\n\n        _check_instances(\n            trainer=trainer,\n            model=model,\n            optimizer=optimizer,\n            train_args=training_args,\n            fed_args=fed_args,\n            data_collator=data_collator,\n        )\n\n        return trainer\n\n    def server_setup(self, stage=\"train\"):\n        if stage == \"predict\":\n            self.local_mode = True\n        if self.algo == \"fedavg\":\n            server_class: Seq2SeqFedAVGServer = Seq2SeqFedAVGServer\n        else:\n            raise ValueError(f\"algo {self.algo} not supported\")\n        ctx = self.get_context()\n        trainer = server_class(ctx=ctx, local_mode=self.local_mode)\n        _check_instances(trainer)\n        return trainer\n    \n\n    def train(\n        self,\n        train_data: Optional[Union[str, DataFrame]] = None,\n        validate_data: Optional[Union[str, DataFrame]] = None,\n        output_dir: str = None,\n        saved_model_path: str = None,\n    ):\n        \n        if self.is_client():\n            train_set = self._prepare_data(train_data, \"train_data\")\n            validate_set = self._prepare_data(validate_data, \"val_data\")\n            trainer = self.setup(\n                train_set=train_set, validate_set=validate_set, output_dir=output_dir, saved_model=saved_model_path\n            )\n            self.trainer = trainer\n            trainer.train()\n\n        elif self.is_server():\n            trainer = self.setup(\n                train_set=None, validate_set=None, output_dir=output_dir, saved_model=saved_model_path\n            )\n            trainer.train()\n\n        if output_dir is not None:\n            if self.training_args.deepspeed and self.training_args.local_rank != 0:\n                pass\n            else:\n                trainer.save_model(output_dir)\n"
  },
  {
    "path": "python/fate_llm/trainer/__init__.py",
    "content": ""
  },
  {
    "path": "python/fate_llm/trainer/seq2seq_trainer.py",
    "content": "#\n#  Copyright 2019 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\nfrom transformers import Seq2SeqTrainingArguments as _hf_Seq2SeqTrainingArguments, Seq2SeqTrainer\nfrom dataclasses import dataclass, field\nfrom typing import Optional\nfrom fate.ml.nn.trainer.trainer_base import HomoTrainerMixin, FedArguments, get_ith_checkpoint\nimport os\nimport torch\nimport copy\nfrom torch import nn\nfrom typing import Any, Dict, List, Callable\nfrom enum import Enum\nfrom fate.arch import Context\nfrom torch.optim import Optimizer\nfrom torch.utils.data import DataLoader, Dataset\nfrom transformers import PreTrainedTokenizer\nfrom transformers import Trainer, EvalPrediction\nfrom transformers.trainer_utils import has_length\nfrom torch.utils.data import _utils\nfrom transformers.trainer_callback import TrainerCallback\nfrom typing import Optional\nfrom dataclasses import dataclass, field\nfrom transformers.modeling_utils import unwrap_model\n\n\nTRAINABLE_WEIGHTS_NAME = \"adapter_model.bin\"\n\n\n@dataclass\nclass _S2STrainingArguments(_hf_Seq2SeqTrainingArguments):\n    # in fate-2.0, we will control the output dir when using pipeline\n    output_dir: str = field(default=\"./\")\n    disable_tqdm: bool = field(default=True)\n    save_strategy: str = field(default=\"no\")\n    logging_strategy: str = field(default=\"epoch\")\n    logging_steps: int = field(default=1)\n    evaluation_strategy: str = field(default=\"no\")\n    logging_dir: str = field(default=None)\n    checkpoint_idx: int = field(default=None)\n    # by default, we use constant learning rate, the same as FATE-1.X\n    lr_scheduler_type: str = field(default=\"constant\")\n    log_level: str = field(default=\"info\")\n    deepspeed: Optional[str] = field(default=None)\n    save_safetensors: bool = field(default=False)\n    use_cpu: bool = field(default=False)\n\n    def __post_init__(self):\n        self.push_to_hub = False\n        self.hub_model_id = None\n        self.hub_strategy = \"every_save\"\n        self.hub_token = None\n        self.hub_private_repo = False\n        self.push_to_hub_model_id = None\n        self.push_to_hub_organization = None\n        self.push_to_hub_token = None\n\n        super().__post_init__()\n\nDEFAULT_ARGS = _S2STrainingArguments().to_dict()\n\n@dataclass\nclass Seq2SeqTrainingArguments(_S2STrainingArguments):\n    # To simplify the to dict result(to_dict only return non-default args)\n\n    def to_dict(self):\n        # Call the superclass's to_dict method\n        all_args = super().to_dict()\n        # Get a dict with default values for all fields\n        default_args = copy.deepcopy(DEFAULT_ARGS)\n        # Filter out args that are equal to their default values\n        set_args = {name: value for name, value in all_args.items() if value != default_args.get(name)}\n        return set_args\n\n\nclass HomoSeq2SeqTrainerClient(Seq2SeqTrainer, HomoTrainerMixin):\n\n    def __init__(\n        self,\n        ctx: Context,\n        model: nn.Module,\n        training_args: Seq2SeqTrainingArguments,\n        fed_args: FedArguments,\n        train_set: Dataset,\n        val_set: Dataset = None,\n        optimizer: torch.optim.Optimizer = None,\n        data_collator: Callable = None,\n        scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,\n        tokenizer: Optional[PreTrainedTokenizer] = None,\n        callbacks: Optional[List[TrainerCallback]] = [],\n        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,\n        local_mode: bool = False,\n        save_trainable_weights_only: bool = False,\n        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,\n    ):\n        # in case you forget to set evaluation_strategy\n        if val_set is not None and training_args.evaluation_strategy == \"no\":\n            training_args.evaluation_strategy = \"epoch\"\n\n        HomoTrainerMixin.__init__(\n            self,\n            ctx=ctx,\n            model=model,\n            optimizer=optimizer,\n            training_args=training_args,\n            fed_args=fed_args,\n            train_set=train_set,\n            val_set=val_set,\n            scheduler=scheduler,\n            callbacks=callbacks,\n            compute_metrics=compute_metrics,\n            local_mode=local_mode,\n            save_trainable_weights_only=save_trainable_weights_only,\n        )\n\n        # concat checkpoint path if checkpoint idx is set\n        if self._args.checkpoint_idx is not None:\n            checkpoint_path = self._args.resume_from_checkpoint\n            if checkpoint_path is not None and os.path.exists(checkpoint_path):\n                checkpoint_folder = get_ith_checkpoint(checkpoint_path, self._args.checkpoint_idx)\n                self._args.resume_from_checkpoint = os.path.join(checkpoint_path, checkpoint_folder)\n\n        Trainer.__init__(\n            self,\n            model=model,\n            args=self._args,\n            train_dataset=train_set,\n            eval_dataset=val_set,\n            data_collator=data_collator,\n            optimizers=(optimizer, scheduler),\n            tokenizer=tokenizer,\n            compute_metrics=self._compute_metrics_warp_func,\n            preprocess_logits_for_metrics=preprocess_logits_for_metrics,\n        )\n\n        self._add_fate_callback(self.callback_handler)\n\n    def _save(\n        self,\n        output_dir: Optional[str] = None,\n        state_dict=None\n    ):\n        if not self._save_trainable_weights_only:\n            return super()._save(output_dir, state_dict)\n        else:\n            model = unwrap_model(self.model)\n\n            if hasattr(model, \"save_trainable\"):\n                model.save_trainable(output_dir)\n            else:\n                state_dict = {\n                    k: p.to(\"cpu\") for k,\n                                       p in model.named_parameters() if p.requires_grad\n                }\n\n                torch.save(state_dict, os.path.join(output_dir, TRAINABLE_WEIGHTS_NAME))\n"
  },
  {
    "path": "python/requirements.txt",
    "content": "accelerate==0.27.2\ndeepspeed==0.13.3\npeft==0.8.2\nsentencepiece==0.2.0\nlm_eval==0.4.2\nrouge-score==0.1.2\ndatasets==2.18.0\neditdistance\ntorch==2.3.1\ntransformers==4.37.2\nopacus==1.4.1\nfastchat\nJinja2\nsentence-transformers\nopenai\n"
  },
  {
    "path": "python/setup.py",
    "content": "# -*- coding: utf-8 -*-\n# \n#  Copyright 2024 The FATE Authors. All Rights Reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n#\n\nfrom setuptools import find_packages, setup\n\n# Define the packages and modules\npackages = find_packages(\".\")\npackage_data = {\"\": [\"*\"]}\n\n# Define dependencies\ninstall_requires = [\n    \"accelerate==0.27.2\",\n    \"deepspeed==0.13.3\", \n    \"peft==0.8.2\",\n    \"sentencepiece==0.2.0\",\n    \"lm_eval==0.4.2\",\n    \"rouge-score==0.1.2\",\n    \"datasets==2.18.0\",\n    \"editdistance\",\n    \"torch==2.3.1\",\n    \"transformers==4.37.2\",\n    \"opacus==1.4.1\",\n    \"fastchat\",\n    \"Jinja2\",\n    \"sentence-transformers\",\n    \"openai\"\n]\n\n# Define the entry points for command-line tools\nentry_points = {\n    \"console_scripts\": [\n        \"fate_llm = fate_llm.evaluate.scripts.fate_llm_cli:fate_llm_cli\"\n    ]\n}\n\nextras_require = {\n    \"fate\": [\"pyfate==2.2.0\"],\n    \"fate_flow\": [\"fate_flow==2.2.0\"],\n    \"fate_client\": [\"fate_client==2.2.0\"]\n}\n\n# Configure and call the setup function\nsetup_kwargs = {\n    \"name\": \"fate_llm\",\n    \"version\": \"2.2.0\",\n    \"description\": \"Federated Learning for Large Language Models\",\n    \"long_description\": \"Federated Learning for Large Language Models (FATE-LLM) provides a framework to train and evaluate large language models in a federated manner.\",\n    \"long_description_content_type\": \"text/markdown\",\n    \"author\": \"FederatedAI\",\n    \"author_email\": \"contact@FedAI.org\",\n    \"url\": \"https://fate.fedai.org/\",\n    \"packages\": packages,\n    \"install_requires\": install_requires,\n    \"entry_points\": entry_points,\n    \"extras_require\": extras_require,\n    \"python_requires\": \">=3.8\",\n    \"include_package_data\": True\n}\n\nsetup(**setup_kwargs)\n"
  }
]