Showing preview only (501K chars total). Download the full file or copy to clipboard to get everything.
Repository: bigscience-workshop/petals
Branch: main
Commit: 22afba627a7e
Files: 100
Total size: 472.8 KB
Directory structure:
gitextract_tsupof0f/
├── .github/
│ └── workflows/
│ ├── check-style.yaml
│ ├── push-docker-image.yaml
│ └── run-tests.yaml
├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── benchmarks/
│ ├── benchmark_forward.py
│ ├── benchmark_inference.py
│ └── benchmark_training.py
├── examples/
│ ├── prompt-tuning-personachat.ipynb
│ └── prompt-tuning-sst2.ipynb
├── pyproject.toml
├── setup.cfg
├── src/
│ └── petals/
│ ├── __init__.py
│ ├── cli/
│ │ ├── __init__.py
│ │ ├── run_dht.py
│ │ ├── run_prod_server.sh
│ │ └── run_server.py
│ ├── client/
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── from_pretrained.py
│ │ ├── inference_session.py
│ │ ├── lm_head.py
│ │ ├── ptune.py
│ │ ├── remote_forward_backward.py
│ │ ├── remote_generation.py
│ │ ├── remote_sequential.py
│ │ ├── routing/
│ │ │ ├── __init__.py
│ │ │ ├── sequence_info.py
│ │ │ ├── sequence_manager.py
│ │ │ └── spending_policy.py
│ │ └── sequential_autograd.py
│ ├── constants.py
│ ├── data_structures.py
│ ├── dht_utils.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── bloom/
│ │ │ ├── __init__.py
│ │ │ ├── block.py
│ │ │ ├── config.py
│ │ │ └── model.py
│ │ ├── falcon/
│ │ │ ├── __init__.py
│ │ │ ├── block.py
│ │ │ ├── config.py
│ │ │ └── model.py
│ │ ├── llama/
│ │ │ ├── __init__.py
│ │ │ ├── block.py
│ │ │ ├── config.py
│ │ │ ├── model.py
│ │ │ └── speculative_model.py
│ │ └── mixtral/
│ │ ├── __init__.py
│ │ ├── block.py
│ │ ├── config.py
│ │ └── model.py
│ ├── server/
│ │ ├── __init__.py
│ │ ├── backend.py
│ │ ├── block_functions.py
│ │ ├── block_selection.py
│ │ ├── block_utils.py
│ │ ├── from_pretrained.py
│ │ ├── handler.py
│ │ ├── memory_cache.py
│ │ ├── reachability.py
│ │ ├── server.py
│ │ ├── task_pool.py
│ │ ├── task_prioritizer.py
│ │ └── throughput.py
│ └── utils/
│ ├── __init__.py
│ ├── asyncio.py
│ ├── auto_config.py
│ ├── convert_block.py
│ ├── cuda_graphs.py
│ ├── dht.py
│ ├── disk_cache.py
│ ├── hf_auth.py
│ ├── logging.py
│ ├── misc.py
│ ├── packaging.py
│ ├── peft.py
│ ├── ping.py
│ ├── random.py
│ └── version.py
└── tests/
├── bootstrap.id
├── conftest.py
├── server2.id
├── test_aux_functions.py
├── test_block_exact_match.py
├── test_cache.py
├── test_chained_calls.py
├── test_dtype.py
├── test_full_model.py
├── test_optimized_layers.py
├── test_peft.py
├── test_priority_pool.py
├── test_remote_sequential.py
├── test_sequence_manager.py
├── test_server_stats.py
├── test_speculative_generation.py
├── test_tensor_parallel.py
└── test_utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/check-style.yaml
================================================
name: Check style
on:
push:
branches: [ main ]
pull_request:
jobs:
black:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: psf/black@stable
with:
options: "--check --diff"
version: "22.3.0"
isort:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: 3.8
- uses: isort/isort-action@master
with:
isortVersion: "5.10.1"
================================================
FILE: .github/workflows/push-docker-image.yaml
================================================
name: Push to Docker Hub
on:
push:
branches: [ main ]
tags:
- "*.*.*"
pull_request:
branches: [ main ]
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Docker meta
id: meta
uses: crazy-max/ghaction-docker-meta@v2
with:
# list of Docker images to use as base name for tags
images: |
learningathome/petals
# generate Docker tags based on the following events/attributes
tags: |
type=ref,event=branch
type=ref,event=pr
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=semver,pattern={{major}}
- name: Set up Docker Buildx
id: buildx
uses: docker/setup-buildx-action@v1
- name: Login to Docker Hub
if: github.event_name != 'pull_request'
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
- name: Free disk space on Ubuntu runner
uses: kfir4444/free-disk-space@main
with:
# found in: https://github.com/docker/build-push-action/issues/968
tool-cache: false
android: true
dotnet: true
haskell: true
large-packages: true
swap-storage: true
- name: Build and push
id: docker_build
uses: docker/build-push-action@v2
with:
context: .
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.meta.outputs.tags }}
- name: Image digest
run: echo ${{ steps.docker_build.outputs.digest }}
================================================
FILE: .github/workflows/run-tests.yaml
================================================
name: Tests
on:
push:
branches: [ main ]
pull_request:
jobs:
run-tests:
strategy:
matrix:
include:
- { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.8' }
- { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' }
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
- { model: 'artek0chumak/TestMixtral', os: 'ubuntu', python-version: '3.8' }
- { model: 'artek0chumak/TestMixtral', os: 'ubuntu', python-version: '3.11' }
fail-fast: false
runs-on: ${{ matrix.os }}-latest
timeout-minutes: 20
steps:
- name: Increase swap space
if: ${{ matrix.os == 'ubuntu' }}
uses: pierotofy/set-swap-space@master
with:
swap-size-gb: 10
- name: Checkout
uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Cache dependencies
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: Key-v1-${{ matrix.python-version }}-${{ hashFiles('setup.cfg') }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[dev]
- name: Test
run: |
set -x # Print executed commands
export MODEL_NAME="${{ matrix.model }}"
export REF_NAME="${{ matrix.model }}"
export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"
# [Step 1] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
python -m petals.cli.run_dht \
--identity_path tests/bootstrap.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> bootstrap.log &
BOOTSTRAP_PID=$!
export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
# ^-- multiaddr in INITIAL_PEERS is determined by --identity_path and --host_maddrs
until [ -s bootstrap.log ]; do sleep 5; done # wait for DHT init
export RUN_SERVER="python -m petals.cli.run_server $MODEL_NAME \
--device cpu --torch_dtype float32 --initial_peers $INITIAL_PEERS"
export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
$RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 5 --throughput 1 --mean_balance_check_period 10 &> server1.log &
SERVER1_PID=$!
# ^-- rebalacing test: this server chooses blocks 0:5, then sees a gap in the swarm and moves there
sleep 10 # wait for the 1st server to choose blocks
$RUN_SERVER --adapters $ADAPTER_NAME --block_indices 0:5 --throughput 1 --identity_path tests/server2.id &> server2.log &
SERVER2_PID=$!
$RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 14 --throughput auto \
--attn_cache_tokens 2048 --max_chunk_size_bytes 1024 &> server3.log &
SERVER3_PID=$!
# ^-- chunking test
$RUN_SERVER $TENSOR_PARALLEL_ARGS --block_indices 0:2 --throughput auto &> server4.log &
SERVER4_PID=$!
# ^-- tensor parallelism test (not compatible with adapters yet)
sleep 5 # wait for the log files to appear
tail -n 100 -f bootstrap.log server*.log &
LOGGER_PID=$!
sleep 30 # wait for servers to eval throughput, download layers, and rebalance
kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived init
# [Step 2] Run PyTest
# Share disk cache between Petals servers, clients, and HF Transformers
export TRANSFORMERS_CACHE=~/.cache/petals
# Necessary for @pytest.mark.forked to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93
export no_proxy=*
export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
# Limit default ClientConfig.max_retries to see tracebacks instead of retrying indefinitely
export PETALS_MAX_RETRIES=10
pytest tests --durations=0 --durations-min=1.0 -v
# [Step 3] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)
python benchmarks/benchmark_inference.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3
python benchmarks/benchmark_forward.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3 --batch_size 3 --n_steps 1
python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task cls
python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task causal_lm
# [Step 4] Clean up
kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
.idea/
================================================
FILE: Dockerfile
================================================
FROM nvcr.io/nvidia/cuda:11.0.3-cudnn8-devel-ubuntu20.04
LABEL maintainer="bigscience-workshop"
LABEL repository="petals"
WORKDIR /home
# Set en_US.UTF-8 locale by default
RUN echo "LC_ALL=en_US.UTF-8" >> /etc/environment
# Install packages
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
wget \
git \
&& apt-get clean autoclean && rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log} /tmp/* /var/tmp/*
RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O install_miniconda.sh && \
bash install_miniconda.sh -b -p /opt/conda && rm install_miniconda.sh
ENV PATH="/opt/conda/bin:${PATH}"
RUN conda install python~=3.10.12 pip && \
pip install --no-cache-dir "torch>=1.12" && \
conda clean --all && rm -rf ~/.cache/pip
VOLUME /cache
ENV PETALS_CACHE=/cache
COPY . petals/
RUN pip install --no-cache-dir -e petals
WORKDIR /home/petals/
CMD bash
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2022 Petals authors and collaborators
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
<p align="center">
<img src="https://i.imgur.com/7eR7Pan.png" width="400"><br>
Run large language models at home, BitTorrent-style.<br>
Fine-tuning and inference <a href="https://github.com/bigscience-workshop/petals#benchmarks">up to 10x faster</a> than offloading
<br><br>
<a href="https://pypi.org/project/petals/"><img src="https://img.shields.io/pypi/v/petals.svg?color=green"></a>
<a href="https://discord.gg/tfHfe8B34k"><img src="https://img.shields.io/discord/865254854262652969?label=discord&logo=discord&logoColor=white"></a>
<br>
</p>
Generate text with distributed **Llama 3.1** (up to 405B), **Mixtral** (8x22B), **Falcon** (40B+) or **BLOOM** (176B) and fine‑tune them for your own tasks — right from your desktop computer or Google Colab:
```python
from transformers import AutoTokenizer
from petals import AutoDistributedModelForCausalLM
# Choose any model available at https://health.petals.dev
model_name = "meta-llama/Meta-Llama-3.1-405B-Instruct"
# Connect to a distributed network hosting model layers
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoDistributedModelForCausalLM.from_pretrained(model_name)
# Run the model as if it were on your computer
inputs = tokenizer("A cat sat", return_tensors="pt")["input_ids"]
outputs = model.generate(inputs, max_new_tokens=5)
print(tokenizer.decode(outputs[0])) # A cat sat on a mat...
```
<p align="center">
🚀 <b><a href="https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing">Try now in Colab</a></b>
</p>
🦙 **Want to run Llama?** [Request access](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct) to its weights, then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev).
🔏 **Privacy.** Your data will be processed with the help of other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust.
💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)!
## Connect your GPU and increase Petals capacity
Petals is a community-run system — we rely on people sharing their GPUs. You can help serving one of the [available models](https://health.petals.dev) or host a new model from 🤗 [Model Hub](https://huggingface.co/models)!
As an example, here is how to host a part of [Llama 3.1 (405B) Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct) on your GPU:
🦙 **Want to host Llama?** [Request access](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct) to its weights, then run `huggingface-cli login` in the terminal before loading the model.
🐧 **Linux + Anaconda.** Run these commands for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD):
```bash
conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
pip install git+https://github.com/bigscience-workshop/petals
python -m petals.cli.run_server meta-llama/Meta-Llama-3.1-405B-Instruct
```
🪟 **Windows + WSL.** Follow [this guide](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows) on our Wiki.
🐋 **Docker.** Run our [Docker](https://www.docker.com) image for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD):
```bash
sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \
learningathome/petals:main \
python -m petals.cli.run_server --port 31330 meta-llama/Meta-Llama-3.1-405B-Instruct
```
🍏 **macOS + Apple M1/M2 GPU.** Install [Homebrew](https://brew.sh/), then run these commands:
```bash
brew install python
python3 -m pip install git+https://github.com/bigscience-workshop/petals
python3 -m petals.cli.run_server meta-llama/Meta-Llama-3.1-405B-Instruct
```
<p align="center">
📚 <b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-server">Learn more</a></b> (how to use multiple GPUs, start the server on boot, etc.)
</p>
🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety).
💬 **Any questions?** Ping us in [our Discord](https://discord.gg/X7DgtxgMhc)!
🏆 **Thank you!** Once you load and host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks. You can specify them with `--public_name YOUR_NAME`.
## How does it work?
- You load a small part of the model, then join a [network](https://health.petals.dev) of people serving the other parts. Single‑batch inference runs at up to **6 tokens/sec** for **Llama 2** (70B) and up to **4 tokens/sec** for **Falcon** (180B) — enough for [chatbots](https://chat.petals.dev) and interactive apps.
- You can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of **PyTorch** and **🤗 Transformers**.
<p align="center">
<img src="https://i.imgur.com/RTYF3yW.png" width="800">
</p>
<p align="center">
📜 <b><a href="https://arxiv.org/pdf/2209.01188.pdf">Read paper</a></b>
📚 <b><a href="https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions">See FAQ</a></b>
</p>
## 📚 Tutorials, examples, and more
Basic tutorials:
- Getting started: [tutorial](https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing)
- Prompt-tune Llama-65B for text semantic classification: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb)
- Prompt-tune BLOOM to create a personified chatbot: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb)
Useful tools:
- [Chatbot web app](https://chat.petals.dev) (connects to Petals via an HTTP/WebSocket endpoint): [source code](https://github.com/petals-infra/chat.petals.dev)
- [Monitor](https://health.petals.dev) for the public swarm: [source code](https://github.com/petals-infra/health.petals.dev)
Advanced guides:
- Launch a private swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
- Run a custom model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals)
### Benchmarks
Please see **Section 3.3** of our [paper](https://arxiv.org/pdf/2209.01188.pdf).
### 🛠️ Contributing
Please see our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#contributing) on contributing.
### 📜 Citations
Alexander Borzunov, Dmitry Baranchuk, Tim Dettmers, Max Ryabinin, Younes Belkada, Artem Chumachenko, Pavel Samygin, and Colin Raffel.
[Petals: Collaborative Inference and Fine-tuning of Large Models.](https://arxiv.org/abs/2209.01188)
_Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)._ 2023.
```bibtex
@inproceedings{borzunov2023petals,
title = {Petals: Collaborative Inference and Fine-tuning of Large Models},
author = {Borzunov, Alexander and Baranchuk, Dmitry and Dettmers, Tim and Riabinin, Maksim and Belkada, Younes and Chumachenko, Artem and Samygin, Pavel and Raffel, Colin},
booktitle = {Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)},
pages = {558--568},
year = {2023},
url = {https://arxiv.org/abs/2209.01188}
}
```
Alexander Borzunov, Max Ryabinin, Artem Chumachenko, Dmitry Baranchuk, Tim Dettmers, Younes Belkada, Pavel Samygin, and Colin Raffel.
[Distributed inference and fine-tuning of large language models over the Internet.](https://arxiv.org/abs/2312.08361)
_Advances in Neural Information Processing Systems_ 36 (2023).
```bibtex
@inproceedings{borzunov2023distributed,
title = {Distributed inference and fine-tuning of large language models over the {I}nternet},
author = {Borzunov, Alexander and Ryabinin, Max and Chumachenko, Artem and Baranchuk, Dmitry and Dettmers, Tim and Belkada, Younes and Samygin, Pavel and Raffel, Colin},
booktitle = {Advances in Neural Information Processing Systems},
volume = {36},
pages = {12312--12331},
year = {2023},
url = {https://arxiv.org/abs/2312.08361}
}
```
--------------------------------------------------------------------------------
<p align="center">
This project is a part of the <a href="https://bigscience.huggingface.co/">BigScience</a> research workshop.
</p>
<p align="center">
<img src="https://petals.dev/bigscience.png" width="150">
</p>
================================================
FILE: benchmarks/benchmark_forward.py
================================================
#!/usr/bin/env python3
import argparse
import multiprocessing as mp
from time import perf_counter
import numpy as np
import torch
from hivemind.utils.logging import get_logger
from petals import AutoDistributedModel
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
logger = get_logger()
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, required=True, help="Model")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
parser.add_argument("--n_steps", type=int, default=100, help="Number of benchmark steps")
parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
args = parser.parse_args()
if args.n_processes == "n_gpus":
args.n_processes = torch.cuda.device_count()
else:
args.n_processes = int(args.n_processes)
pipe_recv, pipe_send = mp.Pipe(duplex=False)
processes = [mp.Process(target=benchmark_forward, args=(i, args, pipe_send)) for i in range(args.n_processes)]
for proc in processes:
proc.start()
for proc in processes:
proc.join()
speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)])
logger.info(f"Final result: {speed=:.2f}")
@torch.inference_mode()
def benchmark_forward(process_idx, args, result_pipe):
model = AutoDistributedModel.from_pretrained(
args.model,
initial_peers=args.initial_peers,
torch_dtype=DTYPE_MAP[args.torch_dtype],
)
logger.info(f"Created model: {process_idx=} {model.device=}")
torch.manual_seed(42)
step_times = []
for step in range(args.warmup_steps + args.n_steps):
start_time = perf_counter()
input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len))
logger.info(f"{process_idx=} Fwd begin {input_ids.shape=}")
h = model(input_ids)
# We don't use model.lm_head
logger.info(f"{process_idx=} Fwd end")
if step >= args.warmup_steps:
step_times.append(perf_counter() - start_time)
speed = input_ids.numel() / np.mean(step_times)
logger.info(f"{process_idx=} {step=} {speed=:.2f}")
result_pipe.send(speed)
if __name__ == "__main__":
main()
================================================
FILE: benchmarks/benchmark_inference.py
================================================
#!/usr/bin/env python3
import argparse
import multiprocessing as mp
from time import perf_counter
import numpy as np
import torch
from hivemind.utils.logging import get_logger
from transformers import AutoTokenizer
from petals import AutoDistributedModelForCausalLM
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
logger = get_logger()
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, required=True, help="Model")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
parser.add_argument("--seq_len", type=int, default=2048, help="Sequence length")
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
args = parser.parse_args()
if args.n_processes == "n_gpus":
args.n_processes = torch.cuda.device_count()
else:
args.n_processes = int(args.n_processes)
pipe_recv, pipe_send = mp.Pipe(duplex=False)
processes = [mp.Process(target=benchmark_inference, args=(i, args, pipe_send)) for i in range(args.n_processes)]
for proc in processes:
proc.start()
for proc in processes:
proc.join()
speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)])
logger.info(f"Final result: {speed=:.2f}")
@torch.inference_mode()
def benchmark_inference(process_idx, args, result_pipe):
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
# Using use_fast=False since LlamaTokenizerFast takes a long time to start, and we decode 1 token at a time anyway
model = AutoDistributedModelForCausalLM.from_pretrained(
args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype]
)
logger.info(f"Created model: {process_idx=} {model.device=}")
result = ""
step_times = []
with model.transformer.h.inference_session(max_length=args.seq_len) as sess:
for step in range(args.seq_len):
start_time = perf_counter()
outputs = model.generate(max_new_tokens=1, session=sess)
result += tokenizer.decode(outputs[0])
if step >= args.warmup_steps:
step_times.append(perf_counter() - start_time)
speed = 1 / np.mean(step_times)
logger.info(f"{process_idx=} {step=} {speed=:.2f}")
result_pipe.send(speed)
if __name__ == "__main__":
main()
================================================
FILE: benchmarks/benchmark_training.py
================================================
#!/usr/bin/env python3
import argparse
import multiprocessing as mp
from time import perf_counter
import numpy as np
import torch
from hivemind.utils.logging import get_logger
from petals import AutoDistributedModelForCausalLM, AutoDistributedModelForSequenceClassification
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
logger = get_logger()
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, required=True, help="Model")
parser.add_argument("--device", type=str, default="cpu", help="Torch device hosting the client")
parser.add_argument("--task", type=str, default="cls", help="Training task type")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
parser.add_argument("--pre_seq_len", type=int, default=16, help="Number of trainable tokens")
parser.add_argument("--n_steps", type=int, default=10, help="Number of benchmark steps")
parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
args = parser.parse_args()
assert args.task in ["cls", "causal_lm"]
if args.n_processes == "n_gpus":
args.n_processes = torch.cuda.device_count()
else:
args.n_processes = int(args.n_processes)
pipe_recv, pipe_send = mp.Pipe(duplex=False)
processes = [mp.Process(target=benchmark_training, args=(i, args, pipe_send)) for i in range(args.n_processes)]
for proc in processes:
proc.start()
for proc in processes:
proc.join()
fwd_speed, bwd_speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)], axis=0)
logger.info(f"Final result: {fwd_speed=:.2f} {bwd_speed=:.2f}")
def benchmark_training(process_idx, args, result_pipe):
if args.task == "cls":
model = AutoDistributedModelForSequenceClassification.from_pretrained(
args.model,
initial_peers=args.initial_peers,
torch_dtype=DTYPE_MAP[args.torch_dtype],
tuning_mode="deep_ptune",
pre_seq_len=args.pre_seq_len,
num_labels=2,
)
elif args.task == "causal_lm":
model = AutoDistributedModelForCausalLM.from_pretrained(
args.model,
initial_peers=args.initial_peers,
torch_dtype=DTYPE_MAP[args.torch_dtype],
tuning_mode="deep_ptune",
pre_seq_len=args.pre_seq_len,
)
model = model.to(args.device)
opt = torch.optim.Adam(model.parameters())
logger.info(f"Created model: {process_idx=} {model.device=}")
torch.manual_seed(42)
fwd_times = []
bwd_times = []
for step in range(args.warmup_steps + args.n_steps):
input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len), device=args.device)
if args.task == "cls":
labels = torch.randint(0, 2, size=[args.batch_size], device=args.device)
else:
labels = input_ids
logger.info(f"{process_idx=} {step=} Forward")
start_time = perf_counter()
outputs = model(input_ids, labels=labels)
if step >= args.warmup_steps:
fwd_times.append(perf_counter() - start_time)
logger.info(f"{process_idx=} {step=} Backward")
start_time = perf_counter()
outputs.loss.backward()
if step >= args.warmup_steps:
bwd_times.append(perf_counter() - start_time)
logger.info(f"{process_idx=} {step=} Optimizer step")
opt.step()
opt.zero_grad()
if step >= args.warmup_steps:
fwd_speed = input_ids.numel() / np.mean(fwd_times)
bwd_speed = input_ids.numel() / np.mean(bwd_times)
logger.info(f"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}")
result_pipe.send((fwd_speed, bwd_speed))
if __name__ == "__main__":
main()
================================================
FILE: examples/prompt-tuning-personachat.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "a07e0f5e",
"metadata": {},
"source": [
"<div>\n",
"<img src=\"https://camo.githubusercontent.com/473dd9f992924d27457650251786464f72e54121ac6e9210add0f483ca849277/68747470733a2f2f692e696d6775722e636f6d2f3765523750616e2e706e67\" width=\"40%\"> \n",
"</div>\n",
"\n",
"# Distributed Bloom for Text Generation using Prompt Tuning\n",
"\n",
"In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n",
"\n",
"We will adapt BLOOM for the task of creating a chatbot with a specific personality using the [Personachat](https://huggingface.co/datasets/bavard/personachat_truecased) dataset. For a given dialogue context, the model has to provide a relevant answer.\n",
"\n",
"To use this notebook in Colab:\n",
"\n",
"1. Follow this link: [](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb)\n",
"2. Go to **Runtime** -> **Change runtime type** and select the GPU accelerator."
]
},
{
"cell_type": "markdown",
"id": "a3f8526f",
"metadata": {},
"source": [
"First, we have to prepare all dependencies."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "73bbc648",
"metadata": {},
"outputs": [],
"source": [
"%pip install -q petals datasets wandb scikit-learn"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b4ab6ca7",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"import torch\n",
"import transformers\n",
"import wandb\n",
"from datasets import load_dataset\n",
"from tqdm import tqdm\n",
"from torch.optim import AdamW\n",
"from torch.utils.data import DataLoader\n",
"from transformers import BloomTokenizerFast, get_scheduler\n",
"\n",
"from petals import DistributedBloomForCausalLM"
]
},
{
"cell_type": "markdown",
"id": "1bf07b5d",
"metadata": {},
"source": [
"Let's set some hyperparameters for training:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f04ba4d2",
"metadata": {},
"outputs": [],
"source": [
"# Choose a model you'd like to prompt-tune. We recommend starting with\n",
"# the smaller 7.1B version of BLOOM (bigscience/bloom-7b1-petals) for faster prototyping.\n",
"# Once your code is ready, you can switch to full-scale\n",
"# 176B-parameter BLOOM (bigscience/bloom-petals) or BLOOMZ (bigscience/bloomz-petals).\n",
"MODEL_NAME = \"bigscience/bloom-7b1-petals\"\n",
"\n",
"# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').\n",
"# The latter fine-tunes separate prefixes for each transformer block,\n",
"# so prompt-tuning will take more time but yield better results.\n",
"# See this paper for details of how it works: https://arxiv.org/pdf/2110.07602.pdf\n",
"TUNING_MODE = 'ptune'\n",
"\n",
"NUM_PREFIX_TOKENS = 16\n",
"DEVICE = 'cuda'\n",
"BATCH_SIZE = 8\n",
"LR = 1e-2\n",
"WEIGHT_DECAY = 0.0\n",
"NUM_SAMPLES = 1000\n",
"SEED = 42\n",
"MODEL_MAX_LENGTH = 256"
]
},
{
"cell_type": "markdown",
"id": "d38316bd",
"metadata": {},
"source": [
"Prepare tokenizer and distributed model, connect it to servers."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "03c6e53e",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = BloomTokenizerFast.from_pretrained(MODEL_NAME)\n",
"tokenizer.padding_side = 'right'\n",
"tokenizer.model_max_length = MODEL_MAX_LENGTH\n",
"model = DistributedBloomForCausalLM.from_pretrained(\n",
" MODEL_NAME,\n",
" pre_seq_len=NUM_PREFIX_TOKENS, \n",
" tuning_mode=TUNING_MODE\n",
").to(DEVICE)"
]
},
{
"cell_type": "markdown",
"id": "042e3786",
"metadata": {},
"source": [
"Let's prepare the Personachat dataset. We need two mapping functions, one to concatenate history and candidate answers, and another for tokenization."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9c44d516",
"metadata": {},
"outputs": [],
"source": [
"dataset = load_dataset(\"bavard/personachat_truecased\")\n",
"\n",
"\n",
"def chunking(examples):\n",
" inputs = [\n",
" \"\\n-----\\n\".join(history) + \"\\n-----\\n\" + candidate\n",
" for history, candidates in zip(examples[\"history\"], examples[\"candidates\"])\n",
" for candidate in candidates\n",
" ]\n",
" return {\"chunks\": inputs}\n",
"\n",
"\n",
"def tokenize(examples):\n",
" outputs = {\n",
" \"input_ids\": tokenizer(examples[\"chunks\"], padding='max_length', truncation=True)[\"input_ids\"]\n",
" }\n",
" outputs[\"labels\"] = outputs[\"input_ids\"]\n",
" return outputs\n",
"\n",
"\n",
"tokenized_datasets = (\n",
" dataset\n",
" .map(chunking, batched=True, remove_columns=dataset[\"train\"].column_names)\n",
" .map(tokenize, batched=True, remove_columns=[\"chunks\"])\n",
")\n",
"\n",
"\n",
"tokenized_datasets.set_format(\"torch\")\n",
"train_dataset = tokenized_datasets[\"train\"].shuffle(seed=SEED)\n",
"train_dataloader = DataLoader(\n",
" train_dataset.select(list(range(NUM_SAMPLES))),\n",
" shuffle=True,\n",
" batch_size=BATCH_SIZE,\n",
" drop_last=True,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "ef4323fd",
"metadata": {},
"source": [
"Before setting up optimizers, check the model parameters that will be trained."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9cc0ba34",
"metadata": {},
"outputs": [],
"source": [
"for n, p in model.named_parameters():\n",
" if p.requires_grad:\n",
" print(n, p.requires_grad, p.device)"
]
},
{
"cell_type": "markdown",
"id": "59cffce7",
"metadata": {},
"source": [
"The optimizer will only work on **prompts**, they are only trainable parameters. Let's initialize optimizer and learning rate scheduler."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ef9bf344",
"metadata": {},
"outputs": [],
"source": [
"optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
"\n",
"lr_scheduler = get_scheduler(\n",
" name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
")"
]
},
{
"cell_type": "markdown",
"id": "423c56d5",
"metadata": {},
"source": [
"Let's initialize wandb for logging and start the training loop!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d9e46807",
"metadata": {},
"outputs": [],
"source": [
"wandb.init(\n",
" project=\"bloom-personachat\",\n",
" config={\n",
" \"num_samples\": NUM_SAMPLES,\n",
" \"batch_size\": BATCH_SIZE,\n",
" \"learning_rate\": LR,\n",
" \"weight_decay\": WEIGHT_DECAY,\n",
" \"num_prefix_tokens\": NUM_PREFIX_TOKENS,\n",
" \"model_name\": MODEL_NAME,\n",
" \"seed\": SEED,\n",
" }\n",
")\n",
"\n",
"for batch in tqdm(train_dataloader):\n",
" batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
"\n",
" model.train()\n",
" outputs = model(**batch)\n",
" loss = outputs.loss\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
" lr_scheduler.step()\n",
" optimizer.zero_grad()\n",
"\n",
" wandb.log({\"Train Loss\": loss})"
]
},
{
"cell_type": "markdown",
"id": "0f36cb80",
"metadata": {},
"source": [
"Try to talk with the trained model! Submit an empty input to stop the execution.\n",
"\n",
"\n",
"__Note__: In this example, we the whole dialogue as a prefix when generating each new replica. In the future, we will support a faster \"interactive\" dialogue mode, so generating a new replica will be able to reuse inference caches from the previous replica."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "720181b7",
"metadata": {},
"outputs": [],
"source": [
"TOP_K = 100\n",
"TEMPERATURE = 0.6\n",
"\n",
"with model.inference_session(max_length=512) as sess:\n",
" while True:\n",
" user_phrase = input()\n",
" if len(user_phrase) == 0:\n",
" break\n",
" inputs = tokenizer([f\"{user_phrase}\\n-----\\n\"], return_tensors='pt')['input_ids'].to(DEVICE)\n",
" while True:\n",
" outputs = model.generate(\n",
" inputs,\n",
" temperature=TEMPERATURE,\n",
" do_sample=True,\n",
" top_k=TOP_K,\n",
" max_new_tokens=1,\n",
" session=sess,\n",
" )\n",
" bloom_answer_token = tokenizer.decode(outputs[0, -1:])\n",
" print(bloom_answer_token, end=\"\", flush=True)\n",
" if bloom_answer_token == \"\\n\":\n",
" break\n",
" inputs = None"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.9 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.9"
},
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: examples/prompt-tuning-sst2.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"id": "a07e0f5e",
"metadata": {
"id": "a07e0f5e"
},
"source": [
"<div>\n",
"<img src=\"https://camo.githubusercontent.com/473dd9f992924d27457650251786464f72e54121ac6e9210add0f483ca849277/68747470733a2f2f692e696d6775722e636f6d2f3765523750616e2e706e67\" width=\"40%\"> \n",
"</div>\n",
"\n",
"# Distributed LLaMA for Text Classification using Prompt Tuning\n",
"\n",
"In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt the [LLaMA](https://github.com/facebookresearch/llama) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the LLaMA blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n",
"\n",
"We will adapt LLaMA for the classification task using the [SST-2 dataset](https://nlp.stanford.edu/sentiment/). This dataset is a binary classification task, where the goal is to predict whether a sentence is positive or negative. The SST-2 dataset is a subset of the Stanford Sentiment Treebank, and it is available in the [Hugging Face Datasets](https://huggingface.co/datasets) library.\n",
"\n",
"To use this notebook in Colab:\n",
"\n",
"1. Follow this link: [](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb)\n",
"2. Go to **Runtime** -> **Change runtime type** and select the GPU accelerator."
]
},
{
"cell_type": "markdown",
"id": "a3f8526f",
"metadata": {
"id": "a3f8526f"
},
"source": [
"First, we have to prepare all dependencies."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "73bbc648",
"metadata": {
"id": "73bbc648"
},
"outputs": [],
"source": [
"%pip install -q datasets wandb scikit-learn\n",
"%pip install -q git+https://github.com/bigscience-workshop/petals@main"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b4ab6ca7",
"metadata": {
"id": "b4ab6ca7"
},
"outputs": [],
"source": [
"import os\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import transformers\n",
"import wandb\n",
"from datasets import load_dataset, load_metric\n",
"from tqdm import tqdm\n",
"from torch.optim import AdamW\n",
"from torch.utils.data import DataLoader\n",
"from transformers import LlamaTokenizer, get_scheduler, set_seed\n",
"\n",
"from petals import DistributedLlamaForSequenceClassification\n",
"\n",
"set_seed(0)"
]
},
{
"cell_type": "markdown",
"id": "1bf07b5d",
"metadata": {
"id": "1bf07b5d"
},
"source": [
"Let's set some hyperparameters for training:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f04ba4d2",
"metadata": {
"id": "f04ba4d2"
},
"outputs": [],
"source": [
"MODEL_NAME = \"enoch/llama-65b-hf\"\n",
"\n",
"# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').\n",
"# The latter fine-tunes separate prefixes for each transformer block,\n",
"# so prompt-tuning will take more time but yield better results.\n",
"# See this paper for details of how it works: https://arxiv.org/pdf/2110.07602.pdf\n",
"TUNING_MODE = 'ptune'\n",
"\n",
"NUM_PREFIX_TOKENS = 8\n",
"DEVICE = 'cuda'\n",
"BATCH_SIZE = 32\n",
"LR = 1e-2\n",
"WEIGHT_DECAY = 0.0\n",
"NUM_EPOCHS = 3\n",
"SEED = 42\n",
"MODEL_MAX_LENGTH = 64"
]
},
{
"cell_type": "markdown",
"id": "d38316bd",
"metadata": {
"id": "d38316bd"
},
"source": [
"Here, we prepare tokenizer and distributed model and connect it to the public swarm."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "03c6e53e",
"metadata": {
"id": "03c6e53e"
},
"outputs": [],
"source": [
"tokenizer = LlamaTokenizer.from_pretrained(MODEL_NAME)\n",
"tokenizer.padding_side = 'right'\n",
"tokenizer.model_max_length = MODEL_MAX_LENGTH\n",
"tokenizer.pad_token = tokenizer.unk_token\n",
"model = DistributedLlamaForSequenceClassification.from_pretrained(\n",
" MODEL_NAME,\n",
" pre_seq_len=NUM_PREFIX_TOKENS,\n",
" tuning_mode=TUNING_MODE\n",
").float().to(DEVICE)\n",
"model.config.pad_token_id = tokenizer.pad_token_id"
]
},
{
"cell_type": "markdown",
"id": "042e3786",
"metadata": {
"id": "042e3786"
},
"source": [
"Let's prepare the SST-2 dataset. We need just one preprocessing function to tokenize the dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9c44d516",
"metadata": {
"id": "9c44d516"
},
"outputs": [],
"source": [
"task = 'sst2'\n",
"\n",
"dataset = load_dataset(\"glue\", task)\n",
"\n",
"def preprocess_function(examples):\n",
" return tokenizer(examples[\"sentence\"], padding='max_length', truncation=True, return_token_type_ids=False)\n",
"\n",
"tokenized_datasets = dataset.map(preprocess_function, batched=True)\n",
"tokenized_datasets = tokenized_datasets.remove_columns([\"sentence\", \"idx\", \"attention_mask\"])\n",
"tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n",
"tokenized_datasets.set_format(\"torch\")\n",
"\n",
"train_dataset = tokenized_datasets[\"train\"].shuffle(seed=SEED)\n",
"valid_dataset = tokenized_datasets[\"validation\"].shuffle(seed=SEED)\n",
"\n",
"train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE, drop_last=True)\n",
"valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)"
]
},
{
"cell_type": "markdown",
"id": "2a3f3590",
"metadata": {
"id": "2a3f3590"
},
"source": [
"To monitor training, we need the metric function. For SST-2, the target metric is accuracy. We will load it from the datasets library."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1e1812be",
"metadata": {
"id": "1e1812be"
},
"outputs": [],
"source": [
"metric = load_metric('glue', task)\n",
"\n",
"def eval_metrics(model, dataloader, device='cpu'):\n",
" model.eval()\n",
" for batch in dataloader:\n",
" batch = {k: v.to(device) for k, v in batch.items()}\n",
"\n",
" with torch.no_grad():\n",
" outputs = model(**batch)\n",
"\n",
" logits = outputs.logits\n",
" predictions = torch.argmax(logits, dim=-1)\n",
" metric.add_batch(predictions=predictions, references=batch[\"labels\"])\n",
" model.train()\n",
" return metric.compute()"
]
},
{
"cell_type": "markdown",
"id": "ef4323fd",
"metadata": {
"id": "ef4323fd"
},
"source": [
"Before setting up optimizers, let's check the model parameters that will be trained."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9cc0ba34",
"metadata": {
"id": "9cc0ba34"
},
"outputs": [],
"source": [
"for n, p in model.named_parameters():\n",
" if p.requires_grad:\n",
" print(n, p.requires_grad, p.device)"
]
},
{
"cell_type": "markdown",
"id": "59cffce7",
"metadata": {
"id": "59cffce7"
},
"source": [
"The optimizer will only work on **prompts and classifier head**: they are only trainable parameters. Let's initialize the optimizer and the learning rate scheduler."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ef9bf344",
"metadata": {
"id": "ef9bf344"
},
"outputs": [],
"source": [
"optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
"\n",
"lr_scheduler = get_scheduler(\n",
" name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader) * NUM_EPOCHS\n",
")"
]
},
{
"cell_type": "markdown",
"id": "423c56d5",
"metadata": {
"id": "423c56d5"
},
"source": [
"Let's initialize wandb for logging and start the training loop!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d9e46807",
"metadata": {
"id": "d9e46807"
},
"outputs": [],
"source": [
"wandb.init(\n",
" project=\"bloom-sst-2\",\n",
" config={\n",
" \"num_epochs\": NUM_EPOCHS,\n",
" \"batch_size\": BATCH_SIZE,\n",
" \"learning_rate\": LR,\n",
" \"weight_decay\": WEIGHT_DECAY,\n",
" \"num_prefix_tokens\": NUM_PREFIX_TOKENS,\n",
" \"model_name\": MODEL_NAME,\n",
" \"seed\": SEED,\n",
" }\n",
")\n",
"\n",
"scaler = torch.cuda.amp.GradScaler()\n",
"\n",
"for epoch in range(NUM_EPOCHS):\n",
" model.train()\n",
" for batch in tqdm(train_dataloader):\n",
" batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
"\n",
" with torch.autocast(device_type=DEVICE, dtype=torch.float16):\n",
" outputs = model(**batch)\n",
" loss = outputs.loss\n",
" scaler.scale(loss).backward()\n",
"\n",
" scaler.step(optimizer)\n",
" scaler.update()\n",
" lr_scheduler.step()\n",
" optimizer.zero_grad()\n",
"\n",
" wandb.log({\"Train Loss\": loss.detach()})\n",
"\n",
" accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\n",
" wandb.log({\"Valid Accuracy\": accuracy}, commit=False)"
]
},
{
"cell_type": "markdown",
"id": "51770911",
"metadata": {
"id": "51770911"
},
"source": [
"Our model has been trained! You can now upload it to the Hub for later use, try out different models [served in the public swarm](https://health.petals.dev/), or [join Petals with your own GPU](https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity)!"
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
},
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
},
"colab": {
"provenance": [],
"gpuType": "T4"
},
"accelerator": "GPU"
},
"nbformat": 4,
"nbformat_minor": 5
}
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = [
"setuptools>=42",
"wheel"
]
build-backend = "setuptools.build_meta"
[tool.black]
line-length = 120
required-version = "22.3.0"
[tool.isort]
profile = "black"
line_length = 120
combine_as_imports = true
combine_star = true
known_local_folder = ["tests", "cli"]
known_first_party = ["test_utils"]
================================================
FILE: setup.cfg
================================================
[metadata]
name = petals
version = attr: petals.__version__
author = Petals Developers
author_email = petals-devs@googlegroups.com
description = Easy way to efficiently run 100B+ language models without high-end GPUs
long_description = file: README.md
long_description_content_type = text/markdown
url = https://github.com/bigscience-workshop/petals
project_urls =
Bug Tracker = https://github.com/bigscience-workshop/petals/issues
classifiers =
Development Status :: 4 - Beta
Intended Audience :: Developers
Intended Audience :: Science/Research
License :: OSI Approved :: MIT License
Programming Language :: Python :: 3
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Topic :: Scientific/Engineering
Topic :: Scientific/Engineering :: Mathematics
Topic :: Scientific/Engineering :: Artificial Intelligence
Topic :: Software Development
Topic :: Software Development :: Libraries
Topic :: Software Development :: Libraries :: Python Modules
[options]
package_dir =
= src
packages = find:
python_requires = >=3.8
install_requires =
torch>=1.12
bitsandbytes==0.41.1
accelerate>=0.27.2
huggingface-hub>=0.11.1,<1.0.0
tokenizers>=0.13.3
transformers==4.43.1 # if you change this, please also change version assert in petals/__init__.py
speedtest-cli==2.1.3
hivemind @ git+https://github.com/learning-at-home/hivemind.git@213bff98a62accb91f254e2afdccbf1d69ebdea9
tensor_parallel==1.0.23
humanfriendly
async-timeout>=4.0.2
cpufeature>=0.2.0; platform_machine == "x86_64"
packaging>=20.9
sentencepiece>=0.1.99
peft==0.8.2
safetensors>=0.3.1
Dijkstar>=2.6.0
numpy<2
[options.extras_require]
dev =
pytest==6.2.5
pytest-forked
pytest-asyncio==0.16.0
black==22.3.0
isort==5.10.1
psutil
[options.packages.find]
where = src
================================================
FILE: src/petals/__init__.py
================================================
import os
import platform
os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1")
if platform.system() == "Darwin":
# Necessary for forks to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93
os.environ.setdefault("no_proxy", "*")
os.environ.setdefault("OBJC_DISABLE_INITIALIZE_FORK_SAFETY", "YES")
import hivemind
import transformers
from packaging import version
from petals.client import *
from petals.models import *
from petals.utils import *
from petals.utils.logging import initialize_logs as _initialize_logs
__version__ = "2.3.0.dev2"
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
assert (
version.parse("4.43.1") <= version.parse(transformers.__version__) < version.parse("4.44.0")
), "Please install a proper transformers version: pip install transformers>=4.43.1,<4.44.0"
def _override_bfloat16_mode_default():
if os.getenv("USE_LEGACY_BFLOAT16") is None:
hivemind.compression.base.USE_LEGACY_BFLOAT16 = False
_initialize_logs()
_override_bfloat16_mode_default()
================================================
FILE: src/petals/cli/__init__.py
================================================
================================================
FILE: src/petals/cli/run_dht.py
================================================
"""
A copy of run_dht.py from hivemind with the ReachabilityProtocol added:
https://github.com/learning-at-home/hivemind/blob/master/hivemind/hivemind_cli/run_dht.py
This script may be used for launching lightweight CPU machines serving as bootstrap nodes to a Petals swarm.
This may be eventually merged to the hivemind upstream.
"""
import argparse
import time
from secrets import token_hex
from hivemind.dht import DHT, DHTNode
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from hivemind.utils.networking import log_visible_maddrs
from petals.server.reachability import ReachabilityProtocol
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__)
async def report_status(dht: DHT, node: DHTNode):
logger.info(
f"{len(node.protocol.routing_table.uid_to_peer_id) + 1} DHT nodes (including this one) "
f"are in the local routing table "
)
logger.debug(f"Routing table contents: {node.protocol.routing_table}")
logger.info(f"Local storage contains {len(node.protocol.storage)} keys")
logger.debug(f"Local storage contents: {node.protocol.storage}")
# Contact peers and keep the routing table healthy (remove stale PeerIDs)
await node.get(f"heartbeat_{token_hex(16)}", latest=True)
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--initial_peers",
nargs="*",
help="Multiaddrs of the peers that will welcome you into the existing DHT. "
"Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/tcp/7777/p2p/YYYY",
)
parser.add_argument(
"--host_maddrs",
nargs="*",
default=["/ip4/0.0.0.0/tcp/0", "/ip6/::/tcp/0"],
help="Multiaddrs to listen for external connections from other DHT instances. "
"Defaults to all IPv4 interfaces and the TCP protocol: /ip4/0.0.0.0/tcp/0",
)
parser.add_argument(
"--announce_maddrs",
nargs="*",
help="Visible multiaddrs the host announces for external connections from other DHT instances",
)
parser.add_argument(
"--use_ipfs",
action="store_true",
help='Use IPFS to find initial_peers. If enabled, you only need to provide the "/p2p/XXXX" '
"part of the multiaddrs for the initial_peers "
"(no need to specify a particular IPv4/IPv6 host and port)",
)
parser.add_argument(
"--identity_path",
help="Path to a private key file. If defined, makes the peer ID deterministic. "
"If the file does not exist, writes a new private key to this file.",
)
parser.add_argument(
"--no_relay",
action="store_false",
dest="use_relay",
help="Disable circuit relay functionality in libp2p (see https://docs.libp2p.io/concepts/nat/circuit-relay/)",
)
parser.add_argument(
"--use_auto_relay",
action="store_true",
help="Look for libp2p relays to become reachable if we are behind NAT/firewall",
)
parser.add_argument(
"--refresh_period", type=int, default=30, help="Period (in seconds) for fetching the keys from DHT"
)
args = parser.parse_args()
dht = DHT(
start=True,
initial_peers=args.initial_peers,
host_maddrs=args.host_maddrs,
announce_maddrs=args.announce_maddrs,
use_ipfs=args.use_ipfs,
identity_path=args.identity_path,
use_relay=args.use_relay,
use_auto_relay=args.use_auto_relay,
)
log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs)
reachability_protocol = ReachabilityProtocol.attach_to_dht(dht, await_ready=True)
while True:
dht.run_coroutine(report_status, return_future=False)
time.sleep(args.refresh_period)
if __name__ == "__main__":
main()
================================================
FILE: src/petals/cli/run_prod_server.sh
================================================
#!/bin/bash
set -x
export HIVEMIND_COLORS=true
while true; do
pkill -f p2p
pkill -f run_server
python -m petals.cli.run_server bigscience/bloom-petals "$@" 2>&1 | tee log_`date '+%F_%H:%M:%S'`
done
================================================
FILE: src/petals/cli/run_server.py
================================================
import argparse
import logging
import configargparse
import torch
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils import limits
from hivemind.utils.logging import get_logger
from humanfriendly import parse_size
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
from petals.server.server import Server
from petals.utils.convert_block import QuantType
from petals.utils.version import validate_version
logger = get_logger(__name__)
def main():
# fmt:off
parser = configargparse.ArgParser(default_config_files=["config.yml"],
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--converted_model_name_or_path', type=str, default=None,
help="path or name of a pretrained model, converted with cli/convert_model.py")
group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path")
parser.add_argument("--public_name", type=str, default=None, help="Public name to be reported in the leaderboard")
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("--token", type=str, default=None, help="Hugging Face hub auth token for .from_pretrained()")
group.add_argument("--use_auth_token", action="store_true", dest="token",
help="Read token saved by `huggingface-cli login")
parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
parser.add_argument('--dht_prefix', type=str, default=None, help="Announce all blocks with this DHT prefix")
parser.add_argument('--port', type=int, required=False,
help='Port this server listens to. '
'This is a simplified way to set the --host_maddrs and --announce_maddrs options (see below) '
'that sets the port across all interfaces (IPv4, IPv6) and protocols (TCP, etc.) '
'to the same number. Default: a random free port is chosen for each interface and protocol')
parser.add_argument('--public_ip', type=str, required=False,
help='Your public IPv4 address, which is visible from the Internet. '
'This is a simplified way to set the --announce_maddrs option (see below).'
'Default: server announces IPv4/IPv6 addresses of your network interfaces')
parser.add_argument("--no_auto_relay", action="store_false", dest="use_auto_relay",
help="Do not look for libp2p relays to become reachable if we are behind NAT/firewall")
parser.add_argument('--host_maddrs', nargs='+', required=False,
help='Multiaddrs to listen for external connections from other peers')
parser.add_argument('--announce_maddrs', nargs='+', required=False,
help='Visible multiaddrs the host announces for external connections from other peers')
parser.add_argument('--daemon_startup_timeout', type=float, default=60,
help='Timeout for the libp2p daemon connecting to initial peers')
parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')
parser.add_argument('--num_handlers', type=int, default=8, required=False,
help='server will use this many processes to handle incoming requests')
parser.add_argument('--prefetch_batches', type=int, default=1, required=False,
help='Pre-form this many subsequent batches while GPU is processing the current one')
parser.add_argument('--sender_threads', type=int, default=1, required=False,
help='Use this many threads to pass results/exceptions from Runtime to Pools')
parser.add_argument('--inference_max_length', type=int, default=None,
help='Maximum total sequence length permitted per inference, defaults to 16384 tokens. '
'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others')
parser.add_argument('--min_batch_size', type=int, default=1,
help='Minimum required batch size for all operations (in total tokens)')
parser.add_argument('--max_batch_size', type=int, default=None,
help='The total number of tokens in the same batch will not exceed this value. '
'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others')
parser.add_argument('--max_chunk_size_bytes', type=int, default=256 * 1024 * 1024,
help='Maximum size of activation tensor processed in one go; larger tensors are split into chunks')
parser.add_argument('--attn_cache_tokens', type=int, default=None,
help='The number of past attention key/value pairs that will be stored between inference steps. '
'Default: 16384 for models with multi-query attention (based on Llama 2, Falcon), 4096 for others')
parser.add_argument('--cache_dir', type=str, default=None,
help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
parser.add_argument("--max_disk_space", type=str, default=None,
help="Maximal disk space used for caches. Example: 50GB, 100GiB (GB != GiB here). "
"Default: unlimited. "
"For bigscience/bloom-petals, this default means that the server may use up to "
"min(free_disk_space, 350GB) in the worst case, which happens when the server runs "
"for a long time and caches all model blocks after a number of rebalancings. "
"However, this worst case is unlikely, expect the server to consume "
"the disk space equal to 2-4x of your GPU memory on average.")
parser.add_argument('--device', type=str, default=None, required=False,
help='all blocks will use this device in torch notation; default: cuda if available else cpu')
parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
help="Use this dtype to store block weights and do computations. "
"By default, respect the dtypes in the pre-trained state dict.")
parser.add_argument('--max_alloc_timeout', type=float, default=600,
help="If the cache is full, the server will wait for memory to be freed up to this many seconds"
" before rejecting the request")
parser.add_argument('--revision', type=str, default=None,
help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
parser.add_argument('--throughput',
type=lambda value: value if value in ['auto', 'eval', 'dry_run'] else float(value),
default='auto',
help='Expected server throughput (a float measured in RPS). '
'If set to "auto" (default), the script evaluates network and compute throughput '
'on the first run and uses these estimates for future runs. '
'If set to "eval", the script re-evaluates the throughput and overrides the cache. '
'If set to "dry_run", the script re-evaluates the throughput and exits.')
parser.add_argument('--update_period', type=float, required=False, default=120,
help='Server will report blocks to DHT once in this many seconds')
parser.add_argument('--expiration', type=float, required=False, default=None,
help='DHT entries will expire after this many seconds')
parser.add_argument('--request_timeout', type=float, required=False, default=3 * 60,
help='Timeout (in seconds) for the whole rpc_forward/rpc_backward/rpc_forward_stream/rpc_backward_stream request')
parser.add_argument('--session_timeout', type=float, required=False, default=30 * 60,
help='Timeout (in seconds) for the whole inference session')
parser.add_argument('--step_timeout', type=float, required=False, default=5 * 60,
help="Timeout (in seconds) for waiting the next step's inputs inside an inference session")
group = parser.add_mutually_exclusive_group()
group.add_argument('--initial_peers', type=str, nargs='+', required=False, default=PUBLIC_INITIAL_PEERS,
help='Multiaddrs of one or more DHT peers from the target swarm. Default: connects to the public swarm')
group.add_argument('--new_swarm', action='store_true',
help='Start a new private swarm (i.e., do not connect to any initial peers)')
parser.add_argument('--increase_file_limit', type=int, default=4096,
help='On *nix, increase the max number of files a server can open '
'before hitting "Too many open files" (set to zero to keep the system limit)')
parser.add_argument('--stats_report_interval', type=int, required=False,
help='Interval between two reports of batch processing performance statistics')
parser.add_argument('--custom_module_path', type=str, required=False,
help='Path of a file with custom nn.modules, wrapped into special decorator')
parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')
parser.add_argument("--balance_quality", type=float, default=0.75,
help="Rebalance the swarm if its throughput is worse than this share of the optimal "
"throughput. Use 0.0 to disable rebalancing, values > 1.0 to force rebalancing "
"on each check for debugging purposes.")
parser.add_argument("--mean_balance_check_period", type=float, default=60,
help="Check the swarm's balance every N seconds (and rebalance it if necessary)")
parser.add_argument('--quant_type', type=str, default=None, choices=[choice.name.lower() for choice in QuantType],
help="Quantize blocks to 8-bit (int8 from the LLM.int8() paper) or "
"4-bit (nf4 from the QLoRA paper) formats to save GPU memory. "
"Default: 'int8' if GPU is available, 'none' otherwise")
parser.add_argument("--tensor_parallel_devices", nargs='+', default=None,
help=
"Split each block between the specified GPUs such that each device holds a portion of every "
"weight matrix. See https://huggingface.co/transformers/v4.9.0/parallelism.html#tensor-parallelism")
parser.add_argument("--skip_reachability_check", action='store_true',
help="Skip checking this server's reachability via health.petals.dev "
"when connecting to the public swarm. If you connect to a private swarm, "
"the check is skipped by default. Use this option only if you know what you are doing")
parser.add_argument("--adapters", nargs='*', default=(),
help="List of pre-loaded LoRA adapters that can be used for inference or training")
# fmt:on
args = vars(parser.parse_args())
args.pop("config", None)
args["converted_model_name_or_path"] = args.pop("model") or args["converted_model_name_or_path"]
host_maddrs = args.pop("host_maddrs")
port = args.pop("port")
if port is not None:
assert host_maddrs is None, "You can't use --port and --host_maddrs at the same time"
else:
port = 0
if host_maddrs is None:
host_maddrs = [f"/ip4/0.0.0.0/tcp/{port}", f"/ip6/::/tcp/{port}"]
announce_maddrs = args.pop("announce_maddrs")
public_ip = args.pop("public_ip")
if public_ip is not None:
assert announce_maddrs is None, "You can't use --public_ip and --announce_maddrs at the same time"
assert port != 0, "Please specify a fixed non-zero --port when you use --public_ip (e.g., --port 31337)"
announce_maddrs = [f"/ip4/{public_ip}/tcp/{port}"]
args["startup_timeout"] = args.pop("daemon_startup_timeout")
file_limit = args.pop("increase_file_limit")
if file_limit:
limits.logger.setLevel(logging.WARNING)
limits.increase_file_limit(file_limit, file_limit)
compression_type = args.pop("compression").upper()
compression = getattr(CompressionType, compression_type)
max_disk_space = args.pop("max_disk_space")
if max_disk_space is not None:
max_disk_space = parse_size(max_disk_space)
assert isinstance(
max_disk_space, (int, type(None))
), "Unrecognized value for --max_disk_space. Correct examples: 1.5GB or 1500MB or 1572864000 (bytes)"
if args.pop("new_swarm"):
args["initial_peers"] = []
quant_type = args.pop("quant_type")
if quant_type is not None:
args["quant_type"] = QuantType[quant_type.upper()]
validate_version()
if not torch.backends.openmp.is_available():
# Necessary to prevent the server from freezing after forks
torch.set_num_threads(1)
server = Server(
**args,
host_maddrs=host_maddrs,
announce_maddrs=announce_maddrs,
compression=compression,
max_disk_space=max_disk_space,
)
try:
server.run()
except KeyboardInterrupt:
logger.info("Caught KeyboardInterrupt, shutting down")
finally:
server.shutdown()
if __name__ == "__main__":
main()
================================================
FILE: src/petals/client/__init__.py
================================================
from petals.client.config import ClientConfig
from petals.client.inference_session import InferenceSession
from petals.client.remote_sequential import RemoteSequential
from petals.client.routing import NoSpendingPolicy, RemoteSequenceManager, SpendingPolicyBase
================================================
FILE: src/petals/client/config.py
================================================
import dataclasses
import os
from typing import Optional, Sequence, Union
from hivemind import PeerID
from petals.constants import PUBLIC_INITIAL_PEERS
_max_retries = os.getenv("PETALS_MAX_RETRIES")
DEFAULT_MAX_RETRIES = int(_max_retries) if isinstance(_max_retries, str) else None
@dataclasses.dataclass
class ClientConfig:
initial_peers: Sequence[str] = tuple(PUBLIC_INITIAL_PEERS) # a list of initial peers for hivemind DHT
dht_prefix: Optional[str] = None # a prefix for all dht keys that correspond to this model (default: model name)
daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers
show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True]
allowed_servers: Optional[Sequence[Union[PeerID, str]]] = None # if defined, send requests only to these servers
blocked_servers: Optional[Sequence[Union[PeerID, str]]] = None # if defined, do not use these servers
use_server_to_server: bool = True # Use direct server-to-server communication
connect_timeout: float = 5 # timeout for opening a connection
request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests
update_period: float = 60 # refresh DHT information once in this many seconds
max_retries: Optional[int] = DEFAULT_MAX_RETRIES # max number of retries before an exception (default: inf)
min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
max_backoff: float = 60 # limit maximal sleep time between retries to this value
ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo)
max_pinged: int = 3 # max servers to ping from each sequence side, per update
ping_timeout: float = 2 # max time to wait for pings, per update
================================================
FILE: src/petals/client/from_pretrained.py
================================================
import contextlib
import json
import os
import re
import tempfile
from contextvars import ContextVar
from typing import List, Optional, Tuple, Union
from hivemind.utils.logging import get_logger
from transformers import BloomPreTrainedModel, modeling_utils
from petals.utils.version import get_compatible_model_repo
logger = get_logger(__name__)
class FromPretrainedMixin:
@classmethod
def from_pretrained(
cls,
model_name_or_path: Union[str, os.PathLike, None],
*args,
low_cpu_mem_usage: Optional[bool] = None,
**kwargs,
):
model_name_or_path = get_compatible_model_repo(model_name_or_path)
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
with ignore_keys(cls._keys_to_ignore_on_load_unexpected):
return super().from_pretrained(model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
"low_cpu_mem_usage(`bool`, *optional*)",
"low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
).replace(
"torch_dtype (`str` or `torch.dtype`, *optional*)",
'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)',
)
_ignored_keys = ContextVar("ignored_keys", default=None)
@contextlib.contextmanager
def ignore_keys(patterns: List[str]):
token = _ignored_keys.set(patterns)
try:
yield
finally:
_ignored_keys.reset(token)
def patched_get_checkpoint_shard_files(
pretrained_model_name_or_path, index_filename, *args, **kwargs
) -> Tuple[List[str], dict]:
"""Same as modeling_utils.get_checkpoint_shard_files(), but does not download shards for the ignored keys."""
should_ignore_keys = _ignored_keys.get() is not None
tempdir_ctx = tempfile.TemporaryDirectory() if should_ignore_keys else contextlib.nullcontext()
with tempdir_ctx as tempdir:
if should_ignore_keys:
with open(index_filename) as f:
index = json.load(f)
n_original_shards = len(set(index["weight_map"].values()))
index["weight_map"] = {
param_name: filename
for param_name, filename in index["weight_map"].items()
if all(re.search(pattern, param_name) is None for pattern in _ignored_keys.get())
}
n_loaded_shards = len(set(index["weight_map"].values()))
logger.debug(f"Loading {n_loaded_shards} shards out of {n_original_shards}")
# Replace the original index with a patched JSON, where ignored keys are removed
index_filename = os.path.join(tempdir, "pytorch_model.bin.index.json")
with open(index_filename, "w") as f:
json.dump(index, f)
return original_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
original_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files
================================================
FILE: src/petals/client/inference_session.py
================================================
from __future__ import annotations
import asyncio
import itertools
import time
import uuid
from typing import AsyncIterator, List, Optional, Tuple
import torch
from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import P2P
from hivemind.proto import runtime_pb2
from hivemind.utils.tensor_descr import BatchTensorDescriptor
from petals.client.config import ClientConfig
from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, DUMMY_INT64, is_dummy
from petals.utils.packaging import pack_args_kwargs
logger = get_logger(__name__)
class _ServerInferenceSession:
"""
An interface to a single multi-step *inference* session for a a set of blocks on a specific server.
:note: This class is *not* fault-tolerant out of the box.
"""
def __init__(
self,
config: ClientConfig,
span: RemoteSpanInfo,
uid: ModuleUID,
rpc_info: RPCInfo,
inputs_queue: asyncio.Queue,
outputs_aiter: AsyncIterator,
*,
max_length: int,
**metadata,
):
self.config = config
self.span, self.uid, self.rpc_info = span, uid, rpc_info
self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
self.session_id = str(uuid.uuid4())
self.session_metadata = dict(max_length=max_length, **metadata)
self.stepped = False
self.closed = False
self._position = 0
self.history = None # Used in case of server failures to regenerate attention caches on new servers
self.next_session = None
@classmethod
async def create(
cls,
config: ClientConfig,
p2p: P2P,
span: RemoteSpanInfo,
uid: ModuleUID,
rpc_info: RPCInfo,
**metadata,
) -> _ServerInferenceSession:
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id)
inputs_queue = asyncio.Queue()
outputs_stream = await asyncio.wait_for(
stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
config.connect_timeout,
)
return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata)
@staticmethod
async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
while True:
next_input_message = await asyncio.wait_for(queue.get(), input_timeout)
yield next_input_message
if not next_input_message.uid and not next_input_message.tensors:
break # this message means "done sending"
@property
def position(self):
return self._position
@position.setter
def position(self, start_from_position: int):
assert start_from_position <= self._position
self._position = start_from_position
if self.history is not None and self.history.shape[1] >= start_from_position:
self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None
def step(
self,
inputs: torch.Tensor,
prompts: torch.Tensor,
hypo_ids: torch.LongTensor,
*,
step_id: str,
) -> torch.Tensor:
"""
Inference step: send a chunk of input tensors and receive a chunk of outputs
:prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]
"""
if self.closed:
raise Exception("Session is closed, cannot perform step")
n_input_tokens = inputs.shape[1]
if self.history is None:
self.history = inputs
elif self.history.shape[1] == self._position:
self.history = torch.cat([self.history, inputs[:, -n_input_tokens:]], dim=1)
assert self.history.shape[1] == self._position + n_input_tokens, (
f"Broken input cache: span={self.span} shape={self.history.shape} "
f"position={self._position} n_input_tokens={n_input_tokens}"
)
if not self.stepped:
inputs = self.history # Pass full inputs including prefix
else:
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
# serialize inputs and put them into the queue
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids)
request_metadata = dict(session_id=self.session_id, step_id=step_id)
if not self.stepped:
request_metadata.update(self.session_metadata)
if self._position is not None:
request_metadata["start_from_position"] = self._position
elif self.config.use_server_to_server:
next_servers = self._collect_next_servers()
if next_servers:
request_metadata["next_servers"] = next_servers
request_metadata["args_structure"] = args_structure
# TODO: make possible to use different compression method for different tensors
server_side_inference_schema, kwargs_schema = self.rpc_info["inference_schema"]
compression = server_side_inference_schema[0].compression
inference_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in input_tensors)
# TODO: create more explicit way to check servers schema and client's structure
assert len(input_tensors) >= len(
server_side_inference_schema
), "Hidden_state, prompts and hypo_ids tensors are necessary for an inference step"
outputs_serialized = RemoteExpertWorker.run_coroutine(
self._step(
runtime_pb2.ExpertRequest(
uid=self.uid,
tensors=[
serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(input_tensors, inference_schema)
],
metadata=MSGPackSerializer.dumps(request_metadata),
)
)
)
outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
assert (
outputs[0].shape == inputs.shape
), f"output activation shape is different from input shape: {outputs[0].shape} != {inputs.shape}"
self._position += n_input_tokens
return outputs[0]
def _collect_next_servers(self) -> List[Tuple[str, str, int, int]]:
next_servers = []
session = self.next_session
while session is not None and session.stepped:
next_servers.append(
(session.span.peer_id.to_base58(), session.session_id, session.span.start, session.span.end)
)
session = session.next_session
return next_servers
async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
await self._inputs_queue.put(inputs_serialized)
self.stepped = True
return await asyncio.wait_for(anext(self._outputs_stream), self.config.request_timeout)
def close(self):
"""Finish a given inference session, close the underlying connection"""
if self._outputs_stream is None:
return # already closed
RemoteExpertWorker.run_coroutine(self._aclose_stream())
self._outputs_stream = self._inputs_queue = None
self.closed = True
async def _aclose_stream(self):
"""Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
if self._outputs_stream is None:
return # already closed
if self.stepped:
await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
try:
await anext(self._outputs_stream)
except StopAsyncIteration:
pass
def __del__(self):
self.close()
def __enter__(self):
assert not self.closed
return self
def __exit__(self, *exc_details):
self.close()
class InferenceSession:
"""
An interface to a multi-step *inference* session for a sequence of remote transformer blocks
"""
def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int):
self._sequence_manager = sequence_manager
self._closed = False
self._server_sessions = []
self._position = 0
self._max_length = max_length
self.output_ids = None
self.past_key_values = None
@property
def num_blocks(self) -> int:
return len(self._sequence_manager)
@property
def position(self) -> int:
return self._position
@position.setter
def position(self, start_from_position: int) -> None:
self._position = start_from_position
for session in self._server_sessions:
assert isinstance(session, _ServerInferenceSession)
session.position = start_from_position
def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
server_sessions = []
try:
for span in chosen_spans:
span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id)
session = RemoteExpertWorker.run_coroutine(
_ServerInferenceSession.create(
self._sequence_manager.config,
self._sequence_manager.state.p2p,
span,
span_uids,
rpc_info=self._sequence_manager.rpc_info,
max_length=self._max_length,
**metadata,
)
)
server_sessions.append(session)
session.__enter__()
return server_sessions
except:
self._exit_server_sessions(server_sessions)
raise
def _exit_server_sessions(self, server_sessions: List[_ServerInferenceSession]) -> None:
for session in reversed(server_sessions):
try:
session.__exit__(None, None, None)
except Exception:
logger.debug("Caught exception while closing connection to server:", exc_info=True)
def __enter__(self) -> "InferenceSession":
assert not self._closed and not self._server_sessions
return self
def step(
self,
inputs: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert not self._closed
if torch.is_grad_enabled():
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
if prompts is None or is_dummy(prompts):
prompts = DUMMY
else:
assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
assert prompts.shape[0] == self.num_blocks
assert prompts.shape[1] in (inputs.shape[0], 1)
assert prompts.shape[2] <= inputs.shape[1]
assert prompts.shape[3] == inputs.shape[2]
if hypo_ids is None or is_dummy(hypo_ids):
hypo_ids = DUMMY_INT64
else:
assert len(hypo_ids) == len(inputs)
assert hypo_ids.dtype == torch.int64
inputs_device = inputs.device
inputs_dtype = inputs.dtype
inputs = inputs.cpu()
prompts = prompts.cpu()
hypo_ids = hypo_ids.cpu()
step_id = str(uuid.uuid4())
n_input_tokens = inputs.shape[1]
if self._position + n_input_tokens > self._max_length:
raise ValueError(
f"Maximum length exceeded: prefix {self._position} + current {n_input_tokens} exceeds pre-allocated maximum {self._max_length}"
)
server_idx = 0
block_idx = 0
while block_idx < self.num_blocks:
for attempt_no in itertools.count():
logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
server_session = None
try:
if not self._server_sessions or attempt_no >= 1:
self._update_sequence(server_idx, block_idx, attempt_no)
server_session = self._server_sessions[server_idx]
assert server_session.position == self.position, f"{server_session.position} and {self.position}"
inputs = server_session.step(
inputs,
prompts[server_session.span.start : server_session.span.end],
hypo_ids,
step_id=step_id,
)
server_idx += 1
block_idx = server_session.span.end
self._sequence_manager.on_request_success(server_session.span.peer_id)
break
except Exception as e:
self._sequence_manager.on_request_failure(
server_session.span.peer_id if server_session is not None else None
)
if attempt_no + 1 == self._sequence_manager.config.max_retries:
raise
delay = self._sequence_manager.get_retry_delay(attempt_no)
logger.warning(
f"Caught exception when running inference via {server_session.span if server_session is not None else None} "
f"(retry in {delay:.0f} sec): {repr(e)}"
)
maybe_log_traceback(e)
time.sleep(delay)
self._position += n_input_tokens
outputs = inputs[:, -n_input_tokens:]
outputs = outputs.to(device=inputs_device, dtype=inputs_dtype)
return outputs
def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) -> int:
# If there is a failed server session, this code closes it
self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
n_prev_spans = len(self._server_sessions)
update_end = self._server_sessions[server_idx].span.end if server_idx < n_prev_spans else self.num_blocks
if attempt_no >= 1:
logger.debug(
f"Due to a server failure, remote attention caches "
f"from block {block_idx} to {update_end} will be regenerated"
)
updated_spans = self._sequence_manager.make_sequence(
block_idx, update_end, mode="min_latency", cache_tokens_needed=self._max_length
)
# make_sequence() could return a longer sequence
updated_spans[-1].end = min(updated_spans[-1].end, update_end)
updated_sessions = self._enter_server_sessions(updated_spans)
logger.debug(f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers")
# If there is a failed span, this code replaces it, otherwise it just adds new ones
if server_idx < n_prev_spans:
updated_sessions[0].history = self._server_sessions[server_idx].history
self._server_sessions[server_idx : server_idx + 1] = updated_sessions
# Update links to the next server session for direct server-to-server communication via rpc_push()
for i in range(max(server_idx - 1, 0), min(server_idx + len(updated_spans), len(self._server_sessions) - 1)):
self._server_sessions[i].next_session = self._server_sessions[i + 1]
def close(self, *exc_details):
"""Finish a given inference session, close the underlying connection"""
if not self._closed:
self._exit_server_sessions(self._server_sessions)
self._server_sessions.clear()
self._closed = True
def __exit__(self, *exc_details):
self.close(*exc_details)
def __del__(self):
self.close()
@property
def last_token_id(self) -> Optional[torch.Tensor]: # Backward compatibility with Petals < 2.1.0
return self.output_ids[:, -1:] if self.output_ids is not None else None
@last_token_id.setter
def last_token_id(self, value: torch.Tensor): # Backward compatibility with Petals < 2.1.0
if self.output_ids is None:
raise RuntimeError("Can't override `last_token_id` since the session has not stepped yet")
self.output_ids[:, -1:] = value
================================================
FILE: src/petals/client/lm_head.py
================================================
import dataclasses
import platform
from typing import Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from hivemind import get_logger
from torch import nn
from transformers import PretrainedConfig
logger = get_logger(__name__)
@dataclasses.dataclass
class LMHeadConfig:
# This settings matter for running the client with dtype bfloat16 on CPU.
# If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations.
use_chunked_forward: Union[str, bool] = "auto"
chunked_forward_step: int = 16384
class LMHead(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
if not config.tie_word_embeddings:
self.weight = nn.Parameter(torch.zeros(config.vocab_size, config.hidden_size))
self.weight.requires_grad = False
else:
self.weight = None # Will be set to get_input_embeddings().weight during loading the model
self.bias = None
self.in_features = config.hidden_size # Similar to nn.Linear attributes
self.out_features = config.vocab_size
self.use_chunked_forward = config.use_chunked_forward
if self.use_chunked_forward == "auto":
if platform.machine() == "x86_64":
# Import of cpufeature may crash on non-x86_64 machines
from cpufeature import CPUFeature
# If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward().
# Otherwise, it's ~8x slower.
self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"])
else:
self.use_chunked_forward = True
self.chunked_forward_step = config.chunked_forward_step
self._bf16_warning_shown = False
def forward(self, hidden_states):
if (
self.weight.dtype in [torch.float16, torch.bfloat16]
and self.weight.device.type == "cpu"
and self.use_chunked_forward
):
lm_logits = self.chunked_forward(hidden_states)
else:
# Switch dtype in case word_embeddings are fp16/bf16
hidden_states = hidden_states.to(self.weight.dtype)
lm_logits = F.linear(hidden_states, self.weight)
return lm_logits
def chunked_forward(self, hidden_states):
"""Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
chunked_forward_step: provides trade-off between efficiency and extra memory consumption.
"""
assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive"
if not self._bf16_warning_shown:
logger.warning(
"Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. "
"To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)"
)
self._bf16_warning_shown = True
hidden_states = hidden_states.float()
output = torch.empty(*hidden_states.shape[:-1], self.out_features)
for i in range(0, self.out_features, self.chunked_forward_step):
chunk = self.weight[i : i + self.chunked_forward_step].float()
output[..., i : i + self.chunked_forward_step] = F.linear(hidden_states, chunk)
return output
================================================
FILE: src/petals/client/ptune.py
================================================
import dataclasses
from contextlib import contextmanager
from typing import Optional
import torch
import torch.nn as nn
from hivemind import get_logger
from transformers import PretrainedConfig
from petals.utils.misc import DUMMY
logger = get_logger(__name__)
@dataclasses.dataclass
class PTuneConfig:
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
tuning_mode: Optional[str] = None # fine-tuning regime, one of [None, "ptune", "deep_ptune"]
class PTuneMixin:
_keys_to_ignore_on_load_missing = [r"(intermediate_)?prompt_embeddings\.weight$"]
def init_prompts(self, config: PretrainedConfig) -> None:
if config.tuning_mode and "ptune" in config.tuning_mode:
assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
self.pre_seq_len = config.pre_seq_len
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
with force_non_empty_weights():
# Prompt embeddings and their optimizer stats are kept in float32 to increase ptune quality
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
if config.tuning_mode == "deep_ptune":
self.intermediate_prompt_embeddings = nn.Embedding(
self.pre_seq_len,
config.num_hidden_layers * config.hidden_size,
# ^-- TODO: should be num_hidden_layers - 1
dtype=torch.float32,
)
elif config.tuning_mode:
raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
def get_prompt(self, batch_size):
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
prompts = self.prompt_embeddings(prefix_tokens)
if self.config.tuning_mode == "deep_ptune":
intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
intermediate_prompts = intermediate_prompts.view(
batch_size,
self.pre_seq_len,
self.config.num_hidden_layers,
self.config.hidden_size
# TODO: should be num_hidden_layers - 1
)
intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
else:
intermediate_prompts = DUMMY
dtype = self.word_embeddings.weight.dtype
return prompts.to(dtype), intermediate_prompts.to(dtype)
_original_register_parameter = nn.Module.register_parameter
@contextmanager
def force_non_empty_weights():
"""
This context manager allows to bypass the accelerate.init_empty_weights() context manager
(that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
The transformers library should replace all meta tensors by empty tensors by itself
but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
[1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
"""
possibly_patched_register_parameter = nn.Module.register_parameter
nn.Module.register_parameter = _original_register_parameter
try:
yield
finally:
nn.Module.register_parameter = possibly_patched_register_parameter
================================================
FILE: src/petals/client/remote_forward_backward.py
================================================
"""
Utility functions that call RPC forward or backward on a single remote server
"""
import asyncio
from typing import Iterable, List, Optional, Sequence, Tuple
import torch
from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor
from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor
from hivemind.p2p import StubBase
from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
from hivemind.utils.streaming import split_for_streaming
from hivemind.utils.tensor_descr import BatchTensorDescriptor
from petals.client.config import ClientConfig
from petals.data_structures import ModuleUID, RPCInfo
async def _forward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
) -> List[torch.Tensor]:
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
timeout=config.request_timeout,
)
return [deserialize_torch_tensor(t) for t in outputs.tensors]
async def _backward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
) -> List[torch.Tensor]:
grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
timeout=config.request_timeout,
)
return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
async def _forward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
) -> List[torch.Tensor]:
parts = (
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
for tensor in serialized_tensors
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
)
outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(parts)), config.connect_timeout)
outputs = aiter_with_timeout(outputs, config.request_timeout)
return await deserialize_tensor_stream(msg.tensors async for msg in outputs)
async def _backward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
) -> List[torch.Tensor]:
parts = (
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
for tensor in serialized_tensors
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
)
grad_inputs = await asyncio.wait_for(stub.rpc_backward_stream(iter_as_aiter(parts)), config.connect_timeout)
grad_inputs = aiter_with_timeout(grad_inputs, config.request_timeout)
return await deserialize_tensor_stream(msg.tensors async for msg in grad_inputs)
async def run_remote_forward(
uid: ModuleUID,
stub: StubBase,
rpc_info: RPCInfo,
*inputs: torch.Tensor,
config: ClientConfig,
metadata: Optional[bytes] = None,
**kwargs,
) -> Tuple[torch.Tensor, ...]:
"""
Serializes input tensors and calls "rpc_forward" on a remote server.
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
"""
# Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
# detach to avoid pickling the computation graph
assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
# Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
forward_inputs = tuple(nested_flatten((inputs, kwargs)))
args_schema, kwargs_schema = rpc_info["forward_schema"]
compression = args_schema[0].compression
forward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in forward_inputs)
inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
# TODO: create more explicit way to check servers schema and client's structure
assert len(inputs) >= len(args_schema) + 1, "Inputs and prompt tensors are necessary for a forward step"
# Asynchronous serialization
loop = asyncio.get_running_loop()
serialized_tensors = await asyncio.gather(
*(
loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs, forward_schema)
)
)
# call RPC on remote server
size = sum(t.element_size() * t.nelement() for t in inputs)
forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary
# Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs)
return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
async def run_remote_backward(
uid: ModuleUID,
stub: StubBase,
rpc_info: RPCInfo,
*inputs_and_grad_outputs: torch.Tensor,
config: ClientConfig,
metadata: Optional[bytes] = None,
**kwargs,
) -> Sequence[torch.Tensor]:
"""
Serializes grad outputs and calls "rpc_backward" on a remote server.
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
"""
args_schema, kwargs_schema = rpc_info["forward_schema"]
outputs_schema = rpc_info["outputs_schema"]
compression = args_schema[0].compression
backward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in inputs_and_grad_outputs)
# TODO: create more explicit way to check servers schema and client's structure
assert (
len(inputs_and_grad_outputs) >= len(args_schema) + len(outputs_schema) + 1
), "Inputs, grad_outputs and prompt tensors are necessary for a backward step"
# Asynchronous serialization
loop = asyncio.get_running_loop()
serialized_tensors = await asyncio.gather(
*(
loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
)
)
size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary
# Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs)
return deserialized_grad_inputs
================================================
FILE: src/petals/client/remote_generation.py
================================================
import contextlib
import dataclasses
from contextvars import ContextVar
from typing import Any, ContextManager, Dict, List, Optional, Tuple
import torch
import transformers
from hivemind.utils.logging import get_logger
from torch import Tensor
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation.utils import ModelOutput
from petals.client.inference_session import InferenceSession
from petals.client.remote_sequential import RemoteSequential
from petals.utils.misc import DUMMY, docstring_from
logger = get_logger(__name__)
class RemotePastKeyValues(Cache):
"""only keeps the number of seen tokens. pretends to be a legit cache"""
def __init__(self) -> None:
super().__init__()
self._seen_tokens = 0
self.hypo_ids: Optional[torch.LongTensor] = None
def __getitem__(self, _index: int) -> List[torch.Tensor]:
return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
return self._seen_tokens
def get_max_length(self) -> Optional[int]:
return None
def update_seen(self, new_seen: int) -> None:
self._seen_tokens += new_seen
def reorder_cache(self, beam_idx):
raise NotImplementedError("Beam search reordering is not implemented yet")
_skipped_tokens = ContextVar("skipped_tokens", default=0)
class _SkipTokensMixin:
# This override is used in RemoteGenerationMixin by has to be defined in a class not named as "GenerationMixin"
# due to how transformers.PreTrainedModel.can_generate() works
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict:
input_ids = input_ids[:, _skipped_tokens.get() :]
_skipped_tokens.set(0)
return super().prepare_inputs_for_generation(input_ids, **kwargs)
class RemoteGenerationMixin(_SkipTokensMixin):
"""
This class is an upgrade to `transformers.GenerationMixin` that:
- Designed to be compatible with most `transformers.GenerationMixin` strategies and options
- Supports generation inside a remote InferenceSession, so that remote servers store your attention caches and
you don't have to rerun the prefix through all the servers to generate each new token
- Supports multiple `.generate()` calls inside one InferenceSession, so you can easily run interactive generation
by showing tokens on the fly (multiple calls like `.generate(None, max_new_tokens=1, ...)`) or
accept prompts from a user in a chat bot (multiple calls like `.generate(new_prompts, ...)`).
- If there is no active session, `.generate()` will create a new InferenceSession with proper `max_length`.
Otherwise, `.generate()` will use the active session. You can use the `session=...` argument to override that.
"""
@docstring_from(RemoteSequential.active_session)
@property
def active_session(self) -> Optional[InferenceSession]:
return self.transformer.h.active_session
@docstring_from(RemoteSequential.use_session)
def use_session(self, session: Optional[InferenceSession]) -> ContextManager[InferenceSession]:
return self.transformer.h.use_session(session)
@docstring_from(RemoteSequential.inference_session)
def inference_session(self, **kwargs) -> ContextManager[InferenceSession]:
return self.transformer.h.inference_session(**kwargs)
@docstring_from(transformers.GenerationMixin.generate.__doc__)
def generate(
self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs
):
self._fix_generate_kwargs(kwargs)
if inputs is None:
inputs = kwargs.pop("input_ids", None)
if session is not None:
# If a session specified explicitly, use it
context_manager = self.use_session(session)
elif self.active_session is not None:
# If there's an active session, don't do anything
context_manager = contextlib.nullcontext(self.active_session)
else:
# If there's no active session, create a new one
max_length = kwargs.get("max_length")
max_new_tokens = kwargs.get("max_new_tokens")
assert (max_length is None) != (
max_new_tokens is None
), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches"
session_max_length = self.transformer.config.pre_seq_len
if max_length is not None:
session_max_length += max_length
else:
session_max_length += (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
context_manager = self.inference_session(max_length=session_max_length)
with context_manager as session:
# Prepend the tokens from the previous .generate() call
n_prev_tokens = session.output_ids.shape[1] if session.output_ids is not None else 0
if n_prev_tokens > 0:
if kwargs.get("num_beams", 1) > 1:
logger.warning(
"Beam search will not work properly in the resumed petals.InferenceSession "
"since intermediate beam entries are lost"
)
if inputs is not None:
inputs = torch.cat([session.output_ids, inputs], dim=1)
else:
inputs = session.output_ids
# Don't actually run all previous tokens through the transformer,
# but keep them for transformers.GenerationMixin (e.g., to compute repetition_penalty)
_skipped_tokens.set(max(0, n_prev_tokens - 1))
if self._supports_cache_class and "past_key_values" not in kwargs:
past_key_values = RemotePastKeyValues()
past_key_values.update_seen(session.position)
kwargs["past_key_values"] = past_key_values
result = super().generate(inputs, *args, **kwargs)
sequences = result.sequences if isinstance(result, ModelOutput) else result
# Save tokens from this .generate() call
session.output_ids = sequences
# Crop the last tokens from the previous call
sequences = sequences[:, n_prev_tokens:].clone()
if isinstance(result, ModelOutput):
result.sequences = sequences
else:
result = sequences
return result
@staticmethod
def _fix_generate_kwargs(kwargs: dict):
# Suppress inappropriate "Both max_new_tokens and max_length" HF warning
if "max_length" in kwargs and kwargs["max_length"] is None:
del kwargs["max_length"]
# Support do_sample = {0, 1} for backward compatibility with Petals < 2.1.0
do_sample = kwargs.get("do_sample")
if isinstance(do_sample, int):
kwargs["do_sample"] = bool(do_sample)
@staticmethod
def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:
return dataclasses.replace(past_key_values, hypo_ids=beam_idx)
================================================
FILE: src/petals/client/remote_sequential.py
================================================
from __future__ import annotations
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Optional, Union
import torch
from hivemind import DHT, get_logger
from torch import nn
from petals.client.config import ClientConfig
from petals.client.inference_session import InferenceSession
from petals.client.routing import RemoteSequenceManager
from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
from petals.data_structures import UID_DELIMITER
logger = get_logger(__name__)
class RemoteSequential(nn.Module):
"""
A sequence of transformer blocks hosted by the swarm.
"""
def __init__(
self,
config: ClientConfig,
*,
sequence_manager: Optional[RemoteSequenceManager] = None,
dht: Optional[DHT] = None,
start_block: Optional[int] = None,
end_block: Optional[int] = None,
**kwargs,
):
super().__init__()
self.config = config
assert sequence_manager is None or (
dht is None and start_block is None and end_block is None
), "`dht`, `start_block`, and `end_block` have no effect when you provide a custom `sequence_manager`"
if sequence_manager is None:
if start_block is None:
start_block = 0
if end_block is None:
end_block = self.config.num_hidden_layers
block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block))
sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs)
self.sequence_manager = sequence_manager
self._active_session = ContextVar("active_session", default=None)
def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
if self.active_session is None:
assert all(v is None for v in kwargs.values()), f"Extra kwargs are not supported in forward: {kwargs}"
return _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
else:
return self.active_session.step(inputs, prompts, **kwargs)
@property
def active_session(self) -> Optional[InferenceSession]:
"""
If called inside `with model.inference_session(...):` or `with model.use_session(...):`,
returns an active InferenceSession. Otherwise, returns None.
"""
return self._active_session.get()
@property
def position(self) -> int:
"""Returns the prefix length (in tokens) in the active inference session or zero if no session is active."""
return self.active_session.position if self.active_session is not None else 0
@contextmanager
def use_session(self, session: Optional[InferenceSession]) -> InferenceSession:
"""Inside this context, forward() will use an _existing_ InferenceSession provided as the argument."""
token = self._active_session.set(session)
try:
yield session
finally:
self._active_session.reset(token)
@contextmanager
def inference_session(self, **kwargs) -> InferenceSession:
"""
Inside this context, forward() will use a _new_ InferenceSession created with given parameters.
:param max_length: Maximal expected length of inference results. Servers use this parameter
to calculate the size of attention caches allocated to this client.
"""
with InferenceSession(self.sequence_manager, **kwargs) as session, self.use_session(session):
yield session
def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:
return RemoteSequential(
self.config,
sequence_manager=self.sequence_manager[ix],
)
def __iter__(self):
for block_index in range(len(self)):
yield self[block_index]
def __len__(self):
return len(self.sequence_manager)
def extra_repr(self) -> str:
return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"
================================================
FILE: src/petals/client/routing/__init__.py
================================================
from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
================================================
FILE: src/petals/client/routing/sequence_info.py
================================================
import dataclasses
import time
from typing import Iterable, List, Optional, Tuple
from hivemind import get_logger
from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
from petals.utils.dht import compute_spans
logger = get_logger(__name__)
@dataclasses.dataclass
class RemoteSequenceInfo:
"""
A dataclass that stores general information about which servers hold any given layer;
- updated by RemoteSequenceManager in a background thread
- accessed by routing strategies in .on_update
:note: this class should *not* be modified by RoutingStrategy.on_update to avoid interference between strategies;
Any metadata specific to one routing strategy, it should be stored inside that strategy. Any information that
is used by most routing strategies should be moved from said strategies to this class.
"""
block_uids: Tuple[ModuleUID, ...]
block_infos: Tuple[RemoteModuleInfo, ...] # note: the contents of RemoteModuleInfo can and will be updated
spans_by_priority: List[RemoteSpanInfo]
spans_containing_block: Tuple[List[RemoteSpanInfo], ...]
last_updated_time: Optional[float]
@classmethod
def make_empty(cls, block_uids: Iterable[ModuleUID]) -> "RemoteSequenceInfo":
block_uids = tuple(block_uids)
empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids)
empty_spans = tuple([] for _ in range(len(block_uids)))
return cls(block_uids, empty_block_infos, [], empty_spans, last_updated_time=None)
def __getitem__(self, ix: slice):
assert isinstance(ix, slice)
block_uids, block_infos = self.block_uids[ix], self.block_infos[ix]
spans_by_priority, spans_containing_block = self._sort_spans(block_infos)
return RemoteSequenceInfo(
block_uids, block_infos, spans_by_priority, spans_containing_block, self.last_updated_time
)
def __len__(self):
return len(self.block_uids)
def update_(self, new_block_infos: List[RemoteModuleInfo]):
assert len(new_block_infos) == len(self.block_uids)
for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
assert uid == info.uid, f"The DHT entry for {uid} actually points to {info.uid}"
self.block_infos[block_index].servers = info.servers
self.spans_by_priority, self.spans_containing_block = self._sort_spans(self.block_infos)
self.last_updated_time = time.perf_counter()
@staticmethod
def _sort_spans(block_infos: List[RemoteModuleInfo]):
spans_by_priority = list(compute_spans(block_infos, min_state=ServerState.ONLINE).values())
spans_by_priority.sort(key=lambda span: span.length, reverse=True)
spans_containing_block = tuple([] for _ in range(len(block_infos)))
for span in spans_by_priority:
for block_index in range(span.start, span.end):
spans_containing_block[block_index].append(span)
return spans_by_priority, spans_containing_block
================================================
FILE: src/petals/client/routing/sequence_manager.py
================================================
from __future__ import annotations
import asyncio
import dataclasses
import itertools
import logging
import random
import threading
import time
import warnings
from typing import Any, Dict, List, Optional, Sequence, Set, Union
from weakref import WeakMethod
import dijkstar
import numpy as np
from hivemind import DHT, P2P, MSGPackSerializer, PeerID
from hivemind.dht.node import Blacklist
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.proto import runtime_pb2
from hivemind.utils.logging import get_logger
from petals.client.config import ClientConfig
from petals.client.routing.sequence_info import RemoteSequenceInfo
from petals.client.routing.spending_policy import NoSpendingPolicy
from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState
from petals.server.handler import TransformerConnectionHandler
from petals.utils.dht import get_remote_module_infos
from petals.utils.ping import PingAggregator
from petals.utils.random import sample_up_to
logger = get_logger(__name__)
class SequenceManagerConfig(ClientConfig):
def __init__(self, *args, **kwargs):
warnings.warn(
"petals.client.routing.SequenceManagerConfig has been moved to petals.ClientConfig. "
"This alias will be removed in Petals 2.2.0+",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
@dataclasses.dataclass
class SequenceManagerState:
p2p: P2P = None
sequence_info: Optional[RemoteSequenceInfo] = None
rpc_info: Optional[dict] = None
banned_peers: Optional[Blacklist] = None
def __getitem__(self, ix: Union[int, slice]) -> SequenceManagerState:
return dataclasses.replace(self, sequence_info=self.sequence_info[ix])
def __len__(self) -> int:
return len(self.sequence_info)
class RemoteSequenceManager:
"""
Sequence manager is a thread that keeps track of remote servers that hold the specified sequence of blocks.
TL;DR it tells you, which peers you should ask to get a specific layer. It is used in RemoteSequential.
When created, RemoteSequenceManager looks up which servers serve necessary layers by reading from DHT.
Using this information, sequence manager can form sequences of servers that collectively have the full sequence.
To form such a sequence, call .make_sequence with the appropriate optimization policy (see make_sequence docstr).
:note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid
running redundant sequence managers for the same set of layers.
"""
def __init__(
self,
config: ClientConfig,
block_uids: Sequence[ModuleUID],
*,
dht: Optional[DHT] = None,
state: Optional[SequenceManagerState] = None,
):
assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
assert len(block_uids) > 0, "Sequences must contain at least one block"
self.config = config
if state is None:
state = SequenceManagerState()
self.state = state
if dht is None:
dht = DHT(
initial_peers=config.initial_peers,
client_mode=True,
num_workers=32,
startup_timeout=config.daemon_startup_timeout,
start=True,
)
assert isinstance(dht, DHT) and dht.is_alive(), "`dht` must be a running hivemind.DHT instance"
self.dht = dht
if state.p2p is None:
state.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
self.lock_changes = threading.Lock()
self._thread = _SequenceManagerUpdateThread(config.update_period, WeakMethod(self._update))
self._thread_start_lock = threading.Lock()
self.policy = NoSpendingPolicy()
self.allowed_servers = self._peer_ids_to_set(config.allowed_servers)
self.blocked_servers = self._peer_ids_to_set(config.blocked_servers)
self.ping_aggregator = PingAggregator(dht)
if state.banned_peers is None:
state.banned_peers = Blacklist(base_time=config.ban_timeout, backoff_rate=2.0)
if state.sequence_info is None:
state.sequence_info = RemoteSequenceInfo.make_empty(block_uids)
if state.sequence_info.last_updated_time is not None:
assert block_uids == state.sequence_info.block_uids
self._thread.ready.set() # no need to await the first dht fetch
@staticmethod
def _peer_ids_to_set(peer_ids: Optional[Sequence[Union[PeerID, str]]]) -> Optional[Set[PeerID]]:
if peer_ids is None:
return None
result = set()
for peer_id in peer_ids:
if isinstance(peer_id, PeerID):
result.add(peer_id)
elif isinstance(peer_id, str):
result.add(PeerID.from_base58(peer_id))
else:
raise TypeError(
f"`allowed_servers` and `blocked_servers` have to contain only PeerIDs or strings, but got {type(peer_id)}"
)
return result
def make_sequence(
self,
start_index: int = 0,
end_index: Optional[int] = None,
*,
mode: str,
cache_tokens_needed: Optional[int] = None,
) -> List[RemoteSpanInfo]:
"""
Form a sequence of remote servers that collectively serve all consecutive layers
:param start_index: optional index of the first module in a sequence, default = the first of block_uids
:param end_index: optional index of the last module (non-inclusive), default = after last of block uids
:param mode: one of ["max_throughput", "min_latency"]
"""
with self._thread_start_lock:
if not self.is_alive():
self._thread.start()
if not self.ready.is_set():
self.update(wait=True) # this will await an existing update or trigger a new one (if not updating)
end_index = end_index if end_index is not None else len(self)
if mode == "min_latency":
span_sequence = self._make_sequence_with_min_latency(
start_index, end_index, cache_tokens_needed=cache_tokens_needed
)
elif mode == "max_throughput":
span_sequence = self._make_sequence_with_max_throughput(start_index, end_index)
else:
raise RuntimeError(f"Unexpected mode {mode}")
if self.config.show_route is True or (mode == "min_latency" and self.config.show_route == "inference"):
route_repr = " => ".join(
[f"{span.start}:{span.end} via …{str(span.peer_id)[-6:]}" for span in span_sequence]
)
logger.info(f"Route found: {route_repr}")
return span_sequence
def _make_sequence_with_min_latency(
self, start_index: int, end_index: int, *, cache_tokens_needed: Optional[int]
) -> List[RemoteSpanInfo]:
if start_index == end_index:
return []
with self.lock_changes:
missing_blocks = [
block_idx
for block_idx in range(start_index, end_index)
if not self.state.sequence_info.spans_containing_block[block_idx]
]
if missing_blocks:
raise MissingBlocksError(missing_blocks)
server_infos = {
span.peer_id: span.server_info
for block_idx in range(start_index, end_index)
for span in self.state.sequence_info.spans_containing_block[block_idx]
}
graph = self._build_inference_graph(start_index, end_index, cache_tokens_needed=cache_tokens_needed)
path = dijkstar.find_path(graph, "start", "end")
logger.debug(f"Path info: {path}")
if start_index == 0 and end_index == len(self):
logger.debug(f"Expected speed: {1 / path.total_cost:.1f} steps/sec")
span_sequence = []
for peer_id, block_idx in path.nodes[1:-1]:
if not span_sequence or span_sequence[-1].peer_id != peer_id:
span_sequence.append(RemoteSpanInfo(peer_id, block_idx, block_idx, server_infos[peer_id]))
else:
span_sequence[-1].end = block_idx
# Remove empty spans that can appear if we don't force to go to the end of each server and network delay
# don't follow triangle inequality (delay(A, B) + delay(B, C) < delay(A, C)) due to measurement errors
span_sequence = [span for span in span_sequence if span.length > 0]
return span_sequence
def _build_inference_graph(
self,
start_index: int,
end_index: int,
*,
cache_tokens_needed: Optional[int],
overhead_delay: float = 0.018, # Serialization overhead (empirically measured)
default_inference_rps: float = 300, # If inference RPS unknown
alloc_delay: float = 10, # If not enough cache left, we penalize the edge
) -> dijkstar.Graph:
missing_blocks = [
block_idx
for block_idx in range(start_index, end_index)
if not self.state.sequence_info.spans_containing_block[block_idx]
]
if missing_blocks:
raise MissingBlocksError(missing_blocks)
client_server_rtts = self.ping_aggregator.to_dict()
graph = dijkstar.Graph()
# Clent -> server network delays
for span in self.state.sequence_info.spans_containing_block[start_index]:
delay = self._rtt_to_delay(client_server_rtts.get(span.peer_id))
delay += overhead_delay
if not self._has_cache_for(span, cache_tokens_needed):
delay += alloc_delay
graph.add_edge("start", (span.peer_id, start_index), delay)
# Server -> client network delays
for span in self.state.sequence_info.spans_containing_block[end_index - 1]:
delay = self._rtt_to_delay(client_server_rtts.get(span.peer_id))
graph.add_edge((span.peer_id, end_index), "end", delay)
# Server -> server network delays
for block_idx in range(start_index + 1, end_index):
for cur_span in self.state.sequence_info.spans_containing_block[block_idx - 1]:
if cur_span.end != block_idx:
# If we choose a server, we force to go to the end of it before switching to a new one
# to avoid O(N^2) graphs for N servers
continue
for next_span in self.state.sequence_info.spans_containing_block[block_idx]:
rtt = None
if cur_span.server_info.next_pings is not None:
rtt = cur_span.server_info.next_pings.get(next_span.peer_id.to_base58())
delay = self._rtt_to_delay(rtt)
delay += overhead_delay
if not self._has_cache_for(next_span, cache_tokens_needed):
delay += alloc_delay
graph.add_edge((cur_span.peer_id, block_idx), (next_span.peer_id, block_idx), delay)
# Compute delays
for span in self.state.sequence_info.spans_by_priority:
for block_idx in range(max(span.start, start_index), min(span.end, end_index)):
inference_rps = span.server_info.inference_rps
if inference_rps is None:
inference_rps = default_inference_rps
graph.add_edge((span.peer_id, block_idx), (span.peer_id, block_idx + 1), 1.0 / inference_rps)
return graph
@staticmethod
def _rtt_to_delay(
rtt: float,
*,
default_delay: float = 0.15, # If network delay unknown
max_delay: float = 5, # If unreachable, we don't want to discard the edge completely
) -> float:
if rtt is None:
return default_delay
return min(rtt / 2, max_delay)
@staticmethod
def _has_cache_for(span: RemoteSpanInfo, cache_tokens_needed: Optional[int] = None) -> bool:
if cache_tokens_needed is None or span.server_info.cache_tokens_left is None:
return True
# Here, `span` contains all blocks hosted by a server - but we won't necessarily run all of them through
# this particular server in our path. It is difficult to estimate how many blocks we'll use at this stage,
# so we assume that we'll use all of them (the worst case for the cache size) and get a pessimistic estimate.
# This is okay since false positives are more costly than false negatives here.
return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left
def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]:
client_server_rtts = self.ping_aggregator.to_dict()
span_sequence = []
current_index = start_index
while current_index < end_index:
candidate_spans = self.state.sequence_info.spans_containing_block[current_index]
if not candidate_spans:
raise MissingBlocksError(current_index)
# We choose longer servers to minimize the number of hops but leave some randomization
# to distribute the load. We also exclude servers known to be unreachable.
eps = 1e-6
span_weights = np.array(
[span.length if client_server_rtts.get(span.peer_id) != np.inf else eps for span in candidate_spans],
dtype=np.float64,
)
chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
assert chosen_span.start <= current_index < chosen_span.end
span_sequence.append(dataclasses.replace(chosen_span, start=current_index))
current_index = chosen_span.end
return span_sequence
def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
"""Get a RemoteSequenceManager for a sub-sequence of blocks"""
assert isinstance(ix, (int, slice))
if not isinstance(ix, slice):
ix = slice(int(ix), int(ix) + 1, 1)
return type(self)(self.config, self.block_uids[ix], dht=self.dht, state=self.state[ix])
def update(self, *, wait: bool):
"""Run an asynchronous update in background as soon as possible"""
self.ready.clear()
self._thread.trigger.set()
if wait:
self.ready.wait()
def _update(self):
"""Perform an immediate and synchronous refresh, may take time"""
new_block_infos = get_remote_module_infos(
self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=True
)
for block_info in new_block_infos:
# Apply allow and block lists
block_info.servers = {
peer_id: server_info
for peer_id, server_info in block_info.servers.items()
if (self.allowed_servers is None or peer_id in self.allowed_servers)
and (self.blocked_servers is None or peer_id not in self.blocked_servers)
}
# Remove temporarily banned peers, unless there are no peers left
valid_servers = {
peer_id: server_info
for peer_id, server_info in block_info.servers.items()
if peer_id not in self.state.banned_peers
}
if len(valid_servers) < len(block_info.servers):
if valid_servers:
logger.debug(
f"Kept {len(valid_servers)} out of {len(block_info.servers)} servers holding {block_info.uid}"
)
block_info.servers = valid_servers
else:
# If we blacklisted all servers, the error may actually be client-caused
logger.debug(f"All servers holding {block_info.uid} are blacklisted, ignoring blacklist")
with self.lock_changes:
self.state.sequence_info.update_(new_block_infos)
first_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[0]]
middle_servers = [
span.peer_id for spans in self.state.sequence_info.spans_containing_block[1:-1] for span in spans
]
last_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[-1]]
pinged_servers = set(sample_up_to(first_servers, self.config.max_pinged))
pinged_servers = set(sample_up_to(middle_servers, self.config.max_pinged))
pinged_servers |= set(sample_up_to(last_servers, self.config.max_pinged))
self.ping_aggregator.ping(list(pinged_servers), wait_timeout=self.config.ping_timeout)
self.ready.set()
def on_request_failure(self, peer_id: Optional[PeerID]):
"""remove a given peer from the routing table. If the routing is no longer possible, trigger an update"""
if peer_id is not None:
logger.debug(f"Peer {peer_id} did not respond, banning it temporarily")
self.state.banned_peers.register_failure(peer_id)
with self.lock_changes:
should_update = False
for info in self.state.sequence_info.block_infos:
info.servers.pop(peer_id, None)
if not info.servers:
should_update = True
if should_update:
self.ready.clear()
self.update(wait=False)
def on_request_success(self, peer_id: PeerID):
"""if peer has a failure streak, clear that streak"""
self.state.banned_peers.register_success(peer_id)
def __len__(self):
return len(self.block_uids)
@property
def is_alive(self):
return self._thread.is_alive
@property
def ready(self) -> threading.Event:
return self._thread.ready
@property
def block_uids(self):
return self.state.sequence_info.block_uids
@property
def rpc_info(self):
"""Return the rpc_info queried from one of the servers that hold the first block"""
if self.state.rpc_info is not None:
return self.state.rpc_info
with self._thread_start_lock:
if not self.is_alive():
self._thread.start()
for attempt_no in itertools.count():
peer_id = None
try:
if not self.ready.is_set():
self.update(wait=True)
active_servers = [
peer_id
for peer_id, server in self.state.sequence_info.block_infos[0].servers.items()
if server.state == ServerState.ONLINE
]
if not active_servers:
raise MissingBlocksError(0)
peer_id = random.choice(active_servers)
stub = TransformerConnectionHandler.get_stub(self.state.p2p, peer_id)
outputs = RemoteExpertWorker.run_coroutine(
stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]), timeout=self.config.request_timeout)
)
self.state.rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
self.on_request_success(peer_id)
break
except Exception as e:
self.on_request_failure(peer_id)
if attempt_no + 1 == self.config.max_retries:
raise
delay = self.get_retry_delay(attempt_no)
logger.warning(
f"Caught exception when gathering information from peer {peer_id} "
f"(retry in {delay:.0f} sec): {repr(e)}"
)
maybe_log_traceback(e)
time.sleep(delay)
return self.state.rpc_info
def get_retry_delay(self, attempt_no: int) -> float:
if attempt_no == 0:
return 0
return min(self.config.min_backoff * 2 ** (attempt_no - 1), self.config.max_backoff)
def get_request_metadata(
self, protocol: str, args_structure: Any = None, *args, **kwargs
) -> Optional[Dict[str, Any]]:
"""
:param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
:param args_structure: the structure of flattened tensors from pack_args_kwargs in petals.utils.packaging
:param args: request-specific inputs, typically block uids and input tensors
:param kwargs: additional request context, such as remote peer ID
:returns: msgpack-serialized metadata dict that will be passed alongside a given request
"""
return dict(
points=self.policy.get_points(protocol, *args, **kwargs),
active_adapter=self.config.active_adapter,
args_structure=args_structure,
)
def shutdown(self):
self._thread.shutdown()
class _SequenceManagerUpdateThread(threading.Thread):
def __init__(self, update_period: float, ref_update_manager: WeakMethod):
super().__init__(daemon=True)
self.ref_update_manager = ref_update_manager
self.ready = threading.Event()
self.trigger = threading.Event()
self.update_period = update_period
self.should_shutdown = False
def run(self) -> None:
while not self.should_shutdown:
update_manager = self.ref_update_manager()
if update_manager is None:
logger.debug(f"{self.__class__.__name__} exited because the sequence manager no longer exists")
break
try:
self.trigger.clear()
update_manager()
except Exception as e:
logger.exception(e)
finally:
del update_manager
self.trigger.wait(self.update_period)
logger.debug(f"{self.__class__.__name__} thread exited")
def shutdown(self, timeout: Optional[float] = None):
self.should_shutdown = True
self.trigger.set()
if self.is_alive():
self.join(timeout)
def __del__(self):
self.shutdown()
def maybe_log_traceback(exc: Exception):
traceback_level = logging.DEBUG if str(exc) or isinstance(exc, asyncio.TimeoutError) else logging.WARNING
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
class MissingBlocksError(RuntimeError):
def __init__(self, block_indices: Union[int, Sequence[int]]):
super().__init__(
f"No servers holding blocks {block_indices} are online. "
f"You can check the public swarm's state at https://health.petals.dev "
f"If there are not enough servers, please connect your GPU: "
f"https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity "
)
================================================
FILE: src/petals/client/routing/spending_policy.py
================================================
"""
An interface for exchanging internal "BLOOM points" for higher priority compute requests. NOT IMPLEMENTED.
The intent is to let Petals participants earn points by helping others while idle (e.g. at night), then use these
points to run their own compute experiments faster. See Section 4 of https://arxiv.org/abs/2209.01188 for discussion.
"""
from abc import ABC, abstractmethod
class SpendingPolicyBase(ABC):
@abstractmethod
def get_points(self, protocol: str, *args, **kwargs) -> float:
pass
class NoSpendingPolicy(SpendingPolicyBase):
def get_points(self, protocol: str, *args, **kwargs) -> float:
return 0.0
================================================
FILE: src/petals/client/sequential_autograd.py
================================================
"""
A PyTorch autograd function that runs forward/backward on a sequence of remote servers in a fault-tolerant manner
"""
import asyncio
import itertools
from collections import deque
from typing import List, Optional, Sequence, Tuple
import torch
from hivemind import MSGPackSerializer
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.utils.logging import get_logger
from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, is_dummy
from petals.utils.packaging import pack_args_kwargs
logger = get_logger(__name__)
MAX_TOKENS_IN_BATCH = 1024
async def sequential_forward(
inputs: torch.Tensor,
prompts: torch.Tensor,
sequence_manager: RemoteSequenceManager,
start_index: int = 0,
end_index: Optional[int] = None,
) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
"""
Constructs a routing path from <start_index> to <end_index>.
Performs chained forward for each subsequence of blocks on the path.
If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
"""
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
inputs_device = inputs.device
inputs_dtype = inputs.dtype
inputs = inputs.cpu()
prompts = prompts.cpu()
end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
assert is_dummy(prompts) or len(prompts) == len(
sequence_manager.block_uids
) # should be n_layers - 1 but add extra prompts for convenience
sequences = deque()
intermediate_inputs = []
done_sequences = []
block_idx = start_index
while block_idx < end_index:
for attempt_no in itertools.count():
logger.debug(f"Forward: block {block_idx}, attempt {attempt_no}")
span = None
try:
if not sequences or attempt_no >= 1:
sequences = deque(sequence_manager.make_sequence(block_idx, end_index, mode="max_throughput"))
# make_sequence() could return a longer sequence
sequences[-1].end = min(sequences[-1].end, end_index)
logger.debug(f"Found path from block {block_idx} to {end_index} via {len(sequences)} servers")
span = sequences.popleft()
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
flat_tensors, args_structure = pack_args_kwargs(inputs, prompts[span.start : span.end])
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
metadata = sequence_manager.get_request_metadata(
"rpc_forward", args_structure, span_uids, *flat_tensors
)
(outputs,) = await run_remote_forward(
span_uids,
stub,
sequence_manager.rpc_info,
*flat_tensors,
config=sequence_manager.config,
metadata=MSGPackSerializer.dumps(metadata),
)
assert isinstance(outputs, torch.Tensor)
assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
# Save intermediate inputs and subsequences if the forward is already done for them
intermediate_inputs.append(inputs)
done_sequences.append(span)
inputs = outputs
block_idx = span.end
sequence_manager.on_request_success(span.peer_id)
break
except Exception as e:
sequence_manager.on_request_failure(span.peer_id if span is not None else None)
if attempt_no + 1 == sequence_manager.config.max_retries:
raise
delay = sequence_manager.get_retry_delay(attempt_no)
logger.warning(
f"Caught exception when running forward via {span} (retry in {delay:.0f} sec): {repr(e)}"
)
maybe_log_traceback(e)
await asyncio.sleep(delay)
outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)
intermediate_inputs = [tensor.to(device=inputs_device, dtype=inputs_dtype) for tensor in intermediate_inputs]
return outputs, intermediate_inputs, done_sequences
async def sequential_backward(
grad_outputs: Sequence[torch.Tensor],
intermediate_inputs: List[torch.Tensor],
prompts: torch.Tensor,
forward_sequences: List[RemoteSpanInfo],
sequence_manager: RemoteSequenceManager,
) -> Tuple[Sequence[torch.Tensor], torch.Tensor]:
"""
Performs chained backward for each forward subsequence.
If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
"""
assert len(intermediate_inputs) == len(forward_sequences)
grad_outputs_device = grad_outputs[0].device if grad_outputs else None
grad_outputs_dtype = grad_outputs[0].dtype if grad_outputs else None
prompts_device = prompts.device
prompts_dtype = prompts.dtype
grad_outputs = [tensor.cpu() for tensor in grad_outputs]
intermediate_inputs = [tensor.cpu() for tensor in intermediate_inputs]
prompts = prompts.cpu()
grad_prompts_reversed = []
while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
inputs = intermediate_inputs.pop()
span = forward_sequences.pop()
for attempt_no in itertools.count():
logger.debug(f"Backward: block {span.end - 1}, attempt {attempt_no}")
try:
if attempt_no >= 1:
_, backup_inputs, backup_sequences = await sequential_forward(
inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
)
assert len(backup_inputs) == len(backup_sequences)
assert backup_sequences[0].start == span.start
assert backup_sequences[-1].end == span.end
intermediate_inputs.extend(backup_inputs)
forward_sequences.extend(backup_sequences)
inputs = intermediate_inputs.pop()
span = forward_sequences.pop()
grad_outputs_cpu = [grad.cpu() for grad in grad_outputs]
flat_tensors, args_structure = pack_args_kwargs(
inputs, *grad_outputs_cpu, prompts[span.start : span.end]
)
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
metadata = sequence_manager.get_request_metadata(
"rpc_backward", args_structure, span_uids, *flat_tensors, peer_id=span.peer_id
)
grad_outputs, *span_grad_prompts = await run_remote_backward(
span_uids,
stub,
sequence_manager.rpc_info,
*flat_tensors,
config=sequence_manager.config,
metadata=MSGPackSerializer.dumps(metadata),
)
grad_outputs = [grad_outputs]
grad_prompts_reversed.extend(span_grad_prompts)
sequence_manager.on_request_success(span.peer_id)
break
except Exception as e:
sequence_manager.on_request_failure(span.peer_id if span is not None else None)
if attempt_no + 1 == sequence_manager.config.max_retries:
raise
delay = sequence_manager.get_retry_delay(attempt_no)
logger.warning(
f"Caught exception when running backward via {span} (retry in {delay:.0f} sec): {repr(e)}"
)
maybe_log_traceback(e)
await asyncio.sleep(delay)
# For now, we do not support mixed dummy and grad prompts
# Concat in num_layer dimension
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None
if grad_outputs_dtype is not None:
grad_outputs = [tensor.to(device=grad_outputs_device, dtype=grad_outputs_dtype) for tensor in grad_outputs]
if grad_prompts is not None:
grad_prompts = grad_prompts.to(device=prompts_device, dtype=prompts_dtype)
return grad_outputs, grad_prompts
async def _gather_forward(input_batches, prompt_batches, sequence_manager):
"""Wrapper for asyncio.gather to perform parallel sequential forwards"""
return await asyncio.gather(
*[
sequential_forward(input_batch, prompt_batch, sequence_manager)
for input_batch, prompt_batch in zip(input_batches, prompt_batches)
]
)
async def _gather_backward(
grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences, sequence_manager
):
"""Wrapper for asyncio.gather to perform parallel sequential backwards"""
return await asyncio.gather(
*[
sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager)
for grad_output, input_batch, prompt_batch, spans in zip(
grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences
)
]
)
class _RemoteSequentialAutogradFunction(torch.autograd.Function):
"""
PyTorch autograd function that provides forward and backward calls for the entire sequence of remote transformer blocks.
This function splits input data into batches with <MAX_TOKENS_IN_BATCH> and performs efficient parallel processing.
"""
@staticmethod
def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
if prompts is None or is_dummy(prompts):
prompt_batches = [DUMMY] * len(input_batches)
else:
prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)
sequence_manager.rpc_info # lazy init
outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager))
assert len(outputs) == len(input_batches)
output_batches = [output[0] for output in outputs]
intemediate_input_batches = [output[1] for output in outputs]
sequences_for_batches = [output[2] for output in outputs]
ctx.prompt_batches = prompt_batches
ctx.sequence_manager = sequence_manager
ctx.intemediate_input_batches = intemediate_input_batches
ctx.sequences_for_batches = sequences_for_batches
return torch.cat(output_batches, dim=0)
@staticmethod
def backward(ctx, grad_outputs: torch.Tensor):
intermediate_input_batches: List[Sequence[torch.Tensor]] = ctx.intemediate_input_batches
forward_sequences: List[Sequence[RemoteSpanInfo]] = ctx.sequences_for_batches
ctx.sequence_manager.rpc_info # lazy init
batch_size = max(MAX_TOKENS_IN_BATCH // grad_outputs.shape[1], 1)
grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)
assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences)
outputs = RemoteExpertWorker.run_coroutine(
_gather_backward(
grad_output_batches,
intermediate_input_batches,
ctx.prompt_batches,
forward_sequences,
ctx.sequence_manager,
)
)
grad_input_batches = [output[0][0] for output in outputs]
grad_prompt_batches = [output[1] for output in outputs]
grad_inputs = torch.cat(grad_input_batches, dim=0)
dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]
grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None
return (grad_inputs, grad_prompts, None)
================================================
FILE: src/petals/constants.py
================================================
import torch
PUBLIC_INITIAL_PEERS = [
# IPv4 DNS addresses
"/dns/bootstrap1.petals.dev/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
"/dns/bootstrap2.petals.dev/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
# IPv6 DNS addresses
"/dns6/bootstrap1.petals.dev/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
"/dns6/bootstrap2.petals.dev/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
# Reserved IPs
"/ip4/159.89.214.152/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
"/ip4/159.203.156.48/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
]
# The reachability API is currently used only when connecting to the public swarm
REACHABILITY_API_URL = "https://health.petals.dev"
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
================================================
FILE: src/petals/data_structures.py
================================================
import dataclasses
from enum import Enum
from typing import Any, Dict, Optional, Sequence, Tuple
import pydantic.v1 as pydantic
from hivemind import PeerID
from hivemind.moe.expert_uid import ExpertUID
ModuleUID = str
UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
def parse_uid(uid: ModuleUID) -> Tuple[str, int]:
assert CHAIN_DELIMITER not in uid, "parse_uid() does not support chained UIDs"
dht_prefix, index = uid.split(UID_DELIMITER)
return dht_prefix, int(index)
@pydantic.dataclasses.dataclass
class ModelInfo:
num_blocks: pydantic.conint(ge=1, strict=True)
repository: Optional[str] = None
def to_dict(self) -> dict:
return dataclasses.asdict(self)
@classmethod
def from_dict(cls, source: dict):
return cls(**source)
class ServerState(Enum):
OFFLINE = 0
JOINING = 1
ONLINE = 2
RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
@pydantic.dataclasses.dataclass
class ServerInfo:
state: ServerState
throughput: RPS
start_block: Optional[pydantic.conint(ge=0, strict=True)] = None
end_block: Optional[pydantic.conint(ge=0, strict=True)] = None
public_name: Optional[str] = None
version: Optional[str] = None
network_rps: Optional[RPS] = None
forward_rps: Optional[RPS] = None
inference_rps: Optional[RPS] = None
adapters: Sequence[str] = ()
torch_dtype: Optional[str] = None
quant_type: Optional[str] = None
using_relay: Optional[bool] = None
cache_tokens_left: Optional[pydantic.conint(ge=0, strict=True)] = None
next_pings: Optional[Dict[str, pydantic.confloat(ge=0, strict=True)]] = None
def to_tuple(self) -> Tuple[int, float, dict]:
extra_info = dataclasses.asdict(self)
del extra_info["state"], extra_info["throughput"]
return (self.state.value, self.throughput, extra_info)
@classmethod
def from_tuple(cls, source: tuple):
state, throughput = source[:2]
extra_info = source[2] if len(source) > 2 else {}
# pydantic will validate existing fields and ignore extra ones
return cls(state=ServerState(state), throughput=throughput, **extra_info)
@dataclasses.dataclass
class RemoteModuleInfo:
"""A remote module that is served by one or more servers"""
uid: ModuleUID
servers: Dict[PeerID, ServerInfo]
@dataclasses.dataclass
class RemoteSpanInfo:
"""A chain of remote blocks served by one specific remote peer"""
peer_id: PeerID
start: int
end: int
server_info: ServerInfo
@property
def length(self) -> int:
return self.end - self.start
@property
def state(self) -> ServerState:
return self.server_info.state
@property
def throughput(self) -> float:
return self.server_info.throughput
RPCInfo = Dict[str, Any]
Handle = int
@dataclasses.dataclass(frozen=True)
class InferenceMetadata:
uid: ExpertUID
prefix_length: int
cache_handles: Tuple[Handle, ...]
active_adapter: Optional[str]
================================================
FILE: src/petals/dht_utils.py
================================================
import warnings
warnings.warn(
"petals.dht_utils has been moved to petals.utils.dht. This alias will be removed in Petals 2.2.0+",
DeprecationWarning,
stacklevel=2,
)
from petals.utils.dht import *
================================================
FILE: src/petals/models/__init__.py
================================================
from petals.models.bloom import *
from petals.models.falcon import *
from petals.models.llama import *
from petals.models.mixtral import *
================================================
FILE: src/petals/models/bloom/__init__.py
================================================
from petals.models.bloom.block import WrappedBloomBlock
from petals.models.bloom.config import DistributedBloomConfig
from petals.models.bloom.model import (
DistributedBloomForCausalLM,
DistributedBloomForSequenceClassification,
DistributedBloomModel,
)
from petals.utils.auto_config import register_model_classes
register_model_classes(
config=DistributedBloomConfig,
model=DistributedBloomModel,
model_for_causal_lm=DistributedBloomForCausalLM,
model_for_sequence_classification=DistributedBloomForSequenceClassification,
)
================================================
FILE: src/petals/models/bloom/block.py
================================================
"""
Bloom intermediate layer
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
See commit history for authorship.
"""
from typing import Optional, Tuple
import torch
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.bloom.modeling_bloom import BloomBlock, build_alibi_tensor
from petals.utils.misc import is_dummy
class WrappedBloomBlock(BloomBlock):
def forward(
self,
hidden_states: torch.Tensor,
*args,
attention_mask: Optional[torch.Tensor] = None,
alibi: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs
):
assert attention_mask is None, "Non-causal attention masks are not supported yet"
batch_size, seq_length = hidden_states.shape[:2]
if layer_past is not None and is_dummy(layer_past[0]):
# Bloom cannot use cache if it was misconsctructed(e.g. Dummy tensors)
# In this case, fallback to the old code:
layer_past = None
past_length = 0 if layer_past is None else layer_past[0].shape[-1]
seq_length_with_past = seq_length + past_length
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
if alibi is None:
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=hidden_states,
past_key_values_length=past_length,
)
attention_mask = attention_mask.bool()
return super().forward(
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
)
================================================
FILE: src/petals/models/bloom/config.py
================================================
import os
from typing import Optional, Union
from hivemind import get_logger
from transformers.models.bloom import BloomConfig
from transformers.models.bloom.modeling_bloom import BloomAttention
from petals.client.config import ClientConfig
from petals.client.lm_head import LMHeadConfig
from petals.client.ptune import PTuneConfig
from petals.models.bloom.block import WrappedBloomBlock
logger = get_logger(__name__)
class DistributedBloomConfig(BloomConfig, ClientConfig, PTuneConfig, LMHeadConfig):
block_class = WrappedBloomBlock
attn_class = BloomAttention
block_prefix = "h"
num_key_value_groups = 1
@classmethod
def from_pretrained(
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
):
logger.info("Make sure you follow the BLOOM terms of use: https://bit.ly/bloom-license")
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
if loading_from_repo and dht_prefix is None:
# We need "-petals" for backward compatibility with Petals < 1.2.0
dht_prefix = str(model_name_or_path) + "-petals"
dht_prefix = dht_prefix.replace(".", "-")
logger.info(f"Using DHT prefix: {dht_prefix}")
return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
================================================
FILE: src/petals/models/bloom/model.py
================================================
from typing import Optional
import hivemind
import torch
import torch.nn as nn
from hivemind.utils.logging import get_logger
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel
from petals.client.from_pretrained import FromPretrainedMixin
from petals.client.lm_head import LMHead
from petals.client.ptune import PTuneMixin
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
from petals.client.remote_sequential import RemoteSequential
from petals.models.bloom.config import DistributedBloomConfig
logger = get_logger(__name__)
class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
"""BloomModel, but all transformer layers are hosted by the swarm"""
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = [r"^h\."]
config_class = DistributedBloomConfig
def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None):
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
super().__init__(config)
assert len(self.h) == 0
config.num_hidden_layers = n_layer
self.h = RemoteSequential(config, dht=dht)
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
self.init_prompts(config)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[RemotePastKeyValues] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
# The causal mask will be added on the server-side
assert (
attention_mask is None or (attention_mask == 1).all()
), f"Custom attention masks are not supported, {attention_mask=}"
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
assert not output_attentions, f"{output_attentions=} is not supported"
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0
if use_prompts:
batch_size = inputs_embeds.shape[0]
prompts, intermediate_prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
else:
prompts = intermediate_prompts = None
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
output_shape = input_shape + (hidden_states.size(-1),)
hidden_states = self.h(
hidden_states,
prompts=intermediate_prompts,
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
)
# Remove prefix
if use_prompts:
hidden_states = hidden_states[:, self.pre_seq_len :]
if past_key_values is None:
past_key_values = RemotePastKeyValues()
past_key_values.update_seen(hidden_states.size(1))
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=None,
attentions=None,
)
class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, BloomForCausalLM):
_keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_missing += [r"^lm_head\."] # Missing since they are shared with input embeddings
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
_supports_cache_class = True
config_class = DistributedBloomConfig
def __init__(self, config: DistributedBloomConfig):
BloomPreTrainedModel.__init__(self, config)
self.transformer = DistributedBloomModel(config)
self.lm_head = LMHead(config)
# Initialize weights and apply final processing
self.post_init()
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
) -> dict:
# Omit tokens covered by past_key_values
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values._seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
def _temporary_reorder_cache(self, past_key_values, beam_idx):
return past_key_values
def get_output_embeddings(self):
return self.lm_head
class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSequenceClassification):
_keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
config_class = DistributedBloomConfig
def __init__(self, config: DistributedBloomConfig):
BloomPreTrainedModel.__init__(self, config)
self.num_labels = config.num_labels
self.transformer = DistributedBloomModel(config)
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
================================================
FILE: src/petals/models/falcon/__init__.py
================================================
from petals.models.falcon.block import WrappedFalconBlock
from petals.models.falcon.config import DistributedFalconConfig
from petals.models.falcon.model import (
DistributedFalconForCausalLM,
DistributedFalconForSequenceClassification,
DistributedFalconModel,
)
from petals.utils.auto_config import register_model_classes
register_model_classes(
config=DistributedFalconConfig,
model=DistributedFalconModel,
model_for_causal_lm=DistributedFalconForCausalLM,
model_for_sequence_classification=DistributedFalconForSequenceClassification,
)
================================================
FILE: src/petals/models/falcon/block.py
================================================
"""
Falcon intermediate layer
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py
See commit history for authorship.
"""
import math
from functools import partial
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.falcon.modeling_falcon import (
FalconAttention,
FalconConfig,
FalconDecoderLayer,
FalconLinear,
FalconMLP,
FalconModel,
LayerNorm,
build_alibi_tensor,
dropout_add,
rotate_half,
)
KVCache = Tuple[torch.Tensor, torch.Tensor]
INFERENCE_MAX_LENGTH = 8192
def apply_rotary(query, key, cos, sin):
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
class OptimizedFalconRotaryEmbedding(nn.Module):
def __init__(self, head_dim: int, base=10000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.head_dim = head_dim
self.seq_len_cached = -1
self.cuda_graph = None
self.input_surface = None
self.static_outputs = None
def _optimized_apply_rotary(self, query, key, cos, sin):
if self.cuda_graph is None:
self.cuda_graph = torch.cuda.CUDAGraph()
self.input_surface = (query, key, cos, sin)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
apply_rotary(*self.input_surface)
torch.cuda.current_stream().wait_stream(s)
with torch.cuda.graph(self.cuda_graph):
self.static_outputs = apply_rotary(*self.input_surface)
inputs = (query, key, cos, sin)
for static_input, data in zip(self.input_surface, inputs):
static_input.copy_(data)
self.cuda_graph.replay()
return tuple(o.detach() for o in self.static_outputs)
def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
total_length = seq_len + past_key_values_length
if self.seq_len_cached == -1:
# warm up the cache
total_length = max(INFERENCE_MAX_LENGTH, total_length)
if total_length > self.seq_len_cached:
with torch.inference_mode(False):
self.seq_len_cached = total_length
t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device)
if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float()
self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, :, :].type(dtype), persistent=False)
return (
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),
)
def forward(self, query, key, past_key_values_length=0):
batch, seq_len, head_dim = query.shape
cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
if seq_len == 1 and torch.is_inference_mode_enabled() and query.device.type == "cuda":
return self._optimized_apply_rotary(query, key, cos, sin)
else:
return apply_rotary(query, key, cos, sin)
def split_heads(
fused_qkv: torch.Tensor, num_heads: int, num_kv_heads: int, head_dim: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch, seq_len, _ = fused_qkv.shape
qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim)
query, key, value = torch.split(qkv, [num_heads // num_kv_heads, 1, 1], dim=3)
key = torch.broadcast_to(key, query.shape)
value = torch.broadcast_to(value, query.shape)
query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
return query, key, value
class OptimizedFalconAttention(FalconAttention):
def __init__(self, config: FalconConfig):
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.split_size = self.hidden_size
self.hidden_dropout = config.hidden_dropout
if self.head_dim * self.num_heads != self.hidden_size:
raise ValueError(
f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
f" {self.num_heads})."
)
self.maybe_rotary = OptimizedFalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k)
# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
self.beta = self.inv_norm_factor
if config.new_decoder_architecture:
qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
elif config.multi_query:
qkv_out_dim = self.hidden_size + 2 * self.head_dim
else:
qkv_out_dim = 3 * self.hidden_size
self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
self.new_decoder_architecture = config.new_decoder_architecture
self.multi_query = config.multi_query
self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
self.attention_dropout = nn.Dropout(config.attention_dropout)
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
if self.new_decoder_architecture:
self._split_heads = partial(
split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim
)
self.split_graph = None
self.input_surface = None
self.static_outputs = None
def _optimized_split_heads(self, fused_qkv):
if self.split_graph is None:
self.split_graph = torch.cuda.CUDAGraph()
self.input_surface = fused_qkv
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
self._split_heads(fused_qkv)
torch.cuda.current_stream().wait_stream(s)
with torch.cuda.graph(self.split_graph):
self.static_outputs = self._split_heads(self.input_surface)
self.input_surface.copy_(fused_qkv)
self.split_graph.replay()
return tuple(o.detach() for o in self.static_outputs)
def forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
assert not output_attentions
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
if (
self.new_decoder_architecture
and hidden_states.size(1) == 1
and torch.is_inference_mode_enabled()
and hidden_states.device.type == "cuda"
):
query_layer, key_layer, value_layer = self._optimized_split_heads(fused_qkv)
else:
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
num_kv_heads = self.num_heads
batch_size, query_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
key_layer = key_layer.transpose(1, 2).reshape(
batch_size * num_kv_heads,
query_length,
self.head_dim,
)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, kv_length, head_dim]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)
_, kv_length, _ = key_layer.shape
if use_cache:
present = (key_layer, value_layer)
else:
present = None
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
if alibi is None:
attn_output = F.scaled_dot_product_attention(
query_layer_, key_layer_, value_layer_, attn_mask=attention_mask_float, dropout_p=0.0, is_causal=False
)
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
attn_output = attn_output.permute(0, 2, 1, 3)
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
output_tensor = self.dense(attn_output)
return output_tensor, present
else:
matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
# change view to [batch_size, num_heads, q_length, kv_length]
attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores.dtype
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
attention_scores = attention_scores.to(torch.float32)
# Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
# adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically
# equivalent and more performant, but there might be a numerical difference. If you're reading this
# and you'd like to experiment and maybe file a PR, feel free!
attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
attention_logits *= self.inv_norm_factor
attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype)
# [batch_size, num_heads, q_length, kv_length]
attention_probs = self.attention_dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
# change view [batch_size, num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
# change view [batch_size, q_length, num_heads * head_dim]
context_layer = self._merge_heads(context_layer)
output_tensor = self.dense(context_layer)
if output_attentions:
return output_tensor, present, attention_probs
else:
return output_tensor, present
class OptimizedFalconDecoderLayer(FalconDecoderLayer):
def __init__(self, config: FalconConfig):
nn.Module.__init__(self)
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.mlp = FalconMLP(config)
self.hidden_dropout = config.hidden_dropout
self.config = config
self.self_attention = OptimizedFalconAttention(config)
if self.config.alibi or not config.new_decoder_architecture:
if config.new_decoder_architecture:
# The layer norm before self-attention
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
# The layer norm before the MLP
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
else:
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
if not config.parallel_attn:
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
else:
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.ln_graph = None
self.static_input = None
self.static_outputs = None
def _optimized_apply_ln(self, hidden_states):
if self.ln_graph is None:
self.ln_graph = torch.cuda.CUDAGraph()
self.static_input = hidden_states
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
self.ln_attn(hidden_states)
self.ln_mlp(hidden_states)
torch.cuda.current_stream().wait_stream(s)
with torch.cuda.graph(self.ln_graph):
ln_attn_output = self.ln_attn(hidden_states)
ln_mlp_output = self.ln_mlp(hidden_states)
self.static_outputs = (ln_attn_output, ln_mlp_output)
self.static_input.copy_(hidden_st
gitextract_tsupof0f/
├── .github/
│ └── workflows/
│ ├── check-style.yaml
│ ├── push-docker-image.yaml
│ └── run-tests.yaml
├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── benchmarks/
│ ├── benchmark_forward.py
│ ├── benchmark_inference.py
│ └── benchmark_training.py
├── examples/
│ ├── prompt-tuning-personachat.ipynb
│ └── prompt-tuning-sst2.ipynb
├── pyproject.toml
├── setup.cfg
├── src/
│ └── petals/
│ ├── __init__.py
│ ├── cli/
│ │ ├── __init__.py
│ │ ├── run_dht.py
│ │ ├── run_prod_server.sh
│ │ └── run_server.py
│ ├── client/
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── from_pretrained.py
│ │ ├── inference_session.py
│ │ ├── lm_head.py
│ │ ├── ptune.py
│ │ ├── remote_forward_backward.py
│ │ ├── remote_generation.py
│ │ ├── remote_sequential.py
│ │ ├── routing/
│ │ │ ├── __init__.py
│ │ │ ├── sequence_info.py
│ │ │ ├── sequence_manager.py
│ │ │ └── spending_policy.py
│ │ └── sequential_autograd.py
│ ├── constants.py
│ ├── data_structures.py
│ ├── dht_utils.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── bloom/
│ │ │ ├── __init__.py
│ │ │ ├── block.py
│ │ │ ├── config.py
│ │ │ └── model.py
│ │ ├── falcon/
│ │ │ ├── __init__.py
│ │ │ ├── block.py
│ │ │ ├── config.py
│ │ │ └── model.py
│ │ ├── llama/
│ │ │ ├── __init__.py
│ │ │ ├── block.py
│ │ │ ├── config.py
│ │ │ ├── model.py
│ │ │ └── speculative_model.py
│ │ └── mixtral/
│ │ ├── __init__.py
│ │ ├── block.py
│ │ ├── config.py
│ │ └── model.py
│ ├── server/
│ │ ├── __init__.py
│ │ ├── backend.py
│ │ ├── block_functions.py
│ │ ├── block_selection.py
│ │ ├── block_utils.py
│ │ ├── from_pretrained.py
│ │ ├── handler.py
│ │ ├── memory_cache.py
│ │ ├── reachability.py
│ │ ├── server.py
│ │ ├── task_pool.py
│ │ ├── task_prioritizer.py
│ │ └── throughput.py
│ └── utils/
│ ├── __init__.py
│ ├── asyncio.py
│ ├── auto_config.py
│ ├── convert_block.py
│ ├── cuda_graphs.py
│ ├── dht.py
│ ├── disk_cache.py
│ ├── hf_auth.py
│ ├── logging.py
│ ├── misc.py
│ ├── packaging.py
│ ├── peft.py
│ ├── ping.py
│ ├── random.py
│ └── version.py
└── tests/
├── bootstrap.id
├── conftest.py
├── server2.id
├── test_aux_functions.py
├── test_block_exact_match.py
├── test_cache.py
├── test_chained_calls.py
├── test_dtype.py
├── test_full_model.py
├── test_optimized_layers.py
├── test_peft.py
├── test_priority_pool.py
├── test_remote_sequential.py
├── test_sequence_manager.py
├── test_server_stats.py
├── test_speculative_generation.py
├── test_tensor_parallel.py
└── test_utils.py
SYMBOL INDEX (511 symbols across 73 files)
FILE: benchmarks/benchmark_forward.py
function main (line 17) | def main():
function benchmark_forward (line 46) | def benchmark_forward(process_idx, args, result_pipe):
FILE: benchmarks/benchmark_inference.py
function main (line 18) | def main():
function benchmark_inference (line 45) | def benchmark_inference(process_idx, args, result_pipe):
FILE: benchmarks/benchmark_training.py
function main (line 17) | def main():
function benchmark_training (line 50) | def benchmark_training(process_idx, args, result_pipe):
FILE: src/petals/__init__.py
function _override_bfloat16_mode_default (line 29) | def _override_bfloat16_mode_default():
FILE: src/petals/cli/run_dht.py
function report_status (line 24) | async def report_status(dht: DHT, node: DHTNode):
function main (line 37) | def main():
FILE: src/petals/cli/run_server.py
function main (line 19) | def main():
FILE: src/petals/client/config.py
class ClientConfig (line 14) | class ClientConfig:
FILE: src/petals/client/from_pretrained.py
class FromPretrainedMixin (line 17) | class FromPretrainedMixin:
method from_pretrained (line 19) | def from_pretrained(
function ignore_keys (line 46) | def ignore_keys(patterns: List[str]):
function patched_get_checkpoint_shard_files (line 54) | def patched_get_checkpoint_shard_files(
FILE: src/petals/client/inference_session.py
class _ServerInferenceSession (line 26) | class _ServerInferenceSession:
method __init__ (line 33) | def __init__(
method create (line 60) | async def create(
method _read_inputs_from_queue (line 79) | async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout:...
method position (line 87) | def position(self):
method position (line 91) | def position(self, start_from_position: int):
method step (line 97) | def step(
method _collect_next_servers (line 174) | def _collect_next_servers(self) -> List[Tuple[str, str, int, int]]:
method _step (line 184) | async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) ->...
method close (line 190) | def close(self):
method _aclose_stream (line 198) | async def _aclose_stream(self):
method __del__ (line 209) | def __del__(self):
method __enter__ (line 212) | def __enter__(self):
method __exit__ (line 216) | def __exit__(self, *exc_details):
class InferenceSession (line 220) | class InferenceSession:
method __init__ (line 225) | def __init__(self, sequence_manager: RemoteSequenceManager, max_length...
method num_blocks (line 235) | def num_blocks(self) -> int:
method position (line 239) | def position(self) -> int:
method position (line 243) | def position(self, start_from_position: int) -> None:
method _enter_server_sessions (line 249) | def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -...
method _exit_server_sessions (line 273) | def _exit_server_sessions(self, server_sessions: List[_ServerInference...
method __enter__ (line 280) | def __enter__(self) -> "InferenceSession":
method step (line 284) | def step(
method _update_sequence (line 364) | def _update_sequence(self, server_idx: int, block_idx: int, attempt_no...
method close (line 393) | def close(self, *exc_details):
method __exit__ (line 400) | def __exit__(self, *exc_details):
method __del__ (line 403) | def __del__(self):
method last_token_id (line 407) | def last_token_id(self) -> Optional[torch.Tensor]: # Backward compati...
method last_token_id (line 411) | def last_token_id(self, value: torch.Tensor): # Backward compatibilit...
FILE: src/petals/client/lm_head.py
class LMHeadConfig (line 16) | class LMHeadConfig:
class LMHead (line 23) | class LMHead(nn.Module):
method __init__ (line 24) | def __init__(self, config: PretrainedConfig):
method forward (line 50) | def forward(self, hidden_states):
method chunked_forward (line 63) | def chunked_forward(self, hidden_states):
FILE: src/petals/client/ptune.py
class PTuneConfig (line 16) | class PTuneConfig:
class PTuneMixin (line 21) | class PTuneMixin:
method init_prompts (line 24) | def init_prompts(self, config: PretrainedConfig) -> None:
method get_prompt (line 43) | def get_prompt(self, batch_size):
function force_non_empty_weights (line 69) | def force_non_empty_weights():
FILE: src/petals/client/remote_forward_backward.py
function _forward_unary (line 21) | async def _forward_unary(
function _backward_unary (line 31) | async def _backward_unary(
function _forward_stream (line 41) | async def _forward_stream(
function _backward_stream (line 54) | async def _backward_stream(
function run_remote_forward (line 67) | async def run_remote_forward(
function run_remote_backward (line 113) | async def run_remote_backward(
FILE: src/petals/client/remote_generation.py
class RemotePastKeyValues (line 20) | class RemotePastKeyValues(Cache):
method __init__ (line 23) | def __init__(self) -> None:
method __getitem__ (line 28) | def __getitem__(self, _index: int) -> List[torch.Tensor]:
method get_seq_length (line 31) | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
method get_max_length (line 34) | def get_max_length(self) -> Optional[int]:
method update_seen (line 37) | def update_seen(self, new_seen: int) -> None:
method reorder_cache (line 40) | def reorder_cache(self, beam_idx):
class _SkipTokensMixin (line 47) | class _SkipTokensMixin:
method prepare_inputs_for_generation (line 50) | def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, *...
class RemoteGenerationMixin (line 56) | class RemoteGenerationMixin(_SkipTokensMixin):
method active_session (line 72) | def active_session(self) -> Optional[InferenceSession]:
method use_session (line 76) | def use_session(self, session: Optional[InferenceSession]) -> ContextM...
method inference_session (line 80) | def inference_session(self, **kwargs) -> ContextManager[InferenceSessi...
method generate (line 84) | def generate(
method _fix_generate_kwargs (line 152) | def _fix_generate_kwargs(kwargs: dict):
method _reorder_cache (line 163) | def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: tor...
FILE: src/petals/client/remote_sequential.py
class RemoteSequential (line 20) | class RemoteSequential(nn.Module):
method __init__ (line 25) | def __init__(
method forward (line 52) | def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor...
method active_session (line 61) | def active_session(self) -> Optional[InferenceSession]:
method position (line 70) | def position(self) -> int:
method use_session (line 76) | def use_session(self, session: Optional[InferenceSession]) -> Inferenc...
method inference_session (line 86) | def inference_session(self, **kwargs) -> InferenceSession:
method __getitem__ (line 97) | def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:
method __iter__ (line 103) | def __iter__(self):
method __len__ (line 107) | def __len__(self):
method extra_repr (line 110) | def extra_repr(self) -> str:
FILE: src/petals/client/routing/sequence_info.py
class RemoteSequenceInfo (line 14) | class RemoteSequenceInfo:
method make_empty (line 31) | def make_empty(cls, block_uids: Iterable[ModuleUID]) -> "RemoteSequenc...
method __getitem__ (line 37) | def __getitem__(self, ix: slice):
method __len__ (line 45) | def __len__(self):
method update_ (line 48) | def update_(self, new_block_infos: List[RemoteModuleInfo]):
method _sort_spans (line 58) | def _sort_spans(block_infos: List[RemoteModuleInfo]):
FILE: src/petals/client/routing/sequence_manager.py
class SequenceManagerConfig (line 34) | class SequenceManagerConfig(ClientConfig):
method __init__ (line 35) | def __init__(self, *args, **kwargs):
class SequenceManagerState (line 46) | class SequenceManagerState:
method __getitem__ (line 52) | def __getitem__(self, ix: Union[int, slice]) -> SequenceManagerState:
method __len__ (line 55) | def __len__(self) -> int:
class RemoteSequenceManager (line 59) | class RemoteSequenceManager:
method __init__ (line 71) | def __init__(
method _peer_ids_to_set (line 122) | def _peer_ids_to_set(peer_ids: Optional[Sequence[Union[PeerID, str]]])...
method make_sequence (line 138) | def make_sequence(
method _make_sequence_with_min_latency (line 177) | def _make_sequence_with_min_latency(
method _build_inference_graph (line 217) | def _build_inference_graph(
method _rtt_to_delay (line 281) | def _rtt_to_delay(
method _has_cache_for (line 292) | def _has_cache_for(span: RemoteSpanInfo, cache_tokens_needed: Optional...
method _make_sequence_with_max_throughput (line 302) | def _make_sequence_with_max_throughput(self, start_index: int, end_ind...
method __getitem__ (line 326) | def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
method update (line 333) | def update(self, *, wait: bool):
method _update (line 340) | def _update(self):
method on_request_failure (line 388) | def on_request_failure(self, peer_id: Optional[PeerID]):
method on_request_success (line 403) | def on_request_success(self, peer_id: PeerID):
method __len__ (line 407) | def __len__(self):
method is_alive (line 411) | def is_alive(self):
method ready (line 415) | def ready(self) -> threading.Event:
method block_uids (line 419) | def block_uids(self):
method rpc_info (line 423) | def rpc_info(self):
method get_retry_delay (line 468) | def get_retry_delay(self, attempt_no: int) -> float:
method get_request_metadata (line 473) | def get_request_metadata(
method shutdown (line 489) | def shutdown(self):
class _SequenceManagerUpdateThread (line 493) | class _SequenceManagerUpdateThread(threading.Thread):
method __init__ (line 494) | def __init__(self, update_period: float, ref_update_manager: WeakMethod):
method run (line 502) | def run(self) -> None:
method shutdown (line 521) | def shutdown(self, timeout: Optional[float] = None):
method __del__ (line 527) | def __del__(self):
function maybe_log_traceback (line 531) | def maybe_log_traceback(exc: Exception):
class MissingBlocksError (line 536) | class MissingBlocksError(RuntimeError):
method __init__ (line 537) | def __init__(self, block_indices: Union[int, Sequence[int]]):
FILE: src/petals/client/routing/spending_policy.py
class SpendingPolicyBase (line 9) | class SpendingPolicyBase(ABC):
method get_points (line 11) | def get_points(self, protocol: str, *args, **kwargs) -> float:
class NoSpendingPolicy (line 15) | class NoSpendingPolicy(SpendingPolicyBase):
method get_points (line 16) | def get_points(self, protocol: str, *args, **kwargs) -> float:
FILE: src/petals/client/sequential_autograd.py
function sequential_forward (line 26) | async def sequential_forward(
function sequential_backward (line 113) | async def sequential_backward(
function _gather_forward (line 199) | async def _gather_forward(input_batches, prompt_batches, sequence_manager):
function _gather_backward (line 209) | async def _gather_backward(
class _RemoteSequentialAutogradFunction (line 223) | class _RemoteSequentialAutogradFunction(torch.autograd.Function):
method forward (line 230) | def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence...
method backward (line 253) | def backward(ctx, grad_outputs: torch.Tensor):
FILE: src/petals/data_structures.py
function parse_uid (line 14) | def parse_uid(uid: ModuleUID) -> Tuple[str, int]:
class ModelInfo (line 21) | class ModelInfo:
method to_dict (line 25) | def to_dict(self) -> dict:
method from_dict (line 29) | def from_dict(cls, source: dict):
class ServerState (line 33) | class ServerState(Enum):
class ServerInfo (line 43) | class ServerInfo:
method to_tuple (line 64) | def to_tuple(self) -> Tuple[int, float, dict]:
method from_tuple (line 70) | def from_tuple(cls, source: tuple):
class RemoteModuleInfo (line 78) | class RemoteModuleInfo:
class RemoteSpanInfo (line 86) | class RemoteSpanInfo:
method length (line 95) | def length(self) -> int:
method state (line 99) | def state(self) -> ServerState:
method throughput (line 103) | def throughput(self) -> float:
class InferenceMetadata (line 113) | class InferenceMetadata:
FILE: src/petals/models/bloom/block.py
class WrappedBloomBlock (line 15) | class WrappedBloomBlock(BloomBlock):
method forward (line 16) | def forward(
FILE: src/petals/models/bloom/config.py
class DistributedBloomConfig (line 16) | class DistributedBloomConfig(BloomConfig, ClientConfig, PTuneConfig, LMH...
method from_pretrained (line 24) | def from_pretrained(
FILE: src/petals/models/bloom/model.py
class DistributedBloomModel (line 21) | class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
method __init__ (line 29) | def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hi...
method forward (line 40) | def forward(
class DistributedBloomForCausalLM (line 111) | class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationM...
method __init__ (line 119) | def __init__(self, config: DistributedBloomConfig):
method prepare_inputs_for_generation (line 127) | def prepare_inputs_for_generation(
method _temporary_reorder_cache (line 176) | def _temporary_reorder_cache(self, past_key_values, beam_idx):
method get_output_embeddings (line 179) | def get_output_embeddings(self):
class DistributedBloomForSequenceClassification (line 183) | class DistributedBloomForSequenceClassification(FromPretrainedMixin, Blo...
method __init__ (line 189) | def __init__(self, config: DistributedBloomConfig):
FILE: src/petals/models/falcon/block.py
function apply_rotary (line 30) | def apply_rotary(query, key, cos, sin):
class OptimizedFalconRotaryEmbedding (line 34) | class OptimizedFalconRotaryEmbedding(nn.Module):
method __init__ (line 35) | def __init__(self, head_dim: int, base=10000):
method _optimized_apply_rotary (line 46) | def _optimized_apply_rotary(self, query, key, cos, sin):
method cos_sin (line 67) | def cos_sin(self, seq_len: int, past_key_values_length: int, device="c...
method forward (line 91) | def forward(self, query, key, past_key_values_length=0):
function split_heads (line 100) | def split_heads(
class OptimizedFalconAttention (line 113) | class OptimizedFalconAttention(FalconAttention):
method __init__ (line 114) | def __init__(self, config: FalconConfig):
method _optimized_split_heads (line 155) | def _optimized_split_heads(self, fused_qkv):
method forward (line 174) | def forward(
class OptimizedFalconDecoderLayer (line 286) | class OptimizedFalconDecoderLayer(FalconDecoderLayer):
method __init__ (line 287) | def __init__(self, config: FalconConfig):
method _optimized_apply_ln (line 317) | def _optimized_apply_ln(self, hidden_states):
method forward (line 339) | def forward(
class WrappedFalconBlock (line 398) | class WrappedFalconBlock(OptimizedFalconDecoderLayer):
method forward (line 399) | def forward(
method _reorder_cache_from_bloom_to_falcon (line 440) | def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> K...
method _reorder_cache_from_falcon_to_bloom (line 452) | def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> K...
method _expand_states (line 464) | def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
method _collapse_states (line 473) | def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
FILE: src/petals/models/falcon/config.py
class DistributedFalconConfig (line 17) | class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, Client...
method num_key_value_groups (line 23) | def num_key_value_groups(self) -> int:
method from_pretrained (line 31) | def from_pretrained(
FILE: src/petals/models/falcon/model.py
class DistributedFalconModel (line 26) | class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, ...
method __init__ (line 34) | def __init__(self, config: DistributedFalconConfig, *, dht: Optional[h...
method forward (line 45) | def forward(
method word_embeddings_layernorm (line 116) | def word_embeddings_layernorm(self) -> nn.Module: # For compatibility...
class DistributedFalconForCausalLM (line 120) | class DistributedFalconForCausalLM(DefaultRevisionMixin, FromPretrainedM...
method __init__ (line 126) | def __init__(self, config: DistributedFalconConfig):
method get_output_embeddings (line 134) | def get_output_embeddings(self):
class DistributedFalconForSequenceClassification (line 138) | class DistributedFalconForSequenceClassification(
method __init__ (line 146) | def __init__(self, config: DistributedFalconConfig):
FILE: src/petals/models/llama/block.py
function apply_rotary_pos_emb (line 26) | def apply_rotary_pos_emb(q, k, cos, sin):
class OptimizedLlamaAttention (line 32) | class OptimizedLlamaAttention(LlamaAttention):
method __init__ (line 33) | def __init__(self, *args, **kwargs):
method _optimized_apply_rotary (line 37) | def _optimized_apply_rotary(self, query_states, key_states, cos, sin):
method forward (line 44) | def forward(
class OptimizedLlamaDecoderLayer (line 130) | class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
method __init__ (line 131) | def __init__(self, config: LlamaConfig):
method _optimized_input_layernorm (line 143) | def _optimized_input_layernorm(self, hidden_states):
method _optimized_output_layernorm (line 150) | def _optimized_output_layernorm(self, hidden_states):
method forward (line 157) | def forward(
class WrappedLlamaBlock (line 225) | class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
method forward (line 226) | def forward(
method _reorder_cache_from_bloom_to_llama (line 280) | def _reorder_cache_from_bloom_to_llama(
method _reorder_cache_from_llama_to_bloom (line 291) | def _reorder_cache_from_llama_to_bloom(
FILE: src/petals/models/llama/config.py
class DistributedLlamaConfig (line 16) | class DistributedLlamaConfig(LlamaConfig, ClientConfig, PTuneConfig, LMH...
method num_key_value_groups (line 22) | def num_key_value_groups(self):
method from_pretrained (line 26) | def from_pretrained(
FILE: src/petals/models/llama/model.py
class DistributedLlamaModel (line 20) | class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
method __init__ (line 28) | def __init__(self, config: DistributedLlamaConfig, *, dht: Optional[hi...
method forward (line 39) | def forward(
method word_embeddings (line 116) | def word_embeddings(self) -> nn.Embedding: # For compatibility with R...
method word_embeddings_layernorm (line 120) | def word_embeddings_layernorm(self) -> nn.Module: # For compatibility...
method h (line 124) | def h(self) -> RemoteSequential: # For compatibility with RemoteGener...
method ln_f (line 128) | def ln_f(self) -> nn.Module: # For compatibility with RemoteGeneratio...
class DistributedLlamaForCausalLM (line 132) | class DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationM...
method __init__ (line 138) | def __init__(self, config: DistributedLlamaConfig):
method get_output_embeddings (line 148) | def get_output_embeddings(self):
method transformer (line 152) | def transformer(self) -> DistributedLlamaModel: # For compatibility w...
class DistributedLlamaForSequenceClassification (line 156) | class DistributedLlamaForSequenceClassification(FromPretrainedMixin, Lla...
method __init__ (line 162) | def __init__(self, config):
method transformer (line 173) | def transformer(self) -> DistributedLlamaModel: # For compatibility w...
FILE: src/petals/models/llama/speculative_model.py
class DistributedLlamaForSpeculativeGeneration (line 13) | class DistributedLlamaForSpeculativeGeneration(DistributedLlamaForCausal...
method __init__ (line 14) | def __init__(self, config: DistributedLlamaConfig, small_model: LlamaF...
method _sample (line 18) | def _sample(
FILE: src/petals/models/mixtral/block.py
class WrappedMixtralBlock (line 13) | class WrappedMixtralBlock(MixtralDecoderLayer):
method __init__ (line 14) | def __init__(self, config: MixtralConfig, layer_idx: int):
method forward (line 21) | def forward(
method _reorder_cache_from_bloom (line 91) | def _reorder_cache_from_bloom(
method _reorder_cache_to_bloom (line 103) | def _reorder_cache_to_bloom(
FILE: src/petals/models/mixtral/config.py
class DistributedMixtralConfig (line 16) | class DistributedMixtralConfig(MixtralConfig, ClientConfig, PTuneConfig,...
method from_pretrained (line 24) | def from_pretrained(
FILE: src/petals/models/mixtral/model.py
class DistributedMixtralModel (line 26) | class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin,...
method __init__ (line 34) | def __init__(self, config: DistributedMixtralConfig, *, dht: Optional[...
method forward (line 45) | def forward(
method word_embeddings (line 125) | def word_embeddings(self) -> nn.Embedding: # For compatibility with R...
method word_embeddings_layernorm (line 129) | def word_embeddings_layernorm(self) -> nn.Module: # For compatibility...
method h (line 133) | def h(self) -> RemoteSequential: # For compatibility with RemoteGener...
method ln_f (line 137) | def ln_f(self) -> nn.Module: # For compatibility with RemoteGeneratio...
class DistributedMixtralForCausalLM (line 141) | class DistributedMixtralForCausalLM(FromPretrainedMixin, RemoteGeneratio...
method __init__ (line 147) | def __init__(self, config: DistributedMixtralConfig):
method get_output_embeddings (line 155) | def get_output_embeddings(self):
method transformer (line 159) | def transformer(self) -> DistributedMixtralModel: # For compatibility...
class DistributedMixtralForSequenceClassification (line 163) | class DistributedMixtralForSequenceClassification(FromPretrainedMixin, M...
method __init__ (line 169) | def __init__(self, config: DistributedMixtralConfig):
method transformer (line 180) | def transformer(self) -> DistributedMixtralModel: # For compatibility...
FILE: src/petals/server/backend.py
class TransformerBackend (line 24) | class TransformerBackend(ModuleBackend):
method __init__ (line 29) | def __init__(
method get_inference_cache_descriptors (line 88) | def get_inference_cache_descriptors(self, batch_size: int, max_length:...
method forward (line 101) | def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Te...
method backward (line 106) | def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.T...
method inference_step (line 112) | def inference_step(
method _estimate_max_chunk_length (line 146) | def _estimate_max_chunk_length(self, hidden_states: torch.Tensor, infe...
method _reorder_cache_inplace (line 154) | def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids...
method _select_layer_past (line 160) | def _select_layer_past(self, cache_tensors: Sequence[torch.Tensor], pr...
method _update_cache_inplace (line 171) | def _update_cache_inplace(
method get_pools (line 183) | def get_pools(self) -> Sequence[PrioritizedTaskPool]:
method get_info (line 186) | def get_info(self) -> Dict[str, Any]:
method shutdown (line 190) | def shutdown(self):
function merge_inference_pools_inplace (line 201) | def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerB...
class _MergedInferenceStep (line 216) | class _MergedInferenceStep:
method __init__ (line 217) | def __init__(self, backends: Dict[ExpertUID, TransformerBackend]):
method __call__ (line 221) | def __call__(
FILE: src/petals/server/block_functions.py
function run_rpc_forward (line 32) | async def run_rpc_forward(
function run_rpc_backward (line 84) | async def run_rpc_backward(
function iterate_rpc_inference (line 144) | async def iterate_rpc_inference(
FILE: src/petals/server/block_selection.py
function compute_throughputs (line 12) | def compute_throughputs(spans: Dict[PeerID, RemoteSpanInfo], *, total_bl...
function _choose_best_start (line 23) | def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
function choose_best_blocks (line 28) | def choose_best_blocks(num_blocks: int, module_infos: List[RemoteModuleI...
function _move_span (line 36) | def _move_span(span: RemoteSpanInfo, new_start: int):
function should_choose_other_blocks (line 40) | def should_choose_other_blocks(
FILE: src/petals/server/block_utils.py
function resolve_block_dtype (line 12) | def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torc...
function get_block_size (line 22) | def get_block_size(
function get_model_block (line 56) | def get_model_block(config, layer_idx: int = 0):
FILE: src/petals/server/from_pretrained.py
function load_pretrained_block (line 35) | def load_pretrained_block(
function _load_state_dict_from_repo (line 81) | def _load_state_dict_from_repo(
function _find_index_file (line 134) | def _find_index_file(
function _load_state_dict_from_repo_file (line 162) | def _load_state_dict_from_repo_file(
function _load_state_dict_from_local_file (line 216) | def _load_state_dict_from_local_file(path: str, *, block_prefix: Optiona...
FILE: src/petals/server/handler.py
class Event (line 48) | class Event(Enum):
class TransformerConnectionHandler (line 55) | class TransformerConnectionHandler(ConnectionHandler):
method __init__ (line 60) | def __init__(
method add_p2p_handlers (line 94) | async def add_p2p_handlers(self, *args, **kwargs) -> None:
method shutdown (line 100) | def shutdown(self):
method _gather_inputs (line 109) | async def _gather_inputs(
method rpc_inference (line 132) | async def rpc_inference(
method _managed_session (line 198) | def _managed_session(self, session_id: str):
method _put_into_session_queue (line 214) | def _put_into_session_queue(self, session_id: str, request: runtime_pb...
method _get_from_session_queue (line 223) | async def _get_from_session_queue(self, session_id: str) -> Optional[r...
method _listen_to_event_queue (line 227) | async def _listen_to_event_queue(self):
method _iterate_inference_steps (line 247) | async def _iterate_inference_steps(
method rpc_push (line 310) | async def rpc_push(self, request: runtime_pb2.ExpertRequest, context: ...
method _push_outputs (line 320) | async def _push_outputs(
method rpc_forward (line 352) | async def rpc_forward(self, request: runtime_pb2.ExpertRequest, contex...
method rpc_forward_stream (line 380) | async def rpc_forward_stream(
method _serialize_outputs (line 411) | def _serialize_outputs(
method rpc_backward (line 434) | async def rpc_backward(self, request: runtime_pb2.ExpertRequest, conte...
method rpc_backward_stream (line 461) | async def rpc_backward_stream(
method _get_active_adapter (line 490) | def _get_active_adapter(self, metadata: dict) -> str:
method _serialize_grads (line 496) | def _serialize_grads(
method _check_uids (line 522) | def _check_uids(self, uids: str) -> Tuple[ModuleUID, ...]:
method _allocate_cache (line 533) | async def _allocate_cache(
method _log_request (line 549) | def _log_request(
method rpc_info (line 575) | async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PC...
FILE: src/petals/server/memory_cache.py
class MemoryCache (line 26) | class MemoryCache:
method __init__ (line 29) | def __init__(self, max_size_bytes: Optional[int], max_alloc_timeout: O...
method current_size_bytes (line 44) | def current_size_bytes(self) -> int:
method current_size_bytes (line 48) | def current_size_bytes(self, value: int):
method enqueued_size_bytes (line 52) | def enqueued_size_bytes(self) -> int:
method enqueued_size_bytes (line 56) | def enqueued_size_bytes(self, value: int):
method bytes_left (line 60) | def bytes_left(self) -> int:
method handle_counter (line 64) | def handle_counter(self) -> int:
method handle_counter (line 68) | def handle_counter(self, value: int):
method allocate_cache (line 72) | async def allocate_cache(
method get_allocation_size (line 110) | def get_allocation_size(*descriptors: TensorDescriptor) -> int:
method _schedule_alloc (line 118) | async def _schedule_alloc(
method _wait_for_free_memory (line 137) | async def _wait_for_free_memory(self, alloc_size: int, timeout: Option...
method _free (line 169) | def _free(self, alloc_size: int, alloc_task: asyncio.Task):
method _wait_until_available (line 179) | def _wait_until_available(self, allocated_size: int, timeout: Optional...
method use_cache (line 196) | def use_cache(self, *handles: Handle) -> Sequence[torch.Tensor]:
class AllocationFailed (line 224) | class AllocationFailed(Exception):
FILE: src/petals/server/reachability.py
function validate_reachability (line 22) | def validate_reachability(peer_id, wait_time: float = 7 * 60, retry_dela...
function check_direct_reachability (line 55) | def check_direct_reachability(max_peers: int = 5, threshold: float = 0.5...
class ReachabilityProtocol (line 86) | class ReachabilityProtocol(ServicerBase):
method __init__ (line 89) | def __init__(self, *, probe: Optional[P2P] = None, wait_timeout: float...
method call_check (line 94) | async def call_check(self, remote_peer: PeerID, *, check_peer: PeerID)...
method rpc_check (line 106) | async def rpc_check(self, request: dht_pb2.PingRequest, context: P2PCo...
method serve (line 119) | async def serve(self, p2p: P2P):
method attach_to_dht (line 127) | def attach_to_dht(cls, dht: DHT, await_ready: bool = False, **kwargs) ...
method shutdown (line 162) | def shutdown(self):
FILE: src/petals/server/server.py
class Server (line 46) | class Server:
method __init__ (line 52) | def __init__(
method _choose_num_blocks (line 275) | def _choose_num_blocks(self) -> int:
method run (line 328) | def run(self):
method _clean_memory_and_fds (line 386) | def _clean_memory_and_fds(self):
method _choose_blocks (line 403) | def _choose_blocks(self) -> List[int]:
method _should_choose_other_blocks (line 413) | def _should_choose_other_blocks(self) -> bool:
method shutdown (line 420) | def shutdown(self, timeout: Optional[float] = 5):
class ModuleContainer (line 431) | class ModuleContainer(threading.Thread):
method create (line 436) | def create(
method __init__ (line 557) | def __init__(
method run (line 607) | def run(self):
method run_in_background (line 617) | def run_in_background(self, await_ready=True, timeout=None):
method ready (line 627) | def ready(self) -> mp.synchronize.Event:
method is_healthy (line 639) | def is_healthy(self) -> bool:
method shutdown (line 644) | def shutdown(self):
class ModuleAnnouncerThread (line 674) | class ModuleAnnouncerThread(threading.Thread):
method __init__ (line 677) | def __init__(
method run (line 717) | def run(self) -> None:
method announce (line 754) | def announce(self, state: ServerState) -> None:
method _ping_next_servers (line 760) | def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]:
class RuntimeWithDeduplicatedPools (line 770) | class RuntimeWithDeduplicatedPools(Runtime):
method __init__ (line 773) | def __init__(self, *args, **kwargs):
FILE: src/petals/server/task_pool.py
class Task (line 18) | class Task:
method uid (line 25) | def uid(self) -> int:
class PrioritizedTaskPool (line 29) | class PrioritizedTaskPool(threading.Thread):
method __init__ (line 49) | def __init__(
method run (line 78) | def run(self):
method terminate (line 88) | def terminate(self):
method shutdown (line 92) | def shutdown(self):
method submit_task (line 95) | def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture:
method get_task_size (line 113) | def get_task_size(self, task: Task) -> int:
method load_batch_to_runtime (line 119) | def load_batch_to_runtime(
method send_outputs_from_runtime (line 133) | def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torc...
method send_exception_from_runtime (line 144) | def send_exception_from_runtime(self, uid: int, exception: BaseExcepti...
method empty (line 155) | def empty(self):
method priority (line 159) | def priority(self) -> Tuple[float, float]:
method priority (line 164) | def priority(self, item: Tuple[float, float]):
function _move_to_device_if_tensor (line 170) | def _move_to_device_if_tensor(arg: Any, device: Union[torch.device, str]...
FILE: src/petals/server/task_prioritizer.py
class TaskPrioritizerBase (line 6) | class TaskPrioritizerBase(ABC):
method prioritize (line 10) | def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwar...
class DummyTaskPrioritizer (line 15) | class DummyTaskPrioritizer(TaskPrioritizerBase):
method prioritize (line 16) | def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwar...
FILE: src/petals/server/throughput.py
function get_server_throughput (line 37) | def get_server_throughput(
function measure_throughput_info (line 111) | def measure_throughput_info(
function measure_network_rps (line 147) | def measure_network_rps(
function _measure_bits_per_second (line 178) | def _measure_bits_per_second(pipe_send: mp.Pipe):
function measure_compute_rps (line 190) | def measure_compute_rps(
function synchronize (line 240) | def synchronize(device: torch.device):
function get_device_name (line 247) | def get_device_name(device: torch.device) -> str:
function get_dtype_name (line 251) | def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str:
FILE: src/petals/utils/asyncio.py
function shield_and_wait (line 4) | async def shield_and_wait(task):
FILE: src/petals/utils/auto_config.py
class _ModelClasses (line 14) | class _ModelClasses:
function register_model_classes (line 25) | def register_model_classes(*, config: Type[PretrainedConfig], **kwargs):
class _AutoDistributedBase (line 32) | class _AutoDistributedBase:
method from_pretrained (line 36) | def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike, N...
class DefaultRevisionMixin (line 55) | class DefaultRevisionMixin:
method from_pretrained (line 73) | def from_pretrained(
class AutoDistributedConfig (line 82) | class AutoDistributedConfig(DefaultRevisionMixin, _AutoDistributedBase):
class AutoDistributedModel (line 86) | class AutoDistributedModel(DefaultRevisionMixin, _AutoDistributedBase):
class AutoDistributedModelForCausalLM (line 90) | class AutoDistributedModelForCausalLM(DefaultRevisionMixin, _AutoDistrib...
class AutoDistributedSpeculativeModel (line 94) | class AutoDistributedSpeculativeModel(DefaultRevisionMixin, _AutoDistrib...
class AutoDistributedModelForSequenceClassification (line 98) | class AutoDistributedModelForSequenceClassification(DefaultRevisionMixin...
FILE: src/petals/utils/convert_block.py
class QuantType (line 19) | class QuantType(Enum):
function convert_block (line 25) | def convert_block(
function quantize_module (line 76) | def quantize_module(model: nn.Module, *, quant_type: QuantType) -> nn.Mo...
function make_tensor_parallel (line 118) | def make_tensor_parallel(
function check_device_balance (line 138) | def check_device_balance(devices: Sequence[torch.device]):
FILE: src/petals/utils/cuda_graphs.py
function make_inference_graphed_callable (line 5) | def make_inference_graphed_callable(callable: callable, sample_args, num...
FILE: src/petals/utils/dht.py
function declare_active_modules (line 28) | def declare_active_modules(
function _declare_active_modules (line 57) | async def _declare_active_modules(
function get_remote_module_infos (line 74) | def get_remote_module_infos(
function _get_remote_module_infos (line 95) | async def _get_remote_module_infos(
function compute_spans (line 134) | def compute_spans(module_infos: List[RemoteModuleInfo], *, min_state: Se...
FILE: src/petals/utils/disk_cache.py
function _blocks_lock (line 19) | def _blocks_lock(cache_dir: Optional[str], mode: int):
function allow_cache_reads (line 31) | def allow_cache_reads(cache_dir: Optional[str]):
function allow_cache_writes (line 36) | def allow_cache_writes(cache_dir: Optional[str]):
function free_disk_space_for (line 41) | def free_disk_space_for(
FILE: src/petals/utils/hf_auth.py
function always_needs_auth (line 5) | def always_needs_auth(model_name: Union[str, os.PathLike, None]) -> bool:
FILE: src/petals/utils/logging.py
function initialize_logs (line 6) | def initialize_logs():
FILE: src/petals/utils/misc.py
function is_dummy (line 10) | def is_dummy(tensor: torch.Tensor) -> bool:
function get_size_in_bytes (line 17) | def get_size_in_bytes(dtype: torch.dtype) -> int:
function docstring_from (line 24) | def docstring_from(source):
FILE: src/petals/utils/packaging.py
function _mark_masked_tensor (line 9) | def _mark_masked_tensor(index: int) -> bytes:
function _is_masked_tensor (line 13) | def _is_masked_tensor(item: Any) -> bool:
function _get_tensor_index (line 17) | def _get_tensor_index(item: bytes) -> int:
function pack_args_kwargs (line 21) | def pack_args_kwargs(*args, **kwargs) -> Tuple[List[torch.Tensor], Any]:
function unpack_args_kwargs (line 38) | def unpack_args_kwargs(flat_tensors: List[torch.Tensor], args_structure:...
FILE: src/petals/utils/peft.py
function check_peft_repository (line 31) | def check_peft_repository(repo_id: str) -> bool:
function load_specific_module (line 35) | def load_specific_module(block_idx: int, filepath: str, framework: str =...
function get_adapter_from_repo (line 51) | def get_adapter_from_repo(
function load_peft (line 72) | def load_peft(
class AdapterContextMixin (line 132) | class AdapterContextMixin:
method using_adapter (line 140) | def using_adapter(active_adapter: Optional[str]):
method active_adapter (line 148) | def active_adapter(self):
method active_adapter (line 154) | def active_adapter(self, value: Optional[str]):
method active_adapters (line 158) | def active_adapters(self):
method set_adapter (line 161) | def set_adapter(self, adapter_names) -> None:
class LoraLinear (line 173) | class LoraLinear(AdapterContextMixin, lora.Linear):
method __init__ (line 176) | def __init__(self, base_layer, adapter_name: str):
class LoraLinear8bitLt (line 184) | class LoraLinear8bitLt(LoraLinear, lora.Linear8bitLt):
class LoraLinear4bit (line 188) | class LoraLinear4bit(LoraLinear, lora.Linear4bit):
function create_lora_adapter (line 192) | def create_lora_adapter(block):
function add_adapter_to_block (line 212) | def add_adapter_to_block(block, block_index, adapter_name, peft_config, ...
function estimate_adapter_memory_per_block (line 263) | def estimate_adapter_memory_per_block(
FILE: src/petals/utils/ping.py
function ping (line 15) | async def ping(
function ping_parallel (line 35) | async def ping_parallel(peer_ids: Sequence[hivemind.PeerID], *args, **kw...
class PingAggregator (line 40) | class PingAggregator:
method __init__ (line 41) | def __init__(self, dht: hivemind.DHT, *, ema_alpha: float = 0.2, expir...
method ping (line 48) | def ping(self, peer_ids: Sequence[hivemind.PeerID], **kwargs) -> None:
method to_dict (line 60) | def to_dict(self) -> Dict[hivemind.PeerID, float]:
FILE: src/petals/utils/random.py
function sample_up_to (line 7) | def sample_up_to(population: Collection[T], k: int) -> T:
FILE: src/petals/utils/version.py
function validate_version (line 14) | def validate_version() -> None:
function get_compatible_model_repo (line 33) | def get_compatible_model_repo(model_name_or_path: Union[str, os.PathLike...
FILE: tests/conftest.py
function event_loop (line 15) | def event_loop():
function cleanup_children (line 31) | def cleanup_children():
FILE: tests/test_aux_functions.py
function test_bnb_not_imported_when_unnecessary (line 16) | def test_bnb_not_imported_when_unnecessary():
function test_compute_throughput (line 33) | def test_compute_throughput(inference: bool, n_tokens: int, tensor_paral...
function test_pack_inputs (line 53) | def test_pack_inputs():
FILE: tests/test_block_exact_match.py
function test_remote_block_exact_match (line 13) | def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
FILE: tests/test_cache.py
function _make_tensor_descriptor (line 16) | def _make_tensor_descriptor(num_bytes: int, dtype: Optional[torch.dtype]...
function test_cache_timeout (line 25) | async def test_cache_timeout():
function test_unlimited_timeout (line 76) | async def test_unlimited_timeout():
function test_cache_usage (line 93) | async def test_cache_usage():
FILE: tests/test_chained_calls.py
function test_forward_backward_exact_match (line 18) | def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1...
function test_chained_inference_exact_match (line 46) | def test_chained_inference_exact_match(atol_inference=1e-4):
FILE: tests/test_dtype.py
function test_block_dtype (line 12) | def test_block_dtype(torch_dtype):
FILE: tests/test_full_model.py
function tokenizer (line 14) | def tokenizer():
function model (line 20) | def model():
function ref_model (line 27) | def ref_model():
function test_full_model_exact_match (line 36) | def test_full_model_exact_match(tokenizer, model, ref_model, use_peft, p...
function make_generate_calls (line 80) | def make_generate_calls(model, inputs, *, max_new_tokens, multiple_calls...
function test_greedy_generation (line 97) | def test_greedy_generation(tokenizer, model, ref_model, max_new_tokens=4):
function test_sampling (line 117) | def test_sampling(tokenizer, model, ref_model, max_new_tokens=10):
function test_beam_search_generation (line 149) | def test_beam_search_generation(tokenizer, model, ref_model, max_new_tok...
function test_input_ids (line 159) | def test_input_ids(tokenizer, model, ref_model, max_new_tokens=4):
FILE: tests/test_optimized_layers.py
class UnoptimizedWrappedFalconBlock (line 18) | class UnoptimizedWrappedFalconBlock(FalconDecoderLayer):
method forward (line 19) | def forward(
method _reorder_cache_from_bloom_to_falcon (line 58) | def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> K...
method _reorder_cache_from_falcon_to_bloom (line 70) | def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> K...
method _expand_states (line 82) | def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
method _collapse_states (line 91) | def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
class UnoptimizedWrappedLlamaBlock (line 101) | class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer):
method forward (line 102) | def forward(
method _reorder_cache_from_bloom_to_llama (line 163) | def _reorder_cache_from_bloom_to_llama(
method _reorder_cache_from_llama_to_bloom (line 175) | def _reorder_cache_from_llama_to_bloom(
function test_optimized_block (line 189) | def test_optimized_block(device):
FILE: tests/test_peft.py
function clear_dir (line 14) | def clear_dir(path_to_dir):
function dir_empty (line 19) | def dir_empty(path_to_dir):
function test_check_peft (line 25) | def test_check_peft():
function test_load_noncached (line 31) | def test_load_noncached(tmpdir):
function test_load_cached (line 44) | def test_load_cached(tmpdir):
function test_load_layer_exists (line 52) | def test_load_layer_exists(tmpdir):
function test_load_layer_nonexists (line 59) | def test_load_layer_nonexists(tmpdir):
FILE: tests/test_priority_pool.py
function _submit_tasks (line 12) | def _submit_tasks(runtime_ready, pools, results_valid):
function test_priority_pools (line 33) | def test_priority_pools():
FILE: tests/test_remote_sequential.py
function test_remote_sequential (line 17) | def test_remote_sequential():
class DummyCustomSequenceManager (line 65) | class DummyCustomSequenceManager(RemoteSequenceManager):
method rpc_info (line 69) | def rpc_info(self):
method get_request_metadata (line 76) | def get_request_metadata(self, protocol: str, *args, **kwargs):
function test_remote_sequential_prompts (line 89) | def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
FILE: tests/test_sequence_manager.py
function test_sequence_manager_basics (line 18) | def test_sequence_manager_basics(mode: str):
class RemoteSequenceManagerWithChecks (line 46) | class RemoteSequenceManagerWithChecks(RemoteSequenceManager):
method __init__ (line 49) | def __init__(self, *args, _was_shut_down: threading.Event, **kwargs):
method shutdown (line 53) | def shutdown(self):
FILE: tests/test_server_stats.py
function test_server_info (line 13) | def test_server_info(block_from: int = 2, block_to: int = 5, max_length:...
FILE: tests/test_speculative_generation.py
function test_remote_block_with_cache_invalidation_exact_match (line 19) | def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1...
function noisy_model (line 46) | def noisy_model():
function model (line 58) | def model():
function tokenizer (line 65) | def tokenizer():
function test_remote_speculative_generation (line 75) | def test_remote_speculative_generation(tokenizer, model, noisy_model, at...
FILE: tests/test_tensor_parallel.py
function test_tp_block (line 16) | def test_tp_block(devices, custom_config):
Condensed preview — 100 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (506K chars).
[
{
"path": ".github/workflows/check-style.yaml",
"chars": 517,
"preview": "name: Check style\n\non:\n push:\n branches: [ main ]\n pull_request:\n\njobs:\n black:\n runs-on: ubuntu-latest\n ste"
},
{
"path": ".github/workflows/push-docker-image.yaml",
"chars": 1769,
"preview": "name: Push to Docker Hub\n\non:\n push:\n branches: [ main ]\n tags:\n - \"*.*.*\"\n pull_request:\n branches: [ m"
},
{
"path": ".github/workflows/run-tests.yaml",
"chars": 5566,
"preview": "name: Tests\n\non:\n push:\n branches: [ main ]\n pull_request:\n\njobs:\n run-tests:\n strategy:\n matrix:\n "
},
{
"path": ".gitignore",
"chars": 1802,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": "Dockerfile",
"chars": 926,
"preview": "FROM nvcr.io/nvidia/cuda:11.0.3-cudnn8-devel-ubuntu20.04\nLABEL maintainer=\"bigscience-workshop\"\nLABEL repository=\"petals"
},
{
"path": "LICENSE",
"chars": 1089,
"preview": "MIT License\n\nCopyright (c) 2022 Petals authors and collaborators\n\nPermission is hereby granted, free of charge, to any p"
},
{
"path": "README.md",
"chars": 9093,
"preview": "<p align=\"center\">\n <img src=\"https://i.imgur.com/7eR7Pan.png\" width=\"400\"><br>\n Run large language models at home"
},
{
"path": "benchmarks/benchmark_forward.py",
"chars": 2737,
"preview": "#!/usr/bin/env python3\n\nimport argparse\nimport multiprocessing as mp\nfrom time import perf_counter\n\nimport numpy as np\ni"
},
{
"path": "benchmarks/benchmark_inference.py",
"chars": 2699,
"preview": "#!/usr/bin/env python3\n\nimport argparse\nimport multiprocessing as mp\nfrom time import perf_counter\n\nimport numpy as np\ni"
},
{
"path": "benchmarks/benchmark_training.py",
"chars": 4323,
"preview": "#!/usr/bin/env python3\n\nimport argparse\nimport multiprocessing as mp\nfrom time import perf_counter\n\nimport numpy as np\ni"
},
{
"path": "examples/prompt-tuning-personachat.ipynb",
"chars": 10448,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"id\": \"a07e0f5e\",\n \"metadata\": {},\n \"source\": [\n \"<div>\\n\",\n "
},
{
"path": "examples/prompt-tuning-sst2.ipynb",
"chars": 11084,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"id\": \"a07e0f5e\",\n \"metadata\": {\n \"id\": \"a07e0f5e\"\n },\n \"sourc"
},
{
"path": "pyproject.toml",
"chars": 333,
"preview": "[build-system]\nrequires = [\n \"setuptools>=42\",\n \"wheel\"\n]\nbuild-backend = \"setuptools.build_meta\"\n\n[tool.black]\nli"
},
{
"path": "setup.cfg",
"chars": 1986,
"preview": "[metadata]\nname = petals\nversion = attr: petals.__version__\nauthor = Petals Developers\nauthor_email = petals-devs@google"
},
{
"path": "src/petals/__init__.py",
"chars": 1062,
"preview": "import os\nimport platform\n\nos.environ.setdefault(\"BITSANDBYTES_NOWELCOME\", \"1\")\n\nif platform.system() == \"Darwin\":\n #"
},
{
"path": "src/petals/cli/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/petals/cli/run_dht.py",
"chars": 3871,
"preview": "\"\"\"\nA copy of run_dht.py from hivemind with the ReachabilityProtocol added:\nhttps://github.com/learning-at-home/hivemind"
},
{
"path": "src/petals/cli/run_prod_server.sh",
"chars": 223,
"preview": "#!/bin/bash\nset -x\n\nexport HIVEMIND_COLORS=true\nwhile true; do\n pkill -f p2p\n pkill -f run_server\n "
},
{
"path": "src/petals/cli/run_server.py",
"chars": 14520,
"preview": "import argparse\nimport logging\n\nimport configargparse\nimport torch\nfrom hivemind.proto.runtime_pb2 import CompressionTyp"
},
{
"path": "src/petals/client/__init__.py",
"chars": 262,
"preview": "from petals.client.config import ClientConfig\nfrom petals.client.inference_session import InferenceSession\nfrom petals.c"
},
{
"path": "src/petals/client/config.py",
"chars": 2007,
"preview": "import dataclasses\nimport os\nfrom typing import Optional, Sequence, Union\n\nfrom hivemind import PeerID\n\nfrom petals.cons"
},
{
"path": "src/petals/client/from_pretrained.py",
"chars": 3120,
"preview": "import contextlib\nimport json\nimport os\nimport re\nimport tempfile\nfrom contextvars import ContextVar\nfrom typing import "
},
{
"path": "src/petals/client/inference_session.py",
"chars": 17406,
"preview": "from __future__ import annotations\n\nimport asyncio\nimport itertools\nimport time\nimport uuid\nfrom typing import AsyncIter"
},
{
"path": "src/petals/client/lm_head.py",
"chars": 3443,
"preview": "import dataclasses\nimport platform\nfrom typing import Union\n\nimport torch\nimport torch.nn.functional as F\nimport torch.u"
},
{
"path": "src/petals/client/ptune.py",
"chars": 3485,
"preview": "import dataclasses\nfrom contextlib import contextmanager\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn"
},
{
"path": "src/petals/client/remote_forward_backward.py",
"chars": 7141,
"preview": "\"\"\"\nUtility functions that call RPC forward or backward on a single remote server\n\"\"\"\nimport asyncio\nfrom typing import "
},
{
"path": "src/petals/client/remote_generation.py",
"chars": 7276,
"preview": "import contextlib\nimport dataclasses\nfrom contextvars import ContextVar\nfrom typing import Any, ContextManager, Dict, Li"
},
{
"path": "src/petals/client/remote_sequential.py",
"chars": 4269,
"preview": "from __future__ import annotations\n\nfrom contextlib import contextmanager\nfrom contextvars import ContextVar\nfrom typing"
},
{
"path": "src/petals/client/routing/__init__.py",
"chars": 181,
"preview": "from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback\nfrom petals.client.routing"
},
{
"path": "src/petals/client/routing/sequence_info.py",
"chars": 3066,
"preview": "import dataclasses\nimport time\nfrom typing import Iterable, List, Optional, Tuple\n\nfrom hivemind import get_logger\n\nfrom"
},
{
"path": "src/petals/client/routing/sequence_manager.py",
"chars": 23156,
"preview": "from __future__ import annotations\n\nimport asyncio\nimport dataclasses\nimport itertools\nimport logging\nimport random\nimpo"
},
{
"path": "src/petals/client/routing/spending_policy.py",
"chars": 649,
"preview": "\"\"\"\nAn interface for exchanging internal \"BLOOM points\" for higher priority compute requests. NOT IMPLEMENTED.\nThe inten"
},
{
"path": "src/petals/client/sequential_autograd.py",
"chars": 12656,
"preview": "\"\"\"\nA PyTorch autograd function that runs forward/backward on a sequence of remote servers in a fault-tolerant manner\n\"\""
},
{
"path": "src/petals/constants.py",
"chars": 904,
"preview": "import torch\n\nPUBLIC_INITIAL_PEERS = [\n # IPv4 DNS addresses\n \"/dns/bootstrap1.petals.dev/tcp/31337/p2p/QmedTaZXmU"
},
{
"path": "src/petals/data_structures.py",
"chars": 3170,
"preview": "import dataclasses\nfrom enum import Enum\nfrom typing import Any, Dict, Optional, Sequence, Tuple\n\nimport pydantic.v1 as "
},
{
"path": "src/petals/dht_utils.py",
"chars": 212,
"preview": "import warnings\n\nwarnings.warn(\n \"petals.dht_utils has been moved to petals.utils.dht. This alias will be removed in "
},
{
"path": "src/petals/models/__init__.py",
"chars": 139,
"preview": "from petals.models.bloom import *\nfrom petals.models.falcon import *\nfrom petals.models.llama import *\nfrom petals.model"
},
{
"path": "src/petals/models/bloom/__init__.py",
"chars": 556,
"preview": "from petals.models.bloom.block import WrappedBloomBlock\nfrom petals.models.bloom.config import DistributedBloomConfig\nfr"
},
{
"path": "src/petals/models/bloom/block.py",
"chars": 1939,
"preview": "\"\"\"\nBloom intermediate layer\nBased on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ff"
},
{
"path": "src/petals/models/bloom/config.py",
"chars": 1396,
"preview": "import os\nfrom typing import Optional, Union\n\nfrom hivemind import get_logger\nfrom transformers.models.bloom import Bloo"
},
{
"path": "src/petals/models/bloom/model.py",
"chars": 8455,
"preview": "from typing import Optional\n\nimport hivemind\nimport torch\nimport torch.nn as nn\nfrom hivemind.utils.logging import get_l"
},
{
"path": "src/petals/models/falcon/__init__.py",
"chars": 568,
"preview": "from petals.models.falcon.block import WrappedFalconBlock\nfrom petals.models.falcon.config import DistributedFalconConfi"
},
{
"path": "src/petals/models/falcon/block.py",
"chars": 20422,
"preview": "\"\"\"\nFalcon intermediate layer\nBased on https://github.com/huggingface/transformers/blob/main/src/transformers/models/fal"
},
{
"path": "src/petals/models/falcon/config.py",
"chars": 1980,
"preview": "import os\nfrom typing import Optional, Union\n\nfrom hivemind import get_logger\nfrom transformers.models.falcon import Fal"
},
{
"path": "src/petals/models/falcon/model.py",
"chars": 6441,
"preview": "from typing import Optional\n\nimport hivemind\nimport torch\nimport torch.nn as nn\nfrom hivemind.utils.logging import get_l"
},
{
"path": "src/petals/models/llama/__init__.py",
"chars": 715,
"preview": "from petals.models.llama.block import WrappedLlamaBlock\nfrom petals.models.llama.config import DistributedLlamaConfig\nfr"
},
{
"path": "src/petals/models/llama/block.py",
"chars": 12660,
"preview": "\"\"\"\nLLaMA intermediate layer\nBased on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llam"
},
{
"path": "src/petals/models/llama/config.py",
"chars": 2022,
"preview": "import os\nfrom typing import Optional, Union\n\nfrom hivemind import get_logger\nfrom transformers.models.llama import Llam"
},
{
"path": "src/petals/models/llama/model.py",
"chars": 7174,
"preview": "from typing import Optional\n\nimport hivemind\nimport torch\nimport torch.nn as nn\nfrom hivemind.utils.logging import get_l"
},
{
"path": "src/petals/models/llama/speculative_model.py",
"chars": 4994,
"preview": "from typing import Optional, Union\n\nimport torch\nfrom transformers.generation import GenerationConfig, LogitsProcessorLi"
},
{
"path": "src/petals/models/mixtral/__init__.py",
"chars": 580,
"preview": "from petals.models.mixtral.block import WrappedMixtralBlock\nfrom petals.models.mixtral.config import DistributedMixtralC"
},
{
"path": "src/petals/models/mixtral/block.py",
"chars": 4612,
"preview": "from typing import Optional, Tuple\n\nimport torch\nfrom transformers import MixtralConfig\nfrom transformers.cache_utils im"
},
{
"path": "src/petals/models/mixtral/config.py",
"chars": 1408,
"preview": "import os\nfrom typing import Optional, Union\n\nfrom hivemind import get_logger\nfrom transformers.models.mixtral import Mi"
},
{
"path": "src/petals/models/mixtral/model.py",
"chars": 7526,
"preview": "from typing import Optional\n\nimport torch\nimport torch.nn as nn\nfrom hivemind import DHT\nfrom hivemind.utils.logging imp"
},
{
"path": "src/petals/server/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/petals/server/backend.py",
"chars": 11945,
"preview": "from __future__ import annotations\n\nfrom collections import Counter\nfrom itertools import chain\nfrom typing import Any, "
},
{
"path": "src/petals/server/block_functions.py",
"chars": 11111,
"preview": "\"\"\"\nThis module implements server-side computations on served blocks: forward, backward and inference; used by handler\n\""
},
{
"path": "src/petals/server/block_selection.py",
"chars": 3936,
"preview": "from typing import Dict, List\n\nimport numpy as np\nfrom hivemind import PeerID, get_logger\n\nfrom petals.data_structures i"
},
{
"path": "src/petals/server/block_utils.py",
"chars": 2583,
"preview": "from typing import Optional, Union\n\nimport torch\nfrom accelerate import init_empty_weights\nfrom transformers import Pret"
},
{
"path": "src/petals/server/from_pretrained.py",
"chars": 8719,
"preview": "\"\"\"\nUtils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code."
},
{
"path": "src/petals/server/handler.py",
"chars": 28238,
"preview": "from __future__ import annotations\n\nimport asyncio\nimport contextlib\nimport multiprocessing as mp\nimport sys\nfrom enum i"
},
{
"path": "src/petals/server/memory_cache.py",
"chars": 10780,
"preview": "\"\"\"\nA pytorch memory cache that can be allocated by ConnectionHandler (on cpu) and used over multiple calls to Runtime.\n"
},
{
"path": "src/petals/server/reachability.py",
"chars": 7707,
"preview": "import asyncio\nimport math\nimport threading\nimport time\nfrom concurrent.futures import Future\nfrom contextlib import asy"
},
{
"path": "src/petals/server/server.py",
"chars": 32735,
"preview": "from __future__ import annotations\n\nimport gc\nimport math\nimport multiprocessing as mp\nimport os\nimport random\nimport sy"
},
{
"path": "src/petals/server/task_pool.py",
"chars": 7838,
"preview": "import ctypes\nimport multiprocessing as mp\nimport threading\nimport time\nfrom concurrent.futures._base import PENDING\nfro"
},
{
"path": "src/petals/server/task_prioritizer.py",
"chars": 747,
"preview": "from abc import ABC, abstractmethod\n\nimport torch\n\n\nclass TaskPrioritizerBase(ABC):\n \"\"\"Abstract class for TaskPriori"
},
{
"path": "src/petals/server/throughput.py",
"chars": 9355,
"preview": "import fcntl\nimport json\nimport math\nimport multiprocessing as mp\nimport os\nimport time\nfrom collections import Counter\n"
},
{
"path": "src/petals/utils/__init__.py",
"chars": 296,
"preview": "from petals.utils.auto_config import (\n AutoDistributedConfig,\n AutoDistributedModel,\n AutoDistributedModelForC"
},
{
"path": "src/petals/utils/asyncio.py",
"chars": 525,
"preview": "import asyncio\n\n\nasync def shield_and_wait(task):\n \"\"\"\n Works like asyncio.shield(), but waits for the task to fin"
},
{
"path": "src/petals/utils/auto_config.py",
"chars": 4030,
"preview": "import os\nfrom dataclasses import dataclass\nfrom typing import Optional, Type, Union\n\nfrom hivemind import get_logger\nfr"
},
{
"path": "src/petals/utils/convert_block.py",
"chars": 6571,
"preview": "\"\"\"\nTools for converting transformer blocks, applying quantization and/or tensor parallelism\n\"\"\"\nimport re\nfrom enum imp"
},
{
"path": "src/petals/utils/cuda_graphs.py",
"chars": 2950,
"preview": "import torch\nfrom torch.utils._pytree import tree_flatten as _tree_flatten, tree_unflatten as _tree_unflatten\n\n\ndef make"
},
{
"path": "src/petals/utils/dht.py",
"chars": 5670,
"preview": "\"\"\"\nUtilities for declaring and retrieving active model layers using a shared DHT.\n\"\"\"\nfrom __future__ import annotation"
},
{
"path": "src/petals/utils/disk_cache.py",
"chars": 2964,
"preview": "import fcntl\nimport os\nimport shutil\nfrom contextlib import contextmanager\nfrom pathlib import Path\nfrom typing import O"
},
{
"path": "src/petals/utils/hf_auth.py",
"chars": 270,
"preview": "import os\nfrom typing import Union\n\n\ndef always_needs_auth(model_name: Union[str, os.PathLike, None]) -> bool:\n loadi"
},
{
"path": "src/petals/utils/logging.py",
"chars": 745,
"preview": "import os\n\nfrom hivemind.utils import logging as hm_logging\n\n\ndef initialize_logs():\n \"\"\"Initialize Petals logging tw"
},
{
"path": "src/petals/utils/misc.py",
"chars": 751,
"preview": "import torch\n\nDUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter parameters\n\nDUMMY_INT64 = tor"
},
{
"path": "src/petals/utils/packaging.py",
"chars": 1637,
"preview": "from typing import Any, Dict, List, Tuple\n\nimport torch\nfrom hivemind import nested_flatten, nested_pack\n\n# TODO: Move f"
},
{
"path": "src/petals/utils/peft.py",
"chars": 11705,
"preview": "import contextlib\nimport re\nimport time\nfrom typing import List, Optional, Sequence, Union\n\nimport bitsandbytes as bnb\ni"
},
{
"path": "src/petals/utils/ping.py",
"chars": 2460,
"preview": "import asyncio\nimport math\nimport threading\nimport time\nfrom functools import partial\nfrom typing import Dict, Sequence\n"
},
{
"path": "src/petals/utils/random.py",
"chars": 310,
"preview": "import random\nfrom typing import Collection, TypeVar\n\nT = TypeVar(\"T\")\n\n\ndef sample_up_to(population: Collection[T], k: "
},
{
"path": "src/petals/utils/version.py",
"chars": 1438,
"preview": "import os\nimport re\nfrom typing import Union\n\nimport requests\nfrom hivemind.utils.logging import TextStyle, get_logger\nf"
},
{
"path": "tests/conftest.py",
"chars": 1664,
"preview": "import asyncio\nimport gc\nfrom contextlib import suppress\n\nimport psutil\nimport pytest\nfrom hivemind.utils.crypto import "
},
{
"path": "tests/test_aux_functions.py",
"chars": 2938,
"preview": "import subprocess\nimport sys\n\nimport pytest\nimport torch\nfrom hivemind import nested_compare, nested_flatten\n\nfrom petal"
},
{
"path": "tests/test_block_exact_match.py",
"chars": 1873,
"preview": "import random\n\nimport pytest\nimport torch\n\nfrom petals import AutoDistributedConfig, RemoteSequential\nfrom petals.server"
},
{
"path": "tests/test_cache.py",
"chars": 9054,
"preview": "import asyncio\nimport multiprocessing as mp\nimport random\nimport time\nfrom typing import Optional\n\nimport pytest\nimport "
},
{
"path": "tests/test_chained_calls.py",
"chars": 3066,
"preview": "######\n# Warning:torch this test is a work in progress. It will be modified soon.\n# - if you want more stable tests, see"
},
{
"path": "tests/test_dtype.py",
"chars": 672,
"preview": "import pytest\nimport torch\n\nfrom petals.server.block_utils import resolve_block_dtype\nfrom petals.server.from_pretrained"
},
{
"path": "tests/test_full_model.py",
"chars": 7219,
"preview": "import peft\nimport pytest\nimport torch\nimport transformers\nfrom hivemind import get_logger\n\nfrom petals import AutoDistr"
},
{
"path": "tests/test_optimized_layers.py",
"chars": 9386,
"preview": "from typing import Optional, Tuple\n\nimport pytest\nimport torch\nfrom transformers.cache_utils import DynamicCache\nfrom tr"
},
{
"path": "tests/test_peft.py",
"chars": 1535,
"preview": "import os\nimport shutil\n\nimport pytest\nfrom huggingface_hub import snapshot_download\n\nfrom petals.utils.peft import chec"
},
{
"path": "tests/test_priority_pool.py",
"chars": 3176,
"preview": "import multiprocessing as mp\nimport platform\nimport time\n\nimport pytest\nimport torch\nfrom hivemind.moe.server.runtime im"
},
{
"path": "tests/test_remote_sequential.py",
"chars": 6052,
"preview": "import pytest\nimport torch\nimport torch.nn.functional as F\nfrom hivemind import DHT, BatchTensorDescriptor, get_logger\nf"
},
{
"path": "tests/test_sequence_manager.py",
"chars": 1922,
"preview": "import threading\nimport time\n\nimport pytest\nimport torch\nfrom hivemind import DHT, get_logger\n\nfrom petals import AutoDi"
},
{
"path": "tests/test_server_stats.py",
"chars": 1810,
"preview": "import time\n\nimport hivemind\nimport pytest\nimport torch\n\nfrom petals import AutoDistributedConfig, RemoteSequential\nfrom"
},
{
"path": "tests/test_speculative_generation.py",
"chars": 3208,
"preview": "import random\n\nimport pytest\nimport torch\nimport transformers\n\nfrom petals import (\n AutoDistributedConfig,\n AutoD"
},
{
"path": "tests/test_tensor_parallel.py",
"chars": 2035,
"preview": "import random\n\nimport pytest\nimport torch\nimport transformers\nfrom tensor_parallel import TensorParallel\nfrom tensor_par"
},
{
"path": "tests/test_utils.py",
"chars": 466,
"preview": "import os\n\nINITIAL_PEERS = os.environ.get(\"INITIAL_PEERS\")\nif not INITIAL_PEERS:\n raise RuntimeError(\"Must specify IN"
}
]
// ... and 2 more files (download for full content)
About this extraction
This page contains the full source code of the bigscience-workshop/petals GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 100 files (472.8 KB), approximately 112.5k tokens, and a symbol index with 511 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.