Repository: bigscience-workshop/petals Branch: main Commit: 22afba627a7e Files: 100 Total size: 472.8 KB Directory structure: gitextract_tsupof0f/ ├── .github/ │ └── workflows/ │ ├── check-style.yaml │ ├── push-docker-image.yaml │ └── run-tests.yaml ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── benchmarks/ │ ├── benchmark_forward.py │ ├── benchmark_inference.py │ └── benchmark_training.py ├── examples/ │ ├── prompt-tuning-personachat.ipynb │ └── prompt-tuning-sst2.ipynb ├── pyproject.toml ├── setup.cfg ├── src/ │ └── petals/ │ ├── __init__.py │ ├── cli/ │ │ ├── __init__.py │ │ ├── run_dht.py │ │ ├── run_prod_server.sh │ │ └── run_server.py │ ├── client/ │ │ ├── __init__.py │ │ ├── config.py │ │ ├── from_pretrained.py │ │ ├── inference_session.py │ │ ├── lm_head.py │ │ ├── ptune.py │ │ ├── remote_forward_backward.py │ │ ├── remote_generation.py │ │ ├── remote_sequential.py │ │ ├── routing/ │ │ │ ├── __init__.py │ │ │ ├── sequence_info.py │ │ │ ├── sequence_manager.py │ │ │ └── spending_policy.py │ │ └── sequential_autograd.py │ ├── constants.py │ ├── data_structures.py │ ├── dht_utils.py │ ├── models/ │ │ ├── __init__.py │ │ ├── bloom/ │ │ │ ├── __init__.py │ │ │ ├── block.py │ │ │ ├── config.py │ │ │ └── model.py │ │ ├── falcon/ │ │ │ ├── __init__.py │ │ │ ├── block.py │ │ │ ├── config.py │ │ │ └── model.py │ │ ├── llama/ │ │ │ ├── __init__.py │ │ │ ├── block.py │ │ │ ├── config.py │ │ │ ├── model.py │ │ │ └── speculative_model.py │ │ └── mixtral/ │ │ ├── __init__.py │ │ ├── block.py │ │ ├── config.py │ │ └── model.py │ ├── server/ │ │ ├── __init__.py │ │ ├── backend.py │ │ ├── block_functions.py │ │ ├── block_selection.py │ │ ├── block_utils.py │ │ ├── from_pretrained.py │ │ ├── handler.py │ │ ├── memory_cache.py │ │ ├── reachability.py │ │ ├── server.py │ │ ├── task_pool.py │ │ ├── task_prioritizer.py │ │ └── throughput.py │ └── utils/ │ ├── __init__.py │ ├── asyncio.py │ ├── auto_config.py │ ├── convert_block.py │ ├── cuda_graphs.py │ ├── dht.py │ ├── disk_cache.py │ ├── hf_auth.py │ ├── logging.py │ ├── misc.py │ ├── packaging.py │ ├── peft.py │ ├── ping.py │ ├── random.py │ └── version.py └── tests/ ├── bootstrap.id ├── conftest.py ├── server2.id ├── test_aux_functions.py ├── test_block_exact_match.py ├── test_cache.py ├── test_chained_calls.py ├── test_dtype.py ├── test_full_model.py ├── test_optimized_layers.py ├── test_peft.py ├── test_priority_pool.py ├── test_remote_sequential.py ├── test_sequence_manager.py ├── test_server_stats.py ├── test_speculative_generation.py ├── test_tensor_parallel.py └── test_utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/check-style.yaml ================================================ name: Check style on: push: branches: [ main ] pull_request: jobs: black: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: psf/black@stable with: options: "--check --diff" version: "22.3.0" isort: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v3 with: python-version: 3.8 - uses: isort/isort-action@master with: isortVersion: "5.10.1" ================================================ FILE: .github/workflows/push-docker-image.yaml ================================================ name: Push to Docker Hub on: push: branches: [ main ] tags: - "*.*.*" pull_request: branches: [ main ] jobs: build: runs-on: ubuntu-latest steps: - name: Checkout uses: actions/checkout@v3 - name: Docker meta id: meta uses: crazy-max/ghaction-docker-meta@v2 with: # list of Docker images to use as base name for tags images: | learningathome/petals # generate Docker tags based on the following events/attributes tags: | type=ref,event=branch type=ref,event=pr type=semver,pattern={{version}} type=semver,pattern={{major}}.{{minor}} type=semver,pattern={{major}} - name: Set up Docker Buildx id: buildx uses: docker/setup-buildx-action@v1 - name: Login to Docker Hub if: github.event_name != 'pull_request' uses: docker/login-action@v1 with: username: ${{ secrets.DOCKER_HUB_USERNAME }} password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }} - name: Free disk space on Ubuntu runner uses: kfir4444/free-disk-space@main with: # found in: https://github.com/docker/build-push-action/issues/968 tool-cache: false android: true dotnet: true haskell: true large-packages: true swap-storage: true - name: Build and push id: docker_build uses: docker/build-push-action@v2 with: context: . push: ${{ github.event_name != 'pull_request' }} tags: ${{ steps.meta.outputs.tags }} - name: Image digest run: echo ${{ steps.docker_build.outputs.digest }} ================================================ FILE: .github/workflows/run-tests.yaml ================================================ name: Tests on: push: branches: [ main ] pull_request: jobs: run-tests: strategy: matrix: include: - { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.8' } - { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' } - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' } - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' } - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' } - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' } - { model: 'artek0chumak/TestMixtral', os: 'ubuntu', python-version: '3.8' } - { model: 'artek0chumak/TestMixtral', os: 'ubuntu', python-version: '3.11' } fail-fast: false runs-on: ${{ matrix.os }}-latest timeout-minutes: 20 steps: - name: Increase swap space if: ${{ matrix.os == 'ubuntu' }} uses: pierotofy/set-swap-space@master with: swap-size-gb: 10 - name: Checkout uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v3 with: python-version: ${{ matrix.python-version }} - name: Cache dependencies uses: actions/cache@v3 with: path: ~/.cache/pip key: Key-v1-${{ matrix.python-version }}-${{ hashFiles('setup.cfg') }} - name: Install dependencies run: | python -m pip install --upgrade pip pip install .[dev] - name: Test run: | set -x # Print executed commands export MODEL_NAME="${{ matrix.model }}" export REF_NAME="${{ matrix.model }}" export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}" # [Step 1] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) python -m petals.cli.run_dht \ --identity_path tests/bootstrap.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> bootstrap.log & BOOTSTRAP_PID=$! export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g # ^-- multiaddr in INITIAL_PEERS is determined by --identity_path and --host_maddrs until [ -s bootstrap.log ]; do sleep 5; done # wait for DHT init export RUN_SERVER="python -m petals.cli.run_server $MODEL_NAME \ --device cpu --torch_dtype float32 --initial_peers $INITIAL_PEERS" export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}" $RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 5 --throughput 1 --mean_balance_check_period 10 &> server1.log & SERVER1_PID=$! # ^-- rebalacing test: this server chooses blocks 0:5, then sees a gap in the swarm and moves there sleep 10 # wait for the 1st server to choose blocks $RUN_SERVER --adapters $ADAPTER_NAME --block_indices 0:5 --throughput 1 --identity_path tests/server2.id &> server2.log & SERVER2_PID=$! $RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 14 --throughput auto \ --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 &> server3.log & SERVER3_PID=$! # ^-- chunking test $RUN_SERVER $TENSOR_PARALLEL_ARGS --block_indices 0:2 --throughput auto &> server4.log & SERVER4_PID=$! # ^-- tensor parallelism test (not compatible with adapters yet) sleep 5 # wait for the log files to appear tail -n 100 -f bootstrap.log server*.log & LOGGER_PID=$! sleep 30 # wait for servers to eval throughput, download layers, and rebalance kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived init # [Step 2] Run PyTest # Share disk cache between Petals servers, clients, and HF Transformers export TRANSFORMERS_CACHE=~/.cache/petals # Necessary for @pytest.mark.forked to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93 export no_proxy=* export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES # Limit default ClientConfig.max_retries to see tracebacks instead of retrying indefinitely export PETALS_MAX_RETRIES=10 pytest tests --durations=0 --durations-min=1.0 -v # [Step 3] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers) python benchmarks/benchmark_inference.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \ --seq_len 3 python benchmarks/benchmark_forward.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \ --seq_len 3 --batch_size 3 --n_steps 1 python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \ --seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task cls python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \ --seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task causal_lm # [Step 4] Clean up kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ .idea/ ================================================ FILE: Dockerfile ================================================ FROM nvcr.io/nvidia/cuda:11.0.3-cudnn8-devel-ubuntu20.04 LABEL maintainer="bigscience-workshop" LABEL repository="petals" WORKDIR /home # Set en_US.UTF-8 locale by default RUN echo "LC_ALL=en_US.UTF-8" >> /etc/environment # Install packages RUN apt-get update && apt-get install -y --no-install-recommends \ build-essential \ wget \ git \ && apt-get clean autoclean && rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log} /tmp/* /var/tmp/* RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O install_miniconda.sh && \ bash install_miniconda.sh -b -p /opt/conda && rm install_miniconda.sh ENV PATH="/opt/conda/bin:${PATH}" RUN conda install python~=3.10.12 pip && \ pip install --no-cache-dir "torch>=1.12" && \ conda clean --all && rm -rf ~/.cache/pip VOLUME /cache ENV PETALS_CACHE=/cache COPY . petals/ RUN pip install --no-cache-dir -e petals WORKDIR /home/petals/ CMD bash ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2022 Petals authors and collaborators Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================


Run large language models at home, BitTorrent-style.
Fine-tuning and inference up to 10x faster than offloading


Generate text with distributed **Llama 3.1** (up to 405B), **Mixtral** (8x22B), **Falcon** (40B+) or **BLOOM** (176B) and fine‑tune them for your own tasks — right from your desktop computer or Google Colab: ```python from transformers import AutoTokenizer from petals import AutoDistributedModelForCausalLM # Choose any model available at https://health.petals.dev model_name = "meta-llama/Meta-Llama-3.1-405B-Instruct" # Connect to a distributed network hosting model layers tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoDistributedModelForCausalLM.from_pretrained(model_name) # Run the model as if it were on your computer inputs = tokenizer("A cat sat", return_tensors="pt")["input_ids"] outputs = model.generate(inputs, max_new_tokens=5) print(tokenizer.decode(outputs[0])) # A cat sat on a mat... ```

🚀  Try now in Colab

🦙 **Want to run Llama?** [Request access](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct) to its weights, then run `huggingface-cli login` in the terminal before loading the model. Or just try it in our [chatbot app](https://chat.petals.dev). 🔏 **Privacy.** Your data will be processed with the help of other people in the public swarm. Learn more about privacy [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). For sensitive data, you can set up a [private swarm](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) among people you trust. 💬 **Any questions?** Ping us in [our Discord](https://discord.gg/KdThf2bWVU)! ## Connect your GPU and increase Petals capacity Petals is a community-run system — we rely on people sharing their GPUs. You can help serving one of the [available models](https://health.petals.dev) or host a new model from 🤗 [Model Hub](https://huggingface.co/models)! As an example, here is how to host a part of [Llama 3.1 (405B) Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct) on your GPU: 🦙 **Want to host Llama?** [Request access](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct) to its weights, then run `huggingface-cli login` in the terminal before loading the model. 🐧 **Linux + Anaconda.** Run these commands for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD): ```bash conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia pip install git+https://github.com/bigscience-workshop/petals python -m petals.cli.run_server meta-llama/Meta-Llama-3.1-405B-Instruct ``` 🪟 **Windows + WSL.** Follow [this guide](https://github.com/bigscience-workshop/petals/wiki/Run-Petals-server-on-Windows) on our Wiki. 🐋 **Docker.** Run our [Docker](https://www.docker.com) image for NVIDIA GPUs (or follow [this](https://github.com/bigscience-workshop/petals/wiki/Running-on-AMD-GPU) for AMD): ```bash sudo docker run -p 31330:31330 --ipc host --gpus all --volume petals-cache:/cache --rm \ learningathome/petals:main \ python -m petals.cli.run_server --port 31330 meta-llama/Meta-Llama-3.1-405B-Instruct ``` 🍏 **macOS + Apple M1/M2 GPU.** Install [Homebrew](https://brew.sh/), then run these commands: ```bash brew install python python3 -m pip install git+https://github.com/bigscience-workshop/petals python3 -m petals.cli.run_server meta-llama/Meta-Llama-3.1-405B-Instruct ```

📚  Learn more (how to use multiple GPUs, start the server on boot, etc.)

🔒 **Security.** Hosting a server does not allow others to run custom code on your computer. Learn more [here](https://github.com/bigscience-workshop/petals/wiki/Security,-privacy,-and-AI-safety). 💬 **Any questions?** Ping us in [our Discord](https://discord.gg/X7DgtxgMhc)! 🏆 **Thank you!** Once you load and host 10+ blocks, we can show your name or link on the [swarm monitor](https://health.petals.dev) as a way to say thanks. You can specify them with `--public_name YOUR_NAME`. ## How does it work? - You load a small part of the model, then join a [network](https://health.petals.dev) of people serving the other parts. Single‑batch inference runs at up to **6 tokens/sec** for **Llama 2** (70B) and up to **4 tokens/sec** for **Falcon** (180B) — enough for [chatbots](https://chat.petals.dev) and interactive apps. - You can employ any fine-tuning and sampling methods, execute custom paths through the model, or see its hidden states. You get the comforts of an API with the flexibility of **PyTorch** and **🤗 Transformers**.

📜  Read paper            📚  See FAQ

## 📚 Tutorials, examples, and more Basic tutorials: - Getting started: [tutorial](https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing) - Prompt-tune Llama-65B for text semantic classification: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb) - Prompt-tune BLOOM to create a personified chatbot: [tutorial](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb) Useful tools: - [Chatbot web app](https://chat.petals.dev) (connects to Petals via an HTTP/WebSocket endpoint): [source code](https://github.com/petals-infra/chat.petals.dev) - [Monitor](https://health.petals.dev) for the public swarm: [source code](https://github.com/petals-infra/health.petals.dev) Advanced guides: - Launch a private swarm: [guide](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) - Run a custom model: [guide](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-Petals) ### Benchmarks Please see **Section 3.3** of our [paper](https://arxiv.org/pdf/2209.01188.pdf). ### 🛠️ Contributing Please see our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#contributing) on contributing. ### 📜 Citations Alexander Borzunov, Dmitry Baranchuk, Tim Dettmers, Max Ryabinin, Younes Belkada, Artem Chumachenko, Pavel Samygin, and Colin Raffel. [Petals: Collaborative Inference and Fine-tuning of Large Models.](https://arxiv.org/abs/2209.01188) _Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)._ 2023. ```bibtex @inproceedings{borzunov2023petals, title = {Petals: Collaborative Inference and Fine-tuning of Large Models}, author = {Borzunov, Alexander and Baranchuk, Dmitry and Dettmers, Tim and Riabinin, Maksim and Belkada, Younes and Chumachenko, Artem and Samygin, Pavel and Raffel, Colin}, booktitle = {Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)}, pages = {558--568}, year = {2023}, url = {https://arxiv.org/abs/2209.01188} } ``` Alexander Borzunov, Max Ryabinin, Artem Chumachenko, Dmitry Baranchuk, Tim Dettmers, Younes Belkada, Pavel Samygin, and Colin Raffel. [Distributed inference and fine-tuning of large language models over the Internet.](https://arxiv.org/abs/2312.08361) _Advances in Neural Information Processing Systems_ 36 (2023). ```bibtex @inproceedings{borzunov2023distributed, title = {Distributed inference and fine-tuning of large language models over the {I}nternet}, author = {Borzunov, Alexander and Ryabinin, Max and Chumachenko, Artem and Baranchuk, Dmitry and Dettmers, Tim and Belkada, Younes and Samygin, Pavel and Raffel, Colin}, booktitle = {Advances in Neural Information Processing Systems}, volume = {36}, pages = {12312--12331}, year = {2023}, url = {https://arxiv.org/abs/2312.08361} } ``` --------------------------------------------------------------------------------

This project is a part of the BigScience research workshop.

================================================ FILE: benchmarks/benchmark_forward.py ================================================ #!/usr/bin/env python3 import argparse import multiprocessing as mp from time import perf_counter import numpy as np import torch from hivemind.utils.logging import get_logger from petals import AutoDistributedModel from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS logger = get_logger() def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--model", type=str, required=True, help="Model") parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers") parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype") parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes") parser.add_argument("--seq_len", type=int, default=128, help="Sequence length") parser.add_argument("--n_steps", type=int, default=100, help="Number of benchmark steps") parser.add_argument("--batch_size", type=int, required=True, help="Batch size") parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps") args = parser.parse_args() if args.n_processes == "n_gpus": args.n_processes = torch.cuda.device_count() else: args.n_processes = int(args.n_processes) pipe_recv, pipe_send = mp.Pipe(duplex=False) processes = [mp.Process(target=benchmark_forward, args=(i, args, pipe_send)) for i in range(args.n_processes)] for proc in processes: proc.start() for proc in processes: proc.join() speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)]) logger.info(f"Final result: {speed=:.2f}") @torch.inference_mode() def benchmark_forward(process_idx, args, result_pipe): model = AutoDistributedModel.from_pretrained( args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype], ) logger.info(f"Created model: {process_idx=} {model.device=}") torch.manual_seed(42) step_times = [] for step in range(args.warmup_steps + args.n_steps): start_time = perf_counter() input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len)) logger.info(f"{process_idx=} Fwd begin {input_ids.shape=}") h = model(input_ids) # We don't use model.lm_head logger.info(f"{process_idx=} Fwd end") if step >= args.warmup_steps: step_times.append(perf_counter() - start_time) speed = input_ids.numel() / np.mean(step_times) logger.info(f"{process_idx=} {step=} {speed=:.2f}") result_pipe.send(speed) if __name__ == "__main__": main() ================================================ FILE: benchmarks/benchmark_inference.py ================================================ #!/usr/bin/env python3 import argparse import multiprocessing as mp from time import perf_counter import numpy as np import torch from hivemind.utils.logging import get_logger from transformers import AutoTokenizer from petals import AutoDistributedModelForCausalLM from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS logger = get_logger() def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--model", type=str, required=True, help="Model") parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers") parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype") parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes") parser.add_argument("--seq_len", type=int, default=2048, help="Sequence length") parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps") args = parser.parse_args() if args.n_processes == "n_gpus": args.n_processes = torch.cuda.device_count() else: args.n_processes = int(args.n_processes) pipe_recv, pipe_send = mp.Pipe(duplex=False) processes = [mp.Process(target=benchmark_inference, args=(i, args, pipe_send)) for i in range(args.n_processes)] for proc in processes: proc.start() for proc in processes: proc.join() speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)]) logger.info(f"Final result: {speed=:.2f}") @torch.inference_mode() def benchmark_inference(process_idx, args, result_pipe): tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False) # Using use_fast=False since LlamaTokenizerFast takes a long time to start, and we decode 1 token at a time anyway model = AutoDistributedModelForCausalLM.from_pretrained( args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype] ) logger.info(f"Created model: {process_idx=} {model.device=}") result = "" step_times = [] with model.transformer.h.inference_session(max_length=args.seq_len) as sess: for step in range(args.seq_len): start_time = perf_counter() outputs = model.generate(max_new_tokens=1, session=sess) result += tokenizer.decode(outputs[0]) if step >= args.warmup_steps: step_times.append(perf_counter() - start_time) speed = 1 / np.mean(step_times) logger.info(f"{process_idx=} {step=} {speed=:.2f}") result_pipe.send(speed) if __name__ == "__main__": main() ================================================ FILE: benchmarks/benchmark_training.py ================================================ #!/usr/bin/env python3 import argparse import multiprocessing as mp from time import perf_counter import numpy as np import torch from hivemind.utils.logging import get_logger from petals import AutoDistributedModelForCausalLM, AutoDistributedModelForSequenceClassification from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS logger = get_logger() def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--model", type=str, required=True, help="Model") parser.add_argument("--device", type=str, default="cpu", help="Torch device hosting the client") parser.add_argument("--task", type=str, default="cls", help="Training task type") parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers") parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype") parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes") parser.add_argument("--seq_len", type=int, default=128, help="Sequence length") parser.add_argument("--pre_seq_len", type=int, default=16, help="Number of trainable tokens") parser.add_argument("--n_steps", type=int, default=10, help="Number of benchmark steps") parser.add_argument("--batch_size", type=int, required=True, help="Batch size") parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps") args = parser.parse_args() assert args.task in ["cls", "causal_lm"] if args.n_processes == "n_gpus": args.n_processes = torch.cuda.device_count() else: args.n_processes = int(args.n_processes) pipe_recv, pipe_send = mp.Pipe(duplex=False) processes = [mp.Process(target=benchmark_training, args=(i, args, pipe_send)) for i in range(args.n_processes)] for proc in processes: proc.start() for proc in processes: proc.join() fwd_speed, bwd_speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)], axis=0) logger.info(f"Final result: {fwd_speed=:.2f} {bwd_speed=:.2f}") def benchmark_training(process_idx, args, result_pipe): if args.task == "cls": model = AutoDistributedModelForSequenceClassification.from_pretrained( args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype], tuning_mode="deep_ptune", pre_seq_len=args.pre_seq_len, num_labels=2, ) elif args.task == "causal_lm": model = AutoDistributedModelForCausalLM.from_pretrained( args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype], tuning_mode="deep_ptune", pre_seq_len=args.pre_seq_len, ) model = model.to(args.device) opt = torch.optim.Adam(model.parameters()) logger.info(f"Created model: {process_idx=} {model.device=}") torch.manual_seed(42) fwd_times = [] bwd_times = [] for step in range(args.warmup_steps + args.n_steps): input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len), device=args.device) if args.task == "cls": labels = torch.randint(0, 2, size=[args.batch_size], device=args.device) else: labels = input_ids logger.info(f"{process_idx=} {step=} Forward") start_time = perf_counter() outputs = model(input_ids, labels=labels) if step >= args.warmup_steps: fwd_times.append(perf_counter() - start_time) logger.info(f"{process_idx=} {step=} Backward") start_time = perf_counter() outputs.loss.backward() if step >= args.warmup_steps: bwd_times.append(perf_counter() - start_time) logger.info(f"{process_idx=} {step=} Optimizer step") opt.step() opt.zero_grad() if step >= args.warmup_steps: fwd_speed = input_ids.numel() / np.mean(fwd_times) bwd_speed = input_ids.numel() / np.mean(bwd_times) logger.info(f"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}") result_pipe.send((fwd_speed, bwd_speed)) if __name__ == "__main__": main() ================================================ FILE: examples/prompt-tuning-personachat.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "a07e0f5e", "metadata": {}, "source": [ "
\n", " \n", "
\n", "\n", "# Distributed Bloom for Text Generation using Prompt Tuning\n", "\n", "In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n", "\n", "We will adapt BLOOM for the task of creating a chatbot with a specific personality using the [Personachat](https://huggingface.co/datasets/bavard/personachat_truecased) dataset. For a given dialogue context, the model has to provide a relevant answer.\n", "\n", "To use this notebook in Colab:\n", "\n", "1. Follow this link: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb)\n", "2. Go to **Runtime** -> **Change runtime type** and select the GPU accelerator." ] }, { "cell_type": "markdown", "id": "a3f8526f", "metadata": {}, "source": [ "First, we have to prepare all dependencies." ] }, { "cell_type": "code", "execution_count": null, "id": "73bbc648", "metadata": {}, "outputs": [], "source": [ "%pip install -q petals datasets wandb scikit-learn" ] }, { "cell_type": "code", "execution_count": null, "id": "b4ab6ca7", "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "import torch\n", "import transformers\n", "import wandb\n", "from datasets import load_dataset\n", "from tqdm import tqdm\n", "from torch.optim import AdamW\n", "from torch.utils.data import DataLoader\n", "from transformers import BloomTokenizerFast, get_scheduler\n", "\n", "from petals import DistributedBloomForCausalLM" ] }, { "cell_type": "markdown", "id": "1bf07b5d", "metadata": {}, "source": [ "Let's set some hyperparameters for training:" ] }, { "cell_type": "code", "execution_count": null, "id": "f04ba4d2", "metadata": {}, "outputs": [], "source": [ "# Choose a model you'd like to prompt-tune. We recommend starting with\n", "# the smaller 7.1B version of BLOOM (bigscience/bloom-7b1-petals) for faster prototyping.\n", "# Once your code is ready, you can switch to full-scale\n", "# 176B-parameter BLOOM (bigscience/bloom-petals) or BLOOMZ (bigscience/bloomz-petals).\n", "MODEL_NAME = \"bigscience/bloom-7b1-petals\"\n", "\n", "# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').\n", "# The latter fine-tunes separate prefixes for each transformer block,\n", "# so prompt-tuning will take more time but yield better results.\n", "# See this paper for details of how it works: https://arxiv.org/pdf/2110.07602.pdf\n", "TUNING_MODE = 'ptune'\n", "\n", "NUM_PREFIX_TOKENS = 16\n", "DEVICE = 'cuda'\n", "BATCH_SIZE = 8\n", "LR = 1e-2\n", "WEIGHT_DECAY = 0.0\n", "NUM_SAMPLES = 1000\n", "SEED = 42\n", "MODEL_MAX_LENGTH = 256" ] }, { "cell_type": "markdown", "id": "d38316bd", "metadata": {}, "source": [ "Prepare tokenizer and distributed model, connect it to servers." ] }, { "cell_type": "code", "execution_count": null, "id": "03c6e53e", "metadata": {}, "outputs": [], "source": [ "tokenizer = BloomTokenizerFast.from_pretrained(MODEL_NAME)\n", "tokenizer.padding_side = 'right'\n", "tokenizer.model_max_length = MODEL_MAX_LENGTH\n", "model = DistributedBloomForCausalLM.from_pretrained(\n", " MODEL_NAME,\n", " pre_seq_len=NUM_PREFIX_TOKENS, \n", " tuning_mode=TUNING_MODE\n", ").to(DEVICE)" ] }, { "cell_type": "markdown", "id": "042e3786", "metadata": {}, "source": [ "Let's prepare the Personachat dataset. We need two mapping functions, one to concatenate history and candidate answers, and another for tokenization." ] }, { "cell_type": "code", "execution_count": null, "id": "9c44d516", "metadata": {}, "outputs": [], "source": [ "dataset = load_dataset(\"bavard/personachat_truecased\")\n", "\n", "\n", "def chunking(examples):\n", " inputs = [\n", " \"\\n-----\\n\".join(history) + \"\\n-----\\n\" + candidate\n", " for history, candidates in zip(examples[\"history\"], examples[\"candidates\"])\n", " for candidate in candidates\n", " ]\n", " return {\"chunks\": inputs}\n", "\n", "\n", "def tokenize(examples):\n", " outputs = {\n", " \"input_ids\": tokenizer(examples[\"chunks\"], padding='max_length', truncation=True)[\"input_ids\"]\n", " }\n", " outputs[\"labels\"] = outputs[\"input_ids\"]\n", " return outputs\n", "\n", "\n", "tokenized_datasets = (\n", " dataset\n", " .map(chunking, batched=True, remove_columns=dataset[\"train\"].column_names)\n", " .map(tokenize, batched=True, remove_columns=[\"chunks\"])\n", ")\n", "\n", "\n", "tokenized_datasets.set_format(\"torch\")\n", "train_dataset = tokenized_datasets[\"train\"].shuffle(seed=SEED)\n", "train_dataloader = DataLoader(\n", " train_dataset.select(list(range(NUM_SAMPLES))),\n", " shuffle=True,\n", " batch_size=BATCH_SIZE,\n", " drop_last=True,\n", ")" ] }, { "cell_type": "markdown", "id": "ef4323fd", "metadata": {}, "source": [ "Before setting up optimizers, check the model parameters that will be trained." ] }, { "cell_type": "code", "execution_count": null, "id": "9cc0ba34", "metadata": {}, "outputs": [], "source": [ "for n, p in model.named_parameters():\n", " if p.requires_grad:\n", " print(n, p.requires_grad, p.device)" ] }, { "cell_type": "markdown", "id": "59cffce7", "metadata": {}, "source": [ "The optimizer will only work on **prompts**, they are only trainable parameters. Let's initialize optimizer and learning rate scheduler." ] }, { "cell_type": "code", "execution_count": null, "id": "ef9bf344", "metadata": {}, "outputs": [], "source": [ "optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", "\n", "lr_scheduler = get_scheduler(\n", " name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n", ")" ] }, { "cell_type": "markdown", "id": "423c56d5", "metadata": {}, "source": [ "Let's initialize wandb for logging and start the training loop!" ] }, { "cell_type": "code", "execution_count": null, "id": "d9e46807", "metadata": {}, "outputs": [], "source": [ "wandb.init(\n", " project=\"bloom-personachat\",\n", " config={\n", " \"num_samples\": NUM_SAMPLES,\n", " \"batch_size\": BATCH_SIZE,\n", " \"learning_rate\": LR,\n", " \"weight_decay\": WEIGHT_DECAY,\n", " \"num_prefix_tokens\": NUM_PREFIX_TOKENS,\n", " \"model_name\": MODEL_NAME,\n", " \"seed\": SEED,\n", " }\n", ")\n", "\n", "for batch in tqdm(train_dataloader):\n", " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n", "\n", " model.train()\n", " outputs = model(**batch)\n", " loss = outputs.loss\n", " loss.backward()\n", "\n", " optimizer.step()\n", " lr_scheduler.step()\n", " optimizer.zero_grad()\n", "\n", " wandb.log({\"Train Loss\": loss})" ] }, { "cell_type": "markdown", "id": "0f36cb80", "metadata": {}, "source": [ "Try to talk with the trained model! Submit an empty input to stop the execution.\n", "\n", "\n", "__Note__: In this example, we the whole dialogue as a prefix when generating each new replica. In the future, we will support a faster \"interactive\" dialogue mode, so generating a new replica will be able to reuse inference caches from the previous replica." ] }, { "cell_type": "code", "execution_count": null, "id": "720181b7", "metadata": {}, "outputs": [], "source": [ "TOP_K = 100\n", "TEMPERATURE = 0.6\n", "\n", "with model.inference_session(max_length=512) as sess:\n", " while True:\n", " user_phrase = input()\n", " if len(user_phrase) == 0:\n", " break\n", " inputs = tokenizer([f\"{user_phrase}\\n-----\\n\"], return_tensors='pt')['input_ids'].to(DEVICE)\n", " while True:\n", " outputs = model.generate(\n", " inputs,\n", " temperature=TEMPERATURE,\n", " do_sample=True,\n", " top_k=TOP_K,\n", " max_new_tokens=1,\n", " session=sess,\n", " )\n", " bloom_answer_token = tokenizer.decode(outputs[0, -1:])\n", " print(bloom_answer_token, end=\"\", flush=True)\n", " if bloom_answer_token == \"\\n\":\n", " break\n", " inputs = None" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.9 64-bit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.9" }, "vscode": { "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: examples/prompt-tuning-sst2.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "a07e0f5e", "metadata": { "id": "a07e0f5e" }, "source": [ "
\n", " \n", "
\n", "\n", "# Distributed LLaMA for Text Classification using Prompt Tuning\n", "\n", "In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt the [LLaMA](https://github.com/facebookresearch/llama) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the LLaMA blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n", "\n", "We will adapt LLaMA for the classification task using the [SST-2 dataset](https://nlp.stanford.edu/sentiment/). This dataset is a binary classification task, where the goal is to predict whether a sentence is positive or negative. The SST-2 dataset is a subset of the Stanford Sentiment Treebank, and it is available in the [Hugging Face Datasets](https://huggingface.co/datasets) library.\n", "\n", "To use this notebook in Colab:\n", "\n", "1. Follow this link: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-sst2.ipynb)\n", "2. Go to **Runtime** -> **Change runtime type** and select the GPU accelerator." ] }, { "cell_type": "markdown", "id": "a3f8526f", "metadata": { "id": "a3f8526f" }, "source": [ "First, we have to prepare all dependencies." ] }, { "cell_type": "code", "execution_count": null, "id": "73bbc648", "metadata": { "id": "73bbc648" }, "outputs": [], "source": [ "%pip install -q datasets wandb scikit-learn\n", "%pip install -q git+https://github.com/bigscience-workshop/petals@main" ] }, { "cell_type": "code", "execution_count": null, "id": "b4ab6ca7", "metadata": { "id": "b4ab6ca7" }, "outputs": [], "source": [ "import os\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import transformers\n", "import wandb\n", "from datasets import load_dataset, load_metric\n", "from tqdm import tqdm\n", "from torch.optim import AdamW\n", "from torch.utils.data import DataLoader\n", "from transformers import LlamaTokenizer, get_scheduler, set_seed\n", "\n", "from petals import DistributedLlamaForSequenceClassification\n", "\n", "set_seed(0)" ] }, { "cell_type": "markdown", "id": "1bf07b5d", "metadata": { "id": "1bf07b5d" }, "source": [ "Let's set some hyperparameters for training:" ] }, { "cell_type": "code", "execution_count": null, "id": "f04ba4d2", "metadata": { "id": "f04ba4d2" }, "outputs": [], "source": [ "MODEL_NAME = \"enoch/llama-65b-hf\"\n", "\n", "# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').\n", "# The latter fine-tunes separate prefixes for each transformer block,\n", "# so prompt-tuning will take more time but yield better results.\n", "# See this paper for details of how it works: https://arxiv.org/pdf/2110.07602.pdf\n", "TUNING_MODE = 'ptune'\n", "\n", "NUM_PREFIX_TOKENS = 8\n", "DEVICE = 'cuda'\n", "BATCH_SIZE = 32\n", "LR = 1e-2\n", "WEIGHT_DECAY = 0.0\n", "NUM_EPOCHS = 3\n", "SEED = 42\n", "MODEL_MAX_LENGTH = 64" ] }, { "cell_type": "markdown", "id": "d38316bd", "metadata": { "id": "d38316bd" }, "source": [ "Here, we prepare tokenizer and distributed model and connect it to the public swarm." ] }, { "cell_type": "code", "execution_count": null, "id": "03c6e53e", "metadata": { "id": "03c6e53e" }, "outputs": [], "source": [ "tokenizer = LlamaTokenizer.from_pretrained(MODEL_NAME)\n", "tokenizer.padding_side = 'right'\n", "tokenizer.model_max_length = MODEL_MAX_LENGTH\n", "tokenizer.pad_token = tokenizer.unk_token\n", "model = DistributedLlamaForSequenceClassification.from_pretrained(\n", " MODEL_NAME,\n", " pre_seq_len=NUM_PREFIX_TOKENS,\n", " tuning_mode=TUNING_MODE\n", ").float().to(DEVICE)\n", "model.config.pad_token_id = tokenizer.pad_token_id" ] }, { "cell_type": "markdown", "id": "042e3786", "metadata": { "id": "042e3786" }, "source": [ "Let's prepare the SST-2 dataset. We need just one preprocessing function to tokenize the dataset." ] }, { "cell_type": "code", "execution_count": null, "id": "9c44d516", "metadata": { "id": "9c44d516" }, "outputs": [], "source": [ "task = 'sst2'\n", "\n", "dataset = load_dataset(\"glue\", task)\n", "\n", "def preprocess_function(examples):\n", " return tokenizer(examples[\"sentence\"], padding='max_length', truncation=True, return_token_type_ids=False)\n", "\n", "tokenized_datasets = dataset.map(preprocess_function, batched=True)\n", "tokenized_datasets = tokenized_datasets.remove_columns([\"sentence\", \"idx\", \"attention_mask\"])\n", "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n", "tokenized_datasets.set_format(\"torch\")\n", "\n", "train_dataset = tokenized_datasets[\"train\"].shuffle(seed=SEED)\n", "valid_dataset = tokenized_datasets[\"validation\"].shuffle(seed=SEED)\n", "\n", "train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE, drop_last=True)\n", "valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)" ] }, { "cell_type": "markdown", "id": "2a3f3590", "metadata": { "id": "2a3f3590" }, "source": [ "To monitor training, we need the metric function. For SST-2, the target metric is accuracy. We will load it from the datasets library." ] }, { "cell_type": "code", "execution_count": null, "id": "1e1812be", "metadata": { "id": "1e1812be" }, "outputs": [], "source": [ "metric = load_metric('glue', task)\n", "\n", "def eval_metrics(model, dataloader, device='cpu'):\n", " model.eval()\n", " for batch in dataloader:\n", " batch = {k: v.to(device) for k, v in batch.items()}\n", "\n", " with torch.no_grad():\n", " outputs = model(**batch)\n", "\n", " logits = outputs.logits\n", " predictions = torch.argmax(logits, dim=-1)\n", " metric.add_batch(predictions=predictions, references=batch[\"labels\"])\n", " model.train()\n", " return metric.compute()" ] }, { "cell_type": "markdown", "id": "ef4323fd", "metadata": { "id": "ef4323fd" }, "source": [ "Before setting up optimizers, let's check the model parameters that will be trained." ] }, { "cell_type": "code", "execution_count": null, "id": "9cc0ba34", "metadata": { "id": "9cc0ba34" }, "outputs": [], "source": [ "for n, p in model.named_parameters():\n", " if p.requires_grad:\n", " print(n, p.requires_grad, p.device)" ] }, { "cell_type": "markdown", "id": "59cffce7", "metadata": { "id": "59cffce7" }, "source": [ "The optimizer will only work on **prompts and classifier head**: they are only trainable parameters. Let's initialize the optimizer and the learning rate scheduler." ] }, { "cell_type": "code", "execution_count": null, "id": "ef9bf344", "metadata": { "id": "ef9bf344" }, "outputs": [], "source": [ "optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", "\n", "lr_scheduler = get_scheduler(\n", " name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader) * NUM_EPOCHS\n", ")" ] }, { "cell_type": "markdown", "id": "423c56d5", "metadata": { "id": "423c56d5" }, "source": [ "Let's initialize wandb for logging and start the training loop!" ] }, { "cell_type": "code", "execution_count": null, "id": "d9e46807", "metadata": { "id": "d9e46807" }, "outputs": [], "source": [ "wandb.init(\n", " project=\"bloom-sst-2\",\n", " config={\n", " \"num_epochs\": NUM_EPOCHS,\n", " \"batch_size\": BATCH_SIZE,\n", " \"learning_rate\": LR,\n", " \"weight_decay\": WEIGHT_DECAY,\n", " \"num_prefix_tokens\": NUM_PREFIX_TOKENS,\n", " \"model_name\": MODEL_NAME,\n", " \"seed\": SEED,\n", " }\n", ")\n", "\n", "scaler = torch.cuda.amp.GradScaler()\n", "\n", "for epoch in range(NUM_EPOCHS):\n", " model.train()\n", " for batch in tqdm(train_dataloader):\n", " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n", "\n", " with torch.autocast(device_type=DEVICE, dtype=torch.float16):\n", " outputs = model(**batch)\n", " loss = outputs.loss\n", " scaler.scale(loss).backward()\n", "\n", " scaler.step(optimizer)\n", " scaler.update()\n", " lr_scheduler.step()\n", " optimizer.zero_grad()\n", "\n", " wandb.log({\"Train Loss\": loss.detach()})\n", "\n", " accuracy = eval_metrics(model, valid_dataloader, device=DEVICE)\n", " wandb.log({\"Valid Accuracy\": accuracy}, commit=False)" ] }, { "cell_type": "markdown", "id": "51770911", "metadata": { "id": "51770911" }, "source": [ "Our model has been trained! You can now upload it to the Hub for later use, try out different models [served in the public swarm](https://health.petals.dev/), or [join Petals with your own GPU](https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity)!" ] }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [], "metadata": { "collapsed": false } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" }, "vscode": { "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } }, "colab": { "provenance": [], "gpuType": "T4" }, "accelerator": "GPU" }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: pyproject.toml ================================================ [build-system] requires = [ "setuptools>=42", "wheel" ] build-backend = "setuptools.build_meta" [tool.black] line-length = 120 required-version = "22.3.0" [tool.isort] profile = "black" line_length = 120 combine_as_imports = true combine_star = true known_local_folder = ["tests", "cli"] known_first_party = ["test_utils"] ================================================ FILE: setup.cfg ================================================ [metadata] name = petals version = attr: petals.__version__ author = Petals Developers author_email = petals-devs@googlegroups.com description = Easy way to efficiently run 100B+ language models without high-end GPUs long_description = file: README.md long_description_content_type = text/markdown url = https://github.com/bigscience-workshop/petals project_urls = Bug Tracker = https://github.com/bigscience-workshop/petals/issues classifiers = Development Status :: 4 - Beta Intended Audience :: Developers Intended Audience :: Science/Research License :: OSI Approved :: MIT License Programming Language :: Python :: 3 Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 Programming Language :: Python :: 3.11 Topic :: Scientific/Engineering Topic :: Scientific/Engineering :: Mathematics Topic :: Scientific/Engineering :: Artificial Intelligence Topic :: Software Development Topic :: Software Development :: Libraries Topic :: Software Development :: Libraries :: Python Modules [options] package_dir = = src packages = find: python_requires = >=3.8 install_requires = torch>=1.12 bitsandbytes==0.41.1 accelerate>=0.27.2 huggingface-hub>=0.11.1,<1.0.0 tokenizers>=0.13.3 transformers==4.43.1 # if you change this, please also change version assert in petals/__init__.py speedtest-cli==2.1.3 hivemind @ git+https://github.com/learning-at-home/hivemind.git@213bff98a62accb91f254e2afdccbf1d69ebdea9 tensor_parallel==1.0.23 humanfriendly async-timeout>=4.0.2 cpufeature>=0.2.0; platform_machine == "x86_64" packaging>=20.9 sentencepiece>=0.1.99 peft==0.8.2 safetensors>=0.3.1 Dijkstar>=2.6.0 numpy<2 [options.extras_require] dev = pytest==6.2.5 pytest-forked pytest-asyncio==0.16.0 black==22.3.0 isort==5.10.1 psutil [options.packages.find] where = src ================================================ FILE: src/petals/__init__.py ================================================ import os import platform os.environ.setdefault("BITSANDBYTES_NOWELCOME", "1") if platform.system() == "Darwin": # Necessary for forks to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93 os.environ.setdefault("no_proxy", "*") os.environ.setdefault("OBJC_DISABLE_INITIALIZE_FORK_SAFETY", "YES") import hivemind import transformers from packaging import version from petals.client import * from petals.models import * from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs __version__ = "2.3.0.dev2" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): assert ( version.parse("4.43.1") <= version.parse(transformers.__version__) < version.parse("4.44.0") ), "Please install a proper transformers version: pip install transformers>=4.43.1,<4.44.0" def _override_bfloat16_mode_default(): if os.getenv("USE_LEGACY_BFLOAT16") is None: hivemind.compression.base.USE_LEGACY_BFLOAT16 = False _initialize_logs() _override_bfloat16_mode_default() ================================================ FILE: src/petals/cli/__init__.py ================================================ ================================================ FILE: src/petals/cli/run_dht.py ================================================ """ A copy of run_dht.py from hivemind with the ReachabilityProtocol added: https://github.com/learning-at-home/hivemind/blob/master/hivemind/hivemind_cli/run_dht.py This script may be used for launching lightweight CPU machines serving as bootstrap nodes to a Petals swarm. This may be eventually merged to the hivemind upstream. """ import argparse import time from secrets import token_hex from hivemind.dht import DHT, DHTNode from hivemind.utils.logging import get_logger, use_hivemind_log_handler from hivemind.utils.networking import log_visible_maddrs from petals.server.reachability import ReachabilityProtocol use_hivemind_log_handler("in_root_logger") logger = get_logger(__name__) async def report_status(dht: DHT, node: DHTNode): logger.info( f"{len(node.protocol.routing_table.uid_to_peer_id) + 1} DHT nodes (including this one) " f"are in the local routing table " ) logger.debug(f"Routing table contents: {node.protocol.routing_table}") logger.info(f"Local storage contains {len(node.protocol.storage)} keys") logger.debug(f"Local storage contents: {node.protocol.storage}") # Contact peers and keep the routing table healthy (remove stale PeerIDs) await node.get(f"heartbeat_{token_hex(16)}", latest=True) def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( "--initial_peers", nargs="*", help="Multiaddrs of the peers that will welcome you into the existing DHT. " "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/tcp/7777/p2p/YYYY", ) parser.add_argument( "--host_maddrs", nargs="*", default=["/ip4/0.0.0.0/tcp/0", "/ip6/::/tcp/0"], help="Multiaddrs to listen for external connections from other DHT instances. " "Defaults to all IPv4 interfaces and the TCP protocol: /ip4/0.0.0.0/tcp/0", ) parser.add_argument( "--announce_maddrs", nargs="*", help="Visible multiaddrs the host announces for external connections from other DHT instances", ) parser.add_argument( "--use_ipfs", action="store_true", help='Use IPFS to find initial_peers. If enabled, you only need to provide the "/p2p/XXXX" ' "part of the multiaddrs for the initial_peers " "(no need to specify a particular IPv4/IPv6 host and port)", ) parser.add_argument( "--identity_path", help="Path to a private key file. If defined, makes the peer ID deterministic. " "If the file does not exist, writes a new private key to this file.", ) parser.add_argument( "--no_relay", action="store_false", dest="use_relay", help="Disable circuit relay functionality in libp2p (see https://docs.libp2p.io/concepts/nat/circuit-relay/)", ) parser.add_argument( "--use_auto_relay", action="store_true", help="Look for libp2p relays to become reachable if we are behind NAT/firewall", ) parser.add_argument( "--refresh_period", type=int, default=30, help="Period (in seconds) for fetching the keys from DHT" ) args = parser.parse_args() dht = DHT( start=True, initial_peers=args.initial_peers, host_maddrs=args.host_maddrs, announce_maddrs=args.announce_maddrs, use_ipfs=args.use_ipfs, identity_path=args.identity_path, use_relay=args.use_relay, use_auto_relay=args.use_auto_relay, ) log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs) reachability_protocol = ReachabilityProtocol.attach_to_dht(dht, await_ready=True) while True: dht.run_coroutine(report_status, return_future=False) time.sleep(args.refresh_period) if __name__ == "__main__": main() ================================================ FILE: src/petals/cli/run_prod_server.sh ================================================ #!/bin/bash set -x export HIVEMIND_COLORS=true while true; do pkill -f p2p pkill -f run_server python -m petals.cli.run_server bigscience/bloom-petals "$@" 2>&1 | tee log_`date '+%F_%H:%M:%S'` done ================================================ FILE: src/petals/cli/run_server.py ================================================ import argparse import logging import configargparse import torch from hivemind.proto.runtime_pb2 import CompressionType from hivemind.utils import limits from hivemind.utils.logging import get_logger from humanfriendly import parse_size from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS from petals.server.server import Server from petals.utils.convert_block import QuantType from petals.utils.version import validate_version logger = get_logger(__name__) def main(): # fmt:off parser = configargparse.ArgParser(default_config_files=["config.yml"], formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add('-c', '--config', required=False, is_config_file=True, help='config file path') group = parser.add_mutually_exclusive_group(required=True) group.add_argument('--converted_model_name_or_path', type=str, default=None, help="path or name of a pretrained model, converted with cli/convert_model.py") group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path") parser.add_argument("--public_name", type=str, default=None, help="Public name to be reported in the leaderboard") group = parser.add_mutually_exclusive_group(required=False) group.add_argument("--token", type=str, default=None, help="Hugging Face hub auth token for .from_pretrained()") group.add_argument("--use_auth_token", action="store_true", dest="token", help="Read token saved by `huggingface-cli login") parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve") parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve") parser.add_argument('--dht_prefix', type=str, default=None, help="Announce all blocks with this DHT prefix") parser.add_argument('--port', type=int, required=False, help='Port this server listens to. ' 'This is a simplified way to set the --host_maddrs and --announce_maddrs options (see below) ' 'that sets the port across all interfaces (IPv4, IPv6) and protocols (TCP, etc.) ' 'to the same number. Default: a random free port is chosen for each interface and protocol') parser.add_argument('--public_ip', type=str, required=False, help='Your public IPv4 address, which is visible from the Internet. ' 'This is a simplified way to set the --announce_maddrs option (see below).' 'Default: server announces IPv4/IPv6 addresses of your network interfaces') parser.add_argument("--no_auto_relay", action="store_false", dest="use_auto_relay", help="Do not look for libp2p relays to become reachable if we are behind NAT/firewall") parser.add_argument('--host_maddrs', nargs='+', required=False, help='Multiaddrs to listen for external connections from other peers') parser.add_argument('--announce_maddrs', nargs='+', required=False, help='Visible multiaddrs the host announces for external connections from other peers') parser.add_argument('--daemon_startup_timeout', type=float, default=60, help='Timeout for the libp2p daemon connecting to initial peers') parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication') parser.add_argument('--num_handlers', type=int, default=8, required=False, help='server will use this many processes to handle incoming requests') parser.add_argument('--prefetch_batches', type=int, default=1, required=False, help='Pre-form this many subsequent batches while GPU is processing the current one') parser.add_argument('--sender_threads', type=int, default=1, required=False, help='Use this many threads to pass results/exceptions from Runtime to Pools') parser.add_argument('--inference_max_length', type=int, default=None, help='Maximum total sequence length permitted per inference, defaults to 16384 tokens. ' 'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others') parser.add_argument('--min_batch_size', type=int, default=1, help='Minimum required batch size for all operations (in total tokens)') parser.add_argument('--max_batch_size', type=int, default=None, help='The total number of tokens in the same batch will not exceed this value. ' 'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others') parser.add_argument('--max_chunk_size_bytes', type=int, default=256 * 1024 * 1024, help='Maximum size of activation tensor processed in one go; larger tensors are split into chunks') parser.add_argument('--attn_cache_tokens', type=int, default=None, help='The number of past attention key/value pairs that will be stored between inference steps. ' 'Default: 16384 for models with multi-query attention (based on Llama 2, Falcon), 4096 for others') parser.add_argument('--cache_dir', type=str, default=None, help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.') parser.add_argument("--max_disk_space", type=str, default=None, help="Maximal disk space used for caches. Example: 50GB, 100GiB (GB != GiB here). " "Default: unlimited. " "For bigscience/bloom-petals, this default means that the server may use up to " "min(free_disk_space, 350GB) in the worst case, which happens when the server runs " "for a long time and caches all model blocks after a number of rebalancings. " "However, this worst case is unlikely, expect the server to consume " "the disk space equal to 2-4x of your GPU memory on average.") parser.add_argument('--device', type=str, default=None, required=False, help='all blocks will use this device in torch notation; default: cuda if available else cpu') parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto", help="Use this dtype to store block weights and do computations. " "By default, respect the dtypes in the pre-trained state dict.") parser.add_argument('--max_alloc_timeout', type=float, default=600, help="If the cache is full, the server will wait for memory to be freed up to this many seconds" " before rejecting the request") parser.add_argument('--revision', type=str, default=None, help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models" "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.") parser.add_argument('--throughput', type=lambda value: value if value in ['auto', 'eval', 'dry_run'] else float(value), default='auto', help='Expected server throughput (a float measured in RPS). ' 'If set to "auto" (default), the script evaluates network and compute throughput ' 'on the first run and uses these estimates for future runs. ' 'If set to "eval", the script re-evaluates the throughput and overrides the cache. ' 'If set to "dry_run", the script re-evaluates the throughput and exits.') parser.add_argument('--update_period', type=float, required=False, default=120, help='Server will report blocks to DHT once in this many seconds') parser.add_argument('--expiration', type=float, required=False, default=None, help='DHT entries will expire after this many seconds') parser.add_argument('--request_timeout', type=float, required=False, default=3 * 60, help='Timeout (in seconds) for the whole rpc_forward/rpc_backward/rpc_forward_stream/rpc_backward_stream request') parser.add_argument('--session_timeout', type=float, required=False, default=30 * 60, help='Timeout (in seconds) for the whole inference session') parser.add_argument('--step_timeout', type=float, required=False, default=5 * 60, help="Timeout (in seconds) for waiting the next step's inputs inside an inference session") group = parser.add_mutually_exclusive_group() group.add_argument('--initial_peers', type=str, nargs='+', required=False, default=PUBLIC_INITIAL_PEERS, help='Multiaddrs of one or more DHT peers from the target swarm. Default: connects to the public swarm') group.add_argument('--new_swarm', action='store_true', help='Start a new private swarm (i.e., do not connect to any initial peers)') parser.add_argument('--increase_file_limit', type=int, default=4096, help='On *nix, increase the max number of files a server can open ' 'before hitting "Too many open files" (set to zero to keep the system limit)') parser.add_argument('--stats_report_interval', type=int, required=False, help='Interval between two reports of batch processing performance statistics') parser.add_argument('--custom_module_path', type=str, required=False, help='Path of a file with custom nn.modules, wrapped into special decorator') parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P') parser.add_argument("--balance_quality", type=float, default=0.75, help="Rebalance the swarm if its throughput is worse than this share of the optimal " "throughput. Use 0.0 to disable rebalancing, values > 1.0 to force rebalancing " "on each check for debugging purposes.") parser.add_argument("--mean_balance_check_period", type=float, default=60, help="Check the swarm's balance every N seconds (and rebalance it if necessary)") parser.add_argument('--quant_type', type=str, default=None, choices=[choice.name.lower() for choice in QuantType], help="Quantize blocks to 8-bit (int8 from the LLM.int8() paper) or " "4-bit (nf4 from the QLoRA paper) formats to save GPU memory. " "Default: 'int8' if GPU is available, 'none' otherwise") parser.add_argument("--tensor_parallel_devices", nargs='+', default=None, help= "Split each block between the specified GPUs such that each device holds a portion of every " "weight matrix. See https://huggingface.co/transformers/v4.9.0/parallelism.html#tensor-parallelism") parser.add_argument("--skip_reachability_check", action='store_true', help="Skip checking this server's reachability via health.petals.dev " "when connecting to the public swarm. If you connect to a private swarm, " "the check is skipped by default. Use this option only if you know what you are doing") parser.add_argument("--adapters", nargs='*', default=(), help="List of pre-loaded LoRA adapters that can be used for inference or training") # fmt:on args = vars(parser.parse_args()) args.pop("config", None) args["converted_model_name_or_path"] = args.pop("model") or args["converted_model_name_or_path"] host_maddrs = args.pop("host_maddrs") port = args.pop("port") if port is not None: assert host_maddrs is None, "You can't use --port and --host_maddrs at the same time" else: port = 0 if host_maddrs is None: host_maddrs = [f"/ip4/0.0.0.0/tcp/{port}", f"/ip6/::/tcp/{port}"] announce_maddrs = args.pop("announce_maddrs") public_ip = args.pop("public_ip") if public_ip is not None: assert announce_maddrs is None, "You can't use --public_ip and --announce_maddrs at the same time" assert port != 0, "Please specify a fixed non-zero --port when you use --public_ip (e.g., --port 31337)" announce_maddrs = [f"/ip4/{public_ip}/tcp/{port}"] args["startup_timeout"] = args.pop("daemon_startup_timeout") file_limit = args.pop("increase_file_limit") if file_limit: limits.logger.setLevel(logging.WARNING) limits.increase_file_limit(file_limit, file_limit) compression_type = args.pop("compression").upper() compression = getattr(CompressionType, compression_type) max_disk_space = args.pop("max_disk_space") if max_disk_space is not None: max_disk_space = parse_size(max_disk_space) assert isinstance( max_disk_space, (int, type(None)) ), "Unrecognized value for --max_disk_space. Correct examples: 1.5GB or 1500MB or 1572864000 (bytes)" if args.pop("new_swarm"): args["initial_peers"] = [] quant_type = args.pop("quant_type") if quant_type is not None: args["quant_type"] = QuantType[quant_type.upper()] validate_version() if not torch.backends.openmp.is_available(): # Necessary to prevent the server from freezing after forks torch.set_num_threads(1) server = Server( **args, host_maddrs=host_maddrs, announce_maddrs=announce_maddrs, compression=compression, max_disk_space=max_disk_space, ) try: server.run() except KeyboardInterrupt: logger.info("Caught KeyboardInterrupt, shutting down") finally: server.shutdown() if __name__ == "__main__": main() ================================================ FILE: src/petals/client/__init__.py ================================================ from petals.client.config import ClientConfig from petals.client.inference_session import InferenceSession from petals.client.remote_sequential import RemoteSequential from petals.client.routing import NoSpendingPolicy, RemoteSequenceManager, SpendingPolicyBase ================================================ FILE: src/petals/client/config.py ================================================ import dataclasses import os from typing import Optional, Sequence, Union from hivemind import PeerID from petals.constants import PUBLIC_INITIAL_PEERS _max_retries = os.getenv("PETALS_MAX_RETRIES") DEFAULT_MAX_RETRIES = int(_max_retries) if isinstance(_max_retries, str) else None @dataclasses.dataclass class ClientConfig: initial_peers: Sequence[str] = tuple(PUBLIC_INITIAL_PEERS) # a list of initial peers for hivemind DHT dht_prefix: Optional[str] = None # a prefix for all dht keys that correspond to this model (default: model name) daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True] allowed_servers: Optional[Sequence[Union[PeerID, str]]] = None # if defined, send requests only to these servers blocked_servers: Optional[Sequence[Union[PeerID, str]]] = None # if defined, do not use these servers use_server_to_server: bool = True # Use direct server-to-server communication connect_timeout: float = 5 # timeout for opening a connection request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests update_period: float = 60 # refresh DHT information once in this many seconds max_retries: Optional[int] = DEFAULT_MAX_RETRIES # max number of retries before an exception (default: inf) min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1) max_backoff: float = 60 # limit maximal sleep time between retries to this value ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo) max_pinged: int = 3 # max servers to ping from each sequence side, per update ping_timeout: float = 2 # max time to wait for pings, per update ================================================ FILE: src/petals/client/from_pretrained.py ================================================ import contextlib import json import os import re import tempfile from contextvars import ContextVar from typing import List, Optional, Tuple, Union from hivemind.utils.logging import get_logger from transformers import BloomPreTrainedModel, modeling_utils from petals.utils.version import get_compatible_model_repo logger = get_logger(__name__) class FromPretrainedMixin: @classmethod def from_pretrained( cls, model_name_or_path: Union[str, os.PathLike, None], *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs, ): model_name_or_path = get_compatible_model_repo(model_name_or_path) if low_cpu_mem_usage is None: low_cpu_mem_usage = True with ignore_keys(cls._keys_to_ignore_on_load_unexpected): return super().from_pretrained(model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs) from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace( "low_cpu_mem_usage(`bool`, *optional*)", "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)", ).replace( "torch_dtype (`str` or `torch.dtype`, *optional*)", 'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)', ) _ignored_keys = ContextVar("ignored_keys", default=None) @contextlib.contextmanager def ignore_keys(patterns: List[str]): token = _ignored_keys.set(patterns) try: yield finally: _ignored_keys.reset(token) def patched_get_checkpoint_shard_files( pretrained_model_name_or_path, index_filename, *args, **kwargs ) -> Tuple[List[str], dict]: """Same as modeling_utils.get_checkpoint_shard_files(), but does not download shards for the ignored keys.""" should_ignore_keys = _ignored_keys.get() is not None tempdir_ctx = tempfile.TemporaryDirectory() if should_ignore_keys else contextlib.nullcontext() with tempdir_ctx as tempdir: if should_ignore_keys: with open(index_filename) as f: index = json.load(f) n_original_shards = len(set(index["weight_map"].values())) index["weight_map"] = { param_name: filename for param_name, filename in index["weight_map"].items() if all(re.search(pattern, param_name) is None for pattern in _ignored_keys.get()) } n_loaded_shards = len(set(index["weight_map"].values())) logger.debug(f"Loading {n_loaded_shards} shards out of {n_original_shards}") # Replace the original index with a patched JSON, where ignored keys are removed index_filename = os.path.join(tempdir, "pytorch_model.bin.index.json") with open(index_filename, "w") as f: json.dump(index, f) return original_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs) original_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files ================================================ FILE: src/petals/client/inference_session.py ================================================ from __future__ import annotations import asyncio import itertools import time import uuid from typing import AsyncIterator, List, Optional, Tuple import torch from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.p2p import P2P from hivemind.proto import runtime_pb2 from hivemind.utils.tensor_descr import BatchTensorDescriptor from petals.client.config import ClientConfig from petals.client.routing import RemoteSequenceManager, maybe_log_traceback from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo from petals.server.handler import TransformerConnectionHandler from petals.utils.misc import DUMMY, DUMMY_INT64, is_dummy from petals.utils.packaging import pack_args_kwargs logger = get_logger(__name__) class _ServerInferenceSession: """ An interface to a single multi-step *inference* session for a a set of blocks on a specific server. :note: This class is *not* fault-tolerant out of the box. """ def __init__( self, config: ClientConfig, span: RemoteSpanInfo, uid: ModuleUID, rpc_info: RPCInfo, inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator, *, max_length: int, **metadata, ): self.config = config self.span, self.uid, self.rpc_info = span, uid, rpc_info self.num_blocks = uid.count(CHAIN_DELIMITER) + 1 self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter self.session_id = str(uuid.uuid4()) self.session_metadata = dict(max_length=max_length, **metadata) self.stepped = False self.closed = False self._position = 0 self.history = None # Used in case of server failures to regenerate attention caches on new servers self.next_session = None @classmethod async def create( cls, config: ClientConfig, p2p: P2P, span: RemoteSpanInfo, uid: ModuleUID, rpc_info: RPCInfo, **metadata, ) -> _ServerInferenceSession: """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker""" stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id) inputs_queue = asyncio.Queue() outputs_stream = await asyncio.wait_for( stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)), config.connect_timeout, ) return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata) @staticmethod async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator: while True: next_input_message = await asyncio.wait_for(queue.get(), input_timeout) yield next_input_message if not next_input_message.uid and not next_input_message.tensors: break # this message means "done sending" @property def position(self): return self._position @position.setter def position(self, start_from_position: int): assert start_from_position <= self._position self._position = start_from_position if self.history is not None and self.history.shape[1] >= start_from_position: self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None def step( self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *, step_id: str, ) -> torch.Tensor: """ Inference step: send a chunk of input tensors and receive a chunk of outputs :prompts: optional DEEP prompts, added to a prefix of each layer's outputs, if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size] """ if self.closed: raise Exception("Session is closed, cannot perform step") n_input_tokens = inputs.shape[1] if self.history is None: self.history = inputs elif self.history.shape[1] == self._position: self.history = torch.cat([self.history, inputs[:, -n_input_tokens:]], dim=1) assert self.history.shape[1] == self._position + n_input_tokens, ( f"Broken input cache: span={self.span} shape={self.history.shape} " f"position={self._position} n_input_tokens={n_input_tokens}" ) if not self.stepped: inputs = self.history # Pass full inputs including prefix else: inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further # serialize inputs and put them into the queue input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids) request_metadata = dict(session_id=self.session_id, step_id=step_id) if not self.stepped: request_metadata.update(self.session_metadata) if self._position is not None: request_metadata["start_from_position"] = self._position elif self.config.use_server_to_server: next_servers = self._collect_next_servers() if next_servers: request_metadata["next_servers"] = next_servers request_metadata["args_structure"] = args_structure # TODO: make possible to use different compression method for different tensors server_side_inference_schema, kwargs_schema = self.rpc_info["inference_schema"] compression = server_side_inference_schema[0].compression inference_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in input_tensors) # TODO: create more explicit way to check servers schema and client's structure assert len(input_tensors) >= len( server_side_inference_schema ), "Hidden_state, prompts and hypo_ids tensors are necessary for an inference step" outputs_serialized = RemoteExpertWorker.run_coroutine( self._step( runtime_pb2.ExpertRequest( uid=self.uid, tensors=[ serialize_torch_tensor(tensor.to(proto.dtype), proto.compression) for tensor, proto in zip(input_tensors, inference_schema) ], metadata=MSGPackSerializer.dumps(request_metadata), ) ) ) outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors)) assert ( outputs[0].shape == inputs.shape ), f"output activation shape is different from input shape: {outputs[0].shape} != {inputs.shape}" self._position += n_input_tokens return outputs[0] def _collect_next_servers(self) -> List[Tuple[str, str, int, int]]: next_servers = [] session = self.next_session while session is not None and session.stepped: next_servers.append( (session.span.peer_id.to_base58(), session.session_id, session.span.start, session.span.end) ) session = session.next_session return next_servers async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse: """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker""" await self._inputs_queue.put(inputs_serialized) self.stepped = True return await asyncio.wait_for(anext(self._outputs_stream), self.config.request_timeout) def close(self): """Finish a given inference session, close the underlying connection""" if self._outputs_stream is None: return # already closed RemoteExpertWorker.run_coroutine(self._aclose_stream()) self._outputs_stream = self._inputs_queue = None self.closed = True async def _aclose_stream(self): """Close the inference session. This code is meant to be run inside RemoteExpertWorker""" if self._outputs_stream is None: return # already closed if self.stepped: await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session try: await anext(self._outputs_stream) except StopAsyncIteration: pass def __del__(self): self.close() def __enter__(self): assert not self.closed return self def __exit__(self, *exc_details): self.close() class InferenceSession: """ An interface to a multi-step *inference* session for a sequence of remote transformer blocks """ def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int): self._sequence_manager = sequence_manager self._closed = False self._server_sessions = [] self._position = 0 self._max_length = max_length self.output_ids = None self.past_key_values = None @property def num_blocks(self) -> int: return len(self._sequence_manager) @property def position(self) -> int: return self._position @position.setter def position(self, start_from_position: int) -> None: self._position = start_from_position for session in self._server_sessions: assert isinstance(session, _ServerInferenceSession) session.position = start_from_position def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]: server_sessions = [] try: for span in chosen_spans: span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end]) metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id) session = RemoteExpertWorker.run_coroutine( _ServerInferenceSession.create( self._sequence_manager.config, self._sequence_manager.state.p2p, span, span_uids, rpc_info=self._sequence_manager.rpc_info, max_length=self._max_length, **metadata, ) ) server_sessions.append(session) session.__enter__() return server_sessions except: self._exit_server_sessions(server_sessions) raise def _exit_server_sessions(self, server_sessions: List[_ServerInferenceSession]) -> None: for session in reversed(server_sessions): try: session.__exit__(None, None, None) except Exception: logger.debug("Caught exception while closing connection to server:", exc_info=True) def __enter__(self) -> "InferenceSession": assert not self._closed and not self._server_sessions return self def step( self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert not self._closed if torch.is_grad_enabled(): logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.") if prompts is None or is_dummy(prompts): prompts = DUMMY else: assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]" assert prompts.shape[0] == self.num_blocks assert prompts.shape[1] in (inputs.shape[0], 1) assert prompts.shape[2] <= inputs.shape[1] assert prompts.shape[3] == inputs.shape[2] if hypo_ids is None or is_dummy(hypo_ids): hypo_ids = DUMMY_INT64 else: assert len(hypo_ids) == len(inputs) assert hypo_ids.dtype == torch.int64 inputs_device = inputs.device inputs_dtype = inputs.dtype inputs = inputs.cpu() prompts = prompts.cpu() hypo_ids = hypo_ids.cpu() step_id = str(uuid.uuid4()) n_input_tokens = inputs.shape[1] if self._position + n_input_tokens > self._max_length: raise ValueError( f"Maximum length exceeded: prefix {self._position} + current {n_input_tokens} exceeds pre-allocated maximum {self._max_length}" ) server_idx = 0 block_idx = 0 while block_idx < self.num_blocks: for attempt_no in itertools.count(): logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}") server_session = None try: if not self._server_sessions or attempt_no >= 1: self._update_sequence(server_idx, block_idx, attempt_no) server_session = self._server_sessions[server_idx] assert server_session.position == self.position, f"{server_session.position} and {self.position}" inputs = server_session.step( inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, step_id=step_id, ) server_idx += 1 block_idx = server_session.span.end self._sequence_manager.on_request_success(server_session.span.peer_id) break except Exception as e: self._sequence_manager.on_request_failure( server_session.span.peer_id if server_session is not None else None ) if attempt_no + 1 == self._sequence_manager.config.max_retries: raise delay = self._sequence_manager.get_retry_delay(attempt_no) logger.warning( f"Caught exception when running inference via {server_session.span if server_session is not None else None} " f"(retry in {delay:.0f} sec): {repr(e)}" ) maybe_log_traceback(e) time.sleep(delay) self._position += n_input_tokens outputs = inputs[:, -n_input_tokens:] outputs = outputs.to(device=inputs_device, dtype=inputs_dtype) return outputs def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) -> int: # If there is a failed server session, this code closes it self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1]) n_prev_spans = len(self._server_sessions) update_end = self._server_sessions[server_idx].span.end if server_idx < n_prev_spans else self.num_blocks if attempt_no >= 1: logger.debug( f"Due to a server failure, remote attention caches " f"from block {block_idx} to {update_end} will be regenerated" ) updated_spans = self._sequence_manager.make_sequence( block_idx, update_end, mode="min_latency", cache_tokens_needed=self._max_length ) # make_sequence() could return a longer sequence updated_spans[-1].end = min(updated_spans[-1].end, update_end) updated_sessions = self._enter_server_sessions(updated_spans) logger.debug(f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers") # If there is a failed span, this code replaces it, otherwise it just adds new ones if server_idx < n_prev_spans: updated_sessions[0].history = self._server_sessions[server_idx].history self._server_sessions[server_idx : server_idx + 1] = updated_sessions # Update links to the next server session for direct server-to-server communication via rpc_push() for i in range(max(server_idx - 1, 0), min(server_idx + len(updated_spans), len(self._server_sessions) - 1)): self._server_sessions[i].next_session = self._server_sessions[i + 1] def close(self, *exc_details): """Finish a given inference session, close the underlying connection""" if not self._closed: self._exit_server_sessions(self._server_sessions) self._server_sessions.clear() self._closed = True def __exit__(self, *exc_details): self.close(*exc_details) def __del__(self): self.close() @property def last_token_id(self) -> Optional[torch.Tensor]: # Backward compatibility with Petals < 2.1.0 return self.output_ids[:, -1:] if self.output_ids is not None else None @last_token_id.setter def last_token_id(self, value: torch.Tensor): # Backward compatibility with Petals < 2.1.0 if self.output_ids is None: raise RuntimeError("Can't override `last_token_id` since the session has not stepped yet") self.output_ids[:, -1:] = value ================================================ FILE: src/petals/client/lm_head.py ================================================ import dataclasses import platform from typing import Union import torch import torch.nn.functional as F import torch.utils.checkpoint from hivemind import get_logger from torch import nn from transformers import PretrainedConfig logger = get_logger(__name__) @dataclasses.dataclass class LMHeadConfig: # This settings matter for running the client with dtype bfloat16 on CPU. # If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations. use_chunked_forward: Union[str, bool] = "auto" chunked_forward_step: int = 16384 class LMHead(nn.Module): def __init__(self, config: PretrainedConfig): super().__init__() if not config.tie_word_embeddings: self.weight = nn.Parameter(torch.zeros(config.vocab_size, config.hidden_size)) self.weight.requires_grad = False else: self.weight = None # Will be set to get_input_embeddings().weight during loading the model self.bias = None self.in_features = config.hidden_size # Similar to nn.Linear attributes self.out_features = config.vocab_size self.use_chunked_forward = config.use_chunked_forward if self.use_chunked_forward == "auto": if platform.machine() == "x86_64": # Import of cpufeature may crash on non-x86_64 machines from cpufeature import CPUFeature # If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward(). # Otherwise, it's ~8x slower. self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"]) else: self.use_chunked_forward = True self.chunked_forward_step = config.chunked_forward_step self._bf16_warning_shown = False def forward(self, hidden_states): if ( self.weight.dtype in [torch.float16, torch.bfloat16] and self.weight.device.type == "cpu" and self.use_chunked_forward ): lm_logits = self.chunked_forward(hidden_states) else: # Switch dtype in case word_embeddings are fp16/bf16 hidden_states = hidden_states.to(self.weight.dtype) lm_logits = F.linear(hidden_states, self.weight) return lm_logits def chunked_forward(self, hidden_states): """Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU. chunked_forward_step: provides trade-off between efficiency and extra memory consumption. """ assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive" if not self._bf16_warning_shown: logger.warning( "Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. " "To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)" ) self._bf16_warning_shown = True hidden_states = hidden_states.float() output = torch.empty(*hidden_states.shape[:-1], self.out_features) for i in range(0, self.out_features, self.chunked_forward_step): chunk = self.weight[i : i + self.chunked_forward_step].float() output[..., i : i + self.chunked_forward_step] = F.linear(hidden_states, chunk) return output ================================================ FILE: src/petals/client/ptune.py ================================================ import dataclasses from contextlib import contextmanager from typing import Optional import torch import torch.nn as nn from hivemind import get_logger from transformers import PretrainedConfig from petals.utils.misc import DUMMY logger = get_logger(__name__) @dataclasses.dataclass class PTuneConfig: pre_seq_len: int = 0 # a number of tokens for prompt tuning. tuning_mode: Optional[str] = None # fine-tuning regime, one of [None, "ptune", "deep_ptune"] class PTuneMixin: _keys_to_ignore_on_load_missing = [r"(intermediate_)?prompt_embeddings\.weight$"] def init_prompts(self, config: PretrainedConfig) -> None: if config.tuning_mode and "ptune" in config.tuning_mode: assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0" self.pre_seq_len = config.pre_seq_len self.prefix_tokens = torch.arange(self.pre_seq_len).long() with force_non_empty_weights(): # Prompt embeddings and their optimizer stats are kept in float32 to increase ptune quality self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32) if config.tuning_mode == "deep_ptune": self.intermediate_prompt_embeddings = nn.Embedding( self.pre_seq_len, config.num_hidden_layers * config.hidden_size, # ^-- TODO: should be num_hidden_layers - 1 dtype=torch.float32, ) elif config.tuning_mode: raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now") def get_prompt(self, batch_size): prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1) prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device) prompts = self.prompt_embeddings(prefix_tokens) if self.config.tuning_mode == "deep_ptune": intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens) intermediate_prompts = intermediate_prompts.view( batch_size, self.pre_seq_len, self.config.num_hidden_layers, self.config.hidden_size # TODO: should be num_hidden_layers - 1 ) intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3]) else: intermediate_prompts = DUMMY dtype = self.word_embeddings.weight.dtype return prompts.to(dtype), intermediate_prompts.to(dtype) _original_register_parameter = nn.Module.register_parameter @contextmanager def force_non_empty_weights(): """ This context manager allows to bypass the accelerate.init_empty_weights() context manager (that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True. The transformers library should replace all meta tensors by empty tensors by itself but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`). [1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515 """ possibly_patched_register_parameter = nn.Module.register_parameter nn.Module.register_parameter = _original_register_parameter try: yield finally: nn.Module.register_parameter = possibly_patched_register_parameter ================================================ FILE: src/petals/client/remote_forward_backward.py ================================================ """ Utility functions that call RPC forward or backward on a single remote server """ import asyncio from typing import Iterable, List, Optional, Sequence, Tuple import torch from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor from hivemind.p2p import StubBase from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE from hivemind.proto import runtime_pb2 from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter from hivemind.utils.streaming import split_for_streaming from hivemind.utils.tensor_descr import BatchTensorDescriptor from petals.client.config import ClientConfig from petals.data_structures import ModuleUID, RPCInfo async def _forward_unary( uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs ) -> List[torch.Tensor]: outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward( runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs), timeout=config.request_timeout, ) return [deserialize_torch_tensor(t) for t in outputs.tensors] async def _backward_unary( uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs ) -> List[torch.Tensor]: grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward( runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs), timeout=config.request_timeout, ) return [deserialize_torch_tensor(t) for t in grad_inputs.tensors] async def _forward_stream( uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs ) -> List[torch.Tensor]: parts = ( runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs) for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE) ) outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(parts)), config.connect_timeout) outputs = aiter_with_timeout(outputs, config.request_timeout) return await deserialize_tensor_stream(msg.tensors async for msg in outputs) async def _backward_stream( uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs ) -> List[torch.Tensor]: parts = ( runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs) for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE) ) grad_inputs = await asyncio.wait_for(stub.rpc_backward_stream(iter_as_aiter(parts)), config.connect_timeout) grad_inputs = aiter_with_timeout(grad_inputs, config.request_timeout) return await deserialize_tensor_stream(msg.tensors async for msg in grad_inputs) async def run_remote_forward( uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, config: ClientConfig, metadata: Optional[bytes] = None, **kwargs, ) -> Tuple[torch.Tensor, ...]: """ Serializes input tensors and calls "rpc_forward" on a remote server. Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198 but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here. """ # Note: *inputs are flattened input tensors that follow the expert's info['input_schema'] # detach to avoid pickling the computation graph assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}" kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]} # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors forward_inputs = tuple(nested_flatten((inputs, kwargs))) args_schema, kwargs_schema = rpc_info["forward_schema"] compression = args_schema[0].compression forward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in forward_inputs) inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs) # TODO: create more explicit way to check servers schema and client's structure assert len(inputs) >= len(args_schema) + 1, "Inputs and prompt tensors are necessary for a forward step" # Asynchronous serialization loop = asyncio.get_running_loop() serialized_tensors = await asyncio.gather( *( loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression) for tensor, proto in zip(inputs, forward_schema) ) ) # call RPC on remote server size = sum(t.element_size() * t.nelement() for t in inputs) forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs) return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"]) async def run_remote_backward( uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs_and_grad_outputs: torch.Tensor, config: ClientConfig, metadata: Optional[bytes] = None, **kwargs, ) -> Sequence[torch.Tensor]: """ Serializes grad outputs and calls "rpc_backward" on a remote server. Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221 but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here. """ args_schema, kwargs_schema = rpc_info["forward_schema"] outputs_schema = rpc_info["outputs_schema"] compression = args_schema[0].compression backward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in inputs_and_grad_outputs) # TODO: create more explicit way to check servers schema and client's structure assert ( len(inputs_and_grad_outputs) >= len(args_schema) + len(outputs_schema) + 1 ), "Inputs, grad_outputs and prompt tensors are necessary for a backward step" # Asynchronous serialization loop = asyncio.get_running_loop() serialized_tensors = await asyncio.gather( *( loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression) for tensor, proto in zip(inputs_and_grad_outputs, backward_schema) ) ) size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs) backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs) return deserialized_grad_inputs ================================================ FILE: src/petals/client/remote_generation.py ================================================ import contextlib import dataclasses from contextvars import ContextVar from typing import Any, ContextManager, Dict, List, Optional, Tuple import torch import transformers from hivemind.utils.logging import get_logger from torch import Tensor from transformers.cache_utils import Cache, DynamicCache from transformers.generation.utils import ModelOutput from petals.client.inference_session import InferenceSession from petals.client.remote_sequential import RemoteSequential from petals.utils.misc import DUMMY, docstring_from logger = get_logger(__name__) class RemotePastKeyValues(Cache): """only keeps the number of seen tokens. pretends to be a legit cache""" def __init__(self) -> None: super().__init__() self._seen_tokens = 0 self.hypo_ids: Optional[torch.LongTensor] = None def __getitem__(self, _index: int) -> List[torch.Tensor]: return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation() def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return self._seen_tokens def get_max_length(self) -> Optional[int]: return None def update_seen(self, new_seen: int) -> None: self._seen_tokens += new_seen def reorder_cache(self, beam_idx): raise NotImplementedError("Beam search reordering is not implemented yet") _skipped_tokens = ContextVar("skipped_tokens", default=0) class _SkipTokensMixin: # This override is used in RemoteGenerationMixin by has to be defined in a class not named as "GenerationMixin" # due to how transformers.PreTrainedModel.can_generate() works def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict: input_ids = input_ids[:, _skipped_tokens.get() :] _skipped_tokens.set(0) return super().prepare_inputs_for_generation(input_ids, **kwargs) class RemoteGenerationMixin(_SkipTokensMixin): """ This class is an upgrade to `transformers.GenerationMixin` that: - Designed to be compatible with most `transformers.GenerationMixin` strategies and options - Supports generation inside a remote InferenceSession, so that remote servers store your attention caches and you don't have to rerun the prefix through all the servers to generate each new token - Supports multiple `.generate()` calls inside one InferenceSession, so you can easily run interactive generation by showing tokens on the fly (multiple calls like `.generate(None, max_new_tokens=1, ...)`) or accept prompts from a user in a chat bot (multiple calls like `.generate(new_prompts, ...)`). - If there is no active session, `.generate()` will create a new InferenceSession with proper `max_length`. Otherwise, `.generate()` will use the active session. You can use the `session=...` argument to override that. """ @docstring_from(RemoteSequential.active_session) @property def active_session(self) -> Optional[InferenceSession]: return self.transformer.h.active_session @docstring_from(RemoteSequential.use_session) def use_session(self, session: Optional[InferenceSession]) -> ContextManager[InferenceSession]: return self.transformer.h.use_session(session) @docstring_from(RemoteSequential.inference_session) def inference_session(self, **kwargs) -> ContextManager[InferenceSession]: return self.transformer.h.inference_session(**kwargs) @docstring_from(transformers.GenerationMixin.generate.__doc__) def generate( self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs ): self._fix_generate_kwargs(kwargs) if inputs is None: inputs = kwargs.pop("input_ids", None) if session is not None: # If a session specified explicitly, use it context_manager = self.use_session(session) elif self.active_session is not None: # If there's an active session, don't do anything context_manager = contextlib.nullcontext(self.active_session) else: # If there's no active session, create a new one max_length = kwargs.get("max_length") max_new_tokens = kwargs.get("max_new_tokens") assert (max_length is None) != ( max_new_tokens is None ), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches" session_max_length = self.transformer.config.pre_seq_len if max_length is not None: session_max_length += max_length else: session_max_length += (inputs.shape[1] if inputs is not None else 0) + max_new_tokens context_manager = self.inference_session(max_length=session_max_length) with context_manager as session: # Prepend the tokens from the previous .generate() call n_prev_tokens = session.output_ids.shape[1] if session.output_ids is not None else 0 if n_prev_tokens > 0: if kwargs.get("num_beams", 1) > 1: logger.warning( "Beam search will not work properly in the resumed petals.InferenceSession " "since intermediate beam entries are lost" ) if inputs is not None: inputs = torch.cat([session.output_ids, inputs], dim=1) else: inputs = session.output_ids # Don't actually run all previous tokens through the transformer, # but keep them for transformers.GenerationMixin (e.g., to compute repetition_penalty) _skipped_tokens.set(max(0, n_prev_tokens - 1)) if self._supports_cache_class and "past_key_values" not in kwargs: past_key_values = RemotePastKeyValues() past_key_values.update_seen(session.position) kwargs["past_key_values"] = past_key_values result = super().generate(inputs, *args, **kwargs) sequences = result.sequences if isinstance(result, ModelOutput) else result # Save tokens from this .generate() call session.output_ids = sequences # Crop the last tokens from the previous call sequences = sequences[:, n_prev_tokens:].clone() if isinstance(result, ModelOutput): result.sequences = sequences else: result = sequences return result @staticmethod def _fix_generate_kwargs(kwargs: dict): # Suppress inappropriate "Both max_new_tokens and max_length" HF warning if "max_length" in kwargs and kwargs["max_length"] is None: del kwargs["max_length"] # Support do_sample = {0, 1} for backward compatibility with Petals < 2.1.0 do_sample = kwargs.get("do_sample") if isinstance(do_sample, int): kwargs["do_sample"] = bool(do_sample) @staticmethod def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues: return dataclasses.replace(past_key_values, hypo_ids=beam_idx) ================================================ FILE: src/petals/client/remote_sequential.py ================================================ from __future__ import annotations from contextlib import contextmanager from contextvars import ContextVar from typing import Optional, Union import torch from hivemind import DHT, get_logger from torch import nn from petals.client.config import ClientConfig from petals.client.inference_session import InferenceSession from petals.client.routing import RemoteSequenceManager from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction from petals.data_structures import UID_DELIMITER logger = get_logger(__name__) class RemoteSequential(nn.Module): """ A sequence of transformer blocks hosted by the swarm. """ def __init__( self, config: ClientConfig, *, sequence_manager: Optional[RemoteSequenceManager] = None, dht: Optional[DHT] = None, start_block: Optional[int] = None, end_block: Optional[int] = None, **kwargs, ): super().__init__() self.config = config assert sequence_manager is None or ( dht is None and start_block is None and end_block is None ), "`dht`, `start_block`, and `end_block` have no effect when you provide a custom `sequence_manager`" if sequence_manager is None: if start_block is None: start_block = 0 if end_block is None: end_block = self.config.num_hidden_layers block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block)) sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs) self.sequence_manager = sequence_manager self._active_session = ContextVar("active_session", default=None) def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]" if self.active_session is None: assert all(v is None for v in kwargs.values()), f"Extra kwargs are not supported in forward: {kwargs}" return _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager) else: return self.active_session.step(inputs, prompts, **kwargs) @property def active_session(self) -> Optional[InferenceSession]: """ If called inside `with model.inference_session(...):` or `with model.use_session(...):`, returns an active InferenceSession. Otherwise, returns None. """ return self._active_session.get() @property def position(self) -> int: """Returns the prefix length (in tokens) in the active inference session or zero if no session is active.""" return self.active_session.position if self.active_session is not None else 0 @contextmanager def use_session(self, session: Optional[InferenceSession]) -> InferenceSession: """Inside this context, forward() will use an _existing_ InferenceSession provided as the argument.""" token = self._active_session.set(session) try: yield session finally: self._active_session.reset(token) @contextmanager def inference_session(self, **kwargs) -> InferenceSession: """ Inside this context, forward() will use a _new_ InferenceSession created with given parameters. :param max_length: Maximal expected length of inference results. Servers use this parameter to calculate the size of attention caches allocated to this client. """ with InferenceSession(self.sequence_manager, **kwargs) as session, self.use_session(session): yield session def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential: return RemoteSequential( self.config, sequence_manager=self.sequence_manager[ix], ) def __iter__(self): for block_index in range(len(self)): yield self[block_index] def __len__(self): return len(self.sequence_manager) def extra_repr(self) -> str: return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}" ================================================ FILE: src/petals/client/routing/__init__.py ================================================ from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase ================================================ FILE: src/petals/client/routing/sequence_info.py ================================================ import dataclasses import time from typing import Iterable, List, Optional, Tuple from hivemind import get_logger from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState from petals.utils.dht import compute_spans logger = get_logger(__name__) @dataclasses.dataclass class RemoteSequenceInfo: """ A dataclass that stores general information about which servers hold any given layer; - updated by RemoteSequenceManager in a background thread - accessed by routing strategies in .on_update :note: this class should *not* be modified by RoutingStrategy.on_update to avoid interference between strategies; Any metadata specific to one routing strategy, it should be stored inside that strategy. Any information that is used by most routing strategies should be moved from said strategies to this class. """ block_uids: Tuple[ModuleUID, ...] block_infos: Tuple[RemoteModuleInfo, ...] # note: the contents of RemoteModuleInfo can and will be updated spans_by_priority: List[RemoteSpanInfo] spans_containing_block: Tuple[List[RemoteSpanInfo], ...] last_updated_time: Optional[float] @classmethod def make_empty(cls, block_uids: Iterable[ModuleUID]) -> "RemoteSequenceInfo": block_uids = tuple(block_uids) empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids) empty_spans = tuple([] for _ in range(len(block_uids))) return cls(block_uids, empty_block_infos, [], empty_spans, last_updated_time=None) def __getitem__(self, ix: slice): assert isinstance(ix, slice) block_uids, block_infos = self.block_uids[ix], self.block_infos[ix] spans_by_priority, spans_containing_block = self._sort_spans(block_infos) return RemoteSequenceInfo( block_uids, block_infos, spans_by_priority, spans_containing_block, self.last_updated_time ) def __len__(self): return len(self.block_uids) def update_(self, new_block_infos: List[RemoteModuleInfo]): assert len(new_block_infos) == len(self.block_uids) for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)): assert uid == info.uid, f"The DHT entry for {uid} actually points to {info.uid}" self.block_infos[block_index].servers = info.servers self.spans_by_priority, self.spans_containing_block = self._sort_spans(self.block_infos) self.last_updated_time = time.perf_counter() @staticmethod def _sort_spans(block_infos: List[RemoteModuleInfo]): spans_by_priority = list(compute_spans(block_infos, min_state=ServerState.ONLINE).values()) spans_by_priority.sort(key=lambda span: span.length, reverse=True) spans_containing_block = tuple([] for _ in range(len(block_infos))) for span in spans_by_priority: for block_index in range(span.start, span.end): spans_containing_block[block_index].append(span) return spans_by_priority, spans_containing_block ================================================ FILE: src/petals/client/routing/sequence_manager.py ================================================ from __future__ import annotations import asyncio import dataclasses import itertools import logging import random import threading import time import warnings from typing import Any, Dict, List, Optional, Sequence, Set, Union from weakref import WeakMethod import dijkstar import numpy as np from hivemind import DHT, P2P, MSGPackSerializer, PeerID from hivemind.dht.node import Blacklist from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.proto import runtime_pb2 from hivemind.utils.logging import get_logger from petals.client.config import ClientConfig from petals.client.routing.sequence_info import RemoteSequenceInfo from petals.client.routing.spending_policy import NoSpendingPolicy from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState from petals.server.handler import TransformerConnectionHandler from petals.utils.dht import get_remote_module_infos from petals.utils.ping import PingAggregator from petals.utils.random import sample_up_to logger = get_logger(__name__) class SequenceManagerConfig(ClientConfig): def __init__(self, *args, **kwargs): warnings.warn( "petals.client.routing.SequenceManagerConfig has been moved to petals.ClientConfig. " "This alias will be removed in Petals 2.2.0+", DeprecationWarning, stacklevel=2, ) super().__init__(*args, **kwargs) @dataclasses.dataclass class SequenceManagerState: p2p: P2P = None sequence_info: Optional[RemoteSequenceInfo] = None rpc_info: Optional[dict] = None banned_peers: Optional[Blacklist] = None def __getitem__(self, ix: Union[int, slice]) -> SequenceManagerState: return dataclasses.replace(self, sequence_info=self.sequence_info[ix]) def __len__(self) -> int: return len(self.sequence_info) class RemoteSequenceManager: """ Sequence manager is a thread that keeps track of remote servers that hold the specified sequence of blocks. TL;DR it tells you, which peers you should ask to get a specific layer. It is used in RemoteSequential. When created, RemoteSequenceManager looks up which servers serve necessary layers by reading from DHT. Using this information, sequence manager can form sequences of servers that collectively have the full sequence. To form such a sequence, call .make_sequence with the appropriate optimization policy (see make_sequence docstr). :note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid running redundant sequence managers for the same set of layers. """ def __init__( self, config: ClientConfig, block_uids: Sequence[ModuleUID], *, dht: Optional[DHT] = None, state: Optional[SequenceManagerState] = None, ): assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`" assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..." assert len(block_uids) > 0, "Sequences must contain at least one block" self.config = config if state is None: state = SequenceManagerState() self.state = state if dht is None: dht = DHT( initial_peers=config.initial_peers, client_mode=True, num_workers=32, startup_timeout=config.daemon_startup_timeout, start=True, ) assert isinstance(dht, DHT) and dht.is_alive(), "`dht` must be a running hivemind.DHT instance" self.dht = dht if state.p2p is None: state.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) self.lock_changes = threading.Lock() self._thread = _SequenceManagerUpdateThread(config.update_period, WeakMethod(self._update)) self._thread_start_lock = threading.Lock() self.policy = NoSpendingPolicy() self.allowed_servers = self._peer_ids_to_set(config.allowed_servers) self.blocked_servers = self._peer_ids_to_set(config.blocked_servers) self.ping_aggregator = PingAggregator(dht) if state.banned_peers is None: state.banned_peers = Blacklist(base_time=config.ban_timeout, backoff_rate=2.0) if state.sequence_info is None: state.sequence_info = RemoteSequenceInfo.make_empty(block_uids) if state.sequence_info.last_updated_time is not None: assert block_uids == state.sequence_info.block_uids self._thread.ready.set() # no need to await the first dht fetch @staticmethod def _peer_ids_to_set(peer_ids: Optional[Sequence[Union[PeerID, str]]]) -> Optional[Set[PeerID]]: if peer_ids is None: return None result = set() for peer_id in peer_ids: if isinstance(peer_id, PeerID): result.add(peer_id) elif isinstance(peer_id, str): result.add(PeerID.from_base58(peer_id)) else: raise TypeError( f"`allowed_servers` and `blocked_servers` have to contain only PeerIDs or strings, but got {type(peer_id)}" ) return result def make_sequence( self, start_index: int = 0, end_index: Optional[int] = None, *, mode: str, cache_tokens_needed: Optional[int] = None, ) -> List[RemoteSpanInfo]: """ Form a sequence of remote servers that collectively serve all consecutive layers :param start_index: optional index of the first module in a sequence, default = the first of block_uids :param end_index: optional index of the last module (non-inclusive), default = after last of block uids :param mode: one of ["max_throughput", "min_latency"] """ with self._thread_start_lock: if not self.is_alive(): self._thread.start() if not self.ready.is_set(): self.update(wait=True) # this will await an existing update or trigger a new one (if not updating) end_index = end_index if end_index is not None else len(self) if mode == "min_latency": span_sequence = self._make_sequence_with_min_latency( start_index, end_index, cache_tokens_needed=cache_tokens_needed ) elif mode == "max_throughput": span_sequence = self._make_sequence_with_max_throughput(start_index, end_index) else: raise RuntimeError(f"Unexpected mode {mode}") if self.config.show_route is True or (mode == "min_latency" and self.config.show_route == "inference"): route_repr = " => ".join( [f"{span.start}:{span.end} via …{str(span.peer_id)[-6:]}" for span in span_sequence] ) logger.info(f"Route found: {route_repr}") return span_sequence def _make_sequence_with_min_latency( self, start_index: int, end_index: int, *, cache_tokens_needed: Optional[int] ) -> List[RemoteSpanInfo]: if start_index == end_index: return [] with self.lock_changes: missing_blocks = [ block_idx for block_idx in range(start_index, end_index) if not self.state.sequence_info.spans_containing_block[block_idx] ] if missing_blocks: raise MissingBlocksError(missing_blocks) server_infos = { span.peer_id: span.server_info for block_idx in range(start_index, end_index) for span in self.state.sequence_info.spans_containing_block[block_idx] } graph = self._build_inference_graph(start_index, end_index, cache_tokens_needed=cache_tokens_needed) path = dijkstar.find_path(graph, "start", "end") logger.debug(f"Path info: {path}") if start_index == 0 and end_index == len(self): logger.debug(f"Expected speed: {1 / path.total_cost:.1f} steps/sec") span_sequence = [] for peer_id, block_idx in path.nodes[1:-1]: if not span_sequence or span_sequence[-1].peer_id != peer_id: span_sequence.append(RemoteSpanInfo(peer_id, block_idx, block_idx, server_infos[peer_id])) else: span_sequence[-1].end = block_idx # Remove empty spans that can appear if we don't force to go to the end of each server and network delay # don't follow triangle inequality (delay(A, B) + delay(B, C) < delay(A, C)) due to measurement errors span_sequence = [span for span in span_sequence if span.length > 0] return span_sequence def _build_inference_graph( self, start_index: int, end_index: int, *, cache_tokens_needed: Optional[int], overhead_delay: float = 0.018, # Serialization overhead (empirically measured) default_inference_rps: float = 300, # If inference RPS unknown alloc_delay: float = 10, # If not enough cache left, we penalize the edge ) -> dijkstar.Graph: missing_blocks = [ block_idx for block_idx in range(start_index, end_index) if not self.state.sequence_info.spans_containing_block[block_idx] ] if missing_blocks: raise MissingBlocksError(missing_blocks) client_server_rtts = self.ping_aggregator.to_dict() graph = dijkstar.Graph() # Clent -> server network delays for span in self.state.sequence_info.spans_containing_block[start_index]: delay = self._rtt_to_delay(client_server_rtts.get(span.peer_id)) delay += overhead_delay if not self._has_cache_for(span, cache_tokens_needed): delay += alloc_delay graph.add_edge("start", (span.peer_id, start_index), delay) # Server -> client network delays for span in self.state.sequence_info.spans_containing_block[end_index - 1]: delay = self._rtt_to_delay(client_server_rtts.get(span.peer_id)) graph.add_edge((span.peer_id, end_index), "end", delay) # Server -> server network delays for block_idx in range(start_index + 1, end_index): for cur_span in self.state.sequence_info.spans_containing_block[block_idx - 1]: if cur_span.end != block_idx: # If we choose a server, we force to go to the end of it before switching to a new one # to avoid O(N^2) graphs for N servers continue for next_span in self.state.sequence_info.spans_containing_block[block_idx]: rtt = None if cur_span.server_info.next_pings is not None: rtt = cur_span.server_info.next_pings.get(next_span.peer_id.to_base58()) delay = self._rtt_to_delay(rtt) delay += overhead_delay if not self._has_cache_for(next_span, cache_tokens_needed): delay += alloc_delay graph.add_edge((cur_span.peer_id, block_idx), (next_span.peer_id, block_idx), delay) # Compute delays for span in self.state.sequence_info.spans_by_priority: for block_idx in range(max(span.start, start_index), min(span.end, end_index)): inference_rps = span.server_info.inference_rps if inference_rps is None: inference_rps = default_inference_rps graph.add_edge((span.peer_id, block_idx), (span.peer_id, block_idx + 1), 1.0 / inference_rps) return graph @staticmethod def _rtt_to_delay( rtt: float, *, default_delay: float = 0.15, # If network delay unknown max_delay: float = 5, # If unreachable, we don't want to discard the edge completely ) -> float: if rtt is None: return default_delay return min(rtt / 2, max_delay) @staticmethod def _has_cache_for(span: RemoteSpanInfo, cache_tokens_needed: Optional[int] = None) -> bool: if cache_tokens_needed is None or span.server_info.cache_tokens_left is None: return True # Here, `span` contains all blocks hosted by a server - but we won't necessarily run all of them through # this particular server in our path. It is difficult to estimate how many blocks we'll use at this stage, # so we assume that we'll use all of them (the worst case for the cache size) and get a pessimistic estimate. # This is okay since false positives are more costly than false negatives here. return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]: client_server_rtts = self.ping_aggregator.to_dict() span_sequence = [] current_index = start_index while current_index < end_index: candidate_spans = self.state.sequence_info.spans_containing_block[current_index] if not candidate_spans: raise MissingBlocksError(current_index) # We choose longer servers to minimize the number of hops but leave some randomization # to distribute the load. We also exclude servers known to be unreachable. eps = 1e-6 span_weights = np.array( [span.length if client_server_rtts.get(span.peer_id) != np.inf else eps for span in candidate_spans], dtype=np.float64, ) chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum()) assert chosen_span.start <= current_index < chosen_span.end span_sequence.append(dataclasses.replace(chosen_span, start=current_index)) current_index = chosen_span.end return span_sequence def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager: """Get a RemoteSequenceManager for a sub-sequence of blocks""" assert isinstance(ix, (int, slice)) if not isinstance(ix, slice): ix = slice(int(ix), int(ix) + 1, 1) return type(self)(self.config, self.block_uids[ix], dht=self.dht, state=self.state[ix]) def update(self, *, wait: bool): """Run an asynchronous update in background as soon as possible""" self.ready.clear() self._thread.trigger.set() if wait: self.ready.wait() def _update(self): """Perform an immediate and synchronous refresh, may take time""" new_block_infos = get_remote_module_infos( self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=True ) for block_info in new_block_infos: # Apply allow and block lists block_info.servers = { peer_id: server_info for peer_id, server_info in block_info.servers.items() if (self.allowed_servers is None or peer_id in self.allowed_servers) and (self.blocked_servers is None or peer_id not in self.blocked_servers) } # Remove temporarily banned peers, unless there are no peers left valid_servers = { peer_id: server_info for peer_id, server_info in block_info.servers.items() if peer_id not in self.state.banned_peers } if len(valid_servers) < len(block_info.servers): if valid_servers: logger.debug( f"Kept {len(valid_servers)} out of {len(block_info.servers)} servers holding {block_info.uid}" ) block_info.servers = valid_servers else: # If we blacklisted all servers, the error may actually be client-caused logger.debug(f"All servers holding {block_info.uid} are blacklisted, ignoring blacklist") with self.lock_changes: self.state.sequence_info.update_(new_block_infos) first_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[0]] middle_servers = [ span.peer_id for spans in self.state.sequence_info.spans_containing_block[1:-1] for span in spans ] last_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[-1]] pinged_servers = set(sample_up_to(first_servers, self.config.max_pinged)) pinged_servers = set(sample_up_to(middle_servers, self.config.max_pinged)) pinged_servers |= set(sample_up_to(last_servers, self.config.max_pinged)) self.ping_aggregator.ping(list(pinged_servers), wait_timeout=self.config.ping_timeout) self.ready.set() def on_request_failure(self, peer_id: Optional[PeerID]): """remove a given peer from the routing table. If the routing is no longer possible, trigger an update""" if peer_id is not None: logger.debug(f"Peer {peer_id} did not respond, banning it temporarily") self.state.banned_peers.register_failure(peer_id) with self.lock_changes: should_update = False for info in self.state.sequence_info.block_infos: info.servers.pop(peer_id, None) if not info.servers: should_update = True if should_update: self.ready.clear() self.update(wait=False) def on_request_success(self, peer_id: PeerID): """if peer has a failure streak, clear that streak""" self.state.banned_peers.register_success(peer_id) def __len__(self): return len(self.block_uids) @property def is_alive(self): return self._thread.is_alive @property def ready(self) -> threading.Event: return self._thread.ready @property def block_uids(self): return self.state.sequence_info.block_uids @property def rpc_info(self): """Return the rpc_info queried from one of the servers that hold the first block""" if self.state.rpc_info is not None: return self.state.rpc_info with self._thread_start_lock: if not self.is_alive(): self._thread.start() for attempt_no in itertools.count(): peer_id = None try: if not self.ready.is_set(): self.update(wait=True) active_servers = [ peer_id for peer_id, server in self.state.sequence_info.block_infos[0].servers.items() if server.state == ServerState.ONLINE ] if not active_servers: raise MissingBlocksError(0) peer_id = random.choice(active_servers) stub = TransformerConnectionHandler.get_stub(self.state.p2p, peer_id) outputs = RemoteExpertWorker.run_coroutine( stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]), timeout=self.config.request_timeout) ) self.state.rpc_info = MSGPackSerializer.loads(outputs.serialized_info) self.on_request_success(peer_id) break except Exception as e: self.on_request_failure(peer_id) if attempt_no + 1 == self.config.max_retries: raise delay = self.get_retry_delay(attempt_no) logger.warning( f"Caught exception when gathering information from peer {peer_id} " f"(retry in {delay:.0f} sec): {repr(e)}" ) maybe_log_traceback(e) time.sleep(delay) return self.state.rpc_info def get_retry_delay(self, attempt_no: int) -> float: if attempt_no == 0: return 0 return min(self.config.min_backoff * 2 ** (attempt_no - 1), self.config.max_backoff) def get_request_metadata( self, protocol: str, args_structure: Any = None, *args, **kwargs ) -> Optional[Dict[str, Any]]: """ :param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference" :param args_structure: the structure of flattened tensors from pack_args_kwargs in petals.utils.packaging :param args: request-specific inputs, typically block uids and input tensors :param kwargs: additional request context, such as remote peer ID :returns: msgpack-serialized metadata dict that will be passed alongside a given request """ return dict( points=self.policy.get_points(protocol, *args, **kwargs), active_adapter=self.config.active_adapter, args_structure=args_structure, ) def shutdown(self): self._thread.shutdown() class _SequenceManagerUpdateThread(threading.Thread): def __init__(self, update_period: float, ref_update_manager: WeakMethod): super().__init__(daemon=True) self.ref_update_manager = ref_update_manager self.ready = threading.Event() self.trigger = threading.Event() self.update_period = update_period self.should_shutdown = False def run(self) -> None: while not self.should_shutdown: update_manager = self.ref_update_manager() if update_manager is None: logger.debug(f"{self.__class__.__name__} exited because the sequence manager no longer exists") break try: self.trigger.clear() update_manager() except Exception as e: logger.exception(e) finally: del update_manager self.trigger.wait(self.update_period) logger.debug(f"{self.__class__.__name__} thread exited") def shutdown(self, timeout: Optional[float] = None): self.should_shutdown = True self.trigger.set() if self.is_alive(): self.join(timeout) def __del__(self): self.shutdown() def maybe_log_traceback(exc: Exception): traceback_level = logging.DEBUG if str(exc) or isinstance(exc, asyncio.TimeoutError) else logging.WARNING logger.log(traceback_level, "See detailed traceback below:", exc_info=True) class MissingBlocksError(RuntimeError): def __init__(self, block_indices: Union[int, Sequence[int]]): super().__init__( f"No servers holding blocks {block_indices} are online. " f"You can check the public swarm's state at https://health.petals.dev " f"If there are not enough servers, please connect your GPU: " f"https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity " ) ================================================ FILE: src/petals/client/routing/spending_policy.py ================================================ """ An interface for exchanging internal "BLOOM points" for higher priority compute requests. NOT IMPLEMENTED. The intent is to let Petals participants earn points by helping others while idle (e.g. at night), then use these points to run their own compute experiments faster. See Section 4 of https://arxiv.org/abs/2209.01188 for discussion. """ from abc import ABC, abstractmethod class SpendingPolicyBase(ABC): @abstractmethod def get_points(self, protocol: str, *args, **kwargs) -> float: pass class NoSpendingPolicy(SpendingPolicyBase): def get_points(self, protocol: str, *args, **kwargs) -> float: return 0.0 ================================================ FILE: src/petals/client/sequential_autograd.py ================================================ """ A PyTorch autograd function that runs forward/backward on a sequence of remote servers in a fault-tolerant manner """ import asyncio import itertools from collections import deque from typing import List, Optional, Sequence, Tuple import torch from hivemind import MSGPackSerializer from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.utils.logging import get_logger from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward from petals.client.routing import RemoteSequenceManager, maybe_log_traceback from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo from petals.server.handler import TransformerConnectionHandler from petals.utils.misc import DUMMY, is_dummy from petals.utils.packaging import pack_args_kwargs logger = get_logger(__name__) MAX_TOKENS_IN_BATCH = 1024 async def sequential_forward( inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None, ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]: """ Constructs a routing path from to . Performs chained forward for each subsequence of blocks on the path. If some subsequence fails, reconstructs the remaining path and tries to finish the forward. """ assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}" inputs_device = inputs.device inputs_dtype = inputs.dtype inputs = inputs.cpu() prompts = prompts.cpu() end_index = end_index if end_index is not None else len(sequence_manager.block_uids) assert start_index >= 0 and end_index <= len(sequence_manager.block_uids) assert is_dummy(prompts) or len(prompts) == len( sequence_manager.block_uids ) # should be n_layers - 1 but add extra prompts for convenience sequences = deque() intermediate_inputs = [] done_sequences = [] block_idx = start_index while block_idx < end_index: for attempt_no in itertools.count(): logger.debug(f"Forward: block {block_idx}, attempt {attempt_no}") span = None try: if not sequences or attempt_no >= 1: sequences = deque(sequence_manager.make_sequence(block_idx, end_index, mode="max_throughput")) # make_sequence() could return a longer sequence sequences[-1].end = min(sequences[-1].end, end_index) logger.debug(f"Found path from block {block_idx} to {end_index} via {len(sequences)} servers") span = sequences.popleft() stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id) flat_tensors, args_structure = pack_args_kwargs(inputs, prompts[span.start : span.end]) span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) metadata = sequence_manager.get_request_metadata( "rpc_forward", args_structure, span_uids, *flat_tensors ) (outputs,) = await run_remote_forward( span_uids, stub, sequence_manager.rpc_info, *flat_tensors, config=sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata), ) assert isinstance(outputs, torch.Tensor) assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}" # Save intermediate inputs and subsequences if the forward is already done for them intermediate_inputs.append(inputs) done_sequences.append(span) inputs = outputs block_idx = span.end sequence_manager.on_request_success(span.peer_id) break except Exception as e: sequence_manager.on_request_failure(span.peer_id if span is not None else None) if attempt_no + 1 == sequence_manager.config.max_retries: raise delay = sequence_manager.get_retry_delay(attempt_no) logger.warning( f"Caught exception when running forward via {span} (retry in {delay:.0f} sec): {repr(e)}" ) maybe_log_traceback(e) await asyncio.sleep(delay) outputs = inputs.to(device=inputs_device, dtype=inputs_dtype) intermediate_inputs = [tensor.to(device=inputs_device, dtype=inputs_dtype) for tensor in intermediate_inputs] return outputs, intermediate_inputs, done_sequences async def sequential_backward( grad_outputs: Sequence[torch.Tensor], intermediate_inputs: List[torch.Tensor], prompts: torch.Tensor, forward_sequences: List[RemoteSpanInfo], sequence_manager: RemoteSequenceManager, ) -> Tuple[Sequence[torch.Tensor], torch.Tensor]: """ Performs chained backward for each forward subsequence. If some subsequence fails, reconstructs the particular sub-path and recovers the backward. """ assert len(intermediate_inputs) == len(forward_sequences) grad_outputs_device = grad_outputs[0].device if grad_outputs else None grad_outputs_dtype = grad_outputs[0].dtype if grad_outputs else None prompts_device = prompts.device prompts_dtype = prompts.dtype grad_outputs = [tensor.cpu() for tensor in grad_outputs] intermediate_inputs = [tensor.cpu() for tensor in intermediate_inputs] prompts = prompts.cpu() grad_prompts_reversed = [] while len(forward_sequences) > 0 and len(intermediate_inputs) > 0: inputs = intermediate_inputs.pop() span = forward_sequences.pop() for attempt_no in itertools.count(): logger.debug(f"Backward: block {span.end - 1}, attempt {attempt_no}") try: if attempt_no >= 1: _, backup_inputs, backup_sequences = await sequential_forward( inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end ) assert len(backup_inputs) == len(backup_sequences) assert backup_sequences[0].start == span.start assert backup_sequences[-1].end == span.end intermediate_inputs.extend(backup_inputs) forward_sequences.extend(backup_sequences) inputs = intermediate_inputs.pop() span = forward_sequences.pop() grad_outputs_cpu = [grad.cpu() for grad in grad_outputs] flat_tensors, args_structure = pack_args_kwargs( inputs, *grad_outputs_cpu, prompts[span.start : span.end] ) span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id) metadata = sequence_manager.get_request_metadata( "rpc_backward", args_structure, span_uids, *flat_tensors, peer_id=span.peer_id ) grad_outputs, *span_grad_prompts = await run_remote_backward( span_uids, stub, sequence_manager.rpc_info, *flat_tensors, config=sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata), ) grad_outputs = [grad_outputs] grad_prompts_reversed.extend(span_grad_prompts) sequence_manager.on_request_success(span.peer_id) break except Exception as e: sequence_manager.on_request_failure(span.peer_id if span is not None else None) if attempt_no + 1 == sequence_manager.config.max_retries: raise delay = sequence_manager.get_retry_delay(attempt_no) logger.warning( f"Caught exception when running backward via {span} (retry in {delay:.0f} sec): {repr(e)}" ) maybe_log_traceback(e) await asyncio.sleep(delay) # For now, we do not support mixed dummy and grad prompts # Concat in num_layer dimension grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None if grad_outputs_dtype is not None: grad_outputs = [tensor.to(device=grad_outputs_device, dtype=grad_outputs_dtype) for tensor in grad_outputs] if grad_prompts is not None: grad_prompts = grad_prompts.to(device=prompts_device, dtype=prompts_dtype) return grad_outputs, grad_prompts async def _gather_forward(input_batches, prompt_batches, sequence_manager): """Wrapper for asyncio.gather to perform parallel sequential forwards""" return await asyncio.gather( *[ sequential_forward(input_batch, prompt_batch, sequence_manager) for input_batch, prompt_batch in zip(input_batches, prompt_batches) ] ) async def _gather_backward( grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences, sequence_manager ): """Wrapper for asyncio.gather to perform parallel sequential backwards""" return await asyncio.gather( *[ sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager) for grad_output, input_batch, prompt_batch, spans in zip( grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences ) ] ) class _RemoteSequentialAutogradFunction(torch.autograd.Function): """ PyTorch autograd function that provides forward and backward calls for the entire sequence of remote transformer blocks. This function splits input data into batches with and performs efficient parallel processing. """ @staticmethod def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager): batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1) input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size) if prompts is None or is_dummy(prompts): prompt_batches = [DUMMY] * len(input_batches) else: prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1) sequence_manager.rpc_info # lazy init outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager)) assert len(outputs) == len(input_batches) output_batches = [output[0] for output in outputs] intemediate_input_batches = [output[1] for output in outputs] sequences_for_batches = [output[2] for output in outputs] ctx.prompt_batches = prompt_batches ctx.sequence_manager = sequence_manager ctx.intemediate_input_batches = intemediate_input_batches ctx.sequences_for_batches = sequences_for_batches return torch.cat(output_batches, dim=0) @staticmethod def backward(ctx, grad_outputs: torch.Tensor): intermediate_input_batches: List[Sequence[torch.Tensor]] = ctx.intemediate_input_batches forward_sequences: List[Sequence[RemoteSpanInfo]] = ctx.sequences_for_batches ctx.sequence_manager.rpc_info # lazy init batch_size = max(MAX_TOKENS_IN_BATCH // grad_outputs.shape[1], 1) grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size) assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences) outputs = RemoteExpertWorker.run_coroutine( _gather_backward( grad_output_batches, intermediate_input_batches, ctx.prompt_batches, forward_sequences, ctx.sequence_manager, ) ) grad_input_batches = [output[0][0] for output in outputs] grad_prompt_batches = [output[1] for output in outputs] grad_inputs = torch.cat(grad_input_batches, dim=0) dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches] grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None return (grad_inputs, grad_prompts, None) ================================================ FILE: src/petals/constants.py ================================================ import torch PUBLIC_INITIAL_PEERS = [ # IPv4 DNS addresses "/dns/bootstrap1.petals.dev/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY", "/dns/bootstrap2.petals.dev/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5", # IPv6 DNS addresses "/dns6/bootstrap1.petals.dev/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY", "/dns6/bootstrap2.petals.dev/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5", # Reserved IPs "/ip4/159.89.214.152/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY", "/ip4/159.203.156.48/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5", ] # The reachability API is currently used only when connecting to the public swarm REACHABILITY_API_URL = "https://health.petals.dev" DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto") ================================================ FILE: src/petals/data_structures.py ================================================ import dataclasses from enum import Enum from typing import Any, Dict, Optional, Sequence, Tuple import pydantic.v1 as pydantic from hivemind import PeerID from hivemind.moe.expert_uid import ExpertUID ModuleUID = str UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention" CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4" def parse_uid(uid: ModuleUID) -> Tuple[str, int]: assert CHAIN_DELIMITER not in uid, "parse_uid() does not support chained UIDs" dht_prefix, index = uid.split(UID_DELIMITER) return dht_prefix, int(index) @pydantic.dataclasses.dataclass class ModelInfo: num_blocks: pydantic.conint(ge=1, strict=True) repository: Optional[str] = None def to_dict(self) -> dict: return dataclasses.asdict(self) @classmethod def from_dict(cls, source: dict): return cls(**source) class ServerState(Enum): OFFLINE = 0 JOINING = 1 ONLINE = 2 RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True) @pydantic.dataclasses.dataclass class ServerInfo: state: ServerState throughput: RPS start_block: Optional[pydantic.conint(ge=0, strict=True)] = None end_block: Optional[pydantic.conint(ge=0, strict=True)] = None public_name: Optional[str] = None version: Optional[str] = None network_rps: Optional[RPS] = None forward_rps: Optional[RPS] = None inference_rps: Optional[RPS] = None adapters: Sequence[str] = () torch_dtype: Optional[str] = None quant_type: Optional[str] = None using_relay: Optional[bool] = None cache_tokens_left: Optional[pydantic.conint(ge=0, strict=True)] = None next_pings: Optional[Dict[str, pydantic.confloat(ge=0, strict=True)]] = None def to_tuple(self) -> Tuple[int, float, dict]: extra_info = dataclasses.asdict(self) del extra_info["state"], extra_info["throughput"] return (self.state.value, self.throughput, extra_info) @classmethod def from_tuple(cls, source: tuple): state, throughput = source[:2] extra_info = source[2] if len(source) > 2 else {} # pydantic will validate existing fields and ignore extra ones return cls(state=ServerState(state), throughput=throughput, **extra_info) @dataclasses.dataclass class RemoteModuleInfo: """A remote module that is served by one or more servers""" uid: ModuleUID servers: Dict[PeerID, ServerInfo] @dataclasses.dataclass class RemoteSpanInfo: """A chain of remote blocks served by one specific remote peer""" peer_id: PeerID start: int end: int server_info: ServerInfo @property def length(self) -> int: return self.end - self.start @property def state(self) -> ServerState: return self.server_info.state @property def throughput(self) -> float: return self.server_info.throughput RPCInfo = Dict[str, Any] Handle = int @dataclasses.dataclass(frozen=True) class InferenceMetadata: uid: ExpertUID prefix_length: int cache_handles: Tuple[Handle, ...] active_adapter: Optional[str] ================================================ FILE: src/petals/dht_utils.py ================================================ import warnings warnings.warn( "petals.dht_utils has been moved to petals.utils.dht. This alias will be removed in Petals 2.2.0+", DeprecationWarning, stacklevel=2, ) from petals.utils.dht import * ================================================ FILE: src/petals/models/__init__.py ================================================ from petals.models.bloom import * from petals.models.falcon import * from petals.models.llama import * from petals.models.mixtral import * ================================================ FILE: src/petals/models/bloom/__init__.py ================================================ from petals.models.bloom.block import WrappedBloomBlock from petals.models.bloom.config import DistributedBloomConfig from petals.models.bloom.model import ( DistributedBloomForCausalLM, DistributedBloomForSequenceClassification, DistributedBloomModel, ) from petals.utils.auto_config import register_model_classes register_model_classes( config=DistributedBloomConfig, model=DistributedBloomModel, model_for_causal_lm=DistributedBloomForCausalLM, model_for_sequence_classification=DistributedBloomForSequenceClassification, ) ================================================ FILE: src/petals/models/bloom/block.py ================================================ """ Bloom intermediate layer Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b See commit history for authorship. """ from typing import Optional, Tuple import torch from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.bloom.modeling_bloom import BloomBlock, build_alibi_tensor from petals.utils.misc import is_dummy class WrappedBloomBlock(BloomBlock): def forward( self, hidden_states: torch.Tensor, *args, attention_mask: Optional[torch.Tensor] = None, alibi: Optional[torch.Tensor] = None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs ): assert attention_mask is None, "Non-causal attention masks are not supported yet" batch_size, seq_length = hidden_states.shape[:2] if layer_past is not None and is_dummy(layer_past[0]): # Bloom cannot use cache if it was misconsctructed(e.g. Dummy tensors) # In this case, fallback to the old code: layer_past = None past_length = 0 if layer_past is None else layer_past[0].shape[-1] seq_length_with_past = seq_length + past_length attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) if alibi is None: alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) attention_mask = _prepare_4d_causal_attention_mask( attention_mask=attention_mask, input_shape=(batch_size, seq_length), inputs_embeds=hidden_states, past_key_values_length=past_length, ) attention_mask = attention_mask.bool() return super().forward( hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs ) ================================================ FILE: src/petals/models/bloom/config.py ================================================ import os from typing import Optional, Union from hivemind import get_logger from transformers.models.bloom import BloomConfig from transformers.models.bloom.modeling_bloom import BloomAttention from petals.client.config import ClientConfig from petals.client.lm_head import LMHeadConfig from petals.client.ptune import PTuneConfig from petals.models.bloom.block import WrappedBloomBlock logger = get_logger(__name__) class DistributedBloomConfig(BloomConfig, ClientConfig, PTuneConfig, LMHeadConfig): block_class = WrappedBloomBlock attn_class = BloomAttention block_prefix = "h" num_key_value_groups = 1 @classmethod def from_pretrained( cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs ): logger.info("Make sure you follow the BLOOM terms of use: https://bit.ly/bloom-license") loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) if loading_from_repo and dht_prefix is None: # We need "-petals" for backward compatibility with Petals < 1.2.0 dht_prefix = str(model_name_or_path) + "-petals" dht_prefix = dht_prefix.replace(".", "-") logger.info(f"Using DHT prefix: {dht_prefix}") return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) ================================================ FILE: src/petals/models/bloom/model.py ================================================ from typing import Optional import hivemind import torch import torch.nn as nn from hivemind.utils.logging import get_logger from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel from petals.client.from_pretrained import FromPretrainedMixin from petals.client.lm_head import LMHead from petals.client.ptune import PTuneMixin from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues from petals.client.remote_sequential import RemoteSequential from petals.models.bloom.config import DistributedBloomConfig logger = get_logger(__name__) class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel): """BloomModel, but all transformer layers are hosted by the swarm""" _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing _keys_to_ignore_on_load_unexpected = [r"^h\."] config_class = DistributedBloomConfig def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None): n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization super().__init__(config) assert len(self.h) == 0 config.num_hidden_layers = n_layer self.h = RemoteSequential(config, dht=dht) self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm self.init_prompts(config) def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[RemotePastKeyValues] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") # The causal mask will be added on the server-side assert ( attention_mask is None or (attention_mask == 1).all() ), f"Custom attention masks are not supported, {attention_mask=}" assert head_mask is None, f"Custom head masks are not supported, {head_mask=}" assert use_cache is None or use_cache, f"{use_cache=} is not supported" assert not output_attentions, f"{output_attentions=} is not supported" assert not output_hidden_states, f"{output_hidden_states=} is not supported" assert return_dict is None or return_dict, f"{return_dict=} is not supported" if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0 if use_prompts: batch_size = inputs_embeds.shape[0] prompts, intermediate_prompts = self.get_prompt(batch_size) inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) else: prompts = intermediate_prompts = None hidden_states = self.word_embeddings_layernorm(inputs_embeds) output_shape = input_shape + (hidden_states.size(-1),) hidden_states = self.h( hidden_states, prompts=intermediate_prompts, hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None, ) # Remove prefix if use_prompts: hidden_states = hidden_states[:, self.pre_seq_len :] if past_key_values is None: past_key_values = RemotePastKeyValues() past_key_values.update_seen(hidden_states.size(1)) # Add last hidden state hidden_states = self.ln_f(hidden_states) hidden_states = hidden_states.view(output_shape) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=None, attentions=None, ) class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, BloomForCausalLM): _keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing _keys_to_ignore_on_load_missing += [r"^lm_head\."] # Missing since they are shared with input embeddings _keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected _supports_cache_class = True config_class = DistributedBloomConfig def __init__(self, config: DistributedBloomConfig): BloomPreTrainedModel.__init__(self, config) self.transformer = DistributedBloomModel(config) self.lm_head = LMHead(config) # Initialize weights and apply final processing self.post_init() def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ) -> dict: # Omit tokens covered by past_key_values if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values._seen_tokens max_cache_length = past_key_values.get_max_length() else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] if ( max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, } ) return model_inputs def _temporary_reorder_cache(self, past_key_values, beam_idx): return past_key_values def get_output_embeddings(self): return self.lm_head class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSequenceClassification): _keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing _keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected config_class = DistributedBloomConfig def __init__(self, config: DistributedBloomConfig): BloomPreTrainedModel.__init__(self, config) self.num_labels = config.num_labels self.transformer = DistributedBloomModel(config) self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() ================================================ FILE: src/petals/models/falcon/__init__.py ================================================ from petals.models.falcon.block import WrappedFalconBlock from petals.models.falcon.config import DistributedFalconConfig from petals.models.falcon.model import ( DistributedFalconForCausalLM, DistributedFalconForSequenceClassification, DistributedFalconModel, ) from petals.utils.auto_config import register_model_classes register_model_classes( config=DistributedFalconConfig, model=DistributedFalconModel, model_for_causal_lm=DistributedFalconForCausalLM, model_for_sequence_classification=DistributedFalconForSequenceClassification, ) ================================================ FILE: src/petals/models/falcon/block.py ================================================ """ Falcon intermediate layer Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py See commit history for authorship. """ import math from functools import partial from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from transformers.models.falcon.modeling_falcon import ( FalconAttention, FalconConfig, FalconDecoderLayer, FalconLinear, FalconMLP, FalconModel, LayerNorm, build_alibi_tensor, dropout_add, rotate_half, ) KVCache = Tuple[torch.Tensor, torch.Tensor] INFERENCE_MAX_LENGTH = 8192 def apply_rotary(query, key, cos, sin): return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin) class OptimizedFalconRotaryEmbedding(nn.Module): def __init__(self, head_dim: int, base=10000): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self.head_dim = head_dim self.seq_len_cached = -1 self.cuda_graph = None self.input_surface = None self.static_outputs = None def _optimized_apply_rotary(self, query, key, cos, sin): if self.cuda_graph is None: self.cuda_graph = torch.cuda.CUDAGraph() self.input_surface = (query, key, cos, sin) s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): apply_rotary(*self.input_surface) torch.cuda.current_stream().wait_stream(s) with torch.cuda.graph(self.cuda_graph): self.static_outputs = apply_rotary(*self.input_surface) inputs = (query, key, cos, sin) for static_input, data in zip(self.input_surface, inputs): static_input.copy_(data) self.cuda_graph.replay() return tuple(o.detach() for o in self.static_outputs) def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor: total_length = seq_len + past_key_values_length if self.seq_len_cached == -1: # warm up the cache total_length = max(INFERENCE_MAX_LENGTH, total_length) if total_length > self.seq_len_cached: with torch.inference_mode(False): self.seq_len_cached = total_length t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1).to(device) if dtype in [torch.float16, torch.bfloat16]: emb = emb.float() self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin()[None, :, :].type(dtype), persistent=False) return ( self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype), self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype), ) def forward(self, query, key, past_key_values_length=0): batch, seq_len, head_dim = query.shape cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype) if seq_len == 1 and torch.is_inference_mode_enabled() and query.device.type == "cuda": return self._optimized_apply_rotary(query, key, cos, sin) else: return apply_rotary(query, key, cos, sin) def split_heads( fused_qkv: torch.Tensor, num_heads: int, num_kv_heads: int, head_dim: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: batch, seq_len, _ = fused_qkv.shape qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim) query, key, value = torch.split(qkv, [num_heads // num_kv_heads, 1, 1], dim=3) key = torch.broadcast_to(key, query.shape) value = torch.broadcast_to(value, query.shape) query, key, value = [x.flatten(2, 3) for x in (query, key, value)] return query, key, value class OptimizedFalconAttention(FalconAttention): def __init__(self, config: FalconConfig): nn.Module.__init__(self) self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.split_size = self.hidden_size self.hidden_dropout = config.hidden_dropout if self.head_dim * self.num_heads != self.hidden_size: raise ValueError( f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" f" {self.num_heads})." ) self.maybe_rotary = OptimizedFalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k) # Layer-wise attention scaling self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) self.beta = self.inv_norm_factor if config.new_decoder_architecture: qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim elif config.multi_query: qkv_out_dim = self.hidden_size + 2 * self.head_dim else: qkv_out_dim = 3 * self.hidden_size self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias) self.new_decoder_architecture = config.new_decoder_architecture self.multi_query = config.multi_query self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias) self.attention_dropout = nn.Dropout(config.attention_dropout) self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1 if self.new_decoder_architecture: self._split_heads = partial( split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim ) self.split_graph = None self.input_surface = None self.static_outputs = None def _optimized_split_heads(self, fused_qkv): if self.split_graph is None: self.split_graph = torch.cuda.CUDAGraph() self.input_surface = fused_qkv s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): self._split_heads(fused_qkv) torch.cuda.current_stream().wait_stream(s) with torch.cuda.graph(self.split_graph): self.static_outputs = self._split_heads(self.input_surface) self.input_surface.copy_(fused_qkv) self.split_graph.replay() return tuple(o.detach() for o in self.static_outputs) def forward( self, hidden_states: torch.Tensor, alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, ): assert not output_attentions fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] if ( self.new_decoder_architecture and hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda" ): query_layer, key_layer, value_layer = self._optimized_split_heads(fused_qkv) else: # 3 x [batch_size, seq_length, num_heads, head_dim] (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) num_kv_heads = self.num_heads batch_size, query_length, _, _ = query_layer.shape query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim) key_layer = key_layer.transpose(1, 2).reshape( batch_size * num_kv_heads, query_length, self.head_dim, ) value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim) past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length) if layer_past is not None: past_key, past_value = layer_past # concatenate along seq_length dimension: # - key: [batch_size * self.num_heads, kv_length, head_dim] # - value: [batch_size * self.num_heads, kv_length, head_dim] key_layer = torch.cat((past_key, key_layer), dim=1) value_layer = torch.cat((past_value, value_layer), dim=1) _, kv_length, _ = key_layer.shape if use_cache: present = (key_layer, value_layer) else: present = None query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim) key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype) if alibi is None: attn_output = F.scaled_dot_product_attention( query_layer_, key_layer_, value_layer_, attn_mask=attention_mask_float, dropout_p=0.0, is_causal=False ) attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim) attn_output = attn_output.permute(0, 2, 1, 3) attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) output_tensor = self.dense(attn_output) return output_tensor, present else: matmul_result = query_layer_ @ key_layer_.transpose(-1, -2) # change view to [batch_size, num_heads, q_length, kv_length] attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] input_dtype = attention_scores.dtype # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` if input_dtype == torch.float16 or input_dtype == torch.bfloat16: attention_scores = attention_scores.to(torch.float32) # Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by # adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically # equivalent and more performant, but there might be a numerical difference. If you're reading this # and you'd like to experiment and maybe file a PR, feel free! attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) attention_logits *= self.inv_norm_factor attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype) # [batch_size, num_heads, q_length, kv_length] attention_probs = self.attention_dropout(attention_probs) if head_mask is not None: attention_probs = attention_probs * head_mask # change view [batch_size, num_heads, q_length, kv_length] attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) # matmul: [batch_size * num_heads, q_length, head_dim] context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1) # change view [batch_size, q_length, num_heads * head_dim] context_layer = self._merge_heads(context_layer) output_tensor = self.dense(context_layer) if output_attentions: return output_tensor, present, attention_probs else: return output_tensor, present class OptimizedFalconDecoderLayer(FalconDecoderLayer): def __init__(self, config: FalconConfig): nn.Module.__init__(self) hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.mlp = FalconMLP(config) self.hidden_dropout = config.hidden_dropout self.config = config self.self_attention = OptimizedFalconAttention(config) if self.config.alibi or not config.new_decoder_architecture: if config.new_decoder_architecture: # The layer norm before self-attention self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) # The layer norm before the MLP self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) else: self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) if not config.parallel_attn: self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) else: self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_graph = None self.static_input = None self.static_outputs = None def _optimized_apply_ln(self, hidden_states): if self.ln_graph is None: self.ln_graph = torch.cuda.CUDAGraph() self.static_input = hidden_states s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): self.ln_attn(hidden_states) self.ln_mlp(hidden_states) torch.cuda.current_stream().wait_stream(s) with torch.cuda.graph(self.ln_graph): ln_attn_output = self.ln_attn(hidden_states) ln_mlp_output = self.ln_mlp(hidden_states) self.static_outputs = (ln_attn_output, ln_mlp_output) self.static_input.copy_(hidden_states) self.ln_graph.replay() return tuple(o.detach() for o in self.static_outputs) def forward( self, hidden_states: torch.Tensor, alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, ): residual = hidden_states if self.config.new_decoder_architecture: if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": attention_layernorm_out, mlp_layernorm_out = self._optimized_apply_ln(hidden_states) else: attention_layernorm_out = self.ln_attn(hidden_states) mlp_layernorm_out = self.ln_mlp(hidden_states) else: attention_layernorm_out = self.input_layernorm(hidden_states) attn_outputs = self.self_attention( attention_layernorm_out, layer_past=layer_past, attention_mask=attention_mask, alibi=alibi, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, ) attention_output = attn_outputs[0] if not self.config.new_decoder_architecture: if self.config.parallel_attn: mlp_layernorm_out = attention_layernorm_out else: residual = dropout_add( attention_output, residual, self.config.attention_dropout, training=self.training ) mlp_layernorm_out = self.post_attention_layernorm(residual) outputs = attn_outputs[1:] mlp_output = self.mlp(mlp_layernorm_out) if self.config.new_decoder_architecture or self.config.parallel_attn: mlp_output += attention_output output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training) if use_cache: outputs = (output,) + outputs else: outputs = (output,) + outputs[1:] return outputs # hidden_states, present, attentions class WrappedFalconBlock(OptimizedFalconDecoderLayer): def forward( self, hidden_states: torch.Tensor, *args, attention_mask: Optional[torch.Tensor] = None, alibi: Optional[torch.Tensor] = None, layer_past: Optional[KVCache] = None, use_cache: bool = False, **kwargs, ): assert attention_mask is None batch_size, seq_length = hidden_states.shape[:2] if layer_past is not None: layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past) past_length = 0 if layer_past is None else layer_past[0].shape[1] seq_length_with_past = seq_length + past_length attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) if alibi is None and self.config.alibi: alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length) outputs = super().forward( hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, use_cache=use_cache, **kwargs, ) if use_cache: present_key_value = outputs[-1] present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value) outputs = outputs[:-1] + (present_key_value,) return outputs def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache: key_states, value_states = key_value key_states = key_states.permute(0, 2, 1) assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] if self.config.new_decoder_architecture: key_states = self._expand_states(key_states) value_states = self._expand_states(value_states) return (key_states, value_states) def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache: key_states, value_states = key_value if self.config.new_decoder_architecture: key_states = self._collapse_states(key_states) value_states = self._collapse_states(value_states) assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] key_states = key_states.permute(0, 2, 1) return (key_states, value_states) def _expand_states(self, state: torch.Tensor) -> torch.Tensor: batch_size_x_num_kv_heads, seq_len, head_dim = state.shape batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim) state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy return state def _collapse_states(self, state: torch.Tensor) -> torch.Tensor: batch_size_x_num_attn_heads, seq_len, head_dim = state.shape batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim) state = state[:, :, 0] state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim) return state ================================================ FILE: src/petals/models/falcon/config.py ================================================ import os from typing import Optional, Union from hivemind import get_logger from transformers.models.falcon import FalconConfig from transformers.models.falcon.modeling_falcon import FalconAttention from petals.client.config import ClientConfig from petals.client.lm_head import LMHeadConfig from petals.client.ptune import PTuneConfig from petals.models.falcon.block import WrappedFalconBlock from petals.utils.auto_config import DefaultRevisionMixin logger = get_logger(__name__) class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig, PTuneConfig, LMHeadConfig): block_class = WrappedFalconBlock attn_class = FalconAttention block_prefix = "transformer.h" @property def num_key_value_groups(self) -> int: if self.new_decoder_architecture: return self.num_attention_heads // self.num_kv_heads if self.multi_query: return self.num_attention_heads return 1 @classmethod def from_pretrained( cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs ): if "180B" in model_name_or_path.upper(): logger.info("Make sure you follow the Falcon-180B license: https://bit.ly/falcon-180b-license") loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) if loading_from_repo and dht_prefix is None: dht_prefix = str(model_name_or_path) dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts dht_prefix = dht_prefix.replace(".", "-") logger.info(f"Using DHT prefix: {dht_prefix}") result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) config = result[0] if isinstance(result, tuple) else result if config.pad_token_id is None: config.pad_token_id = 0 return result ================================================ FILE: src/petals/models/falcon/model.py ================================================ from typing import Optional import hivemind import torch import torch.nn as nn from hivemind.utils.logging import get_logger from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from transformers.models.falcon import ( FalconForCausalLM, FalconForSequenceClassification, FalconModel, FalconPreTrainedModel, ) from petals.client.from_pretrained import FromPretrainedMixin from petals.client.lm_head import LMHead from petals.client.ptune import PTuneMixin from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues from petals.client.remote_sequential import RemoteSequential from petals.models.falcon.config import DistributedFalconConfig from petals.utils.auto_config import DefaultRevisionMixin logger = get_logger(__name__) class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, FalconModel): """FalconModel, but all transformer layers are hosted by the swarm""" _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing _keys_to_ignore_on_load_unexpected = [r"^transformer\.h\."] config_class = DistributedFalconConfig def __init__(self, config: DistributedFalconConfig, *, dht: Optional[hivemind.DHT] = None): n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization super().__init__(config) assert len(self.h) == 0 config.num_hidden_layers = n_layer self.h = RemoteSequential(config, dht=dht) self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm self.init_prompts(config) def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[RemotePastKeyValues] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") # The causal mask will be added on the server-side assert ( attention_mask is None or (attention_mask == 1).all() ), f"Custom attention masks are not supported, {attention_mask=}" assert ( position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all() ), f"Non-consecutive position_ids are not supported, {position_ids=}" assert head_mask is None, f"Custom head masks are not supported, {head_mask=}" assert use_cache is None or use_cache, f"{use_cache=} is not supported" assert not output_attentions, f"{output_attentions=} is not supported" assert not output_hidden_states, f"{output_hidden_states=} is not supported" assert return_dict is None or return_dict, f"{return_dict=} is not supported" if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0 if use_prompts: batch_size = inputs_embeds.shape[0] prompts, intermediate_prompts = self.get_prompt(batch_size) inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) else: prompts = intermediate_prompts = None hidden_states = self.word_embeddings_layernorm(inputs_embeds) output_shape = input_shape + (hidden_states.size(-1),) hidden_states = self.h( hidden_states, prompts=intermediate_prompts, hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None, ) # Remove prefix if use_prompts: hidden_states = hidden_states[:, self.pre_seq_len :] # Add last hidden state hidden_states = self.ln_f(hidden_states) hidden_states = hidden_states.view(output_shape) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=RemotePastKeyValues(), hidden_states=None, attentions=None, ) @property def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin return nn.Identity() class DistributedFalconForCausalLM(DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, FalconForCausalLM): _keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing _keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected config_class = DistributedFalconConfig def __init__(self, config: DistributedFalconConfig): FalconPreTrainedModel.__init__(self, config) self.transformer = DistributedFalconModel(config) self.lm_head = LMHead(config) # Initialize weights and apply final processing self.post_init() def get_output_embeddings(self): return self.lm_head class DistributedFalconForSequenceClassification( DefaultRevisionMixin, FromPretrainedMixin, FalconForSequenceClassification ): _keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing _keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected config_class = DistributedFalconConfig def __init__(self, config: DistributedFalconConfig): FalconPreTrainedModel.__init__(self, config) self.num_labels = config.num_labels self.transformer = DistributedFalconModel(config) self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() ================================================ FILE: src/petals/models/llama/__init__.py ================================================ from petals.models.llama.block import WrappedLlamaBlock from petals.models.llama.config import DistributedLlamaConfig from petals.models.llama.model import ( DistributedLlamaForCausalLM, DistributedLlamaForSequenceClassification, DistributedLlamaModel, ) from petals.models.llama.speculative_model import DistributedLlamaForSpeculativeGeneration from petals.utils.auto_config import register_model_classes register_model_classes( config=DistributedLlamaConfig, model=DistributedLlamaModel, model_for_causal_lm=DistributedLlamaForCausalLM, model_for_speculative=DistributedLlamaForSpeculativeGeneration, model_for_sequence_classification=DistributedLlamaForSequenceClassification, ) ================================================ FILE: src/petals/models/llama/block.py ================================================ """ LLaMA intermediate layer Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py See commit history for authorship. """ import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaConfig, LlamaDecoderLayer, LlamaMLP, LlamaRMSNorm, repeat_kv, rotate_half, ) from petals.utils.cuda_graphs import make_inference_graphed_callable def apply_rotary_pos_emb(q, k, cos, sin): q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class OptimizedLlamaAttention(LlamaAttention): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._rotary_graph = None def _optimized_apply_rotary(self, query_states, key_states, cos, sin): if self._rotary_graph is None: self._rotary_graph = make_inference_graphed_callable( apply_rotary_pos_emb, sample_args=(query_states, key_states, cos, sin) ) return self._rotary_graph(query_states, key_states, cos, sin) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: assert not output_attentions if position_ids is None: past_seen_tokens = past_key_value[0].shape[2] if past_key_value is not None else 0 position_ids = torch.arange( past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device ).unsqueeze(0) bsz, q_len, _ = hidden_states.size() if self.config.pretraining_tp > 1: key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp query_slices = self.q_proj.weight.split( (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 ) key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] query_states = torch.cat(query_states, dim=-1) key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] key_states = torch.cat(key_states, dim=-1) value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] value_states = torch.cat(value_states, dim=-1) else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = cos.unsqueeze(1), sin.unsqueeze(1) if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin) else: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) if self.config.pretraining_tp > 1: attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) else: attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value class OptimizedLlamaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: LlamaConfig): nn.Module.__init__(self) self.hidden_size = config.hidden_size self.self_attn = OptimizedLlamaAttention(config=config, layer_idx=0) # layer_idx only matters for KV caching, and we re-implement it in Petals self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_attn_graph = None self.post_attn_graph = None def _optimized_input_layernorm(self, hidden_states): if self.pre_attn_graph is None: self.pre_attn_graph = make_inference_graphed_callable( self.input_layernorm.forward, sample_args=(hidden_states,) ) return self.pre_attn_graph(hidden_states) def _optimized_output_layernorm(self, hidden_states): if self.post_attn_graph is None: self.post_attn_graph = make_inference_graphed_callable( self.post_attention_layernorm.forward, sample_args=(hidden_states,) ) return self.post_attn_graph(hidden_states) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ residual = hidden_states if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": hidden_states = self._optimized_input_layernorm(hidden_states) else: hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": hidden_states = self._optimized_output_layernorm(hidden_states) else: hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs class WrappedLlamaBlock(OptimizedLlamaDecoderLayer): def forward( self, hidden_states: torch.Tensor, *args, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, layer_past: Optional[Tuple[torch.Tensor]] = None, use_cache: bool = False, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: batch_size, seq_length, _ = hidden_states.shape seq_length_with_past = seq_length past_key_values_length = 0 past_key_value = layer_past if past_key_value is not None: past_key_values_length = past_key_value[0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length) assert position_ids is None # embed positions if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device ) attention_mask = _prepare_4d_causal_attention_mask( attention_mask=attention_mask, input_shape=(batch_size, seq_length), inputs_embeds=hidden_states, past_key_values_length=past_key_values_length, ) outputs = super().forward( hidden_states, *args, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache, **kwargs, ) if use_cache: present_key_value = outputs[-1] present_key_value = self._reorder_cache_from_llama_to_bloom( present_key_value, batch_size, seq_length_with_past ) outputs = outputs[:-1] + (present_key_value,) return outputs def _reorder_cache_from_bloom_to_llama( self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int ) -> Tuple[torch.Tensor]: key_states, value_states = key_value key_states = key_states.permute(0, 2, 1) key_states = key_states.view( batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim ) value_states = value_states.view(*key_states.shape) return (key_states, value_states) def _reorder_cache_from_llama_to_bloom( self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int ) -> Tuple[torch.Tensor]: key_states, value_states = key_value value_states = value_states.view( batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim ) key_states = key_states.view(*value_states.shape) key_states = key_states.permute(0, 2, 1) return (key_states, value_states) ================================================ FILE: src/petals/models/llama/config.py ================================================ import os from typing import Optional, Union from hivemind import get_logger from transformers.models.llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaAttention from petals.client.config import ClientConfig from petals.client.lm_head import LMHeadConfig from petals.client.ptune import PTuneConfig from petals.models.llama.block import WrappedLlamaBlock logger = get_logger(__name__) class DistributedLlamaConfig(LlamaConfig, ClientConfig, PTuneConfig, LMHeadConfig): block_class = WrappedLlamaBlock attn_class = LlamaAttention block_prefix = "model.layers" @property def num_key_value_groups(self): return self.num_attention_heads // self.num_key_value_heads @classmethod def from_pretrained( cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs ): logger.info( "Make sure you follow the Llama terms of use: " "https://llama.meta.com/llama3/license, https://llama.meta.com/llama2/license" ) loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) if loading_from_repo and dht_prefix is None: dht_prefix = str(model_name_or_path) dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts dht_prefix = dht_prefix.replace(".", "-") if not dht_prefix.endswith("-hf"): dht_prefix += "-hf" logger.info(f"Using DHT prefix: {dht_prefix}") result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) config = result[0] if isinstance(result, tuple) else result config.pretraining_tp = 1 # This may give less accurate results but it doesn't matter if we use quantization config.use_cache = True # use_cache=False leads to identical results but is slower and not supported by Petals return result ================================================ FILE: src/petals/models/llama/model.py ================================================ from typing import Optional import hivemind import torch import torch.nn as nn from hivemind.utils.logging import get_logger from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel from petals.client.from_pretrained import FromPretrainedMixin from petals.client.lm_head import LMHead from petals.client.ptune import PTuneMixin from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues from petals.client.remote_sequential import RemoteSequential from petals.models.llama.config import DistributedLlamaConfig logger = get_logger(__name__) class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): """LlamaModel, but all transformer layers are hosted by the swarm""" _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing _keys_to_ignore_on_load_unexpected = [r"^model\.layers\."] config_class = DistributedLlamaConfig def __init__(self, config: DistributedLlamaConfig, *, dht: Optional[hivemind.DHT] = None): n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization super().__init__(config) assert len(self.layers) == 0 config.num_hidden_layers = n_layer self.layers = RemoteSequential(config, dht=dht) self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm self.init_prompts(config) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[RemotePastKeyValues] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> BaseModelOutputWithPast: if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") # The causal mask will be added on the server-side assert ( attention_mask is None or (attention_mask == 1).all() ), f"Custom attention masks are not supported, {attention_mask=}" if cache_position is not None: assert position_ids is not None and torch.all(torch.eq(cache_position, position_ids)).item() assert ( position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all() ), f"Non-consecutive position_ids are not supported, {position_ids=}" assert use_cache is None or use_cache, f"{use_cache=} is not supported" assert not output_attentions, f"{output_attentions=} is not supported" assert not output_hidden_states, f"{output_hidden_states=} is not supported" assert return_dict is None or return_dict, f"{return_dict=} is not supported" if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.layers.position == 0 if use_prompts: batch_size = inputs_embeds.shape[0] prompts, intermediate_prompts = self.get_prompt(batch_size) inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) else: prompts = intermediate_prompts = None hidden_states = inputs_embeds output_shape = input_shape + (hidden_states.size(-1),) hidden_states = self.layers( hidden_states, prompts=intermediate_prompts, hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None, ) if past_key_values is None: past_key_values = RemotePastKeyValues() past_key_values.update_seen(hidden_states.size(1)) # Remove prefix if use_prompts: hidden_states = hidden_states[:, self.pre_seq_len :] # Add last hidden state hidden_states = self.norm(hidden_states) hidden_states = hidden_states.view(output_shape) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=None, attentions=None, ) @property def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin return self.embed_tokens @property def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin return nn.Identity() @property def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin return self.layers @property def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin return self.norm class DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, LlamaForCausalLM): _keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing _keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected config_class = DistributedLlamaConfig def __init__(self, config: DistributedLlamaConfig): LlamaPreTrainedModel.__init__(self, config) self.model = DistributedLlamaModel(config) self.pretraining_tp = config.pretraining_tp self.vocab_size = config.vocab_size self.lm_head = LMHead(config) # Initialize weights and apply final processing self.post_init() def get_output_embeddings(self): return self.lm_head @property def transformer(self) -> DistributedLlamaModel: # For compatibility with RemoteGenerationMixin return self.model class DistributedLlamaForSequenceClassification(FromPretrainedMixin, LlamaForSequenceClassification): _keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing _keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected config_class = DistributedLlamaConfig def __init__(self, config): LlamaPreTrainedModel.__init__(self, config) self.num_labels = config.num_labels self.model = DistributedLlamaModel(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() @property def transformer(self) -> DistributedLlamaModel: # For compatibility with RemoteGenerationMixin return self.model ================================================ FILE: src/petals/models/llama/speculative_model.py ================================================ from typing import Optional, Union import torch from transformers.generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList from transformers.generation.utils import GenerateNonBeamOutput, GenerationMixin from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama import LlamaForCausalLM from petals.models.llama.config import DistributedLlamaConfig from petals.models.llama.model import DistributedLlamaForCausalLM class DistributedLlamaForSpeculativeGeneration(DistributedLlamaForCausalLM, GenerationMixin): def __init__(self, config: DistributedLlamaConfig, small_model: LlamaForCausalLM): DistributedLlamaForCausalLM.__init__(self, config) self.small_model = small_model def _sample( self, input_ids: torch.LongTensor, logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool, streamer: Optional["BaseStreamer"], logits_warper: Optional[LogitsProcessorList], speculative_inference_iteration_size: int = 10, **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: assert not generation_config.do_sample, "sample is not working for speculative generation now" assert not synced_gpus, "synced_gpus is not working for speculative generation now" assert ( not generation_config.return_dict_in_generate ), "return_dict_in_generate is not working for speculative generation now" has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) # keep track of which sequences are already finished batch_size = input_ids.shape[0] unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) finished = False firsts = True while not finished: speculative_inference_iteration_size = min( speculative_inference_iteration_size, self.active_session._max_length - input_ids.shape[1] ) with torch.no_grad(): speculative_outputs = self.small_model.generate( input_ids, max_new_tokens=speculative_inference_iteration_size, do_sample=False, ) speculative_tokens = speculative_outputs[:, -speculative_inference_iteration_size:] full_sequence = torch.cat([input_ids, speculative_tokens], dim=-1) assert input_ids.shape[1] + speculative_inference_iteration_size == full_sequence.shape[1] input_for_validation = full_sequence if not firsts: self.active_session.position = input_ids.shape[1] - 1 input_for_validation = input_for_validation[:, -speculative_inference_iteration_size - 1 :] else: firsts = False input_for_validation = input_for_validation[:, :-1] with torch.no_grad(): precise_model_outputs = self(input_for_validation) full_token_logits = precise_model_outputs.logits[:, -speculative_inference_iteration_size:, :].clone() all_valid_tokens = [] first_token = None for i in range(speculative_inference_iteration_size): token_logits = full_token_logits[:, i, :] token_scores = logits_processor( input_for_validation[:, : -speculative_inference_iteration_size + 1 + i], token_logits ) valid_token = torch.argmax(token_scores, dim=-1) if first_token is None: first_token = valid_token if valid_token.item() == speculative_tokens[:, i].item(): all_valid_tokens.append(valid_token.unsqueeze(-1)) else: break if not all_valid_tokens and first_token is not None: all_valid_tokens.append(first_token.unsqueeze(-1)) all_valid_tokens = torch.cat(all_valid_tokens, dim=-1) # finished sentences should have their next token be a padding token if has_eos_stopping_criteria: all_valid_tokens = all_valid_tokens * unfinished_sequences + generation_config.pad_token_id * ( 1 - unfinished_sequences ) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, all_valid_tokens], dim=-1) if streamer is not None: streamer.put(all_valid_tokens.cpu()) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, None) finished = unfinished_sequences.max() == 0 del precise_model_outputs if streamer is not None: streamer.end() return input_ids ================================================ FILE: src/petals/models/mixtral/__init__.py ================================================ from petals.models.mixtral.block import WrappedMixtralBlock from petals.models.mixtral.config import DistributedMixtralConfig from petals.models.mixtral.model import ( DistributedMixtralForCausalLM, DistributedMixtralForSequenceClassification, DistributedMixtralModel, ) from petals.utils.auto_config import register_model_classes register_model_classes( config=DistributedMixtralConfig, model=DistributedMixtralModel, model_for_causal_lm=DistributedMixtralForCausalLM, model_for_sequence_classification=DistributedMixtralForSequenceClassification, ) ================================================ FILE: src/petals/models/mixtral/block.py ================================================ from typing import Optional, Tuple import torch from transformers import MixtralConfig from transformers.cache_utils import DynamicCache from transformers.modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer class WrappedMixtralBlock(MixtralDecoderLayer): def __init__(self, config: MixtralConfig, layer_idx: int): super().__init__(config, layer_idx) self._attn_implementation = config._attn_implementation self.sliding_window = config.sliding_window self.layer_idx = layer_idx def forward( self, hidden_states: torch.Tensor, *args, attention_mask: Optional[torch.Tensor] = None, layer_past: Optional[Tuple[torch.Tensor]] = None, use_cache: bool = False, **kwargs ): batch_size, seq_length, _ = hidden_states.shape seq_length_with_past = seq_length past_key_values_length = 0 past_key_value = layer_past if past_key_value is not None: past_key_values_length = past_key_value[0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length _past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length) past_key_value = DynamicCache() past_key_value.key_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[0]] past_key_value.value_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[1]] past_key_value._seen_tokens = past_key_values_length if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self._attn_implementation == "sdpa": # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length, ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length, sliding_window=self.sliding_window, ) position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=hidden_states.device ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) outputs = super().forward( hidden_states, *args, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache, **kwargs ) if use_cache: present_key_value = outputs[-1] present_key_value = present_key_value[self.layer_idx] present_key_value = self._reorder_cache_to_bloom(present_key_value, batch_size, seq_length_with_past) outputs = outputs[:-1] + (present_key_value,) return outputs def _reorder_cache_from_bloom( self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int ) -> Tuple[torch.Tensor]: # TODO: Move to mixin key_states, value_states = key_value key_states = key_states.permute(0, 2, 1) key_states = key_states.view( batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim ) value_states = value_states.view(*key_states.shape) return (key_states, value_states) def _reorder_cache_to_bloom( self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int ) -> Tuple[torch.Tensor]: # TODO: Move to mixin key_states, value_states = key_value value_states = value_states.view( batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim ) key_states = key_states.view(*value_states.shape) key_states = key_states.permute(0, 2, 1) return (key_states, value_states) ================================================ FILE: src/petals/models/mixtral/config.py ================================================ import os from typing import Optional, Union from hivemind import get_logger from transformers.models.mixtral import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralAttention from petals.client.config import ClientConfig from petals.client.lm_head import LMHeadConfig from petals.client.ptune import PTuneConfig from petals.models.mixtral.block import WrappedMixtralBlock logger = get_logger(__name__) class DistributedMixtralConfig(MixtralConfig, ClientConfig, PTuneConfig, LMHeadConfig): block_class = WrappedMixtralBlock attn_class = MixtralAttention block_prefix = "model.layers" num_key_value_groups = 1 @classmethod def from_pretrained( cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs ): loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) if loading_from_repo and dht_prefix is None: dht_prefix = str(model_name_or_path) dht_prefix = dht_prefix.replace(".", "-") logger.info(f"Using DHT prefix: {dht_prefix}") result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) config = result[0] if isinstance(result, tuple) else result if config.pad_token_id is None: config.pad_token_id = 0 return result ================================================ FILE: src/petals/models/mixtral/model.py ================================================ from typing import Optional import torch import torch.nn as nn from hivemind import DHT from hivemind.utils.logging import get_logger from transformers.modeling_outputs import MoeModelOutputWithPast from transformers.models.mixtral import ( MixtralForCausalLM, MixtralForSequenceClassification, MixtralModel, MixtralPreTrainedModel, ) from petals.client.from_pretrained import FromPretrainedMixin from petals.client.lm_head import LMHead from petals.client.ptune import PTuneMixin from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues from petals.client.remote_sequential import RemoteSequential from petals.models.mixtral.config import DistributedMixtralConfig from petals.utils.auto_config import DefaultRevisionMixin logger = get_logger(__name__) class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, MixtralModel): """MixtralModel, but all transformer layers are hosted by the swarm""" _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing _keys_to_ignore_on_load_unexpected = [r"^model\.layers\."] config_class = DistributedMixtralConfig def __init__(self, config: DistributedMixtralConfig, *, dht: Optional[DHT] = None): n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization super().__init__(config) assert len(self.layers) == 0 config.num_hidden_layers = n_layer self.layers = RemoteSequential(config, dht=dht) self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm self.init_prompts(config) def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[RemotePastKeyValues] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ): if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") # The causal mask will be added on the server-side assert ( attention_mask is None or (attention_mask == 1).all() ), f"Custom attention masks are not supported, {attention_mask=}" if cache_position is not None: assert position_ids is not None and torch.all(torch.eq(cache_position, position_ids)).item() assert ( position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all() ), f"Non-consecutive position_ids are not supported, {position_ids=}" assert head_mask is None, f"Custom head masks are not supported, {head_mask=}" assert use_cache is None or use_cache, f"{use_cache=} is not supported" assert not output_attentions, f"{output_attentions=} is not supported" assert not output_hidden_states, f"{output_hidden_states=} is not supported" assert return_dict is None or return_dict, f"{return_dict=} is not supported" assert not output_router_logits, f"{output_router_logits=} is not supported" if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0 if use_prompts: batch_size = inputs_embeds.shape[0] prompts, intermediate_prompts = self.get_prompt(batch_size) inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) else: prompts = intermediate_prompts = None hidden_states = inputs_embeds output_shape = input_shape + (hidden_states.size(-1),) if past_key_values is None: past_key_values = RemotePastKeyValues() past_key_values.update_seen(hidden_states.size(1)) hidden_states = self.layers( hidden_states, prompts=intermediate_prompts, hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None, ) # Remove prefix if use_prompts: hidden_states = hidden_states[:, self.pre_seq_len :] # Add last hidden state hidden_states = self.norm(hidden_states) hidden_states = hidden_states.view(output_shape) return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=None, attentions=None, ) @property def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin return self.embed_tokens @property def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests return nn.Identity() @property def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin return self.layers @property def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests return self.norm class DistributedMixtralForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM): _keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing _keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected config_class = DistributedMixtralConfig def __init__(self, config: DistributedMixtralConfig): MixtralPreTrainedModel.__init__(self, config) self.model = DistributedMixtralModel(config) self.lm_head = LMHead(config) # Initialize weights and apply final processing self.post_init() def get_output_embeddings(self): return self.lm_head @property def transformer(self) -> DistributedMixtralModel: # For compatibility with RemoteGenerationMixin return self.model class DistributedMixtralForSequenceClassification(FromPretrainedMixin, MixtralForSequenceClassification): _keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing _keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected config_class = DistributedMixtralConfig def __init__(self, config: DistributedMixtralConfig): MixtralPreTrainedModel.__init__(self, config) self.num_labels = config.num_labels self.model = DistributedMixtralModel(config) self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() @property def transformer(self) -> DistributedMixtralModel: # For compatibility with RemoteGenerationMixin return self.model ================================================ FILE: src/petals/server/__init__.py ================================================ ================================================ FILE: src/petals/server/backend.py ================================================ from __future__ import annotations from collections import Counter from itertools import chain from typing import Any, Dict, Optional, Sequence, Tuple, Union import torch from hivemind import BatchTensorDescriptor, TensorDescriptor from hivemind.moe.expert_uid import ExpertUID from hivemind.moe.server.module_backend import ModuleBackend from hivemind.utils import get_logger from tensor_parallel import TensorParallel from tensor_parallel.tensor_parallel import PerDeviceTensors from transformers import PretrainedConfig from petals.data_structures import InferenceMetadata from petals.server.memory_cache import MemoryCache from petals.server.task_pool import PrioritizedTaskPool from petals.utils.misc import get_size_in_bytes, is_dummy logger = get_logger(__name__) class TransformerBackend(ModuleBackend): """A wrapper for a transformer block that can process requests for forward, backward and inference""" _peft_module = None def __init__( self, *args, config: PretrainedConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, max_chunk_size_bytes: int, **kwargs, ): import petals.utils.peft as _peft_module self._peft_module = _peft_module super().__init__(*args, **kwargs) assert isinstance(self.module, TensorParallel) self.config = config self.memory_cache = memory_cache self.max_chunk_size_bytes = max_chunk_size_bytes for name, param in self.module.named_parameters(): assert not param.requires_grad, f"Block parameters must not accumulate gradients, but {name} does" for name, buf in self.module.named_buffers(): assert not buf.requires_grad, f"Block parameters must not accumulate gradients, but {name} does" max_batch_size = self.forward_pool.max_batch_size device = self.module.devices[self.module.output_device_index] self.inference_pool = PrioritizedTaskPool( self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference" ) # note: inference_pools may be merged later, see merge_inference_pools_inplace self.forward_pool = PrioritizedTaskPool( self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward" ) self.backward_pool = PrioritizedTaskPool( self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward" ) self.dtype = backend_dtype self.dtype_bytes = get_size_in_bytes(self.dtype) self.shard_num_heads = [] for shard in self.module.module_shards: for submodule in shard.modules(): if isinstance(submodule, config.attn_class): self.shard_num_heads.append(submodule.num_heads) assert len(self.shard_num_heads) == len(self.module.devices) assert sum(self.shard_num_heads) == config.num_attention_heads self.inference_schema = ( ( *self.args_schema, BatchTensorDescriptor((), dtype=self.dtype), BatchTensorDescriptor((), dtype=torch.int64), ), self.kwargs_schema, ) self.cache_bytes_per_token: Dict[torch.device, int] = Counter() for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1): self.cache_bytes_per_token[descr.device] += descr.numel() * get_size_in_bytes(descr.dtype) def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]: """Create tensor descriptors for attention cache tensors used during inference_step""" head_dim = self.config.hidden_size // self.config.num_attention_heads cache_tensors = [] for device, num_heads in zip(self.module.devices, self.shard_num_heads): num_heads //= self.config.num_key_value_groups if hasattr(self.config, "num_key_value_heads"): num_heads = self.config.num_key_value_heads keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device) values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device) cache_tensors.extend((keys, values)) return cache_tensors def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: *inputs, active_adapter = inputs with self._peft_module.using_adapter(active_adapter): return super().forward(*inputs) def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: *inputs, active_adapter = inputs with self._peft_module.using_adapter(active_adapter): return super().backward(*inputs) @torch.inference_mode() def inference_step( self, hidden_states: torch.Tensor, hypo_ids: torch.LongTensor, inference_info: InferenceMetadata, ) -> Tuple[torch.Tensor, ...]: assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" seq_len = hidden_states.shape[1] with self.memory_cache.use_cache( *inference_info.cache_handles ) as cache_tensors, self._peft_module.using_adapter(inference_info.active_adapter): self._reorder_cache_inplace(cache_tensors, hypo_ids) # We chunk the inputs so that peak memory for long sequences fits into `autograd_memory` # reserved in `Server._choose_num_blocks()`. This saves us from OOMs if `max_chunk_size_bytes` # is at least 4-6x less than `autograd_memory`. max_chunk_length = self._estimate_max_chunk_length(hidden_states, inference_info) output_hidden_states = torch.empty_like(hidden_states) if seq_len > max_chunk_length else None layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length) for offset in range(0, seq_len, max_chunk_length): hidden_states_chunk = hidden_states[:, offset : offset + max_chunk_length, :] output_hidden_states_chunk, new_kvs = self.module.forward( hidden_states_chunk, layer_past=layer_past, use_cache=True ) if seq_len > max_chunk_length: output_hidden_states[:, offset : offset + max_chunk_length] = output_hidden_states_chunk else: output_hidden_states = output_hidden_states_chunk # saves one memcopy layer_past = new_kvs self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length) return (output_hidden_states,) def _estimate_max_chunk_length(self, hidden_states: torch.Tensor, inference_info: InferenceMetadata) -> int: # We assume that attention logit matrices are the main thing that consumes memory, given that # the model uses multi-query attention batch_size, seq_length, hidden_size = hidden_states.shape worst_case_length = inference_info.prefix_length + seq_length attn_bytes_per_token = max(self.shard_num_heads) * batch_size * self.dtype_bytes * worst_case_length return max(1, self.max_chunk_size_bytes // attn_bytes_per_token) def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor): """If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids""" if not is_dummy(hypo_ids): for cache_tensor in cache_tensors: cache_tensor[...] = cache_tensor[hypo_ids.to(cache_tensor.device)] # in-place reorder cache by hypo ids def _select_layer_past(self, cache_tensors: Sequence[torch.Tensor], prefix_length: int) -> Sequence[torch.Tensor]: """Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past""" key_cache, value_cache = list(cache_tensors[0::2]), list(cache_tensors[1::2]) for i in range(len(key_cache)): key_cache[i] = key_cache[i].flatten(0, 1)[:, :, :prefix_length] # shape: [batch * num_kv_heads, head_dim, kv_length] value_cache[i] = value_cache[i].flatten(0, 1)[:, :prefix_length] # shape: [batch * num_kv_heads, kv_length, head_dim] layer_past = tuple(chain(*zip(key_cache, value_cache))) return PerDeviceTensors(*layer_past) if len(self.module.module_shards) > 1 else layer_past def _update_cache_inplace( self, cache_tensors: Sequence[torch.Tensor], new_kvs: Sequence[torch.Tensor], prefix_length: int ): """Writes new key/value tensors back into cache, works in-place""" _batch_size_times_num_kv_heads, head_dim, new_length = new_kvs[0].shape for cache_key, new_key in zip(cache_tensors[0::2], new_kvs[0::2]): new_key = new_key.view(*cache_key.shape[:3], new_length) cache_key[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length] for cache_value, new_value in zip(cache_tensors[1::2], new_kvs[1::2]): new_value = new_value.view(*cache_value.shape[:2], new_length, head_dim) cache_value[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :] def get_pools(self) -> Sequence[PrioritizedTaskPool]: return self.forward_pool, self.backward_pool, self.inference_pool def get_info(self) -> Dict[str, Any]: """Get module parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration.""" return dict(super().get_info(), inference_schema=self.inference_schema) def shutdown(self): # Break the cyclic references, otherwise TransformerBackend may be not garbage-collected self.forward_pool = self.backward_pool = self.inference_pool = None # Explicitly free the GPU memory. This is not necessary at the time this code is written, # but may help to avoid future issues when the module is not garbage-collected for some reasons dummy = torch.tensor([]) for p in self.module.parameters(): p.data = dummy def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]): """Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call""" assert len(backends) != 0 and all(isinstance(b, TransformerBackend) for b in backends.values()) first_pool = next(iter(backends.values())).inference_pool merged_pool = PrioritizedTaskPool( _MergedInferenceStep(backends), max_batch_size=first_pool.max_batch_size, device=first_pool.device, name=f"merged_inference", ) for backend in backends.values(): assert not backend.inference_pool.is_alive() backend.inference_pool = merged_pool class _MergedInferenceStep: def __init__(self, backends: Dict[ExpertUID, TransformerBackend]): self.backends = backends @torch.inference_mode() def __call__( self, hidden_states: torch.Tensor, hypo_ids: torch.LongTensor, inference_infos: Sequence[InferenceMetadata], *optional_prompts: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, ...]: assert len(inference_infos) == len( optional_prompts ), f"found {len(inference_infos)} blocks but {len(optional_prompts)} prompts" for inference_info, optional_prompt in zip(inference_infos, optional_prompts): if optional_prompt is not None: hidden_states[:, : optional_prompt.shape[1]] += optional_prompt (hidden_states,) = self.backends[inference_info.uid].inference_step(hidden_states, hypo_ids, inference_info) return (hidden_states,) ================================================ FILE: src/petals/server/block_functions.py ================================================ """ This module implements server-side computations on served blocks: forward, backward and inference; used by handler """ from __future__ import annotations from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union import torch from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor from hivemind.moe.expert_uid import ExpertUID from hivemind.proto import runtime_pb2 from hivemind.utils.logging import get_logger from hivemind.utils.nested import nested_flatten from petals.data_structures import Handle, InferenceMetadata from petals.server.backend import TransformerBackend from petals.server.task_pool import PrioritizedTaskPool from petals.server.task_prioritizer import TaskPrioritizerBase from petals.utils.convert_block import QuantType from petals.utils.misc import DUMMY, is_dummy from petals.utils.packaging import unpack_args_kwargs # We prioritize short inference requests and make them use a *merged* inference pool, # so they are processed without interruptions and extra overheads # TODO: Increase the NF4 threshold once bitsandbytes ships efficient NF4 kernel for parallel forward MAX_SHORT_INFERENCE_TOKENS = 128 MAX_NF4_SHORT_INFERENCE_TOKENS = 1 logger = get_logger(__name__) async def run_rpc_forward( *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], active_adapter: str = "", prioritizer: TaskPrioritizerBase, points: int = 0, args_structure: Any = None, ) -> torch.Tensor: """ Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy) :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass :returns: hidden states after the last layer [batch_size, seq_length, hid_size] """ if args_structure is not None: # TODO: kwargs currently is unused, it can be used later for peft-like adaptation flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure) hidden_states, prompts, *_ = flat_tensors dtype = requested_backends[0].dtype # check parse input tensors and cast dtypes hidden_states = hidden_states.to(dtype) assert hidden_states.ndim == 3 if prompts is None or is_dummy(prompts): prompts = [DUMMY] * len(requested_backends) else: prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] # Run a chain of requested backends for backend, prompt in zip(requested_backends, prompts): if not is_dummy(prompt): hidden_states[:, : prompt.shape[1]] += prompt assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" priority = prioritizer.prioritize( hidden_states, points=points / len(requested_backends), backend=backend, type="forward" ) (hidden_states,) = await backend.forward_pool.submit_task( hidden_states, active_adapter, priority=priority, ) assert isinstance(hidden_states, torch.Tensor) assert ( hidden_states.ndim == 3 ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" return hidden_states async def run_rpc_backward( *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], active_adapter: str = "", prioritizer: TaskPrioritizerBase, points: int = 0, args_structure: Any = None, ) -> Union[torch.Tensor, Sequence[torch.Tensor]]: if args_structure is not None: # TODO: kwargs currently is unused, it can be used later for peft-like adaptation flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure) inputs, grad_outputs, prompts, *_ = flat_tensors # Cast inputs & grad outputs to backend dtype inputs = inputs.to(requested_backends[0].dtype) grad_outputs = grad_outputs.to(requested_backends[-1].dtype) if prompts is None or is_dummy(prompts): prompts = [DUMMY] * len(requested_backends) else: prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] # Run a forward chain to collect intermediate inputs # Note that we do not forward for the last module since we do not need its output inter_inputs = [] for backend, prompt in zip(requested_backends[:-1], prompts[:-1]): assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" if not is_dummy(prompt): inputs[:, : prompt.shape[1]] += prompt inter_inputs.append(inputs) assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" priority = prioritizer.prioritize( inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward" ) (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority) assert isinstance(inputs, torch.Tensor) if not is_dummy(prompts[-1]): inputs[:, : prompts[-1].shape[1]] += prompts[-1] inter_inputs.append(inputs) assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward" grad_prompts_reversed = [] # Run a chain of requested backends for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))): assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" priority = prioritizer.prioritize( inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward" ) (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority) assert isinstance(grad_outputs, torch.Tensor) if not is_dummy(prompt): grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0)) grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape async def iterate_rpc_inference( requested_uids: Sequence[ExpertUID], requested_backends: Sequence[TransformerBackend], active_adapter: Optional[str], input_iterator: AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]], cache_handles: Sequence[Sequence[Handle]], *, max_length: int, prioritizer: TaskPrioritizerBase, points: int, quant_type: QuantType, args_structure: Any = None, ) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool, Dict]]: assert len(cache_handles) == len(requested_backends) prefix_length = 0 point_per_piece = points / max_length if max_length > 0 else 0.0 async for request, step_metadata in input_iterator: if "start_from_position" in step_metadata: start_from_position = step_metadata["start_from_position"] assert ( prefix_length >= start_from_position, ), f"prefix_length={prefix_length}, start_from_position={start_from_position}" prefix_length = start_from_position flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors) if args_structure is not None: # TODO: kwargs currently is unused, it can be used later for peft-like adaptation flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure) hidden_states, prompts, hypo_ids, *_ = flat_tensors batch_size, length_increment, _ = hidden_states.shape # Cast inputs to backend dtype hidden_states = hidden_states.to(requested_backends[0].dtype) assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}" # parse deep prompts (optional argument) has_prompts = prompts is not None and not is_dummy(prompts) if not has_prompts: prompts = [None] * len(requested_backends) else: prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts] if not (len(requested_backends) == len(prompts)): raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends") if prefix_length + length_increment > max_length: raise ValueError( f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}" f" exceeds pre-allocated maximum {max_length}" ) merge_max_tokens = MAX_NF4_SHORT_INFERENCE_TOKENS if quant_type == QuantType.NF4 else MAX_SHORT_INFERENCE_TOKENS can_merge_pools = batch_size * length_increment <= merge_max_tokens priority = prioritizer.prioritize( hidden_states, hypo_ids, points=point_per_piece, requested_uids=requested_uids, type="inference", ) # A client may pass a tensor with 0 tokens. This is a special case that occurs, e.g. # when user wants to pre-allocate cache or check that server *can* allocate that cache. if hidden_states.numel() > 0: assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor" if can_merge_pools: inference_infos = tuple( InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter) for uid, handles in zip(requested_uids, cache_handles) ) (hidden_states,) = await requested_backends[0].inference_pool.submit_task( hidden_states, hypo_ids, inference_infos, *prompts, priority=priority ) else: for backend, uid, handles, prompt in zip(requested_backends, requested_uids, cache_handles, prompts): inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),) (hidden_states,) = await backend.inference_pool.submit_task( hidden_states, hypo_ids, inference_infos, prompt, priority=priority ) # serialize and send last layer outputs output_tensors = [ serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)) ] can_push = not has_prompts yield output_tensors, can_push, step_metadata # prepare for next step prefix_length += length_increment ================================================ FILE: src/petals/server/block_selection.py ================================================ from typing import Dict, List import numpy as np from hivemind import PeerID, get_logger from petals.data_structures import RemoteModuleInfo, RemoteSpanInfo, ServerState from petals.utils.dht import compute_spans logger = get_logger(__name__) def compute_throughputs(spans: Dict[PeerID, RemoteSpanInfo], *, total_blocks: int) -> np.ndarray: # We sort servers here to ensure that we get exactly the same throughputs for a given set of servers. # If the order were not defined, we would get slightly different values due to floating point errors, # which may cause excess block replacements. throughputs = np.zeros(total_blocks) for span in sorted(spans.values(), key=lambda span: span.peer_id): throughputs[span.start : span.end] += span.throughput return throughputs def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int: options = ((sorted(throughputs[i : i + num_blocks]), i) for i in range(0, len(throughputs) - num_blocks + 1)) return min(options)[-1] def choose_best_blocks(num_blocks: int, module_infos: List[RemoteModuleInfo]) -> List[int]: spans = compute_spans(module_infos, min_state=ServerState.JOINING) throughputs = compute_throughputs(spans, total_blocks=len(module_infos)) start = _choose_best_start(throughputs, num_blocks) return list(range(start, start + num_blocks)) def _move_span(span: RemoteSpanInfo, new_start: int): span.start, span.end = new_start, new_start + span.length def should_choose_other_blocks( local_peer_id: PeerID, module_infos: List[RemoteModuleInfo], balance_quality: float ) -> bool: if balance_quality > 1.0: return True # Forces rebalancing on each check (may be used for debugging purposes) spans = compute_spans(module_infos, min_state=ServerState.JOINING) throughputs = compute_throughputs(spans, total_blocks=len(module_infos)) initial_throughput = throughputs.min() eps = 1e-3 assert local_peer_id in spans, "Span served by this server is not present in the DHT" local_span = spans[local_peer_id] throughputs[local_span.start : local_span.end] -= local_span.throughput * (1 + eps) # Without (1 + eps) here, we would sometimes subtract a value slightly less than local_span.throughput # due to the floating point error, which would cause excess block replacements. # Also, subtracting local_span.throughput * (1 + eps) makes _choose_best_start() prefer # the previous server position in case of other things being almost equal. if initial_throughput > eps and throughputs.min() <= 0: return False # Switching blocks would make the swarm disjoint new_start = _choose_best_start(throughputs, local_span.length) if local_span.start == new_start: return False # This server is on its best place already throughputs[local_span.start : local_span.end] += local_span.throughput * eps _move_span(local_span, new_start) throughputs[local_span.start : local_span.end] += local_span.throughput moved = True while moved: servers = list(spans.keys()) np.random.shuffle(servers) moved = False for peer_id in servers: span = spans[peer_id] throughputs[span.start : span.end] -= span.throughput * (1 + eps) new_start = _choose_best_start(throughputs, span.length) throughputs[span.start : span.end] += span.throughput * eps if span.start != new_start: _move_span(span, new_start) moved = True throughputs[span.start : span.end] += span.throughput new_throughput = throughputs.min() if new_throughput < initial_throughput or new_throughput < eps: return False actual_quality = initial_throughput / new_throughput logger.info(f"Swarm balance quality: {actual_quality * 100:.1f}%") return actual_quality < balance_quality - eps ================================================ FILE: src/petals/server/block_utils.py ================================================ from typing import Optional, Union import torch from accelerate import init_empty_weights from transformers import PretrainedConfig, PreTrainedModel from petals.models.mixtral.block import WrappedMixtralBlock from petals.utils.convert_block import QuantType from petals.utils.misc import get_size_in_bytes def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype: """If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise.""" if dtype not in ("auto", None): return dtype if config.torch_dtype not in ("auto", None, torch.float32): # If config specifies float32, we override it to the default dtype below return config.torch_dtype return torch.bfloat16 def get_block_size( config: PretrainedConfig, location: str, *, dtype: Optional[Union[str, torch.dtype]] = None, quant_type: QuantType = QuantType.NONE, eps: float = 0.01, # eps accounts for ~1% of metainfo for tensor descriptions, quantization tables, etc. ) -> int: if location == "memory": assert ( dtype is not None and quant_type is not None ), 'get_block_size(..., location="memory") requires to specify dtype and quant_type for calculations' with init_empty_weights(include_buffers=False): block = get_model_block(config) n_params = sum(param.numel() for param in block.parameters()) if location == "memory": if quant_type == QuantType.NONE: dtype = resolve_block_dtype(config, dtype) bytes_per_value = get_size_in_bytes(dtype) elif quant_type == QuantType.INT8: bytes_per_value = 1 elif quant_type == QuantType.NF4: bytes_per_value = 4.25 / 8 # Bitness of NF4 with this config (measured empirically) else: raise ValueError(f"Unsupported quant_type={quant_type}") elif location == "disk": dtype = resolve_block_dtype(config, "auto") bytes_per_value = get_size_in_bytes(dtype) return round(n_params * bytes_per_value * (1 + eps)) def get_model_block(config, layer_idx: int = 0): """ The function to create a model block based on the block class kwargs argument **only** is necessary for specific classes, like Mixtral. They will not be passed to other block constructors. """ if config.block_class == WrappedMixtralBlock: config = PreTrainedModel._autoset_attn_implementation(config) return config.block_class(config, layer_idx) return config.block_class(config) ================================================ FILE: src/petals/server/from_pretrained.py ================================================ """ Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code. If necessary, one can rewrite this to implement a different behavior, such as: - loading files from a local data source (e.g. S3) - load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to ) - fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html ) """ import json import time from contextlib import suppress from typing import Dict, Optional, Union import safetensors import torch import torch.nn as nn from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device from hivemind.utils.logging import get_logger from huggingface_hub import get_hf_file_metadata, hf_hub_url from huggingface_hub.utils import EntryNotFoundError from transformers import PretrainedConfig, PreTrainedModel from transformers.utils import get_file_from_repo from petals.constants import DTYPE_MAP from petals.models.mixtral import WrappedMixtralBlock from petals.server.block_utils import get_model_block, resolve_block_dtype from petals.utils.auto_config import AutoDistributedConfig from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for from petals.utils.hf_auth import always_needs_auth logger = get_logger(__name__) def load_pretrained_block( model_name: str, block_index: int, *, config: Optional[PretrainedConfig] = None, torch_dtype: Union[torch.dtype, str] = "auto", revision: Optional[str] = None, token: Optional[Union[str, bool]] = None, cache_dir: Optional[str] = None, max_disk_space: Optional[int] = None, ) -> nn.Module: if config is None: config = AutoDistributedConfig.from_pretrained(model_name, use_auth_token=token) if cache_dir is None: cache_dir = DEFAULT_CACHE_DIR assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" torch_dtype = resolve_block_dtype(config, torch_dtype) with init_empty_weights(): block = get_model_block(config, layer_idx=block_index) block_prefix = f"{config.block_prefix}.{block_index}." state_dict = _load_state_dict_from_repo( model_name, block_prefix, revision=revision, token=token, cache_dir=cache_dir, max_disk_space=max_disk_space, ) for param_name, _ in block.named_parameters(): assert param_name in state_dict, f"{param_name} not in state dict" param = state_dict[param_name] if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): param = param.to(torch_dtype) set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype) logger.info(f"Loaded {model_name} block {block_index}") return block StateDict = Dict[str, torch.Tensor] def _load_state_dict_from_repo( model_name: str, block_prefix: str, *, revision: Optional[str] = None, token: Optional[Union[str, bool]] = None, cache_dir: str, max_disk_space: Optional[int] = None, ) -> StateDict: if always_needs_auth(model_name) and token is None: token = True index_file = _find_index_file(model_name, revision=revision, token=token, cache_dir=cache_dir) if index_file.endswith(".index.json"): # Sharded model path = get_file_from_repo(model_name, filename=index_file, use_auth_token=token, cache_dir=cache_dir) if path is None: # _find_index_file() told that a file exists but we can't get it (e.g., it just disappeared) raise ValueError(f"Failed to get file {index_file}") with open(path) as f: index = json.load(f) filenames = { filename for param_name, filename in index["weight_map"].items() if param_name.startswith(block_prefix) } if not filenames: raise RuntimeError(f"Block {block_prefix}* not found in the index: {index['weight_map']}") else: # Non-sharded model filenames = {index_file} logger.debug(f"Loading {block_prefix}* from {filenames}") state_dict = {} for filename in filenames: shard_state_dict = _load_state_dict_from_repo_file( model_name, filename, block_prefix=block_prefix, revision=revision, token=token, cache_dir=cache_dir, max_disk_space=max_disk_space, ) shard_state_dict = { param_name[len(block_prefix) :]: param for param_name, param in shard_state_dict.items() if param_name.startswith(block_prefix) } # Remove unused parameters from memory state_dict.update(shard_state_dict) return state_dict INDEX_FILES = ["model.safetensors.index.json", "model.safetensors", "pytorch_model.bin.index.json", "pytorch_model.bin"] def _find_index_file( model_name: str, *, revision: Optional[str] = None, token: Optional[Union[str, bool]] = None, cache_dir: str ) -> str: # If we have cached weights (e.g., Pickle from older Petals versions), reuse them for filename in INDEX_FILES: path = get_file_from_repo( model_name, filename, revision=revision, use_auth_token=token, cache_dir=cache_dir, local_files_only=True, ) if path is not None: return filename # If we don't, prefer Safetensors when possible # (we don't download files here since we can't account for max_disk_space in case of large files) for filename in INDEX_FILES: with suppress(EntryNotFoundError): get_hf_file_metadata(hf_hub_url(model_name, filename, revision=revision), token=token) return filename raise ValueError( f"Repo {model_name} does not contain weights in a supported format: files {INDEX_FILES} do not exist" ) def _load_state_dict_from_repo_file( model_name: str, filename: str, *, block_prefix: Optional[str] = None, revision: Optional[str] = None, token: Optional[Union[str, bool]] = None, cache_dir: str, max_disk_space: Optional[int] = None, delay: float = 30, ) -> StateDict: # First, try to find the weights locally try: with allow_cache_reads(cache_dir): path = get_file_from_repo( model_name, filename, revision=revision, use_auth_token=token, cache_dir=cache_dir, local_files_only=True, ) if path is not None: return _load_state_dict_from_local_file(path, block_prefix=block_prefix) except Exception: logger.warning(f"Cache for file {filename} is corrupted, it will be downloaded again", exc_info=True) # If not found, ensure that we have enough disk space to download them (maybe remove something) while True: try: with allow_cache_writes(cache_dir): url = hf_hub_url(model_name, filename, revision=revision) file_size = get_hf_file_metadata(url, token=token).size if file_size is not None: free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space) else: logger.warning(f"Failed to fetch size of file {filename} from repo {model_name}") path = get_file_from_repo( model_name, filename, revision=revision, use_auth_token=token, cache_dir=cache_dir, local_files_only=False, ) if path is None: raise RuntimeError(f"File {filename} does not exist in repo {model_name}") return _load_state_dict_from_local_file(path, block_prefix=block_prefix) except Exception as e: logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True) time.sleep(delay) def _load_state_dict_from_local_file(path: str, *, block_prefix: Optional[str] = None) -> StateDict: if path.endswith(".bin"): return torch.load(path, map_location="cpu") if path.endswith(".safetensors"): with safetensors.safe_open(path, framework="pt", device="cpu") as f: return {key: f.get_tensor(key) for key in f.keys() if block_prefix is None or key.startswith(block_prefix)} raise ValueError(f"Unknown weight format: {path}") ================================================ FILE: src/petals/server/handler.py ================================================ from __future__ import annotations import asyncio import contextlib import multiprocessing as mp import sys from enum import Enum from itertools import chain from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple import torch from async_timeout import timeout from hivemind import ( DHT, MSGPackSerializer, P2PContext, PeerID, deserialize_tensor_stream, deserialize_torch_tensor, nested_flatten, nested_pack, serialize_torch_tensor, ) from hivemind.moe.server.connection_handler import ConnectionHandler from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE from hivemind.proto import runtime_pb2 from hivemind.utils.asyncio import amap_in_executor, anext from hivemind.utils.logging import get_logger from hivemind.utils.streaming import split_for_streaming import petals from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, Handle, ModuleUID from petals.server.backend import TransformerBackend from petals.server.block_functions import iterate_rpc_inference, run_rpc_backward, run_rpc_forward from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase from petals.utils.convert_block import QuantType logger = get_logger(__name__) # Fix pickling protobufs, see https://stackoverflow.com/a/74873028 sys.modules["runtime_pb2"] = runtime_pb2 CACHE_TOKENS_AVAILABLE = "cache_tokens_available" class Event(Enum): NEW_SESSION = 0 END_SESSION = 1 PUSH = 2 SHUTDOWN = 3 class TransformerConnectionHandler(ConnectionHandler): """Handles three request types: forward, backward and forward-incremental (inference)""" module_backends: Dict[ModuleUID, TransformerBackend] def __init__( self, dht: DHT, module_backends: Dict[str, TransformerBackend], *, adapters: Optional[Sequence[str]], dht_prefix: str, handler_event_queues: Sequence[mp.Queue], handler_index: int, inference_max_length: int, request_timeout: float, session_timeout: float, step_timeout: float, task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(), quant_type: QuantType, ): super().__init__(dht, module_backends) for module_backend in self.module_backends.values(): assert isinstance(module_backend, TransformerBackend) self.dht_prefix = dht_prefix self.adapters = adapters self._handler_event_queues = handler_event_queues self._handler_index = handler_index self._own_event_queue = handler_event_queues[handler_index] self._listener_task: Optional[asyncio.Task] = None self._session_queues: Dict[str, asyncio.Queue] = {} self._session_handlers: Dict[str, int] = {} self.inference_max_length = inference_max_length self.request_timeout = request_timeout self.session_timeout, self.step_timeout = session_timeout, step_timeout self._prioritizer = task_prioritizer self.quant_type = quant_type async def add_p2p_handlers(self, *args, **kwargs) -> None: if self._listener_task is None: # Start listening to our own event queue before we accept any requests self._listener_task = asyncio.create_task(self._listen_to_event_queue()) await super().add_p2p_handlers(*args, **kwargs) def shutdown(self): if self.is_alive(): self._outer_pipe.send("_shutdown") self._own_event_queue.put((Event.SHUTDOWN, None, None)) self.join(self.shutdown_timeout) if self.is_alive(): logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM") self.terminate() async def _gather_inputs( self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext ) -> Tuple[str, List[torch.Tensor], Dict]: block_uid, metadata = None, None def _unpack(req: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]: nonlocal block_uid, metadata if block_uid is None: block_uid = req.uid elif block_uid != req.uid: raise ValueError("Block uids differ in one request") if metadata is None: metadata = MSGPackSerializer.loads(req.metadata) if req.metadata else {} return req.tensors tensors_stream = amap_in_executor(_unpack, requests) inputs = await deserialize_tensor_stream(tensors_stream) assert isinstance(block_uid, str) and isinstance(metadata, dict) return block_uid, inputs, metadata async def rpc_inference( self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext, ) -> AsyncIterator[runtime_pb2.ExpertResponse]: """Compute a single step of inference using attention cache; update attention cache accordingly.""" async with timeout(self.session_timeout): try: request = await asyncio.wait_for(anext(requests), self.step_timeout) except asyncio.TimeoutError: self._log_request("rpc_inference.open", None, context, warning="timed out") return requested_uids = self._check_uids(request.uid) self._log_request("rpc_inference.open", requested_uids, context) try: metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) max_length = metadata.get("max_length") points = metadata.get("points", 0) session_id = metadata.get("session_id") alloc_timeout = float(metadata.get("alloc_timeout", 0.0)) args_structure = metadata.get("args_structure") if not requested_uids: raise ValueError("User must specify at least one block for inference, but got none") assert isinstance( max_length, int ), f"rpc_inference metadata must contain int max_length, got {max_length}" assert isinstance( points, (float, int) ), f"rpc_inference should have number of points as a number or None, got {points}" if not 0 <= max_length <= self.inference_max_length: raise ValueError( f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}" ) batch_size = request.tensors[0].size[0] if request.tensors else 1 async with self._allocate_cache( requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout ) as cache_handles: background_tasks = set() async for output_tensors, can_push, step_metadata in iterate_rpc_inference( requested_uids=requested_uids, requested_backends=requested_backends, active_adapter=self._get_active_adapter(metadata), input_iterator=self._iterate_inference_steps( request, requests, session_id, requested_uids, context ), cache_handles=cache_handles, max_length=max_length, prioritizer=self._prioritizer, points=points, quant_type=self.quant_type, args_structure=args_structure, ): if can_push: task = asyncio.create_task(self._push_outputs(request, output_tensors[0], step_metadata)) background_tasks.add(task) # Keep reference until it is done to save it from GC task.add_done_callback(background_tasks.discard) yield runtime_pb2.ExpertResponse(tensors=output_tensors) finally: self._log_request("rpc_inference.close", requested_uids, context) @contextlib.contextmanager def _managed_session(self, session_id: str): assert session_id not in self._session_queues, f"session id {session_id} is not unique" try: self._session_queues[session_id] = asyncio.Queue() self._session_handlers[session_id] = self._handler_index for other_index, other_queue in enumerate(self._handler_event_queues): if other_index != self._handler_index: other_queue.put_nowait((Event.NEW_SESSION, session_id, self._handler_index)) yield finally: self._session_queues.pop(session_id).put_nowait(None) # put None so that the get task will not hang del self._session_handlers[session_id] for other_index, other_queue in enumerate(self._handler_event_queues): if other_index != self._handler_index: other_queue.put_nowait((Event.END_SESSION, session_id, self._handler_index)) def _put_into_session_queue(self, session_id: str, request: runtime_pb2.ExpertRequest): handler_index = self._session_handlers.get(session_id) if handler_index is None: logger.debug(f"Ignored rpc_push to unknown session ID: {session_id}") elif handler_index == self._handler_index: self._session_queues[session_id].put_nowait(request) else: self._handler_event_queues[handler_index].put_nowait((Event.PUSH, session_id, request)) async def _get_from_session_queue(self, session_id: str) -> Optional[runtime_pb2.ExpertRequest]: assert self._session_handlers[session_id] == self._handler_index, "session belongs to another handler" return await self._session_queues[session_id].get() async def _listen_to_event_queue(self): loop = asyncio.get_event_loop() while True: try: event, session_id, payload = await loop.run_in_executor(None, self._own_event_queue.get) if event == Event.SHUTDOWN: break elif event == Event.NEW_SESSION: self._session_handlers[session_id] = payload # index of the handler that owns that session elif event == Event.END_SESSION: self._session_handlers.pop(session_id, None) elif event == Event.PUSH: maybe_session_queue = self._session_queues.get(session_id) if maybe_session_queue is not None: maybe_session_queue.put_nowait(payload) else: raise RuntimeError(f"Unexpected event: {event}") except Exception as e: logger.exception(e) async def _iterate_inference_steps( self, first_request: runtime_pb2.ExpertRequest, requests: AsyncIterator[runtime_pb2.ExpertRequest], session_id: Optional[str], requested_uids: Sequence[str], context: P2PContext, ) -> AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]]: processed_step_ids = set() n_pushes = n_late_pushes = 0 request = first_request anext_task = get_push_task = None try: with self._managed_session(session_id) if session_id is not None else contextlib.nullcontext(): while request.tensors: # iterate while user is willing to supply tensors metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} step_id = metadata.get("step_id") pushed = metadata.get("pushed") if pushed: n_pushes += 1 self._log_request("rpc_inference.push", requested_uids, context, debug=f"session received push") if step_id is None or step_id not in processed_step_ids: yield request, metadata if step_id is not None: processed_step_ids.add(step_id) elif pushed: n_late_pushes += 1 self._log_request( "rpc_inference.push", requested_uids, context, warning=f"arrived late {n_late_pushes / n_pushes * 100:.1f}% of the time", ) # Wait for the next request, coming either from the `requests` iterator or `push_queue` if anext_task is None: anext_task = asyncio.create_task(anext(requests)) if get_push_task is None: if session_id is not None: get_push_task = asyncio.create_task(self._get_from_session_queue(session_id)) else: get_push_task = asyncio.create_task(asyncio.Event().wait()) # Dummy never-ending task done, _ = await asyncio.wait( [anext_task, get_push_task], timeout=self.step_timeout, return_when=asyncio.FIRST_COMPLETED ) if anext_task in done: request = await anext_task anext_task = None elif get_push_task in done: request = await get_push_task get_push_task = None else: self._log_request("rpc_inference.step", requested_uids, context, warning="timed out") anext_task.cancel() get_push_task.cancel() return except Exception: logger.warning("rpc_inference._iterate_inference_steps() exception:", exc_info=True) raise async def rpc_push(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: """Directly push activation tensors from one server to another""" requested_uids = self._check_uids(request.uid) metadata = MSGPackSerializer.loads(request.metadata) session_id = metadata["session_id"] self._log_request("rpc_push", requested_uids, context, debug=f"session_id={session_id}") self._put_into_session_queue(session_id, request) return runtime_pb2.ExpertResponse() async def _push_outputs( self, request: runtime_pb2.ExpertRequest, serialized_outputs: runtime_pb2.Tensor, metadata: dict ) -> None: try: next_servers = metadata.get("next_servers") if not next_servers: return next_peer_id, next_session_id, next_start, next_end = next_servers[0] next_peer_id = PeerID.from_base58(next_peer_id) next_uid = CHAIN_DELIMITER.join(f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(next_start, next_end)) # Sending hidden states serialized with output_schema to avoid double serialization next_tensors = [serialized_outputs] + request.tensors[1:] next_metadata = metadata.copy() next_metadata.update(session_id=next_session_id, next_servers=next_servers[1:], pushed=True) stub = self.get_stub(self._p2p, next_peer_id) await stub.rpc_push( runtime_pb2.ExpertRequest( uid=next_uid, tensors=next_tensors, metadata=MSGPackSerializer.dumps(next_metadata), ), timeout=self.request_timeout, ) except Exception: logger.debug( f"Failed to push outputs to peer_id={next_peer_id}, session_id={next_session_id}, blocks={next_start}:{next_end}:", exc_info=True, ) async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: async with timeout(self.request_timeout): # Parse request and prepare backends flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors] requested_uids = self._check_uids(request.uid) self._log_request("rpc_forward", requested_uids, context) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) args_structure = metadata.get("args_structure") assert isinstance( points, (float, int) ), f"rpc_forward should have number of points as number or None, got {points}" hidden_states = await run_rpc_forward( *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, active_adapter=active_adapter, points=points, args_structure=args_structure, ) return runtime_pb2.ExpertResponse( tensors=self._serialize_outputs(hidden_states, requested_backends, metadata) ) async def rpc_forward_stream( self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext ) -> AsyncIterator[runtime_pb2.ExpertRequest]: async with timeout(self.request_timeout): # Parse requests and prepare backends uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context) requested_uids = self._check_uids(uid_str) self._log_request("rpc_forward_stream", requested_uids, context) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) args_structure = metadata.get("args_structure") assert isinstance( points, (float, int) ), f"rpc_forward_stream should have number of points as number or None, got {points}" hidden_states = await run_rpc_forward( *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, active_adapter=active_adapter, points=points, args_structure=args_structure, ) # Split the serialized_output for streaming and respond to client for tensor in self._serialize_outputs(hidden_states, requested_backends, metadata): for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE): yield runtime_pb2.ExpertResponse(tensors=[part]) def _serialize_outputs( self, hidden_states: torch.Tensor, requested_backends: Sequence[TransformerBackend], metadata: Dict[str, Any], ) -> Sequence[runtime_pb2.Tensor]: """Serialize forward outputs using either outputs_schema or custom user-specified schema""" assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor" outputs_schema = requested_backends[-1].outputs_schema if metadata.get("output_compression") is not None: assert isinstance(metadata["output_compression"], (list, tuple)), "output_compression must be a tuple/list" output_compression = tuple(metadata["output_compression"]) assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers" assert len(output_compression) == 1, f"output_compression tuple should have 1 element" else: output_compression = tuple(tensor.compression for tensor in outputs_schema) return [ serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True) for result, proto, compression in zip([hidden_states], outputs_schema, output_compression) ] async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: async with timeout(self.request_timeout): # Parse requests and prepare backends flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors] requested_uids = self._check_uids(request.uid) self._log_request("rpc_backward", requested_uids, context) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) args_structure = metadata.get("args_structure") assert isinstance( points, (float, int) ), f"rpc_backward should have number of points as number or None, got {points}" grads = await run_rpc_backward( *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, active_adapter=active_adapter, points=points, args_structure=args_structure, ) return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata)) async def rpc_backward_stream( self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext ) -> AsyncIterator[runtime_pb2.ExpertResponse]: async with timeout(self.request_timeout): uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context) requested_uids = self._check_uids(uids_header) self._log_request("rpc_backward_stream", requested_uids, context) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) args_structure = metadata.get("args_structure") assert isinstance( points, (float, int) ), f"rpc_backward_stream should have number of points as number or None, got {points}" grads = await run_rpc_backward( *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, active_adapter=active_adapter, points=points, args_structure=args_structure, ) # Split the serialized_grad_inputs for streaming and respond for tensor in self._serialize_grads(grads, requested_backends, metadata): for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE): yield runtime_pb2.ExpertResponse(tensors=[part]) def _get_active_adapter(self, metadata: dict) -> str: active_adapter = metadata.get("active_adapter", "") if active_adapter and (active_adapter not in self.adapters): raise KeyError(f"adapter {active_adapter} not found") return active_adapter def _serialize_grads( self, grads: Sequence[torch.Tensor], requested_backends: Sequence[TransformerBackend], metadata: Dict[str, Any], ) -> Sequence[runtime_pb2.Tensor]: """Serialize backward gradients w.r.t. inputs using either default schema or custom user-specified schema""" # Modify grad_inputs_schema to support grad_prompts assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize flat_grads_schema = tuple( nested_flatten((requested_backends[0].args_schema * len(grads), requested_backends[0].kwargs_schema)) ) # TODO generalize if metadata.get("output_compression") is not None: assert isinstance(metadata["output_compression"], (list, tuple)), "output_compression must be a tuple/list" output_compression = tuple(metadata["output_compression"]) assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers" assert len(output_compression) == len(grads), f"output_compression should have {len(grads)} elements" else: output_compression = tuple(tensor.compression for tensor in flat_grads_schema) return [ serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True) for result, proto, compression in zip(grads, flat_grads_schema, output_compression) ] def _check_uids(self, uids: str) -> Tuple[ModuleUID, ...]: """Check that the first request to rpc_inference is valid""" uids = (uids or "").split(CHAIN_DELIMITER) if not uids: raise RuntimeError("User did not provide any uids") for uid in uids: if uid not in self.module_backends: raise RuntimeError(f"Remote peer does not serve {uid}") return tuple(uids) @contextlib.asynccontextmanager async def _allocate_cache( self, backends: Sequence[TransformerBackend], *, batch_size: int, max_length: int, timeout: Optional[float], ) -> Sequence[Sequence[Handle]]: """ Allocate memory cache for all transformer blocks, return cache handle :returns: a list of {len(backends)} elements, where i-th element is a tuple of cache handles for i-th backend """ descriptors = [backend.get_inference_cache_descriptors(batch_size, max_length) for backend in backends] async with backends[0].memory_cache.allocate_cache(*chain(*descriptors), timeout=timeout) as handles: yield nested_pack(handles, descriptors) def _log_request( self, method: str, uids: Optional[Sequence[ModuleUID]], context: P2PContext, *, debug: Optional[str] = None, warning: Optional[str] = None, ) -> None: if uids is not None: friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid] friendly_uids = [int(uid) for uid in friendly_uids if uid.isdigit()] friendly_uids = f"{min(friendly_uids)}:{max(friendly_uids) + 1}" if friendly_uids else uids else: friendly_uids = "n/a" friendly_remote_id = "..." + str(context.remote_id)[-6:] message = f"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})" if warning is not None: logger.warning(f"{message}: {warning}") elif debug is not None: logger.debug(f"{message}: {debug}") else: logger.info(message) async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo: """Return metadata about stored block uids and current load""" backend = self.module_backends[request.uid] if request.uid else next(iter(self.module_backends.values())) result = { "version": petals.__version__, "dht_client_mode": self.dht.client_mode, CACHE_TOKENS_AVAILABLE: backend.memory_cache.bytes_left // max(backend.cache_bytes_per_token.values()), } if request.uid: block_info = self.module_backends[request.uid].get_info() common_keys = set(result.keys()) & set(block_info.keys()) if common_keys: raise RuntimeError(f"The block's rpc_info has keys reserved for the server's rpc_info: {common_keys}") result.update(block_info) return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result)) ================================================ FILE: src/petals/server/memory_cache.py ================================================ """ A pytorch memory cache that can be allocated by ConnectionHandler (on cpu) and used over multiple calls to Runtime. For now, the only purpose of this code is to ensure that allocated memory will be deleted properly. """ import asyncio import contextlib import ctypes import multiprocessing as mp import os import time from typing import AsyncContextManager, Dict, Optional, Sequence import async_timeout import torch from hivemind.utils import TensorDescriptor, enter_asynchronously, get_logger from petals.data_structures import Handle from petals.utils.asyncio import shield_and_wait from petals.utils.misc import get_size_in_bytes logger = get_logger(__name__) class MemoryCache: """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs""" def __init__(self, max_size_bytes: Optional[int], max_alloc_timeout: Optional[float] = None): self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1) self.max_alloc_timeout = max_alloc_timeout self._lock_metadata = mp.Lock() self._current_size = mp.Value(ctypes.c_int64, 0, lock=False) self._enqueued_size = mp.Value(ctypes.c_int64, 0, lock=True) self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False) self._allocated_tensors: Dict[Handle, torch.Tensor] = {} self.runtime_pid = os.getpid() self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False) # any ConnectionHandler -> runtime self._lock_acquire_memory = mp.Lock() self._memory_freed_event = mp.Event() @property def current_size_bytes(self) -> int: return self._current_size.value @current_size_bytes.setter def current_size_bytes(self, value: int): self._current_size.value = value @property def enqueued_size_bytes(self) -> int: return self._enqueued_size.value @enqueued_size_bytes.setter def enqueued_size_bytes(self, value: int): self._enqueued_size.value = value @property def bytes_left(self) -> int: return self.max_size_bytes - self.current_size_bytes @property def handle_counter(self) -> int: return self._handle_counter.value @handle_counter.setter def handle_counter(self, value: int): self._handle_counter.value = value @contextlib.asynccontextmanager async def allocate_cache( self, *descriptors: TensorDescriptor, timeout: float ) -> AsyncContextManager[Sequence[Handle]]: """ Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed. :param descriptors: one or more tensors tensor of this size, dtype, etc :param timeout: optional maximum time to wait for cache allocation; None (default) means no time limit :note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices; if not, it will count maximum tensor allocation across devices for the purposes of size limit :note: This function should be called by connection handlers, it can be called concurrently from multiple processes. Furthermore, it can be called concurrently with at most one use_cache call in runtime. """ assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime" assert all(descr.device is not None for descr in descriptors), "please specify allocated devices" if self.max_alloc_timeout is not None: timeout = min(timeout, self.max_alloc_timeout) max_alloc_size = self.get_allocation_size(*descriptors) gib = 1024**3 cur_size, max_size = self.current_size_bytes, self.max_size_bytes friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf" logger.info( f"rpc_inference.wait_for_alloc(size={max_alloc_size / gib:.2f} GiB), " f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)" ) alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors, timeout=timeout)) try: handles = await shield_and_wait(alloc_task) logger.info(f"rpc_inference.alloc_done(size={max_alloc_size / gib:.2f} GiB)") yield handles finally: self._free(max_alloc_size, alloc_task) @staticmethod def get_allocation_size(*descriptors: TensorDescriptor) -> int: """Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum""" alloc_size_by_device = {} for descr in descriptors: tensor_size = descr.numel() * get_size_in_bytes(descr.dtype) alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size return max(alloc_size_by_device.values()) async def _schedule_alloc( self, alloc_size: int, *descriptors: TensorDescriptor, timeout: Optional[float] ) -> Sequence[Handle]: """ This method should be called inside asyncio.shield() because: - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation """ try: async with self._wait_for_free_memory(alloc_size, timeout): with self._lock_metadata: handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors))) self.current_size_bytes += alloc_size self.handle_counter += len(handles) # note: this will eventually overflow and it is okay self._pipe_send.send((handles, descriptors)) return handles except TimeoutError: raise AllocationFailed(f"Could not allocate {alloc_size} (timeout={timeout})") @contextlib.asynccontextmanager async def _wait_for_free_memory(self, alloc_size: int, timeout: Optional[float]): start_time = time.perf_counter() loop = asyncio.get_event_loop() with self._enqueued_size.get_lock(): self._enqueued_size.value += alloc_size allocated = False try: context_manager = async_timeout.timeout(timeout) if timeout != 0 else contextlib.AsyncExitStack() # contextlib.AsyncExitStack() is used as a null context here async with context_manager: if timeout == 0 and self.current_size_bytes + self.enqueued_size_bytes > self.max_size_bytes: raise AllocationFailed(f"Could not allocate {alloc_size} bytes immediately: out of memory") async with enter_asynchronously(self._lock_acquire_memory): if self.current_size_bytes + alloc_size > self.max_size_bytes: if timeout == 0: raise AllocationFailed(f"Could not allocate {alloc_size} bytes immediately: out of memory") elapsed_time = time.perf_counter() - start_time remaining_timeout = max(0.0, timeout - elapsed_time) if timeout is not None else None await loop.run_in_executor(None, self._wait_until_available, alloc_size, remaining_timeout) allocated = True with self._enqueued_size.get_lock(): self._enqueued_size.value -= alloc_size yield except asyncio.TimeoutError: raise AllocationFailed(f"Could not allocate {alloc_size} within {timeout} seconds") finally: if not allocated: with self._enqueued_size.get_lock(): self._enqueued_size.value -= alloc_size def _free(self, alloc_size: int, alloc_task: asyncio.Task): if alloc_task.exception() is not None: return handles = alloc_task.result() with self._lock_metadata: self._pipe_send.send((handles, None)) # signal runtime to free these handles self.current_size_bytes -= alloc_size self._memory_freed_event.set() def _wait_until_available(self, allocated_size: int, timeout: Optional[float] = None): # note: this function should only be called inside _lock_acquire_memory! if allocated_size > self.max_size_bytes: raise AllocationFailed( f"Could not allocate {allocated_size} bytes, max cache size = {self.max_size_bytes} bytes" ) timeout = timeout if timeout != float("inf") else None deadline = None if timeout is None else time.perf_counter() + timeout while self.current_size_bytes + allocated_size > self.max_size_bytes: remaining_time = None if timeout is None else deadline - time.perf_counter() if not self._memory_freed_event.wait(remaining_time): raise AllocationFailed( f"Server's attention cache is full, failed to allocate {allocated_size} bytes in {timeout} seconds" ) self._memory_freed_event.clear() @contextlib.contextmanager def use_cache(self, *handles: Handle) -> Sequence[torch.Tensor]: """ Return one or more tensors previously allocated with allocate_cache, :note: This method is called by ModuleBackend in runtime: a single process with NO process parallelism. However, runtime may call use_cache concurrently with one or more connection handlers calling allocate_cache """ assert os.getpid() == self.runtime_pid # note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here # read creation/deletion requests from connection handlers while self._pipe_recv.poll(): recv_handles, recv_data = self._pipe_recv.recv() if recv_data is not None: # create new tensors assert len(recv_handles) == len(recv_data) for handle, descr in zip(recv_handles, recv_data): self._allocated_tensors[handle] = descr.make_zeros() assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})" else: # delete tensors by handle for handle in recv_handles: if handle not in self._allocated_tensors: logger.warning( f"Sanity check failed: asked to delete handle {handle}, but there is no such handle" ) self._allocated_tensors.pop(handle, None) yield tuple(self._allocated_tensors[handle] for handle in handles) class AllocationFailed(Exception): pass ================================================ FILE: src/petals/server/reachability.py ================================================ import asyncio import math import threading import time from concurrent.futures import Future from contextlib import asynccontextmanager from functools import partial from typing import Optional import requests from hivemind.dht import DHT, DHTNode from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase from hivemind.proto import dht_pb2 from hivemind.utils import get_logger from petals.constants import REACHABILITY_API_URL logger = get_logger(__name__) def validate_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = 15) -> None: """verify that your peer is reachable from a (centralized) validator, whether directly or through a relay""" for attempt_no in range(math.floor(wait_time / retry_delay) + 1): try: r = requests.get(f"{REACHABILITY_API_URL}/api/v1/is_reachable/{peer_id}", timeout=10) r.raise_for_status() response = r.json() if response["success"]: logger.info("Server is reachable from the Internet. It will appear at https://health.petals.dev soon") return if attempt_no == 0: # Usually, libp2p manages to set up relays before we finish loading blocks. # In other cases, we may need to wait for up to `wait_time` seconds before it's done. logger.info("Detected a NAT or a firewall, connecting to libp2p relays. This takes a few minutes") time.sleep(retry_delay) except Exception as e: logger.warning(f"Skipping reachability check because health.petals.dev is down: {repr(e)}") return raise RuntimeError( f"Server has not become reachable from the Internet:\n\n" f"{response['message']}\n\n" f"You need to fix your port forwarding and/or firewall settings. How to do that:\n\n" f" 1. Choose a specific port for the Petals server, for example, 31337.\n" f" 2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\n" f" 3. Add these arguments to explicitly announce your IP address and port to other peers:\n" f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n" f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n" ) def check_direct_reachability(max_peers: int = 5, threshold: float = 0.5, **kwargs) -> Optional[bool]: """test if your peer is accessible by others in the swarm with the specified network options in **kwargs""" async def _check_direct_reachability(): target_dht = await DHTNode.create(client_mode=True, **kwargs) try: protocol = ReachabilityProtocol(probe=target_dht.protocol.p2p) async with protocol.serve(target_dht.protocol.p2p): successes = requests = 0 for remote_peer in list(target_dht.protocol.routing_table.peer_id_to_uid.keys()): probe_available = await protocol.call_check(remote_peer=remote_peer, check_peer=target_dht.peer_id) if probe_available is None: continue # remote peer failed to check probe successes += probe_available requests += 1 if requests >= max_peers: break logger.debug(f"Direct reachability: {successes}/{requests}") return (successes / requests) >= threshold if requests > 0 else None finally: await target_dht.shutdown() return RemoteExpertWorker.run_coroutine(_check_direct_reachability()) STRIPPED_PROBE_ARGS = dict( dht_mode="client", use_relay=False, auto_nat=False, nat_port_map=False, no_listen=True, startup_timeout=60 ) class ReachabilityProtocol(ServicerBase): """Mini protocol to test if a locally running peer is accessible by other devices in the swarm""" def __init__(self, *, probe: Optional[P2P] = None, wait_timeout: float = 5.0): self.probe = probe self.wait_timeout = wait_timeout self._event_loop = self._stop = None async def call_check(self, remote_peer: PeerID, *, check_peer: PeerID) -> Optional[bool]: """Returns True if remote_peer can reach check_peer, False if it cannot, None if it did not respond""" try: request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=check_peer.to_bytes())) timeout = self.wait_timeout if check_peer == remote_peer else self.wait_timeout * 2 response = await self.get_stub(self.probe, remote_peer).rpc_check(request, timeout=timeout) logger.debug(f"call_check(remote_peer={remote_peer}, check_peer={check_peer}) -> {response.available}") return response.available except Exception as e: logger.debug(f"Requested {remote_peer} to check {check_peer}, but got:", exc_info=True) return None async def rpc_check(self, request: dht_pb2.PingRequest, context: P2PContext) -> dht_pb2.PingResponse: """Help another peer to check its reachability""" response = dht_pb2.PingResponse(available=True) check_peer = PeerID(request.peer.node_id) if check_peer != context.local_id: # remote peer wants us to check someone other than ourselves response.available = await self.call_check(check_peer, check_peer=check_peer) is True logger.info( f"reachability.rpc_check(remote_peer=...{str(context.remote_id)[-6:]}, " f"check_peer=...{str(check_peer)[-6:]}) -> {response.available}" ) return response @asynccontextmanager async def serve(self, p2p: P2P): try: await self.add_p2p_handlers(p2p) yield self finally: await self.remove_p2p_handlers(p2p) @classmethod def attach_to_dht(cls, dht: DHT, await_ready: bool = False, **kwargs) -> Optional["ReachabilityProtocol"]: protocol = cls(**kwargs) ready = Future() async def _serve_with_probe(): try: common_p2p = await dht.replicate_p2p() protocol._event_loop = asyncio.get_event_loop() protocol._stop = asyncio.Event() initial_peers = [str(addr) for addr in await common_p2p.get_visible_maddrs(latest=True)] for info in await common_p2p.list_peers(): initial_peers.extend(f"{addr}/p2p/{info.peer_id}" for addr in info.addrs) protocol.probe = await P2P.create(initial_peers, **STRIPPED_PROBE_ARGS) ready.set_result(True) logger.debug("Reachability service started") async with protocol.serve(common_p2p): await protocol._stop.wait() except Exception as e: logger.debug("Reachability service failed:", exc_info=True) if not ready.done(): ready.set_exception(e) finally: if protocol is not None and protocol.probe is not None: await protocol.probe.shutdown() logger.debug("Reachability service shut down") threading.Thread(target=partial(asyncio.run, _serve_with_probe()), daemon=True).start() if await_ready: ready.result() # Propagates startup exceptions, if any return protocol def shutdown(self): if self._event_loop is not None and self._stop is not None: self._event_loop.call_soon_threadsafe(self._stop.set) ================================================ FILE: src/petals/server/server.py ================================================ from __future__ import annotations import gc import math import multiprocessing as mp import os import random import sys import threading import time from typing import Dict, List, Optional, Sequence, Union import hivemind import psutil import torch import torch.mps from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time from hivemind.moe.server.layers import add_custom_models_from_file from hivemind.moe.server.runtime import Runtime from hivemind.proto.runtime_pb2 import CompressionType from hivemind.utils.logging import get_logger from transformers import PretrainedConfig import petals from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState, parse_uid from petals.server import block_selection from petals.server.backend import TransformerBackend, merge_inference_pools_inplace from petals.server.block_utils import get_block_size, resolve_block_dtype from petals.server.from_pretrained import load_pretrained_block from petals.server.handler import TransformerConnectionHandler from petals.server.memory_cache import MemoryCache from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability from petals.server.throughput import get_dtype_name, get_server_throughput from petals.utils.auto_config import AutoDistributedConfig from petals.utils.convert_block import QuantType, check_device_balance, convert_block from petals.utils.dht import declare_active_modules, get_remote_module_infos from petals.utils.misc import get_size_in_bytes from petals.utils.ping import PingAggregator from petals.utils.random import sample_up_to from petals.utils.version import get_compatible_model_repo logger = get_logger(__name__) class Server: """ Runs ModuleContainer, periodically checks that the network is balanced, restarts the ModuleContainer with other layers if the imbalance is significant """ def __init__( self, *, initial_peers: List[str], dht_prefix: Optional[str], converted_model_name_or_path: str, public_name: Optional[str] = None, throughput: Union[float, str], num_blocks: Optional[int] = None, block_indices: Optional[str] = None, num_handlers: int = 8, inference_max_length: Optional[int] = None, min_batch_size: int = 1, max_batch_size: Optional[int] = None, max_chunk_size_bytes: int = 256 * 1024 * 1024, max_alloc_timeout: float = 600, attn_cache_tokens: Optional[int] = None, torch_dtype: str = "auto", revision: Optional[str] = None, cache_dir: Optional[str] = None, max_disk_space: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, compression=CompressionType.NONE, stats_report_interval: Optional[int] = None, custom_module_path=None, update_period: float = 60, expiration: Optional[float] = None, request_timeout: float = 3 * 60, session_timeout: float = 30 * 60, step_timeout: float = 5 * 60, prefetch_batches: int = 1, sender_threads: int = 1, balance_quality: float = 0.75, mean_balance_check_period: float = 120, mean_block_selection_delay: float = 5, token: Optional[Union[str, bool]] = None, quant_type: Optional[QuantType] = None, tensor_parallel_devices: Optional[Sequence[torch.device]] = None, skip_reachability_check: bool = False, reachable_via_relay: Optional[bool] = None, use_relay: bool = True, use_auto_relay: bool = True, adapters: Sequence[str] = (), **kwargs, ): """Create a server with one or more bloom blocks. See run_server.py for documentation.""" converted_model_name_or_path = get_compatible_model_repo(converted_model_name_or_path) self.converted_model_name_or_path = converted_model_name_or_path self.num_handlers = num_handlers self.compression = compression self.stats_report_interval, self.update_period = stats_report_interval, update_period self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads self.revision, self.token = revision, token if custom_module_path is not None: add_custom_models_from_file(custom_module_path) self.block_config = AutoDistributedConfig.from_pretrained( converted_model_name_or_path, use_auth_token=token, revision=revision, ) if dht_prefix is None: dht_prefix = self.block_config.dht_prefix assert UID_DELIMITER not in dht_prefix and CHAIN_DELIMITER not in dht_prefix, ( f"DHT prefix should not contain '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'. " f"Please specify another --dht_prefix manually when starting a server" ) self.dht_prefix = dht_prefix if expiration is None: expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS) self.expiration = expiration self.request_timeout = request_timeout self.session_timeout, self.step_timeout = session_timeout, step_timeout self.module_uids = [ f"{self.dht_prefix}{UID_DELIMITER}{block_index}" for block_index in range(self.block_config.num_hidden_layers) ] if reachable_via_relay is None: is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs) reachable_via_relay = is_reachable is False # if can't check reachability (returns None), run a full peer logger.info(f"This server is accessible {'via relays' if reachable_via_relay else 'directly'}") self.dht = DHT( initial_peers=initial_peers, start=True, num_workers=self.block_config.num_hidden_layers, use_relay=use_relay, use_auto_relay=use_auto_relay, client_mode=reachable_via_relay, **kwargs, ) self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not reachable_via_relay else None visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()] if initial_peers == PUBLIC_INITIAL_PEERS: logger.info("Connecting to the public swarm") else: logger.info(f"Connecting to a private swarm, initial peers: {initial_peers}") logger.info(f"Running a server on {visible_maddrs_str}") self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS if device is None: if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" device = torch.device(device) if device.type == "cuda" and device.index is None: device = torch.device(device.type, index=0) self.device = device torch_dtype = resolve_block_dtype(self.block_config, DTYPE_MAP[torch_dtype]) if device.type == "cpu" and torch_dtype == torch.float16: raise ValueError( f"Type float16 is not supported on CPU. Please use --torch_dtype float32 or --torch_dtype bfloat16" ) if device.type == "mps" and torch_dtype == torch.bfloat16: logger.warning(f"Type bfloat16 is not supported on MPS, using float16 instead") torch_dtype = torch.float16 self.torch_dtype = torch_dtype if tensor_parallel_devices is None: tensor_parallel_devices = (device,) self.tensor_parallel_devices = tuple(map(torch.device, tensor_parallel_devices)) if len(self.tensor_parallel_devices) > 1: logger.info(f"Model weights will be split between {', '.join(tensor_parallel_devices)}") check_device_balance(self.tensor_parallel_devices) if quant_type is None: quant_type = QuantType.NF4 if device.type == "cuda" else QuantType.NONE self.quant_type = quant_type logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format") is_multiquery_attn = self.block_config.num_key_value_groups > 1 if max_batch_size is None: max_batch_size = 8192 if is_multiquery_attn else 2048 if inference_max_length is None: inference_max_length = 8192 if is_multiquery_attn else 2048 self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size self.inference_max_length = inference_max_length self.max_chunk_size_bytes = max_chunk_size_bytes self.max_alloc_timeout = max_alloc_timeout # For attention cache in GPU or RAM if attn_cache_tokens is None: attn_cache_tokens = 16384 if is_multiquery_attn else 4096 cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens cache_values_per_block //= self.block_config.num_key_value_groups self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype) # For disk cache self.cache_dir = cache_dir self.max_disk_space = max_disk_space self.adapters = adapters assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both" if num_blocks is None and block_indices is None: num_blocks = self._choose_num_blocks() if num_blocks is not None: num_blocks = min(num_blocks, self.block_config.num_hidden_layers) if block_indices is not None: try: start_block, end_block = [int(index.strip()) for index in block_indices.split(":")] except Exception as e: raise ValueError(f"Failed to parse `--block_indices {block_indices}`, must be start:end (e.g. 0:18)") block_indices = range(start_block, end_block) num_blocks = len(block_indices) self.strict_block_indices, self.num_blocks = block_indices, num_blocks gib = 1024**3 self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB") assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"] if throughput in ["auto", "eval", "dry_run"]: force_eval = throughput in ["eval", "dry_run"] throughput_info = get_server_throughput( converted_model_name_or_path, self.block_config, device, torch_dtype, num_blocks=num_blocks, quant_type=quant_type, tensor_parallel_devices=self.tensor_parallel_devices, reachable_via_relay=reachable_via_relay, force_eval=force_eval, cache_dir=cache_dir, ) if throughput == "dry_run": logger.info("Finished estimating throughput, exiting") sys.exit(0) else: throughput_info = {"throughput": throughput} self.server_info = ServerInfo( state=ServerState.JOINING, public_name=public_name, version=petals.__version__, adapters=tuple(adapters), torch_dtype=str(torch_dtype).replace("torch.", ""), quant_type=quant_type.name.lower(), using_relay=reachable_via_relay, **throughput_info, ) self.model_info = ModelInfo(num_blocks=self.block_config.num_hidden_layers) if not os.path.isdir(converted_model_name_or_path): self.model_info.repository = "https://huggingface.co/" + converted_model_name_or_path self.balance_quality = balance_quality self.mean_balance_check_period = mean_balance_check_period self.mean_block_selection_delay = mean_block_selection_delay self.module_container = None self.stop = threading.Event() def _choose_num_blocks(self) -> int: assert self.device.type in ("cuda", "mps"), ( "GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. " "CPU-only servers in the public swarm are discouraged since they are much slower" ) num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1 if num_devices > 1: assert self.device.type == "cuda", f"Tensor parallelism is not supported on {self.device.type.upper()}" memory_per_device = tuple( torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices ) total_memory = min(memory_per_device) * num_devices if max(memory_per_device) / min(memory_per_device) > 1.5: raise ValueError( "GPU devices have highly uneven memory, which makes tensor parallelism inefficient. " "Please launch individual servers on each GPU or set --num_blocks manually to " "override this exception." ) elif self.device.type == "cuda": total_memory = torch.cuda.get_device_properties(self.device).total_memory else: total_memory = psutil.virtual_memory().total gib = 1024**3 # Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models) autograd_memory = 2 * gib * num_devices / 14336 * self.block_config.hidden_size block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, quant_type=self.quant_type) total_memory_per_block = block_size + self._cache_bytes_per_block if self.adapters: # Delay import of petals.utils.peft to avoid unnecessary import of bitsandbytes from petals.utils.peft import estimate_adapter_memory_per_block total_memory_per_block += estimate_adapter_memory_per_block( self.block_config, self.torch_dtype, self.adapters, token=self.token, cache_dir=self.cache_dir, max_disk_space=self.max_disk_space, ) num_blocks = math.floor((total_memory - autograd_memory) / total_memory_per_block) assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block" num_blocks = min(num_blocks, self.block_config.num_hidden_layers) logger.info( f"Server will fill your GPU memory with {num_blocks} transformer blocks. " f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually" ) return num_blocks def run(self): while True: block_indices = self._choose_blocks() self.module_container = ModuleContainer.create( dht=self.dht, dht_prefix=self.dht_prefix, converted_model_name_or_path=self.converted_model_name_or_path, block_config=self.block_config, attn_cache_bytes=self.attn_cache_bytes, server_info=self.server_info, model_info=self.model_info, block_indices=block_indices, num_handlers=self.num_handlers, min_batch_size=self.min_batch_size, max_batch_size=self.max_batch_size, max_chunk_size_bytes=self.max_chunk_size_bytes, max_alloc_timeout=self.max_alloc_timeout, inference_max_length=self.inference_max_length, torch_dtype=self.torch_dtype, cache_dir=self.cache_dir, max_disk_space=self.max_disk_space, device=self.device, compression=self.compression, stats_report_interval=self.stats_report_interval, update_period=self.update_period, expiration=self.expiration, request_timeout=self.request_timeout, session_timeout=self.session_timeout, step_timeout=self.step_timeout, prefetch_batches=self.prefetch_batches, sender_threads=self.sender_threads, revision=self.revision, token=self.token, quant_type=self.quant_type, tensor_parallel_devices=self.tensor_parallel_devices, should_validate_reachability=self.should_validate_reachability, start=True, ) try: self.module_container.ready.wait() while True: timeout = random.random() * 2 * self.mean_balance_check_period if self.stop.wait(timeout): return if not self.module_container.is_healthy(): logger.warning("One of subprocesses crashed, restarting the server") break if self._should_choose_other_blocks(): logger.info("Swarm is imbalanced, server will load other blocks") break # Stop serving this set of modules finally: self.module_container.shutdown() self._clean_memory_and_fds() def _clean_memory_and_fds(self): self.module_container = None gc.collect() # In particular, this closes unused file descriptors if self.device.type == "cuda": torch.cuda.empty_cache() allocated_vram = torch.cuda.memory_allocated(self.device) reserved_vram = torch.cuda.memory_reserved(self.device) gib = 1024**3 logger.info( f"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, " f"{reserved_vram / gib:.1f} GiB reserved memory" ) elif self.device.type == "mps": torch.mps.empty_cache() def _choose_blocks(self) -> List[int]: if self.strict_block_indices is not None: return self.strict_block_indices # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time, # this delay decreases the probability of a race condition while choosing the best blocks to serve. time.sleep(random.random() * 2 * self.mean_block_selection_delay) module_infos = get_remote_module_infos(self.dht, self.module_uids, latest=True) return block_selection.choose_best_blocks(self.num_blocks, module_infos) def _should_choose_other_blocks(self) -> bool: if self.strict_block_indices is not None: return False module_infos = get_remote_module_infos(self.dht, self.module_uids, latest=True) return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality) def shutdown(self, timeout: Optional[float] = 5): self.stop.set() if self.module_container is not None and self.module_container.is_alive(): self.module_container.join(timeout) if self.reachability_protocol is not None: self.reachability_protocol.shutdown() self.dht.shutdown() self.dht.join() class ModuleContainer(threading.Thread): """Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT.""" # noinspection PyMethodOverriding @classmethod def create( cls, *, dht: DHT, dht_prefix: str, converted_model_name_or_path: str, block_config: PretrainedConfig, attn_cache_bytes: int, server_info: ServerInfo, model_info: ModelInfo, block_indices: List[int], min_batch_size: int, max_batch_size: int, max_chunk_size_bytes: int, max_alloc_timeout: float, torch_dtype: torch.dtype, cache_dir: str, max_disk_space: int, device: Union[str, torch.device], compression: CompressionType, update_period: float, expiration: Optional[float], revision: Optional[str], token: Optional[Union[str, bool]], quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], should_validate_reachability: bool, **kwargs, ) -> ModuleContainer: module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices] memory_cache = MemoryCache(attn_cache_bytes, max_alloc_timeout) server_info.state = ServerState.JOINING dht_announcer = ModuleAnnouncerThread( module_uids, dht, server_info, model_info, block_config=block_config, memory_cache=memory_cache, update_period=update_period, expiration=expiration, daemon=True, ) dht_announcer.start() logger.info(f"Announced that blocks {block_indices} are joining") assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices) blocks = {} try: for module_uid, block_index in zip(module_uids, block_indices): block = load_pretrained_block( converted_model_name_or_path, block_index, config=block_config, torch_dtype=torch_dtype, revision=revision, token=token, cache_dir=cache_dir, max_disk_space=max_disk_space, ) block = convert_block( block, block_index, block_config, tensor_parallel_devices, device, quant_type, adapters=server_info.adapters, freeze=True, token=token, cache_dir=cache_dir, max_disk_space=max_disk_space, ) blocks[module_uid] = TransformerBackend( module_uid, block, config=block_config, memory_cache=memory_cache, backend_dtype=torch_dtype, max_chunk_size_bytes=max_chunk_size_bytes, args_schema=( BatchTensorDescriptor( 1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression ), ), kwargs_schema={}, outputs_schema=( BatchTensorDescriptor( 1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression ), ), min_batch_size=min_batch_size, max_batch_size=max_batch_size, ) merge_inference_pools_inplace(blocks) if should_validate_reachability: validate_reachability(dht.peer_id) except: logger.debug("Shutting down backends") for backend in blocks.values(): backend.shutdown() dht_announcer.announce(ServerState.OFFLINE) logger.info(f"Announced that blocks {module_uids} are offline") raise return cls( dht, dht_prefix, blocks, dht_announcer=dht_announcer, server_info=server_info, update_period=update_period, expiration=expiration, **kwargs, ) def __init__( self, dht: DHT, dht_prefix: str, module_backends: Dict[str, TransformerBackend], *, inference_max_length: int, num_handlers: int, dht_announcer: ModuleAnnouncerThread, server_info: ServerInfo, update_period: float, expiration: Optional[float] = None, request_timeout: float, session_timeout: float, step_timeout: float, start: bool, **kwargs, ): super().__init__() self.dht, self.module_backends = dht, module_backends self.server_info, self.update_period, self.expiration = server_info, update_period, expiration handler_event_queues = [mp.Queue() for _ in range(num_handlers)] self.conn_handlers = [ TransformerConnectionHandler( dht, self.module_backends, adapters=server_info.adapters, dht_prefix=dht_prefix, handler_event_queues=handler_event_queues, handler_index=i, inference_max_length=inference_max_length, request_timeout=request_timeout, session_timeout=session_timeout, step_timeout=step_timeout, quant_type=QuantType[server_info.quant_type.upper()], ) for i in range(num_handlers) ] self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs) # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed. dht_announcer.announce(ServerState.ONLINE) self.dht_announcer = dht_announcer if start: self.run_in_background(await_ready=True) def run(self): """ Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers, runs Runtime (self.runtime) to process incoming requests. """ for handler in self.conn_handlers: handler.run_in_background() self.runtime.run() def run_in_background(self, await_ready=True, timeout=None): """ Starts ModuleContainer in a background thread. if await_ready, this method will wait until the container is ready to process incoming requests or for :timeout: seconds max. """ self.start() if await_ready and not self.ready.wait(timeout=timeout): raise TimeoutError("ModuleContainer didn't notify .ready in {timeout} seconds") @property def ready(self) -> mp.synchronize.Event: """ An event (multiprocessing.Event) that is set when the container is ready to process requests. Example ======= >>> container.start() >>> container.ready.wait(timeout=10) >>> print("Container ready" if container.ready.is_set() else "Container didn't start in 10 seconds") """ return self.runtime.ready # mp.Event that is true if self is ready to process batches def is_healthy(self) -> bool: return all(handler.is_alive() for handler in self.conn_handlers) and all( pool.is_alive() for pool in self.runtime.pools ) def shutdown(self): """ Gracefully terminate the container, process-safe. Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes. If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL). """ self.dht_announcer.announce(ServerState.OFFLINE) logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline") self.ready.clear() logger.debug("Shutting down connection handlers") for handler in self.conn_handlers: handler.shutdown() logger.debug(f"Shutting down pools") for pool in self.runtime.pools: if pool.is_alive(): pool.shutdown() logger.debug(f"Shutting down runtime") self.runtime.shutdown() logger.debug("Shutting down backends") for backend in self.module_backends.values(): backend.shutdown() logger.info("Module container shut down successfully") class ModuleAnnouncerThread(threading.Thread): """Periodically announces that this container hosts the specified modules, visible to all DHT peers""" def __init__( self, module_uids: List[str], dht: DHT, server_info: ServerInfo, model_info: ModelInfo, *, block_config: PretrainedConfig, memory_cache: MemoryCache, update_period: float, expiration: float, max_pinged: int = 5, **kwargs, ): super().__init__(**kwargs) self.module_uids = module_uids self.dht = dht self.server_info = server_info self.model_info = model_info self.memory_cache = memory_cache self.bytes_per_token = block_config.hidden_size * get_size_in_bytes(DTYPE_MAP[server_info.torch_dtype]) self.bytes_per_token //= block_config.num_key_value_groups self.update_period = update_period self.expiration = expiration self.trigger = threading.Event() self.dht_prefix = parse_uid(module_uids[0])[0] block_indices = [parse_uid(uid)[1] for uid in module_uids] self.server_info.start_block = min(block_indices) self.server_info.end_block = max(block_indices) + 1 self.max_pinged = max_pinged self.next_uids = [ f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(self.server_info.start_block + 1, self.server_info.end_block + 1) ] self.ping_aggregator = PingAggregator(self.dht) def run(self) -> None: while True: start_time = time.perf_counter() self.server_info.cache_tokens_left = self.memory_cache.bytes_left // self.bytes_per_token if self.server_info.state != ServerState.OFFLINE: self._ping_next_servers() self.server_info.next_pings = { peer_id.to_base58(): rtt for peer_id, rtt in self.ping_aggregator.to_dict().items() } else: self.server_info.next_pings = None # No need to ping if we're disconnecting declare_active_modules( self.dht, self.module_uids, self.server_info, expiration_time=get_dht_time() + self.expiration, ) if self.server_info.state == ServerState.OFFLINE: break if not self.dht_prefix.startswith("_"): # Not private self.dht.store( key="_petals.models", subkey=self.dht_prefix, value=self.model_info.to_dict(), expiration_time=get_dht_time() + self.expiration, ) delay = self.update_period - (time.perf_counter() - start_time) if delay < 0: logger.warning( f"Declaring blocks to DHT takes more than --update_period, consider increasing it (currently {self.update_period})" ) self.trigger.wait(max(delay, 0)) self.trigger.clear() def announce(self, state: ServerState) -> None: self.server_info.state = state self.trigger.set() if state == ServerState.OFFLINE: self.join() def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]: module_infos = get_remote_module_infos(self.dht, self.next_uids, latest=True) middle_servers = {peer_id for info in module_infos[:-1] for peer_id in info.servers} pinged_servers = set(sample_up_to(middle_servers, self.max_pinged)) pinged_servers.discard(self.dht.peer_id) # Sample servers hosting the block after the last one (most likely continuations) separately pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged)) self.ping_aggregator.ping(list(pinged_servers)) class RuntimeWithDeduplicatedPools(Runtime): """A version of hivemind.moe.server.runtime.Runtime that allows multiple backends to reuse a task pool""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.pools = tuple(set(self.pools)) ================================================ FILE: src/petals/server/task_pool.py ================================================ import ctypes import multiprocessing as mp import threading import time from concurrent.futures._base import PENDING from dataclasses import dataclass, field from queue import PriorityQueue from typing import Any, List, Optional, Sequence, Tuple, Union import torch from hivemind import get_logger from hivemind.utils.mpfuture import ALL_STATES, MPFuture logger = get_logger(__name__) @dataclass(order=True, frozen=True) class Task: priority: float time_submitted: float future: MPFuture = field(compare=False) args: Sequence[torch.Tensor] = field(compare=False) @property def uid(self) -> int: return self.future._uid class PrioritizedTaskPool(threading.Thread): """ Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then returns results (or exception) to the corresponding ConnectionHandler. Runs a background process. A single PrioritizedTaskPool services a specific function (e.g. layer1.forward, layer2.forward or layer1.backward) :note: unlike hivemind.moe TaskPool, this pool does *not* combine incoming requests into batches. This would require grouping requests of different length. :param process_func: function to be applied to every formed batch; called by Runtime Note that process_func should accept only positional args (Tensors) and return a flat tuple of Tensors :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs) Measured in the total number of tokens (i.e. batch size * sequence length) :param name: pool name, used for logging :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more :param device: if specified, input tensors will be moved to that device by default :param start: if True, start automatically at the end of __init__ """ def __init__( self, process_func: callable, max_batch_size: int, name: str, min_batch_size=1, device: Optional[torch.device] = None, daemon=True, start=False, ): super().__init__(daemon=daemon, name=name) self.process_func = process_func # the lower the priority is, the more urgent it is to process this pool self._priority = mp.Value(ctypes.c_double, 1.0) self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size self.device = device self.submitted_tasks = mp.SimpleQueue() # interaction with ConnectionHandlers self._ordered_tasks = PriorityQueue() # interaction with Runtime - only valid inside Runtime self._dispatched_tasks = {} self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False) self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0) self.priority = float("inf"), float("inf") # (first task priority, first task timestamp) if start: self.start() def run(self): """Read tasks from incoming queue and put them into a local priority queue""" while True: task = self.submitted_tasks.get() if task is None: logger.debug("Shutting down prioritizer thread") break self._ordered_tasks.put(task, block=True) def terminate(self): """An alias for hivemind.Runtime that assumes that each TaskPool is a process""" self.shutdown() def shutdown(self): self.submitted_tasks.put(None) # Shuts down self.run() def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture: """Add task to this pool's queue, return Future for its output""" future = MPFuture() # Remove shmem from MPFuture. This disables the .cancel() feature but # saves the server from "could not unlink the shared memory file" crashes during rebalancing future._shared_state_code = torch.tensor([ALL_STATES.index(PENDING)], dtype=torch.uint8) task = Task(priority, time.monotonic(), future, args) if self.get_task_size(task) > self.max_batch_size: exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed") task.future.set_exception(exc) else: self.submitted_tasks.put(task) self.batch_sender.send(None) # use this pipe to count the number of unfinished batches if (task.priority, task.time_submitted) < self.priority: self.priority = (task.priority, task.time_submitted) return task.future def get_task_size(self, task: Task) -> int: """compute task processing complexity; defaults to the total number of tokens""" if task.args and task.args[0].ndim >= 2: return task.args[0].shape[0] * task.args[0].shape[1] return 1 def load_batch_to_runtime( self, timeout: Optional[float] = None, device: Optional[torch.device] = None ) -> Tuple[Any, List[torch.Tensor]]: """receive next batch of arrays""" device = device if device is not None else self.device task = self._ordered_tasks.get(block=True, timeout=timeout) batch_inputs = [_move_to_device_if_tensor(arg, device, share_memory=False) for arg in task.args] self._dispatched_tasks[task.uid] = task self.batch_receiver.recv() # reduce the number of active batches if not self._ordered_tasks.empty(): first_remaining_task: Task = self._ordered_tasks.queue[0] self.priority = (first_remaining_task.priority, first_remaining_task.time_submitted) return task.uid, batch_inputs def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]): """send results for a processed batch, previously loaded through load_batch_to_runtime""" batch_outputs = [_move_to_device_if_tensor(output, device="cpu", share_memory=True) for output in batch_outputs] task = self._dispatched_tasks.pop(uid, None) if task is None: logger.error( f"Internal error: task task with index {uid} is missing from the dictionary; " f"Could not set result" ) else: task.future.set_result(batch_outputs) def send_exception_from_runtime(self, uid: int, exception: BaseException): task = self._dispatched_tasks.pop(uid, None) if task is None: logger.error( f"Internal error: task task with index {uid} is missing from the dictionary; " f"Could not set exception {exception}" ) else: task.future.set_exception(exception) @property def empty(self): return not self.batch_receiver.poll() @property def priority(self) -> Tuple[float, float]: """The priority of this pool equals the (priority, timestamp) of the most important task in it.""" return float(self._priority.value), float(self._oldest_undispatched_timestamp.value) @priority.setter def priority(self, item: Tuple[float, float]): assert len(item) == 2 self._priority.value = float(item[0]) self._oldest_undispatched_timestamp.value = float(item[1]) def _move_to_device_if_tensor(arg: Any, device: Union[torch.device, str], share_memory: bool = False): if isinstance(arg, torch.Tensor): arg = arg.detach().to(device, non_blocking=not share_memory).requires_grad_(arg.requires_grad) # note: it is important that non_blocking is disabled if share_memory=True; using share_memory on a tensor # produced by a non-blocking copy will result in undefined behavior (depending on your gpu speed) if share_memory: arg = arg.share_memory_() return arg ================================================ FILE: src/petals/server/task_prioritizer.py ================================================ from abc import ABC, abstractmethod import torch class TaskPrioritizerBase(ABC): """Abstract class for TaskPrioritizer whose responsibility is to evaluate task priority""" @abstractmethod def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float: """Evaluates task value by the amount of points given, task input and additional kwargs. Lower priority is better""" pass class DummyTaskPrioritizer(TaskPrioritizerBase): def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float: # Inference steps go first since they are more latency-sensitive if kwargs.get("type") == "inference": return 1.0 return 2.0 # Forward, backward ================================================ FILE: src/petals/server/throughput.py ================================================ import fcntl import json import math import multiprocessing as mp import os import time from collections import Counter from pathlib import Path from typing import Dict, Optional, Sequence, Union import torch import torch.mps from hivemind.utils.logging import get_logger from transformers import PretrainedConfig from petals.server.block_utils import get_model_block, resolve_block_dtype from petals.utils.convert_block import QuantType, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR from petals.utils.misc import DUMMY_KEY_PAST logger = get_logger(__name__) try: import speedtest except ImportError: raise ImportError("Please `pip install speedtest-cli==2.1.3`") if not hasattr(speedtest, "Speedtest"): raise ImportError( "You are using the wrong speedtest module. Please replace speedtest with speedtest-cli.\n" "To do that, run `pip uninstall -y speedtest`. Depending on your python environment, " "you may need to run uninstall speedtest two or more times, until it says 'not installed'.\n" "After that, please `pip install speedtest-cli==2.1.3` to install the correct version." ) def get_server_throughput( model_name: str, config: PretrainedConfig, device: torch.device, dtype: Union[str, torch.dtype], *, num_blocks: int, quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], reachable_via_relay: bool, relay_penalty: float = 0.2, force_eval: bool = False, cache_dir: Optional[str] = None, ) -> Dict[str, float]: dtype = resolve_block_dtype(config, dtype) if cache_dir is None: cache_dir = DEFAULT_CACHE_DIR lock_path = Path(cache_dir, "throughput.lock") cache_path = Path(cache_dir, "throughput_v5.json") # We use the system-wide lock since only one process at a time can measure the host throughput os.makedirs(lock_path.parent, exist_ok=True) with open(lock_path, "wb+") as lock_fd: logger.info("Loading throughput info") fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX) # The OS will release the lock when lock_fd is closed or the process is killed cache_key = f"model_{model_name}" cache_key += f"_device_{get_device_name(device).replace(' ', '_')}" cache_key += f"_dtype_{get_dtype_name(dtype, quant_type)}" if len(tensor_parallel_devices) > 1: for i, device_i in enumerate(tensor_parallel_devices): cache_key += f"_tp{i}_{get_device_name(device_i).replace(' ', '_')}" cache = {} try: if not force_eval and os.path.exists(cache_path): with open(cache_path) as cache_fd: cache = json.load(cache_fd) assert isinstance(cache, dict) except Exception: logger.exception(f"Failed to read throughput info from {cache_path}") cache = {} if cache_key not in cache: cache[cache_key] = measure_throughput_info( config, device, dtype, quant_type=quant_type, tensor_parallel_devices=tensor_parallel_devices ) try: os.makedirs(cache_path.parent, exist_ok=True) with open(cache_path, "w") as cache_fd: json.dump(cache, cache_fd) except Exception: logger.exception(f"Failed to save throughput info in {cache_path}") throughput_info = cache[cache_key] # Most requests start at some block hosted by a server, then use all next blocks hosted on this server. # Assuming the start block index is distributed uniformly, the average number of blocks used per request is # E[Uniform{1, 2, ..., num_blocks}] = (num_blocks + 1) / 2 average_blocks_used = (num_blocks + 1) / 2 throughput = throughput_info["forward_rps"] / average_blocks_used network_rps = throughput_info["network_rps"] * (relay_penalty if reachable_via_relay else 1) throughput = min(throughput, network_rps) throughput_info["throughput"] = throughput logger.info(f"Reporting throughput: {throughput:.1f} tokens/sec for {num_blocks} blocks") return throughput_info def measure_throughput_info( config: PretrainedConfig, device: torch.device, dtype: torch.dtype, *, quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], ) -> Dict[str, float]: logger.info( "Measuring network and compute throughput. This takes about a minute and will be cached for future runs" ) return { "inference_rps": measure_compute_rps( config, device, dtype, quant_type=quant_type, tensor_parallel_devices=tensor_parallel_devices, n_tokens=1, n_steps=100, inference=True, ), "forward_rps": measure_compute_rps( config, device, dtype, quant_type=quant_type, tensor_parallel_devices=tensor_parallel_devices, n_tokens=1024, n_steps=10, inference=False, ), "network_rps": measure_network_rps(config), } def measure_network_rps( config: PretrainedConfig, *, timeout: float = 60, default_speed: float = 100e6 # 100 Mbit/s ) -> Optional[float]: bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward try: pipe_recv, pipe_send = mp.Pipe(duplex=False) process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,)) process.start() if not pipe_recv.poll(timeout): process.terminate() raise RuntimeError(f"speedtest did not finish in {timeout} seconds") network_info = pipe_recv.recv() if "exception" in network_info: raise RuntimeError(f"speedtest failed: {network_info['exception']}") network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request if network_rps == 0: raise RuntimeError("speedtest has returned network_rps == 0") logger.info( f"Network throughput: {network_rps:.1f} tokens/sec " f"({network_info['download'] / 1e6:.2f} Mbit/s on download, " f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload)" ) return network_rps except RuntimeError as e: logger.info(f"Network throughput is not available: {e}. Using default of {default_speed / 1e6:.2f} Mbit/s") return default_speed / bits_per_request def _measure_bits_per_second(pipe_send: mp.Pipe): try: s = speedtest.Speedtest() s.get_servers() s.get_best_server() s.download() s.upload() pipe_send.send(s.results.dict()) except Exception as e: pipe_send.send({"exception": repr(e)}) def measure_compute_rps( config: PretrainedConfig, device: torch.device, dtype: torch.dtype, *, quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], n_tokens: int, n_steps: int, inference: bool, ) -> float: device = torch.device(device) if not tensor_parallel_devices: tensor_parallel_devices = (device,) with torch.inference_mode(): block = get_model_block(config) block = block.to(dtype) block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) cache = (DUMMY_KEY_PAST.to(dtype=dtype, device=device), DUMMY_KEY_PAST.to(dtype=dtype, device=device)) elapsed = 0 dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype) # Skip the 1st step to exclude the initialization time def step(cache_): outputs = block.forward(dummy_input, use_cache=inference, layer_past=cache_ if inference else None) return outputs[1] if inference else None cache = step(cache) synchronize(device) start_time = time.perf_counter() for _ in range(n_steps): cache = step(cache) synchronize(device) elapsed = time.perf_counter() - start_time device_rps = n_steps * n_tokens / elapsed devices_repr = get_device_name(device) if len(tensor_parallel_devices) > 1: device_names = tuple(map(get_device_name, map(torch.device, tensor_parallel_devices))) devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common()) logger.info( f"{'Inference' if inference else 'Forward pass'} throughput: {device_rps:.1f} tokens/sec per block " f"({n_tokens} tokens/batch, {devices_repr}, {get_dtype_name(dtype, quant_type)})" ) return device_rps def synchronize(device: torch.device): if device.type == "cuda": torch.cuda.synchronize(device) elif device.type == "mps": torch.mps.synchronize() def get_device_name(device: torch.device) -> str: return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else device.type.upper() def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str: name = str(dtype).replace("torch.", "") if quant_type != QuantType.NONE: name += f", quantized to {quant_type.name.lower()}" return name ================================================ FILE: src/petals/utils/__init__.py ================================================ from petals.utils.auto_config import ( AutoDistributedConfig, AutoDistributedModel, AutoDistributedModelForCausalLM, AutoDistributedModelForSequenceClassification, AutoDistributedSpeculativeModel, ) from petals.utils.dht import declare_active_modules, get_remote_module_infos ================================================ FILE: src/petals/utils/asyncio.py ================================================ import asyncio async def shield_and_wait(task): """ Works like asyncio.shield(), but waits for the task to finish before raising CancelledError to the caller. """ if not isinstance(task, asyncio.Task): task = asyncio.create_task(task) cancel_exc = None while True: try: result = await asyncio.shield(task) break except asyncio.CancelledError as e: cancel_exc = e if cancel_exc is not None: raise cancel_exc return result ================================================ FILE: src/petals/utils/auto_config.py ================================================ import os from dataclasses import dataclass from typing import Optional, Type, Union from hivemind import get_logger from transformers import AutoConfig, PretrainedConfig, PreTrainedModel from petals.utils.hf_auth import always_needs_auth logger = get_logger(__name__) @dataclass class _ModelClasses: config: Type[PretrainedConfig] model: Optional[Type[PreTrainedModel]] = None model_for_causal_lm: Optional[Type[PreTrainedModel]] = None model_for_speculative: Optional[Type[PreTrainedModel]] = None model_for_sequence_classification: Optional[Type[PreTrainedModel]] = None _CLASS_MAPPING = {} # Populated by petals.models.* subpackages with register_model_classes() def register_model_classes(*, config: Type[PretrainedConfig], **kwargs): assert issubclass(config, PretrainedConfig) assert config.model_type not in _CLASS_MAPPING, f"Model type {config.model_type} is already registered" _CLASS_MAPPING[config.model_type] = _ModelClasses(config=config, **kwargs) class _AutoDistributedBase: _mapping_field = None # Should be defined in child classes @classmethod def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike, None], *args, **kwargs) -> PretrainedConfig: if ( always_needs_auth(model_name_or_path) and kwargs.get("token") is None and kwargs.get("use_auth_token") is None ): kwargs["use_auth_token"] = True config = AutoConfig.from_pretrained(model_name_or_path, *args, **kwargs) if config.model_type not in _CLASS_MAPPING: raise ValueError(f"Petals does not support model type {config.model_type}") proper_cls = getattr(_CLASS_MAPPING[config.model_type], cls._mapping_field) if proper_cls is None: raise ValueError(f"Petals does not have {cls.__name__} for model type {config.model_type}") return proper_cls.from_pretrained(model_name_or_path, *args, **kwargs) class DefaultRevisionMixin: """ Petals only supports Falcon loaded in the new in-library format (transformers.FalconModel). TII models were recently converted to this format but then reverted back due to compatibility issues. We chose to support only the new format since HF staff promised to eventually convert these models to the new format again, see https://huggingface.co/tiiuae/falcon-40b/discussions/90#64b4d23bf44fd957492f7602 Until it happens, we override the default `main` revision for the TII repos with the commit pointing out to the model in the in-library format. """ DEFAULT_REVISIONS = { "tiiuae/falcon-40b": "f1ba7d328c06aa6fbb4a8afd3c756f46d7e6b232", "tiiuae/falcon-40b-instruct": "7475ff8cfc36ed9a962b658ae3c33391566a85a5", "tiiuae/falcon-7b": "4e2d06f0a7c6370ebabbc30c6f59377ae8f73d76", "tiiuae/falcon-7b-instruct": "f8dac3fff96d5debd43edf56fb4e1abcfffbef28", } @classmethod def from_pretrained( cls, model_name_or_path: Union[str, os.PathLike, None], *args, revision: Optional[str] = None, **kwargs ): if revision is None and model_name_or_path in cls.DEFAULT_REVISIONS: revision = cls.DEFAULT_REVISIONS[model_name_or_path] logger.info(f"Loading {model_name_or_path}, revision {revision}") return super().from_pretrained(model_name_or_path, *args, revision=revision, **kwargs) class AutoDistributedConfig(DefaultRevisionMixin, _AutoDistributedBase): _mapping_field = "config" class AutoDistributedModel(DefaultRevisionMixin, _AutoDistributedBase): _mapping_field = "model" class AutoDistributedModelForCausalLM(DefaultRevisionMixin, _AutoDistributedBase): _mapping_field = "model_for_causal_lm" class AutoDistributedSpeculativeModel(DefaultRevisionMixin, _AutoDistributedBase): _mapping_field = "model_for_speculative" class AutoDistributedModelForSequenceClassification(DefaultRevisionMixin, _AutoDistributedBase): _mapping_field = "model_for_sequence_classification" ================================================ FILE: src/petals/utils/convert_block.py ================================================ """ Tools for converting transformer blocks, applying quantization and/or tensor parallelism """ import re from enum import Enum from typing import Optional, Sequence import tensor_parallel as tp import torch import torch.nn as nn from hivemind.utils.logging import get_logger, use_hivemind_log_handler from tensor_parallel.slicing_configs import get_bloom_config from transformers import PretrainedConfig use_hivemind_log_handler("in_root_logger") logger = get_logger(__name__) class QuantType(Enum): NONE = 0 INT8 = 1 # 8-bit as in the LLM.int8() paper NF4 = 2 # 4-bit as in the QLoRA paper def convert_block( block: nn.Module, block_index: int, config: PretrainedConfig, tensor_parallel_devices: Sequence[torch.device], output_device: torch.device, quant_type: QuantType, freeze: bool = True, adapters: Optional[Sequence[str]] = None, **kwargs, ) -> tp.TensorParallel: """ Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization :note: some optimizations will modify the input block in-place! :param block: a single transformer block, either pre-trained or newly initialized :param config: HF transformers config for the full model :param tensor_parallel_devices: if specified, use tensor parallelism to split the model between these devices :note: if there is only a single device, model wil still be wrapped with TensorParallel (for uniformity) :param output_device: if tensor_parallel_devices is True, output :param quant_type: quantization type :param freeze: if True (default), make all module parameters non-trainable :return: a module that acts like the original block, but runs with all specified optimizations """ if freeze: block.requires_grad_(False) block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device) if quant_type != QuantType.NONE: block = quantize_module(block, quant_type=quant_type) for shard, device in zip(block.module_shards, block.devices): shard.to(device) if adapters: from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft create_lora_adapter(block) for adapter_name in adapters: adapter_config, adapter_state_dict = load_peft( adapter_name, block_idx=block_index, **kwargs, ) add_adapter_to_block(block, block_index, adapter_name, adapter_config, adapter_state_dict) return block def quantize_module(model: nn.Module, *, quant_type: QuantType) -> nn.Module: # Import bitsandbytes only when necessary, so Petals runs on platforms not supported by bitsandbytes import bitsandbytes as bnb for n, module in model.named_children(): if len(list(module.children())) > 0: quantize_module(module, quant_type=quant_type) if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]: assert module.weight.device.type == "cpu", f"expected linear layers on CPU, got {module.weight.device}" if quant_type == QuantType.INT8: model._modules[n] = bnb.nn.Linear8bitLt( module.in_features, module.out_features, module.bias is not None, has_fp16_weights=False, threshold=6.0, # Default from the LLM.int8() paper ) model._modules[n].weight = bnb.nn.Int8Params( module.weight.data, requires_grad=False, has_fp16_weights=False ).to(module.weight.dtype) elif quant_type == QuantType.NF4: compress_statistics = True model._modules[n] = bnb.nn.LinearNF4( module.in_features, module.out_features, module.bias is not None, compress_statistics=compress_statistics, ) model._modules[n].weight = bnb.nn.Params4bit( module.weight.data, requires_grad=False, quant_type="nf4", blocksize=64, compress_statistics=compress_statistics, ).to(module.weight.dtype) else: raise ValueError(f"Unsupported quant_type='{quant_type}'") model._modules[n].bias = module.bias return model def make_tensor_parallel( block: nn.Module, model_config: PretrainedConfig, devices: Sequence[torch.device], output_device: torch.device ) -> nn.Module: if model_config.model_type == "bloom": tp_config = get_bloom_config(model_config, devices) del tp_config.state_rules[re.compile(".*word_embeddings.weight$")] else: if len(devices) > 1: logger.warning("Tensor parallelism is not tested for models other than BLOOM yet, proceed with caution") tp_config = None tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True) total_heads = 0 for tp_shard in tp_block.module_shards: for submodule in tp_shard.modules(): if isinstance(submodule, model_config.attn_class): total_heads += submodule.num_heads assert total_heads == model_config.num_attention_heads return tp_block def check_device_balance(devices: Sequence[torch.device]): if not all(device.type == "cuda" for device in devices): logger.warning("Running tensor parallelism on non-GPU devices; proceed at your own risk") return unique_device_capabilities = set(map(torch.cuda.get_device_capability, devices)) if len(unique_device_capabilities) > 1: logger.warning( f"Found GPUs with uneven capabilities: {unique_device_capabilities}. " f"Using GPUs with different performance will cause the server to wait for the slowest GPU." ) memory_per_device = tuple(torch.cuda.get_device_properties(device).total_memory for device in devices) used_memory = min(memory_per_device) * len(memory_per_device) wasted_memory_rate = (sum(memory_per_device) - used_memory) / sum(memory_per_device) if wasted_memory_rate > 0.05: logger.warning( f"GPU devices have highly uneven memory, {wasted_memory_rate * 100:.2f}% memory is wasted. " f"Consider running high-memory GPUs in a separate server." ) ================================================ FILE: src/petals/utils/cuda_graphs.py ================================================ import torch from torch.utils._pytree import tree_flatten as _tree_flatten, tree_unflatten as _tree_unflatten def make_inference_graphed_callable(callable: callable, sample_args, num_warmup_iters=3): """Similar to torch.cuda.make_graphed_callables, but takes only one function and does not build a graph for the backward pass""" assert not isinstance(callable, torch.nn.Module) if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled(): raise RuntimeError( "make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`." ) flatten_arg, _ = _tree_flatten(sample_args) flatten_sample_args = tuple(flatten_arg) assert all( isinstance(arg, torch.Tensor) for arg in flatten_arg ), "In the beta API, sample_args for each callable must contain only Tensors. Other types are not allowed." len_user_args = len(sample_args) static_input_surface = flatten_sample_args graph = torch.cuda.CUDAGraph() # Warmup # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work # from ending up in any captures. s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(num_warmup_iters): outputs, _ = _tree_flatten(callable(*sample_args)) del outputs torch.cuda.current_stream().wait_stream(s) # Capture forward graph with torch.cuda.graph(graph): outputs = callable(*sample_args) flatten_outputs, output_unflatten_spec = _tree_flatten(outputs) static_outputs = tuple(flatten_outputs) def make_graphed_function( graph, len_user_args, output_unflatten_spec, static_input_surface, static_outputs, ): def replay_graph(*inputs): # At this stage, only the user args may (potentially) be new tensors. for i in range(len_user_args): if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): static_input_surface[i].copy_(inputs[i]) graph.replay() assert isinstance(static_outputs, tuple) return tuple(o.detach() for o in static_outputs) def functionalized(*user_args): # Runs the autograd function with inputs == all inputs to the graph that might require grad # (explicit user args + module parameters) # Assumes module params didn't change since capture. flatten_user_args, _ = _tree_flatten(user_args) out = replay_graph(*flatten_user_args) return _tree_unflatten(out, output_unflatten_spec) return functionalized # Put together the final graphed callable graphed = make_graphed_function( graph, len_user_args, output_unflatten_spec, static_input_surface, static_outputs, ) return graphed ================================================ FILE: src/petals/utils/dht.py ================================================ """ Utilities for declaring and retrieving active model layers using a shared DHT. """ from __future__ import annotations import math from functools import partial from typing import Dict, List, Optional, Sequence, Union from hivemind.dht import DHT, DHTNode, DHTValue from hivemind.p2p import PeerID from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger from petals.data_structures import ( CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerInfo, ServerState, parse_uid, ) logger = get_logger(__name__) def declare_active_modules( dht: DHT, uids: Sequence[ModuleUID], server_info: ServerInfo, expiration_time: DHTExpiration, wait: bool = True, ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]: """ Declare that your node serves the specified modules; update timestamps if declared previously :param uids: a list of module ids to declare :param wait: if True, awaits for declaration to finish, otherwise runs in background :param throughput: specify your performance in terms of compute throughput :param expiration_time: declared modules will be visible for this many seconds :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected) """ if isinstance(uids, str): uids = [uids] if not isinstance(uids, list): uids = list(uids) for uid in uids: assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid return dht.run_coroutine( partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time), return_future=not wait, ) async def _declare_active_modules( dht: DHT, node: DHTNode, uids: List[ModuleUID], server_info: ServerInfo, expiration_time: DHTExpiration, ) -> Dict[ModuleUID, bool]: num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers) return await node.store_many( keys=uids, subkeys=[dht.peer_id.to_base58()] * len(uids), values=[server_info.to_tuple()] * len(uids), expiration_time=expiration_time, num_workers=num_workers, ) def get_remote_module_infos( dht: DHT, uids: Sequence[ModuleUID], expiration_time: Optional[DHTExpiration] = None, active_adapter: Optional[str] = None, *, latest: bool = False, return_future: bool = False, ) -> Union[List[RemoteModuleInfo], MPFuture]: return dht.run_coroutine( partial( _get_remote_module_infos, uids=uids, active_adapter=active_adapter, expiration_time=expiration_time, latest=latest, ), return_future=return_future, ) async def _get_remote_module_infos( dht: DHT, node: DHTNode, uids: List[ModuleUID], active_adapter: Optional[str], expiration_time: Optional[DHTExpiration], latest: bool, ) -> List[RemoteModuleInfo]: if latest: assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both" expiration_time = math.inf elif expiration_time is None: expiration_time = get_dht_time() num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers) found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers) modules = [RemoteModuleInfo(uid=uid, servers={}) for uid in uids] for module_info in modules: metadata = found[module_info.uid] if metadata is None or not isinstance(metadata.value, dict): if metadata is not None: logger.warning(f"Incorrect metadata for {module_info.uid}: {metadata}") continue for peer_id, server_info in metadata.value.items(): try: peer_id = PeerID.from_base58(peer_id) server_info = ServerInfo.from_tuple(server_info.value) if active_adapter and active_adapter not in server_info.adapters: logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}") continue module_info.servers[peer_id] = server_info except (TypeError, ValueError) as e: logger.warning(f"Incorrect peer entry for uid={module_info.uid}, peer_id={peer_id}: {e}") return modules def compute_spans(module_infos: List[RemoteModuleInfo], *, min_state: ServerState) -> Dict[PeerID, RemoteSpanInfo]: block_offset = parse_uid(module_infos[0].uid)[1] if module_infos else 0 num_blocks = len(module_infos) spans = {} for block_idx, module_info in enumerate(module_infos): for peer_id, server_info in sorted(module_info.servers.items()): if server_info.state.value < min_state.value: continue if peer_id not in spans or spans[peer_id].state.value < server_info.state.value: spans[peer_id] = RemoteSpanInfo( peer_id=peer_id, start=block_idx, end=block_idx + 1, server_info=server_info ) if server_info.start_block is not None and server_info.end_block is not None: spans[peer_id].start = max(server_info.start_block - block_offset, 0) spans[peer_id].end = min(server_info.end_block - block_offset, num_blocks) elif spans[peer_id].state == server_info.state: spans[peer_id].end = max(spans[peer_id].end, block_idx + 1) return spans ================================================ FILE: src/petals/utils/disk_cache.py ================================================ import fcntl import os import shutil from contextlib import contextmanager from pathlib import Path from typing import Optional import huggingface_hub from hivemind.utils.logging import get_logger logger = get_logger(__name__) DEFAULT_CACHE_DIR = os.getenv("PETALS_CACHE", Path(Path.home(), ".cache", "petals")) BLOCKS_LOCK_FILE = "blocks.lock" @contextmanager def _blocks_lock(cache_dir: Optional[str], mode: int): if cache_dir is None: cache_dir = DEFAULT_CACHE_DIR lock_path = Path(cache_dir, BLOCKS_LOCK_FILE) os.makedirs(lock_path.parent, exist_ok=True) with open(lock_path, "wb+") as lock_fd: fcntl.flock(lock_fd.fileno(), mode) # The OS will release the lock when lock_fd is closed or the process is killed yield def allow_cache_reads(cache_dir: Optional[str]): """Allows simultaneous reads, guarantees that blocks won't be removed along the way (shared lock)""" return _blocks_lock(cache_dir, fcntl.LOCK_SH) def allow_cache_writes(cache_dir: Optional[str]): """Allows saving new blocks and removing the old ones (exclusive lock)""" return _blocks_lock(cache_dir, fcntl.LOCK_EX) def free_disk_space_for( size: int, *, cache_dir: Optional[str], max_disk_space: Optional[int], os_quota: int = 1024**3, # Minimal space we should leave to keep OS function normally ): if cache_dir is None: cache_dir = DEFAULT_CACHE_DIR cache_info = huggingface_hub.scan_cache_dir(cache_dir) available_space = shutil.disk_usage(cache_dir).free - os_quota if max_disk_space is not None: available_space = min(available_space, max_disk_space - cache_info.size_on_disk) gib = 1024**3 logger.debug(f"Disk space: required {size / gib:.1f} GiB, available {available_space / gib:.1f} GiB") if size <= available_space: return cached_files = [file for repo in cache_info.repos for revision in repo.revisions for file in revision.files] # Remove as few least recently used files as possible removed_files = [] freed_space = 0 extra_space_needed = size - available_space for file in sorted(cached_files, key=lambda file: file.blob_last_accessed): os.remove(file.file_path) # Remove symlink os.remove(file.blob_path) # Remove contents removed_files.append(file) freed_space += file.size_on_disk if freed_space >= extra_space_needed: break if removed_files: logger.info(f"Removed {len(removed_files)} files to free {freed_space / gib:.1f} GiB of disk space") logger.debug(f"Removed paths: {[str(file.file_path) for file in removed_files]}") if freed_space < extra_space_needed: raise RuntimeError( f"Insufficient disk space to load a block. Please free {(extra_space_needed - freed_space) / gib:.1f} GiB " f"on the volume for {cache_dir} or increase --max_disk_space if you set it manually" ) ================================================ FILE: src/petals/utils/hf_auth.py ================================================ import os from typing import Union def always_needs_auth(model_name: Union[str, os.PathLike, None]) -> bool: loading_from_repo = model_name is not None and not os.path.isdir(model_name) return loading_from_repo and model_name.startswith("meta-llama/Llama-2-") ================================================ FILE: src/petals/utils/logging.py ================================================ import os from hivemind.utils import logging as hm_logging def initialize_logs(): """Initialize Petals logging tweaks. This function is called when you import the `petals` module.""" # Env var PETALS_LOGGING=False prohibits Petals do anything with logs if os.getenv("PETALS_LOGGING", "True").lower() in ("false", "0"): return hm_logging.use_hivemind_log_handler("in_root_logger") # We suppress asyncio error logs by default since they are mostly not relevant for the end user, # unless there is env var PETALS_ASYNCIO_LOGLEVEL asyncio_loglevel = os.getenv("PETALS_ASYNCIO_LOGLEVEL", "FATAL" if hm_logging.loglevel != "DEBUG" else "DEBUG") hm_logging.get_logger("asyncio").setLevel(asyncio_loglevel) ================================================ FILE: src/petals/utils/misc.py ================================================ import torch DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter parameters DUMMY_INT64 = torch.empty(0, dtype=torch.int64) DUMMY_KEY_PAST = torch.empty((0, 0, 0)) def is_dummy(tensor: torch.Tensor) -> bool: return tensor.numel() == 0 SPECIAL_DTYPE_SIZES = {torch.bool: 1, torch.qint8: 1, torch.qint32: 4} def get_size_in_bytes(dtype: torch.dtype) -> int: if dtype in SPECIAL_DTYPE_SIZES: return SPECIAL_DTYPE_SIZES[dtype] get_info = torch.finfo if dtype.is_floating_point else torch.iinfo return (get_info(dtype).bits * (1 + dtype.is_complex)) // 8 def docstring_from(source): def add_docstring(dest): dest.__doc__ = source.__doc__ return dest return add_docstring ================================================ FILE: src/petals/utils/packaging.py ================================================ from typing import Any, Dict, List, Tuple import torch from hivemind import nested_flatten, nested_pack # TODO: Move functions to hivemind def _mark_masked_tensor(index: int) -> bytes: return b"__T" + str(index).encode() def _is_masked_tensor(item: Any) -> bool: return isinstance(item, bytes) and item.startswith(b"__T") def _get_tensor_index(item: bytes) -> int: return int(item[3:]) def pack_args_kwargs(*args, **kwargs) -> Tuple[List[torch.Tensor], Any]: """ Check the function's arguments and pack all tensors into different flattened lists. :returns: a flattened list of tensors and args and kwargs, where tensors were masked """ masked_flat_values, flat_tensors, tensor_to_index = [], [], {} for value in nested_flatten((args, kwargs)): if isinstance(value, torch.Tensor): tensor_index = tensor_to_index.setdefault(value, len(flat_tensors)) if tensor_index == len(flat_tensors): flat_tensors.append(value) masked_flat_values.append(_mark_masked_tensor(tensor_index)) else: masked_flat_values.append(value) return flat_tensors, nested_pack(masked_flat_values, (args, kwargs)) def unpack_args_kwargs(flat_tensors: List[torch.Tensor], args_structure: Any): """ Restore arguments after `pack_args_kwargs` function. :returns: list of args and dict of kwargs """ return nested_pack( ( value if not _is_masked_tensor(value) else flat_tensors[_get_tensor_index(value)] for value in nested_flatten(args_structure) ), args_structure, ) ================================================ FILE: src/petals/utils/peft.py ================================================ import contextlib import re import time from typing import List, Optional, Sequence, Union import bitsandbytes as bnb import torch import torch.nn as nn import transformers from accelerate import init_empty_weights from hivemind.utils.logging import get_logger from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url from peft.config import PeftConfig from peft.tuners import lora from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME from safetensors import safe_open from safetensors.torch import load_file from transformers.utils import get_file_from_repo from petals.server.block_utils import get_model_block, resolve_block_dtype from petals.utils.convert_block import QuantType from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for from petals.utils.misc import get_size_in_bytes logger = get_logger(__name__) COMMON_LAYERS_PATTERN = ["layers", "h", "block", "blocks", "layer"] def check_peft_repository(repo_id: str) -> bool: return HfFileSystem().exists(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}") def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None): tensors = dict() is_tensors_found = dict() common_layer_patter_re = ( ".+\." + "".join(f"({common_name})?" for common_name in COMMON_LAYERS_PATTERN) + f"\.({block_idx})?\..+" ) with safe_open(filepath, framework=framework, device=device) as f: for k in f.keys(): if re.match(common_layer_patter_re, k): is_tensors_found[block_idx] = True tensors[k] = f.get_tensor(k) if not is_tensors_found.get(block_idx, False): logger.warning(f"There is no peft weights for block {block_idx}") return tensors def get_adapter_from_repo( repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, *, token: Optional[Union[str, bool]] = None, **kwargs, ): config_path = get_file_from_repo(repo_id, CONFIG_NAME, use_auth_token=token, **kwargs) if config_path is None: raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}") config = PeftConfig.from_json_file(config_path) weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, use_auth_token=token, **kwargs) if weight_path is None: raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}") if block_idx is None: return config, load_file(weight_path) return config, load_specific_module(block_idx, weight_path, device=device) def load_peft( repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, *, revision: Optional[str] = None, token: Optional[Union[str, bool]] = None, cache_dir: str, max_disk_space: Optional[int] = None, delay: float = 30, ): # TODO: Check is it possible to add safetensors loading inside petals/server/from_pretrained.py and reuse it here if not check_peft_repository(repo_id): raise ValueError(f"Repo: {repo_id} doesn't have safetensors inside for a safe loading.") try: with allow_cache_reads(cache_dir): return get_adapter_from_repo( repo_id, block_idx, device, revision=revision, token=token, cache_dir=cache_dir, local_files_only=False, ) except Exception: logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True) while True: try: with allow_cache_writes(cache_dir): config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision) config_file_size = get_hf_file_metadata(config_url, token=token).size weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision) weight_file_size = get_hf_file_metadata(weight_url, token=token).size file_size = config_file_size + weight_file_size if file_size is not None: free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space) else: logger.warning(f"Failed to fetch size from peft repo {repo_id}") return get_adapter_from_repo( repo_id, block_idx, device, revision=revision, token=token, cache_dir=cache_dir, local_files_only=False, ) except Exception as e: logger.warning( f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True ) time.sleep(delay) class AdapterContextMixin: """A mixin that makes LoRA-wrapped linear layers obey an adapter set from context""" ADAPTER_NOT_SET = "__ADAPTER_NOT_SET" _context_active_adapter = ADAPTER_NOT_SET @staticmethod @contextlib.contextmanager def using_adapter(active_adapter: Optional[str]): prev, AdapterContextMixin._context_active_adapter = AdapterContextMixin._context_active_adapter, active_adapter try: yield finally: AdapterContextMixin._context_active_adapter = prev @property def active_adapter(self): if self._context_active_adapter == self.ADAPTER_NOT_SET: logger.warning(f"Layer {self} was called without using_adapter. This should only be used for debug") return self._context_active_adapter @active_adapter.setter def active_adapter(self, value: Optional[str]): assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed via .using_adapter" "" @property def active_adapters(self): return [self._context_active_adapter] def set_adapter(self, adapter_names) -> None: """ In PEFT, this function makes the adapter trainable. However, in Petals environment this is not possible now. Thus, this code removes this functionality. Link to peft code: https://github.com/huggingface/peft/blob/98f4db2c7990ef9c879a0e1da9a28a19a04701ef/src/peft/tuners/tuners_utils.py#L463 """ pass using_adapter = AdapterContextMixin.using_adapter class LoraLinear(AdapterContextMixin, lora.Linear): """LoRA linear layer that uses adapter selected via using_adapter""" def __init__(self, base_layer, adapter_name: str): nn.Module.__init__(self) lora.LoraLayer.__init__(self, base_layer) self._active_adapter = adapter_name self.is_target_conv_1d_layer = False class LoraLinear8bitLt(LoraLinear, lora.Linear8bitLt): """LoRA linear 8-bit with outliers that uses adapter selected via using_adapter""" class LoraLinear4bit(LoraLinear, lora.Linear4bit): """LoRA linear 4-bit that uses adapter selected via using_adapter""" def create_lora_adapter(block): for module_name, module in block.named_modules(): if isinstance(module, LoraLinear): continue for child_name, child in module.named_children(): lora_class = None if isinstance(child, nn.Linear): lora_class = LoraLinear elif isinstance(child, bnb.nn.Linear8bitLt): lora_class = LoraLinear8bitLt elif isinstance(child, bnb.nn.Linear4bit): lora_class = LoraLinear4bit if lora_class: lora_wrapped_child = lora_class( child, AdapterContextMixin.ADAPTER_NOT_SET, ) setattr(module, child_name, lora_wrapped_child) def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict): assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters" if peft_config["lora_dropout"] > 0: logger.info(f"Adapter {adapter_name} has dropout enabled, this server will disable dropout") for _, module in block.named_modules(): for child_name, child in module.named_children(): if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)): continue if child_name in peft_config["target_modules"] or ( isinstance(peft_config["target_modules"], str) and re.fullmatch(peft_config["target_modules"], child_name) ): is_lora_a_loaded = False is_lora_b_loaded = False for peft_key in peft_state_dict: if child_name not in peft_key: continue if adapter_name not in child.lora_A: child.update_layer( adapter_name, peft_config["r"], peft_config["lora_alpha"], use_rslora=peft_config.get("use_rslora", False), lora_dropout=peft_config["lora_dropout"], init_lora_weights=peft_config["init_lora_weights"], ) child.train(False) for p in child.parameters(): p.requires_grad = False if peft_key.endswith(".lora_A.weight"): child.lora_A[adapter_name].weight[...] = peft_state_dict[peft_key] is_lora_a_loaded = True elif peft_key.endswith(".lora_A.bias"): raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}") elif peft_key.endswith(".lora_B.weight"): child.lora_B[adapter_name].weight[...] = peft_state_dict[peft_key] is_lora_b_loaded = True elif peft_key.endswith(".lora_B.bias"): raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}") if is_lora_a_loaded and is_lora_b_loaded: logger.debug(f"Loaded adapter {adapter_name} for block {block_index}.{child_name}") elif is_lora_a_loaded or is_lora_b_loaded: raise ValueError(f"Invalid adapter {adapter_name} for block {block_index}.{child_name}") logger.info(f"Loaded adapter {adapter_name} for block {block_index}") def estimate_adapter_memory_per_block( block_config: transformers.PretrainedConfig, torch_dtype: Optional[torch.dtype], adapters: Sequence[str], **load_peft_kwargs, ) -> int: """Get the number of extra bytes used to store a set of adapters per given block""" with init_empty_weights(include_buffers=False): block = get_model_block(block_config) base_block_parameters = sum(p.numel() for p in block.parameters()) create_lora_adapter(block) for adapter in adapters: peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **load_peft_kwargs) assert peft_config["peft_type"].upper() == "LORA", "only LoRA adapters are supported for now" add_adapter_to_block( block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict ) adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters bytes_per_parameter = get_size_in_bytes(resolve_block_dtype(block_config, torch_dtype)) return adapter_parameters * bytes_per_parameter ================================================ FILE: src/petals/utils/ping.py ================================================ import asyncio import math import threading import time from functools import partial from typing import Dict, Sequence import hivemind from hivemind.proto import dht_pb2 from hivemind.utils.logging import get_logger logger = get_logger(__name__) async def ping( peer_id: hivemind.PeerID, _dht: hivemind.DHT, node: hivemind.dht.DHTNode, *, wait_timeout: float = 5, ) -> float: try: ping_request = dht_pb2.PingRequest(peer=node.protocol.node_info) start_time = time.perf_counter() await node.protocol.get_stub(peer_id).rpc_ping(ping_request, timeout=wait_timeout) return time.perf_counter() - start_time except Exception as e: if str(e) == "protocol not supported": # Happens on servers with client-mode DHT (e.g., reachable via relays) return time.perf_counter() - start_time logger.debug(f"Failed to ping {peer_id}:", exc_info=True) return math.inf async def ping_parallel(peer_ids: Sequence[hivemind.PeerID], *args, **kwargs) -> Dict[hivemind.PeerID, float]: rpc_infos = await asyncio.gather(*[ping(peer_id, *args, **kwargs) for peer_id in peer_ids]) return dict(zip(peer_ids, rpc_infos)) class PingAggregator: def __init__(self, dht: hivemind.DHT, *, ema_alpha: float = 0.2, expiration: float = 300): self.dht = dht self.ema_alpha = ema_alpha self.expiration = expiration self.ping_emas = hivemind.TimedStorage() self.lock = threading.Lock() def ping(self, peer_ids: Sequence[hivemind.PeerID], **kwargs) -> None: current_rtts = self.dht.run_coroutine(partial(ping_parallel, peer_ids, **kwargs)) logger.debug(f"Current RTTs: {current_rtts}") with self.lock: expiration = hivemind.get_dht_time() + self.expiration for peer_id, rtt in current_rtts.items(): prev_rtt = self.ping_emas.get(peer_id) if prev_rtt is not None and prev_rtt.value != math.inf: rtt = self.ema_alpha * rtt + (1 - self.ema_alpha) * prev_rtt.value # Exponential smoothing self.ping_emas.store(peer_id, rtt, expiration) def to_dict(self) -> Dict[hivemind.PeerID, float]: with self.lock, self.ping_emas.freeze(): smoothed_rtts = {peer_id: rtt.value for peer_id, rtt in self.ping_emas.items()} logger.debug(f"Smothed RTTs: {smoothed_rtts}") return smoothed_rtts ================================================ FILE: src/petals/utils/random.py ================================================ import random from typing import Collection, TypeVar T = TypeVar("T") def sample_up_to(population: Collection[T], k: int) -> T: if not isinstance(population, list): population = list(population) if len(population) > k: population = random.sample(population, k) return population ================================================ FILE: src/petals/utils/version.py ================================================ import os import re from typing import Union import requests from hivemind.utils.logging import TextStyle, get_logger from packaging.version import parse import petals logger = get_logger(__name__) def validate_version() -> None: logger.info(f"Running {TextStyle.BOLD}Petals {petals.__version__}{TextStyle.RESET}") try: r = requests.get("https://pypi.python.org/pypi/petals/json") r.raise_for_status() response = r.json() versions = [parse(ver) for ver in response.get("releases")] latest = max(ver for ver in versions if not ver.is_prerelease) if parse(petals.__version__) < latest: logger.info( f"A newer version {latest} is available. Please upgrade with: " f"{TextStyle.BOLD}pip install --upgrade petals{TextStyle.RESET}" ) except Exception as e: logger.warning("Failed to fetch the latest Petals version from PyPI:", exc_info=True) def get_compatible_model_repo(model_name_or_path: Union[str, os.PathLike, None]) -> Union[str, os.PathLike, None]: if model_name_or_path is None: return None match = re.fullmatch(r"(bigscience/.+)-petals", str(model_name_or_path)) if match is None: return model_name_or_path logger.info( f"Loading model from {match.group(1)}, since Petals 1.2.0+ uses original repos instead of converted ones" ) return match.group(1) ================================================ FILE: tests/conftest.py ================================================ import asyncio import gc from contextlib import suppress import psutil import pytest from hivemind.utils.crypto import RSAPrivateKey from hivemind.utils.logging import get_logger from hivemind.utils.mpfuture import MPFuture logger = get_logger(__name__) @pytest.fixture def event_loop(): """ This overrides the ``event_loop`` fixture from pytest-asyncio (e.g. to make it compatible with ``asyncio.subprocess``). This fixture is identical to the original one but does not call ``loop.close()`` in the end. Indeed, at this point, the loop is already stopped (i.e. next tests are free to create new loops). However, finalizers of objects created in the current test may reference the current loop and fail if it is closed. For example, this happens while using ``asyncio.subprocess`` (the ``asyncio.subprocess.Process`` finalizer fails if the loop is closed, but works if the loop is only stopped). """ yield asyncio.get_event_loop() @pytest.fixture(autouse=True, scope="session") def cleanup_children(): yield with RSAPrivateKey._process_wide_key_lock: RSAPrivateKey._process_wide_key = None gc.collect() # Call .__del__() for removed objects children = psutil.Process().children(recursive=True) if children: logger.info(f"Cleaning up {len(children)} leftover child processes") for child in children: with suppress(psutil.NoSuchProcess): child.terminate() psutil.wait_procs(children, timeout=1) for child in children: with suppress(psutil.NoSuchProcess): child.kill() MPFuture.reset_backend() ================================================ FILE: tests/test_aux_functions.py ================================================ import subprocess import sys import pytest import torch from hivemind import nested_compare, nested_flatten from petals import AutoDistributedConfig from petals.server.throughput import measure_compute_rps from petals.utils.convert_block import QuantType from petals.utils.misc import DUMMY, is_dummy from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs from test_utils import MODEL_NAME def test_bnb_not_imported_when_unnecessary(): """ We avoid importing bitsandbytes when it's not used, since bitsandbytes doesn't always find correct CUDA libs and may raise exceptions because of that. If this test fails, please change your code to import bitsandbytes and/or petals.utils.peft in the function's/method's code when it's actually needed instead of importing them in the beginning of the file. This won't slow down the code - importing a module for the 2nd time doesn't rerun module code. """ subprocess.check_call([sys.executable, "-c", "import petals, sys; assert 'bitsandbytes' not in sys.modules"]) @pytest.mark.forked @pytest.mark.parametrize("inference", [False, True]) @pytest.mark.parametrize("n_tokens", [1, 16]) @pytest.mark.parametrize("tensor_parallel", [False, True]) def test_compute_throughput(inference: bool, n_tokens: int, tensor_parallel: bool): config = AutoDistributedConfig.from_pretrained(MODEL_NAME) if tensor_parallel and config.model_type != "bloom": pytest.skip("Tensor parallelism is implemented only for BLOOM for now") tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else () compute_rps = measure_compute_rps( config, device=torch.device("cpu"), dtype=torch.bfloat16, quant_type=QuantType.NONE, tensor_parallel_devices=tensor_parallel_devices, n_tokens=n_tokens, n_steps=5, inference=inference, ) assert isinstance(compute_rps, float) and compute_rps > 0 @pytest.mark.forked def test_pack_inputs(): x = torch.ones(3) y = torch.arange(5) z = DUMMY args = (x, z, None, (y, y), z) kwargs = dict(foo=torch.zeros(1, 1), bar={"l": "i", "g": "h", "t": ("y", "e", "a", "r", torch.rand(1), x, y)}) flat_tensors, args_structure = pack_args_kwargs(*args, **kwargs) assert len(flat_tensors) == 5 assert all(isinstance(t, torch.Tensor) for t in flat_tensors) restored_args, restored_kwargs = unpack_args_kwargs(flat_tensors, args_structure) assert len(restored_args) == len(args) assert torch.all(restored_args[0] == x).item() and restored_args[2] is None assert nested_compare((args, kwargs), (restored_args, restored_kwargs)) for original, restored in zip(nested_flatten((args, kwargs)), nested_flatten((restored_args, restored_kwargs))): if isinstance(original, torch.Tensor): assert torch.all(original == restored) else: assert original == restored ================================================ FILE: tests/test_block_exact_match.py ================================================ import random import pytest import torch from petals import AutoDistributedConfig, RemoteSequential from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS from petals.server.from_pretrained import load_pretrained_block from test_utils import * @pytest.mark.forked def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3): config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) remote_sequential = RemoteSequential(config) block_index = random.randint(0, config.num_hidden_layers - 1) remote_block = remote_sequential[block_index] inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS + 8, config.hidden_size) outputs_forward = remote_block(inputs) outputs_inference = [] with torch.inference_mode(): with remote_block.inference_session(max_length=inputs.shape[1]) as sess: # Test long inference (unmerged inference pools) outputs_inference.append(sess.step(inputs[:, : MAX_SHORT_INFERENCE_TOKENS + 1, :])) # Test short inference (merged inference pools) for i in range(MAX_SHORT_INFERENCE_TOKENS + 1, inputs.shape[1]): outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) # test that max length is respected with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info: sess.step(inputs[:, -1:, :]) assert "Maximum length exceeded" in repr(exc_info.value) outputs_inference = torch.cat(outputs_inference, dim=1) ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32) (outputs_local,) = ref_block(inputs) assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward) assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference) ================================================ FILE: tests/test_cache.py ================================================ import asyncio import multiprocessing as mp import random import time from typing import Optional import pytest import pytest_asyncio # make sure the module exists; otherwise the test will be skipped import torch from hivemind import TensorDescriptor from petals.server.memory_cache import AllocationFailed, MemoryCache from petals.utils.misc import get_size_in_bytes def _make_tensor_descriptor(num_bytes: int, dtype: Optional[torch.dtype] = None): if dtype is None: dtype = random.choice((torch.int64, torch.int8, torch.uint8, torch.float32, torch.bfloat16, torch.bool)) elem_size_bytes = get_size_in_bytes(dtype) descr = TensorDescriptor.from_tensor(torch.empty((num_bytes // elem_size_bytes,), dtype=dtype)) return descr @pytest.mark.asyncio async def test_cache_timeout(): cache = MemoryCache(max_size_bytes=1024, max_alloc_timeout=0.5) cache.runtime_pid += 1 # pretend we're another process async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=0): pass async with cache.allocate_cache(_make_tensor_descriptor(100), timeout=999): async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0): async with cache.allocate_cache(_make_tensor_descriptor(128), _make_tensor_descriptor(32), timeout=1): t_start = time.perf_counter() with pytest.raises(AllocationFailed): async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=0.1): pass assert 0.1 < time.perf_counter() - t_start < 0.2, "wait time exceeds alloc timeout" async with cache.allocate_cache(_make_tensor_descriptor(128), timeout=float("inf")): pass t_start = time.perf_counter() with pytest.raises(AllocationFailed): async with cache.allocate_cache(_make_tensor_descriptor(384), timeout=1.0): # exceeds max timeout pass assert 0.5 < time.perf_counter() - t_start < 0.6, "wait time exceeds max alloc timeout" # test memory allocation when another task frees the memory async def _klog_the_cache(): async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0.2): pass large_alloc_task = asyncio.create_task(_klog_the_cache()) t_start = time.perf_counter() await asyncio.sleep(0.05) # wait for large alloc to enqueue async with cache.allocate_cache(_make_tensor_descriptor(128), timeout=float("inf")): # exceeds max timeout pass # this memory should allocate once the background task clears the queue assert 0.2 < time.perf_counter() - t_start < 0.3, "memory should be allocated after background task clears" with pytest.raises(AllocationFailed): await large_alloc_task # test that zero-timeout allocation fails instantaneously even if someone else is awaiting alloc large_alloc_task = asyncio.create_task(_klog_the_cache()) t_start = time.perf_counter() await asyncio.sleep(0.05) # wait for large alloc to enqueue with pytest.raises(AllocationFailed): async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0): pass # this memory should allocate once the background task clears the queue assert time.perf_counter() - t_start < 0.1, "zero-timeout task should fail (or succeed) instantaneously" with pytest.raises(AllocationFailed): await large_alloc_task @pytest.mark.asyncio async def test_unlimited_timeout(): cache = MemoryCache(max_size_bytes=1024) cache.runtime_pid += 1 # pretend we're another process t_start = time.perf_counter() async def _klog_the_cache(): async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0.2): await asyncio.sleep(0.5) alloc_task = asyncio.create_task(_klog_the_cache()) await asyncio.sleep(0.1) async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=float("inf")): await alloc_task assert 0.5 < time.perf_counter() - t_start < 0.6, "memory should be allocated after background task clears" @pytest.mark.asyncio async def test_cache_usage(): cache = MemoryCache(max_size_bytes=2048) alloc_event, dealloc_a_event, dealloc_bcd_event, dealloc_e_event, dealloc_f_event = (mp.Event() for _ in range(5)) pipe_receiver, pipe_sender = mp.Pipe(duplex=False) with pytest.raises(AssertionError): async with cache.allocate_cache(_make_tensor_descriptor(123), timeout=1): pass # fails because cache must be allocated from another process descr_a = TensorDescriptor.from_tensor(torch.empty(768, dtype=torch.uint8)) # 768 bytes descr_b = TensorDescriptor.from_tensor(torch.empty((), dtype=torch.float64)) # 8 bytes descr_c = TensorDescriptor.from_tensor(torch.empty((33,), dtype=torch.bool)) # 33 bytes descr_d = TensorDescriptor.from_tensor(torch.empty((0,), dtype=torch.int64)) # 0 bytes descr_e = TensorDescriptor.from_tensor(torch.empty((96, 8), dtype=torch.bfloat16)) # 1536 bytes descr_f = TensorDescriptor.from_tensor(torch.empty((1792,), dtype=torch.uint8)) # 1792 bytes async def _allocate_and_wait(dealloc_event, *descrs, timeout=None): loop = asyncio.get_event_loop() async with cache.allocate_cache(*descrs, timeout=timeout) as handles: pipe_sender.send(handles) await loop.run_in_executor(None, dealloc_event.wait) async def _allocate_af(): alloc_event.wait() allocate_a_task = asyncio.create_task(_allocate_and_wait(dealloc_a_event, descr_a)) await allocate_a_task allocate_f_task = asyncio.create_task(_allocate_and_wait(dealloc_f_event, descr_f)) # klogs the cache await allocate_f_task alloc_process1 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_af()), daemon=True) alloc_process1.start() async def _allocate_bcde(): alloc_event.wait() await asyncio.sleep(0.1) # ensure that the other tensor is always allocated (and sent through pipe) first allocate_bcd_task = asyncio.create_task(_allocate_and_wait(dealloc_bcd_event, descr_b, descr_c, descr_d)) allocate_e_task = asyncio.create_task(_allocate_and_wait(dealloc_e_event, descr_e)) # doesn't fit await asyncio.wait({allocate_e_task, allocate_bcd_task}, return_when=asyncio.ALL_COMPLETED) alloc_process2 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_bcde()), daemon=True) alloc_process2.start() assert cache.current_size_bytes == 0 alloc_event.set() (handle_a,) = pipe_receiver.recv() handle_b, handle_c, handle_d = pipe_receiver.recv() with cache.use_cache(handle_a) as (tensor_a,): assert tensor_a.dtype == torch.uint8 tensor_a[2:5] = torch.tensor((42, 43, 44)) with cache.use_cache(handle_a, handle_b, handle_d) as (tensor_a, tensor_b, tensor_d): assert tensor_b.dtype == torch.float64 and tensor_b.numel() == 1 and tensor_b.ndim == 0 assert tensor_d.dtype == torch.int64 and tensor_d.numel() == 0 tensor_a += 1 tensor_b[...] = -1.337 assert cache.current_size_bytes == 809 # this checks a,b,c,d are allocated but b still awaits memory dealloc_bcd_event.set() await asyncio.sleep(0.1) assert cache.current_size_bytes == 768 # only tensor a should be allocated with pytest.raises(KeyError): with cache.use_cache(handle_a, handle_b): pass # one of handles (c) is deallocated with pytest.raises(KeyError): with cache.use_cache(handle_d): pass # handle_d is deallocated correctly, even though it is never used with cache.use_cache(handle_a) as (tensor_a,): assert tuple(tensor_a[2:5]) == (43, 44, 45) dealloc_a_event.set() (handle_e,) = pipe_receiver.recv() # e can finally be allocated await asyncio.sleep(0.1) assert cache.current_size_bytes == 1536 # tensor e should finally be able to allocate with pytest.raises(KeyError): with cache.use_cache(handle_a): pass # tensor a is no longer allocated with cache.use_cache(handle_e) as (tensor_e,): assert tensor_e.dtype == torch.bfloat16 and tensor_e.shape == (96, 8) dealloc_e_event.set() await asyncio.sleep(0.1) assert cache.current_size_bytes == 1792 # only tensor f is still allocated dealloc_f_event.set() alloc_process1.join() alloc_process2.join() await asyncio.sleep(0.1) assert cache.current_size_bytes == 0 assert cache.current_size_bytes == 0 assert alloc_process1.exitcode == 0, "allocation process 1 failed or did not finish, see stderr for details" assert alloc_process2.exitcode == 0, "allocation process 2 failed or did not finish, see stderr for details" ================================================ FILE: tests/test_chained_calls.py ================================================ ###### # Warning:torch this test is a work in progress. It will be modified soon. # - if you want more stable tests, see test_block_exact_match # - if you want to figure out chained inference, ask yozh import pytest import torch from petals import AutoDistributedConfig from petals.client.remote_sequential import RemoteSequential from petals.server.from_pretrained import load_pretrained_block from petals.utils.misc import DUMMY_KEY_PAST from test_utils import * @pytest.mark.forked def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1): config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) remote_blocks = RemoteSequential(config, start_block=3, end_block=6) assert isinstance(remote_blocks, RemoteSequential) ref_blocks = [ load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32), load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32), load_pretrained_block(MODEL_NAME, 5, torch_dtype=torch.float32), ] inputs = torch.randn(1, seq_length, config.hidden_size, requires_grad=True) outputs_rpc = remote_blocks.forward(inputs) outputs_rpc.sum().backward() grads_rpc = inputs.grad inputs.grad = None hidden_states = inputs for ref_block in ref_blocks: hidden_states = ref_block.forward(hidden_states)[0] outputs_ref = hidden_states outputs_ref.sum().backward() grads_ref = inputs.grad assert torch.allclose(outputs_ref, outputs_rpc, rtol=0, atol=atol_forward) assert torch.allclose(grads_ref, grads_rpc, rtol=0, atol=atol_backward) @pytest.mark.forked def test_chained_inference_exact_match(atol_inference=1e-4): config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) remote_blocks = RemoteSequential(config, start_block=3, end_block=5) inputs = torch.randn(1, 8, config.hidden_size) outputs_inference = [] with remote_blocks.inference_session(max_length=inputs.shape[1]) as sess: for i in range(inputs.shape[1]): outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) outputs_inference = torch.cat(outputs_inference, dim=1) dtype = torch.float32 ref_blocks = [ load_pretrained_block(MODEL_NAME, 3, torch_dtype=dtype), load_pretrained_block(MODEL_NAME, 4, torch_dtype=dtype), ] outputs_ref = [] cache = (DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype)) caches = [cache, cache] for i in range(inputs.shape[1]): new_caches = [] hidden_states = inputs[:, i : i + 1, :] for ref_block, cache in zip(ref_blocks, caches): with torch.no_grad(): hidden_states, new_cache = ref_block.forward(hidden_states, use_cache=True, layer_past=cache) new_caches.append(new_cache) outputs_ref.append(hidden_states) caches = new_caches outputs_ref = torch.cat(outputs_ref, dim=1) assert torch.allclose(outputs_ref, outputs_inference, rtol=0, atol=atol_inference) ================================================ FILE: tests/test_dtype.py ================================================ import pytest import torch from petals.server.block_utils import resolve_block_dtype from petals.server.from_pretrained import load_pretrained_block from petals.utils.auto_config import AutoDistributedConfig from test_utils import MODEL_NAME @pytest.mark.forked @pytest.mark.parametrize("torch_dtype", [torch.float32, torch.float16, "auto"]) def test_block_dtype(torch_dtype): config = AutoDistributedConfig.from_pretrained(MODEL_NAME) block = load_pretrained_block(MODEL_NAME, 0, config=config, torch_dtype=torch_dtype) expected_dtype = resolve_block_dtype(config, torch_dtype) assert all(param.dtype == expected_dtype for param in block.parameters()) ================================================ FILE: tests/test_full_model.py ================================================ import peft import pytest import torch import transformers from hivemind import get_logger from petals import AutoDistributedModelForCausalLM from test_utils import * logger = get_logger(__name__) @pytest.fixture def tokenizer(): # We set use_fast=False since LlamaTokenizerFast is slow on load return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False) @pytest.fixture def model(): return AutoDistributedModelForCausalLM.from_pretrained( MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32 ) @pytest.fixture def ref_model(): return transformers.AutoModelForCausalLM.from_pretrained( REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32 ) @pytest.mark.forked @pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,)) @pytest.mark.parametrize("pass_empty_tensors", (True, False)) def test_full_model_exact_match(tokenizer, model, ref_model, use_peft, pass_empty_tensors, atol=1e-3): if use_peft: model.config.active_adapter = ADAPTER_NAME ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME) ref_model.train(False) test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"] with torch.inference_mode(): parallel_outputs = model.forward(test_inputs).logits assert torch.all(torch.isfinite(parallel_outputs)) logger.info("Forward outputs are finite") embs = model.transformer.word_embeddings(test_inputs) embs = model.transformer.word_embeddings_layernorm(embs) recurrent_outputs = [] with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess: if pass_empty_tensors: recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size))) for t in range(embs.shape[1]): if t == 4: recurrent_outputs.append(sess.step(embs[:, 4:9, :])) elif 4 < t < 9: continue else: recurrent_outputs.append(sess.step(embs[:, t : t + 1, :])) if t == 2 and pass_empty_tensors: recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size))) recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size))) recurrent_outputs = torch.cat(recurrent_outputs, dim=1) recurrent_outputs = model.transformer.ln_f(recurrent_outputs) recurrent_outputs = model.lm_head(recurrent_outputs) assert torch.allclose( recurrent_outputs, parallel_outputs, rtol=0, atol=atol ), "Inference differs from forward pass" ref_outputs = ref_model.forward(test_inputs).logits.float() assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol), "Outputs are not identical to HF" def make_generate_calls(model, inputs, *, max_new_tokens, multiple_calls=False, **kwargs): if not multiple_calls: return model.generate(inputs, max_new_tokens=max_new_tokens, **kwargs) with model.inference_session(max_length=inputs.shape[1] + max_new_tokens) as sess: return torch.cat( [ # Sessions provided both explicitly and implicitly should work model.generate(inputs, max_new_tokens=1, **kwargs, session=sess), model.generate(None, max_new_tokens=max_new_tokens - 2, **kwargs), model.generate(None, max_new_tokens=1, **kwargs), ], dim=1, ) @pytest.mark.forked def test_greedy_generation(tokenizer, model, ref_model, max_new_tokens=4): inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[ "input_ids" ] options = dict(max_new_tokens=max_new_tokens, do_sample=False) for multiple_calls in [False, True]: for inputs in [inputs_single, inputs_batch]: outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options) ref_outputs = ref_model.generate(inputs, **options) assert torch.allclose( outputs, ref_outputs ), f"Greedy generation is not identical to HF with {multiple_calls=}, {inputs.shape=}" @pytest.mark.forked def test_sampling(tokenizer, model, ref_model, max_new_tokens=10): inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[ "input_ids" ] for options in [ dict(do_sample=True, temperature=0.5, top_k=5, top_p=0.9), dict(do_sample=True, temperature=0.5, repetition_penalty=1.2), ]: options.update(max_new_tokens=max_new_tokens) for multiple_calls in [False, True]: for inputs in [inputs_single, inputs_batch]: torch.manual_seed(0) outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options) torch.manual_seed(0) ref_outputs = ref_model.generate(inputs, **options) assert torch.allclose( outputs, ref_outputs ), f"Sampling is not identical to HF with {options=}, {multiple_calls=}, {inputs.shape=}" @pytest.mark.skipif( "bloom" not in MODEL_NAME.lower(), reason="Mixtral and Llama use DynamicCache, which can change based on beam search choices", ) @pytest.mark.forked def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5): inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] options = dict(max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False) outputs = make_generate_calls(model, inputs, **options) ref_outputs = ref_model.generate(inputs, **options) assert torch.allclose(outputs, ref_outputs), f"Beam search results are not identical to HF" @pytest.mark.forked def test_input_ids(tokenizer, model, ref_model, max_new_tokens=4): inputs = tokenizer("A cat sat on a mat", return_tensors="pt") assert inputs.keys() == {"input_ids", "attention_mask"} outputs = model.generate(**inputs, max_new_tokens=max_new_tokens) ref_outputs = ref_model.generate(**inputs, max_new_tokens=max_new_tokens) assert torch.allclose(outputs, ref_outputs), f"Outputs are not identical to HF" with model.inference_session(max_length=inputs["input_ids"].shape[1] + max_new_tokens): outputs = torch.cat( [ model.generate(**inputs, max_new_tokens=2), model.generate(None, max_new_tokens=max_new_tokens - 2), ], dim=1, ) assert torch.allclose(outputs, ref_outputs), f"Multi-call outputs are not identical to HF" ================================================ FILE: tests/test_optimized_layers.py ================================================ from typing import Optional, Tuple import pytest import torch from transformers.cache_utils import DynamicCache from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel from petals.server.block_utils import get_model_block from petals.utils.auto_config import AutoDistributedConfig from petals.utils.convert_block import QuantType, convert_block from test_utils import MODEL_NAME KVCache = Tuple[torch.Tensor, torch.Tensor] class UnoptimizedWrappedFalconBlock(FalconDecoderLayer): def forward( self, hidden_states: torch.Tensor, *args, attention_mask: Optional[torch.Tensor] = None, alibi: Optional[torch.Tensor] = None, layer_past: Optional[KVCache] = None, use_cache: bool = False, **kwargs, ): batch_size, seq_length = hidden_states.shape[:2] if layer_past is not None: layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past) past_length = 0 if layer_past is None else layer_past[0].shape[1] seq_length_with_past = seq_length + past_length attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) if alibi is None and self.config.alibi: alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length) outputs = super().forward( hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, use_cache=use_cache, **kwargs, ) if use_cache: present_key_value = outputs[-1] present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value) outputs = outputs[:-1] + (present_key_value,) return outputs def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache: key_states, value_states = key_value key_states = key_states.permute(0, 2, 1) assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] if self.config.new_decoder_architecture: key_states = self._expand_states(key_states) value_states = self._expand_states(value_states) return (key_states, value_states) def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache: key_states, value_states = key_value if self.config.new_decoder_architecture: key_states = self._collapse_states(key_states) value_states = self._collapse_states(value_states) assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] key_states = key_states.permute(0, 2, 1) return (key_states, value_states) def _expand_states(self, state: torch.Tensor) -> torch.Tensor: batch_size_x_num_kv_heads, seq_len, head_dim = state.shape batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim) state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy return state def _collapse_states(self, state: torch.Tensor) -> torch.Tensor: batch_size_x_num_attn_heads, seq_len, head_dim = state.shape batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim) state = state[:, :, 0] state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim) return state class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer): def forward( self, hidden_states: torch.Tensor, *args, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, layer_past: Optional[Tuple[torch.Tensor]] = None, use_cache: bool = False, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: batch_size, seq_length, _ = hidden_states.shape seq_length_with_past = seq_length past_key_values_length = 0 past_key_value = layer_past if past_key_value is not None: past_key_values_length = past_key_value[0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length) elif use_cache: past_key_value = DynamicCache() if position_ids is None: device = hidden_states.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() # embed positions if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device ) attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length ) outputs = super().forward( hidden_states, *args, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache, **kwargs, ) if use_cache: present_key_value = outputs[-1] present_key_value = self._reorder_cache_from_llama_to_bloom( present_key_value, batch_size, seq_length_with_past ) outputs = outputs[:-1] + (present_key_value,) return outputs def _reorder_cache_from_bloom_to_llama( self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int ) -> DynamicCache: key_states, value_states = key_value key_states = key_states.permute(0, 2, 1) key_states = key_states.view( batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim ) value_states = value_states.view(*key_states.shape) past_key_values = ((key_states, value_states),) return DynamicCache.from_legacy_cache(past_key_values) def _reorder_cache_from_llama_to_bloom( self, key_value: DynamicCache, batch_size: int, seq_length: int ) -> Tuple[torch.Tensor]: key_states, value_states = key_value.to_legacy_cache()[0] value_states = value_states.view( batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim ) key_states = key_states.view(*value_states.shape) key_states = key_states.permute(0, 2, 1) return (key_states, value_states) @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) @pytest.mark.forked def test_optimized_block(device): if device == "cuda:0" and not torch.cuda.is_available(): pytest.skip("CUDA tests can be run only in CUDA-enabled setups") config = AutoDistributedConfig.from_pretrained(MODEL_NAME) tensor_parallel_devices = (device,) dtype = torch.bfloat16 quant_type = QuantType.NONE block_idx = 1 block = get_model_block(config, layer_idx=block_idx).to(dtype) block = convert_block(block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) if config.model_type == "falcon": unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype) elif config.model_type == "llama": unopt_block = UnoptimizedWrappedLlamaBlock(config, layer_idx=0).to(dtype) else: pytest.skip(f"This test is not applicable to {config.model_type} models") unopt_block = convert_block( unopt_block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True ) unopt_block.load_state_dict(block.state_dict()) cache = unopt_cache = None with torch.inference_mode(): for length in [10, 1, 1, 1]: dummy_input = torch.randn(1, length, config.hidden_size, device=device, dtype=dtype) block_output, cache = block(dummy_input, layer_past=cache, use_cache=True) unopt_block_output, unopt_cache = unopt_block(dummy_input, layer_past=unopt_cache, use_cache=True) assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), length assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), length assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), length ================================================ FILE: tests/test_peft.py ================================================ import os import shutil import pytest from huggingface_hub import snapshot_download from petals.utils.peft import check_peft_repository, load_peft UNSAFE_PEFT_REPO = "artek0chumak/bloom-560m-unsafe-peft" SAFE_PEFT_REPO = "artek0chumak/bloom-560m-safe-peft" TMP_CACHE_DIR = "tmp_cache/" def clear_dir(path_to_dir): shutil.rmtree(path_to_dir) os.mkdir(path_to_dir) def dir_empty(path_to_dir): files = os.listdir(path_to_dir) return len(files) == 0 @pytest.mark.forked def test_check_peft(): assert not check_peft_repository(UNSAFE_PEFT_REPO), "NOSAFE_PEFT_REPO is safe to load." assert check_peft_repository(SAFE_PEFT_REPO), "SAFE_PEFT_REPO is not safe to load." @pytest.mark.forked def test_load_noncached(tmpdir): clear_dir(tmpdir) with pytest.raises(Exception): load_peft(UNSAFE_PEFT_REPO, cache_dir=tmpdir) assert dir_empty(tmpdir), "UNSAFE_PEFT_REPO is loaded" load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir) assert not dir_empty(tmpdir), "SAFE_PEFT_REPO is not loaded" @pytest.mark.forked def test_load_cached(tmpdir): clear_dir(tmpdir) snapshot_download(SAFE_PEFT_REPO, cache_dir=tmpdir) load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir) @pytest.mark.forked def test_load_layer_exists(tmpdir): clear_dir(tmpdir) load_peft(SAFE_PEFT_REPO, block_idx=2, cache_dir=tmpdir) @pytest.mark.forked def test_load_layer_nonexists(tmpdir): clear_dir(tmpdir) load_peft( SAFE_PEFT_REPO, block_idx=1337, cache_dir=tmpdir, ) ================================================ FILE: tests/test_priority_pool.py ================================================ import multiprocessing as mp import platform import time import pytest import torch from hivemind.moe.server.runtime import Runtime from petals.server.task_pool import PrioritizedTaskPool def _submit_tasks(runtime_ready, pools, results_valid): runtime_ready.wait() futures = [] futures.append(pools[0].submit_task(torch.tensor([0]), priority=1)) futures.append(pools[0].submit_task(torch.tensor([1]), priority=1)) time.sleep(0.01) futures.append(pools[1].submit_task(torch.tensor([2]), priority=1)) futures.append(pools[0].submit_task(torch.tensor([3]), priority=2)) futures.append(pools[0].submit_task(torch.tensor([4]), priority=10)) futures.append(pools[0].submit_task(torch.tensor([5]), priority=0)) futures.append(pools[0].submit_task(torch.tensor([6]), priority=1)) futures.append(pools[1].submit_task(torch.tensor([7]), priority=11)) futures.append(pools[1].submit_task(torch.tensor([8]), priority=1)) for i, f in enumerate(futures): assert f.result()[0].item() == i**2 results_valid.set() @pytest.mark.skipif(platform.system() == "Darwin", reason="Flapping on macOS due to multiprocessing quirks") @pytest.mark.forked def test_priority_pools(): outputs_queue = mp.SimpleQueue() runtime_ready = mp.Event() results_valid = mp.Event() def dummy_pool_func(x): time.sleep(0.1) y = x**2 outputs_queue.put((x, y)) return (y,) class DummyBackend: def __init__(self, pools): self.pools = pools def get_pools(self): return self.pools pools = ( PrioritizedTaskPool(dummy_pool_func, name="A", max_batch_size=1), PrioritizedTaskPool(dummy_pool_func, name="B", max_batch_size=1), ) # Simulate requests coming from ConnectionHandlers proc = mp.context.ForkProcess(target=_submit_tasks, args=(runtime_ready, pools, results_valid)) proc.start() runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0) runtime.ready = runtime_ready runtime.start() proc.join() assert results_valid.is_set() ordered_outputs = [] while not outputs_queue.empty(): ordered_outputs.append(outputs_queue.get()[0].item()) assert ordered_outputs == [0, 5, 1, 2, 6, 8, 3, 4, 7] # 0 - first batch is loaded immediately, before everything else # 5 - highest priority task overall # 1 - first of several tasks with equal lowest priority (1) # 2 - second earliest task with priority 1, fetched from pool B # 6 - third earliest task with priority 1, fetched from pool A again # 8 - last priority-1 task, pool B # 3 - task with priority 2 from pool A # 4 - task with priority 10 from pool A # 7 - task with priority 11 from pool B runtime.shutdown() ================================================ FILE: tests/test_remote_sequential.py ================================================ import pytest import torch import torch.nn.functional as F from hivemind import DHT, BatchTensorDescriptor, get_logger from hivemind.proto import runtime_pb2 from petals import AutoDistributedConfig from petals.client import RemoteSequenceManager, RemoteSequential from petals.data_structures import UID_DELIMITER from petals.server.from_pretrained import load_pretrained_block from test_utils import * logger = get_logger(__name__) @pytest.mark.forked def test_remote_sequential(): config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True) test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True) grad_proj = torch.randn(1, 5, config.hidden_size) sequential = RemoteSequential(config, dht=dht) full_outputs = sequential(test_inputs) (full_outputs * grad_proj).sum().backward() assert test_inputs.grad is not None full_grad = test_inputs.grad.clone() test_inputs.grad.data.zero_() first_half = sequential[: config.num_hidden_layers // 2] second_half = sequential[config.num_hidden_layers // 2 :] assert len(first_half) + len(second_half) == len(sequential) assert abs(len(first_half) - len(second_half)) == config.num_hidden_layers % 2 for m in sequential, first_half, second_half: assert isinstance(repr(m), str) hidden = first_half(test_inputs) assert isinstance(hidden, torch.Tensor) assert hidden.shape == test_inputs.shape assert hidden.requires_grad second_half_outputs = second_half(hidden) assert torch.allclose(second_half_outputs, full_outputs, atol=1e-3) (second_half_outputs * grad_proj).sum().backward() assert torch.allclose(test_inputs.grad, full_grad, atol=3e-2) # test RemoteSequential with lossy compression block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)] lossy_sequential = RemoteSequential( config, sequence_manager=DummyCustomSequenceManager(config, block_uids, dht=dht) ) test_inputs.grad = None approx_outputs = lossy_sequential(test_inputs) (approx_outputs * grad_proj).sum().backward() assert not torch.allclose(approx_outputs, full_outputs, rtol=0, atol=1e-4), "compression was not used" assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-3), "compression was not used" assert abs(approx_outputs - full_outputs).mean() < 0.01 absmax = abs(full_grad).max() assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.05 class DummyCustomSequenceManager(RemoteSequenceManager): """A sequence manager that compresses inputs/outputs during forward and backward pass.""" @property def rpc_info(self): rpc_info = super().rpc_info dims = (2048, 1024) compressed_input_schema = BatchTensorDescriptor(dims, compression=runtime_pb2.CompressionType.FLOAT16) rpc_info["forward_schema"] = (compressed_input_schema,), dict() # (args, kwargs) return rpc_info def get_request_metadata(self, protocol: str, *args, **kwargs): metadata = super().get_request_metadata(protocol, *args, **kwargs) if protocol == "rpc_forward": metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,) elif protocol == "rpc_backward": metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,) # FIXME: Initially, we used CompressionType.BLOCKWISE_8BIT for rpc_backward() here. # This is currently broken since hivemind==1.1.8 is not compatible with bitsandbytes==0.39.1. # Please revert to BLOCKWISE_8BIT once this is fixed: https://github.com/learning-at-home/hivemind/issues/572 return metadata @pytest.mark.forked def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3): config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) remote_sequential = RemoteSequential(config) inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1) output_proj = F.normalize(torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size), dim=-1) input_prompts = F.normalize(torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True), dim=-1) intermediate_prompts = torch.randn( config.num_hidden_layers, batch_size, pre_seq_len, config.hidden_size, requires_grad=True ) input_prompts = input_prompts.detach().requires_grad_(True) intermediate_prompts = intermediate_prompts.detach().requires_grad_(True) inputs_with_prompts = torch.cat([inputs, input_prompts], dim=1) assert inputs_with_prompts.shape == (batch_size, seq_len + pre_seq_len, config.hidden_size) outputs = remote_sequential(inputs_with_prompts, prompts=intermediate_prompts) (outputs * output_proj).sum().backward() assert intermediate_prompts.grad is not None input_prompts_ref = input_prompts.clone().detach().requires_grad_(True) intermediate_prompts_ref = intermediate_prompts.clone().detach().requires_grad_(True) assert input_prompts_ref.grad is None assert intermediate_prompts_ref.grad is None outputs_ref = torch.cat([inputs, input_prompts_ref], dim=1) for block_index in range(config.num_hidden_layers): block_prompt = intermediate_prompts_ref[block_index] outputs_ref[:, : block_prompt.shape[1]] += block_prompt block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32) (outputs_ref,) = block(outputs_ref) assert torch.allclose(outputs_ref, outputs, atol=1e-3) (outputs_ref * output_proj).sum().backward() assert input_prompts_ref.grad is not None assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=3e-2) assert intermediate_prompts_ref.grad is not None assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad, atol=1e-2) ================================================ FILE: tests/test_sequence_manager.py ================================================ import threading import time import pytest import torch from hivemind import DHT, get_logger from petals import AutoDistributedConfig from petals.client import RemoteSequenceManager, RemoteSequential from petals.data_structures import UID_DELIMITER from test_utils import * logger = get_logger(__name__) @pytest.mark.forked @pytest.mark.parametrize("mode", ["max_throughput", "min_latency"]) def test_sequence_manager_basics(mode: str): config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True) sequential = RemoteSequential(config, dht=dht) shutdown_evt = threading.Event() # test RemoteSequential with lossy compression block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)] sequential = RemoteSequential( config, sequence_manager=RemoteSequenceManagerWithChecks(config, block_uids, dht=dht, _was_shut_down=shutdown_evt), ) sequence = sequential.sequence_manager.make_sequence(mode=mode) assert all(sequence[i].peer_id != sequence[i + 1].peer_id for i in range(len(sequence) - 1)) assert sequential.sequence_manager.is_alive() assert sequential.sequence_manager._thread.ready.is_set() assert not shutdown_evt.is_set() sequential(torch.randn(1, 2, config.hidden_size)) sequential.sequence_manager.shutdown() del sequential time.sleep(1) assert shutdown_evt.is_set() class RemoteSequenceManagerWithChecks(RemoteSequenceManager): """A sequence manager that signals if it was shut down""" def __init__(self, *args, _was_shut_down: threading.Event, **kwargs): super().__init__(*args, **kwargs) self._was_shut_down = _was_shut_down def shutdown(self): super().shutdown() assert not self.is_alive() self._was_shut_down.set() ================================================ FILE: tests/test_server_stats.py ================================================ import time import hivemind import pytest import torch from petals import AutoDistributedConfig, RemoteSequential from petals.server.handler import CACHE_TOKENS_AVAILABLE from test_utils import * @pytest.mark.forked def test_server_info(block_from: int = 2, block_to: int = 5, max_length: int = 100, max_length2: int = 50): config = AutoDistributedConfig.from_pretrained(MODEL_NAME) config.allowed_servers = ["QmNV5G3hq2UmAck2htEgsqrmPFBff5goFZAdmKDcZLBZLX"] # PeerID from server2.id dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) blocks1 = RemoteSequential(config, dht=dht, start_block=block_from, end_block=block_to) blocks2 = RemoteSequential(config, dht=dht, start_block=block_to - 1, end_block=block_to) info_before = blocks1.sequence_manager.rpc_info with blocks1.inference_session(max_length=max_length) as sess: sess.step(torch.randn(1, 1, config.hidden_size)) blocks1.sequence_manager.state.rpc_info = None # invalidate cache info_inside = blocks1.sequence_manager.rpc_info with blocks2.inference_session(max_length=max_length2) as sess2: sess2.step(torch.randn(1, 1, config.hidden_size)) blocks2.sequence_manager.state.rpc_info = None # invalidate cache info_inside2 = blocks2.sequence_manager.rpc_info time.sleep(0.1) blocks1.sequence_manager.state.rpc_info = None # invalidate cache info_after = blocks1.sequence_manager.rpc_info assert info_before[CACHE_TOKENS_AVAILABLE] == info_after[CACHE_TOKENS_AVAILABLE] assert info_before[CACHE_TOKENS_AVAILABLE] - info_inside[CACHE_TOKENS_AVAILABLE] == max_length * len(blocks1) assert info_inside[CACHE_TOKENS_AVAILABLE] - info_inside2[CACHE_TOKENS_AVAILABLE] == max_length2 * len(blocks2) ================================================ FILE: tests/test_speculative_generation.py ================================================ import random import pytest import torch import transformers from petals import ( AutoDistributedConfig, AutoDistributedSpeculativeModel, DistributedLlamaForSpeculativeGeneration, RemoteSequential, ) from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS from petals.server.from_pretrained import load_pretrained_block from test_utils import * @pytest.mark.forked def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, atol_inference=1e-3): config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) remote_sequential = RemoteSequential(config) block_index = random.randint(0, config.num_hidden_layers - 1) remote_block = remote_sequential[block_index] inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size) short_inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size) short_inputs[:, :2, :] = inputs[:, :2, :] initial_outputs_inference = None secondary_outputs_inference = None with torch.inference_mode(): with remote_block.inference_session(max_length=inputs.shape[1]) as sess: initial_outputs_inference = sess.step(inputs) sess.position = 2 secondary_outputs_inference = sess.step(short_inputs[:, 2:, :]) result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1) ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32) (outputs_local,) = ref_block(short_inputs) assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference) @pytest.fixture def noisy_model(): noisy_model = transformers.AutoModelForCausalLM.from_pretrained( REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32 ) lm_head = noisy_model.get_output_embeddings() assert isinstance(lm_head, torch.nn.Linear) with torch.no_grad(): lm_head.weight += torch.randn_like(lm_head.weight) * 0.02 return noisy_model @pytest.fixture def model(): return transformers.AutoModelForCausalLM.from_pretrained( MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32 ) @pytest.fixture def tokenizer(): # We set use_fast=False since LlamaTokenizerFast is slow on load return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False) @pytest.mark.forked @pytest.mark.skipif( "llama" not in MODEL_NAME.lower(), reason="Speculative generation now works only for llama models", ) def test_remote_speculative_generation(tokenizer, model, noisy_model, atol_inference=1e-3): speculated_distributed_model = AutoDistributedSpeculativeModel.from_pretrained( MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32, small_model=noisy_model ) inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] generated_spec = speculated_distributed_model.generate(inputs_single, max_new_tokens=100, do_sample=False) generated_local = model.generate(inputs_single, max_new_tokens=100, do_sample=False) assert torch.allclose(generated_spec, generated_local, rtol=0, atol=atol_inference) ================================================ FILE: tests/test_tensor_parallel.py ================================================ import random import pytest import torch import transformers from tensor_parallel import TensorParallel from tensor_parallel.slicing_configs import get_bloom_config from petals.server.from_pretrained import load_pretrained_block from test_utils import MODEL_NAME @pytest.mark.forked @pytest.mark.parametrize("custom_config", [True, False]) @pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3, ("cpu",) * 4]) def test_tp_block(devices, custom_config): model_config = transformers.AutoConfig.from_pretrained(MODEL_NAME) if model_config.model_type != "bloom": pytest.skip("Tensor parallelism is implemented only for BLOOM for now") block_index = random.randint(0, 10) block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32).to(devices[0]) tp_config = None if custom_config: tp_config = get_bloom_config(model_config, devices) batch_size = 2 prefix_length = 5 test_inputs1 = torch.randn(batch_size, 3, 1024, requires_grad=True, device=devices[0]) test_inputs2 = test_inputs1.detach().clone().requires_grad_(True) test_prefix1 = torch.randn(batch_size, prefix_length, 1024, requires_grad=True, device=devices[0]) test_prefix2 = test_prefix1.detach().clone().requires_grad_(True) grad_proj = torch.rand_like(test_inputs1) y_prefix_ref, layer_past = block(test_prefix1, use_cache=True) y_ref, cache_ref = block(test_inputs1, use_cache=True, layer_past=layer_past) y_ref.backward(grad_proj) block_tp = TensorParallel(block, devices, config=tp_config) y_prefix, layer_past = block_tp(test_prefix2, use_cache=True) y_ours, cache_ours = block_tp(test_inputs2, use_cache=True, layer_past=layer_past) y_ours.backward(grad_proj) assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-5) assert torch.allclose(y_ours, y_ref, atol=1e-5) assert torch.allclose(test_inputs1.grad, test_inputs2.grad, atol=1e-4) assert torch.allclose(test_prefix1.grad, test_prefix2.grad, atol=1e-4) ================================================ FILE: tests/test_utils.py ================================================ import os INITIAL_PEERS = os.environ.get("INITIAL_PEERS") if not INITIAL_PEERS: raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids") INITIAL_PEERS = INITIAL_PEERS.split() MODEL_NAME = os.environ.get("MODEL_NAME") if not MODEL_NAME: raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested") REF_NAME = os.environ.get("REF_NAME") ADAPTER_NAME = os.environ.get("ADAPTER_NAME")