[
  {
    "path": ".github/workflows/check-style.yaml",
    "content": "name: Check style\n\non:\n  push:\n    branches: [ main ]\n  pull_request:\n\njobs:\n  black:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v3\n      - uses: psf/black@stable\n        with:\n          options: \"--check --diff\"\n          version: \"22.3.0\"\n  isort:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v3\n      - uses: actions/setup-python@v3\n        with:\n          python-version: 3.8\n      - uses: isort/isort-action@master\n        with:\n          isortVersion: \"5.10.1\"\n"
  },
  {
    "path": ".github/workflows/push-docker-image.yaml",
    "content": "name: Push to Docker Hub\n\non:\n  push:\n    branches: [ main ]\n    tags:\n      - \"*.*.*\"\n  pull_request:\n    branches: [ main ]\n\njobs:\n  build:\n    runs-on: ubuntu-latest\n\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v3\n\n      - name: Docker meta\n        id: meta\n        uses: crazy-max/ghaction-docker-meta@v2\n        with:\n          # list of Docker images to use as base name for tags\n          images: |\n            learningathome/petals\n          # generate Docker tags based on the following events/attributes\n          tags: |\n            type=ref,event=branch\n            type=ref,event=pr\n            type=semver,pattern={{version}}\n            type=semver,pattern={{major}}.{{minor}}\n            type=semver,pattern={{major}}\n\n      - name: Set up Docker Buildx\n        id: buildx\n        uses: docker/setup-buildx-action@v1\n\n      - name: Login to Docker Hub\n        if: github.event_name != 'pull_request'\n        uses: docker/login-action@v1\n        with:\n          username: ${{ secrets.DOCKER_HUB_USERNAME }}\n          password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}\n\n      - name: Free disk space on Ubuntu runner\n        uses: kfir4444/free-disk-space@main\n        with:\n          # found in: https://github.com/docker/build-push-action/issues/968\n          tool-cache: false\n          android: true\n          dotnet: true\n          haskell: true\n          large-packages: true\n          swap-storage: true\n\n      - name: Build and push\n        id: docker_build\n        uses: docker/build-push-action@v2\n        with:\n          context: .\n          push: ${{ github.event_name != 'pull_request' }}\n          tags: ${{ steps.meta.outputs.tags }}\n\n      - name: Image digest\n        run: echo ${{ steps.docker_build.outputs.digest }}\n"
  },
  {
    "path": ".github/workflows/run-tests.yaml",
    "content": "name: Tests\n\non:\n  push:\n    branches: [ main ]\n  pull_request:\n\njobs:\n  run-tests:\n    strategy:\n      matrix:\n        include:\n          - { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.8' }\n          - { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' }\n          - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' }\n          - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }\n          - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }\n          - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }\n          - { model: 'artek0chumak/TestMixtral', os: 'ubuntu', python-version: '3.8' }\n          - { model: 'artek0chumak/TestMixtral', os: 'ubuntu', python-version: '3.11' }\n      fail-fast: false\n    runs-on: ${{ matrix.os }}-latest\n    timeout-minutes: 20\n    steps:\n      - name: Increase swap space\n        if: ${{ matrix.os == 'ubuntu' }}\n        uses: pierotofy/set-swap-space@master\n        with:\n          swap-size-gb: 10\n      - name: Checkout\n        uses: actions/checkout@v3\n      - name: Set up Python\n        uses: actions/setup-python@v3\n        with:\n          python-version: ${{ matrix.python-version }}\n      - name: Cache dependencies\n        uses: actions/cache@v3\n        with:\n          path: ~/.cache/pip\n          key: Key-v1-${{ matrix.python-version }}-${{ hashFiles('setup.cfg') }}\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade pip\n          pip install .[dev]\n      - name: Test\n        run: |\n          set -x  # Print executed commands\n          export MODEL_NAME=\"${{ matrix.model }}\"\n          export REF_NAME=\"${{ matrix.model }}\"\n          export ADAPTER_NAME=\"${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}\"\n\n          # [Step 1] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)\n\n          python -m petals.cli.run_dht \\\n            --identity_path tests/bootstrap.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> bootstrap.log &\n          BOOTSTRAP_PID=$!\n\n          export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g\n          # ^-- multiaddr in INITIAL_PEERS is determined by --identity_path and --host_maddrs\n\n          until [ -s bootstrap.log ]; do sleep 5; done  # wait for DHT init\n\n          export RUN_SERVER=\"python -m petals.cli.run_server $MODEL_NAME \\\n            --device cpu --torch_dtype float32 --initial_peers $INITIAL_PEERS\"\n          export TENSOR_PARALLEL_ARGS=\"${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}\"\n\n          $RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 5 --throughput 1 --mean_balance_check_period 10 &> server1.log &\n          SERVER1_PID=$!\n          # ^-- rebalacing test: this server chooses blocks 0:5, then sees a gap in the swarm and moves there\n\n          sleep 10  # wait for the 1st server to choose blocks\n\n          $RUN_SERVER --adapters $ADAPTER_NAME --block_indices 0:5 --throughput 1 --identity_path tests/server2.id &> server2.log &\n          SERVER2_PID=$!\n\n          $RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 14 --throughput auto \\\n            --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 &> server3.log &\n          SERVER3_PID=$!\n          # ^-- chunking test\n\n          $RUN_SERVER $TENSOR_PARALLEL_ARGS --block_indices 0:2 --throughput auto &> server4.log &\n          SERVER4_PID=$!\n          # ^-- tensor parallelism test (not compatible with adapters yet)\n\n          sleep 5  # wait for the log files to appear\n\n          tail -n 100 -f bootstrap.log server*.log &\n          LOGGER_PID=$!\n\n          sleep 30  # wait for servers to eval throughput, download layers, and rebalance\n          kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID  # ensure all peers survived init\n\n          # [Step 2] Run PyTest\n\n          # Share disk cache between Petals servers, clients, and HF Transformers\n          export TRANSFORMERS_CACHE=~/.cache/petals\n\n          # Necessary for @pytest.mark.forked to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93\n          export no_proxy=*\n          export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES\n\n          # Limit default ClientConfig.max_retries to see tracebacks instead of retrying indefinitely\n          export PETALS_MAX_RETRIES=10\n\n          pytest tests --durations=0 --durations-min=1.0 -v\n\n          # [Step 3] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)\n\n          python benchmarks/benchmark_inference.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \\\n            --seq_len 3\n          python benchmarks/benchmark_forward.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \\\n            --seq_len 3 --batch_size 3 --n_steps 1\n          python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \\\n            --seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task cls\n          python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \\\n            --seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task causal_lm\n\n          # [Step 4] Clean up\n\n          kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n.idea/\n"
  },
  {
    "path": "Dockerfile",
    "content": "FROM nvcr.io/nvidia/cuda:11.0.3-cudnn8-devel-ubuntu20.04\nLABEL maintainer=\"bigscience-workshop\"\nLABEL repository=\"petals\"\n\nWORKDIR /home\n# Set en_US.UTF-8 locale by default\nRUN echo \"LC_ALL=en_US.UTF-8\" >> /etc/environment\n\n# Install packages\nRUN apt-get update && apt-get install -y --no-install-recommends \\\n  build-essential \\\n  wget \\\n  git \\\n  && apt-get clean autoclean && rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log} /tmp/* /var/tmp/*\n\nRUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O install_miniconda.sh && \\\n  bash install_miniconda.sh -b -p /opt/conda && rm install_miniconda.sh\nENV PATH=\"/opt/conda/bin:${PATH}\"\n\nRUN conda install python~=3.10.12 pip && \\\n    pip install --no-cache-dir \"torch>=1.12\" && \\\n    conda clean --all && rm -rf ~/.cache/pip\n\nVOLUME /cache\nENV PETALS_CACHE=/cache\n\nCOPY . petals/\nRUN pip install --no-cache-dir -e petals\n\nWORKDIR /home/petals/\nCMD bash\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2022 Petals authors and collaborators\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "<p align=\"center\">\n    <img src=\"https://i.imgur.com/7eR7Pan.png\" width=\"400\"><br>\n    Run large language models at home, BitTorrent-style.<br>\n    Fine-tuning and inference <a href=\"https://github.com/bigscience-workshop/petals#benchmarks\">up to 10x faster</a> than offloading\n    <br><br>\n    <a href=\"https://pypi.org/project/petals/\"><img src=\"https://img.shields.io/pypi/v/petals.svg?color=green\"></a>\n    <a href=\"https://discord.gg/tfHfe8B34k\"><img src=\"https://img.shields.io/discord/865254854262652969?label=discord&logo=discord&logoColor=white\"></a>\n    <br>\n</p>\n\nGenerate text with distributed **Llama 3.1** (up to 405B), **Mixtral** (8x22B), **Falcon** (40B+) or **BLOOM** (176B) and fine‑tune them for your own tasks &mdash; right from your desktop computer or Google Colab:\n\n```python\nfrom transformers import AutoTokenizer\nfrom petals import AutoDistributedModelForCausalLM\n\n# Choose any model available at https://health.petals.dev\nmodel_name = \"meta-llama/Meta-Llama-3.1-405B-Instruct\"\n\n# Connect to a distributed network hosting model layers\ntokenizer = AutoTokenizer.from_pretrained(model_name)\nmodel = AutoDistributedModelForCausalLM.from_pretrained(model_name)\n\n# Run the model as if it were on your computer\ninputs = tokenizer(\"A cat sat\", return_tensors=\"pt\")[\"input_ids\"]\noutputs = model.generate(inputs, max_new_tokens=5)\nprint(tokenizer.decode(outputs[0]))  # A cat sat on a mat...\n```\n\n<p align=\"center\">\n    🚀 &nbsp;<b><a href=\"https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing\">Try now in Colab</a></b>\n</p>\n\n🦙 **Want to run Llama?** [Request access](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct) to its weights, then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev).\n\n🔏 **Privacy.** Your data will be processed with the help of other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust.\n\n💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)!\n\n## Connect your GPU and increase Petals capacity\n\nPetals is a community-run system &mdash; we rely on people sharing their GPUs. You can help serving one of the [available models](https://health.petals.dev) or host a new model from 🤗 [Model Hub](https://huggingface.co/models)!\n\nAs an example, here is how to host a part of [Llama 3.1 (405B) Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct) on your GPU:\n\n🦙 **Want to host Llama?** [Request access](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct) to its weights, then run `huggingface-cli login` in the terminal before loading the model.\n\n🐧 **Linux + Anaconda.** Run these commands for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD):\n\n```bash\nconda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia\npip install git+https://github.com/bigscience-workshop/petals\npython -m petals.cli.run_server meta-llama/Meta-Llama-3.1-405B-Instruct\n```\n\n🪟 **Windows + WSL.** Follow [this guide](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows) on our Wiki.\n\n🐋 **Docker.** Run our [Docker](https://www.docker.com) image for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD):\n\n```bash\nsudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \\\n    learningathome/petals:main \\\n    python -m petals.cli.run_server --port 31330 meta-llama/Meta-Llama-3.1-405B-Instruct\n```\n\n🍏 **macOS + Apple M1/M2 GPU.** Install [Homebrew](https://brew.sh/), then run these commands:\n\n```bash\nbrew install python\npython3 -m pip install git+https://github.com/bigscience-workshop/petals\npython3 -m petals.cli.run_server meta-llama/Meta-Llama-3.1-405B-Instruct\n```\n\n<p align=\"center\">\n    📚 &nbsp;<b><a href=\"https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server\">Learn more</a></b> (how to use multiple GPUs, start the server on boot, etc.)\n</p>\n\n🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).\n\n💬 **Any questions?** Ping us in [our Discord](https://discord.gg/X7DgtxgMhc)!\n\n🏆 **Thank you!** Once you load and host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks. You can specify them with `--public_name YOUR_NAME`.\n\n## How does it work?\n\n- You load a small part of the model, then join a [network](https://health.petals.dev) of people serving the other parts. Single‑batch inference runs at up to **6 tokens/sec** for **Llama 2** (70B) and up to **4 tokens/sec** for **Falcon** (180B) — enough for [chatbots](https://chat.petals.dev) and interactive apps.\n- You can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of **PyTorch** and **🤗 Transformers**.\n\n<p align=\"center\">\n    <img src=\"https://i.imgur.com/RTYF3yW.png\" width=\"800\">\n</p>\n\n<p align=\"center\">\n    📜 &nbsp;<b><a href=\"https://arxiv.org/pdf/2209.01188.pdf\">Read paper</a></b>\n    &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;\n    📚 &nbsp;<b><a href=\"https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions\">See FAQ</a></b>\n</p>\n\n## 📚 Tutorials, examples, and more\n\nBasic tutorials:\n\n- Getting started: [tutorial](https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing)\n- Prompt-tune Llama-65B for text semantic classification: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb)\n- Prompt-tune BLOOM to create a personified chatbot: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb)\n\nUseful tools:\n\n- [Chatbot web app](https://chat.petals.dev) (connects to Petals via an HTTP/WebSocket endpoint): [source code](https://github.com/petals-infra/chat.petals.dev)\n- [Monitor](https://health.petals.dev) for the public swarm: [source code](https://github.com/petals-infra/health.petals.dev)\n\nAdvanced guides:\n\n- Launch a private swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)\n- Run a custom model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals)\n\n### Benchmarks\n\nPlease see **Section 3.3** of our [paper](https://arxiv.org/pdf/2209.01188.pdf).\n\n### 🛠️ Contributing\n\nPlease see our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#contributing) on contributing.\n\n### 📜 Citations\n\nAlexander Borzunov, Dmitry Baranchuk, Tim Dettmers, Max Ryabinin, Younes Belkada, Artem Chumachenko, Pavel Samygin, and Colin Raffel.\n[Petals: Collaborative Inference and Fine-tuning of Large Models.](https://arxiv.org/abs/2209.01188)\n_Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)._ 2023.\n\n```bibtex\n@inproceedings{borzunov2023petals,\n  title = {Petals: Collaborative Inference and Fine-tuning of Large Models},\n  author = {Borzunov, Alexander and Baranchuk, Dmitry and Dettmers, Tim and Riabinin, Maksim and Belkada, Younes and Chumachenko, Artem and Samygin, Pavel and Raffel, Colin},\n  booktitle = {Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)},\n  pages = {558--568},\n  year = {2023},\n  url = {https://arxiv.org/abs/2209.01188}\n}\n```\n\nAlexander Borzunov, Max Ryabinin, Artem Chumachenko, Dmitry Baranchuk, Tim Dettmers, Younes Belkada, Pavel Samygin, and Colin Raffel.\n[Distributed inference and fine-tuning of large language models over the Internet.](https://arxiv.org/abs/2312.08361)\n_Advances in Neural Information Processing Systems_ 36 (2023).\n\n```bibtex\n@inproceedings{borzunov2023distributed,\n  title = {Distributed inference and fine-tuning of large language models over the {I}nternet},\n  author = {Borzunov, Alexander and Ryabinin, Max and Chumachenko, Artem and Baranchuk, Dmitry and Dettmers, Tim and Belkada, Younes and Samygin, Pavel and Raffel, Colin},\n  booktitle = {Advances in Neural Information Processing Systems},\n  volume = {36},\n  pages = {12312--12331},\n  year = {2023},\n  url = {https://arxiv.org/abs/2312.08361}\n}\n```\n\n--------------------------------------------------------------------------------\n\n<p align=\"center\">\n    This project is a part of the <a href=\"https://bigscience.huggingface.co/\">BigScience</a> research workshop.\n</p>\n<p align=\"center\">\n    <img src=\"https://petals.dev/bigscience.png\" width=\"150\">\n</p>\n"
  },
  {
    "path": "benchmarks/benchmark_forward.py",
    "content": "#!/usr/bin/env python3\n\nimport argparse\nimport multiprocessing as mp\nfrom time import perf_counter\n\nimport numpy as np\nimport torch\nfrom hivemind.utils.logging import get_logger\n\nfrom petals import AutoDistributedModel\nfrom petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS\n\nlogger = get_logger()\n\n\ndef main():\n    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n    parser.add_argument(\"--model\", type=str, required=True, help=\"Model\")\n    parser.add_argument(\"--initial_peers\", type=str, nargs=\"+\", default=PUBLIC_INITIAL_PEERS, help=\"Initial peers\")\n    parser.add_argument(\"--torch_dtype\", type=str, default=\"float32\", help=\"Torch dtype\")\n    parser.add_argument(\"--n_processes\", type=str, default=1, help=\"Number of concurrent processes\")\n    parser.add_argument(\"--seq_len\", type=int, default=128, help=\"Sequence length\")\n    parser.add_argument(\"--n_steps\", type=int, default=100, help=\"Number of benchmark steps\")\n    parser.add_argument(\"--batch_size\", type=int, required=True, help=\"Batch size\")\n    parser.add_argument(\"--warmup_steps\", type=int, default=1, help=\"Number of warmup steps\")\n    args = parser.parse_args()\n\n    if args.n_processes == \"n_gpus\":\n        args.n_processes = torch.cuda.device_count()\n    else:\n        args.n_processes = int(args.n_processes)\n\n    pipe_recv, pipe_send = mp.Pipe(duplex=False)\n    processes = [mp.Process(target=benchmark_forward, args=(i, args, pipe_send)) for i in range(args.n_processes)]\n    for proc in processes:\n        proc.start()\n    for proc in processes:\n        proc.join()\n\n    speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)])\n    logger.info(f\"Final result: {speed=:.2f}\")\n\n\n@torch.inference_mode()\ndef benchmark_forward(process_idx, args, result_pipe):\n    model = AutoDistributedModel.from_pretrained(\n        args.model,\n        initial_peers=args.initial_peers,\n        torch_dtype=DTYPE_MAP[args.torch_dtype],\n    )\n    logger.info(f\"Created model: {process_idx=} {model.device=}\")\n\n    torch.manual_seed(42)\n    step_times = []\n    for step in range(args.warmup_steps + args.n_steps):\n        start_time = perf_counter()\n\n        input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len))\n\n        logger.info(f\"{process_idx=} Fwd begin {input_ids.shape=}\")\n        h = model(input_ids)\n        # We don't use model.lm_head\n        logger.info(f\"{process_idx=} Fwd end\")\n\n        if step >= args.warmup_steps:\n            step_times.append(perf_counter() - start_time)\n            speed = input_ids.numel() / np.mean(step_times)\n            logger.info(f\"{process_idx=} {step=} {speed=:.2f}\")\n\n    result_pipe.send(speed)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmarks/benchmark_inference.py",
    "content": "#!/usr/bin/env python3\n\nimport argparse\nimport multiprocessing as mp\nfrom time import perf_counter\n\nimport numpy as np\nimport torch\nfrom hivemind.utils.logging import get_logger\nfrom transformers import AutoTokenizer\n\nfrom petals import AutoDistributedModelForCausalLM\nfrom petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS\n\nlogger = get_logger()\n\n\ndef main():\n    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n    parser.add_argument(\"--model\", type=str, required=True, help=\"Model\")\n    parser.add_argument(\"--initial_peers\", type=str, nargs=\"+\", default=PUBLIC_INITIAL_PEERS, help=\"Initial peers\")\n    parser.add_argument(\"--torch_dtype\", type=str, default=\"float32\", help=\"Torch dtype\")\n    parser.add_argument(\"--n_processes\", type=str, default=1, help=\"Number of concurrent processes\")\n    parser.add_argument(\"--seq_len\", type=int, default=2048, help=\"Sequence length\")\n    parser.add_argument(\"--warmup_steps\", type=int, default=1, help=\"Number of warmup steps\")\n    args = parser.parse_args()\n\n    if args.n_processes == \"n_gpus\":\n        args.n_processes = torch.cuda.device_count()\n    else:\n        args.n_processes = int(args.n_processes)\n\n    pipe_recv, pipe_send = mp.Pipe(duplex=False)\n    processes = [mp.Process(target=benchmark_inference, args=(i, args, pipe_send)) for i in range(args.n_processes)]\n    for proc in processes:\n        proc.start()\n    for proc in processes:\n        proc.join()\n\n    speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)])\n    logger.info(f\"Final result: {speed=:.2f}\")\n\n\n@torch.inference_mode()\ndef benchmark_inference(process_idx, args, result_pipe):\n    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)\n    # Using use_fast=False since LlamaTokenizerFast takes a long time to start, and we decode 1 token at a time anyway\n\n    model = AutoDistributedModelForCausalLM.from_pretrained(\n        args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype]\n    )\n    logger.info(f\"Created model: {process_idx=} {model.device=}\")\n\n    result = \"\"\n    step_times = []\n    with model.transformer.h.inference_session(max_length=args.seq_len) as sess:\n        for step in range(args.seq_len):\n            start_time = perf_counter()\n\n            outputs = model.generate(max_new_tokens=1, session=sess)\n            result += tokenizer.decode(outputs[0])\n\n            if step >= args.warmup_steps:\n                step_times.append(perf_counter() - start_time)\n                speed = 1 / np.mean(step_times)\n                logger.info(f\"{process_idx=} {step=} {speed=:.2f}\")\n\n    result_pipe.send(speed)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmarks/benchmark_training.py",
    "content": "#!/usr/bin/env python3\n\nimport argparse\nimport multiprocessing as mp\nfrom time import perf_counter\n\nimport numpy as np\nimport torch\nfrom hivemind.utils.logging import get_logger\n\nfrom petals import AutoDistributedModelForCausalLM, AutoDistributedModelForSequenceClassification\nfrom petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS\n\nlogger = get_logger()\n\n\ndef main():\n    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n    parser.add_argument(\"--model\", type=str, required=True, help=\"Model\")\n    parser.add_argument(\"--device\", type=str, default=\"cpu\", help=\"Torch device hosting the client\")\n    parser.add_argument(\"--task\", type=str, default=\"cls\", help=\"Training task type\")\n    parser.add_argument(\"--initial_peers\", type=str, nargs=\"+\", default=PUBLIC_INITIAL_PEERS, help=\"Initial peers\")\n    parser.add_argument(\"--torch_dtype\", type=str, default=\"float32\", help=\"Torch dtype\")\n    parser.add_argument(\"--n_processes\", type=str, default=1, help=\"Number of concurrent processes\")\n    parser.add_argument(\"--seq_len\", type=int, default=128, help=\"Sequence length\")\n    parser.add_argument(\"--pre_seq_len\", type=int, default=16, help=\"Number of trainable tokens\")\n    parser.add_argument(\"--n_steps\", type=int, default=10, help=\"Number of benchmark steps\")\n    parser.add_argument(\"--batch_size\", type=int, required=True, help=\"Batch size\")\n    parser.add_argument(\"--warmup_steps\", type=int, default=1, help=\"Number of warmup steps\")\n    args = parser.parse_args()\n\n    assert args.task in [\"cls\", \"causal_lm\"]\n\n    if args.n_processes == \"n_gpus\":\n        args.n_processes = torch.cuda.device_count()\n    else:\n        args.n_processes = int(args.n_processes)\n\n    pipe_recv, pipe_send = mp.Pipe(duplex=False)\n    processes = [mp.Process(target=benchmark_training, args=(i, args, pipe_send)) for i in range(args.n_processes)]\n    for proc in processes:\n        proc.start()\n    for proc in processes:\n        proc.join()\n\n    fwd_speed, bwd_speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)], axis=0)\n    logger.info(f\"Final result: {fwd_speed=:.2f} {bwd_speed=:.2f}\")\n\n\ndef benchmark_training(process_idx, args, result_pipe):\n    if args.task == \"cls\":\n        model = AutoDistributedModelForSequenceClassification.from_pretrained(\n            args.model,\n            initial_peers=args.initial_peers,\n            torch_dtype=DTYPE_MAP[args.torch_dtype],\n            tuning_mode=\"deep_ptune\",\n            pre_seq_len=args.pre_seq_len,\n            num_labels=2,\n        )\n    elif args.task == \"causal_lm\":\n        model = AutoDistributedModelForCausalLM.from_pretrained(\n            args.model,\n            initial_peers=args.initial_peers,\n            torch_dtype=DTYPE_MAP[args.torch_dtype],\n            tuning_mode=\"deep_ptune\",\n            pre_seq_len=args.pre_seq_len,\n        )\n    model = model.to(args.device)\n    opt = torch.optim.Adam(model.parameters())\n    logger.info(f\"Created model: {process_idx=} {model.device=}\")\n\n    torch.manual_seed(42)\n    fwd_times = []\n    bwd_times = []\n    for step in range(args.warmup_steps + args.n_steps):\n        input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len), device=args.device)\n        if args.task == \"cls\":\n            labels = torch.randint(0, 2, size=[args.batch_size], device=args.device)\n        else:\n            labels = input_ids\n\n        logger.info(f\"{process_idx=} {step=} Forward\")\n        start_time = perf_counter()\n        outputs = model(input_ids, labels=labels)\n        if step >= args.warmup_steps:\n            fwd_times.append(perf_counter() - start_time)\n\n        logger.info(f\"{process_idx=} {step=} Backward\")\n        start_time = perf_counter()\n        outputs.loss.backward()\n        if step >= args.warmup_steps:\n            bwd_times.append(perf_counter() - start_time)\n\n        logger.info(f\"{process_idx=} {step=} Optimizer step\")\n        opt.step()\n        opt.zero_grad()\n\n        if step >= args.warmup_steps:\n            fwd_speed = input_ids.numel() / np.mean(fwd_times)\n            bwd_speed = input_ids.numel() / np.mean(bwd_times)\n            logger.info(f\"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}\")\n\n    result_pipe.send((fwd_speed, bwd_speed))\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/prompt-tuning-personachat.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"a07e0f5e\",\n   \"metadata\": {},\n   \"source\": [\n    \"<div>\\n\",\n    \"<img src=\\\"https://camo.githubusercontent.com/473dd9f992924d27457650251786464f72e54121ac6e9210add0f483ca849277/68747470733a2f2f692e696d6775722e636f6d2f3765523750616e2e706e67\\\" width=\\\"40%\\\">  \\n\",\n    \"</div>\\n\",\n    \"\\n\",\n    \"# Distributed Bloom for Text Generation using Prompt Tuning\\n\",\n    \"\\n\",\n    \"In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\\n\",\n    \"\\n\",\n    \"We will adapt BLOOM for the task of creating a chatbot with a specific personality using the [Personachat](https://huggingface.co/datasets/bavard/personachat_truecased) dataset. For a given dialogue context, the model has to provide a relevant answer.\\n\",\n    \"\\n\",\n    \"To use this notebook in Colab:\\n\",\n    \"\\n\",\n    \"1. Follow this link: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb)\\n\",\n    \"2. Go to **Runtime** -> **Change runtime type** and select the GPU accelerator.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"a3f8526f\",\n   \"metadata\": {},\n   \"source\": [\n    \"First, we have to prepare all dependencies.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"73bbc648\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"%pip install -q petals datasets wandb scikit-learn\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"b4ab6ca7\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"\\n\",\n    \"import torch\\n\",\n    \"import transformers\\n\",\n    \"import wandb\\n\",\n    \"from datasets import load_dataset\\n\",\n    \"from tqdm import tqdm\\n\",\n    \"from torch.optim import AdamW\\n\",\n    \"from torch.utils.data import DataLoader\\n\",\n    \"from transformers import BloomTokenizerFast, get_scheduler\\n\",\n    \"\\n\",\n    \"from petals import DistributedBloomForCausalLM\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"1bf07b5d\",\n   \"metadata\": {},\n   \"source\": [\n    \"Let's set some hyperparameters for training:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"f04ba4d2\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Choose a model you'd like to prompt-tune. We recommend starting with\\n\",\n    \"# the smaller 7.1B version of BLOOM (bigscience/bloom-7b1-petals) for faster prototyping.\\n\",\n    \"# Once your code is ready, you can switch to full-scale\\n\",\n    \"# 176B-parameter BLOOM (bigscience/bloom-petals) or BLOOMZ (bigscience/bloomz-petals).\\n\",\n    \"MODEL_NAME = \\\"bigscience/bloom-7b1-petals\\\"\\n\",\n    \"\\n\",\n    \"# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').\\n\",\n    \"# The latter fine-tunes separate prefixes for each transformer block,\\n\",\n    \"# so prompt-tuning will take more time but yield better results.\\n\",\n    \"# See this paper for details of how it works: https://arxiv.org/pdf/2110.07602.pdf\\n\",\n    \"TUNING_MODE = 'ptune'\\n\",\n    \"\\n\",\n    \"NUM_PREFIX_TOKENS = 16\\n\",\n    \"DEVICE = 'cuda'\\n\",\n    \"BATCH_SIZE = 8\\n\",\n    \"LR = 1e-2\\n\",\n    \"WEIGHT_DECAY = 0.0\\n\",\n    \"NUM_SAMPLES = 1000\\n\",\n    \"SEED = 42\\n\",\n    \"MODEL_MAX_LENGTH = 256\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d38316bd\",\n   \"metadata\": {},\n   \"source\": [\n    \"Prepare tokenizer and distributed model, connect it to servers.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"03c6e53e\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"tokenizer = BloomTokenizerFast.from_pretrained(MODEL_NAME)\\n\",\n    \"tokenizer.padding_side = 'right'\\n\",\n    \"tokenizer.model_max_length = MODEL_MAX_LENGTH\\n\",\n    \"model = DistributedBloomForCausalLM.from_pretrained(\\n\",\n    \"    MODEL_NAME,\\n\",\n    \"    pre_seq_len=NUM_PREFIX_TOKENS, \\n\",\n    \"    tuning_mode=TUNING_MODE\\n\",\n    \").to(DEVICE)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"042e3786\",\n   \"metadata\": {},\n   \"source\": [\n    \"Let's prepare the Personachat dataset. We need two mapping functions, one to concatenate history and candidate answers, and another for tokenization.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"9c44d516\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"dataset = load_dataset(\\\"bavard/personachat_truecased\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def chunking(examples):\\n\",\n    \"    inputs = [\\n\",\n    \"        \\\"\\\\n-----\\\\n\\\".join(history) + \\\"\\\\n-----\\\\n\\\" + candidate\\n\",\n    \"        for history, candidates in zip(examples[\\\"history\\\"], examples[\\\"candidates\\\"])\\n\",\n    \"        for candidate in candidates\\n\",\n    \"    ]\\n\",\n    \"    return {\\\"chunks\\\": inputs}\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def tokenize(examples):\\n\",\n    \"    outputs = {\\n\",\n    \"        \\\"input_ids\\\": tokenizer(examples[\\\"chunks\\\"], padding='max_length', truncation=True)[\\\"input_ids\\\"]\\n\",\n    \"    }\\n\",\n    \"    outputs[\\\"labels\\\"] = outputs[\\\"input_ids\\\"]\\n\",\n    \"    return outputs\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"tokenized_datasets = (\\n\",\n    \"    dataset\\n\",\n    \"        .map(chunking, batched=True, remove_columns=dataset[\\\"train\\\"].column_names)\\n\",\n    \"        .map(tokenize, batched=True, remove_columns=[\\\"chunks\\\"])\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"tokenized_datasets.set_format(\\\"torch\\\")\\n\",\n    \"train_dataset = tokenized_datasets[\\\"train\\\"].shuffle(seed=SEED)\\n\",\n    \"train_dataloader = DataLoader(\\n\",\n    \"    train_dataset.select(list(range(NUM_SAMPLES))),\\n\",\n    \"    shuffle=True,\\n\",\n    \"    batch_size=BATCH_SIZE,\\n\",\n    \"    drop_last=True,\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ef4323fd\",\n   \"metadata\": {},\n   \"source\": [\n    \"Before setting up optimizers, check the model parameters that will be trained.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"9cc0ba34\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"for n, p in model.named_parameters():\\n\",\n    \"    if p.requires_grad:\\n\",\n    \"        print(n, p.requires_grad, p.device)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"59cffce7\",\n   \"metadata\": {},\n   \"source\": [\n    \"The optimizer will only work on **prompts**, they are only trainable parameters. Let's initialize optimizer and learning rate scheduler.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"ef9bf344\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\\n\",\n    \"\\n\",\n    \"lr_scheduler = get_scheduler(\\n\",\n    \"    name=\\\"linear\\\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"423c56d5\",\n   \"metadata\": {},\n   \"source\": [\n    \"Let's initialize wandb for logging and start the training loop!\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"d9e46807\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"wandb.init(\\n\",\n    \"    project=\\\"bloom-personachat\\\",\\n\",\n    \"    config={\\n\",\n    \"        \\\"num_samples\\\": NUM_SAMPLES,\\n\",\n    \"        \\\"batch_size\\\": BATCH_SIZE,\\n\",\n    \"        \\\"learning_rate\\\": LR,\\n\",\n    \"        \\\"weight_decay\\\": WEIGHT_DECAY,\\n\",\n    \"        \\\"num_prefix_tokens\\\": NUM_PREFIX_TOKENS,\\n\",\n    \"        \\\"model_name\\\": MODEL_NAME,\\n\",\n    \"        \\\"seed\\\": SEED,\\n\",\n    \"    }\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"for batch in tqdm(train_dataloader):\\n\",\n    \"    batch = {k: v.to(DEVICE) for k, v in batch.items()}\\n\",\n    \"\\n\",\n    \"    model.train()\\n\",\n    \"    outputs = model(**batch)\\n\",\n    \"    loss = outputs.loss\\n\",\n    \"    loss.backward()\\n\",\n    \"\\n\",\n    \"    optimizer.step()\\n\",\n    \"    lr_scheduler.step()\\n\",\n    \"    optimizer.zero_grad()\\n\",\n    \"\\n\",\n    \"    wandb.log({\\\"Train Loss\\\": loss})\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"0f36cb80\",\n   \"metadata\": {},\n   \"source\": [\n    \"Try to talk with the trained model! Submit an empty input to stop the execution.\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"__Note__: In this example, we the whole dialogue as a prefix when generating each new replica. In the future, we will support a faster \\\"interactive\\\" dialogue mode, so generating a new replica will be able to reuse inference caches from the previous replica.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"720181b7\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"TOP_K = 100\\n\",\n    \"TEMPERATURE = 0.6\\n\",\n    \"\\n\",\n    \"with model.inference_session(max_length=512) as sess:\\n\",\n    \"    while True:\\n\",\n    \"        user_phrase = input()\\n\",\n    \"        if len(user_phrase) == 0:\\n\",\n    \"            break\\n\",\n    \"        inputs = tokenizer([f\\\"{user_phrase}\\\\n-----\\\\n\\\"], return_tensors='pt')['input_ids'].to(DEVICE)\\n\",\n    \"        while True:\\n\",\n    \"            outputs = model.generate(\\n\",\n    \"                inputs,\\n\",\n    \"                temperature=TEMPERATURE,\\n\",\n    \"                do_sample=True,\\n\",\n    \"                top_k=TOP_K,\\n\",\n    \"                max_new_tokens=1,\\n\",\n    \"                session=sess,\\n\",\n    \"            )\\n\",\n    \"            bloom_answer_token = tokenizer.decode(outputs[0, -1:])\\n\",\n    \"            print(bloom_answer_token, end=\\\"\\\", flush=True)\\n\",\n    \"            if bloom_answer_token == \\\"\\\\n\\\":\\n\",\n    \"                break\\n\",\n    \"            inputs = None\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3.8.9 64-bit\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.9\"\n  },\n  \"vscode\": {\n   \"interpreter\": {\n    \"hash\": \"31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6\"\n   }\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "examples/prompt-tuning-sst2.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"a07e0f5e\",\n   \"metadata\": {\n    \"id\": \"a07e0f5e\"\n   },\n   \"source\": [\n    \"<div>\\n\",\n    \"<img src=\\\"https://camo.githubusercontent.com/473dd9f992924d27457650251786464f72e54121ac6e9210add0f483ca849277/68747470733a2f2f692e696d6775722e636f6d2f3765523750616e2e706e67\\\" width=\\\"40%\\\">  \\n\",\n    \"</div>\\n\",\n    \"\\n\",\n    \"# Distributed LLaMA for Text Classification using Prompt Tuning\\n\",\n    \"\\n\",\n    \"In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt the [LLaMA](https://github.com/facebookresearch/llama) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the LLaMA blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\\n\",\n    \"\\n\",\n    \"We will adapt LLaMA for the classification task using the [SST-2 dataset](https://nlp.stanford.edu/sentiment/). This dataset is a binary classification task, where the goal is to predict whether a sentence is positive or negative. The SST-2 dataset is a subset of the Stanford Sentiment Treebank, and it is available in the [Hugging Face Datasets](https://huggingface.co/datasets) library.\\n\",\n    \"\\n\",\n    \"To use this notebook in Colab:\\n\",\n    \"\\n\",\n    \"1. Follow this link: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb)\\n\",\n    \"2. Go to **Runtime** -> **Change runtime type** and select the GPU accelerator.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"a3f8526f\",\n   \"metadata\": {\n    \"id\": \"a3f8526f\"\n   },\n   \"source\": [\n    \"First, we have to prepare all dependencies.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"73bbc648\",\n   \"metadata\": {\n    \"id\": \"73bbc648\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"%pip install -q datasets wandb scikit-learn\\n\",\n    \"%pip install -q git+https://github.com/bigscience-workshop/petals@main\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"b4ab6ca7\",\n   \"metadata\": {\n    \"id\": \"b4ab6ca7\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"\\n\",\n    \"import torch\\n\",\n    \"import torch.nn as nn\\n\",\n    \"import torch.nn.functional as F\\n\",\n    \"import transformers\\n\",\n    \"import wandb\\n\",\n    \"from datasets import load_dataset, load_metric\\n\",\n    \"from tqdm import tqdm\\n\",\n    \"from torch.optim import AdamW\\n\",\n    \"from torch.utils.data import DataLoader\\n\",\n    \"from transformers import LlamaTokenizer, get_scheduler, set_seed\\n\",\n    \"\\n\",\n    \"from petals import DistributedLlamaForSequenceClassification\\n\",\n    \"\\n\",\n    \"set_seed(0)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"1bf07b5d\",\n   \"metadata\": {\n    \"id\": \"1bf07b5d\"\n   },\n   \"source\": [\n    \"Let's set some hyperparameters for training:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"f04ba4d2\",\n   \"metadata\": {\n    \"id\": \"f04ba4d2\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"MODEL_NAME = \\\"enoch/llama-65b-hf\\\"\\n\",\n    \"\\n\",\n    \"# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').\\n\",\n    \"# The latter fine-tunes separate prefixes for each transformer block,\\n\",\n    \"# so prompt-tuning will take more time but yield better results.\\n\",\n    \"# See this paper for details of how it works: https://arxiv.org/pdf/2110.07602.pdf\\n\",\n    \"TUNING_MODE = 'ptune'\\n\",\n    \"\\n\",\n    \"NUM_PREFIX_TOKENS = 8\\n\",\n    \"DEVICE = 'cuda'\\n\",\n    \"BATCH_SIZE = 32\\n\",\n    \"LR = 1e-2\\n\",\n    \"WEIGHT_DECAY = 0.0\\n\",\n    \"NUM_EPOCHS = 3\\n\",\n    \"SEED = 42\\n\",\n    \"MODEL_MAX_LENGTH = 64\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"d38316bd\",\n   \"metadata\": {\n    \"id\": \"d38316bd\"\n   },\n   \"source\": [\n    \"Here, we prepare tokenizer and distributed model and connect it to the public swarm.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"03c6e53e\",\n   \"metadata\": {\n    \"id\": \"03c6e53e\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"tokenizer = LlamaTokenizer.from_pretrained(MODEL_NAME)\\n\",\n    \"tokenizer.padding_side = 'right'\\n\",\n    \"tokenizer.model_max_length = MODEL_MAX_LENGTH\\n\",\n    \"tokenizer.pad_token = tokenizer.unk_token\\n\",\n    \"model = DistributedLlamaForSequenceClassification.from_pretrained(\\n\",\n    \"    MODEL_NAME,\\n\",\n    \"    pre_seq_len=NUM_PREFIX_TOKENS,\\n\",\n    \"    tuning_mode=TUNING_MODE\\n\",\n    \").float().to(DEVICE)\\n\",\n    \"model.config.pad_token_id = tokenizer.pad_token_id\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"042e3786\",\n   \"metadata\": {\n    \"id\": \"042e3786\"\n   },\n   \"source\": [\n    \"Let's prepare the SST-2 dataset. We need just one preprocessing function to tokenize the dataset.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"9c44d516\",\n   \"metadata\": {\n    \"id\": \"9c44d516\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"task = 'sst2'\\n\",\n    \"\\n\",\n    \"dataset = load_dataset(\\\"glue\\\", task)\\n\",\n    \"\\n\",\n    \"def preprocess_function(examples):\\n\",\n    \"    return tokenizer(examples[\\\"sentence\\\"], padding='max_length', truncation=True, return_token_type_ids=False)\\n\",\n    \"\\n\",\n    \"tokenized_datasets = dataset.map(preprocess_function, batched=True)\\n\",\n    \"tokenized_datasets = tokenized_datasets.remove_columns([\\\"sentence\\\", \\\"idx\\\", \\\"attention_mask\\\"])\\n\",\n    \"tokenized_datasets = tokenized_datasets.rename_column(\\\"label\\\", \\\"labels\\\")\\n\",\n    \"tokenized_datasets.set_format(\\\"torch\\\")\\n\",\n    \"\\n\",\n    \"train_dataset = tokenized_datasets[\\\"train\\\"].shuffle(seed=SEED)\\n\",\n    \"valid_dataset = tokenized_datasets[\\\"validation\\\"].shuffle(seed=SEED)\\n\",\n    \"\\n\",\n    \"train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE, drop_last=True)\\n\",\n    \"valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"2a3f3590\",\n   \"metadata\": {\n    \"id\": \"2a3f3590\"\n   },\n   \"source\": [\n    \"To monitor training, we need the metric function. For SST-2, the target metric is accuracy. We will load it from the datasets library.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"1e1812be\",\n   \"metadata\": {\n    \"id\": \"1e1812be\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"metric = load_metric('glue', task)\\n\",\n    \"\\n\",\n    \"def eval_metrics(model, dataloader, device='cpu'):\\n\",\n    \"    model.eval()\\n\",\n    \"    for batch in dataloader:\\n\",\n    \"        batch = {k: v.to(device) for k, v in batch.items()}\\n\",\n    \"\\n\",\n    \"        with torch.no_grad():\\n\",\n    \"            outputs = model(**batch)\\n\",\n    \"\\n\",\n    \"        logits = outputs.logits\\n\",\n    \"        predictions = torch.argmax(logits, dim=-1)\\n\",\n    \"        metric.add_batch(predictions=predictions, references=batch[\\\"labels\\\"])\\n\",\n    \"    model.train()\\n\",\n    \"    return metric.compute()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"ef4323fd\",\n   \"metadata\": {\n    \"id\": \"ef4323fd\"\n   },\n   \"source\": [\n    \"Before setting up optimizers, let's check the model parameters that will be trained.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"9cc0ba34\",\n   \"metadata\": {\n    \"id\": \"9cc0ba34\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"for n, p in model.named_parameters():\\n\",\n    \"    if p.requires_grad:\\n\",\n    \"        print(n, p.requires_grad, p.device)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"59cffce7\",\n   \"metadata\": {\n    \"id\": \"59cffce7\"\n   },\n   \"source\": [\n    \"The optimizer will only work on **prompts and classifier head**: they are only trainable parameters. Let's initialize the optimizer and the learning rate scheduler.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"ef9bf344\",\n   \"metadata\": {\n    \"id\": \"ef9bf344\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\\n\",\n    \"\\n\",\n    \"lr_scheduler = get_scheduler(\\n\",\n    \"    name=\\\"linear\\\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader) * NUM_EPOCHS\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"423c56d5\",\n   \"metadata\": {\n    \"id\": \"423c56d5\"\n   },\n   \"source\": [\n    \"Let's initialize wandb for logging and start the training loop!\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"d9e46807\",\n   \"metadata\": {\n    \"id\": \"d9e46807\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"wandb.init(\\n\",\n    \"    project=\\\"bloom-sst-2\\\",\\n\",\n    \"    config={\\n\",\n    \"        \\\"num_epochs\\\": NUM_EPOCHS,\\n\",\n    \"        \\\"batch_size\\\": BATCH_SIZE,\\n\",\n    \"        \\\"learning_rate\\\": LR,\\n\",\n    \"        \\\"weight_decay\\\": WEIGHT_DECAY,\\n\",\n    \"        \\\"num_prefix_tokens\\\": NUM_PREFIX_TOKENS,\\n\",\n    \"        \\\"model_name\\\": MODEL_NAME,\\n\",\n    \"        \\\"seed\\\": SEED,\\n\",\n    \"    }\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"scaler = torch.cuda.amp.GradScaler()\\n\",\n    \"\\n\",\n    \"for epoch in range(NUM_EPOCHS):\\n\",\n    \"    model.train()\\n\",\n    \"    for batch in tqdm(train_dataloader):\\n\",\n    \"        batch = {k: v.to(DEVICE) for k, v in batch.items()}\\n\",\n    \"\\n\",\n    \"        with torch.autocast(device_type=DEVICE, dtype=torch.float16):\\n\",\n    \"          outputs = model(**batch)\\n\",\n    \"        loss = outputs.loss\\n\",\n    \"        scaler.scale(loss).backward()\\n\",\n    \"\\n\",\n    \"        scaler.step(optimizer)\\n\",\n    \"        scaler.update()\\n\",\n    \"        lr_scheduler.step()\\n\",\n    \"        optimizer.zero_grad()\\n\",\n    \"\\n\",\n    \"        wandb.log({\\\"Train Loss\\\": loss.detach()})\\n\",\n    \"\\n\",\n    \"    accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\\n\",\n    \"    wandb.log({\\\"Valid Accuracy\\\": accuracy}, commit=False)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"51770911\",\n   \"metadata\": {\n    \"id\": \"51770911\"\n   },\n   \"source\": [\n    \"Our model has been trained! You can now upload it to the Hub for later use, try out different models [served in the public swarm](https://health.petals.dev/), or [join Petals with your own GPU](https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity)!\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"outputs\": [],\n   \"source\": [],\n   \"metadata\": {\n    \"collapsed\": false\n   }\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.8\"\n  },\n  \"vscode\": {\n   \"interpreter\": {\n    \"hash\": \"31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6\"\n   }\n  },\n  \"colab\": {\n   \"provenance\": [],\n   \"gpuType\": \"T4\"\n  },\n  \"accelerator\": \"GPU\"\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\n    \"setuptools>=42\",\n    \"wheel\"\n]\nbuild-backend = \"setuptools.build_meta\"\n\n[tool.black]\nline-length = 120\nrequired-version = \"22.3.0\"\n\n[tool.isort]\nprofile = \"black\"\nline_length = 120\ncombine_as_imports = true\ncombine_star = true\nknown_local_folder = [\"tests\", \"cli\"]\nknown_first_party = [\"test_utils\"]\n"
  },
  {
    "path": "setup.cfg",
    "content": "[metadata]\nname = petals\nversion = attr: petals.__version__\nauthor = Petals Developers\nauthor_email = petals-devs@googlegroups.com\ndescription = Easy way to efficiently run 100B+ language models without high-end GPUs\nlong_description = file: README.md\nlong_description_content_type = text/markdown\nurl = https://github.com/bigscience-workshop/petals\nproject_urls =\n    Bug Tracker = https://github.com/bigscience-workshop/petals/issues\nclassifiers =\n    Development Status :: 4 - Beta\n    Intended Audience :: Developers\n    Intended Audience :: Science/Research\n    License :: OSI Approved :: MIT License\n    Programming Language :: Python :: 3\n    Programming Language :: Python :: 3.8\n    Programming Language :: Python :: 3.9\n    Programming Language :: Python :: 3.10\n    Programming Language :: Python :: 3.11\n    Topic :: Scientific/Engineering\n    Topic :: Scientific/Engineering :: Mathematics\n    Topic :: Scientific/Engineering :: Artificial Intelligence\n    Topic :: Software Development\n    Topic :: Software Development :: Libraries\n    Topic :: Software Development :: Libraries :: Python Modules\n\n[options]\npackage_dir =\n    = src\npackages = find:\npython_requires = >=3.8\ninstall_requires =\n    torch>=1.12\n    bitsandbytes==0.41.1\n    accelerate>=0.27.2\n    huggingface-hub>=0.11.1,<1.0.0\n    tokenizers>=0.13.3\n    transformers==4.43.1  # if you change this, please also change version assert in petals/__init__.py\n    speedtest-cli==2.1.3\n    hivemind @ git+https://github.com/learning-at-home/hivemind.git@213bff98a62accb91f254e2afdccbf1d69ebdea9\n    tensor_parallel==1.0.23\n    humanfriendly\n    async-timeout>=4.0.2\n    cpufeature>=0.2.0; platform_machine == \"x86_64\"\n    packaging>=20.9\n    sentencepiece>=0.1.99\n    peft==0.8.2\n    safetensors>=0.3.1\n    Dijkstar>=2.6.0\n    numpy<2\n\n[options.extras_require]\ndev =\n    pytest==6.2.5\n    pytest-forked\n    pytest-asyncio==0.16.0\n    black==22.3.0\n    isort==5.10.1\n    psutil\n\n[options.packages.find]\nwhere = src\n"
  },
  {
    "path": "src/petals/__init__.py",
    "content": "import os\nimport platform\n\nos.environ.setdefault(\"BITSANDBYTES_NOWELCOME\", \"1\")\n\nif platform.system() == \"Darwin\":\n    # Necessary for forks to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93\n    os.environ.setdefault(\"no_proxy\", \"*\")\n    os.environ.setdefault(\"OBJC_DISABLE_INITIALIZE_FORK_SAFETY\", \"YES\")\n\nimport hivemind\nimport transformers\nfrom packaging import version\n\nfrom petals.client import *\nfrom petals.models import *\nfrom petals.utils import *\nfrom petals.utils.logging import initialize_logs as _initialize_logs\n\n__version__ = \"2.3.0.dev2\"\n\n\nif not os.getenv(\"PETALS_IGNORE_DEPENDENCY_VERSION\"):\n    assert (\n        version.parse(\"4.43.1\") <= version.parse(transformers.__version__) < version.parse(\"4.44.0\")\n    ), \"Please install a proper transformers version: pip install transformers>=4.43.1,<4.44.0\"\n\n\ndef _override_bfloat16_mode_default():\n    if os.getenv(\"USE_LEGACY_BFLOAT16\") is None:\n        hivemind.compression.base.USE_LEGACY_BFLOAT16 = False\n\n\n_initialize_logs()\n_override_bfloat16_mode_default()\n"
  },
  {
    "path": "src/petals/cli/__init__.py",
    "content": ""
  },
  {
    "path": "src/petals/cli/run_dht.py",
    "content": "\"\"\"\nA copy of run_dht.py from hivemind with the ReachabilityProtocol added:\nhttps://github.com/learning-at-home/hivemind/blob/master/hivemind/hivemind_cli/run_dht.py\n\nThis script may be used for launching lightweight CPU machines serving as bootstrap nodes to a Petals swarm.\n\nThis may be eventually merged to the hivemind upstream.\n\"\"\"\n\nimport argparse\nimport time\nfrom secrets import token_hex\n\nfrom hivemind.dht import DHT, DHTNode\nfrom hivemind.utils.logging import get_logger, use_hivemind_log_handler\nfrom hivemind.utils.networking import log_visible_maddrs\n\nfrom petals.server.reachability import ReachabilityProtocol\n\nuse_hivemind_log_handler(\"in_root_logger\")\nlogger = get_logger(__name__)\n\n\nasync def report_status(dht: DHT, node: DHTNode):\n    logger.info(\n        f\"{len(node.protocol.routing_table.uid_to_peer_id) + 1} DHT nodes (including this one) \"\n        f\"are in the local routing table \"\n    )\n    logger.debug(f\"Routing table contents: {node.protocol.routing_table}\")\n    logger.info(f\"Local storage contains {len(node.protocol.storage)} keys\")\n    logger.debug(f\"Local storage contents: {node.protocol.storage}\")\n\n    # Contact peers and keep the routing table healthy (remove stale PeerIDs)\n    await node.get(f\"heartbeat_{token_hex(16)}\", latest=True)\n\n\ndef main():\n    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n    parser.add_argument(\n        \"--initial_peers\",\n        nargs=\"*\",\n        help=\"Multiaddrs of the peers that will welcome you into the existing DHT. \"\n        \"Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/tcp/7777/p2p/YYYY\",\n    )\n    parser.add_argument(\n        \"--host_maddrs\",\n        nargs=\"*\",\n        default=[\"/ip4/0.0.0.0/tcp/0\", \"/ip6/::/tcp/0\"],\n        help=\"Multiaddrs to listen for external connections from other DHT instances. \"\n        \"Defaults to all IPv4 interfaces and the TCP protocol: /ip4/0.0.0.0/tcp/0\",\n    )\n    parser.add_argument(\n        \"--announce_maddrs\",\n        nargs=\"*\",\n        help=\"Visible multiaddrs the host announces for external connections from other DHT instances\",\n    )\n    parser.add_argument(\n        \"--use_ipfs\",\n        action=\"store_true\",\n        help='Use IPFS to find initial_peers. If enabled, you only need to provide the \"/p2p/XXXX\" '\n        \"part of the multiaddrs for the initial_peers \"\n        \"(no need to specify a particular IPv4/IPv6 host and port)\",\n    )\n    parser.add_argument(\n        \"--identity_path\",\n        help=\"Path to a private key file. If defined, makes the peer ID deterministic. \"\n        \"If the file does not exist, writes a new private key to this file.\",\n    )\n    parser.add_argument(\n        \"--no_relay\",\n        action=\"store_false\",\n        dest=\"use_relay\",\n        help=\"Disable circuit relay functionality in libp2p (see https://docs.libp2p.io/concepts/nat/circuit-relay/)\",\n    )\n    parser.add_argument(\n        \"--use_auto_relay\",\n        action=\"store_true\",\n        help=\"Look for libp2p relays to become reachable if we are behind NAT/firewall\",\n    )\n    parser.add_argument(\n        \"--refresh_period\", type=int, default=30, help=\"Period (in seconds) for fetching the keys from DHT\"\n    )\n\n    args = parser.parse_args()\n\n    dht = DHT(\n        start=True,\n        initial_peers=args.initial_peers,\n        host_maddrs=args.host_maddrs,\n        announce_maddrs=args.announce_maddrs,\n        use_ipfs=args.use_ipfs,\n        identity_path=args.identity_path,\n        use_relay=args.use_relay,\n        use_auto_relay=args.use_auto_relay,\n    )\n    log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs)\n\n    reachability_protocol = ReachabilityProtocol.attach_to_dht(dht, await_ready=True)\n\n    while True:\n        dht.run_coroutine(report_status, return_future=False)\n        time.sleep(args.refresh_period)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/petals/cli/run_prod_server.sh",
    "content": "#!/bin/bash\nset -x\n\nexport HIVEMIND_COLORS=true\nwhile true; do\n        pkill -f p2p\n        pkill -f run_server\n        python -m petals.cli.run_server bigscience/bloom-petals \"$@\" 2>&1 | tee log_`date '+%F_%H:%M:%S'`\ndone\n"
  },
  {
    "path": "src/petals/cli/run_server.py",
    "content": "import argparse\nimport logging\n\nimport configargparse\nimport torch\nfrom hivemind.proto.runtime_pb2 import CompressionType\nfrom hivemind.utils import limits\nfrom hivemind.utils.logging import get_logger\nfrom humanfriendly import parse_size\n\nfrom petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS\nfrom petals.server.server import Server\nfrom petals.utils.convert_block import QuantType\nfrom petals.utils.version import validate_version\n\nlogger = get_logger(__name__)\n\n\ndef main():\n    # fmt:off\n    parser = configargparse.ArgParser(default_config_files=[\"config.yml\"],\n                                      formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n    parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')\n\n    group = parser.add_mutually_exclusive_group(required=True)\n    group.add_argument('--converted_model_name_or_path', type=str, default=None,\n                       help=\"path or name of a pretrained model, converted with cli/convert_model.py\")\n    group.add_argument('model', nargs='?', type=str, help=\"same as --converted_model_name_or_path\")\n\n    parser.add_argument(\"--public_name\", type=str, default=None, help=\"Public name to be reported in the leaderboard\")\n\n    group = parser.add_mutually_exclusive_group(required=False)\n    group.add_argument(\"--token\", type=str, default=None, help=\"Hugging Face hub auth token for .from_pretrained()\")\n    group.add_argument(\"--use_auth_token\", action=\"store_true\", dest=\"token\",\n                       help=\"Read token saved by `huggingface-cli login\")\n\n    parser.add_argument('--num_blocks', type=int, default=None, help=\"The number of blocks to serve\")\n    parser.add_argument('--block_indices', type=str, default=None, help=\"Specific block indices to serve\")\n    parser.add_argument('--dht_prefix', type=str, default=None, help=\"Announce all blocks with this DHT prefix\")\n\n    parser.add_argument('--port', type=int, required=False,\n                        help='Port this server listens to. '\n                             'This is a simplified way to set the --host_maddrs and --announce_maddrs options (see below) '\n                             'that sets the port across all interfaces (IPv4, IPv6) and protocols (TCP, etc.) '\n                             'to the same number. Default: a random free port is chosen for each interface and protocol')\n    parser.add_argument('--public_ip', type=str, required=False,\n                        help='Your public IPv4 address, which is visible from the Internet. '\n                             'This is a simplified way to set the --announce_maddrs option (see below).'\n                             'Default: server announces IPv4/IPv6 addresses of your network interfaces')\n\n    parser.add_argument(\"--no_auto_relay\", action=\"store_false\", dest=\"use_auto_relay\",\n                        help=\"Do not look for libp2p relays to become reachable if we are behind NAT/firewall\")\n\n    parser.add_argument('--host_maddrs', nargs='+', required=False,\n                        help='Multiaddrs to listen for external connections from other peers')\n    parser.add_argument('--announce_maddrs', nargs='+', required=False,\n                        help='Visible multiaddrs the host announces for external connections from other peers')\n\n    parser.add_argument('--daemon_startup_timeout', type=float, default=60,\n                        help='Timeout for the libp2p daemon connecting to initial peers')\n\n    parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')\n\n    parser.add_argument('--num_handlers', type=int, default=8, required=False,\n                        help='server will use this many processes to handle incoming requests')\n    parser.add_argument('--prefetch_batches', type=int, default=1, required=False,\n                        help='Pre-form this many subsequent batches while GPU is processing the current one')\n    parser.add_argument('--sender_threads', type=int, default=1, required=False,\n                        help='Use this many threads to pass results/exceptions from Runtime to Pools')\n\n    parser.add_argument('--inference_max_length', type=int, default=None,\n                        help='Maximum total sequence length permitted per inference, defaults to 16384 tokens. '\n                             'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others')\n    parser.add_argument('--min_batch_size', type=int, default=1,\n                        help='Minimum required batch size for all operations (in total tokens)')\n    parser.add_argument('--max_batch_size', type=int, default=None,\n                        help='The total number of tokens in the same batch will not exceed this value. '\n                             'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others')\n    parser.add_argument('--max_chunk_size_bytes', type=int, default=256 * 1024 * 1024,\n                        help='Maximum size of activation tensor processed in one go; larger tensors are split into chunks')\n    parser.add_argument('--attn_cache_tokens', type=int, default=None,\n                        help='The number of past attention key/value pairs that will be stored between inference steps. '\n                             'Default: 16384 for models with multi-query attention (based on Llama 2, Falcon), 4096 for others')\n\n    parser.add_argument('--cache_dir', type=str, default=None,\n                        help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')\n    parser.add_argument(\"--max_disk_space\", type=str, default=None,\n                        help=\"Maximal disk space used for caches. Example: 50GB, 100GiB (GB != GiB here). \"\n                             \"Default: unlimited. \"\n                             \"For bigscience/bloom-petals, this default means that the server may use up to \"\n                             \"min(free_disk_space, 350GB) in the worst case, which happens when the server runs \"\n                             \"for a long time and caches all model blocks after a number of rebalancings. \"\n                             \"However, this worst case is unlikely, expect the server to consume \"\n                             \"the disk space equal to 2-4x of your GPU memory on average.\")\n\n    parser.add_argument('--device', type=str, default=None, required=False,\n                        help='all blocks will use this device in torch notation; default: cuda if available else cpu')\n    parser.add_argument(\"--torch_dtype\", type=str, choices=DTYPE_MAP.keys(), default=\"auto\",\n                        help=\"Use this dtype to store block weights and do computations. \"\n                             \"By default, respect the dtypes in the pre-trained state dict.\")\n    parser.add_argument('--max_alloc_timeout', type=float, default=600,\n                        help=\"If the cache is full, the server will wait for memory to be freed up to this many seconds\"\n                             \" before rejecting the request\")\n    parser.add_argument('--revision', type=str, default=None,\n                        help=\"The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models\"\n                             \"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.\")\n\n    parser.add_argument('--throughput',\n                        type=lambda value: value if value in ['auto', 'eval', 'dry_run'] else float(value),\n                        default='auto',\n                        help='Expected server throughput (a float measured in RPS). '\n                             'If set to \"auto\" (default), the script evaluates network and compute throughput '\n                             'on the first run and uses these estimates for future runs. '\n                             'If set to \"eval\", the script re-evaluates the throughput and overrides the cache. '\n                             'If set to \"dry_run\", the script re-evaluates the throughput and exits.')\n    parser.add_argument('--update_period', type=float, required=False, default=120,\n                        help='Server will report blocks to DHT once in this many seconds')\n    parser.add_argument('--expiration', type=float, required=False, default=None,\n                        help='DHT entries will expire after this many seconds')\n    parser.add_argument('--request_timeout', type=float, required=False, default=3 * 60,\n                        help='Timeout (in seconds) for the whole rpc_forward/rpc_backward/rpc_forward_stream/rpc_backward_stream request')\n    parser.add_argument('--session_timeout', type=float, required=False, default=30 * 60,\n                        help='Timeout (in seconds) for the whole inference session')\n    parser.add_argument('--step_timeout', type=float, required=False, default=5 * 60,\n                        help=\"Timeout (in seconds) for waiting the next step's inputs inside an inference session\")\n\n    group = parser.add_mutually_exclusive_group()\n    group.add_argument('--initial_peers', type=str, nargs='+', required=False, default=PUBLIC_INITIAL_PEERS,\n                       help='Multiaddrs of one or more DHT peers from the target swarm. Default: connects to the public swarm')\n    group.add_argument('--new_swarm', action='store_true',\n                       help='Start a new private swarm (i.e., do not connect to any initial peers)')\n\n    parser.add_argument('--increase_file_limit', type=int, default=4096,\n                        help='On *nix, increase the max number of files a server can open '\n                             'before hitting \"Too many open files\" (set to zero to keep the system limit)')\n    parser.add_argument('--stats_report_interval', type=int, required=False,\n                        help='Interval between two reports of batch processing performance statistics')\n\n    parser.add_argument('--custom_module_path', type=str, required=False,\n                        help='Path of a file with custom nn.modules, wrapped into special decorator')\n    parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')\n\n    parser.add_argument(\"--balance_quality\", type=float, default=0.75,\n                        help=\"Rebalance the swarm if its throughput is worse than this share of the optimal \"\n                             \"throughput. Use 0.0 to disable rebalancing, values > 1.0 to force rebalancing \"\n                             \"on each check for debugging purposes.\")\n    parser.add_argument(\"--mean_balance_check_period\", type=float, default=60,\n                        help=\"Check the swarm's balance every N seconds (and rebalance it if necessary)\")\n\n    parser.add_argument('--quant_type', type=str, default=None, choices=[choice.name.lower() for choice in QuantType],\n                        help=\"Quantize blocks to 8-bit (int8 from the LLM.int8() paper) or \"\n                             \"4-bit (nf4 from the QLoRA paper) formats to save GPU memory. \"\n                             \"Default: 'int8' if GPU is available, 'none' otherwise\")\n    parser.add_argument(\"--tensor_parallel_devices\", nargs='+', default=None,\n                        help=\n                        \"Split each block between the specified GPUs such that each device holds a portion of every \"\n                        \"weight matrix. See https://huggingface.co/transformers/v4.9.0/parallelism.html#tensor-parallelism\")\n\n    parser.add_argument(\"--skip_reachability_check\", action='store_true',\n                        help=\"Skip checking this server's reachability via health.petals.dev \"\n                             \"when connecting to the public swarm. If you connect to a private swarm, \"\n                             \"the check is skipped by default. Use this option only if you know what you are doing\")\n\n    parser.add_argument(\"--adapters\", nargs='*', default=(),\n                        help=\"List of pre-loaded LoRA adapters that can be used for inference or training\")\n\n    # fmt:on\n    args = vars(parser.parse_args())\n    args.pop(\"config\", None)\n\n    args[\"converted_model_name_or_path\"] = args.pop(\"model\") or args[\"converted_model_name_or_path\"]\n\n    host_maddrs = args.pop(\"host_maddrs\")\n    port = args.pop(\"port\")\n    if port is not None:\n        assert host_maddrs is None, \"You can't use --port and --host_maddrs at the same time\"\n    else:\n        port = 0\n    if host_maddrs is None:\n        host_maddrs = [f\"/ip4/0.0.0.0/tcp/{port}\", f\"/ip6/::/tcp/{port}\"]\n\n    announce_maddrs = args.pop(\"announce_maddrs\")\n    public_ip = args.pop(\"public_ip\")\n    if public_ip is not None:\n        assert announce_maddrs is None, \"You can't use --public_ip and --announce_maddrs at the same time\"\n        assert port != 0, \"Please specify a fixed non-zero --port when you use --public_ip (e.g., --port 31337)\"\n        announce_maddrs = [f\"/ip4/{public_ip}/tcp/{port}\"]\n\n    args[\"startup_timeout\"] = args.pop(\"daemon_startup_timeout\")\n\n    file_limit = args.pop(\"increase_file_limit\")\n    if file_limit:\n        limits.logger.setLevel(logging.WARNING)\n        limits.increase_file_limit(file_limit, file_limit)\n\n    compression_type = args.pop(\"compression\").upper()\n    compression = getattr(CompressionType, compression_type)\n\n    max_disk_space = args.pop(\"max_disk_space\")\n    if max_disk_space is not None:\n        max_disk_space = parse_size(max_disk_space)\n    assert isinstance(\n        max_disk_space, (int, type(None))\n    ), \"Unrecognized value for --max_disk_space. Correct examples: 1.5GB or 1500MB or 1572864000 (bytes)\"\n\n    if args.pop(\"new_swarm\"):\n        args[\"initial_peers\"] = []\n\n    quant_type = args.pop(\"quant_type\")\n    if quant_type is not None:\n        args[\"quant_type\"] = QuantType[quant_type.upper()]\n\n    validate_version()\n\n    if not torch.backends.openmp.is_available():\n        # Necessary to prevent the server from freezing after forks\n        torch.set_num_threads(1)\n\n    server = Server(\n        **args,\n        host_maddrs=host_maddrs,\n        announce_maddrs=announce_maddrs,\n        compression=compression,\n        max_disk_space=max_disk_space,\n    )\n    try:\n        server.run()\n    except KeyboardInterrupt:\n        logger.info(\"Caught KeyboardInterrupt, shutting down\")\n    finally:\n        server.shutdown()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/petals/client/__init__.py",
    "content": "from petals.client.config import ClientConfig\nfrom petals.client.inference_session import InferenceSession\nfrom petals.client.remote_sequential import RemoteSequential\nfrom petals.client.routing import NoSpendingPolicy, RemoteSequenceManager, SpendingPolicyBase\n"
  },
  {
    "path": "src/petals/client/config.py",
    "content": "import dataclasses\nimport os\nfrom typing import Optional, Sequence, Union\n\nfrom hivemind import PeerID\n\nfrom petals.constants import PUBLIC_INITIAL_PEERS\n\n_max_retries = os.getenv(\"PETALS_MAX_RETRIES\")\nDEFAULT_MAX_RETRIES = int(_max_retries) if isinstance(_max_retries, str) else None\n\n\n@dataclasses.dataclass\nclass ClientConfig:\n    initial_peers: Sequence[str] = tuple(PUBLIC_INITIAL_PEERS)  # a list of initial peers for hivemind DHT\n    dht_prefix: Optional[str] = None  # a prefix for all dht keys that correspond to this model (default: model name)\n    daemon_startup_timeout: int = 60  # timeout for the libp2p daemon connecting to initial peers\n\n    show_route: Union[str, bool] = \"inference\"  # show chosen route through servers. one of [False, \"inference\", True]\n    allowed_servers: Optional[Sequence[Union[PeerID, str]]] = None  # if defined, send requests only to these servers\n    blocked_servers: Optional[Sequence[Union[PeerID, str]]] = None  # if defined, do not use these servers\n    use_server_to_server: bool = True  # Use direct server-to-server communication\n\n    connect_timeout: float = 5  # timeout for opening a connection\n    request_timeout: float = 3 * 60  # timeout for forward/backward/inference requests\n    update_period: float = 60  # refresh DHT information once in this many seconds\n\n    max_retries: Optional[int] = DEFAULT_MAX_RETRIES  # max number of retries before an exception (default: inf)\n    min_backoff: float = 1  # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)\n    max_backoff: float = 60  # limit maximal sleep time between retries to this value\n    ban_timeout: float = 15  # when a remote peer fails to respond, prevent routing to that peer for this many seconds\n    active_adapter: Optional[str] = None  # name of active LoRA adapter (usually, Hugging Face repo)\n\n    max_pinged: int = 3  # max servers to ping from each sequence side, per update\n    ping_timeout: float = 2  # max time to wait for pings, per update\n"
  },
  {
    "path": "src/petals/client/from_pretrained.py",
    "content": "import contextlib\nimport json\nimport os\nimport re\nimport tempfile\nfrom contextvars import ContextVar\nfrom typing import List, Optional, Tuple, Union\n\nfrom hivemind.utils.logging import get_logger\nfrom transformers import BloomPreTrainedModel, modeling_utils\n\nfrom petals.utils.version import get_compatible_model_repo\n\nlogger = get_logger(__name__)\n\n\nclass FromPretrainedMixin:\n    @classmethod\n    def from_pretrained(\n        cls,\n        model_name_or_path: Union[str, os.PathLike, None],\n        *args,\n        low_cpu_mem_usage: Optional[bool] = None,\n        **kwargs,\n    ):\n        model_name_or_path = get_compatible_model_repo(model_name_or_path)\n        if low_cpu_mem_usage is None:\n            low_cpu_mem_usage = True\n\n        with ignore_keys(cls._keys_to_ignore_on_load_unexpected):\n            return super().from_pretrained(model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)\n\n    from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(\n        \"low_cpu_mem_usage(`bool`, *optional*)\",\n        \"low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)\",\n    ).replace(\n        \"torch_dtype (`str` or `torch.dtype`, *optional*)\",\n        'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `\"auto\"` in Petals)',\n    )\n\n\n_ignored_keys = ContextVar(\"ignored_keys\", default=None)\n\n\n@contextlib.contextmanager\ndef ignore_keys(patterns: List[str]):\n    token = _ignored_keys.set(patterns)\n    try:\n        yield\n    finally:\n        _ignored_keys.reset(token)\n\n\ndef patched_get_checkpoint_shard_files(\n    pretrained_model_name_or_path, index_filename, *args, **kwargs\n) -> Tuple[List[str], dict]:\n    \"\"\"Same as modeling_utils.get_checkpoint_shard_files(), but does not download shards for the ignored keys.\"\"\"\n\n    should_ignore_keys = _ignored_keys.get() is not None\n    tempdir_ctx = tempfile.TemporaryDirectory() if should_ignore_keys else contextlib.nullcontext()\n    with tempdir_ctx as tempdir:\n        if should_ignore_keys:\n            with open(index_filename) as f:\n                index = json.load(f)\n            n_original_shards = len(set(index[\"weight_map\"].values()))\n\n            index[\"weight_map\"] = {\n                param_name: filename\n                for param_name, filename in index[\"weight_map\"].items()\n                if all(re.search(pattern, param_name) is None for pattern in _ignored_keys.get())\n            }\n            n_loaded_shards = len(set(index[\"weight_map\"].values()))\n            logger.debug(f\"Loading {n_loaded_shards} shards out of {n_original_shards}\")\n\n            # Replace the original index with a patched JSON, where ignored keys are removed\n            index_filename = os.path.join(tempdir, \"pytorch_model.bin.index.json\")\n            with open(index_filename, \"w\") as f:\n                json.dump(index, f)\n\n        return original_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)\n\n\noriginal_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files\nmodeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files\n"
  },
  {
    "path": "src/petals/client/inference_session.py",
    "content": "from __future__ import annotations\n\nimport asyncio\nimport itertools\nimport time\nimport uuid\nfrom typing import AsyncIterator, List, Optional, Tuple\n\nimport torch\nfrom hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor\nfrom hivemind.moe.client.remote_expert_worker import RemoteExpertWorker\nfrom hivemind.p2p import P2P\nfrom hivemind.proto import runtime_pb2\nfrom hivemind.utils.tensor_descr import BatchTensorDescriptor\n\nfrom petals.client.config import ClientConfig\nfrom petals.client.routing import RemoteSequenceManager, maybe_log_traceback\nfrom petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo\nfrom petals.server.handler import TransformerConnectionHandler\nfrom petals.utils.misc import DUMMY, DUMMY_INT64, is_dummy\nfrom petals.utils.packaging import pack_args_kwargs\n\nlogger = get_logger(__name__)\n\n\nclass _ServerInferenceSession:\n    \"\"\"\n    An interface to a single multi-step *inference* session for a a set of blocks on a specific server.\n\n    :note: This class is *not* fault-tolerant out of the box.\n    \"\"\"\n\n    def __init__(\n        self,\n        config: ClientConfig,\n        span: RemoteSpanInfo,\n        uid: ModuleUID,\n        rpc_info: RPCInfo,\n        inputs_queue: asyncio.Queue,\n        outputs_aiter: AsyncIterator,\n        *,\n        max_length: int,\n        **metadata,\n    ):\n        self.config = config\n        self.span, self.uid, self.rpc_info = span, uid, rpc_info\n        self.num_blocks = uid.count(CHAIN_DELIMITER) + 1\n        self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue\n        self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter\n        self.session_id = str(uuid.uuid4())\n        self.session_metadata = dict(max_length=max_length, **metadata)\n        self.stepped = False\n        self.closed = False\n\n        self._position = 0\n        self.history = None  # Used in case of server failures to regenerate attention caches on new servers\n        self.next_session = None\n\n    @classmethod\n    async def create(\n        cls,\n        config: ClientConfig,\n        p2p: P2P,\n        span: RemoteSpanInfo,\n        uid: ModuleUID,\n        rpc_info: RPCInfo,\n        **metadata,\n    ) -> _ServerInferenceSession:\n        \"\"\"Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker\"\"\"\n        stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id)\n        inputs_queue = asyncio.Queue()\n        outputs_stream = await asyncio.wait_for(\n            stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),\n            config.connect_timeout,\n        )\n        return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata)\n\n    @staticmethod\n    async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:\n        while True:\n            next_input_message = await asyncio.wait_for(queue.get(), input_timeout)\n            yield next_input_message\n            if not next_input_message.uid and not next_input_message.tensors:\n                break  # this message means \"done sending\"\n\n    @property\n    def position(self):\n        return self._position\n\n    @position.setter\n    def position(self, start_from_position: int):\n        assert start_from_position <= self._position\n        self._position = start_from_position\n        if self.history is not None and self.history.shape[1] >= start_from_position:\n            self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None\n\n    def step(\n        self,\n        inputs: torch.Tensor,\n        prompts: torch.Tensor,\n        hypo_ids: torch.LongTensor,\n        *,\n        step_id: str,\n    ) -> torch.Tensor:\n        \"\"\"\n        Inference step: send a chunk of input tensors and receive a chunk of outputs\n        :prompts: optional DEEP prompts, added to a prefix of each layer's outputs,\n          if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]\n        \"\"\"\n        if self.closed:\n            raise Exception(\"Session is closed, cannot perform step\")\n\n        n_input_tokens = inputs.shape[1]\n        if self.history is None:\n            self.history = inputs\n        elif self.history.shape[1] == self._position:\n            self.history = torch.cat([self.history, inputs[:, -n_input_tokens:]], dim=1)\n        assert self.history.shape[1] == self._position + n_input_tokens, (\n            f\"Broken input cache: span={self.span} shape={self.history.shape} \"\n            f\"position={self._position} n_input_tokens={n_input_tokens}\"\n        )\n\n        if not self.stepped:\n            inputs = self.history  # Pass full inputs including prefix\n        else:\n            inputs = inputs[:, -n_input_tokens:]  # No need to pass prefix further\n\n        # serialize inputs and put them into the queue\n        input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids)\n\n        request_metadata = dict(session_id=self.session_id, step_id=step_id)\n        if not self.stepped:\n            request_metadata.update(self.session_metadata)\n        if self._position is not None:\n            request_metadata[\"start_from_position\"] = self._position\n        elif self.config.use_server_to_server:\n            next_servers = self._collect_next_servers()\n            if next_servers:\n                request_metadata[\"next_servers\"] = next_servers\n\n        request_metadata[\"args_structure\"] = args_structure\n\n        # TODO: make possible to use different compression method for different tensors\n        server_side_inference_schema, kwargs_schema = self.rpc_info[\"inference_schema\"]\n        compression = server_side_inference_schema[0].compression\n        inference_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in input_tensors)\n\n        # TODO: create more explicit way to check servers schema and client's structure\n        assert len(input_tensors) >= len(\n            server_side_inference_schema\n        ), \"Hidden_state, prompts and hypo_ids tensors are necessary for an inference step\"\n\n        outputs_serialized = RemoteExpertWorker.run_coroutine(\n            self._step(\n                runtime_pb2.ExpertRequest(\n                    uid=self.uid,\n                    tensors=[\n                        serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)\n                        for tensor, proto in zip(input_tensors, inference_schema)\n                    ],\n                    metadata=MSGPackSerializer.dumps(request_metadata),\n                )\n            )\n        )\n        outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))\n        assert (\n            outputs[0].shape == inputs.shape\n        ), f\"output activation shape is different from input shape: {outputs[0].shape} != {inputs.shape}\"\n\n        self._position += n_input_tokens\n\n        return outputs[0]\n\n    def _collect_next_servers(self) -> List[Tuple[str, str, int, int]]:\n        next_servers = []\n        session = self.next_session\n        while session is not None and session.stepped:\n            next_servers.append(\n                (session.span.peer_id.to_base58(), session.session_id, session.span.start, session.span.end)\n            )\n            session = session.next_session\n        return next_servers\n\n    async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:\n        \"\"\"Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker\"\"\"\n        await self._inputs_queue.put(inputs_serialized)\n        self.stepped = True\n        return await asyncio.wait_for(anext(self._outputs_stream), self.config.request_timeout)\n\n    def close(self):\n        \"\"\"Finish a given inference session, close the underlying connection\"\"\"\n        if self._outputs_stream is None:\n            return  # already closed\n        RemoteExpertWorker.run_coroutine(self._aclose_stream())\n        self._outputs_stream = self._inputs_queue = None\n        self.closed = True\n\n    async def _aclose_stream(self):\n        \"\"\"Close the inference session. This code is meant to be run inside RemoteExpertWorker\"\"\"\n        if self._outputs_stream is None:\n            return  # already closed\n        if self.stepped:\n            await self._inputs_queue.put(runtime_pb2.ExpertRequest())  # empty request will trigger end of session\n            try:\n                await anext(self._outputs_stream)\n            except StopAsyncIteration:\n                pass\n\n    def __del__(self):\n        self.close()\n\n    def __enter__(self):\n        assert not self.closed\n        return self\n\n    def __exit__(self, *exc_details):\n        self.close()\n\n\nclass InferenceSession:\n    \"\"\"\n    An interface to a multi-step *inference* session for a sequence of remote transformer blocks\n    \"\"\"\n\n    def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int):\n        self._sequence_manager = sequence_manager\n        self._closed = False\n        self._server_sessions = []\n        self._position = 0\n        self._max_length = max_length\n        self.output_ids = None\n        self.past_key_values = None\n\n    @property\n    def num_blocks(self) -> int:\n        return len(self._sequence_manager)\n\n    @property\n    def position(self) -> int:\n        return self._position\n\n    @position.setter\n    def position(self, start_from_position: int) -> None:\n        self._position = start_from_position\n        for session in self._server_sessions:\n            assert isinstance(session, _ServerInferenceSession)\n            session.position = start_from_position\n\n    def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:\n        server_sessions = []\n        try:\n            for span in chosen_spans:\n                span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])\n                metadata = self._sequence_manager.get_request_metadata(\"rpc_inference\", span_uids, peer_id=span.peer_id)\n                session = RemoteExpertWorker.run_coroutine(\n                    _ServerInferenceSession.create(\n                        self._sequence_manager.config,\n                        self._sequence_manager.state.p2p,\n                        span,\n                        span_uids,\n                        rpc_info=self._sequence_manager.rpc_info,\n                        max_length=self._max_length,\n                        **metadata,\n                    )\n                )\n                server_sessions.append(session)\n                session.__enter__()\n            return server_sessions\n        except:\n            self._exit_server_sessions(server_sessions)\n            raise\n\n    def _exit_server_sessions(self, server_sessions: List[_ServerInferenceSession]) -> None:\n        for session in reversed(server_sessions):\n            try:\n                session.__exit__(None, None, None)\n            except Exception:\n                logger.debug(\"Caught exception while closing connection to server:\", exc_info=True)\n\n    def __enter__(self) -> \"InferenceSession\":\n        assert not self._closed and not self._server_sessions\n        return self\n\n    def step(\n        self,\n        inputs: torch.Tensor,\n        prompts: Optional[torch.Tensor] = None,\n        hypo_ids: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        assert not self._closed\n        if torch.is_grad_enabled():\n            logger.warning(\"Running inference session with grad enabled. Gradients will *not* be propagated correctly.\")\n\n        if prompts is None or is_dummy(prompts):\n            prompts = DUMMY\n        else:\n            assert prompts.ndim == 4, \"deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]\"\n            assert prompts.shape[0] == self.num_blocks\n            assert prompts.shape[1] in (inputs.shape[0], 1)\n            assert prompts.shape[2] <= inputs.shape[1]\n            assert prompts.shape[3] == inputs.shape[2]\n\n        if hypo_ids is None or is_dummy(hypo_ids):\n            hypo_ids = DUMMY_INT64\n        else:\n            assert len(hypo_ids) == len(inputs)\n            assert hypo_ids.dtype == torch.int64\n\n        inputs_device = inputs.device\n        inputs_dtype = inputs.dtype\n        inputs = inputs.cpu()\n        prompts = prompts.cpu()\n        hypo_ids = hypo_ids.cpu()\n        step_id = str(uuid.uuid4())\n\n        n_input_tokens = inputs.shape[1]\n        if self._position + n_input_tokens > self._max_length:\n            raise ValueError(\n                f\"Maximum length exceeded: prefix {self._position} + current {n_input_tokens} exceeds pre-allocated maximum {self._max_length}\"\n            )\n\n        server_idx = 0\n        block_idx = 0\n        while block_idx < self.num_blocks:\n            for attempt_no in itertools.count():\n                logger.debug(f\"Inference: block {block_idx}, attempt {attempt_no}\")\n                server_session = None\n                try:\n                    if not self._server_sessions or attempt_no >= 1:\n                        self._update_sequence(server_idx, block_idx, attempt_no)\n\n                    server_session = self._server_sessions[server_idx]\n                    assert server_session.position == self.position, f\"{server_session.position} and {self.position}\"\n                    inputs = server_session.step(\n                        inputs,\n                        prompts[server_session.span.start : server_session.span.end],\n                        hypo_ids,\n                        step_id=step_id,\n                    )\n\n                    server_idx += 1\n                    block_idx = server_session.span.end\n                    self._sequence_manager.on_request_success(server_session.span.peer_id)\n                    break\n                except Exception as e:\n                    self._sequence_manager.on_request_failure(\n                        server_session.span.peer_id if server_session is not None else None\n                    )\n                    if attempt_no + 1 == self._sequence_manager.config.max_retries:\n                        raise\n                    delay = self._sequence_manager.get_retry_delay(attempt_no)\n                    logger.warning(\n                        f\"Caught exception when running inference via {server_session.span if server_session is not None else None} \"\n                        f\"(retry in {delay:.0f} sec): {repr(e)}\"\n                    )\n                    maybe_log_traceback(e)\n                    time.sleep(delay)\n\n        self._position += n_input_tokens\n        outputs = inputs[:, -n_input_tokens:]\n        outputs = outputs.to(device=inputs_device, dtype=inputs_dtype)\n        return outputs\n\n    def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) -> int:\n        # If there is a failed server session, this code closes it\n        self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])\n\n        n_prev_spans = len(self._server_sessions)\n        update_end = self._server_sessions[server_idx].span.end if server_idx < n_prev_spans else self.num_blocks\n        if attempt_no >= 1:\n            logger.debug(\n                f\"Due to a server failure, remote attention caches \"\n                f\"from block {block_idx} to {update_end} will be regenerated\"\n            )\n\n        updated_spans = self._sequence_manager.make_sequence(\n            block_idx, update_end, mode=\"min_latency\", cache_tokens_needed=self._max_length\n        )\n        # make_sequence() could return a longer sequence\n        updated_spans[-1].end = min(updated_spans[-1].end, update_end)\n        updated_sessions = self._enter_server_sessions(updated_spans)\n        logger.debug(f\"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers\")\n\n        # If there is a failed span, this code replaces it, otherwise it just adds new ones\n        if server_idx < n_prev_spans:\n            updated_sessions[0].history = self._server_sessions[server_idx].history\n        self._server_sessions[server_idx : server_idx + 1] = updated_sessions\n\n        # Update links to the next server session for direct server-to-server communication via rpc_push()\n        for i in range(max(server_idx - 1, 0), min(server_idx + len(updated_spans), len(self._server_sessions) - 1)):\n            self._server_sessions[i].next_session = self._server_sessions[i + 1]\n\n    def close(self, *exc_details):\n        \"\"\"Finish a given inference session, close the underlying connection\"\"\"\n        if not self._closed:\n            self._exit_server_sessions(self._server_sessions)\n            self._server_sessions.clear()\n            self._closed = True\n\n    def __exit__(self, *exc_details):\n        self.close(*exc_details)\n\n    def __del__(self):\n        self.close()\n\n    @property\n    def last_token_id(self) -> Optional[torch.Tensor]:  # Backward compatibility with Petals < 2.1.0\n        return self.output_ids[:, -1:] if self.output_ids is not None else None\n\n    @last_token_id.setter\n    def last_token_id(self, value: torch.Tensor):  # Backward compatibility with Petals < 2.1.0\n        if self.output_ids is None:\n            raise RuntimeError(\"Can't override `last_token_id` since the session has not stepped yet\")\n        self.output_ids[:, -1:] = value\n"
  },
  {
    "path": "src/petals/client/lm_head.py",
    "content": "import dataclasses\nimport platform\nfrom typing import Union\n\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom hivemind import get_logger\nfrom torch import nn\nfrom transformers import PretrainedConfig\n\nlogger = get_logger(__name__)\n\n\n@dataclasses.dataclass\nclass LMHeadConfig:\n    # This settings matter for running the client with dtype bfloat16 on CPU.\n    # If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations.\n    use_chunked_forward: Union[str, bool] = \"auto\"\n    chunked_forward_step: int = 16384\n\n\nclass LMHead(nn.Module):\n    def __init__(self, config: PretrainedConfig):\n        super().__init__()\n\n        if not config.tie_word_embeddings:\n            self.weight = nn.Parameter(torch.zeros(config.vocab_size, config.hidden_size))\n            self.weight.requires_grad = False\n        else:\n            self.weight = None  # Will be set to get_input_embeddings().weight during loading the model\n        self.bias = None\n        self.in_features = config.hidden_size  # Similar to nn.Linear attributes\n        self.out_features = config.vocab_size\n\n        self.use_chunked_forward = config.use_chunked_forward\n        if self.use_chunked_forward == \"auto\":\n            if platform.machine() == \"x86_64\":\n                # Import of cpufeature may crash on non-x86_64 machines\n                from cpufeature import CPUFeature\n\n                # If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward().\n                # Otherwise, it's ~8x slower.\n                self.use_chunked_forward = not (CPUFeature[\"AVX512f\"] and CPUFeature[\"OS_AVX512\"])\n            else:\n                self.use_chunked_forward = True\n        self.chunked_forward_step = config.chunked_forward_step\n        self._bf16_warning_shown = False\n\n    def forward(self, hidden_states):\n        if (\n            self.weight.dtype in [torch.float16, torch.bfloat16]\n            and self.weight.device.type == \"cpu\"\n            and self.use_chunked_forward\n        ):\n            lm_logits = self.chunked_forward(hidden_states)\n        else:\n            # Switch dtype in case word_embeddings are fp16/bf16\n            hidden_states = hidden_states.to(self.weight.dtype)\n            lm_logits = F.linear(hidden_states, self.weight)\n        return lm_logits\n\n    def chunked_forward(self, hidden_states):\n        \"\"\"Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.\n        chunked_forward_step: provides trade-off between efficiency and extra memory consumption.\n        \"\"\"\n        assert self.chunked_forward_step > 0, \"Chunk size for chunked forward must be positive\"\n\n        if not self._bf16_warning_shown:\n            logger.warning(\n                \"Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. \"\n                \"To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)\"\n            )\n            self._bf16_warning_shown = True\n\n        hidden_states = hidden_states.float()\n        output = torch.empty(*hidden_states.shape[:-1], self.out_features)\n\n        for i in range(0, self.out_features, self.chunked_forward_step):\n            chunk = self.weight[i : i + self.chunked_forward_step].float()\n            output[..., i : i + self.chunked_forward_step] = F.linear(hidden_states, chunk)\n        return output\n"
  },
  {
    "path": "src/petals/client/ptune.py",
    "content": "import dataclasses\nfrom contextlib import contextmanager\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\nfrom hivemind import get_logger\nfrom transformers import PretrainedConfig\n\nfrom petals.utils.misc import DUMMY\n\nlogger = get_logger(__name__)\n\n\n@dataclasses.dataclass\nclass PTuneConfig:\n    pre_seq_len: int = 0  # a number of tokens for prompt tuning.\n    tuning_mode: Optional[str] = None  # fine-tuning regime, one of [None, \"ptune\", \"deep_ptune\"]\n\n\nclass PTuneMixin:\n    _keys_to_ignore_on_load_missing = [r\"(intermediate_)?prompt_embeddings\\.weight$\"]\n\n    def init_prompts(self, config: PretrainedConfig) -> None:\n        if config.tuning_mode and \"ptune\" in config.tuning_mode:\n            assert config.pre_seq_len > 0, \"The number of prefix tokens must be > 0\"\n            self.pre_seq_len = config.pre_seq_len\n            self.prefix_tokens = torch.arange(self.pre_seq_len).long()\n\n            with force_non_empty_weights():\n                # Prompt embeddings and their optimizer stats are kept in float32 to increase ptune quality\n                self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)\n                if config.tuning_mode == \"deep_ptune\":\n                    self.intermediate_prompt_embeddings = nn.Embedding(\n                        self.pre_seq_len,\n                        config.num_hidden_layers * config.hidden_size,\n                        # ^-- TODO: should be num_hidden_layers - 1\n                        dtype=torch.float32,\n                    )\n        elif config.tuning_mode:\n            raise NotImplementedError(f\"{self.tuning_mode} mode is not supported for now\")\n\n    def get_prompt(self, batch_size):\n        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)\n        prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)\n        prompts = self.prompt_embeddings(prefix_tokens)\n\n        if self.config.tuning_mode == \"deep_ptune\":\n            intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)\n            intermediate_prompts = intermediate_prompts.view(\n                batch_size,\n                self.pre_seq_len,\n                self.config.num_hidden_layers,\n                self.config.hidden_size\n                # TODO: should be num_hidden_layers - 1\n            )\n            intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])\n        else:\n            intermediate_prompts = DUMMY\n\n        dtype = self.word_embeddings.weight.dtype\n        return prompts.to(dtype), intermediate_prompts.to(dtype)\n\n\n_original_register_parameter = nn.Module.register_parameter\n\n\n@contextmanager\ndef force_non_empty_weights():\n    \"\"\"\n    This context manager allows to bypass the accelerate.init_empty_weights() context manager\n    (that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.\n    The transformers library should replace all meta tensors by empty tensors by itself\n    but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).\n\n    [1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515\n    \"\"\"\n\n    possibly_patched_register_parameter = nn.Module.register_parameter\n    nn.Module.register_parameter = _original_register_parameter\n    try:\n        yield\n    finally:\n        nn.Module.register_parameter = possibly_patched_register_parameter\n"
  },
  {
    "path": "src/petals/client/remote_forward_backward.py",
    "content": "\"\"\"\nUtility functions that call RPC forward or backward on a single remote server\n\"\"\"\nimport asyncio\nfrom typing import Iterable, List, Optional, Sequence, Tuple\n\nimport torch\nfrom hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor\nfrom hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor\nfrom hivemind.p2p import StubBase\nfrom hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE\nfrom hivemind.proto import runtime_pb2\nfrom hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter\nfrom hivemind.utils.streaming import split_for_streaming\nfrom hivemind.utils.tensor_descr import BatchTensorDescriptor\n\nfrom petals.client.config import ClientConfig\nfrom petals.data_structures import ModuleUID, RPCInfo\n\n\nasync def _forward_unary(\n    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs\n) -> List[torch.Tensor]:\n    outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(\n        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),\n        timeout=config.request_timeout,\n    )\n    return [deserialize_torch_tensor(t) for t in outputs.tensors]\n\n\nasync def _backward_unary(\n    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs\n) -> List[torch.Tensor]:\n    grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(\n        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),\n        timeout=config.request_timeout,\n    )\n    return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]\n\n\nasync def _forward_stream(\n    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs\n) -> List[torch.Tensor]:\n    parts = (\n        runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)\n        for tensor in serialized_tensors\n        for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)\n    )\n    outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(parts)), config.connect_timeout)\n    outputs = aiter_with_timeout(outputs, config.request_timeout)\n    return await deserialize_tensor_stream(msg.tensors async for msg in outputs)\n\n\nasync def _backward_stream(\n    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs\n) -> List[torch.Tensor]:\n    parts = (\n        runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)\n        for tensor in serialized_tensors\n        for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)\n    )\n    grad_inputs = await asyncio.wait_for(stub.rpc_backward_stream(iter_as_aiter(parts)), config.connect_timeout)\n    grad_inputs = aiter_with_timeout(grad_inputs, config.request_timeout)\n    return await deserialize_tensor_stream(msg.tensors async for msg in grad_inputs)\n\n\nasync def run_remote_forward(\n    uid: ModuleUID,\n    stub: StubBase,\n    rpc_info: RPCInfo,\n    *inputs: torch.Tensor,\n    config: ClientConfig,\n    metadata: Optional[bytes] = None,\n    **kwargs,\n) -> Tuple[torch.Tensor, ...]:\n    \"\"\"\n    Serializes input tensors and calls \"rpc_forward\" on a remote server.\n    Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198\n    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.\n    \"\"\"\n\n    # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']\n    # detach to avoid pickling the computation graph\n    assert len(kwargs) == len(rpc_info[\"keyword_names\"]), f\"Keyword args should be {rpc_info['keyword_names']}\"\n    kwargs = {key: kwargs[key] for key in rpc_info[\"keyword_names\"]}\n\n    # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors\n    forward_inputs = tuple(nested_flatten((inputs, kwargs)))\n    args_schema, kwargs_schema = rpc_info[\"forward_schema\"]\n    compression = args_schema[0].compression\n    forward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in forward_inputs)\n    inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)\n    # TODO: create more explicit way to check servers schema and client's structure\n    assert len(inputs) >= len(args_schema) + 1, \"Inputs and prompt tensors are necessary for a forward step\"\n\n    # Asynchronous serialization\n    loop = asyncio.get_running_loop()\n    serialized_tensors = await asyncio.gather(\n        *(\n            loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)\n            for tensor, proto in zip(inputs, forward_schema)\n        )\n    )\n\n    # call RPC on remote server\n    size = sum(t.element_size() * t.nelement() for t in inputs)\n    forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary\n    # Hotfix: we use \"// 2\" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space\n    deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs)\n    return nested_pack(deserialized_outputs, structure=rpc_info[\"outputs_schema\"])\n\n\nasync def run_remote_backward(\n    uid: ModuleUID,\n    stub: StubBase,\n    rpc_info: RPCInfo,\n    *inputs_and_grad_outputs: torch.Tensor,\n    config: ClientConfig,\n    metadata: Optional[bytes] = None,\n    **kwargs,\n) -> Sequence[torch.Tensor]:\n    \"\"\"\n    Serializes grad outputs and calls \"rpc_backward\" on a remote server.\n    Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221\n    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.\n    \"\"\"\n    args_schema, kwargs_schema = rpc_info[\"forward_schema\"]\n    outputs_schema = rpc_info[\"outputs_schema\"]\n    compression = args_schema[0].compression\n    backward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in inputs_and_grad_outputs)\n    # TODO: create more explicit way to check servers schema and client's structure\n    assert (\n        len(inputs_and_grad_outputs) >= len(args_schema) + len(outputs_schema) + 1\n    ), \"Inputs, grad_outputs and prompt tensors are necessary for a backward step\"\n\n    # Asynchronous serialization\n    loop = asyncio.get_running_loop()\n    serialized_tensors = await asyncio.gather(\n        *(\n            loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)\n            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)\n        )\n    )\n\n    size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)\n    backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary\n    # Hotfix: we use \"// 2\" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space\n    deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs)\n    return deserialized_grad_inputs\n"
  },
  {
    "path": "src/petals/client/remote_generation.py",
    "content": "import contextlib\nimport dataclasses\nfrom contextvars import ContextVar\nfrom typing import Any, ContextManager, Dict, List, Optional, Tuple\n\nimport torch\nimport transformers\nfrom hivemind.utils.logging import get_logger\nfrom torch import Tensor\nfrom transformers.cache_utils import Cache, DynamicCache\nfrom transformers.generation.utils import ModelOutput\n\nfrom petals.client.inference_session import InferenceSession\nfrom petals.client.remote_sequential import RemoteSequential\nfrom petals.utils.misc import DUMMY, docstring_from\n\nlogger = get_logger(__name__)\n\n\nclass RemotePastKeyValues(Cache):\n    \"\"\"only keeps the number of seen tokens. pretends to be a legit cache\"\"\"\n\n    def __init__(self) -> None:\n        super().__init__()\n        self._seen_tokens = 0\n        self.hypo_ids: Optional[torch.LongTensor] = None\n\n    def __getitem__(self, _index: int) -> List[torch.Tensor]:\n        return [DUMMY]  # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()\n\n    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:\n        return self._seen_tokens\n\n    def get_max_length(self) -> Optional[int]:\n        return None\n\n    def update_seen(self, new_seen: int) -> None:\n        self._seen_tokens += new_seen\n\n    def reorder_cache(self, beam_idx):\n        raise NotImplementedError(\"Beam search reordering is not implemented yet\")\n\n\n_skipped_tokens = ContextVar(\"skipped_tokens\", default=0)\n\n\nclass _SkipTokensMixin:\n    # This override is used in RemoteGenerationMixin by has to be defined in a class not named as \"GenerationMixin\"\n    # due to how transformers.PreTrainedModel.can_generate() works\n    def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict:\n        input_ids = input_ids[:, _skipped_tokens.get() :]\n        _skipped_tokens.set(0)\n        return super().prepare_inputs_for_generation(input_ids, **kwargs)\n\n\nclass RemoteGenerationMixin(_SkipTokensMixin):\n    \"\"\"\n    This class is an upgrade to `transformers.GenerationMixin` that:\n\n    - Designed to be compatible with most `transformers.GenerationMixin` strategies and options\n    - Supports generation inside a remote InferenceSession, so that remote servers store your attention caches and\n      you don't have to rerun the prefix through all the servers to generate each new token\n    - Supports multiple `.generate()` calls inside one InferenceSession, so you can easily run interactive generation\n      by showing tokens on the fly (multiple calls like `.generate(None, max_new_tokens=1, ...)`) or\n      accept prompts from a user in a chat bot (multiple calls like `.generate(new_prompts, ...)`).\n    - If there is no active session, `.generate()` will create a new InferenceSession with proper `max_length`.\n      Otherwise, `.generate()` will use the active session. You can use the `session=...` argument to override that.\n    \"\"\"\n\n    @docstring_from(RemoteSequential.active_session)\n    @property\n    def active_session(self) -> Optional[InferenceSession]:\n        return self.transformer.h.active_session\n\n    @docstring_from(RemoteSequential.use_session)\n    def use_session(self, session: Optional[InferenceSession]) -> ContextManager[InferenceSession]:\n        return self.transformer.h.use_session(session)\n\n    @docstring_from(RemoteSequential.inference_session)\n    def inference_session(self, **kwargs) -> ContextManager[InferenceSession]:\n        return self.transformer.h.inference_session(**kwargs)\n\n    @docstring_from(transformers.GenerationMixin.generate.__doc__)\n    def generate(\n        self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs\n    ):\n        self._fix_generate_kwargs(kwargs)\n        if inputs is None:\n            inputs = kwargs.pop(\"input_ids\", None)\n\n        if session is not None:\n            # If a session specified explicitly, use it\n            context_manager = self.use_session(session)\n        elif self.active_session is not None:\n            # If there's an active session, don't do anything\n            context_manager = contextlib.nullcontext(self.active_session)\n        else:\n            # If there's no active session, create a new one\n\n            max_length = kwargs.get(\"max_length\")\n            max_new_tokens = kwargs.get(\"max_new_tokens\")\n            assert (max_length is None) != (\n                max_new_tokens is None\n            ), \"You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches\"\n\n            session_max_length = self.transformer.config.pre_seq_len\n            if max_length is not None:\n                session_max_length += max_length\n            else:\n                session_max_length += (inputs.shape[1] if inputs is not None else 0) + max_new_tokens\n            context_manager = self.inference_session(max_length=session_max_length)\n\n        with context_manager as session:\n            # Prepend the tokens from the previous .generate() call\n            n_prev_tokens = session.output_ids.shape[1] if session.output_ids is not None else 0\n            if n_prev_tokens > 0:\n                if kwargs.get(\"num_beams\", 1) > 1:\n                    logger.warning(\n                        \"Beam search will not work properly in the resumed petals.InferenceSession \"\n                        \"since intermediate beam entries are lost\"\n                    )\n\n                if inputs is not None:\n                    inputs = torch.cat([session.output_ids, inputs], dim=1)\n                else:\n                    inputs = session.output_ids\n\n                # Don't actually run all previous tokens through the transformer,\n                # but keep them for transformers.GenerationMixin (e.g., to compute repetition_penalty)\n                _skipped_tokens.set(max(0, n_prev_tokens - 1))\n\n            if self._supports_cache_class and \"past_key_values\" not in kwargs:\n                past_key_values = RemotePastKeyValues()\n                past_key_values.update_seen(session.position)\n                kwargs[\"past_key_values\"] = past_key_values\n\n            result = super().generate(inputs, *args, **kwargs)\n\n            sequences = result.sequences if isinstance(result, ModelOutput) else result\n            # Save tokens from this .generate() call\n            session.output_ids = sequences\n            # Crop the last tokens from the previous call\n            sequences = sequences[:, n_prev_tokens:].clone()\n            if isinstance(result, ModelOutput):\n                result.sequences = sequences\n            else:\n                result = sequences\n\n        return result\n\n    @staticmethod\n    def _fix_generate_kwargs(kwargs: dict):\n        # Suppress inappropriate \"Both max_new_tokens and max_length\" HF warning\n        if \"max_length\" in kwargs and kwargs[\"max_length\"] is None:\n            del kwargs[\"max_length\"]\n\n        # Support do_sample = {0, 1} for backward compatibility with Petals < 2.1.0\n        do_sample = kwargs.get(\"do_sample\")\n        if isinstance(do_sample, int):\n            kwargs[\"do_sample\"] = bool(do_sample)\n\n    @staticmethod\n    def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:\n        return dataclasses.replace(past_key_values, hypo_ids=beam_idx)\n"
  },
  {
    "path": "src/petals/client/remote_sequential.py",
    "content": "from __future__ import annotations\n\nfrom contextlib import contextmanager\nfrom contextvars import ContextVar\nfrom typing import Optional, Union\n\nimport torch\nfrom hivemind import DHT, get_logger\nfrom torch import nn\n\nfrom petals.client.config import ClientConfig\nfrom petals.client.inference_session import InferenceSession\nfrom petals.client.routing import RemoteSequenceManager\nfrom petals.client.sequential_autograd import _RemoteSequentialAutogradFunction\nfrom petals.data_structures import UID_DELIMITER\n\nlogger = get_logger(__name__)\n\n\nclass RemoteSequential(nn.Module):\n    \"\"\"\n    A sequence of transformer blocks hosted by the swarm.\n    \"\"\"\n\n    def __init__(\n        self,\n        config: ClientConfig,\n        *,\n        sequence_manager: Optional[RemoteSequenceManager] = None,\n        dht: Optional[DHT] = None,\n        start_block: Optional[int] = None,\n        end_block: Optional[int] = None,\n        **kwargs,\n    ):\n        super().__init__()\n        self.config = config\n\n        assert sequence_manager is None or (\n            dht is None and start_block is None and end_block is None\n        ), \"`dht`, `start_block`, and `end_block` have no effect when you provide a custom `sequence_manager`\"\n        if sequence_manager is None:\n            if start_block is None:\n                start_block = 0\n            if end_block is None:\n                end_block = self.config.num_hidden_layers\n            block_uids = tuple(f\"{config.dht_prefix}{UID_DELIMITER}{i}\" for i in range(start_block, end_block))\n            sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs)\n        self.sequence_manager = sequence_manager\n\n        self._active_session = ContextVar(\"active_session\", default=None)\n\n    def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:\n        assert inputs.ndim == 3, \"inputs must be a tensor of shape [batch_size, seq_length, hidden_size]\"\n        if self.active_session is None:\n            assert all(v is None for v in kwargs.values()), f\"Extra kwargs are not supported in forward: {kwargs}\"\n            return _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)\n        else:\n            return self.active_session.step(inputs, prompts, **kwargs)\n\n    @property\n    def active_session(self) -> Optional[InferenceSession]:\n        \"\"\"\n        If called inside `with model.inference_session(...):` or `with model.use_session(...):`,\n        returns an active InferenceSession. Otherwise, returns None.\n        \"\"\"\n\n        return self._active_session.get()\n\n    @property\n    def position(self) -> int:\n        \"\"\"Returns the prefix length (in tokens) in the active inference session or zero if no session is active.\"\"\"\n\n        return self.active_session.position if self.active_session is not None else 0\n\n    @contextmanager\n    def use_session(self, session: Optional[InferenceSession]) -> InferenceSession:\n        \"\"\"Inside this context, forward() will use an _existing_ InferenceSession provided as the argument.\"\"\"\n\n        token = self._active_session.set(session)\n        try:\n            yield session\n        finally:\n            self._active_session.reset(token)\n\n    @contextmanager\n    def inference_session(self, **kwargs) -> InferenceSession:\n        \"\"\"\n        Inside this context, forward() will use a _new_ InferenceSession created with given parameters.\n\n        :param max_length: Maximal expected length of inference results. Servers use this parameter\n                           to calculate the size of attention caches allocated to this client.\n        \"\"\"\n\n        with InferenceSession(self.sequence_manager, **kwargs) as session, self.use_session(session):\n            yield session\n\n    def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:\n        return RemoteSequential(\n            self.config,\n            sequence_manager=self.sequence_manager[ix],\n        )\n\n    def __iter__(self):\n        for block_index in range(len(self)):\n            yield self[block_index]\n\n    def __len__(self):\n        return len(self.sequence_manager)\n\n    def extra_repr(self) -> str:\n        return f\"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}\"\n"
  },
  {
    "path": "src/petals/client/routing/__init__.py",
    "content": "from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback\nfrom petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase\n"
  },
  {
    "path": "src/petals/client/routing/sequence_info.py",
    "content": "import dataclasses\nimport time\nfrom typing import Iterable, List, Optional, Tuple\n\nfrom hivemind import get_logger\n\nfrom petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState\nfrom petals.utils.dht import compute_spans\n\nlogger = get_logger(__name__)\n\n\n@dataclasses.dataclass\nclass RemoteSequenceInfo:\n    \"\"\"\n    A dataclass that stores general information about which servers hold any given layer;\n    - updated by RemoteSequenceManager in a background thread\n    - accessed by routing strategies in .on_update\n    :note: this class should *not* be modified by RoutingStrategy.on_update to avoid interference between strategies;\n     Any metadata specific to one routing strategy, it should be stored inside that strategy. Any information that\n     is used by most routing strategies should be moved from said strategies to this class.\n    \"\"\"\n\n    block_uids: Tuple[ModuleUID, ...]\n    block_infos: Tuple[RemoteModuleInfo, ...]  # note: the contents of RemoteModuleInfo can and will be updated\n    spans_by_priority: List[RemoteSpanInfo]\n    spans_containing_block: Tuple[List[RemoteSpanInfo], ...]\n    last_updated_time: Optional[float]\n\n    @classmethod\n    def make_empty(cls, block_uids: Iterable[ModuleUID]) -> \"RemoteSequenceInfo\":\n        block_uids = tuple(block_uids)\n        empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids)\n        empty_spans = tuple([] for _ in range(len(block_uids)))\n        return cls(block_uids, empty_block_infos, [], empty_spans, last_updated_time=None)\n\n    def __getitem__(self, ix: slice):\n        assert isinstance(ix, slice)\n        block_uids, block_infos = self.block_uids[ix], self.block_infos[ix]\n        spans_by_priority, spans_containing_block = self._sort_spans(block_infos)\n        return RemoteSequenceInfo(\n            block_uids, block_infos, spans_by_priority, spans_containing_block, self.last_updated_time\n        )\n\n    def __len__(self):\n        return len(self.block_uids)\n\n    def update_(self, new_block_infos: List[RemoteModuleInfo]):\n        assert len(new_block_infos) == len(self.block_uids)\n        for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):\n            assert uid == info.uid, f\"The DHT entry for {uid} actually points to {info.uid}\"\n            self.block_infos[block_index].servers = info.servers\n\n        self.spans_by_priority, self.spans_containing_block = self._sort_spans(self.block_infos)\n        self.last_updated_time = time.perf_counter()\n\n    @staticmethod\n    def _sort_spans(block_infos: List[RemoteModuleInfo]):\n        spans_by_priority = list(compute_spans(block_infos, min_state=ServerState.ONLINE).values())\n        spans_by_priority.sort(key=lambda span: span.length, reverse=True)\n\n        spans_containing_block = tuple([] for _ in range(len(block_infos)))\n        for span in spans_by_priority:\n            for block_index in range(span.start, span.end):\n                spans_containing_block[block_index].append(span)\n\n        return spans_by_priority, spans_containing_block\n"
  },
  {
    "path": "src/petals/client/routing/sequence_manager.py",
    "content": "from __future__ import annotations\n\nimport asyncio\nimport dataclasses\nimport itertools\nimport logging\nimport random\nimport threading\nimport time\nimport warnings\nfrom typing import Any, Dict, List, Optional, Sequence, Set, Union\nfrom weakref import WeakMethod\n\nimport dijkstar\nimport numpy as np\nfrom hivemind import DHT, P2P, MSGPackSerializer, PeerID\nfrom hivemind.dht.node import Blacklist\nfrom hivemind.moe.client.remote_expert_worker import RemoteExpertWorker\nfrom hivemind.proto import runtime_pb2\nfrom hivemind.utils.logging import get_logger\n\nfrom petals.client.config import ClientConfig\nfrom petals.client.routing.sequence_info import RemoteSequenceInfo\nfrom petals.client.routing.spending_policy import NoSpendingPolicy\nfrom petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState\nfrom petals.server.handler import TransformerConnectionHandler\nfrom petals.utils.dht import get_remote_module_infos\nfrom petals.utils.ping import PingAggregator\nfrom petals.utils.random import sample_up_to\n\nlogger = get_logger(__name__)\n\n\nclass SequenceManagerConfig(ClientConfig):\n    def __init__(self, *args, **kwargs):\n        warnings.warn(\n            \"petals.client.routing.SequenceManagerConfig has been moved to petals.ClientConfig. \"\n            \"This alias will be removed in Petals 2.2.0+\",\n            DeprecationWarning,\n            stacklevel=2,\n        )\n        super().__init__(*args, **kwargs)\n\n\n@dataclasses.dataclass\nclass SequenceManagerState:\n    p2p: P2P = None\n    sequence_info: Optional[RemoteSequenceInfo] = None\n    rpc_info: Optional[dict] = None\n    banned_peers: Optional[Blacklist] = None\n\n    def __getitem__(self, ix: Union[int, slice]) -> SequenceManagerState:\n        return dataclasses.replace(self, sequence_info=self.sequence_info[ix])\n\n    def __len__(self) -> int:\n        return len(self.sequence_info)\n\n\nclass RemoteSequenceManager:\n    \"\"\"\n    Sequence manager is a thread that keeps track of remote servers that hold the specified sequence of blocks.\n    TL;DR it tells you, which peers you should ask to get a specific layer. It is used in RemoteSequential.\n    When created, RemoteSequenceManager looks up which servers serve necessary layers by reading from DHT.\n    Using this information, sequence manager can form sequences of servers that collectively have the full sequence.\n    To form such a sequence, call .make_sequence with the appropriate optimization policy (see make_sequence docstr).\n\n    :note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid\n      running redundant sequence managers for the same set of layers.\n    \"\"\"\n\n    def __init__(\n        self,\n        config: ClientConfig,\n        block_uids: Sequence[ModuleUID],\n        *,\n        dht: Optional[DHT] = None,\n        state: Optional[SequenceManagerState] = None,\n    ):\n        assert config.initial_peers or dht is not None, \"Please specify `config.initial_peers` or `dht`\"\n        assert config.dht_prefix, \"Could not find dht_prefix in config, please create model with dht_prefix=...\"\n        assert len(block_uids) > 0, \"Sequences must contain at least one block\"\n\n        self.config = config\n        if state is None:\n            state = SequenceManagerState()\n        self.state = state\n\n        if dht is None:\n            dht = DHT(\n                initial_peers=config.initial_peers,\n                client_mode=True,\n                num_workers=32,\n                startup_timeout=config.daemon_startup_timeout,\n                start=True,\n            )\n        assert isinstance(dht, DHT) and dht.is_alive(), \"`dht` must be a running hivemind.DHT instance\"\n        self.dht = dht\n\n        if state.p2p is None:\n            state.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())\n\n        self.lock_changes = threading.Lock()\n        self._thread = _SequenceManagerUpdateThread(config.update_period, WeakMethod(self._update))\n        self._thread_start_lock = threading.Lock()\n        self.policy = NoSpendingPolicy()\n\n        self.allowed_servers = self._peer_ids_to_set(config.allowed_servers)\n        self.blocked_servers = self._peer_ids_to_set(config.blocked_servers)\n\n        self.ping_aggregator = PingAggregator(dht)\n\n        if state.banned_peers is None:\n            state.banned_peers = Blacklist(base_time=config.ban_timeout, backoff_rate=2.0)\n        if state.sequence_info is None:\n            state.sequence_info = RemoteSequenceInfo.make_empty(block_uids)\n\n        if state.sequence_info.last_updated_time is not None:\n            assert block_uids == state.sequence_info.block_uids\n            self._thread.ready.set()  # no need to await the first dht fetch\n\n    @staticmethod\n    def _peer_ids_to_set(peer_ids: Optional[Sequence[Union[PeerID, str]]]) -> Optional[Set[PeerID]]:\n        if peer_ids is None:\n            return None\n\n        result = set()\n        for peer_id in peer_ids:\n            if isinstance(peer_id, PeerID):\n                result.add(peer_id)\n            elif isinstance(peer_id, str):\n                result.add(PeerID.from_base58(peer_id))\n            else:\n                raise TypeError(\n                    f\"`allowed_servers` and `blocked_servers` have to contain only PeerIDs or strings, but got {type(peer_id)}\"\n                )\n        return result\n\n    def make_sequence(\n        self,\n        start_index: int = 0,\n        end_index: Optional[int] = None,\n        *,\n        mode: str,\n        cache_tokens_needed: Optional[int] = None,\n    ) -> List[RemoteSpanInfo]:\n        \"\"\"\n        Form a sequence of remote servers that collectively serve all consecutive layers\n\n        :param start_index: optional index of the first module in a sequence, default = the first of block_uids\n        :param end_index: optional index of the last module (non-inclusive), default = after last of block uids\n        :param mode: one of [\"max_throughput\", \"min_latency\"]\n        \"\"\"\n        with self._thread_start_lock:\n            if not self.is_alive():\n                self._thread.start()\n        if not self.ready.is_set():\n            self.update(wait=True)  # this will await an existing update or trigger a new one (if not updating)\n\n        end_index = end_index if end_index is not None else len(self)\n\n        if mode == \"min_latency\":\n            span_sequence = self._make_sequence_with_min_latency(\n                start_index, end_index, cache_tokens_needed=cache_tokens_needed\n            )\n        elif mode == \"max_throughput\":\n            span_sequence = self._make_sequence_with_max_throughput(start_index, end_index)\n        else:\n            raise RuntimeError(f\"Unexpected mode {mode}\")\n\n        if self.config.show_route is True or (mode == \"min_latency\" and self.config.show_route == \"inference\"):\n            route_repr = \" => \".join(\n                [f\"{span.start}:{span.end} via …{str(span.peer_id)[-6:]}\" for span in span_sequence]\n            )\n            logger.info(f\"Route found: {route_repr}\")\n        return span_sequence\n\n    def _make_sequence_with_min_latency(\n        self, start_index: int, end_index: int, *, cache_tokens_needed: Optional[int]\n    ) -> List[RemoteSpanInfo]:\n        if start_index == end_index:\n            return []\n\n        with self.lock_changes:\n            missing_blocks = [\n                block_idx\n                for block_idx in range(start_index, end_index)\n                if not self.state.sequence_info.spans_containing_block[block_idx]\n            ]\n            if missing_blocks:\n                raise MissingBlocksError(missing_blocks)\n            server_infos = {\n                span.peer_id: span.server_info\n                for block_idx in range(start_index, end_index)\n                for span in self.state.sequence_info.spans_containing_block[block_idx]\n            }\n\n            graph = self._build_inference_graph(start_index, end_index, cache_tokens_needed=cache_tokens_needed)\n\n        path = dijkstar.find_path(graph, \"start\", \"end\")\n        logger.debug(f\"Path info: {path}\")\n        if start_index == 0 and end_index == len(self):\n            logger.debug(f\"Expected speed: {1 / path.total_cost:.1f} steps/sec\")\n\n        span_sequence = []\n        for peer_id, block_idx in path.nodes[1:-1]:\n            if not span_sequence or span_sequence[-1].peer_id != peer_id:\n                span_sequence.append(RemoteSpanInfo(peer_id, block_idx, block_idx, server_infos[peer_id]))\n            else:\n                span_sequence[-1].end = block_idx\n\n        # Remove empty spans that can appear if we don't force to go to the end of each server and network delay\n        # don't follow triangle inequality (delay(A, B) + delay(B, C) < delay(A, C)) due to measurement errors\n        span_sequence = [span for span in span_sequence if span.length > 0]\n\n        return span_sequence\n\n    def _build_inference_graph(\n        self,\n        start_index: int,\n        end_index: int,\n        *,\n        cache_tokens_needed: Optional[int],\n        overhead_delay: float = 0.018,  # Serialization overhead (empirically measured)\n        default_inference_rps: float = 300,  # If inference RPS unknown\n        alloc_delay: float = 10,  # If not enough cache left, we penalize the edge\n    ) -> dijkstar.Graph:\n        missing_blocks = [\n            block_idx\n            for block_idx in range(start_index, end_index)\n            if not self.state.sequence_info.spans_containing_block[block_idx]\n        ]\n        if missing_blocks:\n            raise MissingBlocksError(missing_blocks)\n\n        client_server_rtts = self.ping_aggregator.to_dict()\n\n        graph = dijkstar.Graph()\n\n        # Clent -> server network delays\n        for span in self.state.sequence_info.spans_containing_block[start_index]:\n            delay = self._rtt_to_delay(client_server_rtts.get(span.peer_id))\n            delay += overhead_delay\n            if not self._has_cache_for(span, cache_tokens_needed):\n                delay += alloc_delay\n            graph.add_edge(\"start\", (span.peer_id, start_index), delay)\n\n        # Server -> client network delays\n        for span in self.state.sequence_info.spans_containing_block[end_index - 1]:\n            delay = self._rtt_to_delay(client_server_rtts.get(span.peer_id))\n            graph.add_edge((span.peer_id, end_index), \"end\", delay)\n\n        # Server -> server network delays\n        for block_idx in range(start_index + 1, end_index):\n            for cur_span in self.state.sequence_info.spans_containing_block[block_idx - 1]:\n                if cur_span.end != block_idx:\n                    # If we choose a server, we force to go to the end of it before switching to a new one\n                    # to avoid O(N^2) graphs for N servers\n                    continue\n\n                for next_span in self.state.sequence_info.spans_containing_block[block_idx]:\n                    rtt = None\n                    if cur_span.server_info.next_pings is not None:\n                        rtt = cur_span.server_info.next_pings.get(next_span.peer_id.to_base58())\n                    delay = self._rtt_to_delay(rtt)\n                    delay += overhead_delay\n                    if not self._has_cache_for(next_span, cache_tokens_needed):\n                        delay += alloc_delay\n                    graph.add_edge((cur_span.peer_id, block_idx), (next_span.peer_id, block_idx), delay)\n\n        # Compute delays\n        for span in self.state.sequence_info.spans_by_priority:\n            for block_idx in range(max(span.start, start_index), min(span.end, end_index)):\n                inference_rps = span.server_info.inference_rps\n                if inference_rps is None:\n                    inference_rps = default_inference_rps\n                graph.add_edge((span.peer_id, block_idx), (span.peer_id, block_idx + 1), 1.0 / inference_rps)\n\n        return graph\n\n    @staticmethod\n    def _rtt_to_delay(\n        rtt: float,\n        *,\n        default_delay: float = 0.15,  # If network delay unknown\n        max_delay: float = 5,  # If unreachable, we don't want to discard the edge completely\n    ) -> float:\n        if rtt is None:\n            return default_delay\n        return min(rtt / 2, max_delay)\n\n    @staticmethod\n    def _has_cache_for(span: RemoteSpanInfo, cache_tokens_needed: Optional[int] = None) -> bool:\n        if cache_tokens_needed is None or span.server_info.cache_tokens_left is None:\n            return True\n\n        # Here, `span` contains all blocks hosted by a server - but we won't necessarily run all of them through\n        # this particular server in our path. It is difficult to estimate how many blocks we'll use at this stage,\n        # so we assume that we'll use all of them (the worst case for the cache size) and get a pessimistic estimate.\n        # This is okay since false positives are more costly than false negatives here.\n        return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left\n\n    def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]:\n        client_server_rtts = self.ping_aggregator.to_dict()\n\n        span_sequence = []\n        current_index = start_index\n        while current_index < end_index:\n            candidate_spans = self.state.sequence_info.spans_containing_block[current_index]\n            if not candidate_spans:\n                raise MissingBlocksError(current_index)\n\n            # We choose longer servers to minimize the number of hops but leave some randomization\n            # to distribute the load. We also exclude servers known to be unreachable.\n            eps = 1e-6\n            span_weights = np.array(\n                [span.length if client_server_rtts.get(span.peer_id) != np.inf else eps for span in candidate_spans],\n                dtype=np.float64,\n            )\n            chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())\n\n            assert chosen_span.start <= current_index < chosen_span.end\n            span_sequence.append(dataclasses.replace(chosen_span, start=current_index))\n            current_index = chosen_span.end\n        return span_sequence\n\n    def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:\n        \"\"\"Get a RemoteSequenceManager for a sub-sequence of blocks\"\"\"\n        assert isinstance(ix, (int, slice))\n        if not isinstance(ix, slice):\n            ix = slice(int(ix), int(ix) + 1, 1)\n        return type(self)(self.config, self.block_uids[ix], dht=self.dht, state=self.state[ix])\n\n    def update(self, *, wait: bool):\n        \"\"\"Run an asynchronous update in background as soon as possible\"\"\"\n        self.ready.clear()\n        self._thread.trigger.set()\n        if wait:\n            self.ready.wait()\n\n    def _update(self):\n        \"\"\"Perform an immediate and synchronous refresh, may take time\"\"\"\n\n        new_block_infos = get_remote_module_infos(\n            self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=True\n        )\n\n        for block_info in new_block_infos:\n            # Apply allow and block lists\n            block_info.servers = {\n                peer_id: server_info\n                for peer_id, server_info in block_info.servers.items()\n                if (self.allowed_servers is None or peer_id in self.allowed_servers)\n                and (self.blocked_servers is None or peer_id not in self.blocked_servers)\n            }\n\n            # Remove temporarily banned peers, unless there are no peers left\n            valid_servers = {\n                peer_id: server_info\n                for peer_id, server_info in block_info.servers.items()\n                if peer_id not in self.state.banned_peers\n            }\n            if len(valid_servers) < len(block_info.servers):\n                if valid_servers:\n                    logger.debug(\n                        f\"Kept {len(valid_servers)} out of {len(block_info.servers)} servers holding {block_info.uid}\"\n                    )\n                    block_info.servers = valid_servers\n                else:\n                    # If we blacklisted all servers, the error may actually be client-caused\n                    logger.debug(f\"All servers holding {block_info.uid} are blacklisted, ignoring blacklist\")\n\n        with self.lock_changes:\n            self.state.sequence_info.update_(new_block_infos)\n\n            first_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[0]]\n            middle_servers = [\n                span.peer_id for spans in self.state.sequence_info.spans_containing_block[1:-1] for span in spans\n            ]\n            last_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[-1]]\n\n        pinged_servers = set(sample_up_to(first_servers, self.config.max_pinged))\n        pinged_servers = set(sample_up_to(middle_servers, self.config.max_pinged))\n        pinged_servers |= set(sample_up_to(last_servers, self.config.max_pinged))\n        self.ping_aggregator.ping(list(pinged_servers), wait_timeout=self.config.ping_timeout)\n\n        self.ready.set()\n\n    def on_request_failure(self, peer_id: Optional[PeerID]):\n        \"\"\"remove a given peer from the routing table. If the routing is no longer possible, trigger an update\"\"\"\n        if peer_id is not None:\n            logger.debug(f\"Peer {peer_id} did not respond, banning it temporarily\")\n            self.state.banned_peers.register_failure(peer_id)\n        with self.lock_changes:\n            should_update = False\n            for info in self.state.sequence_info.block_infos:\n                info.servers.pop(peer_id, None)\n                if not info.servers:\n                    should_update = True\n            if should_update:\n                self.ready.clear()\n                self.update(wait=False)\n\n    def on_request_success(self, peer_id: PeerID):\n        \"\"\"if peer has a failure streak, clear that streak\"\"\"\n        self.state.banned_peers.register_success(peer_id)\n\n    def __len__(self):\n        return len(self.block_uids)\n\n    @property\n    def is_alive(self):\n        return self._thread.is_alive\n\n    @property\n    def ready(self) -> threading.Event:\n        return self._thread.ready\n\n    @property\n    def block_uids(self):\n        return self.state.sequence_info.block_uids\n\n    @property\n    def rpc_info(self):\n        \"\"\"Return the rpc_info queried from one of the servers that hold the first block\"\"\"\n        if self.state.rpc_info is not None:\n            return self.state.rpc_info\n\n        with self._thread_start_lock:\n            if not self.is_alive():\n                self._thread.start()\n\n        for attempt_no in itertools.count():\n            peer_id = None\n            try:\n                if not self.ready.is_set():\n                    self.update(wait=True)\n\n                active_servers = [\n                    peer_id\n                    for peer_id, server in self.state.sequence_info.block_infos[0].servers.items()\n                    if server.state == ServerState.ONLINE\n                ]\n                if not active_servers:\n                    raise MissingBlocksError(0)\n                peer_id = random.choice(active_servers)\n\n                stub = TransformerConnectionHandler.get_stub(self.state.p2p, peer_id)\n                outputs = RemoteExpertWorker.run_coroutine(\n                    stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]), timeout=self.config.request_timeout)\n                )\n                self.state.rpc_info = MSGPackSerializer.loads(outputs.serialized_info)\n                self.on_request_success(peer_id)\n                break\n            except Exception as e:\n                self.on_request_failure(peer_id)\n                if attempt_no + 1 == self.config.max_retries:\n                    raise\n                delay = self.get_retry_delay(attempt_no)\n                logger.warning(\n                    f\"Caught exception when gathering information from peer {peer_id} \"\n                    f\"(retry in {delay:.0f} sec): {repr(e)}\"\n                )\n                maybe_log_traceback(e)\n                time.sleep(delay)\n\n        return self.state.rpc_info\n\n    def get_retry_delay(self, attempt_no: int) -> float:\n        if attempt_no == 0:\n            return 0\n        return min(self.config.min_backoff * 2 ** (attempt_no - 1), self.config.max_backoff)\n\n    def get_request_metadata(\n        self, protocol: str, args_structure: Any = None, *args, **kwargs\n    ) -> Optional[Dict[str, Any]]:\n        \"\"\"\n        :param protocol: one of \"rpc_forward\", \"rpc_backward\" or \"rpc_inference\"\n        :param args_structure: the structure of flattened tensors from pack_args_kwargs in petals.utils.packaging\n        :param args: request-specific inputs, typically block uids and input tensors\n        :param kwargs: additional request context, such as remote peer ID\n        :returns: msgpack-serialized metadata dict that will be passed alongside a given request\n        \"\"\"\n        return dict(\n            points=self.policy.get_points(protocol, *args, **kwargs),\n            active_adapter=self.config.active_adapter,\n            args_structure=args_structure,\n        )\n\n    def shutdown(self):\n        self._thread.shutdown()\n\n\nclass _SequenceManagerUpdateThread(threading.Thread):\n    def __init__(self, update_period: float, ref_update_manager: WeakMethod):\n        super().__init__(daemon=True)\n        self.ref_update_manager = ref_update_manager\n        self.ready = threading.Event()\n        self.trigger = threading.Event()\n        self.update_period = update_period\n        self.should_shutdown = False\n\n    def run(self) -> None:\n        while not self.should_shutdown:\n            update_manager = self.ref_update_manager()\n            if update_manager is None:\n                logger.debug(f\"{self.__class__.__name__} exited because the sequence manager no longer exists\")\n                break\n\n            try:\n                self.trigger.clear()\n                update_manager()\n            except Exception as e:\n                logger.exception(e)\n            finally:\n                del update_manager\n\n            self.trigger.wait(self.update_period)\n\n        logger.debug(f\"{self.__class__.__name__} thread exited\")\n\n    def shutdown(self, timeout: Optional[float] = None):\n        self.should_shutdown = True\n        self.trigger.set()\n        if self.is_alive():\n            self.join(timeout)\n\n    def __del__(self):\n        self.shutdown()\n\n\ndef maybe_log_traceback(exc: Exception):\n    traceback_level = logging.DEBUG if str(exc) or isinstance(exc, asyncio.TimeoutError) else logging.WARNING\n    logger.log(traceback_level, \"See detailed traceback below:\", exc_info=True)\n\n\nclass MissingBlocksError(RuntimeError):\n    def __init__(self, block_indices: Union[int, Sequence[int]]):\n        super().__init__(\n            f\"No servers holding blocks {block_indices} are online. \"\n            f\"You can check the public swarm's state at https://health.petals.dev \"\n            f\"If there are not enough servers, please connect your GPU: \"\n            f\"https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity \"\n        )\n"
  },
  {
    "path": "src/petals/client/routing/spending_policy.py",
    "content": "\"\"\"\nAn interface for exchanging internal \"BLOOM points\" for higher priority compute requests. NOT IMPLEMENTED.\nThe intent is to let Petals participants earn points by helping others while idle (e.g. at night), then use these\n points to run their own compute experiments faster. See Section 4 of https://arxiv.org/abs/2209.01188 for discussion.\n\"\"\"\nfrom abc import ABC, abstractmethod\n\n\nclass SpendingPolicyBase(ABC):\n    @abstractmethod\n    def get_points(self, protocol: str, *args, **kwargs) -> float:\n        pass\n\n\nclass NoSpendingPolicy(SpendingPolicyBase):\n    def get_points(self, protocol: str, *args, **kwargs) -> float:\n        return 0.0\n"
  },
  {
    "path": "src/petals/client/sequential_autograd.py",
    "content": "\"\"\"\nA PyTorch autograd function that runs forward/backward on a sequence of remote servers in a fault-tolerant manner\n\"\"\"\nimport asyncio\nimport itertools\nfrom collections import deque\nfrom typing import List, Optional, Sequence, Tuple\n\nimport torch\nfrom hivemind import MSGPackSerializer\nfrom hivemind.moe.client.remote_expert_worker import RemoteExpertWorker\nfrom hivemind.utils.logging import get_logger\n\nfrom petals.client.remote_forward_backward import run_remote_backward, run_remote_forward\nfrom petals.client.routing import RemoteSequenceManager, maybe_log_traceback\nfrom petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo\nfrom petals.server.handler import TransformerConnectionHandler\nfrom petals.utils.misc import DUMMY, is_dummy\nfrom petals.utils.packaging import pack_args_kwargs\n\nlogger = get_logger(__name__)\n\nMAX_TOKENS_IN_BATCH = 1024\n\n\nasync def sequential_forward(\n    inputs: torch.Tensor,\n    prompts: torch.Tensor,\n    sequence_manager: RemoteSequenceManager,\n    start_index: int = 0,\n    end_index: Optional[int] = None,\n) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:\n    \"\"\"\n    Constructs a routing path from <start_index> to <end_index>.\n    Performs chained forward for each subsequence of blocks on the path.\n    If some subsequence fails, reconstructs the remaining path and tries to finish the forward.\n    \"\"\"\n\n    assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f\"{type(inputs)}: {inputs.ndim}\"\n\n    inputs_device = inputs.device\n    inputs_dtype = inputs.dtype\n    inputs = inputs.cpu()\n    prompts = prompts.cpu()\n\n    end_index = end_index if end_index is not None else len(sequence_manager.block_uids)\n    assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)\n    assert is_dummy(prompts) or len(prompts) == len(\n        sequence_manager.block_uids\n    )  # should be n_layers - 1 but add extra prompts for convenience\n\n    sequences = deque()\n    intermediate_inputs = []\n    done_sequences = []\n\n    block_idx = start_index\n    while block_idx < end_index:\n        for attempt_no in itertools.count():\n            logger.debug(f\"Forward: block {block_idx}, attempt {attempt_no}\")\n            span = None\n            try:\n                if not sequences or attempt_no >= 1:\n                    sequences = deque(sequence_manager.make_sequence(block_idx, end_index, mode=\"max_throughput\"))\n                    # make_sequence() could return a longer sequence\n                    sequences[-1].end = min(sequences[-1].end, end_index)\n                    logger.debug(f\"Found path from block {block_idx} to {end_index} via {len(sequences)} servers\")\n\n                span = sequences.popleft()\n\n                stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)\n                flat_tensors, args_structure = pack_args_kwargs(inputs, prompts[span.start : span.end])\n\n                span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])\n                metadata = sequence_manager.get_request_metadata(\n                    \"rpc_forward\", args_structure, span_uids, *flat_tensors\n                )\n                (outputs,) = await run_remote_forward(\n                    span_uids,\n                    stub,\n                    sequence_manager.rpc_info,\n                    *flat_tensors,\n                    config=sequence_manager.config,\n                    metadata=MSGPackSerializer.dumps(metadata),\n                )\n\n                assert isinstance(outputs, torch.Tensor)\n                assert outputs.shape == inputs.shape, f\"Expected output {inputs.shape}, got {outputs.shape}\"\n\n                # Save intermediate inputs and subsequences if the forward is already done for them\n                intermediate_inputs.append(inputs)\n                done_sequences.append(span)\n\n                inputs = outputs\n                block_idx = span.end\n                sequence_manager.on_request_success(span.peer_id)\n                break\n            except Exception as e:\n                sequence_manager.on_request_failure(span.peer_id if span is not None else None)\n                if attempt_no + 1 == sequence_manager.config.max_retries:\n                    raise\n                delay = sequence_manager.get_retry_delay(attempt_no)\n                logger.warning(\n                    f\"Caught exception when running forward via {span} (retry in {delay:.0f} sec): {repr(e)}\"\n                )\n                maybe_log_traceback(e)\n                await asyncio.sleep(delay)\n\n    outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)\n    intermediate_inputs = [tensor.to(device=inputs_device, dtype=inputs_dtype) for tensor in intermediate_inputs]\n    return outputs, intermediate_inputs, done_sequences\n\n\nasync def sequential_backward(\n    grad_outputs: Sequence[torch.Tensor],\n    intermediate_inputs: List[torch.Tensor],\n    prompts: torch.Tensor,\n    forward_sequences: List[RemoteSpanInfo],\n    sequence_manager: RemoteSequenceManager,\n) -> Tuple[Sequence[torch.Tensor], torch.Tensor]:\n    \"\"\"\n    Performs chained backward for each forward subsequence.\n    If some subsequence fails, reconstructs the particular sub-path and recovers the backward.\n    \"\"\"\n    assert len(intermediate_inputs) == len(forward_sequences)\n\n    grad_outputs_device = grad_outputs[0].device if grad_outputs else None\n    grad_outputs_dtype = grad_outputs[0].dtype if grad_outputs else None\n    prompts_device = prompts.device\n    prompts_dtype = prompts.dtype\n\n    grad_outputs = [tensor.cpu() for tensor in grad_outputs]\n    intermediate_inputs = [tensor.cpu() for tensor in intermediate_inputs]\n    prompts = prompts.cpu()\n\n    grad_prompts_reversed = []\n    while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:\n        inputs = intermediate_inputs.pop()\n        span = forward_sequences.pop()\n        for attempt_no in itertools.count():\n            logger.debug(f\"Backward: block {span.end - 1}, attempt {attempt_no}\")\n            try:\n                if attempt_no >= 1:\n                    _, backup_inputs, backup_sequences = await sequential_forward(\n                        inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end\n                    )\n                    assert len(backup_inputs) == len(backup_sequences)\n                    assert backup_sequences[0].start == span.start\n                    assert backup_sequences[-1].end == span.end\n\n                    intermediate_inputs.extend(backup_inputs)\n                    forward_sequences.extend(backup_sequences)\n                    inputs = intermediate_inputs.pop()\n                    span = forward_sequences.pop()\n\n                grad_outputs_cpu = [grad.cpu() for grad in grad_outputs]\n                flat_tensors, args_structure = pack_args_kwargs(\n                    inputs, *grad_outputs_cpu, prompts[span.start : span.end]\n                )\n\n                span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])\n                stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)\n                metadata = sequence_manager.get_request_metadata(\n                    \"rpc_backward\", args_structure, span_uids, *flat_tensors, peer_id=span.peer_id\n                )\n                grad_outputs, *span_grad_prompts = await run_remote_backward(\n                    span_uids,\n                    stub,\n                    sequence_manager.rpc_info,\n                    *flat_tensors,\n                    config=sequence_manager.config,\n                    metadata=MSGPackSerializer.dumps(metadata),\n                )\n                grad_outputs = [grad_outputs]\n                grad_prompts_reversed.extend(span_grad_prompts)\n                sequence_manager.on_request_success(span.peer_id)\n                break\n            except Exception as e:\n                sequence_manager.on_request_failure(span.peer_id if span is not None else None)\n                if attempt_no + 1 == sequence_manager.config.max_retries:\n                    raise\n                delay = sequence_manager.get_retry_delay(attempt_no)\n                logger.warning(\n                    f\"Caught exception when running backward via {span} (retry in {delay:.0f} sec): {repr(e)}\"\n                )\n                maybe_log_traceback(e)\n                await asyncio.sleep(delay)\n\n    # For now, we do not support mixed dummy and grad prompts\n    # Concat in num_layer dimension\n    grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None\n\n    if grad_outputs_dtype is not None:\n        grad_outputs = [tensor.to(device=grad_outputs_device, dtype=grad_outputs_dtype) for tensor in grad_outputs]\n    if grad_prompts is not None:\n        grad_prompts = grad_prompts.to(device=prompts_device, dtype=prompts_dtype)\n    return grad_outputs, grad_prompts\n\n\nasync def _gather_forward(input_batches, prompt_batches, sequence_manager):\n    \"\"\"Wrapper for asyncio.gather to perform parallel sequential forwards\"\"\"\n    return await asyncio.gather(\n        *[\n            sequential_forward(input_batch, prompt_batch, sequence_manager)\n            for input_batch, prompt_batch in zip(input_batches, prompt_batches)\n        ]\n    )\n\n\nasync def _gather_backward(\n    grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences, sequence_manager\n):\n    \"\"\"Wrapper for asyncio.gather to perform parallel sequential backwards\"\"\"\n    return await asyncio.gather(\n        *[\n            sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager)\n            for grad_output, input_batch, prompt_batch, spans in zip(\n                grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences\n            )\n        ]\n    )\n\n\nclass _RemoteSequentialAutogradFunction(torch.autograd.Function):\n    \"\"\"\n    PyTorch autograd function that provides forward and backward calls for the entire sequence of remote transformer blocks.\n    This function splits input data into batches with <MAX_TOKENS_IN_BATCH> and performs efficient parallel processing.\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):\n        batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)\n        input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)\n        if prompts is None or is_dummy(prompts):\n            prompt_batches = [DUMMY] * len(input_batches)\n        else:\n            prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)\n\n        sequence_manager.rpc_info  # lazy init\n        outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager))\n        assert len(outputs) == len(input_batches)\n\n        output_batches = [output[0] for output in outputs]\n        intemediate_input_batches = [output[1] for output in outputs]\n        sequences_for_batches = [output[2] for output in outputs]\n\n        ctx.prompt_batches = prompt_batches\n        ctx.sequence_manager = sequence_manager\n        ctx.intemediate_input_batches = intemediate_input_batches\n        ctx.sequences_for_batches = sequences_for_batches\n        return torch.cat(output_batches, dim=0)\n\n    @staticmethod\n    def backward(ctx, grad_outputs: torch.Tensor):\n        intermediate_input_batches: List[Sequence[torch.Tensor]] = ctx.intemediate_input_batches\n        forward_sequences: List[Sequence[RemoteSpanInfo]] = ctx.sequences_for_batches\n        ctx.sequence_manager.rpc_info  # lazy init\n\n        batch_size = max(MAX_TOKENS_IN_BATCH // grad_outputs.shape[1], 1)\n        grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)\n        assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences)\n\n        outputs = RemoteExpertWorker.run_coroutine(\n            _gather_backward(\n                grad_output_batches,\n                intermediate_input_batches,\n                ctx.prompt_batches,\n                forward_sequences,\n                ctx.sequence_manager,\n            )\n        )\n        grad_input_batches = [output[0][0] for output in outputs]\n        grad_prompt_batches = [output[1] for output in outputs]\n\n        grad_inputs = torch.cat(grad_input_batches, dim=0)\n        dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]\n        grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None\n        return (grad_inputs, grad_prompts, None)\n"
  },
  {
    "path": "src/petals/constants.py",
    "content": "import torch\n\nPUBLIC_INITIAL_PEERS = [\n    # IPv4 DNS addresses\n    \"/dns/bootstrap1.petals.dev/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY\",\n    \"/dns/bootstrap2.petals.dev/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5\",\n    # IPv6 DNS addresses\n    \"/dns6/bootstrap1.petals.dev/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY\",\n    \"/dns6/bootstrap2.petals.dev/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5\",\n    # Reserved IPs\n    \"/ip4/159.89.214.152/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY\",\n    \"/ip4/159.203.156.48/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5\",\n]\n\n# The reachability API is currently used only when connecting to the public swarm\nREACHABILITY_API_URL = \"https://health.petals.dev\"\n\nDTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto=\"auto\")\n"
  },
  {
    "path": "src/petals/data_structures.py",
    "content": "import dataclasses\nfrom enum import Enum\nfrom typing import Any, Dict, Optional, Sequence, Tuple\n\nimport pydantic.v1 as pydantic\nfrom hivemind import PeerID\nfrom hivemind.moe.expert_uid import ExpertUID\n\nModuleUID = str\nUID_DELIMITER = \".\"  # delimits parts of one module uid, e.g. \"bloom.transformer.h.4.self_attention\"\nCHAIN_DELIMITER = \" \"  # delimits multiple uids in a sequence, e.g. \"bloom.layer3 bloom.layer4\"\n\n\ndef parse_uid(uid: ModuleUID) -> Tuple[str, int]:\n    assert CHAIN_DELIMITER not in uid, \"parse_uid() does not support chained UIDs\"\n    dht_prefix, index = uid.split(UID_DELIMITER)\n    return dht_prefix, int(index)\n\n\n@pydantic.dataclasses.dataclass\nclass ModelInfo:\n    num_blocks: pydantic.conint(ge=1, strict=True)\n    repository: Optional[str] = None\n\n    def to_dict(self) -> dict:\n        return dataclasses.asdict(self)\n\n    @classmethod\n    def from_dict(cls, source: dict):\n        return cls(**source)\n\n\nclass ServerState(Enum):\n    OFFLINE = 0\n    JOINING = 1\n    ONLINE = 2\n\n\nRPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)\n\n\n@pydantic.dataclasses.dataclass\nclass ServerInfo:\n    state: ServerState\n    throughput: RPS\n\n    start_block: Optional[pydantic.conint(ge=0, strict=True)] = None\n    end_block: Optional[pydantic.conint(ge=0, strict=True)] = None\n\n    public_name: Optional[str] = None\n    version: Optional[str] = None\n\n    network_rps: Optional[RPS] = None\n    forward_rps: Optional[RPS] = None\n    inference_rps: Optional[RPS] = None\n\n    adapters: Sequence[str] = ()\n    torch_dtype: Optional[str] = None\n    quant_type: Optional[str] = None\n    using_relay: Optional[bool] = None\n    cache_tokens_left: Optional[pydantic.conint(ge=0, strict=True)] = None\n    next_pings: Optional[Dict[str, pydantic.confloat(ge=0, strict=True)]] = None\n\n    def to_tuple(self) -> Tuple[int, float, dict]:\n        extra_info = dataclasses.asdict(self)\n        del extra_info[\"state\"], extra_info[\"throughput\"]\n        return (self.state.value, self.throughput, extra_info)\n\n    @classmethod\n    def from_tuple(cls, source: tuple):\n        state, throughput = source[:2]\n        extra_info = source[2] if len(source) > 2 else {}\n        # pydantic will validate existing fields and ignore extra ones\n        return cls(state=ServerState(state), throughput=throughput, **extra_info)\n\n\n@dataclasses.dataclass\nclass RemoteModuleInfo:\n    \"\"\"A remote module that is served by one or more servers\"\"\"\n\n    uid: ModuleUID\n    servers: Dict[PeerID, ServerInfo]\n\n\n@dataclasses.dataclass\nclass RemoteSpanInfo:\n    \"\"\"A chain of remote blocks served by one specific remote peer\"\"\"\n\n    peer_id: PeerID\n    start: int\n    end: int\n    server_info: ServerInfo\n\n    @property\n    def length(self) -> int:\n        return self.end - self.start\n\n    @property\n    def state(self) -> ServerState:\n        return self.server_info.state\n\n    @property\n    def throughput(self) -> float:\n        return self.server_info.throughput\n\n\nRPCInfo = Dict[str, Any]\n\nHandle = int\n\n\n@dataclasses.dataclass(frozen=True)\nclass InferenceMetadata:\n    uid: ExpertUID\n    prefix_length: int\n    cache_handles: Tuple[Handle, ...]\n    active_adapter: Optional[str]\n"
  },
  {
    "path": "src/petals/dht_utils.py",
    "content": "import warnings\n\nwarnings.warn(\n    \"petals.dht_utils has been moved to petals.utils.dht. This alias will be removed in Petals 2.2.0+\",\n    DeprecationWarning,\n    stacklevel=2,\n)\n\nfrom petals.utils.dht import *\n"
  },
  {
    "path": "src/petals/models/__init__.py",
    "content": "from petals.models.bloom import *\nfrom petals.models.falcon import *\nfrom petals.models.llama import *\nfrom petals.models.mixtral import *\n"
  },
  {
    "path": "src/petals/models/bloom/__init__.py",
    "content": "from petals.models.bloom.block import WrappedBloomBlock\nfrom petals.models.bloom.config import DistributedBloomConfig\nfrom petals.models.bloom.model import (\n    DistributedBloomForCausalLM,\n    DistributedBloomForSequenceClassification,\n    DistributedBloomModel,\n)\nfrom petals.utils.auto_config import register_model_classes\n\nregister_model_classes(\n    config=DistributedBloomConfig,\n    model=DistributedBloomModel,\n    model_for_causal_lm=DistributedBloomForCausalLM,\n    model_for_sequence_classification=DistributedBloomForSequenceClassification,\n)\n"
  },
  {
    "path": "src/petals/models/bloom/block.py",
    "content": "\"\"\"\nBloom intermediate layer\nBased on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b\nSee commit history for authorship.\n\"\"\"\nfrom typing import Optional, Tuple\n\nimport torch\nfrom transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask\nfrom transformers.models.bloom.modeling_bloom import BloomBlock, build_alibi_tensor\n\nfrom petals.utils.misc import is_dummy\n\n\nclass WrappedBloomBlock(BloomBlock):\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        *args,\n        attention_mask: Optional[torch.Tensor] = None,\n        alibi: Optional[torch.Tensor] = None,\n        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        **kwargs\n    ):\n        assert attention_mask is None, \"Non-causal attention masks are not supported yet\"\n        batch_size, seq_length = hidden_states.shape[:2]\n        if layer_past is not None and is_dummy(layer_past[0]):\n            # Bloom cannot use cache if it was misconsctructed(e.g. Dummy tensors)\n            # In this case, fallback to the old code:\n            layer_past = None\n        past_length = 0 if layer_past is None else layer_past[0].shape[-1]\n        seq_length_with_past = seq_length + past_length\n        attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)\n        if alibi is None:\n            alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)\n        attention_mask = _prepare_4d_causal_attention_mask(\n            attention_mask=attention_mask,\n            input_shape=(batch_size, seq_length),\n            inputs_embeds=hidden_states,\n            past_key_values_length=past_length,\n        )\n        attention_mask = attention_mask.bool()\n        return super().forward(\n            hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs\n        )\n"
  },
  {
    "path": "src/petals/models/bloom/config.py",
    "content": "import os\nfrom typing import Optional, Union\n\nfrom hivemind import get_logger\nfrom transformers.models.bloom import BloomConfig\nfrom transformers.models.bloom.modeling_bloom import BloomAttention\n\nfrom petals.client.config import ClientConfig\nfrom petals.client.lm_head import LMHeadConfig\nfrom petals.client.ptune import PTuneConfig\nfrom petals.models.bloom.block import WrappedBloomBlock\n\nlogger = get_logger(__name__)\n\n\nclass DistributedBloomConfig(BloomConfig, ClientConfig, PTuneConfig, LMHeadConfig):\n    block_class = WrappedBloomBlock\n    attn_class = BloomAttention\n    block_prefix = \"h\"\n\n    num_key_value_groups = 1\n\n    @classmethod\n    def from_pretrained(\n        cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs\n    ):\n        logger.info(\"Make sure you follow the BLOOM terms of use: https://bit.ly/bloom-license\")\n\n        loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)\n        if loading_from_repo and dht_prefix is None:\n            # We need \"-petals\" for backward compatibility with Petals < 1.2.0\n            dht_prefix = str(model_name_or_path) + \"-petals\"\n            dht_prefix = dht_prefix.replace(\".\", \"-\")\n            logger.info(f\"Using DHT prefix: {dht_prefix}\")\n        return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)\n"
  },
  {
    "path": "src/petals/models/bloom/model.py",
    "content": "from typing import Optional\n\nimport hivemind\nimport torch\nimport torch.nn as nn\nfrom hivemind.utils.logging import get_logger\nfrom transformers.cache_utils import Cache\nfrom transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions\nfrom transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel\n\nfrom petals.client.from_pretrained import FromPretrainedMixin\nfrom petals.client.lm_head import LMHead\nfrom petals.client.ptune import PTuneMixin\nfrom petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues\nfrom petals.client.remote_sequential import RemoteSequential\nfrom petals.models.bloom.config import DistributedBloomConfig\n\nlogger = get_logger(__name__)\n\n\nclass DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):\n    \"\"\"BloomModel, but all transformer layers are hosted by the swarm\"\"\"\n\n    _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing\n    _keys_to_ignore_on_load_unexpected = [r\"^h\\.\"]\n\n    config_class = DistributedBloomConfig\n\n    def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None):\n        n_layer, config.num_hidden_layers = config.num_hidden_layers, 0  # Prevent initialization\n        super().__init__(config)\n        assert len(self.h) == 0\n        config.num_hidden_layers = n_layer\n\n        self.h = RemoteSequential(config, dht=dht)\n\n        self.requires_grad_(False)  # Forbid accumulate grads for embeddings and layernorm\n        self.init_prompts(config)\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[RemotePastKeyValues] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.LongTensor] = None,\n        inputs_embeds: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ):\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\"You cannot specify both input_ids and inputs_embeds at the same time\")\n        elif input_ids is not None:\n            input_shape = input_ids.size()\n            input_ids = input_ids.view(-1, input_shape[-1])\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.size()[:-1]\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        # The causal mask will be added on the server-side\n        assert (\n            attention_mask is None or (attention_mask == 1).all()\n        ), f\"Custom attention masks are not supported, {attention_mask=}\"\n        assert head_mask is None, f\"Custom head masks are not supported, {head_mask=}\"\n        assert use_cache is None or use_cache, f\"{use_cache=} is not supported\"\n        assert not output_attentions, f\"{output_attentions=} is not supported\"\n        assert not output_hidden_states, f\"{output_hidden_states=} is not supported\"\n        assert return_dict is None or return_dict, f\"{return_dict=} is not supported\"\n\n        if inputs_embeds is None:\n            inputs_embeds = self.word_embeddings(input_ids)\n\n        use_prompts = self.config.tuning_mode and \"ptune\" in self.config.tuning_mode and self.h.position == 0\n        if use_prompts:\n            batch_size = inputs_embeds.shape[0]\n            prompts, intermediate_prompts = self.get_prompt(batch_size)\n            inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)\n        else:\n            prompts = intermediate_prompts = None\n\n        hidden_states = self.word_embeddings_layernorm(inputs_embeds)\n        output_shape = input_shape + (hidden_states.size(-1),)\n\n        hidden_states = self.h(\n            hidden_states,\n            prompts=intermediate_prompts,\n            hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,\n        )\n\n        # Remove prefix\n        if use_prompts:\n            hidden_states = hidden_states[:, self.pre_seq_len :]\n\n        if past_key_values is None:\n            past_key_values = RemotePastKeyValues()\n        past_key_values.update_seen(hidden_states.size(1))\n\n        # Add last hidden state\n        hidden_states = self.ln_f(hidden_states)\n        hidden_states = hidden_states.view(output_shape)\n        return BaseModelOutputWithPastAndCrossAttentions(\n            last_hidden_state=hidden_states,\n            past_key_values=past_key_values,\n            hidden_states=None,\n            attentions=None,\n        )\n\n\nclass DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, BloomForCausalLM):\n    _keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing\n    _keys_to_ignore_on_load_missing += [r\"^lm_head\\.\"]  # Missing since they are shared with input embeddings\n    _keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected\n    _supports_cache_class = True\n\n    config_class = DistributedBloomConfig\n\n    def __init__(self, config: DistributedBloomConfig):\n        BloomPreTrainedModel.__init__(self, config)\n        self.transformer = DistributedBloomModel(config)\n        self.lm_head = LMHead(config)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def prepare_inputs_for_generation(\n        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs\n    ) -> dict:\n        # Omit tokens covered by past_key_values\n        if past_key_values is not None:\n            if isinstance(past_key_values, Cache):\n                cache_length = past_key_values.get_seq_length()\n                past_length = past_key_values._seen_tokens\n                max_cache_length = past_key_values.get_max_length()\n            else:\n                cache_length = past_length = past_key_values[0][0].shape[2]\n                max_cache_length = None\n\n            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:\n                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]\n            elif past_length < input_ids.shape[1]:\n                input_ids = input_ids[:, past_length:]\n\n            if (\n                max_cache_length is not None\n                and attention_mask is not None\n                and cache_length + input_ids.shape[1] > max_cache_length\n            ):\n                attention_mask = attention_mask[:, -max_cache_length:]\n\n        position_ids = kwargs.get(\"position_ids\", None)\n        if attention_mask is not None and position_ids is None:\n            # create position_ids on the fly for batch generation\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            if past_key_values:\n                position_ids = position_ids[:, -input_ids.shape[1] :]\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_key_values is None:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        model_inputs.update(\n            {\n                \"position_ids\": position_ids,\n                \"past_key_values\": past_key_values,\n                \"use_cache\": kwargs.get(\"use_cache\"),\n                \"attention_mask\": attention_mask,\n            }\n        )\n        return model_inputs\n\n    def _temporary_reorder_cache(self, past_key_values, beam_idx):\n        return past_key_values\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n\nclass DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSequenceClassification):\n    _keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing\n    _keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected\n\n    config_class = DistributedBloomConfig\n\n    def __init__(self, config: DistributedBloomConfig):\n        BloomPreTrainedModel.__init__(self, config)\n        self.num_labels = config.num_labels\n\n        self.transformer = DistributedBloomModel(config)\n        self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n"
  },
  {
    "path": "src/petals/models/falcon/__init__.py",
    "content": "from petals.models.falcon.block import WrappedFalconBlock\nfrom petals.models.falcon.config import DistributedFalconConfig\nfrom petals.models.falcon.model import (\n    DistributedFalconForCausalLM,\n    DistributedFalconForSequenceClassification,\n    DistributedFalconModel,\n)\nfrom petals.utils.auto_config import register_model_classes\n\nregister_model_classes(\n    config=DistributedFalconConfig,\n    model=DistributedFalconModel,\n    model_for_causal_lm=DistributedFalconForCausalLM,\n    model_for_sequence_classification=DistributedFalconForSequenceClassification,\n)\n"
  },
  {
    "path": "src/petals/models/falcon/block.py",
    "content": "\"\"\"\nFalcon intermediate layer\nBased on https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py\nSee commit history for authorship.\n\"\"\"\nimport math\nfrom functools import partial\nfrom typing import Optional, Tuple\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom transformers.models.falcon.modeling_falcon import (\n    FalconAttention,\n    FalconConfig,\n    FalconDecoderLayer,\n    FalconLinear,\n    FalconMLP,\n    FalconModel,\n    LayerNorm,\n    build_alibi_tensor,\n    dropout_add,\n    rotate_half,\n)\n\nKVCache = Tuple[torch.Tensor, torch.Tensor]\nINFERENCE_MAX_LENGTH = 8192\n\n\ndef apply_rotary(query, key, cos, sin):\n    return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)\n\n\nclass OptimizedFalconRotaryEmbedding(nn.Module):\n    def __init__(self, head_dim: int, base=10000):\n        super().__init__()\n        inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self.head_dim = head_dim\n        self.seq_len_cached = -1\n\n        self.cuda_graph = None\n        self.input_surface = None\n        self.static_outputs = None\n\n    def _optimized_apply_rotary(self, query, key, cos, sin):\n        if self.cuda_graph is None:\n            self.cuda_graph = torch.cuda.CUDAGraph()\n            self.input_surface = (query, key, cos, sin)\n\n            s = torch.cuda.Stream()\n            s.wait_stream(torch.cuda.current_stream())\n            with torch.cuda.stream(s):\n                for _ in range(3):\n                    apply_rotary(*self.input_surface)\n            torch.cuda.current_stream().wait_stream(s)\n\n            with torch.cuda.graph(self.cuda_graph):\n                self.static_outputs = apply_rotary(*self.input_surface)\n\n        inputs = (query, key, cos, sin)\n        for static_input, data in zip(self.input_surface, inputs):\n            static_input.copy_(data)\n        self.cuda_graph.replay()\n        return tuple(o.detach() for o in self.static_outputs)\n\n    def cos_sin(self, seq_len: int, past_key_values_length: int, device=\"cpu\", dtype=torch.bfloat16) -> torch.Tensor:\n        total_length = seq_len + past_key_values_length\n        if self.seq_len_cached == -1:\n            # warm up the cache\n            total_length = max(INFERENCE_MAX_LENGTH, total_length)\n\n        if total_length > self.seq_len_cached:\n            with torch.inference_mode(False):\n                self.seq_len_cached = total_length\n                t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)\n                freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n                emb = torch.cat((freqs, freqs), dim=-1).to(device)\n\n                if dtype in [torch.float16, torch.bfloat16]:\n                    emb = emb.float()\n\n                self.register_buffer(\"cos_cached\", emb.cos()[None, :, :].type(dtype), persistent=False)\n                self.register_buffer(\"sin_cached\", emb.sin()[None, :, :].type(dtype), persistent=False)\n\n        return (\n            self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),\n            self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),\n        )\n\n    def forward(self, query, key, past_key_values_length=0):\n        batch, seq_len, head_dim = query.shape\n        cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)\n        if seq_len == 1 and torch.is_inference_mode_enabled() and query.device.type == \"cuda\":\n            return self._optimized_apply_rotary(query, key, cos, sin)\n        else:\n            return apply_rotary(query, key, cos, sin)\n\n\ndef split_heads(\n    fused_qkv: torch.Tensor, num_heads: int, num_kv_heads: int, head_dim: int\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    batch, seq_len, _ = fused_qkv.shape\n    qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim)\n    query, key, value = torch.split(qkv, [num_heads // num_kv_heads, 1, 1], dim=3)\n    key = torch.broadcast_to(key, query.shape)\n    value = torch.broadcast_to(value, query.shape)\n\n    query, key, value = [x.flatten(2, 3) for x in (query, key, value)]\n    return query, key, value\n\n\nclass OptimizedFalconAttention(FalconAttention):\n    def __init__(self, config: FalconConfig):\n        nn.Module.__init__(self)\n\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.split_size = self.hidden_size\n        self.hidden_dropout = config.hidden_dropout\n\n        if self.head_dim * self.num_heads != self.hidden_size:\n            raise ValueError(\n                f\"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:\"\n                f\" {self.num_heads}).\"\n            )\n\n        self.maybe_rotary = OptimizedFalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k)\n\n        # Layer-wise attention scaling\n        self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)\n        self.beta = self.inv_norm_factor\n        if config.new_decoder_architecture:\n            qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim\n        elif config.multi_query:\n            qkv_out_dim = self.hidden_size + 2 * self.head_dim\n        else:\n            qkv_out_dim = 3 * self.hidden_size\n        self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)\n        self.new_decoder_architecture = config.new_decoder_architecture\n        self.multi_query = config.multi_query\n        self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)\n        self.attention_dropout = nn.Dropout(config.attention_dropout)\n        self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1\n\n        if self.new_decoder_architecture:\n            self._split_heads = partial(\n                split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim\n            )\n            self.split_graph = None\n            self.input_surface = None\n            self.static_outputs = None\n\n    def _optimized_split_heads(self, fused_qkv):\n        if self.split_graph is None:\n            self.split_graph = torch.cuda.CUDAGraph()\n            self.input_surface = fused_qkv\n\n            s = torch.cuda.Stream()\n            s.wait_stream(torch.cuda.current_stream())\n            with torch.cuda.stream(s):\n                for _ in range(3):\n                    self._split_heads(fused_qkv)\n            torch.cuda.current_stream().wait_stream(s)\n\n            with torch.cuda.graph(self.split_graph):\n                self.static_outputs = self._split_heads(self.input_surface)\n\n        self.input_surface.copy_(fused_qkv)\n        self.split_graph.replay()\n        return tuple(o.detach() for o in self.static_outputs)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        alibi: Optional[torch.Tensor],\n        attention_mask: torch.Tensor,\n        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        use_cache: bool = False,\n        output_attentions: bool = False,\n    ):\n        assert not output_attentions\n\n        fused_qkv = self.query_key_value(hidden_states)  # [batch_size, seq_length, 3 x hidden_size]\n\n        if (\n            self.new_decoder_architecture\n            and hidden_states.size(1) == 1\n            and torch.is_inference_mode_enabled()\n            and hidden_states.device.type == \"cuda\"\n        ):\n            query_layer, key_layer, value_layer = self._optimized_split_heads(fused_qkv)\n        else:\n            # 3 x [batch_size, seq_length, num_heads, head_dim]\n            (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)\n\n        num_kv_heads = self.num_heads\n        batch_size, query_length, _, _ = query_layer.shape\n\n        query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)\n        key_layer = key_layer.transpose(1, 2).reshape(\n            batch_size * num_kv_heads,\n            query_length,\n            self.head_dim,\n        )\n        value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)\n\n        past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]\n        query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)\n\n        if layer_past is not None:\n            past_key, past_value = layer_past\n            # concatenate along seq_length dimension:\n            #  - key: [batch_size * self.num_heads, kv_length, head_dim]\n            #  - value: [batch_size * self.num_heads, kv_length, head_dim]\n            key_layer = torch.cat((past_key, key_layer), dim=1)\n            value_layer = torch.cat((past_value, value_layer), dim=1)\n\n        _, kv_length, _ = key_layer.shape\n        if use_cache:\n            present = (key_layer, value_layer)\n        else:\n            present = None\n\n        query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)\n        key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)\n        value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)\n\n        attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float(\"-1e9\")).to(query_layer.dtype)\n\n        if alibi is None:\n            attn_output = F.scaled_dot_product_attention(\n                query_layer_, key_layer_, value_layer_, attn_mask=attention_mask_float, dropout_p=0.0, is_causal=False\n            )\n\n            attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)\n            attn_output = attn_output.permute(0, 2, 1, 3)\n            attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)\n\n            output_tensor = self.dense(attn_output)\n\n            return output_tensor, present\n        else:\n            matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)\n\n            # change view to [batch_size, num_heads, q_length, kv_length]\n            attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)\n\n            # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]\n            input_dtype = attention_scores.dtype\n            # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`\n            if input_dtype == torch.float16 or input_dtype == torch.bfloat16:\n                attention_scores = attention_scores.to(torch.float32)\n            # Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by\n            # adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically\n            # equivalent and more performant, but there might be a numerical difference. If you're reading this\n            # and you'd like to experiment and maybe file a PR, feel free!\n            attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)\n            attention_logits *= self.inv_norm_factor\n            attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype)\n            # [batch_size, num_heads, q_length, kv_length]\n            attention_probs = self.attention_dropout(attention_probs)\n\n            if head_mask is not None:\n                attention_probs = attention_probs * head_mask\n\n            # change view [batch_size, num_heads, q_length, kv_length]\n            attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)\n\n            # matmul: [batch_size * num_heads, q_length, head_dim]\n            context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)\n\n            # change view [batch_size, q_length, num_heads * head_dim]\n            context_layer = self._merge_heads(context_layer)\n\n            output_tensor = self.dense(context_layer)\n\n            if output_attentions:\n                return output_tensor, present, attention_probs\n            else:\n                return output_tensor, present\n\n\nclass OptimizedFalconDecoderLayer(FalconDecoderLayer):\n    def __init__(self, config: FalconConfig):\n        nn.Module.__init__(self)\n        hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n\n        self.mlp = FalconMLP(config)\n        self.hidden_dropout = config.hidden_dropout\n        self.config = config\n\n        self.self_attention = OptimizedFalconAttention(config)\n\n        if self.config.alibi or not config.new_decoder_architecture:\n            if config.new_decoder_architecture:\n                # The layer norm before self-attention\n                self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)\n                # The layer norm before the MLP\n                self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)\n            else:\n                self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)\n                if not config.parallel_attn:\n                    self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)\n\n        else:\n            self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)\n            self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)\n\n            self.ln_graph = None\n            self.static_input = None\n            self.static_outputs = None\n\n    def _optimized_apply_ln(self, hidden_states):\n        if self.ln_graph is None:\n            self.ln_graph = torch.cuda.CUDAGraph()\n            self.static_input = hidden_states\n\n            s = torch.cuda.Stream()\n            s.wait_stream(torch.cuda.current_stream())\n            with torch.cuda.stream(s):\n                for _ in range(3):\n                    self.ln_attn(hidden_states)\n                    self.ln_mlp(hidden_states)\n            torch.cuda.current_stream().wait_stream(s)\n\n            with torch.cuda.graph(self.ln_graph):\n                ln_attn_output = self.ln_attn(hidden_states)\n                ln_mlp_output = self.ln_mlp(hidden_states)\n                self.static_outputs = (ln_attn_output, ln_mlp_output)\n\n        self.static_input.copy_(hidden_states)\n        self.ln_graph.replay()\n        return tuple(o.detach() for o in self.static_outputs)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        alibi: Optional[torch.Tensor],\n        attention_mask: torch.Tensor,\n        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        use_cache: bool = False,\n        output_attentions: bool = False,\n    ):\n        residual = hidden_states\n\n        if self.config.new_decoder_architecture:\n            if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == \"cuda\":\n                attention_layernorm_out, mlp_layernorm_out = self._optimized_apply_ln(hidden_states)\n            else:\n                attention_layernorm_out = self.ln_attn(hidden_states)\n                mlp_layernorm_out = self.ln_mlp(hidden_states)\n        else:\n            attention_layernorm_out = self.input_layernorm(hidden_states)\n\n        attn_outputs = self.self_attention(\n            attention_layernorm_out,\n            layer_past=layer_past,\n            attention_mask=attention_mask,\n            alibi=alibi,\n            head_mask=head_mask,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n        )\n\n        attention_output = attn_outputs[0]\n\n        if not self.config.new_decoder_architecture:\n            if self.config.parallel_attn:\n                mlp_layernorm_out = attention_layernorm_out\n            else:\n                residual = dropout_add(\n                    attention_output, residual, self.config.attention_dropout, training=self.training\n                )\n                mlp_layernorm_out = self.post_attention_layernorm(residual)\n\n        outputs = attn_outputs[1:]\n\n        mlp_output = self.mlp(mlp_layernorm_out)\n\n        if self.config.new_decoder_architecture or self.config.parallel_attn:\n            mlp_output += attention_output\n\n        output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)\n\n        if use_cache:\n            outputs = (output,) + outputs\n        else:\n            outputs = (output,) + outputs[1:]\n\n        return outputs  # hidden_states, present, attentions\n\n\nclass WrappedFalconBlock(OptimizedFalconDecoderLayer):\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        *args,\n        attention_mask: Optional[torch.Tensor] = None,\n        alibi: Optional[torch.Tensor] = None,\n        layer_past: Optional[KVCache] = None,\n        use_cache: bool = False,\n        **kwargs,\n    ):\n        assert attention_mask is None\n\n        batch_size, seq_length = hidden_states.shape[:2]\n\n        if layer_past is not None:\n            layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)\n        past_length = 0 if layer_past is None else layer_past[0].shape[1]\n        seq_length_with_past = seq_length + past_length\n\n        attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)\n        if alibi is None and self.config.alibi:\n            alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)\n        attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)\n\n        outputs = super().forward(\n            hidden_states,\n            *args,\n            attention_mask=attention_mask,\n            alibi=alibi,\n            layer_past=layer_past,\n            use_cache=use_cache,\n            **kwargs,\n        )\n\n        if use_cache:\n            present_key_value = outputs[-1]\n            present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)\n            outputs = outputs[:-1] + (present_key_value,)\n\n        return outputs\n\n    def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:\n        key_states, value_states = key_value\n\n        key_states = key_states.permute(0, 2, 1)\n        assert key_states.shape == value_states.shape  # Both are [batch_size * num_kv_heads, seq_len, head_dim]\n\n        if self.config.new_decoder_architecture:\n            key_states = self._expand_states(key_states)\n            value_states = self._expand_states(value_states)\n\n        return (key_states, value_states)\n\n    def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:\n        key_states, value_states = key_value\n\n        if self.config.new_decoder_architecture:\n            key_states = self._collapse_states(key_states)\n            value_states = self._collapse_states(value_states)\n\n        assert key_states.shape == value_states.shape  # Both are [batch_size * num_kv_heads, seq_len, head_dim]\n        key_states = key_states.permute(0, 2, 1)\n\n        return (key_states, value_states)\n\n    def _expand_states(self, state: torch.Tensor) -> torch.Tensor:\n        batch_size_x_num_kv_heads, seq_len, head_dim = state.shape\n        batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads\n\n        state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)\n        state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1)  # No copy\n        state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim)  # Involves a copy\n        return state\n\n    def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:\n        batch_size_x_num_attn_heads, seq_len, head_dim = state.shape\n        batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads\n\n        state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)\n        state = state[:, :, 0]\n        state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)\n        return state\n"
  },
  {
    "path": "src/petals/models/falcon/config.py",
    "content": "import os\nfrom typing import Optional, Union\n\nfrom hivemind import get_logger\nfrom transformers.models.falcon import FalconConfig\nfrom transformers.models.falcon.modeling_falcon import FalconAttention\n\nfrom petals.client.config import ClientConfig\nfrom petals.client.lm_head import LMHeadConfig\nfrom petals.client.ptune import PTuneConfig\nfrom petals.models.falcon.block import WrappedFalconBlock\nfrom petals.utils.auto_config import DefaultRevisionMixin\n\nlogger = get_logger(__name__)\n\n\nclass DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig, PTuneConfig, LMHeadConfig):\n    block_class = WrappedFalconBlock\n    attn_class = FalconAttention\n    block_prefix = \"transformer.h\"\n\n    @property\n    def num_key_value_groups(self) -> int:\n        if self.new_decoder_architecture:\n            return self.num_attention_heads // self.num_kv_heads\n        if self.multi_query:\n            return self.num_attention_heads\n        return 1\n\n    @classmethod\n    def from_pretrained(\n        cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs\n    ):\n        if \"180B\" in model_name_or_path.upper():\n            logger.info(\"Make sure you follow the Falcon-180B license: https://bit.ly/falcon-180b-license\")\n\n        loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)\n        if loading_from_repo and dht_prefix is None:\n            dht_prefix = str(model_name_or_path)\n            dht_prefix = dht_prefix.split(\"/\")[-1]  # Use only repo name to merge blocks hosted by different accounts\n            dht_prefix = dht_prefix.replace(\".\", \"-\")\n            logger.info(f\"Using DHT prefix: {dht_prefix}\")\n\n        result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)\n        config = result[0] if isinstance(result, tuple) else result\n        if config.pad_token_id is None:\n            config.pad_token_id = 0\n        return result\n"
  },
  {
    "path": "src/petals/models/falcon/model.py",
    "content": "from typing import Optional\n\nimport hivemind\nimport torch\nimport torch.nn as nn\nfrom hivemind.utils.logging import get_logger\nfrom transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions\nfrom transformers.models.falcon import (\n    FalconForCausalLM,\n    FalconForSequenceClassification,\n    FalconModel,\n    FalconPreTrainedModel,\n)\n\nfrom petals.client.from_pretrained import FromPretrainedMixin\nfrom petals.client.lm_head import LMHead\nfrom petals.client.ptune import PTuneMixin\nfrom petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues\nfrom petals.client.remote_sequential import RemoteSequential\nfrom petals.models.falcon.config import DistributedFalconConfig\nfrom petals.utils.auto_config import DefaultRevisionMixin\n\nlogger = get_logger(__name__)\n\n\nclass DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, FalconModel):\n    \"\"\"FalconModel, but all transformer layers are hosted by the swarm\"\"\"\n\n    _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing\n    _keys_to_ignore_on_load_unexpected = [r\"^transformer\\.h\\.\"]\n\n    config_class = DistributedFalconConfig\n\n    def __init__(self, config: DistributedFalconConfig, *, dht: Optional[hivemind.DHT] = None):\n        n_layer, config.num_hidden_layers = config.num_hidden_layers, 0  # Prevent initialization\n        super().__init__(config)\n        assert len(self.h) == 0\n        config.num_hidden_layers = n_layer\n\n        self.h = RemoteSequential(config, dht=dht)\n\n        self.requires_grad_(False)  # Forbid accumulate grads for embeddings and layernorm\n        self.init_prompts(config)\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[RemotePastKeyValues] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        head_mask: Optional[torch.LongTensor] = None,\n        inputs_embeds: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ):\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\"You cannot specify both input_ids and inputs_embeds at the same time\")\n        elif input_ids is not None:\n            input_shape = input_ids.size()\n            input_ids = input_ids.view(-1, input_shape[-1])\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.size()[:-1]\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        # The causal mask will be added on the server-side\n        assert (\n            attention_mask is None or (attention_mask == 1).all()\n        ), f\"Custom attention masks are not supported, {attention_mask=}\"\n        assert (\n            position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()\n        ), f\"Non-consecutive position_ids are not supported, {position_ids=}\"\n        assert head_mask is None, f\"Custom head masks are not supported, {head_mask=}\"\n        assert use_cache is None or use_cache, f\"{use_cache=} is not supported\"\n        assert not output_attentions, f\"{output_attentions=} is not supported\"\n        assert not output_hidden_states, f\"{output_hidden_states=} is not supported\"\n        assert return_dict is None or return_dict, f\"{return_dict=} is not supported\"\n\n        if inputs_embeds is None:\n            inputs_embeds = self.word_embeddings(input_ids)\n\n        use_prompts = self.config.tuning_mode and \"ptune\" in self.config.tuning_mode and self.h.position == 0\n        if use_prompts:\n            batch_size = inputs_embeds.shape[0]\n            prompts, intermediate_prompts = self.get_prompt(batch_size)\n            inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)\n        else:\n            prompts = intermediate_prompts = None\n\n        hidden_states = self.word_embeddings_layernorm(inputs_embeds)\n        output_shape = input_shape + (hidden_states.size(-1),)\n\n        hidden_states = self.h(\n            hidden_states,\n            prompts=intermediate_prompts,\n            hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,\n        )\n\n        # Remove prefix\n        if use_prompts:\n            hidden_states = hidden_states[:, self.pre_seq_len :]\n\n        # Add last hidden state\n        hidden_states = self.ln_f(hidden_states)\n        hidden_states = hidden_states.view(output_shape)\n        return BaseModelOutputWithPastAndCrossAttentions(\n            last_hidden_state=hidden_states,\n            past_key_values=RemotePastKeyValues(),\n            hidden_states=None,\n            attentions=None,\n        )\n\n    @property\n    def word_embeddings_layernorm(self) -> nn.Module:  # For compatibility with RemoteGenerationMixin\n        return nn.Identity()\n\n\nclass DistributedFalconForCausalLM(DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, FalconForCausalLM):\n    _keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing\n    _keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected\n\n    config_class = DistributedFalconConfig\n\n    def __init__(self, config: DistributedFalconConfig):\n        FalconPreTrainedModel.__init__(self, config)\n        self.transformer = DistributedFalconModel(config)\n        self.lm_head = LMHead(config)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n\nclass DistributedFalconForSequenceClassification(\n    DefaultRevisionMixin, FromPretrainedMixin, FalconForSequenceClassification\n):\n    _keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing\n    _keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected\n\n    config_class = DistributedFalconConfig\n\n    def __init__(self, config: DistributedFalconConfig):\n        FalconPreTrainedModel.__init__(self, config)\n        self.num_labels = config.num_labels\n\n        self.transformer = DistributedFalconModel(config)\n        self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n"
  },
  {
    "path": "src/petals/models/llama/__init__.py",
    "content": "from petals.models.llama.block import WrappedLlamaBlock\nfrom petals.models.llama.config import DistributedLlamaConfig\nfrom petals.models.llama.model import (\n    DistributedLlamaForCausalLM,\n    DistributedLlamaForSequenceClassification,\n    DistributedLlamaModel,\n)\nfrom petals.models.llama.speculative_model import DistributedLlamaForSpeculativeGeneration\nfrom petals.utils.auto_config import register_model_classes\n\nregister_model_classes(\n    config=DistributedLlamaConfig,\n    model=DistributedLlamaModel,\n    model_for_causal_lm=DistributedLlamaForCausalLM,\n    model_for_speculative=DistributedLlamaForSpeculativeGeneration,\n    model_for_sequence_classification=DistributedLlamaForSequenceClassification,\n)\n"
  },
  {
    "path": "src/petals/models/llama/block.py",
    "content": "\"\"\"\nLLaMA intermediate layer\nBased on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py\nSee commit history for authorship.\n\"\"\"\nimport math\nfrom typing import Optional, Tuple\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask\nfrom transformers.models.llama.modeling_llama import (\n    LlamaAttention,\n    LlamaConfig,\n    LlamaDecoderLayer,\n    LlamaMLP,\n    LlamaRMSNorm,\n    repeat_kv,\n    rotate_half,\n)\n\nfrom petals.utils.cuda_graphs import make_inference_graphed_callable\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin):\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\nclass OptimizedLlamaAttention(LlamaAttention):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self._rotary_graph = None\n\n    def _optimized_apply_rotary(self, query_states, key_states, cos, sin):\n        if self._rotary_graph is None:\n            self._rotary_graph = make_inference_graphed_callable(\n                apply_rotary_pos_emb, sample_args=(query_states, key_states, cos, sin)\n            )\n        return self._rotary_graph(query_states, key_states, cos, sin)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        assert not output_attentions\n        if position_ids is None:\n            past_seen_tokens = past_key_value[0].shape[2] if past_key_value is not None else 0\n            position_ids = torch.arange(\n                past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device\n            ).unsqueeze(0)\n\n        bsz, q_len, _ = hidden_states.size()\n\n        if self.config.pretraining_tp > 1:\n            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp\n            query_slices = self.q_proj.weight.split(\n                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0\n            )\n            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)\n            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)\n\n            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]\n            query_states = torch.cat(query_states, dim=-1)\n\n            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]\n            key_states = torch.cat(key_states, dim=-1)\n\n            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]\n            value_states = torch.cat(value_states, dim=-1)\n\n        else:\n            query_states = self.q_proj(hidden_states)\n            key_states = self.k_proj(hidden_states)\n            value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n        cos, sin = self.rotary_emb(value_states, position_ids)\n        cos, sin = cos.unsqueeze(1), sin.unsqueeze(1)\n\n        if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == \"cuda\":\n            query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)\n        else:\n            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n        if past_key_value is not None:\n            # reuse k, v, self_attention\n            key_states = torch.cat([past_key_value[0], key_states], dim=2)\n            value_states = torch.cat([past_key_value[1], value_states], dim=2)\n\n        past_key_value = (key_states, value_states) if use_cache else None\n\n        # repeat k/v heads if n_kv_heads < n_heads\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n        if attention_mask is not None:\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n\n        if self.config.pretraining_tp > 1:\n            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)\n            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)\n            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])\n        else:\n            attn_output = self.o_proj(attn_output)\n\n        return attn_output, None, past_key_value\n\n\nclass OptimizedLlamaDecoderLayer(LlamaDecoderLayer):\n    def __init__(self, config: LlamaConfig):\n        nn.Module.__init__(self)\n        self.hidden_size = config.hidden_size\n        self.self_attn = OptimizedLlamaAttention(config=config, layer_idx=0)\n        # layer_idx only matters for KV caching, and we re-implement it in Petals\n        self.mlp = LlamaMLP(config)\n        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n        self.pre_attn_graph = None\n        self.post_attn_graph = None\n\n    def _optimized_input_layernorm(self, hidden_states):\n        if self.pre_attn_graph is None:\n            self.pre_attn_graph = make_inference_graphed_callable(\n                self.input_layernorm.forward, sample_args=(hidden_states,)\n            )\n        return self.pre_attn_graph(hidden_states)\n\n    def _optimized_output_layernorm(self, hidden_states):\n        if self.post_attn_graph is None:\n            self.post_attn_graph = make_inference_graphed_callable(\n                self.post_attention_layernorm.forward, sample_args=(hidden_states,)\n            )\n        return self.post_attn_graph(hidden_states)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n        \"\"\"\n\n        residual = hidden_states\n\n        if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == \"cuda\":\n            hidden_states = self._optimized_input_layernorm(hidden_states)\n        else:\n            hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n            cache_position=cache_position,\n            **kwargs,\n        )\n\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n\n        if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == \"cuda\":\n            hidden_states = self._optimized_output_layernorm(hidden_states)\n        else:\n            hidden_states = self.post_attention_layernorm(hidden_states)\n\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        return outputs\n\n\nclass WrappedLlamaBlock(OptimizedLlamaDecoderLayer):\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        *args,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        layer_past: Optional[Tuple[torch.Tensor]] = None,\n        use_cache: bool = False,\n        **kwargs,\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        batch_size, seq_length, _ = hidden_states.shape\n\n        seq_length_with_past = seq_length\n        past_key_values_length = 0\n\n        past_key_value = layer_past\n        if past_key_value is not None:\n            past_key_values_length = past_key_value[0].shape[2]\n            seq_length_with_past = seq_length_with_past + past_key_values_length\n            past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)\n\n        assert position_ids is None\n\n        # embed positions\n        if attention_mask is None:\n            attention_mask = torch.ones(\n                (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device\n            )\n        attention_mask = _prepare_4d_causal_attention_mask(\n            attention_mask=attention_mask,\n            input_shape=(batch_size, seq_length),\n            inputs_embeds=hidden_states,\n            past_key_values_length=past_key_values_length,\n        )\n\n        outputs = super().forward(\n            hidden_states,\n            *args,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            use_cache=use_cache,\n            **kwargs,\n        )\n\n        if use_cache:\n            present_key_value = outputs[-1]\n            present_key_value = self._reorder_cache_from_llama_to_bloom(\n                present_key_value, batch_size, seq_length_with_past\n            )\n            outputs = outputs[:-1] + (present_key_value,)\n\n        return outputs\n\n    def _reorder_cache_from_bloom_to_llama(\n        self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int\n    ) -> Tuple[torch.Tensor]:\n        key_states, value_states = key_value\n        key_states = key_states.permute(0, 2, 1)\n        key_states = key_states.view(\n            batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim\n        )\n        value_states = value_states.view(*key_states.shape)\n        return (key_states, value_states)\n\n    def _reorder_cache_from_llama_to_bloom(\n        self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int\n    ) -> Tuple[torch.Tensor]:\n        key_states, value_states = key_value\n        value_states = value_states.view(\n            batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim\n        )\n        key_states = key_states.view(*value_states.shape)\n        key_states = key_states.permute(0, 2, 1)\n        return (key_states, value_states)\n"
  },
  {
    "path": "src/petals/models/llama/config.py",
    "content": "import os\nfrom typing import Optional, Union\n\nfrom hivemind import get_logger\nfrom transformers.models.llama import LlamaConfig\nfrom transformers.models.llama.modeling_llama import LlamaAttention\n\nfrom petals.client.config import ClientConfig\nfrom petals.client.lm_head import LMHeadConfig\nfrom petals.client.ptune import PTuneConfig\nfrom petals.models.llama.block import WrappedLlamaBlock\n\nlogger = get_logger(__name__)\n\n\nclass DistributedLlamaConfig(LlamaConfig, ClientConfig, PTuneConfig, LMHeadConfig):\n    block_class = WrappedLlamaBlock\n    attn_class = LlamaAttention\n    block_prefix = \"model.layers\"\n\n    @property\n    def num_key_value_groups(self):\n        return self.num_attention_heads // self.num_key_value_heads\n\n    @classmethod\n    def from_pretrained(\n        cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs\n    ):\n        logger.info(\n            \"Make sure you follow the Llama terms of use: \"\n            \"https://llama.meta.com/llama3/license, https://llama.meta.com/llama2/license\"\n        )\n\n        loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)\n        if loading_from_repo and dht_prefix is None:\n            dht_prefix = str(model_name_or_path)\n            dht_prefix = dht_prefix.split(\"/\")[-1]  # Use only repo name to merge blocks hosted by different accounts\n            dht_prefix = dht_prefix.replace(\".\", \"-\")\n            if not dht_prefix.endswith(\"-hf\"):\n                dht_prefix += \"-hf\"\n            logger.info(f\"Using DHT prefix: {dht_prefix}\")\n\n        result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)\n        config = result[0] if isinstance(result, tuple) else result\n        config.pretraining_tp = 1  # This may give less accurate results but it doesn't matter if we use quantization\n        config.use_cache = True  # use_cache=False leads to identical results but is slower and not supported by Petals\n        return result\n"
  },
  {
    "path": "src/petals/models/llama/model.py",
    "content": "from typing import Optional\n\nimport hivemind\nimport torch\nimport torch.nn as nn\nfrom hivemind.utils.logging import get_logger\nfrom transformers.modeling_outputs import BaseModelOutputWithPast\nfrom transformers.models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel\n\nfrom petals.client.from_pretrained import FromPretrainedMixin\nfrom petals.client.lm_head import LMHead\nfrom petals.client.ptune import PTuneMixin\nfrom petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues\nfrom petals.client.remote_sequential import RemoteSequential\nfrom petals.models.llama.config import DistributedLlamaConfig\n\nlogger = get_logger(__name__)\n\n\nclass DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):\n    \"\"\"LlamaModel, but all transformer layers are hosted by the swarm\"\"\"\n\n    _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing\n    _keys_to_ignore_on_load_unexpected = [r\"^model\\.layers\\.\"]\n\n    config_class = DistributedLlamaConfig\n\n    def __init__(self, config: DistributedLlamaConfig, *, dht: Optional[hivemind.DHT] = None):\n        n_layer, config.num_hidden_layers = config.num_hidden_layers, 0  # Prevent initialization\n        super().__init__(config)\n        assert len(self.layers) == 0\n        config.num_hidden_layers = n_layer\n\n        self.layers = RemoteSequential(config, dht=dht)\n\n        self.requires_grad_(False)  # Forbid accumulate grads for embeddings and layernorm\n        self.init_prompts(config)\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[RemotePastKeyValues] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> BaseModelOutputWithPast:\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\"You cannot specify both input_ids and inputs_embeds at the same time\")\n        elif input_ids is not None:\n            input_shape = input_ids.size()\n            input_ids = input_ids.view(-1, input_shape[-1])\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.size()[:-1]\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        # The causal mask will be added on the server-side\n        assert (\n            attention_mask is None or (attention_mask == 1).all()\n        ), f\"Custom attention masks are not supported, {attention_mask=}\"\n        if cache_position is not None:\n            assert position_ids is not None and torch.all(torch.eq(cache_position, position_ids)).item()\n        assert (\n            position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()\n        ), f\"Non-consecutive position_ids are not supported, {position_ids=}\"\n        assert use_cache is None or use_cache, f\"{use_cache=} is not supported\"\n        assert not output_attentions, f\"{output_attentions=} is not supported\"\n        assert not output_hidden_states, f\"{output_hidden_states=} is not supported\"\n        assert return_dict is None or return_dict, f\"{return_dict=} is not supported\"\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        use_prompts = self.config.tuning_mode and \"ptune\" in self.config.tuning_mode and self.layers.position == 0\n        if use_prompts:\n            batch_size = inputs_embeds.shape[0]\n            prompts, intermediate_prompts = self.get_prompt(batch_size)\n            inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)\n        else:\n            prompts = intermediate_prompts = None\n\n        hidden_states = inputs_embeds\n        output_shape = input_shape + (hidden_states.size(-1),)\n\n        hidden_states = self.layers(\n            hidden_states,\n            prompts=intermediate_prompts,\n            hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,\n        )\n\n        if past_key_values is None:\n            past_key_values = RemotePastKeyValues()\n        past_key_values.update_seen(hidden_states.size(1))\n\n        # Remove prefix\n        if use_prompts:\n            hidden_states = hidden_states[:, self.pre_seq_len :]\n\n        # Add last hidden state\n        hidden_states = self.norm(hidden_states)\n        hidden_states = hidden_states.view(output_shape)\n\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=past_key_values,\n            hidden_states=None,\n            attentions=None,\n        )\n\n    @property\n    def word_embeddings(self) -> nn.Embedding:  # For compatibility with RemoteGenerationMixin\n        return self.embed_tokens\n\n    @property\n    def word_embeddings_layernorm(self) -> nn.Module:  # For compatibility with RemoteGenerationMixin\n        return nn.Identity()\n\n    @property\n    def h(self) -> RemoteSequential:  # For compatibility with RemoteGenerationMixin\n        return self.layers\n\n    @property\n    def ln_f(self) -> nn.Module:  # For compatibility with RemoteGenerationMixin\n        return self.norm\n\n\nclass DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, LlamaForCausalLM):\n    _keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing\n    _keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected\n\n    config_class = DistributedLlamaConfig\n\n    def __init__(self, config: DistributedLlamaConfig):\n        LlamaPreTrainedModel.__init__(self, config)\n        self.model = DistributedLlamaModel(config)\n        self.pretraining_tp = config.pretraining_tp\n        self.vocab_size = config.vocab_size\n        self.lm_head = LMHead(config)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    @property\n    def transformer(self) -> DistributedLlamaModel:  # For compatibility with RemoteGenerationMixin\n        return self.model\n\n\nclass DistributedLlamaForSequenceClassification(FromPretrainedMixin, LlamaForSequenceClassification):\n    _keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing\n    _keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected\n\n    config_class = DistributedLlamaConfig\n\n    def __init__(self, config):\n        LlamaPreTrainedModel.__init__(self, config)\n        self.num_labels = config.num_labels\n\n        self.model = DistributedLlamaModel(config)\n        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    @property\n    def transformer(self) -> DistributedLlamaModel:  # For compatibility with RemoteGenerationMixin\n        return self.model\n"
  },
  {
    "path": "src/petals/models/llama/speculative_model.py",
    "content": "from typing import Optional, Union\n\nimport torch\nfrom transformers.generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList\nfrom transformers.generation.utils import GenerateNonBeamOutput, GenerationMixin\nfrom transformers.modeling_outputs import BaseModelOutputWithPast\nfrom transformers.models.llama import LlamaForCausalLM\n\nfrom petals.models.llama.config import DistributedLlamaConfig\nfrom petals.models.llama.model import DistributedLlamaForCausalLM\n\n\nclass DistributedLlamaForSpeculativeGeneration(DistributedLlamaForCausalLM, GenerationMixin):\n    def __init__(self, config: DistributedLlamaConfig, small_model: LlamaForCausalLM):\n        DistributedLlamaForCausalLM.__init__(self, config)\n        self.small_model = small_model\n\n    def _sample(\n        self,\n        input_ids: torch.LongTensor,\n        logits_processor: LogitsProcessorList,\n        stopping_criteria: StoppingCriteriaList,\n        generation_config: GenerationConfig,\n        synced_gpus: bool,\n        streamer: Optional[\"BaseStreamer\"],\n        logits_warper: Optional[LogitsProcessorList],\n        speculative_inference_iteration_size: int = 10,\n        **model_kwargs,\n    ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:\n        assert not generation_config.do_sample, \"sample is not working for speculative generation now\"\n        assert not synced_gpus, \"synced_gpus is not working for speculative generation now\"\n        assert (\n            not generation_config.return_dict_in_generate\n        ), \"return_dict_in_generate is not working for speculative generation now\"\n\n        has_eos_stopping_criteria = any(hasattr(criteria, \"eos_token_id\") for criteria in stopping_criteria)\n\n        # keep track of which sequences are already finished\n        batch_size = input_ids.shape[0]\n        unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)\n        finished = False\n        firsts = True\n\n        while not finished:\n            speculative_inference_iteration_size = min(\n                speculative_inference_iteration_size, self.active_session._max_length - input_ids.shape[1]\n            )\n            with torch.no_grad():\n                speculative_outputs = self.small_model.generate(\n                    input_ids,\n                    max_new_tokens=speculative_inference_iteration_size,\n                    do_sample=False,\n                )\n                speculative_tokens = speculative_outputs[:, -speculative_inference_iteration_size:]\n\n            full_sequence = torch.cat([input_ids, speculative_tokens], dim=-1)\n            assert input_ids.shape[1] + speculative_inference_iteration_size == full_sequence.shape[1]\n\n            input_for_validation = full_sequence\n            if not firsts:\n                self.active_session.position = input_ids.shape[1] - 1\n                input_for_validation = input_for_validation[:, -speculative_inference_iteration_size - 1 :]\n            else:\n                firsts = False\n            input_for_validation = input_for_validation[:, :-1]\n            with torch.no_grad():\n                precise_model_outputs = self(input_for_validation)\n            full_token_logits = precise_model_outputs.logits[:, -speculative_inference_iteration_size:, :].clone()\n\n            all_valid_tokens = []\n            first_token = None\n            for i in range(speculative_inference_iteration_size):\n                token_logits = full_token_logits[:, i, :]\n                token_scores = logits_processor(\n                    input_for_validation[:, : -speculative_inference_iteration_size + 1 + i], token_logits\n                )\n                valid_token = torch.argmax(token_scores, dim=-1)\n\n                if first_token is None:\n                    first_token = valid_token\n\n                if valid_token.item() == speculative_tokens[:, i].item():\n                    all_valid_tokens.append(valid_token.unsqueeze(-1))\n                else:\n                    break\n\n            if not all_valid_tokens and first_token is not None:\n                all_valid_tokens.append(first_token.unsqueeze(-1))\n            all_valid_tokens = torch.cat(all_valid_tokens, dim=-1)\n\n            # finished sentences should have their next token be a padding token\n            if has_eos_stopping_criteria:\n                all_valid_tokens = all_valid_tokens * unfinished_sequences + generation_config.pad_token_id * (\n                    1 - unfinished_sequences\n                )\n\n            # update generated ids, model inputs, and length for next step\n            input_ids = torch.cat([input_ids, all_valid_tokens], dim=-1)\n\n            if streamer is not None:\n                streamer.put(all_valid_tokens.cpu())\n\n            unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, None)\n            finished = unfinished_sequences.max() == 0\n\n            del precise_model_outputs\n\n        if streamer is not None:\n            streamer.end()\n\n        return input_ids\n"
  },
  {
    "path": "src/petals/models/mixtral/__init__.py",
    "content": "from petals.models.mixtral.block import WrappedMixtralBlock\nfrom petals.models.mixtral.config import DistributedMixtralConfig\nfrom petals.models.mixtral.model import (\n    DistributedMixtralForCausalLM,\n    DistributedMixtralForSequenceClassification,\n    DistributedMixtralModel,\n)\nfrom petals.utils.auto_config import register_model_classes\n\nregister_model_classes(\n    config=DistributedMixtralConfig,\n    model=DistributedMixtralModel,\n    model_for_causal_lm=DistributedMixtralForCausalLM,\n    model_for_sequence_classification=DistributedMixtralForSequenceClassification,\n)\n"
  },
  {
    "path": "src/petals/models/mixtral/block.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nfrom transformers import MixtralConfig\nfrom transformers.cache_utils import DynamicCache\nfrom transformers.modeling_attn_mask_utils import (\n    _prepare_4d_causal_attention_mask,\n    _prepare_4d_causal_attention_mask_for_sdpa,\n)\nfrom transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer\n\n\nclass WrappedMixtralBlock(MixtralDecoderLayer):\n    def __init__(self, config: MixtralConfig, layer_idx: int):\n        super().__init__(config, layer_idx)\n\n        self._attn_implementation = config._attn_implementation\n        self.sliding_window = config.sliding_window\n        self.layer_idx = layer_idx\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        *args,\n        attention_mask: Optional[torch.Tensor] = None,\n        layer_past: Optional[Tuple[torch.Tensor]] = None,\n        use_cache: bool = False,\n        **kwargs\n    ):\n        batch_size, seq_length, _ = hidden_states.shape\n\n        seq_length_with_past = seq_length\n        past_key_values_length = 0\n\n        past_key_value = layer_past\n\n        if past_key_value is not None:\n            past_key_values_length = past_key_value[0].shape[2]\n            seq_length_with_past = seq_length_with_past + past_key_values_length\n            _past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length)\n            past_key_value = DynamicCache()\n            past_key_value.key_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[0]]\n            past_key_value.value_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[1]]\n            past_key_value._seen_tokens = past_key_values_length\n\n        if self._attn_implementation == \"flash_attention_2\":\n            # 2d mask is passed through the layers\n            attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None\n        elif self._attn_implementation == \"sdpa\":\n            # output_attentions=True can not be supported when using SDPA, and we fall back on\n            # the manual implementation that requires a 4D causal mask in all cases.\n            attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(\n                attention_mask,\n                (batch_size, seq_length),\n                hidden_states,\n                past_key_values_length,\n            )\n        else:\n            # 4d mask is passed through the layers\n            attention_mask = _prepare_4d_causal_attention_mask(\n                attention_mask,\n                (batch_size, seq_length),\n                hidden_states,\n                past_key_values_length,\n                sliding_window=self.sliding_window,\n            )\n\n        position_ids = torch.arange(\n            past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=hidden_states.device\n        )\n        position_ids = position_ids.unsqueeze(0).view(-1, seq_length)\n\n        outputs = super().forward(\n            hidden_states,\n            *args,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            use_cache=use_cache,\n            **kwargs\n        )\n\n        if use_cache:\n            present_key_value = outputs[-1]\n            present_key_value = present_key_value[self.layer_idx]\n            present_key_value = self._reorder_cache_to_bloom(present_key_value, batch_size, seq_length_with_past)\n            outputs = outputs[:-1] + (present_key_value,)\n\n        return outputs\n\n    def _reorder_cache_from_bloom(\n        self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int\n    ) -> Tuple[torch.Tensor]:\n        # TODO: Move to mixin\n        key_states, value_states = key_value\n        key_states = key_states.permute(0, 2, 1)\n        key_states = key_states.view(\n            batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim\n        )\n        value_states = value_states.view(*key_states.shape)\n        return (key_states, value_states)\n\n    def _reorder_cache_to_bloom(\n        self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int\n    ) -> Tuple[torch.Tensor]:\n        # TODO: Move to mixin\n        key_states, value_states = key_value\n        value_states = value_states.view(\n            batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim\n        )\n        key_states = key_states.view(*value_states.shape)\n        key_states = key_states.permute(0, 2, 1)\n        return (key_states, value_states)\n"
  },
  {
    "path": "src/petals/models/mixtral/config.py",
    "content": "import os\nfrom typing import Optional, Union\n\nfrom hivemind import get_logger\nfrom transformers.models.mixtral import MixtralConfig\nfrom transformers.models.mixtral.modeling_mixtral import MixtralAttention\n\nfrom petals.client.config import ClientConfig\nfrom petals.client.lm_head import LMHeadConfig\nfrom petals.client.ptune import PTuneConfig\nfrom petals.models.mixtral.block import WrappedMixtralBlock\n\nlogger = get_logger(__name__)\n\n\nclass DistributedMixtralConfig(MixtralConfig, ClientConfig, PTuneConfig, LMHeadConfig):\n    block_class = WrappedMixtralBlock\n    attn_class = MixtralAttention\n    block_prefix = \"model.layers\"\n\n    num_key_value_groups = 1\n\n    @classmethod\n    def from_pretrained(\n        cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs\n    ):\n        loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)\n        if loading_from_repo and dht_prefix is None:\n            dht_prefix = str(model_name_or_path)\n            dht_prefix = dht_prefix.replace(\".\", \"-\")\n            logger.info(f\"Using DHT prefix: {dht_prefix}\")\n        result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)\n        config = result[0] if isinstance(result, tuple) else result\n        if config.pad_token_id is None:\n            config.pad_token_id = 0\n        return result\n"
  },
  {
    "path": "src/petals/models/mixtral/model.py",
    "content": "from typing import Optional\n\nimport torch\nimport torch.nn as nn\nfrom hivemind import DHT\nfrom hivemind.utils.logging import get_logger\nfrom transformers.modeling_outputs import MoeModelOutputWithPast\nfrom transformers.models.mixtral import (\n    MixtralForCausalLM,\n    MixtralForSequenceClassification,\n    MixtralModel,\n    MixtralPreTrainedModel,\n)\n\nfrom petals.client.from_pretrained import FromPretrainedMixin\nfrom petals.client.lm_head import LMHead\nfrom petals.client.ptune import PTuneMixin\nfrom petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues\nfrom petals.client.remote_sequential import RemoteSequential\nfrom petals.models.mixtral.config import DistributedMixtralConfig\nfrom petals.utils.auto_config import DefaultRevisionMixin\n\nlogger = get_logger(__name__)\n\n\nclass DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, MixtralModel):\n    \"\"\"MixtralModel, but all transformer layers are hosted by the swarm\"\"\"\n\n    _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing\n    _keys_to_ignore_on_load_unexpected = [r\"^model\\.layers\\.\"]\n\n    config_class = DistributedMixtralConfig\n\n    def __init__(self, config: DistributedMixtralConfig, *, dht: Optional[DHT] = None):\n        n_layer, config.num_hidden_layers = config.num_hidden_layers, 0  # Prevent initialization\n        super().__init__(config)\n        assert len(self.layers) == 0\n        config.num_hidden_layers = n_layer\n\n        self.layers = RemoteSequential(config, dht=dht)\n\n        self.requires_grad_(False)  # Forbid accumulate grads for embeddings and layernorm\n        self.init_prompts(config)\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[RemotePastKeyValues] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        head_mask: Optional[torch.LongTensor] = None,\n        inputs_embeds: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        output_router_logits: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ):\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\"You cannot specify both input_ids and inputs_embeds at the same time\")\n        elif input_ids is not None:\n            input_shape = input_ids.size()\n            input_ids = input_ids.view(-1, input_shape[-1])\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.size()[:-1]\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        # The causal mask will be added on the server-side\n        assert (\n            attention_mask is None or (attention_mask == 1).all()\n        ), f\"Custom attention masks are not supported, {attention_mask=}\"\n        if cache_position is not None:\n            assert position_ids is not None and torch.all(torch.eq(cache_position, position_ids)).item()\n        assert (\n            position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()\n        ), f\"Non-consecutive position_ids are not supported, {position_ids=}\"\n        assert head_mask is None, f\"Custom head masks are not supported, {head_mask=}\"\n        assert use_cache is None or use_cache, f\"{use_cache=} is not supported\"\n        assert not output_attentions, f\"{output_attentions=} is not supported\"\n        assert not output_hidden_states, f\"{output_hidden_states=} is not supported\"\n        assert return_dict is None or return_dict, f\"{return_dict=} is not supported\"\n        assert not output_router_logits, f\"{output_router_logits=} is not supported\"\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        use_prompts = self.config.tuning_mode and \"ptune\" in self.config.tuning_mode and self.h.position == 0\n        if use_prompts:\n            batch_size = inputs_embeds.shape[0]\n            prompts, intermediate_prompts = self.get_prompt(batch_size)\n            inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)\n        else:\n            prompts = intermediate_prompts = None\n\n        hidden_states = inputs_embeds\n        output_shape = input_shape + (hidden_states.size(-1),)\n\n        if past_key_values is None:\n            past_key_values = RemotePastKeyValues()\n        past_key_values.update_seen(hidden_states.size(1))\n\n        hidden_states = self.layers(\n            hidden_states,\n            prompts=intermediate_prompts,\n            hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,\n        )\n\n        # Remove prefix\n        if use_prompts:\n            hidden_states = hidden_states[:, self.pre_seq_len :]\n\n        # Add last hidden state\n        hidden_states = self.norm(hidden_states)\n        hidden_states = hidden_states.view(output_shape)\n        return MoeModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=past_key_values,\n            hidden_states=None,\n            attentions=None,\n        )\n\n    @property\n    def word_embeddings(self) -> nn.Embedding:  # For compatibility with RemoteGenerationMixin\n        return self.embed_tokens\n\n    @property\n    def word_embeddings_layernorm(self) -> nn.Module:  # For compatibility with RemoteGenerationMixin in tests\n        return nn.Identity()\n\n    @property\n    def h(self) -> RemoteSequential:  # For compatibility with RemoteGenerationMixin\n        return self.layers\n\n    @property\n    def ln_f(self) -> nn.Module:  # For compatibility with RemoteGenerationMixin in tests\n        return self.norm\n\n\nclass DistributedMixtralForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM):\n    _keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing\n    _keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected\n\n    config_class = DistributedMixtralConfig\n\n    def __init__(self, config: DistributedMixtralConfig):\n        MixtralPreTrainedModel.__init__(self, config)\n        self.model = DistributedMixtralModel(config)\n        self.lm_head = LMHead(config)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    @property\n    def transformer(self) -> DistributedMixtralModel:  # For compatibility with RemoteGenerationMixin\n        return self.model\n\n\nclass DistributedMixtralForSequenceClassification(FromPretrainedMixin, MixtralForSequenceClassification):\n    _keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing\n    _keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected\n\n    config_class = DistributedMixtralConfig\n\n    def __init__(self, config: DistributedMixtralConfig):\n        MixtralPreTrainedModel.__init__(self, config)\n        self.num_labels = config.num_labels\n\n        self.model = DistributedMixtralModel(config)\n        self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    @property\n    def transformer(self) -> DistributedMixtralModel:  # For compatibility with RemoteGenerationMixin\n        return self.model\n"
  },
  {
    "path": "src/petals/server/__init__.py",
    "content": ""
  },
  {
    "path": "src/petals/server/backend.py",
    "content": "from __future__ import annotations\n\nfrom collections import Counter\nfrom itertools import chain\nfrom typing import Any, Dict, Optional, Sequence, Tuple, Union\n\nimport torch\nfrom hivemind import BatchTensorDescriptor, TensorDescriptor\nfrom hivemind.moe.expert_uid import ExpertUID\nfrom hivemind.moe.server.module_backend import ModuleBackend\nfrom hivemind.utils import get_logger\nfrom tensor_parallel import TensorParallel\nfrom tensor_parallel.tensor_parallel import PerDeviceTensors\nfrom transformers import PretrainedConfig\n\nfrom petals.data_structures import InferenceMetadata\nfrom petals.server.memory_cache import MemoryCache\nfrom petals.server.task_pool import PrioritizedTaskPool\nfrom petals.utils.misc import get_size_in_bytes, is_dummy\n\nlogger = get_logger(__name__)\n\n\nclass TransformerBackend(ModuleBackend):\n    \"\"\"A wrapper for a transformer block that can process requests for forward, backward and inference\"\"\"\n\n    _peft_module = None\n\n    def __init__(\n        self,\n        *args,\n        config: PretrainedConfig,\n        memory_cache: MemoryCache,\n        backend_dtype: torch.dtype,\n        max_chunk_size_bytes: int,\n        **kwargs,\n    ):\n        import petals.utils.peft as _peft_module\n\n        self._peft_module = _peft_module\n\n        super().__init__(*args, **kwargs)\n        assert isinstance(self.module, TensorParallel)\n        self.config = config\n        self.memory_cache = memory_cache\n        self.max_chunk_size_bytes = max_chunk_size_bytes\n\n        for name, param in self.module.named_parameters():\n            assert not param.requires_grad, f\"Block parameters must not accumulate gradients, but {name} does\"\n        for name, buf in self.module.named_buffers():\n            assert not buf.requires_grad, f\"Block parameters must not accumulate gradients, but {name} does\"\n\n        max_batch_size = self.forward_pool.max_batch_size\n        device = self.module.devices[self.module.output_device_index]\n        self.inference_pool = PrioritizedTaskPool(\n            self.inference_step, max_batch_size=max_batch_size, device=device, name=f\"{self.name}_inference\"\n        )  # note: inference_pools may be merged later, see merge_inference_pools_inplace\n        self.forward_pool = PrioritizedTaskPool(\n            self.forward, max_batch_size=max_batch_size, device=device, name=f\"{self.name}_forward\"\n        )\n        self.backward_pool = PrioritizedTaskPool(\n            self.backward, max_batch_size=max_batch_size, device=device, name=f\"{self.name}_backward\"\n        )\n\n        self.dtype = backend_dtype\n        self.dtype_bytes = get_size_in_bytes(self.dtype)\n        self.shard_num_heads = []\n        for shard in self.module.module_shards:\n            for submodule in shard.modules():\n                if isinstance(submodule, config.attn_class):\n                    self.shard_num_heads.append(submodule.num_heads)\n        assert len(self.shard_num_heads) == len(self.module.devices)\n        assert sum(self.shard_num_heads) == config.num_attention_heads\n\n        self.inference_schema = (\n            (\n                *self.args_schema,\n                BatchTensorDescriptor((), dtype=self.dtype),\n                BatchTensorDescriptor((), dtype=torch.int64),\n            ),\n            self.kwargs_schema,\n        )\n\n        self.cache_bytes_per_token: Dict[torch.device, int] = Counter()\n        for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1):\n            self.cache_bytes_per_token[descr.device] += descr.numel() * get_size_in_bytes(descr.dtype)\n\n    def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:\n        \"\"\"Create tensor descriptors for attention cache tensors used during inference_step\"\"\"\n        head_dim = self.config.hidden_size // self.config.num_attention_heads\n        cache_tensors = []\n        for device, num_heads in zip(self.module.devices, self.shard_num_heads):\n            num_heads //= self.config.num_key_value_groups\n            if hasattr(self.config, \"num_key_value_heads\"):\n                num_heads = self.config.num_key_value_heads\n            keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device)\n            values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device)\n            cache_tensors.extend((keys, values))\n        return cache_tensors\n\n    def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:\n        *inputs, active_adapter = inputs\n        with self._peft_module.using_adapter(active_adapter):\n            return super().forward(*inputs)\n\n    def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:\n        *inputs, active_adapter = inputs\n        with self._peft_module.using_adapter(active_adapter):\n            return super().backward(*inputs)\n\n    @torch.inference_mode()\n    def inference_step(\n        self,\n        hidden_states: torch.Tensor,\n        hypo_ids: torch.LongTensor,\n        inference_info: InferenceMetadata,\n    ) -> Tuple[torch.Tensor, ...]:\n        assert hidden_states.ndim == 3, \"expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]\"\n        seq_len = hidden_states.shape[1]\n\n        with self.memory_cache.use_cache(\n            *inference_info.cache_handles\n        ) as cache_tensors, self._peft_module.using_adapter(inference_info.active_adapter):\n            self._reorder_cache_inplace(cache_tensors, hypo_ids)\n\n            # We chunk the inputs so that peak memory for long sequences fits into `autograd_memory`\n            # reserved in `Server._choose_num_blocks()`. This saves us from OOMs if `max_chunk_size_bytes`\n            # is at least 4-6x less than `autograd_memory`.\n            max_chunk_length = self._estimate_max_chunk_length(hidden_states, inference_info)\n            output_hidden_states = torch.empty_like(hidden_states) if seq_len > max_chunk_length else None\n            layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)\n            for offset in range(0, seq_len, max_chunk_length):\n                hidden_states_chunk = hidden_states[:, offset : offset + max_chunk_length, :]\n                output_hidden_states_chunk, new_kvs = self.module.forward(\n                    hidden_states_chunk, layer_past=layer_past, use_cache=True\n                )\n                if seq_len > max_chunk_length:\n                    output_hidden_states[:, offset : offset + max_chunk_length] = output_hidden_states_chunk\n                else:\n                    output_hidden_states = output_hidden_states_chunk  # saves one memcopy\n                layer_past = new_kvs\n\n            self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)\n            return (output_hidden_states,)\n\n    def _estimate_max_chunk_length(self, hidden_states: torch.Tensor, inference_info: InferenceMetadata) -> int:\n        # We assume that attention logit matrices are the main thing that consumes memory, given that\n        # the model uses multi-query attention\n        batch_size, seq_length, hidden_size = hidden_states.shape\n        worst_case_length = inference_info.prefix_length + seq_length\n        attn_bytes_per_token = max(self.shard_num_heads) * batch_size * self.dtype_bytes * worst_case_length\n        return max(1, self.max_chunk_size_bytes // attn_bytes_per_token)\n\n    def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor):\n        \"\"\"If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids\"\"\"\n        if not is_dummy(hypo_ids):\n            for cache_tensor in cache_tensors:\n                cache_tensor[...] = cache_tensor[hypo_ids.to(cache_tensor.device)]  # in-place reorder cache by hypo ids\n\n    def _select_layer_past(self, cache_tensors: Sequence[torch.Tensor], prefix_length: int) -> Sequence[torch.Tensor]:\n        \"\"\"Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past\"\"\"\n        key_cache, value_cache = list(cache_tensors[0::2]), list(cache_tensors[1::2])\n        for i in range(len(key_cache)):\n            key_cache[i] = key_cache[i].flatten(0, 1)[:, :, :prefix_length]\n            # shape: [batch * num_kv_heads, head_dim, kv_length]\n            value_cache[i] = value_cache[i].flatten(0, 1)[:, :prefix_length]\n            # shape: [batch * num_kv_heads, kv_length, head_dim]\n        layer_past = tuple(chain(*zip(key_cache, value_cache)))\n        return PerDeviceTensors(*layer_past) if len(self.module.module_shards) > 1 else layer_past\n\n    def _update_cache_inplace(\n        self, cache_tensors: Sequence[torch.Tensor], new_kvs: Sequence[torch.Tensor], prefix_length: int\n    ):\n        \"\"\"Writes new key/value tensors back into cache, works in-place\"\"\"\n        _batch_size_times_num_kv_heads, head_dim, new_length = new_kvs[0].shape\n        for cache_key, new_key in zip(cache_tensors[0::2], new_kvs[0::2]):\n            new_key = new_key.view(*cache_key.shape[:3], new_length)\n            cache_key[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length]\n        for cache_value, new_value in zip(cache_tensors[1::2], new_kvs[1::2]):\n            new_value = new_value.view(*cache_value.shape[:2], new_length, head_dim)\n            cache_value[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :]\n\n    def get_pools(self) -> Sequence[PrioritizedTaskPool]:\n        return self.forward_pool, self.backward_pool, self.inference_pool\n\n    def get_info(self) -> Dict[str, Any]:\n        \"\"\"Get module parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration.\"\"\"\n        return dict(super().get_info(), inference_schema=self.inference_schema)\n\n    def shutdown(self):\n        # Break the cyclic references, otherwise TransformerBackend may be not garbage-collected\n        self.forward_pool = self.backward_pool = self.inference_pool = None\n\n        # Explicitly free the GPU memory. This is not necessary at the time this code is written,\n        # but may help to avoid future issues when the module is not garbage-collected for some reasons\n        dummy = torch.tensor([])\n        for p in self.module.parameters():\n            p.data = dummy\n\n\ndef merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):\n    \"\"\"Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call\"\"\"\n    assert len(backends) != 0 and all(isinstance(b, TransformerBackend) for b in backends.values())\n    first_pool = next(iter(backends.values())).inference_pool\n    merged_pool = PrioritizedTaskPool(\n        _MergedInferenceStep(backends),\n        max_batch_size=first_pool.max_batch_size,\n        device=first_pool.device,\n        name=f\"merged_inference\",\n    )\n    for backend in backends.values():\n        assert not backend.inference_pool.is_alive()\n        backend.inference_pool = merged_pool\n\n\nclass _MergedInferenceStep:\n    def __init__(self, backends: Dict[ExpertUID, TransformerBackend]):\n        self.backends = backends\n\n    @torch.inference_mode()\n    def __call__(\n        self,\n        hidden_states: torch.Tensor,\n        hypo_ids: torch.LongTensor,\n        inference_infos: Sequence[InferenceMetadata],\n        *optional_prompts: Optional[torch.Tensor],\n    ) -> Tuple[torch.Tensor, ...]:\n        assert len(inference_infos) == len(\n            optional_prompts\n        ), f\"found {len(inference_infos)} blocks but {len(optional_prompts)} prompts\"\n        for inference_info, optional_prompt in zip(inference_infos, optional_prompts):\n            if optional_prompt is not None:\n                hidden_states[:, : optional_prompt.shape[1]] += optional_prompt\n            (hidden_states,) = self.backends[inference_info.uid].inference_step(hidden_states, hypo_ids, inference_info)\n        return (hidden_states,)\n"
  },
  {
    "path": "src/petals/server/block_functions.py",
    "content": "\"\"\"\nThis module implements server-side computations on served blocks: forward, backward and inference; used by handler\n\"\"\"\nfrom __future__ import annotations\n\nfrom typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union\n\nimport torch\nfrom hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor\nfrom hivemind.moe.expert_uid import ExpertUID\nfrom hivemind.proto import runtime_pb2\nfrom hivemind.utils.logging import get_logger\nfrom hivemind.utils.nested import nested_flatten\n\nfrom petals.data_structures import Handle, InferenceMetadata\nfrom petals.server.backend import TransformerBackend\nfrom petals.server.task_pool import PrioritizedTaskPool\nfrom petals.server.task_prioritizer import TaskPrioritizerBase\nfrom petals.utils.convert_block import QuantType\nfrom petals.utils.misc import DUMMY, is_dummy\nfrom petals.utils.packaging import unpack_args_kwargs\n\n# We prioritize short inference requests and make them use a *merged* inference pool,\n# so they are processed without interruptions and extra overheads\n# TODO: Increase the NF4 threshold once bitsandbytes ships efficient NF4 kernel for parallel forward\nMAX_SHORT_INFERENCE_TOKENS = 128\nMAX_NF4_SHORT_INFERENCE_TOKENS = 1\n\nlogger = get_logger(__name__)\n\n\nasync def run_rpc_forward(\n    *flat_tensors: torch.Tensor,\n    requested_backends: Sequence[TransformerBackend],\n    active_adapter: str = \"\",\n    prioritizer: TaskPrioritizerBase,\n    points: int = 0,\n    args_structure: Any = None,\n) -> torch.Tensor:\n    \"\"\"\n    Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream\n\n    :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors\n    :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)\n    :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass\n    :returns: hidden states after the last layer [batch_size, seq_length, hid_size]\n    \"\"\"\n    if args_structure is not None:\n        # TODO: kwargs currently is unused, it can be used later for peft-like adaptation\n        flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)\n    hidden_states, prompts, *_ = flat_tensors\n\n    dtype = requested_backends[0].dtype\n    # check parse input tensors and cast dtypes\n    hidden_states = hidden_states.to(dtype)\n    assert hidden_states.ndim == 3\n    if prompts is None or is_dummy(prompts):\n        prompts = [DUMMY] * len(requested_backends)\n    else:\n        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]\n\n    # Run a chain of requested backends\n    for backend, prompt in zip(requested_backends, prompts):\n        if not is_dummy(prompt):\n            hidden_states[:, : prompt.shape[1]] += prompt\n\n        assert isinstance(backend.inference_pool, PrioritizedTaskPool), \"petals support only prioritized pools\"\n        priority = prioritizer.prioritize(\n            hidden_states, points=points / len(requested_backends), backend=backend, type=\"forward\"\n        )\n        (hidden_states,) = await backend.forward_pool.submit_task(\n            hidden_states,\n            active_adapter,\n            priority=priority,\n        )\n        assert isinstance(hidden_states, torch.Tensor)\n        assert (\n            hidden_states.ndim == 3\n        ), f\"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states\"\n\n    return hidden_states\n\n\nasync def run_rpc_backward(\n    *flat_tensors: torch.Tensor,\n    requested_backends: Sequence[TransformerBackend],\n    active_adapter: str = \"\",\n    prioritizer: TaskPrioritizerBase,\n    points: int = 0,\n    args_structure: Any = None,\n) -> Union[torch.Tensor, Sequence[torch.Tensor]]:\n    if args_structure is not None:\n        # TODO: kwargs currently is unused, it can be used later for peft-like adaptation\n        flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)\n    inputs, grad_outputs, prompts, *_ = flat_tensors\n\n    # Cast inputs & grad outputs to backend dtype\n    inputs = inputs.to(requested_backends[0].dtype)\n    grad_outputs = grad_outputs.to(requested_backends[-1].dtype)\n\n    if prompts is None or is_dummy(prompts):\n        prompts = [DUMMY] * len(requested_backends)\n    else:\n        prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]\n\n    # Run a forward chain to collect intermediate inputs\n    # Note that we do not forward for the last module since we do not need its output\n    inter_inputs = []\n    for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):\n        assert inputs.ndim == 3, f\"inputs to {type(backend)} must be a single 3d tensor of hidden states\"\n        if not is_dummy(prompt):\n            inputs[:, : prompt.shape[1]] += prompt\n        inter_inputs.append(inputs)\n        assert isinstance(backend.inference_pool, PrioritizedTaskPool), \"petals support only prioritized pools\"\n        priority = prioritizer.prioritize(\n            inputs, points=points / len(requested_backends), backend=backend, type=\"forward_in_backward\"\n        )\n        (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)\n\n        assert isinstance(inputs, torch.Tensor)\n\n    if not is_dummy(prompts[-1]):\n        inputs[:, : prompts[-1].shape[1]] += prompts[-1]\n    inter_inputs.append(inputs)\n\n    assert len(inter_inputs) == len(prompts) == len(requested_backends), \"internal shape error during backward\"\n    grad_prompts_reversed = []\n    # Run a chain of requested backends\n    for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):\n        assert isinstance(backend.inference_pool, PrioritizedTaskPool), \"petals support only prioritized pools\"\n        priority = prioritizer.prioritize(\n            inp, grad_outputs, points=points / len(requested_backends), backend=backend, type=\"backward\"\n        )\n        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)\n\n        assert isinstance(grad_outputs, torch.Tensor)\n        if not is_dummy(prompt):\n            grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))\n\n    grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY\n    return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]  # TODO un-duct-tape\n\n\nasync def iterate_rpc_inference(\n    requested_uids: Sequence[ExpertUID],\n    requested_backends: Sequence[TransformerBackend],\n    active_adapter: Optional[str],\n    input_iterator: AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]],\n    cache_handles: Sequence[Sequence[Handle]],\n    *,\n    max_length: int,\n    prioritizer: TaskPrioritizerBase,\n    points: int,\n    quant_type: QuantType,\n    args_structure: Any = None,\n) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool, Dict]]:\n    assert len(cache_handles) == len(requested_backends)\n\n    prefix_length = 0\n    point_per_piece = points / max_length if max_length > 0 else 0.0\n\n    async for request, step_metadata in input_iterator:\n        if \"start_from_position\" in step_metadata:\n            start_from_position = step_metadata[\"start_from_position\"]\n            assert (\n                prefix_length >= start_from_position,\n            ), f\"prefix_length={prefix_length}, start_from_position={start_from_position}\"\n            prefix_length = start_from_position\n\n        flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)\n        if args_structure is not None:\n            # TODO: kwargs currently is unused, it can be used later for peft-like adaptation\n            flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)\n\n        hidden_states, prompts, hypo_ids, *_ = flat_tensors\n        batch_size, length_increment, _ = hidden_states.shape\n\n        # Cast inputs to backend dtype\n        hidden_states = hidden_states.to(requested_backends[0].dtype)\n        assert hypo_ids.dtype == torch.int64, f\"hypo ids must be int64, got {hypo_ids.dtype}\"\n\n        # parse deep prompts (optional argument)\n        has_prompts = prompts is not None and not is_dummy(prompts)\n        if not has_prompts:\n            prompts = [None] * len(requested_backends)\n        else:\n            prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]\n            prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts]\n\n        if not (len(requested_backends) == len(prompts)):\n            raise ValueError(f\"Received {len(prompts)} prompts for {len(requested_backends)} backends\")\n\n        if prefix_length + length_increment > max_length:\n            raise ValueError(\n                f\"Maximum length exceeded: prefix {prefix_length} + current {length_increment}\"\n                f\" exceeds pre-allocated maximum {max_length}\"\n            )\n\n        merge_max_tokens = MAX_NF4_SHORT_INFERENCE_TOKENS if quant_type == QuantType.NF4 else MAX_SHORT_INFERENCE_TOKENS\n        can_merge_pools = batch_size * length_increment <= merge_max_tokens\n        priority = prioritizer.prioritize(\n            hidden_states,\n            hypo_ids,\n            points=point_per_piece,\n            requested_uids=requested_uids,\n            type=\"inference\",\n        )\n\n        # A client may pass a tensor with 0 tokens. This is a special case that occurs, e.g.\n        # when user wants to pre-allocate cache or check that server *can* allocate that cache.\n        if hidden_states.numel() > 0:\n            assert hidden_states.ndim == 3, f\"hidden states must be a single 3d tensor\"\n            if can_merge_pools:\n                inference_infos = tuple(\n                    InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)\n                    for uid, handles in zip(requested_uids, cache_handles)\n                )\n                (hidden_states,) = await requested_backends[0].inference_pool.submit_task(\n                    hidden_states, hypo_ids, inference_infos, *prompts, priority=priority\n                )\n            else:\n                for backend, uid, handles, prompt in zip(requested_backends, requested_uids, cache_handles, prompts):\n                    inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),)\n                    (hidden_states,) = await backend.inference_pool.submit_task(\n                        hidden_states, hypo_ids, inference_infos, prompt, priority=priority\n                    )\n\n        # serialize and send last layer outputs\n        output_tensors = [\n            serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)\n            for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))\n        ]\n        can_push = not has_prompts\n        yield output_tensors, can_push, step_metadata\n\n        # prepare for next step\n        prefix_length += length_increment\n"
  },
  {
    "path": "src/petals/server/block_selection.py",
    "content": "from typing import Dict, List\n\nimport numpy as np\nfrom hivemind import PeerID, get_logger\n\nfrom petals.data_structures import RemoteModuleInfo, RemoteSpanInfo, ServerState\nfrom petals.utils.dht import compute_spans\n\nlogger = get_logger(__name__)\n\n\ndef compute_throughputs(spans: Dict[PeerID, RemoteSpanInfo], *, total_blocks: int) -> np.ndarray:\n    # We sort servers here to ensure that we get exactly the same throughputs for a given set of servers.\n    # If the order were not defined, we would get slightly different values due to floating point errors,\n    # which may cause excess block replacements.\n\n    throughputs = np.zeros(total_blocks)\n    for span in sorted(spans.values(), key=lambda span: span.peer_id):\n        throughputs[span.start : span.end] += span.throughput\n    return throughputs\n\n\ndef _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:\n    options = ((sorted(throughputs[i : i + num_blocks]), i) for i in range(0, len(throughputs) - num_blocks + 1))\n    return min(options)[-1]\n\n\ndef choose_best_blocks(num_blocks: int, module_infos: List[RemoteModuleInfo]) -> List[int]:\n    spans = compute_spans(module_infos, min_state=ServerState.JOINING)\n    throughputs = compute_throughputs(spans, total_blocks=len(module_infos))\n\n    start = _choose_best_start(throughputs, num_blocks)\n    return list(range(start, start + num_blocks))\n\n\ndef _move_span(span: RemoteSpanInfo, new_start: int):\n    span.start, span.end = new_start, new_start + span.length\n\n\ndef should_choose_other_blocks(\n    local_peer_id: PeerID, module_infos: List[RemoteModuleInfo], balance_quality: float\n) -> bool:\n    if balance_quality > 1.0:\n        return True  # Forces rebalancing on each check (may be used for debugging purposes)\n\n    spans = compute_spans(module_infos, min_state=ServerState.JOINING)\n    throughputs = compute_throughputs(spans, total_blocks=len(module_infos))\n    initial_throughput = throughputs.min()\n    eps = 1e-3\n\n    assert local_peer_id in spans, \"Span served by this server is not present in the DHT\"\n    local_span = spans[local_peer_id]\n    throughputs[local_span.start : local_span.end] -= local_span.throughput * (1 + eps)\n    # Without (1 + eps) here, we would sometimes subtract a value slightly less than local_span.throughput\n    # due to the floating point error, which would cause excess block replacements.\n    # Also, subtracting local_span.throughput * (1 + eps) makes _choose_best_start() prefer\n    # the previous server position in case of other things being almost equal.\n\n    if initial_throughput > eps and throughputs.min() <= 0:\n        return False  # Switching blocks would make the swarm disjoint\n\n    new_start = _choose_best_start(throughputs, local_span.length)\n    if local_span.start == new_start:\n        return False  # This server is on its best place already\n\n    throughputs[local_span.start : local_span.end] += local_span.throughput * eps\n    _move_span(local_span, new_start)\n    throughputs[local_span.start : local_span.end] += local_span.throughput\n\n    moved = True\n    while moved:\n        servers = list(spans.keys())\n        np.random.shuffle(servers)\n\n        moved = False\n        for peer_id in servers:\n            span = spans[peer_id]\n            throughputs[span.start : span.end] -= span.throughput * (1 + eps)\n\n            new_start = _choose_best_start(throughputs, span.length)\n\n            throughputs[span.start : span.end] += span.throughput * eps\n            if span.start != new_start:\n                _move_span(span, new_start)\n                moved = True\n            throughputs[span.start : span.end] += span.throughput\n\n    new_throughput = throughputs.min()\n    if new_throughput < initial_throughput or new_throughput < eps:\n        return False\n\n    actual_quality = initial_throughput / new_throughput\n    logger.info(f\"Swarm balance quality: {actual_quality * 100:.1f}%\")\n\n    return actual_quality < balance_quality - eps\n"
  },
  {
    "path": "src/petals/server/block_utils.py",
    "content": "from typing import Optional, Union\n\nimport torch\nfrom accelerate import init_empty_weights\nfrom transformers import PretrainedConfig, PreTrainedModel\n\nfrom petals.models.mixtral.block import WrappedMixtralBlock\nfrom petals.utils.convert_block import QuantType\nfrom petals.utils.misc import get_size_in_bytes\n\n\ndef resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype:\n    \"\"\"If dtype is \"auto\", resolves it using BloomConfig. Returns `dtype` intact otherwise.\"\"\"\n    if dtype not in (\"auto\", None):\n        return dtype\n    if config.torch_dtype not in (\"auto\", None, torch.float32):\n        # If config specifies float32, we override it to the default dtype below\n        return config.torch_dtype\n    return torch.bfloat16\n\n\ndef get_block_size(\n    config: PretrainedConfig,\n    location: str,\n    *,\n    dtype: Optional[Union[str, torch.dtype]] = None,\n    quant_type: QuantType = QuantType.NONE,\n    eps: float = 0.01,  # eps accounts for ~1% of metainfo for tensor descriptions, quantization tables, etc.\n) -> int:\n    if location == \"memory\":\n        assert (\n            dtype is not None and quant_type is not None\n        ), 'get_block_size(..., location=\"memory\") requires to specify dtype and quant_type for calculations'\n\n    with init_empty_weights(include_buffers=False):\n        block = get_model_block(config)\n        n_params = sum(param.numel() for param in block.parameters())\n\n    if location == \"memory\":\n        if quant_type == QuantType.NONE:\n            dtype = resolve_block_dtype(config, dtype)\n            bytes_per_value = get_size_in_bytes(dtype)\n        elif quant_type == QuantType.INT8:\n            bytes_per_value = 1\n        elif quant_type == QuantType.NF4:\n            bytes_per_value = 4.25 / 8  # Bitness of NF4 with this config (measured empirically)\n        else:\n            raise ValueError(f\"Unsupported quant_type={quant_type}\")\n    elif location == \"disk\":\n        dtype = resolve_block_dtype(config, \"auto\")\n        bytes_per_value = get_size_in_bytes(dtype)\n\n    return round(n_params * bytes_per_value * (1 + eps))\n\n\ndef get_model_block(config, layer_idx: int = 0):\n    \"\"\"\n    The function to create a model block based on the block class\n    kwargs argument **only** is necessary for specific classes, like Mixtral.\n    They will not be passed to other block constructors.\n    \"\"\"\n    if config.block_class == WrappedMixtralBlock:\n        config = PreTrainedModel._autoset_attn_implementation(config)\n        return config.block_class(config, layer_idx)\n    return config.block_class(config)\n"
  },
  {
    "path": "src/petals/server/from_pretrained.py",
    "content": "\"\"\"\nUtils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.\nIf necessary, one can rewrite this to implement a different behavior, such as:\n - loading files from a local data source (e.g. S3)\n - load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )\n - fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )\n\n\"\"\"\nimport json\nimport time\nfrom contextlib import suppress\nfrom typing import Dict, Optional, Union\n\nimport safetensors\nimport torch\nimport torch.nn as nn\nfrom accelerate import init_empty_weights\nfrom accelerate.utils import set_module_tensor_to_device\nfrom hivemind.utils.logging import get_logger\nfrom huggingface_hub import get_hf_file_metadata, hf_hub_url\nfrom huggingface_hub.utils import EntryNotFoundError\nfrom transformers import PretrainedConfig, PreTrainedModel\nfrom transformers.utils import get_file_from_repo\n\nfrom petals.constants import DTYPE_MAP\nfrom petals.models.mixtral import WrappedMixtralBlock\nfrom petals.server.block_utils import get_model_block, resolve_block_dtype\nfrom petals.utils.auto_config import AutoDistributedConfig\nfrom petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for\nfrom petals.utils.hf_auth import always_needs_auth\n\nlogger = get_logger(__name__)\n\n\ndef load_pretrained_block(\n    model_name: str,\n    block_index: int,\n    *,\n    config: Optional[PretrainedConfig] = None,\n    torch_dtype: Union[torch.dtype, str] = \"auto\",\n    revision: Optional[str] = None,\n    token: Optional[Union[str, bool]] = None,\n    cache_dir: Optional[str] = None,\n    max_disk_space: Optional[int] = None,\n) -> nn.Module:\n    if config is None:\n        config = AutoDistributedConfig.from_pretrained(model_name, use_auth_token=token)\n    if cache_dir is None:\n        cache_dir = DEFAULT_CACHE_DIR\n\n    assert torch_dtype in DTYPE_MAP.values(), f\"torch_dtype must be one of {list(DTYPE_MAP.values())}\"\n    torch_dtype = resolve_block_dtype(config, torch_dtype)\n\n    with init_empty_weights():\n        block = get_model_block(config, layer_idx=block_index)\n\n    block_prefix = f\"{config.block_prefix}.{block_index}.\"\n    state_dict = _load_state_dict_from_repo(\n        model_name,\n        block_prefix,\n        revision=revision,\n        token=token,\n        cache_dir=cache_dir,\n        max_disk_space=max_disk_space,\n    )\n\n    for param_name, _ in block.named_parameters():\n        assert param_name in state_dict, f\"{param_name} not in state dict\"\n        param = state_dict[param_name]\n        if not str(param.dtype).startswith((\"torch.uint\", \"torch.int\", \"torch.bool\")):\n            param = param.to(torch_dtype)\n        set_module_tensor_to_device(block, param_name, \"cpu\", value=param, dtype=param.dtype)\n\n    logger.info(f\"Loaded {model_name} block {block_index}\")\n    return block\n\n\nStateDict = Dict[str, torch.Tensor]\n\n\ndef _load_state_dict_from_repo(\n    model_name: str,\n    block_prefix: str,\n    *,\n    revision: Optional[str] = None,\n    token: Optional[Union[str, bool]] = None,\n    cache_dir: str,\n    max_disk_space: Optional[int] = None,\n) -> StateDict:\n    if always_needs_auth(model_name) and token is None:\n        token = True\n\n    index_file = _find_index_file(model_name, revision=revision, token=token, cache_dir=cache_dir)\n    if index_file.endswith(\".index.json\"):  # Sharded model\n        path = get_file_from_repo(model_name, filename=index_file, use_auth_token=token, cache_dir=cache_dir)\n        if path is None:\n            # _find_index_file() told that a file exists but we can't get it (e.g., it just disappeared)\n            raise ValueError(f\"Failed to get file {index_file}\")\n\n        with open(path) as f:\n            index = json.load(f)\n        filenames = {\n            filename for param_name, filename in index[\"weight_map\"].items() if param_name.startswith(block_prefix)\n        }\n        if not filenames:\n            raise RuntimeError(f\"Block {block_prefix}* not found in the index: {index['weight_map']}\")\n    else:  # Non-sharded model\n        filenames = {index_file}\n    logger.debug(f\"Loading {block_prefix}* from {filenames}\")\n\n    state_dict = {}\n    for filename in filenames:\n        shard_state_dict = _load_state_dict_from_repo_file(\n            model_name,\n            filename,\n            block_prefix=block_prefix,\n            revision=revision,\n            token=token,\n            cache_dir=cache_dir,\n            max_disk_space=max_disk_space,\n        )\n        shard_state_dict = {\n            param_name[len(block_prefix) :]: param\n            for param_name, param in shard_state_dict.items()\n            if param_name.startswith(block_prefix)\n        }  # Remove unused parameters from memory\n        state_dict.update(shard_state_dict)\n    return state_dict\n\n\nINDEX_FILES = [\"model.safetensors.index.json\", \"model.safetensors\", \"pytorch_model.bin.index.json\", \"pytorch_model.bin\"]\n\n\ndef _find_index_file(\n    model_name: str, *, revision: Optional[str] = None, token: Optional[Union[str, bool]] = None, cache_dir: str\n) -> str:\n    # If we have cached weights (e.g., Pickle from older Petals versions), reuse them\n    for filename in INDEX_FILES:\n        path = get_file_from_repo(\n            model_name,\n            filename,\n            revision=revision,\n            use_auth_token=token,\n            cache_dir=cache_dir,\n            local_files_only=True,\n        )\n        if path is not None:\n            return filename\n\n    # If we don't, prefer Safetensors when possible\n    # (we don't download files here since we can't account for max_disk_space in case of large files)\n    for filename in INDEX_FILES:\n        with suppress(EntryNotFoundError):\n            get_hf_file_metadata(hf_hub_url(model_name, filename, revision=revision), token=token)\n            return filename\n\n    raise ValueError(\n        f\"Repo {model_name} does not contain weights in a supported format: files {INDEX_FILES} do not exist\"\n    )\n\n\ndef _load_state_dict_from_repo_file(\n    model_name: str,\n    filename: str,\n    *,\n    block_prefix: Optional[str] = None,\n    revision: Optional[str] = None,\n    token: Optional[Union[str, bool]] = None,\n    cache_dir: str,\n    max_disk_space: Optional[int] = None,\n    delay: float = 30,\n) -> StateDict:\n    # First, try to find the weights locally\n    try:\n        with allow_cache_reads(cache_dir):\n            path = get_file_from_repo(\n                model_name,\n                filename,\n                revision=revision,\n                use_auth_token=token,\n                cache_dir=cache_dir,\n                local_files_only=True,\n            )\n            if path is not None:\n                return _load_state_dict_from_local_file(path, block_prefix=block_prefix)\n    except Exception:\n        logger.warning(f\"Cache for file {filename} is corrupted, it will be downloaded again\", exc_info=True)\n\n    # If not found, ensure that we have enough disk space to download them (maybe remove something)\n    while True:\n        try:\n            with allow_cache_writes(cache_dir):\n                url = hf_hub_url(model_name, filename, revision=revision)\n                file_size = get_hf_file_metadata(url, token=token).size\n                if file_size is not None:\n                    free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)\n                else:\n                    logger.warning(f\"Failed to fetch size of file {filename} from repo {model_name}\")\n\n                path = get_file_from_repo(\n                    model_name,\n                    filename,\n                    revision=revision,\n                    use_auth_token=token,\n                    cache_dir=cache_dir,\n                    local_files_only=False,\n                )\n                if path is None:\n                    raise RuntimeError(f\"File {filename} does not exist in repo {model_name}\")\n                return _load_state_dict_from_local_file(path, block_prefix=block_prefix)\n        except Exception as e:\n            logger.warning(f\"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)\", exc_info=True)\n            time.sleep(delay)\n\n\ndef _load_state_dict_from_local_file(path: str, *, block_prefix: Optional[str] = None) -> StateDict:\n    if path.endswith(\".bin\"):\n        return torch.load(path, map_location=\"cpu\")\n\n    if path.endswith(\".safetensors\"):\n        with safetensors.safe_open(path, framework=\"pt\", device=\"cpu\") as f:\n            return {key: f.get_tensor(key) for key in f.keys() if block_prefix is None or key.startswith(block_prefix)}\n\n    raise ValueError(f\"Unknown weight format: {path}\")\n"
  },
  {
    "path": "src/petals/server/handler.py",
    "content": "from __future__ import annotations\n\nimport asyncio\nimport contextlib\nimport multiprocessing as mp\nimport sys\nfrom enum import Enum\nfrom itertools import chain\nfrom typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple\n\nimport torch\nfrom async_timeout import timeout\nfrom hivemind import (\n    DHT,\n    MSGPackSerializer,\n    P2PContext,\n    PeerID,\n    deserialize_tensor_stream,\n    deserialize_torch_tensor,\n    nested_flatten,\n    nested_pack,\n    serialize_torch_tensor,\n)\nfrom hivemind.moe.server.connection_handler import ConnectionHandler\nfrom hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE\nfrom hivemind.proto import runtime_pb2\nfrom hivemind.utils.asyncio import amap_in_executor, anext\nfrom hivemind.utils.logging import get_logger\nfrom hivemind.utils.streaming import split_for_streaming\n\nimport petals\nfrom petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, Handle, ModuleUID\nfrom petals.server.backend import TransformerBackend\nfrom petals.server.block_functions import iterate_rpc_inference, run_rpc_backward, run_rpc_forward\nfrom petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase\nfrom petals.utils.convert_block import QuantType\n\nlogger = get_logger(__name__)\n\n\n# Fix pickling protobufs, see https://stackoverflow.com/a/74873028\nsys.modules[\"runtime_pb2\"] = runtime_pb2\n\n\nCACHE_TOKENS_AVAILABLE = \"cache_tokens_available\"\n\n\nclass Event(Enum):\n    NEW_SESSION = 0\n    END_SESSION = 1\n    PUSH = 2\n    SHUTDOWN = 3\n\n\nclass TransformerConnectionHandler(ConnectionHandler):\n    \"\"\"Handles three request types: forward, backward and forward-incremental (inference)\"\"\"\n\n    module_backends: Dict[ModuleUID, TransformerBackend]\n\n    def __init__(\n        self,\n        dht: DHT,\n        module_backends: Dict[str, TransformerBackend],\n        *,\n        adapters: Optional[Sequence[str]],\n        dht_prefix: str,\n        handler_event_queues: Sequence[mp.Queue],\n        handler_index: int,\n        inference_max_length: int,\n        request_timeout: float,\n        session_timeout: float,\n        step_timeout: float,\n        task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(),\n        quant_type: QuantType,\n    ):\n        super().__init__(dht, module_backends)\n        for module_backend in self.module_backends.values():\n            assert isinstance(module_backend, TransformerBackend)\n        self.dht_prefix = dht_prefix\n        self.adapters = adapters\n        self._handler_event_queues = handler_event_queues\n        self._handler_index = handler_index\n        self._own_event_queue = handler_event_queues[handler_index]\n        self._listener_task: Optional[asyncio.Task] = None\n        self._session_queues: Dict[str, asyncio.Queue] = {}\n        self._session_handlers: Dict[str, int] = {}\n\n        self.inference_max_length = inference_max_length\n        self.request_timeout = request_timeout\n        self.session_timeout, self.step_timeout = session_timeout, step_timeout\n        self._prioritizer = task_prioritizer\n        self.quant_type = quant_type\n\n    async def add_p2p_handlers(self, *args, **kwargs) -> None:\n        if self._listener_task is None:\n            # Start listening to our own event queue before we accept any requests\n            self._listener_task = asyncio.create_task(self._listen_to_event_queue())\n        await super().add_p2p_handlers(*args, **kwargs)\n\n    def shutdown(self):\n        if self.is_alive():\n            self._outer_pipe.send(\"_shutdown\")\n            self._own_event_queue.put((Event.SHUTDOWN, None, None))\n            self.join(self.shutdown_timeout)\n            if self.is_alive():\n                logger.warning(f\"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM\")\n                self.terminate()\n\n    async def _gather_inputs(\n        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext\n    ) -> Tuple[str, List[torch.Tensor], Dict]:\n        block_uid, metadata = None, None\n\n        def _unpack(req: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]:\n            nonlocal block_uid, metadata\n\n            if block_uid is None:\n                block_uid = req.uid\n            elif block_uid != req.uid:\n                raise ValueError(\"Block uids differ in one request\")\n\n            if metadata is None:\n                metadata = MSGPackSerializer.loads(req.metadata) if req.metadata else {}\n\n            return req.tensors\n\n        tensors_stream = amap_in_executor(_unpack, requests)\n        inputs = await deserialize_tensor_stream(tensors_stream)\n        assert isinstance(block_uid, str) and isinstance(metadata, dict)\n        return block_uid, inputs, metadata\n\n    async def rpc_inference(\n        self,\n        requests: AsyncIterator[runtime_pb2.ExpertRequest],\n        context: P2PContext,\n    ) -> AsyncIterator[runtime_pb2.ExpertResponse]:\n        \"\"\"Compute a single step of inference using attention cache; update attention cache accordingly.\"\"\"\n        async with timeout(self.session_timeout):\n            try:\n                request = await asyncio.wait_for(anext(requests), self.step_timeout)\n            except asyncio.TimeoutError:\n                self._log_request(\"rpc_inference.open\", None, context, warning=\"timed out\")\n                return\n\n            requested_uids = self._check_uids(request.uid)\n            self._log_request(\"rpc_inference.open\", requested_uids, context)\n            try:\n                metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}\n                requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)\n                max_length = metadata.get(\"max_length\")\n                points = metadata.get(\"points\", 0)\n                session_id = metadata.get(\"session_id\")\n                alloc_timeout = float(metadata.get(\"alloc_timeout\", 0.0))\n                args_structure = metadata.get(\"args_structure\")\n                if not requested_uids:\n                    raise ValueError(\"User must specify at least one block for inference, but got none\")\n                assert isinstance(\n                    max_length, int\n                ), f\"rpc_inference metadata must contain int max_length, got {max_length}\"\n                assert isinstance(\n                    points, (float, int)\n                ), f\"rpc_inference should have number of points as a number or None, got {points}\"\n                if not 0 <= max_length <= self.inference_max_length:\n                    raise ValueError(\n                        f\"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}\"\n                    )\n\n                batch_size = request.tensors[0].size[0] if request.tensors else 1\n\n                async with self._allocate_cache(\n                    requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout\n                ) as cache_handles:\n                    background_tasks = set()\n                    async for output_tensors, can_push, step_metadata in iterate_rpc_inference(\n                        requested_uids=requested_uids,\n                        requested_backends=requested_backends,\n                        active_adapter=self._get_active_adapter(metadata),\n                        input_iterator=self._iterate_inference_steps(\n                            request, requests, session_id, requested_uids, context\n                        ),\n                        cache_handles=cache_handles,\n                        max_length=max_length,\n                        prioritizer=self._prioritizer,\n                        points=points,\n                        quant_type=self.quant_type,\n                        args_structure=args_structure,\n                    ):\n                        if can_push:\n                            task = asyncio.create_task(self._push_outputs(request, output_tensors[0], step_metadata))\n                            background_tasks.add(task)  # Keep reference until it is done to save it from GC\n                            task.add_done_callback(background_tasks.discard)\n                        yield runtime_pb2.ExpertResponse(tensors=output_tensors)\n\n            finally:\n                self._log_request(\"rpc_inference.close\", requested_uids, context)\n\n    @contextlib.contextmanager\n    def _managed_session(self, session_id: str):\n        assert session_id not in self._session_queues, f\"session id {session_id} is not unique\"\n        try:\n            self._session_queues[session_id] = asyncio.Queue()\n            self._session_handlers[session_id] = self._handler_index\n            for other_index, other_queue in enumerate(self._handler_event_queues):\n                if other_index != self._handler_index:\n                    other_queue.put_nowait((Event.NEW_SESSION, session_id, self._handler_index))\n            yield\n        finally:\n            self._session_queues.pop(session_id).put_nowait(None)  # put None so that the get task will not hang\n            del self._session_handlers[session_id]\n            for other_index, other_queue in enumerate(self._handler_event_queues):\n                if other_index != self._handler_index:\n                    other_queue.put_nowait((Event.END_SESSION, session_id, self._handler_index))\n\n    def _put_into_session_queue(self, session_id: str, request: runtime_pb2.ExpertRequest):\n        handler_index = self._session_handlers.get(session_id)\n        if handler_index is None:\n            logger.debug(f\"Ignored rpc_push to unknown session ID: {session_id}\")\n        elif handler_index == self._handler_index:\n            self._session_queues[session_id].put_nowait(request)\n        else:\n            self._handler_event_queues[handler_index].put_nowait((Event.PUSH, session_id, request))\n\n    async def _get_from_session_queue(self, session_id: str) -> Optional[runtime_pb2.ExpertRequest]:\n        assert self._session_handlers[session_id] == self._handler_index, \"session belongs to another handler\"\n        return await self._session_queues[session_id].get()\n\n    async def _listen_to_event_queue(self):\n        loop = asyncio.get_event_loop()\n        while True:\n            try:\n                event, session_id, payload = await loop.run_in_executor(None, self._own_event_queue.get)\n                if event == Event.SHUTDOWN:\n                    break\n                elif event == Event.NEW_SESSION:\n                    self._session_handlers[session_id] = payload  # index of the handler that owns that session\n                elif event == Event.END_SESSION:\n                    self._session_handlers.pop(session_id, None)\n                elif event == Event.PUSH:\n                    maybe_session_queue = self._session_queues.get(session_id)\n                    if maybe_session_queue is not None:\n                        maybe_session_queue.put_nowait(payload)\n                else:\n                    raise RuntimeError(f\"Unexpected event: {event}\")\n            except Exception as e:\n                logger.exception(e)\n\n    async def _iterate_inference_steps(\n        self,\n        first_request: runtime_pb2.ExpertRequest,\n        requests: AsyncIterator[runtime_pb2.ExpertRequest],\n        session_id: Optional[str],\n        requested_uids: Sequence[str],\n        context: P2PContext,\n    ) -> AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]]:\n        processed_step_ids = set()\n        n_pushes = n_late_pushes = 0\n        request = first_request\n        anext_task = get_push_task = None\n        try:\n            with self._managed_session(session_id) if session_id is not None else contextlib.nullcontext():\n                while request.tensors:  # iterate while user is willing to supply tensors\n                    metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}\n                    step_id = metadata.get(\"step_id\")\n\n                    pushed = metadata.get(\"pushed\")\n                    if pushed:\n                        n_pushes += 1\n                        self._log_request(\"rpc_inference.push\", requested_uids, context, debug=f\"session received push\")\n\n                    if step_id is None or step_id not in processed_step_ids:\n                        yield request, metadata\n                        if step_id is not None:\n                            processed_step_ids.add(step_id)\n                    elif pushed:\n                        n_late_pushes += 1\n                        self._log_request(\n                            \"rpc_inference.push\",\n                            requested_uids,\n                            context,\n                            warning=f\"arrived late {n_late_pushes / n_pushes * 100:.1f}% of the time\",\n                        )\n\n                    # Wait for the next request, coming either from the `requests` iterator or `push_queue`\n                    if anext_task is None:\n                        anext_task = asyncio.create_task(anext(requests))\n                    if get_push_task is None:\n                        if session_id is not None:\n                            get_push_task = asyncio.create_task(self._get_from_session_queue(session_id))\n                        else:\n                            get_push_task = asyncio.create_task(asyncio.Event().wait())  # Dummy never-ending task\n                    done, _ = await asyncio.wait(\n                        [anext_task, get_push_task], timeout=self.step_timeout, return_when=asyncio.FIRST_COMPLETED\n                    )\n\n                    if anext_task in done:\n                        request = await anext_task\n                        anext_task = None\n                    elif get_push_task in done:\n                        request = await get_push_task\n                        get_push_task = None\n                    else:\n                        self._log_request(\"rpc_inference.step\", requested_uids, context, warning=\"timed out\")\n                        anext_task.cancel()\n                        get_push_task.cancel()\n                        return\n        except Exception:\n            logger.warning(\"rpc_inference._iterate_inference_steps() exception:\", exc_info=True)\n            raise\n\n    async def rpc_push(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:\n        \"\"\"Directly push activation tensors from one server to another\"\"\"\n\n        requested_uids = self._check_uids(request.uid)\n        metadata = MSGPackSerializer.loads(request.metadata)\n        session_id = metadata[\"session_id\"]\n        self._log_request(\"rpc_push\", requested_uids, context, debug=f\"session_id={session_id}\")\n        self._put_into_session_queue(session_id, request)\n        return runtime_pb2.ExpertResponse()\n\n    async def _push_outputs(\n        self, request: runtime_pb2.ExpertRequest, serialized_outputs: runtime_pb2.Tensor, metadata: dict\n    ) -> None:\n        try:\n            next_servers = metadata.get(\"next_servers\")\n            if not next_servers:\n                return\n\n            next_peer_id, next_session_id, next_start, next_end = next_servers[0]\n            next_peer_id = PeerID.from_base58(next_peer_id)\n            next_uid = CHAIN_DELIMITER.join(f\"{self.dht_prefix}{UID_DELIMITER}{i}\" for i in range(next_start, next_end))\n\n            # Sending hidden states serialized with output_schema to avoid double serialization\n            next_tensors = [serialized_outputs] + request.tensors[1:]\n            next_metadata = metadata.copy()\n            next_metadata.update(session_id=next_session_id, next_servers=next_servers[1:], pushed=True)\n\n            stub = self.get_stub(self._p2p, next_peer_id)\n            await stub.rpc_push(\n                runtime_pb2.ExpertRequest(\n                    uid=next_uid,\n                    tensors=next_tensors,\n                    metadata=MSGPackSerializer.dumps(next_metadata),\n                ),\n                timeout=self.request_timeout,\n            )\n        except Exception:\n            logger.debug(\n                f\"Failed to push outputs to peer_id={next_peer_id}, session_id={next_session_id}, blocks={next_start}:{next_end}:\",\n                exc_info=True,\n            )\n\n    async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:\n        async with timeout(self.request_timeout):\n            # Parse request and prepare backends\n            flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]\n            requested_uids = self._check_uids(request.uid)\n            self._log_request(\"rpc_forward\", requested_uids, context)\n\n            requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)\n            metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}\n            active_adapter = self._get_active_adapter(metadata)\n            points = metadata.get(\"points\", 0)\n            args_structure = metadata.get(\"args_structure\")\n            assert isinstance(\n                points, (float, int)\n            ), f\"rpc_forward should have number of points as number or None, got {points}\"\n\n            hidden_states = await run_rpc_forward(\n                *flat_inputs,\n                requested_backends=requested_backends,\n                prioritizer=self._prioritizer,\n                active_adapter=active_adapter,\n                points=points,\n                args_structure=args_structure,\n            )\n            return runtime_pb2.ExpertResponse(\n                tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)\n            )\n\n    async def rpc_forward_stream(\n        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext\n    ) -> AsyncIterator[runtime_pb2.ExpertRequest]:\n        async with timeout(self.request_timeout):\n            # Parse requests and prepare backends\n            uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)\n            requested_uids = self._check_uids(uid_str)\n            self._log_request(\"rpc_forward_stream\", requested_uids, context)\n\n            requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)\n            active_adapter = self._get_active_adapter(metadata)\n            points = metadata.get(\"points\", 0)\n            args_structure = metadata.get(\"args_structure\")\n            assert isinstance(\n                points, (float, int)\n            ), f\"rpc_forward_stream should have number of points as number or None, got {points}\"\n\n            hidden_states = await run_rpc_forward(\n                *flat_inputs,\n                requested_backends=requested_backends,\n                prioritizer=self._prioritizer,\n                active_adapter=active_adapter,\n                points=points,\n                args_structure=args_structure,\n            )\n\n            # Split the serialized_output for streaming and respond to client\n            for tensor in self._serialize_outputs(hidden_states, requested_backends, metadata):\n                for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):\n                    yield runtime_pb2.ExpertResponse(tensors=[part])\n\n    def _serialize_outputs(\n        self,\n        hidden_states: torch.Tensor,\n        requested_backends: Sequence[TransformerBackend],\n        metadata: Dict[str, Any],\n    ) -> Sequence[runtime_pb2.Tensor]:\n        \"\"\"Serialize forward outputs using either outputs_schema or custom user-specified schema\"\"\"\n        assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, \"hidden_states must be a 3d tensor\"\n        outputs_schema = requested_backends[-1].outputs_schema\n\n        if metadata.get(\"output_compression\") is not None:\n            assert isinstance(metadata[\"output_compression\"], (list, tuple)), \"output_compression must be a tuple/list\"\n            output_compression = tuple(metadata[\"output_compression\"])\n            assert all(isinstance(c, int) for c in output_compression), \"output_compression must contain integers\"\n            assert len(output_compression) == 1, f\"output_compression tuple should have 1 element\"\n        else:\n            output_compression = tuple(tensor.compression for tensor in outputs_schema)\n\n        return [\n            serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True)\n            for result, proto, compression in zip([hidden_states], outputs_schema, output_compression)\n        ]\n\n    async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:\n        async with timeout(self.request_timeout):\n            # Parse requests and prepare backends\n            flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]\n            requested_uids = self._check_uids(request.uid)\n            self._log_request(\"rpc_backward\", requested_uids, context)\n\n            requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)\n            metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}\n            active_adapter = self._get_active_adapter(metadata)\n            points = metadata.get(\"points\", 0)\n            args_structure = metadata.get(\"args_structure\")\n            assert isinstance(\n                points, (float, int)\n            ), f\"rpc_backward should have number of points as number or None, got {points}\"\n\n            grads = await run_rpc_backward(\n                *flat_tensors,\n                requested_backends=requested_backends,\n                prioritizer=self._prioritizer,\n                active_adapter=active_adapter,\n                points=points,\n                args_structure=args_structure,\n            )\n\n            return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))\n\n    async def rpc_backward_stream(\n        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext\n    ) -> AsyncIterator[runtime_pb2.ExpertResponse]:\n        async with timeout(self.request_timeout):\n            uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)\n            requested_uids = self._check_uids(uids_header)\n            self._log_request(\"rpc_backward_stream\", requested_uids, context)\n\n            requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)\n            active_adapter = self._get_active_adapter(metadata)\n            points = metadata.get(\"points\", 0)\n            args_structure = metadata.get(\"args_structure\")\n            assert isinstance(\n                points, (float, int)\n            ), f\"rpc_backward_stream should have number of points as number or None, got {points}\"\n\n            grads = await run_rpc_backward(\n                *flat_tensors,\n                requested_backends=requested_backends,\n                prioritizer=self._prioritizer,\n                active_adapter=active_adapter,\n                points=points,\n                args_structure=args_structure,\n            )\n            # Split the serialized_grad_inputs for streaming and respond\n            for tensor in self._serialize_grads(grads, requested_backends, metadata):\n                for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):\n                    yield runtime_pb2.ExpertResponse(tensors=[part])\n\n    def _get_active_adapter(self, metadata: dict) -> str:\n        active_adapter = metadata.get(\"active_adapter\", \"\")\n        if active_adapter and (active_adapter not in self.adapters):\n            raise KeyError(f\"adapter {active_adapter} not found\")\n        return active_adapter\n\n    def _serialize_grads(\n        self,\n        grads: Sequence[torch.Tensor],\n        requested_backends: Sequence[TransformerBackend],\n        metadata: Dict[str, Any],\n    ) -> Sequence[runtime_pb2.Tensor]:\n        \"\"\"Serialize backward gradients w.r.t. inputs using either default schema or custom user-specified schema\"\"\"\n        # Modify grad_inputs_schema to support grad_prompts\n        assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize\n        flat_grads_schema = tuple(\n            nested_flatten((requested_backends[0].args_schema * len(grads), requested_backends[0].kwargs_schema))\n        )  # TODO generalize\n\n        if metadata.get(\"output_compression\") is not None:\n            assert isinstance(metadata[\"output_compression\"], (list, tuple)), \"output_compression must be a tuple/list\"\n            output_compression = tuple(metadata[\"output_compression\"])\n            assert all(isinstance(c, int) for c in output_compression), \"output_compression must contain integers\"\n            assert len(output_compression) == len(grads), f\"output_compression should have {len(grads)} elements\"\n        else:\n            output_compression = tuple(tensor.compression for tensor in flat_grads_schema)\n\n        return [\n            serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True)\n            for result, proto, compression in zip(grads, flat_grads_schema, output_compression)\n        ]\n\n    def _check_uids(self, uids: str) -> Tuple[ModuleUID, ...]:\n        \"\"\"Check that the first request to rpc_inference is valid\"\"\"\n        uids = (uids or \"\").split(CHAIN_DELIMITER)\n        if not uids:\n            raise RuntimeError(\"User did not provide any uids\")\n        for uid in uids:\n            if uid not in self.module_backends:\n                raise RuntimeError(f\"Remote peer does not serve {uid}\")\n        return tuple(uids)\n\n    @contextlib.asynccontextmanager\n    async def _allocate_cache(\n        self,\n        backends: Sequence[TransformerBackend],\n        *,\n        batch_size: int,\n        max_length: int,\n        timeout: Optional[float],\n    ) -> Sequence[Sequence[Handle]]:\n        \"\"\"\n        Allocate memory cache for all transformer blocks, return cache handle\n        :returns: a list of {len(backends)} elements, where i-th element is a tuple of cache handles for i-th backend\n        \"\"\"\n        descriptors = [backend.get_inference_cache_descriptors(batch_size, max_length) for backend in backends]\n        async with backends[0].memory_cache.allocate_cache(*chain(*descriptors), timeout=timeout) as handles:\n            yield nested_pack(handles, descriptors)\n\n    def _log_request(\n        self,\n        method: str,\n        uids: Optional[Sequence[ModuleUID]],\n        context: P2PContext,\n        *,\n        debug: Optional[str] = None,\n        warning: Optional[str] = None,\n    ) -> None:\n        if uids is not None:\n            friendly_uids = [uid.split(\".\")[-1] for uid in uids if \".\" in uid]\n            friendly_uids = [int(uid) for uid in friendly_uids if uid.isdigit()]\n            friendly_uids = f\"{min(friendly_uids)}:{max(friendly_uids) + 1}\" if friendly_uids else uids\n        else:\n            friendly_uids = \"n/a\"\n\n        friendly_remote_id = \"...\" + str(context.remote_id)[-6:]\n\n        message = f\"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})\"\n        if warning is not None:\n            logger.warning(f\"{message}: {warning}\")\n        elif debug is not None:\n            logger.debug(f\"{message}: {debug}\")\n        else:\n            logger.info(message)\n\n    async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:\n        \"\"\"Return metadata about stored block uids and current load\"\"\"\n\n        backend = self.module_backends[request.uid] if request.uid else next(iter(self.module_backends.values()))\n        result = {\n            \"version\": petals.__version__,\n            \"dht_client_mode\": self.dht.client_mode,\n            CACHE_TOKENS_AVAILABLE: backend.memory_cache.bytes_left // max(backend.cache_bytes_per_token.values()),\n        }\n\n        if request.uid:\n            block_info = self.module_backends[request.uid].get_info()\n            common_keys = set(result.keys()) & set(block_info.keys())\n            if common_keys:\n                raise RuntimeError(f\"The block's rpc_info has keys reserved for the server's rpc_info: {common_keys}\")\n            result.update(block_info)\n\n        return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result))\n"
  },
  {
    "path": "src/petals/server/memory_cache.py",
    "content": "\"\"\"\nA pytorch memory cache that can be allocated by ConnectionHandler (on cpu) and used over multiple calls to Runtime.\n\nFor now, the only purpose of this code is to ensure that allocated memory will be deleted properly.\n\n\"\"\"\nimport asyncio\nimport contextlib\nimport ctypes\nimport multiprocessing as mp\nimport os\nimport time\nfrom typing import AsyncContextManager, Dict, Optional, Sequence\n\nimport async_timeout\nimport torch\nfrom hivemind.utils import TensorDescriptor, enter_asynchronously, get_logger\n\nfrom petals.data_structures import Handle\nfrom petals.utils.asyncio import shield_and_wait\nfrom petals.utils.misc import get_size_in_bytes\n\nlogger = get_logger(__name__)\n\n\nclass MemoryCache:\n    \"\"\"A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs\"\"\"\n\n    def __init__(self, max_size_bytes: Optional[int], max_alloc_timeout: Optional[float] = None):\n        self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)\n        self.max_alloc_timeout = max_alloc_timeout\n        self._lock_metadata = mp.Lock()\n        self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)\n        self._enqueued_size = mp.Value(ctypes.c_int64, 0, lock=True)\n        self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)\n        self._allocated_tensors: Dict[Handle, torch.Tensor] = {}\n        self.runtime_pid = os.getpid()\n\n        self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False)  # any ConnectionHandler -> runtime\n        self._lock_acquire_memory = mp.Lock()\n        self._memory_freed_event = mp.Event()\n\n    @property\n    def current_size_bytes(self) -> int:\n        return self._current_size.value\n\n    @current_size_bytes.setter\n    def current_size_bytes(self, value: int):\n        self._current_size.value = value\n\n    @property\n    def enqueued_size_bytes(self) -> int:\n        return self._enqueued_size.value\n\n    @enqueued_size_bytes.setter\n    def enqueued_size_bytes(self, value: int):\n        self._enqueued_size.value = value\n\n    @property\n    def bytes_left(self) -> int:\n        return self.max_size_bytes - self.current_size_bytes\n\n    @property\n    def handle_counter(self) -> int:\n        return self._handle_counter.value\n\n    @handle_counter.setter\n    def handle_counter(self, value: int):\n        self._handle_counter.value = value\n\n    @contextlib.asynccontextmanager\n    async def allocate_cache(\n        self, *descriptors: TensorDescriptor, timeout: float\n    ) -> AsyncContextManager[Sequence[Handle]]:\n        \"\"\"\n        Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.\n\n        :param descriptors: one or more tensors tensor of this size, dtype, etc\n        :param timeout: optional maximum time to wait for cache allocation; None (default) means no time limit\n\n        :note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices;\n          if not, it will count maximum tensor allocation across devices for the purposes of size limit\n\n        :note: This function should be called by connection handlers, it can be called concurrently from multiple processes.\n        Furthermore, it can be called concurrently with at most one use_cache call in runtime.\n        \"\"\"\n        assert os.getpid() != self.runtime_pid, \"must be called by a ConnectionHandler, not runtime\"\n        assert all(descr.device is not None for descr in descriptors), \"please specify allocated devices\"\n        if self.max_alloc_timeout is not None:\n            timeout = min(timeout, self.max_alloc_timeout)\n        max_alloc_size = self.get_allocation_size(*descriptors)\n\n        gib = 1024**3\n        cur_size, max_size = self.current_size_bytes, self.max_size_bytes\n        friendly_max_size = f\"{max_size / gib:.2f}\" if max_size != 2**64 - 1 else \"inf\"\n        logger.info(\n            f\"rpc_inference.wait_for_alloc(size={max_alloc_size / gib:.2f} GiB), \"\n            f\"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)\"\n        )\n\n        alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors, timeout=timeout))\n        try:\n            handles = await shield_and_wait(alloc_task)\n            logger.info(f\"rpc_inference.alloc_done(size={max_alloc_size / gib:.2f} GiB)\")\n            yield handles\n        finally:\n            self._free(max_alloc_size, alloc_task)\n\n    @staticmethod\n    def get_allocation_size(*descriptors: TensorDescriptor) -> int:\n        \"\"\"Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum\"\"\"\n        alloc_size_by_device = {}\n        for descr in descriptors:\n            tensor_size = descr.numel() * get_size_in_bytes(descr.dtype)\n            alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size\n        return max(alloc_size_by_device.values())\n\n    async def _schedule_alloc(\n        self, alloc_size: int, *descriptors: TensorDescriptor, timeout: Optional[float]\n    ) -> Sequence[Handle]:\n        \"\"\"\n        This method should be called inside asyncio.shield() because:\n            - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation\n        \"\"\"\n        try:\n            async with self._wait_for_free_memory(alloc_size, timeout):\n                with self._lock_metadata:\n                    handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))\n                    self.current_size_bytes += alloc_size\n                    self.handle_counter += len(handles)  # note: this will eventually overflow and it is okay\n                    self._pipe_send.send((handles, descriptors))\n                    return handles\n        except TimeoutError:\n            raise AllocationFailed(f\"Could not allocate {alloc_size} (timeout={timeout})\")\n\n    @contextlib.asynccontextmanager\n    async def _wait_for_free_memory(self, alloc_size: int, timeout: Optional[float]):\n        start_time = time.perf_counter()\n        loop = asyncio.get_event_loop()\n\n        with self._enqueued_size.get_lock():\n            self._enqueued_size.value += alloc_size\n        allocated = False\n        try:\n            context_manager = async_timeout.timeout(timeout) if timeout != 0 else contextlib.AsyncExitStack()\n            # contextlib.AsyncExitStack() is used as a null context here\n            async with context_manager:\n                if timeout == 0 and self.current_size_bytes + self.enqueued_size_bytes > self.max_size_bytes:\n                    raise AllocationFailed(f\"Could not allocate {alloc_size} bytes immediately: out of memory\")\n                async with enter_asynchronously(self._lock_acquire_memory):\n                    if self.current_size_bytes + alloc_size > self.max_size_bytes:\n                        if timeout == 0:\n                            raise AllocationFailed(f\"Could not allocate {alloc_size} bytes immediately: out of memory\")\n                        elapsed_time = time.perf_counter() - start_time\n                        remaining_timeout = max(0.0, timeout - elapsed_time) if timeout is not None else None\n                        await loop.run_in_executor(None, self._wait_until_available, alloc_size, remaining_timeout)\n\n                allocated = True\n                with self._enqueued_size.get_lock():\n                    self._enqueued_size.value -= alloc_size\n                yield\n        except asyncio.TimeoutError:\n            raise AllocationFailed(f\"Could not allocate {alloc_size} within {timeout} seconds\")\n        finally:\n            if not allocated:\n                with self._enqueued_size.get_lock():\n                    self._enqueued_size.value -= alloc_size\n\n    def _free(self, alloc_size: int, alloc_task: asyncio.Task):\n        if alloc_task.exception() is not None:\n            return\n        handles = alloc_task.result()\n\n        with self._lock_metadata:\n            self._pipe_send.send((handles, None))  # signal runtime to free these handles\n            self.current_size_bytes -= alloc_size\n        self._memory_freed_event.set()\n\n    def _wait_until_available(self, allocated_size: int, timeout: Optional[float] = None):\n        # note: this function should only be called inside _lock_acquire_memory!\n        if allocated_size > self.max_size_bytes:\n            raise AllocationFailed(\n                f\"Could not allocate {allocated_size} bytes, max cache size = {self.max_size_bytes} bytes\"\n            )\n        timeout = timeout if timeout != float(\"inf\") else None\n        deadline = None if timeout is None else time.perf_counter() + timeout\n        while self.current_size_bytes + allocated_size > self.max_size_bytes:\n            remaining_time = None if timeout is None else deadline - time.perf_counter()\n            if not self._memory_freed_event.wait(remaining_time):\n                raise AllocationFailed(\n                    f\"Server's attention cache is full, failed to allocate {allocated_size} bytes in {timeout} seconds\"\n                )\n            self._memory_freed_event.clear()\n\n    @contextlib.contextmanager\n    def use_cache(self, *handles: Handle) -> Sequence[torch.Tensor]:\n        \"\"\"\n        Return one or more tensors previously allocated with allocate_cache,\n\n        :note: This method is called by ModuleBackend in runtime: a single process with NO process parallelism.\n        However, runtime may call use_cache concurrently with one or more connection handlers calling allocate_cache\n        \"\"\"\n        assert os.getpid() == self.runtime_pid\n        # note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here\n\n        # read creation/deletion requests from connection handlers\n        while self._pipe_recv.poll():\n            recv_handles, recv_data = self._pipe_recv.recv()\n            if recv_data is not None:  # create new tensors\n                assert len(recv_handles) == len(recv_data)\n                for handle, descr in zip(recv_handles, recv_data):\n                    self._allocated_tensors[handle] = descr.make_zeros()\n                    assert handle in self._allocated_tensors, f\"Sanity check failed: no such handle ({handle})\"\n            else:  # delete tensors by handle\n                for handle in recv_handles:\n                    if handle not in self._allocated_tensors:\n                        logger.warning(\n                            f\"Sanity check failed: asked to delete handle {handle}, but there is no such handle\"\n                        )\n                    self._allocated_tensors.pop(handle, None)\n        yield tuple(self._allocated_tensors[handle] for handle in handles)\n\n\nclass AllocationFailed(Exception):\n    pass\n"
  },
  {
    "path": "src/petals/server/reachability.py",
    "content": "import asyncio\nimport math\nimport threading\nimport time\nfrom concurrent.futures import Future\nfrom contextlib import asynccontextmanager\nfrom functools import partial\nfrom typing import Optional\n\nimport requests\nfrom hivemind.dht import DHT, DHTNode\nfrom hivemind.moe.client.remote_expert_worker import RemoteExpertWorker\nfrom hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase\nfrom hivemind.proto import dht_pb2\nfrom hivemind.utils import get_logger\n\nfrom petals.constants import REACHABILITY_API_URL\n\nlogger = get_logger(__name__)\n\n\ndef validate_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = 15) -> None:\n    \"\"\"verify that your peer is reachable from a (centralized) validator, whether directly or through a relay\"\"\"\n    for attempt_no in range(math.floor(wait_time / retry_delay) + 1):\n        try:\n            r = requests.get(f\"{REACHABILITY_API_URL}/api/v1/is_reachable/{peer_id}\", timeout=10)\n            r.raise_for_status()\n            response = r.json()\n\n            if response[\"success\"]:\n                logger.info(\"Server is reachable from the Internet. It will appear at https://health.petals.dev soon\")\n                return\n\n            if attempt_no == 0:\n                # Usually, libp2p manages to set up relays before we finish loading blocks.\n                # In other cases, we may need to wait for up to `wait_time` seconds before it's done.\n                logger.info(\"Detected a NAT or a firewall, connecting to libp2p relays. This takes a few minutes\")\n            time.sleep(retry_delay)\n        except Exception as e:\n            logger.warning(f\"Skipping reachability check because health.petals.dev is down: {repr(e)}\")\n            return\n\n    raise RuntimeError(\n        f\"Server has not become reachable from the Internet:\\n\\n\"\n        f\"{response['message']}\\n\\n\"\n        f\"You need to fix your port forwarding and/or firewall settings. How to do that:\\n\\n\"\n        f\"    1. Choose a specific port for the Petals server, for example, 31337.\\n\"\n        f\"    2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\\n\"\n        f\"    3. Add these arguments to explicitly announce your IP address and port to other peers:\\n\"\n        f\"        python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\\n\"\n        f\"    4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\\n\"\n    )\n\n\ndef check_direct_reachability(max_peers: int = 5, threshold: float = 0.5, **kwargs) -> Optional[bool]:\n    \"\"\"test if your peer is accessible by others in the swarm with the specified network options in **kwargs\"\"\"\n\n    async def _check_direct_reachability():\n        target_dht = await DHTNode.create(client_mode=True, **kwargs)\n        try:\n            protocol = ReachabilityProtocol(probe=target_dht.protocol.p2p)\n            async with protocol.serve(target_dht.protocol.p2p):\n                successes = requests = 0\n                for remote_peer in list(target_dht.protocol.routing_table.peer_id_to_uid.keys()):\n                    probe_available = await protocol.call_check(remote_peer=remote_peer, check_peer=target_dht.peer_id)\n                    if probe_available is None:\n                        continue  # remote peer failed to check probe\n                    successes += probe_available\n                    requests += 1\n                    if requests >= max_peers:\n                        break\n\n            logger.debug(f\"Direct reachability: {successes}/{requests}\")\n            return (successes / requests) >= threshold if requests > 0 else None\n        finally:\n            await target_dht.shutdown()\n\n    return RemoteExpertWorker.run_coroutine(_check_direct_reachability())\n\n\nSTRIPPED_PROBE_ARGS = dict(\n    dht_mode=\"client\", use_relay=False, auto_nat=False, nat_port_map=False, no_listen=True, startup_timeout=60\n)\n\n\nclass ReachabilityProtocol(ServicerBase):\n    \"\"\"Mini protocol to test if a locally running peer is accessible by other devices in the swarm\"\"\"\n\n    def __init__(self, *, probe: Optional[P2P] = None, wait_timeout: float = 5.0):\n        self.probe = probe\n        self.wait_timeout = wait_timeout\n        self._event_loop = self._stop = None\n\n    async def call_check(self, remote_peer: PeerID, *, check_peer: PeerID) -> Optional[bool]:\n        \"\"\"Returns True if remote_peer can reach check_peer, False if it cannot, None if it did not respond\"\"\"\n        try:\n            request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=check_peer.to_bytes()))\n            timeout = self.wait_timeout if check_peer == remote_peer else self.wait_timeout * 2\n            response = await self.get_stub(self.probe, remote_peer).rpc_check(request, timeout=timeout)\n            logger.debug(f\"call_check(remote_peer={remote_peer}, check_peer={check_peer}) -> {response.available}\")\n            return response.available\n        except Exception as e:\n            logger.debug(f\"Requested {remote_peer} to check {check_peer}, but got:\", exc_info=True)\n            return None\n\n    async def rpc_check(self, request: dht_pb2.PingRequest, context: P2PContext) -> dht_pb2.PingResponse:\n        \"\"\"Help another peer to check its reachability\"\"\"\n        response = dht_pb2.PingResponse(available=True)\n        check_peer = PeerID(request.peer.node_id)\n        if check_peer != context.local_id:  # remote peer wants us to check someone other than ourselves\n            response.available = await self.call_check(check_peer, check_peer=check_peer) is True\n            logger.info(\n                f\"reachability.rpc_check(remote_peer=...{str(context.remote_id)[-6:]}, \"\n                f\"check_peer=...{str(check_peer)[-6:]}) -> {response.available}\"\n            )\n        return response\n\n    @asynccontextmanager\n    async def serve(self, p2p: P2P):\n        try:\n            await self.add_p2p_handlers(p2p)\n            yield self\n        finally:\n            await self.remove_p2p_handlers(p2p)\n\n    @classmethod\n    def attach_to_dht(cls, dht: DHT, await_ready: bool = False, **kwargs) -> Optional[\"ReachabilityProtocol\"]:\n        protocol = cls(**kwargs)\n        ready = Future()\n\n        async def _serve_with_probe():\n            try:\n                common_p2p = await dht.replicate_p2p()\n                protocol._event_loop = asyncio.get_event_loop()\n                protocol._stop = asyncio.Event()\n\n                initial_peers = [str(addr) for addr in await common_p2p.get_visible_maddrs(latest=True)]\n                for info in await common_p2p.list_peers():\n                    initial_peers.extend(f\"{addr}/p2p/{info.peer_id}\" for addr in info.addrs)\n                protocol.probe = await P2P.create(initial_peers, **STRIPPED_PROBE_ARGS)\n\n                ready.set_result(True)\n                logger.debug(\"Reachability service started\")\n\n                async with protocol.serve(common_p2p):\n                    await protocol._stop.wait()\n            except Exception as e:\n                logger.debug(\"Reachability service failed:\", exc_info=True)\n\n                if not ready.done():\n                    ready.set_exception(e)\n            finally:\n                if protocol is not None and protocol.probe is not None:\n                    await protocol.probe.shutdown()\n                logger.debug(\"Reachability service shut down\")\n\n        threading.Thread(target=partial(asyncio.run, _serve_with_probe()), daemon=True).start()\n        if await_ready:\n            ready.result()  # Propagates startup exceptions, if any\n        return protocol\n\n    def shutdown(self):\n        if self._event_loop is not None and self._stop is not None:\n            self._event_loop.call_soon_threadsafe(self._stop.set)\n"
  },
  {
    "path": "src/petals/server/server.py",
    "content": "from __future__ import annotations\n\nimport gc\nimport math\nimport multiprocessing as mp\nimport os\nimport random\nimport sys\nimport threading\nimport time\nfrom typing import Dict, List, Optional, Sequence, Union\n\nimport hivemind\nimport psutil\nimport torch\nimport torch.mps\nfrom hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time\nfrom hivemind.moe.server.layers import add_custom_models_from_file\nfrom hivemind.moe.server.runtime import Runtime\nfrom hivemind.proto.runtime_pb2 import CompressionType\nfrom hivemind.utils.logging import get_logger\nfrom transformers import PretrainedConfig\n\nimport petals\nfrom petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS\nfrom petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState, parse_uid\nfrom petals.server import block_selection\nfrom petals.server.backend import TransformerBackend, merge_inference_pools_inplace\nfrom petals.server.block_utils import get_block_size, resolve_block_dtype\nfrom petals.server.from_pretrained import load_pretrained_block\nfrom petals.server.handler import TransformerConnectionHandler\nfrom petals.server.memory_cache import MemoryCache\nfrom petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability\nfrom petals.server.throughput import get_dtype_name, get_server_throughput\nfrom petals.utils.auto_config import AutoDistributedConfig\nfrom petals.utils.convert_block import QuantType, check_device_balance, convert_block\nfrom petals.utils.dht import declare_active_modules, get_remote_module_infos\nfrom petals.utils.misc import get_size_in_bytes\nfrom petals.utils.ping import PingAggregator\nfrom petals.utils.random import sample_up_to\nfrom petals.utils.version import get_compatible_model_repo\n\nlogger = get_logger(__name__)\n\n\nclass Server:\n    \"\"\"\n    Runs ModuleContainer, periodically checks that the network is balanced,\n    restarts the ModuleContainer with other layers if the imbalance is significant\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        initial_peers: List[str],\n        dht_prefix: Optional[str],\n        converted_model_name_or_path: str,\n        public_name: Optional[str] = None,\n        throughput: Union[float, str],\n        num_blocks: Optional[int] = None,\n        block_indices: Optional[str] = None,\n        num_handlers: int = 8,\n        inference_max_length: Optional[int] = None,\n        min_batch_size: int = 1,\n        max_batch_size: Optional[int] = None,\n        max_chunk_size_bytes: int = 256 * 1024 * 1024,\n        max_alloc_timeout: float = 600,\n        attn_cache_tokens: Optional[int] = None,\n        torch_dtype: str = \"auto\",\n        revision: Optional[str] = None,\n        cache_dir: Optional[str] = None,\n        max_disk_space: Optional[int] = None,\n        device: Optional[Union[str, torch.device]] = None,\n        compression=CompressionType.NONE,\n        stats_report_interval: Optional[int] = None,\n        custom_module_path=None,\n        update_period: float = 60,\n        expiration: Optional[float] = None,\n        request_timeout: float = 3 * 60,\n        session_timeout: float = 30 * 60,\n        step_timeout: float = 5 * 60,\n        prefetch_batches: int = 1,\n        sender_threads: int = 1,\n        balance_quality: float = 0.75,\n        mean_balance_check_period: float = 120,\n        mean_block_selection_delay: float = 5,\n        token: Optional[Union[str, bool]] = None,\n        quant_type: Optional[QuantType] = None,\n        tensor_parallel_devices: Optional[Sequence[torch.device]] = None,\n        skip_reachability_check: bool = False,\n        reachable_via_relay: Optional[bool] = None,\n        use_relay: bool = True,\n        use_auto_relay: bool = True,\n        adapters: Sequence[str] = (),\n        **kwargs,\n    ):\n        \"\"\"Create a server with one or more bloom blocks. See run_server.py for documentation.\"\"\"\n\n        converted_model_name_or_path = get_compatible_model_repo(converted_model_name_or_path)\n        self.converted_model_name_or_path = converted_model_name_or_path\n\n        self.num_handlers = num_handlers\n        self.compression = compression\n        self.stats_report_interval, self.update_period = stats_report_interval, update_period\n        self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads\n        self.revision, self.token = revision, token\n\n        if custom_module_path is not None:\n            add_custom_models_from_file(custom_module_path)\n\n        self.block_config = AutoDistributedConfig.from_pretrained(\n            converted_model_name_or_path,\n            use_auth_token=token,\n            revision=revision,\n        )\n\n        if dht_prefix is None:\n            dht_prefix = self.block_config.dht_prefix\n        assert UID_DELIMITER not in dht_prefix and CHAIN_DELIMITER not in dht_prefix, (\n            f\"DHT prefix should not contain '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'. \"\n            f\"Please specify another --dht_prefix manually when starting a server\"\n        )\n        self.dht_prefix = dht_prefix\n\n        if expiration is None:\n            expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)\n        self.expiration = expiration\n\n        self.request_timeout = request_timeout\n        self.session_timeout, self.step_timeout = session_timeout, step_timeout\n\n        self.module_uids = [\n            f\"{self.dht_prefix}{UID_DELIMITER}{block_index}\"\n            for block_index in range(self.block_config.num_hidden_layers)\n        ]\n\n        if reachable_via_relay is None:\n            is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs)\n            reachable_via_relay = is_reachable is False  # if can't check reachability (returns None), run a full peer\n            logger.info(f\"This server is accessible {'via relays' if reachable_via_relay else 'directly'}\")\n        self.dht = DHT(\n            initial_peers=initial_peers,\n            start=True,\n            num_workers=self.block_config.num_hidden_layers,\n            use_relay=use_relay,\n            use_auto_relay=use_auto_relay,\n            client_mode=reachable_via_relay,\n            **kwargs,\n        )\n        self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not reachable_via_relay else None\n\n        visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]\n        if initial_peers == PUBLIC_INITIAL_PEERS:\n            logger.info(\"Connecting to the public swarm\")\n        else:\n            logger.info(f\"Connecting to a private swarm, initial peers: {initial_peers}\")\n        logger.info(f\"Running a server on {visible_maddrs_str}\")\n        self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS\n\n        if device is None:\n            if torch.cuda.is_available():\n                device = \"cuda\"\n            elif torch.backends.mps.is_available():\n                device = \"mps\"\n            else:\n                device = \"cpu\"\n        device = torch.device(device)\n        if device.type == \"cuda\" and device.index is None:\n            device = torch.device(device.type, index=0)\n        self.device = device\n\n        torch_dtype = resolve_block_dtype(self.block_config, DTYPE_MAP[torch_dtype])\n        if device.type == \"cpu\" and torch_dtype == torch.float16:\n            raise ValueError(\n                f\"Type float16 is not supported on CPU. Please use --torch_dtype float32 or --torch_dtype bfloat16\"\n            )\n        if device.type == \"mps\" and torch_dtype == torch.bfloat16:\n            logger.warning(f\"Type bfloat16 is not supported on MPS, using float16 instead\")\n            torch_dtype = torch.float16\n        self.torch_dtype = torch_dtype\n\n        if tensor_parallel_devices is None:\n            tensor_parallel_devices = (device,)\n        self.tensor_parallel_devices = tuple(map(torch.device, tensor_parallel_devices))\n        if len(self.tensor_parallel_devices) > 1:\n            logger.info(f\"Model weights will be split between {', '.join(tensor_parallel_devices)}\")\n            check_device_balance(self.tensor_parallel_devices)\n\n        if quant_type is None:\n            quant_type = QuantType.NF4 if device.type == \"cuda\" else QuantType.NONE\n        self.quant_type = quant_type\n        logger.info(f\"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format\")\n\n        is_multiquery_attn = self.block_config.num_key_value_groups > 1\n        if max_batch_size is None:\n            max_batch_size = 8192 if is_multiquery_attn else 2048\n        if inference_max_length is None:\n            inference_max_length = 8192 if is_multiquery_attn else 2048\n        self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size\n        self.inference_max_length = inference_max_length\n        self.max_chunk_size_bytes = max_chunk_size_bytes\n        self.max_alloc_timeout = max_alloc_timeout\n\n        # For attention cache in GPU or RAM\n        if attn_cache_tokens is None:\n            attn_cache_tokens = 16384 if is_multiquery_attn else 4096\n        cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens\n        cache_values_per_block //= self.block_config.num_key_value_groups\n        self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype)\n\n        # For disk cache\n        self.cache_dir = cache_dir\n        self.max_disk_space = max_disk_space\n        self.adapters = adapters\n\n        assert num_blocks is None or block_indices is None, \"Please specify num_blocks or block_indices, not both\"\n        if num_blocks is None and block_indices is None:\n            num_blocks = self._choose_num_blocks()\n        if num_blocks is not None:\n            num_blocks = min(num_blocks, self.block_config.num_hidden_layers)\n        if block_indices is not None:\n            try:\n                start_block, end_block = [int(index.strip()) for index in block_indices.split(\":\")]\n            except Exception as e:\n                raise ValueError(f\"Failed to parse `--block_indices {block_indices}`, must be start:end (e.g. 0:18)\")\n            block_indices = range(start_block, end_block)\n            num_blocks = len(block_indices)\n        self.strict_block_indices, self.num_blocks = block_indices, num_blocks\n\n        gib = 1024**3\n        self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks\n        logger.info(f\"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB\")\n\n        assert isinstance(throughput, float) or throughput in [\"auto\", \"eval\", \"dry_run\"]\n        if throughput in [\"auto\", \"eval\", \"dry_run\"]:\n            force_eval = throughput in [\"eval\", \"dry_run\"]\n            throughput_info = get_server_throughput(\n                converted_model_name_or_path,\n                self.block_config,\n                device,\n                torch_dtype,\n                num_blocks=num_blocks,\n                quant_type=quant_type,\n                tensor_parallel_devices=self.tensor_parallel_devices,\n                reachable_via_relay=reachable_via_relay,\n                force_eval=force_eval,\n                cache_dir=cache_dir,\n            )\n            if throughput == \"dry_run\":\n                logger.info(\"Finished estimating throughput, exiting\")\n                sys.exit(0)\n        else:\n            throughput_info = {\"throughput\": throughput}\n        self.server_info = ServerInfo(\n            state=ServerState.JOINING,\n            public_name=public_name,\n            version=petals.__version__,\n            adapters=tuple(adapters),\n            torch_dtype=str(torch_dtype).replace(\"torch.\", \"\"),\n            quant_type=quant_type.name.lower(),\n            using_relay=reachable_via_relay,\n            **throughput_info,\n        )\n        self.model_info = ModelInfo(num_blocks=self.block_config.num_hidden_layers)\n        if not os.path.isdir(converted_model_name_or_path):\n            self.model_info.repository = \"https://huggingface.co/\" + converted_model_name_or_path\n\n        self.balance_quality = balance_quality\n        self.mean_balance_check_period = mean_balance_check_period\n        self.mean_block_selection_delay = mean_block_selection_delay\n\n        self.module_container = None\n        self.stop = threading.Event()\n\n    def _choose_num_blocks(self) -> int:\n        assert self.device.type in (\"cuda\", \"mps\"), (\n            \"GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. \"\n            \"CPU-only servers in the public swarm are discouraged since they are much slower\"\n        )\n        num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1\n\n        if num_devices > 1:\n            assert self.device.type == \"cuda\", f\"Tensor parallelism is not supported on {self.device.type.upper()}\"\n            memory_per_device = tuple(\n                torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices\n            )\n            total_memory = min(memory_per_device) * num_devices\n            if max(memory_per_device) / min(memory_per_device) > 1.5:\n                raise ValueError(\n                    \"GPU devices have highly uneven memory, which makes tensor parallelism inefficient. \"\n                    \"Please launch individual servers on each GPU or set --num_blocks manually to \"\n                    \"override this exception.\"\n                )\n        elif self.device.type == \"cuda\":\n            total_memory = torch.cuda.get_device_properties(self.device).total_memory\n        else:\n            total_memory = psutil.virtual_memory().total\n\n        gib = 1024**3\n        # Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models)\n        autograd_memory = 2 * gib * num_devices / 14336 * self.block_config.hidden_size\n\n        block_size = get_block_size(self.block_config, \"memory\", dtype=self.torch_dtype, quant_type=self.quant_type)\n        total_memory_per_block = block_size + self._cache_bytes_per_block\n        if self.adapters:\n            # Delay import of petals.utils.peft to avoid unnecessary import of bitsandbytes\n            from petals.utils.peft import estimate_adapter_memory_per_block\n\n            total_memory_per_block += estimate_adapter_memory_per_block(\n                self.block_config,\n                self.torch_dtype,\n                self.adapters,\n                token=self.token,\n                cache_dir=self.cache_dir,\n                max_disk_space=self.max_disk_space,\n            )\n\n        num_blocks = math.floor((total_memory - autograd_memory) / total_memory_per_block)\n        assert num_blocks >= 1, \"Your GPU does not have enough memory to serve at least one block\"\n\n        num_blocks = min(num_blocks, self.block_config.num_hidden_layers)\n        logger.info(\n            f\"Server will fill your GPU memory with {num_blocks} transformer blocks. \"\n            f\"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually\"\n        )\n        return num_blocks\n\n    def run(self):\n        while True:\n            block_indices = self._choose_blocks()\n            self.module_container = ModuleContainer.create(\n                dht=self.dht,\n                dht_prefix=self.dht_prefix,\n                converted_model_name_or_path=self.converted_model_name_or_path,\n                block_config=self.block_config,\n                attn_cache_bytes=self.attn_cache_bytes,\n                server_info=self.server_info,\n                model_info=self.model_info,\n                block_indices=block_indices,\n                num_handlers=self.num_handlers,\n                min_batch_size=self.min_batch_size,\n                max_batch_size=self.max_batch_size,\n                max_chunk_size_bytes=self.max_chunk_size_bytes,\n                max_alloc_timeout=self.max_alloc_timeout,\n                inference_max_length=self.inference_max_length,\n                torch_dtype=self.torch_dtype,\n                cache_dir=self.cache_dir,\n                max_disk_space=self.max_disk_space,\n                device=self.device,\n                compression=self.compression,\n                stats_report_interval=self.stats_report_interval,\n                update_period=self.update_period,\n                expiration=self.expiration,\n                request_timeout=self.request_timeout,\n                session_timeout=self.session_timeout,\n                step_timeout=self.step_timeout,\n                prefetch_batches=self.prefetch_batches,\n                sender_threads=self.sender_threads,\n                revision=self.revision,\n                token=self.token,\n                quant_type=self.quant_type,\n                tensor_parallel_devices=self.tensor_parallel_devices,\n                should_validate_reachability=self.should_validate_reachability,\n                start=True,\n            )\n            try:\n                self.module_container.ready.wait()\n\n                while True:\n                    timeout = random.random() * 2 * self.mean_balance_check_period\n                    if self.stop.wait(timeout):\n                        return\n\n                    if not self.module_container.is_healthy():\n                        logger.warning(\"One of subprocesses crashed, restarting the server\")\n                        break\n\n                    if self._should_choose_other_blocks():\n                        logger.info(\"Swarm is imbalanced, server will load other blocks\")\n                        break  # Stop serving this set of modules\n            finally:\n                self.module_container.shutdown()\n\n            self._clean_memory_and_fds()\n\n    def _clean_memory_and_fds(self):\n        self.module_container = None\n        gc.collect()  # In particular, this closes unused file descriptors\n\n        if self.device.type == \"cuda\":\n            torch.cuda.empty_cache()\n\n            allocated_vram = torch.cuda.memory_allocated(self.device)\n            reserved_vram = torch.cuda.memory_reserved(self.device)\n            gib = 1024**3\n            logger.info(\n                f\"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, \"\n                f\"{reserved_vram / gib:.1f} GiB reserved memory\"\n            )\n        elif self.device.type == \"mps\":\n            torch.mps.empty_cache()\n\n    def _choose_blocks(self) -> List[int]:\n        if self.strict_block_indices is not None:\n            return self.strict_block_indices\n\n        # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,\n        # this delay decreases the probability of a race condition while choosing the best blocks to serve.\n        time.sleep(random.random() * 2 * self.mean_block_selection_delay)\n        module_infos = get_remote_module_infos(self.dht, self.module_uids, latest=True)\n        return block_selection.choose_best_blocks(self.num_blocks, module_infos)\n\n    def _should_choose_other_blocks(self) -> bool:\n        if self.strict_block_indices is not None:\n            return False\n\n        module_infos = get_remote_module_infos(self.dht, self.module_uids, latest=True)\n        return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality)\n\n    def shutdown(self, timeout: Optional[float] = 5):\n        self.stop.set()\n        if self.module_container is not None and self.module_container.is_alive():\n            self.module_container.join(timeout)\n\n        if self.reachability_protocol is not None:\n            self.reachability_protocol.shutdown()\n        self.dht.shutdown()\n        self.dht.join()\n\n\nclass ModuleContainer(threading.Thread):\n    \"\"\"Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT.\"\"\"\n\n    # noinspection PyMethodOverriding\n    @classmethod\n    def create(\n        cls,\n        *,\n        dht: DHT,\n        dht_prefix: str,\n        converted_model_name_or_path: str,\n        block_config: PretrainedConfig,\n        attn_cache_bytes: int,\n        server_info: ServerInfo,\n        model_info: ModelInfo,\n        block_indices: List[int],\n        min_batch_size: int,\n        max_batch_size: int,\n        max_chunk_size_bytes: int,\n        max_alloc_timeout: float,\n        torch_dtype: torch.dtype,\n        cache_dir: str,\n        max_disk_space: int,\n        device: Union[str, torch.device],\n        compression: CompressionType,\n        update_period: float,\n        expiration: Optional[float],\n        revision: Optional[str],\n        token: Optional[Union[str, bool]],\n        quant_type: QuantType,\n        tensor_parallel_devices: Sequence[torch.device],\n        should_validate_reachability: bool,\n        **kwargs,\n    ) -> ModuleContainer:\n        module_uids = [f\"{dht_prefix}{UID_DELIMITER}{block_index}\" for block_index in block_indices]\n        memory_cache = MemoryCache(attn_cache_bytes, max_alloc_timeout)\n\n        server_info.state = ServerState.JOINING\n        dht_announcer = ModuleAnnouncerThread(\n            module_uids,\n            dht,\n            server_info,\n            model_info,\n            block_config=block_config,\n            memory_cache=memory_cache,\n            update_period=update_period,\n            expiration=expiration,\n            daemon=True,\n        )\n        dht_announcer.start()\n        logger.info(f\"Announced that blocks {block_indices} are joining\")\n\n        assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices)\n\n        blocks = {}\n        try:\n            for module_uid, block_index in zip(module_uids, block_indices):\n                block = load_pretrained_block(\n                    converted_model_name_or_path,\n                    block_index,\n                    config=block_config,\n                    torch_dtype=torch_dtype,\n                    revision=revision,\n                    token=token,\n                    cache_dir=cache_dir,\n                    max_disk_space=max_disk_space,\n                )\n                block = convert_block(\n                    block,\n                    block_index,\n                    block_config,\n                    tensor_parallel_devices,\n                    device,\n                    quant_type,\n                    adapters=server_info.adapters,\n                    freeze=True,\n                    token=token,\n                    cache_dir=cache_dir,\n                    max_disk_space=max_disk_space,\n                )\n                blocks[module_uid] = TransformerBackend(\n                    module_uid,\n                    block,\n                    config=block_config,\n                    memory_cache=memory_cache,\n                    backend_dtype=torch_dtype,\n                    max_chunk_size_bytes=max_chunk_size_bytes,\n                    args_schema=(\n                        BatchTensorDescriptor(\n                            1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression\n                        ),\n                    ),\n                    kwargs_schema={},\n                    outputs_schema=(\n                        BatchTensorDescriptor(\n                            1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression\n                        ),\n                    ),\n                    min_batch_size=min_batch_size,\n                    max_batch_size=max_batch_size,\n                )\n\n            merge_inference_pools_inplace(blocks)\n\n            if should_validate_reachability:\n                validate_reachability(dht.peer_id)\n        except:\n            logger.debug(\"Shutting down backends\")\n            for backend in blocks.values():\n                backend.shutdown()\n\n            dht_announcer.announce(ServerState.OFFLINE)\n            logger.info(f\"Announced that blocks {module_uids} are offline\")\n            raise\n\n        return cls(\n            dht,\n            dht_prefix,\n            blocks,\n            dht_announcer=dht_announcer,\n            server_info=server_info,\n            update_period=update_period,\n            expiration=expiration,\n            **kwargs,\n        )\n\n    def __init__(\n        self,\n        dht: DHT,\n        dht_prefix: str,\n        module_backends: Dict[str, TransformerBackend],\n        *,\n        inference_max_length: int,\n        num_handlers: int,\n        dht_announcer: ModuleAnnouncerThread,\n        server_info: ServerInfo,\n        update_period: float,\n        expiration: Optional[float] = None,\n        request_timeout: float,\n        session_timeout: float,\n        step_timeout: float,\n        start: bool,\n        **kwargs,\n    ):\n        super().__init__()\n\n        self.dht, self.module_backends = dht, module_backends\n        self.server_info, self.update_period, self.expiration = server_info, update_period, expiration\n\n        handler_event_queues = [mp.Queue() for _ in range(num_handlers)]\n        self.conn_handlers = [\n            TransformerConnectionHandler(\n                dht,\n                self.module_backends,\n                adapters=server_info.adapters,\n                dht_prefix=dht_prefix,\n                handler_event_queues=handler_event_queues,\n                handler_index=i,\n                inference_max_length=inference_max_length,\n                request_timeout=request_timeout,\n                session_timeout=session_timeout,\n                step_timeout=step_timeout,\n                quant_type=QuantType[server_info.quant_type.upper()],\n            )\n            for i in range(num_handlers)\n        ]\n\n        self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)\n        # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.\n\n        dht_announcer.announce(ServerState.ONLINE)\n        self.dht_announcer = dht_announcer\n\n        if start:\n            self.run_in_background(await_ready=True)\n\n    def run(self):\n        \"\"\"\n        Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,\n        runs Runtime (self.runtime) to process incoming requests.\n        \"\"\"\n        for handler in self.conn_handlers:\n            handler.run_in_background()\n\n        self.runtime.run()\n\n    def run_in_background(self, await_ready=True, timeout=None):\n        \"\"\"\n        Starts ModuleContainer in a background thread. if await_ready, this method will wait until the container\n        is ready to process incoming requests or for :timeout: seconds max.\n        \"\"\"\n        self.start()\n        if await_ready and not self.ready.wait(timeout=timeout):\n            raise TimeoutError(\"ModuleContainer didn't notify .ready in {timeout} seconds\")\n\n    @property\n    def ready(self) -> mp.synchronize.Event:\n        \"\"\"\n        An event (multiprocessing.Event) that is set when the container is ready to process requests.\n\n        Example\n        =======\n        >>> container.start()\n        >>> container.ready.wait(timeout=10)\n        >>> print(\"Container ready\" if container.ready.is_set() else \"Container didn't start in 10 seconds\")\n        \"\"\"\n        return self.runtime.ready  # mp.Event that is true if self is ready to process batches\n\n    def is_healthy(self) -> bool:\n        return all(handler.is_alive() for handler in self.conn_handlers) and all(\n            pool.is_alive() for pool in self.runtime.pools\n        )\n\n    def shutdown(self):\n        \"\"\"\n        Gracefully terminate the container, process-safe.\n        Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.\n        If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).\n        \"\"\"\n        self.dht_announcer.announce(ServerState.OFFLINE)\n        logger.info(f\"Announced that blocks {list(self.module_backends.keys())} are offline\")\n\n        self.ready.clear()\n\n        logger.debug(\"Shutting down connection handlers\")\n        for handler in self.conn_handlers:\n            handler.shutdown()\n\n        logger.debug(f\"Shutting down pools\")\n        for pool in self.runtime.pools:\n            if pool.is_alive():\n                pool.shutdown()\n\n        logger.debug(f\"Shutting down runtime\")\n        self.runtime.shutdown()\n\n        logger.debug(\"Shutting down backends\")\n        for backend in self.module_backends.values():\n            backend.shutdown()\n\n        logger.info(\"Module container shut down successfully\")\n\n\nclass ModuleAnnouncerThread(threading.Thread):\n    \"\"\"Periodically announces that this container hosts the specified modules, visible to all DHT peers\"\"\"\n\n    def __init__(\n        self,\n        module_uids: List[str],\n        dht: DHT,\n        server_info: ServerInfo,\n        model_info: ModelInfo,\n        *,\n        block_config: PretrainedConfig,\n        memory_cache: MemoryCache,\n        update_period: float,\n        expiration: float,\n        max_pinged: int = 5,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.module_uids = module_uids\n        self.dht = dht\n        self.server_info = server_info\n        self.model_info = model_info\n        self.memory_cache = memory_cache\n\n        self.bytes_per_token = block_config.hidden_size * get_size_in_bytes(DTYPE_MAP[server_info.torch_dtype])\n        self.bytes_per_token //= block_config.num_key_value_groups\n\n        self.update_period = update_period\n        self.expiration = expiration\n        self.trigger = threading.Event()\n\n        self.dht_prefix = parse_uid(module_uids[0])[0]\n        block_indices = [parse_uid(uid)[1] for uid in module_uids]\n        self.server_info.start_block = min(block_indices)\n        self.server_info.end_block = max(block_indices) + 1\n\n        self.max_pinged = max_pinged\n        self.next_uids = [\n            f\"{self.dht_prefix}{UID_DELIMITER}{i}\"\n            for i in range(self.server_info.start_block + 1, self.server_info.end_block + 1)\n        ]\n        self.ping_aggregator = PingAggregator(self.dht)\n\n    def run(self) -> None:\n        while True:\n            start_time = time.perf_counter()\n\n            self.server_info.cache_tokens_left = self.memory_cache.bytes_left // self.bytes_per_token\n            if self.server_info.state != ServerState.OFFLINE:\n                self._ping_next_servers()\n                self.server_info.next_pings = {\n                    peer_id.to_base58(): rtt for peer_id, rtt in self.ping_aggregator.to_dict().items()\n                }\n            else:\n                self.server_info.next_pings = None  # No need to ping if we're disconnecting\n\n            declare_active_modules(\n                self.dht,\n                self.module_uids,\n                self.server_info,\n                expiration_time=get_dht_time() + self.expiration,\n            )\n            if self.server_info.state == ServerState.OFFLINE:\n                break\n            if not self.dht_prefix.startswith(\"_\"):  # Not private\n                self.dht.store(\n                    key=\"_petals.models\",\n                    subkey=self.dht_prefix,\n                    value=self.model_info.to_dict(),\n                    expiration_time=get_dht_time() + self.expiration,\n                )\n\n            delay = self.update_period - (time.perf_counter() - start_time)\n            if delay < 0:\n                logger.warning(\n                    f\"Declaring blocks to DHT takes more than --update_period, consider increasing it (currently {self.update_period})\"\n                )\n            self.trigger.wait(max(delay, 0))\n            self.trigger.clear()\n\n    def announce(self, state: ServerState) -> None:\n        self.server_info.state = state\n        self.trigger.set()\n        if state == ServerState.OFFLINE:\n            self.join()\n\n    def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]:\n        module_infos = get_remote_module_infos(self.dht, self.next_uids, latest=True)\n        middle_servers = {peer_id for info in module_infos[:-1] for peer_id in info.servers}\n        pinged_servers = set(sample_up_to(middle_servers, self.max_pinged))\n        pinged_servers.discard(self.dht.peer_id)\n        # Sample servers hosting the block after the last one (most likely continuations) separately\n        pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged))\n        self.ping_aggregator.ping(list(pinged_servers))\n\n\nclass RuntimeWithDeduplicatedPools(Runtime):\n    \"\"\"A version of hivemind.moe.server.runtime.Runtime that allows multiple backends to reuse a task pool\"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.pools = tuple(set(self.pools))\n"
  },
  {
    "path": "src/petals/server/task_pool.py",
    "content": "import ctypes\nimport multiprocessing as mp\nimport threading\nimport time\nfrom concurrent.futures._base import PENDING\nfrom dataclasses import dataclass, field\nfrom queue import PriorityQueue\nfrom typing import Any, List, Optional, Sequence, Tuple, Union\n\nimport torch\nfrom hivemind import get_logger\nfrom hivemind.utils.mpfuture import ALL_STATES, MPFuture\n\nlogger = get_logger(__name__)\n\n\n@dataclass(order=True, frozen=True)\nclass Task:\n    priority: float\n    time_submitted: float\n    future: MPFuture = field(compare=False)\n    args: Sequence[torch.Tensor] = field(compare=False)\n\n    @property\n    def uid(self) -> int:\n        return self.future._uid\n\n\nclass PrioritizedTaskPool(threading.Thread):\n    \"\"\"\n    Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then\n    returns results (or exception) to the corresponding ConnectionHandler. Runs a background process.\n    A single PrioritizedTaskPool services a specific function (e.g. layer1.forward, layer2.forward or layer1.backward)\n\n    :note: unlike hivemind.moe TaskPool, this pool does *not* combine incoming requests into batches.\n      This would require grouping requests of different length.\n\n    :param process_func: function to be applied to every formed batch; called by Runtime\n        Note that process_func should accept only positional args (Tensors) and return a flat tuple of Tensors\n    :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)\n         Measured in the total number of tokens (i.e. batch size * sequence length)\n\n    :param name: pool name, used for logging\n    :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more\n    :param device: if specified, input tensors will be moved to that device by default\n    :param start: if True, start automatically at the end of __init__\n    \"\"\"\n\n    def __init__(\n        self,\n        process_func: callable,\n        max_batch_size: int,\n        name: str,\n        min_batch_size=1,\n        device: Optional[torch.device] = None,\n        daemon=True,\n        start=False,\n    ):\n        super().__init__(daemon=daemon, name=name)\n        self.process_func = process_func\n        # the lower the priority is, the more urgent it is to process this pool\n        self._priority = mp.Value(ctypes.c_double, 1.0)\n\n        self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size\n        self.device = device\n\n        self.submitted_tasks = mp.SimpleQueue()  # interaction with ConnectionHandlers\n        self._ordered_tasks = PriorityQueue()  # interaction with Runtime - only valid inside Runtime\n\n        self._dispatched_tasks = {}\n        self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)\n        self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0)\n        self.priority = float(\"inf\"), float(\"inf\")  # (first task priority, first task timestamp)\n\n        if start:\n            self.start()\n\n    def run(self):\n        \"\"\"Read tasks from incoming queue and put them into a local priority queue\"\"\"\n        while True:\n            task = self.submitted_tasks.get()\n            if task is None:\n                logger.debug(\"Shutting down prioritizer thread\")\n                break\n\n            self._ordered_tasks.put(task, block=True)\n\n    def terminate(self):\n        \"\"\"An alias for hivemind.Runtime that assumes that each TaskPool is a process\"\"\"\n        self.shutdown()\n\n    def shutdown(self):\n        self.submitted_tasks.put(None)  # Shuts down self.run()\n\n    def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture:\n        \"\"\"Add task to this pool's queue, return Future for its output\"\"\"\n        future = MPFuture()\n        # Remove shmem from MPFuture. This disables the .cancel() feature but\n        # saves the server from \"could not unlink the shared memory file\" crashes during rebalancing\n        future._shared_state_code = torch.tensor([ALL_STATES.index(PENDING)], dtype=torch.uint8)\n\n        task = Task(priority, time.monotonic(), future, args)\n        if self.get_task_size(task) > self.max_batch_size:\n            exc = ValueError(f\"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed\")\n            task.future.set_exception(exc)\n        else:\n            self.submitted_tasks.put(task)\n            self.batch_sender.send(None)  # use this pipe to count the number of unfinished batches\n            if (task.priority, task.time_submitted) < self.priority:\n                self.priority = (task.priority, task.time_submitted)\n        return task.future\n\n    def get_task_size(self, task: Task) -> int:\n        \"\"\"compute task processing complexity; defaults to the total number of tokens\"\"\"\n        if task.args and task.args[0].ndim >= 2:\n            return task.args[0].shape[0] * task.args[0].shape[1]\n        return 1\n\n    def load_batch_to_runtime(\n        self, timeout: Optional[float] = None, device: Optional[torch.device] = None\n    ) -> Tuple[Any, List[torch.Tensor]]:\n        \"\"\"receive next batch of arrays\"\"\"\n        device = device if device is not None else self.device\n        task = self._ordered_tasks.get(block=True, timeout=timeout)\n        batch_inputs = [_move_to_device_if_tensor(arg, device, share_memory=False) for arg in task.args]\n        self._dispatched_tasks[task.uid] = task\n        self.batch_receiver.recv()  # reduce the number of active batches\n        if not self._ordered_tasks.empty():\n            first_remaining_task: Task = self._ordered_tasks.queue[0]\n            self.priority = (first_remaining_task.priority, first_remaining_task.time_submitted)\n        return task.uid, batch_inputs\n\n    def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]):\n        \"\"\"send results for a processed batch, previously loaded through load_batch_to_runtime\"\"\"\n        batch_outputs = [_move_to_device_if_tensor(output, device=\"cpu\", share_memory=True) for output in batch_outputs]\n        task = self._dispatched_tasks.pop(uid, None)\n        if task is None:\n            logger.error(\n                f\"Internal error: task task with index {uid} is missing from the dictionary; \" f\"Could not set result\"\n            )\n        else:\n            task.future.set_result(batch_outputs)\n\n    def send_exception_from_runtime(self, uid: int, exception: BaseException):\n        task = self._dispatched_tasks.pop(uid, None)\n        if task is None:\n            logger.error(\n                f\"Internal error: task task with index {uid} is missing from the dictionary; \"\n                f\"Could not set exception {exception}\"\n            )\n        else:\n            task.future.set_exception(exception)\n\n    @property\n    def empty(self):\n        return not self.batch_receiver.poll()\n\n    @property\n    def priority(self) -> Tuple[float, float]:\n        \"\"\"The priority of this pool equals the (priority, timestamp) of the most important task in it.\"\"\"\n        return float(self._priority.value), float(self._oldest_undispatched_timestamp.value)\n\n    @priority.setter\n    def priority(self, item: Tuple[float, float]):\n        assert len(item) == 2\n        self._priority.value = float(item[0])\n        self._oldest_undispatched_timestamp.value = float(item[1])\n\n\ndef _move_to_device_if_tensor(arg: Any, device: Union[torch.device, str], share_memory: bool = False):\n    if isinstance(arg, torch.Tensor):\n        arg = arg.detach().to(device, non_blocking=not share_memory).requires_grad_(arg.requires_grad)\n        # note: it is important that non_blocking is disabled if share_memory=True; using share_memory on a tensor\n        # produced by a non-blocking copy will result in undefined behavior (depending on your gpu speed)\n        if share_memory:\n            arg = arg.share_memory_()\n    return arg\n"
  },
  {
    "path": "src/petals/server/task_prioritizer.py",
    "content": "from abc import ABC, abstractmethod\n\nimport torch\n\n\nclass TaskPrioritizerBase(ABC):\n    \"\"\"Abstract class for TaskPrioritizer whose responsibility is to evaluate task priority\"\"\"\n\n    @abstractmethod\n    def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:\n        \"\"\"Evaluates task value by the amount of points given, task input and additional kwargs. Lower priority is better\"\"\"\n        pass\n\n\nclass DummyTaskPrioritizer(TaskPrioritizerBase):\n    def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:\n        # Inference steps go first since they are more latency-sensitive\n        if kwargs.get(\"type\") == \"inference\":\n            return 1.0\n        return 2.0  # Forward, backward\n"
  },
  {
    "path": "src/petals/server/throughput.py",
    "content": "import fcntl\nimport json\nimport math\nimport multiprocessing as mp\nimport os\nimport time\nfrom collections import Counter\nfrom pathlib import Path\nfrom typing import Dict, Optional, Sequence, Union\n\nimport torch\nimport torch.mps\nfrom hivemind.utils.logging import get_logger\nfrom transformers import PretrainedConfig\n\nfrom petals.server.block_utils import get_model_block, resolve_block_dtype\nfrom petals.utils.convert_block import QuantType, convert_block\nfrom petals.utils.disk_cache import DEFAULT_CACHE_DIR\nfrom petals.utils.misc import DUMMY_KEY_PAST\n\nlogger = get_logger(__name__)\n\ntry:\n    import speedtest\nexcept ImportError:\n    raise ImportError(\"Please `pip install speedtest-cli==2.1.3`\")\n\nif not hasattr(speedtest, \"Speedtest\"):\n    raise ImportError(\n        \"You are using the wrong speedtest module. Please replace speedtest with speedtest-cli.\\n\"\n        \"To do that, run `pip uninstall -y speedtest`. Depending on your python environment, \"\n        \"you may need to run uninstall speedtest two or more times, until it says 'not installed'.\\n\"\n        \"After that, please `pip install speedtest-cli==2.1.3` to install the correct version.\"\n    )\n\n\ndef get_server_throughput(\n    model_name: str,\n    config: PretrainedConfig,\n    device: torch.device,\n    dtype: Union[str, torch.dtype],\n    *,\n    num_blocks: int,\n    quant_type: QuantType,\n    tensor_parallel_devices: Sequence[torch.device],\n    reachable_via_relay: bool,\n    relay_penalty: float = 0.2,\n    force_eval: bool = False,\n    cache_dir: Optional[str] = None,\n) -> Dict[str, float]:\n    dtype = resolve_block_dtype(config, dtype)\n\n    if cache_dir is None:\n        cache_dir = DEFAULT_CACHE_DIR\n    lock_path = Path(cache_dir, \"throughput.lock\")\n    cache_path = Path(cache_dir, \"throughput_v5.json\")\n\n    # We use the system-wide lock since only one process at a time can measure the host throughput\n    os.makedirs(lock_path.parent, exist_ok=True)\n    with open(lock_path, \"wb+\") as lock_fd:\n        logger.info(\"Loading throughput info\")\n        fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)\n        # The OS will release the lock when lock_fd is closed or the process is killed\n\n        cache_key = f\"model_{model_name}\"\n        cache_key += f\"_device_{get_device_name(device).replace(' ', '_')}\"\n        cache_key += f\"_dtype_{get_dtype_name(dtype, quant_type)}\"\n        if len(tensor_parallel_devices) > 1:\n            for i, device_i in enumerate(tensor_parallel_devices):\n                cache_key += f\"_tp{i}_{get_device_name(device_i).replace(' ', '_')}\"\n\n        cache = {}\n        try:\n            if not force_eval and os.path.exists(cache_path):\n                with open(cache_path) as cache_fd:\n                    cache = json.load(cache_fd)\n                assert isinstance(cache, dict)\n        except Exception:\n            logger.exception(f\"Failed to read throughput info from {cache_path}\")\n            cache = {}\n\n        if cache_key not in cache:\n            cache[cache_key] = measure_throughput_info(\n                config, device, dtype, quant_type=quant_type, tensor_parallel_devices=tensor_parallel_devices\n            )\n\n            try:\n                os.makedirs(cache_path.parent, exist_ok=True)\n                with open(cache_path, \"w\") as cache_fd:\n                    json.dump(cache, cache_fd)\n            except Exception:\n                logger.exception(f\"Failed to save throughput info in {cache_path}\")\n\n    throughput_info = cache[cache_key]\n\n    # Most requests start at some block hosted by a server, then use all next blocks hosted on this server.\n    # Assuming the start block index is distributed uniformly, the average number of blocks used per request is\n    # E[Uniform{1, 2, ..., num_blocks}] = (num_blocks + 1) / 2\n    average_blocks_used = (num_blocks + 1) / 2\n    throughput = throughput_info[\"forward_rps\"] / average_blocks_used\n\n    network_rps = throughput_info[\"network_rps\"] * (relay_penalty if reachable_via_relay else 1)\n    throughput = min(throughput, network_rps)\n\n    throughput_info[\"throughput\"] = throughput\n    logger.info(f\"Reporting throughput: {throughput:.1f} tokens/sec for {num_blocks} blocks\")\n\n    return throughput_info\n\n\ndef measure_throughput_info(\n    config: PretrainedConfig,\n    device: torch.device,\n    dtype: torch.dtype,\n    *,\n    quant_type: QuantType,\n    tensor_parallel_devices: Sequence[torch.device],\n) -> Dict[str, float]:\n    logger.info(\n        \"Measuring network and compute throughput. This takes about a minute and will be cached for future runs\"\n    )\n    return {\n        \"inference_rps\": measure_compute_rps(\n            config,\n            device,\n            dtype,\n            quant_type=quant_type,\n            tensor_parallel_devices=tensor_parallel_devices,\n            n_tokens=1,\n            n_steps=100,\n            inference=True,\n        ),\n        \"forward_rps\": measure_compute_rps(\n            config,\n            device,\n            dtype,\n            quant_type=quant_type,\n            tensor_parallel_devices=tensor_parallel_devices,\n            n_tokens=1024,\n            n_steps=10,\n            inference=False,\n        ),\n        \"network_rps\": measure_network_rps(config),\n    }\n\n\ndef measure_network_rps(\n    config: PretrainedConfig, *, timeout: float = 60, default_speed: float = 100e6  # 100 Mbit/s\n) -> Optional[float]:\n    bits_per_request = config.hidden_size * 16  # Clients usually send 16-bit tensors for forward/backward\n    try:\n        pipe_recv, pipe_send = mp.Pipe(duplex=False)\n        process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,))\n        process.start()\n\n        if not pipe_recv.poll(timeout):\n            process.terminate()\n            raise RuntimeError(f\"speedtest did not finish in {timeout} seconds\")\n        network_info = pipe_recv.recv()\n        if \"exception\" in network_info:\n            raise RuntimeError(f\"speedtest failed: {network_info['exception']}\")\n\n        network_rps = min(network_info[\"download\"], network_info[\"upload\"]) / bits_per_request\n        if network_rps == 0:\n            raise RuntimeError(\"speedtest has returned network_rps == 0\")\n\n        logger.info(\n            f\"Network throughput: {network_rps:.1f} tokens/sec \"\n            f\"({network_info['download'] / 1e6:.2f} Mbit/s on download, \"\n            f\"{network_info['upload'] / 1e6:.2f} Mbit/s on upload)\"\n        )\n        return network_rps\n    except RuntimeError as e:\n        logger.info(f\"Network throughput is not available: {e}. Using default of {default_speed / 1e6:.2f} Mbit/s\")\n        return default_speed / bits_per_request\n\n\ndef _measure_bits_per_second(pipe_send: mp.Pipe):\n    try:\n        s = speedtest.Speedtest()\n        s.get_servers()\n        s.get_best_server()\n        s.download()\n        s.upload()\n        pipe_send.send(s.results.dict())\n    except Exception as e:\n        pipe_send.send({\"exception\": repr(e)})\n\n\ndef measure_compute_rps(\n    config: PretrainedConfig,\n    device: torch.device,\n    dtype: torch.dtype,\n    *,\n    quant_type: QuantType,\n    tensor_parallel_devices: Sequence[torch.device],\n    n_tokens: int,\n    n_steps: int,\n    inference: bool,\n) -> float:\n    device = torch.device(device)\n    if not tensor_parallel_devices:\n        tensor_parallel_devices = (device,)\n    with torch.inference_mode():\n        block = get_model_block(config)\n        block = block.to(dtype)\n        block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)\n\n        cache = (DUMMY_KEY_PAST.to(dtype=dtype, device=device), DUMMY_KEY_PAST.to(dtype=dtype, device=device))\n        elapsed = 0\n        dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)\n\n        # Skip the 1st step to exclude the initialization time\n        def step(cache_):\n            outputs = block.forward(dummy_input, use_cache=inference, layer_past=cache_ if inference else None)\n            return outputs[1] if inference else None\n\n        cache = step(cache)\n        synchronize(device)\n\n        start_time = time.perf_counter()\n        for _ in range(n_steps):\n            cache = step(cache)\n        synchronize(device)\n        elapsed = time.perf_counter() - start_time\n        device_rps = n_steps * n_tokens / elapsed\n\n    devices_repr = get_device_name(device)\n    if len(tensor_parallel_devices) > 1:\n        device_names = tuple(map(get_device_name, map(torch.device, tensor_parallel_devices)))\n        devices_repr = \", \".join(f\"{count}x {name}\" for name, count in Counter(device_names).most_common())\n\n    logger.info(\n        f\"{'Inference' if inference else 'Forward pass'} throughput: {device_rps:.1f} tokens/sec per block \"\n        f\"({n_tokens} tokens/batch, {devices_repr}, {get_dtype_name(dtype, quant_type)})\"\n    )\n    return device_rps\n\n\ndef synchronize(device: torch.device):\n    if device.type == \"cuda\":\n        torch.cuda.synchronize(device)\n    elif device.type == \"mps\":\n        torch.mps.synchronize()\n\n\ndef get_device_name(device: torch.device) -> str:\n    return f\"{torch.cuda.get_device_name(device)} GPU\" if device.type == \"cuda\" else device.type.upper()\n\n\ndef get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str:\n    name = str(dtype).replace(\"torch.\", \"\")\n    if quant_type != QuantType.NONE:\n        name += f\", quantized to {quant_type.name.lower()}\"\n    return name\n"
  },
  {
    "path": "src/petals/utils/__init__.py",
    "content": "from petals.utils.auto_config import (\n    AutoDistributedConfig,\n    AutoDistributedModel,\n    AutoDistributedModelForCausalLM,\n    AutoDistributedModelForSequenceClassification,\n    AutoDistributedSpeculativeModel,\n)\nfrom petals.utils.dht import declare_active_modules, get_remote_module_infos\n"
  },
  {
    "path": "src/petals/utils/asyncio.py",
    "content": "import asyncio\n\n\nasync def shield_and_wait(task):\n    \"\"\"\n    Works like asyncio.shield(), but waits for the task to finish before raising CancelledError to the caller.\n    \"\"\"\n\n    if not isinstance(task, asyncio.Task):\n        task = asyncio.create_task(task)\n\n    cancel_exc = None\n    while True:\n        try:\n            result = await asyncio.shield(task)\n            break\n        except asyncio.CancelledError as e:\n            cancel_exc = e\n    if cancel_exc is not None:\n        raise cancel_exc\n    return result\n"
  },
  {
    "path": "src/petals/utils/auto_config.py",
    "content": "import os\nfrom dataclasses import dataclass\nfrom typing import Optional, Type, Union\n\nfrom hivemind import get_logger\nfrom transformers import AutoConfig, PretrainedConfig, PreTrainedModel\n\nfrom petals.utils.hf_auth import always_needs_auth\n\nlogger = get_logger(__name__)\n\n\n@dataclass\nclass _ModelClasses:\n    config: Type[PretrainedConfig]\n    model: Optional[Type[PreTrainedModel]] = None\n    model_for_causal_lm: Optional[Type[PreTrainedModel]] = None\n    model_for_speculative: Optional[Type[PreTrainedModel]] = None\n    model_for_sequence_classification: Optional[Type[PreTrainedModel]] = None\n\n\n_CLASS_MAPPING = {}  # Populated by petals.models.* subpackages with register_model_classes()\n\n\ndef register_model_classes(*, config: Type[PretrainedConfig], **kwargs):\n    assert issubclass(config, PretrainedConfig)\n    assert config.model_type not in _CLASS_MAPPING, f\"Model type {config.model_type} is already registered\"\n\n    _CLASS_MAPPING[config.model_type] = _ModelClasses(config=config, **kwargs)\n\n\nclass _AutoDistributedBase:\n    _mapping_field = None  # Should be defined in child classes\n\n    @classmethod\n    def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike, None], *args, **kwargs) -> PretrainedConfig:\n        if (\n            always_needs_auth(model_name_or_path)\n            and kwargs.get(\"token\") is None\n            and kwargs.get(\"use_auth_token\") is None\n        ):\n            kwargs[\"use_auth_token\"] = True\n\n        config = AutoConfig.from_pretrained(model_name_or_path, *args, **kwargs)\n        if config.model_type not in _CLASS_MAPPING:\n            raise ValueError(f\"Petals does not support model type {config.model_type}\")\n\n        proper_cls = getattr(_CLASS_MAPPING[config.model_type], cls._mapping_field)\n        if proper_cls is None:\n            raise ValueError(f\"Petals does not have {cls.__name__} for model type {config.model_type}\")\n\n        return proper_cls.from_pretrained(model_name_or_path, *args, **kwargs)\n\n\nclass DefaultRevisionMixin:\n    \"\"\"\n    Petals only supports Falcon loaded in the new in-library format (transformers.FalconModel).\n    TII models were recently converted to this format but then reverted back due to compatibility issues.\n    We chose to support only the new format since HF staff promised to eventually convert these models\n    to the new format again, see https://huggingface.co/tiiuae/falcon-40b/discussions/90#64b4d23bf44fd957492f7602\n    Until it happens, we override the default `main` revision for the TII repos with the commit\n    pointing out to the model in the in-library format.\n    \"\"\"\n\n    DEFAULT_REVISIONS = {\n        \"tiiuae/falcon-40b\": \"f1ba7d328c06aa6fbb4a8afd3c756f46d7e6b232\",\n        \"tiiuae/falcon-40b-instruct\": \"7475ff8cfc36ed9a962b658ae3c33391566a85a5\",\n        \"tiiuae/falcon-7b\": \"4e2d06f0a7c6370ebabbc30c6f59377ae8f73d76\",\n        \"tiiuae/falcon-7b-instruct\": \"f8dac3fff96d5debd43edf56fb4e1abcfffbef28\",\n    }\n\n    @classmethod\n    def from_pretrained(\n        cls, model_name_or_path: Union[str, os.PathLike, None], *args, revision: Optional[str] = None, **kwargs\n    ):\n        if revision is None and model_name_or_path in cls.DEFAULT_REVISIONS:\n            revision = cls.DEFAULT_REVISIONS[model_name_or_path]\n            logger.info(f\"Loading {model_name_or_path}, revision {revision}\")\n        return super().from_pretrained(model_name_or_path, *args, revision=revision, **kwargs)\n\n\nclass AutoDistributedConfig(DefaultRevisionMixin, _AutoDistributedBase):\n    _mapping_field = \"config\"\n\n\nclass AutoDistributedModel(DefaultRevisionMixin, _AutoDistributedBase):\n    _mapping_field = \"model\"\n\n\nclass AutoDistributedModelForCausalLM(DefaultRevisionMixin, _AutoDistributedBase):\n    _mapping_field = \"model_for_causal_lm\"\n\n\nclass AutoDistributedSpeculativeModel(DefaultRevisionMixin, _AutoDistributedBase):\n    _mapping_field = \"model_for_speculative\"\n\n\nclass AutoDistributedModelForSequenceClassification(DefaultRevisionMixin, _AutoDistributedBase):\n    _mapping_field = \"model_for_sequence_classification\"\n"
  },
  {
    "path": "src/petals/utils/convert_block.py",
    "content": "\"\"\"\nTools for converting transformer blocks, applying quantization and/or tensor parallelism\n\"\"\"\nimport re\nfrom enum import Enum\nfrom typing import Optional, Sequence\n\nimport tensor_parallel as tp\nimport torch\nimport torch.nn as nn\nfrom hivemind.utils.logging import get_logger, use_hivemind_log_handler\nfrom tensor_parallel.slicing_configs import get_bloom_config\nfrom transformers import PretrainedConfig\n\nuse_hivemind_log_handler(\"in_root_logger\")\nlogger = get_logger(__name__)\n\n\nclass QuantType(Enum):\n    NONE = 0\n    INT8 = 1  # 8-bit as in the LLM.int8() paper\n    NF4 = 2  # 4-bit as in the QLoRA paper\n\n\ndef convert_block(\n    block: nn.Module,\n    block_index: int,\n    config: PretrainedConfig,\n    tensor_parallel_devices: Sequence[torch.device],\n    output_device: torch.device,\n    quant_type: QuantType,\n    freeze: bool = True,\n    adapters: Optional[Sequence[str]] = None,\n    **kwargs,\n) -> tp.TensorParallel:\n    \"\"\"\n    Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization\n\n    :note: some optimizations will modify the input block in-place!\n    :param block: a single transformer block, either pre-trained or newly initialized\n    :param config: HF transformers config for the full model\n    :param tensor_parallel_devices: if specified, use tensor parallelism to split the model between these devices\n    :note: if there is only a single device, model wil still be wrapped with TensorParallel (for uniformity)\n    :param output_device: if tensor_parallel_devices is True, output\n    :param quant_type: quantization type\n    :param freeze: if True (default), make all module parameters non-trainable\n    :return: a module that acts like the original block, but runs with all specified optimizations\n\n    \"\"\"\n    if freeze:\n        block.requires_grad_(False)\n\n    block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device)\n\n    if quant_type != QuantType.NONE:\n        block = quantize_module(block, quant_type=quant_type)\n\n    for shard, device in zip(block.module_shards, block.devices):\n        shard.to(device)\n\n    if adapters:\n        from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft\n\n        create_lora_adapter(block)\n        for adapter_name in adapters:\n            adapter_config, adapter_state_dict = load_peft(\n                adapter_name,\n                block_idx=block_index,\n                **kwargs,\n            )\n            add_adapter_to_block(block, block_index, adapter_name, adapter_config, adapter_state_dict)\n\n    return block\n\n\ndef quantize_module(model: nn.Module, *, quant_type: QuantType) -> nn.Module:\n    # Import bitsandbytes only when necessary, so Petals runs on platforms not supported by bitsandbytes\n    import bitsandbytes as bnb\n\n    for n, module in model.named_children():\n        if len(list(module.children())) > 0:\n            quantize_module(module, quant_type=quant_type)\n\n        if isinstance(module, torch.nn.Linear) and n not in [\"lm_head\", \"score\"]:\n            assert module.weight.device.type == \"cpu\", f\"expected linear layers on CPU, got {module.weight.device}\"\n            if quant_type == QuantType.INT8:\n                model._modules[n] = bnb.nn.Linear8bitLt(\n                    module.in_features,\n                    module.out_features,\n                    module.bias is not None,\n                    has_fp16_weights=False,\n                    threshold=6.0,  # Default from the LLM.int8() paper\n                )\n                model._modules[n].weight = bnb.nn.Int8Params(\n                    module.weight.data, requires_grad=False, has_fp16_weights=False\n                ).to(module.weight.dtype)\n            elif quant_type == QuantType.NF4:\n                compress_statistics = True\n                model._modules[n] = bnb.nn.LinearNF4(\n                    module.in_features,\n                    module.out_features,\n                    module.bias is not None,\n                    compress_statistics=compress_statistics,\n                )\n                model._modules[n].weight = bnb.nn.Params4bit(\n                    module.weight.data,\n                    requires_grad=False,\n                    quant_type=\"nf4\",\n                    blocksize=64,\n                    compress_statistics=compress_statistics,\n                ).to(module.weight.dtype)\n            else:\n                raise ValueError(f\"Unsupported quant_type='{quant_type}'\")\n            model._modules[n].bias = module.bias\n    return model\n\n\ndef make_tensor_parallel(\n    block: nn.Module, model_config: PretrainedConfig, devices: Sequence[torch.device], output_device: torch.device\n) -> nn.Module:\n    if model_config.model_type == \"bloom\":\n        tp_config = get_bloom_config(model_config, devices)\n        del tp_config.state_rules[re.compile(\".*word_embeddings.weight$\")]\n    else:\n        if len(devices) > 1:\n            logger.warning(\"Tensor parallelism is not tested for models other than BLOOM yet, proceed with caution\")\n        tp_config = None\n    tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True)\n    total_heads = 0\n    for tp_shard in tp_block.module_shards:\n        for submodule in tp_shard.modules():\n            if isinstance(submodule, model_config.attn_class):\n                total_heads += submodule.num_heads\n    assert total_heads == model_config.num_attention_heads\n    return tp_block\n\n\ndef check_device_balance(devices: Sequence[torch.device]):\n    if not all(device.type == \"cuda\" for device in devices):\n        logger.warning(\"Running tensor parallelism on non-GPU devices; proceed at your own risk\")\n        return\n    unique_device_capabilities = set(map(torch.cuda.get_device_capability, devices))\n    if len(unique_device_capabilities) > 1:\n        logger.warning(\n            f\"Found GPUs with uneven capabilities: {unique_device_capabilities}. \"\n            f\"Using GPUs with different performance will cause the server to wait for the slowest GPU.\"\n        )\n\n    memory_per_device = tuple(torch.cuda.get_device_properties(device).total_memory for device in devices)\n    used_memory = min(memory_per_device) * len(memory_per_device)\n    wasted_memory_rate = (sum(memory_per_device) - used_memory) / sum(memory_per_device)\n    if wasted_memory_rate > 0.05:\n        logger.warning(\n            f\"GPU devices have highly uneven memory, {wasted_memory_rate * 100:.2f}% memory is wasted. \"\n            f\"Consider running high-memory GPUs in a separate server.\"\n        )\n"
  },
  {
    "path": "src/petals/utils/cuda_graphs.py",
    "content": "import torch\nfrom torch.utils._pytree import tree_flatten as _tree_flatten, tree_unflatten as _tree_unflatten\n\n\ndef make_inference_graphed_callable(callable: callable, sample_args, num_warmup_iters=3):\n    \"\"\"Similar to torch.cuda.make_graphed_callables, but takes only one function and does not build a graph for the backward pass\"\"\"\n    assert not isinstance(callable, torch.nn.Module)\n    if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():\n        raise RuntimeError(\n            \"make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`.\"\n        )\n\n    flatten_arg, _ = _tree_flatten(sample_args)\n    flatten_sample_args = tuple(flatten_arg)\n    assert all(\n        isinstance(arg, torch.Tensor) for arg in flatten_arg\n    ), \"In the beta API, sample_args for each callable must contain only Tensors. Other types are not allowed.\"\n\n    len_user_args = len(sample_args)\n    static_input_surface = flatten_sample_args\n\n    graph = torch.cuda.CUDAGraph()\n\n    # Warmup\n    # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work\n    # from ending up in any captures.\n    s = torch.cuda.Stream()\n    s.wait_stream(torch.cuda.current_stream())\n    with torch.cuda.stream(s):\n        for _ in range(num_warmup_iters):\n            outputs, _ = _tree_flatten(callable(*sample_args))\n        del outputs\n    torch.cuda.current_stream().wait_stream(s)\n\n    # Capture forward graph\n    with torch.cuda.graph(graph):\n        outputs = callable(*sample_args)\n\n    flatten_outputs, output_unflatten_spec = _tree_flatten(outputs)\n    static_outputs = tuple(flatten_outputs)\n\n    def make_graphed_function(\n        graph,\n        len_user_args,\n        output_unflatten_spec,\n        static_input_surface,\n        static_outputs,\n    ):\n        def replay_graph(*inputs):\n            # At this stage, only the user args may (potentially) be new tensors.\n            for i in range(len_user_args):\n                if static_input_surface[i].data_ptr() != inputs[i].data_ptr():\n                    static_input_surface[i].copy_(inputs[i])\n            graph.replay()\n            assert isinstance(static_outputs, tuple)\n            return tuple(o.detach() for o in static_outputs)\n\n        def functionalized(*user_args):\n            # Runs the autograd function with inputs == all inputs to the graph that might require grad\n            # (explicit user args + module parameters)\n            # Assumes module params didn't change since capture.\n            flatten_user_args, _ = _tree_flatten(user_args)\n            out = replay_graph(*flatten_user_args)\n            return _tree_unflatten(out, output_unflatten_spec)\n\n        return functionalized\n\n    # Put together the final graphed callable\n    graphed = make_graphed_function(\n        graph,\n        len_user_args,\n        output_unflatten_spec,\n        static_input_surface,\n        static_outputs,\n    )\n    return graphed\n"
  },
  {
    "path": "src/petals/utils/dht.py",
    "content": "\"\"\"\nUtilities for declaring and retrieving active model layers using a shared DHT.\n\"\"\"\nfrom __future__ import annotations\n\nimport math\nfrom functools import partial\nfrom typing import Dict, List, Optional, Sequence, Union\n\nfrom hivemind.dht import DHT, DHTNode, DHTValue\nfrom hivemind.p2p import PeerID\nfrom hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger\n\nfrom petals.data_structures import (\n    CHAIN_DELIMITER,\n    UID_DELIMITER,\n    ModuleUID,\n    RemoteModuleInfo,\n    RemoteSpanInfo,\n    ServerInfo,\n    ServerState,\n    parse_uid,\n)\n\nlogger = get_logger(__name__)\n\n\ndef declare_active_modules(\n    dht: DHT,\n    uids: Sequence[ModuleUID],\n    server_info: ServerInfo,\n    expiration_time: DHTExpiration,\n    wait: bool = True,\n) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:\n    \"\"\"\n    Declare that your node serves the specified modules; update timestamps if declared previously\n\n    :param uids: a list of module ids to declare\n    :param wait: if True, awaits for declaration to finish, otherwise runs in background\n    :param throughput: specify your performance in terms of compute throughput\n    :param expiration_time: declared modules will be visible for this many seconds\n    :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)\n    \"\"\"\n    if isinstance(uids, str):\n        uids = [uids]\n    if not isinstance(uids, list):\n        uids = list(uids)\n    for uid in uids:\n        assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid\n\n    return dht.run_coroutine(\n        partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time),\n        return_future=not wait,\n    )\n\n\nasync def _declare_active_modules(\n    dht: DHT,\n    node: DHTNode,\n    uids: List[ModuleUID],\n    server_info: ServerInfo,\n    expiration_time: DHTExpiration,\n) -> Dict[ModuleUID, bool]:\n    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)\n    return await node.store_many(\n        keys=uids,\n        subkeys=[dht.peer_id.to_base58()] * len(uids),\n        values=[server_info.to_tuple()] * len(uids),\n        expiration_time=expiration_time,\n        num_workers=num_workers,\n    )\n\n\ndef get_remote_module_infos(\n    dht: DHT,\n    uids: Sequence[ModuleUID],\n    expiration_time: Optional[DHTExpiration] = None,\n    active_adapter: Optional[str] = None,\n    *,\n    latest: bool = False,\n    return_future: bool = False,\n) -> Union[List[RemoteModuleInfo], MPFuture]:\n    return dht.run_coroutine(\n        partial(\n            _get_remote_module_infos,\n            uids=uids,\n            active_adapter=active_adapter,\n            expiration_time=expiration_time,\n            latest=latest,\n        ),\n        return_future=return_future,\n    )\n\n\nasync def _get_remote_module_infos(\n    dht: DHT,\n    node: DHTNode,\n    uids: List[ModuleUID],\n    active_adapter: Optional[str],\n    expiration_time: Optional[DHTExpiration],\n    latest: bool,\n) -> List[RemoteModuleInfo]:\n    if latest:\n        assert expiration_time is None, \"You should define either `expiration_time` or `latest`, not both\"\n        expiration_time = math.inf\n    elif expiration_time is None:\n        expiration_time = get_dht_time()\n    num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)\n    found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)\n\n    modules = [RemoteModuleInfo(uid=uid, servers={}) for uid in uids]\n    for module_info in modules:\n        metadata = found[module_info.uid]\n        if metadata is None or not isinstance(metadata.value, dict):\n            if metadata is not None:\n                logger.warning(f\"Incorrect metadata for {module_info.uid}: {metadata}\")\n            continue\n\n        for peer_id, server_info in metadata.value.items():\n            try:\n                peer_id = PeerID.from_base58(peer_id)\n                server_info = ServerInfo.from_tuple(server_info.value)\n\n                if active_adapter and active_adapter not in server_info.adapters:\n                    logger.debug(f\"Skipped server {peer_id} since it does not have adapter {active_adapter}\")\n                    continue\n\n                module_info.servers[peer_id] = server_info\n            except (TypeError, ValueError) as e:\n                logger.warning(f\"Incorrect peer entry for uid={module_info.uid}, peer_id={peer_id}: {e}\")\n    return modules\n\n\ndef compute_spans(module_infos: List[RemoteModuleInfo], *, min_state: ServerState) -> Dict[PeerID, RemoteSpanInfo]:\n    block_offset = parse_uid(module_infos[0].uid)[1] if module_infos else 0\n    num_blocks = len(module_infos)\n\n    spans = {}\n    for block_idx, module_info in enumerate(module_infos):\n        for peer_id, server_info in sorted(module_info.servers.items()):\n            if server_info.state.value < min_state.value:\n                continue\n\n            if peer_id not in spans or spans[peer_id].state.value < server_info.state.value:\n                spans[peer_id] = RemoteSpanInfo(\n                    peer_id=peer_id, start=block_idx, end=block_idx + 1, server_info=server_info\n                )\n                if server_info.start_block is not None and server_info.end_block is not None:\n                    spans[peer_id].start = max(server_info.start_block - block_offset, 0)\n                    spans[peer_id].end = min(server_info.end_block - block_offset, num_blocks)\n            elif spans[peer_id].state == server_info.state:\n                spans[peer_id].end = max(spans[peer_id].end, block_idx + 1)\n    return spans\n"
  },
  {
    "path": "src/petals/utils/disk_cache.py",
    "content": "import fcntl\nimport os\nimport shutil\nfrom contextlib import contextmanager\nfrom pathlib import Path\nfrom typing import Optional\n\nimport huggingface_hub\nfrom hivemind.utils.logging import get_logger\n\nlogger = get_logger(__name__)\n\nDEFAULT_CACHE_DIR = os.getenv(\"PETALS_CACHE\", Path(Path.home(), \".cache\", \"petals\"))\n\nBLOCKS_LOCK_FILE = \"blocks.lock\"\n\n\n@contextmanager\ndef _blocks_lock(cache_dir: Optional[str], mode: int):\n    if cache_dir is None:\n        cache_dir = DEFAULT_CACHE_DIR\n    lock_path = Path(cache_dir, BLOCKS_LOCK_FILE)\n\n    os.makedirs(lock_path.parent, exist_ok=True)\n    with open(lock_path, \"wb+\") as lock_fd:\n        fcntl.flock(lock_fd.fileno(), mode)\n        # The OS will release the lock when lock_fd is closed or the process is killed\n        yield\n\n\ndef allow_cache_reads(cache_dir: Optional[str]):\n    \"\"\"Allows simultaneous reads, guarantees that blocks won't be removed along the way (shared lock)\"\"\"\n    return _blocks_lock(cache_dir, fcntl.LOCK_SH)\n\n\ndef allow_cache_writes(cache_dir: Optional[str]):\n    \"\"\"Allows saving new blocks and removing the old ones (exclusive lock)\"\"\"\n    return _blocks_lock(cache_dir, fcntl.LOCK_EX)\n\n\ndef free_disk_space_for(\n    size: int,\n    *,\n    cache_dir: Optional[str],\n    max_disk_space: Optional[int],\n    os_quota: int = 1024**3,  # Minimal space we should leave to keep OS function normally\n):\n    if cache_dir is None:\n        cache_dir = DEFAULT_CACHE_DIR\n    cache_info = huggingface_hub.scan_cache_dir(cache_dir)\n\n    available_space = shutil.disk_usage(cache_dir).free - os_quota\n    if max_disk_space is not None:\n        available_space = min(available_space, max_disk_space - cache_info.size_on_disk)\n\n    gib = 1024**3\n    logger.debug(f\"Disk space: required {size / gib:.1f} GiB, available {available_space / gib:.1f} GiB\")\n    if size <= available_space:\n        return\n\n    cached_files = [file for repo in cache_info.repos for revision in repo.revisions for file in revision.files]\n\n    # Remove as few least recently used files as possible\n    removed_files = []\n    freed_space = 0\n    extra_space_needed = size - available_space\n    for file in sorted(cached_files, key=lambda file: file.blob_last_accessed):\n        os.remove(file.file_path)  # Remove symlink\n        os.remove(file.blob_path)  # Remove contents\n\n        removed_files.append(file)\n        freed_space += file.size_on_disk\n        if freed_space >= extra_space_needed:\n            break\n    if removed_files:\n        logger.info(f\"Removed {len(removed_files)} files to free {freed_space / gib:.1f} GiB of disk space\")\n        logger.debug(f\"Removed paths: {[str(file.file_path) for file in removed_files]}\")\n\n    if freed_space < extra_space_needed:\n        raise RuntimeError(\n            f\"Insufficient disk space to load a block. Please free {(extra_space_needed - freed_space) / gib:.1f} GiB \"\n            f\"on the volume for {cache_dir} or increase --max_disk_space if you set it manually\"\n        )\n"
  },
  {
    "path": "src/petals/utils/hf_auth.py",
    "content": "import os\nfrom typing import Union\n\n\ndef always_needs_auth(model_name: Union[str, os.PathLike, None]) -> bool:\n    loading_from_repo = model_name is not None and not os.path.isdir(model_name)\n    return loading_from_repo and model_name.startswith(\"meta-llama/Llama-2-\")\n"
  },
  {
    "path": "src/petals/utils/logging.py",
    "content": "import os\n\nfrom hivemind.utils import logging as hm_logging\n\n\ndef initialize_logs():\n    \"\"\"Initialize Petals logging tweaks. This function is called when you import the `petals` module.\"\"\"\n\n    # Env var PETALS_LOGGING=False prohibits Petals do anything with logs\n    if os.getenv(\"PETALS_LOGGING\", \"True\").lower() in (\"false\", \"0\"):\n        return\n\n    hm_logging.use_hivemind_log_handler(\"in_root_logger\")\n\n    # We suppress asyncio error logs by default since they are mostly not relevant for the end user,\n    # unless there is env var PETALS_ASYNCIO_LOGLEVEL\n    asyncio_loglevel = os.getenv(\"PETALS_ASYNCIO_LOGLEVEL\", \"FATAL\" if hm_logging.loglevel != \"DEBUG\" else \"DEBUG\")\n    hm_logging.get_logger(\"asyncio\").setLevel(asyncio_loglevel)\n"
  },
  {
    "path": "src/petals/utils/misc.py",
    "content": "import torch\n\nDUMMY = torch.empty(0)  # dummy tensor that replaces empty prompt or adapter parameters\n\nDUMMY_INT64 = torch.empty(0, dtype=torch.int64)\n\nDUMMY_KEY_PAST = torch.empty((0, 0, 0))\n\n\ndef is_dummy(tensor: torch.Tensor) -> bool:\n    return tensor.numel() == 0\n\n\nSPECIAL_DTYPE_SIZES = {torch.bool: 1, torch.qint8: 1, torch.qint32: 4}\n\n\ndef get_size_in_bytes(dtype: torch.dtype) -> int:\n    if dtype in SPECIAL_DTYPE_SIZES:\n        return SPECIAL_DTYPE_SIZES[dtype]\n    get_info = torch.finfo if dtype.is_floating_point else torch.iinfo\n    return (get_info(dtype).bits * (1 + dtype.is_complex)) // 8\n\n\ndef docstring_from(source):\n    def add_docstring(dest):\n        dest.__doc__ = source.__doc__\n        return dest\n\n    return add_docstring\n"
  },
  {
    "path": "src/petals/utils/packaging.py",
    "content": "from typing import Any, Dict, List, Tuple\n\nimport torch\nfrom hivemind import nested_flatten, nested_pack\n\n# TODO: Move functions to hivemind\n\n\ndef _mark_masked_tensor(index: int) -> bytes:\n    return b\"__T\" + str(index).encode()\n\n\ndef _is_masked_tensor(item: Any) -> bool:\n    return isinstance(item, bytes) and item.startswith(b\"__T\")\n\n\ndef _get_tensor_index(item: bytes) -> int:\n    return int(item[3:])\n\n\ndef pack_args_kwargs(*args, **kwargs) -> Tuple[List[torch.Tensor], Any]:\n    \"\"\"\n    Check the function's arguments and pack all tensors into different flattened lists.\n    :returns: a flattened list of tensors and args and kwargs, where tensors were masked\n    \"\"\"\n    masked_flat_values, flat_tensors, tensor_to_index = [], [], {}\n    for value in nested_flatten((args, kwargs)):\n        if isinstance(value, torch.Tensor):\n            tensor_index = tensor_to_index.setdefault(value, len(flat_tensors))\n            if tensor_index == len(flat_tensors):\n                flat_tensors.append(value)\n            masked_flat_values.append(_mark_masked_tensor(tensor_index))\n        else:\n            masked_flat_values.append(value)\n    return flat_tensors, nested_pack(masked_flat_values, (args, kwargs))\n\n\ndef unpack_args_kwargs(flat_tensors: List[torch.Tensor], args_structure: Any):\n    \"\"\"\n    Restore arguments after `pack_args_kwargs` function.\n    :returns: list of args and dict of kwargs\n    \"\"\"\n    return nested_pack(\n        (\n            value if not _is_masked_tensor(value) else flat_tensors[_get_tensor_index(value)]\n            for value in nested_flatten(args_structure)\n        ),\n        args_structure,\n    )\n"
  },
  {
    "path": "src/petals/utils/peft.py",
    "content": "import contextlib\nimport re\nimport time\nfrom typing import List, Optional, Sequence, Union\n\nimport bitsandbytes as bnb\nimport torch\nimport torch.nn as nn\nimport transformers\nfrom accelerate import init_empty_weights\nfrom hivemind.utils.logging import get_logger\nfrom huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url\nfrom peft.config import PeftConfig\nfrom peft.tuners import lora\nfrom peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME\nfrom safetensors import safe_open\nfrom safetensors.torch import load_file\nfrom transformers.utils import get_file_from_repo\n\nfrom petals.server.block_utils import get_model_block, resolve_block_dtype\nfrom petals.utils.convert_block import QuantType\nfrom petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for\nfrom petals.utils.misc import get_size_in_bytes\n\nlogger = get_logger(__name__)\n\n\nCOMMON_LAYERS_PATTERN = [\"layers\", \"h\", \"block\", \"blocks\", \"layer\"]\n\n\ndef check_peft_repository(repo_id: str) -> bool:\n    return HfFileSystem().exists(f\"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}\")\n\n\ndef load_specific_module(block_idx: int, filepath: str, framework: str = \"pt\", device: Optional[int] = None):\n    tensors = dict()\n    is_tensors_found = dict()\n    common_layer_patter_re = (\n        \".+\\.\" + \"\".join(f\"({common_name})?\" for common_name in COMMON_LAYERS_PATTERN) + f\"\\.({block_idx})?\\..+\"\n    )\n    with safe_open(filepath, framework=framework, device=device) as f:\n        for k in f.keys():\n            if re.match(common_layer_patter_re, k):\n                is_tensors_found[block_idx] = True\n                tensors[k] = f.get_tensor(k)\n        if not is_tensors_found.get(block_idx, False):\n            logger.warning(f\"There is no peft weights for block {block_idx}\")\n        return tensors\n\n\ndef get_adapter_from_repo(\n    repo_id: str,\n    block_idx: Optional[int] = None,\n    device: Optional[int] = None,\n    *,\n    token: Optional[Union[str, bool]] = None,\n    **kwargs,\n):\n    config_path = get_file_from_repo(repo_id, CONFIG_NAME, use_auth_token=token, **kwargs)\n    if config_path is None:\n        raise RuntimeError(f\"File {CONFIG_NAME} does not exist in repo {repo_id}\")\n    config = PeftConfig.from_json_file(config_path)\n\n    weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, use_auth_token=token, **kwargs)\n    if weight_path is None:\n        raise RuntimeError(f\"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}\")\n    if block_idx is None:\n        return config, load_file(weight_path)\n    return config, load_specific_module(block_idx, weight_path, device=device)\n\n\ndef load_peft(\n    repo_id: str,\n    block_idx: Optional[int] = None,\n    device: Optional[int] = None,\n    *,\n    revision: Optional[str] = None,\n    token: Optional[Union[str, bool]] = None,\n    cache_dir: str,\n    max_disk_space: Optional[int] = None,\n    delay: float = 30,\n):\n    # TODO: Check is it possible to add safetensors loading inside petals/server/from_pretrained.py and reuse it here\n\n    if not check_peft_repository(repo_id):\n        raise ValueError(f\"Repo: {repo_id} doesn't have safetensors inside for a safe loading.\")\n\n    try:\n        with allow_cache_reads(cache_dir):\n            return get_adapter_from_repo(\n                repo_id,\n                block_idx,\n                device,\n                revision=revision,\n                token=token,\n                cache_dir=cache_dir,\n                local_files_only=False,\n            )\n    except Exception:\n        logger.warning(f\"Cache for peft weights {repo_id} is corrupted, it will be downloaded again\", exc_info=True)\n\n    while True:\n        try:\n            with allow_cache_writes(cache_dir):\n                config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)\n                config_file_size = get_hf_file_metadata(config_url, token=token).size\n                weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)\n                weight_file_size = get_hf_file_metadata(weight_url, token=token).size\n\n                file_size = config_file_size + weight_file_size\n                if file_size is not None:\n                    free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)\n                else:\n                    logger.warning(f\"Failed to fetch size from peft repo {repo_id}\")\n\n                return get_adapter_from_repo(\n                    repo_id,\n                    block_idx,\n                    device,\n                    revision=revision,\n                    token=token,\n                    cache_dir=cache_dir,\n                    local_files_only=False,\n                )\n        except Exception as e:\n            logger.warning(\n                f\"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)\", exc_info=True\n            )\n            time.sleep(delay)\n\n\nclass AdapterContextMixin:\n    \"\"\"A mixin that makes LoRA-wrapped linear layers obey an adapter set from context\"\"\"\n\n    ADAPTER_NOT_SET = \"__ADAPTER_NOT_SET\"\n    _context_active_adapter = ADAPTER_NOT_SET\n\n    @staticmethod\n    @contextlib.contextmanager\n    def using_adapter(active_adapter: Optional[str]):\n        prev, AdapterContextMixin._context_active_adapter = AdapterContextMixin._context_active_adapter, active_adapter\n        try:\n            yield\n        finally:\n            AdapterContextMixin._context_active_adapter = prev\n\n    @property\n    def active_adapter(self):\n        if self._context_active_adapter == self.ADAPTER_NOT_SET:\n            logger.warning(f\"Layer {self} was called without using_adapter. This should only be used for debug\")\n        return self._context_active_adapter\n\n    @active_adapter.setter\n    def active_adapter(self, value: Optional[str]):\n        assert value == self.ADAPTER_NOT_SET, \"active adapter can only be changed via .using_adapter\" \"\"\n\n    @property\n    def active_adapters(self):\n        return [self._context_active_adapter]\n\n    def set_adapter(self, adapter_names) -> None:\n        \"\"\"\n        In PEFT, this function makes the adapter trainable. However, in Petals environment this is not possible now. Thus,\n        this code removes this functionality.\n        Link to peft code: https://github.com/huggingface/peft/blob/98f4db2c7990ef9c879a0e1da9a28a19a04701ef/src/peft/tuners/tuners_utils.py#L463\n        \"\"\"\n        pass\n\n\nusing_adapter = AdapterContextMixin.using_adapter\n\n\nclass LoraLinear(AdapterContextMixin, lora.Linear):\n    \"\"\"LoRA linear layer that uses adapter selected via using_adapter\"\"\"\n\n    def __init__(self, base_layer, adapter_name: str):\n        nn.Module.__init__(self)\n        lora.LoraLayer.__init__(self, base_layer)\n\n        self._active_adapter = adapter_name\n        self.is_target_conv_1d_layer = False\n\n\nclass LoraLinear8bitLt(LoraLinear, lora.Linear8bitLt):\n    \"\"\"LoRA linear 8-bit with outliers that uses adapter selected via using_adapter\"\"\"\n\n\nclass LoraLinear4bit(LoraLinear, lora.Linear4bit):\n    \"\"\"LoRA linear 4-bit that uses adapter selected via using_adapter\"\"\"\n\n\ndef create_lora_adapter(block):\n    for module_name, module in block.named_modules():\n        if isinstance(module, LoraLinear):\n            continue\n        for child_name, child in module.named_children():\n            lora_class = None\n            if isinstance(child, nn.Linear):\n                lora_class = LoraLinear\n            elif isinstance(child, bnb.nn.Linear8bitLt):\n                lora_class = LoraLinear8bitLt\n            elif isinstance(child, bnb.nn.Linear4bit):\n                lora_class = LoraLinear4bit\n            if lora_class:\n                lora_wrapped_child = lora_class(\n                    child,\n                    AdapterContextMixin.ADAPTER_NOT_SET,\n                )\n                setattr(module, child_name, lora_wrapped_child)\n\n\ndef add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict):\n    assert peft_config[\"peft_type\"] == \"LORA\", \"Petals works only with LORA adapters\"\n    if peft_config[\"lora_dropout\"] > 0:\n        logger.info(f\"Adapter {adapter_name} has dropout enabled, this server will disable dropout\")\n\n    for _, module in block.named_modules():\n        for child_name, child in module.named_children():\n            if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):\n                continue\n\n            if child_name in peft_config[\"target_modules\"] or (\n                isinstance(peft_config[\"target_modules\"], str)\n                and re.fullmatch(peft_config[\"target_modules\"], child_name)\n            ):\n                is_lora_a_loaded = False\n                is_lora_b_loaded = False\n                for peft_key in peft_state_dict:\n                    if child_name not in peft_key:\n                        continue\n\n                    if adapter_name not in child.lora_A:\n                        child.update_layer(\n                            adapter_name,\n                            peft_config[\"r\"],\n                            peft_config[\"lora_alpha\"],\n                            use_rslora=peft_config.get(\"use_rslora\", False),\n                            lora_dropout=peft_config[\"lora_dropout\"],\n                            init_lora_weights=peft_config[\"init_lora_weights\"],\n                        )\n                        child.train(False)\n                        for p in child.parameters():\n                            p.requires_grad = False\n\n                    if peft_key.endswith(\".lora_A.weight\"):\n                        child.lora_A[adapter_name].weight[...] = peft_state_dict[peft_key]\n                        is_lora_a_loaded = True\n                    elif peft_key.endswith(\".lora_A.bias\"):\n                        raise NotImplementedError(f\"LoRA adapters with bias not supported: {peft_key}\")\n                    elif peft_key.endswith(\".lora_B.weight\"):\n                        child.lora_B[adapter_name].weight[...] = peft_state_dict[peft_key]\n                        is_lora_b_loaded = True\n                    elif peft_key.endswith(\".lora_B.bias\"):\n                        raise NotImplementedError(f\"LoRA adapters with bias not supported: {peft_key}\")\n\n                if is_lora_a_loaded and is_lora_b_loaded:\n                    logger.debug(f\"Loaded adapter {adapter_name} for block {block_index}.{child_name}\")\n                elif is_lora_a_loaded or is_lora_b_loaded:\n                    raise ValueError(f\"Invalid adapter {adapter_name} for block {block_index}.{child_name}\")\n    logger.info(f\"Loaded adapter {adapter_name} for block {block_index}\")\n\n\ndef estimate_adapter_memory_per_block(\n    block_config: transformers.PretrainedConfig,\n    torch_dtype: Optional[torch.dtype],\n    adapters: Sequence[str],\n    **load_peft_kwargs,\n) -> int:\n    \"\"\"Get the number of extra bytes used to store a set of adapters per given block\"\"\"\n    with init_empty_weights(include_buffers=False):\n        block = get_model_block(block_config)\n        base_block_parameters = sum(p.numel() for p in block.parameters())\n        create_lora_adapter(block)\n\n        for adapter in adapters:\n            peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **load_peft_kwargs)\n            assert peft_config[\"peft_type\"].upper() == \"LORA\", \"only LoRA adapters are supported for now\"\n            add_adapter_to_block(\n                block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict\n            )\n        adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters\n    bytes_per_parameter = get_size_in_bytes(resolve_block_dtype(block_config, torch_dtype))\n    return adapter_parameters * bytes_per_parameter\n"
  },
  {
    "path": "src/petals/utils/ping.py",
    "content": "import asyncio\nimport math\nimport threading\nimport time\nfrom functools import partial\nfrom typing import Dict, Sequence\n\nimport hivemind\nfrom hivemind.proto import dht_pb2\nfrom hivemind.utils.logging import get_logger\n\nlogger = get_logger(__name__)\n\n\nasync def ping(\n    peer_id: hivemind.PeerID,\n    _dht: hivemind.DHT,\n    node: hivemind.dht.DHTNode,\n    *,\n    wait_timeout: float = 5,\n) -> float:\n    try:\n        ping_request = dht_pb2.PingRequest(peer=node.protocol.node_info)\n        start_time = time.perf_counter()\n        await node.protocol.get_stub(peer_id).rpc_ping(ping_request, timeout=wait_timeout)\n        return time.perf_counter() - start_time\n    except Exception as e:\n        if str(e) == \"protocol not supported\":  # Happens on servers with client-mode DHT (e.g., reachable via relays)\n            return time.perf_counter() - start_time\n\n        logger.debug(f\"Failed to ping {peer_id}:\", exc_info=True)\n        return math.inf\n\n\nasync def ping_parallel(peer_ids: Sequence[hivemind.PeerID], *args, **kwargs) -> Dict[hivemind.PeerID, float]:\n    rpc_infos = await asyncio.gather(*[ping(peer_id, *args, **kwargs) for peer_id in peer_ids])\n    return dict(zip(peer_ids, rpc_infos))\n\n\nclass PingAggregator:\n    def __init__(self, dht: hivemind.DHT, *, ema_alpha: float = 0.2, expiration: float = 300):\n        self.dht = dht\n        self.ema_alpha = ema_alpha\n        self.expiration = expiration\n        self.ping_emas = hivemind.TimedStorage()\n        self.lock = threading.Lock()\n\n    def ping(self, peer_ids: Sequence[hivemind.PeerID], **kwargs) -> None:\n        current_rtts = self.dht.run_coroutine(partial(ping_parallel, peer_ids, **kwargs))\n        logger.debug(f\"Current RTTs: {current_rtts}\")\n\n        with self.lock:\n            expiration = hivemind.get_dht_time() + self.expiration\n            for peer_id, rtt in current_rtts.items():\n                prev_rtt = self.ping_emas.get(peer_id)\n                if prev_rtt is not None and prev_rtt.value != math.inf:\n                    rtt = self.ema_alpha * rtt + (1 - self.ema_alpha) * prev_rtt.value  # Exponential smoothing\n                self.ping_emas.store(peer_id, rtt, expiration)\n\n    def to_dict(self) -> Dict[hivemind.PeerID, float]:\n        with self.lock, self.ping_emas.freeze():\n            smoothed_rtts = {peer_id: rtt.value for peer_id, rtt in self.ping_emas.items()}\n            logger.debug(f\"Smothed RTTs: {smoothed_rtts}\")\n            return smoothed_rtts\n"
  },
  {
    "path": "src/petals/utils/random.py",
    "content": "import random\nfrom typing import Collection, TypeVar\n\nT = TypeVar(\"T\")\n\n\ndef sample_up_to(population: Collection[T], k: int) -> T:\n    if not isinstance(population, list):\n        population = list(population)\n    if len(population) > k:\n        population = random.sample(population, k)\n    return population\n"
  },
  {
    "path": "src/petals/utils/version.py",
    "content": "import os\nimport re\nfrom typing import Union\n\nimport requests\nfrom hivemind.utils.logging import TextStyle, get_logger\nfrom packaging.version import parse\n\nimport petals\n\nlogger = get_logger(__name__)\n\n\ndef validate_version() -> None:\n    logger.info(f\"Running {TextStyle.BOLD}Petals {petals.__version__}{TextStyle.RESET}\")\n    try:\n        r = requests.get(\"https://pypi.python.org/pypi/petals/json\")\n        r.raise_for_status()\n        response = r.json()\n\n        versions = [parse(ver) for ver in response.get(\"releases\")]\n        latest = max(ver for ver in versions if not ver.is_prerelease)\n\n        if parse(petals.__version__) < latest:\n            logger.info(\n                f\"A newer version {latest} is available. Please upgrade with: \"\n                f\"{TextStyle.BOLD}pip install --upgrade petals{TextStyle.RESET}\"\n            )\n    except Exception as e:\n        logger.warning(\"Failed to fetch the latest Petals version from PyPI:\", exc_info=True)\n\n\ndef get_compatible_model_repo(model_name_or_path: Union[str, os.PathLike, None]) -> Union[str, os.PathLike, None]:\n    if model_name_or_path is None:\n        return None\n\n    match = re.fullmatch(r\"(bigscience/.+)-petals\", str(model_name_or_path))\n    if match is None:\n        return model_name_or_path\n\n    logger.info(\n        f\"Loading model from {match.group(1)}, since Petals 1.2.0+ uses original repos instead of converted ones\"\n    )\n    return match.group(1)\n"
  },
  {
    "path": "tests/conftest.py",
    "content": "import asyncio\nimport gc\nfrom contextlib import suppress\n\nimport psutil\nimport pytest\nfrom hivemind.utils.crypto import RSAPrivateKey\nfrom hivemind.utils.logging import get_logger\nfrom hivemind.utils.mpfuture import MPFuture\n\nlogger = get_logger(__name__)\n\n\n@pytest.fixture\ndef event_loop():\n    \"\"\"\n    This overrides the ``event_loop`` fixture from pytest-asyncio\n    (e.g. to make it compatible with ``asyncio.subprocess``).\n\n    This fixture is identical to the original one but does not call ``loop.close()`` in the end.\n    Indeed, at this point, the loop is already stopped (i.e. next tests are free to create new loops).\n    However, finalizers of objects created in the current test may reference the current loop and fail if it is closed.\n    For example, this happens while using ``asyncio.subprocess`` (the ``asyncio.subprocess.Process`` finalizer\n    fails if the loop is closed, but works if the loop is only stopped).\n    \"\"\"\n\n    yield asyncio.get_event_loop()\n\n\n@pytest.fixture(autouse=True, scope=\"session\")\ndef cleanup_children():\n    yield\n\n    with RSAPrivateKey._process_wide_key_lock:\n        RSAPrivateKey._process_wide_key = None\n\n    gc.collect()  # Call .__del__() for removed objects\n\n    children = psutil.Process().children(recursive=True)\n    if children:\n        logger.info(f\"Cleaning up {len(children)} leftover child processes\")\n        for child in children:\n            with suppress(psutil.NoSuchProcess):\n                child.terminate()\n        psutil.wait_procs(children, timeout=1)\n        for child in children:\n            with suppress(psutil.NoSuchProcess):\n                child.kill()\n\n    MPFuture.reset_backend()\n"
  },
  {
    "path": "tests/test_aux_functions.py",
    "content": "import subprocess\nimport sys\n\nimport pytest\nimport torch\nfrom hivemind import nested_compare, nested_flatten\n\nfrom petals import AutoDistributedConfig\nfrom petals.server.throughput import measure_compute_rps\nfrom petals.utils.convert_block import QuantType\nfrom petals.utils.misc import DUMMY, is_dummy\nfrom petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs\nfrom test_utils import MODEL_NAME\n\n\ndef test_bnb_not_imported_when_unnecessary():\n    \"\"\"\n    We avoid importing bitsandbytes when it's not used,\n    since bitsandbytes doesn't always find correct CUDA libs and may raise exceptions because of that.\n\n    If this test fails, please change your code to import bitsandbytes and/or petals.utils.peft\n    in the function's/method's code when it's actually needed instead of importing them in the beginning of the file.\n    This won't slow down the code - importing a module for the 2nd time doesn't rerun module code.\n    \"\"\"\n\n    subprocess.check_call([sys.executable, \"-c\", \"import petals, sys; assert 'bitsandbytes' not in sys.modules\"])\n\n\n@pytest.mark.forked\n@pytest.mark.parametrize(\"inference\", [False, True])\n@pytest.mark.parametrize(\"n_tokens\", [1, 16])\n@pytest.mark.parametrize(\"tensor_parallel\", [False, True])\ndef test_compute_throughput(inference: bool, n_tokens: int, tensor_parallel: bool):\n    config = AutoDistributedConfig.from_pretrained(MODEL_NAME)\n    if tensor_parallel and config.model_type != \"bloom\":\n        pytest.skip(\"Tensor parallelism is implemented only for BLOOM for now\")\n\n    tensor_parallel_devices = (\"cpu\", \"cpu\") if tensor_parallel else ()\n    compute_rps = measure_compute_rps(\n        config,\n        device=torch.device(\"cpu\"),\n        dtype=torch.bfloat16,\n        quant_type=QuantType.NONE,\n        tensor_parallel_devices=tensor_parallel_devices,\n        n_tokens=n_tokens,\n        n_steps=5,\n        inference=inference,\n    )\n    assert isinstance(compute_rps, float) and compute_rps > 0\n\n\n@pytest.mark.forked\ndef test_pack_inputs():\n    x = torch.ones(3)\n    y = torch.arange(5)\n    z = DUMMY\n\n    args = (x, z, None, (y, y), z)\n    kwargs = dict(foo=torch.zeros(1, 1), bar={\"l\": \"i\", \"g\": \"h\", \"t\": (\"y\", \"e\", \"a\", \"r\", torch.rand(1), x, y)})\n\n    flat_tensors, args_structure = pack_args_kwargs(*args, **kwargs)\n\n    assert len(flat_tensors) == 5\n    assert all(isinstance(t, torch.Tensor) for t in flat_tensors)\n\n    restored_args, restored_kwargs = unpack_args_kwargs(flat_tensors, args_structure)\n\n    assert len(restored_args) == len(args)\n    assert torch.all(restored_args[0] == x).item() and restored_args[2] is None\n    assert nested_compare((args, kwargs), (restored_args, restored_kwargs))\n    for original, restored in zip(nested_flatten((args, kwargs)), nested_flatten((restored_args, restored_kwargs))):\n        if isinstance(original, torch.Tensor):\n            assert torch.all(original == restored)\n        else:\n            assert original == restored\n"
  },
  {
    "path": "tests/test_block_exact_match.py",
    "content": "import random\n\nimport pytest\nimport torch\n\nfrom petals import AutoDistributedConfig, RemoteSequential\nfrom petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS\nfrom petals.server.from_pretrained import load_pretrained_block\nfrom test_utils import *\n\n\n@pytest.mark.forked\ndef test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):\n    config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)\n    remote_sequential = RemoteSequential(config)\n\n    block_index = random.randint(0, config.num_hidden_layers - 1)\n    remote_block = remote_sequential[block_index]\n\n    inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS + 8, config.hidden_size)\n    outputs_forward = remote_block(inputs)\n\n    outputs_inference = []\n    with torch.inference_mode():\n        with remote_block.inference_session(max_length=inputs.shape[1]) as sess:\n            # Test long inference (unmerged inference pools)\n            outputs_inference.append(sess.step(inputs[:, : MAX_SHORT_INFERENCE_TOKENS + 1, :]))\n\n            # Test short inference (merged inference pools)\n            for i in range(MAX_SHORT_INFERENCE_TOKENS + 1, inputs.shape[1]):\n                outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))\n\n            # test that max length is respected\n            with pytest.raises(ValueError, match=r\"Maximum length exceeded\") as exc_info:\n                sess.step(inputs[:, -1:, :])\n            assert \"Maximum length exceeded\" in repr(exc_info.value)\n    outputs_inference = torch.cat(outputs_inference, dim=1)\n\n    ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)\n    (outputs_local,) = ref_block(inputs)\n\n    assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)\n    assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)\n"
  },
  {
    "path": "tests/test_cache.py",
    "content": "import asyncio\nimport multiprocessing as mp\nimport random\nimport time\nfrom typing import Optional\n\nimport pytest\nimport pytest_asyncio  # make sure the module exists; otherwise the test will be skipped\nimport torch\nfrom hivemind import TensorDescriptor\n\nfrom petals.server.memory_cache import AllocationFailed, MemoryCache\nfrom petals.utils.misc import get_size_in_bytes\n\n\ndef _make_tensor_descriptor(num_bytes: int, dtype: Optional[torch.dtype] = None):\n    if dtype is None:\n        dtype = random.choice((torch.int64, torch.int8, torch.uint8, torch.float32, torch.bfloat16, torch.bool))\n    elem_size_bytes = get_size_in_bytes(dtype)\n    descr = TensorDescriptor.from_tensor(torch.empty((num_bytes // elem_size_bytes,), dtype=dtype))\n    return descr\n\n\n@pytest.mark.asyncio\nasync def test_cache_timeout():\n    cache = MemoryCache(max_size_bytes=1024, max_alloc_timeout=0.5)\n    cache.runtime_pid += 1  # pretend we're another process\n    async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=0):\n        pass\n\n    async with cache.allocate_cache(_make_tensor_descriptor(100), timeout=999):\n        async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0):\n            async with cache.allocate_cache(_make_tensor_descriptor(128), _make_tensor_descriptor(32), timeout=1):\n                t_start = time.perf_counter()\n                with pytest.raises(AllocationFailed):\n                    async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=0.1):\n                        pass\n                assert 0.1 < time.perf_counter() - t_start < 0.2, \"wait time exceeds alloc timeout\"\n                async with cache.allocate_cache(_make_tensor_descriptor(128), timeout=float(\"inf\")):\n                    pass\n\n                t_start = time.perf_counter()\n                with pytest.raises(AllocationFailed):\n                    async with cache.allocate_cache(_make_tensor_descriptor(384), timeout=1.0):  # exceeds max timeout\n                        pass\n                assert 0.5 < time.perf_counter() - t_start < 0.6, \"wait time exceeds max alloc timeout\"\n\n            # test memory allocation when another task frees the memory\n            async def _klog_the_cache():\n                async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0.2):\n                    pass\n\n            large_alloc_task = asyncio.create_task(_klog_the_cache())\n\n            t_start = time.perf_counter()\n            await asyncio.sleep(0.05)  # wait for large alloc to enqueue\n            async with cache.allocate_cache(_make_tensor_descriptor(128), timeout=float(\"inf\")):  # exceeds max timeout\n                pass  # this memory should allocate once the background task clears the queue\n            assert 0.2 < time.perf_counter() - t_start < 0.3, \"memory should be allocated after background task clears\"\n            with pytest.raises(AllocationFailed):\n                await large_alloc_task\n\n            # test that zero-timeout allocation fails instantaneously even if someone else is awaiting alloc\n            large_alloc_task = asyncio.create_task(_klog_the_cache())\n            t_start = time.perf_counter()\n            await asyncio.sleep(0.05)  # wait for large alloc to enqueue\n            with pytest.raises(AllocationFailed):\n                async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0):\n                    pass  # this memory should allocate once the background task clears the queue\n            assert time.perf_counter() - t_start < 0.1, \"zero-timeout task should fail (or succeed) instantaneously\"\n            with pytest.raises(AllocationFailed):\n                await large_alloc_task\n\n\n@pytest.mark.asyncio\nasync def test_unlimited_timeout():\n    cache = MemoryCache(max_size_bytes=1024)\n    cache.runtime_pid += 1  # pretend we're another process\n    t_start = time.perf_counter()\n\n    async def _klog_the_cache():\n        async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0.2):\n            await asyncio.sleep(0.5)\n\n    alloc_task = asyncio.create_task(_klog_the_cache())\n    await asyncio.sleep(0.1)\n    async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=float(\"inf\")):\n        await alloc_task\n    assert 0.5 < time.perf_counter() - t_start < 0.6, \"memory should be allocated after background task clears\"\n\n\n@pytest.mark.asyncio\nasync def test_cache_usage():\n    cache = MemoryCache(max_size_bytes=2048)\n    alloc_event, dealloc_a_event, dealloc_bcd_event, dealloc_e_event, dealloc_f_event = (mp.Event() for _ in range(5))\n    pipe_receiver, pipe_sender = mp.Pipe(duplex=False)\n    with pytest.raises(AssertionError):\n        async with cache.allocate_cache(_make_tensor_descriptor(123), timeout=1):\n            pass  # fails because cache must be allocated from another process\n\n    descr_a = TensorDescriptor.from_tensor(torch.empty(768, dtype=torch.uint8))  # 768 bytes\n    descr_b = TensorDescriptor.from_tensor(torch.empty((), dtype=torch.float64))  # 8 bytes\n    descr_c = TensorDescriptor.from_tensor(torch.empty((33,), dtype=torch.bool))  # 33 bytes\n    descr_d = TensorDescriptor.from_tensor(torch.empty((0,), dtype=torch.int64))  # 0 bytes\n    descr_e = TensorDescriptor.from_tensor(torch.empty((96, 8), dtype=torch.bfloat16))  # 1536 bytes\n    descr_f = TensorDescriptor.from_tensor(torch.empty((1792,), dtype=torch.uint8))  # 1792 bytes\n\n    async def _allocate_and_wait(dealloc_event, *descrs, timeout=None):\n        loop = asyncio.get_event_loop()\n        async with cache.allocate_cache(*descrs, timeout=timeout) as handles:\n            pipe_sender.send(handles)\n            await loop.run_in_executor(None, dealloc_event.wait)\n\n    async def _allocate_af():\n        alloc_event.wait()\n        allocate_a_task = asyncio.create_task(_allocate_and_wait(dealloc_a_event, descr_a))\n        await allocate_a_task\n        allocate_f_task = asyncio.create_task(_allocate_and_wait(dealloc_f_event, descr_f))  # klogs the cache\n        await allocate_f_task\n\n    alloc_process1 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_af()), daemon=True)\n    alloc_process1.start()\n\n    async def _allocate_bcde():\n        alloc_event.wait()\n        await asyncio.sleep(0.1)  # ensure that the other tensor is always allocated (and sent through pipe) first\n        allocate_bcd_task = asyncio.create_task(_allocate_and_wait(dealloc_bcd_event, descr_b, descr_c, descr_d))\n        allocate_e_task = asyncio.create_task(_allocate_and_wait(dealloc_e_event, descr_e))  # doesn't fit\n        await asyncio.wait({allocate_e_task, allocate_bcd_task}, return_when=asyncio.ALL_COMPLETED)\n\n    alloc_process2 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_bcde()), daemon=True)\n    alloc_process2.start()\n    assert cache.current_size_bytes == 0\n    alloc_event.set()\n    (handle_a,) = pipe_receiver.recv()\n\n    handle_b, handle_c, handle_d = pipe_receiver.recv()\n\n    with cache.use_cache(handle_a) as (tensor_a,):\n        assert tensor_a.dtype == torch.uint8\n        tensor_a[2:5] = torch.tensor((42, 43, 44))\n\n    with cache.use_cache(handle_a, handle_b, handle_d) as (tensor_a, tensor_b, tensor_d):\n        assert tensor_b.dtype == torch.float64 and tensor_b.numel() == 1 and tensor_b.ndim == 0\n        assert tensor_d.dtype == torch.int64 and tensor_d.numel() == 0\n        tensor_a += 1\n        tensor_b[...] = -1.337\n    assert cache.current_size_bytes == 809  # this checks a,b,c,d are allocated but b still awaits memory\n\n    dealloc_bcd_event.set()\n    await asyncio.sleep(0.1)\n    assert cache.current_size_bytes == 768  # only tensor a should be allocated\n    with pytest.raises(KeyError):\n        with cache.use_cache(handle_a, handle_b):\n            pass  # one of handles (c) is deallocated\n    with pytest.raises(KeyError):\n        with cache.use_cache(handle_d):\n            pass  # handle_d is deallocated correctly, even though it is never used\n    with cache.use_cache(handle_a) as (tensor_a,):\n        assert tuple(tensor_a[2:5]) == (43, 44, 45)\n\n    dealloc_a_event.set()\n    (handle_e,) = pipe_receiver.recv()  # e can finally be allocated\n    await asyncio.sleep(0.1)\n    assert cache.current_size_bytes == 1536  # tensor e should finally be able to allocate\n\n    with pytest.raises(KeyError):\n        with cache.use_cache(handle_a):\n            pass  # tensor a is no longer allocated\n    with cache.use_cache(handle_e) as (tensor_e,):\n        assert tensor_e.dtype == torch.bfloat16 and tensor_e.shape == (96, 8)\n\n    dealloc_e_event.set()\n    await asyncio.sleep(0.1)\n    assert cache.current_size_bytes == 1792  # only tensor f is still allocated\n    dealloc_f_event.set()\n\n    alloc_process1.join()\n    alloc_process2.join()\n    await asyncio.sleep(0.1)\n    assert cache.current_size_bytes == 0\n    assert cache.current_size_bytes == 0\n    assert alloc_process1.exitcode == 0, \"allocation process 1 failed or did not finish, see stderr for details\"\n    assert alloc_process2.exitcode == 0, \"allocation process 2 failed or did not finish, see stderr for details\"\n"
  },
  {
    "path": "tests/test_chained_calls.py",
    "content": "######\n# Warning:torch this test is a work in progress. It will be modified soon.\n# - if you want more stable tests, see test_block_exact_match\n# - if you want to figure out chained inference, ask yozh\n\n\nimport pytest\nimport torch\n\nfrom petals import AutoDistributedConfig\nfrom petals.client.remote_sequential import RemoteSequential\nfrom petals.server.from_pretrained import load_pretrained_block\nfrom petals.utils.misc import DUMMY_KEY_PAST\nfrom test_utils import *\n\n\n@pytest.mark.forked\ndef test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):\n    config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)\n    remote_blocks = RemoteSequential(config, start_block=3, end_block=6)\n    assert isinstance(remote_blocks, RemoteSequential)\n\n    ref_blocks = [\n        load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),\n        load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32),\n        load_pretrained_block(MODEL_NAME, 5, torch_dtype=torch.float32),\n    ]\n    inputs = torch.randn(1, seq_length, config.hidden_size, requires_grad=True)\n    outputs_rpc = remote_blocks.forward(inputs)\n    outputs_rpc.sum().backward()\n    grads_rpc = inputs.grad\n\n    inputs.grad = None\n    hidden_states = inputs\n    for ref_block in ref_blocks:\n        hidden_states = ref_block.forward(hidden_states)[0]\n    outputs_ref = hidden_states\n    outputs_ref.sum().backward()\n    grads_ref = inputs.grad\n\n    assert torch.allclose(outputs_ref, outputs_rpc, rtol=0, atol=atol_forward)\n    assert torch.allclose(grads_ref, grads_rpc, rtol=0, atol=atol_backward)\n\n\n@pytest.mark.forked\ndef test_chained_inference_exact_match(atol_inference=1e-4):\n    config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)\n    remote_blocks = RemoteSequential(config, start_block=3, end_block=5)\n\n    inputs = torch.randn(1, 8, config.hidden_size)\n\n    outputs_inference = []\n    with remote_blocks.inference_session(max_length=inputs.shape[1]) as sess:\n        for i in range(inputs.shape[1]):\n            outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))\n    outputs_inference = torch.cat(outputs_inference, dim=1)\n\n    dtype = torch.float32\n    ref_blocks = [\n        load_pretrained_block(MODEL_NAME, 3, torch_dtype=dtype),\n        load_pretrained_block(MODEL_NAME, 4, torch_dtype=dtype),\n    ]\n    outputs_ref = []\n    cache = (DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype))\n    caches = [cache, cache]\n    for i in range(inputs.shape[1]):\n        new_caches = []\n        hidden_states = inputs[:, i : i + 1, :]\n        for ref_block, cache in zip(ref_blocks, caches):\n            with torch.no_grad():\n                hidden_states, new_cache = ref_block.forward(hidden_states, use_cache=True, layer_past=cache)\n                new_caches.append(new_cache)\n\n        outputs_ref.append(hidden_states)\n        caches = new_caches\n    outputs_ref = torch.cat(outputs_ref, dim=1)\n    assert torch.allclose(outputs_ref, outputs_inference, rtol=0, atol=atol_inference)\n"
  },
  {
    "path": "tests/test_dtype.py",
    "content": "import pytest\nimport torch\n\nfrom petals.server.block_utils import resolve_block_dtype\nfrom petals.server.from_pretrained import load_pretrained_block\nfrom petals.utils.auto_config import AutoDistributedConfig\nfrom test_utils import MODEL_NAME\n\n\n@pytest.mark.forked\n@pytest.mark.parametrize(\"torch_dtype\", [torch.float32, torch.float16, \"auto\"])\ndef test_block_dtype(torch_dtype):\n    config = AutoDistributedConfig.from_pretrained(MODEL_NAME)\n    block = load_pretrained_block(MODEL_NAME, 0, config=config, torch_dtype=torch_dtype)\n    expected_dtype = resolve_block_dtype(config, torch_dtype)\n    assert all(param.dtype == expected_dtype for param in block.parameters())\n"
  },
  {
    "path": "tests/test_full_model.py",
    "content": "import peft\nimport pytest\nimport torch\nimport transformers\nfrom hivemind import get_logger\n\nfrom petals import AutoDistributedModelForCausalLM\nfrom test_utils import *\n\nlogger = get_logger(__name__)\n\n\n@pytest.fixture\ndef tokenizer():\n    # We set use_fast=False since LlamaTokenizerFast is slow on load\n    return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)\n\n\n@pytest.fixture\ndef model():\n    return AutoDistributedModelForCausalLM.from_pretrained(\n        MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32\n    )\n\n\n@pytest.fixture\ndef ref_model():\n    return transformers.AutoModelForCausalLM.from_pretrained(\n        REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32\n    )\n\n\n@pytest.mark.forked\n@pytest.mark.parametrize(\"use_peft\", (True, False) if ADAPTER_NAME else (False,))\n@pytest.mark.parametrize(\"pass_empty_tensors\", (True, False))\ndef test_full_model_exact_match(tokenizer, model, ref_model, use_peft, pass_empty_tensors, atol=1e-3):\n    if use_peft:\n        model.config.active_adapter = ADAPTER_NAME\n\n        ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)\n        ref_model.train(False)\n\n    test_inputs = tokenizer(\"A quick brown fox was minding its own buisness\", return_tensors=\"pt\")[\"input_ids\"]\n\n    with torch.inference_mode():\n        parallel_outputs = model.forward(test_inputs).logits\n        assert torch.all(torch.isfinite(parallel_outputs))\n        logger.info(\"Forward outputs are finite\")\n\n        embs = model.transformer.word_embeddings(test_inputs)\n        embs = model.transformer.word_embeddings_layernorm(embs)\n        recurrent_outputs = []\n        with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess:\n            if pass_empty_tensors:\n                recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))\n\n            for t in range(embs.shape[1]):\n                if t == 4:\n                    recurrent_outputs.append(sess.step(embs[:, 4:9, :]))\n                elif 4 < t < 9:\n                    continue\n                else:\n                    recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))\n\n                if t == 2 and pass_empty_tensors:\n                    recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))\n                    recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))\n\n        recurrent_outputs = torch.cat(recurrent_outputs, dim=1)\n        recurrent_outputs = model.transformer.ln_f(recurrent_outputs)\n        recurrent_outputs = model.lm_head(recurrent_outputs)\n        assert torch.allclose(\n            recurrent_outputs, parallel_outputs, rtol=0, atol=atol\n        ), \"Inference differs from forward pass\"\n\n        ref_outputs = ref_model.forward(test_inputs).logits.float()\n        assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol), \"Outputs are not identical to HF\"\n\n\ndef make_generate_calls(model, inputs, *, max_new_tokens, multiple_calls=False, **kwargs):\n    if not multiple_calls:\n        return model.generate(inputs, max_new_tokens=max_new_tokens, **kwargs)\n\n    with model.inference_session(max_length=inputs.shape[1] + max_new_tokens) as sess:\n        return torch.cat(\n            [\n                # Sessions provided both explicitly and implicitly should work\n                model.generate(inputs, max_new_tokens=1, **kwargs, session=sess),\n                model.generate(None, max_new_tokens=max_new_tokens - 2, **kwargs),\n                model.generate(None, max_new_tokens=1, **kwargs),\n            ],\n            dim=1,\n        )\n\n\n@pytest.mark.forked\ndef test_greedy_generation(tokenizer, model, ref_model, max_new_tokens=4):\n    inputs_single = tokenizer(\"A cat sat on a mat\", return_tensors=\"pt\")[\"input_ids\"]\n\n    if tokenizer.pad_token_id is None:\n        tokenizer.pad_token_id = tokenizer.eos_token_id\n    inputs_batch = tokenizer([\"A cat sat on a mat\", \"A dog sat on a mat\"], return_tensors=\"pt\", padding=True)[\n        \"input_ids\"\n    ]\n\n    options = dict(max_new_tokens=max_new_tokens, do_sample=False)\n    for multiple_calls in [False, True]:\n        for inputs in [inputs_single, inputs_batch]:\n            outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options)\n            ref_outputs = ref_model.generate(inputs, **options)\n            assert torch.allclose(\n                outputs, ref_outputs\n            ), f\"Greedy generation is not identical to HF with {multiple_calls=}, {inputs.shape=}\"\n\n\n@pytest.mark.forked\ndef test_sampling(tokenizer, model, ref_model, max_new_tokens=10):\n    inputs_single = tokenizer(\"A cat sat on a mat\", return_tensors=\"pt\")[\"input_ids\"]\n\n    if tokenizer.pad_token_id is None:\n        tokenizer.pad_token_id = tokenizer.eos_token_id\n    inputs_batch = tokenizer([\"A cat sat on a mat\", \"A dog sat on a mat\"], return_tensors=\"pt\", padding=True)[\n        \"input_ids\"\n    ]\n\n    for options in [\n        dict(do_sample=True, temperature=0.5, top_k=5, top_p=0.9),\n        dict(do_sample=True, temperature=0.5, repetition_penalty=1.2),\n    ]:\n        options.update(max_new_tokens=max_new_tokens)\n        for multiple_calls in [False, True]:\n            for inputs in [inputs_single, inputs_batch]:\n                torch.manual_seed(0)\n                outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options)\n\n                torch.manual_seed(0)\n                ref_outputs = ref_model.generate(inputs, **options)\n\n                assert torch.allclose(\n                    outputs, ref_outputs\n                ), f\"Sampling is not identical to HF with {options=}, {multiple_calls=}, {inputs.shape=}\"\n\n\n@pytest.mark.skipif(\n    \"bloom\" not in MODEL_NAME.lower(),\n    reason=\"Mixtral and Llama use DynamicCache, which can change based on beam search choices\",\n)\n@pytest.mark.forked\ndef test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5):\n    inputs = tokenizer(\"A cat sat on a mat\", return_tensors=\"pt\")[\"input_ids\"]\n\n    options = dict(max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False)\n    outputs = make_generate_calls(model, inputs, **options)\n    ref_outputs = ref_model.generate(inputs, **options)\n    assert torch.allclose(outputs, ref_outputs), f\"Beam search results are not identical to HF\"\n\n\n@pytest.mark.forked\ndef test_input_ids(tokenizer, model, ref_model, max_new_tokens=4):\n    inputs = tokenizer(\"A cat sat on a mat\", return_tensors=\"pt\")\n    assert inputs.keys() == {\"input_ids\", \"attention_mask\"}\n\n    outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)\n    ref_outputs = ref_model.generate(**inputs, max_new_tokens=max_new_tokens)\n    assert torch.allclose(outputs, ref_outputs), f\"Outputs are not identical to HF\"\n\n    with model.inference_session(max_length=inputs[\"input_ids\"].shape[1] + max_new_tokens):\n        outputs = torch.cat(\n            [\n                model.generate(**inputs, max_new_tokens=2),\n                model.generate(None, max_new_tokens=max_new_tokens - 2),\n            ],\n            dim=1,\n        )\n    assert torch.allclose(outputs, ref_outputs), f\"Multi-call outputs are not identical to HF\"\n"
  },
  {
    "path": "tests/test_optimized_layers.py",
    "content": "from typing import Optional, Tuple\n\nimport pytest\nimport torch\nfrom transformers.cache_utils import DynamicCache\nfrom transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask\nfrom transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor\nfrom transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel\n\nfrom petals.server.block_utils import get_model_block\nfrom petals.utils.auto_config import AutoDistributedConfig\nfrom petals.utils.convert_block import QuantType, convert_block\nfrom test_utils import MODEL_NAME\n\nKVCache = Tuple[torch.Tensor, torch.Tensor]\n\n\nclass UnoptimizedWrappedFalconBlock(FalconDecoderLayer):\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        *args,\n        attention_mask: Optional[torch.Tensor] = None,\n        alibi: Optional[torch.Tensor] = None,\n        layer_past: Optional[KVCache] = None,\n        use_cache: bool = False,\n        **kwargs,\n    ):\n        batch_size, seq_length = hidden_states.shape[:2]\n\n        if layer_past is not None:\n            layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)\n        past_length = 0 if layer_past is None else layer_past[0].shape[1]\n        seq_length_with_past = seq_length + past_length\n\n        attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)\n        if alibi is None and self.config.alibi:\n            alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)\n        attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)\n\n        outputs = super().forward(\n            hidden_states,\n            *args,\n            attention_mask=attention_mask,\n            alibi=alibi,\n            layer_past=layer_past,\n            use_cache=use_cache,\n            **kwargs,\n        )\n\n        if use_cache:\n            present_key_value = outputs[-1]\n            present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)\n            outputs = outputs[:-1] + (present_key_value,)\n\n        return outputs\n\n    def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:\n        key_states, value_states = key_value\n\n        key_states = key_states.permute(0, 2, 1)\n        assert key_states.shape == value_states.shape  # Both are [batch_size * num_kv_heads, seq_len, head_dim]\n\n        if self.config.new_decoder_architecture:\n            key_states = self._expand_states(key_states)\n            value_states = self._expand_states(value_states)\n\n        return (key_states, value_states)\n\n    def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:\n        key_states, value_states = key_value\n\n        if self.config.new_decoder_architecture:\n            key_states = self._collapse_states(key_states)\n            value_states = self._collapse_states(value_states)\n\n        assert key_states.shape == value_states.shape  # Both are [batch_size * num_kv_heads, seq_len, head_dim]\n        key_states = key_states.permute(0, 2, 1)\n\n        return (key_states, value_states)\n\n    def _expand_states(self, state: torch.Tensor) -> torch.Tensor:\n        batch_size_x_num_kv_heads, seq_len, head_dim = state.shape\n        batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads\n\n        state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)\n        state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1)  # No copy\n        state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim)  # Involves a copy\n        return state\n\n    def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:\n        batch_size_x_num_attn_heads, seq_len, head_dim = state.shape\n        batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads\n\n        state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)\n        state = state[:, :, 0]\n        state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)\n        return state\n\n\nclass UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer):\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        *args,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        layer_past: Optional[Tuple[torch.Tensor]] = None,\n        use_cache: bool = False,\n        **kwargs,\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        batch_size, seq_length, _ = hidden_states.shape\n\n        seq_length_with_past = seq_length\n        past_key_values_length = 0\n\n        past_key_value = layer_past\n        if past_key_value is not None:\n            past_key_values_length = past_key_value[0].shape[2]\n            seq_length_with_past = seq_length_with_past + past_key_values_length\n            past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)\n        elif use_cache:\n            past_key_value = DynamicCache()\n\n        if position_ids is None:\n            device = hidden_states.device\n            position_ids = torch.arange(\n                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device\n            )\n            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)\n        else:\n            position_ids = position_ids.view(-1, seq_length).long()\n\n        # embed positions\n        if attention_mask is None:\n            attention_mask = torch.ones(\n                (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device\n            )\n\n        attention_mask = _prepare_4d_causal_attention_mask(\n            attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length\n        )\n\n        outputs = super().forward(\n            hidden_states,\n            *args,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            use_cache=use_cache,\n            **kwargs,\n        )\n\n        if use_cache:\n            present_key_value = outputs[-1]\n            present_key_value = self._reorder_cache_from_llama_to_bloom(\n                present_key_value, batch_size, seq_length_with_past\n            )\n            outputs = outputs[:-1] + (present_key_value,)\n\n        return outputs\n\n    def _reorder_cache_from_bloom_to_llama(\n        self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int\n    ) -> DynamicCache:\n        key_states, value_states = key_value\n        key_states = key_states.permute(0, 2, 1)\n        key_states = key_states.view(\n            batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim\n        )\n        value_states = value_states.view(*key_states.shape)\n        past_key_values = ((key_states, value_states),)\n        return DynamicCache.from_legacy_cache(past_key_values)\n\n    def _reorder_cache_from_llama_to_bloom(\n        self, key_value: DynamicCache, batch_size: int, seq_length: int\n    ) -> Tuple[torch.Tensor]:\n        key_states, value_states = key_value.to_legacy_cache()[0]\n        value_states = value_states.view(\n            batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim\n        )\n        key_states = key_states.view(*value_states.shape)\n        key_states = key_states.permute(0, 2, 1)\n        return (key_states, value_states)\n\n\n@pytest.mark.parametrize(\"device\", [\"cpu\", \"cuda:0\"])\n@pytest.mark.forked\ndef test_optimized_block(device):\n    if device == \"cuda:0\" and not torch.cuda.is_available():\n        pytest.skip(\"CUDA tests can be run only in CUDA-enabled setups\")\n\n    config = AutoDistributedConfig.from_pretrained(MODEL_NAME)\n\n    tensor_parallel_devices = (device,)\n    dtype = torch.bfloat16\n    quant_type = QuantType.NONE\n\n    block_idx = 1\n    block = get_model_block(config, layer_idx=block_idx).to(dtype)\n    block = convert_block(block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)\n\n    if config.model_type == \"falcon\":\n        unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)\n    elif config.model_type == \"llama\":\n        unopt_block = UnoptimizedWrappedLlamaBlock(config, layer_idx=0).to(dtype)\n    else:\n        pytest.skip(f\"This test is not applicable to {config.model_type} models\")\n\n    unopt_block = convert_block(\n        unopt_block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True\n    )\n\n    unopt_block.load_state_dict(block.state_dict())\n    cache = unopt_cache = None\n\n    with torch.inference_mode():\n        for length in [10, 1, 1, 1]:\n            dummy_input = torch.randn(1, length, config.hidden_size, device=device, dtype=dtype)\n            block_output, cache = block(dummy_input, layer_past=cache, use_cache=True)\n            unopt_block_output, unopt_cache = unopt_block(dummy_input, layer_past=unopt_cache, use_cache=True)\n            assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), length\n            assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), length\n            assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), length\n"
  },
  {
    "path": "tests/test_peft.py",
    "content": "import os\nimport shutil\n\nimport pytest\nfrom huggingface_hub import snapshot_download\n\nfrom petals.utils.peft import check_peft_repository, load_peft\n\nUNSAFE_PEFT_REPO = \"artek0chumak/bloom-560m-unsafe-peft\"\nSAFE_PEFT_REPO = \"artek0chumak/bloom-560m-safe-peft\"\nTMP_CACHE_DIR = \"tmp_cache/\"\n\n\ndef clear_dir(path_to_dir):\n    shutil.rmtree(path_to_dir)\n    os.mkdir(path_to_dir)\n\n\ndef dir_empty(path_to_dir):\n    files = os.listdir(path_to_dir)\n    return len(files) == 0\n\n\n@pytest.mark.forked\ndef test_check_peft():\n    assert not check_peft_repository(UNSAFE_PEFT_REPO), \"NOSAFE_PEFT_REPO is safe to load.\"\n    assert check_peft_repository(SAFE_PEFT_REPO), \"SAFE_PEFT_REPO is not safe to load.\"\n\n\n@pytest.mark.forked\ndef test_load_noncached(tmpdir):\n    clear_dir(tmpdir)\n    with pytest.raises(Exception):\n        load_peft(UNSAFE_PEFT_REPO, cache_dir=tmpdir)\n\n    assert dir_empty(tmpdir), \"UNSAFE_PEFT_REPO is loaded\"\n\n    load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)\n\n    assert not dir_empty(tmpdir), \"SAFE_PEFT_REPO is not loaded\"\n\n\n@pytest.mark.forked\ndef test_load_cached(tmpdir):\n    clear_dir(tmpdir)\n    snapshot_download(SAFE_PEFT_REPO, cache_dir=tmpdir)\n\n    load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)\n\n\n@pytest.mark.forked\ndef test_load_layer_exists(tmpdir):\n    clear_dir(tmpdir)\n\n    load_peft(SAFE_PEFT_REPO, block_idx=2, cache_dir=tmpdir)\n\n\n@pytest.mark.forked\ndef test_load_layer_nonexists(tmpdir):\n    clear_dir(tmpdir)\n\n    load_peft(\n        SAFE_PEFT_REPO,\n        block_idx=1337,\n        cache_dir=tmpdir,\n    )\n"
  },
  {
    "path": "tests/test_priority_pool.py",
    "content": "import multiprocessing as mp\nimport platform\nimport time\n\nimport pytest\nimport torch\nfrom hivemind.moe.server.runtime import Runtime\n\nfrom petals.server.task_pool import PrioritizedTaskPool\n\n\ndef _submit_tasks(runtime_ready, pools, results_valid):\n    runtime_ready.wait()\n\n    futures = []\n    futures.append(pools[0].submit_task(torch.tensor([0]), priority=1))\n    futures.append(pools[0].submit_task(torch.tensor([1]), priority=1))\n    time.sleep(0.01)\n    futures.append(pools[1].submit_task(torch.tensor([2]), priority=1))\n    futures.append(pools[0].submit_task(torch.tensor([3]), priority=2))\n    futures.append(pools[0].submit_task(torch.tensor([4]), priority=10))\n    futures.append(pools[0].submit_task(torch.tensor([5]), priority=0))\n    futures.append(pools[0].submit_task(torch.tensor([6]), priority=1))\n    futures.append(pools[1].submit_task(torch.tensor([7]), priority=11))\n    futures.append(pools[1].submit_task(torch.tensor([8]), priority=1))\n    for i, f in enumerate(futures):\n        assert f.result()[0].item() == i**2\n    results_valid.set()\n\n\n@pytest.mark.skipif(platform.system() == \"Darwin\", reason=\"Flapping on macOS due to multiprocessing quirks\")\n@pytest.mark.forked\ndef test_priority_pools():\n    outputs_queue = mp.SimpleQueue()\n    runtime_ready = mp.Event()\n    results_valid = mp.Event()\n\n    def dummy_pool_func(x):\n        time.sleep(0.1)\n        y = x**2\n        outputs_queue.put((x, y))\n        return (y,)\n\n    class DummyBackend:\n        def __init__(self, pools):\n            self.pools = pools\n\n        def get_pools(self):\n            return self.pools\n\n    pools = (\n        PrioritizedTaskPool(dummy_pool_func, name=\"A\", max_batch_size=1),\n        PrioritizedTaskPool(dummy_pool_func, name=\"B\", max_batch_size=1),\n    )\n\n    # Simulate requests coming from ConnectionHandlers\n    proc = mp.context.ForkProcess(target=_submit_tasks, args=(runtime_ready, pools, results_valid))\n    proc.start()\n\n    runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0)\n    runtime.ready = runtime_ready\n    runtime.start()\n\n    proc.join()\n    assert results_valid.is_set()\n\n    ordered_outputs = []\n    while not outputs_queue.empty():\n        ordered_outputs.append(outputs_queue.get()[0].item())\n\n    assert ordered_outputs == [0, 5, 1, 2, 6, 8, 3, 4, 7]\n    #                          0 - first batch is loaded immediately, before everything else\n    #                             5 - highest priority task overall\n    #                                1 - first of several tasks with equal lowest priority (1)\n    #                                   2 - second earliest task with priority 1, fetched from pool B\n    #                                      6 - third earliest task with priority 1, fetched from pool A again\n    #                                         8 - last priority-1 task, pool B\n    #                                            3 - task with priority 2 from pool A\n    #                                               4 - task with priority 10 from pool A\n    #                                                  7 - task with priority 11 from pool B\n\n    runtime.shutdown()\n"
  },
  {
    "path": "tests/test_remote_sequential.py",
    "content": "import pytest\nimport torch\nimport torch.nn.functional as F\nfrom hivemind import DHT, BatchTensorDescriptor, get_logger\nfrom hivemind.proto import runtime_pb2\n\nfrom petals import AutoDistributedConfig\nfrom petals.client import RemoteSequenceManager, RemoteSequential\nfrom petals.data_structures import UID_DELIMITER\nfrom petals.server.from_pretrained import load_pretrained_block\nfrom test_utils import *\n\nlogger = get_logger(__name__)\n\n\n@pytest.mark.forked\ndef test_remote_sequential():\n    config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)\n    dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)\n    test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True)\n    grad_proj = torch.randn(1, 5, config.hidden_size)\n\n    sequential = RemoteSequential(config, dht=dht)\n\n    full_outputs = sequential(test_inputs)\n    (full_outputs * grad_proj).sum().backward()\n    assert test_inputs.grad is not None\n    full_grad = test_inputs.grad.clone()\n    test_inputs.grad.data.zero_()\n\n    first_half = sequential[: config.num_hidden_layers // 2]\n    second_half = sequential[config.num_hidden_layers // 2 :]\n    assert len(first_half) + len(second_half) == len(sequential)\n    assert abs(len(first_half) - len(second_half)) == config.num_hidden_layers % 2\n    for m in sequential, first_half, second_half:\n        assert isinstance(repr(m), str)\n\n    hidden = first_half(test_inputs)\n    assert isinstance(hidden, torch.Tensor)\n    assert hidden.shape == test_inputs.shape\n    assert hidden.requires_grad\n    second_half_outputs = second_half(hidden)\n    assert torch.allclose(second_half_outputs, full_outputs, atol=1e-3)\n\n    (second_half_outputs * grad_proj).sum().backward()\n    assert torch.allclose(test_inputs.grad, full_grad, atol=3e-2)\n\n    # test RemoteSequential with lossy compression\n    block_uids = [f\"{config.dht_prefix}{UID_DELIMITER}{i}\" for i in range(config.num_hidden_layers)]\n    lossy_sequential = RemoteSequential(\n        config, sequence_manager=DummyCustomSequenceManager(config, block_uids, dht=dht)\n    )\n\n    test_inputs.grad = None\n    approx_outputs = lossy_sequential(test_inputs)\n    (approx_outputs * grad_proj).sum().backward()\n\n    assert not torch.allclose(approx_outputs, full_outputs, rtol=0, atol=1e-4), \"compression was not used\"\n    assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-3), \"compression was not used\"\n    assert abs(approx_outputs - full_outputs).mean() < 0.01\n    absmax = abs(full_grad).max()\n    assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.05\n\n\nclass DummyCustomSequenceManager(RemoteSequenceManager):\n    \"\"\"A sequence manager that compresses inputs/outputs during forward and backward pass.\"\"\"\n\n    @property\n    def rpc_info(self):\n        rpc_info = super().rpc_info\n        dims = (2048, 1024)\n        compressed_input_schema = BatchTensorDescriptor(dims, compression=runtime_pb2.CompressionType.FLOAT16)\n        rpc_info[\"forward_schema\"] = (compressed_input_schema,), dict()  # (args, kwargs)\n        return rpc_info\n\n    def get_request_metadata(self, protocol: str, *args, **kwargs):\n        metadata = super().get_request_metadata(protocol, *args, **kwargs)\n        if protocol == \"rpc_forward\":\n            metadata[\"output_compression\"] = (runtime_pb2.CompressionType.FLOAT16,)\n        elif protocol == \"rpc_backward\":\n            metadata[\"output_compression\"] = (runtime_pb2.CompressionType.FLOAT16,)\n            # FIXME: Initially, we used CompressionType.BLOCKWISE_8BIT for rpc_backward() here.\n            # This is currently broken since hivemind==1.1.8 is not compatible with bitsandbytes==0.39.1.\n            # Please revert to BLOCKWISE_8BIT once this is fixed: https://github.com/learning-at-home/hivemind/issues/572\n        return metadata\n\n\n@pytest.mark.forked\ndef test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):\n    config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)\n    remote_sequential = RemoteSequential(config)\n\n    inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1)\n    output_proj = F.normalize(torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size), dim=-1)\n    input_prompts = F.normalize(torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True), dim=-1)\n    intermediate_prompts = torch.randn(\n        config.num_hidden_layers, batch_size, pre_seq_len, config.hidden_size, requires_grad=True\n    )\n\n    input_prompts = input_prompts.detach().requires_grad_(True)\n    intermediate_prompts = intermediate_prompts.detach().requires_grad_(True)\n\n    inputs_with_prompts = torch.cat([inputs, input_prompts], dim=1)\n    assert inputs_with_prompts.shape == (batch_size, seq_len + pre_seq_len, config.hidden_size)\n\n    outputs = remote_sequential(inputs_with_prompts, prompts=intermediate_prompts)\n\n    (outputs * output_proj).sum().backward()\n    assert intermediate_prompts.grad is not None\n\n    input_prompts_ref = input_prompts.clone().detach().requires_grad_(True)\n    intermediate_prompts_ref = intermediate_prompts.clone().detach().requires_grad_(True)\n\n    assert input_prompts_ref.grad is None\n    assert intermediate_prompts_ref.grad is None\n\n    outputs_ref = torch.cat([inputs, input_prompts_ref], dim=1)\n    for block_index in range(config.num_hidden_layers):\n        block_prompt = intermediate_prompts_ref[block_index]\n        outputs_ref[:, : block_prompt.shape[1]] += block_prompt\n\n        block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32)\n        (outputs_ref,) = block(outputs_ref)\n\n    assert torch.allclose(outputs_ref, outputs, atol=1e-3)\n\n    (outputs_ref * output_proj).sum().backward()\n    assert input_prompts_ref.grad is not None\n    assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=3e-2)\n    assert intermediate_prompts_ref.grad is not None\n    assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad, atol=1e-2)\n"
  },
  {
    "path": "tests/test_sequence_manager.py",
    "content": "import threading\nimport time\n\nimport pytest\nimport torch\nfrom hivemind import DHT, get_logger\n\nfrom petals import AutoDistributedConfig\nfrom petals.client import RemoteSequenceManager, RemoteSequential\nfrom petals.data_structures import UID_DELIMITER\nfrom test_utils import *\n\nlogger = get_logger(__name__)\n\n\n@pytest.mark.forked\n@pytest.mark.parametrize(\"mode\", [\"max_throughput\", \"min_latency\"])\ndef test_sequence_manager_basics(mode: str):\n    config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)\n    dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)\n    sequential = RemoteSequential(config, dht=dht)\n    shutdown_evt = threading.Event()\n\n    # test RemoteSequential with lossy compression\n    block_uids = [f\"{config.dht_prefix}{UID_DELIMITER}{i}\" for i in range(config.num_hidden_layers)]\n    sequential = RemoteSequential(\n        config,\n        sequence_manager=RemoteSequenceManagerWithChecks(config, block_uids, dht=dht, _was_shut_down=shutdown_evt),\n    )\n\n    sequence = sequential.sequence_manager.make_sequence(mode=mode)\n    assert all(sequence[i].peer_id != sequence[i + 1].peer_id for i in range(len(sequence) - 1))\n\n    assert sequential.sequence_manager.is_alive()\n    assert sequential.sequence_manager._thread.ready.is_set()\n    assert not shutdown_evt.is_set()\n    sequential(torch.randn(1, 2, config.hidden_size))\n\n    sequential.sequence_manager.shutdown()\n    del sequential\n    time.sleep(1)\n\n    assert shutdown_evt.is_set()\n\n\nclass RemoteSequenceManagerWithChecks(RemoteSequenceManager):\n    \"\"\"A sequence manager that signals if it was shut down\"\"\"\n\n    def __init__(self, *args, _was_shut_down: threading.Event, **kwargs):\n        super().__init__(*args, **kwargs)\n        self._was_shut_down = _was_shut_down\n\n    def shutdown(self):\n        super().shutdown()\n        assert not self.is_alive()\n        self._was_shut_down.set()\n"
  },
  {
    "path": "tests/test_server_stats.py",
    "content": "import time\n\nimport hivemind\nimport pytest\nimport torch\n\nfrom petals import AutoDistributedConfig, RemoteSequential\nfrom petals.server.handler import CACHE_TOKENS_AVAILABLE\nfrom test_utils import *\n\n\n@pytest.mark.forked\ndef test_server_info(block_from: int = 2, block_to: int = 5, max_length: int = 100, max_length2: int = 50):\n    config = AutoDistributedConfig.from_pretrained(MODEL_NAME)\n    config.allowed_servers = [\"QmNV5G3hq2UmAck2htEgsqrmPFBff5goFZAdmKDcZLBZLX\"]  # PeerID from server2.id\n\n    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)\n    blocks1 = RemoteSequential(config, dht=dht, start_block=block_from, end_block=block_to)\n    blocks2 = RemoteSequential(config, dht=dht, start_block=block_to - 1, end_block=block_to)\n\n    info_before = blocks1.sequence_manager.rpc_info\n\n    with blocks1.inference_session(max_length=max_length) as sess:\n        sess.step(torch.randn(1, 1, config.hidden_size))\n        blocks1.sequence_manager.state.rpc_info = None  # invalidate cache\n        info_inside = blocks1.sequence_manager.rpc_info\n\n        with blocks2.inference_session(max_length=max_length2) as sess2:\n            sess2.step(torch.randn(1, 1, config.hidden_size))\n            blocks2.sequence_manager.state.rpc_info = None  # invalidate cache\n            info_inside2 = blocks2.sequence_manager.rpc_info\n\n    time.sleep(0.1)\n    blocks1.sequence_manager.state.rpc_info = None  # invalidate cache\n    info_after = blocks1.sequence_manager.rpc_info\n\n    assert info_before[CACHE_TOKENS_AVAILABLE] == info_after[CACHE_TOKENS_AVAILABLE]\n    assert info_before[CACHE_TOKENS_AVAILABLE] - info_inside[CACHE_TOKENS_AVAILABLE] == max_length * len(blocks1)\n    assert info_inside[CACHE_TOKENS_AVAILABLE] - info_inside2[CACHE_TOKENS_AVAILABLE] == max_length2 * len(blocks2)\n"
  },
  {
    "path": "tests/test_speculative_generation.py",
    "content": "import random\n\nimport pytest\nimport torch\nimport transformers\n\nfrom petals import (\n    AutoDistributedConfig,\n    AutoDistributedSpeculativeModel,\n    DistributedLlamaForSpeculativeGeneration,\n    RemoteSequential,\n)\nfrom petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS\nfrom petals.server.from_pretrained import load_pretrained_block\nfrom test_utils import *\n\n\n@pytest.mark.forked\ndef test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, atol_inference=1e-3):\n    config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)\n    remote_sequential = RemoteSequential(config)\n\n    block_index = random.randint(0, config.num_hidden_layers - 1)\n    remote_block = remote_sequential[block_index]\n\n    inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)\n    short_inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)\n    short_inputs[:, :2, :] = inputs[:, :2, :]\n\n    initial_outputs_inference = None\n    secondary_outputs_inference = None\n    with torch.inference_mode():\n        with remote_block.inference_session(max_length=inputs.shape[1]) as sess:\n            initial_outputs_inference = sess.step(inputs)\n            sess.position = 2\n            secondary_outputs_inference = sess.step(short_inputs[:, 2:, :])\n            result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)\n\n    ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)\n    (outputs_local,) = ref_block(short_inputs)\n\n    assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference)\n\n\n@pytest.fixture\ndef noisy_model():\n    noisy_model = transformers.AutoModelForCausalLM.from_pretrained(\n        REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32\n    )\n    lm_head = noisy_model.get_output_embeddings()\n    assert isinstance(lm_head, torch.nn.Linear)\n    with torch.no_grad():\n        lm_head.weight += torch.randn_like(lm_head.weight) * 0.02\n    return noisy_model\n\n\n@pytest.fixture\ndef model():\n    return transformers.AutoModelForCausalLM.from_pretrained(\n        MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32\n    )\n\n\n@pytest.fixture\ndef tokenizer():\n    # We set use_fast=False since LlamaTokenizerFast is slow on load\n    return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)\n\n\n@pytest.mark.forked\n@pytest.mark.skipif(\n    \"llama\" not in MODEL_NAME.lower(),\n    reason=\"Speculative generation now works only for llama models\",\n)\ndef test_remote_speculative_generation(tokenizer, model, noisy_model, atol_inference=1e-3):\n    speculated_distributed_model = AutoDistributedSpeculativeModel.from_pretrained(\n        MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32, small_model=noisy_model\n    )\n\n    inputs_single = tokenizer(\"A cat sat on a mat\", return_tensors=\"pt\")[\"input_ids\"]\n\n    generated_spec = speculated_distributed_model.generate(inputs_single, max_new_tokens=100, do_sample=False)\n    generated_local = model.generate(inputs_single, max_new_tokens=100, do_sample=False)\n\n    assert torch.allclose(generated_spec, generated_local, rtol=0, atol=atol_inference)\n"
  },
  {
    "path": "tests/test_tensor_parallel.py",
    "content": "import random\n\nimport pytest\nimport torch\nimport transformers\nfrom tensor_parallel import TensorParallel\nfrom tensor_parallel.slicing_configs import get_bloom_config\n\nfrom petals.server.from_pretrained import load_pretrained_block\nfrom test_utils import MODEL_NAME\n\n\n@pytest.mark.forked\n@pytest.mark.parametrize(\"custom_config\", [True, False])\n@pytest.mark.parametrize(\"devices\", [(\"cpu\",) * 2, (\"cpu\",) * 3, (\"cpu\",) * 4])\ndef test_tp_block(devices, custom_config):\n    model_config = transformers.AutoConfig.from_pretrained(MODEL_NAME)\n    if model_config.model_type != \"bloom\":\n        pytest.skip(\"Tensor parallelism is implemented only for BLOOM for now\")\n\n    block_index = random.randint(0, 10)\n    block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32).to(devices[0])\n\n    tp_config = None\n    if custom_config:\n        tp_config = get_bloom_config(model_config, devices)\n\n    batch_size = 2\n    prefix_length = 5\n\n    test_inputs1 = torch.randn(batch_size, 3, 1024, requires_grad=True, device=devices[0])\n    test_inputs2 = test_inputs1.detach().clone().requires_grad_(True)\n    test_prefix1 = torch.randn(batch_size, prefix_length, 1024, requires_grad=True, device=devices[0])\n    test_prefix2 = test_prefix1.detach().clone().requires_grad_(True)\n    grad_proj = torch.rand_like(test_inputs1)\n\n    y_prefix_ref, layer_past = block(test_prefix1, use_cache=True)\n    y_ref, cache_ref = block(test_inputs1, use_cache=True, layer_past=layer_past)\n    y_ref.backward(grad_proj)\n\n    block_tp = TensorParallel(block, devices, config=tp_config)\n    y_prefix, layer_past = block_tp(test_prefix2, use_cache=True)\n    y_ours, cache_ours = block_tp(test_inputs2, use_cache=True, layer_past=layer_past)\n    y_ours.backward(grad_proj)\n\n    assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-5)\n    assert torch.allclose(y_ours, y_ref, atol=1e-5)\n    assert torch.allclose(test_inputs1.grad, test_inputs2.grad, atol=1e-4)\n    assert torch.allclose(test_prefix1.grad, test_prefix2.grad, atol=1e-4)\n"
  },
  {
    "path": "tests/test_utils.py",
    "content": "import os\n\nINITIAL_PEERS = os.environ.get(\"INITIAL_PEERS\")\nif not INITIAL_PEERS:\n    raise RuntimeError(\"Must specify INITIAL_PEERS environment variable with one or more peer ids\")\nINITIAL_PEERS = INITIAL_PEERS.split()\n\n\nMODEL_NAME = os.environ.get(\"MODEL_NAME\")\nif not MODEL_NAME:\n    raise RuntimeError(\"Must specify MODEL_NAME as an index of a transformer block to be tested\")\n\nREF_NAME = os.environ.get(\"REF_NAME\")\n\nADAPTER_NAME = os.environ.get(\"ADAPTER_NAME\")\n"
  }
]