Repository: huggingface/text-generation-inference
Branch: main
Commit: db931fcf42da
Files: 864
Total size: 6.3 MB
Directory structure:
gitextract_uvlvpncm/
├── .dockerignore
├── .github/
│ ├── ISSUE_TEMPLATE/
│ │ ├── bug-report.yml
│ │ ├── config.yml
│ │ ├── feature-request.yml
│ │ └── new-model-addition.yml
│ ├── PULL_REQUEST_TEMPLATE.md
│ └── workflows/
│ ├── autodocs.yaml
│ ├── build.yaml
│ ├── build_documentation.yaml
│ ├── build_pr_documentation.yaml
│ ├── ci_build.yaml
│ ├── client-tests.yaml
│ ├── codeql.yml
│ ├── integration_tests.yaml
│ ├── load_test.yaml
│ ├── nix_build.yaml
│ ├── nix_cache.yaml
│ ├── nix_tests.yaml
│ ├── stale.yaml
│ ├── tests.yaml
│ ├── trufflehog.yaml
│ └── upload_pr_documentation.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── .redocly.lint-ignore.yaml
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── Cargo.toml
├── Dockerfile
├── Dockerfile.neuron
├── Dockerfile.nix
├── Dockerfile_amd
├── Dockerfile_gaudi
├── Dockerfile_intel
├── Dockerfile_llamacpp
├── Dockerfile_trtllm
├── LICENSE
├── Makefile
├── README.md
├── assets/
│ └── tgi_grafana.json
├── backends/
│ ├── client/
│ │ ├── Cargo.toml
│ │ ├── build.rs
│ │ └── src/
│ │ ├── lib.rs
│ │ ├── v2/
│ │ │ ├── client.rs
│ │ │ ├── mod.rs
│ │ │ └── sharded_client.rs
│ │ └── v3/
│ │ ├── client.rs
│ │ ├── mod.rs
│ │ └── sharded_client.rs
│ ├── gaudi/
│ │ ├── Makefile
│ │ ├── README.md
│ │ ├── examples/
│ │ │ └── docker_commands/
│ │ │ └── docker_commands.md
│ │ ├── server/
│ │ │ ├── .gitignore
│ │ │ ├── Makefile
│ │ │ ├── Makefile-awq
│ │ │ ├── Makefile-eetq
│ │ │ ├── Makefile-fbgemm
│ │ │ ├── Makefile-flash-att
│ │ │ ├── Makefile-flash-att-v2
│ │ │ ├── Makefile-selective-scan
│ │ │ ├── Makefile-vllm
│ │ │ ├── README.md
│ │ │ ├── dill-0.3.7-patch.sh
│ │ │ ├── dill-0.3.8-patch.sh
│ │ │ ├── pyproject.toml
│ │ │ ├── requirements.txt
│ │ │ └── text_generation_server/
│ │ │ ├── __init__.py
│ │ │ ├── adapters/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── config.py
│ │ │ │ ├── lora.py
│ │ │ │ └── weights.py
│ │ │ ├── cache.py
│ │ │ ├── cli.py
│ │ │ ├── interceptor.py
│ │ │ ├── layers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── attention/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── common.py
│ │ │ │ │ ├── hpu.py
│ │ │ │ │ └── kv_cache.py
│ │ │ │ ├── awq/
│ │ │ │ │ ├── conversion_utils.py
│ │ │ │ │ └── quantize/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── hpu.py
│ │ │ │ ├── bnb.py
│ │ │ │ ├── compressed_tensors/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── loader.py
│ │ │ │ │ └── w8an_fp.py
│ │ │ │ ├── conv.py
│ │ │ │ ├── exl2.py
│ │ │ │ ├── fp8.py
│ │ │ │ ├── gptq/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── hpu.py
│ │ │ │ │ ├── quantize.py
│ │ │ │ │ └── utils.py
│ │ │ │ ├── layernorm.py
│ │ │ │ ├── linear.py
│ │ │ │ ├── lora.py
│ │ │ │ ├── medusa.py
│ │ │ │ ├── mlp.py
│ │ │ │ ├── moe/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── fp8.py
│ │ │ │ │ ├── fused_moe.py
│ │ │ │ │ └── unquantized.py
│ │ │ │ ├── rotary.py
│ │ │ │ ├── speculative.py
│ │ │ │ └── tensor_parallel.py
│ │ │ ├── models/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── custom_modeling/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── bloom_modeling.py
│ │ │ │ │ ├── clip.py
│ │ │ │ │ ├── flash_cohere_modeling.py
│ │ │ │ │ ├── flash_dbrx_modeling.py
│ │ │ │ │ ├── flash_deepseek_v2_modeling.py
│ │ │ │ │ ├── flash_deepseek_v3_modeling.py
│ │ │ │ │ ├── flash_gemma2_modeling.py
│ │ │ │ │ ├── flash_gemma3_modeling.py
│ │ │ │ │ ├── flash_gemma_modeling.py
│ │ │ │ │ ├── flash_gpt2_modeling.py
│ │ │ │ │ ├── flash_gptj_modeling.py
│ │ │ │ │ ├── flash_llama4_modeling.py
│ │ │ │ │ ├── flash_llama_modeling.py
│ │ │ │ │ ├── flash_llava_next.py
│ │ │ │ │ ├── flash_mistral_modeling.py
│ │ │ │ │ ├── flash_mixtral_modeling.py
│ │ │ │ │ ├── flash_mllama.py
│ │ │ │ │ ├── flash_neox_modeling.py
│ │ │ │ │ ├── flash_pali_gemma_modeling.py
│ │ │ │ │ ├── flash_phi_modeling.py
│ │ │ │ │ ├── flash_phi_moe_modeling.py
│ │ │ │ │ ├── flash_qwen2_modeling.py
│ │ │ │ │ ├── flash_qwen3_modeling.py
│ │ │ │ │ ├── flash_qwen3_moe_modeling.py
│ │ │ │ │ ├── flash_rw_modeling.py
│ │ │ │ │ ├── flash_santacoder_modeling.py
│ │ │ │ │ ├── flash_starcoder2_modeling.py
│ │ │ │ │ ├── idefics2.py
│ │ │ │ │ ├── idefics3.py
│ │ │ │ │ ├── mamba_modeling.py
│ │ │ │ │ ├── qwen2_5_vl.py
│ │ │ │ │ ├── qwen2_vl.py
│ │ │ │ │ ├── siglip.py
│ │ │ │ │ └── vlm.py
│ │ │ │ ├── flash_causal_lm.py
│ │ │ │ ├── flash_vlm_causal_lm.py
│ │ │ │ ├── globals.py
│ │ │ │ ├── mllama_causal_lm.py
│ │ │ │ ├── model.py
│ │ │ │ ├── seq2seq_lm.py
│ │ │ │ └── types.py
│ │ │ ├── pb/
│ │ │ │ └── .gitignore
│ │ │ ├── server.py
│ │ │ ├── tracing.py
│ │ │ └── utils/
│ │ │ ├── __init__.py
│ │ │ ├── adapter.py
│ │ │ ├── chunks.py
│ │ │ ├── convert.py
│ │ │ ├── debug.py
│ │ │ ├── dist.py
│ │ │ ├── hub.py
│ │ │ ├── import_utils.py
│ │ │ ├── kernels.py
│ │ │ ├── log.py
│ │ │ ├── logits_process.py
│ │ │ ├── merges/
│ │ │ │ ├── strategies.py
│ │ │ │ └── utils.py
│ │ │ ├── peft.py
│ │ │ ├── prefill_chunking.py
│ │ │ ├── quantization.py
│ │ │ ├── segments.py
│ │ │ ├── sgmv.py
│ │ │ ├── speculate.py
│ │ │ ├── tokens.py
│ │ │ ├── version.py
│ │ │ ├── watermark.py
│ │ │ └── weights.py
│ │ └── tgi-entrypoint.sh
│ ├── grpc-metadata/
│ │ ├── Cargo.toml
│ │ └── src/
│ │ └── lib.rs
│ ├── llamacpp/
│ │ ├── Cargo.toml
│ │ ├── README.md
│ │ ├── build.rs
│ │ ├── requirements.txt
│ │ └── src/
│ │ ├── backend.rs
│ │ ├── llamacpp.rs
│ │ ├── main.rs
│ │ └── quantize.rs
│ ├── neuron/
│ │ ├── Cargo.toml
│ │ ├── Makefile
│ │ ├── README.md
│ │ ├── server/
│ │ │ ├── .gitignore
│ │ │ ├── Makefile
│ │ │ ├── build-requirements.txt
│ │ │ ├── pyproject.toml
│ │ │ └── text_generation_server/
│ │ │ ├── cli.py
│ │ │ ├── generator.py
│ │ │ ├── interceptor.py
│ │ │ ├── model.py
│ │ │ ├── server.py
│ │ │ └── tgi_env.py
│ │ ├── tests/
│ │ │ ├── conftest.py
│ │ │ ├── fixtures/
│ │ │ │ └── model.py
│ │ │ ├── prune_test_models.py
│ │ │ ├── pytest.ini
│ │ │ ├── requirements.txt
│ │ │ ├── server/
│ │ │ │ ├── helpers.py
│ │ │ │ ├── test_cached_model.py
│ │ │ │ ├── test_continuous_batching.py
│ │ │ │ ├── test_decode.py
│ │ │ │ ├── test_generator_slot.py
│ │ │ │ ├── test_info.py
│ │ │ │ └── test_prefill.py
│ │ │ └── test_entry_point.py
│ │ ├── tgi-entrypoint.sh
│ │ └── tgi_entry_point.py
│ ├── trtllm/
│ │ ├── CMakeLists.txt
│ │ ├── Cargo.toml
│ │ ├── README.md
│ │ ├── build.rs
│ │ ├── cmake/
│ │ │ ├── json.cmake
│ │ │ ├── spdlog.cmake
│ │ │ ├── trtllm.cmake
│ │ │ └── utils/
│ │ │ └── detect_cuda_arch.cu
│ │ ├── csrc/
│ │ │ ├── backend.cpp
│ │ │ ├── backend.hpp
│ │ │ ├── ffi.hpp
│ │ │ └── hardware.hpp
│ │ ├── scripts/
│ │ │ ├── install_tensorrt.sh
│ │ │ └── setup_sccache.py
│ │ ├── src/
│ │ │ ├── errors.rs
│ │ │ ├── lib.rs
│ │ │ ├── looper.rs
│ │ │ ├── main.rs
│ │ │ └── utils.rs
│ │ └── tests/
│ │ ├── test_backend.cpp
│ │ └── test_hardware.cpp
│ ├── v2/
│ │ ├── Cargo.toml
│ │ ├── build.rs
│ │ └── src/
│ │ ├── backend.rs
│ │ ├── client/
│ │ │ ├── grpc_client.rs
│ │ │ ├── mod.rs
│ │ │ └── sharded_client.rs
│ │ ├── lib.rs
│ │ ├── main.rs
│ │ └── queue.rs
│ └── v3/
│ ├── Cargo.toml
│ ├── benches/
│ │ └── prefix_cache.rs
│ ├── build.rs
│ └── src/
│ ├── backend.rs
│ ├── block_allocator.rs
│ ├── client/
│ │ ├── grpc_client.rs
│ │ ├── mod.rs
│ │ └── sharded_client.rs
│ ├── lib.rs
│ ├── main.rs
│ ├── queue.rs
│ └── radix.rs
├── benchmark/
│ ├── Cargo.toml
│ ├── README.md
│ └── src/
│ ├── app.rs
│ ├── event.rs
│ ├── generation.rs
│ ├── lib.rs
│ ├── main.rs
│ ├── table.rs
│ └── utils.rs
├── clients/
│ └── python/
│ ├── .gitignore
│ ├── Makefile
│ ├── README.md
│ ├── pyproject.toml
│ ├── tests/
│ │ ├── conftest.py
│ │ ├── test_client.py
│ │ ├── test_errors.py
│ │ ├── test_inference_api.py
│ │ └── test_types.py
│ └── text_generation/
│ ├── __init__.py
│ ├── client.py
│ ├── errors.py
│ ├── inference_api.py
│ └── types.py
├── crate-hashes.json
├── docs/
│ ├── README.md
│ ├── index.html
│ ├── openapi.json
│ └── source/
│ ├── _toctree.yml
│ ├── architecture.md
│ ├── backends/
│ │ ├── gaudi.mdx
│ │ ├── llamacpp.md
│ │ ├── neuron.md
│ │ └── trtllm.md
│ ├── basic_tutorials/
│ │ ├── consuming_tgi.md
│ │ ├── gated_model_access.md
│ │ ├── monitoring.md
│ │ ├── non_core_models.md
│ │ ├── preparing_model.md
│ │ ├── safety.md
│ │ ├── train_medusa.md
│ │ ├── using_cli.md
│ │ ├── using_guidance.md
│ │ └── visual_language_models.md
│ ├── conceptual/
│ │ ├── chunking.md
│ │ ├── external.md
│ │ ├── flash_attention.md
│ │ ├── guidance.md
│ │ ├── lora.md
│ │ ├── paged_attention.md
│ │ ├── quantization.md
│ │ ├── safetensors.md
│ │ ├── speculation.md
│ │ ├── streaming.md
│ │ └── tensor_parallelism.md
│ ├── index.md
│ ├── installation.md
│ ├── installation_amd.md
│ ├── installation_gaudi.md
│ ├── installation_inferentia.md
│ ├── installation_intel.md
│ ├── installation_nvidia.md
│ ├── installation_tpu.md
│ ├── multi_backend_support.md
│ ├── quicktour.md
│ ├── reference/
│ │ ├── api_reference.md
│ │ ├── launcher.md
│ │ └── metrics.md
│ ├── supported_models.md
│ └── usage_statistics.md
├── flake.nix
├── integration-tests/
│ ├── conftest.py
│ ├── fixtures/
│ │ ├── gaudi/
│ │ │ └── service.py
│ │ └── neuron/
│ │ ├── export_models.py
│ │ └── service.py
│ ├── gaudi/
│ │ ├── capture_expected_outputs.py
│ │ └── test_gaudi_generate.py
│ ├── models/
│ │ ├── __snapshots__/
│ │ │ ├── test.py
│ │ │ ├── test_bloom_560m/
│ │ │ │ ├── test_bloom_560m.json
│ │ │ │ ├── test_bloom_560m_all_params.json
│ │ │ │ └── test_bloom_560m_load.json
│ │ │ ├── test_bloom_560m_sharded/
│ │ │ │ ├── test_bloom_560m_sharded.json
│ │ │ │ └── test_bloom_560m_sharded_load.json
│ │ │ ├── test_chat_llama/
│ │ │ │ └── test_flash_llama_simple.json
│ │ │ ├── test_completion_prompts/
│ │ │ │ ├── test_chat_hfhub_nousage.json
│ │ │ │ ├── test_chat_hfhub_usage.json
│ │ │ │ ├── test_chat_openai_nousage.json
│ │ │ │ ├── test_chat_openai_usage.json
│ │ │ │ ├── test_flash_llama_completion_many_prompts.json
│ │ │ │ ├── test_flash_llama_completion_many_prompts_stream.json
│ │ │ │ ├── test_flash_llama_completion_single_prompt.json
│ │ │ │ └── test_flash_llama_completion_stream_usage.json
│ │ │ ├── test_compressed_tensors_w8a8_int/
│ │ │ │ ├── test_compressed_tensors_w8a8_int.json
│ │ │ │ ├── test_compressed_tensors_w8a8_int_all_params.json
│ │ │ │ └── test_compressed_tensors_w8a8_int_load.json
│ │ │ ├── test_compressed_tensors_w8a8_int_dynamic_weight/
│ │ │ │ ├── test_compressed_tensors_w8a8_int_dynamic_weight.json
│ │ │ │ ├── test_compressed_tensors_w8a8_int_dynamic_weight_all_params.json
│ │ │ │ └── test_compressed_tensors_w8a8_int_dynamic_weight_load.json
│ │ │ ├── test_compressed_tensors_w8an_fp/
│ │ │ │ ├── test_compressed_tensors_w8an.json
│ │ │ │ ├── test_compressed_tensors_w8an_all_params.json
│ │ │ │ └── test_compressed_tensors_w8an_load.json
│ │ │ ├── test_compressed_tensors_wna16_int/
│ │ │ │ ├── test_compressed_tensors_wna16.json
│ │ │ │ ├── test_compressed_tensors_wna16_all_params.json
│ │ │ │ └── test_compressed_tensors_wna16_load.json
│ │ │ ├── test_compressed_tensors_wna16_int_24/
│ │ │ │ ├── test_compressed_tensors_wna16_int_24.json
│ │ │ │ ├── test_compressed_tensors_wna16_int_24_all_params.json
│ │ │ │ └── test_compressed_tensors_wna16_int_24_load.json
│ │ │ ├── test_continue_final_message/
│ │ │ │ ├── test_llama_completion_single_prompt.json
│ │ │ │ └── test_llama_completion_single_prompt_continue.json
│ │ │ ├── test_flash_awq/
│ │ │ │ ├── test_flash_llama_awq.json
│ │ │ │ ├── test_flash_llama_awq_all_params.json
│ │ │ │ └── test_flash_llama_awq_load.json
│ │ │ ├── test_flash_awq_sharded/
│ │ │ │ ├── test_flash_llama_awq_load_sharded.json
│ │ │ │ └── test_flash_llama_awq_sharded.json
│ │ │ ├── test_flash_deepseek_v2/
│ │ │ │ ├── test_flash_deepseek_v2.json
│ │ │ │ ├── test_flash_deepseek_v2_all_params.json
│ │ │ │ └── test_flash_deepseek_v2_load.json
│ │ │ ├── test_flash_falcon/
│ │ │ │ ├── test_flash_falcon.json
│ │ │ │ ├── test_flash_falcon_all_params.json
│ │ │ │ └── test_flash_falcon_load.json
│ │ │ ├── test_flash_gemma/
│ │ │ │ ├── test_flash_gemma_all_params.json
│ │ │ │ ├── test_flash_gemma_load.json
│ │ │ │ └── test_flash_gemma_simple.json
│ │ │ ├── test_flash_gemma2/
│ │ │ │ ├── test_flash_gemma2.json
│ │ │ │ └── test_flash_gemma2_load.json
│ │ │ ├── test_flash_gemma3/
│ │ │ │ ├── test_exceed_window.json
│ │ │ │ ├── test_flash_gemma3.json
│ │ │ │ ├── test_flash_gemma3_image_base64_rgb_jpg.json
│ │ │ │ ├── test_flash_gemma3_image_base64_rgb_png.json
│ │ │ │ ├── test_flash_gemma3_image_base64_rgba.json
│ │ │ │ ├── test_flash_gemma3_image_cow.json
│ │ │ │ └── test_flash_gemma3_image_cow_dog.json
│ │ │ ├── test_flash_gemma_gptq/
│ │ │ │ ├── test_flash_gemma_gptq.json
│ │ │ │ ├── test_flash_gemma_gptq_all_params.json
│ │ │ │ └── test_flash_gemma_gptq_load.json
│ │ │ ├── test_flash_gpt2/
│ │ │ │ ├── test_flash_gpt2.json
│ │ │ │ └── test_flash_gpt2_load.json
│ │ │ ├── test_flash_grammar_llama/
│ │ │ │ ├── test_flash_llama_grammar.json
│ │ │ │ ├── test_flash_llama_grammar_json.json
│ │ │ │ ├── test_flash_llama_grammar_load.json
│ │ │ │ ├── test_flash_llama_grammar_regex.json
│ │ │ │ └── test_flash_llama_grammar_single_load_instance.json
│ │ │ ├── test_flash_llama/
│ │ │ │ ├── test_flash_llama_all_params.json
│ │ │ │ ├── test_flash_llama_load.json
│ │ │ │ └── test_flash_llama_simple.json
│ │ │ ├── test_flash_llama_exl2/
│ │ │ │ ├── test_flash_llama_exl2.json
│ │ │ │ ├── test_flash_llama_exl2_all_params.json
│ │ │ │ └── test_flash_llama_exl2_load.json
│ │ │ ├── test_flash_llama_fp8/
│ │ │ │ ├── test_flash_llama_fp8.json
│ │ │ │ ├── test_flash_llama_fp8_all_params.json
│ │ │ │ └── test_flash_llama_fp8_load.json
│ │ │ ├── test_flash_llama_fp8_kv_cache/
│ │ │ │ ├── test_flash_llama_fp8_kv_cache.json
│ │ │ │ ├── test_flash_llama_fp8_kv_cache_all_params.json
│ │ │ │ └── test_flash_llama_fp8_kv_cache_load.json
│ │ │ ├── test_flash_llama_gptq/
│ │ │ │ ├── test_flash_llama_gptq.json
│ │ │ │ ├── test_flash_llama_gptq_all_params.json
│ │ │ │ └── test_flash_llama_gptq_load.json
│ │ │ ├── test_flash_llama_marlin/
│ │ │ │ ├── test_flash_llama_marlin.json
│ │ │ │ ├── test_flash_llama_marlin_all_params.json
│ │ │ │ └── test_flash_llama_marlin_load.json
│ │ │ ├── test_flash_llama_marlin_24/
│ │ │ │ ├── test_flash_llama_marlin.json
│ │ │ │ ├── test_flash_llama_marlin24_all_params.json
│ │ │ │ └── test_flash_llama_marlin24_load.json
│ │ │ ├── test_flash_llama_prefix/
│ │ │ │ └── test_flash_llama_load.json
│ │ │ ├── test_flash_llama_prefix_flashdecoding/
│ │ │ │ └── test_flash_llama_flashdecoding.json
│ │ │ ├── test_flash_medusa/
│ │ │ │ ├── test_flash_medusa_all_params.json
│ │ │ │ ├── test_flash_medusa_load.json
│ │ │ │ └── test_flash_medusa_simple.json
│ │ │ ├── test_flash_mistral/
│ │ │ │ ├── test_flash_mistral.json
│ │ │ │ ├── test_flash_mistral_all_params.json
│ │ │ │ └── test_flash_mistral_load.json
│ │ │ ├── test_flash_mixtral/
│ │ │ │ ├── test_flash_mixtral.json
│ │ │ │ ├── test_flash_mixtral_all_params.json
│ │ │ │ └── test_flash_mixtral_load.json
│ │ │ ├── test_flash_mixtral_awq/
│ │ │ │ ├── test_flash_mixtral_awq.json
│ │ │ │ ├── test_flash_mixtral_awq_all_params.json
│ │ │ │ └── test_flash_mixtral_awq_load.json
│ │ │ ├── test_flash_mixtral_gptq/
│ │ │ │ ├── test_flash_mixtral_gptq.json
│ │ │ │ ├── test_flash_mixtral_gptq_all_params.json
│ │ │ │ └── test_flash_mixtral_gptq_load.json
│ │ │ ├── test_flash_neox/
│ │ │ │ ├── test_flash_neox.json
│ │ │ │ └── test_flash_neox_load.json
│ │ │ ├── test_flash_neox_sharded/
│ │ │ │ ├── test_flash_neox.json
│ │ │ │ └── test_flash_neox_load.json
│ │ │ ├── test_flash_pali_gemma/
│ │ │ │ ├── test_flash_pali_gemma.json
│ │ │ │ └── test_flash_pali_gemma_two_images.json
│ │ │ ├── test_flash_pali_gemma2/
│ │ │ │ └── test_flash_pali_gemma_image.json
│ │ │ ├── test_flash_phi/
│ │ │ │ ├── test_flash_phi.json
│ │ │ │ ├── test_flash_phi_all_params.json
│ │ │ │ └── test_flash_phi_load.json
│ │ │ ├── test_flash_phi35_moe/
│ │ │ │ ├── test_flash_phi35_moe.json
│ │ │ │ ├── test_flash_phi35_moe_all_params.json
│ │ │ │ └── test_flash_phi35_moe_load.json
│ │ │ ├── test_flash_qwen2/
│ │ │ │ ├── test_flash_qwen2.json
│ │ │ │ ├── test_flash_qwen2_all_params.json
│ │ │ │ └── test_flash_qwen2_load.json
│ │ │ ├── test_flash_qwen2_5_vl/
│ │ │ │ ├── test_flash_qwen2_5_vl_bay.json
│ │ │ │ ├── test_flash_qwen2_5_vl_inpaint.json
│ │ │ │ ├── test_flash_qwen2_5_vl_simple.json
│ │ │ │ └── test_flash_qwen2_5_vl_simple_streaming.json
│ │ │ ├── test_flash_qwen2_vl/
│ │ │ │ ├── test_flash_qwen2_vl_bay.json
│ │ │ │ ├── test_flash_qwen2_vl_inpaint.json
│ │ │ │ ├── test_flash_qwen2_vl_simple.json
│ │ │ │ └── test_flash_qwen2_vl_simple_streaming.json
│ │ │ ├── test_flash_santacoder/
│ │ │ │ ├── test_flash_santacoder.json
│ │ │ │ └── test_flash_santacoder_load.json
│ │ │ ├── test_flash_starcoder/
│ │ │ │ ├── test_flash_starcoder.json
│ │ │ │ ├── test_flash_starcoder_default_params.json
│ │ │ │ └── test_flash_starcoder_load.json
│ │ │ ├── test_flash_starcoder2/
│ │ │ │ ├── test_flash_starcoder2.json
│ │ │ │ ├── test_flash_starcoder2_default_params.json
│ │ │ │ └── test_flash_starcoder2_load.json
│ │ │ ├── test_flash_starcoder2_lora/
│ │ │ │ ├── test_flash_starcoder2.json
│ │ │ │ ├── test_flash_starcoder2_default_params.json
│ │ │ │ ├── test_flash_starcoder2_load.json
│ │ │ │ └── test_flash_starcoder2_with_hugcode_adapter.json
│ │ │ ├── test_flash_starcoder_gptq/
│ │ │ │ ├── test_flash_starcoder_gptq.json
│ │ │ │ ├── test_flash_starcoder_gptq_default_params.json
│ │ │ │ └── test_flash_starcoder_gptq_load.json
│ │ │ ├── test_grammar_llama/
│ │ │ │ └── test_non_flash_llama_grammar_json.json
│ │ │ ├── test_grammar_response_format_llama/
│ │ │ │ ├── test_grammar_response_format_llama_json.1.json
│ │ │ │ ├── test_grammar_response_format_llama_json.2.json
│ │ │ │ └── test_grammar_response_format_llama_json.json
│ │ │ ├── test_idefics/
│ │ │ │ ├── test_idefics.json
│ │ │ │ ├── test_idefics_load.json
│ │ │ │ └── test_idefics_two_images.json
│ │ │ ├── test_idefics2/
│ │ │ │ ├── test_flash_idefics2_next_all_params.json
│ │ │ │ ├── test_flash_idefics2_next_load.json
│ │ │ │ ├── test_flash_idefics2_next_simple.json
│ │ │ │ └── test_flash_idefics2_two_images.json
│ │ │ ├── test_idefics3/
│ │ │ │ └── test_flash_idefics3_next_simple_url.json
│ │ │ ├── test_json_schema_constrain/
│ │ │ │ ├── test_json_schema_basic.json
│ │ │ │ ├── test_json_schema_complex.json
│ │ │ │ └── test_json_schema_stream.json
│ │ │ ├── test_llava_next/
│ │ │ │ ├── test_flash_llava_next_all_params.json
│ │ │ │ ├── test_flash_llava_next_load.json
│ │ │ │ └── test_flash_llava_next_simple.json
│ │ │ ├── test_lora_mistral/
│ │ │ │ ├── test_lora_mistral_with_customer_support_adapter.json
│ │ │ │ ├── test_lora_mistral_with_dbpedia_adapter.json
│ │ │ │ ├── test_lora_mistral_without_adapter.json
│ │ │ │ └── test_lora_mistral_without_customer_support_adapter.json
│ │ │ ├── test_mamba/
│ │ │ │ ├── test_mamba.json
│ │ │ │ ├── test_mamba_all_params.json
│ │ │ │ └── test_mamba_load.json
│ │ │ ├── test_mllama/
│ │ │ │ ├── test_mllama_load.json
│ │ │ │ └── test_mllama_simpl.json
│ │ │ ├── test_mpt/
│ │ │ │ ├── test_mpt.json
│ │ │ │ └── test_mpt_load.json
│ │ │ ├── test_mt0_base/
│ │ │ │ ├── test_mt0_base.json
│ │ │ │ ├── test_mt0_base_all_params.json
│ │ │ │ └── test_mt0_base_load.json
│ │ │ ├── test_neox/
│ │ │ │ ├── test_neox.json
│ │ │ │ └── test_neox_load.json
│ │ │ ├── test_neox_sharded/
│ │ │ │ ├── test_neox.json
│ │ │ │ └── test_neox_load.json
│ │ │ ├── test_server_gptq_quantized/
│ │ │ │ ├── test_server_gptq_quantized.json
│ │ │ │ ├── test_server_gptq_quantized_all_params.json
│ │ │ │ └── test_server_gptq_quantized_load.json
│ │ │ ├── test_smolvlm/
│ │ │ │ └── test_flash_smolvlm_next_simple_url.json
│ │ │ ├── test_t5_sharded/
│ │ │ │ ├── test_t5_sharded.json
│ │ │ │ └── test_t5_sharded_load.json
│ │ │ ├── test_tools_llama/
│ │ │ │ ├── test_flash_llama_grammar_tools_auto_nostream.json
│ │ │ │ ├── test_flash_llama_grammar_tools_choice_nostream.json
│ │ │ │ ├── test_flash_llama_grammar_tools_choice_stream.json
│ │ │ │ ├── test_flash_llama_grammar_tools_insufficient_information_nostream.json
│ │ │ │ ├── test_flash_llama_grammar_tools_insufficient_information_stream.json
│ │ │ │ ├── test_flash_llama_grammar_tools_nostream.json
│ │ │ │ ├── test_flash_llama_grammar_tools_openai.json
│ │ │ │ ├── test_flash_llama_grammar_tools_sea_creatures_stream_auto.json
│ │ │ │ ├── test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json
│ │ │ │ ├── test_flash_llama_grammar_tools_sea_creatures_stream_none.json
│ │ │ │ ├── test_flash_llama_grammar_tools_sea_creatures_stream_required.json
│ │ │ │ └── test_flash_llama_tool_reply_response.json
│ │ │ ├── test_transformers_llama4/
│ │ │ │ ├── test_flash_llama4.json
│ │ │ │ ├── test_flash_llama4_image_base64_rgb_jpg.json
│ │ │ │ ├── test_flash_llama4_image_base64_rgb_png.json
│ │ │ │ ├── test_flash_llama4_image_base64_rgba.json
│ │ │ │ ├── test_flash_llama4_image_cow.json
│ │ │ │ └── test_flash_llama4_image_cow_dog.json
│ │ │ └── test_transformers_olmo/
│ │ │ ├── test_flash_llama_load.json
│ │ │ └── test_flash_llama_simple.json
│ │ ├── test_bloom_560m.py
│ │ ├── test_bloom_560m_sharded.py
│ │ ├── test_chat_llama.py
│ │ ├── test_chat_stream_options.py
│ │ ├── test_completion_prompts.py
│ │ ├── test_compressed_tensors_w8a8_int.py
│ │ ├── test_compressed_tensors_w8a8_int_dynamic_weight.py
│ │ ├── test_compressed_tensors_w8an_fp.py
│ │ ├── test_compressed_tensors_wna16_int.py
│ │ ├── test_compressed_tensors_wna16_int_24.py
│ │ ├── test_continue_final_message.py
│ │ ├── test_flash_awq.py
│ │ ├── test_flash_awq_sharded.py
│ │ ├── test_flash_deepseek_v2.py
│ │ ├── test_flash_falcon.py
│ │ ├── test_flash_gemma.py
│ │ ├── test_flash_gemma2.py
│ │ ├── test_flash_gemma3.py
│ │ ├── test_flash_gemma_gptq.py
│ │ ├── test_flash_gpt2.py
│ │ ├── test_flash_grammar_llama.py
│ │ ├── test_flash_llama.py
│ │ ├── test_flash_llama_exl2.py
│ │ ├── test_flash_llama_fp8.py
│ │ ├── test_flash_llama_fp8_kv_cache.py
│ │ ├── test_flash_llama_gptq.py
│ │ ├── test_flash_llama_marlin.py
│ │ ├── test_flash_llama_marlin_24.py
│ │ ├── test_flash_llama_prefix.py
│ │ ├── test_flash_llama_prefix_flashdecoding.py
│ │ ├── test_flash_medusa.py
│ │ ├── test_flash_mistral.py
│ │ ├── test_flash_mixtral.py
│ │ ├── test_flash_mixtral_awq.py
│ │ ├── test_flash_mixtral_gptq.py
│ │ ├── test_flash_neox.py
│ │ ├── test_flash_neox_sharded.py
│ │ ├── test_flash_pali_gemma.py
│ │ ├── test_flash_pali_gemma2.py
│ │ ├── test_flash_phi.py
│ │ ├── test_flash_phi35_moe.py
│ │ ├── test_flash_qwen2.py
│ │ ├── test_flash_qwen2_5_vl.py
│ │ ├── test_flash_qwen2_vl.py
│ │ ├── test_flash_santacoder.py
│ │ ├── test_flash_starcoder.py
│ │ ├── test_flash_starcoder2.py
│ │ ├── test_flash_starcoder2_lora.py
│ │ ├── test_flash_starcoder_gptq.py
│ │ ├── test_grammar_llama.py
│ │ ├── test_grammar_response_format_llama.py
│ │ ├── test_idefics.py
│ │ ├── test_idefics2.py
│ │ ├── test_idefics3.py
│ │ ├── test_json_schema_constrain.py
│ │ ├── test_llava_next.py
│ │ ├── test_lora_mistral.py
│ │ ├── test_mamba.py
│ │ ├── test_mllama.py
│ │ ├── test_mpt.py
│ │ ├── test_mt0_base.py
│ │ ├── test_neox.py
│ │ ├── test_neox_sharded.py
│ │ ├── test_opt.py
│ │ ├── test_smolvlm.py
│ │ ├── test_t5_sharded.py
│ │ ├── test_tools_llama.py
│ │ ├── test_transformers_llama4.py
│ │ └── test_transformers_olmo.py
│ ├── neuron/
│ │ ├── test_generate.py
│ │ └── test_implicit_env.py
│ ├── pyproject.toml
│ ├── pytest.ini
│ └── requirements.txt
├── launcher/
│ ├── Cargo.toml
│ ├── build.rs
│ └── src/
│ ├── env_runtime.rs
│ ├── gpu.rs
│ └── main.rs
├── load_tests/
│ ├── Makefile
│ ├── benchmarks.py
│ ├── common.js
│ ├── filter.py
│ ├── long.js
│ ├── long.py
│ ├── long_prompt2.py
│ ├── orca.py
│ └── pyproject.toml
├── nix/
│ ├── client.nix
│ ├── crate-overrides.nix
│ ├── docker.nix
│ ├── impure-shell.nix
│ ├── overlay.nix
│ └── server.nix
├── proto/
│ ├── generate.proto
│ └── v3/
│ └── generate.proto
├── router/
│ ├── Cargo.toml
│ ├── README.md
│ ├── build.rs
│ └── src/
│ ├── chat.rs
│ ├── config.rs
│ ├── infer/
│ │ ├── chat_template.rs
│ │ ├── mod.rs
│ │ └── tool_grammar.rs
│ ├── kserve.rs
│ ├── lib.rs
│ ├── logging.rs
│ ├── sagemaker.rs
│ ├── server.rs
│ ├── usage_stats.rs
│ ├── validation.rs
│ └── vertex.rs
├── rust-toolchain.toml
├── sagemaker-entrypoint.sh
├── server/
│ ├── .gitignore
│ ├── Makefile
│ ├── Makefile-awq
│ ├── Makefile-eetq
│ ├── Makefile-exllamav2
│ ├── Makefile-flash-att
│ ├── Makefile-flash-att-v2
│ ├── Makefile-flashinfer
│ ├── Makefile-selective-scan
│ ├── Makefile-vllm
│ ├── README.md
│ ├── bounds-from-nix.py
│ ├── custom_kernels/
│ │ ├── custom_kernels/
│ │ │ ├── fused_attention_cuda.cu
│ │ │ └── fused_bloom_attention_cuda.cu
│ │ └── setup.py
│ ├── exllama_kernels/
│ │ ├── exllama_kernels/
│ │ │ ├── cu_compat.cuh
│ │ │ ├── cuda_buffers.cu
│ │ │ ├── cuda_buffers.cuh
│ │ │ ├── cuda_func/
│ │ │ │ ├── column_remap.cu
│ │ │ │ ├── column_remap.cuh
│ │ │ │ ├── q4_matmul.cu
│ │ │ │ ├── q4_matmul.cuh
│ │ │ │ ├── q4_matrix.cu
│ │ │ │ └── q4_matrix.cuh
│ │ │ ├── exllama_ext.cpp
│ │ │ ├── hip_compat.cuh
│ │ │ ├── matrix.cuh
│ │ │ ├── tuning.h
│ │ │ └── util.cuh
│ │ └── setup.py
│ ├── exllamav2_kernels/
│ │ ├── exllamav2_kernels/
│ │ │ ├── config.h
│ │ │ ├── cpp/
│ │ │ │ └── util.h
│ │ │ ├── cuda/
│ │ │ │ ├── compat.cuh
│ │ │ │ ├── matrix_view.cuh
│ │ │ │ ├── q_gemm.cu
│ │ │ │ ├── q_gemm.cuh
│ │ │ │ ├── q_gemm_kernel.cuh
│ │ │ │ ├── q_gemm_kernel_gptq.cuh
│ │ │ │ ├── q_matrix.cu
│ │ │ │ ├── q_matrix.cuh
│ │ │ │ ├── quant/
│ │ │ │ │ ├── qdq_2.cuh
│ │ │ │ │ ├── qdq_3.cuh
│ │ │ │ │ ├── qdq_4.cuh
│ │ │ │ │ ├── qdq_5.cuh
│ │ │ │ │ ├── qdq_6.cuh
│ │ │ │ │ ├── qdq_8.cuh
│ │ │ │ │ └── qdq_util.cuh
│ │ │ │ └── util.cuh
│ │ │ └── ext.cpp
│ │ └── setup.py
│ ├── pyproject.toml
│ ├── req.txt
│ ├── requirements_cuda.txt
│ ├── requirements_gen.txt
│ ├── requirements_intel.txt
│ ├── requirements_rocm.txt
│ ├── tests/
│ │ ├── conftest.py
│ │ ├── models/
│ │ │ ├── test_bloom.py
│ │ │ ├── test_causal_lm.py
│ │ │ ├── test_model.py
│ │ │ ├── test_santacoder.py
│ │ │ └── test_seq2seq_lm.py
│ │ └── utils/
│ │ ├── test_adapter.py
│ │ ├── test_convert.py
│ │ ├── test_hub.py
│ │ ├── test_layers.py
│ │ ├── test_tokens.py
│ │ ├── test_watermark.py
│ │ └── test_weights.py
│ └── text_generation_server/
│ ├── __init__.py
│ ├── adapters/
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── lora.py
│ │ └── weights.py
│ ├── cache.py
│ ├── cli.py
│ ├── interceptor.py
│ ├── layers/
│ │ ├── __init__.py
│ │ ├── attention/
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── cuda.py
│ │ │ ├── flash_attn_triton.py
│ │ │ ├── flashinfer.py
│ │ │ ├── ipex.py
│ │ │ ├── kv_cache.py
│ │ │ └── rocm.py
│ │ ├── awq/
│ │ │ ├── conversion_utils.py
│ │ │ └── quantize/
│ │ │ ├── __init__.py
│ │ │ ├── cuda.py
│ │ │ └── ipex.py
│ │ ├── bnb.py
│ │ ├── compressed_tensors/
│ │ │ ├── __init__.py
│ │ │ ├── loader.py
│ │ │ ├── w8a8_int.py
│ │ │ ├── w8an_fp.py
│ │ │ ├── wna16_int.py
│ │ │ └── wna16_int_24.py
│ │ ├── conv.py
│ │ ├── eetq.py
│ │ ├── exl2.py
│ │ ├── fp8.py
│ │ ├── gptq/
│ │ │ ├── __init__.py
│ │ │ ├── custom_autotune.py
│ │ │ ├── exllama.py
│ │ │ ├── exllamav2.py
│ │ │ ├── ipex.py
│ │ │ ├── quantize.py
│ │ │ ├── triton.py
│ │ │ └── utils.py
│ │ ├── layernorm.py
│ │ ├── linear.py
│ │ ├── lora.py
│ │ ├── marlin/
│ │ │ ├── __init__.py
│ │ │ ├── fp8.py
│ │ │ ├── gptq.py
│ │ │ ├── marlin.py
│ │ │ └── util.py
│ │ ├── medusa.py
│ │ ├── mlp.py
│ │ ├── moe/
│ │ │ ├── __init__.py
│ │ │ ├── fp8.py
│ │ │ ├── fused_moe_ipex.py
│ │ │ ├── gptq_marlin.py
│ │ │ └── unquantized.py
│ │ ├── rotary.py
│ │ ├── speculative.py
│ │ └── tensor_parallel.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── bloom.py
│ │ ├── causal_lm.py
│ │ ├── custom_modeling/
│ │ │ ├── __init__.py
│ │ │ ├── bloom_modeling.py
│ │ │ ├── clip.py
│ │ │ ├── flash_cohere_modeling.py
│ │ │ ├── flash_dbrx_modeling.py
│ │ │ ├── flash_deepseek_v2_modeling.py
│ │ │ ├── flash_deepseek_v3_modeling.py
│ │ │ ├── flash_gemma2_modeling.py
│ │ │ ├── flash_gemma3_modeling.py
│ │ │ ├── flash_gemma_modeling.py
│ │ │ ├── flash_gpt2_modeling.py
│ │ │ ├── flash_gptj_modeling.py
│ │ │ ├── flash_llama_modeling.py
│ │ │ ├── flash_mistral_modeling.py
│ │ │ ├── flash_mixtral_modeling.py
│ │ │ ├── flash_neox_modeling.py
│ │ │ ├── flash_pali_gemma_modeling.py
│ │ │ ├── flash_phi_modeling.py
│ │ │ ├── flash_phi_moe_modeling.py
│ │ │ ├── flash_qwen2_modeling.py
│ │ │ ├── flash_rw_modeling.py
│ │ │ ├── flash_santacoder_modeling.py
│ │ │ ├── flash_starcoder2_modeling.py
│ │ │ ├── gemma3/
│ │ │ │ ├── configuration_gemma3.py
│ │ │ │ ├── image_processing_gemma3.py
│ │ │ │ ├── processing_gemma3.py
│ │ │ │ └── utils.py
│ │ │ ├── idefics2.py
│ │ │ ├── idefics3.py
│ │ │ ├── idefics_config.py
│ │ │ ├── idefics_image_processing.py
│ │ │ ├── idefics_modeling.py
│ │ │ ├── idefics_perceiver.py
│ │ │ ├── idefics_processing.py
│ │ │ ├── idefics_vision.py
│ │ │ ├── llava_next.py
│ │ │ ├── mamba_modeling.py
│ │ │ ├── mllama.py
│ │ │ ├── mpt_modeling.py
│ │ │ ├── neox_modeling.py
│ │ │ ├── opt_modeling.py
│ │ │ ├── phi_modeling.py
│ │ │ ├── qwen2_5_vl.py
│ │ │ ├── qwen2_vl.py
│ │ │ ├── siglip.py
│ │ │ ├── t5_modeling.py
│ │ │ └── vlm.py
│ │ ├── flash_causal_lm.py
│ │ ├── galactica.py
│ │ ├── globals.py
│ │ ├── idefics_causal_lm.py
│ │ ├── mamba.py
│ │ ├── metadata_kernels.py
│ │ ├── mllama_causal_lm.py
│ │ ├── model.py
│ │ ├── seq2seq_lm.py
│ │ ├── transformers_flash_causal_lm.py
│ │ ├── transformers_flash_vlm.py
│ │ ├── types.py
│ │ └── vlm_causal_lm.py
│ ├── pb/
│ │ └── .gitignore
│ ├── server.py
│ ├── tracing.py
│ └── utils/
│ ├── __init__.py
│ ├── adapter.py
│ ├── chunks.py
│ ├── convert.py
│ ├── dist.py
│ ├── hub.py
│ ├── import_utils.py
│ ├── kernels.py
│ ├── log.py
│ ├── logits_process.py
│ ├── merges/
│ │ ├── strategies.py
│ │ └── utils.py
│ ├── peft.py
│ ├── prefill_chunking.py
│ ├── quantization.py
│ ├── segments.py
│ ├── speculate.py
│ ├── tokens.py
│ ├── watermark.py
│ └── weights.py
├── tgi-entrypoint.sh
└── update_doc.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .dockerignore
================================================
aml
target
server/transformers
server/flash-attention
cmake-build-debug/
cmake-build-release/
Dockerfile*
================================================
FILE: .github/ISSUE_TEMPLATE/bug-report.yml
================================================
name: "\U0001F41B Bug Report"
description: Submit a bug report to help us improve text-generation-inference
body:
- type: textarea
id: system-info
attributes:
label: System Info
description: |
Please share your system info with us (`text-generation-launcher --env` if installed locally).
The full command line used that causes issues:
OS version:
Rust version (if self-compiling, `cargo version`):
Model being used (`curl 127.0.0.1:8080/info | jq`):
If local model please explicit the kind of model and/or equivalents.
Hardware used (GPUs, how many, on which cloud) (`nvidia-smi`):
Deployment specificities (Kubernetes, EKS, AKS, any particular deployments):
The current version being used:
placeholder: text-generation-inference version, platform, python version, ...
validations:
required: true
- type: checkboxes
id: information-scripts-examples
attributes:
label: Information
description: 'The problem arises when using:'
options:
- label: "Docker"
- label: "The CLI directly"
- type: checkboxes
id: information-tasks
attributes:
label: Tasks
description: "The thing I am working on is:"
options:
- label: "An officially supported command"
- label: "My own modifications"
- type: textarea
id: reproduction
validations:
required: true
attributes:
label: Reproduction
description: |
Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet.
If you have code snippets, error messages, stack traces please provide them here as well.
Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
placeholder: |
Steps to reproduce the behavior:
1.
2.
3.
- type: textarea
id: expected-behavior
validations:
required: true
attributes:
label: Expected behavior
description: "A clear and concise description of what you would expect to happen."
================================================
FILE: .github/ISSUE_TEMPLATE/config.yml
================================================
blank_issues_enabled: true
version: 2.1
================================================
FILE: .github/ISSUE_TEMPLATE/feature-request.yml
================================================
name: "\U0001F680 Feature request"
description: Submit a proposal/request for a new text-generation-inference feature
labels: [ "feature" ]
body:
- type: textarea
id: feature-request
validations:
required: true
attributes:
label: Feature request
description: |
A clear and concise description of the feature proposal. Please provide a link to the paper and code in case they exist.
- type: textarea
id: motivation
validations:
required: true
attributes:
label: Motivation
description: |
Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too.
- type: textarea
id: contribution
validations:
required: true
attributes:
label: Your contribution
description: |
Is there any way that you could help, e.g. by submitting a PR? Make sure to read the CONTRIBUTING.MD [readme](https://github.com/huggingface/text-generation-inference/blob/main/CONTRIBUTING.md)
================================================
FILE: .github/ISSUE_TEMPLATE/new-model-addition.yml
================================================
name: "\U0001F31F New model addition"
description: Submit a proposal/request to implement a new model
labels: [ "New model" ]
body:
- type: textarea
id: description-request
validations:
required: true
attributes:
label: Model description
description: |
Put any and all important information relative to the model
- type: checkboxes
id: information-tasks
attributes:
label: Open source status
description: |
Please note that if the model implementation isn't available or if the weights aren't open-source, we are less likely to implement it in `transformers`.
options:
- label: "The model implementation is available"
- label: "The model weights are available"
- type: textarea
id: additional-info
attributes:
label: Provide useful links for the implementation
description: |
Please provide information regarding the implementation, the weights, and the authors.
Please mention the authors by @gh-username if you're aware of their usernames.
================================================
FILE: .github/PULL_REQUEST_TEMPLATE.md
================================================
# What does this PR do?
Fixes # (issue)
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the
[documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and
[here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?
## Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
================================================
FILE: .github/workflows/autodocs.yaml
================================================
name: Automatic Documentation for Launcher
on:
pull_request:
jobs:
update_docs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Rust
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
- name: Install Protocol Buffers compiler
run: |
sudo apt-get update
sudo apt-get install -y protobuf-compiler libprotobuf-dev
- name: Install Launcher
id: install-launcher
run: cargo install --path launcher/
- name: Install router
id: install-router
run: cargo install --path backends/v3/
- uses: actions/setup-node@v4
with:
node-version: 22
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'
- name: Check that documentation is up-to-date
run: |
npm install -g @redocly/cli@1.34.2
python update_doc.py --check
================================================
FILE: .github/workflows/build.yaml
================================================
name: Build and push docker image to internal registry
on:
workflow_call:
inputs:
hardware:
type: string
description: Hardware
# options:
# - cuda
# - cuda-trtllm
# - rocm
# - intel
required: true
release-tests:
description: "Run release integration tests"
required: true
default: false
type: boolean
jobs:
build-and-push:
outputs:
docker_image: ${{ steps.final.outputs.docker_image }}
docker_volume: ${{ steps.final.outputs.docker_volume }}
docker_devices: ${{ steps.final.outputs.docker_devices }}
runs_on: ${{ steps.final.outputs.runs_on }}
label_extension: ${{ steps.final.outputs.label_extension }}
extra_pytest: ${{ steps.final.outputs.extra_pytest }}
concurrency:
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
runs-on:
group: aws-highmemory-64-plus-priv
permissions:
contents: write
packages: write
id-token: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1
- name: Inject required variables for sccache to interact with Github Actions Cache
uses: actions/github-script@v7
with:
script: |
core.exportVariable('ACTIONS_RESULTS_URL', process.env.ACTIONS_RESULTS_URL || '');
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
- name: Extract TensorRT-LLM version
run: |
echo "TENSORRT_LLM_VERSION=$(grep -oP '([a-z,0-9]{40})' $GITHUB_WORKSPACE/backends/trtllm/cmake/trtllm.cmake)" >> $GITHUB_ENV
echo "TensorRT-LLM version: ${{ env.TENSORRT_LLM_VERSION }}"
- name: Construct hardware variables
shell: bash
run: |
case ${{ inputs.hardware }} in
cuda)
export dockerfile="Dockerfile"
export label_extension=""
export docker_volume="/mnt/cache"
export docker_devices=""
export runs_on="aws-g6-12xl-plus-priv-cache"
export platform=""
export extra_pytest=""
export target=""
;;
cuda-trtllm)
export dockerfile="Dockerfile_trtllm"
export label_extension="-trtllm"
export docker_volume="/mnt/cache"
export docker_devices=""
export runs_on="ubuntu-latest"
export platform=""
export extra_pytest=""
if [[ "${GITHUB_REF}" == refs/tags/* ]]; then
export build_type="release";
export target="";
else
export build_type="dev";
export target="ci-runtime";
fi
;;
rocm)
export dockerfile="Dockerfile_amd"
export label_extension="-rocm"
export docker_devices="/dev/kfd,/dev/dri"
export docker_volume="/mnt"
# This runner was deactivated.
export runs_on="ubuntu-latest"
export platform=""
export extra_pytest="-k test_flash_gemma_gptq_load"
export target=""
;;
intel-xpu)
export dockerfile="Dockerfile_intel"
export label_extension="-intel-xpu"
export docker_devices=""
export docker_volume="/mnt/cache"
export runs_on="ubuntu-latest"
export platform="xpu"
export extra_pytest=""
export target=""
;;
intel-cpu)
export dockerfile="Dockerfile_intel"
export label_extension="-intel-cpu"
export docker_devices="none"
export docker_volume="/mnt/cache"
# export runs_on="ubuntu-latest"
export runs_on="aws-highmemory-32-plus-priv"
export platform="cpu"
export extra_pytest="-k test_flash_gemma_simple"
export target=""
;;
neuron)
export dockerfile="Dockerfile.neuron"
export label_extension="-neuron"
export docker_devices="/dev/neuron0"
export docker_volume="/mnt/cache"
export runs_on="aws-inf2-8xlarge"
export platform="cpu"
export extra_pytest="--neuron"
export target=""
;;
gaudi)
export dockerfile="Dockerfile_gaudi"
export label_extension="-gaudi"
export docker_volume="/mnt/cache"
export docker_devices=""
export runs_on="itac-bm-emr-gaudi3-dell-2gaudi"
export platform=""
export extra_pytest="--gaudi"
export target=""
esac
echo $dockerfile
echo "Dockerfile=${dockerfile}"
echo $label_extension
echo $docker_devices
echo $runs_on
echo $platform
echo "DOCKERFILE=${dockerfile}" >> $GITHUB_ENV
echo "LABEL_EXTENSION=${label_extension}" >> $GITHUB_ENV
echo "PLATFORM=${platform}" >> $GITHUB_ENV
echo "DOCKER_VOLUME=${docker_volume}" >> $GITHUB_ENV
echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV
echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV
echo "EXTRA_PYTEST=${extra_pytest}" >> $GITHUB_ENV
echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV
echo "TARGET=${target}" >> $GITHUB_ENV
echo "BUILD_TYPE=${build_type}" >> $GITHUB_ENV
- name: Initialize Docker Buildx
uses: docker/setup-buildx-action@v3
with:
install: true
buildkitd-config: /tmp/buildkitd.toml
- name: Login to internal Container Registry
if: github.event_name != 'pull_request'
uses: docker/login-action@v3
with:
username: ${{ secrets.REGISTRY_USERNAME }}
password: ${{ secrets.REGISTRY_PASSWORD }}
registry: registry.internal.huggingface.tech
- name: Login to GitHub Container Registry
if: github.event_name != 'pull_request'
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Login to Docker Hub Container Registry
uses: docker/login-action@v3
with:
registry: docker.io
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: configure aws credentials
id: aws-creds
uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502
with:
role-to-assume: ${{ secrets.AWS_ROLE_GITHUB_BUILDX_CACHE }}
role-duration-seconds: 18000
aws-region: us-east-1
output-credentials: true
# If pull request
- name: Extract metadata (tags, labels) for Docker
if: ${{ github.event_name == 'pull_request' }}
id: meta-pr
uses: docker/metadata-action@v5
with:
images: |
docker.io/huggingface/text-generation-inference-ci
tags: |
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL_EXTENSION }}
# If main, release or tag
- name: Extract metadata (tags, labels) for Docker
if: ${{ github.event_name != 'pull_request' }}
id: meta
uses: docker/metadata-action@v4.3.0
with:
flavor: |
latest=false
images: |
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
ghcr.io/huggingface/text-generation-inference
tags: |
type=semver,pattern={{version}}${{ env.LABEL_EXTENSION }}
type=semver,pattern={{major}}.{{minor}}${{ env.LABEL_EXTENSION }}
type=raw,value=latest${{ env.LABEL_EXTENSION }},enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL_EXTENSION }}
- name: Build and push Docker image
id: build-and-push
uses: docker/build-push-action@v4
env:
DOCKER_BUILD_SUMMARY: false
with:
context: .
file: ${{ env.DOCKERFILE }}
push: true
platforms: 'linux/amd64'
build-args: |
GIT_SHA=${{ env.GITHUB_SHA }}
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL_EXTENSION }}
PLATFORM=${{ env.PLATFORM }}
build_type=${{ env.BUILD_TYPE }}
sccache_gha_enabled=on
secrets: |
actions_results_url=${{ env.ACTIONS_RESULTS_URL }}
actions_runtime_token=${{ env.ACTIONS_RUNTIME_TOKEN }}
target: ${{ env.TARGET }}
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
cache-from: type=s3,region=us-east-1,bucket=${{ vars.AWS_S3BUCKET_GITHUB_BUILDX_CACHE }},name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ steps.aws-creds.outputs.aws-access-key-id }},secret_access_key=${{ steps.aws-creds.outputs.aws-secret-access-key }},session_token=${{ steps.aws-creds.outputs.aws-session-token }},mode=max
cache-to: type=s3,region=us-east-1,bucket=${{ vars.AWS_S3BUCKET_GITHUB_BUILDX_CACHE }},name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ steps.aws-creds.outputs.aws-access-key-id }},secret_access_key=${{ steps.aws-creds.outputs.aws-secret-access-key }},session_token=${{ steps.aws-creds.outputs.aws-session-token }},mode=max
- name: Final
id: final
run: |
if [ "${{ github.event_name }}" = "pull_request" ]; then
echo "docker_image=docker.io/huggingface/text-generation-inference-ci:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
else
echo "docker_image=ghcr.io/huggingface/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
fi
echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT"
echo "docker_volume=${{ env.DOCKER_VOLUME }}" >> "$GITHUB_OUTPUT"
echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT"
echo "label_extension=${{ env.LABEL_EXTENSION }}" >> "$GITHUB_OUTPUT"
echo "extra_pytest=${{ env.EXTRA_PYTEST }}" >> "$GITHUB_OUTPUT"
precompile_neuron_models:
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label_extension }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
needs: build-and-push
if: needs.build-and-push.outputs.label_extension == '-neuron'
runs-on:
group: ${{ needs.build-and-push.outputs.runs_on }}
env:
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '--release' }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install
run: |
make install-integration-tests
- name: Export neuron models
run: |
export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}
echo $DOCKER_IMAGE
docker pull $DOCKER_IMAGE
export HF_TOKEN=${{ secrets.HF_TOKEN_NEURON }}
python integration-tests/fixtures/neuron/export_models.py
integration_tests:
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label_extension }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
needs: [precompile_neuron_models, build-and-push]
if: ${{ always() && !contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') && needs.build-and-push.outputs.runs_on != 'ubuntu-latest' }}
runs-on:
group: ${{ needs.build-and-push.outputs.runs_on }}
env:
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '--release' }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install
run: |
make install-integration-tests
- name: Run tests
run: |
export DOCKER_VOLUME=${{ needs.build-and-push.outputs.docker_volume }}
export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}
export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }}
export EXTRA_PYTEST="${{ needs.build-and-push.outputs.extra_pytest }}"
export HF_TOKEN=${{ secrets.HF_TOKEN }}
echo $DOCKER_IMAGE
docker pull $DOCKER_IMAGE
pytest -s -vv integration-tests ${PYTEST_FLAGS} ${EXTRA_PYTEST}
backend_trtllm_cxx_tests:
needs: build-and-push
if: needs.build-and-push.outputs.label_extension == '-trtllm'
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-trtllm-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
runs-on:
group: aws-g6-12xl-plus-priv-cache
container:
image: ${{ needs.build-and-push.outputs.docker_image }}
credentials:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
options: --gpus all --shm-size=8g
steps:
- name: Run C++/CUDA tests
if: ${{ env.LABEL_EXTENSION == 'ci-runtime' }}
run: /usr/local/tgi/bin/tgi_trtllm_backend_tests
================================================
FILE: .github/workflows/build_documentation.yaml
================================================
name: Build documentation
on:
push:
paths:
- "docs/source/**"
branches:
- main
- doc-builder*
- v*-release
jobs:
build:
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
with:
commit_sha: ${{ github.sha }}
package: text-generation-inference
additional_args: --not_python_module
secrets:
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
================================================
FILE: .github/workflows/build_pr_documentation.yaml
================================================
name: Build PR Documentation
on:
pull_request:
paths:
- "docs/source/**"
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
build:
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
with:
commit_sha: ${{ github.event.pull_request.head.sha }}
pr_number: ${{ github.event.number }}
package: text-generation-inference
additional_args: --not_python_module
================================================
FILE: .github/workflows/ci_build.yaml
================================================
name: CI build
on:
push:
branches:
- 'main'
tags:
- 'v*'
pull_request:
paths:
- ".github/workflows/build.yaml"
- "integration-tests/**"
- "backends/**"
- "server/**"
- "proto/**"
- "router/**"
- "launcher/**"
- "Cargo.lock"
- "rust-toolchain.toml"
- "Dockerfile"
- "Dockerfile_amd"
- "Dockerfile_intel"
- "Dockerfile.neuron"
- "Dockerfile_gaudi"
branches:
- "main"
workflow_dispatch:
inputs:
release-tests:
description: "Run release integration tests"
required: true
default: false
type: boolean
jobs:
build:
strategy:
# super important if you want to see all results, even if one fails
# fail-fast is true by default
fail-fast: false
matrix:
hardware: ["cuda", "cuda-trtllm", "rocm", "intel-xpu", "intel-cpu", "neuron", "gaudi"]
uses: ./.github/workflows/build.yaml # calls the one above ^
permissions:
contents: write
packages: write
id-token: write
with:
hardware: ${{ matrix.hardware }}
# https://github.com/actions/runner/issues/2206
release-tests: ${{ inputs.release-tests == true }}
secrets: inherit
================================================
FILE: .github/workflows/client-tests.yaml
================================================
name: Python Client Tests
on:
pull_request:
paths:
- ".github/workflows/client-tests.yaml"
- "clients/python/**"
jobs:
run_tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v1
with:
python-version: 3.9
- name: Install
run: |
cd clients/python && pip install .
- name: Run tests
run: |
pip install pytest pytest-asyncio
export HF_TOKEN=${{ secrets.HF_TOKEN }}
make python-client-tests
================================================
FILE: .github/workflows/codeql.yml
================================================
---
name: CodeQL Security Analysis For Github Actions
on:
push:
branches: ["main"]
workflow_dispatch:
# pull_request:
jobs:
codeql:
name: CodeQL Analysis
uses: huggingface/security-workflows/.github/workflows/codeql-reusable.yml@v1.2.0
permissions:
security-events: write
packages: read
actions: read
contents: read
with:
languages: '["actions"]'
queries: 'security-extended,security-and-quality'
runner: 'ubuntu-latest' #optional if need custom runner
use-runner-group: false #optional
# if need to use runner group:
# runner: 'cpu-low'
# use-runner-group: true
================================================
FILE: .github/workflows/integration_tests.yaml
================================================
name: Integration tests
on:
workflow_call:
inputs:
docker_image:
type: string
description: Hardware
required: true
docker_devices:
type: string
description: Hardware
runs_on:
type: string
required: true
description: Hardware to run integration tests
jobs:
integration_tests:
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
runs-on: ${{ inputs.runs_on }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.9
- name: Install
run: |
make install-integration-tests
- name: Run tests
run: |
export DOCKER_VOLUME=/mnt/cache
export DOCKER_IMAGE=${{ inputs.docker_image }}
export DOCKER_DEVICES=${{ inputs.docker_devices }}
export HF_TOKEN=${{ secrets.HF_TOKEN }}
pytest -s -vv integration-tests
================================================
FILE: .github/workflows/load_test.yaml
================================================
name: Nightly load test
on:
schedule:
- cron: '0 0 * * 1-5'
workflow_call:
workflow_dispatch:
pull_request:
paths:
- ".github/workflows/load_test.yaml"
env:
AWS_DEFAULT_REGION: us-east-1
AWS_ACCESS_KEY_ID: ${{ secrets.S3_AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_AWS_SECRET_ACCESS_KEY }}
jobs:
load-tests:
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
runs-on:
group: aws-g6-12xl-plus-priv-cache
env:
DOCKER_VOLUME: /cache
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Install Python 3.11
uses: actions/setup-python@v2
with:
python-version: 3.11
- name: Install poetry
run: |
curl -sSL https://install.python-poetry.org | python3 -
export PATH="$HOME/.local/bin:$PATH"
poetry --version
- name: Run bench test
run: |
export PATH="$HOME/.local/bin:$PATH"
cd load_tests
poetry install
poetry run python benchmarks.py --sha ${{ github.sha }} --results-file "s3://text-generation-inference-ci/benchmarks/ci/${{ github.sha }}.parquet"
shell: bash
env:
HF_TOKEN: ${{ secrets.HF_TOKEN_BENCHMARK }}
================================================
FILE: .github/workflows/nix_build.yaml
================================================
name: "Nix Build Docker image"
on:
pull_request:
push:
branches:
- 'main'
tags:
- 'v*'
concurrency:
group: nix-image-${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
build_nix_image:
runs-on:
group: aws-highmemory-32-plus-priv
steps:
- uses: actions/checkout@v4
- uses: cachix/install-nix-action@v27
with:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
with:
name: huggingface
# If you chose signing key for write access
# authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
env:
USER: github_runner
- name: Build
run: nix build .#dockerImage
- name: Initialize Docker Buildx
uses: docker/setup-buildx-action@v3
with:
install: true
buildkitd-config: /tmp/buildkitd.toml
- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1
- name: Login to internal Container Registry
# if: github.event_name != 'pull_request'
uses: docker/login-action@v3
with:
username: ${{ secrets.REGISTRY_USERNAME }}
password: ${{ secrets.REGISTRY_PASSWORD }}
registry: registry.internal.huggingface.tech
- name: Push to docker
run: |
if [ "${{ github.event_name }}" = "pull_request" ]; then
export TAG=nix-sha-${{ env.GITHUB_SHA_SHORT }}
else
export TAG=${{ github.ref_name }}-nix
fi
export IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:$TAG
nix-shell -p skopeo --command "skopeo --insecure-policy copy docker-archive:$(readlink -f ./result) docker://$IMAGE --dest-compress-format zstd"
================================================
FILE: .github/workflows/nix_cache.yaml
================================================
name: "Cache devshells"
on:
pull_request:
paths:
- "flake.nix"
- "flake.lock"
- "nix/**"
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
tests:
runs-on:
group: aws-highmemory-32-plus-priv
steps:
- uses: actions/checkout@v4
- uses: cachix/install-nix-action@v27
with:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
with:
name: huggingface
# If you chose signing key for write access
#authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
env:
USER: github_runner
- name: Build impure devshell
run: nix build .\#devShells.x86_64-linux.impure
- name: Build impure devshell (CUDA dev)
run: nix build .\#devShells.x86_64-linux.impureWithCuda
# Pure shell dependencies are covered by Nix tests.
# - name: Build pure devshell
# run: nix build .\#devShells.x86_64-linux.pure
================================================
FILE: .github/workflows/nix_tests.yaml
================================================
name: "Nix Tests"
on:
pull_request:
paths:
- ".github/workflows/nix_tests.yaml"
- "server/**"
- "proto/**"
- "router/**"
- "launcher/**"
- "backends/**"
- "Cargo.lock"
- "rust-toolchain.toml"
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
tests:
runs-on:
group: aws-highmemory-32-plus-priv
steps:
- uses: actions/checkout@v4
- uses: cachix/install-nix-action@v27
with:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
with:
name: huggingface
# If you chose signing key for write access
#authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
env:
USER: github_runner
- name: Nix info
run: nix-shell -p nix-info --run "nix-info -m"
- name: Build
run: nix develop .#test --command echo "Ok"
- name: Pre-commit tests.
run: nix develop .#test --command pre-commit run --all-files
- name: Python tests.
run: nix develop .#test --command python -m pytest server/tests/
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
- name: Rust tests.
run: nix develop .#test --command cargo test
================================================
FILE: .github/workflows/stale.yaml
================================================
name: 'Close stale issues and PRs'
on:
schedule:
- cron: '30 1 * * *'
jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v8
with:
stale-issue-message: 'This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.'
days-before-stale: 30
days-before-close: 5
================================================
FILE: .github/workflows/tests.yaml
================================================
name: Server Tests
on:
pull_request:
paths:
- ".github/workflows/tests.yaml"
- "server/**"
- "proto/**"
- "router/**"
- "launcher/**"
- "backends/**"
- "Cargo.lock"
- "rust-toolchain.toml"
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
run_tests:
runs-on:
group: aws-highmemory-32-plus-priv
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
id: python
with:
python-version: 3.11
- uses: dtolnay/rust-toolchain@1.85.0
with:
components: rustfmt, clippy
- name: Install Protoc
uses: arduino/setup-protoc@v1
- name: Clean unused files
run: |
sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android
sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET
- name: Install
run: |
sudo apt update
sudo apt install python3.11-dev -y
pip install -U pip uv
uv venv
source ./.venv/bin/activate
make install-cpu
- name: Download locked kernels
run: |
source ./.venv/bin/activate
kernels download server
- name: Run server tests
run: |
source ./.venv/bin/activate
uv pip install pytest
export HF_TOKEN=${{ secrets.HF_TOKEN }}
pytest -s -vv server/tests
- name: Pre-commit checks
run: |
pip install pre-commit
pre-commit install
pre-commit run --all-files
- name: Run Rust tests
run: |
cargo test
- name: Run Rust tests with google feature
run: |
cargo test --features google
================================================
FILE: .github/workflows/trufflehog.yaml
================================================
on:
push:
name: Secret Leaks
permissions:
contents: read
jobs:
trufflehog:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@853e1e8d249fd1e29d0fcc7280d29b03df3d643d
with:
# exclude buggy postgres detector that is causing false positives and not relevant to our codebase
extra_args: --results=verified,unknown --exclude-detectors=postgres
================================================
FILE: .github/workflows/upload_pr_documentation.yaml
================================================
name: Upload PR Documentation
on:
workflow_run:
workflows: ["Build PR Documentation"]
types:
- completed
jobs:
build:
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
with:
package_name: text-generation-inference
secrets:
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
================================================
FILE: .gitignore
================================================
.idea
target
router/tokenizer.json
*__pycache__*
backends/v2/src/client/pb
backends/v3/src/client/pb
backends/client/src/v2/pb
backends/client/src/v3/pb
# ROCm auto-generated files
*.hip
server/exllamav2
server/exllama_kernels/exllama_kernels/hip/
server/exllama_kernels/exllama_kernels/hip_func/
*_hip.cuh
server/exllama_kernels/exllama_kernels/hip_buffers.cuh
server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
data/
load_tests/*.json
server/fbgemmm
.direnv/
.venv/
# Gaudi auto-generated files
hl-smi_log*.txt
.graph_dumps
out
hqt_output
================================================
FILE: .pre-commit-config.yaml
================================================
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
exclude: crate-hashes.json
- id: trailing-whitespace
exclude: docs/source/reference/launcher.md
- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
- id: black
- repo: https://github.com/doublify/pre-commit-rust
rev: v1.0
hooks:
- id: cargo-check
- id: fmt
- id: clippy
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
================================================
FILE: .redocly.lint-ignore.yaml
================================================
# This file instructs Redocly's linter to ignore the rules contained for specific parts of your API.
# See https://redoc.ly/docs/cli/ for more information.
docs/openapi.json:
no-empty-servers:
- '#/openapi'
spec:
- >-
#/components/schemas/GenerateParameters/properties/best_of/exclusiveMinimum
- >-
#/components/schemas/GenerateParameters/properties/frequency_penalty/exclusiveMinimum
- '#/components/schemas/GenerateParameters/properties/grammar/nullable'
- >-
#/components/schemas/GenerateParameters/properties/repetition_penalty/exclusiveMinimum
- '#/components/schemas/GenerateParameters/properties/seed/exclusiveMinimum'
- >-
#/components/schemas/GenerateParameters/properties/temperature/exclusiveMinimum
- '#/components/schemas/GenerateParameters/properties/top_k/exclusiveMinimum'
- >-
#/components/schemas/GenerateParameters/properties/top_n_tokens/exclusiveMinimum
- '#/components/schemas/GenerateParameters/properties/top_p/exclusiveMinimum'
- >-
#/components/schemas/GenerateParameters/properties/typical_p/exclusiveMinimum
- '#/components/schemas/GenerateResponse/properties/details/nullable'
- '#/components/schemas/StreamResponse/properties/details/nullable'
- '#/components/schemas/ChatRequest/properties/response_format/nullable'
- '#/components/schemas/ChatRequest/properties/stream_options/nullable'
- '#/components/schemas/ChatRequest/properties/tool_choice/nullable'
- '#/components/schemas/ToolChoice/nullable'
- '#/components/schemas/ChatCompletionComplete/properties/logprobs/nullable'
- '#/components/schemas/ChatCompletionChunk/properties/usage/nullable'
- '#/components/schemas/ChatCompletionChoice/properties/logprobs/nullable'
no-invalid-media-type-examples:
- '#/paths/~1/post/responses/422/content/application~1json/example'
- '#/paths/~1/post/responses/424/content/application~1json/example'
- '#/paths/~1/post/responses/429/content/application~1json/example'
- '#/paths/~1/post/responses/500/content/application~1json/example'
- '#/paths/~1generate/post/responses/422/content/application~1json/example'
- '#/paths/~1generate/post/responses/424/content/application~1json/example'
- '#/paths/~1generate/post/responses/429/content/application~1json/example'
- '#/paths/~1generate/post/responses/500/content/application~1json/example'
- >-
#/paths/~1generate_stream/post/responses/422/content/text~1event-stream/example
- >-
#/paths/~1generate_stream/post/responses/424/content/text~1event-stream/example
- >-
#/paths/~1generate_stream/post/responses/429/content/text~1event-stream/example
- >-
#/paths/~1generate_stream/post/responses/500/content/text~1event-stream/example
- '#/paths/~1tokenize/post/responses/404/content/application~1json/example'
- >-
#/paths/~1v1~1chat~1completions/post/responses/422/content/application~1json/example
- >-
#/paths/~1v1~1chat~1completions/post/responses/424/content/application~1json/example
- >-
#/paths/~1v1~1chat~1completions/post/responses/429/content/application~1json/example
- >-
#/paths/~1v1~1chat~1completions/post/responses/500/content/application~1json/example
- >-
#/paths/~1v1~1completions/post/responses/422/content/application~1json/example
- >-
#/paths/~1v1~1completions/post/responses/424/content/application~1json/example
- >-
#/paths/~1v1~1completions/post/responses/429/content/application~1json/example
- >-
#/paths/~1v1~1completions/post/responses/500/content/application~1json/example
operation-4xx-response:
- '#/paths/~1health/get/responses'
- '#/paths/~1info/get/responses'
- '#/paths/~1metrics/get/responses'
no-unused-components:
- '#/components/schemas/Completion'
security-defined:
- '#/paths/~1/post'
- '#/paths/~1generate/post'
- '#/paths/~1generate_stream/post'
- '#/paths/~1health/get'
- '#/paths/~1info/get'
- '#/paths/~1metrics/get'
- '#/paths/~1tokenize/post'
- '#/paths/~1v1~1chat~1completions/post'
- '#/paths/~1v1~1completions/post'
- '#/paths/~1v1~1models/get'
================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, caste, color, religion, or sexual
identity and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the overall
community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or advances of
any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email address,
without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
feedback@huggingface.co.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series of
actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or permanent
ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within the
community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.1, available at
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
Community Impact Guidelines were inspired by
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
For answers to common questions about this code of conduct, see the FAQ at
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
[https://www.contributor-covenant.org/translations][translations].
[homepage]: https://www.contributor-covenant.org
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
[Mozilla CoC]: https://github.com/mozilla/diversity
[FAQ]: https://www.contributor-covenant.org/faq
[translations]: https://www.contributor-covenant.org/translations
================================================
FILE: CONTRIBUTING.md
================================================
# Contribute to text-generation-inference
Everyone is welcome to contribute, and we value everybody's contribution. Code
contributions are not the only way to help the community. Answering questions, helping
others, and improving the documentation are also immensely valuable.
It also helps us if you spread the word! Reference the library in blog posts
about the awesome projects it made possible, shout out on Twitter every time it has
helped you, or simply ⭐️ the repository to say thank you.
However you choose to contribute, please be mindful and respect our
[code of conduct](https://github.com/huggingface/text-generation-inference/blob/main/CODE_OF_CONDUCT.md).
**This guide was heavily inspired by the awesome [scikit-learn guide to contributing](https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md).**
## Ways to contribute
There are several ways you can contribute to text-generation-inference.
* Fix outstanding issues with the existing code.
* Submit issues related to bugs or desired new features.
* Contribute to the examples or to the documentation.
> All contributions are equally valuable to the community. 🥰
## Fixing outstanding issues
If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request) and open
a Pull Request!
## Submitting a bug-related issue or feature request
Do your best to follow these guidelines when submitting a bug-related issue or a feature
request. It will make it easier for us to come back to you quickly and with good
feedback.
### Did you find a bug?
The text-generation-inference library is robust and reliable thanks to users who report the problems they encounter.
Before you report an issue, we would really appreciate it if you could **make sure the bug was not
already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the
library itself, and not your code.
Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so
we can quickly resolve it:
* Your **OS type and version**, as well as your environment versions (versions of rust, python, and dependencies).
* A short, self-contained, code snippet that allows us to reproduce the bug.
* The *full* traceback if an exception is raised.
* Attach any other additional information, like screenshots, you think may help.
To get the OS and software versions automatically, you can re-run the launcher with the `--env` flag:
```bash
text-generation-launcher --env
```
This will precede the launch of the model with the information relative to your environment. We recommend pasting
that in your issue report.
### Do you want a new feature?
If there is a new feature you'd like to see in text-generation-inference, please open an issue and describe:
1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it
a feature related to something you need for a project? Is it something you worked on and think it could benefit
the community?
Whatever it is, we'd love to hear about it!
2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better
we'll be able to help you.
3. Provide a *code snippet* that demonstrates the feature's usage.
4. If the feature is related to a paper, please include a link.
If your issue is well written we're already 80% of the way there by the time you create it.
We have added [templates](https://github.com/huggingface/text-generation-inference/tree/main/.github/ISSUE_TEMPLATE)
to help you get started with your issue.
## Do you want to implement a new model?
New models are constantly released and if you want to implement a new model, please provide the following information:
* A short description of the model and a link to the paper.
* Link to the implementation if it is open-sourced.
* Link to the model weights if they are available.
If you are willing to contribute the model yourself, let us know so we can help you add it to text-generation-inference!
## Do you want to add documentation?
We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know
how the documentation can be improved such as typos and any content that is missing, unclear or inaccurate. We'll be
happy to make the changes or help you make a contribution if you're interested!
## I want to become a maintainer of the project. How do I get there?
TGI is a project led and managed by Hugging Face as it powers our internal services. However, we are happy to have
motivated individuals from other organizations join us as maintainers with the goal of making TGI the best inference
service.
If you are such an individual (or organization), please reach out to us and let's collaborate.
================================================
FILE: Cargo.toml
================================================
[workspace]
members = [
"benchmark",
"backends/v2",
"backends/v3",
"backends/grpc-metadata",
"backends/trtllm",
"backends/llamacpp",
"launcher",
"router"
]
default-members = [
"benchmark",
"backends/v2",
"backends/v3",
"backends/grpc-metadata",
# "backends/trtllm",
"launcher",
"router"
]
resolver = "2"
[workspace.package]
version = "3.3.6-dev0"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"
[workspace.dependencies]
base64 = "0.22.0"
tokenizers = { version = "0.20.0", features = ["http"] }
hf-hub = { version = "0.4.2", features = ["tokio"] }
metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
minijinja = { version = "2.2.0", features = ["json"] }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
[profile.release]
incremental = true
[profile.release-binary]
inherits = "release"
debug = 1
incremental = true
panic = "abort"
[profile.release-opt]
inherits = "release"
debug = 0
incremental = false
lto = "fat"
opt-level = 3
codegen-units = 1
================================================
FILE: Dockerfile
================================================
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
FROM chef AS planner
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.11-dev
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP
COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
ARG GIT_SHA
ARG DOCKER_LABEL
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo build --profile release-opt --frozen
# Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
WORKDIR /usr/src/
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
ARG PYTORCH_VERSION=2.7
ARG PYTHON_VERSION=3.11
# Keep in sync with `server/pyproject.toml
# Automatically set by buildx
ARG TARGETPLATFORM
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
ca-certificates \
ccache \
curl \
git && \
rm -rf /var/lib/apt/lists/*
COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
ENV PATH="$PATH:/root/.local/bin"
RUN uv python install ${PYTHON_VERSION}
RUN uv venv --python ${PYTHON_VERSION} && uv pip install torch==${PYTORCH_VERSION} torchvision pip setuptools packaging
ENV VIRTUAL_ENV=/usr/src/.venv/
ENV PATH="$PATH:/usr/src/.venv/bin/"
# CUDA kernels builder image
FROM pytorch-install AS kernel-builder
ARG MAX_JOBS=8
ENV TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0+PTX"
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
ninja-build cmake \
&& rm -rf /var/lib/apt/lists/*
# Build Flash Attention CUDA kernels
FROM kernel-builder AS flash-att-builder
WORKDIR /usr/src
COPY server/Makefile-flash-att Makefile
# Build specific version of flash attention
RUN . .venv/bin/activate && make build-flash-attention
# Build Flash Attention v2 CUDA kernels
FROM kernel-builder AS flash-att-v2-builder
WORKDIR /usr/src
COPY server/Makefile-flash-att-v2 Makefile
# Build specific version of flash attention v2
RUN . .venv/bin/activate && make build-flash-attention-v2-cuda
# Build Transformers exllama kernels
FROM kernel-builder AS exllama-kernels-builder
WORKDIR /usr/src
COPY server/exllama_kernels/ .
RUN . .venv/bin/activate && python setup.py build
# Build Transformers exllama kernels
FROM kernel-builder AS exllamav2-kernels-builder
WORKDIR /usr/src
COPY server/Makefile-exllamav2/ Makefile
# Build specific version of transformers
RUN . .venv/bin/activate && make build-exllamav2
# Build Transformers awq kernels
FROM kernel-builder AS awq-kernels-builder
WORKDIR /usr/src
COPY server/Makefile-awq Makefile
# Build specific version of transformers
RUN . .venv/bin/activate && make build-awq
# Build Transformers CUDA kernels
FROM kernel-builder AS custom-kernels-builder
WORKDIR /usr/src
COPY server/custom_kernels/ .
# Build specific version of transformers
RUN . .venv/bin/activate && python setup.py build
# Build mamba kernels
FROM kernel-builder AS mamba-builder
WORKDIR /usr/src
COPY server/Makefile-selective-scan Makefile
RUN . .venv/bin/activate && make build-all
# Build flashinfer
FROM kernel-builder AS flashinfer-builder
WORKDIR /usr/src
COPY server/Makefile-flashinfer Makefile
RUN . .venv/bin/activate && make install-flashinfer
# Text Generation Inference base image
FROM nvidia/cuda:12.4.0-base-ubuntu22.04 AS base
# Text Generation Inference base env
ENV HF_HOME=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80
WORKDIR /usr/src
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
libssl-dev \
ca-certificates \
make \
curl \
git \
&& rm -rf /var/lib/apt/lists/*
# RUN curl -LsSf https://astral.sh/uv/install.sh | sh
# ENV PATH="$PATH:/root/.local/bin"
COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
# Install flash-attention dependencies
# RUN pip install einops --no-cache-dir
# Copy env with PyTorch installed
COPY --from=pytorch-install /usr/src/.venv /usr/src/.venv
ENV PYTHON_VERSION=3.11
RUN uv python install ${PYTHON_VERSION}
ENV VIRTUAL_ENV=/usr/src/.venv/
ENV PATH="$PATH:/usr/src/.venv/bin/"
# Install server
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
ENV HF_KERNELS_CACHE=/kernels
RUN cd server && \
uv sync --frozen --extra gen --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --extra torch --no-install-project --active && \
make gen-server-raw && \
kernels download .
RUN cd server && \
uv sync --frozen --extra gen --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --extra torch --active --python=${PYTHON_VERSION} && \
uv pip install nvidia-nccl-cu12==2.25.1 && \
pwd && \
text-generation-server --help
# Copy build artifacts from flash attention builder
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /usr/src/.venv/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from exllama kernels builder
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from awq kernels builder
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages
# Copy build artifacts from mamba builder
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages
COPY --from=flashinfer-builder /usr/src/.venv/lib/python3.11/site-packages/flashinfer/ /usr/src/.venv/lib/python3.11/site-packages/flashinfer/
# ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
# Required to find libpython within the rust binaries
# This is needed because exl2 tries to load flash-attn
# And fails with our builds.
ENV EXLLAMA_NO_FLASH_ATTN=1
# Deps before the binaries
# The binaries change on every build given we burn the SHA into them
# The deps change less often.
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
g++ \
&& rm -rf /var/lib/apt/lists/*
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
# AWS Sagemaker compatible image
FROM base AS sagemaker
COPY sagemaker-entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh
ENTRYPOINT ["./entrypoint.sh"]
# Final image
FROM base
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/root/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/"
ENTRYPOINT ["/tgi-entrypoint.sh"]
# CMD ["--json-output"]
================================================
FILE: Dockerfile.neuron
================================================
# Fetch and extract the TGI sources
FROM alpine AS tgi
RUN mkdir -p /tgi
# Fetch the optimum-neuron sources directly to avoid relying on pypi deployments
FROM alpine AS optimum-neuron
RUN mkdir -p /optimum-neuron
ADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.3.0.tar.gz /optimum-neuron/sources.tar.gz
RUN tar -C /optimum-neuron -xf /optimum-neuron/sources.tar.gz --strip-components=1
# Build cargo components (adapted from TGI original Dockerfile)
# Note: we cannot use the cargo-chef base image as it uses python 3.11
FROM ubuntu:22.04 AS chef
RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \
curl ca-certificates build-essential \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.1 --profile minimal -y
ENV PATH="/root/.cargo/bin:${PATH}"
RUN cargo install cargo-chef --locked
WORKDIR /usr/src
FROM chef AS planner
COPY backends/neuron/Cargo.toml Cargo.toml
COPY Cargo.lock Cargo.lock
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder
RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \
unzip python3-dev libssl-dev pkg-config \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP
COPY backends/neuron/Cargo.toml Cargo.toml
COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json
COPY Cargo.lock Cargo.lock
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo build --release
# Python base image
FROM ubuntu:22.04 AS base
RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \
python3-pip \
python3-setuptools \
python-is-python3 \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
RUN pip3 --no-cache-dir install --upgrade pip
# Python server build image
FROM base AS pyserver
RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \
make \
python3-venv \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
RUN install -d /pyserver
WORKDIR /pyserver
COPY backends/neuron/server server
COPY proto proto
RUN pip3 install -r server/build-requirements.txt
RUN VERBOSE=1 BUILDDIR=/pyserver/build PROTODIR=/pyserver/proto make -C server package
# Neuron base image (used for deployment)
FROM base AS neuron
# Install system prerequisites
RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \
gnupg2 \
wget \
python3-dev \
libexpat1 \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
RUN echo "deb https://apt.repos.neuron.amazonaws.com jammy main" > /etc/apt/sources.list.d/neuron.list
RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | apt-key add -
# Install neuronx packages
RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \
aws-neuronx-dkms=2.22.2.0 \
aws-neuronx-collectives=2.26.43.0-47cc904ea \
aws-neuronx-runtime-lib=2.26.42.0-2ff3b5c7d \
aws-neuronx-tools=2.24.54.0 \
libxml2 \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
ENV PATH="/opt/bin/:/opt/aws/neuron/bin:${PATH}"
# Install manually torch CPU version to avoid pulling CUDA
RUN pip3 install \
torch==2.7.0 \
torchvision==0.22.0 \
--index-url https://download.pytorch.org/whl/cpu
RUN pip3 install \
neuronx-cc==2.19.8089.0+8ab9f450 \
torch-neuronx==2.7.0.2.8.6734+ac864f72 \
neuronx-distributed==0.13.14393+b8569585 \
libneuronxla==2.2.4410.0+835a67fb \
--extra-index-url=https://pip.repos.neuron.amazonaws.com
# Install HuggingFace packages
RUN pip3 install \
hf_transfer huggingface_hub
# Install optimum-neuron
COPY --from=optimum-neuron /optimum-neuron optimum-neuron
RUN pip3 install ./optimum-neuron
# TGI base env
ENV HUGGINGFACE_HUB_CACHE=/tmp \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80
# Disable color logs as they are not supported by CloudWatch
ENV LOGURU_COLORIZE=NO
ENV LOG_COLORIZE=0
# Install router
COPY --from=builder /usr/src/target/release/text-generation-router-v2 /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
# Install python server
COPY --from=pyserver /pyserver/build/dist dist
RUN pip install dist/text_generation_server*.tar.gz
# Final image
FROM neuron
COPY backends/neuron/tgi_entry_point.py /tgi_entry_point.py
COPY backends/neuron/tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh
ENTRYPOINT ["/tgi-entrypoint.sh"]
================================================
FILE: Dockerfile.nix
================================================
# Build the image and get out the docker file:
#
# docker build -t tgi-nix-builder -f Dockerfile.nix
# docker run --log-driver=none tgi-nix-builder | docker load
FROM nixos/nix:2.18.8 AS builder
RUN echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf
RUN nix profile install nixpkgs#cachix
RUN cachix use huggingface
WORKDIR /root
ADD . .
RUN nix build .
RUN mkdir /tmp/nix-store-closure
RUN cp -R $(nix-store -qR result/) /tmp/nix-store-closure
FROM ubuntu:24.04
WORKDIR /app
# Copy /nix/store
COPY --from=builder /tmp/nix-store-closure /nix/store
COPY --from=builder /root/result /app
RUN ldconfig
CMD ["ldconfig", "/app/bin/text-generation-launcher"]
================================================
FILE: Dockerfile_amd
================================================
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
FROM chef AS planner
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.11-dev
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP
COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
ARG GIT_SHA
ARG DOCKER_LABEL
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo build --profile release-opt --frozen
FROM rocm/dev-ubuntu-22.04:6.3.1-complete AS base
ARG HIPBLASLT_BRANCH="4d40e36"
ARG HIPBLAS_COMMON_BRANCH="7c1566b"
ARG LEGACY_HIPBLASLT_OPTION=
ARG RCCL_BRANCH="648a58d"
ARG RCCL_REPO="https://github.com/ROCm/rccl"
ARG TRITON_BRANCH="e5be006"
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
ARG PYTORCH_BRANCH="3a585126"
ARG PYTORCH_VISION_BRANCH="v0.19.1"
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="b7d29fb"
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
ARG AITER_BRANCH="21d47a9"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
ENV PATH=/opt/rocm/llvm/bin:$PATH
ENV ROCM_PATH=/opt/rocm
ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
ARG PYTHON_VERSION=3.11
RUN mkdir -p /app
WORKDIR /app
ENV DEBIAN_FRONTEND=noninteractive
# Install Python and other dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
ca-certificates \
ccache \
curl \
git \
ninja-build \
cmake \
software-properties-common \
python3.11-dev \
python3.11-venv && \
rm -rf /var/lib/apt/lists/*
COPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/
ENV PATH="$PATH:/root/.local/bin"
RUN uv python install ${PYTHON_VERSION}
RUN uv venv --python ${PYTHON_VERSION} && uv pip install pip setuptools packaging
ENV VIRTUAL_ENV=/usr/src/.venv/
ENV PATH="$PATH:/usr/src/.venv/bin/"
RUN . .venv/bin/activate && pip install -U packaging cmake ninja wheel setuptools pybind11 Cython
FROM base AS build_hipblaslt
ARG HIPBLASLT_BRANCH
ARG HIPBLAS_COMMON_BRANCH
# Set to "--legacy_hipblas_direct" for ROCm<=6.2
ARG LEGACY_HIPBLASLT_OPTION
RUN git clone https://github.com/ROCm/hipBLAS-common.git
RUN . .venv/bin/activate && cd hipBLAS-common \
&& git checkout ${HIPBLAS_COMMON_BRANCH} \
&& mkdir build \
&& cd build \
&& cmake .. \
&& make package \
&& dpkg -i ./*.deb
RUN git clone https://github.com/ROCm/hipBLASLt
RUN . .venv/bin/activate && cd hipBLASLt \
&& git checkout ${HIPBLASLT_BRANCH} \
&& ./install.sh -d --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \
&& cd build/release \
&& make package
RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install
FROM base AS build_rccl
ARG RCCL_BRANCH
ARG RCCL_REPO
RUN git clone ${RCCL_REPO}
RUN . .venv/bin/activate && cd rccl \
&& git checkout ${RCCL_BRANCH} \
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install
FROM base AS build_triton
ARG TRITON_BRANCH
ARG TRITON_REPO
RUN git clone ${TRITON_REPO}
RUN . .venv/bin/activate && cd triton \
&& git checkout ${TRITON_BRANCH} \
&& cd python \
&& python3 setup.py bdist_wheel --dist-dir=dist
RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install
FROM base AS build_amdsmi
RUN . .venv/bin/activate && cd /opt/rocm/share/amd_smi \
&& pip wheel . --wheel-dir=dist
RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install
FROM base AS build_pytorch
ARG PYTORCH_BRANCH
ARG PYTORCH_VISION_BRANCH
ARG PYTORCH_REPO
ARG PYTORCH_VISION_REPO
ARG FA_BRANCH
ARG FA_REPO
RUN git clone ${PYTORCH_REPO} pytorch
RUN . .venv/bin/activate && cd pytorch && git checkout ${PYTORCH_BRANCH} && \
pip install -r requirements.txt && git submodule update --init --recursive \
&& python3 tools/amd_build/build_amd.py \
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
&& pip install dist/*.whl
RUN git clone ${PYTORCH_VISION_REPO} vision
RUN . .venv/bin/activate && cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
&& python3 setup.py bdist_wheel --dist-dir=dist \
&& pip install dist/*.whl
RUN git clone ${FA_REPO}
RUN . .venv/bin/activate && cd flash-attention \
&& git checkout ${FA_BRANCH} \
&& git submodule update --init \
&& MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
&& cp /app/vision/dist/*.whl /app/install \
&& cp /app/flash-attention/dist/*.whl /app/install
FROM base AS final
RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \
dpkg -i /install/*deb \
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status
RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \
dpkg -i /install/*deb \
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
. .venv/bin/activate && \
pip install /install/*.whl
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
. .venv/bin/activate && \
pip install /install/*.whl
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
. .venv/bin/activate && \
pip install /install/*.whl
ARG AITER_REPO
ARG AITER_BRANCH
RUN git clone --recursive ${AITER_REPO}
RUN . .venv/bin/activate && cd aiter \
&& git checkout ${AITER_BRANCH} \
&& git submodule update --init --recursive \
&& pip install -r requirements.txt \
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter
RUN rm -rf /var/lib/apt/lists/*
FROM final AS kernel-builder
# # Build vllm kernels
FROM kernel-builder AS vllm-builder
COPY server/Makefile-vllm Makefile
RUN . .venv/bin/activate && pip install setuptools_scm
# Build specific version of vllm
RUN . .venv/bin/activate && make build-vllm-rocm
# Build Transformers CUDA kernels (gpt-neox and bloom)
FROM kernel-builder AS custom-kernels-builder
COPY server/custom_kernels/ .
RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
# Build exllama kernels
FROM kernel-builder AS exllama-kernels-builder
COPY server/exllama_kernels/ .
RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
# Build exllama v2 kernels
FROM kernel-builder AS exllamav2-kernels-builder
COPY server/exllamav2_kernels/ .
RUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist
FROM kernel-builder AS marlin-kernels
ENV MARLIN_KERNELS_BRANCH=v0.3.6
ENV VLLM_TARGET_DEVICE=rocm
RUN . .venv/bin/activate && git clone https://github.com/danieldk/marlin-kernels.git && \
cd marlin-kernels && \
git checkout ${MARLIN_KERNELS_BRANCH} && \
python3 setup.py bdist_wheel --dist-dir=dist
FROM kernel-builder AS moe-kernels
ENV MOE_KERNELS_BRANCH=v0.8.2
ENV VLLM_TARGET_DEVICE=rocm
RUN . .venv/bin/activate && git clone https://github.com/danieldk/moe-kernels.git && \
cd moe-kernels && \
git checkout ${MOE_KERNELS_BRANCH} && \
python3 setup.py bdist_wheel --dist-dir=dist
FROM final AS base-copy
# Text Generation Inference base env
ENV HF_HOME=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80
ENV VIRTUAL_ENV=/app/.venv/
ENV PATH="$PATH:/app/.venv/bin/"
# Install server
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
RUN cd server && \
uv pip install grpcio-tools mypy-protobuf && \
uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir && \
make gen-server-raw
RUN cd server && \
pwd && \
text-generation-server --help
RUN --mount=type=bind,from=vllm-builder,src=/app/vllm/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=custom-kernels-builder,src=/app/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=custom-kernels-builder,src=/app/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=exllama-kernels-builder,src=/app/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=exllamav2-kernels-builder,src=/app/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=marlin-kernels,src=/app/marlin-kernels/dist,target=/install \
uv pip install /install/*.whl
RUN --mount=type=bind,from=moe-kernels,src=/app/moe-kernels/dist,target=/install \
uv pip install /install/*.whl
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
# AWS Sagemaker compatible image
FROM base AS sagemaker
COPY sagemaker-entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh
ENTRYPOINT ["./entrypoint.sh"]
# Final image
FROM base-copy
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
ENV HIP_FORCE_DEV_KERNARG=1
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
# However, Triton requires a tunning for each prompt length, which is prohibitive.
ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
ENV VLLM_MOE_PADDING=0
ENV ATTENTION=paged
ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0
ENV ROCM_USE_SKINNY_GEMM=1
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh
ENTRYPOINT ["/tgi-entrypoint.sh"]
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/root/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib"
ENV PYTHONPATH=/app/.venv/lib/python3.11/site-packages
# CMD ["--json-output"]
================================================
FILE: Dockerfile_gaudi
================================================
# Those arguments are required to build the image
ARG HABANA_VERSION=1.21.0
ARG PYTORCH_VERSION=2.6.0
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
FROM chef AS planner
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder
ENV PYO3_PYTHON="/root/.local/bin/python" \
PYTHON_SYS_EXECUTABLE="/root/.local/bin/python" \
PYO3_PYTHON_VERSION="3.10"
RUN curl -LsSf https://astral.sh/uv/install.sh | sh \
&& . $HOME/.local/bin/env \
&& uv python install 3.10 --default --preview \
&& test -f /root/.local/bin/python || (echo "Python 3.10 not found at /root/.local/bin/python" && exit 1)
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP
COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
ARG GIT_SHA
ARG DOCKER_LABEL
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo build --profile release-opt
# Text Generation Inference base image
ARG HABANA_VERSION
ARG PYTORCH_VERSION
FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base
ENV ATTENTION=paged
ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0
ENV PT_HPU_LAZY_MODE=1
ENV PT_HPU_WEIGHT_SHARING=0
ENV VLLM_EXPONENTIAL_BUCKETING=true
# Text Generation Inference base env
ENV HF_HOME=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80
# Assert that Python 3.10 is installed as the launcher is compiled with Python 3.10
RUN python3.10 --version || (echo "Python 3.10 is not installed" && exit 1)
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
WORKDIR /usr/src
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
libssl-dev \
ca-certificates \
make \
curl \
git \
&& rm -rf /var/lib/apt/lists/*
# Install server
COPY proto proto
COPY backends/gaudi/server server
COPY backends/gaudi/server/Makefile server/Makefile
ARG HABANA_VERSION
RUN cd server && \
make gen-server && \
pip install --no-deps -r requirements.txt && \
bash ./dill-0.3.8-patch.sh && \
pip install . --no-cache-dir
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix
RUN pip install compressed-tensors==0.9.1
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
# AWS Sagemaker compatible image
FROM base AS sagemaker
COPY sagemaker-entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh
ENTRYPOINT ["./entrypoint.sh"]
# Final image
FROM base
ENV HF_HUB_ENABLE_HF_TRANSFER=1
ENV HABANA_VISIBLE_DEVICES=all
ENV OMPI_MCA_btl_vader_single_copy_mechanism=NONE
COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh
ENTRYPOINT ["/tgi-entrypoint.sh"]
CMD ["--json-output"]
================================================
FILE: Dockerfile_intel
================================================
ARG PLATFORM=xpu
FROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
FROM chef AS planner
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.11-dev
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP
COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
ARG GIT_SHA
ARG DOCKER_LABEL
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo build --profile release-opt --frozen
# Text Generation Inference base image for Intel
FROM intel/oneapi-basekit:2025.1.3-0-devel-ubuntu22.04 AS xpu
USER root
ARG MAMBA_VERSION=23.1.0-1
ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH=/opt/conda/bin:$PATH
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba
# translating Docker's TARGETPLATFORM into mamba arches
RUN case ${TARGETPLATFORM} in \
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
*) MAMBA_ARCH=x86_64 ;; \
esac && \
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
RUN chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \
rm ~/mambaforge.sh
RUN case ${TARGETPLATFORM} in \
"linux/arm64") exit 1 ;; \
*) /opt/conda/bin/conda update -y conda && \
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
esac && \
/opt/conda/bin/conda clean -ya
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/intel-for-pytorch-gpu-dev all main" > /tmp/intel-for-pytorch-gpu-dev.list
RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc libnl-genl-3-200
# Text Generation Inference base env
ENV HF_HOME=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80
WORKDIR /usr/src
RUN pip install torch==2.8.0 torchvision==0.23.0 --index-url https://download.pytorch.org/whl/xpu
# Install server
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
ENV UV_SYSTEM_PYTHON=1
RUN cd server && \
make gen-server && \
pip install -U pip uv && \
uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib
ENV CCL_ZE_IPC_EXCHANGE=sockets
ENV TORCH_LLM_ALLREDUCE=1
ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.8.10%2Bxpu-cp311-cp311-linux_x86_64.whl
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
# Text Generation Inference base image for Intel-cpu
FROM ubuntu:22.04 AS cpu
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
curl \
ca-certificates \
make \
g++-12 \
gcc-12 \
git \
wget \
cmake \
libnuma-dev
RUN update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-12 12
RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 12
RUN update-alternatives --install /usr/bin/cc cc /usr/bin/gcc 30
RUN update-alternatives --set cc /usr/bin/gcc
RUN update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++ 30
RUN update-alternatives --set c++ /usr/bin/g++
ENV HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80
ARG MAMBA_VERSION=23.1.0-1
ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH=/opt/conda/bin:$PATH
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba
# translating Docker's TARGETPLATFORM into mamba arches
RUN case ${TARGETPLATFORM} in \
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
*) MAMBA_ARCH=x86_64 ;; \
esac && \
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
RUN chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \
rm ~/mambaforge.sh
RUN case ${TARGETPLATFORM} in \
"linux/arm64") exit 1 ;; \
*) /opt/conda/bin/conda update -y conda && \
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
esac && \
/opt/conda/bin/conda clean -ya
RUN conda install -c conda-forge gperftools mkl
RUN pip install torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cpu
RUN pip install triton==3.2.0 py-libnuma
WORKDIR /usr/src
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/intel_extension_for_pytorch-2.7.0%2Bcpu-cp311-cp311-linux_x86_64.whl
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/oneccl_bind_pt-2.7.0%2Bcpu-cp311-cp311-linux_x86_64.whl
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so
ENV CCL_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch
ENV I_MPI_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/lib
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
# Install server
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile
ENV UV_SYSTEM_PYTHON=1
RUN cd server && \
make gen-server && \
pip install -U pip uv && \
uv pip install -e ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
FROM ${PLATFORM} AS final
ENV ATTENTION=flashdecoding-ipex
ENV PREFIX_CACHING=1
ENV PREFILL_CHUNKING=1
ENV CUDA_GRAPHS=0
ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]
================================================
FILE: Dockerfile_llamacpp
================================================
FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu24.04 AS deps
ARG llamacpp_version=b4827
ARG llamacpp_cuda=OFF
ARG llamacpp_native=ON
ARG llamacpp_cpu_arm_arch=native
ARG cuda_arch=75-real;80-real;86-real;89-real;90-real
WORKDIR /opt/src
ENV DEBIAN_FRONTEND=noninteractive
RUN apt update && apt upgrade -y && apt install -y \
clang \
cmake \
curl \
git \
python3-dev \
libssl-dev \
pkg-config \
tar
ADD https://github.com/ggml-org/llama.cpp/archive/refs/tags/${llamacpp_version}.tar.gz /opt/src/
RUN mkdir -p llama.cpp \
&& tar -xzf ${llamacpp_version}.tar.gz -C llama.cpp --strip-components=1 \
&& cd llama.cpp \
&& cmake -B build \
-DCMAKE_INSTALL_PREFIX=/usr \
-DCMAKE_INSTALL_LIBDIR=/usr/lib \
-DCMAKE_C_COMPILER=clang \
-DCMAKE_CXX_COMPILER=clang++ \
-DCMAKE_CUDA_ARCHITECTURES=${cuda_arch} \
-DGGML_CUDA=${llamacpp_cuda} \
-DGGML_NATIVE=${llamacpp_native} \
-DGGML_CPU_ARM_ARCH=${llamacpp_cpu_arm_arch} \
-DLLAMA_BUILD_COMMON=OFF \
-DLLAMA_BUILD_TESTS=OFF \
-DLLAMA_BUILD_EXAMPLES=OFF \
-DLLAMA_BUILD_SERVER=OFF \
&& cmake --build build --parallel --config Release \
&& cmake --install build
WORKDIR /app
COPY rust-toolchain.toml rust-toolchain.toml
RUN curl -sSf https://sh.rustup.rs | sh -s -- --no-modify-path --default-toolchain 1.85.1 --profile minimal -y
ENV PATH="/root/.cargo/bin:$PATH"
RUN cargo install cargo-chef --locked
FROM deps AS planner
COPY . .
RUN cargo chef prepare --recipe-path recipe.json
FROM deps AS builder
COPY --from=planner /app/recipe.json recipe.json
RUN cargo chef cook \
--recipe-path recipe.json \
--profile release \
--package text-generation-router-llamacpp
COPY . .
RUN cargo build \
--profile release \
--package text-generation-router-llamacpp --frozen
FROM nvidia/cuda:12.8.0-cudnn-runtime-ubuntu24.04
WORKDIR /app
ENV DEBIAN_FRONTEND=noninteractive
RUN apt update && apt upgrade -y && apt install -y \
python3-venv \
python3-pip
RUN python3 -m venv /venv
ENV PATH="/venv/bin:$PATH"
COPY backends/llamacpp/requirements.txt requirements.txt
COPY --from=builder /opt/src/llama.cpp/gguf-py gguf-py
COPY --from=builder /opt/src/llama.cpp/convert_hf_to_gguf.py /bin/
RUN pip3 install --no-cache-dir \
-r requirements.txt \
-e gguf-py
COPY --from=builder /usr/lib/libllama.so /usr/lib/
COPY --from=builder /usr/lib/libggml*.so /usr/lib/
COPY --from=builder /app/target/release/text-generation-router-llamacpp /usr/bin/
ENV HF_HUB_ENABLE_HF_TRANSFER=1
ENTRYPOINT ["text-generation-router-llamacpp"]
================================================
FILE: Dockerfile_trtllm
================================================
ARG cuda_arch_list="75-real;80-real;86-real;89-real;90-real;100-real;120-real"
ARG cuda_base=12.8.0
ARG build_type=release
ARG ompi_version=4.1.7
ARG sccache_gha_enabled=off
ARG actions_results_url=""
ARG actions_runtime_token=""
# CUDA dependent dependencies resolver stage
FROM nvidia/cuda:${cuda_base}-cudnn-devel-ubuntu24.04 AS cuda-builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
build-essential \
cmake \
curl \
gcc-14 \
g++-14 \
git \
git-lfs \
lld \
libssl-dev \
libucx-dev \
libasan8 \
libubsan1 \
ninja-build \
pkg-config \
pipx \
python3 \
python3-dev \
python3-setuptools \
tar \
wget --no-install-recommends && \
pipx ensurepath
ENV TGI_INSTALL_PREFIX=/usr/local/tgi
ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt
# Install OpenMPI
FROM cuda-builder AS mpi-builder
WORKDIR /opt/src/mpi
ARG ompi_version
ENV OMPI_VERSION=${ompi_version}
ENV OMPI_TARBALL_FILENAME=openmpi-${OMPI_VERSION}.tar.bz2
ADD --checksum=sha256:54a33cb7ad81ff0976f15a6cc8003c3922f0f3d8ceed14e1813ef3603f22cd34 \
https://download.open-mpi.org/release/open-mpi/v4.1/${OMPI_TARBALL_FILENAME} .
RUN tar --strip-components=1 -xf ${OMPI_TARBALL_FILENAME} &&\
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \
make -j all && \
make install && \
rm -rf ${OMPI_TARBALL_FILENAME}/..
# Install TensorRT
FROM cuda-builder AS trt-builder
COPY backends/trtllm/scripts/install_tensorrt.sh /opt/install_tensorrt.sh
RUN chmod +x /opt/install_tensorrt.sh && \
/opt/install_tensorrt.sh
# Build Backend
FROM cuda-builder AS tgi-builder
WORKDIR /usr/src/text-generation-inference
# Scoped global args reuse
ARG cuda_arch_list
ARG build_type
ARG sccache_gha_enabled
# Install Rust
ENV PATH="/root/.cargo/bin:$PATH"
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.1 --profile minimal -y && \
chmod -R a+w /root/.rustup && \
chmod -R a+w /root/.cargo && \
cargo install sccache --version ">=0.10.0" --locked
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig"
ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt"
ENV USE_LLD_LINKER=ON
ENV CUDA_ARCH_LIST=${cuda_arch_list}
# SCCACHE Specifics args - before finding a better, more generic, way...
ENV SCCACHE_GHA_ENABLED=${sccache_gha_enabled}
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY router router
COPY backends backends
COPY benchmark benchmark
COPY launcher launcher
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
ENV RUSTC_WRAPPER=sccache
ENV CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX
RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
export CMAKE_C_COMPILER_LAUNCHER=sccache && \
export CMAKE_CXX_COMPILER_LAUNCHER=sccache && \
export CMAKE_CUDA_COMPILER_LAUNCHER=sccache && \
mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \
cargo build --profile ${build_type} --package text-generation-backends-trtllm --bin text-generation-backends-trtllm && \
sccache --show-stats
FROM nvidia/cuda:${cuda_base}-cudnn-runtime-ubuntu24.04 AS runtime
RUN apt update && apt install -y libucx0 pipx python3-minimal python3-dev python3-pip python3-venv && \
rm -rf /var/lib/{apt,dpkg,cache,log}/ && \
pipx ensurepath && \
pipx install --include-deps transformers tokenizers
WORKDIR /usr/local/tgi/bin
ENV PATH=/root/.local/share/pipx/venvs/transformers/bin/:$PATH
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/mpi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
ENV TOKENIZERS_PARALLELISM=false
ENV OMPI_MCA_plm_rsh_agent=""
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
# This is used only for the CI/CD
FROM nvidia/cuda:${cuda_base}-cudnn-runtime-ubuntu24.04 AS ci-runtime
RUN apt update && apt install -y libasan8 libubsan1 libucx0 pipx python3-minimal python3-dev python3-pip python3-venv && \
rm -rf /var/lib/{apt,dpkg,cache,log}/ && \
pipx ensurepath && \
pipx install --include-deps transformers tokenizers
WORKDIR /usr/local/tgi/bin
ENV PATH=/root/.local/share/pipx/venvs/transformers/bin/:$PATH
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/mpi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
ENV TOKENIZERS_PARALLELISM=false
ENV OMPI_MCA_plm_rsh_agent=""
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
# Basically we copy from target/debug instead of target/release
COPY --from=tgi-builder /usr/src/text-generation-inference/target/debug/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
# This is the final image
FROM runtime
LABEL co.huggingface.vendor="Hugging Face Inc."
LABEL org.opencontainers.image.authors="hardware@hf.co"
LABEL org.opencontainers.title="Text-Generation-Inference TensorRT-LLM Backend"
ENTRYPOINT ["./text-generation-launcher"]
CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"]
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2022 Hugging Face
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: Makefile
================================================
install-server:
cd server && make install
install-server-cpu:
cd server && make install-server
install-router:
cargo install --path backends/v3/
install-launcher:
cargo install --path launcher/
install-benchmark:
cargo install --path benchmark/
install: install-server install-router install-launcher
install-cpu: install-server-cpu install-router install-launcher
server-dev:
cd server && make run-dev
router-dev:
cd router && cargo run -- --port 8080
rust-tests: install-router install-launcher
cargo test
install-integration-tests:
cd integration-tests && pip install -r requirements.txt
cd clients/python && pip install .
integration-tests: install-integration-tests
pytest -s -vv -m "not private" integration-tests
update-integration-tests: install-integration-tests
pytest -s -vv --snapshot-update integration-tests
python-server-tests:
HF_HUB_ENABLE_HF_TRANSFER=1 pytest -s -vv -m "not private" server/tests
python-client-tests:
pytest clients/python/tests
python-tests: python-server-tests python-client-tests
run-falcon-7b-instruct:
text-generation-launcher --model-id tiiuae/falcon-7b-instruct --port 8080
run-falcon-7b-instruct-quantize:
text-generation-launcher --model-id tiiuae/falcon-7b-instruct --quantize bitsandbytes --port 8080
clean:
rm -rf target aml
preview_doc:
doc-builder preview text-generation-inference docs/source --not_python_module
================================================
FILE: README.md
================================================
> [!CAUTION]
> text-generation-inference is now in maintenance mode. Going forward, we will accept pull requests for minor bug fixes, documentation improvements and lightweight maintenance tasks.
>
> TGI has initiated the movement for optimized inference engines to rely on a `transformers` model architectures. This approach is now adopted by downstream inference engines, which we contribute to and recommend using going forward: [vllm](https://github.com/vllm-project/vllm), [SGLang](https://github.com/sgl-project/sglang), as well as local engines with inter-compatibility such as llama.cpp or MLX.
# Text Generation Inference
A Rust, Python and gRPC server for text generation inference. Used in production at [Hugging Face](https://huggingface.co)
to power Hugging Chat, the Inference API and Inference Endpoints.
## Table of contents
- [Get Started](#get-started)
- [Docker](#docker)
- [API documentation](#api-documentation)
- [Using a private or gated model](#using-a-private-or-gated-model)
- [A note on Shared Memory (shm)](#a-note-on-shared-memory-shm)
- [Distributed Tracing](#distributed-tracing)
- [Architecture](#architecture)
- [Local install](#local-install)
- [Local install (Nix)](#local-install-nix)
- [Optimized architectures](#optimized-architectures)
- [Run locally](#run-locally)
- [Run](#run)
- [Quantization](#quantization)
- [Develop](#develop)
- [Testing](#testing)
Text Generation Inference (TGI) is a toolkit for deploying and serving Large Language Models (LLMs). TGI enables high-performance text generation for the most popular open-source LLMs, including Llama, Falcon, StarCoder, BLOOM, GPT-NeoX, and [more](https://huggingface.co/docs/text-generation-inference/supported_models). TGI implements many features, such as:
- Simple launcher to serve most popular LLMs
- Production ready (distributed tracing with Open Telemetry, Prometheus metrics)
- Tensor Parallelism for faster inference on multiple GPUs
- Token streaming using Server-Sent Events (SSE)
- Continuous batching of incoming requests for increased total throughput
- [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) compatible with Open AI Chat Completion API
- Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures
- Quantization with :
- [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
- [GPT-Q](https://arxiv.org/abs/2210.17323)
- [EETQ](https://github.com/NetEase-FuXi/EETQ)
- [AWQ](https://github.com/casper-hansen/AutoAWQ)
- [Marlin](https://github.com/IST-DASLab/marlin)
- [fp8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/)
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))
- Stop sequences
- Log probabilities
- [Speculation](https://huggingface.co/docs/text-generation-inference/conceptual/speculation) ~2x latency
- [Guidance/JSON](https://huggingface.co/docs/text-generation-inference/conceptual/guidance). Specify output format to speed up inference and make sure the output is valid according to some specs..
- Custom Prompt Generation: Easily generate text by providing custom prompts to guide the model's output
- Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance
### Hardware support
- [Nvidia](https://github.com/huggingface/text-generation-inference/pkgs/container/text-generation-inference)
- [AMD](https://github.com/huggingface/text-generation-inference/pkgs/container/text-generation-inference) (-rocm)
- [Inferentia](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference)
- [Intel GPU](https://github.com/huggingface/text-generation-inference/pull/1475)
- [Gaudi](https://github.com/huggingface/tgi-gaudi)
- [Google TPU](https://huggingface.co/docs/optimum-tpu/howto/serving)
## Get Started
### Docker
For a detailed starting guide, please see the [Quick Tour](https://huggingface.co/docs/text-generation-inference/quicktour). The easiest way of getting started is using the official Docker container:
```shell
model=HuggingFaceH4/zephyr-7b-beta
# share a volume with the Docker container to avoid downloading weights every run
volume=$PWD/data
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.3.5 --model-id $model
```
And then you can make requests like
```bash
curl 127.0.0.1:8080/generate_stream \
-X POST \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
-H 'Content-Type: application/json'
```
You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses.
```bash
curl localhost:8080/v1/chat/completions \
-X POST \
-d '{
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is deep learning?"
}
],
"stream": true,
"max_tokens": 20
}' \
-H 'Content-Type: application/json'
```
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.5-rocm --model-id $model` instead of the command above.
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
```
text-generation-launcher --help
```
### API documentation
You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route.
The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference).
### Using a private or gated model
You have the option to utilize the `HF_TOKEN` environment variable for configuring the token employed by
`text-generation-inference`. This allows you to gain access to protected resources.
For example, if you want to serve the gated Llama V2 model variants:
1. Go to https://huggingface.co/settings/tokens
2. Copy your CLI READ token
3. Export `HF_TOKEN=`
or with Docker:
```shell
model=meta-llama/Meta-Llama-3.1-8B-Instruct
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
token=
docker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.3.5 --model-id $model
```
### A note on Shared Memory (shm)
[`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by
`PyTorch` to do distributed training/inference. `text-generation-inference` makes
use of `NCCL` to enable Tensor Parallelism to dramatically speed up inference for large language models.
In order to share data between the different devices of a `NCCL` group, `NCCL` might fall back to using the host memory if
peer-to-peer using NVLink or PCI is not possible.
To allow the container to use 1G of Shared Memory and support SHM sharing, we add `--shm-size 1g` on the above command.
If you are running `text-generation-inference` inside `Kubernetes`. You can also add Shared Memory to the container by
creating a volume with:
```yaml
- name: shm
emptyDir:
medium: Memory
sizeLimit: 1Gi
```
and mounting it to `/dev/shm`.
Finally, you can also disable SHM sharing by using the `NCCL_SHM_DISABLE=1` environment variable. However, note that
this will impact performance.
### Distributed Tracing
`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature
by setting the address to an OTLP collector with the `--otlp-endpoint` argument. The default service name can be
overridden with the `--otlp-service-name` argument
### Architecture

Detailed blogpost by Adyen on TGI inner workings: [LLM inference at scale with TGI (Martin Iglesias Goyanes - Adyen, 2024)](https://www.adyen.com/knowledge-hub/llm-inference-at-scale-with-tgi)
### Local install
You can also opt to install `text-generation-inference` locally.
First clone the repository and change directory into it:
```shell
git clone https://github.com/huggingface/text-generation-inference
cd text-generation-inference
```
Then [install Rust](https://rustup.rs/) and create a Python virtual environment with at least
Python 3.9, e.g. using `conda` or `python venv`:
```shell
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
#using conda
conda create -n text-generation-inference python=3.11
conda activate text-generation-inference
#using python venv
python3 -m venv .venv
source .venv/bin/activate
```
You may also need to install Protoc.
On Linux:
```shell
PROTOC_ZIP=protoc-21.12-linux-x86_64.zip
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP
sudo unzip -o $PROTOC_ZIP -d /usr/local bin/protoc
sudo unzip -o $PROTOC_ZIP -d /usr/local 'include/*'
rm -f $PROTOC_ZIP
```
On MacOS, using Homebrew:
```shell
brew install protobuf
```
Then run:
```shell
BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels
text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2
```
**Note:** on some machines, you may also need the OpenSSL libraries and gcc. On Linux machines, run:
```shell
sudo apt-get install libssl-dev gcc -y
```
### Local install (Nix)
Another option is to install `text-generation-inference` locally using [Nix](https://nixos.org). Currently,
we only support Nix on x86_64 Linux with CUDA GPUs. When using Nix, all dependencies can
be pulled from a binary cache, removing the need to build them locally.
First follow the instructions to [install Cachix and enable the Hugging Face cache](https://app.cachix.org/cache/huggingface).
Setting up the cache is important, otherwise Nix will build many of the dependencies
locally, which can take hours.
After that you can run TGI with `nix run`:
```shell
cd text-generation-inference
nix run --extra-experimental-features nix-command --extra-experimental-features flakes . -- --model-id meta-llama/Llama-3.1-8B-Instruct
```
**Note:** when you are using Nix on a non-NixOS system, you have to [make some symlinks](https://danieldk.eu/Nix-CUDA-on-non-NixOS-systems#make-runopengl-driverlib-and-symlink-the-driver-library)
to make the CUDA driver libraries visible to Nix packages.
For TGI development, you can use the `impure` dev shell:
```shell
nix develop .#impure
# Only needed the first time the devshell is started or after updating the protobuf.
(
cd server
mkdir text_generation_server/pb || true
python -m grpc_tools.protoc -I../proto/v3 --python_out=text_generation_server/pb \
--grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/v3/generate.proto
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation_server/pb/__init__.py
)
```
All development dependencies (cargo, Python, Torch), etc. are available in this
dev shell.
## Optimized architectures
TGI works out of the box to serve optimized models for all modern models. They can be found in [this list](https://huggingface.co/docs/text-generation-inference/supported_models).
Other architectures are supported on a best-effort basis using:
`AutoModelForCausalLM.from_pretrained(, device_map="auto")`
or
`AutoModelForSeq2SeqLM.from_pretrained(, device_map="auto")`
## Run locally
### Run
```shell
text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2
```
### Quantization
You can also run pre-quantized weights (AWQ, GPTQ, Marlin) or on-the-fly quantize weights with bitsandbytes, EETQ, fp8, to reduce the VRAM requirement:
```shell
text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantize
```
4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`.
Read more about quantization in the [Quantization documentation](https://huggingface.co/docs/text-generation-inference/en/conceptual/quantization).
## Develop
```shell
make server-dev
make router-dev
```
## Testing
```shell
# python
make python-server-tests
make python-client-tests
# or both server and client tests
make python-tests
# rust cargo tests
make rust-tests
# integration tests
make integration-tests
```
================================================
FILE: assets/tgi_grafana.json
================================================
{
"__inputs": [
{
"name": "DS_PROMETHEUS_EKS API INFERENCE PROD",
"label": "Prometheus EKS API Inference Prod",
"description": "",
"type": "datasource",
"pluginId": "prometheus",
"pluginName": "Prometheus"
}
],
"__elements": {},
"__requires": [
{
"type": "panel",
"id": "gauge",
"name": "Gauge",
"version": ""
},
{
"type": "grafana",
"id": "grafana",
"name": "Grafana",
"version": "10.0.2"
},
{
"type": "panel",
"id": "heatmap",
"name": "Heatmap",
"version": ""
},
{
"type": "datasource",
"id": "prometheus",
"name": "Prometheus",
"version": "1.0.0"
},
{
"type": "panel",
"id": "timeseries",
"name": "Time series",
"version": ""
}
],
"annotations": {
"list": [
{
"builtIn": 1,
"datasource": {
"type": "grafana",
"uid": "-- Grafana --"
},
"enable": true,
"hide": true,
"iconColor": "rgba(0, 211, 255, 1)",
"name": "Annotations & Alerts",
"target": {
"limit": 100,
"matchAny": false,
"tags": [],
"type": "dashboard"
},
"type": "dashboard"
}
]
},
"editable": true,
"fiscalYearStartMonth": 0,
"graphTooltip": 2,
"id": 551,
"links": [],
"liveNow": false,
"panels": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "thresholds"
},
"fieldMinMax": false,
"mappings": [],
"min": 0,
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 1000
}
]
},
"unit": "ms"
},
"overrides": []
},
"gridPos": {
"h": 7,
"w": 8,
"x": 0,
"y": 0
},
"id": 49,
"options": {
"colorMode": "value",
"graphMode": "area",
"justifyMode": "auto",
"orientation": "auto",
"reduceOptions": {
"calcs": [
"mean"
],
"fields": "",
"values": false
},
"showPercentChange": false,
"textMode": "auto",
"wideLayout": true
},
"pluginVersion": "10.4.2",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "(histogram_quantile(0.5, sum by (le) (rate(tgi_request_queue_duration_bucket{container=\"$service\"}[10m]))) * 1000) > 0",
"hide": true,
"instant": false,
"legendFormat": "__auto",
"range": true,
"refId": "B"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "(histogram_quantile(0.5, sum by (le) (rate(tgi_batch_inference_duration_bucket{method=\"prefill\", container=\"$service\"}[10m]))) * 1000) > 0",
"hide": true,
"instant": false,
"legendFormat": "__auto",
"range": true,
"refId": "C"
},
{
"datasource": {
"name": "Expression",
"type": "__expr__",
"uid": "__expr__"
},
"expression": "$B + $C",
"hide": false,
"refId": "D",
"type": "math"
}
],
"title": "Time to first token",
"type": "stat"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "thresholds"
},
"mappings": [],
"min": 0,
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
},
"unit": "ms"
},
"overrides": []
},
"gridPos": {
"h": 7,
"w": 8,
"x": 9,
"y": 0
},
"id": 44,
"options": {
"colorMode": "value",
"graphMode": "area",
"justifyMode": "auto",
"orientation": "auto",
"reduceOptions": {
"calcs": [
"mean"
],
"fields": "",
"values": false
},
"showPercentChange": false,
"textMode": "auto",
"wideLayout": true
},
"pluginVersion": "10.4.2",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "(histogram_quantile(0.5, sum by (le) (rate(tgi_batch_forward_duration_bucket{method=\"decode\", container=\"$service\"}[10m]))) * 1000)>0",
"instant": false,
"range": true,
"refId": "A"
}
],
"title": "Decode per-token latency",
"type": "stat"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "thresholds"
},
"mappings": [],
"min": 0,
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
}
]
},
"unit": "short"
},
"overrides": []
},
"gridPos": {
"h": 7,
"w": 7,
"x": 17,
"y": 0
},
"id": 45,
"options": {
"colorMode": "value",
"graphMode": "area",
"justifyMode": "auto",
"orientation": "auto",
"reduceOptions": {
"calcs": [
"mean"
],
"fields": "",
"values": false
},
"showPercentChange": false,
"textMode": "auto",
"wideLayout": true
},
"pluginVersion": "10.4.2",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "sum((rate(tgi_request_generated_tokens_sum{container=\"$service\"}[10m]) / rate(tgi_request_generated_tokens_count{container=\"$service\"}[10m]))>0)",
"instant": false,
"range": true,
"refId": "A"
}
],
"title": "Throughput (generated tok/s)",
"type": "stat"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
},
"unit": "none"
},
"overrides": [
{
"matcher": {
"id": "byName",
"options": "p50"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "green",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p90"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "orange",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p99"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "red",
"mode": "fixed"
}
}
]
}
]
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 7
},
"id": 48,
"options": {
"legend": {
"calcs": [
"min",
"max"
],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"mode": "single",
"sort": "none"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_request_input_length_bucket{container=\"$service\"}[10m])))",
"legendFormat": "p50",
"range": true,
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_request_input_length_bucket{container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p90",
"range": true,
"refId": "B"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_request_input_length_bucket{container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p99",
"range": true,
"refId": "C"
}
],
"title": "Number of tokens per prompt",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
},
"unit": "none"
},
"overrides": [
{
"matcher": {
"id": "byName",
"options": "p50"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "green",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p90"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "orange",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p99"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "red",
"mode": "fixed"
}
}
]
}
]
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 7
},
"id": 30,
"options": {
"legend": {
"calcs": [
"min",
"max"
],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"mode": "single",
"sort": "none"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_request_generated_tokens_bucket{container=\"$service\"}[10m])))",
"legendFormat": "p50",
"range": true,
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_request_generated_tokens_bucket{container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p90",
"range": true,
"refId": "B"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_request_generated_tokens_bucket{container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p99",
"range": true,
"refId": "C"
}
],
"title": "Number of generated tokens per request",
"type": "timeseries"
},
{
"collapsed": false,
"gridPos": {
"h": 1,
"w": 24,
"x": 0,
"y": 15
},
"id": 20,
"panels": [],
"title": "General",
"type": "row"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 30,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
}
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 6,
"x": 0,
"y": 16
},
"id": 4,
"maxDataPoints": 100,
"options": {
"legend": {
"calcs": [
"min",
"max"
],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"mode": "single",
"sort": "none"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "sum(increase(tgi_request_success{container=\"$service\"}[1m]))",
"legendFormat": "Success",
"range": true,
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "sum(increase(tgi_request_failure{container=\"$service\"}[1m])) by (err)",
"hide": false,
"legendFormat": "Error: {{err}}",
"range": true,
"refId": "B"
}
],
"title": "Requests",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
},
"unit": "s"
},
"overrides": [
{
"matcher": {
"id": "byName",
"options": "p50"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "green",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p90"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "orange",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p99"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "red",
"mode": "fixed"
}
}
]
}
]
},
"gridPos": {
"h": 13,
"w": 9,
"x": 6,
"y": 16
},
"id": 6,
"options": {
"legend": {
"calcs": [
"min",
"max"
],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"mode": "single",
"sort": "none"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_request_mean_time_per_token_duration_bucket{container=\"$service\"}[10m])))",
"legendFormat": "p50",
"range": true,
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_request_mean_time_per_token_duration_bucket{container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p90",
"range": true,
"refId": "B"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_request_mean_time_per_token_duration_bucket{container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p99",
"range": true,
"refId": "C"
}
],
"title": "Mean Time Per Token quantiles",
"type": "timeseries"
},
{
"cards": {},
"color": {
"cardColor": "#5794F2",
"colorScale": "linear",
"colorScheme": "interpolateSpectral",
"exponent": 0.5,
"min": 0,
"mode": "opacity"
},
"dataFormat": "tsbuckets",
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"custom": {
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"scaleDistribution": {
"type": "linear"
}
}
},
"overrides": []
},
"gridPos": {
"h": 13,
"w": 9,
"x": 15,
"y": 16
},
"heatmap": {},
"hideZeroBuckets": false,
"highlightCards": true,
"id": 13,
"legend": {
"show": false
},
"maxDataPoints": 25,
"options": {
"calculate": false,
"calculation": {},
"cellGap": 2,
"cellValues": {},
"color": {
"exponent": 0.5,
"fill": "#5794F2",
"min": 0,
"mode": "scheme",
"reverse": false,
"scale": "exponential",
"scheme": "Spectral",
"steps": 128
},
"exemplars": {
"color": "rgba(255,0,255,0.7)"
},
"filterValues": {
"le": 1e-9
},
"legend": {
"show": false
},
"rowsFrame": {
"layout": "auto"
},
"showValue": "never",
"tooltip": {
"mode": "single",
"showColorScale": false,
"yHistogram": false
},
"yAxis": {
"axisPlacement": "left",
"decimals": 1,
"reverse": false,
"unit": "s"
}
},
"pluginVersion": "10.4.2",
"reverseYBuckets": false,
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"exemplar": true,
"expr": "sum(increase(tgi_request_mean_time_per_token_duration_bucket{container=\"$service\"}[5m])) by (le)",
"format": "heatmap",
"interval": "",
"legendFormat": "{{ le }}",
"range": true,
"refId": "A"
}
],
"title": "Mean Time Per Token",
"tooltip": {
"show": true,
"showHistogram": false
},
"type": "heatmap",
"xAxis": {
"show": true
},
"yAxis": {
"decimals": 1,
"format": "s",
"logBase": 1,
"show": true
},
"yBucketBound": "auto"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "auto",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "percentage",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "orange",
"value": 70
},
{
"color": "red",
"value": 85
}
]
}
},
"overrides": []
},
"gridPos": {
"h": 5,
"w": 3,
"x": 0,
"y": 24
},
"id": 18,
"options": {
"legend": {
"calcs": [],
"displayMode": "list",
"placement": "bottom",
"showLegend": false
},
"tooltip": {
"mode": "single",
"sort": "none"
}
},
"pluginVersion": "9.1.0",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "count(tgi_request_count{container=\"$service\"})",
"legendFormat": "Replicas",
"range": true,
"refId": "A"
}
],
"title": "Number of replicas",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"mappings": [],
"thresholds": {
"mode": "percentage",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "orange",
"value": 70
},
{
"color": "red",
"value": 85
}
]
}
},
"overrides": []
},
"gridPos": {
"h": 5,
"w": 3,
"x": 3,
"y": 24
},
"id": 32,
"options": {
"minVizHeight": 75,
"minVizWidth": 75,
"orientation": "auto",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"showThresholdLabels": false,
"showThresholdMarkers": true,
"sizing": "auto"
},
"pluginVersion": "10.4.2",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "sum(tgi_queue_size{container=\"$service\"})",
"legendFormat": "__auto",
"range": true,
"refId": "A"
}
],
"title": "Queue Size",
"type": "gauge"
},
{
"collapsed": false,
"gridPos": {
"h": 1,
"w": 24,
"x": 0,
"y": 29
},
"id": 26,
"panels": [],
"title": "Batching",
"type": "row"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "bars",
"fillOpacity": 50,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "normal"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
}
},
"overrides": []
},
"gridPos": {
"h": 5,
"w": 6,
"x": 0,
"y": 30
},
"id": 29,
"maxDataPoints": 40,
"options": {
"legend": {
"calcs": [],
"displayMode": "list",
"placement": "bottom",
"showLegend": false
},
"tooltip": {
"mode": "single",
"sort": "none"
}
},
"pluginVersion": "9.1.0",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "avg(tgi_batch_current_max_tokens{container=\"$service\"})",
"legendFormat": "{{ pod }}",
"range": true,
"refId": "A"
}
],
"title": "Max tokens per batch",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
},
"unit": "none"
},
"overrides": [
{
"matcher": {
"id": "byName",
"options": "p50"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "green",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p90"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "orange",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p99"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "red",
"mode": "fixed"
}
}
]
}
]
},
"gridPos": {
"h": 9,
"w": 4,
"x": 6,
"y": 30
},
"id": 33,
"options": {
"legend": {
"calcs": [
"min",
"max"
],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"mode": "single",
"sort": "none"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_request_skipped_tokens_bucket{container=\"$service\"}[10m])))",
"legendFormat": "p50",
"range": true,
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_request_skipped_tokens_bucket{container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p90",
"range": true,
"refId": "B"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_request_skipped_tokens_bucket{container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p99",
"range": true,
"refId": "C"
}
],
"title": "Speculated Tokens",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
},
"unit": "none"
},
"overrides": [
{
"matcher": {
"id": "byName",
"options": "p50"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "green",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p90"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "orange",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p99"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "red",
"mode": "fixed"
}
}
]
}
]
},
"gridPos": {
"h": 9,
"w": 5,
"x": 10,
"y": 30
},
"id": 46,
"options": {
"legend": {
"calcs": [
"min",
"max"
],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"mode": "single",
"sort": "none"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_request_input_length_bucket{container=\"$service\"}[10m])))",
"legendFormat": "p50",
"range": true,
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_request_input_length_bucket{container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p90",
"range": true,
"refId": "B"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_request_input_length_bucket{container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p99",
"range": true,
"refId": "C"
}
],
"title": "Prompt Tokens",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
},
"unit": "s"
},
"overrides": [
{
"matcher": {
"id": "byName",
"options": "p50"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "green",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p90"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "orange",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p99"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "red",
"mode": "fixed"
}
}
]
}
]
},
"gridPos": {
"h": 9,
"w": 9,
"x": 15,
"y": 30
},
"id": 8,
"options": {
"legend": {
"calcs": [
"min",
"max"
],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"mode": "single",
"sort": "none"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_request_duration_bucket{container=\"$service\"}[10m])))",
"legendFormat": "p50",
"range": true,
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_request_duration_bucket{container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p90",
"range": true,
"refId": "B"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_request_duration_bucket{container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p99",
"range": true,
"refId": "C"
}
],
"title": "Latency quantiles",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "bars",
"fillOpacity": 50,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "normal"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
}
},
"overrides": []
},
"gridPos": {
"h": 4,
"w": 6,
"x": 0,
"y": 35
},
"id": 27,
"maxDataPoints": 40,
"options": {
"legend": {
"calcs": [],
"displayMode": "list",
"placement": "bottom",
"showLegend": false
},
"tooltip": {
"mode": "single",
"sort": "none"
}
},
"pluginVersion": "9.1.0",
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "avg(tgi_batch_current_size{container=\"$service\"})",
"legendFormat": "{{ pod }}",
"range": true,
"refId": "A"
}
],
"title": "Batch Size",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 30,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
}
},
"overrides": []
},
"gridPos": {
"h": 9,
"w": 6,
"x": 0,
"y": 39
},
"id": 28,
"maxDataPoints": 100,
"options": {
"legend": {
"calcs": [
"min",
"max"
],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"mode": "single",
"sort": "none"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "sum(increase(tgi_batch_concat{container=\"$service\"}[1m])) by (reason)",
"hide": false,
"legendFormat": "Reason: {{ reason }}",
"range": true,
"refId": "B"
}
],
"title": "Concatenates",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
},
"unit": "s"
},
"overrides": [
{
"matcher": {
"id": "byName",
"options": "p50"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "green",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p90"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "orange",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p99"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "red",
"mode": "fixed"
}
}
]
}
]
},
"gridPos": {
"h": 9,
"w": 9,
"x": 6,
"y": 39
},
"id": 31,
"options": {
"legend": {
"calcs": [
"min",
"max"
],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"mode": "single",
"sort": "none"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_request_queue_duration_bucket{container=\"$service\"}[10m])))",
"legendFormat": "p50",
"range": true,
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_request_queue_duration_bucket{container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p90",
"range": true,
"refId": "B"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_request_queue_duration_bucket{container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p99",
"range": true,
"refId": "C"
}
],
"title": "Queue quantiles",
"type": "timeseries"
},
{
"collapsed": false,
"gridPos": {
"h": 1,
"w": 24,
"x": 0,
"y": 48
},
"id": 22,
"panels": [],
"title": "Prefill",
"type": "row"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
},
"unit": "s"
},
"overrides": [
{
"matcher": {
"id": "byName",
"options": "p50"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "green",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p90"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "orange",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p99"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "red",
"mode": "fixed"
}
}
]
}
]
},
"gridPos": {
"h": 11,
"w": 12,
"x": 0,
"y": 49
},
"id": 7,
"options": {
"legend": {
"calcs": [
"min",
"max"
],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"mode": "single",
"sort": "none"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_batch_inference_duration_bucket{method=\"prefill\", container=\"$service\"}[10m])))",
"legendFormat": "p50",
"range": true,
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_batch_inference_duration_bucket{method=\"prefill\", container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p90",
"range": true,
"refId": "B"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_batch_inference_duration_bucket{method=\"prefill\", container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p99",
"range": true,
"refId": "C"
}
],
"title": "Prefill Quantiles",
"type": "timeseries"
},
{
"cards": {},
"color": {
"cardColor": "#5794F2",
"colorScale": "linear",
"colorScheme": "interpolateSpectral",
"exponent": 0.5,
"min": 0,
"mode": "opacity"
},
"dataFormat": "tsbuckets",
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"custom": {
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"scaleDistribution": {
"type": "linear"
}
}
},
"overrides": []
},
"gridPos": {
"h": 11,
"w": 12,
"x": 12,
"y": 49
},
"heatmap": {},
"hideZeroBuckets": false,
"highlightCards": true,
"id": 14,
"legend": {
"show": false
},
"maxDataPoints": 25,
"options": {
"calculate": false,
"calculation": {},
"cellGap": 2,
"cellValues": {},
"color": {
"exponent": 0.5,
"fill": "#5794F2",
"min": 0,
"mode": "scheme",
"reverse": false,
"scale": "exponential",
"scheme": "Spectral",
"steps": 128
},
"exemplars": {
"color": "rgba(255,0,255,0.7)"
},
"filterValues": {
"le": 1e-9
},
"legend": {
"show": false
},
"rowsFrame": {
"layout": "auto"
},
"showValue": "never",
"tooltip": {
"mode": "single",
"showColorScale": false,
"yHistogram": false
},
"yAxis": {
"axisPlacement": "left",
"decimals": 1,
"reverse": false,
"unit": "s"
}
},
"pluginVersion": "10.4.2",
"reverseYBuckets": false,
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"exemplar": true,
"expr": "sum(increase(tgi_batch_inference_duration_bucket{method=\"prefill\", container=\"$service\"}[5m])) by (le)",
"format": "heatmap",
"interval": "",
"legendFormat": "{{ le }}",
"range": true,
"refId": "A"
}
],
"title": "Prefill Latency",
"tooltip": {
"show": true,
"showHistogram": false
},
"type": "heatmap",
"xAxis": {
"show": true
},
"yAxis": {
"decimals": 1,
"format": "s",
"logBase": 1,
"show": true
},
"yBucketBound": "auto"
},
{
"collapsed": false,
"gridPos": {
"h": 1,
"w": 24,
"x": 0,
"y": 60
},
"id": 24,
"panels": [],
"title": "Decode",
"type": "row"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
},
"unit": "s"
},
"overrides": [
{
"matcher": {
"id": "byName",
"options": "p50"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "green",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p90"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "orange",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p99"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "red",
"mode": "fixed"
}
}
]
}
]
},
"gridPos": {
"h": 11,
"w": 12,
"x": 0,
"y": 61
},
"id": 11,
"options": {
"legend": {
"calcs": [
"min",
"max"
],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"mode": "single",
"sort": "none"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_batch_inference_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))",
"legendFormat": "p50",
"range": true,
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_batch_inference_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p90",
"range": true,
"refId": "B"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_batch_inference_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p99",
"range": true,
"refId": "C"
}
],
"title": "Decode quantiles",
"type": "timeseries"
},
{
"cards": {},
"color": {
"cardColor": "#5794F2",
"colorScale": "linear",
"colorScheme": "interpolateSpectral",
"exponent": 0.5,
"min": 0,
"mode": "opacity"
},
"dataFormat": "tsbuckets",
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"custom": {
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"scaleDistribution": {
"type": "linear"
}
}
},
"overrides": []
},
"gridPos": {
"h": 11,
"w": 12,
"x": 12,
"y": 61
},
"heatmap": {},
"hideZeroBuckets": false,
"highlightCards": true,
"id": 15,
"legend": {
"show": false
},
"maxDataPoints": 25,
"options": {
"calculate": false,
"calculation": {},
"cellGap": 2,
"cellValues": {},
"color": {
"exponent": 0.5,
"fill": "#5794F2",
"min": 0,
"mode": "scheme",
"reverse": false,
"scale": "exponential",
"scheme": "Spectral",
"steps": 128
},
"exemplars": {
"color": "rgba(255,0,255,0.7)"
},
"filterValues": {
"le": 1e-9
},
"legend": {
"show": false
},
"rowsFrame": {
"layout": "auto"
},
"showValue": "never",
"tooltip": {
"mode": "single",
"showColorScale": false,
"yHistogram": false
},
"yAxis": {
"axisPlacement": "left",
"decimals": 1,
"reverse": false,
"unit": "s"
}
},
"pluginVersion": "10.4.2",
"reverseYBuckets": false,
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"exemplar": true,
"expr": "sum(increase(tgi_batch_inference_duration_bucket{method=\"decode\", container=\"$service\"}[5m])) by (le)",
"format": "heatmap",
"interval": "",
"legendFormat": "{{ le }}",
"range": true,
"refId": "A"
}
],
"title": "Decode Latency",
"tooltip": {
"show": true,
"showHistogram": false
},
"type": "heatmap",
"xAxis": {
"show": true
},
"yAxis": {
"decimals": 1,
"format": "s",
"logBase": 1,
"show": true
},
"yBucketBound": "auto"
},
{
"collapsed": false,
"gridPos": {
"h": 1,
"w": 24,
"x": 0,
"y": 72
},
"id": 43,
"panels": [],
"title": "Debug",
"type": "row"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
},
"unit": "s"
},
"overrides": [
{
"matcher": {
"id": "byName",
"options": "p50"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "green",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p90"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "orange",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p99"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "red",
"mode": "fixed"
}
}
]
}
]
},
"gridPos": {
"h": 11,
"w": 6,
"x": 0,
"y": 73
},
"id": 38,
"options": {
"legend": {
"calcs": [
"min",
"max"
],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"mode": "single",
"sort": "none"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_batch_forward_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))",
"legendFormat": "p50",
"range": true,
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_batch_forward_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p90",
"range": true,
"refId": "B"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_batch_forward_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p99",
"range": true,
"refId": "C"
}
],
"title": "Forward quantiles",
"type": "timeseries"
},
{
"cards": {},
"color": {
"cardColor": "#5794F2",
"colorScale": "linear",
"colorScheme": "interpolateSpectral",
"exponent": 0.5,
"min": 0,
"mode": "opacity"
},
"dataFormat": "tsbuckets",
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"custom": {
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"scaleDistribution": {
"type": "linear"
}
}
},
"overrides": []
},
"gridPos": {
"h": 11,
"w": 6,
"x": 6,
"y": 73
},
"heatmap": {},
"hideZeroBuckets": false,
"highlightCards": true,
"id": 35,
"legend": {
"show": false
},
"maxDataPoints": 25,
"options": {
"calculate": false,
"calculation": {},
"cellGap": 2,
"cellValues": {},
"color": {
"exponent": 0.5,
"fill": "#5794F2",
"min": 0,
"mode": "scheme",
"reverse": false,
"scale": "exponential",
"scheme": "Spectral",
"steps": 128
},
"exemplars": {
"color": "rgba(255,0,255,0.7)"
},
"filterValues": {
"le": 1e-9
},
"legend": {
"show": false
},
"rowsFrame": {
"layout": "auto"
},
"showValue": "never",
"tooltip": {
"mode": "single",
"showColorScale": false,
"yHistogram": false
},
"yAxis": {
"axisPlacement": "left",
"decimals": 1,
"reverse": false,
"unit": "s"
}
},
"pluginVersion": "10.4.2",
"reverseYBuckets": false,
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"exemplar": true,
"expr": "sum(increase(tgi_batch_forward_duration_bucket{method=\"decode\", container=\"$service\"}[5m])) by (le)",
"format": "heatmap",
"interval": "",
"legendFormat": "{{ le }}",
"range": true,
"refId": "A"
}
],
"title": "Forward Latency",
"tooltip": {
"show": true,
"showHistogram": false
},
"type": "heatmap",
"xAxis": {
"show": true
},
"yAxis": {
"decimals": 1,
"format": "s",
"logBase": 1,
"show": true
},
"yBucketBound": "auto"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
},
"unit": "s"
},
"overrides": [
{
"matcher": {
"id": "byName",
"options": "p50"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "green",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p90"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "orange",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p99"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "red",
"mode": "fixed"
}
}
]
}
]
},
"gridPos": {
"h": 11,
"w": 6,
"x": 12,
"y": 73
},
"id": 34,
"options": {
"legend": {
"calcs": [
"min",
"max"
],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"mode": "single",
"sort": "none"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_batch_decode_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))",
"legendFormat": "p50",
"range": true,
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_batch_decode_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p90",
"range": true,
"refId": "B"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_batch_decode_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p99",
"range": true,
"refId": "C"
}
],
"title": "Token Decode quantiles",
"type": "timeseries"
},
{
"cards": {},
"color": {
"cardColor": "#5794F2",
"colorScale": "linear",
"colorScheme": "interpolateSpectral",
"exponent": 0.5,
"min": 0,
"mode": "opacity"
},
"dataFormat": "tsbuckets",
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"custom": {
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"scaleDistribution": {
"type": "linear"
}
}
},
"overrides": []
},
"gridPos": {
"h": 11,
"w": 6,
"x": 18,
"y": 73
},
"heatmap": {},
"hideZeroBuckets": false,
"highlightCards": true,
"id": 40,
"legend": {
"show": false
},
"maxDataPoints": 25,
"options": {
"calculate": false,
"calculation": {},
"cellGap": 2,
"cellValues": {},
"color": {
"exponent": 0.5,
"fill": "#5794F2",
"min": 0,
"mode": "scheme",
"reverse": false,
"scale": "exponential",
"scheme": "Spectral",
"steps": 128
},
"exemplars": {
"color": "rgba(255,0,255,0.7)"
},
"filterValues": {
"le": 1e-9
},
"legend": {
"show": false
},
"rowsFrame": {
"layout": "auto"
},
"showValue": "never",
"tooltip": {
"mode": "single",
"showColorScale": false,
"yHistogram": false
},
"yAxis": {
"axisPlacement": "left",
"decimals": 1,
"reverse": false,
"unit": "s"
}
},
"pluginVersion": "10.4.2",
"reverseYBuckets": false,
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"exemplar": true,
"expr": "sum(increase(tgi_batch_decode_duration_bucket{method=\"decode\", container=\"$service\"}[5m])) by (le)",
"format": "heatmap",
"interval": "",
"legendFormat": "{{ le }}",
"range": true,
"refId": "A"
}
],
"title": "Token Decode Latency",
"tooltip": {
"show": true,
"showHistogram": false
},
"type": "heatmap",
"xAxis": {
"show": true
},
"yAxis": {
"decimals": 1,
"format": "s",
"logBase": 1,
"show": true
},
"yBucketBound": "auto"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
},
"unit": "s"
},
"overrides": [
{
"matcher": {
"id": "byName",
"options": "p50"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "green",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p90"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "orange",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p99"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "red",
"mode": "fixed"
}
}
]
}
]
},
"gridPos": {
"h": 11,
"w": 6,
"x": 0,
"y": 84
},
"id": 42,
"options": {
"legend": {
"calcs": [
"min",
"max"
],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"mode": "single",
"sort": "none"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_batch_filter_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))",
"legendFormat": "p50",
"range": true,
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_batch_filter_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p90",
"range": true,
"refId": "B"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_batch_filter_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p99",
"range": true,
"refId": "C"
}
],
"title": "Filter Batch quantiles",
"type": "timeseries"
},
{
"cards": {},
"color": {
"cardColor": "#5794F2",
"colorScale": "linear",
"colorScheme": "interpolateSpectral",
"exponent": 0.5,
"min": 0,
"mode": "opacity"
},
"dataFormat": "tsbuckets",
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"custom": {
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"scaleDistribution": {
"type": "linear"
}
}
},
"overrides": []
},
"gridPos": {
"h": 11,
"w": 6,
"x": 6,
"y": 84
},
"heatmap": {},
"hideZeroBuckets": false,
"highlightCards": true,
"id": 39,
"legend": {
"show": false
},
"maxDataPoints": 25,
"options": {
"calculate": false,
"calculation": {},
"cellGap": 2,
"cellValues": {},
"color": {
"exponent": 0.5,
"fill": "#5794F2",
"min": 0,
"mode": "scheme",
"reverse": false,
"scale": "exponential",
"scheme": "Spectral",
"steps": 128
},
"exemplars": {
"color": "rgba(255,0,255,0.7)"
},
"filterValues": {
"le": 1e-9
},
"legend": {
"show": false
},
"rowsFrame": {
"layout": "auto"
},
"showValue": "never",
"tooltip": {
"mode": "single",
"showColorScale": false,
"yHistogram": false
},
"yAxis": {
"axisPlacement": "left",
"decimals": 1,
"reverse": false,
"unit": "s"
}
},
"pluginVersion": "10.4.2",
"reverseYBuckets": false,
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"exemplar": true,
"expr": "sum(increase(tgi_batch_filter_duration_bucket{method=\"decode\", container=\"$service\"}[5m])) by (le)",
"format": "heatmap",
"interval": "",
"legendFormat": "{{ le }}",
"range": true,
"refId": "A"
}
],
"title": "Filter Batch Latency",
"tooltip": {
"show": true,
"showHistogram": false
},
"type": "heatmap",
"xAxis": {
"show": true
},
"yAxis": {
"decimals": 1,
"format": "s",
"logBase": 1,
"show": true
},
"yBucketBound": "auto"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "never",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
},
"unit": "s"
},
"overrides": [
{
"matcher": {
"id": "byName",
"options": "p50"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "green",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p90"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "orange",
"mode": "fixed"
}
}
]
},
{
"matcher": {
"id": "byName",
"options": "p99"
},
"properties": [
{
"id": "color",
"value": {
"fixedColor": "red",
"mode": "fixed"
}
}
]
}
]
},
"gridPos": {
"h": 11,
"w": 6,
"x": 12,
"y": 84
},
"id": 36,
"options": {
"legend": {
"calcs": [
"min",
"max"
],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"mode": "single",
"sort": "none"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.5, sum by (le) (rate(tgi_batch_concat_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))",
"legendFormat": "p50",
"range": true,
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.9, sum by (le) (rate(tgi_batch_concat_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p90",
"range": true,
"refId": "B"
},
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"expr": "histogram_quantile(0.99, sum by (le) (rate(tgi_batch_concat_duration_bucket{method=\"decode\", container=\"$service\"}[10m])))",
"hide": false,
"legendFormat": "p99",
"range": true,
"refId": "C"
}
],
"title": "Batch Concat quantiles",
"type": "timeseries"
},
{
"cards": {},
"color": {
"cardColor": "#5794F2",
"colorScale": "linear",
"colorScheme": "interpolateSpectral",
"exponent": 0.5,
"min": 0,
"mode": "opacity"
},
"dataFormat": "tsbuckets",
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"fieldConfig": {
"defaults": {
"custom": {
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"scaleDistribution": {
"type": "linear"
}
}
},
"overrides": []
},
"gridPos": {
"h": 11,
"w": 6,
"x": 18,
"y": 84
},
"heatmap": {},
"hideZeroBuckets": false,
"highlightCards": true,
"id": 41,
"legend": {
"show": false
},
"maxDataPoints": 25,
"options": {
"calculate": false,
"calculation": {},
"cellGap": 2,
"cellValues": {},
"color": {
"exponent": 0.5,
"fill": "#5794F2",
"min": 0,
"mode": "scheme",
"reverse": false,
"scale": "exponential",
"scheme": "Spectral",
"steps": 128
},
"exemplars": {
"color": "rgba(255,0,255,0.7)"
},
"filterValues": {
"le": 1e-9
},
"legend": {
"show": false
},
"rowsFrame": {
"layout": "auto"
},
"showValue": "never",
"tooltip": {
"mode": "single",
"showColorScale": false,
"yHistogram": false
},
"yAxis": {
"axisPlacement": "left",
"decimals": 1,
"reverse": false,
"unit": "s"
}
},
"pluginVersion": "10.4.2",
"reverseYBuckets": false,
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"editorMode": "code",
"exemplar": true,
"expr": "sum(increase(tgi_batch_concat_duration_bucket{method=\"decode\", container=\"$service\"}[5m])) by (le)",
"format": "heatmap",
"interval": "",
"legendFormat": "{{ le }}",
"range": true,
"refId": "A"
}
],
"title": "Batch Concat latency",
"tooltip": {
"show": true,
"showHistogram": false
},
"type": "heatmap",
"xAxis": {
"show": true
},
"yAxis": {
"decimals": 1,
"format": "s",
"logBase": 1,
"show": true
},
"yBucketBound": "auto"
}
],
"refresh": "",
"schemaVersion": 39,
"tags": [],
"templating": {
"list": [
{
"current": {
"selected": false,
"text": "gpu-txt-gen-cohereforai-c4ai-command-r-plu-ba7f1",
"value": "gpu-txt-gen-cohereforai-c4ai-command-r-plu-ba7f1"
},
"datasource": {
"type": "prometheus",
"uid": "${DS_PROMETHEUS_EKS API INFERENCE PROD}"
},
"definition": "label_values(tgi_request_count, container)",
"hide": 0,
"includeAll": false,
"multi": false,
"name": "service",
"options": [],
"query": {
"query": "label_values(tgi_request_count, container)",
"refId": "StandardVariableQuery"
},
"refresh": 1,
"regex": "",
"skipUrlSync": false,
"sort": 1,
"type": "query"
}
]
},
"time": {
"from": "now-30m",
"to": "now-30s"
},
"timepicker": {
"nowDelay": "30s"
},
"timezone": "",
"title": "Text Generation Inference",
"uid": "RHSk7EL4kdqsd",
"version": 12,
"weekStart": ""
}
================================================
FILE: backends/client/Cargo.toml
================================================
[package]
name = "text-generation-client"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true
[dependencies]
async-trait = "^0.1"
base64 = { workspace = true }
futures = "^0.3"
grpc-metadata = { path = "../grpc-metadata" }
prost = "^0.12"
thiserror = "^1.0"
tokio = { version = "^1.32", features = ["sync"] }
tonic = "^0.10"
tower = "^0.4"
tracing = "^0.1"
[build-dependencies]
tonic-build = "0.10.1"
prost-build = "0.12.1"
================================================
FILE: backends/client/build.rs
================================================
use std::fs;
fn main() -> Result<(), Box> {
println!("cargo:rerun-if-changed=../../proto/");
fs::create_dir_all("src/v2/pb").unwrap_or(());
let mut config = prost_build::Config::new();
config.protoc_arg("--experimental_allow_proto3_optional");
tonic_build::configure()
.build_client(true)
.build_server(false)
.out_dir("src/v2/pb")
.include_file("mod.rs")
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
.map_err(|e| match e.kind(){
std::io::ErrorKind::NotFound => {panic!("`protoc` not found, install libprotoc")},
std::io::ErrorKind::Other => {panic!("`protoc` version unsupported, upgrade protoc: https://github.com/protocolbuffers/protobuf/releases")},
e => {e}
}).unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
fs::create_dir_all("src/v3/pb").unwrap_or(());
let mut config = prost_build::Config::new();
config.protoc_arg("--experimental_allow_proto3_optional");
tonic_build::configure()
.build_client(true)
.build_server(false)
.out_dir("src/v3/pb")
.include_file("mod.rs")
.compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"])
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
Ok(())
}
================================================
FILE: backends/client/src/lib.rs
================================================
//! Text Generation gRPC client library
pub mod v2;
pub mod v3;
use async_trait::async_trait;
use base64::{engine::general_purpose::STANDARD, Engine};
use thiserror::Error;
use tonic::transport;
use tonic::Status;
pub use v3::{Chunk, Image, Input, InputChunk};
#[async_trait]
pub trait Health {
/// Check if a generate server is healthy by asking it to allocate a tensor on device
async fn device_health(&self) -> Result<()>;
/// Check if a generate server is healthy by doing a forward pass.
/// EXPENSIVE
async fn model_health(&self) -> Result<()>;
}
#[derive(Debug)]
pub struct ShardInfo {
pub requires_padding: bool,
pub dtype: String,
pub device_type: String,
pub window_size: Option,
pub speculate: u32,
}
#[derive(Error, Debug, Clone)]
pub enum ClientError {
#[error("Could not connect to Text Generation server: {0}")]
Connection(String),
#[error("Server error: {0}")]
Generation(String),
#[error("Sharded results are empty")]
EmptyResults,
}
impl From for ClientError {
fn from(err: Status) -> Self {
let err = Self::Generation(err.message().to_string());
tracing::error!("{err}");
err
}
}
impl From for ClientError {
fn from(err: transport::Error) -> Self {
let err = Self::Connection(err.to_string());
tracing::error!("{err}");
err
}
}
// Small convenience re-wrapping of `Chunk`.
impl From for InputChunk {
fn from(chunk: Chunk) -> Self {
InputChunk { chunk: Some(chunk) }
}
}
/// Convert input chunks to a stringly-typed input for backwards
/// compat for backends that haven't implemented chunked inputs.
pub trait ChunksToString {
/// Convert chunks to string.
fn chunks_to_string(&self) -> String;
}
impl ChunksToString for Vec {
fn chunks_to_string(&self) -> String {
let mut output = String::new();
self.iter().for_each(|c| match &c.chunk {
Some(Chunk::Text(text)) => output.push_str(text),
Some(Chunk::Image(Image { data, mimetype })) => {
let encoded = STANDARD.encode(data);
output.push_str(&format!("", mimetype, encoded))
}
// We don't create empty chunks, so this should be unreachable.
None => unreachable!("Chunks should never be empty"),
});
output
}
}
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
pub type Result = std::result::Result;
================================================
FILE: backends/client/src/v2/client.rs
================================================
/// Single shard Client
use crate::v2::pb;
use crate::{ClientError, Result};
use crate::WARMUP_IMAGE_BASE64;
use grpc_metadata::InjectTelemetryContext;
use pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
use pb::generate::v2::*;
use std::cmp::min;
use std::time::Duration;
use tonic::transport::{Channel, Uri};
use tracing::instrument;
/// Text Generation Inference gRPC client
#[derive(Debug, Clone)]
pub struct Client {
stub: TextGenerationServiceClient,
}
impl Client {
/// Returns a client connected to the given url
pub async fn connect(uri: Uri) -> Result {
let channel = Channel::builder(uri).connect().await?;
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
}
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result {
let channel = Channel::from_shared("http://[::]:50051".to_string())
.unwrap()
.connect_with_connector(tower::service_fn(move |_: Uri| {
tokio::net::UnixStream::connect(path.clone())
}))
.await?;
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
}
/// Returns a list of uris or unix sockets of all shards
#[instrument(skip(self))]
pub async fn service_discovery(&mut self) -> Result> {
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
let response = self.stub.service_discovery(request).await.map_err(|_| {
ClientError::Connection("Server does not support v2 interface".to_string())
})?;
let urls = response
.into_inner()
.urls
.into_iter()
// Remove unix socket prefix
.map(|url| match url.strip_prefix("unix://") {
None => url,
Some(stripped_url) => stripped_url.to_string(),
})
.collect();
Ok(urls)
}
/// Get model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result {
let request = tonic::Request::new(InfoRequest {}).inject_context();
let response = self.stub.info(request).await?.into_inner();
Ok(response)
}
/// Get model health
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result {
let request = tonic::Request::new(HealthRequest {}).inject_context();
let response = self.stub.health(request).await?.into_inner();
Ok(response)
}
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> {
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
self.stub.clear_cache(request).await?;
Ok(())
}
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec,
) -> Result> {
let request = tonic::Request::new(FilterBatchRequest {
batch_id,
request_ids,
})
.inject_context();
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
Ok(filtered_batch.batch)
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip_all)]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_size: Option,
) -> Result> {
let mut n_tokens = 0;
let mut requests = Vec::new();
// Create requests
while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
inputs.push_str(&format!(
"",
));
}
requests.push(Request {
id: 0,
inputs,
// We truncate the input on the server side to be sure that it has the correct size
truncate,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
top_k: 10,
top_p: 0.9,
typical_p: 0.9,
do_sample: false,
seed: 0,
repetition_penalty: 1.2,
frequency_penalty: 0.1,
watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,
stop_sequences: vec![],
ignore_eos_token: true,
}),
prefill_logprobs: true,
top_n_tokens: 20,
});
n_tokens += max_input_length;
// Check max_batch_size
if Some(requests.len()) == max_batch_size {
break;
}
}
let batch = Batch {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: 0,
};
let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_input_length,
max_prefill_tokens,
max_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens)
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec, Option, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((
response.generations,
response.batch,
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
))
}
/// Generate one token for each request in the given cached batches
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::()))]
pub async fn decode(
&mut self,
batches: Vec,
) -> Result<(Vec, Option, DecodeTimings)> {
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self.stub.decode(request).await?.into_inner();
Ok((
response.generations,
response.batch,
DecodeTimings::new(
response.concat_ns,
response.forward_ns,
response.decode_ns,
response.total_ns,
),
))
}
}
pub struct PrefillTimings {
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl PrefillTimings {
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
}
}
pub struct DecodeTimings {
pub concat: Option,
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl DecodeTimings {
fn new(concat_ns: Option, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
concat: concat_ns.map(Duration::from_nanos),
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
}
}
================================================
FILE: backends/client/src/v2/mod.rs
================================================
#[allow(clippy::derive_partial_eq_without_eq)]
mod pb;
mod client;
mod sharded_client;
pub use client::Client;
pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, InfoResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;
================================================
FILE: backends/client/src/v2/sharded_client.rs
================================================
/// Multi shard Client
use crate::{v2, Health, ShardInfo};
use crate::{ClientError, Result};
use crate::v2::InfoResponse;
use async_trait::async_trait;
use futures::future::join_all;
use tonic::transport::Uri;
use tracing::instrument;
use v2::client::{DecodeTimings, PrefillTimings};
use v2::{
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
#[derive(Debug, Clone)]
/// Text Generation Inference gRPC multi client
pub struct ShardedClient {
clients: Vec,
}
impl ShardedClient {
fn new(clients: Vec) -> Self {
Self { clients }
}
/// Create a new ShardedClient from a master client. The master client will communicate with
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
async fn from_master_client(mut master_client: Client) -> Result {
// Get all uris/unix sockets from the master client
let uris = master_client.service_discovery().await?;
let futures = uris.into_iter().map(Client::connect_uds);
let clients: Result> = join_all(futures).await.into_iter().collect();
Ok(Self::new(clients?))
}
/// Returns a client connected to the given uri
pub async fn connect(uri: Uri) -> Result {
let master_client = Client::connect(uri).await?;
Self::from_master_client(master_client).await
}
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result {
let master_client = Client::connect_uds(path).await?;
Self::from_master_client(master_client).await
}
/// Get the model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.info())
.collect();
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
}
/// GRPC health check
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.health())
.collect();
join_all(futures).await.pop().unwrap()
}
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.clear_cache(batch_id))
.collect();
join_all(futures).await.into_iter().collect()
}
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec,
) -> Result> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_size: Option,
) -> Result> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| {
Box::pin(client.warmup(
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_size,
))
})
.collect();
// Take the minimum value
let results = join_all(futures)
.await
.into_iter()
.collect::>>>()?;
Ok(results.into_iter().flatten().min())
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec, Option, PrefillTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result, Option, PrefillTimings)>> =
join_all(futures).await.into_iter().collect();
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
/// Generate one token for each request in the given cached batches
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
pub async fn decode(
&mut self,
batches: Vec,
) -> Result<(Vec, Option, DecodeTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.decode(batches.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result, Option, DecodeTimings)>> =
join_all(futures).await.into_iter().collect();
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
}
impl From for ShardInfo {
fn from(value: InfoResponse) -> Self {
Self {
requires_padding: value.requires_padding,
dtype: value.dtype,
device_type: value.device_type,
window_size: value.window_size,
speculate: value.speculate,
}
}
}
#[async_trait]
impl Health for ShardedClient {
async fn device_health(&self) -> Result<()> {
self.clone().health().await?;
Ok(())
}
async fn model_health(&self) -> Result<()> {
// Dummy batch of 1 token and 1 generated token
let liveness_request = Request {
id: u64::MAX,
inputs: "liveness".to_string(),
truncate: 10,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
stop_sequences: vec![],
ignore_eos_token: false,
}),
top_n_tokens: 0,
};
let batch = Batch {
id: u64::MAX,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
};
self.clone().prefill(batch).await?;
Ok(())
}
}
================================================
FILE: backends/client/src/v3/client.rs
================================================
use crate::v3::{pb, Chunk};
use crate::{ClientError, Result, WARMUP_IMAGE_BASE64};
/// Single shard Client
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use grpc_metadata::InjectTelemetryContext;
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
use pb::generate::v3::*;
use std::cmp::min;
use std::time::Duration;
use tonic::transport::{Channel, Uri};
use tracing::instrument;
/// Text Generation Inference gRPC client
#[derive(Debug, Clone)]
pub struct Client {
stub: TextGenerationServiceClient,
}
impl Client {
/// Returns a client connected to the given url
pub async fn connect(uri: Uri) -> Result {
let channel = Channel::builder(uri).connect().await?;
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
}
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result {
let channel = Channel::from_shared("http://[::]:50051".to_string())
.unwrap()
.connect_with_connector(tower::service_fn(move |_: Uri| {
tokio::net::UnixStream::connect(path.clone())
}))
.await?;
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
}
/// Returns a list of uris or unix sockets of all shards
#[instrument(skip(self))]
pub async fn service_discovery(&mut self) -> Result> {
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
let response = self.stub.service_discovery(request).await.map_err(|_| {
ClientError::Connection("Server does not support v3 interface".to_string())
})?;
let urls = response
.into_inner()
.urls
.into_iter()
// Remove unix socket prefix
.map(|url| match url.strip_prefix("unix://") {
None => url,
Some(stripped_url) => stripped_url.to_string(),
})
.collect();
Ok(urls)
}
/// Get model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result {
let request = tonic::Request::new(InfoRequest {}).inject_context();
let response = self.stub.info(request).await?.into_inner();
Ok(response)
}
/// Get model health
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result {
let request = tonic::Request::new(HealthRequest {}).inject_context();
let response = self.stub.health(request).await?.into_inner();
Ok(response)
}
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> {
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
self.stub.clear_cache(request).await?;
Ok(())
}
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec,
) -> Result> {
let request = tonic::Request::new(FilterBatchRequest {
batch_id,
request_ids,
})
.inject_context();
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
Ok(filtered_batch.batch)
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip_all)]
pub async fn warmup(
&mut self,
max_input_tokens: Option,
max_prefill_tokens: u32,
max_total_tokens: Option,
max_batch_size: Option,
) -> Result<(Option, u32, u32)> {
let mut n_tokens = 0;
let mut requests = Vec::new();
// Create requests
while n_tokens < max_prefill_tokens {
let mut truncate = max_prefill_tokens - n_tokens;
if let Some(max_input_tokens) = max_input_tokens {
truncate = min(max_input_tokens, truncate);
}
let mut input_chunks = Vec::new();
input_chunks.push(Chunk::Text("_test ".to_string().repeat(truncate as usize)).into());
if n_tokens == 0 {
input_chunks.push(
Chunk::Image(Image {
// Safe unwrap, because we control the data.
data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(),
mimetype: "image/jpeg;base64".to_string(),
})
.into(),
);
}
// Send stringly-typed inputs for compatibility for backends that haven't
// been updated to support chunks.
let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(truncate as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
inputs.push_str(&format!(
"",
));
}
let max_new_tokens = if let Some(max_total_tokens) = max_total_tokens {
max_total_tokens - truncate
} else {
1
};
requests.push(Request {
id: 0,
inputs,
input_chunks: Some(Input {
chunks: input_chunks,
}),
// We truncate the input on the server side to be sure that it has the correct size
truncate,
// Most request will have that
add_special_tokens: true,
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],
cache_len: 0,
chunk_len: None,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
top_k: 10,
top_p: 0.9,
typical_p: 0.9,
do_sample: false,
seed: 0,
repetition_penalty: 1.2,
frequency_penalty: 0.1,
watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens,
stop_sequences: vec![],
ignore_eos_token: true,
}),
prefill_logprobs: true,
top_n_tokens: 20,
adapter_id: None,
});
n_tokens += truncate;
// Check max_batch_size
if Some(requests.len()) == max_batch_size {
break;
}
}
let batch = Batch {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: max_input_tokens.unwrap_or(0),
max_blocks: 0,
};
let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_input_tokens,
max_prefill_tokens,
max_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Ok((
response.max_supported_total_tokens,
response.max_input_tokens,
response.max_total_tokens,
))
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
pub async fn prefill(
&mut self,
batch: Batch,
cached_batch: Option,
) -> Result<(Vec, Option, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest {
batch: Some(batch),
cached_batch,
})
.inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((
response.generations,
response.batch,
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
))
}
/// Generate one token for each request in the given cached batches
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::()))]
pub async fn decode(
&mut self,
batches: Vec,
) -> Result<(Vec, Option, DecodeTimings)> {
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self.stub.decode(request).await?.into_inner();
Ok((
response.generations,
response.batch,
DecodeTimings::new(
response.concat_ns,
response.forward_ns,
response.decode_ns,
response.total_ns,
),
))
}
}
pub struct PrefillTimings {
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl PrefillTimings {
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
}
}
pub struct DecodeTimings {
pub concat: Option,
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl DecodeTimings {
fn new(concat_ns: Option, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
concat: concat_ns.map(Duration::from_nanos),
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
}
}
================================================
FILE: backends/client/src/v3/mod.rs
================================================
#[allow(clippy::derive_partial_eq_without_eq)]
mod pb;
mod client;
mod sharded_client;
pub use client::Client;
pub use pb::generate::v3::{
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;
================================================
FILE: backends/client/src/v3/sharded_client.rs
================================================
/// Multi shard Client
use crate::{v3, Health, ShardInfo};
use crate::{ClientError, Result};
use crate::v3::{Chunk, InfoResponse, Input};
use async_trait::async_trait;
use futures::future::join_all;
use tonic::transport::Uri;
use tracing::instrument;
use v3::client::{DecodeTimings, PrefillTimings};
use v3::{
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
#[derive(Debug, Clone)]
/// Text Generation Inference gRPC multi client
pub struct ShardedClient {
clients: Vec,
}
impl ShardedClient {
fn new(clients: Vec) -> Self {
Self { clients }
}
/// Create a new ShardedClient from a master client. The master client will communicate with
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
async fn from_master_client(mut master_client: Client) -> Result {
// Get all uris/unix sockets from the master client
let uris = master_client.service_discovery().await?;
let futures = uris.into_iter().map(Client::connect_uds);
let clients: Result> = join_all(futures).await.into_iter().collect();
Ok(Self::new(clients?))
}
/// Returns a client connected to the given uri
pub async fn connect(uri: Uri) -> Result {
let master_client = Client::connect(uri).await?;
Self::from_master_client(master_client).await
}
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result {
let master_client = Client::connect_uds(path).await?;
Self::from_master_client(master_client).await
}
/// Get the model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.info())
.collect();
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
}
/// GRPC health check
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.health())
.collect();
join_all(futures).await.pop().unwrap()
}
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.clear_cache(batch_id))
.collect();
join_all(futures).await.into_iter().collect()
}
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec,
) -> Result> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: Option,
max_prefill_tokens: u32,
max_total_tokens: Option,
max_batch_size: Option,
) -> Result<(Option, u32, u32)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| {
Box::pin(client.warmup(
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_size,
))
})
.collect();
// Take the minimum value
let results = join_all(futures)
.await
.into_iter()
.collect::, u32, u32)>>>()?;
// Take the minimum value
// Different shards hold different parts of vocab, might yield
// different available block size.
let min = results
.iter()
.min()
.expect("Expect at least 1 warmup result");
Ok(*min)
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
pub async fn prefill(
&mut self,
batch: Batch,
cached_batch: Option,
) -> Result<(Vec, Option, PrefillTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result, Option, PrefillTimings)>> =
join_all(futures).await.into_iter().collect();
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
/// Generate one token for each request in the given cached batches
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
pub async fn decode(
&mut self,
batches: Vec,
) -> Result<(Vec, Option, DecodeTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.decode(batches.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result, Option, DecodeTimings)>> =
join_all(futures).await.into_iter().collect();
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
}
impl From for ShardInfo {
fn from(value: InfoResponse) -> Self {
Self {
requires_padding: value.requires_padding,
dtype: value.dtype,
device_type: value.device_type,
window_size: value.window_size,
speculate: value.speculate,
}
}
}
#[async_trait]
impl Health for ShardedClient {
async fn device_health(&self) -> Result<()> {
self.clone().health().await?;
Ok(())
}
async fn model_health(&self) -> Result<()> {
// Dummy batch of 1 token and 1 generated token
let liveness_request = Request {
id: u64::MAX,
inputs: "liveness".to_string(),
input_chunks: Some(Input {
chunks: vec![Chunk::Text("liveness".into()).into()],
}),
truncate: 10,
add_special_tokens: true,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
stop_sequences: vec![],
ignore_eos_token: false,
}),
top_n_tokens: 0,
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
cache_len: 0,
chunk_len: None,
adapter_id: None,
};
let batch = Batch {
id: u64::MAX,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
max_blocks: 1,
};
self.clone().prefill(batch, None).await?;
Ok(())
}
}
================================================
FILE: backends/gaudi/Makefile
================================================
mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
mkfile_dir := $(dir $(mkfile_path))
root_dir := ${mkfile_dir}/../..
HABANA_VERSION := 1.21.0
PYTORCH_VERSION := 2.6.0
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
image:
docker build --ulimit nofile=4096 -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION)
run-local-dev-container:
docker run -it \
--runtime=habana \
--ipc=host \
--cap-add=sys_nice \
--net=host \
-e HABANA_VISIBLE_DEVICES=all \
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
-e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
-e HF_TOKEN=`cat /home/ubuntu/.cache/huggingface/token` \
-e LOG_LEVEL=debug \
-e PORT=8080 \
-v /home/ubuntu/.cache/huggingface:/data \
-v $(PWD):/text-generation-inference \
-w /text-generation-inference \
vault.habana.ai/gaudi-docker/$(HABANA_VERSION)/ubuntu22.04/habanalabs/pytorch-installer-$(PYTORCH_VERSION):latest
install-dependencies:
pip install git+https://github.com/HabanaAI/DeepSpeed.git@$(HABANA_VERSION)
pip install outlines~=0.0.34
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
install-server:
make -C ${root_dir}/backends/gaudi/server install PROTO_PATH=../../../proto/v3
install-router:
make -C ${root_dir} install-router
install-launcher:
make -C ${root_dir} install-launcher
# use source to load the rust in path
local-dev-install: install-dependencies
bash -c 'source "$$HOME/.cargo/env" && \
make install-server && \
make install-router && \
make install-launcher'
# In order to run the integration tests, you need to first build the image (make -C backends/gaudi image)
run-integration-tests:
DOCKER_VOLUME=${root_dir}/data \
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
pytest --durations=0 -s -vv ${root_dir}/integration-tests --gaudi
run-integration-tests-with-all-models:
DOCKER_VOLUME=${root_dir}/data \
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
pytest --durations=0 -s -vv ${root_dir}/integration-tests --gaudi --gaudi-all-models
# This is used to capture the expected outputs for the integration tests offering an easy way to add more models to the integration tests
capture-expected-outputs-for-integration-tests:
pip install -U pip uv
DOCKER_VOLUME=${root_dir}/data \
HF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \
uv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests/capture_expected_outputs.py
================================================
FILE: backends/gaudi/README.md
================================================
# Text-generation-inference - Gaudi backend
## Description
This is the TGI backend for Intel Gaudi. This backend is composed of the tgi server optimized for Gaudi hardware.
## Build your own image
The simplest way to build TGI with the Gaudi backend is to use the provided `Makefile`:
Option 1: From the project root directory:
```bash
make -C backends/gaudi image
```
Option 2: From the Gaudi backend directory:
```bash
cd backends/gaudi
make image
```
You can now run the server with the following command:
Option 1: Sharded:
```bash
model=meta-llama/Llama-3.1-8B-Instruct
hf_token=$(cat ${HOME}/.cache/huggingface/token)
volume=${HOME}/.cache/huggingface
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
-p 8080:80 -v $volume:/data \
-e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \
tgi-gaudi --model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 8 --max-batch-prefill-tokens 2048
```
Option 2: Non-sharded:
```bash
model=meta-llama/Llama-3.1-8B-Instruct
hf_token=$(cat ${HOME}/.cache/huggingface/token)
volume=${HOME}/.cache/huggingface
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
-p 8080:80 -v $volume:/data \
-e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \
tgi-gaudi --model-id $model \
--max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 4 --max-batch-prefill-tokens 2048
```
## Contributing
### Local Development
This is useful if you want to run the server locally for better debugging.
```bash
make -C backends/gaudi run-local-dev-container
```
Then run the following command inside the container to install tgi for gaudi:
```bash
make -C backends/gaudi local-dev-install
```
Add rust to path:
```bash
. "$HOME/.cargo/env"
```
Option 1: Run the server (sharded model):
```bash
LOG_LEVEL=debug text-generation-launcher \
--model-id meta-llama/Llama-3.1-8B-Instruct \
--sharded true \
--num-shard 8 \
--max-input-tokens 512 \
--max-total-tokens 1024 \
--max-batch-size 8 \
--max-batch-prefill-tokens 2048
```
Option 2: Run the server (non-sharded model):
```bash
LOG_LEVEL=debug text-generation-launcher \
--model-id meta-llama/Llama-3.1-8B-Instruct \
--max-input-tokens 512 \
--max-total-tokens 1024 \
--max-batch-size 4 \
--max-batch-prefill-tokens 2048
```
You can then test the server with the following curl command from another terminal (can be outside the container):
```bash
curl 127.0.0.1:8080/generate \
-X POST \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
-H 'Content-Type: application/json'
```
### Integration tests
Install the dependencies:
```bash
pip install -r integration-tests/requirements.txt
```
To run the integration tests, you need to first build the image:
```bash
make -C backends/gaudi image
```
Then run the following command to run the integration tests (CI tests):
```bash
make -C backends/gaudi run-integration-tests
```
To run the integration tests with all models, you can run the following command:
```bash
make -C backends/gaudi run-integration-tests-with-all-models
```
To capture the expected outputs for the integration tests, you can run the following command:
```bash
make -C backends/gaudi capture-expected-outputs-for-integration-tests
```
#### How the integration tests works
The integration tests works as follows:
1. Start a tgi server in a container, similar to the command:
```bash
docker run --runtime=habana --ipc=host --cap-add=sys_nice \
-p 8080:80 -v $volume:/data \
-e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \
tgi-gaudi --model-id $model \
--max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 4 --max-batch-prefill-tokens 2048
```
2. Do a /generate request to the server, similar to the command:
```bash
curl 127.0.0.1:8080/generate \
-X POST \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
-H 'Content-Type: application/json'
```
3. Check the output of the server against the expected output:
```python
assert curl_output == expected_output
```
This is the repeated for a set of models and configurations.
================================================
FILE: backends/gaudi/examples/docker_commands/docker_commands.md
================================================
# Examples of Docker Commands for Gaudi Backend
This page gives a list of examples of docker run commands for some of the most popular models.
> **Note:** The parameters are chosen for Gaudi2 hardware to maximize performance on this given hardware, please adjust the parameters based on your hardware. For example, if you are using Gaudi3, you may want to increase the batch size.
## Default Precision (BF16)
### Llama3.1-8B on 1 card (BF16)
```bash
model=meta-llama/Meta-Llama-3.1-8B-Instruct
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e HF_TOKEN=$hf_token \
ghcr.io/huggingface/text-generation-inference:3.3.5-gaudi \
--model-id $model \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
```
### Llama3.1-70B 8 cards (BF16)
```bash
model=meta-llama/Meta-Llama-3.1-70B-Instruct
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e HF_TOKEN=$hf_token \
ghcr.io/huggingface/text-generation-inference:3.3.5-gaudi \
--model-id $model \
--sharded true --num-shard 8 \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
```
### Llava-v1.6-Mistral-7B on 1 card (BF16)
```bash
model=llava-hf/llava-v1.6-mistral-7b-hf
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
ghcr.io/huggingface/text-generation-inference:3.3.5-gaudi \
--model-id $model \
--max-input-tokens 4096 --max-batch-prefill-tokens 16384 \
--max-total-tokens 8192 --max-batch-size 4
```
## FP8 Precision
You could also set kv cache dtype to FP8 when launching the server, fp8_e4m3fn is supported in Gaudi
## Llama3-8B on 1 Card (FP8)
```bash
model=RedHatAI/Meta-Llama-3-8B-Instruct-FP8-KV
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e HF_TOKEN=$hf_token \
ghcr.io/huggingface/text-generation-inference:3.3.5-gaudi \
--model-id $model \
--kv-cache-dtype fp8_e4m3fn \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 2048 --max-batch-size 32 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64
```
## Llama3-70B on 8 cards (FP8)
```bash
model=RedHatAI/Meta-Llama-3-70B-Instruct-FP8
hf_token=YOUR_ACCESS_TOKEN
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run -p 8080:80 \
--runtime=habana \
--cap-add=sys_nice \
--ipc=host \
-v $volume:/data \
-e HF_TOKEN=$hf_token \
ghcr.io/huggingface/text-generation-inference:3.3.5-gaudi \
--model-id $model \
--kv-cache-dtype fp8_e4m3fn \
--sharded true --num-shard 8 \
--max-input-tokens 1024 --max-total-tokens 2048 \
--max-batch-prefill-tokens 4096 --max-batch-size 256 \
--max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512
```
================================================
FILE: backends/gaudi/server/.gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
text_generation_server/__pycache__/
text_generation_server/pb/__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
transformers
safetensors
flash-attention/
flash-attention-v2/
vllm/
llm-awq/
eetq/
mamba/
================================================
FILE: backends/gaudi/server/Makefile
================================================
include Makefile-flash-att
include Makefile-flash-att-v2
include Makefile-vllm
include Makefile-awq
include Makefile-eetq
include Makefile-selective-scan
PROTO_PATH ?= ../proto/v3
unit-tests:
pytest -s -vv -m "not private" tests
gen-server:
# Compile protos
pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir
mkdir text_generation_server/pb || true
python -m grpc_tools.protoc -I$(PROTO_PATH) --python_out=text_generation_server/pb \
--grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb $(PROTO_PATH)/generate.proto
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation_server/pb/__init__.py
install: gen-server
pip install pip --upgrade
pip install --no-deps -r requirements.txt
pip install -e "."
run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
install-poetry:
curl -sSL https://install.python-poetry.org | python3 -
update-lock:
rm poetry.lock
poetry lock --no-update
export-requirements:
poetry export -o requirements.txt --without-hashes
================================================
FILE: backends/gaudi/server/Makefile-awq
================================================
# Fork that adds only the correct stream to this kernel in order
# to make cuda graphs work.
awq_commit := bd1dc2d5254345cc76ab71894651fb821275bdd4
awq:
rm -rf llm-awq
git clone https://github.com/huggingface/llm-awq
build-awq: awq
cd llm-awq/ && git fetch && git checkout $(awq_commit)
cd llm-awq/awq/kernels && python setup.py build
install-awq: build-awq
pip uninstall awq_inference_engine -y || true
cd llm-awq/awq/kernels && python setup.py install
================================================
FILE: backends/gaudi/server/Makefile-eetq
================================================
eetq_commit := 1657b1504faa359e2ce0ac02999439d7ac8c74c0
eetq:
# Clone eetq
pip install packaging
git clone https://github.com/NetEase-FuXi/EETQ.git eetq
build-eetq: eetq
cd eetq && git fetch && git checkout $(eetq_commit) && git submodule update --init --recursive
cd eetq && python setup.py build
install-eetq: build-eetq
cd eetq && python setup.py install
================================================
FILE: backends/gaudi/server/Makefile-fbgemm
================================================
fbgemm_commit := v0.8.0
build-fbgemm:
@if [ ! -d "fbgemm" ]; then \
git clone https://github.com/pytorch/FBGEMM.git fbgemm; \
fi
cd fbgemm && git fetch && git checkout $(fbgemm_commit) && \
git submodule update --init --recursive && \
cd fbgemm_gpu && \
pip install -r requirements.txt && \
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build
install-fbgemm: build-fbgemm
cd fbgemm/fbgemm_gpu && \
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install
================================================
FILE: backends/gaudi/server/Makefile-flash-att
================================================
flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec
build-flash-attention:
if [ ! -d 'flash-attention' ]; then \
pip install -U packaging ninja --no-cache-dir && \
git clone https://github.com/HazyResearch/flash-attention.git; \
fi
cd flash-attention && git fetch && git checkout $(flash_att_commit) && \
MAX_JOBS=8 python setup.py build && cd csrc/layer_norm && python setup.py build && cd ../rotary && python setup.py build
install-flash-attention: build-flash-attention
cd flash-attention && git checkout $(flash_att_commit) && MAX_JOBS=8 python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install
================================================
FILE: backends/gaudi/server/Makefile-flash-att-v2
================================================
flash_att_v2_commit_cuda := v2.6.1
flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4
build-flash-attention-v2-cuda:
pip install -U packaging wheel
pip install flash-attn==$(flash_att_v2_commit_cuda)
install-flash-attention-v2-cuda: build-flash-attention-v2-cuda
echo "Flash v2 installed"
build-flash-attention-v2-rocm:
if [ ! -d 'flash-attention-v2' ]; then \
pip install -U packaging ninja --no-cache-dir && \
git clone https://github.com/mht-sharma/flash-attention.git flash-attention-v2 && \
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \
git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \
fi
install-flash-attention-v2-rocm: build-flash-attention-v2-rocm
cd flash-attention-v2 && \
GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install
================================================
FILE: backends/gaudi/server/Makefile-selective-scan
================================================
selective_scan_commit := 2a3704fd47ba817b415627b06fd796b971fdc137
causal-conv1d:
rm -rf causal-conv1d
git clone https://github.com/Dao-AILab/causal-conv1d.git
build-causal-conv1d: causal-conv1d
cd causal-conv1d/ && git checkout v1.1.1 # known latest working version tag
cd causal-conv1d/ && CAUSAL_CONV1D_FORCE_BUILD=TRUE python setup.py build
install-causal-conv1d: build-causal-conv1d
pip uninstall causal-conv1d -y || true
cd causal-conv1d/ && pip install .
# selective-scan dependends on causal-conv1d
selective-scan:
rm -rf mamba
git clone https://github.com/state-spaces/mamba.git mamba
build-selective-scan: selective-scan
cd mamba/ && git fetch && git checkout $(selective_scan_commit)
cd mamba && python setup.py build
install-selective-scan: install-causal-conv1d build-selective-scan
pip uninstall selective-scan-cuda -y || true
cd mamba && pip install .
build-all: build-causal-conv1d build-selective-scan
================================================
FILE: backends/gaudi/server/Makefile-vllm
================================================
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247
build-vllm-cuda:
if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \
git clone https://github.com/Narsil/vllm.git vllm; \
fi
cd vllm && git fetch origin && git checkout $(commit_cuda) && python setup.py build
install-vllm-cuda: build-vllm-cuda
cd vllm && git fetch origin && git checkout $(commit_cuda) && pip install -e .
build-vllm-rocm:
if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \
git clone https://github.com/mht-sharma/vllm.git vllm; \
fi
cd vllm && git fetch && git checkout $(commit_rocm) && \
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
install-vllm-rocm: build-vllm-rocm
cd vllm && git fetch && git checkout $(commit_rocm) && \
PYTORCH_ROCM_ARCH="gfx90a;gfx942" pip install -e .
================================================
FILE: backends/gaudi/server/README.md
================================================
# Text Generation Inference Python gRPC Server
A Python gRPC server for Text Generation Inference
## Install
```shell
make install
```
## Run
```shell
make run-dev
```
================================================
FILE: backends/gaudi/server/dill-0.3.7-patch.sh
================================================
#!/bin/bash
git clone -b dill-0.3.7 https://github.com/uqfoundation/dill.git
pushd dill
cat < dill-0.3.7.patch
diff --git a/dill/_dill.py b/dill/_dill.py
index d0cf543..f6eb662 100644
--- a/dill/_dill.py
+++ b/dill/_dill.py
@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered
XRangeType = range
from types import MappingProxyType as DictProxyType, new_class
from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError
-import __main__ as _main_module
+class _LazyMainModule(object):
+ _module = None
+ @property
+ def module(self):
+ if self._module is None:
+ import __main__ as _m_module
+ self._module = _m_module
+ return self._module
+_main_module = _LazyMainModule()
import marshal
import gc
# import zlib
@@ -353,7 +361,7 @@ class Pickler(StockPickler):
_fmode = kwds.pop('fmode', None)
_recurse = kwds.pop('recurse', None)
StockPickler.__init__(self, file, *args, **kwds)
- self._main = _main_module
+ self._main = _main_module.module
self._diff_cache = {}
self._byref = settings['byref'] if _byref is None else _byref
self._strictio = False #_strictio
@@ -435,12 +443,12 @@ class Unpickler(StockUnpickler):
settings = Pickler.settings
_ignore = kwds.pop('ignore', None)
StockUnpickler.__init__(self, *args, **kwds)
- self._main = _main_module
+ self._main = _main_module.module
self._ignore = settings['ignore'] if _ignore is None else _ignore
def load(self): #NOTE: if settings change, need to update attributes
obj = StockUnpickler.load(self)
- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'):
+ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'):
if not self._ignore:
# point obj class to main
try: obj.__class__ = getattr(self._main, type(obj).__name__)
@@ -1194,11 +1202,11 @@ def save_module_dict(pickler, obj):
logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj
pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
logger.trace(pickler, "# D1")
- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):
+ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__):
logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj
pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general?
logger.trace(pickler, "# D3")
- elif '__name__' in obj and obj != _main_module.__dict__ \\
+ elif '__name__' in obj and obj != _main_module.module.__dict__ \\
and type(obj['__name__']) is str \\
and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None):
logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj
diff --git a/dill/session.py b/dill/session.py
index 74234ab..1be8d89 100644
--- a/dill/session.py
+++ b/dill/session.py
@@ -233,7 +233,7 @@ def dump_module(
protocol = settings['protocol']
main = module
if main is None:
- main = _main_module
+ main = _main_module.module
elif isinstance(main, str):
main = _import_module(main)
if not isinstance(main, ModuleType):
@@ -501,7 +501,7 @@ def load_module(
pass
assert loaded is main
_restore_modules(unpickler, main)
- if main is _main_module or main is module:
+ if main is _main_module.module or main is module:
return None
else:
return main
EOF
git apply dill-0.3.7.patch
python -m pip install .
popd
rm -fr dill
================================================
FILE: backends/gaudi/server/dill-0.3.8-patch.sh
================================================
#!/bin/bash
git clone -b 0.3.8 https://github.com/uqfoundation/dill.git
pushd dill
cat < dill-0.3.8.patch
diff --git a/dill/_dill.py b/dill/_dill.py
index d42432f..1d251e6 100644
--- a/dill/_dill.py
+++ b/dill/_dill.py
@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered
XRangeType = range
from types import MappingProxyType as DictProxyType, new_class
from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError
-import __main__ as _main_module
+class _LazyMainModule(object):
+ _module = None
+ @property
+ def module(self):
+ if self._module is None:
+ import __main__ as _m_module
+ self._module = _m_module
+ return self._module
+_main_module = _LazyMainModule()
import marshal
import gc
# import zlib
@@ -355,7 +363,7 @@ class Pickler(StockPickler):
_fmode = kwds.pop('fmode', None)
_recurse = kwds.pop('recurse', None)
StockPickler.__init__(self, file, *args, **kwds)
- self._main = _main_module
+ self._main = _main_module.module
self._diff_cache = {}
self._byref = settings['byref'] if _byref is None else _byref
self._strictio = False #_strictio
@@ -437,12 +445,12 @@ class Unpickler(StockUnpickler):
settings = Pickler.settings
_ignore = kwds.pop('ignore', None)
StockUnpickler.__init__(self, *args, **kwds)
- self._main = _main_module
+ self._main = _main_module.module
self._ignore = settings['ignore'] if _ignore is None else _ignore
def load(self): #NOTE: if settings change, need to update attributes
obj = StockUnpickler.load(self)
- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'):
+ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'):
if not self._ignore:
# point obj class to main
try: obj.__class__ = getattr(self._main, type(obj).__name__)
@@ -1199,11 +1207,11 @@ def save_module_dict(pickler, obj):
logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj
pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
logger.trace(pickler, "# D1")
- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):
+ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__):
logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj
pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general?
logger.trace(pickler, "# D3")
- elif '__name__' in obj and obj != _main_module.__dict__ \\
+ elif '__name__' in obj and obj != _main_module.module.__dict__ \\
and type(obj['__name__']) is str \\
and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None):
logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj
diff --git a/dill/session.py b/dill/session.py
index e91068a..a921b43 100644
--- a/dill/session.py
+++ b/dill/session.py
@@ -233,7 +233,7 @@ def dump_module(
protocol = settings['protocol']
main = module
if main is None:
- main = _main_module
+ main = _main_module.module
elif isinstance(main, str):
main = _import_module(main)
if not isinstance(main, ModuleType):
@@ -501,7 +501,7 @@ def load_module(
pass
assert loaded is main
_restore_modules(unpickler, main)
- if main is _main_module or main is module:
+ if main is _main_module.module or main is module:
return None
else:
return main
EOF
git apply dill-0.3.8.patch
python -m pip install .
popd
rm -fr dill
================================================
FILE: backends/gaudi/server/pyproject.toml
================================================
[tool.poetry]
name = "text-generation-server"
version = "2.0.4"
description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene "]
[tool.poetry.scripts]
text-generation-server = 'text_generation_server.cli:app'
[tool.poetry.dependencies]
python = ">=3.9,<3.13"
protobuf = "^5.0"
grpcio = "^1.71.1"
grpcio-status = "*"
grpcio-reflection = "*"
grpc-interceptor = "^0.15.0"
typer = "^0.15.0"
loguru = "^0.7.3"
opentelemetry-api = "^1.32.0"
opentelemetry-exporter-otlp = "^1.32.0"
opentelemetry-instrumentation-grpc = "^0.53b0"
hf-transfer = "^0.1.9"
sentencepiece = "^0.2.0"
peft = "^0.15"
transformers = "^4.52.4"
numpy = "^1.26"
accelerate = "^1.7.0"
outlines= { version = "^0.0.36", optional = true }
prometheus-client = "^0.21.1"
py-cpuinfo = "^9.0.0"
[tool.poetry.group.dev.dependencies]
grpcio-tools = "*"
pytest = "^8.3.5"
[tool.pytest.ini_options]
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.poetry.requires-plugins]
poetry-plugin-export = ">=1.8"
================================================
FILE: backends/gaudi/server/requirements.txt
================================================
accelerate==1.7.0 ; python_version >= "3.9" and python_version < "3.13"
annotated-types==0.7.0 ; python_version >= "3.9" and python_version < "3.13"
attrs==25.3.0 ; python_version >= "3.9" and python_version < "3.13"
certifi==2025.1.31 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.1 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.8 ; python_version >= "3.9" and python_version < "3.13"
cloudpickle==3.1.1 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Windows" or python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
deprecated==1.2.18 ; python_version >= "3.9" and python_version < "3.13"
diffusers==0.31.0 ; python_version >= "3.9" and python_version < "3.13"
diskcache==5.6.3 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.18.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2025.3.2 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.70.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.71.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.71.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.72.0rc1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.9 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.30.2 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==8.6.1 ; python_version >= "3.9" and python_version < "3.13"
interegular==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
jinja2==3.1.6 ; python_version >= "3.9" and python_version < "3.13"
joblib==1.4.2 ; python_version >= "3.9" and python_version < "3.13"
jsonschema-specifications==2024.10.1 ; python_version >= "3.9" and python_version < "3.13"
jsonschema==4.23.0 ; python_version >= "3.9" and python_version < "3.13"
lark==1.2.2 ; python_version >= "3.9" and python_version < "3.13"
llvmlite==0.43.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.7.3 ; python_version >= "3.9" and python_version < "3.13"
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
markupsafe==3.0.2 ; python_version >= "3.9" and python_version < "3.13"
mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13"
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
nest-asyncio==1.6.0 ; python_version >= "3.9" and python_version < "3.13"
networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
numba==0.60.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-common==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
optimum==1.24.0 ; python_version >= "3.9" and python_version < "3.13"
outlines==0.0.36 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.2 ; python_version >= "3.9" and python_version < "3.13"
peft==0.15.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==11.2.1 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
protobuf==5.29.4 ; python_version >= "3.9" and python_version < "3.13"
psutil==7.0.0 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pydantic-core==2.33.1 ; python_version >= "3.9" and python_version < "3.13"
pydantic==2.11.3 ; python_version >= "3.9" and python_version < "3.13"
pygments==2.19.1 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
referencing==0.36.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.11.6 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==14.0.0 ; python_version >= "3.9" and python_version < "3.13"
rpds-py==0.24.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.5.3 ; python_version >= "3.9" and python_version < "3.13"
scikit-learn==1.6.1 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentence-transformers==3.3.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==78.1.0 ; python_version >= "3.12" and python_version < "3.13"
shellingham==1.5.4 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
threadpoolctl==3.6.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.67.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.52.4 ; python_version >= "3.9" and python_version < "3.13"
triton==3.2.0 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
typer==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.13.2 ; python_version >= "3.9" and python_version < "3.13"
typing-inspection==0.4.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.4.0 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.2.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.17.2 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.21.0 ; python_version >= "3.9" and python_version < "3.13"
================================================
FILE: backends/gaudi/server/text_generation_server/__init__.py
================================================
================================================
FILE: backends/gaudi/server/text_generation_server/adapters/__init__.py
================================================
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/__init__.py
# License: Apache License Version 2.0, January 2004
from text_generation_server.adapters.weights import (
AdapterBatchData,
AdapterBatchMetadata,
)
__all__ = [
"AdapterBatchData",
"AdapterBatchMetadata",
]
================================================
FILE: backends/gaudi/server/text_generation_server/adapters/config.py
================================================
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/config.py
# License: Apache License Version 2.0, January 2004
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, Set, Tuple
import torch
from text_generation_server.adapters.weights import AdapterWeights
@dataclass
class ModuleMap:
module_name: str
module_weights: Dict[str, Tuple[torch.Tensor, str]]
@dataclass
class AdapterConfig(ABC):
base_model_name_or_path: str
@abstractmethod
def map_weights_for_model(
self,
adapter_weights: Dict[int, AdapterWeights],
weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]:
pass
================================================
FILE: backends/gaudi/server/text_generation_server/adapters/lora.py
================================================
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/lora.py
# License: Apache License Version 2.0, January 2004
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple, Type, Union
import torch
from peft import LoraConfig as _LoraConfig
from torch.distributed import ProcessGroup
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
from text_generation_server.adapters.weights import (
AdapterBatchMetadata,
AdapterWeights,
BatchAdapterWeights,
)
from text_generation_server.utils.sgmv import (
BGMV_MAX_RANK,
MAX_RANK_CUSTOM,
get_tmp_tensors,
orient_for_rank,
pad_rank,
use_cutlass_shrink,
)
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
block_size = size // world_size
start = offset + rank * block_size
stop = offset + (rank + 1) * block_size
return start, stop
def shard_on_dim(
t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
):
world_size = process_group.size()
rank = process_group.rank()
size = t.shape[dim]
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
if dim == 0:
tensor = t[start:stop]
elif dim == 1:
tensor = t[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
return tensor
def shard_lora_weights(
weights_a: List[torch.Tensor],
weights_b: List[torch.Tensor],
split_dim: int,
process_group: ProcessGroup,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
# [hidden_size, r]
weights_a = [
shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a
]
# [r, hidden_size]
weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b]
return weights_a, weights_b
@dataclass
class LoraConfig(AdapterConfig):
r: int
target_modules: Optional[Union[List[str], str]]
fan_in_fan_out: bool
lora_alpha: int
use_rslora: bool
def map_weights_for_model(
self,
adapter_weights: Dict[int, AdapterWeights],
weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]:
adapter_weight_names = set()
module_map = {}
for weight_name in weight_names:
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:
continue
module_map[weight_name] = {
"lora_A": (adapter_weights[lora_a_name], lora_a_name),
"lora_B": (adapter_weights[lora_b_name], lora_b_name),
}
adapter_weight_names.add(lora_a_name)
adapter_weight_names.add(lora_b_name)
return module_map, adapter_weight_names
@classmethod
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
return cls(
base_model_name_or_path=hf_config.base_model_name_or_path,
r=hf_config.r,
target_modules=hf_config.target_modules,
fan_in_fan_out=hf_config.fan_in_fan_out,
lora_alpha=hf_config.lora_alpha,
use_rslora=(
hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False
),
)
class LoraWeights(AdapterWeights):
"""LoRA weights for a single adapter merged across all layers."""
def __init__(
self,
weights_a: List[torch.Tensor],
weights_b: List[torch.Tensor],
adapter_config: LoraConfig,
):
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
self._is_transposed = False
# [num_layers, hidden_size, r]
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
self._weights_a = torch.stack(weights_a)
# [num_layers, r, hidden_size]
self._weights_b = torch.stack(weights_b)
self.adapter_config = adapter_config
@property
def weights_a(self) -> torch.Tensor:
if self._is_transposed:
self._transpose_weights()
return self._weights_a
@property
def weights_b(self) -> torch.Tensor:
if self._is_transposed:
self._transpose_weights()
return self._weights_b
@property
def weights_a_t(self) -> torch.Tensor:
if not self._is_transposed:
self._transpose_weights()
return self._weights_a
@property
def weights_b_t(self) -> torch.Tensor:
if not self._is_transposed:
self._transpose_weights()
return self._weights_b
def _transpose_weights(self):
if self._use_cutlass_shrink:
# If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation
self._weights_a = self._weights_a.transpose(1, 2).contiguous()
self._weights_b = self._weights_b.transpose(1, 2).contiguous()
self._is_transposed = not self._is_transposed
@classmethod
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
return [BatchLoraWeights]
# prepare pre-loaded lora weights for use in the model.
#
# this method processes and organizes lora weights for a specific layer type across all layers:
# - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor.
# - retrieves weights from `module_map` based on the `layer_type`.
# - processes `nlayers` number of layers.
# - converts weights to the specified `dtype`.
# - shards weights across `world_size` number of processes using the `process_group`.
# - maps weights to specific layers using `target_to_layer`.
# - tracks `unused_weight_names` to identify any unused weights.
#
# the method handles weight transposition, scaling, and padding to ensure compatibility
# with SGMV or BGMV operations.
@classmethod
def prepare_weights(
cls,
config: LoraConfig,
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
nlayers: int,
dtype: torch.dtype,
world_size: int,
process_group: ProcessGroup,
target_to_layer: Dict[str, Tuple[str, torch.Tensor]],
) -> Optional[AdapterWeights]:
lora_a_list = [None] * nlayers
lora_b_list = [None] * nlayers
for layer_id in range(nlayers):
key = (layer_id, layer_type)
weight_name, layer = target_to_layer[key]
base_weight = layer.base_layer.linear.weight
base_device = base_weight.device
if weight_name not in module_map:
# There is no LoRA weight for this layer type in the adapter
return None
lora_a, lora_a_name = module_map[weight_name]["lora_A"]
lora_a = lora_a.to(base_device, dtype)
lora_b, lora_b_name = module_map[weight_name]["lora_B"]
lora_b = lora_b.to(base_device, dtype)
scale = get_scaling_factor(
config.lora_alpha,
config.r,
uses_rslora=config.use_rslora,
)
unused_weight_names.discard(lora_a_name)
unused_weight_names.discard(lora_b_name)
# Merge scaling factor into lora_b due to associativity of matrix multiplication:
# (A * B) * C = A * (B * C)
lora_a_list[layer_id] = lora_a.transpose(0, 1)
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
# pad lora ranks to be compatible with sgmv
lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list]
lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list]
if lora_a_list:
# update rank if it was padded
padded_rank = lora_a_list[0].size(1)
config.r = padded_rank
return LoraWeights(
*shard_lora_weights(
weights_a=lora_a_list,
weights_b=lora_b_list,
split_dim=0 if layer_type in {"o_proj", "down_proj", "lm_head"} else 1,
process_group=process_group,
),
config,
)
@dataclass
class RankSegments:
rank: int
lora_a_ptr: torch.Tensor
lora_b_ptr: torch.Tensor
# prefill (sgmv)
tmp_shrink: torch.Tensor
tmp_expand: torch.Tensor
segment_starts: torch.Tensor
segment_ends: torch.Tensor
# decode (bgmv)
indices: torch.Tensor
@dataclass
class BatchLoraWeights(BatchAdapterWeights):
lora_a: Dict[int, torch.Tensor]
lora_b: Dict[int, torch.Tensor]
adapter_index_configs: Dict[int, LoraConfig]
rank_data: Dict[int, RankSegments]
use_sgmv: bool
def has_adapter(self, adapter_index: int) -> bool:
return adapter_index in self.adapter_index_configs
def can_vectorize(self, pg: ProcessGroup) -> bool:
return all(
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
for rank_data in self.rank_data.values()
)
@classmethod
def load(
self,
adapter_weights: Dict[int, AdapterWeights],
meta: AdapterBatchMetadata,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> Optional["BatchLoraWeights"]:
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
adapter_weights = {
k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
}
if not adapter_weights:
return None
first_weights = next(iter(adapter_weights.values()))
device = first_weights.weights_a.device
segment_indices = meta.segment_indices
lora_a = {
idx: adapter_weights[idx].weights_a
for idx in segment_indices
if idx in adapter_weights
}
lora_b = {
idx: adapter_weights[idx].weights_b
for idx in segment_indices
if idx in adapter_weights
}
max_rank = max(
(
adapter_weights[idx].lora_a_r
for idx in segment_indices
if idx in adapter_weights
),
default=0,
)
if prefill or max_rank > BGMV_MAX_RANK:
use_sgmv = True
lora_a_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_a.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_b.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
else:
use_sgmv = False
lora_a_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_a_t.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_b_t.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
adapter_index_configs = {
idx: adapter_weights[idx].adapter_config
for idx in segment_indices
if idx in adapter_weights
}
adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}
rank_indices = defaultdict(list)
for segment_idx, adapter_idx in enumerate(segment_indices):
if adapter_idx not in adapter_weights:
continue
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
if prefill_head_indices is not None:
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
for head_index in prefill_head_indices:
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
if head_index < meta.adapter_segments[j]:
prefill_head_segment_ends[-1] += 1
else:
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
j += 1
rank_data = {}
for rank, indices in rank_indices.items():
tmp_shrink = None
tmp_expand = None
segment_starts = None
segment_ends = None
batch_indices = None
if use_sgmv:
lora_a_ptr_indices = lora_a_ptr[indices]
tmp_shrink, tmp_expand = get_tmp_tensors(
lora_a_ptr_indices.size(0), rank, device
)
segment_starts = meta.adapter_segments[indices]
segment_ends = meta.adapter_segments[[i + 1 for i in indices]]
if prefill_head_indices is not None:
for i, segment_index in enumerate(indices):
segment_starts[i] = prefill_head_segment_starts[segment_index]
segment_ends[i] = prefill_head_segment_ends[segment_index]
else:
rank_indices = set(indices)
batch_indices = [
adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
]
batch_indices = [
idx if idx in rank_indices else -1 for idx in batch_indices
]
batch_indices = torch.tensor(
batch_indices, dtype=torch.int64, device=device
)
rank_data[rank] = RankSegments(
rank=rank,
tmp_shrink=tmp_shrink,
tmp_expand=tmp_expand,
lora_a_ptr=lora_a_ptr[indices],
lora_b_ptr=lora_b_ptr[indices],
segment_starts=segment_starts,
segment_ends=segment_ends,
indices=batch_indices,
)
return BatchLoraWeights(
lora_a=lora_a,
lora_b=lora_b,
adapter_index_configs=adapter_index_configs,
rank_data=rank_data,
use_sgmv=use_sgmv,
)
def get_scaling_factor(
lora_alpha: int,
r: int,
uses_rslora: bool = False,
) -> float:
"""Computes the scaling factor for the lora weights."""
if uses_rslora:
return lora_alpha / (r**0.5)
return lora_alpha / r
def _convert_lora(v: AdapterWeights) -> AdapterWeights:
if hasattr(v, "lora_weights"):
return v.lora_weights
return v
================================================
FILE: backends/gaudi/server/text_generation_server/adapters/weights.py
================================================
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/weights.py
# License: Apache License Version 2.0, January 2004
from abc import ABC, abstractclassmethod
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Type
import torch
@dataclass
class AdapterBatchMetadata:
# [batch_size]
adapter_indices: torch.Tensor
# [num_adapters]
adapter_set: Set[int]
# [num_segments + 1]
adapter_segments: torch.Tensor
# [num_segments]
# maps from segment index to adapter index, i.e.:
# segment_indices[s] == adapter_indices[i]
segment_indices: List[int]
class AdapterWeights(ABC):
@abstractclassmethod
def get_batch_types(cls) -> List[Type["BatchAdapterWeights"]]:
pass
@property
def speculative_tokens(self) -> int:
return 0
class BatchAdapterWeights(ABC):
@abstractclassmethod
def has_adapter(self, adapter_index: int) -> bool:
pass
@abstractclassmethod
def load(
cls,
adapter_weights: Dict[int, AdapterWeights],
meta: "AdapterBatchMetadata",
prefill: bool,
prefill_head_indices: torch.Tensor,
) -> Optional["BatchAdapterWeights"]:
pass
class LayerAdapterWeights:
"""Adapter weights that apply to a particular layer."""
def __init__(self):
self.adapter_weights: Dict[int, AdapterWeights] = {}
def add_adapter(self, adapter_idx: int, weights: AdapterWeights):
self.adapter_weights[adapter_idx] = weights
def remove_adapter(self, adapter_idx: int):
if adapter_idx not in self.adapter_weights:
return
del self.adapter_weights[adapter_idx]
def is_empty(self) -> bool:
return len(self.adapter_weights) == 0
def get_data(
self,
meta: AdapterBatchMetadata,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> Dict[str, BatchAdapterWeights]:
# bucket adapters by batch class
adapter_batch_types: Dict[
Type[BatchAdapterWeights], Dict[int, AdapterWeights]
] = defaultdict(dict)
for adapter_index, adapter_weights in self.adapter_weights.items():
for batch_type in adapter_weights.get_batch_types():
adapter_batch_types[batch_type][adapter_index] = adapter_weights
batch_data = {}
for batch_type, adapter_weights in adapter_batch_types.items():
batched_weights = batch_type.load(
adapter_weights, meta, prefill, prefill_head_indices
)
if batched_weights is not None:
batch_data = batched_weights
return batch_data
@dataclass
class AdapterBatchData:
meta: AdapterBatchMetadata
# layer type -> adapter type -> batch weight data
data: Dict[str, Dict[str, BatchAdapterWeights]]
prefill: bool
@staticmethod
def from_meta(
meta: AdapterBatchMetadata,
weights: Dict[str, LayerAdapterWeights],
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> "AdapterBatchData":
data = {}
for k, v in weights.items():
if v.is_empty():
continue
data[k] = v.get_data(
meta, prefill, prefill_head_indices if k == "lm_head" else None
)
return AdapterBatchData(meta=meta, data=data, prefill=prefill)
def ranks(self) -> Set[int]:
# TODO(travis): refactor to be less coupled to lora implementation
ranks = set()
for lora_data in self.data.values():
if lora_data is None:
continue
for rank_data in lora_data.rank_data.values():
ranks.add(rank_data.rank)
return ranks
def layer_names(self) -> Set[str]:
return set(self.data.keys())
def adapter_keys(self) -> Set[str]:
adapter_keys = set()
for layer_data in self.data.values():
adapter_keys.update(layer_data.keys())
return adapter_keys
@property
def max_rank(self) -> int:
ranks = self.ranks()
return max(ranks) if len(ranks) > 0 else 0
================================================
FILE: backends/gaudi/server/text_generation_server/cache.py
================================================
import torch
from typing import Dict, Optional, TypeVar
from text_generation_server.models.types import Batch
B = TypeVar("B", bound=Batch)
class Cache:
def __init__(self):
self.cache: Dict[int, B] = {}
def pop(self, batch_id: int) -> Optional[B]:
return self.cache.pop(batch_id, None)
def set(self, entry: B):
if entry is not None:
self.cache[entry.batch_id] = entry
def delete(self, batch_id: int):
batch = self.pop(batch_id)
if batch is not None:
del batch
if torch.cuda.is_available():
torch.cuda.empty_cache()
def clear(self):
keys = list(self.cache.keys())
for k in keys:
self.delete(k)
def __len__(self):
return len(self.cache.keys())
================================================
FILE: backends/gaudi/server/text_generation_server/cli.py
================================================
import os
import sys
import typer
from pathlib import Path
from loguru import logger
from typing import Optional
from enum import Enum
from huggingface_hub import hf_hub_download
from text_generation_server.utils.adapter import parse_lora_adapters
app = typer.Typer()
class Quantization(str, Enum):
gptq = "gptq"
awq = "awq"
fp8 = "fp8"
compressed_tensors = "compressed-tensors"
class Dtype(str, Enum):
float16 = "float16"
bloat16 = "bfloat16"
class KVCacheDtype(str, Enum):
fp8_e4m3fn = "fp8_e4m3fn"
fp8_e5m2 = "fp8_e5m2"
@app.command()
def serve(
model_id: str,
revision: Optional[str] = None,
sharded: bool = False,
quantize: Optional[Quantization] = None,
speculate: Optional[int] = None,
dtype: Optional[Dtype] = None,
kv_cache_dtype: Optional[KVCacheDtype] = None,
trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
otlp_service_name: str = "text-generation-inference.server",
max_input_tokens: Optional[int] = None,
):
if sharded:
# assert (
# os.getenv("RANK", None) is not None
# ), "RANK must be set when sharded is True"
assert (
os.getenv("WORLD_SIZE", None) is not None
), "WORLD_SIZE must be set when sharded is True"
assert (
os.getenv("MASTER_ADDR", None) is not None
), "MASTER_ADDR must be set when sharded is True"
assert (
os.getenv("MASTER_PORT", None) is not None
), "MASTER_PORT must be set when sharded is True"
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation_server",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
# Import here after the logger is added to log potential import exceptions
from text_generation_server import server
from text_generation_server.tracing import setup_tracing
# Setup OpenTelemetry distributed tracing
if otlp_endpoint is not None:
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS"))
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
# and warn the user
if lora_adapters:
logger.warning("LoRA adapters enabled (experimental feature).")
if "CUDA_GRAPHS" in os.environ:
logger.warning(
"LoRA adapters incompatible with CUDA Graphs. Disabling CUDA Graphs."
)
global CUDA_GRAPHS
CUDA_GRAPHS = None
# Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value
dtype = "bfloat16" if dtype is None else dtype.value
kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value
logger.info(f"quantize={quantize} kv_cache_dtype={kv_cache_dtype}")
if dtype is not None and quantize not in {
None,
"bitsandbytes",
"bitsandbytes-nf4",
"bitsandbytes-fp4",
"gptq",
"awq",
"fp8",
"compressed-tensors",
}:
raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)
server.serve(
model_id,
lora_adapters,
revision,
sharded,
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
uds_path,
max_input_tokens,
)
@app.command()
def download_weights(
model_id: str,
revision: Optional[str] = None,
extension: str = ".safetensors",
auto_convert: bool = True,
logger_level: str = "INFO",
json_output: bool = False,
trust_remote_code: bool = False,
merge_lora: bool = False,
):
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation_server",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
# Import here after the logger is added to log potential import exceptions
from text_generation_server import utils
# Test if files were already download
try:
utils.weight_files(model_id, revision, extension)
logger.info("Files are already present on the host. " "Skipping download.")
return
# Local files not found
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
pass
is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
"WEIGHTS_CACHE_OVERRIDE", None
) is not None
if not is_local_model:
# TODO: maybe reverse the default value of merge_lora?
# currently by default we don't merge the weights with the base model
if merge_lora:
try:
hf_hub_download(
model_id, revision=revision, filename="adapter_config.json"
)
utils.download_and_unload_peft(
model_id, revision, trust_remote_code=trust_remote_code
)
is_local_model = True
utils.weight_files(model_id, revision, extension)
return
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
else:
try:
utils.peft.download_peft(
model_id, revision, trust_remote_code=trust_remote_code
)
except Exception:
pass
try:
import json
config = hf_hub_download(
model_id, revision=revision, filename="config.json"
)
with open(config, "r") as f:
config = json.load(f)
base_model_id = config.get("base_model_name_or_path", None)
if base_model_id and base_model_id != model_id:
try:
logger.info(f"Downloading parent model {base_model_id}")
download_weights(
model_id=base_model_id,
revision="main",
extension=extension,
auto_convert=auto_convert,
logger_level=logger_level,
json_output=json_output,
trust_remote_code=trust_remote_code,
)
except Exception:
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
# Try to download weights from the hub
try:
filenames = utils.weight_hub_files(model_id, revision, extension)
utils.download_weights(filenames, model_id, revision)
# Successfully downloaded weights
return
# No weights found on the hub with this extension
except utils.EntryNotFoundError as e:
# Check if we want to automatically convert to safetensors or if we can use .bin weights instead
if not extension == ".safetensors" or not auto_convert:
raise e
elif (Path(model_id) / "adapter_config.json").exists():
# Try to load as a local PEFT model
try:
utils.download_and_unload_peft(
model_id, revision, trust_remote_code=trust_remote_code
)
utils.weight_files(model_id, revision, extension)
return
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
elif (Path(model_id) / "config.json").exists():
# Try to load as a local Medusa model
try:
import json
config = Path(model_id) / "config.json"
with open(config, "r") as f:
config = json.load(f)
base_model_id = config.get("base_model_name_or_path", None)
if base_model_id:
try:
logger.info(f"Downloading parent model {base_model_id}")
download_weights(
model_id=base_model_id,
revision="main",
extension=extension,
auto_convert=auto_convert,
logger_level=logger_level,
json_output=json_output,
trust_remote_code=trust_remote_code,
)
except Exception:
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
# Try to see if there are local pytorch weights
try:
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
try:
local_pt_files = utils.weight_files(model_id, revision, ".bin")
except Exception:
local_pt_files = utils.weight_files(model_id, revision, ".pt")
# No local pytorch weights
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
if extension == ".safetensors":
logger.warning(
f"No safetensors weights found for model {model_id} at revision {revision}. "
f"Downloading PyTorch weights."
)
# Try to see if there are pytorch weights on the hub
pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
# Download pytorch weights
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
if auto_convert:
if not trust_remote_code:
logger.warning(
"🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because "
"Pickle files are unsafe and can essentially contain remote code execution!"
"Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety",
)
logger.warning(
f"No safetensors weights found for model {model_id} at revision {revision}. "
f"Converting PyTorch weights to safetensors."
)
# Safetensors final filenames
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
for p in local_pt_files
]
try:
import transformers
import json
if is_local_model:
config_filename = os.path.join(model_id, "config.json")
else:
config_filename = hf_hub_download(
model_id, revision=revision, filename="config.json"
)
with open(config_filename, "r") as f:
config = json.load(f)
architecture = config["architectures"][0]
class_ = getattr(transformers, architecture)
# Name for this varible depends on transformers version.
discard_names = getattr(class_, "_tied_weights_keys", [])
except Exception:
discard_names = []
# Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files, discard_names)
@app.command()
def quantize(
model_id: str,
output_dir: str,
revision: Optional[str] = None,
logger_level: str = "INFO",
json_output: bool = False,
trust_remote_code: bool = False,
upload_to_model_id: Optional[str] = None,
percdamp: float = 0.01,
act_order: bool = False,
groupsize: int = 128,
):
if revision is None:
revision = "main"
download_weights(
model_id=model_id,
revision=revision,
logger_level=logger_level,
json_output=json_output,
)
from text_generation_server.layers.gptq.quantize import quantize
quantize(
model_id=model_id,
bits=4,
groupsize=groupsize,
output_dir=output_dir,
revision=revision,
trust_remote_code=trust_remote_code,
upload_to_model_id=upload_to_model_id,
percdamp=percdamp,
act_order=act_order,
sym=True,
)
if __name__ == "__main__":
app()
================================================
FILE: backends/gaudi/server/text_generation_server/interceptor.py
================================================
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import torch
import grpc
from google.rpc import status_pb2, code_pb2
from grpc_status import rpc_status
from grpc_interceptor.server import AsyncServerInterceptor
from loguru import logger
from typing import Callable, Any
import traceback
import os
class ExceptionInterceptor(AsyncServerInterceptor):
async def intercept(
self,
method: Callable,
request_or_iterator: Any,
context: grpc.ServicerContext,
method_name: str,
) -> Any:
try:
response = method(request_or_iterator, context)
return await response
except Exception as err:
trace = " " + traceback.format_exc() if os.environ.get("DUMP_STACK") else ""
method_name = method_name.split("/")[-1]
logger.exception(f"Method {method_name} encountered an error.")
# Runtime Error cannot be recovered from
if isinstance(err, RuntimeError):
exit(1)
if torch.cuda.is_available():
torch.cuda.empty_cache()
from .utils.debug import dbg_trace
dbg_trace("EXCEPTION", traceback.format_exc())
await context.abort_with_status(
rpc_status.to_status(
status_pb2.Status(code=code_pb2.INTERNAL, message=str(err) + trace)
)
)
================================================
FILE: backends/gaudi/server/text_generation_server/layers/__init__.py
================================================
from text_generation_server.layers.tensor_parallel import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
TensorParallelEmbedding,
)
from text_generation_server.layers.linear import (
get_linear,
FastLinear,
)
from text_generation_server.layers.speculative import SpeculativeHead
# Just to add the `load` methods.
from text_generation_server.layers.layernorm import load_layer_norm
from text_generation_server.layers.conv import load_conv2d
from text_generation_server.layers.fp8 import Fp8Linear
from text_generation_server.layers.lora import (
LoraLinear,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
__all__ = [
"get_linear",
"FastLinear",
"TensorParallelColumnLinear",
"TensorParallelRowLinear",
"TensorParallelEmbedding",
"SpeculativeHead",
"LoraLinear",
"Fp8Linear",
"TensorParallelMultiAdapterLinear",
"TensorParallelAdapterRowLinear",
"load_layer_norm",
"load_conv2d",
]
================================================
FILE: backends/gaudi/server/text_generation_server/layers/attention/__init__.py
================================================
from .common import (
Seqlen,
HPUPagedAttentionMetadata,
trim_attn_metadata,
trim_seqlen_metadata,
_async_h2d_tensor_copy,
)
from .hpu import (
SUPPORTS_WINDOWING,
attention,
paged_attention,
paged_attention_mla,
set_block_mapping,
)
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
from .kv_cache import KVCache, get_kv_scales, KVCompressCache
__all__ = [
"attention",
"get_kv_scales",
"paged_attention",
"paged_attention_mla",
"set_block_mapping",
"SUPPORTS_WINDOWING",
"KVCache",
"KVCompressCache",
"Seqlen",
"HPUPagedAttentionMetadata",
"trim_seqlen_metadata",
"trim_attn_metadata",
"_async_h2d_tensor_copy",
]
================================================
FILE: backends/gaudi/server/text_generation_server/layers/attention/common.py
================================================
from dataclasses import dataclass
import torch
from typing import Optional, List, Dict
import collections
import torch.nn.functional as F
_TYPE_CACHE = {}
@dataclass
class HPUPagedAttentionMetadata:
"""Metadata for PagedAttention."""
block_list: Optional[torch.Tensor]
block_mapping: Optional[torch.Tensor]
block_usage: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor]
attn_bias: Optional[torch.Tensor]
slots_in_window_mask: Optional[torch.Tensor] = None
block_list_in_window: Optional[torch.Tensor] = None
block_mapping_in_window: Optional[torch.Tensor] = None
block_usage_in_window: Optional[torch.Tensor] = None
block_groups_in_window: Optional[torch.Tensor] = None
attn_bias_in_window: Optional[torch.Tensor] = None
def subtuple(
obj: object,
typename: str,
to_copy: List[str],
to_override: Optional[Dict[str, object]] = None,
):
if obj is None:
return None
if to_override is None:
to_override = {}
fields = set(to_copy) | set(to_override.keys())
if isinstance(obj, dict):
values = {key: obj[key] for key in fields if key in obj}
else:
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
if typename not in _TYPE_CACHE:
_TYPE_CACHE[typename] = collections.namedtuple(typename, " ".join(fields))
return _TYPE_CACHE[typename](**values)
def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
# NOTE(kzawora): To anyone working on this in the future:
# Trimming metadata is required when using HPUGraphs.
# Attention metadata is going to be hashed by PT bridge, and
# appropriate HPUGraphs will be matched based on all inputs' hash.
# Before you put more keys in here, make sure you know their
# value type and make sure you know how it's going to be hashed.
# You can find that information in input_hash function
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
# If you use primitive types here - they will get hashed based
# on their value. You *will* get lots of excessive graph captures
# (and an OOM eventually) if you decide to put something like
# seq_len int here.
# If you absolutely need a scalar, put it in a tensor. Tensors
# get hashed using their metadata, not their values:
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
# input_hash(123) != input_hash(321)
# input_hash("abc") != input_hash("cba")
attention_metadata = subtuple(
metadata,
"TrimmedAttentionMetadata",
[
"block_list",
"block_mapping",
"block_usage",
"block_groups",
"attn_bias",
"slots_in_window_mask",
"block_list_in_window",
"block_mapping_in_window",
"block_usage_in_window",
"block_groups_in_window",
"attn_bias_in_window",
],
)
return attention_metadata
@dataclass
class Seqlen:
input_lengths: torch.Tensor
attn_mask: Optional[torch.Tensor] = None
def __init__(
self,
input_lengths,
):
self.input_lengths = input_lengths
def clamp(self, max):
# Flash decoding doesn't need to clamp
return self
def make_sliding_window_bias(
self,
seq_lens: List[int],
window_size: Optional[int],
dtype: torch.dtype,
padded_input_len: Optional[int],
padded_bs: Optional[int],
) -> List[torch.Tensor]:
attn_biases = []
for seq_len in seq_lens:
if seq_len != 0:
tensor = torch.full(
(1, seq_len, seq_len),
dtype=dtype,
fill_value=1,
)
shift = 0
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
if window_size is not None:
mask = torch.triu(mask, diagonal=shift - window_size + 1)
mask = F.pad(
mask,
(
padded_input_len - seq_len,
0,
padded_input_len - seq_len,
0,
0,
0,
),
value=0,
)
else:
mask = torch.full(
(1, padded_input_len, padded_input_len),
dtype=dtype,
fill_value=0,
)
attn_biases.append(mask)
attn_biases = torch.stack(attn_biases, dim=0)
return attn_biases.to(torch.bool)
def _async_h2d_tensor_copy(source, device="hpu"):
if source is None:
return None
if source.device.type == "hpu":
return source
assert source.device.type == "cpu", "Source tensor is not present in host memory!"
target = torch.empty(source.shape, dtype=source.dtype, device=device)
target.copy_(source, non_blocking=True)
return target
def trim_seqlen_metadata(metadata: Seqlen) -> object:
# NOTE(kzawora): To anyone working on this in the future:
# Trimming metadata is required when using HPUGraphs.
# Attention metadata is going to be hashed by PT bridge, and
# appropriate HPUGraphs will be matched based on all inputs' hash.
# Before you put more keys in here, make sure you know their
# value type and make sure you know how it's going to be hashed.
# You can find that information in input_hash function
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
# If you use primitive types here - they will get hashed based
# on their value. You *will* get lots of excessive graph captures
# (and an OOM eventually) if you decide to put something like
# seq_len int here.
# If you absolutely need a scalar, put it in a tensor. Tensors
# get hashed using their metadata, not their values:
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
# input_hash(123) != input_hash(321)
# input_hash("abc") != input_hash("cba")
attention_metadata = subtuple(
metadata,
"TrimmedSeqlen",
[
"input_lengths",
"attn_mask",
],
)
return attention_metadata
================================================
FILE: backends/gaudi/server/text_generation_server/layers/attention/hpu.py
================================================
import torch
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
from typing import Optional
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from vllm_hpu_extension import ops
from vllm_hpu_extension.utils import Matmul
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA
import os
from text_generation_server.models.globals import BLOCK_SIZE
import math
SUPPORTS_WINDOWING = False
class FP8Matmul(torch.nn.Module):
def __init__(self, scale_other):
super().__init__()
self.scale_input = torch.tensor(1.0, dtype=torch.bfloat16, device="hpu")
self.scale_other = scale_other
def quant_input(self, x, scale):
return torch.ops.hpu.cast_to_fp8_v2(
x, scale, False, False, torch.float8_e4m3fn
)[0]
def matmul_fp8(
self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None
):
return torch.ops.hpu.fp8_gemm_v2(
A=x,
trans_A=False,
B=other,
trans_B=False,
D=None,
out_dtype=out_dtype,
A_scale_inv=scale_input_inv,
B_scale_inv=scale_other_inv,
bias=None,
accumulate=False,
)
def forward(self, input, other):
qinput = self.quant_input(input, self.scale_input)
qother = self.quant_input(other, self.scale_other)
output = self.matmul_fp8(
qinput,
qother,
out_dtype=torch.bfloat16,
scale_input_inv=1.0 / self.scale_input,
scale_other_inv=1.0 / self.scale_other,
)
return output
class FetchFromCache(torch.nn.Module):
def __init__(self, scale_inv):
super().__init__()
self.scale_inv = scale_inv
def forward(self, cache, blocks):
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
out = cache[: blocks.size(0)]
else:
out = cache.index_select(0, blocks)
if out.dtype == torch.float8_e4m3fn:
out = torch.ops.hpu.cast_from_fp8(out, self.scale_inv, torch.bfloat16)
return out
def attention(
*,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
):
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
bs = seqlen.input_lengths.shape[0]
_, head_num, head_size = query.shape
_, kv_head_num, head_size = key.shape
query = query.view(bs, -1, head_num, head_size).transpose(1, 2)
key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
attn_output = fsdpa_op(
query,
key,
value,
attn_mask=seqlen.attn_mask if window_size_left != -1 else None,
dropout_p=0.0,
is_causal=causal if window_size_left == -1 else False,
scale=softmax_scale,
softmax_mode="None",
recompute_mode=None,
valid_sequence_lengths=seqlen.input_lengths if window_size_left == -1 else None,
padding_side="left",
)
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
return attn_output
def set_block_mapping(hpu_attention_meta: HPUPagedAttentionMetadata, batch_size):
block_mapping = torch.nn.functional.one_hot(
hpu_attention_meta.block_groups, num_classes=batch_size
)
dtype = hpu_attention_meta.block_usage.dtype
device = hpu_attention_meta.block_usage.device
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
mask = mask >= hpu_attention_meta.block_usage.unsqueeze(-1)
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
hpu_attention_meta = hpu_attention_meta._replace(
attn_bias=attn_bias, block_mapping=block_mapping.to(dtype)
)
if hpu_attention_meta.block_groups_in_window is not None:
block_mapping = torch.nn.functional.one_hot(
hpu_attention_meta.block_groups_in_window, num_classes=batch_size
)
attn_bias = torch.log(hpu_attention_meta.slots_in_window_mask.float())
hpu_attention_meta = hpu_attention_meta._replace(
attn_bias_in_window=attn_bias,
block_mapping_in_window=block_mapping.to(dtype),
)
return hpu_attention_meta
def paged_attention(
query: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
seqlen: Seqlen,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
hpu_attention_meta: HPUPagedAttentionMetadata,
window_size_left: int = -1,
):
batch_size, head_num, head_size = query.shape
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
output = ops.flat_pa(
query=query.view(batch_size, 1, head_num * head_size),
key_cache=kv_cache.key,
value_cache=kv_cache.value,
block_list=(
hpu_attention_meta.block_list
if window_size_left == -1
else hpu_attention_meta.block_list_in_window
),
block_mapping=(
hpu_attention_meta.block_mapping
if window_size_left == -1
else hpu_attention_meta.block_mapping_in_window
),
block_bias=(
hpu_attention_meta.attn_bias
if window_size_left == -1
else hpu_attention_meta.attn_bias_in_window
),
block_groups=(
hpu_attention_meta.block_groups
if window_size_left == -1
else hpu_attention_meta.block_groups_in_window
),
block_size=BLOCK_SIZE,
scale=softmax_scale,
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
batch2block_matmul_op=Matmul(),
block2batch_matmul_op=Matmul(),
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
values_fetch_func=FetchFromCache(1.0 / kv_scales.value_scale_cpu),
)
# Reshape the output tensor.
return output.view(batch_size, head_num, head_size)
def paged_attention_mla(
query: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
seqlen: Seqlen,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
hpu_attention_meta: HPUPagedAttentionMetadata,
kv_lora_rank: int = 0,
):
batch_size, head_num, head_size = query.shape
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
output = ops.flat_pa_mla(
query=query,
key_cache=kv_cache.key,
value_cache=None,
block_list=hpu_attention_meta.block_list,
block_mapping=hpu_attention_meta.block_mapping,
block_bias=hpu_attention_meta.attn_bias,
block_groups=hpu_attention_meta.block_groups,
block_size=BLOCK_SIZE,
scale=softmax_scale,
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
batch2block_matmul_op=Matmul(),
block2batch_matmul_op=Matmul(),
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
values_fetch_func=None,
kv_lora_rank=kv_lora_rank,
)
# Reshape the output tensor.
return output.view(batch_size, head_num, -1)
__all__ = [
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"paged_attention_mla",
"set_block_mapping",
]
================================================
FILE: backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py
================================================
from typing import Tuple
from dataclasses import dataclass, field
import torch
from text_generation_server.models.globals import BLOCK_SIZE
from text_generation_server.utils.weights import Weights
@dataclass
class KVScales:
"""
Key-value scales for FP8 KV cache.
This data class stores key and value scales both as a GPU tensor and
as a GPU float. This inconvenience is necessary because some functions
(e.g. scaling kernels) take scales as a GPU tensor, whereas others
(e.g. flashinfer) take scales as a CPU scalar.
"""
key_scale: torch.Tensor
value_scale: torch.Tensor
key_scale_cpu: float = field(init=False)
value_scale_cpu: float = field(init=False)
def __post_init__(self):
if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:
raise ValueError("Key and value scales must be scalar tensors.")
self.key_scale_cpu = self.key_scale.item()
self.value_scale_cpu = self.value_scale.item()
class KVCache:
"""
Key-value cache for attention layers.
"""
kv_cache: Tuple[torch.Tensor, torch.Tensor]
def __init__(
self,
*,
num_blocks: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
"""Construct the key-value cache for a layer."""
## TODO FP8 kv cache support
if dtype is torch.float8_e5m2:
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
self.kv_cache = (
torch.zeros(
(num_blocks * BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
torch.zeros(
(num_blocks * BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
)
@property
def dtype(self):
"""Get the data type of the cache."""
return self.kv_cache[0].dtype
@property
def key(self):
"""Get the key cache."""
return self.kv_cache[0]
@property
def value(self):
"""Get the value cache."""
return self.kv_cache[1]
def store(
self,
*,
key: torch.Tensor,
value: torch.Tensor,
slots: torch.Tensor,
kv_scales: KVScales,
):
"""Store the key and value at the given slots."""
## TODO FP8 kv cache support
key_cache = self.kv_cache[0]
value_cache = self.kv_cache[1]
paged_reshape_and_cache(
key,
value,
key_cache,
value_cache,
slots,
kv_scales.key_scale,
kv_scales.value_scale,
)
class KVCompressCache(KVCache):
"""
Key-value cache for attention layers.
"""
kv_cache: torch.Tensor
def __init__(
self,
*,
num_blocks: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
"""Construct the key-value cache for a layer."""
## TODO FP8 kv cache support
if dtype is torch.float8_e5m2:
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
self.kv_cache = torch.zeros(
(num_blocks * BLOCK_SIZE, 1, head_size),
dtype=dtype,
device=device,
)
@property
def dtype(self):
"""Get the data type of the cache."""
return self.kv_cache.dtype
@property
def key(self):
"""Get the key cache."""
return self.kv_cache
@property
def value(self):
"""Get the value cache."""
return self.kv_cache
def store(
self,
*,
key: torch.Tensor,
value: torch.Tensor,
slots: torch.Tensor,
kv_scales: KVScales,
):
"""Store the key and value at the given slots."""
## TODO FP8 kv cache support
if self.kv_cache.dtype == torch.float8_e4m3fn:
key = torch.ops.hpu.cast_to_fp8_v2(
key, kv_scales.key_scale, False, False, torch.float8_e4m3fn
)[0]
self.kv_cache.index_copy_(0, slots, key)
def paged_reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
):
if key_cache.dtype == torch.float8_e4m3fn:
key = torch.ops.hpu.cast_to_fp8_v2(
key, k_scale, False, False, torch.float8_e4m3fn
)[0]
value = torch.ops.hpu.cast_to_fp8_v2(
value, v_scale, False, False, torch.float8_e4m3fn
)[0]
key_cache.index_copy_(0, slots, key)
value_cache.index_copy_(0, slots, value)
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
"""Load KV cache scales."""
key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)
value_scale = key_scale
if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor(
f"{prefix}.v_scale"
):
key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float()
value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float()
elif weights.has_tensor(f"{prefix}.kv_scale"):
# Fall back to older more coarse-grained scale when available.
key_scale = weights.get_tensor(f"{prefix}.kv_scale").float()
value_scale = key_scale
return KVScales(key_scale=key_scale, value_scale=value_scale)
================================================
FILE: backends/gaudi/server/text_generation_server/layers/awq/conversion_utils.py
================================================
import torch
from typing import List
AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
def pack(imatrix: torch.Tensor, direction: str = "column"):
"""
Packs a 4-bit integer matrix into a packed 32-bit integer matrix.
Args:
imatrix (torch.Tensor): matrix of integers
direction (str): direction of packing, either "column" or "row"
Returns:
qmatrix (torch.Tensor): packed matrix of integers
"""
shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=imatrix.device)
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
if direction == "column":
imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4))
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)
elif direction == "row":
imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1)
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1)
qmatrix = qmatrix.to(torch.int32)
return qmatrix
def unpack(qmatrix: torch.Tensor, direction: str = "column"):
"""
Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix.
Args:
qmatrix (torch.Tensor): matrix of packed integers
direction (str): direction of unpacking, either "column" or "row"
Returns:
imatrix (torch.Tensor): matrix of integers
"""
shifts = torch.arange(0, 32, 4, device=qmatrix.device)
if direction == "column":
imatrix = torch.bitwise_right_shift(
qmatrix[:, :, None], shifts[None, None, :]
).view(qmatrix.shape[0], -1)
elif direction == "row":
imatrix = torch.bitwise_right_shift(
qmatrix[:, None, :], shifts[None, :, None]
).view(-1, qmatrix.shape[-1])
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
return imatrix
def apply_order(
imatrix: torch.Tensor,
direction: str = "column",
order: List[int] = AWQ_PACK_ORDER,
):
"""
Applies the order to a 4-bit integer matrix.
Args:
imatrix (torch.Tensor): matrix of integers
direction (str): direction of applying order, either "column" or "row"
order (List[int]): order to apply, default is AWQ_PACK_ORDER
Returns:
imatrix (torch.Tensor): matrix of integers
"""
if direction == "column":
imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape)
elif direction == "row":
imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape)
return imatrix
def fast_awq_to_gptq(qweight, qzeros):
# awq uses column packing for both weights and zeros
izeros = unpack(qzeros, direction="column")
iweights = unpack(qweight, direction="column")
# Reverse the order of the iweight and izeros tensors
izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER)
iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER)
# Subtract 1 from the izeros tensor (gptq adds 1 to the zeros)
izeros = izeros - 1
# exllama uses row packing for weights and column packing for zeros
qzeros = pack(izeros, direction="column")
qweight = pack(iweights, direction="row")
return qweight, qzeros
================================================
FILE: backends/gaudi/server/text_generation_server/layers/awq/quantize/__init__.py
================================================
from .hpu import WQLinear
__all__ = ["WQLinear"]
================================================
FILE: backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py
================================================
from typing import Optional
import torch
import torch.nn as nn
try:
import habana_frameworks.torch.hpu # noqa: F401
convert_from_uint4 = torch.ops.hpu.convert_from_uint4
except Exception as e:
hpu_import_exception = e
def error_raiser_hpu(*args, **kwargs):
raise ValueError(
f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
)
convert_from_uint4 = error_raiser_hpu
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qzeros.device)
# unpacking columnwise
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
torch.int8 # smallest dtype available
)
iweights = iweights.view(iweights.shape[0], -1)
# unpacking columnwise
if qzeros is not None:
izeros = torch.bitwise_right_shift(
qzeros[:, :, None], shifts[None, None, :]
).to(
torch.int8 # smallest dtype available
)
izeros = izeros.view(izeros.shape[0], -1)
else:
izeros = qzeros
return iweights, izeros
def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
reverse_order_tensor = torch.arange(
iweights.shape[-1],
dtype=torch.int32,
device=izeros.device,
)
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
reverse_order_tensor = reverse_order_tensor.view(-1)
if izeros is not None:
izeros = izeros[:, reverse_order_tensor]
iweights = iweights[:, reverse_order_tensor]
return iweights, izeros
def unpack_weight_and_zeros(qweight, qzeros, bits):
# Unpack the qweight and qzeros tensors
iweight, izeros = unpack_awq(qweight, qzeros, bits)
# Reverse the order of the iweight and izeros tensors
iweight, izeros = reverse_awq_order(iweight, izeros, bits)
# overflow checks
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
return iweight, izeros
def pack_tensor(input, bits=4):
normal = input.to(torch.int32)
q = torch.zeros(
(normal.shape[0], normal.shape[1] // 32 * bits),
dtype=torch.int32,
device=input.device,
)
i = 0
col = 0
while col < q.shape[1]:
for j in range(i, i + (32 // bits)):
q[:, col] |= normal[:, j] << (bits * (j - i))
i += 32 // bits
col += 1
q = q.to(torch.int32)
return q
class WQLinear(nn.Module):
def __init__(
self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.in_features = qweight.shape[0]
self.out_features = qweight.shape[1] * 32 // w_bit
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else self.in_features
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert self.out_features % (32 // self.w_bit) == 0
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.bias = bias
self._preprocessing()
def _preprocessing(self):
device = self.qweight.device
weight, zeros = unpack_weight_and_zeros(
self.qweight.cpu(), self.qzeros.cpu(), self.w_bit
)
self.qweight = pack_tensor(weight).to(device)
self.qzeros = pack_tensor(zeros).to(device)
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,)
x = x.reshape(-1, x.shape[-1])
weights = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
outputs = torch.matmul(x, weights)
outputs = outputs + self.bias if self.bias is not None else outputs
outputs = outputs.reshape(out_shape)
return outputs
================================================
FILE: backends/gaudi/server/text_generation_server/layers/bnb.py
================================================
from dataclasses import dataclass
import bitsandbytes as bnb
import torch
from bitsandbytes.nn import Int8Params, Params4bit
from text_generation_server.utils.weights import UnquantizedWeight
@dataclass
class BNBWeight(UnquantizedWeight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
return Linear8bitLt(self.weight, bias, has_fp16_weights=False, threshold=6.0)
class Linear8bitLt(torch.nn.Module):
def __init__(
self,
weight,
bias,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0,
index=None,
):
super().__init__()
assert (
not memory_efficient_backward
), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
self.state = bnb.MatmulLtState()
self.index = index
# Necessary for stacked layers
self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights
self.state.memory_efficient_backward = memory_efficient_backward
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
self.weight = Int8Params(
weight.data,
has_fp16_weights=has_fp16_weights,
requires_grad=has_fp16_weights,
)
self.weight.cuda(weight.device)
self.bias = bias
def init_8bit_state(self):
self.state.CB = self.weight.CB
self.state.SCB = self.weight.SCB
self.weight.CB = None
self.weight.SCB = None
def forward(self, x: torch.Tensor):
self.state.is_training = self.training
if self.weight.CB is not None:
self.init_8bit_state()
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
if not self.state.has_fp16_weights:
if self.state.CB is not None and self.state.CxB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self.state.CB
self.weight.data = self.state.CxB
return out
@dataclass
class BNBFP4Weight(UnquantizedWeight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
return Linear4bit(self.weight, bias, quant_type="fp4")
@dataclass
class BNBNF4Weight(UnquantizedWeight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
return Linear4bit(self.weight, bias, quant_type="nf4")
class Linear4bit(torch.nn.Module):
def __init__(self, weight, bias, quant_type):
super().__init__()
self.weight = Params4bit(
weight.data,
requires_grad=False,
compress_statistics=True,
quant_type=quant_type,
)
self.compute_dtype = None
self.weight.cuda(weight.device)
self.bias = bias
def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
if getattr(self.weight, "quant_state", None) is None:
print(
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
)
inp_dtype = x.dtype
if self.compute_dtype is not None:
x = x.to(self.compute_dtype)
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
out = bnb.matmul_4bit(
x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
)
out = out.to(inp_dtype)
return out
================================================
FILE: backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py
================================================
from .loader import CompressedTensorsLoader
__all__ = ["CompressedTensorsLoader"]
================================================
FILE: backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py
================================================
from typing import Any, Dict, List, Union
from compressed_tensors import QuantizationConfig, QuantizationStatus
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization import (
QuantizationScheme,
QuantizationType,
find_name_or_class_matches,
)
from loguru import logger
from pydantic import ValidationError
from torch import nn
from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
Weights,
WeightsLoader,
)
# compressed-tensors can match modules as quantization targets. However,
# they need to be objects rather than classes or class names. Since we
# need to match `Linear` targets, make an instance that can be re-used.
_EMPTY_LINEAR: nn.Module = nn.Linear(0, 0)
class CompressedTensorsLoader(WeightsLoader):
"""Loader for checkpoints stored in the compressed-tensors format."""
def __init__(self, config: Dict[str, Any]):
quantization_config_raw = config.get("quantization_config")
if quantization_config_raw is None:
# `compression_config` was renamed to `quantization_config`; support
# retained for backward compatibility.
quantization_config_raw = config.get("compression_config")
if quantization_config_raw is None:
raise ValueError(
"Checkpoint does not have compressed-tensors configuration"
)
try:
quantization_config = QuantizationConfig.model_validate(
quantization_config_raw
)
except ValidationError as e:
raise ValueError("Cannot parse compressed-tensors configuration") from e
if quantization_config.quantization_status not in (
QuantizationStatus.COMPRESSED,
QuantizationStatus.FROZEN,
):
raise ValueError(
f"Model quantization was not finished, status was: {quantization_config.quantization_status}"
)
self.ignore = (
quantization_config.ignore if quantization_config.ignore is not None else []
)
self.loaders = self._get_target_loaders(quantization_config)
for target, loader in self.loaders.items():
log_once(
logger.info,
f"Using {loader} for compressed-tensors target '{target}'",
)
def get_weights(self, weights: Weights, prefix: str):
loader = self._lookup_loader(prefix)
return loader.get_weights(weights, prefix)
def get_weights_col_packed(
self,
weights: "Weights",
prefix: str,
block_sizes: Union[int, List[int]],
):
loader = self._lookup_loader(prefix)
return loader.get_weights_col_packed(weights, prefix, block_sizes)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
loader = self._lookup_loader(prefixes[0])
return loader.get_multi_weights_col(weights, prefixes, dim)
def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int):
loader = self._lookup_loader(prefixes[0])
return loader.get_multi_weights(weights, prefixes, dim)
def get_weights_row(self, weights: Weights, prefix: str):
loader = self._lookup_loader(prefix)
return loader.get_weights_row(weights, prefix)
def _get_target_loaders(
self, quantization_config: QuantizationConfig
) -> Dict[str, WeightsLoader]:
"""
A compressed-tensors checkpoint can use different quantizations
for different targets. This method returns a dictionary with a
loader per target.
"""
loaders: Dict[str, WeightsLoader] = {}
format = quantization_config.format
for group_name, group in quantization_config.config_groups.items():
# The group configuration can be a string, but does that ever
# happen in a serialized quantization config?
assert isinstance(group, QuantizationScheme)
loader = self._create_loader_for_group(format, group_name, group)
# A quantized parameter group can have multiple targets, add the
# loader for all the targets.
for target in group.targets:
if target in loaders:
raise ValueError(
f"Target '{target} has multiple configured loaders'"
)
loaders[target] = loader
return loaders
def _create_loader_for_group(
self, format: str, group_name: str, group: QuantizationScheme
) -> WeightsLoader:
"""
Find and create a loader for the group with the given quantization
scheme.
"""
# NOTE: we ignore group.output_activations because we don't support
# output quantization yet.
input_activations = group.input_activations
weights = group.weights
if (
format
in {
CompressionFormat.float_quantized.value,
CompressionFormat.naive_quantized.value,
}
and weights is not None
and weights.type == QuantizationType.FLOAT
and weights.num_bits == 8
):
# FP W8A8 or W8A16.
return W8ANFpLoader(input_activations=input_activations, weights=weights)
else:
raise ValueError(
f"Group '{group_name}' has unsupported compressed-tensors configurtion"
)
def _lookup_loader(self, prefix: str) -> WeightsLoader:
"""
Look up the loader to use for a given parameter name (prefix).
"""
if len(find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.ignore)) > 0:
return DefaultWeightsLoader(UnquantizedWeight)
# We currently only handle linear layers, so unconditionally pass
# a `Linear` instance.
targets = find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.loaders.keys())
if len(targets) == 0:
raise ValueError(
f"Cannot find compressed-tensors target for prefix: {prefix}"
)
return self.loaders[targets[0]]
================================================
FILE: backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py
================================================
from typing import List, Optional, Union
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
from text_generation_server.layers.fp8 import (
Fp8Weight,
_load_scalar_or_matrix_scale,
requantize_with_max_scale,
)
from text_generation_server.utils.weights import Weights, WeightsLoader
class W8ANFpLoader(WeightsLoader):
"""
Loader for W8A8/W8A16 FP compressed-tensors parameters.
"""
def __init__(
self,
*,
input_activations: Optional[QuantizationArgs],
weights: QuantizationArgs,
):
assert weights.type == QuantizationType.FLOAT and weights.num_bits == 8
# We ignore the `strategy` option which sets the scales to be
# per-tensor, per-channel or per-token. What scales are supported
# is dependent on the kernels used (e.g. cutlass can do tokenwise,
# Torch cannot, and FP8-Marlin does not quantize inputs at all).
# So, instead we try to use the best-possible configuration.
self.load_weight_scale = not weights.dynamic
self.load_input_scale = (
input_activations is not None and not input_activations.dynamic
)
self.force_w8a16 = (
input_activations is not None and input_activations.num_bits == 16
)
def __str__(self) -> str:
def scale_to_str(scale):
return "static" if scale else "dynamic"
quantization_type = f"W8A{16 if self.force_w8a16 else 8}"
return f"{self.__class__.__name__} ({quantization_type}, weight: {scale_to_str(self.load_weight_scale)}, input: {scale_to_str(self.load_input_scale)})"
def get_weights(self, weights: "Weights", prefix: str):
w = weights.get_tensor(f"{prefix}.weight")
weight_scale = None
if self.load_weight_scale:
weight_scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
logical_widths = [w.shape[0]]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
w = weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
weight_scale = None
if self.load_weight_scale:
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if weight_scale.numel() > 1:
weight_scale = weights.get_packed_sharded(
f"{prefix}.weight_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
logical_widths = [w.shape[0]]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
if input_scale.numel() > 1:
input_scale = weights.get_packed_sharded(
f"{prefix}.input_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
input_scale = input_scale.reshape(-1).max()
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
w = [
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
]
shapes = [x.shape for x in w]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
weight_scale = None
if self.load_weight_scale:
weight_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p, shape in zip(prefixes, shapes)
]
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
logical_widths = [x[0] for x in shapes]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes]
shapes = [x.shape for x in w]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
weight_scale = None
if self.load_weight_scale:
weight_scale = [
weights.get_tensor(f"{p}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(shape[0])
for p, shape in zip(prefixes, shapes)
]
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
logical_widths = [x[0] for x in shapes]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = [
weights.get_tensor(f"{p}.input_scale", to_dtype=False)
.reshape(-1)
.expand(shape[0])
for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_weights_row(self, weights: "Weights", prefix: str):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
weight_scale = None
if self.load_weight_scale:
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
logical_widths = [w.shape[0]]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
================================================
FILE: backends/gaudi/server/text_generation_server/layers/conv.py
================================================
from accelerate import init_empty_weights
import torch
@classmethod
def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
weight = weights.get_tensor(f"{prefix}.weight")
bias = weights.get_tensor(f"{prefix}.bias")
with init_empty_weights():
conv2d = cls(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
)
conv2d.weight = torch.nn.Parameter(weight)
conv2d.bias = torch.nn.Parameter(bias)
return conv2d
@classmethod
def load_conv2d_no_bias(
cls, prefix, weights, in_channels, out_channels, kernel_size, stride
):
weight = weights.get_tensor(f"{prefix}.weight")
with init_empty_weights():
conv2d = cls(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
)
conv2d.weight = torch.nn.Parameter(weight)
conv2d.bias = None
return conv2d
torch.nn.Conv2d.load = load_conv2d
torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias
================================================
FILE: backends/gaudi/server/text_generation_server/layers/exl2.py
================================================
from dataclasses import dataclass
from typing import List, Union
import torch
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
@dataclass
class Exl2Weight(Weight):
"""
Exllama2 exl2 quantized weights.
"""
q_weight: torch.Tensor
q_scale: torch.Tensor
q_invperm: torch.Tensor
q_scale_max: torch.Tensor
q_groups: torch.Tensor
def __post_init__(self):
self.q_scale_max /= 256
self.q_invperm = self.q_invperm.short()
@property
def device(self) -> torch.device:
return self.q_weight.device
def get_linear(self, bias: torch.Tensor):
from text_generation_server.layers.gptq import ExllamaQuantLinear
return ExllamaQuantLinear(self, bias)
class Exl2WeightsLoader(WeightsLoader):
"""Loader for exl2-quantized weights."""
def get_weights(self, weights: "Weights", prefix: str):
"""
Get weights at the given prefix and apply without tensor paralllism.
"""
try:
q_weight = weights.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = weights.get_tensor(f"{prefix}.q_scale")
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
q_groups = weights.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
raise RuntimeError("Column-packed weights are not supported for exl")
def get_weights_col(self, weights: Weights, prefix: str):
# Sharding is not yet supported, so we return the weights as-is.
return self.get_weights(weights, prefix)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
raise ValueError("get_multi_weights_col is not supported for exl2")
def get_weights_row(self, weights: Weights, prefix: str):
# Sharding is not yet supported, so we return the weights as-is.
return self.get_weights(weights, prefix)
================================================
FILE: backends/gaudi/server/text_generation_server/layers/fp8.py
================================================
from dataclasses import dataclass
from typing import Optional, Tuple, Type, Union, List
import torch
from text_generation_server.utils.weights import (
Weight,
WeightsLoader,
UnquantizedWeight,
Weights,
)
from vllm_hpu_extension.ops import scaled_fp8_quant
from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2
quant_dtype: torch.dtype = torch.float8_e4m3fn
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
if is_hpu_gaudi2():
FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max
def pad_weight(weight, block_size):
"""Pads a matrix to make its dimensions multiples of block_size."""
M, N = weight.shape[-2:]
block_size_m, block_size_n = block_size
pad_M = (block_size_m - M % block_size_m) % block_size_m
pad_N = (block_size_n - N % block_size_n) % block_size_n
if pad_M == 0 and pad_N == 0:
return weight, M, N # No padding needed
padded_weight = torch.nn.functional.pad(
weight, (0, pad_N, 0, pad_M), mode="constant", value=0
)
return padded_weight, M, N # Return original dimensions for unpadding
def unpad_weight(weight, original_M, original_N, keep_first_dim=False):
"""Removes padding from the matrix to restore its original shape."""
if (weight.shape[-2] == original_M) and (weight.shape[-1] == original_N):
return weight
if keep_first_dim:
return weight[:, :original_M, :original_N]
else:
return weight[:original_M, :original_N]
def pad_block_fp8_weight_naive(weight, weight_scale, block_size):
assert len(block_size) == 2
block_size_m, block_size_n = block_size
weight_scale_m, weight_scale_n = weight_scale.shape[-2:]
weight, orig_M, orig_N = pad_weight(weight, block_size)
M, N = weight.shape[-2:]
assert weight_scale_m == M // block_size_m
assert weight_scale_n == N // block_size_n
return weight, orig_M, orig_N
def dynamic_quant(data, single_scale=False):
if single_scale:
scale = ((torch.abs(data)).max() + 1e-8) / FP8_MAX
else:
scale = ((torch.abs(data)).max(dim=-1).values + 1e-8) / FP8_MAX
scale = scale.unsqueeze(-1)
data_fp8 = torch.ops.hpu.cast_to_fp8_v2(
data, 1.0 / scale, False, False, torch.float8_e4m3fn
)[0]
return data_fp8, scale.float()
def dequant_block_fp8_weight_naive(
weight,
weight_scale,
block_size,
dtype=torch.bfloat16,
original_M=None,
original_N=None,
do_unpad=False,
):
if weight_scale is None:
return weight
assert len(block_size) == 2
weight_shape_len = len(weight.shape)
block_size_m, block_size_n = block_size
# mul scale
if weight_shape_len == 2:
weight_scale_m, weight_scale_n = weight_scale.shape
weight_scale = weight_scale.view(weight_scale_m, 1, weight_scale_n, 1)
weight = weight.view(weight_scale_m, block_size_m, weight_scale_n, block_size_n)
if is_hpu_gaudi2():
fake_weight = weight.cpu().to(dtype).to(weight.device)
dequant_weight = fake_weight * weight_scale.to(dtype)
else:
dequant_weight = weight.to(dtype) * weight_scale.to(dtype)
dequant_weight = dequant_weight.view(
weight_scale_m * block_size_m, weight_scale_n * block_size_n
)
keep_first_dim = False
elif weight_shape_len == 3:
fd, weight_scale_m, weight_scale_n = weight_scale.shape
weight_scale = weight_scale.view(fd, weight_scale_m, 1, weight_scale_n, 1)
weight = weight.view(
fd, weight_scale_m, block_size_m, weight_scale_n, block_size_n
)
if is_hpu_gaudi2():
fake_weight = weight.cpu().to(dtype).to(weight.device)
dequant_weight = fake_weight * weight_scale.to(dtype)
else:
dequant_weight = weight.to(dtype) * weight_scale.to(dtype)
dequant_weight = dequant_weight.view(
fd, weight_scale_m * block_size_m, weight_scale_n * block_size_n
)
keep_first_dim = True
else:
raise ValueError("Only support original weight shape is either 2 or 3")
if do_unpad:
dequant_weight = unpad_weight(
dequant_weight, original_M, original_N, keep_first_dim=keep_first_dim
)
return dequant_weight
def apply_block_fp8_linear_hpu_dynamic(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
x_fp8, x_scale = dynamic_quant(input_2d)
output = torch.ops.hpu.fp8_gemm_v2(
x_fp8,
False,
weight,
True,
None,
torch.bfloat16,
x_scale,
weight_scale,
None,
False,
)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
"""
Return an FP8 linear `Module` that is compatible with the current system.
"""
# On other systems let Torch decide if the hardware supports FP8.
return Fp8Linear
def normalize_e4m3fn_to_native_float8(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return weight, weight_scale, input_scale
def per_tensor_dequantize(
tensor: torch.Tensor,
inv_scale: Union[float, torch.Tensor],
dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
device = tensor.device
dtype = torch.bfloat16
if is_hpu_gaudi2():
# dequant on cpu to avoid nan on gaudi2
tensor = tensor.to("cpu")
fake_qweight = tensor.to(dtype).to(device)
dq_weight = fake_qweight * inv_scale
return dq_weight
def requantize_with_max_scale(
weight: torch.Tensor,
weight_scale: torch.Tensor,
logical_widths: int,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max()
if is_hpu_gaudi2():
max_w_scale = max_w_scale * get_hpu_gaudi2_scale_factor()
start = 0
for idx, logical_width in enumerate(logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(
weight[start:end, :], weight_scale[start:end, :], dtype
)
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
weight_dq, max_w_scale
)
start = end
return weight, max_w_scale_normalized
def fp8_quantize(
weight: torch.Tensor,
scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[torch.Tensor] = None,
qdtype: torch.dtype = torch.float8_e4m3fn,
scalar: bool = False,
):
"""
This function returns a reciprocal of the scale, so that a tensor can be unscaled
by multiplying it with the returned scale. If a scale is given through the `scale`
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
be used without modification).
"""
shape = weight.shape
qweight, scale = scaled_fp8_quant(
weight.reshape(-1, shape[-1]),
scale=scale,
scale_ub=scale_upper_bound,
# TODO: don't do this when we have to use the Torch kernel.
use_per_token_if_dynamic=not scalar,
)
return qweight.reshape(shape), scale
class HybridFP8UnquantLoader(WeightsLoader):
"""Weight loader that loads FP8 and unquantized Torch tensors."""
def __init__(
self,
activation_scale_ub: Optional[float],
to_fp8: bool,
weight_block_size: Optional[List[int]] = None,
):
self.activation_scale_ub = activation_scale_ub
self.to_fp8 = to_fp8
self.weight_block_size = weight_block_size
def get_weights(self, weights: "Weights", prefix: str):
w = weights.get_tensor(f"{prefix}.weight")
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
scale = weights.get_tensor(f"{prefix}.weight_scale_inv")
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
# FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
scale = scale.reshape(-1).expand(w.shape[0])
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = (
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
.reshape(-1)
.max()
)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
w = weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if scale.numel() > 1:
scale = weights.get_packed_sharded(
f"{prefix}.weight_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
scale = scale.reshape(-1).expand(w.shape[0])
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
)
if input_scale.numel() > 1:
input_scale = weights.get_packed_sharded(
f"{prefix}.input_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
input_scale = input_scale.reshape(-1).max()
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
w = [
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
]
shapes = [x.shape for x in w]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
scale = [
weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False)
for p in prefixes
]
scale = torch.cat(scale, dim=dim)
scale = scale.to(weights.device)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p, shape in zip(prefixes, shapes)
]
scale = torch.cat(scale, dim=0).reshape(-1)
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes]
shapes = [x.shape for x in w]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
scale = [
weights.get_tensor(f"{p}.weight_scale_inv", to_device=False)
for p in prefixes
]
scale = torch.cat(scale, dim=dim)
scale = scale.to(weights.device)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
scale = [
weights.get_tensor(f"{p}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(shape[0])
for p, shape in zip(prefixes, shapes)
]
scale = torch.cat(scale, dim=0).reshape(-1)
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = [
weights.get_tensor(f"{p}.input_scale", to_dtype=False).reshape(-1)
for p in prefixes
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
def get_weights_row(self, weights: "Weights", prefix: str):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
# XXX: Yes the weights is named scale_inv, but corresponds to scale it seems.
scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = (
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
.reshape(-1)
.max()
)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
@dataclass
class Fp8Weight(Weight):
weight: torch.Tensor
dtype: torch.dtype
weight_scale: Optional[torch.Tensor] = None
input_scale: Optional[torch.Tensor] = None
activation_scale_ub: Optional[float] = None
force_w8a16: bool = False
weight_block_size: Optional[List[int]] = None
def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None:
return get_fp8_linear(force_w8a16=self.force_w8a16).from_unquant(
self.weight, bias, self.dtype
)
# This is not checked by the fbgemm kernels, but they require contiguous
# memory. Can be non-contiguous when we e.g. expand from scalars.
self.weight_scale = self.weight_scale.contiguous()
return get_fp8_linear(force_w8a16=self.force_w8a16).from_fp8(
weight=self.weight,
scale=self.weight_scale,
dtype=self.dtype,
bias=bias,
input_scale=self.input_scale,
scale_upper_bound=self.activation_scale_ub,
weight_block_size=self.weight_block_size,
)
class Fp8Linear(torch.nn.Module):
_device_identity_cache = {}
def __init__(
self,
qweight: torch.Tensor,
scale: torch.Tensor,
dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[float] = None,
weight_block_size: Optional[List[int]] = None,
) -> None:
super().__init__()
self.dtype = dtype
self.qweight = qweight
self.scale = scale.float()
self.input_scale = input_scale.float() if input_scale is not None else None
self.weight_block_size = weight_block_size
self.scale_upper_bound = scale_upper_bound
self.bias = bias if bias is not None else None
@classmethod
def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight, scalar=True)
return cls(
qweight=qweight,
scale=scale,
dtype=dtype,
bias=bias,
input_scale=None,
scale_upper_bound=None,
)
@classmethod
def from_fp8(
cls,
weight: torch.Tensor,
scale: torch.Tensor,
dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
**kwargs,
) -> "Fp8Linear":
input_scale = kwargs.get("input_scale", None)
scale_upper_bound = kwargs.get("scale_upper_bound", None)
weight_block_size = kwargs.get("weight_block_size", None)
if weight_block_size is not None:
weight, orig_M, orig_N = pad_block_fp8_weight_naive(
weight, scale, weight_block_size
)
weight, scale = dynamic_quant(
dequant_block_fp8_weight_naive(
weight,
scale,
weight_block_size,
original_M=orig_M,
original_N=orig_N,
do_unpad=True,
)
)
scale = scale.squeeze(-1)
return cls(
qweight=weight,
scale=scale,
input_scale=input_scale,
scale_upper_bound=scale_upper_bound,
bias=bias,
dtype=dtype,
weight_block_size=weight_block_size,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.weight_block_size is not None or self.input_scale is None:
return apply_block_fp8_linear_hpu_dynamic(
input, self.qweight, self.scale, self.input_scale, self.bias
)
x_fp8 = torch.ops.hpu.cast_to_fp8_v2(
input, 1.0 / self.input_scale, False, False, torch.float8_e4m3fn
)[0]
return torch.ops.hpu.fp8_gemm_v2(
A=x_fp8,
trans_A=False,
B=self.qweight,
trans_B=True,
D=None,
out_dtype=input.dtype,
A_scale_inv=self.input_scale,
B_scale_inv=self.scale,
bias=self.bias,
accumulate=False,
)
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
scale = weights.get_tensor(prefix, to_dtype=False)
if scale.numel() > 1:
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
return scale.reshape(-1).expand(shape[0])
================================================
FILE: backends/gaudi/server/text_generation_server/layers/gptq/__init__.py
================================================
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
from loguru import logger
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import (
Weight,
Weights,
WeightsLoader,
DefaultWeightsLoader,
)
from .hpu import QuantLinear
@dataclass
class GPTQWeight(Weight):
qweight: torch.Tensor
qzeros: torch.Tensor
scales: torch.Tensor
g_idx: Optional[torch.Tensor]
bits: int
groupsize: int
use_awq_kernel: bool
use_exllama: bool
def __post_init__(self):
if self.scales.dtype == torch.float:
self.scales = self.scales.half()
@property
def device(self) -> torch.device:
return self.qweight.device
def get_linear(self, bias: torch.Tensor):
if self.use_awq_kernel:
try:
from text_generation_server.layers.awq.quantize import WQLinear
return WQLinear(
w_bit=self.bits,
group_size=self.groupsize,
qweight=self.qweight,
qzeros=self.qzeros,
scales=self.scales,
bias=bias,
)
except ImportError:
raise NotImplementedError(
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
)
else:
return QuantLinear(
self.qweight,
self.qzeros,
self.scales,
self.g_idx,
bias,
self.bits,
self.groupsize,
)
class GPTQWeightsLoader(WeightsLoader):
"""
Loader for GPTQ- and AWQ-quantized weights.
"""
def __init__(
self,
*,
bits: int,
desc_act: bool,
groupsize: int,
quant_method: str,
quantize: str,
sym: bool,
modules_to_not_convert: List[str],
):
self.bits = bits
self.desc_act = desc_act
self.groupsize = groupsize
self.quant_method = quant_method
self.quantize = quantize
self.sym = sym
self.modules_to_not_convert = modules_to_not_convert
def is_layer_skipped_quantization(
self, prefix: str, modules_to_not_convert: List[str]
):
return any(module_name in prefix for module_name in modules_to_not_convert)
def get_weights(self, weights: Weights, prefix: str):
self._get_gptq_params(weights)
use_exllama = True
if self.bits != 4:
use_exllama = False
if self.desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
return DefaultWeightsLoader.get_weights(weights, prefix)
try:
qweight = weights.get_tensor(f"{prefix}.qweight")
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if self.quantize == "gptq" and self.quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.g_idx")
else:
g_idx = None
qzeros = weights.get_tensor(f"{prefix}.qzeros")
scales = weights.get_tensor(f"{prefix}.scales")
if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0]
if self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_exllama=use_exllama,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
return DefaultWeightsLoader.get_weights_col_packed(
weights, prefix, block_sizes
)
try:
qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
)
scales = weights.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes
)
scales = scales.to(dtype=weights.dtype)
self._get_gptq_params(weights)
qzeros = weights.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
if self.quantize == "gptq" and self.quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.g_idx")
elif self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq",
use_exllama=False,
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim)
try:
qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
)
scales = torch.cat(
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
self._get_gptq_params(weights)
qzeros = torch.cat(
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act
if self.quantize == "gptq" and self.quant_method == "gptq":
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq",
use_exllama=use_exllama,
)
def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int):
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
return DefaultWeightsLoader.get_multi_weights(weights, prefixes, dim)
try:
qweight = torch.cat(
[weights.get_tensor(f"{p}.qweight") for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
)
scales = torch.cat([weights.get_tensor(f"{p}.scales") for p in prefixes], dim=1)
self._get_gptq_params(weights)
qzeros = torch.cat([weights.get_tensor(f"{p}.qzeros") for p in prefixes], dim=1)
use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act
if self.quantize == "gptq" and self.quant_method == "gptq":
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
).to(dtype=torch.int32)
else:
g_idx = None
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq",
use_exllama=use_exllama,
)
def get_weights_row(self, weights: Weights, prefix: str):
self._get_gptq_params(weights)
use_exllama = True
desc_act = self.desc_act
if self.bits != 4:
use_exllama = False
if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):
return DefaultWeightsLoader.get_weights_row(weights, prefix)
if self.desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if self.quantize == "gptq" and self.quant_method == "gptq":
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
else:
g_idx = None
if weights.process_group.size() > 1:
if g_idx is not None:
if (
not torch.equal(
# Remove g_idx[0] to adapt the check with TP>1.
(g_idx - g_idx[0]).cpu(),
torch.tensor(
[i // self.groupsize for i in range(g_idx.shape[0])],
dtype=torch.int32,
),
)
and not (g_idx == 0).all()
):
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
use_exllama = False
desc_act = True
from text_generation_server.layers.gptq import (
GPTQWeight,
)
if not desc_act and self.groupsize != -1:
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
if g_idx is not None:
# qzeros, scales sharded, and g_idx must be adjusted accordingly
g_idx = g_idx - g_idx[0]
else:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
scales = weights.get_tensor(f"{prefix}.scales")
if self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq",
use_exllama=use_exllama,
)
def _get_gptq_params(self, weights: Weights):
if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
self.bits = weights.get_tensor("gptq_bits").item()
self.groupsize = weights.get_tensor("gptq_groupsize").item()
self.desc_act = False
# `server quantize` used asymmetric quantization unconditionally
# before the `gptq_sym` setting tensor was added.
self.sym = (
weights.get_tensor("gptq_sym").item()
if weights.has_tensor("gptq_sym")
else False
)
self.quant_method = "gptq"
================================================
FILE: backends/gaudi/server/text_generation_server/layers/gptq/hpu.py
================================================
import math
import numpy as np
import torch
import torch.nn as nn
try:
convert_from_uint4 = torch.ops.hpu.convert_from_uint4
except Exception as e:
hpu_import_exception = e
def error_raiser_hpu(*args, **kwargs):
raise ValueError(
f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
)
convert_from_uint4 = error_raiser_hpu
def pack_tensor(input, bits=4):
normal = input.to(torch.int32)
q = torch.zeros((normal.shape[0], normal.shape[1] // 32 * bits), dtype=torch.int32)
i = 0
col = 0
while col < q.shape[1]:
for j in range(i, i + (32 // bits)):
q[:, col] |= normal[:, j] << (bits * (j - i))
i += 32 // bits
col += 1
q = q.to(torch.int32)
return q
class QuantLinear(nn.Module):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
super().__init__()
self.register_buffer("qweight", qweight)
self.register_buffer("qzeros", qzeros)
self.register_buffer("scales", scales)
self.register_buffer("g_idx", g_idx)
if bias is not None:
self.register_buffer("bias", bias)
else:
self.bias = None
if bits not in [4]:
raise NotImplementedError("Only 4 bits are supported.")
self.bits = bits
self.maxq = 2**self.bits - 1
self.groupsize = groupsize
self.outfeatures = qweight.shape[1]
self.infeatures = qweight.shape[0] * 32 // bits
self.wf = torch.tensor(
list(range(0, 32, self.bits)), dtype=torch.int32
).unsqueeze(0)
self._preprocessing()
def unpack_zeros_from_cuda_old_format(self):
zeros = torch.bitwise_right_shift(
torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
self.wf.unsqueeze(0),
).to(torch.int16 if self.bits == 8 else torch.int8)
zeros = zeros + 1
zeros = torch.bitwise_and(zeros, (2**self.bits) - 1).to(
self.scales.dtype
) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important.
zeros = zeros.reshape(-1, zeros.shape[1] * zeros.shape[2])
return zeros
def unpack_weight_from_cuda_old_format(self):
weight = torch.bitwise_right_shift(
torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),
self.wf.unsqueeze(-1),
).to(torch.int16 if self.bits == 8 else torch.int8)
weight = torch.bitwise_and(weight, (2**self.bits) - 1)
weight = weight.reshape((weight.shape[0] * weight.shape[1], weight.shape[2]))
return weight
def _preprocessing(self):
orig_device = self.qweight.device
self.qweight = self.qweight.cpu()
weight = self.unpack_weight_from_cuda_old_format()
new_qweight = pack_tensor(weight)
self.qweight = new_qweight.to(orig_device)
# TODO: Support group indexing and remove the check
columns = self.qweight.shape[0]
g_idx_trivial = [i // self.groupsize for i in range(columns)]
g_idx_trivial = torch.tensor(
g_idx_trivial, dtype=torch.int32, device=self.g_idx.device
)
sort_zeros = not (torch.equal(self.g_idx, g_idx_trivial))
self.qzeros = self.qzeros.cpu()
zeros = self.unpack_zeros_from_cuda_old_format()
if sort_zeros:
zeros_group_1 = torch.zeros(
(self.infeatures, self.outfeatures),
dtype=zeros.dtype,
device=zeros.device,
)
scales = self.scales.cpu()
scale_group_1 = torch.zeros(
(self.infeatures, self.outfeatures),
dtype=scales.dtype,
device=scales.device,
)
for i in range(self.infeatures):
zeros_group_1[i] = zeros[self.g_idx[i]]
scale_group_1[i] = self.scales[self.g_idx[i]]
self.qzeros = pack_tensor(zeros_group_1).to(orig_device)
self.scales = scale_group_1.to(orig_device)
self.groupsize = 1
self.g_idx = None
else:
new_qzeros = pack_tensor(zeros)
self.qzeros = new_qzeros.to(orig_device)
@classmethod
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
if bits not in [4]:
raise NotImplementedError("Only 4 bits are supported.")
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
qzeros = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
dtype=torch.int32,
)
scales = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
)
g_idx = torch.tensor(
[i // groupsize for i in range(infeatures)], dtype=torch.int32
)
if bias:
bias = torch.zeros((outfeatures), dtype=torch.float16)
else:
bias = None
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
def pack(self, linear, scales, zeros, g_idx=None):
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
/ self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
x = x.reshape(-1, x.shape[-1])
weight = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
out = torch.matmul(x, weight)
out = out.reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out
================================================
FILE: backends/gaudi/server/text_generation_server/layers/gptq/quantize.py
================================================
import time
import torch.nn as nn
import math
import json
import os
import torch
import transformers
from texttable import Texttable
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
from huggingface_hub import HfApi
from accelerate import init_empty_weights
from text_generation_server.utils import initialize_torch_distributed, Weights
from text_generation_server.utils.hub import weight_files
from text_generation_server.layers.gptq import QuantLinear
from loguru import logger
from typing import Optional
from text_generation_server.layers.gptq.utils import torch_snr_error
from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
DEV = torch.device("cuda:0")
class Quantizer(nn.Module):
def __init__(self, shape=1):
super(Quantizer, self).__init__()
self.register_buffer("maxq", torch.tensor(0))
self.register_buffer("scale", torch.zeros(shape))
self.register_buffer("zero", torch.zeros(shape))
def configure(
self,
bits,
perchannel=False,
sym=True,
mse=False,
norm=2.4,
grid=100,
maxshrink=0.8,
trits=False,
):
self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel
self.sym = sym
self.mse = mse
self.norm = norm
self.grid = grid
self.maxshrink = maxshrink
if trits:
self.maxq = torch.tensor(-1)
self.scale = torch.zeros_like(self.scale)
def _quantize(self, x, scale, zero, maxq):
if maxq < 0:
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
return scale * (q - zero)
def find_params(self, x, weight=False):
dev = x.device
self.maxq = self.maxq.to(dev)
shape = x.shape
if self.perchannel:
if weight:
x = x.flatten(1)
else:
if len(shape) == 4:
x = x.permute([1, 0, 2, 3])
x = x.flatten(1)
if len(shape) == 3:
x = x.reshape((-1, shape[-1])).t()
if len(shape) == 2:
x = x.t()
else:
x = x.flatten().unsqueeze(0)
tmp = torch.zeros(x.shape[0], device=dev)
xmin = torch.minimum(x.min(1)[0], tmp)
xmax = torch.maximum(x.max(1)[0], tmp)
if self.sym:
xmax = torch.maximum(torch.abs(xmin), xmax)
tmp = xmin < 0
if torch.any(tmp):
xmin[tmp] = -xmax[tmp]
tmp = (xmin == 0) & (xmax == 0)
xmin[tmp] = -1
xmax[tmp] = +1
if self.maxq < 0:
self.scale = xmax
self.zero = xmin
else:
self.scale = (xmax - xmin) / self.maxq
if self.sym:
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
else:
self.zero = torch.round(-xmin / self.scale)
if self.mse:
best = torch.full([x.shape[0]], float("inf"), device=dev)
for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid
xmin1 = p * xmin
xmax1 = p * xmax
scale1 = (xmax1 - xmin1) / self.maxq
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
q = self._quantize(
x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq
)
q -= x
q.abs_()
q.pow_(self.norm)
err = torch.sum(q, 1)
tmp = err < best
if torch.any(tmp):
best[tmp] = err[tmp]
self.scale[tmp] = scale1[tmp]
self.zero[tmp] = zero1[tmp]
if not self.perchannel:
if weight:
tmp = shape[0]
else:
tmp = shape[1] if len(shape) != 3 else shape[2]
self.scale = self.scale.repeat(tmp)
self.zero = self.zero.repeat(tmp)
if weight:
shape = [-1] + [1] * (len(shape) - 1)
self.scale = self.scale.reshape(shape)
self.zero = self.zero.reshape(shape)
return
if len(shape) == 4:
self.scale = self.scale.reshape((1, -1, 1, 1))
self.zero = self.zero.reshape((1, -1, 1, 1))
if len(shape) == 3:
self.scale = self.scale.reshape((1, 1, -1))
self.zero = self.zero.reshape((1, 1, -1))
if len(shape) == 2:
self.scale = self.scale.unsqueeze(0)
self.zero = self.zero.unsqueeze(0)
def quantize(self, x):
if self.ready():
return self._quantize(x, self.scale, self.zero, self.maxq)
return x
def enabled(self):
return self.maxq > 0
def ready(self):
return torch.all(self.scale != 0)
class GPTQ:
def __init__(self, layer, observe=False):
self.layer = layer
self.dev = self.layer.weight.device
W = layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
self.rows = W.shape[0]
self.columns = W.shape[1]
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
self.nsamples = 0
self.quantizer = Quantizer()
self.observe = observe
def add_batch(self, inp, out):
# Hessian H = 2 X XT + λ I
if self.observe:
self.inp1 = inp
self.out1 = out
else:
self.inp1 = None
self.out1 = None
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
tmp = inp.shape[0]
if isinstance(self.layer, nn.Linear) or isinstance(
self.layer, transformers.Conv1D
):
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()
if isinstance(self.layer, nn.Conv2d):
unfold = nn.Unfold(
self.layer.kernel_size,
dilation=self.layer.dilation,
padding=self.layer.padding,
stride=self.layer.stride,
)
inp = unfold(inp)
inp = inp.permute([1, 0, 2])
inp = inp.flatten(1)
self.H *= self.nsamples / (self.nsamples + tmp)
self.nsamples += tmp
# inp = inp.float()
inp = math.sqrt(2 / self.nsamples) * inp.float()
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
self.H += inp.matmul(inp.t())
def print_loss(self, name, q_weight, weight_error, timecost):
table = Texttable()
length = 28
name = (
(name + " " * (length - len(name)))
if len(name) <= length
else name[:length]
)
table.header(["name", "weight_error", "fp_inp_SNR", "q_inp_SNR", "time"])
# assign weight
self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(
self.layer.weight.data.dtype
)
if self.inp1 is not None:
# quantize input to int8
quantizer = Quantizer()
quantizer.configure(8, perchannel=False, sym=True, mse=False)
quantizer.find_params(self.inp1)
q_in = quantizer.quantize(self.inp1).type(torch.float16)
q_out = self.layer(q_in)
# get kinds of SNR
q_SNR = torch_snr_error(q_out, self.out1).item()
fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item()
else:
q_SNR = "-"
fp_SNR = "-"
table.add_row([name, weight_error, fp_SNR, q_SNR, timecost])
print(table.draw().split("\n")[-2])
def fasterquant(
self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, name=""
):
self.layer.to(self.dev)
W = self.layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
W = W.float()
tick = time.time()
if not self.quantizer.ready():
self.quantizer.find_params(W, weight=True)
H = self.H
if not self.observe:
del self.H
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0
if act_order:
perm = torch.argsort(torch.diag(H), descending=True)
W = W[:, perm]
H = H[perm][:, perm]
Losses = torch.zeros_like(W)
Q = torch.zeros_like(W)
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.dev)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
try:
H = torch.linalg.cholesky(H, upper=True)
except Exception:
# Addition because Falcon fails on h_to_4h
H = torch.linalg.cholesky(
H + 1e-5 * torch.eye(H.shape[0]).to(H.device), upper=True
)
Hinv = H
g_idx = []
scale = []
zero = []
now_idx = 1
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1
W1 = W[:, i1:i2].clone()
Q1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]
for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]
if groupsize != -1:
if (i1 + i) % groupsize == 0:
self.quantizer.find_params(
W[:, (i1 + i) : (i1 + i + groupsize)], weight=True
)
if ((i1 + i) // groupsize) - now_idx == -1:
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
now_idx += 1
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d**2
err1 = (w - q) / d
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
Err1[:, i] = err1
Q[:, i1:i2] = Q1
Losses[:, i1:i2] = Losses1 / 2
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
torch.cuda.synchronize()
error = torch.sum(Losses).item()
groupsize = groupsize if groupsize != -1 else self.columns
g_idx = [i // groupsize for i in range(self.columns)]
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
if act_order:
invperm = torch.argsort(perm)
Q = Q[:, invperm]
g_idx = g_idx[invperm]
if isinstance(self.layer, transformers.Conv1D):
Q = Q.t()
self.print_loss(
name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick)
)
if scale == []:
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
scale = torch.cat(scale, dim=1)
zero = torch.cat(zero, dim=1)
return scale, zero, g_idx, error
def free(self):
self.inp1 = None
self.out1 = None
self.H = None
self.Losses = None
self.Trace = None
torch.cuda.empty_cache()
def get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code):
from datasets import load_dataset
traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
try:
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code
)
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code
)
trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
return trainloader, testenc
def get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code):
from datasets import load_dataset
traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
valdata = load_dataset("ptb_text_only", "penn_treebank", split="validation")
try:
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code
)
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code
)
trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt")
testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt")
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
return trainloader, testenc
def get_c4(nsamples, seed, seqlen, model_id, trust_remote_code):
from datasets import load_dataset
traindata = load_dataset(
"allenai/c4",
"allenai--c4",
data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
split="train",
use_auth_token=False,
)
valdata = load_dataset(
"allenai/c4",
"allenai--c4",
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
split="validation",
use_auth_token=False,
)
try:
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code
)
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code
)
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
while True:
i = random.randint(0, len(traindata) - 1)
trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
if trainenc.input_ids.shape[1] >= seqlen:
break
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
import random
random.seed(0)
valenc = []
for _ in range(256):
while True:
i = random.randint(0, len(valdata) - 1)
tmp = tokenizer(valdata[i]["text"], return_tensors="pt")
if tmp.input_ids.shape[1] >= seqlen:
break
i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
valenc.append(tmp.input_ids[:, i:j])
valenc = torch.hstack(valenc)
class TokenizerWrapper:
def __init__(self, input_ids):
self.input_ids = input_ids
valenc = TokenizerWrapper(valenc)
return trainloader, valenc
def get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code):
from datasets import load_dataset
traindata = load_dataset("ptb_text_only", "penn_treebank", split="train")
testdata = load_dataset("ptb_text_only", "penn_treebank", split="test")
try:
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code
)
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code
)
trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt")
testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt")
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
return trainloader, testenc
def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code):
from datasets import load_dataset
traindata = load_dataset(
"allenai/c4",
"allenai--c4",
data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
split="train",
)
valdata = load_dataset(
"allenai/c4",
"allenai--c4",
data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
split="validation",
)
try:
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code
)
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code
)
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
while True:
i = random.randint(0, len(traindata) - 1)
trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
if trainenc.input_ids.shape[1] >= seqlen:
break
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt")
valenc = valenc.input_ids[:, : (256 * seqlen)]
class TokenizerWrapper:
def __init__(self, input_ids):
self.input_ids = input_ids
valenc = TokenizerWrapper(valenc)
return trainloader, valenc
def get_loaders(
name, nsamples=128, seed=0, seqlen=2048, model_id="", trust_remote_code=False
):
if "wikitext2" in name:
return get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code)
if "ptb" in name:
if "new" in name:
return get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code)
return get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code)
if "c4" in name:
if "new" in name:
return get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code)
return get_c4(nsamples, seed, seqlen, model_id, trust_remote_code)
def find_layers(module, layers=(nn.Conv2d, nn.Linear), name=""):
# Skip last lm_head linear
# Need isintance Falcon is inheriting Linear.
if isinstance(module, layers) and "lm_head" not in name:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(
find_layers(
child, layers=layers, name=name + "." + name1 if name != "" else name1
)
)
return res
@torch.no_grad()
def sequential(
model,
dataloader,
dev,
nsamples,
bits,
groupsize,
*,
hooks,
percdamp=0.01,
sym: bool = False,
act_order: bool = False,
):
print("Starting ...")
use_cache = model.config.use_cache
model.config.use_cache = False
try:
layers = model.model.layers
prefix = "model.layers"
except Exception:
layers = model.transformer.h
prefix = "transformer.h"
dtype = next(iter(model.parameters())).dtype
inps = torch.zeros(
(nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
)
cache = {"i": 0}
extra = {}
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, inp, **kwargs):
inps[cache["i"]] = inp
cache["i"] += 1
extra.update(kwargs.copy())
raise ValueError
layers[0] = Catcher(layers[0])
for batch in dataloader:
try:
model(batch[0].cuda())
except ValueError:
pass
layers[0] = layers[0].module
# layers[0] = layers[0].cpu()
# model.model.embed_tokens = model.model.embed_tokens.cpu()
# model.model.norm = model.model.norm.cpu()
torch.cuda.empty_cache()
for hook in hooks:
hook.remove()
outs = torch.zeros_like(inps)
extra = {
k: v.to(dev) if isinstance(v, torch.Tensor) else v for k, v in extra.items()
}
print("Ready.")
quantizers = {}
for i in range(len(layers)):
print(f"Quantizing layer {i+1}/{len(layers)}..")
print("+------------------+--------------+------------+-----------+-------+")
print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |")
print("+==================+==============+============+===========+=======+")
layer = layers[i]
layer.load()
full = find_layers(layer)
sequential = [list(full.keys())]
for names in sequential:
subset = {n: full[n] for n in names}
gptq = {}
for name in subset:
gptq[name] = GPTQ(subset[name])
gptq[name].quantizer.configure(
bits, perchannel=True, sym=sym, mse=False
)
pass
def add_batch(name):
nonlocal gptq
def tmp(_, inp, out):
gptq[name].add_batch(inp[0].data, out.data)
return tmp
handles = []
for name in subset:
handles.append(subset[name].register_forward_hook(add_batch(name)))
for j in range(nsamples):
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
for h in handles:
h.remove()
for name in subset:
scale, zero, g_idx, error = gptq[name].fasterquant(
percdamp=percdamp,
groupsize=groupsize,
act_order=act_order,
name=name,
)
quantizers[f"{prefix}.{i}.{name}"] = (
gptq[name].quantizer.cpu(),
scale.cpu(),
zero.cpu(),
g_idx.cpu(),
bits,
groupsize,
)
gptq[name].free()
for j in range(nsamples):
outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]
layer.unload()
del layer
del gptq
torch.cuda.empty_cache()
inps, outs = outs, inps
print("+------------------+--------------+------------+-----------+-------+")
print("\n")
model.config.use_cache = use_cache
return quantizers
def make_quant_linear(module, names, bits, groupsize, name=""):
if isinstance(module, QuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + "." + attr if name != "" else attr
if name1 in names:
delattr(module, attr)
setattr(
module,
attr,
QuantLinear.new(
bits,
groupsize,
tmp.in_features,
tmp.out_features,
tmp.bias is not None,
),
)
for name1, child in module.named_children():
make_quant_linear(
child, names, bits, groupsize, name + "." + name1 if name != "" else name1
)
# TODO: perform packing on GPU
def pack(model, quantizers, bits, groupsize):
layers = find_layers(model)
layers = {n: layers[n] for n in quantizers}
make_quant_linear(model, quantizers, bits, groupsize)
qlayers = find_layers(model, (QuantLinear,))
print("Packing ...")
for name in qlayers:
print(name)
quantizers[name], scale, zero, g_idx, _, _ = quantizers[name]
qlayers[name].pack(layers[name], scale, zero, g_idx)
print("Done.")
return model
def setdeepattr(module, full_name, tensor):
current = module
tokens = full_name.split(".")
for token in tokens[:-1]:
current = getattr(current, token)
setattr(current, tokens[-1], tensor)
def getdeepattr(module, full_name):
current = module
tokens = full_name.split(".")
for token in tokens:
current = getattr(current, token)
return current
def load_weights_pre_hook(module_name, weights, recursive=False):
def inner(module, args):
print(f"Pre hook {module_name}")
local_params = {}
for k, v in module.named_parameters():
if not recursive and k.count(".") != 1:
continue
local_params[k] = v
for k, v in module.named_buffers():
if not recursive and k.count(".") != 1:
continue
local_params[k] = v
for local_param in local_params:
current_tensor = getdeepattr(module, local_param)
if current_tensor.device == torch.device("meta"):
# print(f"Loading {local_param}")
if module_name:
tensor_name = f"{module_name}.{local_param}"
else:
tensor_name = local_param
tensor = weights.get_tensor(tensor_name)
setdeepattr(module, local_param, nn.Parameter(tensor))
else:
tensor = current_tensor.to(device=torch.device("cuda:0"))
if current_tensor.requires_grad:
tensor = nn.Parameter(tensor)
setdeepattr(module, local_param, tensor)
return inner
def load_weights_post_hook(module_name, weights, recursive=False):
def inner(module, args, output):
print(f"Post hook {module_name}")
local_params = {}
for k, v in module.named_parameters():
if not recursive and k.count(".") != 1:
continue
local_params[k] = v
for k, v in module.named_buffers():
if not recursive and k.count(".") != 1:
continue
local_params[k] = v
for local_param in local_params:
# print(f"Unloading {local_param}")
current_tensor = getdeepattr(module, local_param)
setdeepattr(
module,
local_param,
nn.Parameter(current_tensor.to(device=torch.device("cpu"))),
)
return output
return inner
def quantize(
model_id: str,
bits: int,
groupsize: int,
output_dir: str,
revision: str,
trust_remote_code: bool,
upload_to_model_id: Optional[str],
percdamp: float,
act_order: bool,
sym: bool,
):
print("loading model")
config = AutoConfig.from_pretrained(
model_id,
trust_remote_code=trust_remote_code,
)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
config, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
)
model = model.eval()
print("LOADED model")
files = weight_files(model_id, revision, extension=".safetensors")
process_group, _, _ = initialize_torch_distributed()
weights = Weights(
files,
device=torch.device("cuda:0"),
dtype=torch.float16,
process_group=process_group,
aliases={"embed_tokens.weight": ["lm_head.weight"]},
weights_loader=DefaultWeightsLoader(UnquantizedWeight),
)
hooks = []
for name, module in model.named_modules():
def load(module, name):
def _load():
load_weights_pre_hook(name, weights, recursive=True)(module, None)
return _load
def unload(module, name):
def _unload():
load_weights_post_hook(name, weights, recursive=True)(
module, None, None
)
return _unload
module.load = load(module, name)
module.unload = unload(module, name)
hooks.append(
module.register_forward_pre_hook(load_weights_pre_hook(name, weights))
)
hooks.append(
module.register_forward_hook(load_weights_post_hook(name, weights))
)
model.seqlen = 2048
dataset = "wikitext2"
nsamples = 128
seed = None
dataloader, testloader = get_loaders(
dataset,
nsamples=nsamples,
seed=seed,
model_id=model_id,
seqlen=model.seqlen,
trust_remote_code=trust_remote_code,
)
tick = time.time()
quantizers = sequential(
model,
dataloader,
DEV,
nsamples,
bits,
groupsize,
percdamp=percdamp,
act_order=act_order,
hooks=hooks,
sym=sym,
)
print(time.time() - tick)
pack(model, quantizers, bits, groupsize)
from safetensors.torch import save_file
from huggingface_hub import split_torch_state_dict_into_shards
state_dict = model.state_dict()
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
max_shard_size = "10GB"
state_dict_split = split_torch_state_dict_into_shards(
state_dict,
filename_pattern="model.safetensors",
max_shard_size=max_shard_size,
)
index = None
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
shards = state_dict_split.filename_to_tensors
os.makedirs(output_dir, exist_ok=True)
for shard_file, shard in shards.items():
save_file(
shard,
os.path.join(output_dir, shard_file),
metadata={
"format": "pt",
"quantized": "gptq",
"origin": "text-generation-inference",
},
)
if index is None:
path_to_weights = os.path.join(output_dir, "model.safetensors")
logger.info(f"Model weights saved in {path_to_weights}")
else:
save_index_file = "model.safetensors.index.json"
save_index_file = os.path.join(output_dir, save_index_file)
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
config.quantization_config = {
"bits": bits,
"group_size": groupsize,
"damp_percent": percdamp,
"desc_act": act_order,
"static_groups": False,
"sym": sym,
"quant_method": "gptq",
}
config.save_pretrained(output_dir)
logger.info("Saved config")
logger.info("Saving tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
model_id, trust_remote_code=trust_remote_code
)
tokenizer.save_pretrained(output_dir)
logger.info("Saved tokenizer")
if upload_to_model_id:
api = HfApi()
api.upload_folder(
folder_path=output_dir, repo_id=upload_to_model_id, repo_type="model"
)
================================================
FILE: backends/gaudi/server/text_generation_server/layers/gptq/utils.py
================================================
import torch
# copied from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py
def torch_snr_error(
y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = "mean"
) -> torch.Tensor:
"""
Compute SNR between y_pred(tensor) and y_real(tensor)
SNR can be calcualted as following equation:
SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2
if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements.
SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)
Args:
y_pred (torch.Tensor): _description_
y_real (torch.Tensor): _description_
reduction (str, optional): _description_. Defaults to 'mean'.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
torch.Tensor: _description_
"""
if y_pred.shape != y_real.shape:
raise ValueError(
f"Can not compute snr loss for tensors with different shape. "
f"({y_pred.shape} and {y_real.shape})"
)
reduction = str(reduction).lower()
if y_pred.ndim == 1:
y_pred = y_pred.unsqueeze(0)
y_real = y_real.unsqueeze(0)
y_pred = y_pred.flatten(start_dim=1)
y_real = y_real.flatten(start_dim=1)
noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1)
signal_power = torch.pow(y_real, 2).sum(dim=-1)
snr = (noise_power) / (signal_power + 1e-7)
if reduction == "mean":
return torch.mean(snr)
elif reduction == "sum":
return torch.sum(snr)
elif reduction == "none":
return snr
else:
raise ValueError("Unsupported reduction method.")
================================================
FILE: backends/gaudi/server/text_generation_server/layers/layernorm.py
================================================
import torch
from torch import nn
from accelerate import init_empty_weights
# Monkey patching
@classmethod
def load_layer_norm(cls, prefix, weights, eps):
weight = weights.get_tensor(f"{prefix}.weight")
bias = weights.get_tensor(f"{prefix}.bias")
with init_empty_weights():
ln = cls(weight.shape, eps=eps)
ln.weight = torch.nn.Parameter(weight)
ln.bias = torch.nn.Parameter(bias)
return ln
@classmethod
def load_layer_norm_no_bias(cls, prefix, weights, eps):
weight = weights.get_tensor(f"{prefix}.weight")
with init_empty_weights():
ln = cls(weight.shape, eps=eps)
ln.weight = torch.nn.Parameter(weight)
ln.bias = None
return ln
torch.nn.LayerNorm.load = load_layer_norm
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states
return super().forward(hidden_states), residual
class FastRMSNorm(nn.Module):
def __init__(self, weight: torch.Tensor, eps: float):
super().__init__()
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps
@classmethod
def load(cls, prefix, weights, eps=1e-6):
weight = weights.get_tensor(f"{prefix}.weight")
return cls(weight, eps)
def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(self.weight.dtype), residual
================================================
FILE: backends/gaudi/server/text_generation_server/layers/linear.py
================================================
import torch
from torch.nn import functional as F
class FastLinear(torch.nn.Module):
def __init__(
self,
weight,
bias,
) -> None:
super().__init__()
self.weight = torch.nn.Parameter(weight, requires_grad=False)
if bias is not None:
self.bias = torch.nn.Parameter(bias, requires_grad=False)
else:
self.bias = None
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_tensor(f"{prefix}.weight")
if bias:
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return cls(weight, bias)
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight, self.bias)
def get_linear(weight, bias):
# Weights that are loaded through methods that are not
# quantization-aware are still bare tensors. We may want
# to change this in the future.
if isinstance(weight, torch.Tensor):
return FastLinear(weight, bias)
return weight.get_linear(bias)
================================================
FILE: backends/gaudi/server/text_generation_server/layers/lora.py
================================================
from typing import TYPE_CHECKING, Optional, List
import torch
import torch.distributed
from torch import nn
from torch.distributed import ProcessGroup
from text_generation_server.utils.sgmv import (
add_lora_a_bgmv,
add_lora_b_bgmv,
has_sgmv,
lora_a_sgmv_cutlass,
lora_b_sgmv_cutlass,
orient_for_rank,
)
if TYPE_CHECKING:
from text_generation_server.adapters import AdapterBatchData
from text_generation_server.adapters.lora import BatchLoraWeights
class LoraLinear(nn.Module):
def __init__(
self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup
):
super().__init__()
self.base_layer = base_layer
self.layer_id = layer_id
self.process_group = process_group
def forward_layer_type(
self,
result: torch.Tensor,
input: torch.Tensor,
adapter_data: "AdapterBatchData",
layer_type: str,
start_idx: int,
end_idx: int,
) -> torch.Tensor:
if adapter_data is None:
return result
data: Optional["BatchLoraWeights"] = adapter_data.data.get(layer_type)
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
# The 'result' tensor represents the full output, which can vary in size based on
# the layer type (e.g., attention vs. feed-forward layers). We define the current
# segment using start_idx and end_idx. If the segment size doesn't match this GPU's
# slice of 'result', we create a zero tensor of the correct size for LoRA computation.
# This approach ensures accurate LoRA application across various layer sizes and
# configurations, adapting to different model architectures and parallelization strategies.
#
# Example scenarios where this is necessary:
# 1. The adapter's size doesn't evenly divide across GPUs.
# 2. We're processing the last segment which might be smaller.
# 3. Different projection layers (q, k, v) have different sizes.
if end_idx - start_idx != result.shape[1]:
proj = torch.zeros_like(result[:, start_idx:end_idx])
else:
proj = result
for r, rank_segments in data.rank_data.items():
lora_a_ptr = rank_segments.lora_a_ptr
lora_b_ptr = rank_segments.lora_b_ptr
if lora_a_ptr is None or lora_b_ptr is None:
raise ValueError("LoRA data is missing")
if data.use_sgmv:
# Use SGMV for prefill
v = lora_a_sgmv_cutlass(
input,
rank_segments.tmp_shrink,
lora_a_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
r,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
lora_b_sgmv_cutlass(
proj,
v,
rank_segments.tmp_expand,
lora_b_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
)
else:
# Use BGMV for decode
v = torch.zeros(
(input.size(0), r), dtype=input.dtype, device=input.device
)
# TODO: error with [-1, 0], but not [0, -1]
add_lora_a_bgmv(
v,
input,
lora_a_ptr,
rank_segments.indices,
self.layer_id,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
add_lora_b_bgmv(
proj,
v,
lora_b_ptr,
rank_segments.indices,
self.layer_id,
)
if end_idx - start_idx != result.shape[1]:
result[:, start_idx:end_idx] += proj
else:
for adapter_index in adapter_data.meta.adapter_set:
if data is not None and data.has_adapter(adapter_index):
adapter_mask = (
(adapter_data.meta.adapter_indices == adapter_index)
.to(input.dtype)
.view(-1, 1)
)
layer_result = self.forward_lora(
input, data, adapter_index, adapter_mask
)
result[:, start_idx:end_idx] += layer_result
return result
def forward_lora(
self,
input: torch.Tensor,
data: "BatchLoraWeights",
adapter_index: int,
adapter_mask: torch.Tensor,
) -> torch.Tensor:
lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
lora_a = orient_for_rank(lora_a, lora_b.size(0))
a_out = input @ lora_a
if self.process_group.size() > 1:
a_out = self.collect_lora_a(a_out)
result = (a_out @ lora_b) * adapter_mask
return result
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
raise NotImplementedError("Implemented in subclasses")
class TensorParallelMultiAdapterLinear(LoraLinear):
def __init__(
self,
base_layer: nn.Module,
layer_id: int,
layer_names: List[str],
sizes: List[int],
process_group: ProcessGroup,
):
super().__init__(base_layer, layer_id, process_group)
self.layer_names = layer_names
self.sizes = sizes
@classmethod
def load(
cls,
base_layer: nn.Module,
layer_id: int,
layer_names: List[str],
sizes: List[int],
process_group: ProcessGroup,
):
return TensorParallelMultiAdapterLinear(
base_layer, layer_id, layer_names, sizes, process_group
)
def forward(
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
) -> torch.Tensor:
result = self.base_layer(input)
# noop if no layer names are provided (e.g. for models without adapters)
if self.layer_names is None:
return result
# handle models like Bloom that have inputs of shape
# (batch_size, sequence_length, hidden_size)
# we need to reshape them to (batch_size * sequence_length, hidden_size)
# for the LoRA computation, then reshape back
prev_shape = result.shape
is_3d = len(input.shape) >= 3
if is_3d:
input = input.reshape(-1, input.shape[-1])
result = result.reshape(-1, result.shape[-1])
offset = 0
for i, layer_name in enumerate(self.layer_names):
start_idx = offset // self.process_group.size()
# The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
# projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
# ensures correct slicing of the result tensor, accommodating variations like grouped-query
# attention where k_proj and v_proj differ from q_proj. This allows precise application of
# LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
# different projection sizes across layers and model architectures.
if self.sizes is not None:
offset += self.sizes[i]
end_idx = offset // self.process_group.size()
else:
end_idx = result.shape[1]
result = self.forward_layer_type(
result, input, adapter_data, layer_name, start_idx, end_idx
)
if is_3d:
result = result.reshape(prev_shape)
return result
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
# Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.
# We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.
#
# TODO(travis): this is not very efficient as we do an all-gather for every adapter,
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
gathered_tensors = [
torch.empty_like(a_out) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(gathered_tensors, a_out)
return torch.cat(gathered_tensors, dim=1)
class TensorParallelAdapterRowLinear(LoraLinear):
def __init__(self, base_layer, layer_id, layer_name, process_group):
super().__init__(base_layer, layer_id, process_group)
self.layer_name = layer_name
@classmethod
def load(cls, base_layer, layer_id, layer_name, process_group):
return cls(base_layer, layer_id, layer_name, process_group)
def forward(
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
) -> torch.Tensor:
result = self.base_layer(input)
if self.layer_name is None:
return result
# Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
stride = result.shape[-1] // self.process_group.size()
start_idx = self.process_group.rank() * stride
end_idx = (self.process_group.rank() + 1) * stride
self.forward_layer_type(
result, input, adapter_data, self.layer_name, start_idx, end_idx
)
return result
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
# Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise.
# We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks.
#
# TODO(travis): this is not very efficient as we do an all-reduce for every adapter,
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
torch.distributed.all_reduce(a_out, group=self.process_group)
return a_out
================================================
FILE: backends/gaudi/server/text_generation_server/layers/medusa.py
================================================
import torch
from torch import nn
from typing import Tuple, Optional
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.layers.linear import FastLinear
from text_generation_server.layers.tensor_parallel import (
TensorParallelHead,
TensorParallelColumnLinear,
)
class ResBlock(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.linear = FastLinear.load(
config, prefix=f"{prefix}.linear", weights=weights, bias=True
)
self.act = torch.nn.SiLU()
def forward(self, x):
return x + self.act(self.linear(x))
class MedusaModel(torch.nn.Module):
def __init__(self, config, medusa_config, weights):
super().__init__()
self.heads = torch.nn.ModuleList(
[
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
for i in range(get_speculate())
]
)
def forward(self, x):
if not self.heads:
return None
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
return speculative_logits
class MedusaHead(torch.nn.Module):
def __init__(self, config, medusa_config, prefix, weights):
super().__init__()
self.blocks = torch.nn.ModuleList(
[
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
for i in range(medusa_config["medusa_num_layers"])
]
)
n = len(self.blocks)
self.out = FastLinear.load(
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
)
def forward(self, x):
for block in self.blocks:
x = block(x)
x = self.out(x)
return x
class MedusaHeadV1(nn.Module):
def __init__(self, lm_head, medusa):
super().__init__()
self.lm_head = lm_head
self.medusa = medusa
@staticmethod
def load(config, prefix: str, weights):
from pathlib import Path
from safetensors import safe_open
import json
speculator = config.speculator
path = speculator["path"]
medusa_config = str(Path(path) / "config.json")
for fname in speculator["model_paths"]:
filename = str(Path(path) / fname)
with open(medusa_config, "r") as f:
medusa_config = json.load(f)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing and routing[k] != filename:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
medusa = MedusaModel(config, medusa_config, weights)
lm_head = TensorParallelHead.load(config, prefix, weights)
return MedusaHeadV1(lm_head, medusa)
def forward(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input)
# If we have too many tokens, we skip speculative logits
if input.shape[0] > 128:
return logits, None
speculative_logits = self.medusa(input)
return logits, speculative_logits
class MedusaHeadV2(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
from pathlib import Path
from safetensors import safe_open
import json
speculator_path = config.speculator["path"]
medusa_config = str(Path(speculator_path) / "config.json")
filename = str(Path(speculator_path) / "medusa_lm_head.safetensors")
with open(medusa_config, "r") as f:
medusa_config = json.load(f)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing and routing[k] != filename:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
self.n_medusa_heads = get_speculate()
assert medusa_config["medusa_num_layers"] == 1
self.linear = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)],
dim=0,
weights=weights,
bias=True,
)
self.process_group = weights.process_group
self.world_size = self.process_group.size()
self.rank = self.process_group.rank()
self.act = torch.nn.SiLU()
self.lm_head = TensorParallelHead.load(config, prefix, weights)
def forward(self, x):
# If we have too many tokens, we skip speculative logits
if x.shape[0] > 128:
logits = self.lm_head(x)
return logits, None
size = x.shape[-1]
block_size = (size + self.world_size - 1) // self.world_size
start = self.rank * block_size
stop = (self.rank + 1) * block_size
x_block = x[:, start:stop]
# Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
medusa_res = self.act(self.linear(x)).reshape(
*x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]
)
# Apply all residual medusa heads
output = x[:, start:stop].unsqueeze(-2) + medusa_res
# Gather medusa heads
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)
# Stack x and medusa residual x
stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)
# Compute lm head on x + medusa residual x
logits = self.lm_head(stacked_x)
# Finally, split logits from speculative logits
logits, speculative_logits = torch.split(
logits, [1, self.n_medusa_heads], dim=-2
)
# Squeeze added dimension
logits = logits.squeeze(-2)
return logits, speculative_logits
================================================
FILE: backends/gaudi/server/text_generation_server/layers/mlp.py
================================================
import torch
import math
from torch import nn
from torch.nn import functional as F
from typing import Optional, Tuple
from text_generation_server.layers import TensorParallelEmbedding, FastLinear
from text_generation_server.layers.tensor_parallel import TensorParallelHead
from text_generation_server.utils.speculate import get_speculate
class MLPSpeculatorLayerNorm(nn.Module):
"""
A L2 normalization implementation
...
Args
----
normalized_shape : int
Dimensionality of input data (size of final tensor axis)
elementwise_scale_weight : torch.Tensor
learned scaling term after normalization?
elementwise_shift_bias : torch.Tensor
learned bias term after normalization?
eps : float
Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8).
"""
def __init__(
self,
prefix,
config,
weights,
eps=1e-06,
):
super(MLPSpeculatorLayerNorm, self).__init__()
self.weight = weights.get_tensor(f"{prefix}.weight")
self.bias = weights.get_tensor(f"{prefix}.bias")
self.eps = eps
def forward(self, x):
xf = x
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
x = xf.type_as(x)
x = self.weight * x
x = x + self.bias
return x
INV_SQRT2 = 2**-0.5
def simple_norm(x: torch.Tensor, eps=1e-06):
xf = x
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + eps)
x = xf.type_as(x)
return x * INV_SQRT2
class MLPSpeculatorModelTied(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.config = config
self.n_predict = get_speculate()
self.hidden_size = config.hidden_size
self.emb = TensorParallelEmbedding(f"{prefix}.emb.0", weights)
self.proj0 = FastLinear.load(
config,
prefix=f"{prefix}.proj.0",
weights=weights,
bias=False,
)
self.proj1 = FastLinear.load(
config,
prefix=f"{prefix}.proj.1",
weights=weights,
bias=False,
)
self.head = FastLinear.load(config, f"{prefix}.head.0", weights, bias=False)
self.ln = MLPSpeculatorLayerNorm(
prefix=f"{prefix}.ln.0",
config=config,
weights=weights,
)
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
self.activation = nn.GELU()
self.vsize = config.vocab_size
self.inner_dim = config.speculator_config["inner_dim"]
self.top_k_tokens_per_head = [1] * self.n_predict
self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(
self.inner_dim / 2
)
self.emb.weight *= self.emb_weight
def forward(
self,
hidden_states: torch.Tensor,
input_ids: torch.Tensor,
):
top_k_tokens_per_head = self.top_k_tokens_per_head
# k indicates # of candidates
# h indicates # of generated tokens
state = hidden_states
b = state.size(0)
ind = input_ids.unsqueeze(0)
all_probs = torch.empty(
b, self.n_predict, self.vsize, device=state.device
) # b k h v
assert (
len(top_k_tokens_per_head) == self.n_predict
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
for i in range(self.n_predict):
# Project and predict
z = self.emb(ind)
# z = z.mul(self.emb_weight) # b k d
if i == 0:
state = self.proj0(state) * self.state_weight + z
else:
state = self.proj1(state) * self.state_weight + z
state = self.activation(self.ln(state)) # b k d
probs = F.log_softmax(self.head(state), dim=-1) # b k v
_probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
# Update candidate set with new predictions
# Update distribution set with new logits
all_probs[:, i] = probs.exp()
# Update state, log_probs and ind for new predictions
state = state.unsqueeze(2).expand(
-1, -1, top_k_tokens_per_head[i], -1
) # b k k' d
state = state.reshape(-1, b, state.size(3)) # b kk' d
ind = preds.view(-1, b) # b kk'
speculative_logits = all_probs
return speculative_logits
class MLPSpeculatorModel(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.config = config
self.n_predict = get_speculate()
self.hidden_size = config.hidden_size
self.emb = nn.ModuleList(
[
TensorParallelEmbedding(f"{prefix}.emb.{i}", weights)
for i in range(self.n_predict)
]
)
self.proj = [
FastLinear.load(
config,
prefix=f"{prefix}.proj.{i}",
weights=weights,
bias=False,
)
for i in range(self.n_predict)
]
self.head = nn.ModuleList(
[
FastLinear.load(config, f"{prefix}.head.{i}", weights, bias=False)
for i in range(self.n_predict)
]
)
self.ln = nn.ModuleList(
[
MLPSpeculatorLayerNorm(
prefix=f"{prefix}.ln.{i}",
config=config,
weights=weights,
)
for i in range(self.n_predict)
]
)
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
self.activation = nn.GELU()
self.vsize = config.vocab_size
self.inner_dim = config.speculator_config["inner_dim"]
self.top_k_tokens_per_head = [1] * self.n_predict
self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(
self.inner_dim / 2
)
self.emb.weight *= self.emb_weight
def forward(
self,
hidden_states: torch.Tensor,
input_ids: torch.Tensor,
):
top_k_tokens_per_head = self.top_k_tokens_per_head
# k indicates # of candidates
# h indicates # of generated tokens
state = hidden_states
b = state.size(0)
ind = input_ids.unsqueeze(0)
all_probs = torch.empty(
b, self.n_predict, self.vsize, device=state.device
) # b k h v
assert (
len(top_k_tokens_per_head) == self.n_predict
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
for i in range(self.n_predict):
# Project and predict
z = self.emb[i](ind)
# z = z.mul(self.emb_weight) # b k d
state = self.proj[i](state) * self.state_weight + z
state = self.activation(self.ln[i](state)) # b k d
probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
_probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
# Update candidate set with new predictions
# Update distribution set with new logits
all_probs[:, i] = probs.exp()
# Update state, log_probs and ind for new predictions
state = state.unsqueeze(2).expand(
-1, -1, top_k_tokens_per_head[i], -1
) # b k k' d
state = state.reshape(-1, b, state.size(3)) # b kk' d
ind = preds.view(-1, b) # b kk'
speculative_logits = all_probs
return speculative_logits
class MLPSpeculatorHead(nn.Module):
def __init__(self, lm_head, mlp_speculator, scale_input: bool):
super().__init__()
self.lm_head = lm_head
self.mlp_speculator = mlp_speculator
self.scale_input = scale_input
def forward(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input)
# If we have too many tokens, we skip speculative logits
if input.shape[0] > 128:
return logits, None
input_ids = logits.argmax(dim=-1)
if self.scale_input:
input = simple_norm(input)
speculative_logits = self.mlp_speculator(input, input_ids)
return logits, speculative_logits
@staticmethod
def load(config, prefix: str, weights):
from pathlib import Path
from safetensors import safe_open
speculator_path = config.speculator["path"]
for fname in config.speculator["model_paths"]:
filename = str(Path(speculator_path) / fname)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing and routing[k] != filename:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
tie_weights = config.speculator_config.get("tie_weights", False)
if tie_weights:
mlp_speculator = MLPSpeculatorModelTied(config, "speculator", weights)
else:
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
# This is used in https://huggingface.co/ibm-fms/llama3-70b-accelerator
scale_input = config.speculator_config.get("scale_input", False)
lm_head = TensorParallelHead.load(config, prefix, weights)
return MLPSpeculatorHead(lm_head, mlp_speculator, scale_input)
================================================
FILE: backends/gaudi/server/text_generation_server/layers/moe/__init__.py
================================================
from typing import Optional, Protocol, runtime_checkable
import torch
import torch.nn as nn
from loguru import logger
from transformers.activations import ACT2FN
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
Weights,
UnquantizedWeight,
)
from .fused_moe import fused_topk, grouped_topk
# NOTE: we are using a protocol here, because multiple inherance is not nice.
# We need `Module`, and `Module` -> some abstract class -> some concrete
# class inheritance is whacky.
@runtime_checkable
class MoELayer(Protocol):
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
hidden_act: str = "silu",
scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None,
): ...
def forward(
self, x: torch.Tensor, *, gating_output: torch.Tensor
) -> torch.Tensor: ...
class DenseMoELayer(nn.Module):
"""
Layer for MoE that applies *all* experts to each tokens and then weights
their outputs based on the calculated routing. This layer is much slower
than `SparseMoELayer` and should only be used when no fused kernels are
available (e.g. for unsupported quantizers).
"""
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
hidden_act: str = "silu",
scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None,
):
super().__init__()
assert scoring_func is None, "scoring func is not handled"
assert e_score_correction_bias is None, "scoring correction bias is not handled"
log_once(
logger.info,
"No fused layers are available for this model type, using (slower) dense MoE layer",
)
assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"
self.n_expert_group = n_expert_group
self.n_experts = n_experts
self.renormalize = renormalize
self.topk = topk
self.topk_group = topk_group
if "gelu" in hidden_act:
self.act = lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh"
if hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none"
),
)
elif "silu" in hidden_act:
self.act = torch.nn.functional.silu
else:
self.act = ACT2FN[hidden_act]
self.gate_proj = [
TensorParallelColumnLinear.load(
None,
prefix=f"{prefix}.{i}.{gate_proj_name}",
weights=weights,
bias=False,
)
for i in range(self.n_experts)
]
self.up_proj = [
TensorParallelColumnLinear.load(
None,
prefix=f"{prefix}.{i}.{up_proj_name}",
weights=weights,
bias=False,
)
for i in range(self.n_experts)
]
self.down_proj = [
TensorParallelRowLinear.load(
None,
prefix=f"{prefix}.{i}.{down_proj_name}",
weights=weights,
bias=False,
)
for i in range(self.n_experts)
]
self.process_group = weights.process_group
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gating_output: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
if self.n_expert_group is not None and self.topk_group is not None:
topk_weights, topk_ids = grouped_topk(
x,
gating_output,
self.topk,
renormalize=self.renormalize,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
else:
topk_weights, topk_ids = fused_topk(
x, gating_output, self.topk, self.renormalize
)
topk_weights = topk_weights.to(x.dtype)
weights = torch.zeros(
topk_ids.shape[0], self.n_experts, dtype=x.dtype, device=x.device
)
weights.scatter_(1, topk_ids.long(), topk_weights.to(weights.dtype))
out = torch.zeros_like(x)
for i in range(self.n_experts):
h = self.act(self.gate_proj[i](x)) * self.up_proj[i](x)
h = self.down_proj[i](h, reduce=False)
out += h * weights[:, i].view(-1, 1)
return out
class SparseMoELayer(nn.Module):
"""
Layer for MoE that uses fused kernels to only apply the active experts
for each token (rather than applying all experts and selecting the
outputs of active experts).
"""
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()
if (
isinstance(weights.loader, DefaultWeightsLoader)
and isinstance(weights.loader.weight_class, UnquantizedWeight)
) or isinstance(weights.loader, HybridFP8UnquantLoader):
if (
isinstance(weights.loader, HybridFP8UnquantLoader)
and weights.loader.to_fp8
):
cls = FP8SparseMoELayer
else:
cls = UnquantizedSparseMoELayer
else:
raise ValueError(
f"Unsupported weights loader: {type(weights.loader)}, sparse MoE is only supported for unquantized, AWQ, and GPTQ weights"
)
log_once(
logger.info,
"Using MoE layer wih fused gemm",
)
self.moe = cls(
n_expert_group=n_expert_group,
n_experts=n_experts,
prefix=prefix,
renormalize=renormalize,
topk=topk,
topk_group=topk_group,
weights=weights,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name,
down_proj_name=down_proj_name,
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return self.moe(x, gating_output=gating_output)
@staticmethod
def is_supported(weights: Weights) -> bool:
return (
isinstance(weights.loader, DefaultWeightsLoader)
and isinstance(weights.loader.weight_class, UnquantizedWeight)
) or isinstance(weights.loader, HybridFP8UnquantLoader)
================================================
FILE: backends/gaudi/server/text_generation_server/layers/moe/fp8.py
================================================
from typing import Optional
import torch
import torch.nn as nn
import os
from text_generation_server.utils.weights import Weights
from text_generation_server.layers.fp8 import (
Fp8Weight,
fp8_quantize,
quant_dtype,
normalize_e4m3fn_to_native_float8,
dynamic_quant,
dequant_block_fp8_weight_naive,
)
from text_generation_server.layers.moe.fused_moe import select_experts
import habana_frameworks.torch as htorch
class FP8SparseMoELayer(nn.Module):
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()
assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"
self.n_expert_group = n_expert_group
self.topk = topk
self.topk_group = topk_group
self.renormalize = renormalize
self.weight_block_size = weights.weights_loader.weight_block_size
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.world_size = weights.process_group.size()
self.rank = weights.process_group.rank()
self.ep_rank = self.rank
self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true"
if (n_experts + self.world_size - 1) // self.world_size < 4:
self.use_ep = False
if self.use_ep:
n_experts_per_rank = (n_experts + self.world_size - 1) // self.world_size
self.ep_offset = self.ep_rank * n_experts_per_rank
n_experts = min(n_experts_per_rank, n_experts - self.ep_offset)
else:
self.ep_offset = 0
(
self.gate_up_proj,
self.gate_up_proj_weight_scale,
self.gate_up_proj_input_scale,
) = _load_expert_multi_weights_col(
prefix=prefix,
n_experts=n_experts,
gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name,
weights=weights,
use_ep=self.use_ep,
ep_offset=self.ep_offset,
)
self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
_load_expert_weights_row(
prefix=prefix,
n_experts=n_experts,
name=down_proj_name,
weights=weights,
use_ep=self.use_ep,
ep_offset=self.ep_offset,
)
)
if self.weight_block_size is not None:
self.gate_up_proj, self.gate_up_proj_weight_scale = dynamic_quant(
dequant_block_fp8_weight_naive(
self.gate_up_proj,
self.gate_up_proj_weight_scale,
self.weight_block_size,
)
)
self.down_proj, self.down_proj_weight_scale = dynamic_quant(
dequant_block_fp8_weight_naive(
self.down_proj, self.down_proj_weight_scale, self.weight_block_size
)
)
self.gate_up_proj_weight_scale, self.down_proj_weight_scale = (
self.gate_up_proj_weight_scale.squeeze(-1),
self.down_proj_weight_scale.squeeze(-1),
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=gating_output,
use_grouped_topk=self.n_expert_group is not None,
top_k=self.topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.n_expert_group,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
)
total_num_experts = gating_output.size(-1)
x_fp8, x_scale = dynamic_quant(x, single_scale=True)
if self.use_ep:
moe_n_slice = 1
n_expert_slice = (
total_num_experts + self.world_size - 1
) // self.world_size
else:
moe_n_slice = 1
n_expert_slice = (total_num_experts + moe_n_slice - 1) // moe_n_slice
for i in range(moe_n_slice):
min_expert = i * n_expert_slice
max_expert = min((i + 1) * n_expert_slice, total_num_experts)
w13_list_slice = [
self.gate_up_proj[j, ...] for j in range(min_expert, max_expert)
]
w2_list_slice = [
self.down_proj[j, ...] for j in range(min_expert, max_expert)
]
w13_weight_scale = [
self.gate_up_proj_weight_scale[j, ...]
for j in range(min_expert, max_expert)
]
w2_weight_scale = [
self.down_proj_weight_scale[j, ...]
for j in range(min_expert, max_expert)
]
current_hidden_states = torch.ops.hpu.mixture_of_experts(
hidden_states=x_fp8,
expert_routing_table=topk_ids.to(torch.int64),
router_weights=topk_weights.to(x.dtype),
w12=w13_list_slice,
w3=w2_list_slice,
d_scale_hidden_states=x_scale,
d_scale_w12=w13_weight_scale,
d_scale_w3=w2_weight_scale,
permuted_weights=True,
activation="silu",
experts_min=min_expert + self.ep_offset,
experts_max=max_expert + self.ep_offset - 1,
)
htorch.core.mark_step()
if i == 0:
final_hidden_states = current_hidden_states
else:
final_hidden_states.add_(current_hidden_states)
return final_hidden_states
def _load_expert_weights(
get_weight_fn,
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
ep_offset: int = 0,
) -> torch.Tensor:
all_weight = None
all_weight_scales = None
max_input_scale = None
for i in range(n_experts):
weight = get_weight_fn(prefix, i + ep_offset, name, weights)
assert isinstance(weight, Fp8Weight)
if all_weight is None:
all_weight = torch.empty(
(n_experts,) + weight.weight.shape,
dtype=quant_dtype,
device=weight.weight.device,
)
if all_weight_scales is None:
all_weight_scales = torch.empty(
(n_experts,) + weight.weight_scale.shape,
dtype=torch.float32,
device=weight.weight.device,
)
if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}:
all_weight[i], all_weight_scales[i], current_input_scale = (
normalize_e4m3fn_to_native_float8(
weight.weight, weight.weight_scale, weight.input_scale
)
)
if current_input_scale is not None:
if max_input_scale is None or current_input_scale > max_input_scale:
max_input_scale = current_input_scale
else:
all_weight[i], all_weight_scales[i] = fp8_quantize(
weight.weight, scalar=True
)
assert all_weight is not None
return all_weight, all_weight_scales, max_input_scale
def _load_expert_multi_weights_col(
*,
prefix: str,
n_experts: int,
gate_proj_name: str,
up_proj_name: str,
weights: Weights,
use_ep: bool = False,
ep_offset: int = 0,
) -> torch.Tensor:
def get_weight_fn_sharded(prefix, i, name, weights):
return weights.get_multi_weights_col(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
def get_weight_fn(prefix, i, name, weights):
return weights.get_multi_weights(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
return _load_expert_weights(
get_weight_fn if use_ep else get_weight_fn_sharded,
prefix=prefix,
n_experts=n_experts,
name=None,
weights=weights,
ep_offset=ep_offset if use_ep else 0,
)
def _load_expert_weights_row(
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
use_ep: bool = False,
ep_offset: int = 0,
) -> torch.Tensor:
def get_weight_fn_sharded(prefix, i, name, weights):
return weights.get_weights_row(f"{prefix}.{i}.{name}")
def get_weight_fn(prefix, i, name, weights):
return weights.get_weights(f"{prefix}.{i}.{name}")
return _load_expert_weights(
get_weight_fn if use_ep else get_weight_fn_sharded,
prefix=prefix,
n_experts=n_experts,
name=name,
weights=weights,
ep_offset=ep_offset if use_ep else 0,
)
================================================
FILE: backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py
================================================
# coding=utf-8
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Optional
import torch
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
gating_output = gating_output.float()
if e_score_correction_bias is not None:
e_score_correction_bias = e_score_correction_bias.float()
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.shape[0]
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
)
else:
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
topk_weights = torch.nn.functional.softmax(
gating_output, dim=1, dtype=torch.float32
)
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
if renormalize:
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
):
# DeekSeekv2 uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
else:
topk_weights, topk_ids = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
return topk_weights, topk_ids
================================================
FILE: backends/gaudi/server/text_generation_server/layers/moe/unquantized.py
================================================
from typing import Optional
import torch
import torch.nn as nn
from text_generation_server.utils.weights import UnquantizedWeight, Weights
from vllm_hpu_extension.ops import VllmMixtureOfExpertsOp
import habana_frameworks.torch as htorch
import torch.nn.functional as F
import os
class UnquantizedSparseMoELayer(nn.Module):
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()
assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"
self.n_expert_group = n_expert_group
self.topk = topk
self.topk_group = topk_group
self.renormalize = renormalize
self.weight_block_size = weights.weights_loader.weight_block_size
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.rank = weights.process_group.rank()
self.world_size = weights.process_group.size()
self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true"
if (n_experts + self.world_size - 1) // self.world_size < 4:
self.use_ep = False
if self.use_ep:
n_experts_per_rank = (n_experts + self.world_size - 1) // self.world_size
self.ep_offset = self.rank * n_experts_per_rank
n_experts = min(n_experts_per_rank, n_experts - self.ep_offset)
experts_min = self.ep_offset
experts_max = self.ep_offset + n_experts - 1
else:
self.ep_offset = 0
experts_min = 0
experts_max = n_experts - 1
self.gate_up_proj = _load_expert_multi_weights_col(
prefix=prefix,
n_experts=n_experts,
gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name,
weights=weights,
use_ep=self.use_ep,
ep_offset=self.ep_offset,
)
self.down_proj = _load_expert_weights_row(
prefix=prefix,
n_experts=n_experts,
name=down_proj_name,
weights=weights,
use_ep=self.use_ep,
ep_offset=self.ep_offset,
)
self.MoeOp = VllmMixtureOfExpertsOp(n_experts, experts_min, experts_max)
for i in range(n_experts):
self.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
self.MoeOp.w2_list[i].set_weight(self.down_proj[i])
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
htorch.core.mark_step()
routing_weights = F.softmax(gating_output, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(
routing_weights, self.topk, dim=-1
)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(x.dtype)
final_hidden_states = self.MoeOp(
hidden_states=x,
expert_routing_table=selected_experts,
router_weights=routing_weights,
permuted_weights=True,
activation="silu",
)
return final_hidden_states.view(-1, x.shape[1])
def _load_expert_multi_weights_col(
*,
prefix: str,
n_experts: int,
gate_proj_name: str,
up_proj_name: str,
weights: Weights,
use_ep: bool = False,
ep_offset: int = 0,
) -> torch.Tensor:
all_weight = None
for i in range(n_experts):
if not use_ep:
weight = weights.get_multi_weights_col(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
else:
weight = weights.get_multi_weights(
[
f"{prefix}.{i+ep_offset}.{gate_proj_name}",
f"{prefix}.{i+ep_offset}.{up_proj_name}",
],
0,
)
assert isinstance(weight, UnquantizedWeight)
if all_weight is None:
all_weight = torch.empty(
(n_experts,) + weight.weight.shape,
dtype=weight.weight.dtype,
device=weight.weight.device,
)
all_weight[i] = weight.weight
assert all_weight is not None
return all_weight
def _load_expert_weights_row(
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
use_ep: bool = False,
ep_offset: int = 0,
) -> torch.Tensor:
all_weight = None
for i in range(n_experts):
if not use_ep:
weight = weights.get_weights_row(
f"{prefix}.{i}.{name}",
)
else:
weight = weights.get_weights(
f"{prefix}.{i+ep_offset}.{name}",
)
assert isinstance(weight, UnquantizedWeight)
if all_weight is None:
all_weight = torch.empty(
(n_experts,) + weight.weight.shape,
dtype=weight.weight.dtype,
device=weight.weight.device,
)
all_weight[i] = weight.weight
assert all_weight is not None
return all_weight
================================================
FILE: backends/gaudi/server/text_generation_server/layers/rotary.py
================================================
import os
import math
import torch
from torch import nn
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
apply_rotary_pos_emb,
)
def _create_inv_freq(dim, base, device):
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
return inv_freq
def _get_rope_config(config):
if os.getenv("ROPE_SCALING", None) is not None:
rope_scaling = {
"type": os.environ["ROPE_SCALING"],
"factor": float(os.environ["ROPE_FACTOR"]),
}
return rope_scaling
return getattr(config, "rope_scaling", None)
class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq, scaling_factor, max_position_embeddings):
super().__init__()
self.inv_freq = inv_freq
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
self.scaling_factor = scaling_factor
self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, inv_freq.device, max_position_embeddings
)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
):
num_tokens = query.shape[0]
head_size = query.shape[-1]
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
# to query hidden dimension, so the original tensors need to be
# expanded
# GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
# and expansion of cos/sin tensors via concatenation
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
cos = torch.cat((cos, cos), dim=-1)
sin = torch.cat((sin, sin), dim=-1)
rotary_dim = cos.shape[-1]
query_shape = query.shape
query = query.view(num_tokens, -1, head_size)
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
key_shape = key.shape
key = key.view(num_tokens, -1, head_size)
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
@classmethod
def static(cls, config, dim, base, device):
inv_freq = _create_inv_freq(dim, base, device)
scaling_factor = None
rope_scaling = _get_rope_config(config)
if not hasattr(config, "max_position_embeddings") and hasattr(
config, "max_seq_len"
):
# handling for dbrx
config.max_position_embeddings = config.max_seq_len
if rope_scaling is not None:
# `rope_type` is now standard in transformers, but some existing models
# have `type` instead.
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))
if rope_type == "linear":
pass
elif rope_type == "default":
pass
elif rope_type == "mrope":
mrope_section = rope_scaling["mrope_section"]
if mrope_section is not None:
return RotaryPositionEmbeddingMultimodalSections(
inv_freq,
scaling_factor,
mrope_section,
config.max_position_embeddings,
)
elif rope_type == "dynamic":
scaling_factor = rope_scaling["factor"]
return DynamicPositionRotaryEmbedding(
dim=dim,
max_position_embeddings=config.max_position_embeddings,
base=base,
device=inv_freq.device,
scaling_factor=scaling_factor,
)
elif rope_type == "llama3":
inv_freq = apply_llama3_scaling(
inv_freq,
scaling_factor=rope_scaling["factor"],
low_freq_factor=rope_scaling["low_freq_factor"],
high_freq_factor=rope_scaling["high_freq_factor"],
original_max_position_embeddings=rope_scaling[
"original_max_position_embeddings"
],
)
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
elif rope_type == "yarn":
scaling_factor = rope_scaling["factor"]
mscale = rope_scaling.get("mscale", 1.0)
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0],
max_position_embeddings=rope_scaling[
"original_max_position_embeddings"
],
base=base,
device=inv_freq.device,
scaling_factor=scaling_factor,
extrapolation_factor=1,
attn_factor=1,
beta_fast=32,
beta_slow=1,
mscale=mscale,
mscale_all_dim=mscale_all_dim,
)
elif rope_type in ["su", "longrope"]:
short_factor = torch.tensor(
rope_scaling["short_factor"], dtype=torch.float32, device=device
)
short_inv_freq = 1.0 / (
short_factor
* base
** (
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
/ dim
)
)
long_factor = torch.tensor(
rope_scaling["long_factor"], dtype=torch.float32, device=device
)
long_inv_freq = 1.0 / (
long_factor
* base
** (
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
/ dim
)
)
original_max_position_embeddings = (
config.original_max_position_embeddings
)
max_position_embeddings = config.max_position_embeddings
if max_position_embeddings <= original_max_position_embeddings:
scaling_factor = 1.0
else:
scale = max_position_embeddings / original_max_position_embeddings
scaling_factor = math.sqrt(
1 + math.log(scale) / math.log(original_max_position_embeddings)
)
# if short_mscale and long_mscale are provided we need to scale the freqs
# using the Phi3LongRoPEScaledRotaryEmbedding
if ("short_mscale" in rope_scaling) and ("long_mscale" in rope_scaling):
short_mscale = rope_scaling["short_mscale"]
long_mscale = rope_scaling["long_mscale"]
return Phi3LongRoPEScaledRotaryEmbedding(
short_inv_freq=short_inv_freq,
long_inv_freq=long_inv_freq,
max_position_embeddings=config.max_position_embeddings,
short_mscale=short_mscale,
long_mscale=long_mscale,
original_max_position_embeddings=original_max_position_embeddings,
)
return SuRotaryEmbedding(
short_inv_freq=short_inv_freq,
long_inv_freq=long_inv_freq,
scaling_factor=scaling_factor,
original_max_position_embeddings=original_max_position_embeddings,
max_position_embeddings=config.max_position_embeddings,
)
else:
raise NotImplementedError(
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
)
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
@classmethod
def load(cls, config, prefix, weights):
# XXX: Always load this in float32 !
dtype = weights.dtype
weights.dtype = torch.float32
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
weights.dtype = dtype
scaling_factor = None
rope_scaling = _get_rope_config(config)
if rope_scaling is not None:
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "linear":
pass
elif rope_scaling["type"] == "dynamic":
return DynamicPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0],
max_position_embeddings=config.max_position_embeddings,
base=10000.0,
device=inv_freq.device,
scaling_factor=scaling_factor,
)
elif rope_scaling["type"] == "yarn":
mscale = rope_scaling.get("mscale", 1.0)
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0],
max_position_embeddings=rope_scaling[
"original_max_position_embeddings"
],
base=10000.0,
device=inv_freq.device,
scaling_factor=scaling_factor,
extrapolation_factor=1,
attn_factor=1,
beta_fast=32,
beta_slow=1,
mscale=mscale,
mscale_all_dim=mscale_all_dim,
)
else:
raise NotImplementedError(
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
)
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
if self.scaling_factor is not None:
t /= self.scaling_factor
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
def get_cos_sin(self, position_ids: torch.Tensor):
cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids)
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
return cos.unsqueeze(1), sin.unsqueeze(1)
class SuRotaryEmbedding(PositionRotaryEmbedding):
def __init__(
self,
short_inv_freq,
long_inv_freq,
scaling_factor,
original_max_position_embeddings,
max_position_embeddings,
):
super(PositionRotaryEmbedding, self).__init__()
self.short_inv_freq = short_inv_freq
self.long_inv_freq = long_inv_freq
self.scaling_factor = scaling_factor
self.original_max_position_embeddings = original_max_position_embeddings
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, short_inv_freq.device, max_position_embeddings
)
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
short_freqs = torch.outer(
t[: self.original_max_position_embeddings],
self.short_inv_freq.to(device=t.device),
)
long_freqs = torch.outer(
t[self.original_max_position_embeddings :],
self.long_inv_freq.to(device=t.device),
)
freqs = torch.cat([short_freqs, long_freqs])
self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype)
self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)
class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
def __init__(
self,
short_inv_freq: torch.Tensor,
long_inv_freq: torch.Tensor,
max_position_embeddings: int,
short_mscale: float,
long_mscale: float,
original_max_position_embeddings: int,
):
super(PositionRotaryEmbedding, self).__init__()
self.short_inv_freq = short_inv_freq
self.long_inv_freq = long_inv_freq
self.max_position_embeddings = max_position_embeddings
self.short_mscale = short_mscale
self.long_mscale = long_mscale
self.original_max_position_embeddings = original_max_position_embeddings
# cache
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, short_inv_freq.device, max_position_embeddings
)
def _update_cos_sin_cache(self, dtype, device, seqlen):
if (
seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)
short_freqs = torch.outer(
t[: self.original_max_position_embeddings],
self.short_inv_freq.to(device=t.device),
)
long_freqs = torch.outer(
t[self.original_max_position_embeddings :],
self.long_inv_freq.to(device=t.device),
)
short_freqs = short_freqs * self.short_mscale
long_freqs = long_freqs * self.long_mscale
freqs = torch.empty((seqlen, short_freqs.shape[1]), device=device)
freqs[: self.original_max_position_embeddings] = short_freqs
freqs[self.original_max_position_embeddings :] = long_freqs
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = _create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
if seqlen > self.max_position_embeddings:
newbase = self.base * (
(self.scaling_factor * seqlen / self.max_position_embeddings)
- (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
self.inv_freq = _create_inv_freq(
self.dim, newbase, self.inv_freq.device
)
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
)
# Find dim range bounds based on rotations
def find_correction_range(
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
):
low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1) # Clamp values just in case
def linear_ramp_mask(min, max, dim):
if min == max:
max += 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
def get_mscale(scale: float = 1.0, mscale: float = 1.0):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(
self,
dim,
max_position_embeddings,
base,
device,
scaling_factor,
*,
extrapolation_factor,
attn_factor,
beta_fast,
beta_slow,
mscale: float,
mscale_all_dim: float,
):
inv_freq = _create_inv_freq(dim, base, device)
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self.mscale_all_dim = mscale_all_dim
self.scaling_factor = scaling_factor
self.mscale = float(
get_mscale(self.scaling_factor, mscale)
/ get_mscale(self.scaling_factor, mscale_all_dim)
* self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
if seqlen > self.max_position_embeddings or True:
inv_freq_extrapolation = _create_inv_freq(
self.dim, self.base, self.inv_freq.device
)
freqs = 1.0 / inv_freq_extrapolation
inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)
low, high = find_correction_range(
self.beta_fast,
self.beta_slow,
self.dim,
self.base,
self.max_position_embeddings,
)
inv_freq_mask = (
1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)
) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_mask)
+ inv_freq_extrapolation * inv_freq_mask
)
self.inv_freq = inv_freq
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
def apply_llama3_scaling(
freqs: torch.Tensor,
*,
scaling_factor: int,
low_freq_factor: int,
high_freq_factor: int,
original_max_position_embeddings: int,
):
low_freq_wavelen = original_max_position_embeddings / low_freq_factor
high_freq_wavelen = original_max_position_embeddings / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scaling_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
def __init__(
self,
inv_freq: torch.Tensor,
scaling_factor: float,
sections: list,
max_position_embeddings,
):
self.sections = sections
self._cos_cached = None
self._sin_cached = None
self.section_indices = (
torch.arange(len(self.sections))
.repeat_interleave(torch.tensor(self.sections))
.view(1, 1, -1)
.to(inv_freq.device)
)
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
def _update_cos_sin_cache(
self, dtype: torch.dtype, device: torch.device, seqlen: int
):
# always cache the cos/sin for the full sequence length to avoid
# recomputing if the sequence length is smaller than the cached one
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
self._sections = self.section_indices.expand(seqlen, -1, -1)
def get_cos_sin(
self,
position_ids: torch.Tensor,
):
slen = position_ids.shape[0]
cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])
return cos, sin
================================================
FILE: backends/gaudi/server/text_generation_server/layers/speculative.py
================================================
import torch
import json
from typing import Tuple, Optional
from text_generation_server.layers.tensor_parallel import TensorParallelHead
from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2
from text_generation_server.layers.mlp import MLPSpeculatorHead
class SpeculativeHead(torch.nn.Module):
def __init__(self, lm_head, speculator):
super().__init__()
self.head = lm_head
self.speculator = speculator
@staticmethod
def load(config, prefix: str, weights):
speculator = config.speculator
if speculator:
speculator_path = config.speculator["path"]
speculator_config = str(speculator_path / "config.json")
with open(speculator_config, "r") as f:
speculator_config = json.load(f)
config.speculator_config = speculator_config
try:
architecture = speculator_config["architectures"][0]
if architecture == "MLPSpeculatorPreTrainedModel":
speculator = MLPSpeculatorHead.load(config, prefix, weights)
else:
speculator = None
except KeyError:
try:
speculator = MedusaHeadV1.load(config, prefix, weights)
except Exception:
speculator = MedusaHeadV2(config, prefix, weights)
lm_head = None
else:
lm_head = TensorParallelHead.load(config, prefix, weights)
speculator = None
return SpeculativeHead(lm_head, speculator)
def forward(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if self.speculator is not None:
return self.speculator(input)
assert self.head is not None
logits = self.head(input)
return logits, None
================================================
FILE: backends/gaudi/server/text_generation_server/layers/tensor_parallel.py
================================================
import torch
from torch.nn import functional as F
from typing import Iterable, List
from text_generation_server.layers.linear import get_linear, FastLinear
import habana_frameworks.torch as htorch
class LayerConcat(torch.nn.Module):
"""
Apply multiple layers to the input and concatenate their
outputs.
"""
def __init__(self, layers: Iterable[torch.nn.Module], dim: int = -1):
"""
`dim` is the dimension along which layer outputs are concatenated.
"""
super().__init__()
self.layers = layers
self.dim = dim
def forward(self, x: torch.Tensor):
outputs = [layer(x) for layer in self.layers]
return torch.cat(outputs, self.dim)
class SuperLayer(torch.nn.Module):
def __init__(self, linear):
super().__init__()
self.linear = linear
def forward(self, x):
return self.linear.forward(x)
class TensorParallelHead(SuperLayer):
def __init__(self, linear, process_group, should_gather: bool):
super().__init__(linear)
self.process_group = process_group
self.should_gather = should_gather
@staticmethod
def load(config, prefix: str, weights):
if config.quantize == "exl2":
try:
# If the piece and LM head embeddings are shared, we have
# non-quantized weights...
weight = weights.get_tensor(f"{prefix}.weight")
except Exception:
# ...otherwise they are quantized.
weight = weights.get_weights_col(prefix)
should_gather = weights.process_group.size() > 1
elif weights.process_group.size() > 1:
try:
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
should_gather = True
except AssertionError:
# If the vocab size is not divisible by number of shards
# just load the entire thing.
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
else:
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
return TensorParallelHead(
get_linear(weight, bias=None),
process_group=weights.process_group,
should_gather=should_gather,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if not self.should_gather:
return super().forward(input)
world_size = self.process_group.size()
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
out_dim = self.linear.weight.shape[0]
if input.shape[0] == 1:
world_out = input.new_empty(1, out_dim * world_size)
local_out = input.new_empty(1, out_dim)
gather_input = local_out
else:
world_out = input.new_empty(out_dim * world_size, input.shape[0])
gather_input = input.new_empty(out_dim, input.shape[0])
local_out = gather_input.T
torch.mm(input, self.linear.weight.T, out=local_out)
htorch.core.mark_step()
torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
if input.shape[0] == 1:
return world_out
return world_out.T
output = super().forward(input)
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
htorch.core.mark_step()
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)
return world_output
class TensorParallelColumnLinear(SuperLayer):
@classmethod
def load_gate_up(cls, config, prefix: str, weights, bias: bool):
"""Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_gate_up(prefix)
if bias:
raise NotImplementedError("packed_gate_up only implemented without bias")
else:
bias = None
linear = get_linear(weight, bias)
return cls(linear)
@classmethod
def load_qkv(
cls,
config,
prefix: str,
weights,
bias: bool,
num_heads: int,
num_key_value_heads: int,
):
"""Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_qkv(
prefix,
num_heads=num_heads,
num_key_value_heads=num_key_value_heads,
)
if bias:
raise NotImplementedError("packed_qkv only implemented for baichuan")
else:
bias = None
linear = get_linear(weight, bias)
return cls(linear)
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_weights_col(prefix)
if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else:
bias = None
linear = get_linear(weight, bias)
return cls(linear)
@classmethod
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
if config.quantize == "exl2":
linears = []
for prefix in prefixes:
weight = weights.get_weights_col(prefix)
b = weights.get_tensor(f"{prefix}.bias") if bias else None
linears.append(get_linear(weight, b))
linear = LayerConcat(linears)
else:
weight = weights.get_multi_weights_col(prefixes, dim=dim)
if bias:
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
bias = torch.cat(b, dim=dim)
else:
bias = None
linear = get_linear(weight, bias)
return cls(linear)
class TensorParallelRowLinear(SuperLayer):
def __init__(self, linear, process_group):
super().__init__(linear)
self.process_group = process_group
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return cls(
get_linear(weight, bias),
process_group=weights.process_group,
)
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
out = super().forward(input)
if self.process_group.size() > 1 and reduce:
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
# (which is required for tensor parallel HPUGraph inference)
htorch.core.mark_step()
torch.distributed.all_reduce(out, group=self.process_group)
return out
class TensorParallelEmbedding(torch.nn.Module):
def __init__(self, prefix: str, weights, reduce=True):
super().__init__()
weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
process_group = weights.process_group
world_size = process_group.size()
rank = process_group.rank()
block_size = (num_embeddings + world_size - 1) // world_size
self.min_id = rank * block_size
self.max_id = min(num_embeddings, (rank + 1) * block_size)
self.null_idx = weight.shape[
0
] # Usually block_size, might be less in non even vocab_size.
self.process_group = weights.process_group
self.reduce = reduce
"""Additional 0 entry used for masking"""
self.weight = torch.nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
def forward(self, input: torch.Tensor) -> torch.Tensor:
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
# translate for [0, self.max_id - self.min_id[
input = torch.where(
(self.min_id > input) | (input >= self.max_id),
self.null_idx,
input - self.min_id,
)
out = torch.nn.functional.embedding(input, self.weight)
if self.reduce and self.process_group.size() > 1:
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
# (which is required for tensor parallel HPUGraph inference)
htorch.core.mark_step()
torch.distributed.all_reduce(out, group=self.process_group)
return out
================================================
FILE: backends/gaudi/server/text_generation_server/models/__init__.py
================================================
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables
import torch
import os
from loguru import logger
from transformers.configuration_utils import PretrainedConfig
from huggingface_hub import hf_hub_download, HfApi
from typing import Optional
from pathlib import Path
from typing import List, Dict
import enum
# Needed to properly setup habana_frameworks
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
PhiMoEConfig,
)
from text_generation_server.utils.adapter import (
AdapterParameters,
build_layer_weight_lookup,
load_and_merge_adapters,
AdapterInfo,
)
from text_generation_server.adapters.lora import LoraWeights
from text_generation_server.utils.log import log_master
__all__ = [
"Model",
"CausalLM",
"Seq2SeqLM",
"get_model_with_lora_adapters",
]
VLM_BATCH_TYPES = set()
FLASH_ATTENTION = True
try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLM
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
FlashDeepseekV2ForCausalLM,
DeepseekV2Config,
)
from text_generation_server.models.custom_modeling.flash_deepseek_v3_modeling import (
FlashDeepseekV3ForCausalLM,
DeepseekV3Config,
)
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_llama4_modeling import (
Llama4ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
FlashCohereForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gemma3_modeling import (
Gemma3ForConditionalGeneration,
FlashGemma3ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
FlashDbrxForCausalLM,
DbrxConfig,
)
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
RWConfig,
FlashRWForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
FlashGPTNeoXForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM,
)
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
from text_generation_server.models.custom_modeling.flash_mllama import (
FlashMllamaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_llava_next import (
FlashLlavaNextForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
FlashSantacoderForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
FlashStarcoder2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_qwen3_modeling import (
Qwen3ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_qwen3_moe_modeling import (
Qwen3MoeForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
FlashMixtralForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
FlashGPT2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gptj_modeling import (
FlashGPTJForCausalLM,
)
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.idefics3 import (
Idefics3ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.qwen2_vl import (
Qwen2VLForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.qwen2_5_vl import (
Qwen2_5VLForConditionalGeneration,
Qwen2_5_VLConfig,
Qwen2_5_VLProcessor,
)
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
except ImportError as e:
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
SUPPORTS_WINDOWING = False
FLASH_ATTENTION = False
VLM_BATCH_TYPES = set()
if FLASH_ATTENTION:
__all__.append(FlashCausalLM)
from text_generation_server.models.flash_vlm_causal_lm import (
FlashVlmCausalLMBatch,
)
VLM_BATCH_TYPES = {
FlashVlmCausalLMBatch,
FlashMllamaCausalLMBatch,
}
__all__.append(VLM_BATCH_TYPES)
class ModelType(enum.Enum):
DEEPSEEK_V2 = {
"type": "deepseek_v2",
"name": "Deepseek V2",
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
}
DEEPSEEK_V3 = {
"type": "deepseek_v3",
"name": "Deepseek V3",
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V3",
}
IDEFICS2 = {
"type": "idefics2",
"name": "Idefics 2",
"url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
"multimodal": True,
}
IDEFICS3 = {
"type": "idefics3",
"name": "Idefics 3",
"url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3",
"multimodal": True,
}
LLAVA_NEXT = {
"type": "llava_next",
"name": "Llava Next (1.6)",
"url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
"multimodal": True,
}
LLAMA = {
"type": "llama",
"name": "Llama",
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
}
LLAMA4 = {
"type": "llama4",
"name": "Llama4",
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
}
PHI3 = {
"type": "phi3",
"name": "Phi 3",
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
}
GRANITE = {
"type": "granite",
"name": "Granite",
"url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct",
}
GEMMA = {
"type": "gemma",
"name": "Gemma",
"url": "https://huggingface.co/google/gemma-7b",
}
PALIGEMMA = {
"type": "paligemma",
"name": "PaliGemma",
"url": "https://huggingface.co/google/paligemma-3b-pt-224",
}
GEMMA2 = {
"type": "gemma2",
"name": "Gemma2",
"url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
}
GEMMA3 = {
"type": "gemma3",
"name": "Gemma3",
"url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d",
}
GEMMA3_TEXT = {
"type": "gemma3_text",
"name": "Gemma3 Text",
"url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d",
}
COHERE = {
"type": "cohere",
"name": "Cohere",
"url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
}
DBRX = {
"type": "dbrx",
"name": "Dbrx",
"url": "https://huggingface.co/databricks/dbrx-instruct",
}
MAMBA = {
"type": "mamba",
"name": "Mamba",
"url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
}
MISTRAL = {
"type": "mistral",
"name": "Mistral",
"url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407",
}
MIXTRAL = {
"type": "mixtral",
"name": "Mixtral",
"url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
}
GPT_BIGCODE = {
"type": "gpt_bigcode",
"name": "Gpt Bigcode",
"url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
}
PHI = {
"type": "phi",
"name": "Phi",
"url": "https://huggingface.co/microsoft/phi-1_5",
}
PHI_MOE = {
"type": "phimoe",
"name": "PhiMoe",
"url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
}
BAICHUAN = {
"type": "baichuan",
"name": "Baichuan",
"url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
}
FALCON = {
"type": "falcon",
"name": "Falcon",
"url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
}
STARCODER2 = {
"type": "starcoder2",
"name": "StarCoder 2",
"url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
}
QWEN2 = {
"type": "qwen2",
"name": "Qwen 2",
"url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
}
QWEN2_VL = {
"type": "qwen2_vl",
"name": "Qwen 2 VL",
"url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d",
}
QWEN2_5_VL = {
"type": "qwen2_5_vl",
"name": "Qwen 2.5 VL",
"url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e",
}
QWEN3 = {
"type": "qwen3",
"name": "Qwen 3",
"url": "https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f",
}
QWEN3_MOE = {
"type": "qwen3_moe",
"name": "Qwen 3 Moe",
"url": "https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f",
}
GALACTICA = {
"type": "galactica",
"name": "Galactica",
"url": "https://huggingface.co/facebook/galactica-120b",
}
SANTACODER = {
"type": "santacoder",
"name": "SantaCoder",
"url": "https://huggingface.co/bigcode/santacoder",
}
GPT2 = {
"type": "gpt2",
"name": "Gpt2",
"url": "https://huggingface.co/openai-community/gpt2",
}
GPT_NEOX = {
"type": "gpt_neox",
"name": "Gpt Neox",
"url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
}
GPTJ = {
"type": "gptj",
"name": "Gptj",
"url": "https://huggingface.co/EleutherAI/gpt-j-6b",
}
MLLAMA = {
"type": "mllama",
"name": "Mllama",
"url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
"multimodal": True,
}
__GLOBALS = locals()
for data in ModelType:
__GLOBALS[data.name] = data.value["type"]
SDP_ON_BF16 = int(os.environ.get("SDP_ON_BF16", 0))
# Disable gradients
torch.set_grad_enabled(False)
def get_model(
model_id: str,
lora_adapter_ids: Optional[List[str]],
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[torch.dtype],
kv_cache_dtype: Optional[str],
trust_remote_code: bool,
max_input_tokens: int,
) -> Model:
global FLASH_ATTENTION
if speculate is not None:
set_speculate(speculate)
else:
set_speculate(0)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
model_type = config_dict.get("model_type", None)
speculator = None
if "medusa_num_heads" in config_dict:
medusa_model_id = model_id
medusa_revision = revision
model_id = config_dict["base_model_name_or_path"]
revision = "main"
speculate_medusa = config_dict["medusa_num_heads"]
if speculate is not None:
if speculate > speculate_medusa:
raise RuntimeError(
f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
)
else:
set_speculate(speculate)
else:
set_speculate(speculate_medusa)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
# Reload model type from parent.
model_type = config_dict.get("model_type", None)
is_local = Path(medusa_model_id).exists()
if not is_local:
medusa_config = hf_hub_download(
medusa_model_id, revision=medusa_revision, filename="config.json"
)
hf_hub_download(
medusa_model_id,
revision=medusa_revision,
filename="medusa_lm_head.safetensors",
)
speculator = {
"path": Path(medusa_config).parent,
"model_paths": ["medusa_lm_head.safetensors"],
}
else:
speculator = {
"path": Path(medusa_model_id),
"model_paths": ["medusa_lm_head.safetensors"],
}
method = "medusa"
elif model_type == "mlp_speculator":
mlp_model_id = model_id
mlp_revision = revision
model_id = config_dict["base_model_name_or_path"]
revision = "main"
speculate_mlp = config_dict["n_predict"]
if speculate is not None:
if speculate > speculate_mlp:
raise RuntimeError(
f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
)
else:
set_speculate(speculate)
else:
set_speculate(speculate_mlp)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
# Reload model type from parent.
model_type = config_dict.get("model_type", None)
is_local = Path(mlp_model_id).exists()
extension = ".safetensors"
if not is_local:
mlp_speculator_config = hf_hub_download(
mlp_model_id, revision=mlp_revision, filename="config.json"
)
api = HfApi()
info = api.model_info(mlp_model_id, revision=mlp_revision)
filenames = [
s.rfilename
for s in info.siblings
if s.rfilename.endswith(extension)
and len(s.rfilename.split("/")) == 1
and "arguments" not in s.rfilename
and "args" not in s.rfilename
and "training" not in s.rfilename
]
for filename in filenames:
hf_hub_download(
mlp_model_id,
revision=mlp_revision,
filename=filename,
)
speculator_dir_path = Path(mlp_speculator_config).parent
# if these are downloaded, they get converted to safetensors
filenames.extend(
[p for p in os.listdir(speculator_dir_path) if p.endswith(extension)]
)
speculator = {
"path": Path(mlp_speculator_config).parent,
"model_paths": filenames,
}
else:
speculator = Path(mlp_model_id)
filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
speculator = {"path": speculator, "model_paths": filenames}
method = "mlp_speculator"
else:
method = "n-gram"
speculate = get_speculate()
if speculate > 0:
logger.info(f"Using speculation {method} with {speculate} input ids.")
model_type = config_dict["model_type"]
if kv_cache_dtype == "fp8_e4m3fn":
kv_cache_dtype = torch.float8_e4m3fn
elif kv_cache_dtype == "fp8_e5m2":
kv_cache_dtype = torch.float8_e5m2
else:
kv_cache_dtype = dtype
if FLASH_ATTENTION:
if model_type == DEEPSEEK_V2:
head_size = max(
config_dict.get("qk_nope_dim", 128)
+ config_dict.get("qk_rope_dim", 64),
config_dict.get("v_head_dim", 128),
)
return FlashCausalLM(
model_id=model_id,
model_class=FlashDeepseekV2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
default_dtype=torch.bfloat16,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DeepseekV2Config,
head_size=head_size,
)
elif model_type == DEEPSEEK_V3:
head_size = max(
config_dict.get("qk_nope_dim", 128)
+ config_dict.get("qk_rope_dim", 64),
config_dict.get("v_head_dim", 128),
)
return FlashCausalLM(
model_id=model_id,
model_class=FlashDeepseekV3ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
default_dtype=torch.bfloat16,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DeepseekV3Config,
head_size=head_size,
)
elif (
model_type == GPT_BIGCODE
or model_type == GPT2
and model_id.startswith("bigcode/")
):
return FlashCausalLM(
model_id=model_id,
model_class=FlashSantacoderForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
num_kv_heads=1,
)
elif model_type == GPT2:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGPT2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GPTJ:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGPTJForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GPT_NEOX:
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
GPTNeoXConfig,
)
return FlashCausalLM(
model_id=model_id,
model_class=FlashGPTNeoXForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=GPTNeoXConfig,
)
elif model_type == PHI:
return FlashCausalLM(
model_id=model_id,
model_class=FlashPhiForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == PHI_MOE:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
config_class=PhiMoEConfig,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == LLAMA4:
print(f"Llama4 model detected: {model_id}")
return FlashVlmCausalLM(
model_id=model_id,
model_class=Llama4ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
support_chunking=False,
)
elif model_type == BAICHUAN:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GEMMA:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemmaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GEMMA2:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemma2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GEMMA3:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Gemma3ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
support_chunking=False,
)
elif model_type == GEMMA3_TEXT:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemma3ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == COHERE:
return FlashCausalLM(
model_id=model_id,
model_class=FlashCohereForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == DBRX:
return FlashCausalLM(
model_id=model_id,
model_class=FlashDbrxForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Dbrx works better in bfloat16.
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DbrxConfig,
)
elif (
model_type in ["RefinedWeb", "RefinedWebModel", FALCON]
and not sharded
and not config_dict.get("alibi", False)
):
return FlashCausalLM(
model_id=model_id,
model_class=FlashRWForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"],
},
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=RWConfig,
)
elif model_type == MISTRAL:
return FlashCausalLM(
model_id=model_id,
model_class=FlashMistralForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == MIXTRAL:
return FlashCausalLM(
model_id=model_id,
model_class=FlashMixtralForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == STARCODER2:
return FlashCausalLM(
model_id=model_id,
model_class=FlashStarcoder2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == QWEN2:
return FlashCausalLM(
model_id=model_id,
model_class=Qwen2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == QWEN2_VL:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Qwen2VLForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# TODO: Fix bug in rust image_text_replacement implementation
support_chunking=False,
)
elif model_type == QWEN2_5_VL:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Qwen2_5VLForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=Qwen2_5_VLConfig,
processor_class=Qwen2_5_VLProcessor,
# TODO: Fix bug in rust image_text_replacement implementation
support_chunking=False,
)
elif model_type == QWEN3:
return FlashCausalLM(
model_id=model_id,
model_class=Qwen3ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == QWEN3_MOE:
return FlashCausalLM(
model_id=model_id,
model_class=Qwen3MoeForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == MLLAMA:
return FlashMllamaCausalLM(
model_id=model_id,
model_class=FlashMllamaForConditionalGeneration,
batch_class=FlashMllamaCausalLMBatch,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
support_chunking=False,
)
elif model_type == IDEFICS2:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Idefics2ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
)
elif model_type == IDEFICS3:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Idefics3ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
processor_kwargs={"size": {"longest_edge": 1456}},
)
elif model_type == PALIGEMMA:
return FlashVlmCausalLM(
model_id=model_id,
model_class=PaliGemmaForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == LLAVA_NEXT:
return FlashVlmCausalLM(
model_class=FlashLlavaNextForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
)
raise ValueError(f"Unsupported model type {model_type}")
# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters
# this provides a post model loading hook to load adapters into the model after the model has been loaded
def get_model_with_lora_adapters(
model_id: str,
lora_adapters: Optional[List[AdapterInfo]],
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[torch.dtype],
kv_cache_dtype: Optional[str],
trust_remote_code: bool,
max_input_tokens: int,
adapter_to_index: Dict[str, int],
):
lora_adapter_ids = [adapter.id for adapter in lora_adapters]
model = get_model(
model_id,
lora_adapter_ids,
revision,
sharded,
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
max_input_tokens,
)
if len(lora_adapters) > 0:
target_to_layer = build_layer_weight_lookup(model.model)
for index, adapter in enumerate(lora_adapters):
# The AdapterParameters object allows for merging multiple adapters into a single adapter.
# At the moment, we only support loading a single adapter into the model, but we keep the
# AdapterParameters object for easier extension in the future.
adapter_parameters = AdapterParameters(
adapter_info=[adapter],
# when merging multiple adapters we can weight them differently
# if this is not set, all adapters will be weighted equally
# see: text_generation_server.utils.merges.strategies for impl
weights=None,
merge_strategy=0,
density=1.0,
majority_sign_method=0,
)
adapter_index = index + 1
adapter_to_index[adapter.id] = adapter_index
logger.info(
f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
)
weight_names = tuple([v[0] for v in target_to_layer.values()])
(
module_map,
adapter_config,
adapter_weight_names,
adapter_tokenizer,
) = load_and_merge_adapters(
model.model_id,
adapter_parameters,
adapter_index,
weight_names,
False,
)
unused_weight_names = adapter_weight_names.copy()
adapter_layers = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
"qkv_proj",
]
for layer_name in adapter_layers:
nlayers = (
1 if layer_name == "lm_head" else len(model.model.model.layers)
)
adapter_weights = LoraWeights.prepare_weights(
config=adapter_config,
module_map=module_map,
layer_type=layer_name,
unused_weight_names=unused_weight_names,
nlayers=nlayers,
dtype=model.dtype,
world_size=model.world_size,
process_group=model.process_group,
target_to_layer=target_to_layer,
)
if adapter_weights is None:
continue
model.layer_to_adapter_weights[layer_name].add_adapter(
adapter_index, adapter_weights
)
if len(unused_weight_names) > 0:
logger.warning(
f"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}"
)
if adapter_tokenizer is not None:
model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
model.loaded_adapters.add(adapter_index)
return model
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/__init__.py
================================================
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/bloom_modeling.py
================================================
# coding=utf-8
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BLOOM model."""
import math
import os
import warnings
from typing import Optional, Tuple, Union
import torch
import torch.distributed
import torch.utils.checkpoint
from torch import nn
from torch.nn import LayerNorm
from torch.nn import functional as F
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
)
from transformers import BloomConfig, PreTrainedModel
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
SpeculativeHead,
)
CUSTOM_KERNELS_ENABLED = False
if (
torch.cuda.is_available()
and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True"
):
try:
from custom_kernels import fused_bloom_attention_cuda
CUSTOM_KERNELS_ENABLED = True
except ImportError:
pass
_CHECKPOINT_FOR_DOC = "bigscience/bloom-560m"
_CONFIG_FOR_DOC = "BloomConfig"
BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"bigscience/bigscience-small-testing",
"bigscience/bloom-560m",
"bigscience/bloom-1b1",
"bigscience/bloom-1b7",
"bigscience/bloom-3b",
"bigscience/bloom-7b1",
"bigscience/bloom",
]
def _make_causal_mask(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
) -> torch.BoolTensor:
"""
Make causal mask used for self-attention.
"""
batch_size, target_length = input_ids_shape
mask = torch.ones(
(target_length, target_length + past_key_values_length),
dtype=torch.bool,
device=device,
)
mask = mask.triu(1 + past_key_values_length)
expanded_mask = mask.unsqueeze(0).expand(
batch_size, target_length, target_length + past_key_values_length
)
return expanded_mask
def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
"""
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
"""
batch_size, src_length = mask.shape
tgt_length = tgt_length if tgt_length is not None else src_length
expanded_mask = ~(mask[:, None, :].to(torch.bool))
return expanded_mask.expand(batch_size, tgt_length, src_length)
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int) -> torch.Tensor:
"""
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
`softmax(l+a) = softmax(l)`. Based on
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
Args:
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
attention_mask (`torch.Tensor`):
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
num_heads (`int`, *required*):
number of heads
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
dtype of the output tensor
"""
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32,
)
powers = torch.arange(
1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32
)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(
1,
1 + 2 * num_remaining_heads,
2,
device=attention_mask.device,
dtype=torch.int32,
)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
# => the query_length dimension will then be broadcasted correctly
# This is more or less identical to T5's relative position bias:
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
return alibi
# @torch.jit.script
def dropout_add(
x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool
) -> torch.Tensor:
"""
Dropout add function
Args:
x (`torch.tensor`, *required*):
input tensor
residual (`torch.tensor`, *required*):
esidual tensor
prob (`float`, *required*):
dropout probability
training (`bool`, *required*):
training mode
"""
out = F.dropout(x, p=prob, training=training)
out = residual + out
return out
# @torch.jit.script # this is shit for unknow reasons.
def _split_heads(
fused_qkv: torch.Tensor, num_heads: int, head_dim: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
storage as `fused_qkv`
Args:
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
Returns:
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
value: [batch_size, seq_length, num_heads, head_dim]
"""
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
fused_qkv = fused_qkv.view(batch_size, seq_length, num_heads, 3 * head_dim)
query_layer, key_layer, value_layer = fused_qkv.split(head_dim, dim=-1)
query_layer = query_layer.transpose(1, 2).reshape(
batch_size * num_heads, seq_length, head_dim
)
key_layer = key_layer.permute(0, 2, 3, 1).reshape(
batch_size * num_heads, head_dim, seq_length
)
value_layer = value_layer.transpose(1, 2).reshape(
batch_size * num_heads, seq_length, head_dim
)
return query_layer, key_layer, value_layer
# @torch.jit.script
def _merge_heads(x: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor:
"""
Merge heads together over the last dimenstion
Args:
x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
Returns:
torch.tensor: [batch_size, seq_length, num_heads * head_dim]
"""
# What we want to achieve is:
# batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
batch_size_and_num_heads, seq_length, _ = x.shape
batch_size = batch_size_and_num_heads // num_heads
# First view to decompose the batch size
# batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
x = x.view(batch_size, num_heads, seq_length, head_dim)
# batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
x = x.permute(0, 2, 1, 3)
# batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
return x.reshape(batch_size, seq_length, num_heads * head_dim)
class BloomAttention(nn.Module):
def __init__(self, prefix, config: BloomConfig, weights):
super().__init__()
self.pretraining_tp = config.pretraining_tp
self.slow_but_exact = config.slow_but_exact
self.process_group = weights.process_group
self.hidden_size = config.hidden_size
self.num_heads = config.n_head
self.head_dim = self.hidden_size // self.num_heads
self.split_size = self.hidden_size
self.hidden_dropout = config.hidden_dropout
if self.head_dim * self.num_heads != self.hidden_size:
raise ValueError(
f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
f" {self.num_heads})."
)
# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
self.beta = 1.0
process_group = weights.process_group
if self.num_heads % process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {process_group.size()}"
)
self.num_heads = self.num_heads // process_group.size()
self.query_key_value = TensorParallelColumnLinear.load(
config=config,
prefix=f"{prefix}.query_key_value",
weights=weights,
bias=True,
)
self.dense = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.dense", weights=weights, bias=True
)
self.attention_dropout = nn.Dropout(config.attention_dropout)
@staticmethod
def compute_attention(
fused_qkv: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]],
alibi: torch.Tensor,
attention_mask: torch.Tensor,
head_mask: Optional[torch.Tensor],
beta: float,
inv_norm_factor: float,
num_heads: int,
use_cache: bool,
):
batch_size, q_length, three_times_hidden_size = fused_qkv.shape
head_dim = three_times_hidden_size // (3 * num_heads)
batch_size * num_heads
### TODO @thomasw21: this takes quite a bit of time, how do I accelerate that?
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = _split_heads(
fused_qkv, num_heads=num_heads, head_dim=head_dim
)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, head_dim, kv_length]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
past_key = past_key.view(-1, *past_key.shape[-2:])
key_layer = torch.cat((past_key, key_layer), dim=2)
past_value = past_value.view(-1, *past_value.shape[-2:])
value_layer = torch.cat((past_value, value_layer), dim=1)
_, _, kv_length = key_layer.shape
if use_cache is True:
present = (key_layer, value_layer)
else:
present = None
###
# [batch_size * num_heads, q_length, kv_length]
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
attention_scores = alibi.baddbmm(
batch1=query_layer,
batch2=key_layer,
beta=beta,
alpha=inv_norm_factor,
)
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores.dtype
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
if input_dtype == torch.float16:
attention_scores = attention_scores.to(torch.float)
# torch.finfo not supported by torch.jit, we temporarily remplace with `-1e34`
attn_weights = attention_scores.masked_fill_(
attention_mask, torch.finfo(attention_scores.dtype).min
)
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
input_dtype
)
# # [batch_size, num_heads, q_length, kv_length]
# attention_probs = self.attention_dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm(attention_probs, value_layer, out=query_layer)
# change view [batch_size, num_heads, q_length, head_dim]
context_layer = _merge_heads(
context_layer, num_heads=num_heads, head_dim=head_dim
)
return context_layer, present, attention_probs
def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
fused_qkv = self.query_key_value(
hidden_states
) # [batch_size, seq_length, 3 x hidden_size]
batch_size, q_length, _ = fused_qkv.shape
if layer_past is not None:
past_key, past_value = layer_past
layer_past = (
past_key.view(-1, *past_key.shape[-2:]),
past_value.view(-1, *past_value.shape[-2:]),
)
if CUSTOM_KERNELS_ENABLED and attention_mask.shape[-1] < 4096:
assert self.training is False, "Only foward pass was implemented"
assert (
attention_mask.shape[-1] < 4096
), "Custom kernel support only up to 4096 tokens"
(
context_layer,
present,
attention_probs,
) = fused_bloom_attention_cuda.forward(
fused_qkv,
layer_past,
alibi,
attention_mask,
head_mask,
self.beta,
self.inv_norm_factor,
self.num_heads,
use_cache,
)
else:
context_layer, present, attention_probs = self.compute_attention(
fused_qkv=fused_qkv,
layer_past=layer_past,
alibi=alibi,
attention_mask=attention_mask,
head_mask=head_mask,
beta=self.beta,
inv_norm_factor=self.inv_norm_factor,
num_heads=self.num_heads,
use_cache=use_cache,
)
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
# output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
output_tensor += residual
outputs = (output_tensor, present)
if output_attentions:
outputs += (attention_probs,)
return outputs
class BloomMLP(nn.Module):
def __init__(self, prefix, config: BloomConfig, weights):
super().__init__()
self.pretraining_tp = config.pretraining_tp
self.slow_but_exact = config.slow_but_exact
self.dense_h_to_4h = TensorParallelColumnLinear.load(
config=config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
)
self.dense_4h_to_h = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
)
self.gelu_impl = torch.nn.GELU(approximate="tanh")
self.hidden_dropout = config.hidden_dropout
def forward(
self, hidden_states: torch.Tensor, residual: torch.Tensor
) -> torch.Tensor:
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
if self.pretraining_tp > 1 and self.slow_but_exact:
intermediate_output = torch.zeros_like(residual)
slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
for i in range(self.pretraining_tp):
intermediate_output = intermediate_output + F.linear(
hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense_4h_to_h.weight[
:, int(i * slices) : int((i + 1) * slices)
],
)
else:
intermediate_output = self.dense_4h_to_h(hidden_states)
# output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
intermediate_output += residual
return intermediate_output
class BloomBlock(nn.Module):
def __init__(self, layer_id: int, config: BloomConfig, weights):
super().__init__()
prefix = f"h.{layer_id}"
self.input_layernorm = LayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.num_heads = config.n_head
self.self_attention = BloomAttention(
prefix=f"{prefix}.self_attention", config=config, weights=weights
)
self.post_attention_layernorm = LayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.mlp = BloomMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm
)
self.hidden_dropout = config.hidden_dropout
def forward(
self,
hidden_states: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
# hidden_states: [batch_size, seq_length, hidden_size]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Layer norm post the self attention.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# Self attention.
attn_outputs = self.self_attention(
layernorm_output,
residual,
layer_past=layer_past,
attention_mask=attention_mask,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attention_output = attn_outputs[0]
outputs = attn_outputs[1:]
layernorm_output = self.post_attention_layernorm(attention_output)
# Get residual
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = attention_output
# MLP.
output = self.mlp(layernorm_output, residual)
if use_cache:
outputs = (output,) + outputs
else:
outputs = (output,) + outputs[1:]
return outputs # hidden_states, present, attentions
class BloomPreTrainedModel(PreTrainedModel):
config_class = BloomConfig
base_model_prefix = "transformer"
_no_split_modules = ["BloomBlock"]
@staticmethod
def _convert_to_standard_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
num_heads, ...]))
"""
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
num_heads = batch_size_times_num_heads // batch_size
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)
@staticmethod
def _convert_to_bloom_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]],
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
"""
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
batch_size_times_num_heads = batch_size * num_heads
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)
class BloomModel(BloomPreTrainedModel):
def __init__(self, config: BloomConfig, weights):
super().__init__(config)
self.embed_dim = config.hidden_size
self.num_heads = config.n_head
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.word_embeddings = TensorParallelEmbedding(
prefix="word_embeddings", weights=weights
)
self.word_embeddings_layernorm = LayerNorm.load(
prefix="word_embeddings_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
# Transformer blocks
self.h = nn.ModuleList(
[
BloomBlock(layer_id=layer_id, config=config, weights=weights)
for layer_id in range(config.num_hidden_layers)
]
)
# Final Layer Norm
self.ln_f = LayerNorm.load(
prefix="ln_f", weights=weights, eps=config.layer_norm_epsilon
)
def _prepare_attn_mask(
self,
attention_mask: torch.Tensor,
input_shape: Tuple[int, int],
past_key_values_length: int,
) -> torch.BoolTensor:
# create causal mask
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
combined_attention_mask = None
device = attention_mask.device
_, src_length = input_shape
if src_length > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
device=device,
past_key_values_length=past_key_values_length,
)
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask | combined_attention_mask
)
return combined_attention_mask
def set_input_embeddings(self, new_embeddings: torch.Tensor):
self.word_embeddings = new_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings.warn(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`.",
FutureWarning,
)
if len(deprecated_arguments) > 0:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
# Compute alibi tensor: check build_alibi_tensor documentation
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[-1]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), device=hidden_states.device
)
else:
attention_mask = attention_mask.to(hidden_states.device)
alibi = build_alibi_tensor(attention_mask, self.num_heads)
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
if hasattr(self, "tp_rank"):
assert self.num_heads % self.tp_world_size == 0
block_size = self.num_heads // self.tp_world_size
alibi = alibi[
:, self.tp_rank * block_size : (self.tp_rank + 1) * block_size
]
alibi = alibi.reshape(batch_size * block_size, 1, seq_length_with_past)
causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)
else:
alibi = alibi.reshape(batch_size * self.num_heads, 1, seq_length_with_past)
causal_mask = torch.repeat_interleave(causal_mask, self.num_heads, dim=0)
alibi = alibi.to(hidden_states.dtype)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (
outputs[2 if use_cache else 1],
)
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class BloomForCausalLM(BloomPreTrainedModel):
def __init__(self, prefix: str, config, weights):
super().__init__(config)
self.transformer = BloomModel(config, weights)
self.lm_head = SpeculativeHead.load(
config,
prefix="word_embeddings",
weights=weights,
)
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
# only last token for input_ids if past is not None
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_bloom_cache(past_key_values)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**deprecated_arguments,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings.warn(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`.",
FutureWarning,
)
if len(deprecated_arguments) > 0:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits, speculative_logits = self.lm_head(hidden_states)
loss = None
if not return_dict:
output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return (
CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
),
speculative_logits,
)
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/clip.py
================================================
from typing import Optional, Tuple
import torch
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import (
_create_4d_causal_attention_mask,
_prepare_4d_attention_mask,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPooling,
)
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
from text_generation_server.layers import (
TensorParallelEmbedding,
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
class CLIPVisionEmbeddings(nn.Module):
def __init__(self, prefix, config: CLIPVisionConfig, weights):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
# TODO Should we TP this ?
self.class_embedding = weights.get_tensor(f"{prefix}.class_embedding")
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = TensorParallelEmbedding(
prefix=f"{prefix}.position_embedding", weights=weights
)
self.register_buffer(
"position_ids",
torch.arange(self.num_positions, device=weights.device).expand((1, -1)),
persistent=False,
)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(
pixel_values.to(dtype=target_dtype)
) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
class CLIPTextEmbeddings(nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
embed_dim = config.hidden_size
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(
config.max_position_embeddings, embed_dim
)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer(
"position_ids",
torch.arange(config.max_position_embeddings).expand((1, -1)),
persistent=False,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
seq_length = (
input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
)
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
class CLIPAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_size = self.embed_dim // self.num_heads
if self.head_size * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.num_heads = self.num_heads // weights.process_group.size()
self.embed_dim = self.embed_dim // weights.process_group.size()
self.scale = self.head_size**-0.5
self.dropout = config.attention_dropout
self.qkv = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=True,
)
self.out_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.out_proj",
weights=weights,
bias=True,
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return (
tensor.view(bsz, seq_len, self.num_heads, self.head_size)
.transpose(1, 2)
.contiguous()
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
# get query proj
qkv = self.qkv(hidden_states)
query_states, key_states, value_states = qkv.split(
[
self.head_size * self.num_heads,
]
* 3,
dim=2,
)
query_states = query_states * self.scale
key_states = self._shape(key_states, -1, bsz)
value_states = self._shape(value_states, -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_size)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
# apply the causal_attention_mask first
if causal_attention_mask is not None:
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {causal_attention_mask.size()}"
)
attn_weights = (
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ causal_attention_mask
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = (
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attention_mask
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_probs = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None
class CLIPMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class CLIPEncoderLayer(nn.Module):
def __init__(self, prefix, config: CLIPConfig, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = CLIPAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.layer_norm1 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
)
self.mlp = CLIPMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.layer_norm2 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor,
):
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class CLIPPreTrainedModel(nn.Module):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = CLIPConfig
base_model_prefix = "clip"
supports_gradient_checkpointing = True
CLIP_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
CLIP_TEXT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
"""
CLIP_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
"""
CLIP_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
return_loss (`bool`, *optional*):
Whether or not to return the contrastive loss.
"""
class CLIPEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`CLIPEncoderLayer`].
Args:
config: CLIPConfig
"""
def __init__(self, prefix, config: CLIPConfig, weights):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[
CLIPEncoderLayer(
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
)
for i in range(config.num_hidden_layers)
]
)
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
):
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Causal mask for the text model. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
"""
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
hidden_states = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
)
return hidden_states
class CLIPTextTransformer(nn.Module):
def __init__(self, prefix: str, config: CLIPTextConfig, weights=None):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = CLIPTextEmbeddings(config)
# Initialize weights and apply final processing with `self.post_init()`
self.encoder = CLIPEncoder(
prefix=f"{prefix}.encoder", config=config, weights=weights
)
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
# For `pooled_output` computation
self.eos_token_id = config.eos_token_id
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
):
r"""
Returns:
"""
if input_ids is None:
raise ValueError("You have to specify input_ids")
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = _create_4d_causal_attention_mask(
input_shape, hidden_states.dtype, device=hidden_states.device
)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(
attention_mask, hidden_states.dtype
)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)
if self.eos_token_id == 2:
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
# ------------------------------------------------------------
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
last_hidden_state[
torch.arange(
last_hidden_state.shape[0], device=last_hidden_state.device
),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(
dim=-1
),
]
else:
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
last_hidden_state[
torch.arange(
last_hidden_state.shape[0], device=last_hidden_state.device
),
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
(
input_ids.to(dtype=torch.int, device=last_hidden_state.device)
== self.eos_token_id
)
.int()
.argmax(dim=-1),
]
return last_hidden_state
class CLIPTextModel(CLIPPreTrainedModel):
config_class = CLIPTextConfig
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
def __init__(self, prefix, config: CLIPTextConfig):
super().__init__(config)
self.text_model = CLIPTextTransformer(prefix, config)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
):
r"""
Returns:
Examples:
```python
>>> from transformers import AutoTokenizer, CLIPTextModel
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
```"""
return self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
)
class CLIPVisionTransformer(nn.Module):
def __init__(self, prefix, config: CLIPVisionConfig, weights):
super().__init__()
self.config = config
self.embeddings = CLIPVisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights
)
self.pre_layrnorm = nn.LayerNorm.load(
prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps
)
self.encoder = CLIPEncoder(
prefix=f"{prefix}.encoder", config=config, weights=weights
)
# self.post_layernorm = nn.LayerNorm.load(prefix=f"{prefix}.post_layernorm", weights=weights, eps=config.layer_norm_eps)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
):
r"""
Returns:
"""
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
)
last_hidden_state = encoder_outputs
# pooled_output = last_hidden_state[:, 0, :]
# pooled_output = self.post_layernorm(pooled_output)
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
# pooler_output=pooled_output,
# hidden_states=encoder_outputs,
)
class CLIPVisionModel(CLIPPreTrainedModel):
config_class = CLIPVisionConfig
main_input_name = "pixel_values"
_no_split_modules = ["CLIPEncoderLayer"]
def __init__(self, config: CLIPVisionConfig):
super().__init__(config)
self.vision_model = CLIPVisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
):
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, CLIPVisionModel
>>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled CLS states
```"""
return self.vision_model(
pixel_values=pixel_values,
)
class CLIPModel(nn.Module):
def __init__(self, prefix, config: CLIPConfig, weights):
super().__init__()
text_config = config.text_config
vision_config = config.vision_config
self.projection_dim = config.projection_dim
self.text_embed_dim = text_config.hidden_size
self.vision_embed_dim = vision_config.hidden_size
self.text_model = CLIPTextTransformer(text_config)
self.vision_model = CLIPVisionTransformer(vision_config)
self.visual_projection = nn.Linear(
self.vision_embed_dim, self.projection_dim, bias=False
)
self.text_projection = nn.Linear(
self.text_embed_dim, self.projection_dim, bias=False
)
self.logit_scale = nn.Parameter(
torch.tensor(self.config.logit_scale_init_value)
)
# Initialize weights and apply final processing
self.post_init()
def get_text_features(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
r"""
Returns:
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
applying the projection layer to the pooled output of [`CLIPTextModel`].
Examples:
```python
>>> from transformers import AutoTokenizer, CLIPModel
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
>>> text_features = model.get_text_features(**inputs)
```"""
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
)
pooled_output = text_outputs[1]
text_features = self.text_projection(pooled_output)
return text_features
def get_image_features(
self,
pixel_values: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
r"""
Returns:
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
applying the projection layer to the pooled output of [`CLIPVisionModel`].
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, CLIPModel
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> image_features = model.get_image_features(**inputs)
```"""
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
vision_outputs = self.vision_model(
pixel_values=pixel_values,
)
pooled_output = vision_outputs[1] # pooled_output
image_features = self.visual_projection(pooled_output)
return image_features
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
):
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, CLIPModel
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
... )
>>> outputs = model(**inputs)
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
```"""
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
vision_outputs = self.vision_model(
pixel_values=pixel_values,
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
)
image_embeds = vision_outputs[1]
image_embeds = self.visual_projection(image_embeds)
text_embeds = text_outputs[1]
text_embeds = self.text_projection(text_embeds)
# normalized features
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.t()
return logits_per_image, logits_per_text
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py
================================================
# coding=utf-8
# Copyright 2024 Cohere team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.utils.weights import UnquantizedWeight
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
apply_rotary_pos_emb,
)
import habana_frameworks.torch as htorch
class CohereRotary(PositionRotaryEmbedding):
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
):
# Such controlflows may add some overhead.
num_tokens = query.shape[0]
head_size = query.shape[-1]
rope_mode = RotaryPosEmbeddingMode.PAIRWISE
sin = torch.repeat_interleave(sin, 2, dim=-1)
cos = torch.repeat_interleave(cos, 2, dim=-1)
rotary_dim = cos.shape[-1]
query_shape = query.shape
query = query.view(num_tokens, -1, head_size)
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
key_shape = key.shape
key = key.view(num_tokens, -1, head_size)
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
class CohereLayerNorm(nn.Module):
def __init__(self, prefix, weights, eps):
super().__init__()
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
self.weight = nn.Parameter(weight)
# Fake weights
self.ones = weight.new_ones(weight.shape[1])
self.eps = eps
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(
-1, self.weight.shape[0], self.weight.shape[1]
)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
hidden_states_minus_mean = hidden_states - mean
variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps)
hidden_states = self.weight.to(torch.float32) * hidden_states
hidden_states = hidden_states.view(-1, self.weight.shape[1])
return hidden_states.to(input_dtype)
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=config.attention_bias,
)
def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
)
if isinstance(weight, UnquantizedWeight):
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
if config.attention_bias:
w = [
weights.get_sharded(f"{p}.bias", dim=0)
for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
]
bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device)
else:
bias = None
return TensorParallelColumnLinear(get_linear(weight, bias=bias))
class FlashCohereAttention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
rotary_emb,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = rotary_emb
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.use_qk_norm = config.use_qk_norm
if self.use_qk_norm:
self.q_norm = CohereLayerNorm(
prefix=f"{prefix}.q_norm",
weights=weights,
eps=config.layer_norm_eps,
)
self.k_norm = CohereLayerNorm(
prefix=f"{prefix}.k_norm",
weights=weights,
eps=config.layer_norm_eps,
)
else:
self.q_norm = None
self.k_norm = None
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=config.attention_bias,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
query, key, value = qkv.split(
[
self.head_size * self.num_heads,
self.head_size * self.num_key_value_heads,
self.head_size * self.num_key_value_heads,
],
dim=1,
)
if self.use_qk_norm:
query = query.reshape(-1, self.head_size)
key = key.reshape(-1, self.head_size)
query = self.q_norm(query.contiguous())
key = self.k_norm(key.contiguous())
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_key_value_heads, self.head_size)
value = value.view(-1, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, key, cos, sin)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
attn_output = attention(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), reduce=False
)
class CohereMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=False
)
class FlashCohereLayer(nn.Module):
def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = FlashCohereAttention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
self.mlp = CohereMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastLayerNorm.load_no_bias(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
self.process_group = weights.process_group
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
mlp_output = self.mlp(normed_hidden_states)
output = attn_output + mlp_output
if self.process_group.size() > 1:
torch.distributed.all_reduce(output, group=self.process_group)
return output, res
class FlashCohereModel(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
rotary_emb = CohereRotary.static(
config=config,
dim=config.hidden_size // config.num_attention_heads,
base=config.rope_theta,
device=weights.device,
)
self.layers = nn.ModuleList(
[
FlashCohereLayer(
prefix,
layer_id,
config,
weights,
rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = FastLayerNorm.load_no_bias(
prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_eps
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: torch.Tensor,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashCohereForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = FlashCohereModel(prefix, config, weights)
try:
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head",
weights=weights,
)
except RuntimeError:
self.lm_head = SpeculativeHead.load(
config,
prefix=f"{prefix}.embed_tokens",
weights=weights,
)
self.logit_scale = config.logit_scale
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
logits *= self.logit_scale
if speculative_logits is not None:
speculative_logits *= self.logit_scale
return logits, speculative_logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py
================================================
# coding=utf-8
# Copyright 2022 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
FastLinear,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
from vllm_hpu_extension.ops import DynamicFusedMOE
import habana_frameworks.torch as htorch
class DbrxAttentionConfig(PretrainedConfig):
def __init__(
self,
attn_pdrop: float = 0,
clip_qkv: Optional[float] = None,
kv_n_heads: int = 1,
rope_theta: float = 10000.0,
**kwargs: Any,
):
super().__init__(**kwargs)
self.attn_pdrop = attn_pdrop
self.clip_qkv = clip_qkv
self.kv_n_heads = kv_n_heads
self.rope_theta = rope_theta
for k in ["model_type"]:
if k in kwargs:
kwargs.pop(k)
if len(kwargs) != 0:
raise ValueError(f"Found unknown {kwargs=}")
class DbrxFFNConfig(PretrainedConfig):
def __init__(
self,
ffn_act_fn: Optional[dict] = None,
ffn_hidden_size: int = 3584,
moe_num_experts: int = 4,
moe_top_k: int = 1,
moe_jitter_eps: Optional[float] = None,
moe_loss_weight: float = 0.01,
moe_normalize_expert_weights: Optional[float] = 1,
uniform_expert_assignment: bool = False,
**kwargs: Any,
):
super().__init__()
if ffn_act_fn is None:
ffn_act_fn = {"name": "silu"}
self.ffn_act_fn = ffn_act_fn
self.ffn_hidden_size = ffn_hidden_size
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.moe_jitter_eps = moe_jitter_eps
self.moe_loss_weight = moe_loss_weight
self.moe_normalize_expert_weights = moe_normalize_expert_weights
self.uniform_expert_assignment = uniform_expert_assignment
if uniform_expert_assignment:
raise ValueError("`uniform_expert_assignment = True` is not supported")
for k in ["model_type"]:
if k in kwargs:
kwargs.pop(k)
if len(kwargs) != 0:
raise ValueError(f"Found unknown {kwargs=}")
class DbrxConfig(PretrainedConfig):
attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "n_heads",
"num_hidden_layers": "n_layers",
}
def __init__(
self,
d_model: int = 2048,
n_heads: int = 16,
n_layers: int = 24,
max_seq_len: int = 2048,
vocab_size: int = 32000,
resid_pdrop: float = 0.0,
emb_pdrop: float = 0.0,
attn_config: Optional[DbrxAttentionConfig] = None,
ffn_config: Optional[DbrxFFNConfig] = None,
use_cache: bool = True,
initializer_range: float = 0.02,
output_router_logits: bool = False,
router_aux_loss_coef: float = 0.05,
**kwargs: Any,
):
if attn_config is None:
self.attn_config = DbrxAttentionConfig()
elif isinstance(attn_config, dict):
self.attn_config = DbrxAttentionConfig(**attn_config)
else:
self.attn_config = attn_config
if ffn_config is None:
self.ffn_config = DbrxFFNConfig()
elif isinstance(ffn_config, dict):
self.ffn_config = DbrxFFNConfig(**ffn_config)
else:
self.ffn_config = ffn_config
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.resid_pdrop = resid_pdrop
self.emb_pdrop = emb_pdrop
self.use_cache = use_cache
self.initializer_range = initializer_range
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
if tie_word_embeddings:
raise ValueError("tie_word_embeddings is not supported for Dbrx models.")
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
@property
def num_key_value_heads(self):
# We can't use the attribute map, since this the number of KV
# heads is not top-level.
return self.attn_config.kv_n_heads
def promote_scalar(x: torch.Tensor) -> torch.Tensor:
return x.view(1) if len(x.size()) == 0 else x
def load_attention(config, prefix, weights):
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.Wqkv",
weights=weights,
bias=False,
num_heads=config.n_heads,
num_key_value_heads=config.attn_config.kv_n_heads,
)
def _load_experts(config, prefix, weights):
world_size = weights.process_group.size()
rank = weights.process_group.rank()
assert (
config.ffn_config.ffn_hidden_size % world_size == 0
), f"The chosen size {config.ffn_config.ffn_hidden_size} is not compatible with sharding on {world_size} shards"
expert_size = config.ffn_config.ffn_hidden_size
block_size = expert_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = torch.empty(
(config.ffn_config.moe_num_experts * block_size, config.d_model),
dtype=weights.dtype,
device=weights.device,
)
slice_ = weights._get_slice(f"{prefix}")
for i in range(config.ffn_config.moe_num_experts):
offset = i * expert_size
expert_slice = slice_[start + offset : stop + offset]
tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
dtype=weights.dtype
).to(device=weights.device)
return tensor
def _load_experts_quantized(config, prefix, weights, cls):
world_size = weights.process_group.size()
rank = weights.process_group.rank()
assert (
config.ffn_config.ffn_hidden_size % world_size == 0
), f"The chosen size {config.ffn_config.ffn_hidden_size} is not compatible with sharding on {world_size} shards"
expert_size = config.ffn_config.ffn_hidden_size
block_size = expert_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
slice_ = weights._get_slice(f"{prefix}")
experts = []
for i in range(config.ffn_config.moe_num_experts):
if config.quantize in ["gptq", "awq"]:
raise NotImplementedError(
"Dbrx does not support gptq/awq quantization yet."
)
else:
offset = i * expert_size
expert_slice = (
slice_[start + offset : stop + offset]
.to(dtype=weights.dtype)
.to(device=weights.device)
)
if cls == TensorParallelRowLinear:
expert_slice = expert_slice.t().contiguous()
linear = get_linear(expert_slice, None)
experts.append(cls(linear, weights.process_group))
else:
linear = get_linear(expert_slice, None)
experts.append(cls(linear))
return experts
class DbrxAttention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
rotary_emb,
):
super().__init__()
self.clip_qkv = config.attn_config.clip_qkv
self.num_heads = config.n_heads
self.hidden_size = config.d_model
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = rotary_emb
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.attn_config.kv_n_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.out_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
if self.clip_qkv is not None:
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
class DbrxNormAttentionNorm(nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
rotary_emb,
):
super().__init__()
self.norm_1 = FastLayerNorm.load_no_bias(
prefix=f"{prefix}.norm_1", weights=weights, eps=1e-5
)
self.self_attn = DbrxAttention(
prefix=f"{prefix}.attn",
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
self.norm_2 = FastLayerNorm.load_no_bias(
prefix=f"{prefix}.norm_2",
weights=weights,
eps=1e-5,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
normed_hidden_states, res = self.norm_1(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
# faster post attention rms norm
normed_attn_res_output, attn_res = self.norm_2(attn_output, res)
return normed_attn_res_output, attn_res
@torch.jit.script
def select_experts(
gate_logits: torch.Tensor, top_k: int, moe_normalize_expert_weights: int
):
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
# weights, selected_experts: (sequence_length, top-k)
weights, selected_experts = torch.topk(all_probs, top_k, dim=-1)
if moe_normalize_expert_weights:
weights = weights / torch.norm(
weights, p=moe_normalize_expert_weights, dim=-1, keepdim=True
)
weights = weights.view(-1)
selected_experts = selected_experts.view(-1)
return selected_experts, weights
@torch.jit.script
def round_up(x: torch.Tensor, value: int):
return torch.div(x + (value - 1), value, rounding_mode="trunc") * value
class BlockSparseMoE(nn.Module):
def __init__(self, prefix, config: DbrxConfig, weights):
super().__init__()
self.moe_normalize_expert_weights = (
config.ffn_config.moe_normalize_expert_weights
)
self.hidden_dim = config.d_model
self.ffn_dim = config.ffn_config.ffn_hidden_size // weights.process_group.size()
self.num_experts = config.ffn_config.moe_num_experts
self.top_k = config.ffn_config.moe_top_k
act = config.ffn_config.ffn_act_fn["name"]
if "gelu" in act:
self.act = lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
elif "silu" in act:
self.act = torch.nn.functional.silu
else:
self.act = ACT2FN[act]
# gating
self.gate = FastLinear.load(
config, f"{prefix}.router.layer", weights, bias=False
)
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights).view(
self.num_experts, self.ffn_dim, self.hidden_dim
)
v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights).view(
self.num_experts, self.ffn_dim, self.hidden_dim
)
self.wv1 = torch.cat([w1, v1], dim=1)
self.w2 = (
_load_experts(config, f"{prefix}.experts.mlp.w2", weights)
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
.transpose(1, 2)
.contiguous()
)
self.process_group = weights.process_group
self.hpu_fused_moe = DynamicFusedMOE(self.num_experts)
for i in range(self.num_experts):
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.wv1[i])
self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.w2[i])
def forward(self, x: torch.Tensor) -> torch.Tensor:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(x)
out = self.hpu_fused_moe(x, router_logits, self.top_k)
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out.view(*x.shape)
class DenseMoE(nn.Module):
def __init__(self, prefix, config: DbrxConfig, weights):
super().__init__()
self.moe_normalize_expert_weights = (
config.ffn_config.moe_normalize_expert_weights
)
self.hidden_dim = config.d_model
self.ffn_dim = config.ffn_config.ffn_hidden_size // weights.process_group.size()
self.num_experts = config.ffn_config.moe_num_experts
self.top_k = config.ffn_config.moe_top_k
act = config.ffn_config.ffn_act_fn["name"]
if "gelu" in act:
self.act = lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
elif "silu" in act:
self.act = torch.nn.functional.silu
else:
self.act = ACT2FN[act]
# gating
self.gate = FastLinear.load(
config, f"{prefix}.router.layer", weights, bias=False
)
self.w1 = _load_experts_quantized(
config,
prefix=f"{prefix}.experts.mlp.w1",
weights=weights,
cls=TensorParallelColumnLinear,
)
self.w2 = _load_experts_quantized(
config,
prefix=f"{prefix}.experts.mlp.w2",
weights=weights,
cls=TensorParallelRowLinear,
)
self.v1 = _load_experts_quantized(
config,
prefix=f"{prefix}.experts.mlp.v1",
weights=weights,
cls=TensorParallelColumnLinear,
)
self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
# all_probs: (sequence_length, n_experts) and upcast for softmax
weights = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
if self.top_k < self.num_experts:
_, not_selected_experts = torch.topk(
weights,
self.num_experts - self.top_k,
largest=False,
sorted=False,
dim=1,
)
# Mask not selected experts
weights.scatter_(1, not_selected_experts, 0)
# Re-normalize
if self.moe_normalize_expert_weights:
weights = weights / torch.norm(
weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
)
weights = weights.to(x.dtype)
# Final output tensor
out = x.new_zeros(x.shape[0], self.hidden_dim)
for i in range(self.num_experts):
h = self.act(self.w1[i](x)) * self.v1[i](x)
h = self.w2[i](h, reduce=False)
# Add expert output to out with masking
out += h * weights[:, i].view(-1, 1)
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out
class DbrxLayer(nn.Module):
def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):
super().__init__()
prefix = f"{prefix}.blocks.{layer_id}"
self.attn = DbrxNormAttentionNorm(
prefix=f"{prefix}.norm_attn_norm",
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE
self.moe = moe_cls(f"{prefix}.ffn", config, weights)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
# Self Attention
attn_output, attn_res = self.attn(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
moe_output = self.moe(attn_output)
return moe_output, attn_res
class DbrxModel(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.wte", weights=weights
)
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.d_model // config.n_heads,
base=config.attn_config.rope_theta,
device=weights.device,
)
self.layers = nn.ModuleList(
[
DbrxLayer(
prefix,
layer_id,
config,
weights,
rotary_emb,
)
for layer_id in range(config.n_layers)
]
)
self.norm = FastLayerNorm.load_no_bias(
prefix=f"{prefix}.norm_f", weights=weights, eps=1e-5
)
self.head_size = self.layers[0].attn.self_attn.head_size
self.num_heads = self.layers[0].attn.self_attn.num_heads
self.num_key_value_heads = self.layers[0].attn.self_attn.num_key_value_heads
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashDbrxForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
self.model = DbrxModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py
================================================
# coding=utf-8
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Type
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from text_generation_server.layers import (
FastLinear,
SpeculativeHead,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
get_linear,
Fp8Linear,
)
from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention_mla,
set_block_mapping,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.weights import Weights
import habana_frameworks.torch as htorch
def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor:
if isinstance(layer, Fp8Linear):
eye = torch.eye(
layer.qweight.shape[-1], dtype=torch.bfloat16, device=layer.qweight.device
)
dequant_weights = layer(eye)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
class DeepseekV2Config(PretrainedConfig):
def __init__(
self,
vocab_size=102400,
hidden_size=4096,
intermediate_size=11008,
moe_intermediate_size=1407,
num_hidden_layers=30,
num_attention_heads=32,
num_key_value_heads=32,
n_shared_experts=2,
n_routed_experts=160,
ep_size=1,
routed_scaling_factor=1.0,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
topk_method="gready",
n_group=8,
topk_group=3,
num_experts_per_tok=6,
moe_layer_freq=1,
first_k_dense_replace=0,
norm_topk_prob=False,
scoring_func="softmax",
aux_loss_alpha=0.001,
seq_aux=True,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=100000,
eos_token_id=100001,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
self.aux_loss_alpha = aux_loss_alpha
self.seq_aux = seq_aux
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
if tie_word_embeddings:
raise ValueError(
"tie_word_embeddings is not supported for Deepseek V2 models."
)
if ep_size != 1:
raise ValueError(
f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}"
)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class DeepseekV2Attention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights: Weights,
rotary_emb,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.kv_lora_rank = config.kv_lora_rank
self.q_lora_rank = config.q_lora_rank
self.qk_nope_head_dim = config.qk_nope_head_dim
self.qk_rope_head_dim = config.qk_rope_head_dim
self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
self.value_head_size = config.v_head_dim
self.head_pad_size = max(self.head_size, self.value_head_size)
self.rotary_emb = rotary_emb
mscale = get_mscale(
self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
)
self.softmax_scale = self.head_size**-0.5 * mscale * mscale
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
if self.q_lora_rank is None:
self.q_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_proj",
weights=weights,
bias=config.attention_bias,
)
else:
self.q_a_proj = get_linear(
weight=weights.get_weights(f"{prefix}.q_a_proj"),
bias=(
weights.get_tensor(f"{prefix}.q_a_proj.bias")
if config.attention_bias
else None
),
)
self.q_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.q_a_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.q_b_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_b_proj",
weights=weights,
bias=config.attention_bias,
)
self.kv_a_proj_with_mqa = get_linear(
weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"),
bias=(
weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias")
if config.attention_bias
else None
),
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.kv_b_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.kv_b_proj",
weights=weights,
bias=config.attention_bias,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj.linear).T
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.value_head_size,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.value_head_size], dim=-1
)
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)
def _q_proj_and_k_up_proj(self, x):
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
q_nope, q_pe = (
q_proj(x)
.view(-1, self.num_heads, self.head_size)
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
return ql_nope.transpose(0, 1), q_pe
def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.value_head_size)
return self.o_proj(x)
def forward(
self,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache: KVCache,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
if self.q_lora_rank is None:
hidden_states_or_q_c = hidden_states
else:
hidden_states_or_q_c = self.q_a_layernorm(self.q_a_proj(hidden_states))[0]
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, key_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
kv_c_normed = self.kv_a_layernorm(compressed_kv.contiguous())[0]
# Prefill
if cu_seqlen_prefill is not None:
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
query = q_proj(hidden_states_or_q_c)
query = query.view(-1, self.num_heads, self.head_size)
query_nope, query_pe = torch.split(
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
else:
query_nope, query_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
batch_size, heads, head_dim = query_pe.shape
query_pe = (
query_pe.view(batch_size, heads, head_dim // 2, 2)
.transpose(2, 3)
.reshape(batch_size, heads, head_dim)
)
batch_size, heads, head_dim = key_pe.shape
key_pe = (
key_pe.view(batch_size, heads, head_dim // 2, 2)
.transpose(2, 3)
.reshape(batch_size, heads, head_dim)
)
self.rotary_emb(query_pe, key_pe, cos, sin)
latent_vec_k = torch.concat(
(kv_c_normed, key_pe.view(-1, self.qk_rope_head_dim)), dim=-1
)
latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)
latent_vec_k = latent_vec_k.unflatten(0, (slots.size(0), -1))
kv_cache.store(
key=latent_vec_k,
value=None,
slots=slots,
kv_scales=self.kv_scales,
)
if cu_seqlen_prefill is not None:
kv = self.kv_b_proj(kv_c_normed).view(
-1,
self.num_key_value_heads,
self.qk_nope_head_dim + self.value_head_size,
)
key_nope, value = torch.split(
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
)
query[..., self.qk_nope_head_dim :] = query_pe
key = torch.empty_like(query)
key[..., : self.qk_nope_head_dim] = key_nope
key[..., self.qk_nope_head_dim :] = key_pe
# We need to pad the heads because Flash Attention does not support
# qk and v with different head sizes.
query = torch.nn.functional.pad(
query, (0, self.head_pad_size - self.head_size), value=0
)
key = torch.nn.functional.pad(
key, (0, self.head_pad_size - self.head_size), value=0
)
value = torch.nn.functional.pad(
value, (0, self.head_pad_size - self.value_head_size), value=0
)
# flash attention
attn_output = attention(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
)
attn_output = attn_output[..., : self.value_head_size]
return self.o_proj(
attn_output.reshape(-1, self.num_heads * self.value_head_size)
)
else:
# Decode
query = torch.cat([query_nope, query_pe], dim=-1)
attn_output = paged_attention_mla(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
kv_lora_rank=self.kv_lora_rank,
)
attn_output = self._v_up_proj_and_o_proj(attn_output)
return attn_output
class DeepseekV2MLP(nn.Module):
def __init__(self, prefix: str, config, weights, intermediate_size: int):
super().__init__()
self.hidden_act = config.hidden_act
if self.hidden_act != "silu":
# Bail out because MoE only supports silu.
raise NotImplementedError(
"Currently only `silu` is supported as an activation for Deepseek V2."
)
self.act = ACT2FN[self.hidden_act]
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.intermediate_size = intermediate_size // weights.process_group.size()
# TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
)
class DeepseekV2MoE(nn.Module):
def __init__(
self,
prefix,
config: DeepseekV2Config,
moe_layer_cls: Type[MoELayer],
weights,
):
super().__init__()
self.hidden_dim = config.hidden_size
self.moe_intermediate_size = (
config.moe_intermediate_size // weights.process_group.size()
)
self.routed_scaling_factor = config.routed_scaling_factor
# Gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
self.moe_layer = moe_layer_cls(
prefix=f"{prefix}.experts",
n_experts=config.n_routed_experts,
n_expert_group=config.n_group,
renormalize=config.norm_topk_prob,
topk=config.num_experts_per_tok,
topk_group=config.topk_group,
weights=weights,
)
assert isinstance(self.moe_layer, MoELayer)
if config.n_shared_experts is not None:
self.shared_experts = DeepseekV2MLP(
prefix=f"{prefix}.shared_experts",
config=config,
weights=weights,
intermediate_size=config.moe_intermediate_size
* config.n_shared_experts,
)
else:
self.shared_experts = None
self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.shared_experts is not None:
shared_output = self.shared_experts(x, reduce=False)
else:
shared_output = None
router_logits = self.gate(x)
out = self.moe_layer(x, gating_output=router_logits)
if shared_output is not None:
out = out + shared_output
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out.view(*x.shape)
class DeepseekV2Layer(nn.Module):
def __init__(self, prefix, layer_id, config, weights, rotary_emb):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = DeepseekV2Attention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
if (
config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace
and layer_id % config.moe_layer_freq == 0
):
moe_layer_cls = (
SparseMoELayer
if SparseMoELayer.is_supported(weights)
else DenseMoELayer
)
self.mlp = DeepseekV2MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
else:
self.mlp = DeepseekV2MLP(
prefix=f"{prefix}.mlp",
config=config,
weights=weights,
intermediate_size=config.intermediate_size,
)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
# faster post attention rms norm
normed_attn_res_output, residual = self.post_attention_layernorm(
attn_output, residual
)
output = self.mlp(normed_attn_res_output)
return output, residual
class DeepseekV2Model(torch.nn.Module):
def __init__(self, prefix: str, config, weights: Weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.qk_rope_head_dim,
base=config.rope_theta,
device=weights.device,
)
self.layers = nn.ModuleList(
[
DeepseekV2Layer(
prefix,
layer_id,
config,
weights,
rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashDeepseekV2ForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights: Weights):
super().__init__()
self.model = DeepseekV2Model(
"model" if not prefix else f"{prefix}.model", config, weights
)
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py
================================================
# coding=utf-8
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Type
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from text_generation_server.layers import (
FastLinear,
SpeculativeHead,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
get_linear,
Fp8Linear,
)
from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention_mla,
set_block_mapping,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.weights import Weights
import habana_frameworks.torch as htorch
def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor:
if isinstance(layer, Fp8Linear):
eye = torch.eye(
layer.qweight.shape[-1], dtype=torch.bfloat16, device=layer.qweight.device
)
dequant_weights = layer(eye)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
class DeepseekV3Config(PretrainedConfig):
def __init__(
self,
vocab_size=102400,
hidden_size=4096,
intermediate_size=11008,
moe_intermediate_size=1407,
num_hidden_layers=30,
num_attention_heads=32,
num_key_value_heads=32,
n_shared_experts=2,
n_routed_experts=160,
ep_size=1,
routed_scaling_factor=1.0,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
topk_method="gready",
n_group=8,
topk_group=3,
num_experts_per_tok=6,
moe_layer_freq=1,
first_k_dense_replace=0,
norm_topk_prob=False,
scoring_func="softmax",
aux_loss_alpha=0.001,
seq_aux=True,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=100000,
eos_token_id=100001,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
self.aux_loss_alpha = aux_loss_alpha
self.seq_aux = seq_aux
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
if tie_word_embeddings:
raise ValueError(
"tie_word_embeddings is not supported for Deepseek V2 models."
)
if ep_size != 1:
raise ValueError(
f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}"
)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class DeepseekV3Attention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights: Weights,
rotary_emb,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.kv_lora_rank = config.kv_lora_rank
self.q_lora_rank = config.q_lora_rank
self.qk_nope_head_dim = config.qk_nope_head_dim
self.qk_rope_head_dim = config.qk_rope_head_dim
self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
self.value_head_size = config.v_head_dim
self.head_pad_size = max(self.head_size, self.value_head_size)
self.rotary_emb = rotary_emb
mscale = get_mscale(
self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
)
self.softmax_scale = self.head_size**-0.5 * mscale * mscale
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
if self.q_lora_rank is None:
self.q_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_proj",
weights=weights,
bias=config.attention_bias,
)
else:
self.q_a_proj = get_linear(
weight=weights.get_weights(f"{prefix}.q_a_proj"),
bias=(
weights.get_tensor(f"{prefix}.q_a_proj.bias")
if config.attention_bias
else None
),
)
self.q_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.q_a_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.q_b_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_b_proj",
weights=weights,
bias=config.attention_bias,
)
self.kv_a_proj_with_mqa = get_linear(
weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"),
bias=(
weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias")
if config.attention_bias
else None
),
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.kv_b_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.kv_b_proj",
weights=weights,
bias=config.attention_bias,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj.linear).T
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.value_head_size,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.value_head_size], dim=-1
)
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)
def _q_proj_and_k_up_proj(self, x):
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
q_nope, q_pe = (
q_proj(x)
.view(-1, self.num_heads, self.head_size)
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
return ql_nope.transpose(0, 1), q_pe
def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.value_head_size)
return self.o_proj(x)
def forward(
self,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache: KVCache,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
if self.q_lora_rank is None:
hidden_states_or_q_c = hidden_states
else:
hidden_states_or_q_c = self.q_a_layernorm(self.q_a_proj(hidden_states))[0]
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, key_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
kv_c_normed = self.kv_a_layernorm(compressed_kv.contiguous())[0]
# Prefill
if cu_seqlen_prefill is not None:
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
query = q_proj(hidden_states_or_q_c)
query = query.view(-1, self.num_heads, self.head_size)
query_nope, query_pe = torch.split(
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
else:
query_nope, query_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
batch_size, heads, head_dim = query_pe.shape
query_pe = (
query_pe.view(batch_size, heads, head_dim // 2, 2)
.transpose(2, 3)
.reshape(batch_size, heads, head_dim)
)
batch_size, heads, head_dim = key_pe.shape
key_pe = (
key_pe.view(batch_size, heads, head_dim // 2, 2)
.transpose(2, 3)
.reshape(batch_size, heads, head_dim)
)
self.rotary_emb(query_pe, key_pe, cos, sin)
latent_vec_k = torch.concat(
(kv_c_normed, key_pe.view(-1, self.qk_rope_head_dim)), dim=-1
)
latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)
latent_vec_k = latent_vec_k.unflatten(0, (slots.size(0), -1))
kv_cache.store(
key=latent_vec_k,
value=None,
slots=slots,
kv_scales=self.kv_scales,
)
if cu_seqlen_prefill is not None:
kv = self.kv_b_proj(kv_c_normed).view(
-1,
self.num_key_value_heads,
self.qk_nope_head_dim + self.value_head_size,
)
key_nope, value = torch.split(
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
)
query[..., self.qk_nope_head_dim :] = query_pe
key = torch.empty_like(query)
key[..., : self.qk_nope_head_dim] = key_nope
key[..., self.qk_nope_head_dim :] = key_pe
# We need to pad the heads because Flash Attention does not support
# qk and v with different head sizes.
query = torch.nn.functional.pad(
query, (0, self.head_pad_size - self.head_size), value=0
)
key = torch.nn.functional.pad(
key, (0, self.head_pad_size - self.head_size), value=0
)
value = torch.nn.functional.pad(
value, (0, self.head_pad_size - self.value_head_size), value=0
)
# flash attention
attn_output = attention(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
)
attn_output = attn_output[..., : self.value_head_size]
return self.o_proj(
attn_output.reshape(-1, self.num_heads * self.value_head_size)
)
else:
# Decode
query = torch.cat([query_nope, query_pe], dim=-1)
attn_output = paged_attention_mla(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
kv_lora_rank=self.kv_lora_rank,
)
attn_output = self._v_up_proj_and_o_proj(attn_output)
return attn_output
class DeepseekV3MLP(nn.Module):
def __init__(self, prefix: str, config, weights, intermediate_size: int):
super().__init__()
self.hidden_act = config.hidden_act
if self.hidden_act != "silu":
# Bail out because MoE only supports silu.
raise NotImplementedError(
"Currently only `silu` is supported as an activation for Deepseek V2."
)
self.act = ACT2FN[self.hidden_act]
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.intermediate_size = intermediate_size // weights.process_group.size()
# TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
)
class DeepseekV3MoE(nn.Module):
def __init__(
self,
prefix,
config: DeepseekV3Config,
moe_layer_cls: Type[MoELayer],
weights,
):
super().__init__()
self.hidden_dim = config.hidden_size
self.moe_intermediate_size = (
config.moe_intermediate_size // weights.process_group.size()
)
self.routed_scaling_factor = config.routed_scaling_factor
# Gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
if config.topk_method == "noaux_tc":
self.gate.e_score_correction_bias = torch.zeros(
config.n_routed_experts, device=weights.device
)
else:
self.gate.e_score_correction_bias = None
self.moe_layer = moe_layer_cls(
prefix=f"{prefix}.experts",
n_experts=config.n_routed_experts,
n_expert_group=config.n_group,
renormalize=config.norm_topk_prob,
topk=config.num_experts_per_tok,
topk_group=config.topk_group,
weights=weights,
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
)
assert isinstance(self.moe_layer, MoELayer)
if config.n_shared_experts is not None:
self.shared_experts = DeepseekV3MLP(
prefix=f"{prefix}.shared_experts",
config=config,
weights=weights,
intermediate_size=config.moe_intermediate_size
* config.n_shared_experts,
)
else:
self.shared_experts = None
self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.shared_experts is not None:
shared_output = self.shared_experts(x, reduce=False)
else:
shared_output = None
router_logits = self.gate(x)
out = self.moe_layer(x, gating_output=router_logits)
if shared_output is not None:
out = out + shared_output
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out.view(*x.shape)
class DeepseekV3Layer(nn.Module):
def __init__(self, prefix, layer_id, config, weights, rotary_emb):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = DeepseekV3Attention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
if (
config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace
and layer_id % config.moe_layer_freq == 0
):
moe_layer_cls = (
SparseMoELayer
if SparseMoELayer.is_supported(weights)
else DenseMoELayer
)
self.mlp = DeepseekV3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
else:
self.mlp = DeepseekV3MLP(
prefix=f"{prefix}.mlp",
config=config,
weights=weights,
intermediate_size=config.intermediate_size,
)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
# faster post attention rms norm
normed_attn_res_output, residual = self.post_attention_layernorm(
attn_output, residual
)
output = self.mlp(normed_attn_res_output)
return output, residual
class DeepseekV3Model(torch.nn.Module):
def __init__(self, prefix: str, config, weights: Weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.qk_rope_head_dim,
base=config.rope_theta,
device=weights.device,
)
self.layers = nn.ModuleList(
[
DeepseekV3Layer(
prefix,
layer_id,
config,
weights,
rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashDeepseekV3ForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights: Weights):
super().__init__()
self.model = DeepseekV3Model(
"model" if not prefix else f"{prefix}.model", config, weights
)
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py
================================================
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
get_linear,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class Gemma2Config(PretrainedConfig):
def __init__(
self,
vocab_size=256128,
hidden_size=3072,
intermediate_size=24576,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.head_dim = head_dim
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class Gemma2FastRMSNorm(FastRMSNorm):
@classmethod
def load(cls, prefix: str, weights, eps=1e-6):
dtype = weights.dtype
weights.dtype = torch.float32
weight = weights.get_tensor(f"{prefix}.weight") + 1
weights.dtype = dtype
new = cls(weight, eps)
new.dtype = dtype
return new
# perform the multiplication in full precision and downcast after
def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = hidden_states * self.weight
return hidden_states.to(self.dtype), residual
def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
def _load_gqa(config, prefix: str, weights):
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
)
if isinstance(weight, UnquantizedWeight):
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.head_dim
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias=None))
class FlashGemma2Attention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
layer_id,
causal: bool,
is_sliding: bool,
rotary_emb,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.head_size = config.head_dim
self.causal = causal
if is_sliding:
self.window_size = config.sliding_window
else:
self.window_size = -1
self.rotary_emb = rotary_emb
# self.softmax_scale = self.head_size**-0.5
self.softmax_scale = config.query_pre_attn_scalar**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.softcap = config.attn_logit_softcapping
query_key_value = load_attention(config, prefix, weights)
self.query_key_value = TensorParallelMultiAdapterLinear.load(
query_key_value,
layer_id,
["q_proj", "k_proj", "v_proj"],
sizes=[
self.head_size * config.num_attention_heads,
self.head_size * config.num_key_value_heads,
self.head_size * config.num_key_value_heads,
],
process_group=weights.process_group,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
layer_id,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
window_size_left=self.window_size,
softcap=self.softcap,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
softcap=self.softcap,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
window_size_left=self.window_size,
)
return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class Gemma2MLP(nn.Module):
def __init__(self, prefix, config, weights, layer_id):
super().__init__()
act = config.hidden_activation
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
# Fuse gate and up proj
gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
layer_id,
["gate_proj", "up_proj"],
sizes=[
config.intermediate_size,
config.intermediate_size,
],
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
layer_id,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class FlashGemma2Layer(nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
layer_id,
causal: bool,
is_sliding: bool,
rotary_emb,
):
super().__init__()
self.self_attn = FlashGemma2Attention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
layer_id=layer_id,
causal=causal,
is_sliding=is_sliding,
rotary_emb=rotary_emb,
)
self.mlp = Gemma2MLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
)
self.input_layernorm = Gemma2FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = Gemma2FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.pre_feedforward_layernorm = Gemma2FastRMSNorm.load(
prefix=f"{prefix}.pre_feedforward_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.post_feedforward_layernorm = Gemma2FastRMSNorm.load(
prefix=f"{prefix}.post_feedforward_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
)
# faster post attention rms norm
normed_attn_res_output, _ = self.post_attention_layernorm(attn_output)
normed_attn_res_output = normed_attn_res_output + res
res = normed_attn_res_output
pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)
mlp_output = self.mlp(pre_normed, adapter_data)
post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)
return post_hidden_states, normed_attn_res_output
class FlashGemma2Model(torch.nn.Module):
def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.head_dim,
base=config.rope_theta,
device=weights.device,
)
self.layers = nn.ModuleList(
[
FlashGemma2Layer(
prefix=f"{prefix}.layers.{layer_id}",
config=config,
weights=weights,
layer_id=layer_id,
causal=causal,
is_sliding=layer_id % 2 == 0,
rotary_emb=rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = Gemma2FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
adapter_data: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
adapter_data,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashGemma2ForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights, *, causal: bool = True):
super().__init__()
embed_norm = config.hidden_size**0.5
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.embed_tokens.weight *= embed_norm
self.model = FlashGemma2Model(
prefix=prefix, config=config, weights=weights, causal=causal
)
self.lm_head = SpeculativeHead.load(
prefix=(
f"{prefix}.embed_tokens"
if config.tie_word_embeddings
else f"{prefix}.lm_head"
),
config=config,
weights=weights,
)
self.softcap = config.final_logit_softcapping
assert isinstance(self.softcap, float)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
input_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
input_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
logits /= self.softcap
logits = torch.tanh(logits)
logits *= self.softcap
return logits, speculative_logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py
================================================
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from typing import Optional, List, Tuple
import copy
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
get_linear,
#
SpeculativeHead,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
import torch
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.utils.weights import UnquantizedWeight
from transformers.activations import ACT2FN
from text_generation_server.layers.attention import (
paged_attention,
attention,
Seqlen,
set_block_mapping,
HPUPagedAttentionMetadata,
)
import habana_frameworks.torch as htorch
ATTENTION_TYPE_GLOBAL = "global"
ATTENTION_TYPE_LOCAL = "local_sliding"
class Gemma3FastRMSNorm(FastRMSNorm):
@classmethod
def load(cls, prefix: str, weights, eps=1e-6):
dtype = weights.dtype
weights.dtype = torch.float32
weight = weights.get_tensor(f"{prefix}.weight") + 1
weights.dtype = dtype
new = cls(weight, eps)
new.dtype = dtype
return new
# perform the multiplication in full precision and downcast after
def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = hidden_states * self.weight
return hidden_states.to(self.dtype), residual
def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
def _load_gqa(config, prefix: str, weights):
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
)
if isinstance(weight, UnquantizedWeight):
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.head_dim
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias=None))
class FlashGemma3Attention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
layer_id,
causal: bool,
is_sliding: bool,
local_rotary_emb,
global_rotary_emb,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.head_size = config.head_dim
self.causal = causal
if is_sliding:
self.window_size = config.sliding_window
self.rotary_emb = local_rotary_emb
else:
self.window_size = -1
self.rotary_emb = global_rotary_emb
self.softmax_scale = (
config.query_pre_attn_scalar**-0.5
if config.query_pre_attn_scalar is not None
else None
)
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.softcap = None # config.attn_logit_softcapping
query_key_value = load_attention(config, prefix, weights)
self.query_key_value = TensorParallelMultiAdapterLinear.load(
query_key_value,
layer_id,
["q_proj", "k_proj", "v_proj"],
sizes=[
self.head_size * config.num_attention_heads,
self.head_size * config.num_key_value_heads,
self.head_size * config.num_key_value_heads,
],
process_group=weights.process_group,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
layer_id,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
self.q_norm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps
)
self.k_norm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps
)
self.enable_gqa = self.num_heads != self.num_key_value_heads
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
kv = kv.view(-1, 2, self.num_key_value_heads * self.head_size)
key = kv[:, 0]
value = kv[:, 1]
query = query.reshape(-1, self.head_size)
key = key.reshape(-1, self.head_size)
query, _ = self.q_norm(query.contiguous())
key, _ = self.k_norm(key.contiguous())
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_key_value_heads, self.head_size)
value = value.view(-1, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, key, cos, sin)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
attn_output = attention(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
window_size_left=self.window_size,
softcap=self.softcap,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
softcap=self.softcap,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
window_size_left=self.window_size,
)
return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class Gemma3MLP(nn.Module):
def __init__(self, prefix, config, weights, layer_id):
super().__init__()
act = config.hidden_activation
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
# Fuse gate and up proj
gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
layer_id,
["gate_proj", "up_proj"],
sizes=[
config.intermediate_size,
config.intermediate_size,
],
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
layer_id,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class FlashGemma3Layer(nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
layer_id,
causal: bool,
is_sliding: bool,
local_rotary_emb,
global_rotary_emb,
):
super().__init__()
self.self_attn = FlashGemma3Attention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
layer_id=layer_id,
causal=causal,
is_sliding=is_sliding,
local_rotary_emb=local_rotary_emb,
global_rotary_emb=global_rotary_emb,
)
self.mlp = Gemma3MLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
)
self.input_layernorm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.pre_feedforward_layernorm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.pre_feedforward_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.post_feedforward_layernorm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.post_feedforward_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
)
# faster post attention rms norm
normed_attn_res_output, _ = self.post_attention_layernorm(attn_output)
normed_attn_res_output = normed_attn_res_output + res
res = normed_attn_res_output
pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)
mlp_output = self.mlp(pre_normed, adapter_data)
post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)
return post_hidden_states, normed_attn_res_output
class FlashGemma3Model(torch.nn.Module):
def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
local_config = copy.deepcopy(config)
local_config.rope_scaling = dict(rope_type="default")
local_rotary_emb = PositionRotaryEmbedding.static(
config=local_config,
dim=config.head_dim,
base=config.rope_local_base_freq,
device=weights.device,
)
global_rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.head_dim,
base=config.rope_theta,
device=weights.device,
)
self.layers = nn.ModuleList(
[
FlashGemma3Layer(
prefix=f"{prefix}.layers.{layer_id}",
config=config,
weights=weights,
layer_id=layer_id,
causal=causal,
is_sliding=bool((layer_id + 1) % config.sliding_window_pattern),
local_rotary_emb=local_rotary_emb,
global_rotary_emb=global_rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
adapter_data: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
# Get rotary cos and sin for this forward
# Avoid to index in each layer
residual = None
for i, layer in enumerate(self.layers):
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = layer.self_attn.rotary_emb.get_cos_sin(position_ids)
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
adapter_data,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashGemma3ForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights, *, causal: bool = True):
super().__init__()
embed_norm = config.hidden_size**0.5
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.embed_tokens.weight *= embed_norm
self.model = FlashGemma3Model(
prefix=prefix, config=config, weights=weights, causal=causal
)
self.lm_head = SpeculativeHead.load(
prefix=(
f"{prefix}.embed_tokens"
if config.tie_word_embeddings
else f"{prefix}.lm_head"
),
config=config,
weights=weights,
)
# self.softcap = config.attn_logit_softcapping
# assert isinstance(self.softcap, float)
self.softcap = None
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
input_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
input_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits
class Gemma3MultimodalInputProjection(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.mm_input_projection_weight = weights.get_tensor(
"multi_modal_projector.mm_input_projection_weight"
)
self.mm_soft_emb_norm = Gemma3FastRMSNorm.load(
prefix=f"{prefix}.mm_soft_emb_norm",
weights=weights,
eps=config.vision_config.layer_norm_eps,
)
self.patches_per_image = int(
config.vision_config.image_size // config.vision_config.patch_size
)
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
self.kernel_size = self.patches_per_image // self.tokens_per_side
self.avg_pool = nn.AvgPool2d(
kernel_size=self.kernel_size, stride=self.kernel_size
)
def forward(self, vision_outputs: torch.Tensor):
batch_size, _, seq_length = vision_outputs.shape
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
batch_size, seq_length, self.patches_per_image, self.patches_per_image
)
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
normed_vision_outputs, _ = self.mm_soft_emb_norm(pooled_vision_outputs)
projected_vision_outputs = torch.matmul(
normed_vision_outputs, self.mm_input_projection_weight
)
return projected_vision_outputs.type_as(vision_outputs)
class Gemma3ForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
if config.vision_config is not None:
config.vision_config.quantize = config.quantize
self.post_vision_model_layernorm = nn.LayerNorm.load(
prefix="vision_tower.vision_model.post_layernorm",
weights=weights,
eps=config.vision_config.layer_norm_eps,
)
self.multimodal_projector = Gemma3MultimodalInputProjection(
prefix="multi_modal_projector",
config=config,
weights=weights,
)
text_config = config.text_config
text_config.speculator = config.speculator
text_config.quantize = config.quantize
self.vision_model = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config,
weights=weights,
)
self.text_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,
weights=weights,
)
else:
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
self.text_model = load_text_model(
prefix=prefix,
config=config.text_config,
weights=weights,
)
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
self.dtype = weights.dtype
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
pixel_values = pixel_values.to(dtype=self.dtype)
image_outputs = self.vision_model(pixel_values)
vision_outputs = self.post_vision_model_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multimodal_projector(vision_outputs)
image_features = image_features.view(-1, image_features.shape[-1])
return image_features
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
# Replace the image token embeddings with the vision features
image_token_mask = (input_ids == self.config.image_token_index).to(
input_ids.device
)
inputs_embeds[image_token_mask] = vision_embeds.view(
-1, vision_embeds.shape[-1]
)
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if cu_seqlen_prefill is not None:
position_ids += 1
if attention_mask is not None:
min_dtype = torch.finfo(inputs_embeds.dtype).min
# prefill may be larger than sliding window
effective_seq_len = max(
position_ids.shape[0], self.config.text_config.sliding_window
)
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool),
diagonal=-self.config.text_config.sliding_window,
)
attention_mask_local = torch.where(
sliding_window_mask, min_dtype, attention_mask
)
offset = max(0, position_ids.shape[0] - effective_seq_len)
attention_mask_local = attention_mask_local[
:, :, :, offset : offset + effective_seq_len
]
else:
attention_mask_local = None
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
adapter_data=adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py
================================================
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class GemmaConfig(PretrainedConfig):
def __init__(
self,
vocab_size=256128,
hidden_size=3072,
intermediate_size=24576,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.head_dim = head_dim
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class GemmaFastRMSNorm(FastRMSNorm):
@classmethod
def load(cls, prefix: str, weights, eps=1e-6):
dtype = weights.dtype
weights.dtype = torch.float32
weight = weights.get_tensor(f"{prefix}.weight") + 1
weights.dtype = dtype
new = cls(weight, eps)
new.dtype = dtype
return new
# perform the multiplication in full precision and downcast after
def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = hidden_states * self.weight
return hidden_states.to(self.dtype), residual
def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
def _load_gqa(config, prefix: str, weights):
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
)
if isinstance(weight, UnquantizedWeight):
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.head_dim
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias=None))
class FlashGemmaAttention(torch.nn.Module):
def __init__(self, prefix: str, config, weights, causal: bool, rotary_emb):
super().__init__()
self.num_heads = config.num_attention_heads
self.head_size = config.head_dim
self.causal = causal
self.rotary_emb = rotary_emb
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
causal=self.causal,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
class GemmaMLP(nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
class FlashGemmaLayer(nn.Module):
def __init__(self, prefix: str, config, weights, causal: bool, rotary_emb):
super().__init__()
self.self_attn = FlashGemmaAttention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
causal=causal,
rotary_emb=rotary_emb,
)
self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = GemmaFastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = GemmaFastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
# faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
)
mlp_output = self.mlp(normed_attn_res_output)
return mlp_output, attn_res
class FlashGemmaModel(torch.nn.Module):
def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.head_dim,
base=config.rope_theta,
device=weights.device,
)
self.layers = nn.ModuleList(
[
FlashGemmaLayer(
prefix=f"{prefix}.layers.{layer_id}",
config=config,
weights=weights,
causal=causal,
rotary_emb=rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = GemmaFastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
adapter_data: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashGemmaForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights, *, causal: bool = True):
super().__init__()
embed_norm = config.hidden_size**0.5
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.embed_tokens.weight *= embed_norm
self.model = FlashGemmaModel(
prefix=prefix, config=config, weights=weights, causal=causal
)
self.lm_head = SpeculativeHead.load(
prefix=(
f"{prefix}.embed_tokens"
if config.tie_word_embeddings
else f"{prefix}.lm_head"
),
config=config,
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
input_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
input_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py
================================================
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
import habana_frameworks.torch as htorch
def load_qkv(config, prefix: str, weights, head_size, num_heads):
if config.quantize == "gptq":
return _load_qkv_gptq(
config,
prefix,
weights,
)
else:
return _load_qkv(config, prefix, weights, head_size, num_heads)
def _load_qkv_gptq(config, prefix: str, weights):
world_size = weights.process_group.size()
rank = weights.process_group.rank()
# Weights
weight = weights.get_weights_col_packed_qkv(
f"{prefix}.c_attn",
config.num_attention_heads,
config.num_attention_heads,
)
# Bias
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
shape = slice_.get_shape()
total_size = shape[0]
assert total_size % 3 == 0, f"Prepacked is not divisible by {3}"
single_size = total_size // 3
assert single_size % world_size == 0
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensors = []
for i in range(3):
tensor = slice_[start + i * single_size : stop + i * single_size]
tensors.append(tensor)
bias = torch.cat(tensors, dim=0)
bias = bias.to(device=weights.device)
return TensorParallelColumnLinear(get_linear(weight, bias))
def _load_qkv(config, prefix: str, weights, head_size, num_heads):
"""Load QKV from a single, transposed matrix."""
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
shape = slice_.get_shape()
total_size = shape[1]
assert total_size % 3 == 0, f"Prepacked is not divisible by {3}"
world_size = weights.process_group.size()
single_size = total_size // 3
assert single_size % world_size == 0
rank = weights.process_group.rank()
# Weights
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensors = []
for i in range(3):
tensor = slice_[:, start + i * single_size : stop + i * single_size]
tensors.append(tensor)
weight = torch.cat(tensors, dim=1).T
weight = weight.to(dtype=weights.dtype)
weight = weight.to(device=weights.device)
# Bias
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
shape = slice_.get_shape()
total_size = shape[0]
single_size = total_size // 3
block_size = single_size // world_size
assert single_size % world_size == 0
start = rank * block_size
stop = (rank + 1) * block_size
b = []
for i in range(3):
tensor = slice_[start + i * single_size : stop + i * single_size]
b.append(tensor)
bias = torch.cat(b, dim=0)
bias = bias.to(dtype=weights.dtype)
bias = bias.to(device=weights.device)
assert list(bias.shape) == [
3 * num_heads * head_size
], f"{weight.shape} != {[3 * num_heads * head_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias))
def load_row(config, prefix: str, weights, bias: bool):
"""load_row, but with transposed weight matrices."""
if config.quantize == "gptq":
weight = weights.get_weights_row(prefix)
else:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return TensorParallelRowLinear(
get_linear(weight, bias), process_group=weights.process_group
)
def load_col(config, prefix: str, weights, bias: bool):
"""load_col, but with transposed weight matrices."""
if config.quantize == "gptq":
weight = weights.get_multi_weights_col([prefix], dim=1)
else:
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else:
bias = None
return TensorParallelColumnLinear(get_linear(weight, bias))
class FlashGPT2Attention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = load_qkv(
config,
prefix=prefix,
weights=weights,
head_size=self.head_size,
num_heads=self.num_heads,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = load_row(
config,
prefix=f"{prefix}.c_proj",
weights=weights,
bias=True,
)
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward(
self,
hidden_states,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
query, key, value = self.query_key_value(hidden_states).split(
self.head_size * self.num_heads, dim=1
)
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
attn_output = attention(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
class GPT2MLP(nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
act = config.activation_function
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
self.c_fc = load_col(
config, prefix=f"{prefix}.c_fc", weights=weights, bias=True
)
self.c_proj = load_row(
config,
prefix=f"{prefix}.c_proj",
weights=weights,
bias=True,
)
intermediate_size = (
config.n_inner if config.n_inner is not None else 4 * config.hidden_size
)
self.intermediate_size = intermediate_size // weights.process_group.size()
def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
return self.c_proj(hidden_states)
class FlashGPT2Layer(nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.self_attn = FlashGPT2Attention(
prefix=f"{prefix}.attn", config=config, weights=weights
)
self.mlp = GPT2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
)
self.post_attention_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.ln_2",
weights=weights,
eps=config.layer_norm_epsilon,
)
def forward(
self,
hidden_states,
residual,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
attn_output = self.self_attn(
hidden_states,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
hidden_states = attn_output + residual
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
mlp_output = self.mlp(hidden_states)
return residual + mlp_output, residual
class FlashGPT2Model(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.layers = nn.ModuleList(
[
FlashGPT2Layer(
prefix=(
f"h.{layer_id}" if not prefix else f"{prefix}.h.{layer_id}"
),
config=config,
weights=weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = nn.LayerNorm.load(
prefix="ln_f" if not prefix else f"{prefix}.ln_f",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states = self.norm(hidden_states)
return hidden_states
class FlashGPT2ForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=("wte" if not prefix else f"{prefix}.wte"),
weights=weights,
)
self.embed_positions = TensorParallelEmbedding(
prefix=("wpe" if not prefix else f"{prefix}.wpe"),
weights=weights,
)
self.model = FlashGPT2Model(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config,
prefix="wte" if not prefix else f"{prefix}.wte",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
token_embeds = self.embed_tokens(input_ids)
position_embeds = self.embed_positions(position_ids)
inputs_embeds = token_embeds + position_embeds
hidden_states = self.model(
inputs_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta=hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py
================================================
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
apply_rotary_pos_emb,
)
import habana_frameworks.torch as htorch
def load_attention(config, prefix: str, weights):
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
linear = get_linear(weight, bias)
return TensorParallelRowLinear(linear, process_group=weights.process_group)
class GPTJRotary(PositionRotaryEmbedding):
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
):
num_tokens = query.shape[0]
head_size = query.shape[-1]
rope_mode = RotaryPosEmbeddingMode.PAIRWISE
sin = torch.repeat_interleave(sin, 2, dim=-1)
cos = torch.repeat_interleave(cos, 2, dim=-1)
rotary_dim = cos.shape[-1]
query_shape = query.shape
query = query.view(num_tokens, -1, head_size)
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
key_shape = key.shape
key = key.view(num_tokens, -1, head_size)
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
class FlashGPTJAttention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
rotary_emb,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.softmax_scale = self.head_size**-0.5
self.rotary_dim = config.rotary_dim
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = load_attention(
config,
prefix=prefix,
weights=weights,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = load_row(
config,
prefix=f"{prefix}.out_proj",
weights=weights,
bias=False,
)
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
self.rotary_emb = rotary_emb
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
query, key, value = self.query_key_value(hidden_states).split(
self.head_size * self.num_heads, dim=1
)
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size)
# Compute rotary embeddings on rotary_ndims
if self.rotary_dim is not None:
self.rotary_emb(
query[..., : self.rotary_dim], key[..., : self.rotary_dim], cos, sin
)
else:
self.rotary_emb(query, key, cos, sin)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
attn_output = attention(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
class GPTJMLP(nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
act = config.activation_function
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
self.fc_in = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.fc_in", weights=weights, bias=True
)
self.fc_out = load_row(
config,
prefix=f"{prefix}.fc_out",
weights=weights,
bias=True,
)
def forward(self, hidden_states):
hidden_states = self.fc_in(hidden_states)
hidden_states = self.act(hidden_states)
return self.fc_out(hidden_states)
class FlashGPTJLayer(nn.Module):
def __init__(self, prefix: str, config, weights, rotary_emb):
super().__init__()
self.self_attn = FlashGPTJAttention(
prefix=f"{prefix}.attn",
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
self.mlp = GPTJMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
feed_forward_hidden_states = self.mlp(hidden_states)
return attn_output + feed_forward_hidden_states, residual
class FlashGPTJModel(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.config = config
self.wte = TensorParallelEmbedding(prefix=f"{prefix}.wte", weights=weights)
rotary_emb = GPTJRotary.static(
config=config,
dim=config.rotary_dim,
base=10000,
device=weights.device,
)
self.layers = nn.ModuleList(
[
FlashGPTJLayer(
prefix=(
f"h.{layer_id}" if not prefix else f"{prefix}.h.{layer_id}"
),
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.ln_f = FastLayerNorm.load(
prefix="ln_f" if not prefix else f"{prefix}.ln_f",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
def forward(
self,
input_ids: Optional[torch.LongTensor],
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.wte(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states
class FlashGPTJForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
self.model = FlashGPTJModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta=hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py
================================================
# coding=utf-8
# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import math
import torch.utils.checkpoint
from torch import nn
import torch.nn.functional as F
import habana_frameworks.torch as htorch
from transformers.cache_utils import Cache
from transformers.activations import ACT2FN
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.modeling_outputs import (
BaseModelOutput,
)
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
SpeculativeHead,
FastLinear,
)
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.attention import (
KVCache,
paged_attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaAttention,
)
def reshape_for_broadcast(freqs: torch.Tensor, target):
ndim = len(target)
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(target)]
return freqs.view(*shape)
def apply_rotary_emb(
query: torch.Tensor,
key: torch.Tensor,
freqs_ci: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
query_shape = query.shape
key_shape = key.shape
cos_emb, sin_emb = freqs_ci.split(1, dim=-1)
if len(query.shape) == 3:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
query_reshaped = query.float().reshape(*query.shape[:-1], -1, 2)
key_reshaped = key.float().reshape(*key.shape[:-1], -1, 2)
q_shape = query_reshaped.shape[:-1]
cos_emb = reshape_for_broadcast(cos_emb, q_shape)
sin_emb = reshape_for_broadcast(sin_emb, q_shape)
x_q, y_q = query_reshaped.unbind(-1)
x_k, y_k = key_reshaped.unbind(-1)
x_q_rot = x_q * cos_emb - y_q * sin_emb
y_q_rot = x_q * sin_emb + y_q * cos_emb
x_k_rot = x_k * cos_emb - y_k * sin_emb
y_k_rot = x_k * sin_emb + y_k * cos_emb
query_out = torch.stack([x_q_rot, y_q_rot], dim=-1).flatten(-2)
key_out = torch.stack([x_k_rot, y_k_rot], dim=-1).flatten(-2)
query_out = query_out.view(*query_shape)
key_out = key_out.view(*key_shape)
return query_out.type_as(query), key_out.type_as(key)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class Llama4TextExperts(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.process_group = weights.process_group
self.num_experts = config.num_local_experts
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
self.hidden_size = config.hidden_size
self.expert_dim = self.intermediate_size
self.gate_up_proj = nn.Parameter(
weights.get_packed_sharded(f"{prefix}.gate_up_proj", dim=-1, block_sizes=2),
requires_grad=False,
)
self.down_proj = nn.Parameter(
weights.get_sharded(f"{prefix}.down_proj", dim=1), requires_grad=False
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
This should really not be run on a single machine, as we are reaching compute bound:
- the inputs are expected to be "sorted" per expert already.
- the weights are viewed with another dim, to match num_expert, 1, shape * num_tokens, shape
Args:
hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
selected_experts (torch.Tensor): (batch_size * token_num, top_k)
routing_weights (torch.Tensor): (batch_size * token_num, top_k)
Returns:
torch.Tensor
"""
gate_up_proj = self.gate_up_proj.view(self.num_experts, -1, 2 * self.expert_dim)
down_proj = self.down_proj.view(self.num_experts, self.expert_dim, -1)
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
gate_up = torch.bmm(hidden_states, gate_up_proj)
gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
next_states = torch.bmm((up * self.act_fn(gate)), down_proj)
next_states = next_states.view(-1, self.hidden_size)
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(next_states, group=self.process_group)
return next_states
# Phi3MLP
class Llama4TextMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config=config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config=config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
gate_up_states = self.gate_up_proj(x)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1])
class Llama4TextL2Norm(torch.nn.Module):
def __init__(self, eps: float = 1e-6):
super().__init__()
self.eps = eps
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self._norm(x.float()).type_as(x)
def extra_repr(self):
return f"eps={self.eps}"
class Llama4TextMoe(nn.Module):
def __init__(
self,
prefix,
config,
weights,
):
super().__init__()
self.top_k = config.num_experts_per_tok
self.hidden_dim = config.hidden_size
self.num_experts = config.num_local_experts
self.experts = Llama4TextExperts(
config=config, prefix=f"{prefix}.experts", weights=weights
)
self.router = FastLinear.load(
config=config, prefix=f"{prefix}.router", weights=weights, bias=False
)
self.shared_expert = Llama4TextMLP(
config=config, prefix=f"{prefix}.shared_expert", weights=weights
)
self.process_group = weights.process_group
def forward(self, hidden_states, adapter_data):
seq_len, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_dim)
tokens_per_expert = hidden_states.shape[0]
router_logits = self.router(hidden_states)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
router_scores = (
torch.full_like(router_logits, float("-inf"))
.scatter_(1, router_indices, router_top_value)
.transpose(0, 1)
)
# We do this to make sure we have -inf for non topK tokens before going through the !
# Here we are just creating a tensor to index each and every single one of the hidden states. Let s maybe register a buffer for this!
router_indices = (
torch.arange(tokens_per_expert, device=hidden_states.device)
.view(1, -1)
.expand(router_scores.size(0), -1)
)
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
router_indices = router_indices.reshape(-1, 1).expand(-1, self.hidden_dim)
routed_in = torch.gather(
input=hidden_states,
dim=0,
index=router_indices,
).to(hidden_states.device)
# we gather inputs corresponding to each expert based on the router indices
routed_in = routed_in * router_scores.reshape(-1, 1)
routed_out = self.experts(routed_in)
out = self.shared_expert(hidden_states)
# now that we finished expert computation -> we scatter add because we gathered previously
# we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound
# this scales a lot better if you do EP!
out.scatter_add_(
dim=0, index=router_indices, src=routed_out.view(-1, self.hidden_dim)
)
return out
class Llama4TextRotaryEmbedding(nn.Module):
def __init__(self, config, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
self.rope_type = "llama3" if config.rope_scaling is not None else "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def forward(self, x, position_ids):
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
)
position_ids_expanded = position_ids[:, None, :].float()
device_type = (
x.device.type
if isinstance(x.device.type, str) and x.device.type != "mps"
else "cpu"
)
inv_freq_expanded = inv_freq_expanded.to(device_type)
position_ids_expanded = position_ids_expanded.to(device_type)
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
freqs_cis = (
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
* self.attention_scaling
)
return freqs_cis.to(dtype=x.dtype, device=x.device)
class Llama4TextAttention(FlashLlamaAttention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, prefix, config, weights, layer_idx):
super().__init__(layer_idx, prefix, config, weights, None)
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
self.num_key_value_groups = (
config.num_attention_heads // config.num_key_value_heads
)
self.scaling = self.head_dim**-0.5
self.attn_scale = config.attn_scale
self.floor_scale = config.floor_scale
self.attn_temperature_tuning = config.attn_temperature_tuning
self.attention_dropout = config.attention_dropout
self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers
if self.config.use_qk_norm and self.use_rope:
self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
freqs_ci,
cu_seqlen_prefill,
kv_cache: KVCache,
slots,
seqlen,
adapter_data,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bs = seqlen.input_lengths.shape[0]
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
qkv = self.query_key_value(hidden_states, adapter_data)
query_states, key_states, value_states = qkv.split(
[
self.head_dim * self.num_heads,
self.head_dim * self.num_key_value_heads,
self.head_dim * self.num_key_value_heads,
],
dim=-1,
)
query_states = query_states.view(hidden_shape)
key_states = key_states.view(hidden_shape)
value_states = value_states.view(hidden_shape)
if self.use_rope: # the 16E model skips rope for long context on certain layers
query_states, key_states = apply_rotary_emb(
query_states, key_states, freqs_ci
)
if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm
query_states = self.qk_norm(query_states)
key_states = self.qk_norm(key_states)
kv_cache.store(
key=key_states,
value=value_states,
slots=slots,
kv_scales=self.kv_scales,
)
# Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers
if self.attn_temperature_tuning and not self.use_rope:
attn_scales = (
torch.log(
torch.floor((position_ids.float() + 1.0) / self.floor_scale) + 1.0
)
* self.attn_scale
+ 1.0
)
attn_scales = attn_scales.view(*input_shape, 1, 1)
query_states = (query_states * attn_scales).to(query_states.dtype)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
query = query_states.view(bs, -1, self.num_heads, self.head_dim).transpose(
1, 2
)
key = key_states.view(
bs, -1, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value = value_states.view(
bs, -1, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
key = repeat_kv(key, self.num_key_value_groups)
value = repeat_kv(value, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None and causal_mask.ndim == 4:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
is_causal = query.shape[2] > 1 and causal_mask is None
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
# Reference: https://github.com/pytorch/pytorch/issues/112577.
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=causal_mask,
dropout_p=0,
scale=self.scaling,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
# Decode
else:
attn_output = paged_attention(
query_states,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output, adapter_data)
return attn_output
class Llama4TextDecoderLayer(nn.Module):
def __init__(self, prefix, config, weights, layer_idx):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Llama4TextAttention(
f"{prefix}.self_attn", config, weights, layer_idx
)
self.use_chunked_attention = int((layer_idx + 1) % 4 != 0) # <=> use rope
self.is_moe_layer = layer_idx in config.moe_layers
if self.is_moe_layer: # the 128E model interleaves dense / sparse
self.feed_forward = Llama4TextMoe(f"{prefix}.feed_forward", config, weights)
else:
self.feed_forward = Llama4TextMLP(f"{prefix}.feed_forward", config, weights)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states,
freqs_ci,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
attention_mask: Optional[torch.Tensor] = None,
chunk_causal_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
residual = hidden_states
hidden_states, _ = self.input_layernorm(hidden_states)
# use local attention mask for ROPE layers
if self.use_chunked_attention and chunk_causal_mask is not None:
attention_mask = chunk_causal_mask
attention_states = self.self_attn(
hidden_states,
freqs_ci,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
attention_mask=attention_mask,
position_ids=position_ids,
hpu_attention_meta=hpu_attention_meta,
)
hidden_states = residual + attention_states
# Fully Connected
residual = hidden_states
hidden_states, _ = self.post_attention_layernorm(hidden_states)
hidden_states = self.feed_forward(hidden_states, adapter_data)
hidden_states = residual + hidden_states.view(residual.shape)
return hidden_states
class Llama4TextModel(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
Llama4TextDecoderLayer(
prefix=f"{prefix}.layers.{layer_idx}",
config=config,
weights=weights,
layer_idx=layer_idx,
)
for layer_idx in range(config.num_hidden_layers)
]
)
# self.norm = Llama4TextRMSNorm(prefix=f"{prefix}.norm", config=config, weights=weights)
self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm",
weights=weights,
eps=config.rms_norm_eps,
)
self.rotary_emb = Llama4TextRotaryEmbedding(config=config)
self.gradient_checkpointing = False
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
adapter_data,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
bs = seqlen.input_lengths.shape[0]
seq_len = inputs_embeds.shape[0] / bs
cache_position = torch.arange(0, seq_len, device=inputs_embeds.device)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask, chunk_causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds.view(bs, int(seq_len), -1),
cache_position,
None,
output_attentions=False,
use_cache=False,
)
freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1))
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states = layer(
hidden_states,
freqs_ci,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
adapter_data,
attention_mask=causal_mask,
chunk_causal_mask=chunk_causal_mask,
position_ids=position_ids,
hpu_attention_meta=hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states)
return hidden_states
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
chunked_attention_mask=None,
use_cache=True,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return (
attention_mask,
attention_mask,
) # flash does not support chunked attn TODO support flash
return None, None
if self.config._attn_implementation not in ["sdpa", "flex_attention", "eager"]:
return None, None
sequence_length = input_tensor.shape[1]
attention_chunk_size = self.config.attention_chunk_size
first_cache_position = cache_position[0]
if past_key_values is not None:
full_cache_length = past_key_values.get_max_cache_shape() or sequence_length
else:
full_cache_length = (
attention_mask.shape[-1]
if attention_mask is not None
else sequence_length
)
cond1 = first_cache_position >= attention_chunk_size
cond2 = (first_cache_position < attention_chunk_size) & (
first_cache_position + sequence_length > attention_chunk_size
)
key_length = (
torch.where(
cond1,
attention_chunk_size + sequence_length - 1,
torch.where(
cond2, first_cache_position + sequence_length, attention_chunk_size
),
)
if use_cache
else full_cache_length
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
dtype, device = input_tensor.dtype, input_tensor.device
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=max(full_cache_length, attention_chunk_size),
dtype=dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
device=device,
)
if full_cache_length > self.config.attention_chunk_size:
start_idx = max(first_cache_position - attention_chunk_size + 1, 0)
end_idx = start_idx + key_length
chunked_attention_mask = self.create_chunked_attention_mask(
self.config.attention_chunk_size,
start=start_idx, # same offset as with flex
end=end_idx,
device=device,
)
local_attention_mask = attention_mask[
:, start_idx:end_idx
] # offset here as well
# It may be smaller than attention_chunk_size -> pad it
requires_padding = local_attention_mask.shape[-1] < attention_chunk_size
if requires_padding:
local_attention_mask = nn.functional.pad(
local_attention_mask,
(0, attention_chunk_size - local_attention_mask.shape[-1]),
)
# Depending on the padding, take the query tokens from the end or the cache_position
if not requires_padding:
chunked_attention_mask = chunked_attention_mask[
None, None, -sequence_length:, :
]
else:
chunked_attention_mask = chunked_attention_mask[
None, None, cache_position, :
]
chunked_attention_mask = chunked_attention_mask.expand(
input_tensor.shape[0], -1, -1, -1
)
chunked_attention_mask = (
chunked_attention_mask * local_attention_mask[:, None, None, :]
)
if self.config._attn_implementation == "eager":
min_dtype = torch.finfo(dtype).min
chunked_attention_mask = torch.where(
chunked_attention_mask == 0, min_dtype, 0.0
).to(dtype)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and attention_mask.ndim == 4
and not output_attentions # Only unmask for 4d masks
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(
causal_mask, min_dtype
)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if (
self.config._attn_implementation == "sdpa"
and chunked_attention_mask is not None
):
chunked_attention_mask = chunked_attention_mask.bool()
causal_mask = causal_mask.bool()
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=first_cache_position,
is_training=self.training,
):
causal_mask = None
return causal_mask, chunked_attention_mask
def create_chunked_attention_mask(
self, attention_chunk_size: int, start: int, end: int, device: torch.device
) -> torch.Tensor:
"""
Generate the following:
'What' : 0 ■ ⬚ ⬚ ⬚ ⬚ ⬚ |
'▁is' : 1 ■ ■ ⬚ ⬚ ⬚ ⬚ |
'▁ch' : 2 ■ ■ ■ ⬚ ⬚ ⬚ |
'unked' : 3 ⬚ ⬚ ⬚ ■ ⬚ ⬚ |
'▁attention': 4 ⬚ ⬚ ⬚ ■ ■ ⬚ |
'?' : 5 ⬚ ⬚ ⬚ ■ ■ ■ |
If the chunk size is 3.
This can just be applied over the already created attention mask
"""
arange_vector = torch.arange(start, end, device=device)
block_pos = torch.abs(
arange_vector.unsqueeze(0) // attention_chunk_size
- arange_vector.unsqueeze(1) // attention_chunk_size
)
token_pos = arange_vector.unsqueeze(0) - arange_vector.unsqueeze(1)
mask = (block_pos == 0) & (token_pos <= 0)
return mask.to(device)
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(
target_length, device=device
) > cache_position.to(device).reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = (
causal_mask.clone()
) # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
:, None, None, :
].to(device)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)
return causal_mask
class Llama4ForCausalLM(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.model = Llama4TextModel(
prefix=f"{prefix}.model", config=config, weights=weights
)
self.vocab_size = config.vocab_size
self.lm_head = SpeculativeHead.load(
config,
f"{prefix}.lm_head",
weights,
)
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
adapter_data: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = self.model(
inputs_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data=adapter_data,
hpu_attention_meta=hpu_attention_meta,
attention_mask=attention_mask,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits
class Llama4VisionMLP2(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.fc1 = TensorParallelColumnLinear.load(
config=config, prefix=f"{prefix}.fc1", weights=weights, bias=False
)
self.fc2 = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.fc2", weights=weights, bias=False
)
self.activation_fn = nn.GELU() # ACT2FN[config.hidden_act]
self.dropout = config.projector_dropout
def forward(self, hidden_states):
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
return self.activation_fn(
hidden_states
) # TODO: check if we need to apply activation again
class Llama4MultiModalProjector(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.linear_1 = FastLinear.load(
config=config, prefix=f"{prefix}.linear_1", weights=weights, bias=False
)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)
return hidden_states
def pixel_shuffle(input_tensor, shuffle_ratio):
# input_tensor: [batch_size, num_patches, channels]
batch_size, num_patches, channels = input_tensor.shape
patch_size = int(math.sqrt(num_patches))
input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
batch_size, height, width, channels = input_tensor.size()
reshaped_tensor = input_tensor.view(
batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)
)
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
reshaped_tensor = reshaped_tensor.view(
batch_size,
int(height * shuffle_ratio),
int(width * shuffle_ratio),
int(channels / (shuffle_ratio**2)),
)
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1])
return output_tensor
class Llama4VisionPixelShuffleMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
self.inner_dim = int(
config.projector_input_dim // (self.pixel_shuffle_ratio**2)
)
self.output_dim = config.projector_output_dim
self.mlp = Llama4VisionMLP2(
prefix=f"{prefix}.mlp", config=config, weights=weights
)
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
return self.mlp(encoded_patches)
# TODO there is a different RoPE for vision encoder, defined as below
def vision_reshape_for_broadcast(freqs_ci: torch.Tensor, query: torch.Tensor):
ndim = query.ndim
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(query.shape)]
return freqs_ci.view(*shape)
class Llama4VisionAttention(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads // weights.process_group.size()
self.progress_group = weights.process_group
self.head_dim = config.hidden_size // config.num_attention_heads
self.num_key_value_groups = 1
self.attention_dropout = config.attention_dropout
self.qkv_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=True,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=True,
)
def forward(
self,
hidden_states: torch.Tensor,
freqs_ci: torch.Tensor, # Now takes (cos_theta, sin_theta) instead of complex
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
qkv = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv.split(
[
self.head_dim * self.num_heads,
self.head_dim * self.num_heads,
self.head_dim * self.num_heads,
],
dim=2,
)
query_states = query_states.view(hidden_shape)
key_states = key_states.view(hidden_shape)
value_states = value_states.view(hidden_shape)
query_states, key_states = apply_rotary_emb(
query_states, key_states, freqs_ci=freqs_ci
)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=False,
dropout_p=0,
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output
class Llama4VisionMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.activation_fn = nn.GELU() # ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Llama4VisionEncoderLayer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Llama4VisionAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.mlp = Llama4VisionMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights
)
self.input_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05
)
self.post_attention_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=1e-05
)
def forward(
self,
hidden_state: torch.Tensor,
freqs_ci: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
# Self Attention
residual = hidden_state
hidden_state = self.input_layernorm(hidden_state)
hidden_state = self.self_attn(
hidden_state,
freqs_ci=freqs_ci,
attention_mask=attention_mask,
)
hidden_state = residual + hidden_state
# Feed forward
residual = hidden_state
hidden_state = self.post_attention_layernorm(hidden_state)
hidden_state = self.mlp(hidden_state)
hidden_state = residual + hidden_state
outputs = (hidden_state,)
return outputs
class Llama4VisionEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`Llama4VisionEncoderLayer`].
Args:
config: Llama4VisionConfig
"""
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[
Llama4VisionEncoderLayer(
prefix=f"{prefix}.layers.{layer_id}", config=config, weights=weights
)
for layer_id in range(config.num_hidden_layers)
]
)
self.gradient_checkpointing = False
self.config = config
def forward(
self,
hidden_states: torch.Tensor,
freqs_ci: torch.Tensor, # TODO move this to an attribute instead of keeping it around
attention_mask: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutput]:
for encoder_layer in self.layers:
layer_outputs = encoder_layer(
hidden_state=hidden_states,
attention_mask=attention_mask,
freqs_ci=freqs_ci,
)
hidden_states = layer_outputs[0]
return hidden_states
class Llama4UnfoldConvolution(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
kernel_size = config.patch_size
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size)
self.linear = FastLinear.load(
config=config, prefix=f"{prefix}.linear", weights=weights, bias=False
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.unfold(hidden_states)
hidden_states = hidden_states.permute(0, 2, 1)
hidden_states = self.linear(hidden_states)
return hidden_states
class Llama4VisionRotaryEmbedding(nn.Module):
def __init__(self, config, weights):
super().__init__()
# Calculate image grid indices
idx = config.image_size // config.patch_size
img_idx = torch.arange(
idx**2, dtype=torch.int32, device=weights.device
).reshape(idx**2, 1)
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
img_idx[-1, -1] = -2 # ID_CLS_TOKEN
# Calculate x and y coordinates
frequencies_x = img_idx % idx # x coordinates
frequencies_y = torch.div(img_idx, idx, rounding_mode="floor") # y coordinates
# Calculate frequency components
freq_dim = config.hidden_size // config.num_attention_heads // 2
rope_freq = 1.0 / (
config.rope_theta
** (
torch.arange(0, freq_dim, 2, device=weights.device)[
: (freq_dim // 2)
].float()
/ freq_dim
)
)
# Compute frequencies for x and y directions
freqs_x = (frequencies_x + 1)[..., None] * rope_freq[None, None, :]
freqs_x = freqs_x.repeat_interleave(2, dim=-1)
freqs_y = (frequencies_y + 1)[..., None] * rope_freq[None, None, :]
freqs_y = freqs_y.repeat_interleave(2, dim=-1)
# Combine frequencies and mask special tokens
freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
freq_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2
def forward(self, hidden_states):
"""
Returns the rotary embedding components (cosθ, sinθ) for the given hidden states
"""
return self.freqs_ci.to(dtype=hidden_states.dtype, device=hidden_states.device)
class Llama4VisionModel(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.image_size = config.image_size
self.patch_size = config.patch_size
self.hidden_size = config.hidden_size
self.num_channels = config.num_channels
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
self.scale = config.hidden_size**-0.5
self.patch_embedding = Llama4UnfoldConvolution(
prefix=f"{prefix}.patch_embedding", config=config, weights=weights
)
self.class_embedding = nn.Parameter(
weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False
)
self.positional_embedding_vlm = nn.Parameter(
weights.get_tensor(f"{prefix}.positional_embedding_vlm"),
requires_grad=False,
)
self.rotary_embedding = Llama4VisionRotaryEmbedding(config, weights)
# layer norms
self.layernorm_pre = nn.LayerNorm.load(
prefix=f"{prefix}.layernorm_pre", weights=weights, eps=config.norm_eps
)
self.layernorm_post = nn.LayerNorm.load(
prefix=f"{prefix}.layernorm_post", weights=weights, eps=config.norm_eps
)
# encoders
self.model = Llama4VisionEncoder(
prefix=f"{prefix}.model", config=config, weights=weights
)
self.vision_adapter = Llama4VisionPixelShuffleMLP(
prefix=f"{prefix}.vision_adapter", config=config, weights=weights
)
def forward(
self,
pixel_values: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
):
# num_concurrent_media and num_chunks are both currently 1
batch_size_times_num_tiles, num_channels, height, width = pixel_values.shape
num_concurrent_media = 1
num_chunks = 1
hidden_state = self.patch_embedding(pixel_values)
_, num_patches, hidden_dim = hidden_state.shape
# Add cls token
hidden_state = hidden_state.reshape(
batch_size_times_num_tiles * num_concurrent_media * num_chunks,
num_patches,
hidden_dim,
)
class_embedding = self.class_embedding.expand(
hidden_state.shape[0], 1, hidden_state.shape[-1]
)
hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
num_patches += 1
# Position embeddings
hidden_state = hidden_state.reshape(
batch_size_times_num_tiles * num_concurrent_media,
num_chunks,
num_patches,
hidden_dim,
)
positional_embedding = self.positional_embedding_vlm.to(
dtype=hidden_state.dtype, device=hidden_state.device
)
hidden_state = hidden_state + positional_embedding
hidden_state = self.layernorm_pre(hidden_state)
hidden_state = hidden_state.view(batch_size_times_num_tiles, -1, hidden_dim)
freqs_ci = self.rotary_embedding(pixel_values)
hidden_state = self.model(
hidden_state,
attention_mask=None,
freqs_ci=freqs_ci,
)
hidden_state = self.layernorm_post(hidden_state)
hidden_state = hidden_state[:, :-1, :]
# now, we use Llama4VisionPixelShuffle + mlp to project embeddings
hidden_state = self.vision_adapter(hidden_state)
return hidden_state
class Llama4ForConditionalGeneration(nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.config = config
config.vision_config.quantize = None
config.vision_config.speculator = config.speculator
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
config.text_config._attn_implementation = None
self.vision_model = Llama4VisionModel(
prefix="vision_model", config=config.vision_config, weights=weights
)
self.multi_modal_projector = Llama4MultiModalProjector(
prefix="multi_modal_projector", config=config, weights=weights
)
self.text_model = Llama4ForCausalLM(
prefix="language_model", config=config.text_config, weights=weights
)
self.vocab_size = config.text_config.vocab_size
self.pad_token_id = (
self.config.pad_token_id if self.config.pad_token_id is not None else -1
)
self.config = config
self.dtype = weights.dtype
self.device = weights.device
def get_image_features(
self,
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
**kwargs,
):
"""
Obtains image last hidden states from the vision tower and apply al projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError(
f"Unexpected select feature strategy: {self.vision_feature_select_strategy}"
)
kwargs = {k: v for k, v in kwargs.items() if v is not None}
hidden_state = self.vision_model(pixel_values)
return hidden_state
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=self.config.vision_config.vision_feature_layer,
vision_feature_select_strategy=self.config.vision_config.vision_feature_select_strategy,
image_sizes=image_sizes,
)
vision_flat = image_features.view(-1, image_features.size(-1))
image_features = self.multi_modal_projector(vision_flat)
return image_features
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.text_model.model.embed_tokens(input_ids)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
original_inputs_embeds_shape = inputs_embeds.shape
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
-1
)
final_mask = special_image_mask.to(inputs_embeds.device)
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1))
final_mask_1d = final_mask[..., 0].reshape(-1)
num_tokens_to_fill = final_mask_1d.sum()
if num_tokens_to_fill != vision_embeds.size(0):
raise ValueError(
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
f"but multi_modal_projector returned {vision_embeds.size(0)}"
)
expanded_mask = final_mask_1d.unsqueeze(-1).expand(
-1, inputs_embeds.size(-1)
)
inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, vision_embeds)
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
cu_seqlen_prefill: Optional[torch.Tensor] = None,
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = None,
slots: torch.Tensor = None,
seqlen: Seqlen = None,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
lm_head_indices: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
**lm_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
logits, speculative_logits = self.text_model(
inputs_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
adapter_data,
lm_head_indices,
attention_mask,
)
return logits, speculative_logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
================================================
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import List, Optional, Tuple, Type
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
import habana_frameworks.torch as htorch
from text_generation_server.layers.attention import (
KVCache,
get_kv_scales,
)
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
FastLayerNorm,
)
from text_generation_server.layers import (
FastLinear,
)
from text_generation_server.utils.weights import (
Weights,
)
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
def load_attention(config, prefix: str, weights, layer_id):
# Only defined in granite.
bias = getattr(config, "attention_bias", False)
head_size = config.hidden_size // config.num_attention_heads
sizes = None
prefixes = None
if config.model_type == "phi3":
base_layer = TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.qkv_proj",
weights=weights,
bias=bias,
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
)
prefixes = ["qkv_proj"]
elif config.model_type == "baichuan":
prefix = f"{prefix}.W_pack"
base_layer = TensorParallelColumnLinear.load_qkv(
config,
prefix=prefix,
weights=weights,
bias=bias,
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
)
prefixes = [prefix]
else:
prefixes = ["q_proj", "k_proj", "v_proj"]
sizes = [
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
]
base_layer = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=bias,
)
return TensorParallelMultiAdapterLinear.load(
base_layer=base_layer,
layer_id=layer_id,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
@contextmanager
def no_fp8(weights: Weights):
"""De-activate fp8 auto conversion for the duration of this context manager"""
weights_loader = weights.weights_loader
if isinstance(weights_loader, HybridFP8UnquantLoader) and weights_loader.to_fp8:
weights_loader = HybridFP8UnquantLoader(
weights_loader.activation_scale_ub, to_fp8=False
)
with weights.use_loader(weights_loader):
yield
class FlashLlamaAttention(torch.nn.Module):
def __init__(
self,
index: int,
prefix: str,
config,
weights,
rotary_emb,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = rotary_emb
# `config.attention_multiplier` is used in Granite
self.softmax_scale = getattr(
config, "attention_multiplier", self.head_size**-0.5
)
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
if config.num_key_value_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_key_value_heads` must be divisible by `num_shards` (got `num_key_value_heads`: {config.num_key_value_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights, index)
self.index = index
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=getattr(config, "attention_bias", False),
)
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
index,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache: KVCache,
slots,
seqlen,
adapter_data,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_scales=self.kv_scales,
kv_cache=kv_cache,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class Phi3MoE(nn.Module):
def __init__(
self, prefix: str, config, moe_layer_cls: Type[MoELayer], weights: Weights
):
super().__init__()
# gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
self.moe = moe_layer_cls(
prefix=f"{prefix}.experts",
n_experts=config.num_local_experts,
n_expert_group=None,
renormalize=True,
topk=config.num_experts_per_tok,
topk_group=None,
weights=weights,
gate_proj_name="w1",
up_proj_name="w3",
down_proj_name="w2",
)
self.process_group = weights.process_group
def forward(self, x, adapter_data) -> torch.Tensor:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(x)
out = self.moe(x, gating_output=router_logits)
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out.view(*x.shape)
class LlamaMLP(nn.Module):
def __init__(self, prefix, config, weights, index):
super().__init__()
self.hidden_act = config.hidden_act
self.act = (
ACT2FN[self.hidden_act]
if "gelu" not in self.hidden_act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh"
if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none"
),
)
)
prefixes = None
sizes = None
# Fuse gate and up proj
bias = getattr(config, "mlp_bias", False)
if config.model_type == "phi3":
gate_up_proj = TensorParallelColumnLinear.load_gate_up(
config,
prefix=f"{prefix}.gate_up_proj",
weights=weights,
bias=bias,
)
else:
prefixes = ["gate_proj", "up_proj"]
sizes = [
config.intermediate_size,
config.intermediate_size,
]
gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=bias,
)
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
index,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=bias,
)
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
index,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
# TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize
self.hidden_size = config.hidden_size
def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class FlashLlamaLayer(nn.Module):
def __init__(self, index, prefix, config, weights, rotary_emb):
super().__init__()
with no_fp8(weights):
self.self_attn = FlashLlamaAttention(
index=index,
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
if config.model_type == "phimoe":
moe_layer_cls = (
SparseMoELayer
if SparseMoELayer.is_supported(weights)
else DenseMoELayer
)
self.mlp = Phi3MoE(
f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
)
# with moe the layernorms are are not rmsnorms and they have bias
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.post_attention_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
else:
self.mlp = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
# Used in Granite
# This could eventually be baked into the weights like we do for the embeddings/lm_head
# but this would mean modifying the lora code
self.residual_multiplier = getattr(config, "residual_multiplier", None)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
cross_attention_states,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
hpu_attention_meta=hpu_attention_meta,
)
if self.residual_multiplier is not None:
attn_output *= self.residual_multiplier
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
)
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
if self.residual_multiplier is not None:
mlp_output *= self.residual_multiplier
return mlp_output, attn_res
class FlashLlamaModel(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
# Skip fp8 quant for first and last layers
self.layers = nn.ModuleList()
self.cross_attention_layers = getattr(config, "cross_attention_layers", [])
# Setting defaults for baichuan custom config which doesn't apply them.
config.rope_theta = getattr(config, "rope_theta", 10000)
config.num_key_value_heads = getattr(
config, "num_key_value_heads", config.num_attention_heads
)
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.hidden_size // config.num_attention_heads,
base=config.rope_theta,
device=weights.device,
)
with no_fp8(weights):
self.layers.append(
FlashLlamaLayer(
index=0,
prefix=f"{prefix}.layers.0",
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
)
# Skip first and last layers
for layer_id in range(1, config.num_hidden_layers - 1):
if layer_id in self.cross_attention_layers:
from text_generation_server.models.custom_modeling.flash_mllama import (
FlashLlamaCrossLayer,
)
self.layers.append(
FlashLlamaCrossLayer(
index=layer_id,
prefix=(f"{prefix}.layers.{layer_id}"),
config=config,
weights=weights,
)
)
else:
self.layers.append(
FlashLlamaLayer(
index=layer_id,
prefix=(f"{prefix}.layers.{layer_id}"),
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
)
with no_fp8(weights):
last_layer_id = config.num_hidden_layers - 1
self.layers.append(
FlashLlamaLayer(
index=last_layer_id,
prefix=(f"{prefix}.layers.{last_layer_id}"),
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
)
self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm",
weights=weights,
eps=config.rms_norm_eps,
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
adapter_data,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
cross_attention_states=None,
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
adapter_data,
cross_attention_states,
hpu_attention_meta=hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights, name=None):
if name is None:
name = "model"
super().__init__()
with no_fp8(weights):
self.embed_tokens = TensorParallelEmbedding(
prefix=(
f"{name}.embed_tokens"
if not prefix
else f"{prefix}.{name}.embed_tokens"
),
weights=weights,
)
self.model = FlashLlamaModel(
prefix=name if not prefix else f"{prefix}.{name}",
config=config,
weights=weights,
)
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
else:
suffix = "lm_head"
# Used in Granite
embedding_multiplier = getattr(config, "embedding_multiplier", None)
if embedding_multiplier is not None:
self.embed_tokens.weight.data *= embedding_multiplier
prefix = suffix if not prefix or name != "model" else f"{prefix}.{suffix}"
with no_fp8(weights):
self.lm_head = SpeculativeHead.load(
config,
prefix,
weights,
)
# Used in Granite
self.logits_scaling = getattr(config, "logits_scaling", None)
if self.logits_scaling is not None and self.lm_head.head is not None:
try:
# Scale the weights directly
self.lm_head.head.linear.weight.data /= self.logits_scaling
self.logits_scaled = True
except Exception:
self.logits_scaled = False
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
inputs_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data=adapter_data,
cross_attention_states=cross_attention_states,
hpu_attention_meta=hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
# Used in Granite
if self.logits_scaling is not None and not self.logits_scaled:
logits /= self.logits_scaling
if speculative_logits is not None:
speculative_logits /= self.logits_scaling
return logits, speculative_logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py
================================================
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Llava-NeXT model."""
from typing import List, Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (`tuple`):
The size of the input image in the format (height, width).
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (height, width).
"""
if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists")
height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (`torch.Tensor`):
The image tensor, assumed to be of shape (num_channels, height, width).
original_size (`tuple`):
The original size of the image (height, width).
Returns:
`torch.Tensor`: The unpadded image tensor.
"""
original_height, original_width = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
return unpadded_tensor
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
class LlavaNextMultiModalProjector(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.linear_1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class FlashLlavaNextForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
vision_config = config.vision_config
# Instead of selecting in hidden_states[-2].
# Instead compute only the n -2 + 1 layers and don't pool
if config.vision_feature_layer < 0:
vision_config.num_hidden_layers += config.vision_feature_layer + 1
else:
vision_config.num_hidden_layers = config.vision_feature_layer + 1
self.vision_tower = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config,
weights=weights,
)
self.multi_modal_projector = LlavaNextMultiModalProjector(
prefix="multi_modal_projector", config=config, weights=weights
)
self.image_newline = weights.get_tensor("image_newline")
self.vocab_size = config.text_config.vocab_size
self.config = config
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
self.text_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,
weights=weights,
)
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
def _merge_input_ids_with_image_features(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
mask = torch.where(input_ids == self.config.image_token_index)
# Let's pray we have enabled enough slots !
try:
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
except Exception as e:
raise RuntimeError(
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
)
return inputs_embeds
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
# 1. Extract the input embeddings
# 2. Merge text and images
num_images, num_patches, channels, height, width = pixel_values.shape
pixel_values = pixel_values.view(
num_images * num_patches, channels, height, width
)
image_features = self.vision_tower(pixel_values)
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
# Already done within the clip model
selected_image_feature = image_features.last_hidden_state
if self.config.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.config.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise RuntimeError(
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
)
image_features = self.multi_modal_projector(selected_image_feature)
# split up image_features for each of the individual images
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [num_patches] * num_images
image_features = torch.split(image_features, split_sizes, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = (
self.config.vision_config.image_size // self.config.vision_config.patch_size
)
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
if height * width != base_image_feature.shape[0]:
raise ValueError(
"The number of patches is not consistent with the image size."
)
# Dimensions are intentionally swapped to be bug-compatible with
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature, self.image_newline[None]), dim=0
)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
return image_features.view(-1, image_features.shape[-1])
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, vision_embeds
)
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
):
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
adapter_data=adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py
================================================
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
import habana_frameworks.torch as htorch
class MistralConfig(PretrainedConfig):
model_type = "mistral"
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
sliding_window=None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class MistralAttention(torch.nn.Module):
def __init__(self, prefix: str, config, weights, layer_id, rotary_emb):
super().__init__()
self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
if getattr(config, "head_dim", None) is not None:
self.head_size = config.head_dim
else:
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = rotary_emb
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
query_key_value = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.query_key_value = TensorParallelMultiAdapterLinear.load(
query_key_value,
layer_id,
["q_proj", "k_proj", "v_proj"],
sizes=[
self.head_size * config.num_attention_heads,
self.head_size * config.num_key_value_heads,
self.head_size * config.num_key_value_heads,
],
process_group=weights.process_group,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
layer_id,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
window_size_left=self.max_past,
)
return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class MistralMLP(nn.Module):
def __init__(self, prefix: str, config, weights, layer_id):
super().__init__()
self.hidden_act = config.hidden_act
self.act = (
ACT2FN[self.hidden_act]
if "gelu" not in self.hidden_act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh"
if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none"
),
)
)
# Fuse gate and up proj
gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
layer_id,
["gate_proj", "up_proj"],
sizes=[
config.intermediate_size,
config.intermediate_size,
],
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
layer_id,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
# TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize
def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class MistralLayer(nn.Module):
def __init__(self, prefix: str, config, weights, layer_id, rotary_emb):
super().__init__()
self.self_attn = MistralAttention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
layer_id=layer_id,
rotary_emb=rotary_emb,
)
self.mlp = MistralMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
)
# faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
)
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
return mlp_output, attn_res
class MistralModel(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if getattr(config, "head_dim", None) is not None:
head_dim = config.head_dim
else:
head_dim = config.hidden_size // config.num_attention_heads
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=head_dim,
base=config.rope_theta,
device=weights.device,
)
self.layers = nn.ModuleList(
[
MistralLayer(
prefix=f"{prefix}.layers.{layer_id}",
config=config,
weights=weights,
layer_id=layer_id,
rotary_emb=rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
adapter_data: Optional[torch.Tensor] = None,
):
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
adapter_data,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashMistralForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights, name=None):
if name is None:
name = "model"
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=(
f"{name}.embed_tokens"
if not prefix
else f"{prefix}.{name}.embed_tokens"
),
weights=weights,
)
self.model = MistralModel(
prefix=name if not prefix else f"{prefix}.{name}",
config=config,
weights=weights,
)
self.lm_head = SpeculativeHead.load(
config,
# TODO dirty hack for idefics2.
prefix=(
"lm_head" if not prefix or name != "model" else f"{prefix}.lm_head"
),
weights=weights,
)
self.max_past = config.sliding_window
self.max_past_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
else None
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
inputs_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py
================================================
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Type
import torch
import torch.distributed
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from text_generation_server.layers import (
FastLinear,
SpeculativeHead,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
get_linear,
)
from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
set_block_mapping,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class MixtralConfig(PretrainedConfig):
model_type = "mixtral"
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
sliding_window=None,
num_experts_per_tok=2,
num_local_experts=8,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.num_experts_per_tok = num_experts_per_tok
self.num_local_experts = num_local_experts
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def promote_scalar(x: torch.Tensor) -> torch.Tensor:
return x.view(1) if len(x.size()) == 0 else x
def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
)
if isinstance(weight, UnquantizedWeight):
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias=None))
def _load_experts(config, prefix: str, mat, weights):
if config.quantize is not None:
raise NotImplementedError("Mixtral does not support weight quantization yet.")
assert mat in ["w1", "w2", "w3"]
world_size = weights.process_group.size()
rank = weights.process_group.rank()
assert (
config.intermediate_size % world_size == 0
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
block_size = config.intermediate_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = torch.empty(
(config.num_local_experts * block_size, config.hidden_size),
dtype=weights.dtype,
device=weights.device,
)
for i in range(config.num_local_experts):
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
if mat == "w2":
expert_slice = slice_[:, start:stop].t().contiguous()
else:
expert_slice = slice_[start:stop]
tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
dtype=weights.dtype
).to(device=weights.device)
return tensor
class MixtralAttention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
rotary_emb,
):
super().__init__()
self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = rotary_emb
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@torch.jit.script
def select_experts(gate_logits: torch.Tensor, top_k: int):
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
# weights, selected_experts: (sequence_length, top-k)
weights, selected_experts = torch.topk(all_probs, top_k, dim=-1)
weights /= weights.sum(dim=-1, keepdim=True)
weights = weights.view(-1)
selected_experts = selected_experts.view(-1)
return selected_experts, weights
@torch.jit.script
def round_up(x: torch.Tensor, value: int):
return torch.div(x + (value - 1), value, rounding_mode="trunc") * value
class MixtralMoE(nn.Module):
def __init__(
self, prefix, config: MixtralConfig, moe_layer_cls: Type[MoELayer], weights
):
super().__init__()
# gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
self.moe = moe_layer_cls(
n_expert_group=None,
n_experts=config.num_local_experts,
prefix=f"{prefix}.experts",
renormalize=True,
topk=config.num_experts_per_tok,
topk_group=None,
weights=weights,
gate_proj_name="w1",
up_proj_name="w3",
down_proj_name="w2",
)
assert isinstance(self.moe, MoELayer)
self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(x)
out = self.moe(x, gating_output=router_logits)
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out.view(*x.shape)
class MixtralLayer(nn.Module):
def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = MixtralAttention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
moe_layer_cls = (
SparseMoELayer if SparseMoELayer.is_supported(weights) else DenseMoELayer
)
self.moe = MixtralMoE(
f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
# faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
)
moe_output = self.moe(normed_attn_res_output)
return moe_output, attn_res
class MixtralModel(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
),
weights=weights,
)
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.hidden_size // config.num_attention_heads,
base=config.rope_theta,
device=weights.device,
)
self.layers = nn.ModuleList(
[
MixtralLayer(
"model" if not prefix else f"{prefix}.model",
layer_id,
config,
weights,
rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
weights=weights,
eps=config.rms_norm_eps,
)
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashMixtralForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.model = MixtralModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
weights=weights,
)
self.max_past = config.sliding_window
self.max_past_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
else None
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py
================================================
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Mllama model."""
from typing import Optional, Tuple, List
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
import torch.nn.functional as F
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
FastLinear,
)
from text_generation_server.layers.attention import (
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA
import habana_frameworks.torch as htorch
def _prepare_aspect_ratio_attention_mask(
aspect_ratio_mask: torch.Tensor,
num_patches: int,
target_length: int,
dtype: torch.dtype,
) -> torch.Tensor:
# Expand aspect ratio mask to target_length
batch_size, max_num_tiles = aspect_ratio_mask.shape
attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype)
attention_mask = attention_mask.repeat(1, 1, target_length, 1)
# Mask padding patches
pad_patches = target_length - num_patches
attention_mask[:, :, -pad_patches:] = 0
# Invert the mask (0 -> 1, 1 -> 0)
attention_mask = 1 - attention_mask
# Reshape to 2D and create 4D attention mask
# (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
attention_mask = attention_mask.reshape(
batch_size, max_num_tiles * target_length, 1
)
attention_mask = (
attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min
)
attention_mask = attention_mask.unsqueeze(1)
return attention_mask
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(
target_length, device=device
) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = (
causal_mask.clone()
) # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = (
causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)
return causal_mask
def _prepare_cross_attention_mask(
cross_attention_mask: torch.Tensor,
num_vision_tokens: int,
dtype: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
# reshape so it can be used by attn module
batch_size, text_total_length, *_ = cross_attention_mask.shape
cross_attention_mask = cross_attention_mask.repeat_interleave(
num_vision_tokens, dim=3
)
cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1)
cross_attention_mask = cross_attention_mask.unsqueeze(1)
# invert the mask
inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype)
cross_attention_mask = inverted_cross_attn_mask.masked_fill(
inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min
)
# apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's
# last dimension contains negative infinity values, otherwise it's 1
negative_inf_value = torch.finfo(dtype).min
full_text_row_masked_out_mask = (
(cross_attention_mask != negative_inf_value)
.any(dim=-1)
.type_as(cross_attention_mask)[..., None]
)
cross_attention_mask *= full_text_row_masked_out_mask
return cross_attention_mask, full_text_row_masked_out_mask
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision
class MllamaVisionMLP(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class MllamaVisionSdpaAttention(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.head_dim = config.hidden_size // config.attention_heads
self.num_heads = config.attention_heads // weights.process_group.size()
self.qkv_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_state)
query, key, value = qkv.split(
[
self.head_dim * self.num_heads,
self.head_dim * self.num_heads,
self.head_dim * self.num_heads,
],
dim=2,
)
batch_size, q_seq_len, _ = query.shape
_, kv_seq_len, _ = key.shape
query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim)
key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
attn_output = fsdpa_op(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
scale=None,
softmax_mode="None",
recompute_mode=None,
valid_sequence_lengths=None,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
output = self.o_proj(attn_output)
return output
class MllamaVisionEncoderLayer(nn.Module):
def __init__(self, *, prefix, config, weights, is_gated: bool):
super().__init__()
self.hidden_size = config.hidden_size
self.num_attention_heads = config.attention_heads
self.is_gated = is_gated
self.intermediate_size = config.intermediate_size
self.self_attn = MllamaVisionSdpaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.mlp = MllamaVisionMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights
)
self.input_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05
)
self.post_attention_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=1e-05
)
# there used to be an if else here, no code path
if is_gated:
self.gate_attn = nn.Parameter(
weights.get_tensor(f"{prefix}.gate_attn"), requires_grad=False
)
self.gate_ffn = nn.Parameter(
weights.get_tensor(f"{prefix}.gate_ffn"), requires_grad=False
)
def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
# Self Attention
residual = hidden_state
hidden_state = self.input_layernorm(hidden_state)
hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask)
gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()
hidden_state = residual + gate_attn * hidden_state
# Feed forward
residual = hidden_state
hidden_state = self.post_attention_layernorm(hidden_state)
hidden_state = self.mlp(hidden_state)
gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()
hidden_state = residual + gate_ffn * hidden_state
return hidden_state
class MllamaVisionEncoder(nn.Module):
def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int):
super().__init__()
self.config = config
self.layers = [
MllamaVisionEncoderLayer(
prefix=f"{prefix}.layers.{i}",
config=config,
weights=weights,
is_gated=is_gated,
)
for i in range(num_layers)
]
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
encoder_states = [hidden_states]
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for encoder_layer in self.layers:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
)
hidden_states = layer_outputs
encoder_states.append(hidden_states)
if lazy_mode:
htorch.core.mark_step()
return hidden_states, encoder_states
class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.max_num_tiles = config.max_num_tiles
self.hidden_size = config.hidden_size
self.max_aspect_ratio_id = config.max_aspect_ratio_id
self.embedding = TensorParallelEmbedding(
prefix=f"{prefix}.embedding", weights=weights
)
self.gate = nn.Parameter(
weights.get_tensor(f"{prefix}.gate"), requires_grad=False
)
def forward(
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
) -> torch.Tensor:
embeddings = self.embedding(aspect_ratio_ids)
embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)
# Always gated.
embeddings = embeddings * self.gate.tanh()
hidden_state = hidden_state + embeddings
return hidden_state
class MllamaPrecomputedPositionEmbedding(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.max_num_tiles = config.max_num_tiles
self.max_aspect_ratio_id = config.max_aspect_ratio_id
self.num_patches = (config.image_size // config.patch_size) ** 2 + 1
self.hidden_size = config.hidden_size
self.scale = config.hidden_size**-0.5
self.gate = nn.Parameter(
weights.get_tensor(f"{prefix}.gate"), requires_grad=False
)
# position embedding
embedding = nn.Parameter(
weights.get_tensor(f"{prefix}.embedding"), requires_grad=False
)
self.gated_position_embedding = (1 - self.gate.tanh()) * embedding
self.tile_embedding = TensorParallelEmbedding(
prefix=f"{prefix}.tile_embedding", weights=weights
)
def forward(
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
) -> torch.Tensor:
# position embeddings
hidden_state = hidden_state + self.gated_position_embedding.view(
1, 1, self.num_patches, self.hidden_size
)
# precomputed tile position embeddings
tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
batch_size = hidden_state.shape[0]
tile_position_embedding = tile_position_embedding.reshape(
batch_size, self.max_num_tiles, self.num_patches, self.hidden_size
)
gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding
hidden_state = hidden_state + gated_tile_position_embedding
return hidden_state
class MllamaVisionModel(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.image_size = config.image_size
self.patch_size = config.patch_size
self.max_num_tiles = config.max_num_tiles
self.hidden_size = config.hidden_size
self.num_channels = config.num_channels
self.intermediate_layers_indices = config.intermediate_layers_indices
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
self.scale = config.hidden_size**-0.5
self.dtype = weights.dtype
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.hidden_size,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
bias=False,
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
)
self.class_embedding = nn.Parameter(
weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False
)
self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(
prefix=f"{prefix}.gated_positional_embedding",
config=config,
weights=weights,
)
self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
prefix=f"{prefix}.pre_tile_positional_embedding",
config=config,
weights=weights,
)
self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
prefix=f"{prefix}.post_tile_positional_embedding",
config=config,
weights=weights,
)
## layer norms
self.layernorm_pre = nn.LayerNorm.load(
prefix=f"{prefix}.layernorm_pre",
weights=weights,
# torch default
eps=1e-05,
)
self.layernorm_post = nn.LayerNorm.load(
prefix=f"{prefix}.layernorm_post",
weights=weights,
# torch default
eps=1e-05,
)
## encoders
self.transformer = MllamaVisionEncoder(
prefix=f"{prefix}.transformer",
config=config,
weights=weights,
is_gated=False,
num_layers=config.num_hidden_layers,
)
self.global_transformer = MllamaVisionEncoder(
prefix=f"{prefix}.global_transformer",
config=config,
weights=weights,
is_gated=True,
num_layers=config.num_global_layers,
)
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
batch_size, _, hidden_size = hidden_state.shape
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
return hidden_state
def forward(
self,
pixel_values: torch.Tensor,
aspect_ratio_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
(
batch_size,
num_concurrent_media,
num_tiles,
num_channels,
height,
width,
) = pixel_values.shape
pixel_values = pixel_values.reshape(
batch_size * num_concurrent_media * num_tiles, num_channels, height, width
)
aspect_ratio_ids = aspect_ratio_ids.reshape(
batch_size * num_concurrent_media, -1
)
# patch embedding
patch_embeds = self.patch_embedding(pixel_values)
hidden_state = patch_embeds.flatten(2).transpose(1, 2)
# tile embeddings
_, num_patches, dim = hidden_state.shape
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media, num_tiles, -1, dim
)
hidden_state = self.pre_tile_positional_embedding(
hidden_state, aspect_ratio_ids
)
# apply cls token
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media * num_tiles, num_patches, dim
)
hidden_state = self.apply_class_embedding(hidden_state)
num_patches += 1
# apply position embeddings
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media, num_tiles, num_patches, dim
)
hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)
# apply encoder
hidden_state = self.layernorm_pre(hidden_state)
# Compute the number of tokens to pad
num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
# Compute padding tuple for pad function
padding = (
0,
0,
0,
num_padding_patches,
) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
# Pad the tensor
hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
slice_index = -num_padding_patches if num_padding_patches > 0 else None
if attention_mask is not None:
attention_mask = attention_mask.reshape(
batch_size * num_concurrent_media, -1
)
attention_mask = _prepare_aspect_ratio_attention_mask(
aspect_ratio_mask=attention_mask,
num_patches=self.num_patches,
target_length=hidden_state.shape[2],
dtype=self.dtype,
)
hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim)
hidden_state, all_intermediate_hidden_states = self.transformer(
hidden_state,
attention_mask=attention_mask,
)
intermediate_hidden_states = [
hidden_state
for idx, hidden_state in enumerate(all_intermediate_hidden_states)
if idx in self.intermediate_layers_indices
]
intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1)
# apply global encoder
hidden_state = self.layernorm_post(hidden_state)
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
dim,
)
hidden_state = self.post_tile_positional_embedding(
hidden_state, aspect_ratio_ids
)
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media,
num_tiles * (num_patches + num_padding_patches),
dim,
)
hidden_state, _ = self.global_transformer(
hidden_state, attention_mask=attention_mask
)
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
dim,
)
hidden_state = hidden_state[:, :, :slice_index]
# adding intermediate layer outputs
hidden_state = hidden_state.reshape(
batch_size, num_concurrent_media, num_tiles, num_patches, dim
)
intermediate_hidden_states = intermediate_hidden_states.reshape(
batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
-1,
)
intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]
intermediate_hidden_states = intermediate_hidden_states.reshape(
batch_size, num_concurrent_media, num_tiles, num_patches, -1
)
hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)
return hidden_state
class MllamaTextCrossAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, *, prefix, config, weights, layer_idx):
super().__init__()
self.config = config
self.num_heads = self.config.num_attention_heads
self.num_key_value_heads = self.config.num_key_value_heads
self.dropout = config.dropout
self.hidden_size = config.hidden_size
self.head_size = config.hidden_size // self.num_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.layer_idx = layer_idx
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
self.num_key_value_heads // weights.process_group.size()
)
self.q_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_proj",
weights=weights,
bias=False,
)
self.k_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.k_proj",
weights=weights,
bias=False,
)
self.v_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.v_proj",
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.q_norm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps
)
self.k_norm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps
)
self.softmax_scale = self.head_size**-0.5
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: Optional[torch.Tensor] = None,
# past_key_value=None,
# attention_mask: Optional[torch.Tensor] = None,
# cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# hidden_states = hidden_states.unsqueeze(0)
# bsz, q_len, _ = hidden_states.size()
(
cross_attention_states,
cross_attention_len,
indices,
) = cross_attention_states
bs = cross_attention_len.size(0)
query_states = self.q_proj(hidden_states)
query_states = query_states.view(bs, -1, self.num_heads, self.head_size)
query_states = self.q_norm(query_states)
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
key_states = key_states.view(bs, -1, self.num_key_value_heads, self.head_size)
value_states = value_states.view(
bs, -1, self.num_key_value_heads, self.head_size
)
key_states = self.k_norm(key_states)
# key_states = key_states.repeat(1, self.num_key_value_groups, 1)
# value_states = value_states.repeat(1, self.num_key_value_groups, 1)
# logger.info(
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
# )
# execute sdpa
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
attn_output = fsdpa_op(
query_states,
key_states,
value_states,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
softmax_mode="None",
recompute_mode=None,
valid_sequence_lengths=None,
)
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
return attn_output
# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText
class MllamaTextMLP(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
shape = x.shape
gate_up_states = self.gate_up_proj(x)
gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size)
result = self.down_proj(
self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1]
)
return result
class FlashLlamaCrossLayer(torch.nn.Module):
"""Cross-attention transformer block with tanh-gated attention and feedforward."""
def __init__(self, *, prefix, config, weights, index) -> None:
layer_idx = index
super().__init__()
self.cross_attn = MllamaTextCrossAttention(
prefix=f"{prefix}.cross_attn",
config=config,
weights=weights,
layer_idx=layer_idx,
)
self.input_layernorm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.cross_attn_attn_gate = torch.nn.Parameter(
weights.get_tensor(f"{prefix}.cross_attn_attn_gate"), requires_grad=False
)
self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.post_attention_layernorm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.cross_attn_mlp_gate = torch.nn.Parameter(
weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False
)
self.layer_idx = layer_idx
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
cross_attention_states, # [ IB, ...]
hpu_attention_meta,
) -> Tuple[torch.Tensor, torch.Tensor]:
if cross_attention_states is None:
return hidden_states, residual
if residual is not None:
hidden_states += residual
indices = cross_attention_states[-1]
out_hidden_states = hidden_states[:]
hidden_states = hidden_states[indices]
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.cross_attn(
hidden_states=hidden_states,
# attention_mask=cross_attention_mask,
cross_attention_states=cross_attention_states,
)
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
out_hidden_states[indices] = hidden_states
hidden_states = out_hidden_states
return hidden_states, None
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText
class MllamaTextRMSNorm(nn.Module):
def __init__(self, weight, eps):
super().__init__()
self.weight = weight
self.variance_epsilon = eps
@classmethod
def load(cls, *, prefix, weights, eps):
weight = nn.Parameter(
weights.get_tensor(f"{prefix}.weight"), requires_grad=False
)
return cls(weight=weight, eps=eps)
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class FlashMllamaForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = None
config.vision_config.speculator = config.speculator
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
config.text_config._attn_implementation = "sdpa"
self.hidden_size = config.text_config.hidden_size
self.vision_model = MllamaVisionModel(
prefix="vision_model", config=config.vision_config, weights=weights
)
self.multi_modal_projector = FastLinear.load(
prefix="multi_modal_projector", config=config, weights=weights, bias=True
)
self.text_model = FlashLlamaForCausalLM(
prefix="language_model", config=config.text_config, weights=weights
)
self.config = config
self.dtype = weights.dtype
self.device = weights.device
def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask):
if aspect_ratio_ids is None:
raise ValueError(
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
)
# logger.info(f"PIxel values {pixel_values.shape}")
batch_size = pixel_values.shape[0]
vision_states = self.vision_model(
pixel_values, aspect_ratio_ids, aspect_ratio_mask
)
cross_attention_states = self.multi_modal_projector(vision_states).reshape(
-1, vision_states.shape[-2], self.hidden_size
)
_, _, h = cross_attention_states.shape
cross_attention_states = cross_attention_states.view(batch_size, -1, h)
# logger.info(f"cross {cross_attention_states.shape}")
return cross_attention_states
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor],
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
indices=None,
cross_attention_len: Optional[torch.Tensor] = None,
):
if cross_attention_states is not None:
cross_attention_states = (
cross_attention_states,
cross_attention_len,
indices,
)
outputs = self.text_model(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
lm_head_indices=lm_head_indices,
adapter_data=adapter_data,
cross_attention_states=cross_attention_states,
)
return outputs
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
================================================
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class GPTNeoXConfig(TransformersGPTNeoXConfig):
attribute_map = {
"num_key_value_heads": "num_attention_heads",
}
def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
linear = get_linear(weight, bias)
if config.use_parallel_residual:
return linear
else:
return TensorParallelRowLinear(linear, process_group=weights.process_group)
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
weight = weights.get_multi_weights_col([prefix], dim=0)
if isinstance(weight, UnquantizedWeight):
# Only on non quantized versions
weight.weight = (
weight.weight.view(
num_heads,
3,
head_size,
hidden_size,
)
.permute(1, 0, 2, 3)
.reshape(-1, hidden_size)
)
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
linear = get_linear(weight, bias)
if config.use_parallel_residual:
return linear
else:
return TensorParallelColumnLinear(linear)
class FlashNeoxAttention(torch.nn.Module):
def __init__(self, config, prefix, weights, rotary_emb):
super().__init__()
num_heads = config.num_attention_heads
hidden_size = config.hidden_size
self.num_heads = num_heads
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
self.rotary_dim = int(config.rotary_pct * self.head_size)
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.rotary_emb = rotary_emb
self.softmax_scale = self.head_size ** (-0.5)
self.query_key_value = load_qkv(
config,
prefix=f"{prefix}.query_key_value",
weights=weights,
num_heads=self.num_heads,
head_size=self.head_size,
hidden_size=self.hidden_size,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=True
)
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
# Compute rotary embeddings on rotary_ndims
query_rot = qkv[:, 0][..., : self.rotary_dim]
query_pass = qkv[:, 0][..., self.rotary_dim :]
key_rot = qkv[:, 1][..., : self.rotary_dim]
key_pass = qkv[:, 1][..., self.rotary_dim :]
# Inplace rotary
self.rotary_emb(query_rot, key_rot, cos, sin)
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
kv_cache.store(
key=qkv[:, 1],
value=qkv[:, 2],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
attn_output = attention(
query=qkv[:, 0],
key=qkv[:, 1],
value=qkv[:, 2],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
qkv[:, 0],
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
class FlashMLP(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
act = config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
self.dense_h_to_4h = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
)
self.dense_4h_to_h = load_row(
config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
)
def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dense_4h_to_h(hidden_states)
return hidden_states
class FlashNeoXLayer(nn.Module):
def __init__(self, layer_id, config, weights, rotary_emb):
super().__init__()
layer_norm_eps = config.layer_norm_eps
prefix = f"gpt_neox.layers.{layer_id}"
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=layer_norm_eps
)
self.post_attention_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=layer_norm_eps,
)
self.attention = FlashNeoxAttention(
config,
prefix=f"{prefix}.attention",
weights=weights,
rotary_emb=rotary_emb,
)
self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights)
self.process_group = weights.process_group
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
if self.use_parallel_residual:
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
attn_output = self.attention(
ln1_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
mlp_output = self.mlp(ln2_hidden_states)
intermediate = mlp_output + attn_output
if self.process_group.size() > 1:
torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate + hidden_states, None
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.attention(
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
mlp_output = self.mlp(hidden_states)
return mlp_output, residual
class FlashGPTNeoXPreTrainedModel(PreTrainedModel):
config_class = GPTNeoXConfig
base_model_prefix = "gpt_neox"
supports_gradient_checkpointing = False
_no_split_modules = None
class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
def __init__(self, prefix: str, config, weights):
super().__init__(config)
self.config = config
self.embed_in = TensorParallelEmbedding(
prefix=f"{prefix}.embed_in", weights=weights
)
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=int(
config.rotary_pct * (config.hidden_size // config.num_attention_heads)
),
base=config.rotary_emb_base,
device=weights.device,
)
self.layers = nn.ModuleList(
[
FlashNeoXLayer(layer_id, config, weights, rotary_emb)
for layer_id in range(config.num_hidden_layers)
]
)
self.final_layer_norm = FastLayerNorm.load(
prefix=f"{prefix}.final_layer_norm",
weights=weights,
eps=config.layer_norm_eps,
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].attention.head_size
self.num_heads = self.layers[0].attention.num_heads
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_in(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
return hidden_states
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
def __init__(self, prefix, config, weights):
super().__init__(config)
if not prefix:
prefix = "gpt_neox"
else:
prefix = f"{prefix}.gpt_neox"
self.gpt_neox = FlashGPTNeoXModel(prefix, config, weights)
self.embed_out = SpeculativeHead.load(
config, prefix="embed_out", weights=weights
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.gpt_neox(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.embed_out(hidden_states)
return logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py
================================================
# coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from typing import Optional, List, Tuple
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
class PaliGemmaForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
self.vision_tower = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config,
weights=weights,
)
self.post_vision_tower_layernorm = nn.LayerNorm.load(
prefix="vision_tower.vision_model.post_layernorm",
weights=weights,
eps=config.vision_config.layer_norm_eps,
)
self.multi_modal_projector = TensorParallelColumnLinear.load(
config,
prefix="multi_modal_projector.linear",
weights=weights,
bias=True,
)
self.vocab_size = config.text_config.vocab_size
self.config = config
text_config = config.text_config
text_config.speculator = config.speculator
text_config.quantize = config.quantize
self.text_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,
weights=weights,
)
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
self.dtype = weights.dtype
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
pixel_values = pixel_values.to(dtype=self.dtype)
image_outputs = self.vision_tower(pixel_values)
last_hidden_state = self.post_vision_tower_layernorm(
image_outputs.last_hidden_state
)
image_features = self.multi_modal_projector(last_hidden_state)
image_features = image_features.view(-1, image_features.shape[-1])
return image_features
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
mask = input_ids == self.config.image_token_index
inputs_embeds[mask] = vision_embeds
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# TODO This is odd but apparently pali gemma position ids start at 1.
if cu_seqlen_prefill is not None:
position_ids += 1
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
adapter_data=adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py
================================================
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
import habana_frameworks.torch as htorch
class PhiConfig(PretrainedConfig):
def __init__(
self,
vocab_size=51200,
hidden_size=2560,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="gelu_fast", # llama uses silu
layer_norm_eps=1e-05, # rms in llama,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=10000.0,
resid_pdrop=0.1, # llama doesn't have this
partial_rotary_factor=0.5, # important difference between llama and phi
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.layer_norm_eps = layer_norm_eps
self.rope_theta = rope_theta
self.resid_pdrop = resid_pdrop
self.partial_rotary_factor = partial_rotary_factor
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# this is the same as llama except for Phi uses bias=True
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=True,
)
def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
# this is the same as llama except for Phi uses bias=True
return TensorParallelColumnLinear(get_linear(weight, bias=True))
class FlashPhiAttention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
rotary_emb,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.softmax_scale = self.head_size**-0.5
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
self.rotary_emb = rotary_emb
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
# in llama the dense layer is called "o_proj" and has bias=False
self.dense = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.dense",
weights=weights,
bias=True,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
# Compute query, key, value and split
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
# Reshape query and key for rotary embeddings
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
# NOTE: this is the main difference between Llama and Phi
# in llama the rotary embeddings are applied to the whole query and key.
# Phi uses PARTIAL rotary embeddings, which are applied to the first 32 dimensions
#
# Apply partial positional embeddings in place
self.rotary_emb(
query[:, :, : self.rotary_dim], kv[:, 0, :, : self.rotary_dim], cos, sin
)
# Reshape key and value and cache
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_scales=self.kv_scales,
kv_cache=kv_cache,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
class PhiMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
# llama weights are up_proj and down_proj and bias=False
self.up_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.fc1",
weights=weights,
bias=True,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.fc2",
weights=weights,
bias=True,
)
def forward(self, hidden_states):
# NOTE: Llama requires the gate up states to an intermediate size
# Phi does not and we can avoid the `view` operation
return self.down_proj(self.act(self.up_proj(hidden_states)))
class FlashPhiLayer(nn.Module):
def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = FlashPhiAttention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
self.mlp = PhiMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
hidden_states = self.resid_dropout(attn_output).add(
self.resid_dropout(self.mlp(hidden_states))
)
return hidden_states, res
class FlashPhiModel(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=int(
config.partial_rotary_factor
* (config.hidden_size // config.num_attention_heads)
),
base=config.rope_theta,
device=weights.device,
)
self.layers = nn.ModuleList(
[
FlashPhiLayer(
prefix,
layer_id,
config,
weights,
rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
self.norm = FastLayerNorm.load(
prefix="model.final_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashPhiForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = FlashPhiModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
return self.lm_head(hidden_states)
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py
================================================
# coding=utf-8
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Phi-MoE model."""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
PHIMOE_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"microsoft/Phi-3.5-MoE-instruct": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct/resolve/main/config.json",
}
class PhiMoEConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PhiMoEModel`]. It is used to instantiate a Phi-MoE
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the
[microsoft/Phi-3.5-MoE-instruct](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32064):
Vocabulary size of the PhiMoE model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`PhiMoEModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 6400):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
The maximum sequence length that this model might ever be used with. Mixtral's sliding window attention
allows sequence of up to 4096*32 tokens.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
The id of the padding token.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`dict`, *optional*):
The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
contain the following keys: `type`, `short_factor`, `long_factor`, `short_mscale`, `long_mscale` and
`original_max_position_embeddings`. The `type` must be `longrope`, the `short_mscale` and `long_scale` must
be numbers, the `short_factor` and `long_factor` must be lists of numbers with the same length as half of
the attention head size and the `original_max_position_embeddings` must be an integer.
sliding_window (`int`, *optional*):
Sliding window attention window size. If not specified, will default to `262144`.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
num_experts_per_tok (`int`, *optional*, defaults to 2):
The number of experts to root per-token, can be also interpreted as the `top-p` routing
parameter
num_local_experts (`int`, *optional*, defaults to 16):
Number of experts per Sparse MLP layer.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabeling this will also
allow the model to output the auxiliary loss. See [here]() for more details
router_aux_loss_coef (`float`, *optional*, defaults to 0.0):
The aux loss factor for the total loss.
router_jitter_noise (`float`, *optional*, defaults to 0.01):
Amount of noise to add to the router.
```python
>>> from transformers import PhiMoEModel, PhiMoEConfig
>>> # Initializing a Phi-3 style configuration
>>> configuration = PhiMoEConfig.from_pretrained("microsoft/Phi-3.5-MoE-instruct")
>>> # Initializing a model from the configuration
>>> model = PhiMoEModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "phimoe"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32064,
hidden_size=4096,
intermediate_size=6400,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=1e6,
rope_scaling=None,
sliding_window=None,
attention_dropout=0.0,
num_experts_per_tok=2,
num_local_experts=16,
output_router_logits=False,
router_aux_loss_coef=0.001,
router_jitter_noise=0.01,
input_jitter_noise=0.0,
attention_bias=False,
lm_head_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
self.attention_bias = attention_bias
self.lm_head_bias = lm_head_bias
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.num_experts_per_tok = num_experts_per_tok
self.num_local_experts = num_local_experts
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.router_jitter_noise = router_jitter_noise
self.input_jitter_noise = input_jitter_noise
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 6:
raise ValueError(
"`rope_scaling` must be a dictionary with three fields, `type`, `short_factor`, `long_factor`, "
f"`short_mscale`, `long_mscale` and `original_max_position_embeddings`, got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
rope_scaling_short_mscale = self.rope_scaling.get("short_mscale", None)
rope_scaling_long_mscale = self.rope_scaling.get("long_mscale", None)
original_max_position_embeddings = self.rope_scaling.get(
"original_max_position_embeddings", None
)
if rope_scaling_type is None or rope_scaling_type not in ["longrope"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}"
)
if not (
isinstance(rope_scaling_short_factor, list)
and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
):
raise ValueError(
f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
)
if (
not len(rope_scaling_short_factor)
== self.hidden_size // self.num_attention_heads // 2
):
raise ValueError(
f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
)
if not (
isinstance(rope_scaling_long_factor, list)
and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
):
raise ValueError(
f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
)
if (
not len(rope_scaling_long_factor)
== self.hidden_size // self.num_attention_heads // 2
):
raise ValueError(
f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
)
if not isinstance(rope_scaling_short_mscale, (int, float)):
raise ValueError(
f"`rope_scaling`'s short_mscale field must be a number, got {rope_scaling_short_mscale}"
)
if not isinstance(rope_scaling_long_mscale, (int, float)):
raise ValueError(
f"`rope_scaling`'s long_mscale field must be a number, got {rope_scaling_long_mscale}"
)
if not isinstance(original_max_position_embeddings, int):
raise ValueError(
f"`rope_scaling`'s original_max_position_embeddings field must be an integer, got {original_max_position_embeddings}"
)
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py
================================================
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
import habana_frameworks.torch as htorch
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=True,
)
def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=True,
)
class Qwen2Attention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
rotary_emb,
):
super().__init__()
self.max_past = (
config.sliding_window
if config.use_sliding_window and config.sliding_window is not None
else -1
)
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = rotary_emb
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
window_size_left=self.max_past,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
class Qwen2MLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
class Qwen2Layer(nn.Module):
def __init__(self, prefix, layer_id, config, weights, rotary_emb):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = Qwen2Attention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
rotary_emb=rotary_emb,
)
self.mlp = Qwen2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
normed_hidden_states, residual = self.input_layernorm(hidden_states)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
hidden_states = attn_output + residual
# faster post attention rms norm
hidden_states, residual = self.post_attention_layernorm(hidden_states)
mlp_output = self.mlp(hidden_states)
hidden_states = mlp_output + residual
return hidden_states
class Qwen2Model(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
prefix = f"{prefix}.model" if prefix else "model"
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.hidden_size // config.num_attention_heads,
base=config.rope_theta,
device=weights.device,
)
self.layers = nn.ModuleList(
[
Qwen2Layer(
prefix,
layer_id,
config,
weights,
rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids,
)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states)
return hidden_states
class Qwen2ForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.model = Qwen2Model(prefix, config, weights)
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
else:
suffix = "lm_head"
self.lm_head = SpeculativeHead.load(
config,
prefix=f"{prefix}.{suffix}" if prefix else suffix,
weights=weights,
)
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens",
weights=weights,
)
self.max_past = config.sliding_window
self.max_past_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
else None
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
inputs_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py
================================================
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, List
import torch
from torch import nn
import habana_frameworks.torch as htorch
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers import (
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelColumnLinear,
SpeculativeHead,
)
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from .flash_qwen2_modeling import Qwen2MLP
from text_generation_server.layers.rotary import PositionRotaryEmbedding
class Qwen3Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config, prefix, weights, layer_idx, rotary_emb):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
self.num_key_value_groups = (
config.num_attention_heads // config.num_key_value_heads
)
self.num_heads = config.num_attention_heads
self.attention_dropout = config.attention_dropout
self.softmax_scale = self.head_dim**-0.5
self.rotary_emb = rotary_emb
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
)
self.q_norm = FastRMSNorm.load(
prefix=f"{prefix}.q_norm",
weights=weights,
eps=config.rms_norm_eps,
)
self.k_norm = FastRMSNorm.load(
prefix=f"{prefix}.k_norm",
weights=weights,
eps=config.rms_norm_eps,
)
self.sliding_window = config.sliding_window
if not (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
self.sliding_window = None
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
qkv = self.query_key_value(hidden_states)
query_states, key_states, value_states = qkv.split(
[
self.head_dim * self.num_heads,
self.head_dim * self.num_key_value_heads,
self.head_dim * self.num_key_value_heads,
],
dim=1,
)
query_states, _ = self.q_norm(query_states.view(hidden_shape))
key_states, _ = self.k_norm(key_states.view(hidden_shape))
value_states = value_states.view(hidden_shape)
self.rotary_emb(query_states, key_states, cos, sin)
kv_cache.store(
key=key_states,
value=value_states,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
attn_output = attention(
query=query_states,
key=key_states,
value=value_states,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past,
)
# Decode
else:
attn_output = paged_attention(
query_states,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
window_size_left=self.max_past,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
return self.o_proj(attn_output)
class Qwen3DecoderLayer(nn.Module):
def __init__(self, config, prefix, weights, layer_idx: int, rotary_emb):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Qwen3Attention(
config=config,
prefix=f"{prefix}.self_attn",
weights=weights,
layer_idx=layer_idx,
rotary_emb=rotary_emb,
)
self.mlp = Qwen2MLP(config=config, prefix=f"{prefix}.mlp", weights=weights)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
) -> torch.Tensor:
residual = hidden_states
hidden_states, _ = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states, _ = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Qwen3Model(nn.Module):
def __init__(self, config, prefix: str, weights):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=head_dim,
base=config.rope_theta,
device=weights.device,
)
self.layers = nn.ModuleList(
[
Qwen3DecoderLayer(
config=config,
prefix=f"{prefix}.layers.{layer_idx}",
weights=weights,
layer_idx=layer_idx,
rotary_emb=rotary_emb,
)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids,
)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states)
# add hidden states from the last decoder layer
return hidden_states
class Qwen3ForCausalLM(nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.model = Qwen3Model(config=config, prefix="model", weights=weights)
self.vocab_size = config.vocab_size
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
else:
suffix = "lm_head"
self.lm_head = SpeculativeHead.load(
config,
prefix=f"{prefix}.{suffix}" if prefix else suffix,
weights=weights,
)
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = self.model(
inputs_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py
================================================
# coding=utf-8
# Copyright 5 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Type
import torch
from torch import nn
import torch.nn.functional as F
from text_generation_server.layers.attention import (
attention,
paged_attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers import (
TensorParallelEmbedding,
TensorParallelColumnLinear,
TensorParallelRowLinear,
SpeculativeHead,
FastLinear,
)
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from .flash_qwen2_modeling import Qwen2MLP
from .flash_qwen3_modeling import Qwen3Attention
from transformers.activations import ACT2FN
from text_generation_server.layers.rotary import PositionRotaryEmbedding
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class Qwen3MoeAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config, prefix, weights, layer_idx, rotary_emb):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = (
config.num_attention_heads // config.num_key_value_heads
)
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = FastLinear.load(
config, f"{prefix}.q_proj", weights, bias=config.attention_bias
)
self.k_proj = FastLinear.load(
config, f"{prefix}.k_proj", weights, bias=config.attention_bias
)
self.v_proj = FastLinear.load(
config, f"{prefix}.v_proj", weights, bias=config.attention_bias
)
self.o_proj = FastLinear.load(
config, f"{prefix}.o_proj", weights, bias=config.attention_bias
)
self.rotary_emb = rotary_emb
self.q_norm = FastRMSNorm.load(
prefix=f"{prefix}.q_norm",
weights=weights,
eps=config.rms_norm_eps,
)
self.k_norm = FastRMSNorm.load(
prefix=f"{prefix}.k_norm",
weights=weights,
eps=config.rms_norm_eps,
)
self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_key_value_groups)
self.sliding_window = config.sliding_window
if not (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
self.sliding_window = None
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states, _ = self.q_norm(self.q_proj(hidden_states).view(hidden_shape))
key_states, _ = self.k_norm(self.k_proj(hidden_states).view(hidden_shape))
value_states = self.v_proj(hidden_states).view(hidden_shape)
self.rotary_emb(query_states, key_states, cos, sin)
# query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
kv_cache.store(
key=key_states,
value=value_states,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
attn_output = attention(
query=query_states,
key=key_states,
value=value_states,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.scaling,
window_size_left=self.max_past,
)
# Decode
else:
attn_output = paged_attention(
query_states,
kv_cache,
self.kv_head_mapping,
self.scaling,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
window_size_left=self.max_past,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output
class Qwen3MoE(nn.Module):
def __init__(self, prefix, config, moe_layer_cls: Type[MoELayer], weights):
super().__init__()
# gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
self.moe = moe_layer_cls(
n_expert_group=None,
n_experts=config.num_experts,
prefix=f"{prefix}.experts",
renormalize=True,
topk=config.num_experts_per_tok,
topk_group=None,
weights=weights,
)
# gate_proj_name="w1",
# up_proj_name="w3",
# down_proj_name="w2",
assert isinstance(self.moe, MoELayer)
self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor:
router_logits = self.gate(x)
out = self.moe(x, gating_output=router_logits)
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out.view(*x.shape)
class Qwen3MoeMLP(nn.Module):
def __init__(self, prefix, config, weights, intermediate_size=None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = (
intermediate_size
if intermediate_size is not None
else config.intermediate_size
)
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
gate_up_states = self.gate_up_proj(x)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
class Qwen3MoeSparseMoeBlock(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
# gating
# self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
self.experts = nn.ModuleList(
[
Qwen3MoeMLP(
prefix=f"{prefix}.experts.{i}",
config=config,
weights=weights,
intermediate_size=config.moe_intermediate_size,
)
for i in range(self.num_experts)
]
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
input_shape = hidden_states.shape
_, hidden_dim = hidden_states.shape
# hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=hidden_states.dtype)
routing_weights, selected_experts = torch.topk(
routing_weights, self.top_k, dim=-1
)
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(input_shape), dtype=hidden_states.dtype, device=hidden_states.device
)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.num_experts
).permute(2, 1, 0)
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = (
expert_layer(current_state) * routing_weights[top_x, idx, None]
)
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(
0, top_x, current_hidden_states.to(hidden_states.dtype)
)
final_hidden_states = final_hidden_states.reshape(input_shape)
return final_hidden_states
class Qwen3MoeDecoderLayer(nn.Module):
def __init__(self, config, prefix, weights, layer_idx: int, rotary_emb):
super().__init__()
self.hidden_size = config.hidden_size
if config.num_key_value_heads // weights.process_group.size() > 0:
self.self_attn = Qwen3Attention(
config,
prefix=f"{prefix}.self_attn",
weights=weights,
layer_idx=layer_idx,
rotary_emb=rotary_emb,
)
else:
self.self_attn = Qwen3MoeAttention(
config,
prefix=f"{prefix}.self_attn",
weights=weights,
layer_idx=layer_idx,
rotary_emb=rotary_emb,
)
moe_layer_cls = (
SparseMoELayer if SparseMoELayer.is_supported(weights) else DenseMoELayer
)
if (layer_idx not in config.mlp_only_layers) and (
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
):
self.mlp = Qwen3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
# self.mlp = Qwen3MoeSparseMoeBlock(f"{prefix}.mlp", config, weights)
else:
self.mlp = Qwen2MLP(config=config, prefix=f"{prefix}.mlp", weights=weights)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
) -> torch.Tensor:
if residual is None:
residual = hidden_states
hidden_states, _ = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states, _ = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Qwen3MoeModel(nn.Module):
def __init__(self, config, prefix: str, weights):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=head_dim,
base=config.rope_theta,
device=weights.device,
)
self.layers = nn.ModuleList(
[
Qwen3MoeDecoderLayer(
config=config,
prefix=f"{prefix}.layers.{layer_idx}",
weights=weights,
layer_idx=layer_idx,
rotary_emb=rotary_emb,
)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, inputs_embeds.shape[0]
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids,
)
residual = None
for i, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
hpu_attention_meta,
)
hidden_states, _ = self.norm(hidden_states)
# add hidden states from the last decoder layer
return hidden_states
class Qwen3MoeForCausalLM(nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.model = Qwen3MoeModel(config=config, prefix="model", weights=weights)
self.vocab_size = config.vocab_size
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
else:
suffix = "lm_head"
self.lm_head = SpeculativeHead.load(
config,
prefix=f"{prefix}.{suffix}" if prefix else suffix,
weights=weights,
)
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = self.model(
inputs_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
================================================
from typing import List, Optional, Tuple
import torch
import torch.distributed
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from text_generation_server.layers import (
SpeculativeHead,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import FastLayerNorm
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.attention import (
attention,
paged_attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
import habana_frameworks.torch as htorch
def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
linear = get_linear(weight, bias)
if config.parallel_attn:
return linear
else:
return TensorParallelRowLinear(linear, process_group=weights.process_group)
class RWConfig(PretrainedConfig):
attribute_map = {
"num_hidden_layers": "n_layer",
"num_attention_heads": "n_head",
"num_key_value_heads": "n_head_kv",
}
def __init__(
self,
model_type="RefinedWeb",
vocab_size=250880,
hidden_size=64,
num_hidden_layers=None,
num_attention_heads=None,
num_ln_in_prallel_attention=None,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
use_cache=True,
bos_token_id=1,
eos_token_id=2,
hidden_dropout=0.0,
attention_dropout=0.0,
num_kv_heads=None,
multi_query=False,
alibi=False,
new_decoder_architecture=None,
bias=False,
parallel_attn=False,
rope_theta=10_000.0,
**kwargs,
):
if alibi:
raise NotImplementedError(
"alibi is not supported by this version of the model"
)
self.model_type = model_type
self.alibi = False
self.rotary = True
self.rope_theta = rope_theta
self.max_position_embeddings = 2048
self.vocab_size = vocab_size
# Backward compatibility with n_embed kwarg
n_embed = kwargs.pop("n_embed", None)
self.hidden_size = hidden_size if n_embed is None else n_embed
self.n_layer = (
num_hidden_layers
if num_hidden_layers is not None
else kwargs.pop("n_layer", 2)
)
self.n_head = (
num_attention_heads
if num_attention_heads is not None
else kwargs.pop("n_head", 8)
)
self.layer_norm_epsilon = layer_norm_epsilon
self.num_ln_in_parallel_attn = num_ln_in_prallel_attention
self.initializer_range = initializer_range
self.use_cache = use_cache
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.bias = bias
self.parallel_attn = parallel_attn
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
if num_kv_heads is not None:
self.n_head_kv = num_kv_heads
else:
old_n_head_kv = kwargs.pop("n_head_kv", None)
if old_n_head_kv is not None:
self.n_head_kv = old_n_head_kv
else:
self.n_head_kv = 1 if multi_query else self.n_head
if new_decoder_architecture is not None:
self.new_decoder_architecture = new_decoder_architecture
elif model_type == "RefinedWeb":
self.new_decoder_architecture = True
else:
self.new_decoder_architecture = False
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
class FlashRWAttention(torch.nn.Module):
def __init__(
self,
config,
prefix: str,
weights,
rotary_emb,
):
super().__init__()
self.num_heads = config.n_head
self.num_heads_kv = config.n_head_kv
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rope_theta = config.rope_theta
self.rotary_emb = rotary_emb
self.softmax_scale = self.head_size ** (-0.5)
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.query_key_value",
weights=weights,
bias=config.bias,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
)
if self.num_heads_kv == 1:
self.kv_head_mapping = torch.zeros(
self.num_heads, dtype=torch.int32, device=weights.device
)
else:
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
# Split query from key_value
query, kv = qkv.split(
[self.head_size * self.num_heads, 2 * self.head_size * self.num_heads_kv],
dim=1,
)
# Prepare query and key_value for indexing
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_heads_kv, self.head_size)
# Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
class FlashRWLargeAttention(torch.nn.Module):
def __init__(
self,
config,
prefix: str,
weights,
rotary_emb,
):
super().__init__()
hidden_size = config.hidden_size
num_heads = config.n_head
# num_heads_kv = config.n_head_kv
num_groups = config.n_head_kv
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
self.num_groups = num_groups
self.rope_theta = config.rope_theta
self.rotary_emb = rotary_emb
self.softmax_scale = self.head_size ** (-0.5)
# self.num_groups = num_heads // (num_heads_kv * 2)
self.num_heads = num_heads // self.num_groups
# self.num_heads_kv = num_heads_kv // self.num_groups
process_group = weights.process_group
if process_group.size() > self.num_groups:
raise NotImplementedError(
"Tensor Parallelism is not implemented for world_size > n groups"
)
if self.num_groups % process_group.size() != 0:
raise NotImplementedError(
f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}"
)
self.num_groups = self.num_groups // process_group.size()
self.query_key_value = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.query_key_value",
weights=weights,
bias=config.bias,
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
)
self.kv_head_mapping = torch.arange(
0, self.num_groups, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_heads)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
# Split on group dimension
query, kv = qkv.split(
[self.num_heads, 2],
dim=2,
)
# Merge groups and heads
query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
# Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
kv_cache.store(
key=kv[:, :, 0].contiguous(),
value=kv[:, :, 1].contiguous(),
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query=query,
key=kv[:, :, 0],
value=kv[:, :, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.dense(
attn_output.view(-1, self.num_groups * self.num_heads * self.head_size)
)
class FlashMLP(nn.Module):
def __init__(self, config, prefix: str, weights):
super().__init__()
self.act = torch.nn.functional.gelu
self.dense_h_to_4h = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias
)
self.dense_4h_to_h = load_row(
config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias
)
def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dense_4h_to_h(hidden_states)
return hidden_states
class FlashRWLayer(nn.Module):
def __init__(
self,
layer_id,
prefix: str,
config,
weights,
rotary_emb,
):
super().__init__()
parallel_attn = config.parallel_attn
self.parallel_attn = parallel_attn
prefix = f"{prefix}.h.{layer_id}"
# NOTE: Falcon 180B uses the ln_attn prefix
ln_prefix = "input_layernorm"
if config.num_hidden_layers == 80:
ln_prefix = "ln_attn"
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.{ln_prefix}",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.self_attention = FlashRWAttention(
config,
prefix=f"{prefix}.self_attention",
weights=weights,
rotary_emb=rotary_emb,
)
self.post_attention_layernorm = (
FastLayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
if not parallel_attn
else None
)
self.mlp = FlashMLP(
config,
prefix=f"{prefix}.mlp",
weights=weights,
)
self.process_group = weights.process_group
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
if self.parallel_attn:
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
attn_output = self.self_attention(
ln_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
mlp_output = self.mlp(ln_hidden_states)
intermediate = mlp_output + attn_output
if self.process_group.size() > 1:
torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate, residual
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attention(
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
if self.post_attention_layernorm is not None:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
mlp_output = self.mlp(hidden_states)
return mlp_output, residual
class FlashRWLayerNorm(nn.Module):
def __init__(self, config, prefix: str, weights):
super().__init__()
# Falcon2 includes the number of layer norms in the config
# in the case no number of layer norms is provided, we default to 1
self.num_ln = getattr(config, "num_ln_in_parallel_attn", 1)
# Falcon 180B uses the ln_attn prefix and has 2 layer norms
if config.num_hidden_layers == 80:
self.num_ln = 2
if self.num_ln == 1:
self.input_ln = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
elif self.num_ln == 2:
self.ln_attn = FastLayerNorm.load(
prefix=f"{prefix}.ln_attn",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.ln_mlp = FastLayerNorm.load(
prefix=f"{prefix}.ln_mlp",
weights=weights,
eps=config.layer_norm_epsilon,
)
else:
raise ValueError("Number of layer norms can either be 1 or 2.")
def forward(
self,
hidden_states,
residual,
):
if self.num_ln == 1:
ln_hidden_states, residual = self.input_ln(hidden_states, residual)
return ln_hidden_states, ln_hidden_states, residual
elif self.num_ln == 2:
ln_attn, residual = self.ln_attn(hidden_states, residual)
ln_mlp, _ = self.ln_mlp(residual)
return ln_attn, ln_mlp, residual
class FlashRWLargeLayer(nn.Module):
def __init__(self, layer_id, prefix: str, config, weights, rotary_emb):
super().__init__()
prefix = f"{prefix}.h.{layer_id}"
self.ln_layer = FlashRWLayerNorm(config, prefix, weights)
self.self_attention = FlashRWLargeAttention(
config,
prefix=f"{prefix}.self_attention",
weights=weights,
rotary_emb=rotary_emb,
)
assert config.parallel_attn, "This version doesn't support non parallel_attn"
self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights)
self.process_group = weights.process_group
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
# Layer norm.
ln_attn, ln_mlp, residual = self.ln_layer(hidden_states, residual)
# Self attention.
attn_output = self.self_attention(
ln_attn,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
# MLP.
mlp_output = self.mlp(ln_mlp)
intermediate = attn_output + mlp_output
if self.process_group.size() > 1:
torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate, residual
class FlashRWPreTrainedModel(PreTrainedModel):
config_class = RWConfig
class FlashRWModel(FlashRWPreTrainedModel):
def __init__(self, prefix: str, config, weights):
super().__init__(config)
self.config = config
self.word_embeddings = TensorParallelEmbedding(
prefix=f"{prefix}.word_embeddings", weights=weights
)
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.hidden_size // config.n_head,
base=config.rope_theta,
device=weights.device,
)
if config.new_decoder_architecture:
self.h = nn.ModuleList(
[
FlashRWLargeLayer(layer_id, prefix, config, weights, rotary_emb)
for layer_id in range(config.num_hidden_layers)
]
)
self.cache_size = self.h[0].self_attention.num_groups
else:
self.h = nn.ModuleList(
[
FlashRWLayer(layer_id, prefix, config, weights, rotary_emb)
for layer_id in range(config.num_hidden_layers)
]
)
self.cache_size = self.h[0].self_attention.num_heads_kv
self.ln_f = FastLayerNorm.load(
prefix=f"{prefix}.ln_f",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.head_size = self.h[0].self_attention.head_size
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.word_embeddings(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.h):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states
class FlashRWForCausalLM(FlashRWPreTrainedModel):
def __init__(self, prefix: str, config, weights):
super().__init__(config)
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
self.transformer = FlashRWModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.transformer(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
================================================
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
SpeculativeHead,
TensorParallelEmbedding,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
import habana_frameworks.torch as htorch
def load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
if config.quantize == "gptq":
return _load_multi_mqa_gptq(
config, prefix, weights, bias, head_size, num_heads, hidden_size
)
else:
return _load_multi_mqa(
config, prefix, weights, bias, head_size, num_heads, hidden_size
)
def _load_multi_mqa_gptq(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
from text_generation_server.layers.gptq import GPTQWeight
if any("c_attn" in k for k in weights.routing.keys()) and not config.transpose:
world_size = weights.process_group.size()
rank = weights.process_group.rank()
slice_ = weights._get_slice(f"{prefix}.c_attn.qweight")
shape = slice_.get_shape()
block_size = (shape[1] - 2 * head_size) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
assert (shape[1] - 2 * head_size) % world_size == 0
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size :]
qweight = torch.cat([q_tensor, kv_tensor], dim=1)
qweight = qweight.to(device=weights.device)
slice_ = weights._get_slice(f"{prefix}.c_attn.scales")
shape = slice_.get_shape()
block_size = (shape[1] - 2 * head_size) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
assert (shape[1] - 2 * head_size) % world_size == 0
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size :]
scales = torch.cat([q_tensor, kv_tensor], dim=1)
scales = scales.to(device=weights.device)
slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros")
shape = slice_.get_shape()
block_size = (shape[1] - (2 * head_size) * 4 // 32) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
assert 2 * head_size % (32 // 4) == 0
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size * 4 // 32 :]
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
qzeros = qzeros.to(device=weights.device)
loader = weights.weights_loader
assert isinstance(loader, GPTQWeightsLoader)
loader._get_gptq_params(weights)
if loader.quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device)
elif loader.quant_method == "awq":
g_idx = None
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
from text_generation_server.layers.gptq import HAS_EXLLAMA
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=loader.bits,
groupsize=loader.groupsize,
use_awq_kernel=loader.quantize == "awq",
use_exllama=HAS_EXLLAMA,
)
if bias:
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
shape = slice_.get_shape()
block_size = (shape[0] - 2 * head_size) // world_size
assert (shape[0] - 2 * head_size) % world_size == 0
q_tensor = slice_[start:stop]
start = rank * block_size
stop = (rank + 1) * block_size
q_tensor = slice_[start:stop]
kv_tensor = slice_[-2 * head_size :]
bias = torch.cat([q_tensor, kv_tensor], dim=0)
bias = bias.to(device=weights.device)
return TensorParallelColumnLinear(get_linear(weight, bias))
else:
raise NotImplementedError("Gptq loading with santacoder is not implemented")
def _load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
if any("c_attn" in k for k in weights.routing.keys()):
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
shape = slice_.get_shape()
world_size = weights.process_group.size()
rank = weights.process_group.rank()
if config.transpose:
block_size = (shape[1] - 2 * head_size) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
assert (shape[1] - 2 * head_size) % world_size == 0
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size :]
weight = torch.cat([q_tensor, kv_tensor], dim=1).T
else:
block_size = (shape[0] - 2 * head_size) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
assert (shape[0] - 2 * head_size) % world_size == 0
q_tensor = slice_[start:stop]
kv_tensor = slice_[-2 * head_size :]
weight = torch.cat([q_tensor, kv_tensor], dim=0)
if bias:
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
shape = slice_.get_shape()
block_size = (shape[0] - 2 * head_size) // world_size
assert (shape[0] - 2 * head_size) % world_size == 0
start = rank * block_size
stop = (rank + 1) * block_size
q_tensor = slice_[start:stop]
kv_tensor = slice_[-2 * head_size :]
bias = torch.cat([q_tensor, kv_tensor], dim=0)
else:
if config.transpose:
w = [
weights.get_sharded(f"{prefix}.q_attn.weight", dim=1).T,
weights.get_tensor(f"{prefix}.kv_attn.weight").T,
]
weight = torch.cat(w, dim=0)
else:
w = [
weights.get_sharded(f"{prefix}.q_attn.weight", dim=0),
weights.get_tensor(f"{prefix}.kv_attn.weight"),
]
weight = torch.cat(w, dim=1)
if bias:
b = [
weights.get_sharded(f"{prefix}.q_attn.bias", dim=0),
weights.get_tensor(f"{prefix}.kv_attn.bias"),
]
bias = torch.cat(b, dim=0)
else:
bias = None
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
assert list(weight.shape) == [
(num_heads + 2) * head_size,
hidden_size,
], f"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}"
if bias is not None:
bias = bias.to(dtype=weights.dtype).to(device=weights.device)
assert list(bias.shape) == [
(num_heads + 2) * head_size
], f"{weight.shape} != {[(num_heads + 2) * head_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias))
def load_col(config, prefix: str, weights, bias: bool):
if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
else:
weight = weights.get_multi_weights_col([prefix], dim=0)
if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else:
bias = None
return TensorParallelColumnLinear(get_linear(weight, bias))
def load_row(config, prefix: str, weights, bias: bool):
if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
else:
weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return TensorParallelRowLinear(
get_linear(weight, bias), process_group=weights.process_group
)
class FlashMQAttention(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
num_heads = config.num_attention_heads
hidden_size = config.hidden_size
self.num_heads = num_heads
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.softmax_scale = self.head_size ** (-0.5)
self.c_attn = load_multi_mqa(
config,
prefix=prefix,
weights=weights,
bias=True,
head_size=self.head_size,
hidden_size=hidden_size,
num_heads=self.num_heads,
)
self.c_proj = load_row(
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_head_mapping = torch.zeros(
self.num_heads, dtype=torch.int32, device=weights.device
)
def forward(
self,
hidden_states,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
qkv = self.c_attn(hidden_states)
# Split query from key_value
query, key_value = qkv.split(
[self.head_size * self.num_heads, 2 * self.head_size], dim=1
)
# Prepare query and key_value for indexing
query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size)
kv_cache.store(
key=key_value[:, 0],
value=key_value[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
attn_output = attention(
query=query,
key=key_value[:, 0],
value=key_value[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
class MLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.activation_function
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
self.c_fc = load_col(
config, prefix=f"{prefix}.c_fc", weights=weights, bias=True
)
self.c_proj = load_row(
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
)
def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
return hidden_states
class Block(nn.Module):
def __init__(self, prefix: str, layer_id, config, weights):
super().__init__()
prefix = f"{prefix}.h.{layer_id}"
self.ln_1 = FastLayerNorm.load(
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
)
self.ln_2 = FastLayerNorm.load(
prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon
)
self.self_attn = FlashMQAttention(
prefix=f"{prefix}.attn",
config=config,
weights=weights,
)
self.mlp = MLP(
prefix=f"{prefix}.mlp",
config=config,
weights=weights,
)
def forward(
self,
hidden_states,
residual,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
):
hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.self_attn(
hidden_states,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
hidden_states, residual = self.ln_2(hidden_states, residual)
mlp_output = self.mlp(hidden_states)
return mlp_output, residual
class FlashSantacoderModel(nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.config = config
self.process_group = weights.process_group
self.wte = TensorParallelEmbedding(
prefix=f"{prefix}.wte",
weights=weights,
reduce=False,
)
self.wpe = TensorParallelEmbedding(
prefix=f"{prefix}.wpe",
weights=weights,
reduce=False,
)
self.layers = nn.ModuleList(
[
Block(
prefix,
layer_id,
config,
weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.ln_f = FastLayerNorm.load(
prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon
)
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
if self.process_group.size() > 1:
torch.distributed.all_reduce(hidden_states, group=self.process_group)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states
class FlashSantacoderForCausalLM(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
config.transpose = config.architectures[0].startswith("GPT2")
self.model = FlashSantacoderModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config, prefix=f"{prefix}.wte", weights=weights
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py
================================================
# coding=utf-8
# Copyright 2024 Starcoder2 AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
set_block_mapping,
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import (
FastLayerNorm,
FastRMSNorm,
)
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch
class Starcoder2Config(PretrainedConfig):
model_type = "starcoder2"
def __init__(
self,
vocab_size=49152,
hidden_size=3072,
intermediate_size=12288,
num_hidden_layers=30,
num_attention_heads=24,
num_key_value_heads=2,
mlp_type="default",
hidden_act="gelu_pytorch_tanh",
max_position_embeddings=4096,
initializer_range=0.018042,
norm_type="layer_norm",
norm_epsilon=1e-5,
use_cache=True,
bos_token_id=50256,
eos_token_id=50256,
rope_theta=10000.0,
sliding_window=None,
attention_dropout=0.0,
residual_dropout=0.0,
embedding_dropout=0.0,
use_bias: bool = True,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
self.use_bias = use_bias
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.mlp_type = mlp_type
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.norm_type = norm_type
self.norm_epsilon = norm_epsilon
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.residual_dropout = residual_dropout
self.embedding_dropout = embedding_dropout
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs,
)
def load_attention(config, prefix, weights, layer_id):
prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
head_size = config.hidden_size // config.num_attention_heads
sizes = [
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
]
if config.num_attention_heads != config.num_key_value_heads:
base_layer = _load_gqa(config, prefix, weights)
else:
base_layer = TensorParallelColumnLinear.load_multi(
config,
prefixes=prefixes,
dim=0,
weights=weights,
bias=config.use_bias,
)
return TensorParallelMultiAdapterLinear.load(
base_layer=base_layer,
layer_id=layer_id,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
)
if isinstance(weight, UnquantizedWeight):
weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
if config.use_bias:
w = [
weights.get_sharded(f"{p}.bias", dim=0)
for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
]
bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device)
else:
bias = None
return TensorParallelColumnLinear(get_linear(weight, bias=bias))
class Starcoder2Attention(torch.nn.Module):
def __init__(
self,
index: int,
prefix: str,
config,
weights,
rotary_emb,
):
super().__init__()
self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = rotary_emb
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights, index)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=getattr(config, "use_bias", False),
)
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
index,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
):
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# sdpa
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
window_size_left=self.max_past,
)
return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class Starcoder2MLP(nn.Module):
def __init__(self, prefix, config, weights, index):
super().__init__()
act = config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
# Fuse gate and up proj
c_fc = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.c_fc",
weights=weights,
bias=config.use_bias,
)
c_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.c_proj",
weights=weights,
bias=config.use_bias,
)
self.c_fc = TensorParallelMultiAdapterLinear.load(
c_fc,
layer_id=index,
layer_names=[f"{prefix}.c_fc"],
sizes=[config.intermediate_size, config.intermediate_size],
process_group=weights.process_group,
)
self.c_proj = TensorParallelAdapterRowLinear.load(
c_proj,
index,
"c_proj",
process_group=weights.process_group,
)
def forward(self, hidden_states, adapter_data):
hidden_states = self.c_fc(hidden_states, adapter_data)
hidden_states = self.act(hidden_states)
return self.c_proj(hidden_states, adapter_data)
class Starcoder2GatedMLP(nn.Module):
def __init__(self, index, prefix, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
# Fuse gate and up proj
prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
sizes = [
config.intermediate_size,
config.intermediate_size,
]
gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=prefixes,
weights=weights,
dim=0,
bias=config.use_bias,
)
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
index,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=config.use_bias,
)
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
index,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
STARCODER2_NORMALIZATION_CLASSES = {
"layer_norm": FastLayerNorm,
"rms_norm": FastRMSNorm,
}
STARCODER2_MLP_CLASSES = {
"default": Starcoder2MLP,
"gated": Starcoder2GatedMLP,
}
class Starcoder2Layer(nn.Module):
def __init__(self, layer_id, config, weights, rotary_emb):
super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = Starcoder2Attention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
index=layer_id,
rotary_emb=rotary_emb,
)
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id
)
self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.norm_epsilon
)
self.post_attention_layernorm = STARCODER2_NORMALIZATION_CLASSES[
config.norm_type
].load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.norm_epsilon,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
)
# faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
)
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
return mlp_output, attn_res
class Starcoder2Model(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=config.hidden_size // config.num_attention_heads,
base=config.rope_theta,
device=weights.device,
)
self.layers = nn.ModuleList(
[
Starcoder2Layer(
layer_id,
config,
weights,
rotary_emb,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
prefix=f"{prefix}.norm", weights=weights, eps=config.norm_epsilon
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
adapter_data,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
if hpu_attention_meta is not None:
hpu_attention_meta = set_block_mapping(
hpu_attention_meta, input_ids.shape[0]
)
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
adapter_data,
hpu_attention_meta,
)
if lazy_mode:
htorch.core.mark_step()
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashStarcoder2ForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = Starcoder2Model(prefix, config, weights)
try:
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head",
weights=weights,
)
except RuntimeError:
self.lm_head = SpeculativeHead.load(
config,
prefix=f"{prefix}.embed_tokens",
weights=weights,
)
self.max_past = config.sliding_window
self.max_past_tensor = (
torch.tensor(config.sliding_window, device=weights.device)
if self.max_past is not None
else None
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py
================================================
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Idefics2 model."""
from typing import List, Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
import math
from transformers.activations import ACT2FN
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
)
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class Idefics2VisionEmbeddings(nn.Module):
"""
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
resolution.
The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
which allows treating images in their native aspect ratio and without the need to resize them to the same
fixed size. In particular, we start from the original pre-trained SigLIP model
(which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
"""
def __init__(self, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
)
self.patch_embedding.bias = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
)
self.num_patches_per_side = self.image_size // self.patch_size
self.num_patches = self.num_patches_per_side**2
self.num_positions = self.num_patches
self.position_embedding = TensorParallelEmbedding(
prefix=f"{prefix}.position_embedding", weights=weights
)
def forward(
self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
) -> torch.Tensor:
batch_size, _, max_im_h, max_im_w = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
max_nb_patches_h, max_nb_patches_w = (
max_im_h // self.patch_size,
max_im_w // self.patch_size,
)
boundaries = torch.arange(
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
)
position_ids = torch.full(
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(
fractional_coords_h, boundaries, right=True
)
bucket_coords_w = torch.bucketize(
fractional_coords_w, boundaries, right=True
)
pos_ids = (
bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
position_ids = position_ids.to(self.position_embedding.weight.device)
embeddings = embeddings + self.position_embedding(position_ids)
return embeddings
class Idefics2VisionAttention(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_size = self.embed_dim // self.num_heads
if self.head_size * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_size**-0.5
self.dropout = config.attention_dropout
self.num_heads = self.num_heads // weights.process_group.size()
self.embed_dim = self.embed_dim // weights.process_group.size()
self.qkv = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=True,
)
self.out_proj = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
)
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, q_len, _ = hidden_states.size()
qkv = self.qkv(hidden_states)
query_states, key_states, value_states = qkv.split(
[
self.head_size * self.num_heads,
self.head_size * self.num_heads,
self.head_size * self.num_heads,
],
dim=2,
)
query_states = query_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
key_states = key_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
value_states = value_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = (
torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
)
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
raise ValueError(
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class Idefics2VisionMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Idefics2EncoderLayer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = Idefics2VisionAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.layer_norm1 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights
)
self.layer_norm2 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights
)
self.mlp = Idefics2VisionMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights
)
# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Idefics2Encoder(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[
Idefics2EncoderLayer(
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
)
for i in range(config.num_hidden_layers)
]
)
# Ignore copy
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
):
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
attention_mask,
)
return hidden_states
class Idefics2VisionTransformer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.embeddings = Idefics2VisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights
)
self.encoder = Idefics2Encoder(
prefix=f"{prefix}.encoder", config=config, weights=weights
)
self.post_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.post_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
def forward(
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
):
batch_size = pixel_values.size(0)
if patch_attention_mask is None:
patch_size = self.config.patch_size
patch_attention_mask = torch.ones(
(
batch_size,
pixel_values.size(2) // patch_size,
pixel_values.size(3) // patch_size,
)
)
patch_attention_mask = patch_attention_mask.to(
dtype=torch.bool, device=pixel_values.device
)
hidden_states = self.embeddings(
pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
)
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
# The call to `_upad_input` in `_flash_attention_forward` is expensive
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
if not torch.any(~patch_attention_mask):
patch_attention_mask = None
else:
patch_attention_mask = _prepare_4d_attention_mask(
patch_attention_mask, hidden_states.dtype
)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=patch_attention_mask,
)
last_hidden_state = encoder_outputs
last_hidden_state = self.post_layernorm(last_hidden_state)
return last_hidden_state
class Idefics2MLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.text_config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
def forward(self, hidden_states):
start_shape = hidden_states.shape[:-1]
gate_up_states = self.gate_up_proj(hidden_states)
intermediate_size = gate_up_states.shape[-1] // 2
gate_up_states = gate_up_states.view(-1, 2, intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]
).view(*start_shape, -1)
class Idefics2RMSNorm(nn.Module):
def __init__(self, prefix, weights, eps):
"""
Idefics2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.weight"), requires_grad=False
)
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class Idefics2PerceiverAttention(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.layer_idx = None
self.hidden_size = config.text_config.hidden_size
self.num_heads = config.perceiver_config.resampler_n_heads
self.head_size = config.perceiver_config.resampler_head_dim
self.num_key_value_heads = config.perceiver_config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.attention_dropout = config.perceiver_config.attention_dropout
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
self.num_key_value_heads // weights.process_group.size()
)
self.q_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_proj",
weights=weights,
bias=False,
)
self.kv = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
)
self.is_causal = False
def forward(
self,
latents: torch.Tensor,
context: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = latents.size()
kv_seq_len = q_len + context.size()[1]
hidden_states = torch.concat([context, latents], dim=-2)
query_states = self.q_proj(latents)
kv = self.kv(hidden_states)
key_states, value_states = kv.split(
[
self.head_size * self.num_key_value_heads,
self.head_size * self.num_key_value_heads,
],
dim=2,
)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_size
).transpose(1, 2)
key_states = key_states.view(
bsz, kv_seq_len, self.num_key_value_heads, self.head_size
).transpose(1, 2)
value_states = value_states.view(
bsz, kv_seq_len, self.num_key_value_heads, self.head_size
).transpose(1, 2)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_size)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_size):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_size)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_size)
attn_output = self.o_proj(attn_output)
return attn_output
class Idefics2PerceiverLayer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.hidden_size = config.text_config.hidden_size
self.n_latents = config.perceiver_config.resampler_n_latents
self.depth = config.perceiver_config.resampler_depth
self.rms_norm_eps = config.text_config.rms_norm_eps
self.input_latents_norm = Idefics2RMSNorm(
prefix=f"{prefix}.input_latents_norm",
weights=weights,
eps=self.rms_norm_eps,
)
self.input_context_norm = Idefics2RMSNorm(
prefix=f"{prefix}.input_context_norm",
weights=weights,
eps=self.rms_norm_eps,
)
self.self_attn = Idefics2PerceiverAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.post_attention_layernorm = Idefics2RMSNorm(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=self.rms_norm_eps,
)
self.mlp = Idefics2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
def forward(
self,
latents: torch.Tensor,
context: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
"""
Args:
latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
"""
residual = latents
latents = self.input_latents_norm(latents)
context = self.input_context_norm(context)
latents = self.self_attn(
latents=latents,
context=context,
attention_mask=attention_mask,
)
latents = residual + latents
residual = latents
latents = self.post_attention_layernorm(latents)
latents = self.mlp(latents)
latents = residual + latents
return latents
class Idefics2PerceiverResampler(nn.Module):
def __init__(self, prefix, config, weights) -> None:
super().__init__()
self.hidden_size = config.text_config.hidden_size
self.hidden_act = config.perceiver_config.hidden_act
self.n_latents = config.perceiver_config.resampler_n_latents
self.depth = config.perceiver_config.resampler_depth
self.rms_norm_eps = config.text_config.rms_norm_eps
# Create Latents for Perceiver
self.latents = weights.get_tensor(f"{prefix}.latents")
# Create Transformer Blocks
self.layers = nn.ModuleList(
[
Idefics2PerceiverLayer(
prefix=f"{prefix}.layers.{idx}", config=config, weights=weights
)
for idx in range(self.depth)
]
)
self.norm = Idefics2RMSNorm(
prefix=f"{prefix}.norm",
weights=weights,
eps=config.text_config.rms_norm_eps,
)
def forward(
self,
context: torch.Tensor,
attention_mask,
) -> torch.Tensor:
# seq embed -> bsz seq embed
latents = self.latents.unsqueeze(0).expand(
(context.shape[0], *self.latents.size())
)
latent_attention_mask = torch.ones(
(attention_mask.size(0), latents.size(1)),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1)
attention_mask = _prepare_4d_attention_mask(
attention_mask, latents.dtype, tgt_len=self.n_latents
)
compressed_context = latents
for perceiver_layer in self.layers:
compressed_context = perceiver_layer(
compressed_context,
context,
attention_mask=attention_mask,
)
compressed_context = self.norm(compressed_context)
return compressed_context
class Idefics2Connector(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.modality_projection = Idefics2MLP(
prefix=f"{prefix}.modality_projection", config=config, weights=weights
)
self.perceiver_resampler = Idefics2PerceiverResampler(
prefix=f"{prefix}.perceiver_resampler", config=config, weights=weights
)
def forward(self, image_hidden_states, attention_mask):
image_hidden_states = self.modality_projection(image_hidden_states)
image_hidden_states = self.perceiver_resampler(
context=image_hidden_states, attention_mask=attention_mask
)
return image_hidden_states
class Idefics2ForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = None
config.vision_config.speculator = config.speculator
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
vision_config = config.vision_config
self.text_model = load_text_model(
prefix="model" if not prefix else f"{prefix}.model",
config=config.text_config,
weights=weights,
name="text_model",
)
self.dtype = weights.dtype
# The vision and connector models are not quantized.
with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):
self.vision_model = Idefics2VisionTransformer(
prefix=(
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
),
config=vision_config,
weights=weights,
)
config.quantize = None
self.connector = Idefics2Connector(
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
config=config,
weights=weights,
)
self.config = config
self.image_seq_len = config.perceiver_config.resampler_n_latents
self.image_token_id = config.image_token_id
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
def _merge_input_ids_with_image_features(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
# mask = input_ids == self.config.image_token_index
# - replace `==` with torch.where to fix the issue in hpu graph
mask = torch.where(input_ids == self.config.image_token_id)
# Let's pray we have enabled enough slots !
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
assert pixel_values is not None
batch_size, num_images, num_channels, height, width = pixel_values.shape
all_states = []
all_pixel_values = pixel_values
all_pixel_mask = pixel_attention_mask
for i in range(batch_size):
pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility
pixel_values = pixel_values[i : i + 1]
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(
dim=(-1, -2, -3)
) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(
pixel_values.size(0),
pixel_values.size(2),
pixel_values.size(3),
),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask/pP p
pixel_attention_mask = all_pixel_mask[i : i + 1]
pixel_attention_mask = pixel_attention_mask.view(
1 * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[
real_images_inds
].contiguous()
patch_size = self.config.vision_config.patch_size
"""
patches_subgrid = pixel_attention_mask.unfold(
dimension=1, size=patch_size, step=patch_size
)
patches_subgrid = patches_subgrid.unfold(
dimension=2, size=patch_size, step=patch_size
)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
"""
# hpu does none support unfold
conv_kernel = torch.ones(
[1, 1, patch_size, patch_size],
dtype=pixel_values.dtype,
device=pixel_values.device,
)
patches_subgrid = torch.nn.functional.conv2d(
pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),
conv_kernel,
stride=patch_size,
).squeeze(1)
patch_attention_mask = torch.gt(patches_subgrid, 0)
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
)
# Modality projection & resampling
image_hidden_states = self.connector(
image_hidden_states,
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
)
all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0)
return image_hidden_states.view(-1, image_hidden_states.shape[-1])
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, vision_embeds
)
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
):
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
adapter_data=adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py
================================================
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Idefics3 model."""
from typing import List, Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
)
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class Idefics3VisionEmbeddings(nn.Module):
"""
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
resolution.
The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
which allows treating images in their native aspect ratio and without the need to resize them to the same
fixed size. In particular, we start from the original pre-trained SigLIP model
(which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
"""
def __init__(self, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
)
self.patch_embedding.bias = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
)
self.num_patches_per_side = self.image_size // self.patch_size
self.num_patches = self.num_patches_per_side**2
self.num_positions = self.num_patches
self.position_embedding = TensorParallelEmbedding(
prefix=f"{prefix}.position_embedding", weights=weights
)
def forward(
self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
) -> torch.Tensor:
batch_size, _, max_im_h, max_im_w = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
max_nb_patches_h, max_nb_patches_w = (
max_im_h // self.patch_size,
max_im_w // self.patch_size,
)
boundaries = torch.arange(
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
)
position_ids = torch.full(
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(
fractional_coords_h, boundaries, right=True
)
bucket_coords_w = torch.bucketize(
fractional_coords_w, boundaries, right=True
)
pos_ids = (
bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
position_ids = position_ids.to(self.position_embedding.weight.device)
embeddings = embeddings + self.position_embedding(position_ids)
return embeddings
class Idefics3VisionAttention(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_size = self.embed_dim // self.num_heads
if self.head_size * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_size**-0.5
self.dropout = config.attention_dropout
self.num_heads = self.num_heads // weights.process_group.size()
self.embed_dim = self.embed_dim // weights.process_group.size()
self.qkv = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=True,
)
self.out_proj = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
)
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, q_len, _ = hidden_states.size()
qkv = self.qkv(hidden_states)
query_states, key_states, value_states = qkv.split(
[
self.head_size * self.num_heads,
self.head_size * self.num_heads,
self.head_size * self.num_heads,
],
dim=2,
)
query_states = query_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
key_states = key_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
value_states = value_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = (
torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
)
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
raise ValueError(
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class Idefics3VisionMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Idefics3EncoderLayer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = Idefics3VisionAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.layer_norm1 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights
)
self.layer_norm2 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights
)
self.mlp = Idefics3VisionMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights
)
# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Idefics3Encoder(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[
Idefics3EncoderLayer(
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
)
for i in range(config.num_hidden_layers)
]
)
# Ignore copy
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
):
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
attention_mask,
)
return hidden_states
class Idefics3VisionTransformer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.embeddings = Idefics3VisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights
)
self.encoder = Idefics3Encoder(
prefix=f"{prefix}.encoder", config=config, weights=weights
)
self.post_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.post_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
def forward(
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
):
batch_size = pixel_values.size(0)
if patch_attention_mask is None:
patch_size = self.config.patch_size
patch_attention_mask = torch.ones(
(
batch_size,
pixel_values.size(2) // patch_size,
pixel_values.size(3) // patch_size,
)
)
patch_attention_mask = patch_attention_mask.to(
dtype=torch.bool, device=pixel_values.device
)
hidden_states = self.embeddings(
pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
)
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
# The call to `_upad_input` in `_flash_attention_forward` is expensive
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
if not torch.any(~patch_attention_mask):
patch_attention_mask = None
else:
patch_attention_mask = _prepare_4d_attention_mask(
patch_attention_mask, hidden_states.dtype
)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=patch_attention_mask,
)
last_hidden_state = encoder_outputs
last_hidden_state = self.post_layernorm(last_hidden_state)
return last_hidden_state
class Idefics3SimpleMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
input_size = config.vision_config.hidden_size * (config.scale_factor**2)
output_size = config.text_config.hidden_size
proj = nn.Parameter(
weights.get_tensor(f"{prefix}.modality_projection.proj.weight"),
requires_grad=False,
).to(weights.dtype)
self.proj = nn.Linear(input_size, output_size, bias=False)
self.proj.weight = proj
def forward(self, x):
return self.proj(x)
class Idefics3Connector(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.modality_projection = Idefics3SimpleMLP(prefix, config, weights)
self.scale_factor = config.scale_factor
def pixel_shuffle(self, x, scale_factor=2):
bsz, seq, embed_dim = x.size()
height = width = int(seq**0.5)
x = x.view(bsz, height, width, embed_dim)
x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
x = x.permute(0, 2, 1, 3)
x = x.reshape(
bsz,
int(width / scale_factor),
int(height / scale_factor),
embed_dim * (scale_factor**2),
)
x = x.permute(0, 2, 1, 3)
x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
return x
def forward(self, image_hidden_states):
image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
image_hidden_states = self.modality_projection(image_hidden_states)
return image_hidden_states
class Idefics3ForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = None
config.vision_config.speculator = config.speculator
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
# set tie_word_embeddings to True to load `.embed_tokens.weight` instead of `.lm_head.weight`
# since Idefics3 uses the `embed_tokens` for the final prediction
# config.text_config.tie_word_embeddings = True
vision_config = config.vision_config
self.text_model = load_text_model(
prefix="model" if not prefix else f"{prefix}.model",
config=config.text_config,
weights=weights,
name="text_model",
)
self.dtype = weights.dtype
# The vision and connector models are not quantized.
with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):
self.vision_model = Idefics3VisionTransformer(
prefix=(
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
),
config=vision_config,
weights=weights,
)
config.quantize = None
self.connector = Idefics3Connector(
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
config=config,
weights=weights,
)
self.config = config
self.image_token_id = config.image_token_id
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
def _merge_input_ids_with_image_features(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
# mask = input_ids == self.config.image_token_index
# - replace `==` with torch.where to fix the issue in hpu graph
mask = torch.where(input_ids == self.config.image_token_id)
# Let's pray we have enabled enough slots !
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
batch_size, num_images, num_channels, height, width = pixel_values.shape
all_states = []
all_pixel_values = pixel_values
all_pixel_mask = pixel_attention_mask
for i in range(batch_size):
pixel_values = all_pixel_values.to(dtype=self.dtype) # fp16 compatibility
pixel_values = pixel_values[i : i + 1]
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(
dim=(-1, -2, -3)
) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(
pixel_values.size(0),
pixel_values.size(2),
pixel_values.size(3),
),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask/pP p
pixel_attention_mask = all_pixel_mask[i : i + 1]
pixel_attention_mask = pixel_attention_mask.view(
1 * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[
real_images_inds
].contiguous()
patch_size = self.config.vision_config.patch_size
"""
patches_subgrid = pixel_attention_mask.unfold(
dimension=1, size=patch_size, step=patch_size
)
patches_subgrid = patches_subgrid.unfold(
dimension=2, size=patch_size, step=patch_size
)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
"""
# hpu does none support unfold
conv_kernel = torch.ones(
[1, 1, patch_size, patch_size],
dtype=pixel_values.dtype,
device=pixel_values.device,
)
patches_subgrid = torch.nn.functional.conv2d(
pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),
conv_kernel,
stride=patch_size,
).squeeze(1)
patch_attention_mask = torch.gt(patches_subgrid, 0)
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
)
# Modality projection & resampling
image_hidden_states = self.connector(
image_hidden_states,
)
all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0)
return image_hidden_states.view(-1, image_hidden_states.shape[-1])
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if vision_embeds is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, vision_embeds
)
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_indices=None,
):
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
adapter_data=adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/mamba_modeling.py
================================================
import torch
import torch.distributed
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from torch import nn
from typing import Optional, Tuple, Any
from transformers.configuration_utils import PretrainedConfig
import torch.nn.functional as F
from text_generation_server.layers import (
SpeculativeHead,
TensorParallelEmbedding,
FastLinear,
)
from text_generation_server.layers.layernorm import FastRMSNorm
from einops import rearrange
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
import math
from dataclasses import dataclass
@dataclass
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_seqlen: int
max_batch_size: int
conv_states: torch.Tensor
ssm_states: torch.Tensor
seqlen_offset: int
class MambaConfig(PretrainedConfig):
def __init__(
self,
vocab_size=50280,
d_model=768,
d_state=16,
n_layer=32,
layer_norm_epsilon=1e-5,
tie_word_embeddings=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
expand=2,
dt_rank="auto",
**kwargs,
):
self.vocab_size = vocab_size
self.n_layer = n_layer
self.layer_norm_epsilon = layer_norm_epsilon
self.d_model = d_model
self.d_inner = d_model * 2
self.d_conv = 4
self.d_state = d_state
self.expand = expand
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class MambaBlock(nn.Module):
def __init__(self, prefix, config, weights, layer_id):
super().__init__()
self.layer_id = layer_id
self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False)
self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False)
self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True)
self.dt_proj_no_bias = FastLinear.load(
config, f"{prefix}.dt_proj", weights, bias=False
)
self.out_proj = FastLinear.load(
config, f"{prefix}.out_proj", weights, bias=False
)
self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True)
self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float())
self.D = weights.get_tensor(f"{prefix}.D")
self.activation = "silu"
self.dt_rank = config.dt_rank
self.d_state = config.d_state
self.d_conv = config.d_conv
self.act = nn.SiLU()
# inference_params
def forward(self, hidden_states: torch.Tensor, inference_params=None):
if inference_params.seqlen_offset > 0:
conv_state = inference_params.conv_states[self.layer_id]
ssm_state = inference_params.ssm_states[self.layer_id]
out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state)
return out, conv_state, ssm_state
_, seqlen, _ = hidden_states.shape
projected_states = self.in_proj(hidden_states).transpose(1, 2)
# assert projected_states.shape == [batch_size, 2 * dstate, seqlen], f"{projected_states.shape} [{batch_size}, {dstate}, {seqlen}]"
x, z = projected_states.chunk(2, dim=1)
conv_state = F.pad(x, (self.d_conv - seqlen, 0))
x = causal_conv1d_fn(
x=x,
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
)
# We're careful here about the layout, to avoid extra transposes.
# We want dt to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
dt, B, C = torch.split(
x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
)
dt = self.dt_proj.weight @ dt.t()
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
y, last_state = selective_scan_fn(
x,
dt,
self.negA,
B,
C,
self.D.float(),
z=z,
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=True,
)
y = rearrange(y, "b d l -> b l d")
attn_outputs = self.out_proj(y)
return attn_outputs, conv_state, last_state
def step(self, hidden_states, conv_state, ssm_state):
xz = self.in_proj(hidden_states.squeeze(1))
x, z = xz.chunk(2, dim=-1) # (B D)
x = causal_conv1d_update(
x,
conv_state,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.activation,
)
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = F.linear(dt, self.dt_proj.weight)
A = self.negA
y = selective_state_update(
ssm_state,
x,
dt,
A,
B,
C,
self.D,
z=z,
dt_bias=self.dt_proj.bias,
dt_softplus=True,
)
out = self.out_proj(y)
return out.unsqueeze(1), conv_state.clone(), ssm_state.clone()
class ResidualBlock(nn.Module):
def __init__(self, prefix, config, weights, layer_id):
super().__init__()
self.mamba_block = MambaBlock(
prefix=f"{prefix}.mixer", config=config, weights=weights, layer_id=layer_id
)
self.layer_norm = FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_epsilon
)
def forward(
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = None,
inference_params: Optional[Any] = None,
):
residual = (hidden_states + residual) if residual is not None else hidden_states
shape = residual.shape
hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1]))
hidden_states, conv_state, last_ssm_state = self.mamba_block(
hidden_states.view(*shape), inference_params
)
return hidden_states, residual, conv_state, last_ssm_state
class MambaModel(nn.Module):
def __init__(self, config, weights):
super().__init__()
prefix = "backbone"
try:
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embeddings", weights)
except RuntimeError:
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
self.blocks = nn.ModuleList(
[
ResidualBlock(f"{prefix}.layers.{i}", config, weights, layer_id=i)
for i in range(config.n_layer)
]
)
self.norm_f = FastRMSNorm.load(
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
)
try:
self.lm_head = SpeculativeHead.load(config, f"{prefix}.embeddings", weights)
except RuntimeError:
self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights)
self.config = config
def forward(
self, input_ids: torch.Tensor, inference_params=None, residual=None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.embed_tokens(input_ids)
for i, block in enumerate(self.blocks):
hidden_states, residual, conv_state, ssm_state = block(
hidden_states, residual, inference_params
)
inference_params.conv_states[i].copy_(conv_state)
inference_params.ssm_states[i].copy_(ssm_state)
hidden_states = (
hidden_states + residual if residual is not None else hidden_states
)
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
hidden_states = hidden_states.view(residual.shape)
logits, speculative_logits = self.lm_head(hidden_states)
# update the offset for the next inference using these params
inference_params.seqlen_offset += input_ids.size(1)
return logits, speculative_logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py
================================================
# coding=utf-8
# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Qwen2.5 VL model."""
from typing import Optional, Tuple, List
import torch
import torch.utils.checkpoint
from torch import nn
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA
import numpy as np
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
import torch.nn.functional as F
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
TensorParallelEmbedding,
SpeculativeHead,
)
from text_generation_server.layers.attention import (
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2Model,
)
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
apply_rotary_pos_emb,
)
import habana_frameworks.torch as htorch
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py
from typing import Union
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.video_utils import VideoInput
from transformers.processing_utils import (
ProcessingKwargs,
ProcessorMixin,
Unpack,
VideosKwargs,
)
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False):
fps: Union[List[float], float]
class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
videos_kwargs: Qwen2_5_VLVideosProcessorKwargs
_defaults = {
"text_kwargs": {
"padding": False,
},
"videos_kwargs": {"fps": 2.0},
}
class Qwen2_5_VLProcessor(ProcessorMixin):
r"""
Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor.
[`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
[`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information.
Args:
image_processor ([`Qwen2VLImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`Qwen2TokenizerFast`], *optional*):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
def __init__(
self, image_processor=None, tokenizer=None, chat_template=None, **kwargs
):
self.image_token = (
"<|image_pad|>"
if not hasattr(tokenizer, "image_token")
else tokenizer.image_token
)
self.video_token = (
"<|video_pad|>"
if not hasattr(tokenizer, "video_token")
else tokenizer.video_token
)
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__(
self,
images: ImageInput = None,
text: Union[
TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
] = None,
videos: VideoInput = None,
**kwargs: Unpack[Qwen2_5_VLProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
- **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
"""
output_kwargs = self._merge_kwargs(
Qwen2_5_VLProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None:
image_inputs = self.image_processor(
images=images, videos=None, **output_kwargs["images_kwargs"]
)
image_grid_thw = image_inputs["image_grid_thw"]
else:
image_inputs = {}
image_grid_thw = None
if videos is not None:
videos_inputs = self.image_processor(
images=None, videos=videos, **output_kwargs["images_kwargs"]
)
video_grid_thw = videos_inputs["video_grid_thw"]
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
if isinstance(fps, (int, float)):
second_per_grid_ts = [
self.image_processor.temporal_patch_size / fps
] * len(video_grid_thw)
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
second_per_grid_ts = [
self.image_processor.temporal_patch_size / tmp for tmp in fps
]
else:
raise ValueError(
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
)
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})
else:
videos_inputs = {}
video_grid_thw = None
if not isinstance(text, list):
text = [text]
if image_grid_thw is not None:
merge_length = self.image_processor.merge_size**2
index = 0
for i in range(len(text)):
while self.image_token in text[i]:
text[i] = text[i].replace(
self.image_token,
"<|placeholder|>"
* (image_grid_thw[index].prod() // merge_length),
1,
)
index += 1
text[i] = text[i].replace("<|placeholder|>", self.image_token)
if video_grid_thw is not None:
merge_length = self.image_processor.merge_size**2
index = 0
for i in range(len(text)):
while self.video_token in text[i]:
text[i] = text[i].replace(
self.video_token,
"<|placeholder|>"
* (video_grid_thw[index].prod() // merge_length),
1,
)
index += 1
text[i] = text[i].replace("<|placeholder|>", self.video_token)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
names_from_processor = list(
dict.fromkeys(tokenizer_input_names + image_processor_input_names)
)
return names_from_processor + ["second_per_grid_ts"]
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py
class Qwen2_5_VLVisionConfig(PretrainedConfig):
model_type = "qwen2_5_vl"
base_config_key = "vision_config"
def __init__(
self,
depth=32,
hidden_size=3584,
hidden_act="silu",
intermediate_size=3420,
num_heads=16,
in_channels=3,
patch_size=14,
spatial_merge_size=2,
spatial_patch_size=14,
temporal_patch_size=2,
tokens_per_second=4,
window_size=112,
out_hidden_size=3584,
fullatt_block_indexes=[7, 15, 23, 31],
**kwargs,
):
super().__init__(**kwargs)
self.depth = depth
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.num_heads = num_heads
self.in_channels = in_channels
self.patch_size = patch_size
self.spatial_patch_size = spatial_patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
self.tokens_per_second = tokens_per_second
self.window_size = window_size
self.fullatt_block_indexes = fullatt_block_indexes
self.out_hidden_size = out_hidden_size
class Qwen2_5_VLConfig(PretrainedConfig):
def __init__(
self,
vocab_size=152064,
hidden_size=8192,
intermediate_size=29568,
num_hidden_layers=80,
num_attention_heads=64,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
tie_word_embeddings=False,
rope_theta=1000000.0,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=80,
attention_dropout=0.0,
vision_config=None,
rope_scaling=None,
**kwargs,
):
if vision_config is not None:
self.vision_config = Qwen2_5_VLVisionConfig(**vision_config)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.max_window_layers = max_window_layers
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.rope_scaling = rope_scaling
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
# and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
# TODO: @raushan update config in the hub
if self.rope_scaling is not None and "type" in self.rope_scaling:
if self.rope_scaling["type"] == "mrope":
self.rope_scaling["type"] = "default"
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
class Qwen2_5VLAttention(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size // weights.process_group.size()
self.head_dim = config.hidden_size // config.num_heads
self.num_heads = config.num_heads // weights.process_group.size()
self.qkv = TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.qkv",
weights=weights,
bias=False,
num_heads=self.num_heads,
num_key_value_heads=self.num_heads,
)
self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
self.proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.proj",
weights=weights,
bias=True,
)
self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)
def forward(
self,
hidden_state: torch.Tensor,
cu_seqlens: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
max_seqlen: int,
) -> torch.Tensor:
# apply the qkv linear layer to the hidden state
qkv = self.qkv(hidden_state)
query, key, value = qkv.split(
[self.embed_dim, self.embed_dim, self.embed_dim], dim=1
)
# reshape the query, key, and value tensors
_shape = (
hidden_state.shape[0],
self.num_heads,
self.embed_dim // self.num_heads,
)
query = query.view(*_shape)
key = key.view(*_shape)
value = value.view(*_shape)
# apply rotary positional embeddings
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
rotary_dim = cos.shape[-1]
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape))
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
# execute sdpa
causal = False
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
attention_mask = torch.zeros(
[1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool
)
for i in range(1, len(cu_seqlens)):
attention_mask[
:, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]
] = True
attn_output = fsdpa_op(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=causal,
scale=None,
softmax_mode="None",
recompute_mode=None,
valid_sequence_lengths=None,
)
attn_output = attn_output.transpose(0, 1)
# reshape output to original dimensions
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
attn_output = self.proj(attn_output)
return attn_output
class Qwen2_5VLVisionMLP(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.activation_fn = ACT2FN[config.hidden_act]
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
self.up = TensorParallelColumnLinear.load(
prefix=f"{prefix}.up_proj", weights=weights, config=config, bias=True
)
self.gate = TensorParallelColumnLinear.load(
prefix=f"{prefix}.gate_proj", weights=weights, config=config, bias=True
)
self.down = TensorParallelRowLinear.load(
prefix=f"{prefix}.down_proj", weights=weights, config=config, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
gate_states = self.gate(hidden_states)
up_states = self.up(hidden_states)
activated_states = self.activation_fn(gate_states) * up_states
down_states = self.down(activated_states)
return down_states
class Qwen2_5VLVisionBlock(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.attn = Qwen2_5VLAttention(
prefix=f"{prefix}.attn",
config=config,
weights=weights,
)
self.norm1 = FastRMSNorm.load(
prefix=f"{prefix}.norm1",
weights=weights,
eps=1e-6,
)
self.norm2 = FastRMSNorm.load(
prefix=f"{prefix}.norm2",
weights=weights,
eps=1e-6,
)
self.mlp = Qwen2_5VLVisionMLP(
prefix=f"{prefix}.mlp",
config=config,
weights=weights,
)
def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor:
norm1_out, _ = self.norm1(hidden_states)
attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen)
hidden_states = hidden_states + attn_out
norm2_out, _ = self.norm2(hidden_states)
mlp_out = self.mlp(norm2_out)
hidden_states = hidden_states + mlp_out
return hidden_states
class Qwen2_5VLPatchMerger(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
self.patch_merger_ln_q = FastRMSNorm.load(
prefix=f"{prefix}.ln_q",
weights=weights,
eps=1e-6,
)
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
)
def forward(self, hidden_states) -> torch.Tensor:
hidden_states, _ = self.patch_merger_ln_q(hidden_states)
hidden_states = hidden_states.view(-1, self.hidden_size)
hidden_states = self.fc1(hidden_states)
hidden_states = F.gelu(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Qwen2_5VisionModel(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.spatial_merge_size = config.spatial_merge_size
kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]
self.patch_embedding = nn.Conv3d(
in_channels=config.in_channels,
out_channels=config.hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=False,
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False
)
head_dim = config.hidden_size // config.num_heads
theta = 10000.0
dim = head_dim // 2
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.blocks = nn.ModuleList(
[
Qwen2_5VLVisionBlock(
prefix=f"{prefix}.blocks.{i}",
config=config,
weights=weights,
)
for i in range(config.depth)
]
)
self.merger = Qwen2_5VLPatchMerger(
prefix=f"{prefix}.merger",
config=config,
weights=weights,
)
self.temporal_patch_size = config.temporal_patch_size
self.spatial_patch_size = config.spatial_patch_size
self.in_channels = config.in_channels
self.embed_dim = config.hidden_size
self.window_size = config.window_size
self.patch_size = config.patch_size
self.spatial_merge_unit = config.spatial_merge_size * config.spatial_merge_size
self.fullatt_block_indexes = config.fullatt_block_indexes
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
batch_size, _, hidden_size = hidden_state.shape
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
return hidden_state
def get_window_index(self, grid_thw):
window_index: list = []
cu_window_seqlens: list = [0]
window_index_id = 0
vit_merger_window_size = (
self.window_size // self.spatial_merge_size // self.patch_size
)
for grid_t, grid_h, grid_w in grid_thw:
llm_grid_h, llm_grid_w = (
grid_h // self.spatial_merge_size,
grid_w // self.spatial_merge_size,
)
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
grid_t, llm_grid_h, llm_grid_w
)
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
index_padded = index_padded.reshape(
grid_t,
num_windows_h,
vit_merger_window_size,
num_windows_w,
vit_merger_window_size,
)
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
grid_t,
num_windows_h * num_windows_w,
vit_merger_window_size,
vit_merger_window_size,
)
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
index_padded = index_padded.reshape(-1)
index_new = index_padded[index_padded != -100]
window_index.append(index_new + window_index_id)
cu_seqlens_tmp = (
seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
)
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
window_index = torch.cat(window_index, dim=0)
return window_index, cu_window_seqlens
def forward(
self,
pixel_values: torch.Tensor,
grid_thw: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
# reshape the input tensor for processing
shape = (
-1,
self.in_channels,
self.temporal_patch_size,
self.spatial_patch_size,
self.spatial_patch_size,
)
pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)
hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)
# TODO: revisit to see if we can avoid some of these reshapes
# find the position ids for the input tensor based on the grid_thw
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
# apply the positional embeddings to the position ids
seq = torch.arange(
max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
seq_len = hidden_states.shape[0]
patch_shape = (seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
og_shape = (seq_len, -1)
hidden_states = hidden_states.view(patch_shape)[window_index, :, :].view(
og_shape
)
rotary_pos_emb = rotary_pos_emb.view(patch_shape)[window_index, :, :].view(
og_shape
)
rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device)
cos = rotary_pos_emb.cos()
sin = rotary_pos_emb.sin()
cos = torch.cat((cos, cos), dim=-1).unsqueeze(1)
sin = torch.cat((sin, sin), dim=-1).unsqueeze(1)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
device="cpu",
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens).to(
hidden_states.device
)
# create a cu_seqlens tensor to be used in the attention mask
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
# iterately apply the blocks to the hidden states
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for layer_num, block in enumerate(self.blocks):
# NOTE: qwen2_5_vl.py has a concept of full attention blocks
# that are applied at specific layers.
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
hidden_states = block(hidden_states, cu_seqlens_now, cos, sin, max_seqlen)
if lazy_mode:
htorch.core.mark_step()
# apply the final patch merger to the hidden states
hidden_states = self.merger(hidden_states)
reverse_indices = torch.argsort(window_index)
hidden_states = hidden_states[reverse_indices, :]
return hidden_states
class Qwen2_5VLForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
config.vision_config.quantize = None
config.vision_config.speculator = config.speculator
# set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly
# returns rope_scaling.type == "default" for Qwen2_5-VL model at the moment
if (
hasattr(config, "rope_scaling")
and config.rope_scaling is not None
and config.rope_scaling.get("type", None) == "default"
):
config.rope_scaling.update({"rope_type": "mrope"})
self.hidden_size = config.hidden_size
self.vision_start_token_id = config.vision_start_token_id
self.vision_end_token_id = config.vision_end_token_id
self.image_token_id = config.image_token_id
self.video_token_id = config.video_token_id
self.spatial_merge_size = config.vision_config.spatial_merge_size
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.visual = Qwen2_5VisionModel(
prefix="visual", config=config.vision_config, weights=weights
)
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
else:
suffix = "lm_head"
self.lm_head = SpeculativeHead.load(
config,
prefix=suffix if not prefix else f"{prefix}.{suffix}",
weights=weights,
)
self.device = weights.device
# based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391
# modified to first find segments then initialize position ids for each segment
# Steps:
# locate all vision and text segments
# calculate `vision_segment_lengths` for each vision segment to be use as offset
# calculate `text_segment_lengths` for each text segment to be used as offset
# create position ids for each vision segment based on the image grid
# create position ids for each text segment
# combine all the position ids
# the final segment is the difference between the last vision segment and the end of the input
# combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
def get_position_ids(
self,
input_ids: torch.Tensor,
image_grid_thw: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if image_grid_thw is None:
return (
torch.arange(input_ids.shape[0], device=input_ids.device)
.unsqueeze(1)
.repeat(1, 3)
)
spatial_merge_size = self.spatial_merge_size
vision_start_token_id = self.vision_start_token_id
vision_end_token_id = self.vision_end_token_id
device = input_ids.device
dtype = input_ids.dtype
input_ids_len = input_ids.shape[0]
vision_starts = torch.where(input_ids == vision_start_token_id)[0]
vision_ends = torch.where(input_ids == vision_end_token_id)[0]
vision_segments = torch.stack((vision_starts, vision_ends), dim=1)
prev_vision_end = torch.cat(
[torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]
)
text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1
vision_widths_max = torch.cat(
[
torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
image_grid_thw[:-1, 2] // spatial_merge_size,
]
)
vision_segment_lengths = vision_widths_max + text_lengths_between_vision
vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)
text_segment_lengths = vision_segment_lengths - text_lengths_between_vision
# create position ids for each vision segment based on the image grid
llm_pos_ids_list = []
for i, _ in enumerate(vision_segments):
t, h, w = (
image_grid_thw[i][0],
image_grid_thw[i][1] // spatial_merge_size,
image_grid_thw[i][2] // spatial_merge_size,
)
t_indices = torch.arange(t, device=device).repeat_interleave(h * w)
h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
w_indices = torch.arange(w, device=device).repeat(t * h)
image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
# offset by the position of the last vision segment
im = image_position_ids + vision_segment_lengths[i]
llm_pos_ids_list.append(im)
# create position ids for each text segment
text_ranges = [
torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
+ text_segment_lengths[i]
for i, seq_len in enumerate(text_lengths_between_vision)
]
full_llm_pos_ids_list = [
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
]
max_s = full_llm_pos_ids_list[-1].max() + 1
final_text_len = input_ids_len - vision_ends[-1]
if final_text_len > 0:
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
full_llm_pos_ids_list.append(m + max_s)
position_ids = (
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
)
return position_ids
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
return image_embeds
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if vision_embeds is not None:
mask = torch.where(input_ids == self.image_token_id)
inputs_embeds[mask] = vision_embeds
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor],
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_indices=None,
):
hidden_states = self.text_model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py
================================================
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Qwen2 VL model."""
from typing import Optional, Tuple, List
import torch
import torch.utils.checkpoint
from torch import nn
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA
import numpy as np
from transformers.activations import ACT2FN
import torch.nn.functional as F
from text_generation_server.layers.layernorm import FastLayerNorm, FastRMSNorm
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
TensorParallelEmbedding,
SpeculativeHead,
)
from text_generation_server.layers.attention import (
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2Model,
)
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
apply_rotary_pos_emb,
)
import habana_frameworks.torch as htorch
class Qwen2VLAttention(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.embed_dim = config.embed_dim // weights.process_group.size()
self.head_dim = config.hidden_size // config.num_heads
self.num_heads = config.num_heads // weights.process_group.size()
self.qkv = TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.qkv",
weights=weights,
bias=False,
num_heads=self.num_heads,
num_key_value_heads=self.num_heads,
)
self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
self.proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.proj",
weights=weights,
bias=True,
)
self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)
def forward(
self,
hidden_state: torch.Tensor,
cu_seqlens: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
max_seqlen: int,
) -> torch.Tensor:
# apply the qkv linear layer to the hidden state
qkv = self.qkv(hidden_state)
query, key, value = qkv.split(
[self.embed_dim, self.embed_dim, self.embed_dim], dim=1
)
# reshape the query, key, and value tensors
_shape = (
hidden_state.shape[0],
self.num_heads,
self.embed_dim // self.num_heads,
)
query = query.view(*_shape)
key = key.view(*_shape)
value = value.view(*_shape)
# apply rotary positional embeddings
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
rotary_dim = cos.shape[-1]
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape))
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))
# execute sdpa
causal = False
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
attention_mask = torch.zeros(
[1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool
)
for i in range(1, len(cu_seqlens)):
attention_mask[
:, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]
] = True
attn_output = fsdpa_op(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=causal,
scale=None,
softmax_mode="None",
recompute_mode=None,
valid_sequence_lengths=None,
)
attn_output = attn_output.transpose(0, 1)
# reshape output to original dimensions
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
attn_output = self.proj(attn_output)
return attn_output
class Qwen2VLVisionMLP(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Qwen2VLVisionBlock(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.attn = Qwen2VLAttention(
prefix=f"{prefix}.attn",
config=config,
weights=weights,
)
self.norm1 = FastLayerNorm.load(
prefix=f"{prefix}.norm1",
weights=weights,
eps=1e-6,
)
self.norm2 = FastLayerNorm.load(
prefix=f"{prefix}.norm2",
weights=weights,
eps=1e-6,
)
self.mlp = Qwen2VLVisionMLP(
prefix=f"{prefix}.mlp",
config=config,
weights=weights,
)
def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor:
norm1_out, residual = self.norm1(hidden_states)
attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen)
hidden_states = attn_out + residual
norm2_out, residual = self.norm2(hidden_states)
hidden_states = hidden_states + self.mlp(norm2_out)
return hidden_states
class Qwen2VLPatchMerger(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.hidden_size = config.embed_dim * (config.spatial_merge_size**2)
self.patch_merger_ln_q = FastLayerNorm.load(
prefix=f"{prefix}.ln_q",
weights=weights,
eps=1e-6,
)
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
)
def forward(self, hidden_states) -> torch.Tensor:
hidden_states, _ = self.patch_merger_ln_q(hidden_states)
hidden_states = hidden_states.view(-1, self.hidden_size)
hidden_states = self.fc1(hidden_states)
hidden_states = F.gelu(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Qwen2VisionModel(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.spatial_merge_size = config.spatial_merge_size
kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]
self.patch_embedding = nn.Conv3d(
in_channels=config.in_chans,
out_channels=config.embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
bias=False,
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False
)
head_dim = config.embed_dim // config.num_heads
# TODO: replace with static positional embeddings once implemented
theta = 10000.0
dim = head_dim // 2
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.blocks = nn.ModuleList(
[
Qwen2VLVisionBlock(
prefix=f"{prefix}.blocks.{i}",
config=config,
weights=weights,
)
for i in range(config.depth)
]
)
self.merger = Qwen2VLPatchMerger(
prefix=f"{prefix}.merger",
config=config,
weights=weights,
)
self.temporal_patch_size = config.temporal_patch_size
self.spatial_patch_size = config.spatial_patch_size
self.in_channels = config.in_channels
self.embed_dim = config.embed_dim
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
batch_size, _, hidden_size = hidden_state.shape
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
return hidden_state
def forward(
self,
pixel_values: torch.Tensor,
grid_thw: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
# reshape the input tensor for processing
shape = (
-1,
self.in_channels,
self.temporal_patch_size,
self.spatial_patch_size,
self.spatial_patch_size,
)
pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)
hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)
# TODO: revisit to see if we can avoid some of these reshapes
# find the position ids for the input tensor based on the grid_thw
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
# apply the positional embeddings to the position ids
seq = torch.arange(
max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype)
cos = rotary_pos_emb.cos()
sin = rotary_pos_emb.sin()
cos = torch.cat((cos, cos), dim=-1).unsqueeze(1)
sin = torch.cat((sin, sin), dim=-1).unsqueeze(1)
# create a cu_seqlens tensor to be used in the attention mask
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
# iterately apply the blocks to the hidden states
lazy_mode = htorch.utils.internal.is_lazy()
if lazy_mode:
htorch.core.mark_step()
for block in self.blocks:
hidden_states = block(hidden_states, cu_seqlens, cos, sin, max_seqlen)
if lazy_mode:
htorch.core.mark_step()
# apply the final patch merger to the hidden states
hidden_states = self.merger(hidden_states)
return hidden_states
class Qwen2VLForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
config.vision_config.quantize = None
config.vision_config.speculator = config.speculator
# set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly
# returns rope_scaling.type == "default" for Qwen2-VL model at the moment
if (
hasattr(config, "rope_scaling")
and config.rope_scaling is not None
and config.rope_scaling.get("type", None) == "default"
):
config.rope_scaling.update({"rope_type": "mrope"})
self.hidden_size = config.hidden_size
self.vision_start_token_id = config.vision_start_token_id
self.vision_end_token_id = config.vision_end_token_id
self.image_token_id = config.image_token_id
self.video_token_id = config.video_token_id
self.spatial_merge_size = config.vision_config.spatial_merge_size
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.visual = Qwen2VisionModel(
prefix="visual", config=config.vision_config, weights=weights
)
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
else:
suffix = "lm_head"
self.lm_head = SpeculativeHead.load(
config,
prefix=suffix if not prefix else f"{prefix}.{suffix}",
weights=weights,
)
self.norm = FastRMSNorm.load(
prefix="model.norm",
weights=weights,
eps=config.rms_norm_eps,
)
self.device = weights.device
# based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391
# modified to first find segments then initialize position ids for each segment
# Steps:
# locate all vision and text segments
# calculate `vision_segment_lengths` for each vision segment to be use as offset
# calculate `text_segment_lengths` for each text segment to be used as offset
# create position ids for each vision segment based on the image grid
# create position ids for each text segment
# combine all the position ids
# the final segment is the difference between the last vision segment and the end of the input
# combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
def get_position_ids(
self,
input_ids: torch.Tensor,
image_grid_thw: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if image_grid_thw is None:
return (
torch.arange(input_ids.shape[0], device=input_ids.device)
.unsqueeze(1)
.repeat(1, 3)
)
spatial_merge_size = self.spatial_merge_size
vision_start_token_id = self.vision_start_token_id
vision_end_token_id = self.vision_end_token_id
device = input_ids.device
dtype = input_ids.dtype
input_ids_len = input_ids.shape[0]
vision_starts = torch.where(input_ids == vision_start_token_id)[0]
vision_ends = torch.where(input_ids == vision_end_token_id)[0]
vision_segments = torch.stack((vision_starts, vision_ends), dim=1)
prev_vision_end = torch.cat(
[torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]
)
text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1
vision_widths_max = torch.cat(
[
torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
image_grid_thw[:-1, 2] // spatial_merge_size,
]
)
vision_segment_lengths = vision_widths_max + text_lengths_between_vision
vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)
text_segment_lengths = vision_segment_lengths - text_lengths_between_vision
# create position ids for each vision segment based on the image grid
llm_pos_ids_list = []
for i, _ in enumerate(vision_segments):
t, h, w = (
image_grid_thw[i][0],
image_grid_thw[i][1] // spatial_merge_size,
image_grid_thw[i][2] // spatial_merge_size,
)
t_indices = torch.arange(t, device=device).repeat_interleave(h * w)
h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
w_indices = torch.arange(w, device=device).repeat(t * h)
image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
# offset by the position of the last vision segment
im = image_position_ids + vision_segment_lengths[i]
llm_pos_ids_list.append(im)
# create position ids for each text segment
text_ranges = [
torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
+ text_segment_lengths[i]
for i, seq_len in enumerate(text_lengths_between_vision)
]
full_llm_pos_ids_list = [
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
]
max_s = full_llm_pos_ids_list[-1].max() + 1
final_text_len = input_ids_len - vision_ends[-1]
if final_text_len > 0:
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
full_llm_pos_ids_list.append(m + max_s)
position_ids = (
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
)
return position_ids
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
return image_embeds
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if vision_embeds is not None:
mask = torch.where(input_ids == self.image_token_id)
inputs_embeds[mask] = vision_embeds
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor],
attention_mask: Optional[torch.BoolTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_indices=None,
):
hidden_states = self.text_model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/siglip.py
================================================
from typing import Optional, Tuple
import warnings
import math
import torch
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutputWithPooling,
)
from transformers import SiglipConfig, SiglipVisionConfig
from torch.nn.init import _calculate_fan_in_and_fan_out
from text_generation_server.layers.tensor_parallel import (
TensorParallelEmbedding,
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, prefix, config: SiglipVisionConfig, weights):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
)
self.patch_embedding.bias = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.position_embedding = TensorParallelEmbedding(
prefix=f"{prefix}.position_embedding", weights=weights
)
self.register_buffer(
"position_ids",
torch.arange(self.num_positions, device=weights.device).expand((1, -1)),
persistent=False,
)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
patch_embeds = self.patch_embedding(
pixel_values
) # shape = [*, width, grid, grid]
embeddings = patch_embeds.flatten(2).transpose(1, 2)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
class SiglipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.head_size = self.head_dim
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.num_heads = self.num_heads // weights.process_group.size()
self.embed_dim = self.embed_dim // weights.process_group.size()
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.k_proj", weights=weights, bias=True
)
self.v_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.v_proj", weights=weights, bias=True
)
self.q_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.q_proj", weights=weights, bias=True
)
self.out_proj = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return (
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
# scale post matmul
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) * self.scale
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = (
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attention_mask
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(attn_weights.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class SiglipMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load( # config.hidden_size, config.intermediate_size
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
)
self.fc2 = TensorParallelRowLinear.load( # config.intermediate_size, config.hidden_size
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class SiglipEncoderLayer(nn.Module):
def __init__(self, prefix, config: SiglipConfig, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = SiglipAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.layer_norm1 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
)
self.mlp = SiglipMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.layer_norm2 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
) -> Tuple[torch.FloatTensor]:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, None
class SiglipMultiheadAttentionPoolingHead(nn.Module):
"""Multihead Attention Pooling."""
def __init__(self, prefix, config: SiglipVisionConfig, weights):
super().__init__()
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.attention = torch.nn.MultiheadAttention(
config.hidden_size, config.num_attention_heads, batch_first=True
)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(prefix, config, weights)
def forward(self, hidden_state):
batch_size = hidden_state.shape[0]
probe = self.probe.repeat(batch_size, 1, 1)
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
residual = hidden_state
hidden_state = self.layernorm(hidden_state)
hidden_state = residual + self.mlp(hidden_state)
return hidden_state[:, 0]
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
lower = norm_cdf((a - mean) / std)
upper = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * lower - 1, 2 * upper - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
def trunc_normal_tf_(
tensor: torch.Tensor,
mean: float = 0.0,
std: float = 1.0,
a: float = -2.0,
b: float = 2.0,
) -> torch.Tensor:
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \\leq \text{mean} \\leq b`.
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
and the result is subsquently scaled and shifted by the mean and std args.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
"""
with torch.no_grad():
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean)
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == "fan_in":
denom = fan_in
elif mode == "fan_out":
denom = fan_out
elif mode == "fan_avg":
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
elif distribution == "normal":
with torch.no_grad():
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
with torch.no_grad():
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
def default_flax_embed_init(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="normal")
class SiglipEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`SiglipEncoderLayer`].
Args:
config: SiglipConfig
"""
def __init__(self, prefix, config: SiglipConfig, weights):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[
SiglipEncoderLayer(
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
)
for i in range(config.num_hidden_layers)
]
)
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
):
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
hidden_states, _ = encoder_layer(
hidden_states,
attention_mask,
)
return hidden_states
class SiglipVisionTransformer(nn.Module):
def __init__(self, prefix, config: SiglipVisionConfig, weights):
super().__init__()
self.config = config
self.embeddings = SiglipVisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights
)
self.encoder = SiglipEncoder(
prefix=f"{prefix}.encoder", config=config, weights=weights
)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
):
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values)
# NOTE: up until this point, the code logits are exactly
# the same as the transformers code. The values evaulate
# slightly differently in our encoder layer.
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
)
last_hidden_state = encoder_outputs
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
# pooler_output=pooled_output,
# hidden_states=encoder_outputs,
)
================================================
FILE: backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py
================================================
def load_text_model(prefix, config, weights, name=None):
if config.model_type == "llama":
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
return FlashLlamaForCausalLM(prefix, config, weights, name=name)
elif config.model_type == "mistral":
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
return FlashMistralForCausalLM(prefix, config, weights, name=name)
elif config.model_type == "gemma":
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
)
return FlashGemmaForCausalLM(prefix, config, weights)
elif config.model_type == "gemma2":
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma2ForCausalLM,
)
return FlashGemma2ForCausalLM(prefix, config, weights)
elif config.model_type == "gemma3" or config.model_type == "gemma3_text":
from text_generation_server.models.custom_modeling.flash_gemma3_modeling import (
FlashGemma3ForCausalLM,
)
return FlashGemma3ForCausalLM(prefix, config, weights)
elif config.model_type == "paligemma":
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
)
return FlashGemmaForCausalLM(prefix, config, weights)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
def load_vision_model(prefix, config, weights):
if config.model_type == "clip_vision_model":
from text_generation_server.models.custom_modeling.clip import (
CLIPVisionTransformer,
)
return CLIPVisionTransformer(
prefix=f"{prefix}.vision_model", config=config, weights=weights
)
if (
config.model_type == "siglip_vision_model"
or config.model_type == "gemma3_vision"
):
from text_generation_server.models.custom_modeling.siglip import (
SiglipVisionTransformer,
)
# TODO: ensure that using the prefix doesn't break any existing models
# that rely on the old prefix (update the old models if necessary)
return SiglipVisionTransformer(
prefix=f"{prefix}.vision_model",
config=config,
weights=weights,
)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
================================================
FILE: backends/gaudi/server/text_generation_server/models/flash_causal_lm.py
================================================
import math
import os
import time
import torch
import torch.distributed
import numpy as np
from loguru import logger
from dataclasses import dataclass
from opentelemetry import trace
from transformers import (
PreTrainedTokenizerBase,
AutoConfig,
AutoTokenizer,
GenerationConfig,
)
from typing import (
Any,
Iterable,
Optional,
Tuple,
List,
Type,
Dict,
Union,
)
import torch.nn.functional as F
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.models import Model
from text_generation_server.utils.log import log_master
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
pad_next_token_chooser_parameters,
)
from text_generation_server.models.types import (
Batch,
Tokens,
Generation,
GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.models.globals import (
BLOCK_SIZE,
REQUEST_LOGPROBS,
TGI_WIGGLE_ROOM,
get_adapter_to_index,
)
from text_generation_server.layers.attention import (
KVCache,
KVCompressCache,
Seqlen,
HPUPagedAttentionMetadata,
trim_attn_metadata,
trim_seqlen_metadata,
_async_h2d_tensor_copy,
)
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
from text_generation_server.utils.import_utils import (
empty_cache,
synchronize,
get_free_memory,
)
from text_generation_server.utils.prefill_chunking import (
get_max_prefill_tokens,
)
import vllm_hpu_extension.environment as environment
import habana_frameworks.torch as htorch
import itertools
from vllm_hpu_extension.bucketing.common import get_bucketing_context
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
tracer = trace.get_tracer(__name__)
def generate_block_metadata(
dtype,
use_contiguous_pa,
slots,
block_tables,
bucketing_ctx,
slots_in_window=None,
block_bucket_size=None,
):
# Prepare values if we need to continue decoding
# need for HPUPagedAttentionMetadata preparation
def flatten(in_list):
return list(itertools.chain(*in_list))
def gather_list(input, indices, v):
return [input[i] if i is not None else v for i in indices]
def pad_list(input, k, v):
input_len = len(input)
target_len = (input_len + k - 1) // k * k
padding = target_len - input_len
return input + [v] * padding
last_block_usage = [slot % BLOCK_SIZE + 1 for slot in slots]
block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)]
block_usage = [
[BLOCK_SIZE] * (len(bt) - 1) + [lbu]
for bt, lbu in zip(block_tables, last_block_usage)
if bt
]
block_list = flatten(block_tables)
block_groups = flatten(block_groups)
block_usage = flatten(block_usage)
assert len(block_list) == len(block_groups)
assert len(block_list) == len(block_usage)
if use_contiguous_pa:
if block_bucket_size is None:
block_bucket_size = max(max(block_list) + 1, len(block_list))
if bucketing_ctx is not None:
block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks(
block_bucket_size
)
indices: List[Any]
indices = [None] * block_bucket_size
for i, bid in enumerate(block_list):
indices[bid] = i
block_list = gather_list(block_list, indices, 0)
block_groups = gather_list(block_groups, indices, -1)
block_usage = gather_list(block_usage, indices, 1)
else:
if block_bucket_size is None:
block_bucket_size = len(block_list)
if bucketing_ctx is not None:
block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks(
block_bucket_size
)
block_list = pad_list(block_list, block_bucket_size, 0)
block_groups = pad_list(block_groups, block_bucket_size, -1)
block_usage = pad_list(block_usage, block_bucket_size, 1)
slots_in_window_mask = None
if slots_in_window is not None:
slot_list = [
block_id * BLOCK_SIZE + slot_idx
for block_id in block_list
for slot_idx in range(BLOCK_SIZE)
]
slot_list = torch.tensor(slot_list, dtype=torch.int64)
slot_list = slot_list.view(-1, BLOCK_SIZE)
slots_in_window_mask = torch.isin(slot_list, slots_in_window)
for i in range(slots_in_window_mask.shape[0]):
if not slots_in_window_mask[i].any():
slots_in_window_mask[i, 0] = True
block_list = torch.tensor(block_list, dtype=torch.int, device="cpu")
block_groups = torch.tensor(block_groups, dtype=torch.int, device="cpu")
block_usage = torch.tensor(block_usage, dtype=dtype, device="cpu")
return (
block_list,
block_groups,
block_usage,
slots_in_window_mask,
block_bucket_size,
)
@dataclass
class FlashCausalLMBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
# request id -> idx in list mapping
requests_idx_mapping: Dict[int, int]
# Decoder values
# Can be a list for easy filtering
# If `input_ids` is a list, it needs to be materialized to a tensor first
input_ids: Union[torch.Tensor, List[List[int]]]
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
position_ids: Optional[torch.Tensor]
speculative_ids: Optional[torch.Tensor]
# Set when creating the batch
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
slot_indices: Optional[torch.Tensor]
# list of length b of list of length s_i // block_size
block_tables: List[List[int]]
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
block_tables_tensor: torch.Tensor
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots: torch.Tensor
# list of length b + 1 containing the cumulative sequence slot lengths of the sequences in the batch
# used for filtering
cu_slots: torch.Tensor
max_input_length: int
max_current_length: int
# Whether this batch contains at least one request that is prefilling
prefilling: bool
# Whether each request is prefilling
prefilling_mask: List[bool]
# Prefill metadata tensors to efficiently compute logprobs
# tensor of length b + 1 containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
cu_seqlen_prefill: Optional[torch.Tensor]
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
# as we only keep SLIDING_WINDOW values instead of the whole tensor
prefill_cache_indices: Optional[torch.Tensor]
# Will be set by `generate_token` and reset after each prefill forward
prefill_head_indices: Optional[torch.Tensor]
# Will be set by `generate_token` and reset after each prefill forward
prefill_next_token_indices: Optional[torch.tensor]
# Will be set by `generate_token` and reset after each prefill forward
prefill_cu_outlens: Optional[List[int]]
# Will be set by `generate_token` and reset after each prefill forward
prefill_logprob_tokens: List[Optional[Tokens]]
# All tokens
all_input_ids: List[List[int]]
all_input_ids_tensor: torch.Tensor
# Lengths of all generations present in the batch
input_lengths: List[int]
# size [b], containing the number of blocks that can be retrieved from the cache
cache_lengths: List[int]
prompt_lengths: List[int]
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
input_lengths_tensor: Optional[torch.Tensor]
cache_lengths_tensor: Optional[torch.Tensor]
prompt_lengths_tensor: torch.Tensor
prefix_offsets: List[Optional[int]]
read_offsets: List[Optional[int]]
# Generation helpers
next_token_chooser: HeterogeneousNextTokenChooser
stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Adapter metadata for each request
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
adapter_meta: Optional[AdapterBatchMetadata]
# Number of blocks in this batch
num_blocks: int
# Maximum number of blocks
max_blocks: int
hpu_attn_meta: Optional[HPUPagedAttentionMetadata]
next_token_logits: Optional[torch.Tensor]
speculative_logits: Optional[torch.Tensor]
valid_indices: Optional[List[int]]
def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch(
id=self.batch_id,
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.num_blocks * BLOCK_SIZE,
current_tokens=(
sum([len(i) for i in self.input_ids])
if isinstance(self.input_ids, list)
else len(self.input_ids)
),
)
@classmethod
def batch_tokenized_inputs(
cls, requests: Iterable[generate_pb2.Request], tokenizer
):
max_length = 0
all_input_ids = []
batch_size = 0
for r in requests:
batch_size += 1
inputs = concat_text_chunks(r.input_chunks.chunks)
input_ids = tokenizer(
inputs,
truncation=True,
max_length=r.truncate,
add_special_tokens=r.add_special_tokens,
)["input_ids"]
max_length = max(max_length, len(input_ids))
all_input_ids.append(input_ids)
return all_input_ids
@classmethod
def from_tokenized(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
batch_tokenized_inputs,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
cache_lengths = []
input_lengths = []
prompt_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
all_postfix_ids = []
requests_idx_mapping = {}
slots = []
cu_slots = [0]
next_token_chooser_parameters = []
stopping_criterias = []
top_n_tokens = []
num_blocks = 0
max_input_length = 0
max_current_length = 0
max_length = 0
max_blocks = 0
cu_blocks = [0]
block_tables = []
block_tables_ragged = []
# Parse batch
for i, (r, tokenized_input) in enumerate(
zip(pb.requests, batch_tokenized_inputs)
):
### XXX: This consumes so much memory on long requests
### Deactivating it by default seems like the best course.
if not REQUEST_LOGPROBS:
r.prefill_logprobs = False
else:
assert False, "prefill_logprobs not supported yet"
# request id -> idx in list mapping
requests_idx_mapping[r.id] = i
prompt_length = len(tokenized_input)
prompt_lengths.append(prompt_length)
cache_length = r.cache_len
assert (
cache_length <= prompt_length
), f"Prefix {cache_length} vs input {prompt_length}"
if cache_length == prompt_length:
assert False, "unreachable"
# `chunk_len` is an optional field in the protobuf
# It is only set if the model support chunking
# Use all the remaining ids
postfix_ids = tokenized_input[cache_length:]
input_length = len(postfix_ids)
input_lengths.append(input_length)
prefix_offsets.append(prompt_length - 5)
read_offsets.append(prompt_length)
all_postfix_ids.append(postfix_ids)
all_input_ids.append(tokenized_input)
next_token_chooser_parameters.append(r.parameters)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
# Paged attention
# Remove one as the first token des not have a past
speculative_length = get_speculate()
speculative_length = 0 if speculative_length is None else speculative_length
# Tokens that need to be mapped to blocks.
block_tokens = prompt_length + max_new_tokens - 1 + speculative_length
# blocks and slots can be empty (for example in warmup)
if not r.blocks:
needed_blocks = math.ceil(block_tokens / BLOCK_SIZE)
request_blocks = [
b for b in range(num_blocks, num_blocks + needed_blocks)
]
request_slots = [
s
for b in request_blocks
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
]
else:
request_blocks = r.blocks
request_slots = r.slots
block_tables.append(request_blocks)
block_tables_ragged.extend(request_blocks)
cu_blocks.append(len(block_tables_ragged))
slots.extend(request_slots)
cu_slots.append(len(slots))
cache_lengths.append(cache_length)
num_blocks += len(request_blocks)
# Update
max_blocks = max(max_blocks, len(request_blocks))
max_input_length = max(max_input_length, input_length)
max_current_length = max(max_current_length, cache_length + input_length)
max_length = max(
max_length,
prompt_length + max_new_tokens + speculative_length,
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device, tokenizer
)
# Padded all_input_ids_tensor
all_input_ids_tensor = np.zeros(
(len(all_input_ids), max_length), dtype=np.int64
)
for i, input_ids in enumerate(all_input_ids):
all_input_ids_tensor[i, : len(input_ids)] = input_ids
# put on cpu temporarily, move to hpu in prepare_for_prefill
all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64)
top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64)
block_tables_ragged = torch.tensor(block_tables_ragged, dtype=torch.int32)
cu_blocks = torch.tensor(cu_blocks, dtype=torch.int64)
block_tables_tensor = torch.empty(
(len(block_tables), max_blocks),
dtype=torch.int32,
)
for i, request_blocks in enumerate(block_tables):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
prompt_lengths_tensor = torch.tensor(prompt_lengths, dtype=torch.int32)
slots = torch.tensor(slots, dtype=torch.int64)
cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=all_postfix_ids,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
cache_lengths=cache_lengths,
max_input_length=max_input_length,
max_current_length=max_current_length,
prefilling=True,
prefilling_mask=[True] * len(pb.requests),
prefill_logprob_tokens=[None] * len(pb.requests),
input_lengths=input_lengths,
prompt_lengths=prompt_lengths,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
num_blocks=num_blocks,
max_blocks=max_blocks,
speculative_ids=None,
prompt_lengths_tensor=prompt_lengths_tensor,
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids=None,
cu_seqlen_prefill=None,
prefill_cache_indices=None,
slot_indices=None,
slots=slots,
cu_slots=cu_slots,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
cache_lengths_tensor=None,
input_lengths_tensor=None,
adapter_meta=None,
hpu_attn_meta=None,
next_token_logits=None,
speculative_logits=None,
valid_indices=None,
)
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
assert len(pb.requests) > 0
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same
if len(request_ids) == len(self):
return self
device = self.block_tables_tensor.device
# New values after filtering
requests_idx_mapping = {}
# Used to index into tensors
indices = []
# slots to keep after filtering
slot_filtering_indices = torch.zeros(self.slots.shape[0], dtype=torch.bool)
# Create on CPU to only move to GPU once instead of at every copy
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
max_input_length = 0
max_current_length = 0
requests = []
block_tables = []
all_input_ids = []
input_ids = []
prompt_lengths = []
input_lengths = []
cache_lengths = []
prefix_offsets = []
read_offsets = []
cu_slots = [0]
prefilling_mask = []
prefill_logprob_tokens = []
stopping_criterias = []
adapter_set = set()
num_blocks = 0
max_blocks = 0
max_slots = 0
cumulative_slot_tokens = 0
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)
requests_idx_mapping[request_id] = i
requests.append(self.requests[idx])
# Prefilling
request_prefilling = self.prefilling_mask[idx]
prefilling_mask.append(request_prefilling)
# Get length
request_input_length = self.input_lengths[idx]
request_cache_length = self.cache_lengths[idx]
max_input_length = max(max_input_length, request_input_length)
max_current_length = max(
max_current_length, request_cache_length + request_input_length
)
all_input_ids.append(self.all_input_ids[idx])
prompt_lengths.append(self.prompt_lengths[idx])
input_lengths.append(request_input_length)
cache_lengths.append(request_cache_length)
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx])
ADAPTER_TO_INDEX = get_adapter_to_index()
adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
adapter_set.add(adapter_index)
request_block_table = self.block_tables[idx]
num_blocks += len(request_block_table)
block_tables.append(request_block_table)
start_slot = self.cu_slots[idx]
end_slot = self.cu_slots[idx + 1]
slot_length = end_slot - start_slot
# Set slice
slot_filtering_indices[start_slot:end_slot] = True
cu_slots.append(cumulative_slot_tokens + slot_length)
# Input ids if the request was part of a prefilling batch
# If the batch was decoding we can index into the tensor directly later
if self.prefilling:
input_ids.append(self.input_ids[idx])
else:
# Copy to tensor (CPU)
slot_indices[i] = cumulative_slot_tokens + request_cache_length
cumulative_slot_tokens += slot_length
max_blocks = max(max_blocks, len(request_block_table))
max_slots = max(max_slots, slot_length)
block_tables_tensor = self.block_tables_tensor[indices]
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
slots = self.slots[slot_filtering_indices]
if self.prefilling:
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids = None
slot_indices = None
cache_lengths_tensor = None
input_lengths_tensor = None
adapter_meta = None
else:
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
input_lengths_tensor = self.input_lengths_tensor[indices]
cache_lengths_tensor = self.cache_lengths_tensor[indices]
# Move to GPU now that we have the whole tensor
slot_indices = slot_indices.to(device)
if self.adapter_meta is not None:
adapter_indices = self.adapter_meta.adapter_indices[indices]
adapter_segments, adapter_segment_indices = find_segments(
adapter_indices
)
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
)
else:
adapter_meta = None
htorch.core.mark_step()
return type(self)(
batch_id=self.batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
prefill_cache_indices=None,
slot_indices=slot_indices,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
slots=slots,
cu_slots=cu_slots,
max_input_length=max_input_length,
max_current_length=max_current_length,
prefilling=self.prefilling,
prefilling_mask=prefilling_mask,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
prefill_logprob_tokens=prefill_logprob_tokens,
prompt_lengths=prompt_lengths,
prompt_lengths_tensor=prompt_lengths_tensor,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
cache_lengths=cache_lengths,
cache_lengths_tensor=cache_lengths_tensor,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=self.all_input_ids_tensor,
next_token_chooser=self.next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=self.top_n_tokens,
top_n_tokens_tensor=self.top_n_tokens_tensor,
num_blocks=num_blocks,
max_blocks=max_blocks,
speculative_ids=self.speculative_ids,
adapter_meta=adapter_meta,
hpu_attn_meta=None,
valid_indices=indices,
next_token_logits=self.next_token_logits,
speculative_logits=self.speculative_logits,
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(
cls, batches: List["FlashCausalLMBatch"], padded_total_bs: int = 0
) -> "FlashCausalLMBatch":
# Batch attributes
requests = []
requests_idx_mapping = {}
prefilling = False
num_blocks = 0
total_batch_size = 0
total_slots = 0
max_blocks = 0
max_length = 0
max_input_length = 0
max_current_length = 0
ADAPTER_TO_INDEX = get_adapter_to_index()
for b in batches:
total_batch_size += len(b)
max_blocks = max(max_blocks, b.max_blocks)
total_slots += len(b.slots)
num_blocks += b.num_blocks
speculative_length = (
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
)
max_input_length = max(max_input_length, b.max_input_length)
max_current_length = max(max_current_length, b.max_current_length)
max_length = max(
max_length,
max(
prompt_length
+ stopping_criteria.max_new_tokens
+ speculative_length
for prompt_length, stopping_criteria in zip(
b.prompt_lengths, b.stopping_criterias
)
),
)
prefilling = prefilling or b.prefilling
slots = batches[0].slots.new_empty(total_slots)
cu_slots = torch.zeros(total_batch_size + 1, dtype=torch.int64)
if prefilling:
input_ids = []
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids = None
slot_indices = None
cache_lengths_tensor = None
input_lengths_tensor = None
adapter_meta = None
adapter_segment_builder = None
else:
if padded_total_bs == batches[0].input_ids.shape[0]:
input_ids = batches[0].input_ids
else:
input_ids = batches[0].input_ids.new_empty(total_batch_size)
if (
batches[0].position_ids is not None
and batches[0].position_ids.dim() == 2
):
# Qwen2_vl case:
position_ids = batches[0].position_ids.new_empty(
(total_batch_size, batches[0].position_ids.shape[-1])
)
else:
position_ids = batches[0].position_ids.new_empty(total_batch_size)
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
total_batch_size
)
cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(
total_batch_size
)
if ADAPTER_TO_INDEX:
total_indices_size = sum(
b.adapter_meta.adapter_indices.shape[0] for b in batches
)
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
total_indices_size
)
adapter_segment_builder = SegmentConcatBuilder()
adapter_set = set()
prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
total_batch_size
)
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
(total_batch_size, max_blocks)
)
all_input_ids_tensor = batches[0].all_input_ids_tensor
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
block_tables = []
cache_lengths = []
all_input_ids = []
prompt_lengths = []
input_lengths = []
prefix_offsets = []
read_offsets = []
prefill_logprob_tokens = []
next_token_chooser_parameters = []
fsm_grammar_states = []
stopping_criterias = []
top_n_tokens = []
prefilling_mask = []
# Cumulative length
cumulative_batch_size = 0
cumulative_slots = 0
cumulative_adapter_indices_size = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
valid_bsize = len(batch)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + cumulative_batch_size
start_index = cumulative_batch_size
end_index = cumulative_batch_size + valid_bsize
index = torch.tensor(list(range(start_index, end_index)), device="cpu")
top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)
if i > 0:
all_input_ids_tensor.index_copy_(
0,
index.to(batch.all_input_ids_tensor.device),
batch.all_input_ids_tensor[:valid_bsize, :],
)
block_tables_tensor[
start_index:end_index, : batch.block_tables_tensor.shape[1]
] = batch.block_tables_tensor[:, :max_blocks]
prompt_lengths_tensor.index_copy_(0, index, batch.prompt_lengths_tensor)
slots_start_index = cumulative_slots
slots_end_index = cumulative_slots + len(batch.slots)
slot_index = torch.tensor(
list(range(slots_start_index, slots_end_index)),
device=batch.slots.device,
)
slots.index_copy_(0, slot_index, batch.slots)
cu_slots[start_index + 1 : end_index + 1] = (
batch.cu_slots[1:] + cumulative_slots
)
if not prefilling:
if padded_total_bs != batches[0].input_ids.shape[0] or i > 0:
input_ids.index_copy_(
0, index.to(input_ids.device), batch.input_ids[:valid_bsize]
)
position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize])
slot_indices.index_copy_(
0, index, batch.slot_indices + cumulative_slots
)
input_lengths_tensor.index_copy_(
0, index, batch.input_lengths_tensor[:valid_bsize]
)
cache_lengths_tensor.index_copy_(
0, index, batch.cache_lengths_tensor[:valid_bsize]
)
if ADAPTER_TO_INDEX:
adapter_start_index = cumulative_adapter_indices_size
adapter_end_index = (
cumulative_adapter_indices_size
+ batch.adapter_meta.adapter_indices.shape[0]
)
adapter_indices[adapter_start_index:adapter_end_index] = (
batch.adapter_meta.adapter_indices
)
cumulative_adapter_indices_size = adapter_end_index
adapter_set.update(batch.adapter_meta.adapter_set)
adapter_segment_builder.concat(
batch.adapter_meta.adapter_segments,
batch.adapter_meta.segment_indices,
)
else:
if isinstance(batch.input_ids, torch.Tensor):
batch.input_ids = batch.input_ids.view(-1, 1).tolist()
input_ids.extend(batch.input_ids)
prefilling_mask.extend(batch.prefilling_mask)
block_tables.extend(batch.block_tables)
cache_lengths.extend(batch.cache_lengths)
all_input_ids.extend(batch.all_input_ids)
prompt_lengths.extend(batch.prompt_lengths)
input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets)
prefill_logprob_tokens.extend(batch.prefill_logprob_tokens)
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
# Update
cumulative_slots += len(batch.slots)
cumulative_batch_size += len(batch)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters,
dtype=batches[0].next_token_chooser.dtype,
device=batches[0].next_token_chooser.device,
tokenizer=batches[0].next_token_chooser.tokenizer,
fsm_grammar_states=fsm_grammar_states,
)
# We skip computing the speculative_ids when the batch size is too large, so
# we must check that all batches have them, otherwise they must be discarded
if get_speculate() > 0 and all(b.speculative_ids is not None for b in batches):
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
else:
speculative_ids = None
if ADAPTER_TO_INDEX and adapter_segment_builder is not None:
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
)
return cls(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
prefill_cache_indices=None,
slot_indices=slot_indices,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
cache_lengths=cache_lengths,
cache_lengths_tensor=cache_lengths_tensor,
slots=slots,
cu_slots=cu_slots,
max_input_length=max_input_length,
max_current_length=max_current_length,
prefilling=prefilling,
prefilling_mask=prefilling_mask,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
prefill_logprob_tokens=prefill_logprob_tokens,
prompt_lengths=prompt_lengths,
prompt_lengths_tensor=prompt_lengths_tensor,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
num_blocks=num_blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids,
adapter_meta=adapter_meta if ADAPTER_TO_INDEX else None,
hpu_attn_meta=None,
next_token_logits=None,
speculative_logits=None,
valid_indices=None,
)
def prepare_for_decode(
self, dtype, use_contiguous_pa, bucketing_ctx, pad_token_id, sliding_window
):
block_num = [length // BLOCK_SIZE + 1 for length in self.cache_lengths]
block_tables = []
for i, bt in enumerate(self.block_tables):
block_tables.append(bt[0 : block_num[i]])
if bucketing_ctx is not None:
padded_bs = bucketing_ctx.get_padded_decode_batch_size(
self.input_ids.shape[0]
)
else:
padded_bs = self.input_ids.shape[0]
slots = self.slots[self.slot_indices]
block_list, block_groups, block_usage, _, block_bucket_size = (
generate_block_metadata(
dtype,
use_contiguous_pa,
slots,
block_tables,
bucketing_ctx,
)
)
meta = HPUPagedAttentionMetadata(
block_list=_async_h2d_tensor_copy(block_list),
block_groups=_async_h2d_tensor_copy(block_groups),
block_usage=_async_h2d_tensor_copy(block_usage),
block_mapping=None,
attn_bias=None,
)
if sliding_window is not None:
block_tables_in_window = []
for i, bt in enumerate(self.block_tables):
block_num_in_window = (
sliding_window + 2 * BLOCK_SIZE - 2 - slots[i] % BLOCK_SIZE
) // BLOCK_SIZE
block_tables_in_window.append(
bt[max(0, block_num[i] - block_num_in_window) : block_num[i]]
)
slots_in_window = []
for i, indice in enumerate(self.slot_indices):
start_idx = indice - self.cache_lengths[i]
mask = (
indice
- torch.arange(
start_idx,
indice + 1,
device=self.slots.device,
)
) < sliding_window
slots_in_window.append(self.slots[start_idx : indice + 1][mask])
slots_in_window = torch.cat(slots_in_window, dim=0)
(
block_list_in_window,
block_groups_in_window,
block_usage_in_window,
slots_in_window_mask,
_,
) = generate_block_metadata(
dtype,
use_contiguous_pa,
slots,
block_tables_in_window,
bucketing_ctx,
slots_in_window,
block_bucket_size,
)
meta.block_list_in_window = _async_h2d_tensor_copy(block_list_in_window)
meta.block_groups_in_window = _async_h2d_tensor_copy(block_groups_in_window)
meta.block_usage_in_window = _async_h2d_tensor_copy(block_usage_in_window)
meta.slots_in_window_mask = _async_h2d_tensor_copy(slots_in_window_mask)
self.hpu_attn_meta = trim_attn_metadata(meta)
self.input_ids = F.pad(
self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=pad_token_id
)
if self.position_ids.dim() == 2:
# Qwen VL case
self.position_ids = F.pad(
self.position_ids,
(0, 0, 0, padded_bs - self.position_ids.shape[0]),
value=1,
)
else:
self.position_ids = F.pad(
self.position_ids, (0, padded_bs - self.position_ids.shape[0]), value=1
)
self.input_lengths_tensor = F.pad(
self.input_lengths_tensor,
(0, padded_bs - self.input_lengths_tensor.shape[0]),
value=0,
)
self.cache_lengths_tensor = F.pad(
self.cache_lengths_tensor,
(0, padded_bs - self.cache_lengths_tensor.shape[0]),
value=0,
)
if len(self.next_token_chooser.do_sample) != padded_bs:
next_token_chooser_parameters = []
next_token_chooser_parameters.extend([r.parameters for r in self.requests])
pad_next_token_chooser_parameters(next_token_chooser_parameters, padded_bs)
# update past grammar states
fsm_grammar_states = [0] * padded_bs
for i, req in enumerate(self.requests):
fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i]
self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters,
self.next_token_chooser.dtype,
self.next_token_chooser.device,
self.next_token_chooser.tokenizer,
fsm_grammar_states,
)
def prepare_for_prefill(
self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
):
# Prepare values if we need to continue prefilling
# Speculation must be ignored while we prefill even with chunking
# it simplifies everything
assert self.speculative_ids is None
# device = self.block_tables_tensor.device
# hpu does not support varlen for prefill, use sdpa instead. so need to pad input_tensor, position
# padding to left to work with sliding window
# use prefill_cache_indices to indicate the valid kv slot, update prefill_next_token_indices to indicate
# the right logit position
input_ids_padded_length = []
# need extra pad to match warmup seq
extra_pad = max_padded_input_len - self.max_input_length
extra_pad_bs = max_padded_bs - len(self)
device = "hpu"
if isinstance(self.input_ids, list) and len(self) > 1:
input_ids_padded_length = []
input_ids = []
for input_id in self.input_ids:
padded = self.max_input_length - len(input_id) + extra_pad
if padded > 0:
input_id = [pad_token_id] * padded + input_id
input_ids.append(input_id)
input_ids_padded_length.append(padded)
input_ids = np.concatenate(input_ids, dtype=np.int64)
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
elif isinstance(self.input_ids, list):
input_ids = self.input_ids[0]
input_ids_padded_length.append(extra_pad)
input_ids = [pad_token_id] * extra_pad + input_ids
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
else:
input_ids = torch.full(
(max_padded_input_len * len(self),),
pad_token_id,
dtype=torch.int64,
device=self.input_ids.device,
)
src_pos = 0
for i in range(len(self)):
end_pos = (i + 1) * max_padded_input_len
start_pos = end_pos - self.input_lengths[i]
input_ids[start_pos:end_pos] = self.input_ids[
src_pos : src_pos + self.input_lengths[i]
]
input_ids_padded_length.append(
max_padded_input_len - self.input_lengths[i]
)
src_pos += self.input_lengths[i]
self.input_ids = input_ids
self.input_ids = F.pad(
self.input_ids, (0, extra_pad_bs * max_padded_input_len), value=pad_token_id
)
self.input_lengths_tensor = torch.tensor(self.input_lengths, dtype=torch.int32)
self.input_lengths_tensor = F.pad(
self.input_lengths_tensor, (0, extra_pad_bs), value=0
)
cu_seqlen_prefill = self.input_lengths_tensor.new_zeros(max_padded_bs + 1)
torch.cumsum(self.input_lengths_tensor, out=cu_seqlen_prefill[1:], dim=0)
self.cu_seqlen_prefill = cu_seqlen_prefill.to(torch.int32)
self.cache_lengths_tensor = torch.tensor(self.cache_lengths, dtype=torch.int32)
self.cache_lengths_tensor = F.pad(
self.cache_lengths_tensor, (0, extra_pad_bs), value=0
)
position_ids = []
slot_indices = []
prefill_cache_indices = []
all_prefill_logprobs = True
no_prefill_logprobs = True
prefill_cu_outlens = [0]
# Cumulative length
cumulative_length = 0
cumulative_slot_tokens = 0
prefill_out_cumulative_length = 0
adapter_indices_list = []
adapter_set = set()
for i, (
r,
cache_length,
input_length,
prompt_length,
request_prefilling,
blocks,
) in enumerate(
zip(
self.requests,
self.cache_lengths,
self.input_lengths,
self.prompt_lengths,
self.prefilling_mask,
self.block_tables,
)
):
next_chunk_length = input_length
# Position ids
request_position_ids = torch.arange(
cache_length, cache_length + input_length, dtype=torch.int32
)
request_position_ids = F.pad(
request_position_ids, (input_ids_padded_length[i], 0), value=1
)
position_ids.append(request_position_ids)
if not r.slots:
request_slots = [
s
for b in blocks
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
]
else:
request_slots = r.slots
request_slot_indices = torch.arange(
cache_length + cumulative_slot_tokens,
cache_length + cumulative_slot_tokens + input_length,
dtype=torch.int64,
)
slot_indices.append(request_slot_indices)
# Update
cumulative_slot_tokens += len(request_slots)
# Create tensor to slice into the kv tensor in prefill
# hpu need request_prefill_cache_indices to skip padding in kv cache
sliding_window = input_length
cumulative_length += input_ids_padded_length[i]
if sliding_window is not None:
request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, input_length - sliding_window),
cumulative_length + input_length,
dtype=torch.int64,
)
# Prefill logprobs is ignored if the request is done prefilling
prefill_logprobs = r.prefill_logprobs and request_prefilling
all_prefill_logprobs = all_prefill_logprobs and prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
if prefill_logprobs:
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
prefill_out_cumulative_length += input_length
else:
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
prefill_cache_indices.append(request_prefill_cache_indices)
ADAPTER_TO_INDEX = get_adapter_to_index()
if ADAPTER_TO_INDEX:
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
adapter_indices_list.append(
torch.full((next_chunk_length,), adapter_index)
)
adapter_set.add(adapter_index)
# Update
cumulative_length += next_chunk_length
if not all_prefill_logprobs and not no_prefill_logprobs:
prefill_head_indices = []
prefill_next_token_indices = []
# Cumulative length
cumulative_length = 0
prefill_out_cumulative_length = 0
for i, (
r,
input_length,
request_prefilling,
) in enumerate(
zip(
self.requests,
self.input_lengths,
self.prefilling_mask,
)
):
# Prefill logprobs is ignored if the request is done prefilling
prefill_logprobs = r.prefill_logprobs and request_prefilling
if prefill_logprobs:
prefill_head_indices.append(
torch.arange(
cumulative_length,
cumulative_length + input_length,
dtype=torch.int32,
)
)
prefill_next_token_indices.append(
prefill_out_cumulative_length + input_length - 1
)
prefill_out_cumulative_length += input_length
else:
prefill_head_indices.append(
torch.tensor(
[cumulative_length + input_length - 1],
dtype=torch.int32,
)
)
prefill_next_token_indices.append(prefill_out_cumulative_length)
prefill_out_cumulative_length += 1
# Update
cumulative_length += input_length
if len(self) > 1:
if position_ids:
position_ids = torch.cat(position_ids)
if slot_indices:
slot_indices = torch.cat(slot_indices)
prefill_cache_indices = torch.cat(prefill_cache_indices)
else:
if position_ids:
position_ids = position_ids[0]
if slot_indices:
slot_indices = slot_indices[0]
prefill_cache_indices = prefill_cache_indices[0]
self.position_ids = position_ids
self.position_ids = F.pad(
self.position_ids, (0, extra_pad_bs * max_padded_input_len), value=1
)
self.slot_indices = slot_indices
self.prefill_cu_outlens = prefill_cu_outlens
self.prefill_cache_indices = torch.zeros_like(
self.input_ids, dtype=torch.bool, device="cpu"
)
self.prefill_cache_indices[prefill_cache_indices] = True
if all_prefill_logprobs:
prefill_head_indices = None
prefill_next_token_indices = self.cu_seqlen_prefill[1:] - 1
elif no_prefill_logprobs:
prefill_head_indices = self.cu_seqlen_prefill[1:] - 1
prefill_next_token_indices = None
else:
prefill_head_indices = torch.cat(prefill_head_indices)
prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64
)
self.prefill_head_indices = prefill_head_indices
self.prefill_next_token_indices = prefill_next_token_indices
input_ids_padded_length_tensor = torch.cumsum(
torch.tensor(input_ids_padded_length, dtype=torch.int32),
dim=-1,
).to(torch.int32)
input_ids_padded_length_tensor = F.pad(
input_ids_padded_length_tensor, (0, extra_pad_bs), value=0
)
if self.prefill_head_indices is not None:
self.prefill_head_indices = (
self.prefill_head_indices + input_ids_padded_length_tensor
)
if self.prefill_next_token_indices is not None:
self.prefill_next_token_indices = (
self.prefill_next_token_indices + input_ids_padded_length_tensor
)
all_input_ids_tensor = torch.full(
(max_padded_bs, max(max_total_tokens, self.all_input_ids_tensor.shape[-1])),
pad_token_id,
dtype=torch.int64,
device="hpu",
)
for i in range(len(self)):
all_input_ids_tensor[i, : self.all_input_ids_tensor.shape[-1]] = (
self.all_input_ids_tensor[i]
)
self.all_input_ids_tensor = all_input_ids_tensor
if len(self.next_token_chooser.do_sample) != max_padded_bs:
next_token_chooser_parameters = []
next_token_chooser_parameters.extend([r.parameters for r in self.requests])
pad_next_token_chooser_parameters(
next_token_chooser_parameters, max_padded_bs
)
# update past grammar states
fsm_grammar_states = [0] * max_padded_bs
for i, req in enumerate(self.requests):
fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i]
self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters,
self.next_token_chooser.dtype,
self.next_token_chooser.device,
self.next_token_chooser.tokenizer,
fsm_grammar_states,
)
if ADAPTER_TO_INDEX:
if adapter_set:
adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64)
adapter_segments, adapter_segment_indices = find_segments(
adapter_indices
)
else:
adapter_indices = torch.zeros_like(self.input_ids)
adapter_segments = [0, len(adapter_indices)]
adapter_segment_indices = [len(adapter_indices) - 1]
adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)
self.adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
)
def __len__(self):
return len(self.requests)
ADAPTER_LAYERS = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class FlashCausalLM(Model):
def __init__(
self,
model_id: str,
model_class,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
lora_adapter_ids: Optional[list] = [],
tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
config_class: PreTrainedTokenizerBase = AutoConfig,
default_dtype=torch.float16,
aliases=None,
# Used for Santacoder override of config
num_kv_heads: Optional[int] = None,
# Deepseek V2 uses different QK and V dims.
head_size: Optional[int] = None,
skip_special_tokens: bool = True,
kv_cache_dtype: Optional[torch.dtype] = None,
support_chunking: bool = True,
):
self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed()
if world_size > 1:
self.process_group_cpu = torch.distributed.new_group(backend="gloo")
device = torch.device("hpu")
dtype = torch.bfloat16 if dtype is None else dtype
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
try:
generation_config = GenerationConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
if isinstance(generation_config.eos_token_id, (list, set)):
# TODO Huge hack
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
except Exception:
pass
config = config_class.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
weights_loader = get_loader(quantize, model_id, revision)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
aliases=aliases,
weights_loader=weights_loader,
)
prefix = None
model = model_class(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
# VLM models define the config we care about in their text_config
text_config = getattr(config, "text_config", None)
if text_config is not None:
config = text_config
if getattr(config, "sliding_window", None) is None:
config.sliding_window = None
if getattr(config, "use_sliding_window", True) is False:
config.sliding_window = None
self.num_layers = config.num_hidden_layers
self.num_heads = config.num_attention_heads // self.process_group.size()
self.config = config
# Validation is done in the model itself
if num_kv_heads is None:
num_kv_heads = getattr(config, "num_key_value_heads", None)
# GPT-2 workaround
if num_kv_heads is None:
num_kv_heads = getattr(config, "n_head", None)
if num_kv_heads is None:
raise ValueError("Cannot get the number of key/value heads")
self.num_kv_heads = (
num_kv_heads // self.process_group.size()
if num_kv_heads // self.process_group.size() > 0
else num_kv_heads
)
assert self.num_kv_heads > 0
if head_size is None:
# Some models use GQA and different sizes for o_proj
# and q_proj, that allows for that.
if getattr(config, "head_dim", None) is not None:
self.head_size = config.head_dim
else:
self.head_size = config.hidden_size // config.num_attention_heads
else:
self.head_size = head_size
self.cuda_graphs = {}
self.kv_cache = []
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
self.bucketing_ctx = None
self.max_total_tokens = None
self.max_input_tokens = None
htorch.core.hpu_set_env()
if htorch.utils.internal.is_lazy():
htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True)
environment.set_model_config(self.config)
self.use_contiguous_pa = (
os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true"
)
self.limit_hpu_graph = (
os.environ.get("LIMIT_HPU_GRAPH", "false").lower() == "true"
)
self.skip_warmup = os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true"
self.max_seq_len_to_capture = 8192
if tokenizer.pad_token_id is None:
if config.pad_token_id is not None:
tokenizer.pad_token_id = config.pad_token_id
elif config.eos_token_id is not None:
tokenizer.pad_token_id = (
config.eos_token_id[0]
if isinstance(config.eos_token_id, list)
else config.eos_token_id
)
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.pad_token_id = 0
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
sliding_window=config.sliding_window,
support_chunking=support_chunking,
)
@property
def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch
def max_past(self) -> int:
return getattr(self.model, "max_past", None)
def init_kv_cache(
self,
num_blocks: int,
num_layers: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
self.kv_cache = []
empty_cache()
if self.config.model_type in ["deepseek_v3", "deepseek_v2"]:
self.kv_cache = [
KVCompressCache(
num_blocks=num_blocks,
head_size=self.config.kv_lora_rank + self.config.qk_rope_head_dim,
dtype=dtype,
device=device,
)
for _ in range(num_layers)
]
else:
self.kv_cache = [
KVCache(
num_blocks=num_blocks,
num_heads=num_heads,
head_size=head_size,
dtype=dtype,
device=device,
)
for _ in range(num_layers)
]
def warmup(
self,
batch: FlashCausalLMBatch,
max_input_tokens: Optional[int],
max_total_tokens: Optional[int],
):
if os.environ.get("MAX_BATCH_SIZE") is None:
raise RuntimeError(
"MAX_BATCH_SIZE is not set, it should be set in the launcher "
"using `--max-batch-size xxx`"
)
# The warmup batch is the biggest batch we could ever receive
self.kv_cache = []
empty_cache()
self.graphed_buckets = set()
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
if self.config.model_type in ["deepseek_v3", "deepseek_v2"]:
cache_block_size = BLOCK_SIZE * (
self.config.kv_lora_rank + self.config.qk_rope_head_dim
)
else:
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
cache_block_size = cache_block_size * 2
total_cache_size = self.num_layers * cache_block_size * dtype_size
free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM)
self.mem_reserved = int(free_memory * (1 - MEMORY_FRACTION))
graph_reserved_mem = (
float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1"))
if htorch.utils.internal.is_lazy()
else 0
)
mem_used_from_graph = int(
(free_memory - self.mem_reserved) * graph_reserved_mem
)
log_master(
logger.info,
f"Free memory on device {self.device}: {format_bytes(free_memory)} used_for_graph: {format_bytes(mem_used_from_graph)} ratio {graph_reserved_mem} reserved_for_runtime: {format_bytes(self.mem_reserved)}",
)
if max_total_tokens is None:
max_total_tokens = sum(batch.input_lengths)
if max_input_tokens is None:
max_input_tokens = max_total_tokens - 1
self.max_total_tokens = max_total_tokens
self.max_input_tokens = max_input_tokens
try:
self.init_kv_cache(
batch.num_blocks,
self.num_layers,
self.num_kv_heads,
self.head_size,
self.kv_cache_dtype,
self.device,
)
batch_num_blocks = batch.num_blocks
num_tokens = batch.to_pb().current_tokens
synchronize(self.device)
_, _batch, _ = self.generate_token([batch])
except Exception:
raise RuntimeError(
f"Not enough memory to handle {num_tokens} prefill tokens. "
f"You need to decrease `--max-batch-prefill-tokens`"
)
synchronize(self.device)
free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM)
kv_memory = free_memory - self.mem_reserved - mem_used_from_graph
num_blocks = (
# Leave 5% for some wiggle room
int(kv_memory // total_cache_size)
# Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
+ batch_num_blocks
)
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
self.kv_cache = []
empty_cache()
self.init_kv_cache(
num_blocks,
self.num_layers,
self.num_kv_heads,
self.head_size,
self.kv_cache_dtype,
self.device,
)
self.max_batch_prefill_tokens = get_max_prefill_tokens()
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE"))
HPUBucketingContext = get_bucketing_context()
# need to warmup one more step since block is allocated from 1
block_step = os.getenv("VLLM_DECODE_BLOCK_BUCKET_STEP", BLOCK_SIZE)
max_total_tokens_aligned = math.ceil(
max_total_tokens / BLOCK_SIZE
) * BLOCK_SIZE + math.ceil(block_step * BLOCK_SIZE / max_num_seqs)
model_max_length = self.tokenizer.model_max_length
max_position_embeddings = getattr(
self.config, "max_position_embeddings", model_max_length
)
self.bucketing_ctx = HPUBucketingContext(
max_num_seqs,
max_num_seqs, # self.max_num_prefill_seqs, #TODO
BLOCK_SIZE,
max_num_seqs * max_total_tokens_aligned,
False,
min(model_max_length, max_position_embeddings),
max_input_tokens,
max_total_tokens_aligned,
)
max_blocks = max(
BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE
)
self.bucketing_ctx.num_hpu_blocks = min(max_blocks, num_blocks)
synchronize(self.device)
if self.skip_warmup:
self.bucketing_ctx.generate_prompt_buckets()
self.bucketing_ctx.generate_decode_buckets(
self.bucketing_ctx.num_hpu_blocks
)
log_master(
logger.info, "skip warmup hpu graph, not recommmended, may cause OOM"
)
del _batch, batch
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
self.warmup_hpu_graph(batch)
del _batch, batch
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
def log_warmup(self, prefilling, i, max_i, batch_size, seq_len):
free_mem = format_bytes(HabanaMemoryProfiler.current_free_device_memory())
phase = "Prompt" if prefilling else "Decode"
dim = "seq_len" if prefilling else "num_blocks"
graphed_bucket = (batch_size, seq_len, prefilling)
bypass = graphed_bucket not in self.graphed_buckets
msg = (
f"[Warmup][{phase}][{i+1}/{max_i}] "
f"batch_size:{batch_size} "
f"{dim}:{seq_len} "
f"bypass:{bypass} "
f"free_mem:{free_mem}"
", this may take a while..."
)
log_master(logger.info, msg)
def use_graphs(self, prefill, seq_len, batch_size):
if self.limit_hpu_graph and prefill:
return False
if self.skip_warmup:
return True
return (batch_size, seq_len, prefill) in self.graphed_buckets
def align_workers(self, value, op):
if self.world_size <= 1:
return value
value_t = torch.tensor(value, device="cpu")
torch.distributed.all_reduce(value_t, op=op, group=self.process_group_cpu)
return value_t.item()
def warmup_hpu_graph(self, batch):
prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3"))
free_mem = HabanaMemoryProfiler.current_free_device_memory()
graph_free_mem = free_mem - self.mem_reserved
graph_free_mem = self.align_workers(
graph_free_mem, torch.distributed.ReduceOp.MIN
)
prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
decode_available_memory = graph_free_mem - prompt_available_memory
msg = (
f"Using {format_bytes(graph_free_mem)}"
f"/{format_bytes(free_mem)} "
"of free device memory for HPUGraphs, "
f"{format_bytes(prompt_available_memory)} for prompt and "
f"{format_bytes(decode_available_memory)} for decode "
f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})"
)
log_master(logger.info, msg)
start_time = time.time()
warmup_shape_count = 0
warmup_times = 3
self.bucketing_ctx.generate_prompt_buckets()
def ordering_function_min_tokens(b):
return (b[0] * b[1], b[1], b[0])
buckets = list(
sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
)
total_batch_seq = 0.001
total_mem = 0
available_mem = prompt_available_memory
msg = (
f"Prefill batch size list:{[bsz[0] for bsz in buckets]}\n"
f"Prefill sequence length list:{[seq[1] for seq in buckets]}\n"
)
log_master(logger.info, msg)
for i, (batch_size, seq_len) in enumerate(buckets):
if batch_size * seq_len > self.max_batch_prefill_tokens:
continue
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size * seq_len
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, seq_len, True)
if not (
mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture
):
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
self.log_warmup(True, i, len(buckets), batch_size, seq_len)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
)
if graphed_bucket in self.graphed_buckets:
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
log_master(logger.info, "Prefill warmup successful.\n")
def ordering_function_max_bs(b):
return (-b[0], b[1])
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
buckets = list(
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
)
free_mem = HabanaMemoryProfiler.current_free_device_memory()
total_batch_seq = 0.001
total_mem = 0
available_mem = free_mem - self.mem_reserved
log_master(
logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n"
)
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num:
continue
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, block_num, False)
if not mem_estimate >= available_mem:
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
self.log_warmup(False, i, len(buckets), batch_size, block_num)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
)
if graphed_bucket in self.graphed_buckets:
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
log_master(logger.info, "Decode warmup successful.\n")
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
)
def warmup_prefill(
self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch
):
input_ids = torch.zeros(prompt_len, dtype=batch.input_ids.dtype).repeat(
batch_size
)
position_ids = torch.arange(prompt_len, dtype=batch.position_ids.dtype).repeat(
batch_size
)
max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size
block_tables = torch.arange(max_bt, dtype=torch.int32).reshape(batch_size, -1)
slot_acc = []
for i in range(batch_size):
slots = []
for b in block_tables[i]:
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
slot_acc.extend(slots[:prompt_len])
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype)
input_lengths = torch.ones(batch_size, dtype=torch.int32) * prompt_len
cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32)
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen(
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
lm_head_indices = input_lengths - 1
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
True, prompt_len, batch_size
)
if self.sliding_window is not None:
attn_mask = seqlen.make_sliding_window_bias(
input_lengths.tolist(),
self.sliding_window,
self.dtype,
prompt_len,
batch_size,
)
seqlen.attn_mask = _async_h2d_tensor_copy(attn_mask)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=self.kv_cache,
slots=_async_h2d_tensor_copy(slots),
seqlen=trim_seqlen_metadata(seqlen),
lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
adapter_data=None,
hpu_attention_meta=None,
**kwargs,
)
def warmup_decode(self, batch_size: int, block_num: int, batch: FlashCausalLMBatch):
input_ids = torch.zeros(batch_size, dtype=batch.input_ids.dtype)
position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype)
blocks = [block_num // batch_size for _ in range(batch_size)]
blocks[0] += block_num % batch_size
block_tables = []
slots = []
start_idx = 0
slot_indices = []
# fetch the last blocked to warmup block num
for i in range(batch_size):
block_array = list(range(start_idx, start_idx + blocks[i]))
slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
slot_indices.append((start_idx + blocks[i]) * BLOCK_SIZE - 1)
block_tables.append(block_array)
start_idx += blocks[i]
input_lengths = torch.ones(batch_size, dtype=torch.int32)
cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32)
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen(
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
block_list, block_groups, block_usage, _, block_bucket_size = (
generate_block_metadata(
self.dtype,
self.use_contiguous_pa,
slots,
block_tables,
self.bucketing_ctx,
)
)
meta = HPUPagedAttentionMetadata(
block_list=_async_h2d_tensor_copy(block_list),
block_groups=_async_h2d_tensor_copy(block_groups),
block_usage=_async_h2d_tensor_copy(block_usage),
block_mapping=None,
attn_bias=None,
)
if self.sliding_window is not None:
block_tables_in_window = []
for i, bt in enumerate(block_tables):
block_num_in_window = (
self.sliding_window + BLOCK_SIZE - 1
) // BLOCK_SIZE
block_tables_in_window.append(
bt[max(0, blocks[i] - block_num_in_window) : blocks[i]]
)
slots_in_window = []
start_idx = 0
for i, indice in enumerate(slot_indices):
mask = (
indice - torch.arange(start_idx, indice + 1)
) < self.sliding_window
slots_in_window.append(torch.arange(start_idx, indice + 1)[mask])
start_idx += blocks[i] * BLOCK_SIZE
slots_in_window = torch.cat(slots_in_window, dim=0)
(
block_list_in_window,
block_groups_in_window,
block_usage_in_window,
slots_in_window_mask,
_,
) = generate_block_metadata(
self.dtype,
self.use_contiguous_pa,
slots,
block_tables_in_window,
self.bucketing_ctx,
slots_in_window,
block_bucket_size,
)
meta.block_list_in_window = _async_h2d_tensor_copy(block_list_in_window)
meta.block_groups_in_window = _async_h2d_tensor_copy(block_groups_in_window)
meta.block_usage_in_window = _async_h2d_tensor_copy(block_usage_in_window)
meta.slots_in_window_mask = _async_h2d_tensor_copy(slots_in_window_mask)
hpu_attention_meta = trim_attn_metadata(meta)
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
False, hpu_attention_meta.block_list.shape[0], batch_size
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
slots=_async_h2d_tensor_copy(slots_tensor),
seqlen=trim_seqlen_metadata(seqlen),
lm_head_indices=None,
adapter_data=None,
hpu_attention_meta=hpu_attention_meta,
**kwargs,
)
def forward(
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
# Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices,
# then update the slots with the additional indices to ensure we're grabbing the ones that have been
# allocated
slot_indices = (
batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
slots = batch.slots[slot_indices]
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
# Add Copy the block tables for all members
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
else:
slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[: slots.shape[0]] = slots
slots = slots_pad
seqlen = Seqlen(
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
kwargs = {}
batch_size = input_lengths.shape[0]
prompt_len = (
input_ids.shape[0] // batch_size
if batch.prefilling
else batch.hpu_attn_meta.block_list.shape[0]
)
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
batch.prefilling, prompt_len, batch_size
)
if self.sliding_window is not None and batch.prefilling:
attn_mask = seqlen.make_sliding_window_bias(
input_lengths.tolist(),
self.sliding_window,
self.dtype,
prompt_len,
batch_size,
)
seqlen.attn_mask = _async_h2d_tensor_copy(attn_mask)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=kv_cache,
slots=_async_h2d_tensor_copy(slots),
seqlen=trim_seqlen_metadata(seqlen),
lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
# TODO not support adapter now, need the add in the future
adapter_data=None,
hpu_attention_meta=batch.hpu_attn_meta,
**kwargs,
)
return logits, speculative_logits
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batches: List[FlashCausalLMBatch]
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
# In order to pipeline any actions on CPU we perform the operation in 3 main stages:
# Stage 1. Collect next token ids of any previously started generations
start = time.time_ns()
prev_batches = []
requests_to_generate = []
for batch_id, batch in enumerate(batches):
if batch.next_token_logits is not None:
prefill = batch.prefilling
if batch.prefilling:
batch.prefilling = False
batch.prefilling_mask = [False] * len(batch)
speculate = get_speculate()
(
next_input_ids,
next_token_logprobs,
logprobs,
accepted_ids,
speculative_ids,
) = batch.next_token_chooser(
batch.all_input_ids_tensor[
: batch.next_token_logits.shape[0], : batch.max_current_length
],
batch.next_token_logits,
speculate,
batch.speculative_ids,
batch.speculative_logits,
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
_async_h2d_tensor_copy(batch.top_n_tokens_tensor),
logprobs,
accepted_ids,
)
if batch.valid_indices is not None:
# TODO speculative decoding handling missing
index = torch.arange(
0,
len(batch.valid_indices),
device=batch.all_input_ids_tensor.device,
)
batch.all_input_ids_tensor.index_copy_(
0, index, batch.all_input_ids_tensor[batch.valid_indices]
)
padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size(
len(batch.valid_indices)
)
next_input_ids.index_copy_(
0, index, next_input_ids[batch.valid_indices]
)
next_input_ids = next_input_ids[:padded_total_bs]
next_token_logprobs.index_copy_(
0, index, next_token_logprobs[batch.valid_indices]
)
accepted_ids.index_copy_(
0, index, accepted_ids[batch.valid_indices]
)
if speculative_ids is not None:
speculative_ids = speculative_ids[batch.valid_indices]
batch.top_n_tokens_tensor = batch.top_n_tokens_tensor[
batch.valid_indices
]
top_n_tokens = []
batch_top_token_ids_v = []
batch_top_token_logprobs_v = []
for i in batch.valid_indices:
top_n_tokens.append(batch.top_n_tokens[i])
batch_top_token_ids_v.append(batch_top_token_ids[i])
batch_top_token_logprobs_v.append(batch_top_token_logprobs[i])
batch_top_token_ids = batch_top_token_ids_v
batch_top_token_logprobs = batch_top_token_logprobs_v
batch.top_n_tokens = top_n_tokens
batch.next_token_chooser = batch.next_token_chooser.filter(
batch.valid_indices
)
batch.valid_indices = None
# Since we are done prefilling, all the tensors that were concatenating values for all the requests
# instantly become of shape [BATCH_SIZE]
if prefill:
indices = batch.cu_seqlen_prefill[1:] - 1
# pad in left
if batch.prefill_cache_indices is not None:
batch.position_ids = batch.position_ids[
batch.prefill_cache_indices
][indices]
else:
batch.position_ids = batch.position_ids[indices]
batch.slot_indices = batch.slot_indices[indices[: len(batch)]]
if batch.adapter_meta is not None:
batch.adapter_meta.adapter_indices = (
batch.adapter_meta.adapter_indices[indices]
)
# For each member of the batch
# Cumulative length
if batch.speculative_logits is not None:
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
for i in range(len(batch)):
batch.all_input_ids_tensor[
i,
batch.cache_lengths[i]
+ batch.input_lengths[i] : batch.cache_lengths[i]
+ batch.input_lengths[i]
+ accepted_ids[i],
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
accepted_ids = accepted_ids.cpu()
if batch.position_ids.dim() == 2:
# Qwen2_vl case:
batch.position_ids += accepted_ids.unsqueeze(-1)
else:
batch.position_ids += accepted_ids
batch.cache_lengths_tensor += (
batch.input_lengths_tensor + accepted_ids - 1
)
batch.input_lengths_tensor = torch.ones_like(
batch.input_lengths_tensor
)
batch.slot_indices += accepted_ids[: len(batch)]
else:
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
index = F.pad(
index, (0, next_input_ids.shape[0] - index.shape[0]), value=0
)
index = index.to(batch.all_input_ids_tensor.device)
batch_idx = torch.arange(
0,
index.shape[0],
dtype=torch.long,
device=batch.all_input_ids_tensor.device,
)
batch.all_input_ids_tensor.index_put_(
(batch_idx, index.long()), next_input_ids
)
batch.input_ids = next_input_ids
batch.position_ids += 1
batch.cache_lengths_tensor += batch.input_lengths_tensor
batch.input_lengths_tensor = torch.ones_like(
batch.input_lengths_tensor
)
batch.slot_indices += 1
batch.speculative_ids = speculative_ids
# Does a HPU <-> CPU sync internally
if prefill and batch.adapter_meta is not None:
# adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments(
batch.adapter_meta.adapter_indices
)
batch.adapter_meta.adapter_segments = torch.tensor(
adapter_segments,
dtype=torch.int32,
device=batch.adapter_meta.adapter_segments.device,
)
prev_batches.append(
{
"next_token_ids": next_input_ids,
"next_token_logprobs": next_token_logprobs,
"accepted_ids": accepted_ids,
}
)
idx = len(prev_batches) - 1
for req_idx, req in enumerate(batch.requests):
new_input_length = 1
if batch.speculative_logits is not None:
new_cache_length = (
batch.cache_lengths[req_idx]
+ batch.input_lengths[req_idx]
+ accepted_ids[req_idx]
- 1
)
else:
new_cache_length = (
batch.cache_lengths[req_idx] + batch.input_lengths[req_idx]
)
batch.cache_lengths[req_idx] = new_cache_length
batch.max_input_length = max(
batch.max_input_length, new_input_length
)
batch.input_lengths[req_idx] = new_input_length
current_length = new_cache_length + new_input_length
batch.max_current_length = max(
batch.max_current_length, current_length
)
requests_to_generate.append(
{
"idx": idx,
"request_id": req.id,
"prefix_offset": batch.prefix_offsets[req_idx],
"read_offset": batch.read_offsets[req_idx],
"stopping_criteria": batch.stopping_criterias[req_idx],
"all_input_ids": batch.all_input_ids[req_idx],
"do_sample": batch.next_token_chooser.do_sample[req_idx],
"seed": batch.next_token_chooser.seeds[req_idx],
"top_n_tokens": batch.top_n_tokens[req_idx],
"top_token_ids": batch_top_token_ids[req_idx],
"top_token_logprobs": batch_top_token_logprobs[req_idx],
}
)
if prefill:
# We do not need prefill tensors anymore
batch.cu_seqlen_prefill = None
batch.prefill_cache_indices = None
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
batch.next_token_logits = None
batch.speculative_ids = None
htorch.core.mark_step()
# Stage 2. Prepare new batch for speculative scheduling
if len(batches) > 1:
if self.bucketing_ctx is not None:
total_batch_size = 0
for b in batches:
total_batch_size += len(b)
padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size(
total_batch_size
)
batch = self.batch_type.concatenate(
batches, padded_total_bs=padded_total_bs
)
else:
batch = self.batch_type.concatenate(batches)
else:
batch = batches[0]
prefill = batch.prefilling
if prefill:
if self.bucketing_ctx is not None:
batch.prepare_for_prefill(
self.bucketing_ctx.get_padded_prompt_seq_len(
batch.max_input_length
),
self.bucketing_ctx.get_padded_prompt_batch_size(len(batch)),
self.max_total_tokens,
self.tokenizer.pad_token_id,
)
else:
batch.prepare_for_prefill(
batch.max_input_length,
len(batch),
self.max_total_tokens,
self.tokenizer.pad_token_id,
)
else:
batch.prepare_for_decode(
self.dtype,
self.use_contiguous_pa,
self.bucketing_ctx,
self.tokenizer.pad_token_id,
self.sliding_window,
)
if hasattr(self, "set_inputs_embeds") and callable(self.set_inputs_embeds):
self.set_inputs_embeds(batch)
prefill_logprobs = batch.prefill_next_token_indices is not None
# Update adapter indices for speculative tokens (if present)
adapter_meta = batch.adapter_meta
if adapter_meta is not None:
if batch.speculative_ids is not None:
B, speculative_length = batch.speculative_ids.shape
new_length = speculative_length + 1
adapter_indices = (
adapter_meta.adapter_indices.unsqueeze(-1)
.expand(B, new_length)
.reshape(-1)
)
adapter_segments = adapter_meta.adapter_segments * new_length
adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_meta.adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_meta.segment_indices,
)
# Assign pointers to adapter weights
# TODO(travis): don't update this if indices haven't changed
adapter_data = AdapterBatchData.from_meta(
adapter_meta,
self.layer_to_adapter_weights,
prefill,
batch.prefill_head_indices,
)
else:
adapter_data = None
out, speculative_logits = self.forward(batch, adapter_data)
if prefill:
batch.next_token_logits = (
out[batch.prefill_next_token_indices] if prefill_logprobs else out
)
if speculative_logits is not None:
speculative_logits = (
speculative_logits[batch.prefill_next_token_indices]
if prefill_logprobs
else speculative_logits
)
else:
prefill_logprobs = None
batch.next_token_logits = out
batch.speculative_logits = speculative_logits
# HPU->CPU sync
htorch.core.mark_step()
start_decode = time.time_ns()
for prev_batch in prev_batches:
prev_batch["next_token_logprobs"] = prev_batch[
"next_token_logprobs"
].tolist()
prev_batch["next_token_ids"] = prev_batch["next_token_ids"].tolist()
prev_batch["accepted_ids"] = prev_batch["accepted_ids"].tolist()
htorch.core.mark_step()
# Stage 3. Finish and return previous generations
# Results
generations: List[Generation] = []
stopped = len(requests_to_generate) > 0
# Reset max_input_length
batch.max_input_length = 0
# For each member of the batch
indexs = [0] * len(prev_batches)
idx_accept_ids = [0] * len(prev_batches)
for i, req_data in enumerate(requests_to_generate):
idx = req_data["idx"]
request_id = req_data["request_id"]
prefix_offset = req_data["prefix_offset"]
read_offset = req_data["read_offset"]
stopping_criteria = req_data["stopping_criteria"]
all_input_ids = req_data["all_input_ids"]
do_sample = req_data["do_sample"]
seed = req_data["seed"]
top_n_tokens = req_data["top_n_tokens"]
n_accepted_ids = prev_batches[idx]["accepted_ids"][idx_accept_ids[idx]]
top_token_ids = req_data["top_token_ids"]
top_token_logprobs = req_data["top_token_logprobs"]
# Append next token to all tokens
next_token_texts = []
left = 0
if n_accepted_ids > 1:
log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")
current_stopped = False
index = indexs[idx]
for j in range(index, index + n_accepted_ids):
# Generated token
next_token_id = prev_batches[idx]["next_token_ids"][j]
all_input_ids.append(next_token_id)
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids,
prefix_offset,
read_offset,
)
next_token_texts.append(next_token_text)
stop, reason = stopping_criteria(
next_token_id,
next_token_text,
)
if stop:
left = index + n_accepted_ids - j - 1
current_stopped = True
break
else:
current_stopped = False
stopped = stopped and current_stopped
_next_token_ids = prev_batches[idx]["next_token_ids"][
index : index + n_accepted_ids - left
]
_next_token_logprobs = prev_batches[idx]["next_token_logprobs"][
index : index + n_accepted_ids - left
]
# Shard generations
# All generations will be appended in the rust sharded client
if request_id % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text, _, _ = self.decode_token(
all_input_ids,
prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens
- 1,
read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
)
generated_text = GeneratedText(
output_text,
stopping_criteria.current_tokens,
reason,
seed if do_sample else None,
)
else:
generated_text = None
if top_n_tokens > 0:
all_top_tokens = []
for top_token_ids, top_token_logprobs in zip(
top_token_ids, top_token_logprobs
):
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids
for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
all_top_tokens.append(top_tokens)
top_tokens = all_top_tokens
else:
top_tokens = None
generation = Generation(
request_id,
None,
Tokens(
_next_token_ids,
_next_token_logprobs,
next_token_texts,
[nid in self.all_special_ids for nid in _next_token_ids],
),
generated_text,
top_tokens,
)
generations.append(generation)
# accept each new token for this specific request since we may
# have more than one new token per request with speculative decoding
for next_token_id in _next_token_ids:
batch.next_token_chooser = (
batch.next_token_chooser.advance_grammar_single(
i, next_token_id
)
)
# Update values
indexs[idx] += n_accepted_ids
idx_accept_ids[idx] += 1
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids
htorch.core.mark_step()
if stopped:
# No need to return a batch if we know that all requests stopped
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)
================================================
FILE: backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py
================================================
import torch
from PIL import Image
from io import BytesIO
from dataclasses import dataclass
from opentelemetry import trace
from typing import Iterable, Optional, Tuple, List, Type, Dict
from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch,
FlashCausalLM,
generate_block_metadata,
)
from text_generation_server.models.globals import PREFIX_CACHING, BLOCK_SIZE
from loguru import logger
from text_generation_server.utils.log import log_master
from transformers import AutoProcessor
from text_generation_server.layers.attention import (
Seqlen,
trim_seqlen_metadata,
_async_h2d_tensor_copy,
HPUPagedAttentionMetadata,
trim_attn_metadata,
)
import habana_frameworks.torch as htorch
import time
from text_generation_server.utils.import_utils import (
synchronize,
)
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
tracer = trace.get_tracer(__name__)
IDEFICS2_FAKE_TOKEN = ""
IDEFICS2_IMAGE_TOKEN = ""
IDEFICS3_IMAGE_TOKEN = ""
IDEFICS3_FAKE_IMAGE_TOKEN = ""
IDEFICS3_GLOBAL_IMG_TOKEN = ""
def prompt_split_image_llama4(aspect_ratio, num_patches_per_chunk):
"""
Create a structured string representation of image tokens
Args:
num_patches: Number of patches in the image
Returns:
String with appropriate image tokens
"""
img_string = "<|image_start|>"
ratio_h, ratio_w = aspect_ratio
if ratio_h * ratio_w > 1:
for yy in range(ratio_h):
for xx in range(ratio_w):
img_string += "<|patch|>" * num_patches_per_chunk
if xx < ratio_w - 1:
img_string += "<|tile_x_separator|>"
img_string += "<|tile_y_separator|>"
img_string += "<|image|>"
img_string += "<|patch|>" * num_patches_per_chunk
img_string += "<|image_end|>"
return img_string
# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60
def _prompt_split_image(
*,
image_seq_len: int,
image_rows: int,
image_cols: int,
fake_token_around_image: str,
image_token: str,
global_img_token: str,
):
"""Prompt with expanded image tokens for when the image is split into patches."""
text_split_images = ""
for n_h in range(image_rows):
for n_w in range(image_cols):
text_split_images += (
f"{fake_token_around_image}"
+ f""
+ f"{image_token}" * image_seq_len
)
text_split_images += "\n"
text_split_images += (
f"\n{fake_token_around_image}"
+ f"{global_img_token}"
+ f"{image_token}" * image_seq_len
+ f"{fake_token_around_image}"
)
return text_split_images
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (`tuple`):
The size of the input image in the format (height, width).
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists")
height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size
def image_text_replacement(processor, image_input, config) -> str:
if config.model_type == "idefics2":
image_seq_len = 64
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
if processor.image_processor.do_image_splitting:
image_str *= 5
return image_str, IDEFICS2_FAKE_TOKEN
if config.model_type == "idefics3":
# TODO: implement this in a more general way
n_rows = image_input["rows"][0][0]
n_cols = image_input["cols"][0][0]
image_seq_len = int(
((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
/ (config.scale_factor**2)
)
image_str = _prompt_split_image(
image_seq_len=image_seq_len,
image_rows=n_rows,
image_cols=n_cols,
fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN,
image_token=IDEFICS3_IMAGE_TOKEN,
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
)
return image_str, IDEFICS3_FAKE_IMAGE_TOKEN
elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][0]
num_features = get_number_of_features(height, width, config)
log_master(
logger.info,
f"Found {num_features} features in image of resolution {height}x{width}",
)
return "" * num_features, ""
elif config.model_type == "paligemma":
return "" * config.text_config.num_image_tokens, ""
elif config.model_type == "qwen2_vl":
grid_t, grid_h, grid_w = image_input["image_grid_thw"][0]
num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
elif config.model_type == "qwen2_5_vl":
grid_t, grid_h, grid_w = image_input["image_grid_thw"][0]
num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>", "<|vision_start|>"
elif config.model_type == "gemma3":
# TODO: get correct number of features via reviewing the Gemma3 architecture
# and calculating the number of image tokens
num_pads = 256
padding = "" * num_pads
return f"\n\n{padding}\n\n", ""
elif config.model_type == "llama4":
patch_size = config.vision_config.patch_size
pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio
downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
aspect_ratios = image_input["aspect_ratios"][0]
image_height, image_width = image_input["pixel_values"][0].shape[-2:]
num_patches_per_chunk = int(
(image_height // patch_size)
* (image_width // patch_size)
// downsample_ratio
)
tokens_for_this_image = prompt_split_image_llama4(
aspect_ratios, num_patches_per_chunk
)
return tokens_for_this_image, "<|image_start|>"
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def image_text_replacement_fixup(config, text: str) -> str:
if config.model_type == "idefics2":
return text.replace(
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
)
return text
def preprocess_text(config, text: str) -> str:
if config.model_type == "paligemma":
return "" + text + "\n"
return text
def preprocess_image(config, img):
model_type = config.model_type
if model_type in {"qwen2_vl", "qwen2_5_vl"} and img.width <= 20:
img = img.resize((img.width * 2, img.height * 2))
if model_type == "paligemma":
img = img.convert("RGB")
if model_type not in {"llava_next", "gemma3", "llama4"}:
# TODO: check if this is needed
img = [img]
return img
def get_unpadded_features(
original_height: int,
original_width: int,
npatches: int,
num_patch_height: int,
num_patch_width: int,
) -> Tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
aspect_ratio: float = original_width / original_height
current_aspect_ratio: float = current_width / current_height
if aspect_ratio > current_aspect_ratio:
new_height = (original_height * current_width) // original_width
padding = (current_height - new_height) // 2
current_height = current_height - (2 * padding)
else:
new_width = (original_width * current_height) // original_height
padding = (current_width - new_width) // 2
current_width = current_width - (2 * padding)
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)
def get_number_of_features(height: int, width: int, config) -> int:
# From config
# Hardcoded for CLIP for now
# image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
image_grid_pinpoints = config.image_grid_pinpoints
image_size = config.vision_config.image_size
patch_size = config.vision_config.patch_size
assert image_size % patch_size == 0
npatches = image_size // patch_size
# Dimensions are intentionally swapped to be bug-compatible with
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
[height, width],
image_grid_pinpoints,
image_size,
)
unpadded_features, newline_features = get_unpadded_features(
height, width, npatches, num_patch_height, num_patch_width
)
# The base patch covers the entire image
base_features = npatches**2
return unpadded_features + newline_features + base_features
def scatter_image_embeds(
embeds: torch.Tensor, is_embed: Optional[torch.Tensor]
) -> torch.Tensor:
if is_embed is None:
return embeds
placeholders = embeds.new_full(
(is_embed.shape[0], embeds.shape[-1]),
fill_value=torch.nan,
)
placeholders[is_embed.to(embeds.device)] = embeds
return placeholders
def gather_image_embeds(
embeds: torch.Tensor, is_embed: Optional[torch.Tensor]
) -> Optional[torch.Tensor]:
if is_embed is None:
return embeds
sel = embeds[is_embed.to(embeds.device)]
return sel if sel.numel() else None
@dataclass
class ImagePositions:
offset: int
length: int
id: int
num_placeholder_tokens: int
is_embed: Optional[torch.Tensor] = None
class FlashVlmCausalLMBatch(FlashCausalLMBatch):
image_inputs: Optional[List[List[Dict[str, torch.Tensor]]]]
image_positions: Optional[List[List[ImagePositions]]]
encoder_cache: Optional[List[Dict[int, torch.Tensor]]]
pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
image_grid_thw: Optional[torch.Tensor]
cache_entries_to_free: List[Tuple[int, int]]
has_image_inputs: bool = False
inputs_embeds: Optional[torch.Tensor] = None
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches, padded_total_bs: int = 0):
batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs)
batch.image_inputs = []
batch.image_positions = []
batch.encoder_cache = []
for b in batches:
if b.image_inputs is not None:
batch.image_inputs.extend(b.image_inputs)
else:
batch.image_inputs.append(None)
if b.image_positions is not None:
batch.image_positions.extend(b.image_positions)
else:
batch.image_positions.append(None)
if b.encoder_cache is not None:
batch.encoder_cache.extend(b.encoder_cache)
else:
batch.encoder_cache.append(None)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
batch.inputs_embeds = None
# To be filled in prepare_for_prefill
batch.has_image_inputs = False
batch.cache_entries_to_free = []
return batch
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
image_inputs = []
image_positions = []
encoder_cache = []
for request_id in request_ids:
idx = self.requests_idx_mapping[request_id]
image_inputs.append(self.image_inputs[idx])
image_positions.append(self.image_positions[idx])
encoder_cache.append(self.encoder_cache[idx])
batch = super().filter(request_ids)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
batch.inputs_embeds = None
batch.image_inputs = image_inputs
batch.image_positions = image_positions
batch.encoder_cache = encoder_cache
# To be filled in prepare_for_prefill
batch.has_image_inputs = False
batch.cache_entries_to_free = []
return batch
@classmethod
def batch_tokenized_inputs(
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
):
kwargs = {}
if (
hasattr(processor, "image_processor_class")
and processor.image_processor_class == "Idefics3ImageProcessor"
):
kwargs["return_row_col_info"] = True
max_length = 0
vocab = tokenizer.get_vocab()
if not hasattr(config, "image_token_index"):
config.image_token_index = config.image_token_id
batch_tokenized_inputs: List[List[int]] = []
batch_image_inputs: List[Optional[List[dict]]] = []
batch_image_positions: List[Optional[List[ImagePositions]]] = []
for r in requests:
text_parts = []
image_inputs = []
image_texts = []
image_id = 0
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
text = preprocess_text(config, chunk.text)
text_parts.append(text)
elif chunk_type == "image":
img = Image.open(BytesIO(chunk.image.data))
img = preprocess_image(config, img)
image_input = processor.image_processor(
[img], return_tensors="pt", **kwargs
)
image_inputs.append(image_input)
img_text, img_start_token_str = image_text_replacement(
processor, image_input, config
)
text_parts.append(img_text)
image_texts.append([image_id, img_start_token_str, img_text])
image_id += 1
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
full_text = image_text_replacement_fixup(config, "".join(text_parts))
input_ids = tokenizer(
full_text,
truncation=True,
max_length=r.truncate,
add_special_tokens=(
r.add_special_tokens if config.model_type != "paligemma" else False
),
)["input_ids"]
max_length = max(max_length, len(input_ids))
if len(image_inputs) > 0:
img_start_token = vocab[image_texts[0][1]]
image_positions = cls.get_image_positions(
input_ids, image_texts, img_start_token, config, tokenizer
)
else:
image_inputs = None
image_positions = None
batch_tokenized_inputs.append(input_ids)
batch_image_inputs.append(image_inputs)
batch_image_positions.append(image_positions)
return batch_tokenized_inputs, batch_image_inputs, batch_image_positions
@classmethod
def get_image_positions(
cls,
input_ids: List[int],
image_texts: List[Tuple[int, str, str]],
img_start_token: int,
config,
tokenizer: PreTrainedTokenizerBase,
) -> List[ImagePositions]:
image_positions = []
num_images = len(image_texts)
input_ids_t = torch.as_tensor(input_ids)
img_start_token_pos = torch.where(input_ids_t.eq(img_start_token))[0]
num_tokens = input_ids_t.numel()
last_pos = 0
for i in range(num_images):
image_id, img_start_token_str, img_text = image_texts[i]
img_text = image_text_replacement_fixup(config, img_text)
if config.model_type == "gemma3":
img_text = img_text.replace("\n\n", "")
tokens = tokenizer(img_text, add_special_tokens=False, return_tensors="pt")[
"input_ids"
][0]
length = tokens.numel()
assert (
length <= num_tokens
), f"{length} > {num_tokens} Image is truncated, try increasing --max-batch-prefill-tokens"
pos = torch.searchsorted(img_start_token_pos, last_pos, right=False)
index = img_start_token_pos[pos]
assert torch.equal(
input_ids_t[index : index + length], tokens
), "Image tokens not found in input_ids"
is_embed = tokens == config.image_token_index
num_placeholder_tokens = int(is_embed.sum())
if num_placeholder_tokens == length:
is_embed = None
pos = ImagePositions(
offset=index,
length=length,
id=image_id,
num_placeholder_tokens=num_placeholder_tokens,
is_embed=is_embed,
)
image_positions.append(pos)
last_pos = index + length
if (
config.model_type == "idefics2"
and i + 1 != num_images
and input_ids[last_pos] == config.image_token_index
):
fake_token = last_pos - 1
fake_token_index = torch.searchsorted(
img_start_token_pos, fake_token, right=False
)
img_start_token_pos[fake_token_index] = last_pos
image_texts[i + 1][2] = image_texts[i + 1][2][
len(img_start_token_str) :
]
return image_positions
@classmethod
def from_pb_processor(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
processor,
config,
dtype: torch.dtype,
device: torch.device,
) -> "FlashVlmCausalLMBatch":
batch_tokenized_inputs, image_inputs, image_positions = (
cls.batch_tokenized_inputs(pb.requests, tokenizer, processor, config)
)
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
batch.image_inputs = image_inputs
batch.image_positions = image_positions
batch.encoder_cache = [{} for _ in range(len(pb.requests))]
if len(image_inputs):
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
return batch
def prepare_for_prefill(
self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
):
super().prepare_for_prefill(
max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
)
self.has_image_inputs = False
self.cache_entries_to_free = []
self.pixel_values = []
assert (
len(self.cache_lengths)
== len(self.input_lengths)
== len(self.prefilling_mask)
), "Mismatch in lengths of cache_lengths, input_lengths, and prefilling_mask"
for i, (
cache_length,
input_length,
request_prefilling,
) in enumerate(
zip(
self.cache_lengths,
self.input_lengths,
self.prefilling_mask,
)
):
if not request_prefilling or self.image_positions[i] is None:
continue
for image_position in self.image_positions[i]:
if image_position is None:
continue
start_pos = image_position.offset
length = image_position.length
if start_pos >= cache_length + input_length:
# No encoder input required at this step
break
if start_pos + length <= cache_length:
# The encode input is already processed
continue
self.has_image_inputs = True
if image_position.id not in self.encoder_cache[i]:
image_inputs = self.image_inputs[i][image_position.id]
self.pixel_values.append((i, image_position.id, image_inputs))
# Remove the image from the image_inputs
self.image_inputs[i][image_position.id] = None
if not self.has_image_inputs:
self.pixel_values = None
self.pixel_attention_mask = None
self.image_sizes = None
self.image_grid_thw = None
else:
image_grid_thw_list = [
x[2]["image_grid_thw"]
for x in self.pixel_values
if "image_grid_thw" in x[2]
]
if image_grid_thw_list:
self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0)
else:
self.image_grid_thw = None
def update_encoder_cache(self, encoder_outputs, request_id, img_pos):
self.encoder_cache[request_id][img_pos.id] = scatter_image_embeds(
encoder_outputs, img_pos.is_embed
)
def gather_vision_embeds(self):
device = self.input_ids.device
chunks = []
for (
i,
cache_length,
input_length,
request_prefilling,
) in zip(
range(len(self.requests)),
self.cache_lengths,
self.input_lengths,
self.prefilling_mask,
):
if not request_prefilling or self.image_positions[i] is None:
continue
for image_position in self.image_positions[i]:
if image_position is None:
continue
start_pos = image_position.offset
length = image_position.length
if start_pos >= cache_length + input_length:
# No encoder input required at this step
break
if start_pos + length <= cache_length:
# The encode input is already processed
continue
start_idx = max(cache_length - start_pos, 0)
end_idx = min(cache_length - start_pos + input_length, length)
assert (
image_position.id in self.encoder_cache[i]
), f"image_id {image_position.id} not in encoder_cache {self.encoder_cache[i]}"
encoder_output = self.encoder_cache[i][image_position.id]
is_embed = image_position.is_embed
if is_embed is not None:
is_embed = is_embed[start_idx:end_idx]
from loguru import logger
logger.info(
f"image_id {image_position.id} start_idx {start_idx} end_idx {end_idx}, length {length}"
)
embeds = gather_image_embeds(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
if embeds is not None:
chunks.append(embeds)
if end_idx == length:
self.cache_entries_to_free.append((i, image_position.id))
self.image_positions[i][image_position.id] = None
if len(chunks) == 0:
return None
return torch.cat(chunks, dim=0).to(device)
def free_encoder_cache(self):
for i, image_id in self.cache_entries_to_free:
self.encoder_cache[i].pop(image_id, None)
self.cache_entries_to_free = []
class FlashVlmCausalLM(FlashCausalLM):
def __init__(
self,
model_id: str,
*,
processor_class=AutoProcessor,
processor_kwargs=None,
batch_class=FlashVlmCausalLMBatch,
revision,
trust_remote_code: bool,
support_chunking: bool = False,
**kwargs,
):
if PREFIX_CACHING:
raise NotImplementedError("Vlm do not work with prefix caching yet")
if processor_kwargs is None:
processor_kwargs = {}
self.processor = processor_class.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
**processor_kwargs,
)
self.batch_class = batch_class
super().__init__(
model_id=model_id,
revision=revision,
trust_remote_code=trust_remote_code,
support_chunking=support_chunking,
**kwargs,
)
@property
def batch_type(self) -> Type[FlashVlmCausalLMBatch]:
return self.batch_class
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)
def warmup_decode(
self, batch_size: int, block_num: int, batch: FlashVlmCausalLMBatch
):
input_ids = torch.zeros(batch_size, dtype=batch.input_ids.dtype)
position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype)
if batch.position_ids is not None and batch.position_ids.dim() == 2:
# qwen2_vl and qwen2_5_vl case
position_ids = position_ids.unsqueeze(-1).repeat(
(1, batch.position_ids.shape[-1])
)
blocks = [block_num // batch_size for _ in range(batch_size)]
blocks[0] += block_num % batch_size
block_tables = []
slots = []
start_idx = 0
slot_indices = []
# fetch the last blocked to warmup block num
for i in range(batch_size):
block_array = list(range(start_idx, start_idx + blocks[i]))
slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
block_tables.append(block_array)
slot_indices.append((start_idx + blocks[i]) * BLOCK_SIZE - 1)
start_idx += blocks[i]
input_lengths = torch.ones(batch_size, dtype=torch.int32)
seqlen = Seqlen(
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
block_list, block_groups, block_usage, _, block_bucket_size = (
generate_block_metadata(
self.dtype,
self.use_contiguous_pa,
slots,
block_tables,
self.bucketing_ctx,
)
)
meta = HPUPagedAttentionMetadata(
block_list=_async_h2d_tensor_copy(block_list),
block_groups=_async_h2d_tensor_copy(block_groups),
block_usage=_async_h2d_tensor_copy(block_usage),
block_mapping=None,
attn_bias=None,
)
if self.sliding_window is not None:
block_tables_in_window = []
for i, bt in enumerate(block_tables):
block_num_in_window = (
self.sliding_window + BLOCK_SIZE - 1
) // BLOCK_SIZE
block_tables_in_window.append(
bt[max(0, blocks[i] - block_num_in_window) : blocks[i]]
)
slots_in_window = []
start_idx = 0
for i, indice in enumerate(slot_indices):
mask = (
indice - torch.arange(start_idx, indice + 1)
) < self.sliding_window
slots_in_window.append(torch.arange(start_idx, indice + 1)[mask])
start_idx += blocks[i] * BLOCK_SIZE
slots_in_window = torch.cat(slots_in_window, dim=0)
(
block_list_in_window,
block_groups_in_window,
block_usage_in_window,
slots_in_window_mask,
_,
) = generate_block_metadata(
self.dtype,
self.use_contiguous_pa,
slots,
block_tables_in_window,
self.bucketing_ctx,
slots_in_window,
block_bucket_size,
)
meta.block_list_in_window = _async_h2d_tensor_copy(block_list_in_window)
meta.block_groups_in_window = _async_h2d_tensor_copy(block_groups_in_window)
meta.block_usage_in_window = _async_h2d_tensor_copy(block_usage_in_window)
meta.slots_in_window_mask = _async_h2d_tensor_copy(slots_in_window_mask)
hpu_attention_meta = trim_attn_metadata(meta)
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
inputs_embeds = self.get_inputs_embeds(
input_ids=input_ids.to(self.device),
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
inputs_embeds=inputs_embeds,
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
slots=_async_h2d_tensor_copy(slots_tensor),
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=hpu_attention_meta,
lm_head_indices=None,
attention_mask=None,
)
def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):
free_mem = HabanaMemoryProfiler.current_free_device_memory()
graph_free_mem = free_mem - self.mem_reserved
graph_free_mem = self.align_workers(
graph_free_mem, torch.distributed.ReduceOp.MIN
)
decode_available_memory = graph_free_mem
msg = (
f"Using {format_bytes(graph_free_mem)}"
f"/{format_bytes(free_mem)} "
"of free device memory for HPUGraphs, "
f"{format_bytes(decode_available_memory)} for decode "
)
log_master(logger.info, msg)
start_time = time.time()
warmup_shape_count = 0
warmup_times = 3
# only warmup decode, for prefill, image pixal size may change, make the warmup useless
def ordering_function_max_bs(b):
return (-b[0], b[1])
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
buckets = list(
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
)
total_batch_seq = 0.001
total_mem = 0
available_mem = decode_available_memory
log_master(
logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n"
)
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num:
continue
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, block_num, False)
if not mem_estimate >= available_mem:
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
self.log_warmup(False, i, len(buckets), batch_size, block_num)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
)
if graphed_bucket in self.graphed_buckets:
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
log_master(logger.info, "Decode warmup successful.\n")
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
)
def get_vision_embeds(
self,
pixel_values: torch.Tensor,
pixel_attention_mask: torch.Tensor,
image_sizes: torch.Tensor,
image_grid_thw: torch.Tensor,
):
embeds = self.model.get_vision_embeds(
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
image_sizes=image_sizes,
image_grid_thw=image_grid_thw,
)
return embeds
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: Optional[torch.Tensor] = None,
):
return self.model.get_inputs_embeds(
input_ids=input_ids,
vision_embeds=vision_embeds,
)
def encode_images(self, batch):
if batch.pixel_values is not None:
device = batch.input_ids.device
for request_id, image_id, image_input in batch.pixel_values:
pixel_values = image_input["pixel_values"].to(device)
if "pixel_attention_mask" in image_input:
pixel_attention_mask = image_input["pixel_attention_mask"].to(
device
)
else:
pixel_attention_mask = None
if "image_sizes" in image_input:
image_sizes = image_input["image_sizes"].to(device)
else:
image_sizes = None
if "image_grid_thw" in image_input:
image_grid_thw = image_input["image_grid_thw"]
else:
image_grid_thw = None
encoder_outputs = self.get_vision_embeds(
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
image_sizes=image_sizes,
image_grid_thw=image_grid_thw,
)
batch.update_encoder_cache(
encoder_outputs,
request_id,
batch.image_positions[request_id][image_id],
)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
def set_inputs_embeds(self, batch):
if batch.has_image_inputs:
self.encode_images(batch)
vision_embeds = batch.gather_vision_embeds()
batch.has_image_inputs = False
else:
vision_embeds = None
inputs_embeds = self.get_inputs_embeds(
batch.input_ids, vision_embeds=vision_embeds
)
batch.inputs_embeds = inputs_embeds
def forward(
self,
batch: FlashVlmCausalLMBatch,
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
# Add Copy the block tables for all members
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids = batch.input_ids
inputs_embeds = batch.inputs_embeds
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
if position_ids.dim() == 1 and batch.prefilling:
position_ids = self.model.get_position_ids(
input_ids.cpu(), batch.image_grid_thw
)
batch.position_ids = position_ids
attention_mask = None
attention_mask_forward = None
if self.model.config.model_type == "llama4":
attention_mask = (input_ids != self.tokenizer.pad_token_id).long()
attention_mask_forward = attention_mask.view(input_lengths.shape[0], -1)
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
else:
slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[: slots.shape[0]] = slots
slots = slots_pad
seqlen = Seqlen(
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
kwargs = {}
batch_size = input_lengths.shape[0]
prompt_len = (
input_ids.shape[0] // batch_size
if batch.prefilling
else batch.hpu_attn_meta.block_list.shape[0]
)
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
batch.prefilling, prompt_len, batch_size
)
if self.sliding_window is not None:
attn_mask = seqlen.make_sliding_window_bias(
input_lengths.tolist(),
self.sliding_window,
self.dtype,
prompt_len,
batch_size,
)
seqlen.attn_mask = _async_h2d_tensor_copy(attn_mask)
logits, speculative_logits = self.model.forward(
inputs_embeds=inputs_embeds,
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=kv_cache,
slots=_async_h2d_tensor_copy(slots),
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=batch.hpu_attn_meta,
lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
attention_mask=attention_mask_forward,
**kwargs,
)
batch.image_grid_thw = None
batch.free_encoder_cache()
return logits, speculative_logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/globals.py
================================================
import os
from typing import Dict, Optional
from loguru import logger
from text_generation_server.utils.log import log_master
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
ATTENTION = os.getenv("ATTENTION", "paged")
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in {
"1",
"true",
}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
_expected = {"paged"}
assert (
ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}"
log_master(logger.info, f"Using Attention = {ATTENTION}")
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
assert TGI_WIGGLE_ROOM > 0
assert TGI_WIGGLE_ROOM < 1
# This is overridden by the cli
BLOCK_SIZE: int
BLOCK_SIZE = 128
# This is overridden at model loading.
global MODEL_ID
MODEL_ID = None
def set_model_id(model_id: str):
global MODEL_ID
MODEL_ID = model_id
# NOTE: eventually we should move this into the router and pass back the
# index in all cases.
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
global ADAPTER_TO_INDEX
ADAPTER_TO_INDEX = adapter_to_index
def get_adapter_to_index():
global ADAPTER_TO_INDEX
return ADAPTER_TO_INDEX
================================================
FILE: backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py
================================================
import torch
import numpy as np
from typing import Iterable, Optional, Tuple, List, Dict
from text_generation_server.pb.generate_pb2 import Request
from io import BytesIO
from PIL import Image
from dataclasses import dataclass
from opentelemetry import trace
from transformers import (
PreTrainedTokenizerBase,
)
from text_generation_server.models.flash_causal_lm import (
generate_block_metadata,
)
from text_generation_server.models.flash_vlm_causal_lm import (
FlashVlmCausalLMBatch,
FlashVlmCausalLM,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.layers.attention import (
Seqlen,
trim_seqlen_metadata,
_async_h2d_tensor_copy,
HPUPagedAttentionMetadata,
trim_attn_metadata,
)
import habana_frameworks.torch as htorch
from loguru import logger
from text_generation_server.models.globals import BLOCK_SIZE
from text_generation_server.utils.import_utils import (
synchronize,
)
import torch.nn.functional as F
from text_generation_server.utils.log import log_master
import time
import os
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
tracer = trace.get_tracer(__name__)
@dataclass
class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
image_indices: List[int] = 42
aspect_ratio_ids: Optional[torch.Tensor] = None
aspect_ratio_mask: Optional[torch.Tensor] = None
cross_attention_states: Optional[torch.Tensor] = None
def prepare_for_prefill(
self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
):
super(FlashVlmCausalLMBatch, self).prepare_for_prefill(
max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches, padded_total_bs: int = 0):
batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs)
batch.pixel_values = None
batch.pixel_attention_mask = None
offset = 0
image_indices = []
attention_states = []
for b in batches:
if b.cross_attention_states is not None:
attention_states.append(b.cross_attention_states)
image_indices.extend([i + offset for i in b.image_indices])
offset += len(b.image_indices)
if len(attention_states) > 0:
assert len(image_indices) > 0
batch.cross_attention_states = torch.cat(attention_states, dim=0)
batch.image_indices = image_indices
else:
batch.cross_attention_states = None
batch.image_indices = []
return batch
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
assert self.image_indices is not None
batch = super(FlashVlmCausalLMBatch, self).filter(request_ids)
assert self.image_indices is not None
indices = []
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)
offset = 0
new_image_indices = []
prev_i = None
for i in self.image_indices:
if i in indices:
new_image_indices.append(offset)
if i != prev_i:
offset += 1
prev_i = i
batch.image_indices = new_image_indices
if len(new_image_indices) > 0:
assert max(new_image_indices) < self.cross_attention_states.shape[0]
assert offset <= self.cross_attention_states.shape[0]
batch.cross_attention_states = self.cross_attention_states[
new_image_indices
]
else:
batch.cross_attention_states = None
batch.pixel_values = None
return batch
@classmethod
def batch_tokenized_inputs(
cls, requests: Iterable[Request], tokenizer, processor, config
):
image_inputs = []
texts = []
image_indices = []
batch_tokenized_inputs = []
for i, r in enumerate(requests):
# Each input is encoded into a list, where each element of this input list is either a string or a URL
curr_text = ""
curr_image = None
curr_i = None
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
curr_text += chunk.text
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
# TODO unsure about BOS
curr_text += "<|image|>"
image_input = processor.image_processor(image, return_tensors="pt")
curr_image = image_input
curr_i = i
# image_inputs.append(image_input)
# image_indices.append(i)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
texts.append(curr_text)
if curr_image is not None:
image_inputs.append(curr_image)
image_indices.append(curr_i)
input_ids = tokenizer(
curr_text,
truncation=True,
max_length=r.truncate,
add_special_tokens=r.add_special_tokens,
)["input_ids"]
batch_tokenized_inputs.append(input_ids)
if image_inputs:
image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
}
if "aspect_ratio_ids" in image_input:
new_image_inputs["aspect_ratio_ids"] = torch.cat(
[img["aspect_ratio_ids"] for img in image_inputs], dim=0
)
if "aspect_ratio_mask" in image_input:
new_image_inputs["aspect_ratio_mask"] = torch.cat(
[img["aspect_ratio_mask"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
image_inputs["image_indices"] = image_indices
else:
image_inputs = None
if image_inputs is not None:
assert len(image_indices) == image_inputs["pixel_values"].shape[0]
return batch_tokenized_inputs, image_inputs
@classmethod
def from_pb_processor(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
processor,
config,
dtype: torch.dtype,
device: torch.device,
) -> "FlashVlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor, config
)
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
# XXX: <|image|> token is actually out of bounds and bugs out the logit processors.
batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
max=config.text_config.vocab_size - 1
)
if isinstance(batch.input_ids, list):
if len(batch) > 1:
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
else:
input_ids = batch.input_ids[0]
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(
device=device, dtype=dtype
)
batch.aspect_ratio_ids = image_inputs["aspect_ratio_ids"].to(device=device)
batch.aspect_ratio_mask = image_inputs["aspect_ratio_mask"].to(
device=device
)
batch.image_indices = image_inputs["image_indices"]
else:
batch.pixel_values = None
batch.aspect_ratio_ids = None
batch.aspect_ratio_mask = None
batch.image_indices = []
assert batch.image_indices is not None
return batch
def generate_cross_attention_states(
cross_attention_states, image_indices, input_lengths, pad_seq_len, prefilling
):
if cross_attention_states is None:
return None, None
indices_list = []
if prefilling:
for i in image_indices:
indices_list.append(torch.arange(pad_seq_len * i, pad_seq_len * (i + 1)))
indices = torch.cat(indices_list, dim=0)
else:
indices = image_indices[:]
return indices, input_lengths.index_select(0, image_indices)
class FlashMllamaCausalLM(FlashVlmCausalLM):
def set_inputs_embeds(self, batch):
# Set the input embeddings to None, as we are using the input_ids for the model
batch.inputs_embeds = None
def warmup_decode(
self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch
):
input_ids = torch.zeros(batch_size, dtype=batch.input_ids.dtype)
position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype)
blocks = [block_num // batch_size for _ in range(batch_size)]
blocks[0] += block_num % batch_size
block_tables = []
slots = []
start_idx = 0
slot_indices = []
# fetch the last blocked to warmup block num
for i in range(batch_size):
block_array = list(range(start_idx, start_idx + blocks[i]))
slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)
block_tables.append(block_array)
slot_indices.append((start_idx + blocks[i]) * BLOCK_SIZE - 1)
start_idx += blocks[i]
input_lengths = torch.ones(batch_size, dtype=torch.int32)
seqlen = Seqlen(
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
block_list, block_groups, block_usage, _, block_bucket_size = (
generate_block_metadata(
self.dtype,
self.use_contiguous_pa,
slots,
block_tables,
self.bucketing_ctx,
)
)
meta = HPUPagedAttentionMetadata(
block_list=_async_h2d_tensor_copy(block_list),
block_groups=_async_h2d_tensor_copy(block_groups),
block_usage=_async_h2d_tensor_copy(block_usage),
block_mapping=None,
attn_bias=None,
)
hpu_attention_meta = trim_attn_metadata(meta)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
image_indices = torch.tensor(batch.image_indices)
image_indices = image_indices.repeat(batch_size)
cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)
indices, cross_attention_len = generate_cross_attention_states(
cross_attention_states, image_indices, input_lengths, 1, False
)
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
False, hpu_attention_meta.block_list.shape[0], batch_size
)
self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
slots=_async_h2d_tensor_copy(slots_tensor),
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=hpu_attention_meta,
lm_head_indices=None,
adapter_data=None,
cross_attention_states=cross_attention_states,
indices=_async_h2d_tensor_copy(indices),
cross_attention_len=_async_h2d_tensor_copy(cross_attention_len),
**kwargs,
)
def warmup_prefill(
self, prompt_len: int, batch_size: int, batch: FlashMllamaCausalLMBatch
):
input_ids = torch.zeros(prompt_len, dtype=batch.input_ids.dtype).repeat(
batch_size
)
position_ids = torch.arange(prompt_len, dtype=batch.position_ids.dtype).repeat(
batch_size
)
max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size
block_tables = torch.arange(max_bt, dtype=torch.int32).reshape(batch_size, -1)
slot_acc = []
for i in range(batch_size):
slots = []
for b in block_tables[i]:
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
slot_acc.extend(slots[:prompt_len])
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype)
input_lengths = (
torch.ones(
batch_size,
dtype=torch.int32,
)
* prompt_len
)
cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32)
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
lm_head_indices = input_lengths - 1
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
image_indices = torch.tensor(batch.image_indices)
image_indices = image_indices.repeat(batch_size)
cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)
indices, cross_attention_len = generate_cross_attention_states(
cross_attention_states, image_indices, input_lengths, prompt_len, True
)
seqlen = Seqlen(
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
True, prompt_len, batch_size
)
self.model.forward(
input_ids=_async_h2d_tensor_copy(input_ids),
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=self.kv_cache,
slots=_async_h2d_tensor_copy(slots),
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=None,
lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
adapter_data=None,
cross_attention_states=cross_attention_states,
indices=_async_h2d_tensor_copy(indices),
cross_attention_len=_async_h2d_tensor_copy(cross_attention_len),
**kwargs,
)
def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3"))
free_mem = HabanaMemoryProfiler.current_free_device_memory()
graph_free_mem = free_mem - self.mem_reserved
graph_free_mem = self.align_workers(
graph_free_mem, torch.distributed.ReduceOp.MIN
)
prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
decode_available_memory = graph_free_mem - prompt_available_memory
msg = (
f"Using {format_bytes(graph_free_mem)}"
f"/{format_bytes(free_mem)} "
"of free device memory for HPUGraphs, "
f"{format_bytes(prompt_available_memory)} for prompt and "
f"{format_bytes(decode_available_memory)} for decode "
f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})"
)
log_master(logger.info, msg)
start_time = time.time()
warmup_shape_count = 0
warmup_times = 3
self.bucketing_ctx.generate_prompt_buckets()
def ordering_function_min_tokens(b):
return (b[0] * b[1], b[1], b[0])
buckets = list(
sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
)
graph_free_mem
total_batch_seq = 0.001
total_mem = 0
available_mem = prompt_available_memory
msg = (
f"Prefill batch size list:{[bsz[0] for bsz in buckets]}\n"
f"Prefill sequence length list:{[seq[1] for seq in buckets]}\n"
)
log_master(logger.info, msg)
for i, (batch_size, seq_len) in enumerate(buckets):
if batch_size * seq_len > self.max_batch_prefill_tokens:
continue
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size * seq_len
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, seq_len, True)
if not (
mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture
):
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
self.log_warmup(True, i, len(buckets), batch_size, seq_len)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_prefill(seq_len, batch_size, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
)
if graphed_bucket in self.graphed_buckets:
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
log_master(logger.info, "Prefill warmup successful.\n")
def ordering_function_max_bs(b):
return (-b[0], b[1])
self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)
buckets = list(
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
)
free_mem = HabanaMemoryProfiler.current_free_device_memory()
total_batch_seq = 0.001
total_mem = 0
available_mem = free_mem - self.mem_reserved
log_master(
logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n"
)
for i, (batch_size, block_num) in enumerate(buckets):
if batch_size > block_num:
continue
# Graph memory usage is proportional to seq dimension in a batch
batch_seq = batch_size
mem_estimate = batch_seq / total_batch_seq * total_mem
graphed_bucket = (batch_size, block_num, False)
if not mem_estimate >= available_mem:
if graphed_bucket not in self.graphed_buckets:
self.graphed_buckets.add(graphed_bucket)
warmup_shape_count += 1
self.log_warmup(False, i, len(buckets), batch_size, block_num)
with HabanaMemoryProfiler() as mem_prof:
for index in range(warmup_times):
self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device)
used_mem = self.align_workers(
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
)
if graphed_bucket in self.graphed_buckets:
available_mem -= used_mem
total_mem += used_mem
total_batch_seq += batch_seq
log_master(logger.info, "Decode warmup successful.\n")
log_master(
logger.info,
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
)
def forward(
self,
batch: FlashMllamaCausalLMBatch,
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
# Add Copy the block tables for all members
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
if batch.pixel_values is not None:
cross_attention_states = self.model.vision_forward(
pixel_values=batch.pixel_values,
aspect_ratio_ids=batch.aspect_ratio_ids,
aspect_ratio_mask=batch.aspect_ratio_mask,
)
batch.cross_attention_states = cross_attention_states
cross_attention_states = batch.cross_attention_states
kwargs = {}
if htorch.utils.internal.is_lazy():
batch_size = input_lengths.shape[0]
seqlen = (
input_ids.shape[0] // batch_size
if batch.prefilling
else batch.hpu_attn_meta.block_list.shape[0]
)
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
batch.prefilling, seqlen, batch_size
)
if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
else:
slots_pad = torch.zeros_like(input_ids, device=slots.device)
slots_pad[: slots.shape[0]] = slots
slots = slots_pad
orig_bs = len(batch)
padded_bs = batch.input_lengths_tensor.shape[0]
padded_input_len = input_ids.view(padded_bs, -1).shape[-1]
image_indices = torch.tensor(batch.image_indices)
if cross_attention_states is not None:
cross_attention_states = F.pad(
cross_attention_states,
(0, 0, 0, 0, 0, (padded_bs - orig_bs)),
value=0,
)
if len(image_indices) != 0:
pad_indices = torch.arange(orig_bs, padded_bs)
image_indices = torch.cat((image_indices, pad_indices), dim=0)
indices, cross_attention_len = generate_cross_attention_states(
cross_attention_states,
image_indices,
input_lengths,
padded_input_len,
batch.prefilling,
)
seqlen = Seqlen(
input_lengths=_async_h2d_tensor_copy(input_lengths),
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=_async_h2d_tensor_copy(position_ids),
cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),
kv_cache=kv_cache,
slots=_async_h2d_tensor_copy(slots),
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=batch.hpu_attn_meta,
lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),
# TODO list
adapter_data=None,
cross_attention_states=cross_attention_states,
indices=_async_h2d_tensor_copy(indices),
cross_attention_len=_async_h2d_tensor_copy(cross_attention_len),
**kwargs,
)
if batch.pixel_values is not None:
batch.pixel_values = None
return logits, speculative_logits
================================================
FILE: backends/gaudi/server/text_generation_server/models/model.py
================================================
import inspect
import torch
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type, Dict
from collections import defaultdict
from transformers import PreTrainedTokenizerBase
from text_generation_server.models.types import Batch, Generation
from text_generation_server.models.globals import BLOCK_SIZE
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse
from text_generation_server.adapters.weights import LayerAdapterWeights
from text_generation_server.pb import generate_pb2
BASE_MODEL_ADAPTER_ID = "__base_model__"
B = TypeVar("B", bound=Batch)
class Model(ABC):
def __init__(
self,
model_id: str,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase,
requires_padding: bool,
dtype: torch.dtype,
device: torch.device,
rank: int = 0,
world_size: int = 1,
sliding_window: Optional[int] = None,
speculate: Optional[int] = None,
adapter_id: str = BASE_MODEL_ADAPTER_ID,
support_chunking: bool = False,
):
self.model_id = model_id
self.model = model.eval()
self.tokenizer = tokenizer
# all_special_ids is not set correctly if the rust tokenizer is unpacked
# TODO report this to transformers.
other_special_ids = {
id for id, token in tokenizer.added_tokens_decoder.items() if token.special
}
self.all_special_ids = set(tokenizer.all_special_ids)
self.all_special_ids.update(other_special_ids)
self.requires_padding = requires_padding
self.dtype = dtype
self.device = device
self.rank = rank
self.world_size = world_size
self.sliding_window = sliding_window if sliding_window != -1 else None
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
LayerAdapterWeights
)
self.loaded_adapters = set()
self.static_adapter_id = adapter_id
if speculate is None:
speculate = get_speculate()
self.speculate = speculate
self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
)
self.check_initialized()
@property
def info(self) -> InfoResponse:
if self.requires_padding and self.sliding_window is not None:
raise NotImplementedError("sliding_window is not implemented with padding")
return InfoResponse(
requires_padding=self.requires_padding,
dtype=str(self.dtype),
device_type=self.device.type,
window_size=None,
speculate=self.speculate,
block_size=BLOCK_SIZE,
)
@property
@abstractmethod
def batch_type(self) -> Type[B]:
raise NotImplementedError
@abstractmethod
def generate_token(
self, batch: B
) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
raise NotImplementedError
def warmup(
self, batch: generate_pb2.WarmupRequest
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
self.generate_token(batch)
return None, None, None
def decode_token(
self,
all_input_ids: List[int],
prefix_offset: int = 0,
read_offset: int = 0,
skip_special_tokens: bool = False,
) -> Tuple[str, int, int]:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
# The prefix text is necessary only to defeat cleanup algorithms in the decode
# which decide to add a space or not depending on the surrounding ids.
prefix_text = self.tokenizer.decode(
all_input_ids[prefix_offset:read_offset],
skip_special_tokens=skip_special_tokens,
)
new_text = self.tokenizer.decode(
all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens
)
if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
new_text = new_text[len(prefix_text) :]
return new_text, read_offset, len(all_input_ids)
else:
return "", prefix_offset, read_offset
def check_initialized(self):
uninitialized_parameters = []
for n, p in self.model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
)
================================================
FILE: backends/gaudi/server/text_generation_server/models/seq2seq_lm.py
================================================
import torch
import torch.distributed
import time
from dataclasses import dataclass
from opentelemetry import trace
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
PreTrainedTokenizerBase,
AutoConfig,
)
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model
from text_generation_server.models.types import (
GeneratedText,
Batch,
Generation,
Tokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__)
@dataclass
class Seq2SeqLMBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
requests_idx_mapping: Dict[int, int]
# Encoder values
input_ids: Optional[torch.Tensor]
attention_mask: torch.Tensor
# Decoder values
decoder_input_ids: torch.Tensor
decoder_attention_mask: Optional[torch.Tensor]
encoder_last_hidden_state: Optional[torch.Tensor]
# All tokens
all_decoder_input_ids: List[torch.Tensor]
# Seq2SeqLM keeps track of both encoder and decoder attention keys and values
past_key_values: Optional[List[Tuple]]
# Lengths of all generations present in the batch
input_lengths: List[int]
decoder_input_lengths: List[int]
prefix_offsets: List[int]
read_offsets: List[int]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Metadata used for padding
max_input_length: int
max_decoder_input_length: int
padding_right_offset: int
# Maximum number of tokens this batch will grow to
max_tokens: int
def to_pb(self) -> generate_pb2.CachedBatch:
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.CachedBatch protobuf"""
return generate_pb2.CachedBatch(
id=self.batch_id,
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
)
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "Seq2SeqLMBatch":
"""Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
inputs = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
decoder_input_lengths = []
prefix_offsets = []
read_offsets = []
requests_idx_mapping = {}
# Parse batch
max_truncation = 0
padding_right_offset = 0
max_decode_tokens = 0
for i, r in enumerate(pb.requests):
inputs.append(concat_text_chunks(r.input_chunks.chunks))
requests_idx_mapping[r.id] = i
decoder_input_lengths.append(1)
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
# Tokenize batch
tokenized_inputs = tokenizer(
inputs,
return_tensors="pt",
padding=True,
return_token_type_ids=False,
truncation=True,
max_length=max_truncation,
).to(device)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
# Decoder sequence only contains the bos_token
decoder_input_ids = (
torch.tensor(tokenizer.bos_token_id, device=device)
.repeat(len(pb.requests))
.view(-1, 1)
)
for _ in pb.requests:
prefix_offsets.append(0)
read_offsets.append(1)
all_decoder_input_ids = decoder_input_ids.view(-1).split(1)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=tokenized_inputs["input_ids"],
attention_mask=tokenized_inputs["attention_mask"],
decoder_input_ids=decoder_input_ids,
all_decoder_input_ids=list(all_decoder_input_ids),
decoder_attention_mask=None,
encoder_last_hidden_state=None,
past_key_values=None,
input_lengths=input_lengths.tolist(),
decoder_input_lengths=decoder_input_lengths,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(),
max_decoder_input_length=1,
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
)
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> Optional["Seq2SeqLMBatch"]:
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
if len(request_ids) == len(self):
return self
keep_indices = []
# New values after filtering
requests_idx_mapping = {}
requests = []
input_lengths = []
decoder_input_lengths = []
prefix_offsets = []
read_offsets = []
all_decoder_input_ids = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
max_input_length = 0
max_decoder_input_length = 0
padding_right_offset = 0
total_remaining_decode_tokens = 0
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
requests_idx_mapping[request_id] = i
keep_indices.append(idx)
requests.append(self.requests[idx])
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
all_decoder_input_ids.append(self.all_decoder_input_ids[idx])
request_input_length = self.input_lengths[idx]
input_lengths.append(request_input_length)
max_input_length = max(max_input_length, request_input_length)
request_decoder_input_length = self.decoder_input_lengths[idx]
decoder_input_lengths.append(request_decoder_input_length)
max_decoder_input_length = max(
max_decoder_input_length, request_decoder_input_length
)
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
total_remaining_decode_tokens += remaining_decode_tokens
padding_right_offset = max(padding_right_offset, remaining_decode_tokens)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
self.decoder_input_ids = self.decoder_input_ids[keep_indices]
self.attention_mask = self.attention_mask[keep_indices, -max_input_length:]
if self.decoder_attention_mask is not None:
self.decoder_attention_mask = self.decoder_attention_mask[
keep_indices,
-(self.padding_right_offset + max_decoder_input_length) : (
self.decoder_attention_mask.shape[1] - self.padding_right_offset
)
+ padding_right_offset,
]
self.encoder_last_hidden_state = self.encoder_last_hidden_state[
keep_indices, -max_input_length:
]
# Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) is tuple:
self.past_key_values = [
[t for t in layer] for layer in self.past_key_values
]
decoder_past_seq_len = max_decoder_input_length - 1
for layer in self.past_key_values:
layer[0] = layer[0][keep_indices, :, -decoder_past_seq_len:]
layer[1] = layer[1][keep_indices, :, -decoder_past_seq_len:]
layer[2] = layer[2][keep_indices, :, -max_input_length:]
layer[3] = layer[3][keep_indices, :, -max_input_length:]
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = (
len(request_ids) * (max_input_length + max_decoder_input_length)
+ remaining_decode_tokens
)
self.requests = requests
self.requests_idx_mapping = requests_idx_mapping
self.input_ids = None
self.all_decoder_input_ids = all_decoder_input_ids
self.input_lengths = input_lengths
self.decoder_input_lengths = decoder_input_lengths
self.prefix_offsets = prefix_offsets
self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens
self.top_n_tokens_tensor = top_n_tokens_tensor
self.max_input_length = max_input_length
self.max_decoder_input_length = max_decoder_input_length
self.padding_right_offset = padding_right_offset
self.max_tokens = max_tokens
return self
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["Seq2SeqLMBatch"]) -> "Seq2SeqLMBatch":
"""Concatenate multiple batches together by padding internal torch tensors"""
# Used for padding
total_batch_size = 0
max_input_length = 0
max_decoder_input_length = 0
padding_right_offset = 0
for batch in batches:
total_batch_size += len(batch)
max_input_length = max(max_input_length, batch.max_input_length)
max_decoder_input_length = max(
max_decoder_input_length, batch.max_decoder_input_length
)
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes
requests = []
requests_idx_mapping = {}
all_decoder_input_ids = []
input_lengths = []
decoder_input_lengths = []
prefix_offsets = []
read_offsets = []
next_token_choosers = []
stopping_criterias = []
top_n_tokens = []
max_tokens = 0
# Batch tensors
attention_mask = None
decoder_input_ids = None
decoder_attention_mask = None
encoder_last_hidden_state = None
top_n_tokens_tensor = None
past_key_values = []
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index = 0
for i, batch in enumerate(batches):
# Extend all list attributes
requests.extend(batch.requests)
all_decoder_input_ids.extend(batch.all_decoder_input_ids)
input_lengths.extend(batch.input_lengths)
decoder_input_lengths.extend(batch.decoder_input_lengths)
prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + start_index
# Slicing end index for this batch
end_index = start_index + len(batch)
# We only concatenate batches that did at least one step
if batch.encoder_last_hidden_state is None:
raise ValueError("Batch encoder_last_hidden_state cannot be None")
# Create padded tensor
if attention_mask is None:
attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_input_length),
)
# Copy to correct indices
attention_mask[start_index:end_index, -batch.max_input_length :] = (
batch.attention_mask[:, -batch.max_input_length :]
)
# Create padded tensor
if decoder_input_ids is None:
decoder_input_ids = batch.decoder_input_ids.new_zeros(
(total_batch_size, 1),
)
# Copy to correct indices
decoder_input_ids[start_index:end_index] = batch.decoder_input_ids
# Create padded tensor
if decoder_attention_mask is None:
# As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here
decoder_attention_mask = batch.attention_mask.new_zeros(
(total_batch_size, max_decoder_input_length + padding_right_offset),
)
# If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
# this batch. All generations are of length `batch.max_decoder_input_length`.
left_offset = max_decoder_input_length - batch.max_decoder_input_length
if batch.decoder_attention_mask is None:
decoder_attention_mask[
start_index:end_index,
left_offset:-padding_right_offset,
] = 1
# If it exists, we need to index
else:
batch_left_offset = (
batch.decoder_attention_mask.shape[1]
- batch.max_decoder_input_length
- batch.padding_right_offset
)
decoder_attention_mask[
start_index:end_index,
left_offset:-padding_right_offset,
] = batch.decoder_attention_mask[
:,
batch_left_offset : -batch.padding_right_offset,
]
# Create padded tensor
if encoder_last_hidden_state is None:
encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(
(
total_batch_size,
max_input_length,
batch.encoder_last_hidden_state.shape[-1],
),
)
if top_n_tokens_tensor is None:
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
# Copy to correct indices
encoder_last_hidden_state[
start_index:end_index, -batch.max_input_length :, :
] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]
batch.encoder_last_hidden_state = None
# Ensure that we can update tensors in-place
if isinstance(batch.past_key_values[0], tuple):
batch.past_key_values = [
[t for t in layer] for layer in batch.past_key_values
]
# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length
- batch.max_input_length
+ max_decoder_input_length
- batch.max_decoder_input_length
) * len(batch)
start_index = end_index
# Determine shapes for new past kv tensors
first_past_kvs = batches[0].past_key_values
_, num_heads, _, head_dim = first_past_kvs[0][0].shape
padded_dec_t_shape = (
total_batch_size,
num_heads,
(max_decoder_input_length - 1),
head_dim,
)
padded_enc_t_shape = (
total_batch_size,
num_heads,
max_input_length,
head_dim,
)
# Iterate over attention layers
for j in range(len(first_past_kvs)):
past_key_values.append([])
# Decoder past
for k in range(0, 2):
# Initialize tensors
padded_past_values = first_past_kvs[j][k].new_zeros(padded_dec_t_shape)
past_key_values[j].append(padded_past_values)
start_index = 0
for batch in batches:
t = batch.past_key_values[j][k]
# Clear reference to the original tensor
batch.past_key_values[j][k] = None
# Slicing end index for this batch
end_index = start_index + len(batch)
# We slice the past keys and values to remove the padding from previous batches
past_seq_len = batch.max_decoder_input_length - 1
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[
:, :, -past_seq_len:, :
]
del t
start_index = end_index
# Encoder past
for k in range(2, 4):
# Initialize tensors
padded_past_values = first_past_kvs[j][k].new_zeros(padded_enc_t_shape)
past_key_values[j].append(padded_past_values)
start_index = 0
for batch in batches:
t = batch.past_key_values[j][k]
# Clear reference to the original tensor
batch.past_key_values[j][k] = None
# Slicing end index for this batch
end_index = start_index + len(batch)
# We slice the past keys and values to remove the padding from previous batches
padded_past_values[
start_index:end_index, :, -batch.max_input_length :, :
] = t[:, :, -batch.max_input_length :, :]
del t
start_index = end_index
return cls(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=None,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
all_decoder_input_ids=all_decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_last_hidden_state=encoder_last_hidden_state,
past_key_values=past_key_values,
input_lengths=input_lengths,
decoder_input_lengths=decoder_input_lengths,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length,
max_decoder_input_length=max_decoder_input_length,
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
)
def __len__(self):
return len(self.requests)
class Seq2SeqLM(Model):
def __init__(
self,
model_id: str,
model_class,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
default_dtype=torch.float16,
trust_remote_code: bool = False,
config_class=AutoConfig,
tokenizer_class=AutoTokenizer,
aliases=None,
):
self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed()
device = torch.device("hpu")
dtype = torch.bfloat16 if dtype is None else dtype
config = config_class.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.bos_token_id = config.decoder_start_token_id
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
aliases=aliases,
weights_loader=weights_loader,
)
if config.quantize in ["awq", "gptq"]:
weights._set_gptq_params(model_id, revision)
model = model_class(config, weights)
torch.distributed.barrier(group=self.process_group)
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@classmethod
def fallback(
cls,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
model = AutoModelForSeq2SeqLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.bos_token_id = model.config.decoder_start_token_id
self = cls.__new__(
cls,
)
super().__init__(
self,
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)
self.quantize = quantize
return self
@property
def batch_type(self) -> Type[Seq2SeqLMBatch]:
return Seq2SeqLMBatch
def forward(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask: Optional,
encoder_last_hidden_state: Optional,
past_key_values: Optional = None,
) -> Tuple[
torch.Tensor,
Optional[torch.Tensor],
torch.Tensor,
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
]:
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_last_hidden_state,
past_key_values=past_key_values,
use_cache=True,
)
if isinstance(outputs, tuple):
# Our custom models
outputs, speculative_logits = outputs
else:
# Generic transformers models
speculative_logits = None
return (
outputs.logits,
speculative_logits,
outputs.encoder_last_hidden_state,
outputs.past_key_values,
)
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: Seq2SeqLMBatch
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch], Tuple[int, int]]:
start = time.time_ns()
if batch.decoder_attention_mask is not None:
# slice to the correct shape
decoder_attention_mask = batch.decoder_attention_mask[
:, : -batch.padding_right_offset
]
else:
decoder_attention_mask = None
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
# internally...
if batch.encoder_last_hidden_state is not None:
encoder_last_hidden_state = [batch.encoder_last_hidden_state]
else:
encoder_last_hidden_state = None
logits, speculative_logits, encoder_last_hidden_state, past = self.forward(
batch.input_ids,
batch.attention_mask,
batch.decoder_input_ids,
decoder_attention_mask,
encoder_last_hidden_state,
batch.past_key_values,
)
# Speculation is not active for seq2seq
accepted_ids = torch.ones_like(batch.decoder_input_ids)[:, 0]
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.log_softmax(logits[:, -1], -1),
accepted_ids,
)
start_decode = time.time_ns()
# Finished requests
generations: List[Generation] = []
stopped = True
# Zipped iterator
iterator = zip(
batch.requests,
batch.input_lengths,
batch.prefix_offsets,
batch.read_offsets,
batch.decoder_input_lengths,
logits,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_decoder_input_ids,
batch.top_n_tokens,
batch_top_token_ids,
batch_top_token_logprobs,
)
# For each member of the batch
for i, (
request,
input_length,
prefix_offset,
read_offset,
decoder_input_length,
logits,
next_token_chooser,
stopping_criteria,
all_decoder_input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
all_decoder_input_ids.view(1, -1), logits[-1:, :]
)
# Append next token to decoder tokens
all_decoder_input_ids = torch.cat(
[all_decoder_input_ids, next_token_id.squeeze(1)]
)
new_decoder_input_length = decoder_input_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text, prefix_offset, read_offset = self.decode_token(
all_decoder_input_ids, prefix_offset, read_offset
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(next_token_id, next_token_text)
if not stop:
stopped = False
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop:
# Slice with decoder_input_length to remove padding
# Decode all tokens
output_text, _, _ = self.decode_token(
all_decoder_input_ids,
prefix_offset=len(all_decoder_input_ids)
- decoder_input_length
- 1,
read_offset=len(all_decoder_input_ids) - decoder_input_length,
skip_special_tokens=True,
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
generated_text = None
# Prefill
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
prefill_tokens = Tokens(
[self.tokenizer.bos_token_id],
[float("nan")],
[self.tokenizer.bos_token],
[False],
)
else:
prefill_tokens = None
if top_n_tokens > 0:
all_top_tokens = []
for top_token_ids, top_token_logprobs in zip(
top_token_ids, top_token_logprobs
):
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids
for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
all_top_tokens.append(top_tokens)
top_tokens = all_top_tokens
else:
top_tokens = None
generation = Generation(
request.id,
prefill_tokens,
Tokens(
[next_token_id_squeezed],
[next_token_logprob],
[next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids],
),
generated_text,
top_tokens,
)
generations.append(generation)
# Update values
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
next_token_id_squeezed.item()
)
batch.decoder_input_ids[i] = next_token_id
batch.all_decoder_input_ids[i] = all_decoder_input_ids
batch.input_lengths[i] = input_length
batch.decoder_input_lengths[i] = new_decoder_input_length
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.max_input_length = max(batch.max_input_length, input_length)
batch.max_decoder_input_length = max(
batch.max_decoder_input_length, new_decoder_input_length
)
# We finished all generations in the batch; there is no next batch
if stopped:
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
# We don't need input_ids after the prefill forward
batch.input_ids = None
batch.encoder_last_hidden_state = encoder_last_hidden_state
batch.past_key_values = past
# Update decoder_attention_mask as we added a new token to input_ids
if batch.decoder_attention_mask is not None:
batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
batch.padding_right_offset -= 1
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)
================================================
FILE: backends/gaudi/server/text_generation_server/models/types.py
================================================
import torch
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional
from transformers import PreTrainedTokenizerBase
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason
class Batch(ABC):
@abstractmethod
def to_pb(self) -> generate_pb2.CachedBatch:
raise NotImplementedError
@classmethod
@abstractmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "Batch":
raise NotImplementedError
@abstractmethod
def filter(self, request_ids: List[int]) -> "Batch":
raise NotImplementedError
@classmethod
@abstractmethod
def concatenate(cls, batches: List["Batch"]) -> "Batch":
raise NotImplementedError
@abstractmethod
def __len__(self):
raise NotImplementedError
@dataclass
class GeneratedText:
text: str
generated_tokens: int
finish_reason: FinishReason
seed: Optional[int]
def to_pb(self) -> generate_pb2.GeneratedText:
return generate_pb2.GeneratedText(
text=self.text,
generated_tokens=self.generated_tokens,
finish_reason=self.finish_reason,
seed=self.seed,
)
@dataclass
class Tokens:
token_ids: List[int]
logprobs: List[float]
texts: List[str]
is_special: List[bool]
def to_pb(self) -> generate_pb2.Tokens:
return generate_pb2.Tokens(
ids=self.token_ids,
logprobs=self.logprobs,
texts=self.texts,
is_special=self.is_special,
)
def __len__(self):
return len(self.token_ids)
@dataclass
class Generation:
request_id: int
prefill_tokens: Optional[Tokens]
tokens: Tokens
generated_text: Optional[GeneratedText]
# Optional for now, since it's not yet supported for every model.
top_tokens: Optional[List[Tokens]]
def to_pb(self) -> generate_pb2.Generation:
return generate_pb2.Generation(
request_id=self.request_id,
prefill_tokens=(
self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None
),
tokens=self.tokens.to_pb(),
generated_text=(
self.generated_text.to_pb() if self.generated_text is not None else None
),
top_tokens=(
[top_tokens.to_pb() for top_tokens in self.top_tokens]
if self.top_tokens is not None
else None
),
)
================================================
FILE: backends/gaudi/server/text_generation_server/pb/.gitignore
================================================
*.py
*.pyi
*.py-e
================================================
FILE: backends/gaudi/server/text_generation_server/server.py
================================================
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import asyncio
import os
import torch
import time
import signal
from grpc import aio
from loguru import logger
from grpc_reflection.v1alpha import reflection
from pathlib import Path
from typing import List, Optional
from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model_with_lora_adapters
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.globals import set_model_id, ATTENTION
from text_generation_server.models.globals import set_adapter_to_index
from text_generation_server.utils.adapter import AdapterInfo
from text_generation_server.utils.tokens import make_tokenizer_optional
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
from text_generation_server.models import VLM_BATCH_TYPES
from text_generation_server.utils.version import (
is_driver_compatible,
MIN_TGI_GAUDI_SYNAPSE_VERSION,
)
class SignalHandler:
KEEP_PROCESSING = True
def __init__(self):
signal.signal(signal.SIGINT, self.exit_gracefully)
signal.signal(signal.SIGTERM, self.exit_gracefully)
def exit_gracefully(self, signum, frame):
print(f"Exiting gracefully: Signal {signum}")
self.KEEP_PROCESSING = False
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__(
self,
model: Model,
cache: Cache,
server_urls: List[str],
):
self.cache = cache
self.model = model
# Quantize is resolved during model loading
self.quantize = model.quantize
self.server_urls = server_urls
# For some reason, inference_mode does not work well with GLOO which we use on CPU
# TODO: The inferecemode set messes up the autograd op dispatch. And results in aten::matmul
# op not optimized issue. Will investigate further.
# if model.device.type == "hpu":
# Force inference mode for the lifetime of TextGenerationService
# self._inference_mode_raii_guard = torch._C._InferenceMode(True)
async def Info(self, request, context):
return self.model.info
async def Health(self, request, context):
if self.model.device.type == "hpu":
torch.zeros((2, 2)).to("hpu")
return generate_pb2.HealthResponse()
async def ServiceDiscovery(self, request, context):
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
async def ClearCache(self, request, context):
if request.HasField("id"):
self.cache.delete(request.id)
else:
self.cache.clear()
return generate_pb2.ClearCacheResponse()
async def FilterBatch(self, request, context):
batch = self.cache.pop(request.batch_id)
if batch is None:
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
filtered_batch = batch.filter(request.request_ids)
self.cache.set(filtered_batch)
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context):
if ATTENTION == "paged":
set_max_prefill_tokens(request.max_prefill_tokens)
if (
self.model.batch_type in VLM_BATCH_TYPES
): # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb_processor(
request.batch,
self.model.tokenizer,
self.model.processor,
self.model.model.config,
self.model.dtype,
self.model.device,
)
else:
batch = self.model.batch_type.from_pb(
request.batch,
self.model.tokenizer,
self.model.dtype,
self.model.device,
)
# Override default values with None for clearer semantics.
max_input_tokens = (
request.max_input_tokens
if request.HasField("max_input_tokens")
else None
)
max_total_tokens = (
request.max_total_tokens
if request.HasField("max_total_tokens")
else None
)
max_supported_total_tokens, max_input_tokens, max_total_tokens = (
self.model.warmup(batch, max_input_tokens, max_total_tokens)
)
else:
max_supported_total_tokens, max_input_tokens, max_total_tokens = (
self.model.warmup(request)
)
# W/A for the skip tokenizer path
# We need to call make_tokenizer_optional after the warmup,
# because router is not aware of that feature
make_tokenizer_optional(self.model.tokenizer)
return generate_pb2.WarmupResponse(
max_supported_total_tokens=max_supported_total_tokens,
max_input_tokens=max_input_tokens,
max_total_tokens=max_total_tokens,
)
async def Prefill(self, request, context):
start = time.time_ns()
if (
self.model.batch_type in VLM_BATCH_TYPES
): # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb_processor(
request.batch,
self.model.tokenizer,
self.model.processor,
self.model.model.config,
self.model.dtype,
self.model.device,
)
else:
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
generations, next_batch, timings = self.model.generate_token([batch])
self.cache.set(next_batch)
return generate_pb2.PrefillResponse(
generations=[generation.to_pb() for generation in generations],
batch=next_batch.to_pb() if next_batch else None,
forward_ns=timings[0],
decode_ns=timings[1],
total_ns=time.time_ns() - start,
)
async def Decode(self, request, context):
start = time.time_ns()
if len(request.batches) == 0:
raise ValueError("Must provide at least one batch")
batches = []
for batch_pb in request.batches:
batch = self.cache.pop(batch_pb.id)
if batch is None:
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
batches.append(batch)
if len(batches) == 0:
raise ValueError("All batches are empty")
generations, next_batch, timings = self.model.generate_token(batches)
self.cache.set(next_batch)
return generate_pb2.DecodeResponse(
generations=[generation.to_pb() for generation in generations],
batch=next_batch.to_pb() if next_batch else None,
concat_ns=None,
forward_ns=timings[0],
decode_ns=timings[1],
total_ns=time.time_ns() - start,
)
def serve(
model_id: str,
lora_adapters: Optional[List[AdapterInfo]],
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[str],
kv_cache_dtype: Optional[str],
trust_remote_code: bool,
uds_path: Path,
max_input_tokens: int,
):
async def serve_inner(
model_id: str,
lora_adapters: Optional[List[AdapterInfo]],
revision: Optional[str],
sharded: bool = False,
quantize: Optional[str] = None,
speculate: Optional[int] = None,
dtype: Optional[str] = None,
kv_cache_dtype: Optional[str] = None,
trust_remote_code: bool = False,
):
if not is_driver_compatible():
logger.warning(
f"Current Synapse version is lower than the minimum version supported: {MIN_TGI_GAUDI_SYNAPSE_VERSION}, this could result in failures"
)
unix_socket_template = "unix://{}-{}"
adapter_to_index = {}
logger.info("Server:server_inner: sharded ={}".format(sharded))
if sharded:
rank = int(os.environ["RANK"])
logger.info("Server:server_inner: rank ={}".format(rank))
server_urls = [
unix_socket_template.format(uds_path, rank)
for rank in range(int(os.environ["WORLD_SIZE"]))
]
local_url = server_urls[int(os.environ["RANK"])]
else:
local_url = unix_socket_template.format(uds_path, 0)
server_urls = [local_url]
logger.info(
"Server:server_inner: data type = {}, local_url = {}".format(
dtype, local_url
)
)
if dtype == "bfloat16" or None:
data_type = torch.bfloat16
else:
data_type = torch.float
if revision == "None":
revision = None
try:
model = get_model_with_lora_adapters(
model_id,
lora_adapters,
revision,
sharded,
quantize,
speculate,
data_type,
kv_cache_dtype,
trust_remote_code,
max_input_tokens,
adapter_to_index,
)
except Exception:
logger.exception("Error when initializing model")
raise
set_adapter_to_index(adapter_to_index)
server = aio.server(
interceptors=[
ExceptionInterceptor(),
UDSOpenTelemetryAioServerInterceptor(),
],
options=[
# Set the maximum possible message length: i32::MAX
("grpc.max_receive_message_length", (1 << 31) - 1)
],
)
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
TextGenerationService(model, Cache(), server_urls), server
)
SERVICE_NAMES = (
generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
reflection.SERVICE_NAME,
)
reflection.enable_server_reflection(SERVICE_NAMES, server)
server.add_insecure_port(local_url)
await server.start()
logger.info("Server started at {}".format(local_url))
signal_handler = SignalHandler()
while signal_handler.KEEP_PROCESSING:
await asyncio.sleep(0.5)
set_model_id(model_id)
asyncio.run(
serve_inner(
model_id,
lora_adapters,
revision,
sharded,
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
)
)
================================================
FILE: backends/gaudi/server/text_generation_server/tracing.py
================================================
import grpc
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.grpc._aio_server import (
OpenTelemetryAioServerInterceptor,
)
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import (
BatchSpanProcessor,
)
class UDSOpenTelemetryAioServerInterceptor(OpenTelemetryAioServerInterceptor):
def __init__(self):
super().__init__(trace.get_tracer(__name__))
def _start_span(self, handler_call_details, context, set_status_on_exception=False):
"""
Rewrite _start_span method to support Unix Domain Socket gRPC contexts
"""
# standard attributes
attributes = {
SpanAttributes.RPC_SYSTEM: "grpc",
SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[0],
}
# if we have details about the call, split into service and method
if handler_call_details.method:
service, method = handler_call_details.method.lstrip("/").split("/", 1)
attributes.update(
{
SpanAttributes.RPC_METHOD: method,
SpanAttributes.RPC_SERVICE: service,
}
)
# add some attributes from the metadata
metadata = dict(context.invocation_metadata())
if "user-agent" in metadata:
attributes["rpc.user_agent"] = metadata["user-agent"]
# We use gRPC over a UNIX socket
attributes.update({SpanAttributes.NET_TRANSPORT: "unix"})
return self._tracer.start_as_current_span(
name=handler_call_details.method,
kind=trace.SpanKind.SERVER,
attributes=attributes,
set_status_on_exception=set_status_on_exception,
)
def setup_tracing(otlp_service_name: str, otlp_endpoint: str):
resource = Resource.create(attributes={"service.name": otlp_service_name})
span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)
span_processor = BatchSpanProcessor(span_exporter)
trace.set_tracer_provider(TracerProvider(resource=resource))
trace.get_tracer_provider().add_span_processor(span_processor)
================================================
FILE: backends/gaudi/server/text_generation_server/utils/__init__.py
================================================
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
from text_generation_server.utils.convert import convert_file, convert_files
from text_generation_server.utils.dist import initialize_torch_distributed
from text_generation_server.utils.weights import Weights
from text_generation_server.utils.peft import download_and_unload_peft
from text_generation_server.utils.hub import (
weight_files,
weight_hub_files,
download_weights,
EntryNotFoundError,
LocalEntryNotFoundError,
RevisionNotFoundError,
)
from text_generation_server.utils.tokens import (
NextTokenChooser,
HeterogeneousNextTokenChooser,
StoppingCriteria,
StopSequenceCriteria,
FinishReason,
Sampling,
Greedy,
make_tokenizer_optional,
is_tokenizer_transparent,
pad_next_token_chooser_parameters,
)
__all__ = [
"convert_file",
"convert_files",
"initialize_torch_distributed",
"weight_files",
"weight_hub_files",
"download_weights",
"download_and_unload_peft",
"EntryNotFoundError",
"HeterogeneousNextTokenChooser",
"LocalEntryNotFoundError",
"RevisionNotFoundError",
"Greedy",
"NextTokenChooser",
"Sampling",
"StoppingCriteria",
"StopSequenceCriteria",
"FinishReason",
"Weights",
"make_tokenizer_optional",
"is_tokenizer_transparent",
"pad_next_token_chooser_parameters",
]
================================================
FILE: backends/gaudi/server/text_generation_server/utils/adapter.py
================================================
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/utils/adapter.py
# License: Apache License Version 2.0, January 2004
import warnings
import re
from dataclasses import dataclass
from functools import lru_cache
from typing import TYPE_CHECKING, Set, Tuple, Optional, List
from safetensors.torch import load_file
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
from text_generation_server.utils.merges.strategies import merge_adapters
from text_generation_server.utils import hub
from text_generation_server.adapters.lora import LoraConfig
if TYPE_CHECKING:
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
BASE_MODEL_ADAPTER_ID = "__base_model__"
@dataclass
class AdapterInfo:
id: str
path: Optional[str]
revision: Optional[str] = None
@dataclass
class AdapterParameters:
adapter_info: Tuple[AdapterInfo]
weights: Tuple[float]
merge_strategy: NotImplemented
density: float
majority_sign_method: NotImplemented
@dataclass
class AdapterSource:
adapter_id: str
model_id: str
revision: str
def parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]:
if not lora_adapters:
return []
adapter_list = []
for adapter in lora_adapters.split(","):
adapter = adapter.strip()
if adapter.count("=") > 1 or adapter.count("@") > 1:
raise ValueError(f"Invalid LoRA adapter format: {adapter}")
match = re.match(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$", adapter)
if match:
adapter_id, path, revision = match.groups()
adapter_list.append(
AdapterInfo(id=adapter_id, path=path, revision=revision)
)
else:
raise ValueError(f"Invalid LoRA adapter format: {adapter}")
return adapter_list
def load_and_merge_adapters(
model_id: str,
adapter_parameters: AdapterParameters,
adapter_index: int,
weight_names: Tuple[str],
trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
if len(adapter_parameters.adapter_info) == 1:
adapter = next(iter(adapter_parameters.adapter_info))
return load_module_map(
model_id,
adapter.revision,
adapter.id,
adapter.path,
weight_names,
trust_remote_code,
)
adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)
return _load_and_merge(
model_id,
adapter_params,
weight_names,
trust_remote_code,
)
@dataclass
class AdapterParametersContainer:
adapter_parameters: AdapterParameters
adapter_index: int
def __hash__(self) -> int:
return self.adapter_index
@lru_cache(maxsize=32)
def _load_and_merge(
model_id: str,
adapter_params: AdapterParametersContainer,
weight_names: Tuple[str],
trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
params = adapter_params.adapter_parameters
adapters_to_merge = []
merged_weight_names = set()
tokenizer = None
for adapter in params.adapter_info:
if adapter.id == BASE_MODEL_ADAPTER_ID:
raise ValueError("Base model adapter cannot be merged.")
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
load_module_map(
model_id,
adapter.revision,
adapter.id,
adapter.path,
weight_names,
trust_remote_code,
)
)
adapters_to_merge.append((module_map, adapter_config))
merged_weight_names = merged_weight_names.union(adapter_weight_names)
if tokenizer is None:
tokenizer = adapter_tokenizer
if len(adapters_to_merge) == 0:
raise ValueError("No adapters to merge.")
module_map, adapter_config = merge_adapters(adapters_to_merge, params)
return module_map, adapter_config, merged_weight_names, tokenizer
def check_architectures(
model_id: str,
adapter_id: str,
adapter_config: "AdapterConfig",
trust_remote_code: bool = False,
):
try:
if not adapter_config.base_model_name_or_path:
# Avoid execution latency caused by the network connection retrying for AutoConfig.from_pretrained(None)
return
expected_config = AutoConfig.from_pretrained(
model_id, trust_remote_code=trust_remote_code
)
model_config = AutoConfig.from_pretrained(
adapter_config.base_model_name_or_path, trust_remote_code=trust_remote_code
)
except Exception as e:
warnings.warn(
f"Unable to check architecture compatibility for adapter '{adapter_id}' "
f"against model '{model_id}'. Assuming they are compatible. Error: {e}"
)
return
if model_config.architectures == expected_config.architectures:
warnings.warn(
f"Adapter '{adapter_id}' was not trained on base model '{model_id}'. "
f"If you encounter issues, use --model-id '{adapter_config.base_model_name_or_path}' instead."
)
else:
# TODO(travis): revisit this when we support clasification heads which will not use CausalLM
raise ValueError(
f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. "
f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. "
f"Use --model-id '{adapter_config.base_model_name_or_path}' instead."
)
@lru_cache(maxsize=128)
def load_module_map(
model_id: str,
revision: str,
adapter_id: str,
adapter_path: Optional[str],
weight_names: Tuple[str],
trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
if not adapter_path and adapter_config.base_model_name_or_path != model_id:
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
adapter_filenames = (
hub._weight_files_from_dir(adapter_path, extension=".safetensors")
if adapter_path
else hub._cached_weight_files(
adapter_id, revision=revision, extension=".safetensors"
)
)
# throw an error if no adapter weights are found
if not adapter_filenames:
raise FileNotFoundError(
f"No adapter weights found for adapter '{adapter_id}' and revision '{revision}'."
)
try:
adapter_tokenizer = AutoTokenizer.from_pretrained(
adapter_config.config_path,
trust_remote_code=trust_remote_code,
)
except Exception:
# Adapter does not have a tokenizer, so fallback to base model tokenizer
adapter_tokenizer = None
# load adapter weights from all shards (should have relatively small memory footprint)
adapter_weights = {}
for filename in adapter_filenames:
adapter_weights.update(load_file(filename))
# map the model weights to the relevant adapter weights (LoRA A and B matrices)
module_map, adapter_weight_names = adapter_config.map_weights_for_model(
adapter_weights, weight_names
)
return module_map, adapter_config, adapter_weight_names, adapter_tokenizer
def get_attn_weights(i, layer):
qkv = layer.self_attn.query_key_value
weights = {}
for k in ["q", "k", "v"]:
key = (i, f"{k}_proj")
value = (f"model.layers.{i}.self_attn.{k}_proj", qkv)
weights[key] = value
# also add the qkv_proj weight for the adapter
weights[(i, "qkv_proj")] = (
f"model.layers.{i}.self_attn.qkv_proj",
qkv,
)
weights[(i, "o_proj")] = (
f"model.layers.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
return weights
def get_mlp_weights(i, layer):
weights = {}
if hasattr(layer, "mlp"):
mlp = layer.mlp
if hasattr(mlp, "gate_up_proj"):
# handle combined gate_up_proj (e.g., for some LLaMA variants)
weights.update(
{
(i, "gate_proj"): (
f"model.layers.{i}.mlp.gate_proj",
mlp.gate_up_proj,
),
(i, "up_proj"): (f"model.layers.{i}.mlp.up_proj", mlp.gate_up_proj),
}
)
else:
# handle separate gate_proj, up_proj, and down_proj (e.g., for Gemma)
if hasattr(mlp, "gate_proj"):
weights[(i, "gate_proj")] = (
f"model.layers.{i}.mlp.gate_proj",
mlp.gate_proj,
)
if hasattr(mlp, "up_proj"):
weights[(i, "up_proj")] = (f"model.layers.{i}.mlp.up_proj", mlp.up_proj)
if hasattr(mlp, "down_proj"):
weights[(i, "down_proj")] = (
f"model.layers.{i}.mlp.down_proj",
mlp.down_proj,
)
return weights
# build_layer_weight_lookup creates a mapping of model layers to their corresponding
# weight tensors and paths. It builds a dictionary that maps layer identifiers to tuples
# containing the weight tensor path and the actual layer object. This mapping is needed
# for the lora adapter to know which weights to update when applying the adapter.
def build_layer_weight_lookup(model):
if hasattr(model, "language_model"):
m = model.language_model.model
elif hasattr(model, "text_model"):
m = model.text_model.model
else:
m = model.model
layer_weights = {}
for i, layer in enumerate(m.layers):
attn_weights = get_attn_weights(i, layer)
mlp_weights = get_mlp_weights(i, layer)
layer_weights.update(attn_weights)
layer_weights.update(mlp_weights)
lm_head = None
if hasattr(m, "lm_head"):
lm_head = m.lm_head
elif hasattr(model, "lm_head"):
lm_head = model.lm_head
if lm_head:
layer_weights[(0, "lm_head")] = ("lm_head", lm_head)
return layer_weights
================================================
FILE: backends/gaudi/server/text_generation_server/utils/chunks.py
================================================
from typing import Iterable
from loguru import logger
from text_generation_server.pb import generate_pb2
def concat_text_chunks(chunks: Iterable[generate_pb2.InputChunk]) -> str:
"""
Concatenate text in text chunks. Non-text chunks are dropped.
"""
text = None
for chunk in chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
if text is None:
text = chunk.text
else:
raise NotImplementedError("Request contained more than one text chunk")
else:
# We cannot reject this, e.g. warmup sends an image chunk.
logger.debug(f"Encountered non-text chunk type {chunk_type}")
if text is None:
raise NotImplementedError("Request without a text chunk")
return text
================================================
FILE: backends/gaudi/server/text_generation_server/utils/convert.py
================================================
import datetime
import torch
import os
from loguru import logger
from pathlib import Path
from safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete
from typing import List, Dict
from collections import defaultdict
def _remove_duplicate_names(
state_dict: Dict[str, torch.Tensor],
*,
preferred_names: List[str] = None,
discard_names: List[str] = None,
) -> Dict[str, List[str]]:
if preferred_names is None:
preferred_names = []
preferred_names = set(preferred_names)
if discard_names is None:
discard_names = []
discard_names = set(discard_names)
shareds = _find_shared_tensors(state_dict)
to_remove = defaultdict(list)
for shared in shareds:
complete_names = set(
[name for name in shared if _is_complete(state_dict[name])]
)
if not complete_names:
if len(shared) == 1:
# Force contiguous
name = list(shared)[0]
state_dict[name] = state_dict[name].clone()
complete_names = {name}
else:
raise RuntimeError(
f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
)
keep_name = sorted(list(complete_names))[0]
# Mecanism to preferentially select keys to keep
# coming from the on-disk file to allow
# loading models saved with a different choice
# of keep_name
preferred = complete_names.difference(discard_names)
if preferred:
keep_name = sorted(list(preferred))[0]
if preferred_names:
preferred = preferred_names.intersection(complete_names)
if preferred:
keep_name = sorted(list(preferred))[0]
for name in sorted(shared):
if name != keep_name:
to_remove[keep_name].append(name)
return to_remove
def convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]):
"""
Convert a pytorch file to a safetensors file
This will remove duplicate tensors from the file.
Unfortunately, this might not respect *transformers* convention.
Forcing us to check for potentially different keys during load when looking
for specific tensors (making tensor sharing explicit).
"""
loaded = torch.load(pt_file, map_location="cpu", weights_only=True)
if "state_dict" in loaded:
loaded = loaded["state_dict"]
to_removes = _remove_duplicate_names(loaded, discard_names=discard_names)
metadata = {"format": "pt"}
for kept_name, to_remove_group in to_removes.items():
for to_remove in to_remove_group:
if to_remove not in metadata:
metadata[to_remove] = kept_name
del loaded[to_remove]
# Force tensors to be contiguous
loaded = {k: v.contiguous() for k, v in loaded.items()}
dirname = os.path.dirname(sf_file)
os.makedirs(dirname, exist_ok=True)
save_file(loaded, sf_file, metadata=metadata)
reloaded = load_file(sf_file)
for k in loaded:
pt_tensor = loaded[k]
sf_tensor = reloaded[k]
if not torch.equal(pt_tensor, sf_tensor):
raise RuntimeError(f"The output tensors do not match for key {k}")
def convert_files(pt_files: List[Path], sf_files: List[Path], discard_names: List[str]):
assert len(pt_files) == len(sf_files)
N = len(pt_files)
# We do this instead of using tqdm because we want to parse the logs with the launcher
for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)):
# Skip blacklisted files
if (
"arguments" in pt_file.name
or "args" in pt_file.name
or "training" in pt_file.name
):
continue
start = datetime.datetime.now()
convert_file(pt_file, sf_file, discard_names)
elapsed = datetime.datetime.now() - start
logger.info(f"Convert: [{i + 1}/{N}] -- Took: {elapsed}")
================================================
FILE: backends/gaudi/server/text_generation_server/utils/debug.py
================================================
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import os
import glob
import time
import habana_frameworks.torch as htorch
import numpy as np
START_TS = None
DBG_TRACE_FILENAME = os.environ.get("DBG_TRACE_FILENAME")
if "GRAPH_VISUALIZATION" in os.environ:
for f in glob.glob(".graph_dumps/*"):
os.remove(f)
def to_gb_rounded(mem: float) -> float:
"""
Rounds and converts to GB.
Args:
mem (float): memory in bytes
Returns:
float: memory in GB rounded to the second decimal
"""
return np.round(mem / 1024**3, 2)
def count_hpu_graphs():
return len(glob.glob(".graph_dumps/*PreGraph*"))
def dbg_trace(tag, txt):
global START_TS
if DBG_TRACE_FILENAME is not None and int(os.getenv("RANK", 0)) == 0:
if START_TS is None:
START_TS = time.perf_counter()
time_offset = time.perf_counter() - START_TS
mem_stats = htorch.hpu.memory.memory_stats()
mem_used = to_gb_rounded(mem_stats["InUse"])
max_mem_used = to_gb_rounded(mem_stats["MaxInUse"])
print(
f"ts:{time_offset:.3f}s g:{count_hpu_graphs()} mu:{mem_used:.1f}GB "
f"mmu:{max_mem_used:.1f}GB | {tag} | {txt}",
flush=True,
file=open(DBG_TRACE_FILENAME, "a"),
)
================================================
FILE: backends/gaudi/server/text_generation_server/utils/dist.py
================================================
import os
import torch
from torch.distributed import ProcessGroup
from datetime import timedelta
from loguru import logger
# Tensor Parallelism settings
RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.8"))
class FakeBarrier:
def wait(self):
pass
class FakeGroup(ProcessGroup):
def __init__(self, rank, size):
self._rank = rank
self._size = size
super().__init__(rank, size)
def allreduce(self, *args, **kwargs):
return FakeBarrier()
def allgather(self, inputs, local_tensor, **kwargs):
assert (
len(inputs[0]) == len(local_tensor) == 1
), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors"
for input_ in inputs:
input_[0].data = local_tensor[0].data
return FakeBarrier()
def barrier(self, *args, **kwargs):
return FakeBarrier()
def size(self):
return self._size
def rank(self):
return self._rank
def _get_backend_name(self):
return "fake"
def initialize_torch_distributed():
if WORLD_SIZE == 1:
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
else:
if os.getenv("DEBUG", None) == "1":
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
if not torch.distributed.is_initialized():
# Call the init process.
torch.distributed.init_process_group(
backend="hccl",
world_size=WORLD_SIZE,
rank=RANK,
timeout=timedelta(seconds=120),
)
else:
logger.warning("torch.distributed is already initialized.")
return torch.distributed.group.WORLD, RANK, WORLD_SIZE
================================================
FILE: backends/gaudi/server/text_generation_server/utils/hub.py
================================================
import time
import os
from datetime import timedelta
from loguru import logger
from pathlib import Path
from typing import Optional, List
from huggingface_hub import file_download, hf_api, HfApi, hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub.utils import (
LocalEntryNotFoundError,
EntryNotFoundError,
RevisionNotFoundError, # noqa # Import here to ease try/except in other part of the lib
)
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
def _cached_weight_files(
model_id: str, revision: Optional[str], extension: str
) -> List[str]:
"""Guess weight files from the cached revision snapshot directory"""
d = _get_cached_revision_directory(model_id, revision)
if not d:
return []
filenames = _weight_files_from_dir(d, extension)
return filenames
def _weight_hub_files_from_model_info(
info: hf_api.ModelInfo, extension: str
) -> List[str]:
return [
s.rfilename
for s in info.siblings
if s.rfilename.endswith(extension)
and len(s.rfilename.split("/")) == 1
and "arguments" not in s.rfilename
and "args" not in s.rfilename
and "training" not in s.rfilename
]
def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
# os.walk: do not iterate, just scan for depth 1, not recursively
# see _weight_hub_files_from_model_info, that's also what is
# done there with the len(s.rfilename.split("/")) == 1 condition
root, _, files = next(os.walk(str(d)))
filenames = [
os.path.join(root, f)
for f in files
if f.endswith(extension)
and "arguments" not in f
and "args" not in f
and "training" not in f
]
return filenames
def _get_cached_revision_directory(
model_id: str, revision: Optional[str]
) -> Optional[Path]:
if revision is None:
revision = "main"
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path(
file_download.repo_folder_name(repo_id=model_id, repo_type="model")
)
if not repo_cache.is_dir():
# No cache for this model
return None
refs_dir = repo_cache / "refs"
snapshots_dir = repo_cache / "snapshots"
# Resolve refs (for instance to convert main to the associated commit sha)
if refs_dir.is_dir():
revision_file = refs_dir / revision
if revision_file.exists():
with revision_file.open() as f:
revision = f.read()
# Check if revision folder exists
if not snapshots_dir.exists():
return None
cached_shas = os.listdir(snapshots_dir)
if revision not in cached_shas:
# No cache for this revision and we won't try to return a random revision
return None
return snapshots_dir / revision
def weight_hub_files(
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
) -> List[str]:
"""Get the weights filenames on the hub"""
api = HfApi()
if HF_HUB_OFFLINE:
filenames = _cached_weight_files(model_id, revision, extension)
else:
# Online case, fetch model info from the Hub
info = api.model_info(model_id, revision=revision)
filenames = _weight_hub_files_from_model_info(info, extension)
if not filenames:
raise EntryNotFoundError(
f"No {extension} weights found for model {model_id} and revision {revision}.",
None,
)
return filenames
def try_to_load_from_cache(
model_id: str, revision: Optional[str], filename: str
) -> Optional[Path]:
"""Try to load a file from the Hugging Face cache"""
d = _get_cached_revision_directory(model_id, revision)
if not d:
return None
# Check if file exists in cache
cached_file = d / filename
return cached_file if cached_file.is_file() else None
def weight_files(
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
) -> List[Path]:
"""Get the local files"""
# Local model
d = Path(model_id)
if d.exists() and d.is_dir():
local_files = _weight_files_from_dir(d, extension)
if not local_files:
raise FileNotFoundError(
f"No local weights found in {model_id} with extension {extension}"
)
return [Path(f) for f in local_files]
try:
filenames = weight_hub_files(model_id, revision, extension)
except EntryNotFoundError as e:
if extension != ".safetensors":
raise e
# Try to see if there are pytorch weights
pt_filenames = weight_hub_files(model_id, revision, extension=".bin")
# Change pytorch extension to safetensors extension
# It is possible that we have safetensors weights locally even though they are not on the
# hub if we converted weights locally without pushing them
filenames = [
f"{Path(f).stem.lstrip('pytorch_')}.safetensors" for f in pt_filenames
]
if WEIGHTS_CACHE_OVERRIDE is not None:
files = []
for filename in filenames:
p = Path(WEIGHTS_CACHE_OVERRIDE) / filename
if not p.exists():
raise FileNotFoundError(
f"File {p} not found in {WEIGHTS_CACHE_OVERRIDE}."
)
files.append(p)
return files
files = []
for filename in filenames:
cache_file = try_to_load_from_cache(
model_id, revision=revision, filename=filename
)
if cache_file is None:
raise LocalEntryNotFoundError(
f"File {filename} of model {model_id} not found in "
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
f"Please run `text-generation-server download-weights {model_id}` first."
)
files.append(cache_file)
return files
def download_weights(
filenames: List[str], model_id: str, revision: Optional[str] = None
) -> List[Path]:
"""Download the safetensors files from the hub"""
def download_file(fname, tries=5, backoff: int = 5):
local_file = try_to_load_from_cache(model_id, revision, fname)
if local_file is not None:
logger.info(f"File {fname} already present in cache.")
return Path(local_file)
for idx in range(tries):
try:
logger.info(f"Download file: {fname}")
stime = time.time()
local_file = hf_hub_download(
filename=fname,
repo_id=model_id,
revision=revision,
local_files_only=HF_HUB_OFFLINE,
)
logger.info(
f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - stime))}."
)
return Path(local_file)
except Exception as e:
if idx + 1 == tries:
raise e
logger.error(e)
logger.info(f"Retrying in {backoff} seconds")
time.sleep(backoff)
logger.info(f"Retry {idx + 1}/{tries - 1}")
# We do this instead of using tqdm because we want to parse the logs with the launcher
start_time = time.time()
files = []
for i, filename in enumerate(filenames):
file = download_file(filename)
elapsed = timedelta(seconds=int(time.time() - start_time))
remaining = len(filenames) - (i + 1)
eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0
logger.info(f"Download: [{i + 1}/{len(filenames)}] -- ETA: {eta}")
files.append(file)
return files
================================================
FILE: backends/gaudi/server/text_generation_server/utils/import_utils.py
================================================
import torch
def get_hpu_free_memory(device, memory_fraction):
free_hpu_memory, _ = torch.hpu.mem_get_info()
return free_hpu_memory
def synchronize_hpu(device):
torch.hpu.synchronize()
def noop(*args, **kwargs):
pass
empty_cache = noop
synchronize = synchronize_hpu
get_free_memory = get_hpu_free_memory
================================================
FILE: backends/gaudi/server/text_generation_server/utils/kernels.py
================================================
import importlib
from loguru import logger
from hf_kernels import load_kernel as hf_load_kernel
from text_generation_server.utils.log import log_once
def load_kernel(*, module: str, repo_id: str):
"""
Load a kernel. First try to load it as the given module (e.g. for
local development), falling back to a locked Hub kernel.
"""
try:
m = importlib.import_module(module)
log_once(logger.info, f"Using local module for `{module}`")
return m
except ModuleNotFoundError:
return hf_load_kernel(repo_id=repo_id)
__all__ = ["load_kernel"]
================================================
FILE: backends/gaudi/server/text_generation_server/utils/log.py
================================================
from functools import lru_cache
from text_generation_server.utils.dist import RANK
@lru_cache(10)
def log_once(log, msg: str, master=True):
if master:
log_master(log, msg)
else:
log(msg)
def log_master(log, msg: str):
if RANK == 0:
log(msg)
================================================
FILE: backends/gaudi/server/text_generation_server/utils/logits_process.py
================================================
import math
import torch
import habana_frameworks.torch.core as htcore
from loguru import logger
from typing import Dict
from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_schema
from functools import lru_cache
from typing import List, Optional, DefaultDict
import time
from transformers import (
LogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
)
mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
class StaticWarper:
def __init__(
self,
temperature=1.0,
top_k=None,
top_p=None,
typical_p=None,
):
self.warpers = []
if temperature is not None and temperature != 1.0:
temperature = float(temperature)
self.warpers.append(TemperatureLogitsWarper(temperature))
if top_k is not None and top_k != 0:
self.warpers.append(TopKLogitsWarper(top_k=top_k))
if top_p is not None and top_p < 1.0:
self.warpers.append(TopPLogitsWarper(top_p=top_p))
if typical_p is not None and typical_p < 1.0:
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
self.hpu_graph = None
self.static_scores = None
self.static_warped_scores = None
self.static_next_logprob = None
def __call__(self, scores):
if self.hpu_graph is None:
self.static_scores = scores.clone().contiguous()
self.static_warped_scores = scores.clone().contiguous()
self.static_next_logprob = scores.clone().contiguous()
self.hpu_graph = htcore.hpu.HPUGraph()
with htcore.hpu.graph(self.hpu_graph):
local_scores = self.static_scores
for warper in self.warpers:
local_scores = warper(None, local_scores)
self.static_warped_scores.copy_(local_scores)
# Compute logprobs
self.static_next_logprob.copy_(
torch.log_softmax(self.static_warped_scores, -1)
)
self.static_scores.copy_(scores)
self.hpu_graph.replay()
return self.static_warped_scores, self.static_next_logprob
@lru_cache(10)
def static_warper(
temperature: Optional[float],
top_k: Optional[int],
top_p: Optional[float],
typical_p: Optional[float],
) -> StaticWarper:
return StaticWarper(
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
)
class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
Args:
repetition_penalty (`List[float]`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
"""
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
self.penalty = penalty
self.penalty_tensor = torch.tensor(
penalty, dtype=dtype, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
)
scores.scatter_(1, input_ids, score)
return scores
def filter(self, indices):
self.penalty = [self.penalty[i] for i in indices]
if any([x != 1.0 for x in self.penalty]):
self.penalty_tensor = self.penalty_tensor[indices]
return self
return None
class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
r"""
Frequency penalty as defined by OpenAI
Args:
penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty.
"""
def __init__(self, penalty: float):
self.penalty = penalty
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
score = torch.gather(scores, 1, input_ids)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
score = -torch.where(score < 0, score * self.penalty, score / self.penalty)
# set score to 0 where input_ids is a padding token
score *= input_ids.ne(0)
return scores.scatter_add_(1, input_ids, score)
class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
r"""
Frequency penalty as defined by OpenAI in
https://platform.openai.com/docs/guides/text-generation/parameter-details
Args:
frequency_penalty (`List[float]`):
The parameter for frequency penalty. 0.0 means no penalty.
"""
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
self.penalty = penalty
self.penalty_tensor = torch.tensor(
penalty, dtype=dtype, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
batch_size, input_size = input_ids.size()
vocab_size = scores.size(1)
# Calculate the frequency for each token so far
token_freq = torch.zeros(
batch_size, vocab_size, dtype=scores.dtype, device=scores.device
)
token_freq.scatter_add_(
1,
input_ids,
torch.ones_like(input_ids, dtype=scores.dtype, device=scores.device),
)
token_freq /= input_size
# Apply the frequency penalty to logits
scores -= token_freq * self.penalty_tensor
return scores
def filter(self, indices):
self.penalty = [self.penalty[i] for i in indices]
if any([x != 0.0 for x in self.penalty]):
self.penalty_tensor = self.penalty_tensor[indices]
return self
return None
class HeterogeneousTemperatureLogitsWarper:
r"""
[`LogitsProcessor`] for temperature (exponential scaling output probability distribution).
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
Args:
temperature (`float`):
The value used to module the logits distribution.
"""
def __init__(
self, temperature: List[float], dtype: torch.dtype, device: torch.device
):
self.temperature = temperature
self.temperature_tensor = torch.tensor(
temperature, dtype=dtype, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
scores.div_(self.temperature_tensor)
return scores
def filter(self, indices):
self.temperature = [self.temperature[i] for i in indices]
if any([x != 1.0 for x in self.temperature]):
self.temperature_tensor = self.temperature_tensor[indices]
return self
return None
class HeterogeneousTopPLogitsWarper(LogitsProcessor):
"""
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
Args:
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(
self,
top_p: List[float],
dtype: torch.dtype,
device: torch.device,
filter_value: float = -math.inf,
min_tokens_to_keep: int = 1,
):
self.top_p = top_p
self.top_p_opposite = 1 - torch.tensor(
top_p, dtype=dtype, device=device
).unsqueeze(1)
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
probs = sorted_logits.softmax(dim=-1)
# This is way faster for some reason
for i in range(probs.shape[0]):
probs[i] = probs[i].cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = probs <= self.top_p_opposite
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
return warped_scores
def filter(self, indices):
self.top_p = [self.top_p[i] for i in indices]
if any([x < 1.0 for x in self.top_p]):
self.top_p_opposite = self.top_p_opposite[indices]
return self
return None
class HeterogeneousTopKLogitsWarper(LogitsProcessor):
r"""
[`LogitsProcessor`] that performs top-k, i.e. restricting to the k highest probability elements.
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
Args:
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(
self,
top_k: List[int],
device: torch.device,
filter_value: float = -math.inf,
min_tokens_to_keep: int = 1,
):
self.top_k = top_k
self.max_top_k = max(top_k)
# value - 1 as we will use top_k to index and python uses 0 based numbering
self.top_k_tensor = torch.tensor(
[max(x - 1, min_tokens_to_keep - 1) for x in top_k],
dtype=torch.int64,
device=device,
).unsqueeze(1)
# 0 is a special value that disables top_k warping for this member of the batch
disabled = [x == 0 for x in top_k]
if any(disabled):
self.top_k_disabled_mask = torch.tensor(
disabled, dtype=torch.bool, device=device
).view(-1, 1)
else:
self.top_k_disabled_mask = None
self.filter_value = filter_value
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
# If max_top_k is superior to the vocab, we need to clamp or the warper will fail
if scores.size(-1) < self.max_top_k:
max_top_k = scores.size(-1)
top_k = torch.clamp_max(self.top_k_tensor, max_top_k)
else:
max_top_k = self.max_top_k
top_k = self.top_k_tensor
# Get the kth score for each member of the batch
kth_scores = torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k)
# Mask member of kth_scores that do not want to use top_k warping
if self.top_k_disabled_mask is not None:
kth_scores.masked_fill_(self.top_k_disabled_mask, self.filter_value)
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < kth_scores
scores.masked_fill_(indices_to_remove, self.filter_value)
return scores
def filter(self, indices):
self.top_k = [self.top_k[i] for i in indices]
disabled = [x == 0 for x in self.top_k]
if not all(disabled):
self.top_k_tensor = self.top_k_tensor[indices]
self.max_top_k = max(self.top_k)
if self.top_k_disabled_mask is not None:
self.top_k_disabled_mask = (
self.top_k_disabled_mask[indices] if any(disabled) else None
)
return self
return None
class HeterogeneousTypicalLogitsWarper(LogitsProcessor):
r"""
[`LogitsProcessor`] that performs typical decoding. See [Typical Decoding for Natural Language
Generation](https://arxiv.org/abs/2202.00666) for more information.
This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs.
Args:
mass (`float`):
Value of typical_p between 0 and 1 inclusive, defaults to 0.9.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(
self,
mass: List[float],
dtype: torch.dtype,
device: torch.device,
filter_value: float = -math.inf,
min_tokens_to_keep: int = 1,
):
self.mass = mass
self.mass_tensor = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1)
# 1 is a special value that disables typical_p warping for this member of the batch
disabled = [x == 1.0 for x in mass]
if any(disabled):
self.disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device)
else:
self.disabled_mask = None
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
# calculate entropy
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True)
# shift and sort
shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
sorted_logits = scores.gather(-1, sorted_indices)
probs = sorted_logits.softmax(dim=-1)
# This is way faster for some reason
for i in range(probs.shape[0]):
probs[i] = probs[i].cumsum(dim=-1)
# Remove tokens with cumulative mass above the threshold
last_ind = (probs < self.mass_tensor).sum(dim=1)
last_ind[last_ind < 0] = 0
if self.disabled_mask is not None:
last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1)
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
1, last_ind.view(-1, 1)
)
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
return warped_scores
def filter(self, indices):
self.mass = [self.mass[i] for i in indices]
disabled = [x == 1.0 for x in self.mass]
if not all(disabled):
self.mass_tensor = self.mass_tensor[indices]
if self.disabled_mask is not None:
self.disabled_mask = (
self.disabled_mask[indices] if any(disabled) else None
)
return self
return None
class HeterogeneousProcessorWrapper(LogitsProcessor):
r"""
A wrapper for logit warpers or processors without heterogeneous parameter support.
Args:
processors (`Dict[int, LogitsProcessor]`):
A mapping of sample indices to logit warpers or processors, to be run sequentially.
"""
def __init__(
self,
processors: Dict[int, LogitsProcessor],
):
self.processors = processors
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
for i, processor in self.processors.items():
scores[i : i + 1] = processor(input_ids[i : i + 1], scores[i : i + 1])
return scores
def filter(self, indices):
new_processors = {}
for i, idx in enumerate(indices):
if idx in self.processors:
new_processors[i] = self.processors[idx]
if new_processors:
self.processors = new_processors
return self
return None
class GrammarLogitProcessor(LogitsProcessor):
fsm_state: DefaultDict[int, int]
fsm: RegexFSM
def __init__(self, tokenizer, device, grammar, grammar_type):
self.device = device
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsm = GrammarLogitProcessor._cached_compile_fsm(
grammar_type, grammar, self.tokenizer
)
def __call__(
self,
logits: torch.Tensor,
fsm_grammar_state: int,
):
if fsm_grammar_state == -1 or self.fsm is None:
return logits
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
mask = torch.full_like(logits, -math.inf)
mask[:, allowed_tokens] = 0
biased_scores = logits + mask
return biased_scores
def advance(self, next_token_id, fsm_grammar_state):
return GrammarLogitProcessor._advance(
next_token_id, fsm_grammar_state, self.fsm
)
@staticmethod
def _advance(next_token_id, fsm_grammar_state, fsm):
if fsm_grammar_state == -1:
return fsm_grammar_state
return fsm.next_state(fsm_grammar_state, next_token_id)
# TODO: move grammar compilation into the router
@staticmethod
@lru_cache(maxsize=32, typed=True)
def _cached_compile_fsm(grammar_type, schema, tokenizer):
start_time = time.time()
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
schema = build_regex_from_schema(schema)
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
pass # schema is already a regex just here for clarity
fsm = RegexFSM(schema, tokenizer)
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
return fsm
@staticmethod
@lru_cache(maxsize=32, typed=True)
def _cached_adapt_tokenizer(tokenizer):
"""Adapt tokenizer to work with the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.
"""
start_time = time.time()
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
tokenizer.convert_token_to_string = convert_token_to_string
logger.debug(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
return tokenizer
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
def __init__(self, tokenizer, device, grammars, grammar_types):
self.device = device
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsms = []
for grammar, grammar_type in zip(grammars, grammar_types):
if len(grammar) == 0:
self.fsms.append(None)
continue
fsm = GrammarLogitProcessor._cached_compile_fsm(
grammar_type, grammar, self.tokenizer
)
self.fsms.append(fsm)
def __call__(
self,
logits: torch.Tensor,
fsm_grammar_states: List[int],
):
mask = torch.full_like(logits, -math.inf)
for i in range(logits.shape[0]):
fsm = self.fsms[i]
if fsm is None:
continue
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
mask[i, allowed_tokens] = 0
logits[i] += mask[i]
return logits
def advance_batch(self, next_token_ids, fsm_grammar_states):
return [
GrammarLogitProcessor._advance(
next_token_ids[i], fsm_grammar_states[i], self.fsms[i]
)
for i in range(len(next_token_ids))
]
def advance_at_index(self, next_token_id, fsm_grammar_state, index):
if self.fsms[index] is None:
return fsm_grammar_state
return GrammarLogitProcessor._advance(
next_token_id, fsm_grammar_state, self.fsms[index]
)
def filter(self, indices):
new_fsms = []
for i in indices:
new_fsms.append(self.fsms[i])
self.fsms = new_fsms
return self
================================================
FILE: backends/gaudi/server/text_generation_server/utils/merges/strategies.py
================================================
import copy
from abc import ABC
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union
from text_generation_server.utils.merges.utils import (
calculate_majority_sign_mask,
disjoint_merge,
prune,
)
import torch
if TYPE_CHECKING:
from text_generation_server.adapters.lora import LoraConfig
from text_generation_server.utils.adapter import ModuleMap
class AdapterParameters:
def __init__(
self, adapter_ids, weights, merge_strategy, density, majority_sign_method
):
self.adapter_ids = adapter_ids
self.weights = weights
self.merge_strategy = merge_strategy
self.density = density
self.majority_sign_method = majority_sign_method
def _apply_weights(
tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor
) -> torch.Tensor:
if isinstance(tensors, torch.Tensor):
t = tensors
else:
t = torch.stack(tensors, dim=0)
# element-wise weighting of each task tensor
# need to unsqueeze weights to match task tensor dimensions
# for multiplication to apply element-wise
while len(t.shape) > len(w.shape):
w = w.unsqueeze(-1)
return t * w
class MergeStrategy(ABC):
def merge(
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
) -> torch.Tensor:
raise NotImplementedError()
class LinearMerge(MergeStrategy):
def __init__(self, **kwargs):
pass
def merge(
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
) -> torch.Tensor:
weighted_task_tensors = _apply_weights(task_tensors, weights)
return weighted_task_tensors.sum(dim=0)
class TiesMerge(MergeStrategy):
def __init__(self, density: float, majority_sign_method: str = "total", **kwargs):
self.density = density
self.majority_sign_method = majority_sign_method
def merge(
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
) -> torch.Tensor:
# sparsify
task_tensors = [
prune(tensor, self.density, method="magnitude") for tensor in task_tensors
]
task_tensors = torch.stack(task_tensors, dim=0)
# elect sign before applying weights
majority_sign_mask = calculate_majority_sign_mask(
task_tensors, method=self.majority_sign_method
)
weighted_task_tensors = _apply_weights(task_tensors, weights)
# disjoint merge
return disjoint_merge(weighted_task_tensors, majority_sign_mask)
class DareLinearMerge(MergeStrategy):
def __init__(self, density: float, **kwargs):
self.density = density
def merge(
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
) -> torch.Tensor:
# sparsify
task_tensors = [
prune(tensor, self.density, method="random", rescale=True)
for tensor in task_tensors
]
weighted_task_tensors = _apply_weights(task_tensors, weights)
return weighted_task_tensors.sum(dim=0)
class DareTiesMerge(MergeStrategy):
def __init__(self, density: float, majority_sign_method: str = "total", **kwargs):
self.density = density
self.majority_sign_method = majority_sign_method
def merge(
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
) -> torch.Tensor:
# sparsify
task_tensors = [
prune(tensor, self.density, method="random", rescale=True)
for tensor in task_tensors
]
task_tensors = torch.stack(task_tensors, dim=0)
# elect sign before applying weights
majority_sign_mask = calculate_majority_sign_mask(
task_tensors, method=self.majority_sign_method
)
weighted_task_tensors = _apply_weights(task_tensors, weights)
# disjoint merge
mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask)
return mixed_task_tensors
strategy_registry: Dict[str, Type[MergeStrategy]] = {
"linear": LinearMerge,
"ties": TiesMerge,
"dare_linear": DareLinearMerge,
"dare_ties": DareTiesMerge,
}
def merge_adapters(
adapters: List[Tuple["ModuleMap", "LoraConfig"]],
merge_params: AdapterParameters,
) -> Tuple["ModuleMap", "LoraConfig"]:
# strategy_name = MergeStrategyEnum.Name(merge_params.merge_strategy).lower()
strategy_name = "linear"
weights = merge_params.weights
if not weights:
weights = torch.ones(len(adapters))
else:
weights = torch.tensor(weights)
merge_config = {
"density": merge_params.density,
# "majority_sign_method": MajoritySignMethodEnum.Name(
# merge_params.majority_sign_method
# ).lower(),
"majority_sign_method": "total",
}
merge_strategy = strategy_registry[strategy_name](**merge_config)
module_maps: Dict[str, Dict[str, Dict[str, List[torch.Tensor]]]] = defaultdict(
lambda: defaultdict(lambda: defaultdict(list))
)
lora_configs = []
weight_name_to_adapter_idx = defaultdict(list)
# input is list of (module_map, lora_config) tuples
# convert into dict[k][param_name] -> list of tensors
for idx, (module_map, lora_config) in enumerate(adapters):
for weight_name, data in module_map.items():
weight_name_to_adapter_idx[weight_name].append(idx)
for k, (param_data, param_name) in data.items():
module_maps[weight_name][k][param_name].append(param_data)
lora_configs.append(lora_config)
# validate lora configs are compatible
_validate_lora_configs(lora_configs)
# merge tensors for each module such that we have a single ModuleMap:
# dict[k] -> merged tensor
merged_module_map: "ModuleMap" = defaultdict(dict)
for weight_name, data in module_maps.items():
indices = weight_name_to_adapter_idx[weight_name]
param_weights = weights[indices]
for k, param_data in data.items():
for param_name, tensors in param_data.items():
merged_tensor = merge_strategy.merge(tensors, param_weights)
merged_module_map[weight_name][k] = (merged_tensor, param_name)
# merge lora configs
merged_lora_config = _merge_lora_configs(lora_configs)
return merged_module_map, merged_lora_config
def _validate_lora_configs(lora_configs: List["LoraConfig"]):
# check that all configs have the same rank
ranks = set(lora_config.r for lora_config in lora_configs)
if len(ranks) > 1:
raise ValueError(
f"unable to merge adapters, lora configs have different ranks: {ranks}"
)
if all(len(lora_config.target_modules) == 0 for lora_config in lora_configs):
raise ValueError(
"unable to merge adapters, lora configs have no target modules"
)
def _merge_lora_configs(lora_configs: List["LoraConfig"]) -> "LoraConfig":
merged_lora_config = copy.copy(lora_configs[0])
# merge target modules as a union operation
merged_target_modules = sorted(
set(
module
for lora_config in lora_configs
for module in lora_config.target_modules
)
)
merged_lora_config.target_modules = merged_target_modules
return merged_lora_config
================================================
FILE: backends/gaudi/server/text_generation_server/utils/merges/utils.py
================================================
# coding=utf-8
# From: https://github.com/huggingface/peft/pull/1364
# Copyright 2024-present the HuggingFace Inc. team.
# Modifications by Predibase, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal
import torch
def magnitude_based_pruning(tensor: torch.Tensor, density: float) -> torch.Tensor:
"""
Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction
`density`.
Args:
tensor (`torch.Tensor`):The tensor to prune.
density (`float`):The fraction of values to preserve. Should be in [0,1].
"""
mask = torch.zeros_like(tensor).reshape(-1)
k = int(density * tensor.reshape(-1).shape[0])
top_k = torch.topk(tensor.abs().reshape(-1), k=k, largest=True)
mask[top_k[1]] = 1
return tensor * mask.reshape(tensor.shape)
def random_pruning(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor:
"""
Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction
`density`.
Args:
tensor (`torch.Tensor`):The tensor to prune.
density (`float`):The fraction of values to preserve. Should be in [0,1].
rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.
"""
mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density))
pruned_tensor = tensor * mask
if rescale:
torch.div(input=pruned_tensor, other=density)
return pruned_tensor
def prune(
tensor: torch.Tensor,
density: float,
method: Literal["magnitude", "random"],
rescale: bool = False,
) -> torch.Tensor:
"""
Prune the values of task tensors based on the `method`.
Args:
tensor (`torch.Tensor`):The tensor to prune.
density (`float`):The fraction of values to preserve. Should be in [0,1].
method (`str`):The method to use to prune. Should be one of ["magnitude", "random"].
rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.
"""
if density >= 1:
return tensor
elif density < 0:
raise ValueError("Density should be >= 0, got {density}")
if method == "magnitude":
return magnitude_based_pruning(tensor, density)
elif method == "random":
return random_pruning(tensor, density, rescale=rescale)
else:
raise ValueError(f"Unknown method {method}")
def calculate_majority_sign_mask(
tensor: torch.Tensor, method: Literal["total", "frequency"] = "total"
):
"""
Get the mask of the majority sign across the task tensors. Task tensors are stacked on dimension 0.
Args:
tensor (`torch.Tensor`):The tensor to get the mask from.
method (`str`):The method to use to get the mask. Should be one of ["total", "frequency"].
"""
sign = tensor.sign()
if method == "total":
sign_magnitude = (sign * tensor.abs()).sum(dim=0)
elif method == "frequency":
sign_magnitude = sign.sum(dim=0)
else:
raise RuntimeError(f'Unimplemented mask method "{method}"')
majority_sign = torch.where(sign_magnitude >= 0, 1, -1)
return sign == majority_sign
def disjoint_merge(task_tensors, majority_sign_mask):
mixed_task_tensors = (task_tensors * majority_sign_mask).sum(dim=0)
num_params_preserved = majority_sign_mask.sum(dim=0)
return mixed_task_tensors / torch.clamp(num_params_preserved, min=1.0)
================================================
FILE: backends/gaudi/server/text_generation_server/utils/peft.py
================================================
import os
from typing import Union
from loguru import logger
import torch
from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
def download_and_unload_peft(model_id, revision, trust_remote_code):
torch_dtype = torch.float16
logger.info("Trying to load a Peft model. It might take a while without feedback")
try:
model = AutoPeftModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
)
except Exception:
model = AutoPeftModelForSeq2SeqLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
)
logger.info("Peft model detected.")
logger.info("Merging the lora weights.")
base_model_id = model.peft_config["default"].base_model_name_or_path
model = model.merge_and_unload()
os.makedirs(model_id, exist_ok=True)
cache_dir = model_id
logger.info(f"Saving the newly created merged model to {cache_dir}")
tokenizer = AutoTokenizer.from_pretrained(
base_model_id, trust_remote_code=trust_remote_code
)
model.save_pretrained(cache_dir, safe_serialization=True)
model.config.save_pretrained(cache_dir)
tokenizer.save_pretrained(cache_dir)
def download_peft(
model_id: Union[str, os.PathLike], revision: str, trust_remote_code: bool
):
torch_dtype = torch.float16
try:
_model = AutoPeftModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
)
except Exception:
_model = AutoPeftModelForSeq2SeqLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
)
logger.info("Peft model downloaded.")
================================================
FILE: backends/gaudi/server/text_generation_server/utils/prefill_chunking.py
================================================
from typing import Optional
SUPPORT_CHUNKING: Optional[bool] = None
MAX_PREFILL_TOKENS: Optional[int] = None
def set_support_chunking(support_chunking: bool):
global SUPPORT_CHUNKING
SUPPORT_CHUNKING = support_chunking
def get_support_chunking() -> bool:
global SUPPORT_CHUNKING
return SUPPORT_CHUNKING
def set_max_prefill_tokens(max_prefill_tokens: int):
global MAX_PREFILL_TOKENS
MAX_PREFILL_TOKENS = max_prefill_tokens
def get_max_prefill_tokens() -> int:
global MAX_PREFILL_TOKENS
return MAX_PREFILL_TOKENS
================================================
FILE: backends/gaudi/server/text_generation_server/utils/quantization.py
================================================
import json
import os
from dataclasses import dataclass
from typing import Optional, List
from huggingface_hub import hf_hub_download
from text_generation_server.utils.weights import (
WeightsLoader,
)
# TODO: Split this config to have a single config type per quant method
@dataclass
class _QuantizerConfig:
bits: int
checkpoint_format: Optional[str]
desc_act: bool
groupsize: int
quant_method: str
sym: bool
weight_block_size: Optional[List[int]]
modules_to_not_convert: List[str]
@dataclass
class _FP8QuantizerConfig:
activation_scale_ub: float
def _get_config_json(model_id: str, revision: Optional[str], filename: str):
if os.path.exists(
os.path.join(
model_id,
)
):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(model_id, filename=filename, revision=revision)
with open(filename, "r") as f:
return json.load(f)
# We should probably do this with Pydantic JSON deserialization,
# but for now we'll stay close to the old _set_gptq_params.
def _get_quantizer_config(model_id, revision):
bits = 4
groupsize = -1
quant_method = "gptq"
checkpoint_format = None
sym = False
desc_act = False
weight_block_size = None
modules_to_not_convert = []
filename = "config.json"
try:
data = _get_config_json(model_id, revision, filename)
# FP8 config
if data["quantization_config"]["quant_method"] == "fbgemm_fp8":
return _FP8QuantizerConfig(
activation_scale_ub=data["quantization_config"]["activation_scale_ub"]
)
weight_block_size = data["quantization_config"].get("weight_block_size", None)
if "zero_point" in data["quantization_config"]:
sym = not data["quantization_config"]["zero_point"]
quant_method = "awq"
elif "sym" in data["quantization_config"]:
sym = data["quantization_config"]["sym"]
bits = data["quantization_config"]["bits"]
groupsize = data["quantization_config"]["group_size"]
# Order is important here, desc_act is missing on some real models
quant_method = data["quantization_config"]["quant_method"]
checkpoint_format = data["quantization_config"].get("checkpoint_format")
desc_act = data["quantization_config"].get("desc_act", False)
modules_to_not_convert = data["quantization_config"].get(
"modules_to_not_convert", []
)
if modules_to_not_convert is None:
modules_to_not_convert = []
except Exception:
filename = "quantize_config.json"
try:
data = _get_config_json(model_id, revision, filename)
bits = data["bits"]
groupsize = data["group_size"]
if "zero_point" in data:
sym = not data["zero_point"]
quant_method = "awq"
elif "sym" in data:
sym = data["sym"]
desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
quant_method = "awq"
except Exception:
filename = "quant_config.json"
try:
data = _get_config_json(model_id, revision, filename)
bits = data["w_bit"]
groupsize = data["q_group_size"]
desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
quant_method = "awq"
except Exception:
pass
return _QuantizerConfig(
bits=bits,
groupsize=groupsize,
quant_method=quant_method,
checkpoint_format=checkpoint_format,
sym=sym,
desc_act=desc_act,
weight_block_size=weight_block_size,
modules_to_not_convert=modules_to_not_convert,
)
def get_loader(
quantize: Optional[str], model_id: str, revision: Optional[str]
) -> WeightsLoader:
if quantize == "compressed-tensors":
config = _get_config_json(model_id, revision, "config.json")
from text_generation_server.layers.compressed_tensors import (
CompressedTensorsLoader,
)
return CompressedTensorsLoader(config)
quantizer_config = _get_quantizer_config(model_id, revision)
if quantize in {"awq", "gptq"}:
from text_generation_server.layers.gptq import GPTQWeightsLoader
# TODO: improve check once we have one config type per quantize value
if not isinstance(quantizer_config, _QuantizerConfig):
raise ValueError(
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
)
return GPTQWeightsLoader(
bits=quantizer_config.bits,
desc_act=quantizer_config.desc_act,
groupsize=quantizer_config.groupsize,
quant_method=quantizer_config.quant_method,
quantize=quantize,
sym=quantizer_config.sym,
modules_to_not_convert=quantizer_config.modules_to_not_convert,
)
elif quantize == "fp8" or quantize is None:
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
# Since the default for the quantize config is _QuantizerConfig,
# we need to add this check to not get an attribute error
activation_scale_ub = None
weight_block_size = quantizer_config.weight_block_size
if isinstance(quantizer_config, _FP8QuantizerConfig):
activation_scale_ub = quantizer_config.activation_scale_ub
return HybridFP8UnquantLoader(
activation_scale_ub,
to_fp8=quantize == "fp8",
weight_block_size=weight_block_size,
)
else:
raise ValueError(f"Unknown quantization method: {quantize}")
================================================
FILE: backends/gaudi/server/text_generation_server/utils/segments.py
================================================
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/utils/segments.py
# License: Apache License Version 2.0, January 2004
from typing import List, Tuple, Union
import torch
def find_segments(
adapter_indices: Union[torch.Tensor, List[int]],
) -> Tuple[List[int], List[int]]:
segments = [0]
segment_indices = []
if isinstance(adapter_indices, torch.Tensor):
# Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first
adapter_indices = adapter_indices.cpu().tolist()
start_index = 0
for i in range(1, len(adapter_indices)):
if adapter_indices[i] != adapter_indices[i - 1]:
segments.append(i)
segment_indices.append(adapter_indices[i - 1])
start_index = i
# Handle the last segment
if start_index < len(adapter_indices):
segments.append(len(adapter_indices))
segment_indices.append(adapter_indices[-1])
return segments, segment_indices
class SegmentConcatBuilder:
def __init__(self):
self.adapter_segment_indices = []
self.adapter_segment_tensors = []
def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]):
# Update adapter segments
if self.adapter_segment_tensors:
# Because we have already processed at least one batch, remove the 0 start index
# from this batch denoting the beginning of the segment, then offset all segment
# positions by the value of the last segment in the previous batch to account for
# the concatenation.
adapter_segments = (
adapter_segments[1:] + self.adapter_segment_tensors[-1][-1]
)
if (
self.adapter_segment_indices
and self.adapter_segment_indices[-1] == segment_indices[0]
):
# If the last segment in the previous batch is the same as the first segment in this batch,
# then we merge them together into a single segment. In effect, this means removing it from
# the segment indices of this batch, and extending the segment span by removing the segment
# end index from the previous batch.
segment_indices = segment_indices[1:]
self.adapter_segment_tensors[-1] = self.adapter_segment_tensors[-1][:-1]
self.adapter_segment_indices.extend(segment_indices)
self.adapter_segment_tensors.append(adapter_segments)
def build(self) -> Tuple[torch.Tensor, List[int]]:
return torch.concat(self.adapter_segment_tensors), self.adapter_segment_indices
================================================
FILE: backends/gaudi/server/text_generation_server/utils/sgmv.py
================================================
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/utils/sgmv.py
# License: Apache License Version 2.0, January 2004
import os
import warnings
from functools import lru_cache
from typing import List, Tuple
import torch
import torch.nn.functional as F
try:
import punica_kernels as _kernels
HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", ""))
except ImportError:
warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.")
_kernels = None
HAS_SGMV = False
MIN_SGMV_RANK = 8
MIN_RANK_CUSTOM = 16
MAX_RANK_CUSTOM = 128
SGMV_BLOCK_SIZE = 16
BGMV_MAX_RANK = 64
def has_sgmv() -> bool:
return HAS_SGMV
def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
"""Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size."""
if not has_sgmv():
return t
# tensor parallelism will result in effective rank being divided by world_size,
# so we need to scale the min rank to offset that effect
min_rank = MIN_SGMV_RANK * world_size
# if we're at or below the min rank, pad up to the min rank
# otherwise, pad to the nearest multiple of the block size
current_rank = t.size(dim)
target_rank = (
min_rank
if current_rank <= min_rank
else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE
)
if current_rank == target_rank:
return t
pad_size = target_rank - current_rank
# see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
pad = [0, 0] * t.dim()
pad[(t.dim() - dim - 1) * 2 + 1] = pad_size
pad = tuple(pad)
return F.pad(t, pad, mode="constant", value=0.0)
def use_cutlass_shrink(lora_rank: int) -> bool:
return lora_rank < MIN_RANK_CUSTOM
def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:
if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:
return t.transpose(0, 1)
return t
# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py
def add_lora_sgmv_cutlass(
y: torch.Tensor,
x: torch.Tensor,
wa_ptr: torch.Tensor,
wb_ptr: torch.Tensor,
s_start: torch.Tensor,
s_end: torch.Tensor,
layer_idx: int,
lora_rank: int,
):
"""
Semantics:
y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
Weight matrix shape: `[num_layers, R, H1]`.
wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
Weight matrix shape: `[num_layers, R, H2]`.
s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.
s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.
layer_idx: Layer index of the weight matrices.
"""
if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:
# Custom SGMV shrink only supports rank 16, 32, 64, 128
_add_lora_sgmv_cutlass_legacy(
y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank
)
return
tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
_kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)
def _add_lora_sgmv_cutlass_legacy(
y: torch.Tensor,
x: torch.Tensor,
wa_ptr: torch.Tensor,
wb_ptr: torch.Tensor,
s_start: torch.IntTensor,
s_end: torch.IntTensor,
layer_idx: int,
lora_rank: int,
):
tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
_kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
@lru_cache(maxsize=1)
def get_tmp_tensor(device: torch.device) -> torch.Tensor:
return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)
@lru_cache(maxsize=32)
def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
tmp_size = _kernels.sgmv_cutlass_tmp_size(size)
return torch.empty((tmp_size,), dtype=torch.uint8, device=device)
def get_tmp_tensor_for_size_no_kernels(size: int, device: torch.device) -> torch.Tensor:
return torch.empty((size,), dtype=torch.uint8, device=device)
def get_tmp_expand_size(size: int) -> int:
return _kernels.sgmv_cutlass_tmp_size(size)
def get_tmp_tensors(
nsegments: int, lora_rank: int, device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]:
use_cutlass = use_cutlass_shrink(lora_rank) and has_sgmv()
has_sgmv_available = has_sgmv()
if use_cutlass:
tmp = get_tmp_tensor_for_size(nsegments, device)
return tmp, tmp
elif has_sgmv_available:
return get_tmp_tensor(device), get_tmp_tensor_for_size(nsegments, device)
else:
tmp = get_tmp_tensor_for_size(nsegments, device)
return tmp, tmp
def lora_a_sgmv_cutlass(
x: torch.Tensor,
tmp: torch.Tensor,
wa_ptr: torch.Tensor,
s_start: torch.IntTensor,
s_end: torch.IntTensor,
layer_idx: int,
lora_rank: int,
) -> torch.Tensor:
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
_kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
else:
_kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
return v
def lora_b_sgmv_cutlass(
y: torch.Tensor,
v: torch.Tensor,
tmp: torch.Tensor,
wb_ptr: torch.Tensor,
s_start: torch.IntTensor,
s_end: torch.IntTensor,
layer_idx: int,
):
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
"""
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ wa_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
v: Shape: `[B, R]`. Temporary vector.
x: Shape: `[B, H1]`. Input vectors.
wa_T_all: Shape: `[None, L, R, H1]`. All of the transposed LoRA A matrices.
wb_T_all: Shape: `[None, L, H2, R]`. All of the transposed LoRA B matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
"""
def add_lora_a_bgmv(
v: torch.Tensor,
x: torch.Tensor,
wa_T_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
):
_kernels.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0)
def add_lora_b_bgmv(
y: torch.Tensor,
v: torch.Tensor,
wb_T_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
):
_kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)
def segmented_matmul(
y: torch.Tensor,
x: torch.Tensor,
w: List[torch.Tensor],
b: List[torch.Tensor],
s_start: torch.IntTensor,
s_end: torch.IntTensor,
):
for i in range(len(w)):
if s_end[i] - s_start[i] <= 0:
continue
xi = x[s_start[i] : s_end[i]]
wi = w[i]
bi = b[i]
y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi)
================================================
FILE: backends/gaudi/server/text_generation_server/utils/speculate.py
================================================
SPECULATE = None
def get_speculate() -> int:
global SPECULATE
return SPECULATE
def set_speculate(speculate: int):
global SPECULATE
SPECULATE = speculate
================================================
FILE: backends/gaudi/server/text_generation_server/utils/tokens.py
================================================
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import re
from typing import List, Optional, Tuple, Set, Union
import torch
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType
from text_generation_server.utils.logits_process import (
FrequencyPenaltyLogitsProcessor,
GrammarLogitProcessor,
HeterogeneousProcessorWrapper,
HeterogeneousRepetitionPenaltyLogitsProcessor,
HeterogeneousFrequencyPenaltyLogitsProcessor,
HeterogeneousTemperatureLogitsWarper,
HeterogeneousTopKLogitsWarper,
HeterogeneousTopPLogitsWarper,
HeterogeneousTypicalLogitsWarper,
HeterogeneousGrammarLogitProcessor,
static_warper,
)
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor
import os
class NextTokenChooser:
def __init__(
self,
watermark: bool = False,
temperature: float = 1.0,
repetition_penalty: float = 1.0,
frequency_penalty: float = 0.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
typical_p: Optional[float] = None,
do_sample: bool = False,
seed: int = 0,
device: str = "cpu",
tokenizer: Optional[PreTrainedTokenizerBase] = None,
grammar: str = "",
grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
fsm_grammar_state: int = 0,
):
self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None
)
self.repetition_processor = (
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
if repetition_penalty and repetition_penalty != 1.0
else None
)
self.frequency_processor = (
FrequencyPenaltyLogitsProcessor(penalty=frequency_penalty)
if frequency_penalty and frequency_penalty != 0.0
else None
)
self.grammar_processor = (
GrammarLogitProcessor(tokenizer, device, grammar, grammar_type)
if grammar != ""
else None
)
self.tokenizer = tokenizer
has_warpers = (
(temperature is not None and temperature != 1.0)
or (top_k is not None and top_k != 0)
or (top_p is not None and top_p < 1.0)
or (typical_p is not None and typical_p < 1.0)
)
if has_warpers:
self.static_warper = static_warper(
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
)
else:
self.static_warper = None
sampling = do_sample or has_warpers
self.choice = Sampling(seed, device) if sampling else Greedy()
self.fsm_grammar_state = fsm_grammar_state
self.grammar = grammar
def __call__(self, input_ids, scores):
if self.watermark_processor is not None:
scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores)
if self.frequency_processor is not None:
scores = self.frequency_processor(input_ids, scores)
if self.grammar_processor is not None:
scores = self.grammar_processor(scores, self.fsm_grammar_state)
if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1)
else:
scores, next_logprob = self.static_warper(scores)
next_id = self.choice(scores[-1]).view(1, 1)
return next_id, next_logprob
def advance_grammar(self, next_id: int):
if self.grammar_processor is not None:
self.fsm_grammar_state = self.grammar_processor.advance(
next_id, self.fsm_grammar_state
)
return self
@classmethod
def from_pb(
cls,
pb: generate_pb2.NextTokenChooserParameters,
device: torch.device,
tokenizer: PreTrainedTokenizerBase,
) -> "NextTokenChooser":
return NextTokenChooser(
watermark=pb.watermark,
temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty,
frequency_penalty=pb.frequency_penalty,
top_k=pb.top_k,
top_p=pb.top_p,
typical_p=pb.typical_p,
do_sample=pb.do_sample,
seed=pb.seed,
device=device,
tokenizer=tokenizer,
grammar=pb.grammar,
grammar_type=pb.grammar_type,
)
class StopSequenceCriteria:
def __init__(self, stop_sequence: str):
stop_sequence = re.escape(stop_sequence)
self.regex = re.compile(f"{stop_sequence}$")
def __call__(self, output: str) -> bool:
if self.regex.findall(output):
return True
return False
class StoppingCriteria:
def __init__(
self,
eos_token_ids: Optional[Union[Set[int], int]],
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens: int = 20,
ignore_eos_token: bool = False,
):
if eos_token_ids is None:
eos_token_ids = set()
elif isinstance(eos_token_ids, int):
eos_token_ids = set([eos_token_ids])
elif isinstance(eos_token_ids, set):
eos_token_ids = eos_token_ids
else:
raise RuntimeError(
f"eos_token_ids is of invalid type {type(eos_token_ids)}, expected int, None or set[int]"
)
self.eos_token_ids = eos_token_ids
self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
self.current_output = ""
if os.getenv("TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN", "false") == "true":
self.ignore_eos_token = True
else:
self.ignore_eos_token = ignore_eos_token
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
if isinstance(last_token, torch.Tensor):
last_token = last_token.item()
if not self.ignore_eos_token and last_token in self.eos_token_ids:
return True, FinishReason.FINISH_REASON_EOS_TOKEN
if self.stop_sequence_criterias:
self.current_output += last_output
# There is no need to keep an output that is too long
if len(self.current_output) > 300:
# Slice to -200 to avoid doing it all the time
self.current_output = self.current_output[-200:]
for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(self.current_output):
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
return False, None
@classmethod
def from_pb(
cls,
pb: generate_pb2.StoppingCriteriaParameters,
tokenizer: PreTrainedTokenizerBase,
) -> "StoppingCriteria":
stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
]
# TODO Hack because eos_token_id cannot be what we want.
eos_token_id = getattr(tokenizer, "_eos_token_ids", tokenizer.eos_token_id)
return StoppingCriteria(
eos_token_id,
stop_sequence_criterias,
pb.max_new_tokens,
pb.ignore_eos_token,
)
def create_n_gram_speculation(
input_ids: torch.Tensor,
next_ids: torch.Tensor,
accepted_ids: torch.Tensor,
speculate: int,
verbose: bool,
):
# Very trivial approach, find first match in the string.
# This is much less refined than actual n-gram but seems to work
# relatively OK in grounded mode and is by far much faster with
# much less worst case complexity as everything happens on device.
B = accepted_ids.shape[0]
device = input_ids.device
seeds = next_ids[accepted_ids.cumsum(dim=-1) - 1]
indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1
all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(
speculate, device=device
)
all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)
speculative_ids = input_ids.gather(dim=-1, index=all_indices)
return speculative_ids
class HeterogeneousNextTokenChooser:
def __init__(
self,
dtype: torch.dtype,
device: torch.device,
watermark: List[bool],
temperature: List[float],
repetition_penalty: List[float],
frequency_penalty: List[float],
top_k: List[int],
top_p: List[float],
typical_p: List[float],
do_sample: List[bool],
seeds: List[int],
tokenizer: PreTrainedTokenizerBase,
grammars: List[str],
grammar_types: List[int],
fsm_grammar_states: List[int],
quantization_enabled: bool,
):
warpers = []
# TODO: enable watermark with FP8 quantization
self.watermark_processor = (
HeterogeneousProcessorWrapper(
{
i: WatermarkLogitsProcessor(device=device)
for i, do_watermark in enumerate(watermark)
if do_watermark
}
)
if any(watermark) and not quantization_enabled
else None
)
self.repetition_processor = (
HeterogeneousRepetitionPenaltyLogitsProcessor(
repetition_penalty, dtype, device
)
if any([x != 1.0 for x in repetition_penalty])
else None
)
self.frequency_processor = (
HeterogeneousFrequencyPenaltyLogitsProcessor(
frequency_penalty, dtype, device
)
if any([x != 0.0 for x in frequency_penalty])
else None
)
self.grammar_processor = (
HeterogeneousGrammarLogitProcessor(
tokenizer, device, grammars, grammar_types
)
if any([grammar != "" for grammar in grammars])
else None
)
if any(x != 1.0 for x in temperature):
do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
]
warpers.append(
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
)
if any(x != 0 for x in top_k):
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))
if any(x < 1.0 for x in top_p):
do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]
warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))
if any(x < 1.0 for x in typical_p):
do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]
warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))
self.warpers = warpers
if any(do_sample):
self.choice = HeterogeneousSampling(do_sample, seeds, device)
else:
self.choice = Greedy()
self.seeds = seeds
self.do_sample = do_sample
self.dtype = dtype
self.device = device
self.tokenizer = tokenizer
self.fsm_grammar_states = fsm_grammar_states
self.grammars = grammars
self.grammar_types = grammar_types
def __call__(
self,
input_ids: torch.Tensor,
scores: torch.Tensor,
speculate: int,
speculated_ids: Optional[torch.Tensor] = None,
speculative_scores: Optional[torch.Tensor] = None,
verbose=False,
):
if speculated_ids is not None:
B = scores.shape[0] // (speculated_ids.shape[1] + 1)
S = speculated_ids.shape[1] + 1
scores = scores.view(B, S, -1)
else:
B = scores.shape[0]
S = 1
scores = scores.view(B, S, -1)
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
for j in range(S):
_scores = scores[:, j]
if self.watermark_processor is not None:
_scores = self.watermark_processor(input_ids, _scores)
if self.repetition_processor is not None:
_scores = self.repetition_processor(input_ids, _scores)
if self.frequency_processor is not None:
_scores = self.frequency_processor(input_ids, _scores)
if self.grammar_processor is not None:
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
for warper in self.warpers:
_scores = warper(input_ids, _scores)
_next_ids = self.choice(_scores)
scores[:, j] = _scores
next_ids[:, j] = _next_ids
next_ids = next_ids.view(B * S)
allscores = scores.view(B * S, -1)
alllogprobs = torch.log_softmax(allscores, -1)
if speculated_ids is not None:
accepted_ids = []
B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
S = speculated_ids.shape[1] + 1
indices = []
for i in range(B):
_next_ids = next_ids[i * S : (i + 1) * S]
_speculated_ids = speculated_ids[i]
validate_speculative = _next_ids[:-1] == _speculated_ids
index = i * S
accepted = 1
# First is always valid
indices.append(index)
for valid in validate_speculative.tolist():
if valid:
index += 1
accepted += 1
indices.append(index)
else:
break
accepted_ids.append(accepted)
accepted_ids = torch.tensor(
accepted_ids, device=input_ids.device, dtype=input_ids.dtype
)
next_ids = next_ids[indices]
logprobs = alllogprobs[indices]
indices = torch.arange(B, device=input_ids.device) * S
if speculative_scores is not None:
speculative_scores = speculative_scores[indices + accepted_ids - 1]
else:
accepted_ids = torch.ones_like(next_ids)
logprobs = alllogprobs
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
if speculate > 0:
if speculative_scores is not None:
# Medusa provided some scores
speculative_ids = Greedy()(speculative_scores)
else:
# n-gram
speculative_ids = create_n_gram_speculation(
input_ids, next_ids, accepted_ids, speculate, verbose
)
else:
speculative_ids = None
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
def advance_grammar(self, next_ids: List[int]):
if self.grammar_processor is not None:
other_new_states = self.grammar_processor.advance_batch(
next_ids, self.fsm_grammar_states
)
self.fsm_grammar_states = other_new_states
return self
def advance_grammar_single(self, grammar_state_index: int, next_id: int):
if self.grammar_processor is not None:
self.fsm_grammar_states[grammar_state_index] = (
self.grammar_processor.advance_at_index(
next_id,
self.fsm_grammar_states[grammar_state_index],
grammar_state_index,
)
)
return self
def advance_grammar_single_with_past_state(
self, grammar_state_index: int, next_id: torch.Tensor, past_state: int
):
if self.grammar_processor is not None:
next_id = next_id.item()
self.fsm_grammar_states[grammar_state_index] = (
self.grammar_processor.advance_at_index(
next_id,
past_state,
grammar_state_index,
)
)
return self
def filter(self, indices):
if self.watermark_processor is not None:
self.watermark_processor = self.watermark_processor.filter(indices)
if self.repetition_processor is not None:
self.repetition_processor = self.repetition_processor.filter(indices)
if self.frequency_processor is not None:
self.frequency_processor = self.frequency_processor.filter(indices)
if self.grammar_processor is not None:
self.grammar_processor = self.grammar_processor.filter(indices)
filtered_warpers = []
for warper in self.warpers:
filtered_warper = warper.filter(indices)
if filtered_warper is not None:
filtered_warpers.append(filtered_warper)
self.warpers = filtered_warpers
self.seeds = [self.seeds[i] for i in indices]
self.do_sample = [self.do_sample[i] for i in indices]
new_grammars = []
new_fsm_grammar_states = []
new_grammar_types = []
for i in indices:
new_grammars.append(self.grammars[i])
new_fsm_grammar_states.append(self.fsm_grammar_states[i])
new_grammar_types.append(self.grammar_types[i])
self.grammars = new_grammars
self.fsm_grammar_states = new_fsm_grammar_states
self.grammar_types = new_grammar_types
if any(self.do_sample):
self.choice.filter(indices)
else:
self.choice = Greedy()
return self
@classmethod
def from_pb(
cls,
pb: List[generate_pb2.NextTokenChooserParameters],
dtype: torch.dtype,
device: torch.device,
tokenizer: PreTrainedTokenizerBase,
fsm_grammar_states: Optional[List[int]] = None,
quantization_enabled: bool = False,
) -> "HeterogeneousNextTokenChooser":
return HeterogeneousNextTokenChooser(
watermark=[pb_.watermark for pb_ in pb],
temperature=[pb_.temperature for pb_ in pb],
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
frequency_penalty=[pb_.frequency_penalty for pb_ in pb],
top_k=[pb_.top_k for pb_ in pb],
top_p=[pb_.top_p for pb_ in pb],
typical_p=[pb_.typical_p for pb_ in pb],
do_sample=[pb_.do_sample for pb_ in pb],
seeds=[pb_.seed for pb_ in pb],
device=device,
dtype=dtype,
tokenizer=tokenizer,
grammars=[pb_.grammar for pb_ in pb],
grammar_types=[pb_.grammar_type for pb_ in pb],
fsm_grammar_states=(
fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
),
quantization_enabled=quantization_enabled,
)
def pad_next_token_chooser_parameters(
parameters: List[generate_pb2.NextTokenChooserParameters],
expected_size: int,
) -> List[generate_pb2.NextTokenChooserParameters]:
# disable all logits processors to minimize padding overhead
empty_parameters = generate_pb2.NextTokenChooserParameters(
temperature=1.0,
top_k=0,
top_p=1.0,
typical_p=1.0,
do_sample=False,
seed=0,
repetition_penalty=1.0,
frequency_penalty=0.0,
watermark=False,
grammar="",
grammar_type=0,
)
parameters.extend([empty_parameters] * (expected_size - len(parameters)))
return parameters
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
if device in ["hpu", torch.device("hpu")]:
import habana_frameworks.torch.hpu.random as htrandom
self.generator = htrandom.default_generators[0].manual_seed(seed)
else:
self.generator = torch.Generator("cpu")
self.generator.manual_seed(seed)
self.seed = seed
def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, -1)
# Avoid GPU<->CPU sync done by torch multinomial
# See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
return probs.div_(q).argmax()
class Greedy:
def __call__(self, logits):
return logits.argmax(dim=-1)
class HeterogeneousSampling:
r"""
Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.
"""
def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):
self.seeds = seeds
self.greedy_indices = []
self.sampling_mapping = {}
for i, (sample, seed) in enumerate(zip(do_sample, seeds)):
if sample:
self.sampling_mapping[i] = Sampling(seed, device)
else:
self.greedy_indices.append(i)
self.greedy = Greedy()
def __call__(self, logits):
out = torch.zeros(logits.shape[0], dtype=torch.int64, device=logits.device)
if self.greedy_indices:
# Computing for all indices is faster than slicing
torch.argmax(logits, -1, out=out)
for i, sampling in self.sampling_mapping.items():
out[i] = sampling(logits[i])
return out
def filter(self, indices):
new_greedy_indices = []
new_sampling_mapping = {}
for i, idx in enumerate(indices):
if idx in self.sampling_mapping:
new_sampling_mapping[i] = self.sampling_mapping[idx]
else:
new_greedy_indices.append(i)
self.greedy_indices = new_greedy_indices
self.sampling_mapping = new_sampling_mapping
return self
def batch_top_tokens(
top_n_tokens: List[int],
top_n_tokens_tensor: torch.Tensor,
logprobs: torch.Tensor,
accepted_ids: torch.Tensor,
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
"""Find the top n most likely tokens for a batch of generations.
When multiple tokens have equal probabilities and they don't all fit, the
remaining tokens are also returned.
"""
max_top_n = max(top_n_tokens)
# Early exit when top_n_tokens is not used
if max_top_n == 0:
return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)
batch_size = accepted_ids.shape[0]
speculate_size = logprobs.shape[0] // batch_size
top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)
# Ensure top_n doesn't exceed vocab size
top_n_tokens = [
min(tok, logprobs.size(-1))
for tok in top_n_tokens
for _ in range(speculate_size)
]
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
# Sorted topk is faster than torch.sort() since we only need a small subset
sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=-1, sorted=True).values
nth_highest = torch.gather(
sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)
)
nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min
# Find the new "fuzzy" top n values
top_n_indices = (logprobs >= nth_highest).nonzero()
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
k = 1 if top_n_ishes.numel() == 0 else top_n_ishes.max()
# Take a new topk for these new max n values
top_k = torch.topk(logprobs, k=k, dim=1, sorted=True)
top_n_ishes = top_n_ishes.tolist()
top_indices = top_k.indices.tolist()
top_values = top_k.values.tolist()
batch_top_token_ids = []
batch_top_token_logprobs = []
accepted_ids_list = accepted_ids.tolist()
for i, n_accepted_ids in enumerate(accepted_ids_list):
start = speculate_size * i
stop = speculate_size * (i + 1)
_top_indices = top_indices[start:stop]
_top_values = top_values[start:stop]
_top_n_ishes = top_n_ishes[start:stop]
_top_n_tokens = top_n_tokens[start:stop]
_top_indices = _top_indices[:n_accepted_ids]
_top_values = _top_values[:n_accepted_ids]
_top_n_ishes = _top_n_ishes[:n_accepted_ids]
_top_n_tokens = _top_n_tokens[:n_accepted_ids]
row_top_token_ids = []
row_top_token_logprobs = []
for idxs, vals, n, req_n in zip(
_top_indices, _top_values, _top_n_ishes, _top_n_tokens
):
indices = idxs[:n] if req_n > 0 else []
values = vals[:n] if req_n > 0 else []
row_top_token_ids.append(indices)
row_top_token_logprobs.append(values)
batch_top_token_ids.append(row_top_token_ids)
batch_top_token_logprobs.append(row_top_token_logprobs)
return batch_top_token_ids, batch_top_token_logprobs
def make_tokenizer_optional(tokenizer):
class _(type(tokenizer)):
def __call__(
self,
text,
return_tensors,
padding,
return_token_type_ids,
truncation,
max_length,
):
assert (
return_tensors == "pt"
), "inccorrect input arguments when calling TransparentTokenizer"
assert (
padding == "max_length" or padding == "longest"
), "inccorrect input arguments when calling TransparentTokenizer"
assert (
not return_token_type_ids
), "inccorrect input arguments when calling TransparentTokenizer"
assert (
truncation
), "inccorrect input arguments when calling TransparentTokenizer"
def str_token_to_int(i):
if i == "?":
return tokenizer.pad_token_id
else:
return int(i)
all_tokens = [
[str_token_to_int(i.strip()) for i in inner_text.split(",")]
for inner_text in text
]
if padding == "longest":
max_length = max(len(tokens) for tokens in all_tokens)
return {
"input_ids": torch.tensor(
[
[tokenizer.pad_token_id] * (max_length - len(tokens)) + tokens
for tokens in all_tokens
]
),
"attention_mask": torch.tensor(
[
[0] * (max_length - len(tokens)) + [1] * len(tokens)
for tokens in all_tokens
]
),
}
def decode(
self,
token_ids,
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = None,
**kwargs,
) -> str:
# I don't think this method is used anywhere and should be removed when doing refactoring
return ",".join(str(i) for i in to_py_obj(token_ids)) # noqa: F821
if os.getenv("SKIP_TOKENIZER_IN_TGI", "false").lower() == "true":
tokenizer.__class__ = _
tokenizer.is_transparent = True
def is_tokenizer_transparent(tokenizer):
return hasattr(tokenizer, "is_transparent") and tokenizer.is_transparent is True
================================================
FILE: backends/gaudi/server/text_generation_server/utils/version.py
================================================
from packaging.version import Version
from packaging import version
import subprocess
def get_driver_version():
"""
Returns the driver version.
"""
# Enable console printing for `hl-smi` check
output = subprocess.run(
"hl-smi",
shell=True,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env={"ENABLE_CONSOLE": "true"},
)
if output.returncode == 0 and output.stdout:
return version.parse(
output.stdout.split("\n")[2]
.replace(" ", "")
.split(":")[1][:-1]
.split("-")[0]
)
return None
MIN_TGI_GAUDI_SYNAPSE_VERSION = Version("1.19.0")
def is_driver_compatible():
driver_version = get_driver_version()
if driver_version is not None:
if driver_version < MIN_TGI_GAUDI_SYNAPSE_VERSION:
return False
return True
================================================
FILE: backends/gaudi/server/text_generation_server/utils/watermark.py
================================================
# coding=utf-8
# Copyright 2023 Authors of "A Watermark for Large Language Models"
# available at https://arxiv.org/abs/2301.10226
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from transformers import LogitsProcessor
from typing import List, Union
GAMMA = float(os.getenv("WATERMARK_GAMMA", 0.5))
DELTA = float(os.getenv("WATERMARK_DELTA", 2.0))
class WatermarkLogitsProcessor(LogitsProcessor):
def __init__(
self,
gamma: float = GAMMA,
delta: float = DELTA,
hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
device: str = "cpu",
):
# watermarking parameters
self.gamma = gamma
self.delta = delta
self.rng = torch.Generator(device="cpu")
self.hash_key = hash_key
def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]):
if isinstance(input_ids, list):
assert (
len(input_ids) >= 1
), "requires at least a 1 token prefix sequence to seed rng"
prev_token = input_ids[-1]
else:
assert len(input_ids) == 1
input_ids = input_ids[0]
assert (
input_ids.shape[-1] >= 1
), "requires at least a 1 token prefix sequence to seed rng"
prev_token = input_ids[-1].item()
self.rng.manual_seed(self.hash_key * prev_token)
def _get_greenlist_ids(
self,
input_ids: Union[List[int], torch.LongTensor],
max_value: int,
device: torch.device,
) -> List[int]:
# seed the rng using the previous tokens/prefix
self._seed_rng(input_ids)
greenlist_size = int(max_value * self.gamma)
vocab_permutation = torch.randperm(max_value, device=device, generator=self.rng)
greenlist_ids = vocab_permutation[:greenlist_size]
return greenlist_ids
@staticmethod
def _calc_greenlist_mask(
scores: torch.FloatTensor, greenlist_token_ids
) -> torch.BoolTensor:
green_tokens_mask = torch.zeros_like(scores)
green_tokens_mask[-1, greenlist_token_ids] = 1
final_mask = green_tokens_mask.bool()
return final_mask
@staticmethod
def _bias_greenlist_logits(
scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float
) -> torch.Tensor:
scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
return scores
def __call__(
self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor
) -> torch.FloatTensor:
greenlist_ids = self._get_greenlist_ids(
input_ids, scores.shape[-1], scores.device
)
green_tokens_mask = self._calc_greenlist_mask(
scores=scores, greenlist_token_ids=greenlist_ids
)
scores = self._bias_greenlist_logits(
scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta
)
return scores
================================================
FILE: backends/gaudi/server/text_generation_server/utils/weights.py
================================================
import torch
from abc import ABC, abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, List, Optional, Union, Type
from safetensors import safe_open
from dataclasses import dataclass
class WeightsLoader(ABC):
"""
Instances of this type implement higher-level weight loading.
At a low-level, every weight is stored in the Safetensors format.
The interpretation of weights may be different however, for instance
could be packed, quantized weights. Loaders are responsible for
interpreting the raw tensors, sharding tensors in a manner compatible
with the format, etc.
"""
@abstractmethod
def get_weights(self, weights: "Weights", prefix: str):
"""
Get weights at the given prefix and apply without tensor paralllism.
"""
...
@abstractmethod
def get_weights_col_packed(
self,
weights: "Weights",
prefix: str,
block_sizes: Union[int, List[int]],
):
"""
Get the packed weights at the given prefix with column-splitting for
tensor parallelism. This method should be used when multiple different
weights are packed into a tensor, for instance, query/key/value
weights or a gate/up projection.
The `block_sizes` determines the proportions of the packed tensors.
The columns are split in equally sized blocks when `block_sizes` is an
`int`, or in blocks proportional given to the sizes. For instance
`[2, 1, 1]` will divide an input with dimensionality `1024` in
`[512, 256, 256]`.
"""
...
def get_weights_col(self, weights: "Weights", prefix: str):
"""
Get weights at the given prefix and apply column-splitting for tensor
paralllism.
"""
return weights.get_multi_weights_col([prefix], 0)
@abstractmethod
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
"""
Get the weights at the given prefixes, column-split them for tensor
parallelim, and then concatenate the weights along the given dimension.
"""
...
@abstractmethod
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
"""
Get the weights at the given prefixes, column-split them for tensor
parallelim, and then concatenate the weights along the given dimension.
"""
...
@abstractmethod
def get_weights_row(self, weights: "Weights", prefix: str):
"""
Get the weights at the given prefix and apply row-splitting for tensor
parallism.
"""
...
class Weight(ABC):
"""Instances of this type implement unquantized/quantized/to-be
quantized weights."""
@abstractmethod
def get_linear(self, bias: torch.Tensor):
"""Create a linear layer from this weight."""
...
@dataclass
class UnquantizedWeight(Weight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
from text_generation_server.layers.linear import FastLinear
return FastLinear(self.weight, bias)
class DefaultWeightsLoader(WeightsLoader):
"""Weight loader that loads (unquantized) Torch tensors."""
def __init__(self, weight_class: Type[UnquantizedWeight]):
"""Create a loader. Weights will be wrapped using the given `weights_class`,
normally this will be `UnquantizedWeight`, but a quantizer-specific class
such as `Fp8Weight` can be used to quantize the weights during loading.
"""
self.weight_class = weight_class
"""
Loader that uses tensors as-is with the exception of applying sharding
and/or concatenation.
"""
def get_weights(self, weights: "Weights", prefix: str):
return weights.get_tensor(f"{prefix}.weight")
def get_weights_col_packed(
self,
weights: "Weights",
prefix: str,
block_sizes: Union[int, List[int]],
):
return self.weight_class(
weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
),
)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
return self.weight_class(torch.cat(w, dim=dim))
def get_weights_row(self, weights: "Weights", prefix: str):
return self.weight_class(
weights.get_sharded(f"{prefix}.weight", dim=1),
)
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
w = [weights.get_tensor(f"{p}.weight") for p in prefixes]
return self.weight_class(torch.cat(w, dim=dim))
class Weights:
def __init__(
self,
filenames: List[Path],
device,
dtype,
process_group,
weights_loader: WeightsLoader,
aliases: Optional[Dict[str, List[str]]] = None,
prefix: Optional[str] = None,
):
routing = {}
for filename in filenames:
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
if aliases is None:
aliases = {}
self.aliases = aliases
self.routing = routing
self.device = device
self.dtype = dtype
self.process_group = process_group
self.prefix = prefix
self.weights_loader = weights_loader
self._handles = {}
def _get_handle(self, filename):
if filename not in self._handles:
f = safe_open(filename, framework="pytorch")
self._handles[filename] = f
return self._handles[filename]
def get_filename(self, tensor_name: str) -> (str, str):
names = [tensor_name]
if self.prefix is not None:
prefixed = f"{self.prefix}.{tensor_name}"
names.append(prefixed)
for name in names:
filename = self.routing.get(name, None)
if filename is not None:
return str(filename), name
aliases = self.aliases.get(name, [])
for alias in aliases:
filename = self.routing.get(alias, None)
if filename is not None:
return str(filename), alias
raise RuntimeError(f"weight {tensor_name} does not exist")
def _get_slice(self, tensor_name: str):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
return slice_
def has_tensor(self, tensor_name: str):
try:
self.get_filename(tensor_name)
except Exception:
return False
return True
def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape()
def get_tensor(
self, tensor_name: str, to_device: bool = True, to_dtype: bool = True
) -> torch.Tensor:
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name)
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32. Exl2 uses int16
# as well. FP8 uses torch.float8_e4m3fn
if (
tensor.dtype
not in [
torch.float8_e4m3fn,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
]
and to_dtype
):
tensor = tensor.to(dtype=self.dtype)
if to_device:
tensor = tensor.to(device=self.device)
return tensor
def get_partial_sharded(
self, tensor_name: str, dim: int, to_device=True, to_dtype=True
):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
world_size = self.process_group.size()
rank = self.process_group.rank()
size = slice_.get_shape()[dim]
block_size = (size + world_size - 1) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
if dim == 0:
tensor = slice_[start:stop]
elif dim == 1:
tensor = slice_[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32. exl2 uses int16.
# FP8 uses torch.float8_e4m3fn.
if (
tensor.dtype
not in (torch.float8_e4m3fn, torch.int8, torch.int16, torch.int32)
and to_dtype
):
tensor = tensor.to(dtype=self.dtype)
if to_device:
tensor = tensor.to(device=self.device)
return tensor
def get_sharded(self, tensor_name: str, dim: int, to_device=True, to_dtype=True):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
world_size = self.process_group.size()
size = slice_.get_shape()[dim]
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(
tensor_name, dim, to_device=to_device, to_dtype=to_dtype
)
def get_packed_sharded(
self,
tensor_name: str,
dim: int,
block_sizes: Union[int, List[int]],
to_dtype=True,
) -> torch.Tensor:
"""
Get a shard from a tensor that packs multiple tensors.
When a tensor packs multiple tensors (such as QKV or an up
projection + gate projection), sharding with `get_sharded` is not
safe since it would not split the packed tensors across shards.
This method shards a tensor, such that the packed tensors are
split across shards.
The columns are split in equally sized blocks when blocks is an `int`, or
in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
convenient for e.g. splitting QKV without knowing the storage details of
quantized weights.
"""
slice_ = self._get_slice(tensor_name)
total_size = slice_.get_shape()[dim]
block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)
world_size = self.process_group.size()
rank = self.process_group.rank()
tensors_slices = []
block_offset = 0
for block_size in block_sizes:
assert (
block_size % world_size == 0
), f"Prepacked tensor cannot be sharded across {world_size} shards"
shard_block_size = block_size // world_size
start = rank * shard_block_size
stop = (rank + 1) * shard_block_size
tensors_slices += range(block_offset + start, block_offset + stop)
block_offset += block_size
if dim == 0:
tensor = slice_[tensors_slices, ...]
elif dim == 1 or dim == -2:
tensor = slice_[:, tensors_slices, ...]
elif dim == 2 or dim == -1:
tensor = slice_[..., tensors_slices]
else:
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
tensor = tensor.to(device=self.device)
# Avoid casting quantizer dtypes.
if (
tensor.dtype
not in [
torch.float8_e4m3fn,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
]
and to_dtype
):
tensor = tensor.to(dtype=self.dtype)
return tensor
def get_weights(self, prefix: str):
return self.weights_loader.get_weights(self, prefix)
def get_weights_col_packed_qkv(
self,
prefix: str,
num_heads: int,
num_key_value_heads: int,
):
return self.get_weights_col_packed(
prefix, [num_heads, num_key_value_heads, num_key_value_heads]
)
def get_weights_col_packed_gate_up(self, prefix: str):
return self.get_weights_col_packed(prefix, 2)
def get_weights_col_packed(self, prefix: str, block_sizes: Union[int, List[int]]):
"""
The columns are split in equally sized blocks when blocks is an `int`, or
in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
convenient for e.g. splitting QKV without knowing the storage details of
quantized weights.
"""
return self.weights_loader.get_weights_col_packed(self, prefix, block_sizes)
def get_weights_col(self, prefix: str):
return self.weights_loader.get_weights_col(self, prefix)
def get_multi_weights_col(self, prefixes: List[str], dim: int):
return self.weights_loader.get_multi_weights_col(self, prefixes, dim)
def get_tensor_shard(self, var, dim):
world_size = self.process_group.size()
rank = self.process_group.rank()
block_size = var.size()[dim] // world_size
start = rank * block_size
stop = (rank + 1) * block_size
if dim == 0:
tensor = var[start:stop]
elif dim == 1:
tensor = var[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
def get_weights_row(self, prefix: str):
return self.weights_loader.get_weights_row(self, prefix)
def get_multi_weights(self, prefixes: List[str], dim: int):
return self.weights_loader.get_multi_weights(self, prefixes, dim)
@contextmanager
def use_loader(self, weights_loader: WeightsLoader):
"""
This method is a context manager that can be used to use `Weights` with
a different loader for the duration of the context.
"""
old_loader = self.weights_loader
self.weights_loader = weights_loader
try:
yield
finally:
self.weights_loader = old_loader
@property
def loader(self):
return self.weights_loader
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
"""
Convert block count or proportions to block sizes.
This function accepts
- The number of blocks (int), in which case the block size is
total_size//blocks; or
- A list of block sizes (List[int]).
In the latter case, if sum(blocks) < total_size, the ratios between
the block sizes will be preserved. For instance, if blocks is
[2, 1, 1] and total_size is 1024, the returned block sizes are
[512, 256, 256].
"""
if isinstance(blocks, list):
total_blocks = sum(blocks)
assert (
total_size % total_blocks == 0
), f"Cannot split {total_size} in proportional blocks: {blocks}"
part_size = total_size // total_blocks
return [part_size * block for block in blocks]
else:
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
single_size = total_size // blocks
return [single_size] * blocks
================================================
FILE: backends/gaudi/tgi-entrypoint.sh
================================================
#!/bin/bash
ldconfig 2>/dev/null || echo 'unable to refresh ld cache, not a big deal in most cases'
# Check if --sharded argument is present in the command line arguments
if [[ "$*" == *"--sharded true"* ]]; then
echo 'setting PT_HPU_ENABLE_LAZY_COLLECTIVES=1 for sharding'
export PT_HPU_ENABLE_LAZY_COLLECTIVES=1
fi
text-generation-launcher $@
================================================
FILE: backends/grpc-metadata/Cargo.toml
================================================
[package]
name = "grpc-metadata"
version = "0.1.0"
edition = "2021"
[dependencies]
opentelemetry = "^0.20"
tonic = "^0.10"
tracing = "^0.1"
tracing-opentelemetry = "^0.21"
================================================
FILE: backends/grpc-metadata/src/lib.rs
================================================
//! A crate to extract and inject a OpenTelemetry context from and to a gRPC request.
//! Inspired by: https://github.com/open-telemetry/opentelemetry-rust gRPC examples
use opentelemetry::global;
use opentelemetry::propagation::Injector;
use tracing_opentelemetry::OpenTelemetrySpanExt;
/// Inject context in the metadata of a gRPC request.
struct MetadataInjector<'a>(pub &'a mut tonic::metadata::MetadataMap);
impl Injector for MetadataInjector<'_> {
/// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs
fn set(&mut self, key: &str, value: String) {
if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) {
if let Ok(val) = value.parse() {
self.0.insert(key, val);
}
}
}
}
/// Get a context from the global context and inject the span into a gRPC request's metadata.
fn inject(metadata: &mut tonic::metadata::MetadataMap) {
global::get_text_map_propagator(|propagator| {
propagator.inject_context(
&tracing::Span::current().context(),
&mut MetadataInjector(metadata),
)
})
}
pub trait InjectTelemetryContext {
fn inject_context(self) -> Self;
}
impl InjectTelemetryContext for tonic::Request {
fn inject_context(mut self) -> Self {
inject(self.metadata_mut());
self
}
}
================================================
FILE: backends/llamacpp/Cargo.toml
================================================
[package]
name = "text-generation-router-llamacpp"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true
[build-dependencies]
bindgen = "0.71.1"
pkg-config = "0.3.31"
[dependencies]
async-trait = "0.1.85"
clap = "4.5.27"
hf-hub.workspace = true
num_cpus = "1.16.0"
text-generation-router = { path = "../../router" }
thiserror = "2.0.11"
tokenizers.workspace = true
tokio = { version = "1.43.0", features = ["process"] }
tokio-stream = "0.1.17"
tracing = "0.1.41"
================================================
FILE: backends/llamacpp/README.md
================================================
# Llamacpp backend
If all your dependencies are installed at the system level, running
cargo build should be sufficient. However, if you want to experiment
with different versions of llama.cpp, some additional setup is required.
## Install llama.cpp
LLAMACPP_PREFIX=$(pwd)/llama.cpp.out
git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp
cmake -B build \
-DCMAKE_INSTALL_PREFIX="$LLAMACPP_PREFIX" \
-DLLAMA_BUILD_COMMON=OFF \
-DLLAMA_BUILD_TESTS=OFF \
-DLLAMA_BUILD_EXAMPLES=OFF \
-DLLAMA_BUILD_SERVER=OFF
cmake --build build --config Release -j
cmake --install build
## Build TGI
PKG_CONFIG_PATH="$LLAMACPP_PREFIX/lib/pkgconfig" cargo build
================================================
FILE: backends/llamacpp/build.rs
================================================
use bindgen::callbacks::{ItemInfo, ParseCallbacks};
use std::env;
use std::path::PathBuf;
#[derive(Debug)]
struct PrefixStripper;
impl ParseCallbacks for PrefixStripper {
fn generated_name_override(&self, item_info: ItemInfo<'_>) -> Option {
item_info.name.strip_prefix("llama_").map(str::to_string)
}
}
fn main() {
if let Some(cuda_version) = option_env!("CUDA_VERSION") {
let mut version: Vec<&str> = cuda_version.split('.').collect();
if version.len() > 2 {
version.pop();
}
let cuda_version = format!("cuda-{}", version.join("."));
pkg_config::Config::new().probe(&cuda_version).unwrap();
}
let llama = pkg_config::Config::new().probe("llama").unwrap();
for path in &llama.link_paths {
println!("cargo:rustc-link-arg=-Wl,-rpath,{}", path.display());
}
if cfg!(target_os = "linux") {
println!("cargo:rustc-link-arg=-Wl,--disable-new-dtags");
}
let bindings = bindgen::Builder::default()
.clang_args(
llama
.include_paths
.iter()
.map(|p| format!("-I{}", p.display())),
)
.header_contents("llama_bindings.h", "#include ")
.prepend_enum_name(false)
.parse_callbacks(Box::new(PrefixStripper))
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.generate()
.expect("Unable to generate bindings");
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
bindings
.write_to_file(out_path.join("llamacpp.rs"))
.expect("Couldn't write bindings!");
}
================================================
FILE: backends/llamacpp/requirements.txt
================================================
transformers==4.49
huggingface-hub==0.28.1
hf-transfer==0.1.9
torch==2.6.0
================================================
FILE: backends/llamacpp/src/backend.rs
================================================
use crate::llamacpp;
use async_trait::async_trait;
use std::ffi::CString;
use std::mem::replace;
use std::str::FromStr;
use std::sync::{mpsc, Once};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{FinishReason, Token};
use thiserror::Error;
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::sync::{oneshot, watch};
use tokio::task::{spawn, spawn_blocking};
use tokio::time::{timeout, Duration, Instant};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::instrument;
use tracing::{debug, error, info, trace, warn};
#[derive(Debug, Clone, Copy)]
pub enum LlamacppSplitMode {
GPU(usize),
Layer,
Row,
}
impl FromStr for LlamacppSplitMode {
type Err = String;
fn from_str(s: &str) -> Result {
match s.to_lowercase().as_str() {
"layer" => Ok(LlamacppSplitMode::Layer),
"row" => Ok(LlamacppSplitMode::Row),
_ => match s.parse::() {
Ok(n) => Ok(LlamacppSplitMode::GPU(n)),
Err(_) => Err("Choose a GPU number or `layer` or `row`".to_string()),
},
}
}
}
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
pub enum LlamacppNuma {
Disabled,
Distribute,
Isolate,
Numactl,
Mirror,
}
#[allow(non_camel_case_types)]
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
pub enum LlamacppGGMLType {
F32,
F16,
Q4_0,
Q4_1,
Q5_0,
Q5_1,
Q8_0,
Q8_1,
Q2_K,
Q3_K,
Q4_K,
Q5_K,
Q6_K,
Q8_K,
IQ2_XXS,
IQ2_XS,
IQ3_XXS,
IQ1_S,
IQ4_NL,
IQ3_S,
IQ2_S,
IQ4_XS,
I8,
I16,
I32,
I64,
F64,
IQ1_M,
BF16,
TQ1_0,
TQ2_0,
}
// TODO: macro
impl LlamacppGGMLType {
fn to_ggml_type(self) -> llamacpp::ggml_type {
match self {
LlamacppGGMLType::F32 => llamacpp::GGML_TYPE_F32,
LlamacppGGMLType::F16 => llamacpp::GGML_TYPE_F16,
LlamacppGGMLType::Q4_0 => llamacpp::GGML_TYPE_Q4_0,
LlamacppGGMLType::Q4_1 => llamacpp::GGML_TYPE_Q4_1,
LlamacppGGMLType::Q5_0 => llamacpp::GGML_TYPE_Q5_0,
LlamacppGGMLType::Q5_1 => llamacpp::GGML_TYPE_Q5_1,
LlamacppGGMLType::Q8_0 => llamacpp::GGML_TYPE_Q8_0,
LlamacppGGMLType::Q8_1 => llamacpp::GGML_TYPE_Q8_1,
LlamacppGGMLType::Q2_K => llamacpp::GGML_TYPE_Q2_K,
LlamacppGGMLType::Q3_K => llamacpp::GGML_TYPE_Q3_K,
LlamacppGGMLType::Q4_K => llamacpp::GGML_TYPE_Q4_K,
LlamacppGGMLType::Q5_K => llamacpp::GGML_TYPE_Q5_K,
LlamacppGGMLType::Q6_K => llamacpp::GGML_TYPE_Q6_K,
LlamacppGGMLType::Q8_K => llamacpp::GGML_TYPE_Q8_K,
LlamacppGGMLType::IQ2_XXS => llamacpp::GGML_TYPE_IQ2_XXS,
LlamacppGGMLType::IQ2_XS => llamacpp::GGML_TYPE_IQ2_XS,
LlamacppGGMLType::IQ3_XXS => llamacpp::GGML_TYPE_IQ3_XXS,
LlamacppGGMLType::IQ1_S => llamacpp::GGML_TYPE_IQ1_S,
LlamacppGGMLType::IQ4_NL => llamacpp::GGML_TYPE_IQ4_NL,
LlamacppGGMLType::IQ3_S => llamacpp::GGML_TYPE_IQ3_S,
LlamacppGGMLType::IQ2_S => llamacpp::GGML_TYPE_IQ2_S,
LlamacppGGMLType::IQ4_XS => llamacpp::GGML_TYPE_IQ4_XS,
LlamacppGGMLType::I8 => llamacpp::GGML_TYPE_I8,
LlamacppGGMLType::I16 => llamacpp::GGML_TYPE_I16,
LlamacppGGMLType::I32 => llamacpp::GGML_TYPE_I32,
LlamacppGGMLType::I64 => llamacpp::GGML_TYPE_I64,
LlamacppGGMLType::F64 => llamacpp::GGML_TYPE_F64,
LlamacppGGMLType::IQ1_M => llamacpp::GGML_TYPE_IQ1_M,
LlamacppGGMLType::BF16 => llamacpp::GGML_TYPE_BF16,
LlamacppGGMLType::TQ1_0 => llamacpp::GGML_TYPE_TQ1_0,
LlamacppGGMLType::TQ2_0 => llamacpp::GGML_TYPE_TQ2_0,
}
}
}
pub struct LlamacppConfig {
pub model_gguf: String,
pub max_batch_total_tokens: usize,
pub max_physical_batch_total_tokens: usize,
pub max_batch_size: usize,
pub batch_timeout: Duration,
pub n_threads: usize,
pub n_threads_batch: usize,
pub n_gpu_layers: usize,
pub split_mode: LlamacppSplitMode,
pub numa: LlamacppNuma,
pub defrag_threshold: f32,
pub use_mmap: bool,
pub use_mlock: bool,
pub offload_kqv: bool,
pub flash_attention: bool,
pub type_k: LlamacppGGMLType,
pub type_v: LlamacppGGMLType,
}
#[derive(Debug)]
struct LlamacppRequest {
input_ids: Vec,
top_k: i32,
top_p: f32,
typical_p: f32,
min_keep: usize,
temp: f32,
seed: u32,
penalty_last_n: i32,
penalty_repeat: f32,
penalty_freq: f32,
penalty_present: f32,
max_new_tokens: usize,
tx: UnboundedSender>,
time: Instant,
}
pub struct LlamacppBackend {
tx: UnboundedSender,
status: watch::Receiver,
}
impl LlamacppRequest {
fn new(
from: &ValidGenerateRequest,
tx: UnboundedSender>,
) -> Option {
from.input_ids.as_ref().map(|input_ids| LlamacppRequest {
input_ids: input_ids.iter().map(|&x| x as i32).collect(),
top_k: from.parameters.top_k as _,
top_p: from.parameters.top_p as _,
typical_p: from.parameters.typical_p as _,
min_keep: 0, // disabled
temp: from.parameters.temperature as _,
seed: from.parameters.seed as _,
penalty_last_n: 64, // 0 = disabled, -1 = context size
penalty_repeat: from.parameters.repetition_penalty as _,
penalty_freq: from.parameters.frequency_penalty as _,
penalty_present: 0.0, // disabled
max_new_tokens: from.stopping_parameters.max_new_tokens as _,
tx,
time: Instant::now(),
})
}
}
struct Llamacpp {
model: *mut llamacpp::llama_model,
ctx: *mut llamacpp::llama_context,
vocab: *const llamacpp::llama_vocab,
logprobs: Vec,
batch: llamacpp::llama_batch,
}
extern "C" fn llamacpp_log_callback(
level: llamacpp::ggml_log_level,
msg: *const std::os::raw::c_char,
_user_data: *mut std::os::raw::c_void,
) {
let cmsg = unsafe { std::ffi::CStr::from_ptr(msg) };
let rmsg = cmsg.to_string_lossy().trim_end_matches('\n').to_string();
match level {
llamacpp::GGML_LOG_LEVEL_DEBUG => debug!(target: "llamacpp", "{}", rmsg),
llamacpp::GGML_LOG_LEVEL_INFO => info!(target: "llamacpp", "{}", rmsg),
llamacpp::GGML_LOG_LEVEL_WARN => warn!(target: "llamacpp", "{}", rmsg),
llamacpp::GGML_LOG_LEVEL_ERROR => error!(target: "llamacpp", "{}", rmsg),
_ => trace!(target: "llamacpp", "{}", rmsg),
}
}
impl Llamacpp {
fn new(conf: LlamacppConfig) -> Result {
let gguf = CString::new(conf.model_gguf)?;
let model = unsafe {
let mut params = llamacpp::model_default_params();
params.n_gpu_layers = conf.n_gpu_layers as _;
params.split_mode = match conf.split_mode {
LlamacppSplitMode::GPU(_) => llamacpp::LLAMA_SPLIT_MODE_NONE,
LlamacppSplitMode::Layer => llamacpp::LLAMA_SPLIT_MODE_LAYER,
LlamacppSplitMode::Row => llamacpp::LLAMA_SPLIT_MODE_ROW,
};
params.main_gpu = match conf.split_mode {
LlamacppSplitMode::GPU(n) => n as _,
_ => 0,
};
params.use_mmap = conf.use_mmap;
params.use_mlock = conf.use_mlock;
llamacpp::model_load_from_file(gguf.as_ptr(), params)
};
if model.is_null() {
return Err(BackendError::Llamacpp("Failed to load model".to_string()));
}
let ctx = unsafe {
let mut params = llamacpp::context_default_params();
params.n_ctx = conf.max_batch_total_tokens as _;
params.n_batch = conf.max_batch_total_tokens as _;
params.n_ubatch = conf.max_physical_batch_total_tokens as _;
params.n_seq_max = conf.max_batch_size as _;
params.n_threads = conf.n_threads as _;
params.n_threads_batch = conf.n_threads_batch as _;
params.defrag_thold = conf.defrag_threshold;
params.offload_kqv = conf.offload_kqv;
params.flash_attn = conf.flash_attention;
params.type_k = conf.type_k.to_ggml_type();
params.type_v = conf.type_v.to_ggml_type();
params.no_perf = true;
llamacpp::init_from_model(model, params)
};
if ctx.is_null() {
return Err(BackendError::Llamacpp("Failed to init context".to_string()));
}
let vocab = unsafe { llamacpp::model_get_vocab(model) };
if vocab.is_null() {
return Err(BackendError::Llamacpp("Failed to get vocab".to_string()));
}
let n_tokens = unsafe { llamacpp::vocab_n_tokens(vocab) };
let mut logprobs = Vec::with_capacity(n_tokens as usize);
for token in 0..n_tokens {
logprobs.push(llamacpp::llama_token_data {
id: token,
logit: 0.0,
p: 0.0,
});
}
let batch = unsafe { llamacpp::batch_init(conf.max_batch_total_tokens as _, 0, 1) };
Ok(Llamacpp {
model,
ctx,
vocab,
logprobs,
batch,
})
}
fn decode(&mut self) -> i32 {
unsafe { llamacpp::decode(self.ctx, self.batch) }
}
fn clear_kv_cache(&mut self, seq_id: llamacpp::llama_seq_id) {
unsafe {
llamacpp::kv_cache_seq_rm(self.ctx, seq_id, -1, -1);
}
}
fn batch_push(
&mut self,
token: llamacpp::llama_token,
pos: llamacpp::llama_pos,
seq_id: llamacpp::llama_seq_id,
logits: bool,
) -> usize {
let n = self.batch.n_tokens as usize;
unsafe {
*self.batch.token.add(n) = token;
*self.batch.pos.add(n) = pos;
*self.batch.n_seq_id.add(n) = 1;
*(*self.batch.seq_id.add(n)).add(0) = seq_id;
*self.batch.logits.add(n) = logits as i8;
}
self.batch.n_tokens += 1;
n
}
}
impl Drop for Llamacpp {
fn drop(&mut self) {
if !self.ctx.is_null() {
unsafe { llamacpp::free(self.ctx) };
}
if !self.model.is_null() {
unsafe { llamacpp::model_free(self.model) };
}
unsafe { llamacpp::batch_free(self.batch) };
}
}
struct LlamacppSampler {
chain: *mut llamacpp::llama_sampler,
}
impl LlamacppSampler {
fn new(req: &LlamacppRequest) -> Option {
let chain = unsafe {
let params = llamacpp::sampler_chain_default_params();
llamacpp::sampler_chain_init(params)
};
if chain.is_null() {
error!("Failed to init sampler");
return None;
}
let (top_k, top_p, typical_p, temp, penalties, dist) = unsafe {
(
llamacpp::sampler_init_top_k(req.top_k),
llamacpp::sampler_init_top_p(req.top_p, req.min_keep),
llamacpp::sampler_init_typical(req.typical_p, req.min_keep),
llamacpp::sampler_init_temp(req.temp),
llamacpp::sampler_init_penalties(
req.penalty_last_n,
req.penalty_repeat,
req.penalty_freq,
req.penalty_present,
),
llamacpp::sampler_init_dist(req.seed),
)
};
let all = &[
("top_k", top_k),
("top_p", top_p),
("typical_p", typical_p),
("temp", temp),
("penalties", penalties),
("dist", dist),
];
let mut failed = false;
for (k, v) in all {
if v.is_null() {
error!("Failed to init {k} sampler");
failed = true;
} else {
unsafe { llamacpp::sampler_chain_add(chain, *v) };
}
}
if failed {
unsafe { llamacpp::sampler_free(chain) };
None
} else {
Some(LlamacppSampler { chain })
}
}
fn sample(&self, llamacpp: &mut Llamacpp, idx: usize) -> (llamacpp::llama_token, f32) {
let logits = unsafe { llamacpp::get_logits_ith(llamacpp.ctx, idx as _) };
for (token, logprob) in llamacpp.logprobs.iter_mut().enumerate() {
*logprob = llamacpp::llama_token_data {
id: token as _,
logit: unsafe { *logits.add(token) },
p: 0.0,
};
}
let mut view = llamacpp::llama_token_data_array {
data: llamacpp.logprobs.as_mut_ptr(),
size: llamacpp.logprobs.len(),
selected: -1,
sorted: false,
};
unsafe {
llamacpp::sampler_apply(self.chain, &mut view);
let logprob = *view.data.offset(view.selected as _);
llamacpp::sampler_accept(self.chain, logprob.id);
(logprob.id, logprob.p.ln())
}
}
}
impl Drop for LlamacppSampler {
fn drop(&mut self) {
if !self.chain.is_null() {
unsafe { llamacpp::sampler_free(self.chain) };
}
}
}
struct LlamacppSeq {
id: usize,
batch_pos: usize,
token: llamacpp::llama_token,
pos: llamacpp::llama_pos,
sampler: LlamacppSampler,
text: String,
n_new_tokens: usize,
running: bool,
}
static INIT: Once = Once::new();
impl LlamacppBackend {
pub fn new(
conf: LlamacppConfig,
tokenizer: Tokenizer,
) -> (
Self,
oneshot::Receiver>,
watch::Sender,
) {
// Setup llama & export logs, once and for all
INIT.call_once(|| unsafe {
llamacpp::log_set(Some(llamacpp_log_callback), std::ptr::null_mut());
llamacpp::backend_init();
llamacpp::numa_init(match conf.numa {
LlamacppNuma::Disabled => llamacpp::GGML_NUMA_STRATEGY_DISABLED,
LlamacppNuma::Distribute => llamacpp::GGML_NUMA_STRATEGY_DISTRIBUTE,
LlamacppNuma::Isolate => llamacpp::GGML_NUMA_STRATEGY_ISOLATE,
LlamacppNuma::Numactl => llamacpp::GGML_NUMA_STRATEGY_NUMACTL,
LlamacppNuma::Mirror => llamacpp::GGML_NUMA_STRATEGY_MIRROR,
});
});
let (status_tx, status_rx) = watch::channel(false);
let (shutdown_tx, shutdown_rx) = watch::channel(false);
let (ok_tx, ok_rx) = oneshot::channel();
let (tx, mut rx) = unbounded_channel::();
let (sync_tx, sync_rx) = mpsc::channel();
spawn(async move {
let mut n_tokens = 0;
let mut requests = Vec::with_capacity(conf.max_batch_size);
let flush = |requests: &mut Vec<_>, n_tokens: &mut usize| {
if !requests.is_empty() {
let _ =
sync_tx.send(replace(requests, Vec::with_capacity(conf.max_batch_size)));
*n_tokens = 0;
}
};
loop {
match timeout(conf.batch_timeout, rx.recv()).await {
Ok(Some(request)) => {
let n_tokens_to_add = request.input_ids.len();
if n_tokens + n_tokens_to_add > conf.max_batch_total_tokens {
flush(&mut requests, &mut n_tokens);
}
n_tokens += n_tokens_to_add;
requests.push(request);
if requests.len() == conf.max_batch_size {
flush(&mut requests, &mut n_tokens);
}
}
Ok(None) => break, // closed
Err(_) => flush(&mut requests, &mut n_tokens), // timeout
}
}
});
spawn_blocking(move || {
let mut llamacpp = match Llamacpp::new(conf) {
Ok(v) => {
let _ = ok_tx.send(Ok(()));
v
}
Err(e) => {
let _ = ok_tx.send(Err(e));
return;
}
};
let vocab = tokenizer.get_added_vocabulary();
// health() returns true
let _ = status_tx.send(true);
while let Ok(requests) = sync_rx.recv() {
if *shutdown_rx.borrow() {
break;
}
let start_time = Instant::now();
let mut seqs: Vec = Vec::with_capacity(requests.len());
llamacpp.batch.n_tokens = 0;
for (seq_id, request) in requests.iter().enumerate() {
debug!("Request: {:?}", request);
// TODO remove this
let sampler = match LlamacppSampler::new(request) {
Some(sampler) => sampler,
_ => {
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
continue;
}
};
let last_pos = request.input_ids.len() - 1;
for (pos, &token_id) in request.input_ids.iter().enumerate() {
llamacpp.batch_push(
token_id as llamacpp::llama_token,
pos as llamacpp::llama_pos,
seq_id as llamacpp::llama_seq_id,
pos == last_pos, // check samplers
);
}
seqs.push(LlamacppSeq {
id: seq_id,
batch_pos: llamacpp.batch.n_tokens as usize - 1,
token: llamacpp::LLAMA_TOKEN_NULL,
pos: last_pos as llamacpp::llama_pos + 1,
sampler,
text: String::with_capacity(1024),
n_new_tokens: 0,
running: true,
});
}
while llamacpp.batch.n_tokens > 0 {
if llamacpp.decode() != 0 {
warn!("llama_decode failed, clearing kv cache");
llamacpp.clear_kv_cache(-1);
for seq in seqs.iter_mut() {
let _ = requests[seq.id]
.tx
.send(Err(InferError::IncompleteGeneration));
seq.running = false;
}
break;
}
for seq in seqs.iter_mut() {
if !seq.running {
continue;
}
let (next, logprob) = seq.sampler.sample(&mut llamacpp, seq.batch_pos);
seq.n_new_tokens += 1;
seq.token = next;
let piece = match tokenizer.decode(&[next as u32], false) {
Ok(piece) => piece,
Err(e) => {
error!("Failed to decode token: {e}");
let _ = requests[seq.id]
.tx
.send(Err(InferError::IncompleteGeneration));
seq.running = false;
continue;
}
};
let special = vocab.is_special_token(&piece);
if !special {
seq.text.push_str(&piece);
}
let token = Token {
id: next as _,
text: piece,
logprob,
special,
};
let finish: Option = {
if unsafe { llamacpp::vocab_is_eog(llamacpp.vocab, next) } {
Some(FinishReason::EndOfSequenceToken)
} else if seq.n_new_tokens == requests[seq.id].max_new_tokens {
Some(FinishReason::Length)
} else {
None
}
};
if let Some(reason) = finish {
let _ = requests[seq.id].tx.send(Ok(InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text: GeneratedText {
text: seq.text.clone(),
generated_tokens: seq.n_new_tokens as _,
finish_reason: reason,
seed: Some(requests[seq.id].seed as _),
},
start: start_time,
queued: requests[seq.id].time,
}));
seq.running = false;
continue;
}
let _ = requests[seq.id]
.tx
.send(Ok(InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
}));
}
// generate a new batch
llamacpp.batch.n_tokens = 0;
for seq in seqs.iter_mut() {
if seq.running {
seq.batch_pos =
llamacpp.batch_push(seq.token, seq.pos, seq.id as _, true);
seq.pos += 1;
} else {
llamacpp.clear_kv_cache(seq.id as _);
}
}
}
}
});
(
Self {
tx,
status: status_rx,
},
ok_rx,
shutdown_tx,
)
}
}
#[async_trait]
impl Backend for LlamacppBackend {
#[instrument(skip_all)]
fn schedule(
&self,
request: ValidGenerateRequest,
) -> Result>, InferError> {
debug!(?request);
let (tx, rx) = unbounded_channel::>();
match LlamacppRequest::new(&request, tx) {
Some(v) => match self.tx.send(v) {
Err(e) => Err(InferError::GenerationError(e.to_string())),
_ => Ok(UnboundedReceiverStream::new(rx)),
},
_ => Err(InferError::GenerationError("Bad request".to_string())),
}
}
async fn health(&self, _: bool) -> bool {
*self.status.borrow()
}
fn name(&self) -> &'static str {
"llamacpp"
}
}
#[derive(Debug, Error)]
pub enum BackendError {
#[error("CString error: {0}")]
CStringError(#[from] std::ffi::NulError),
#[error("Llamacpp error: {0}")]
Llamacpp(String),
}
================================================
FILE: backends/llamacpp/src/llamacpp.rs
================================================
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
#![allow(dead_code)]
include!(concat!(env!("OUT_DIR"), "/llamacpp.rs"));
================================================
FILE: backends/llamacpp/src/main.rs
================================================
mod backend;
mod llamacpp;
mod quantize;
use quantize::QuantizeType;
use backend::{
BackendError, LlamacppBackend, LlamacppConfig, LlamacppGGMLType, LlamacppNuma,
LlamacppSplitMode,
};
use clap::Parser;
use hf_hub::api::tokio::ApiBuilder;
use hf_hub::{Repo, RepoType};
use std::path::Path;
use text_generation_router::{logging, server, usage_stats};
use thiserror::Error;
use tokenizers::Tokenizer;
use tokio::process::Command;
use tokio::sync::oneshot::error::RecvError;
use tracing::{error, warn};
/// Backend Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
/// Name of the model to load.
#[clap(long, env)]
model_id: String,
/// Revision of the model.
#[clap(default_value = "main", long, env)]
revision: String,
/// Path to the GGUF model file for inference.
#[clap(long, env)]
model_gguf: Option,
/// Number of threads to use for generation.
#[clap(long, env)]
n_threads: Option,
/// Number of threads to use for batch processing.
#[clap(long, env)]
n_threads_batch: Option,
/// Number of layers to store in VRAM.
#[clap(default_value = "0", long, env)]
n_gpu_layers: usize,
/// Split the model across multiple GPUs.
#[clap(default_value = "layer", long, env)]
split_mode: LlamacppSplitMode,
/// Defragment the KV cache if holes/size > threshold.
#[clap(default_value = "-1.0", long, env)]
defrag_threshold: f32,
/// Enable NUMA optimizations.
#[clap(default_value = "disabled", value_enum, long, env)]
numa: LlamacppNuma,
/// Use memory mapping for the model.
#[clap(long, env)]
disable_mmap: bool,
/// Use memory locking to prevent swapping.
#[clap(long, env)]
use_mlock: bool,
/// Enable offloading of KQV operations to the GPU.
#[clap(long, env)]
disable_offload_kqv: bool,
/// Enable flash attention for faster inference. (EXPERIMENTAL)
#[clap(long, env)]
disable_flash_attention: bool,
/// Data type used for K cache.
#[clap(default_value = "f16", value_enum, long, env)]
type_k: LlamacppGGMLType,
/// Data type used for V cache.
#[clap(default_value = "f16", value_enum, long, env)]
type_v: LlamacppGGMLType,
/// Number of tokenizer workers used for payload validation and truncation.
#[clap(default_value = "2", long, env)]
validation_workers: usize,
/// Maximum number of concurrent requests.
#[clap(long, env)]
max_concurrent_requests: Option,
/// Maximum number of input tokens per request.
#[clap(default_value = "1024", long, env)]
max_input_tokens: usize,
/// Maximum number of total tokens (input + output) per request.
#[clap(default_value = "2048", long, env)]
max_total_tokens: usize,
/// Maximum number of tokens in a batch.
#[clap(long, env)]
max_batch_total_tokens: Option,
/// Maximum number of tokens in a physical batch.
#[clap(long, env)]
max_physical_batch_total_tokens: Option,
/// Maximum number of requests per batch.
#[clap(long, env)]
max_batch_size: Option,
/// IP address to listen on.
#[clap(default_value = "0.0.0.0", long)]
hostname: String,
/// Port to listen on.
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "9000", long, short, env)]
prometheus_port: u16,
/// Enable JSON output format.
#[clap(long, env)]
json_output: bool,
/// OTLP endpoint for telemetry data.
#[clap(long, env)]
otlp_endpoint: Option,
/// Service name for OTLP telemetry.
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: String,
/// Allowed origins for CORS.
#[clap(long, env)]
cors_allow_origin: Option>,
/// Path to the tokenizer configuration file.
#[clap(long, env)]
tokenizer_config_path: Option,
/// Disable grammar support.
#[clap(long, env)]
disable_grammar_support: bool,
/// Maximum number of inputs per request.
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
/// Level of usage statistics collection.
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
/// Maximum payload size in bytes.
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
/// Maximum image fetch size in bytes.
#[clap(default_value = "1073741824", long, env)]
max_image_fetch_size: usize,
}
#[tokio::main]
async fn main() -> Result<(), RouterError> {
let args = Args::parse();
logging::init_logging(args.otlp_endpoint, args.otlp_service_name, args.json_output);
let n_threads = match args.n_threads {
Some(0) | None => num_cpus::get(),
Some(threads) => threads,
};
let n_threads_batch = match args.n_threads_batch {
Some(0) | None => n_threads,
Some(threads) => threads,
};
let max_batch_size = match args.max_batch_size {
Some(0) | None => n_threads_batch,
Some(threads) => threads,
};
let max_batch_total_tokens = match args.max_batch_total_tokens {
None => max_batch_size * args.max_total_tokens,
Some(size) => size,
};
let max_physical_batch_total_tokens = match args.max_physical_batch_total_tokens {
None => max_batch_total_tokens,
Some(size) => size,
};
let max_concurrent_requests = match args.max_concurrent_requests {
None => max_batch_size * 2,
Some(size) => size,
};
if args.max_input_tokens >= args.max_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if args.max_total_tokens > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
));
}
if max_batch_size * args.max_total_tokens > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
));
}
let api_builder = || {
let mut builder = ApiBuilder::new().with_progress(true);
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
builder = builder.with_cache_dir(cache_dir.into());
}
if let Ok(token) = std::env::var("HF_TOKEN") {
builder = builder.with_token(token.into());
}
if let Ok(origin) = std::env::var("HF_HUB_USER_AGENT_ORIGIN") {
builder = builder.with_user_agent("origin", origin.as_str());
}
builder
};
let api_repo = api_builder().build()?.repo(Repo::with_revision(
args.model_id.clone(),
RepoType::Model,
args.revision.clone(),
));
let tokenizer_path = api_repo.get("tokenizer.json").await?;
let tokenizer = Tokenizer::from_file(&tokenizer_path)?;
let model_gguf = if let Some(model_gguf) = args.model_gguf {
model_gguf
} else {
let model_gguf = format!("models/{}/model.gguf", args.model_id);
let model_gguf_path = Path::new(&model_gguf);
if !model_gguf_path.exists() {
let tmp_gguf = "models/tmp.gguf";
if let Some(parent) = Path::new(model_gguf_path).parent() {
std::fs::create_dir_all(parent)?;
}
let cache_path = tokenizer_path.parent().unwrap();
for sibling in api_repo.info().await?.siblings {
let _ = api_repo.get(&sibling.rfilename).await?;
}
let status = Command::new("convert_hf_to_gguf.py")
.arg("--outfile")
.arg(tmp_gguf)
.arg(cache_path)
.spawn()?
.wait()
.await?;
if !status.success() {
let exit_code = status.code().unwrap_or(-1);
error!("Failed to generate GGUF, exit code: {}", exit_code);
return Err(RouterError::CommandError(exit_code));
}
quantize::model(tmp_gguf, &model_gguf, QuantizeType::MostlyQ4_0, n_threads)
.map_err(RouterError::QuantizeError)?;
}
model_gguf
};
let (backend, ok, shutdown) = LlamacppBackend::new(
LlamacppConfig {
model_gguf,
n_threads,
n_threads_batch,
n_gpu_layers: args.n_gpu_layers,
split_mode: args.split_mode,
defrag_threshold: args.defrag_threshold,
numa: args.numa,
use_mmap: !args.disable_mmap,
use_mlock: args.use_mlock,
flash_attention: !args.disable_flash_attention,
type_k: args.type_k,
type_v: args.type_v,
offload_kqv: !args.disable_offload_kqv,
max_batch_total_tokens,
max_physical_batch_total_tokens,
max_batch_size,
batch_timeout: tokio::time::Duration::from_millis(5),
},
tokenizer,
);
ok.await??;
if cfg!(debug_assertions) {
warn!("Graceful shutdown disabled!");
let _ = tokio::task::spawn(async move {
let _ = tokio::signal::ctrl_c().await;
let _ = shutdown.send(true);
});
}
server::run(
backend,
max_concurrent_requests,
0, // max_best_of
0, // max_stop_sequences
0, // max_top_n_tokens
args.max_input_tokens,
args.max_total_tokens,
args.validation_workers,
None, // api_key
args.model_id, // tokenizer_name
args.tokenizer_config_path,
Some(args.revision),
false, // trust_remote_code
args.hostname,
args.port,
args.cors_allow_origin,
false, // ngrok,
None, // ngrok_authtoken,
None, // ngrok_edge,
args.disable_grammar_support,
args.max_client_batch_size,
args.usage_stats,
args.payload_limit,
args.max_image_fetch_size,
args.prometheus_port,
)
.await?;
Ok(())
}
#[derive(Debug, Error)]
enum RouterError {
#[error("Argument validation error: {0}")]
ArgumentValidation(String),
#[error("Tokenizer error: {0}")]
Tokenizer(#[from] tokenizers::Error),
#[error("Backend error: {0}")]
Backend(#[from] BackendError),
#[error("WebServer error: {0}")]
WebServer(#[from] server::WebServerError),
#[error("Recv error: {0}")]
RecvError(#[from] RecvError),
#[error("Io error: {0}")]
IoError(#[from] std::io::Error),
#[error("Var error: {0}")]
VarError(#[from] std::env::VarError),
#[error("Quantize error: {0}")]
QuantizeError(String),
#[error("Command error: {0}")]
CommandError(i32),
#[error("HF hub error: {0}")]
HubError(#[from] hf_hub::api::tokio::ApiError),
}
================================================
FILE: backends/llamacpp/src/quantize.rs
================================================
use crate::llamacpp;
use std::ffi::CString;
#[repr(u32)]
#[derive(Debug, Clone, Copy)]
pub enum QuantizeType {
MostlyQ4_0 = 2,
}
pub fn model(
input_path: &str,
output_path: &str,
ftype: QuantizeType,
n_threads: usize,
) -> Result<(), String> {
let c_input_path =
CString::new(input_path).map_err(|e| format!("Failed to convert input path: {}", e))?;
let c_output_path =
CString::new(output_path).map_err(|e| format!("Failed to convert output path: {}", e))?;
let result = unsafe {
let mut params = llamacpp::model_quantize_default_params();
params.nthread = n_threads as _;
params.ftype = ftype as _;
params.quantize_output_tensor = true;
llamacpp::model_quantize(c_input_path.as_ptr(), c_output_path.as_ptr(), ¶ms)
};
if result == 0 {
Ok(())
} else {
Err(format!("Quantization failed, error code: {}", result))
}
}
================================================
FILE: backends/neuron/Cargo.toml
================================================
[workspace]
members = [
"backends/v2",
"backends/grpc-metadata",
"launcher",
"router"
]
default-members = [
"backends/v2",
"backends/grpc-metadata",
"launcher",
"router"
]
resolver = "2"
[workspace.package]
version = "3.0.0"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"
[workspace.dependencies]
base64 = "0.22.0"
tokenizers = { version = "0.20.0", features = ["http"] }
hf-hub = { version = "0.4.2", features = ["tokio"] }
metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
minijinja = { version = "2.2.0", features = ["json"] }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
[profile.release]
incremental = true
[profile.release-binary]
inherits = "release"
debug = 1
incremental = true
panic = "abort"
[profile.release-opt]
inherits = "release"
debug = 0
incremental = false
lto = "fat"
opt-level = 3
codegen-units = 1
================================================
FILE: backends/neuron/Makefile
================================================
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
mkfile_dir := $(dir $(mkfile_path))
root_dir := "${mkfile_dir}/../.."
.PHONY: image install_server test_server test_integration
VERSION := $(shell gawk 'match($$0, /^version = "(.*)"/, a) {print a[1]}' ${root_dir}/Cargo.toml)
image:
docker build --rm -f ${root_dir}/Dockerfile.neuron \
--ulimit nofile=100000:100000 \
--build-arg VERSION=$(VERSION) \
-t text-generation-inference:$(VERSION)-neuron ${root_dir}
docker tag text-generation-inference:$(VERSION)-neuron text-generation-inference:latest-neuron
install_server:
make -C ${mkfile_dir}/server install VERSION:=${VERSION}
test_server: install_server
python -m pip install -r ${mkfile_dir}/tests/requirements.txt
python -m pytest -sv ${mkfile_dir}/tests/server
================================================
FILE: backends/neuron/README.md
================================================
# Text-generation-inference - Neuron backend for AWS Trainium and inferentia2
## Description
This is the TGI backend for AWS Neuron Trainium and Inferentia family of chips.
This backend is composed of:
- the AWS Neuron SDK,
- the legacy v2 TGI launcher and router,
- a neuron specific inference server for text-generation.
## Usage
Please refer to the official [documentation](https://huggingface.co/docs/text-generation-inference/backends/neuron).
## Build your own image
The simplest way to build TGI with the neuron backend is to use the provided `Makefile`:
```shell
$ make -C backends/neuron image
```
Alternatively, you can build the image directly from the top directory using a command similar to the one defined
in the `Makefile` under the `image` target.
================================================
FILE: backends/neuron/server/.gitignore
================================================
build
================================================
FILE: backends/neuron/server/Makefile
================================================
# Initialize base variables
SHELL := /bin/bash
pkg_name := text_generation_server
BUILDDIR ?= $(CURDIR)/build
VERSION ?= 0.0.1
mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
mkfile_dir := $(dir $(mkfile_path))
pkg_dir := $(BUILDDIR)/$(pkg_name)
py_version := $(subst -,.,${VERSION})
pkg_dist := ${BUILDDIR}/dist/${pkg_name}-$(py_version).tar.gz
clean:
rm -rf $(BUILDDIR)/*
${BUILDDIR}:
install -d $@
# List static sources to be deployed in the package
src_dir := $(mkfile_dir)/$(pkg_name)
sources := $(wildcard $(src_dir)/*.py)
deployed_sources := $(subst $(src_dir), $(pkg_dir), $(sources))
# Static files are just copied
define COPY
cp -f $< $@
endef
# We use a PHONY target to represent the VERSION
.PHONY: VERSION
VERSION: ${BUILDDIR}
# The trick is to compare the value of the variable with the content of a file in the build directory
@if [[ `cat ${BUILDDIR}/VERSION 2>&1` != '$(VERSION)' ]]; then echo -n $(VERSION) >${BUILDDIR}/VERSION; fi
# Depending on the PHONY VERSION target makes sure the pyproject.toml is regenerated if the version changes
$(BUILDDIR)/pyproject.toml: $(mkfile_dir)/pyproject.toml VERSION
mkdir -p $(BUILDDIR)
$(COPY)
sed -i -e 's/version = "VERSION"/version = \"${VERSION}\"/' $@
$(pkg_dir)/%.py: $(src_dir)/%.py
mkdir -p $(pkg_dir)
$(COPY)
# Generated files are produced by grpcio tools
# If not provided, get local proto files
ifndef PROTODIR
PROTODIR := $(mkfile_dir)/../../../proto
endif
# Three python files are generated for each protobuf
protobufs := $(PROTODIR)/generate.proto
pkg_pb_dir := $(pkg_dir)/pb
generated_sources_base := $(foreach proto, $(protobufs), $(proto:.proto=_pb2.py))
generated_sources := $(subst $(PROTODIR), $(pkg_pb_dir), $(generated_sources_base))
generated_sources += $(subst $(PROTODIR), $(pkg_pb_dir), $(generated_sources_base:.py=.pyi))
generated_sources += $(subst $(PROTODIR), $(pkg_pb_dir), $(generated_sources_base:.py=_grpc.py))
$(pkg_pb_dir)/%_pb2.py $(pkg_pb_dir)/%_pb2.pyi $(pkg_pb_dir)/%_pb2_grpc.py: $(PROTODIR)/%.proto
mkdir -p $(pkg_pb_dir)
python -m grpc_tools.protoc -I$(PROTODIR) --python_out=$(pkg_pb_dir) \
--grpc_python_out=$(pkg_pb_dir) --mypy_out=$(pkg_pb_dir) $^
sed -i -e 's/^\(import.*pb2\)/from . \1/g' $(pkg_pb_dir)/$*_pb2_grpc.py
${pkg_dist}: $(BUILDDIR)/pyproject.toml $(deployed_sources) $(generated_sources)
python -m build $(BUILDDIR)
package: ${pkg_dist}
install: ${pkg_dist}
python3 -m pip uninstall -y ${pkg_name}
python3 -m pip install ${pkg_dist}
================================================
FILE: backends/neuron/server/build-requirements.txt
================================================
build
grpcio-tools==1.53.0
mypy-protobuf
================================================
FILE: backends/neuron/server/pyproject.toml
================================================
[build-system]
requires = ["setuptools>=78.1"]
build-backend = "setuptools.build_meta"
[project]
name = "text-generation-server"
version = "VERSION"
authors = [{name="David Corvoysier", email="david@huggingface.co" }]
description = "TGI compatible inference server for AWS Neuronx platforms"
dependencies = [
'protobuf > 3.20.1, < 4',
'grpcio == 1.57.0',
'grpcio-status == 1.48.2',
'grpcio-reflection == 1.48.2',
'grpc-interceptor == 0.15.2',
'typer == 0.6.1',
'safetensors',
'loguru == 0.6.0',
'optimum-neuron[neuronx] >= 0.0.28',
]
[tool.setuptools]
packages = ["text_generation_server", "text_generation_server.pb"]
[project.scripts]
text-generation-server = 'text_generation_server.cli:app'
================================================
FILE: backends/neuron/server/text_generation_server/cli.py
================================================
import sys
from typing import Optional
import typer
from loguru import logger
app = typer.Typer()
@app.command()
def serve(
model_id: str,
revision: Optional[str] = None,
sharded: bool = False,
trust_remote_code: bool = None,
uds_path: str = "/tmp/text-generation-server",
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
otlp_service_name: str = "text-generation-inference.server",
max_input_tokens: Optional[int] = None,
):
"""This is the main entry-point for the server CLI.
Args:
model_id (`str`):
The *model_id* of a model on the HuggingFace hub or the path to a local model.
revision (`Optional[str]`, defaults to `None`):
The revision of the model on the HuggingFace hub.
sharded (`bool`):
Whether the model must be sharded or not. Kept for compatibility with the
text-generation-launcher, but must be set to False.
trust-remote-code (`bool`):
Kept for compatibility with text-generation-launcher. Ignored.
uds_path (`Union[Path, str]`):
The local path on which the server will expose its google RPC services.
logger_level (`str`):
The server logger level. Defaults to *INFO*.
json_output (`bool`):
Use JSON format for log serialization.
otlp_endpoint (`Optional[str]`, defaults to `None`):
The Open Telemetry endpoint to use.
otlp_service_name (`Optional[str]`, defaults to `None`):
The name to use when pushing data to the Open Telemetry endpoint.
max_input_tokens (`Optional[int]`, defaults to `None`):
The maximum number of input tokens each request should contain.
"""
if sharded:
raise ValueError("Sharding is not supported.")
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation_server",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
if trust_remote_code is not None:
logger.warning(
"'trust_remote_code' argument is not supported and will be ignored."
)
# Import here after the logger is added to log potential import exceptions
from .server import serve
serve(model_id, revision, uds_path)
@app.command()
def download_weights(
model_id: str,
revision: Optional[str] = None,
logger_level: str = "INFO",
json_output: bool = False,
auto_convert: Optional[bool] = None,
extension: Optional[str] = None,
trust_remote_code: Optional[bool] = None,
merge_lora: Optional[bool] = None,
):
"""Download the model weights.
This command will be called by text-generation-launcher before serving the model.
"""
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation_server",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
if extension is not None:
logger.warning("'extension' argument is not supported and will be ignored.")
if trust_remote_code is not None:
logger.warning(
"'trust_remote_code' argument is not supported and will be ignored."
)
if auto_convert is not None:
logger.warning("'auto_convert' argument is not supported and will be ignored.")
if merge_lora is not None:
logger.warning("'merge_lora' argument is not supported and will be ignored.")
# Import here after the logger is added to log potential import exceptions
from .model import fetch_model
fetch_model(model_id, revision)
================================================
FILE: backends/neuron/server/text_generation_server/generator.py
================================================
import copy
import logging
import time
from abc import ABC
from enum import Enum
from typing import List, Optional, Tuple
import torch
from loguru import logger
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from optimum.neuron.configuration_utils import NeuronConfig
from transformers.generation import GenerationConfig
from optimum.neuron import NeuronModelForCausalLM
from optimum.neuron.generation import TokenSelector
from .model import get_export_kwargs_from_env
from .pb.generate_pb2 import (
Batch,
CachedBatch,
FinishReason,
GeneratedText,
Generation,
InfoResponse,
Request,
Tokens,
)
# Disable optimum-neuron warnings as it seems to block the server after a while
optimum_logger = logging.getLogger("optimum.neuron")
optimum_logger.setLevel("CRITICAL")
class Generator(ABC):
"""An abstract class to represent the workhorse behind TextGenerationService.
Ideally, it should not rely on protobuf constructs, but in a first step it does.
Implementations would typically need a model and a tokenizer to implement the Generator methods.
"""
@property
def info(self) -> InfoResponse:
"""This should simply return the expected InfoResponse"""
raise NotImplementedError
def warmup(self, batch: Batch) -> int:
"""Verify if the hardware can support the target load.
Args:
batch (`Batch`):
A batch corresponding to the maximum number of concurrent requests.
Return:
The maximum number of tokens the model supports.
"""
raise NotImplementedError
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
"""Prefill is called whenever new requests need to be added.
When this method returns successfully, a decode method will follow
with both the current and newly prefilled batch(es).
Args:
batch (`Batch`):
A batch containing the new requests.
Return:
A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
"""
raise NotImplementedError
def decode(self, batches: List[Batch]) -> Tuple[List[Generation], CachedBatch]:
"""Decode after a prefill or another decode."""
raise NotImplementedError
def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch:
"""Remove requests that are not listed from the specified batch"""
raise NotImplementedError
def clear(self):
"""Remove all requests from the generator"""
raise NotImplementedError
@classmethod
def from_pretrained(cls, model_id: str, revision: Optional[str]):
"""Factory method "a la transformers" """
raise NotImplementedError
class Slot:
"""Represents a slot in a static batch"""
class State(Enum):
EMPTY = 0
PAUSE = 1
READY = 2
def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase):
self._id = id
self._tokenizer = tokenizer
self.clear()
def clear(self):
"""Clear the slot and mark it as available."""
self._state = Slot.State.EMPTY
self._batch_id = None
self._request_id = None
self._inputs = ""
self._truncate = 0
self._generation_config = None
self._tokens = []
self._mask = torch.tensor([])
self._selector = None
self._generated_tokens = 0
self._next_text_token_start = 0
self._next_text_token_end = 0
self._generated_text = ""
self._next_text = ""
@property
def id(self) -> int:
return self._id
@property
def state(self) -> "Slot.State":
return self._state
@property
def batch_id(self) -> int:
return self._batch_id
@property
def request_id(self) -> int:
return self._request_id
@property
def cached_text(self) -> str:
return self._inputs + self._generated_text
@property
def generation_config(self) -> GenerationConfig:
return self._generation_config
@property
def generated_tokens(self) -> int:
return self._generated_tokens
def assign(
self, batch_id: int, request: Request, generation_config: GenerationConfig
):
"""Assign a request to a slot.
Args:
request (`Request`):
The request to be assigned. Contains the inputs and tokens selection parameters.
generation_config (`transformers.GenerationConfig`):
The base generation config (might be modified by the request generation parameters).
"""
self._state = Slot.State.READY
self._batch_id = batch_id
self._request_id = request.id
self._inputs = request.inputs
if request.truncate:
self._truncate = request.truncate
self._generation_config = copy.deepcopy(generation_config)
# Update generation config with request parameters
self._generation_config.do_sample = request.parameters.do_sample
if self._generation_config.do_sample:
if request.parameters.temperature != 0:
self._generation_config.temperature = request.parameters.temperature
if request.parameters.top_k != 0:
self._generation_config.top_k = request.parameters.top_k
if request.parameters.top_p != 0:
self._generation_config.top_p = request.parameters.top_p
if request.parameters.typical_p != 0:
self._generation_config.typical_p = request.parameters.typical_p
else:
# Set the sampling parameters to emulate greedy decoding when using on-device sampling
self._generation_config.temperature = 1.0
self._generation_config.top_k = 1
self._generation_config.top_p = 1.0
self._generation_config.typical_p = 1.0
if request.parameters.repetition_penalty != 0:
self._generation_config.repetition_penalty = (
request.parameters.repetition_penalty
)
self.seed = request.parameters.seed
self._generation_config.max_new_tokens = (
request.stopping_parameters.max_new_tokens
)
self._max_new_tokens = self._generation_config.max_new_tokens
stop_strings = request.stopping_parameters.stop_sequences
if stop_strings:
self._generation_config.stop_strings = stop_strings
def reset(
self,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
selector: TokenSelector,
):
"""Reset the slot for the next generation.
Args:
input_ids: (`torch.LongTensor`):
The new input_ids to use to generate the next token.
attention_mask: (`torch.LongTensor`):
The new attention_mask to use to generate the next token.
selector: (`optimum.neuron.generation.TokenSelector`):
An object implementing the updated token selection logic.
"""
self._tokens = input_ids.clone()
self._next_text_token_start = 0
self._next_text_token_end = torch.numel(self._tokens)
self._next_text = ""
self._mask = attention_mask.clone()
self._selector = selector
def pause(self):
"""Mark the current slot as paused for generation.
Note that the KV cache for this slot will still be filled.
"""
self._state = Slot.State.PAUSE
def resume(self):
"""Mark the slot as ready for generation."""
self._state = Slot.State.READY
def _decode_next_tokens(
self,
) -> str:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
# We need to include the tokens that produced the last text to defeat cleanup algorithms in the decode
# which decide to add a space or not depending on the surrounding ids.
new_text = self._tokenizer.decode(
self._tokens[self._next_text_token_start :], skip_special_tokens=False
)
if new_text.endswith("�"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
return ""
# Compare the generated text with the one using only the tokens producing the last one
last_text = self._tokenizer.decode(
self._tokens[self._next_text_token_start : self._next_text_token_end],
skip_special_tokens=False,
)
if len(new_text) == len(last_text):
# Nothing new was actually generated
return ""
# Return the decoded text and store its token offsets
self._next_text_token_start = self._next_text_token_end
self._next_text_token_end = torch.numel(self._tokens)
return new_text[len(last_text) :]
def append(self, next_token: int) -> str:
"""Append a new generated token to this slot
The new token is added to the list of generated tokens, which impacts
directly the generated_text and stopped property.
The new token is however not added immediately to the slot inputs: it will
be added later on when it has effectively been used to produce the next token.
Args:
next_token (`int`):
The newly generated token.
Return:
The corresponding decoded text (if any).
"""
self._tokens = torch.cat([self._tokens, torch.LongTensor([next_token])])
self._mask = torch.cat([self._mask, torch.LongTensor([1])])
self._generated_tokens += 1
next_text = self._decode_next_tokens()
# Now that a new token has been generated, we can append the previous one to the generated text
self._generated_text += self._next_text
self._next_text = next_text
return next_text
def select(
self, input_ids: torch.LongTensor, logits: torch.Tensor
) -> torch.LongTensor:
"""Select the next token from the candidate logits.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation (not used in all generation modes).
logits (`torch.Tensor` of shape `(batch_size, sequence_length)`):
The logits corresponding to the generated tokens.
Return:
`torch.LongTensor`: A scalar torch.LongTensor` containing the selected token.
"""
return self._selector.select(input_ids, logits)[0]
@property
def stopped(self) -> bool:
# Transformers stopping criteria expects a batch of input ids
input_ids = torch.unsqueeze(self._tokens, dim=0)
return self._selector.stopping_criteria(input_ids, None)
@property
def generated_text(self) -> str:
return self._generated_text + self._next_text
@property
def next_token(self) -> int:
return None if len(self._tokens) == 0 else self._tokens[-1]
@property
def attention_mask(self) -> torch.LongTensor:
return self._mask
@property
def max_token(self) -> int:
return self._generation_config.max_length
@property
def max_new_tokens(self) -> int:
# The current value of max_new_tokens: might be different of the target max_new_tokens
# if the slot has been paused and resumed.
return self._generation_config.max_new_tokens
@property
def truncate(self) -> int:
return self._truncate
class NeuronGenerator(Generator):
"""A Generator for Neuron models."""
def __init__(
self,
model: NeuronModelForCausalLM,
tokenizer: PreTrainedTokenizerBase,
):
self.model = model
if not isinstance(self.model, NeuronModelForCausalLM):
raise ValueError("The model must be a NeuronModelForCausalLM.")
if (
model.neuron_config.batch_size > 1
and not model.neuron_config.continuous_batching
):
raise ValueError(
"The neuron model must be compiled with continuous_batching=True."
)
# Specify padding and truncation options for decoder-only architecture
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"
self.tokenizer = tokenizer
self.special_tokens = self.tokenizer.all_special_ids
self.slots = [
Slot(i, tokenizer) for i in range(self.model.neuron_config.batch_size)
]
self.batch_id = 0
@property
def on_device_sampling(self) -> bool:
return getattr(self.model.neuron_config, "on_device_sampling", False)
@property
def info(self) -> InfoResponse:
"""Returns the expected InfoResponse."""
dtype = getattr(self.model.config, "torch_dtype", "float32")
return InfoResponse(
requires_padding=True,
dtype=str(dtype),
device_type="xla",
)
def warmup(self, batch: Batch) -> int:
"""Verify if the hardware can support the target load.
Args:
batch (`Batch`):
A batch corresponding to the maximum number of concurrent requests.
Return:
The maximum number of tokens the model supports.
"""
# Just check that the warmup request parameters match the model capacity
batch_size = self.model.neuron_config.batch_size
if len(batch.requests) > batch_size:
raise ValueError(
f"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI. The compiled model.neuron_config.batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE."
)
self.prefill(batch)
self.clear()
return (
self.model.neuron_config.batch_size
* self.model.neuron_config.sequence_length
)
def max_prefill_length(self) -> int:
if hasattr(self.model.neuron_config, "max_context_length"):
return self.model.neuron_config.max_context_length
return self.model.neuron_config.sequence_length
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
"""Prefill new requests.
Args:
batch (`Batch`):
A batch containing the new requests.
Return:
A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
"""
slots = {state: [] for state in Slot.State}
for slot in self.slots:
slots[slot.state].append(slot)
active_slots = slots[Slot.State.READY]
empty_slots = slots[Slot.State.EMPTY]
if len(empty_slots) < len(batch.requests):
raise ValueError(
f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots."
f" Please align max_batch_size with the static batch size: {self.model.neuron_config.batch_size}."
)
# Assign each request to an empty slot
logger.debug(
f"Prefilling {len(batch.requests)} new request(s) with {len(empty_slots)} empty slot(s)"
)
new_slots = []
for request in batch.requests:
slot = empty_slots.pop()
slot.assign(self.batch_id, request, self.model.generation_config)
new_slots.append(slot)
logger.debug(
f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
)
prefill_slots = new_slots
seq_ids = torch.tensor([slot.id for slot in prefill_slots])
# Reconstruct the full inputs (without padding) as seen by the model.
# This comprises:
# - the inputs for new requests,
# - only when rebuilding the cache, the inputs and the generated text that has already
# been cached (i.e. excluding the last generated token) for unfinished requests.
inputs = []
max_length = 0
for slot in prefill_slots:
inputs.append(slot.cached_text)
# Apply truncation, making sure we fit into static dimensions
if slot.truncate == 0:
max_length = self.max_prefill_length()
elif (
slot.truncate > max_length and slot.truncate < self.max_prefill_length()
):
max_length = slot.truncate
# Tokenize with padding and truncation
padded_inputs = self.tokenizer(
inputs,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
input_ids = padded_inputs.input_ids
attention_mask = padded_inputs.attention_mask
sampling_params = (
torch.zeros(input_ids.shape[0], 3) if self.on_device_sampling else None
)
# Pause previously active slots during generation
for slot in active_slots:
slot.pause()
# Each slot must be reset with the padded inputs and masks
for i, slot in enumerate(prefill_slots):
if slot.state != slot.state.EMPTY:
if slot.truncate > 0 and slot.truncate < input_ids.shape[-1]:
# Apply per-request truncation
input_ids[i, : -slot.truncate] = self.tokenizer.pad_token_id
attention_mask[i, : -slot.truncate] = 0
slot_input_ids = input_ids[i : i + 1, :]
# Padded input ids are also required to set logits processors and stopping criterias
selector = TokenSelector.create(
slot_input_ids,
slot.generation_config,
self.model,
self.model.neuron_config.sequence_length,
tokenizer=self.tokenizer,
seed=slot.seed,
)
slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)
slot_attention_mask = attention_mask[i]
slot.reset(slot_input_ids, slot_attention_mask, selector)
if sampling_params is not None:
sampling_params[i, 0] = slot.generation_config.top_k
sampling_params[i, 1] = slot.generation_config.top_p
sampling_params[i, 2] = slot.generation_config.temperature
# Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored,
# as they have already been generated and sent back in the last decode.
model_inputs = self.model.prepare_inputs_for_prefill(
input_ids,
attention_mask=attention_mask,
seq_ids=seq_ids,
sampling_params=sampling_params,
)
tokens_or_logits = self.model(**model_inputs)[0]
generation, next_batch = self._generate_token(
prefill_slots, self.batch_id, tokens_or_logits, input_ids
)
self.batch_id += 1
# Reactivate previously active slots for the next decode
for i, slot in enumerate(active_slots):
slot.resume()
logger.debug("Model ready for decoding")
if next_batch is not None:
logger.debug(
f"Next batch is {next_batch.id} with requests: {next_batch.request_ids}"
)
return generation, next_batch
def decode(
self, batches: List[CachedBatch]
) -> Tuple[List[Generation], CachedBatch]:
"""Decode the specified prefilled requests.
Args:
batches (`List[CachedBatch]`):
A list of previous batches containing the prefilled requests.
Return:
A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
"""
# batches contains a list composed of:
# - the batch id returned by the last decode,
# - the batch id(s) returned by the last prefill(s)
# Batches are always concatenated during prefill, so we can
# just carry on with decoding. We adopt the id of the first
# batch in the list as our next batch id.
next_batch_id = batches[0].id
request_ids = []
for batch in batches:
request_ids += batch.request_ids
cleared_request_ids = []
for slot in self.slots:
if slot.state == slot.State.READY and slot.request_id not in request_ids:
cleared_request_ids.append(slot.request_id)
slot.clear()
if len(cleared_request_ids) > 0:
logger.info(
f"Clearing slot for requests {cleared_request_ids} as they are not requested."
)
active_slots = [slot for slot in self.slots if slot.state == slot.State.READY]
if len(active_slots) < len(request_ids):
raise ValueError(
"Unable to decode tokens for non-prefilled batches (probably due to a previous failure)"
)
decode_slots = active_slots
seq_ids = torch.tensor([slot.id for slot in decode_slots])
# Reconstruct input_ids and attention_mask from decode slots
n_slots = len(decode_slots)
input_ids = torch.full(
[n_slots, 1], fill_value=self.tokenizer.eos_token_id, dtype=torch.int64
)
max_length = 0
for slot in decode_slots:
max_length = max(max_length, slot.attention_mask.size(-1))
attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64)
sampling_params = torch.zeros(n_slots, 3) if self.on_device_sampling else None
for i, slot in enumerate(decode_slots):
if slot.state != Slot.State.EMPTY:
# input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
input_ids[i, 0] = slot.next_token
attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask
if sampling_params is not None:
sampling_params[i, 0] = slot.generation_config.top_k
sampling_params[i, 1] = slot.generation_config.top_p
sampling_params[i, 2] = slot.generation_config.temperature
model_inputs = self.model.prepare_inputs_for_decode(
input_ids,
attention_mask=attention_mask,
seq_ids=seq_ids,
sampling_params=sampling_params,
)
tokens_or_logits = self.model(**model_inputs)[0]
return self._generate_token(
decode_slots, next_batch_id, tokens_or_logits, input_ids
)
def _generate_token(
self,
slots: List[Slot],
next_batch_id: int,
tokens_or_logits: torch.Tensor,
input_ids: torch.LongTensor,
) -> Tuple[List[Generation], CachedBatch]:
generations = []
active_slots = False
for i, slot in enumerate(slots):
if slot.state != Slot.State.READY:
continue
request_id = slot.request_id
slot_input_ids = input_ids[i : i + 1, :]
if self.on_device_sampling:
next_token = tokens_or_logits[i]
else:
next_token_logits = tokens_or_logits[i : i + 1, -1, :]
next_token = slot.select(slot_input_ids, next_token_logits)
next_token_text = slot.append(next_token)
generated_text = None
finish_reason = None
if next_token == self.tokenizer.eos_token_id:
finish_reason = FinishReason.FINISH_REASON_EOS_TOKEN
elif slot.stopped:
if slot.generated_tokens == slot.max_new_tokens:
finish_reason = FinishReason.FINISH_REASON_LENGTH
else:
finish_reason = FinishReason.FINISH_REASON_STOP_SEQUENCE
if finish_reason is not None:
# We must include the generated text for each finished sequence in the response
generated_text = GeneratedText(
text=slot.generated_text,
generated_tokens=slot.generated_tokens,
finish_reason=finish_reason,
)
logger.debug(
f"Decode complete for request {request_id} with {slot.generated_tokens} tokens"
)
# mark the slot as available
slot.clear()
else:
active_slots = True
generations.append(
Generation(
request_id=request_id,
prefill_tokens=None,
tokens=Tokens(
ids=[next_token],
logprobs=[0],
texts=[next_token_text],
is_special=[next_token in self.special_tokens],
),
generated_text=generated_text,
)
)
batch = None
if active_slots:
# Whatever initial batch these requests came from, we always return all pending requests in a single batch
request_ids = [
slot.request_id for slot in self.slots if slot.state == Slot.State.READY
]
batch = self._cached_batch(next_batch_id, request_ids)
else:
logger.debug("No more pending requests")
return generations, batch
def _cached_batch(self, batch_id: int, request_ids: List):
size = len(request_ids)
max_tokens = size * self.model.neuron_config.sequence_length
return CachedBatch(
id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens
)
def filter(self, batch_id: int, keep_request_ids: List[int]) -> CachedBatch:
"""Remove requests that are not listed from the specified batch
Args:
batch_id (`int`):
The id of a cached batch.
keep_ids(`List[int]`):
The list of requests that must be kept.
Return:
A `CachedBatch` containing the pending requests.
"""
keep_slot_ids = [
slot.id for slot in self.slots if slot.request_id in keep_request_ids
]
self._clear(keep_slot_ids)
return self._cached_batch(batch_id, keep_request_ids)
def clear(self, batch_id: Optional[int] = None):
"""Remove a subset or all requests from the generator"""
keep_ids = []
if batch_id is not None:
keep_ids = [slot.id for slot in self.slots if slot.batch_id != batch_id]
return self._clear(keep_ids)
def _clear(self, keep_slot_ids: List):
for slot in self.slots:
if slot.state != Slot.State.EMPTY and slot.id not in keep_slot_ids:
logger.debug(f"Removing slot {slot.id} with request {slot.request_id}")
slot.clear()
@classmethod
def from_pretrained(cls, model_id: str, revision: str = None):
"""Instantiate a NeuronGenerator.
Args:
model_id (`str`):
A hub model id or the path to a local model. This path must also contain a Tokenizer.
revision (`Optional[str]`, defaults to `None`):
The revision of the model on the HuggingFace hub.
Returns:
A NeuronGenerator.
"""
try:
neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision)
except Exception as e:
logger.debug(
"NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
model_id,
revision,
e,
)
neuron_config = None
start = time.time()
if neuron_config is None:
export_kwargs = get_export_kwargs_from_env()
logger.info(f"Exporting model to neuron with config: {export_kwargs}.")
model = NeuronModelForCausalLM.from_pretrained(
model_id,
revision=revision,
low_cpu_mem_usage=True,
export=True,
**export_kwargs,
)
else:
logger.info(
"Loading model on neuron devices (this can take a few minutes)."
)
model = NeuronModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, revision=revision
)
end = time.time()
logger.info(f"Model successfully loaded in {end - start:.2f} s.")
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
return cls(model, tokenizer)
================================================
FILE: backends/neuron/server/text_generation_server/interceptor.py
================================================
from typing import Any, Callable
import grpc
from google.rpc import code_pb2, status_pb2
from grpc_interceptor.server import AsyncServerInterceptor
from grpc_status import rpc_status
from loguru import logger
class ExceptionInterceptor(AsyncServerInterceptor):
async def intercept(
self,
method: Callable,
request_or_iterator: Any,
context: grpc.ServicerContext,
method_name: str,
) -> Any:
try:
response = method(request_or_iterator, context)
return await response
except Exception as err:
method_name = method_name.split("/")[-1]
logger.exception(f"Method {method_name} encountered an error.")
await context.abort_with_status(
rpc_status.to_status(
status_pb2.Status(code=code_pb2.INTERNAL, message=str(err))
)
)
================================================
FILE: backends/neuron/server/text_generation_server/model.py
================================================
import os
import shutil
import time
from typing import Optional
from huggingface_hub import snapshot_download
from huggingface_hub.constants import HF_HUB_CACHE
from loguru import logger
from optimum.neuron.cache import get_hub_cached_entries
from optimum.neuron.configuration_utils import NeuronConfig
from .tgi_env import check_env_and_neuron_config_compatibility
def get_export_kwargs_from_env():
batch_size = os.environ.get("MAX_BATCH_SIZE", None)
if batch_size is not None:
batch_size = int(batch_size)
sequence_length = os.environ.get("MAX_TOTAL_TOKENS", None)
if sequence_length is not None:
sequence_length = int(sequence_length)
num_cores = os.environ.get("HF_NUM_CORES", None)
if num_cores is not None:
num_cores = int(num_cores)
auto_cast_type = os.environ.get("HF_AUTO_CAST_TYPE", None)
return {
"batch_size": batch_size,
"sequence_length": sequence_length,
"num_cores": num_cores,
"auto_cast_type": auto_cast_type,
}
def is_cached(model_id):
# Look for cached entries for the specified model
in_cache = False
entries = get_hub_cached_entries(model_id)
# Look for compatible entries
for entry in entries:
if check_env_and_neuron_config_compatibility(
entry, check_compiler_version=True
):
in_cache = True
break
return in_cache
def log_cache_size():
path = HF_HUB_CACHE
if os.path.exists(path):
usage = shutil.disk_usage(path)
gb = 2**30
logger.info(
f"Cache disk [{path}]: total = {usage.total / gb:.2f} G, free = {usage.free / gb:.2f} G"
)
else:
raise ValueError(f"The cache directory ({path}) does not exist.")
def fetch_model(
model_id: str,
revision: Optional[str] = None,
) -> str:
"""Fetch a neuron model.
Args:
model_id (`str`):
The *model_id* of a model on the HuggingFace hub or the path to a local model.
revision (`Optional[str]`, defaults to `None`):
The revision of the model on the HuggingFace hub.
Returns:
A string corresponding to the model_id or path.
"""
if not os.path.isdir("/sys/class/neuron_device/"):
raise SystemError("No neuron cores detected on the host.")
if os.path.isdir(model_id) and revision is not None:
logger.warning(
"Revision {} ignored for local model at {}".format(revision, model_id)
)
revision = None
# Download the model from the Hub (HUGGING_FACE_HUB_TOKEN must be set for a private or gated model)
# Note that the model may already be present in the cache.
try:
neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision)
except Exception as e:
logger.debug(
"NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
model_id,
revision,
e,
)
neuron_config = None
if neuron_config is not None:
if os.path.isdir(model_id):
return model_id
# Prefetch the neuron model from the Hub
logger.info(
f"Fetching revision [{revision}] for neuron model {model_id} under {HF_HUB_CACHE}"
)
log_cache_size()
return snapshot_download(model_id, revision=revision, ignore_patterns="*.bin")
# Model needs to be exported: look for compatible cached entries on the hub
if not is_cached(model_id):
hub_cache_url = "https://huggingface.co/aws-neuron/optimum-neuron-cache"
neuron_export_url = "https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-neuronx-tgi"
error_msg = (
f"No cached version found for {model_id} with {get_export_kwargs_from_env()}."
f"You can start a discussion to request it on {hub_cache_url}"
f"Alternatively, you can export your own neuron model as explained in {neuron_export_url}"
)
raise ValueError(error_msg)
logger.warning(
f"{model_id} is not a neuron model: it will be exported using cached artifacts."
)
if os.path.isdir(model_id):
return model_id
# Prefetch weights, tokenizer and generation config so that they are in cache
log_cache_size()
start = time.time()
snapshot_path = snapshot_download(
model_id, revision=revision, ignore_patterns="*.bin"
)
end = time.time()
logger.info(f"Model weights fetched in {end - start:.2f} s.")
log_cache_size()
return snapshot_path
================================================
FILE: backends/neuron/server/text_generation_server/server.py
================================================
import asyncio
from pathlib import Path
from typing import List
from grpc import aio
from grpc_reflection.v1alpha import reflection
from loguru import logger
from .generator import Generator, NeuronGenerator
from .interceptor import ExceptionInterceptor
from .pb import generate_pb2, generate_pb2_grpc
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__(self, generator: Generator, server_urls: List[str]):
self.generator = generator
self.server_urls = server_urls
async def Info(self, request, context):
return self.generator.info
async def Health(self, request, context):
return generate_pb2.HealthResponse()
async def ServiceDiscovery(self, request, context):
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
async def ClearCache(self, request, context):
if request.HasField("id"):
self.generator.clear(request.id)
else:
self.generator.clear()
return generate_pb2.ClearCacheResponse()
async def FilterBatch(self, request, context):
filtered_batch = self.generator.filter(request.batch_id, request.request_ids)
return generate_pb2.FilterBatchResponse(batch=filtered_batch)
async def Warmup(self, request, context):
max_tokens = self.generator.warmup(request.batch)
return generate_pb2.WarmupResponse(max_supported_total_tokens=max_tokens)
async def Prefill(self, request, context):
generations, batch = self.generator.prefill(request.batch)
return generate_pb2.PrefillResponse(generations=generations, batch=batch)
async def Decode(self, request, context):
generations, batch = self.generator.decode(request.batches)
return generate_pb2.DecodeResponse(generations=generations, batch=batch)
def serve(
model_id: str,
revision: str,
uds_path: Path,
):
async def serve_inner(model_id: str, revision: str):
unix_socket_template = "unix://{}-{}"
local_url = unix_socket_template.format(uds_path, 0)
server_urls = [local_url]
try:
generator = NeuronGenerator.from_pretrained(model_id, revision)
except Exception:
logger.exception("Error when initializing model")
raise
server = aio.server(interceptors=[ExceptionInterceptor()])
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
TextGenerationService(generator, server_urls), server
)
SERVICE_NAMES = (
generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
reflection.SERVICE_NAME,
)
reflection.enable_server_reflection(SERVICE_NAMES, server)
server.add_insecure_port(local_url)
await server.start()
logger.info("Server started at {}".format(local_url))
try:
await server.wait_for_termination()
except KeyboardInterrupt:
logger.info("Signal received. Shutting down")
await server.stop(0)
asyncio.run(serve_inner(model_id, revision))
================================================
FILE: backends/neuron/server/text_generation_server/tgi_env.py
================================================
#!/usr/bin/env python
import argparse
import logging
import os
import sys
from typing import Any, Dict, List, Optional
from optimum.neuron.modeling_decoder import get_available_cores
from optimum.neuron.cache import get_hub_cached_entries
from optimum.neuron.configuration_utils import NeuronConfig
from optimum.neuron.utils.version_utils import get_neuronxcc_version
from optimum.neuron.utils import map_torch_dtype
logger = logging.getLogger(__name__)
tgi_router_env_vars = [
"MAX_BATCH_SIZE",
"MAX_TOTAL_TOKENS",
"MAX_INPUT_TOKENS",
"MAX_BATCH_PREFILL_TOKENS",
]
tgi_server_env_vars = ["HF_NUM_CORES", "HF_AUTO_CAST_TYPE"]
# By the end of this script all env var should be specified properly
tgi_env_vars = tgi_server_env_vars + tgi_router_env_vars
available_cores = get_available_cores()
neuronxcc_version = get_neuronxcc_version()
def parse_cmdline_and_set_env(argv: List[str] = None) -> argparse.Namespace:
parser = argparse.ArgumentParser()
if not argv:
argv = sys.argv
# All these are params passed to tgi and intercepted here
parser.add_argument(
"--max-input-tokens",
type=int,
default=os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0)),
)
parser.add_argument(
"--max-total-tokens", type=int, default=os.getenv("MAX_TOTAL_TOKENS", 0)
)
parser.add_argument(
"--max-batch-size", type=int, default=os.getenv("MAX_BATCH_SIZE", 0)
)
parser.add_argument(
"--max-batch-prefill-tokens",
type=int,
default=os.getenv("MAX_BATCH_PREFILL_TOKENS", 0),
)
parser.add_argument("--model-id", type=str, default=os.getenv("MODEL_ID"))
parser.add_argument("--revision", type=str, default=os.getenv("REVISION"))
args = parser.parse_known_args(argv)[0]
if not args.model_id:
raise Exception(
"No model id provided ! Either specify it using --model-id cmdline or MODEL_ID env var"
)
# Override env with cmdline params
os.environ["MODEL_ID"] = args.model_id
# Set all tgi router and tgi server values to consistent values as early as possible
# from the order of the parser defaults, the tgi router value can override the tgi server ones
if args.max_total_tokens > 0:
os.environ["MAX_TOTAL_TOKENS"] = str(args.max_total_tokens)
if args.max_input_tokens > 0:
os.environ["MAX_INPUT_TOKENS"] = str(args.max_input_tokens)
if args.max_batch_size > 0:
os.environ["MAX_BATCH_SIZE"] = str(args.max_batch_size)
if args.max_batch_prefill_tokens > 0:
os.environ["MAX_BATCH_PREFILL_TOKENS"] = str(args.max_batch_prefill_tokens)
if args.revision:
os.environ["REVISION"] = str(args.revision)
return args
def neuron_config_to_env(neuron_config):
if isinstance(neuron_config, NeuronConfig):
neuron_config = neuron_config.to_dict()
with open(os.environ["ENV_FILEPATH"], "w") as f:
f.write("export MAX_BATCH_SIZE={}\n".format(neuron_config["batch_size"]))
f.write("export MAX_TOTAL_TOKENS={}\n".format(neuron_config["sequence_length"]))
f.write("export HF_NUM_CORES={}\n".format(neuron_config["tp_degree"]))
config_key = (
"auto_cast_type" if "auto_cast_type" in neuron_config else "torch_dtype"
)
auto_cast_type = neuron_config[config_key]
f.write("export HF_AUTO_CAST_TYPE={}\n".format(auto_cast_type))
max_input_tokens = os.getenv("MAX_INPUT_TOKENS")
if not max_input_tokens:
max_input_tokens = int(neuron_config["sequence_length"]) // 2
if max_input_tokens == 0:
raise Exception("Model sequence length should be greater than 1")
f.write("export MAX_INPUT_TOKENS={}\n".format(max_input_tokens))
max_batch_prefill_tokens = os.getenv("MAX_BATCH_PREFILL_TOKENS")
if not max_batch_prefill_tokens:
max_batch_prefill_tokens = int(neuron_config["batch_size"]) * int(
max_input_tokens
)
f.write("export MAX_BATCH_PREFILL_TOKENS={}\n".format(max_batch_prefill_tokens))
def sort_neuron_configs(dictionary):
return -dictionary["tp_degree"], -dictionary["batch_size"]
def lookup_compatible_cached_model(
model_id: str, revision: Optional[str]
) -> Optional[Dict[str, Any]]:
# Reuse the same mechanic as the one in use to configure the tgi server part
# The only difference here is that we stay as flexible as possible on the compatibility part
entries = get_hub_cached_entries(model_id)
logger.debug(
"Found %d cached entries for model %s, revision %s",
len(entries),
model_id,
revision,
)
all_compatible = []
for entry in entries:
if check_env_and_neuron_config_compatibility(
entry, check_compiler_version=True
):
all_compatible.append(entry)
if not all_compatible:
logger.debug(
"No compatible cached entry found for model %s, env %s, available cores %s, neuronxcc version %s",
model_id,
get_env_dict(),
available_cores,
neuronxcc_version,
)
return None
logger.info("%d compatible neuron cached models found", len(all_compatible))
all_compatible = sorted(all_compatible, key=sort_neuron_configs)
entry = all_compatible[0]
return entry
def check_env_and_neuron_config_compatibility(
neuron_config_dict: Dict[str, Any], check_compiler_version: bool
) -> bool:
logger.debug(
"Checking the provided neuron config %s is compatible with the local setup and provided environment",
neuron_config_dict,
)
# Local setup compat checks
if neuron_config_dict["tp_degree"] > available_cores:
logger.debug(
"Not enough neuron cores available to run the provided neuron config"
)
return False
if (
check_compiler_version
and neuron_config_dict["neuronxcc_version"] != neuronxcc_version
):
logger.debug(
"Compiler version conflict, the local one (%s) differs from the one used to compile the model (%s)",
neuronxcc_version,
neuron_config_dict["neuronxcc_version"],
)
return False
batch_size = os.getenv("MAX_BATCH_SIZE", None)
if batch_size is not None and neuron_config_dict["batch_size"] < int(batch_size):
logger.debug(
"The provided MAX_BATCH_SIZE (%s) is higher than the neuron config batch size (%s)",
os.getenv("MAX_BATCH_SIZE"),
neuron_config_dict["batch_size"],
)
return False
max_total_tokens = os.getenv("MAX_TOTAL_TOKENS", None)
if max_total_tokens is not None and neuron_config_dict["sequence_length"] < int(
max_total_tokens
):
logger.debug(
"The provided MAX_TOTAL_TOKENS (%s) is higher than the neuron config sequence length (%s)",
max_total_tokens,
neuron_config_dict["sequence_length"],
)
return False
num_cores = os.getenv("HF_NUM_CORES", None)
if num_cores is not None and neuron_config_dict["tp_degree"] < int(num_cores):
logger.debug(
"The provided HF_NUM_CORES (%s) is higher than the neuron config tp degree (%s)",
num_cores,
neuron_config_dict["tp_degree"],
)
return False
auto_cast_type = os.getenv("HF_AUTO_CAST_TYPE", None)
if auto_cast_type is not None:
config_key = (
"auto_cast_type"
if "auto_cast_type" in neuron_config_dict
else "torch_dtype"
)
neuron_config_value = map_torch_dtype(str(neuron_config_dict[config_key]))
env_value = map_torch_dtype(auto_cast_type)
if env_value != neuron_config_value:
logger.debug(
"The provided auto cast type and the neuron config param differ (%s != %s)",
env_value,
neuron_config_value,
)
return False
max_input_tokens = int(
os.getenv("MAX_INPUT_TOKENS", os.getenv("MAX_INPUT_LENGTH", 0))
)
if max_input_tokens > 0:
if hasattr(neuron_config_dict, "max_context_length"):
sequence_length = neuron_config_dict["max_context_length"]
else:
sequence_length = neuron_config_dict["sequence_length"]
if max_input_tokens >= sequence_length:
logger.debug(
"Specified max input tokens is not compatible with config sequence length ( %s >= %s)",
max_input_tokens,
sequence_length,
)
return False
return True
def get_env_dict() -> Dict[str, str]:
d = {}
for k in tgi_env_vars:
d[k] = os.getenv(k)
return d
def get_neuron_config_for_model(
model_name_or_path: str, revision: Optional[str] = None
) -> NeuronConfig:
try:
neuron_config = NeuronConfig.from_pretrained(
model_name_or_path, revision=revision
)
except Exception as e:
logger.debug(
"NeuronConfig.from_pretrained failed for model %s, revision %s: %s",
model_name_or_path,
revision,
e,
)
neuron_config = None
if neuron_config is not None:
compatible = check_env_and_neuron_config_compatibility(
neuron_config.to_dict(), check_compiler_version=False
)
if not compatible:
env_dict = get_env_dict()
msg = (
"Invalid neuron config and env. Config {}, env {}, available cores {}, neuronxcc version {}"
).format(neuron_config, env_dict, available_cores, neuronxcc_version)
logger.error(msg)
raise Exception(msg)
else:
neuron_config = lookup_compatible_cached_model(model_name_or_path, revision)
return neuron_config
================================================
FILE: backends/neuron/tests/conftest.py
================================================
pytest_plugins = ["fixtures.model"]
================================================
FILE: backends/neuron/tests/fixtures/model.py
================================================
import copy
import logging
import subprocess
import sys
from tempfile import TemporaryDirectory
import os
import pytest
from transformers import AutoTokenizer
from optimum.neuron.cache import synchronize_hub_cache
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s [%(filename)s.%(funcName)s:%(lineno)d] %(message)s",
stream=sys.stdout,
)
logger = logging.getLogger(__file__)
OPTIMUM_CACHE_REPO_ID = "optimum-internal-testing/neuron-testing-cache"
# All model configurations below will be added to the neuron_model_config fixture
MODEL_CONFIGURATIONS = {
"llama": {
"model_id": "unsloth/Llama-3.2-1B-Instruct",
"export_kwargs": {
"batch_size": 4,
"sequence_length": 4096,
"num_cores": 2,
"auto_cast_type": "bf16",
},
},
"qwen2": {
"model_id": "Qwen/Qwen2.5-0.5B",
"export_kwargs": {
"batch_size": 4,
"sequence_length": 4096,
"num_cores": 2,
"auto_cast_type": "bf16",
},
},
"granite": {
"model_id": "ibm-granite/granite-3.1-2b-instruct",
"export_kwargs": {
"batch_size": 4,
"sequence_length": 4096,
"num_cores": 2,
"auto_cast_type": "bf16",
},
},
}
def export_model(model_id, export_kwargs, neuron_model_path):
export_command = [
"optimum-cli",
"export",
"neuron",
"-m",
model_id,
"--task",
"text-generation",
]
for kwarg, value in export_kwargs.items():
export_command.append(f"--{kwarg}")
export_command.append(str(value))
export_command.append(neuron_model_path)
logger.info(f"Exporting {model_id} with {export_kwargs}")
try:
subprocess.run(export_command, check=True)
except subprocess.CalledProcessError as e:
raise ValueError(f"Failed to export model: {e}")
@pytest.fixture(scope="session", params=MODEL_CONFIGURATIONS.keys())
def neuron_model_config(request):
"""Expose a pre-trained neuron model
The fixture exports a model locally and returns a dictionary containing:
- a configuration name,
- the original model id,
- the export parameters,
- the neuron model local path.
For each exposed model, the local directory is maintained for the duration of the
test session and cleaned up afterwards.
"""
config_name = request.param
model_config = copy.deepcopy(MODEL_CONFIGURATIONS[request.param])
model_id = model_config["model_id"]
export_kwargs = model_config["export_kwargs"]
with TemporaryDirectory() as neuron_model_path:
export_model(model_id, export_kwargs, neuron_model_path)
synchronize_hub_cache(cache_repo_id=OPTIMUM_CACHE_REPO_ID)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.save_pretrained(neuron_model_path)
del tokenizer
# Add dynamic parameters to the model configuration
model_config["neuron_model_path"] = neuron_model_path
# Also add model configuration name to allow tests to adapt their expectations
model_config["name"] = config_name
# Yield instead of returning to keep a reference to the temporary directory.
# It will go out of scope and be released only once all tests needing the fixture
# have been completed.
logger.info(f"{config_name} ready for testing ...")
os.environ["CUSTOM_CACHE_REPO"] = OPTIMUM_CACHE_REPO_ID
yield model_config
logger.info(f"Done with {config_name}")
@pytest.fixture(scope="module")
def neuron_model_path(neuron_model_config):
yield neuron_model_config["neuron_model_path"]
================================================
FILE: backends/neuron/tests/prune_test_models.py
================================================
from argparse import ArgumentParser
from huggingface_hub import HfApi
def main():
parser = ArgumentParser()
parser.add_argument("--yes", action="store_true", default=False)
args = parser.parse_args()
api = HfApi()
models = api.list_models(search="optimum-internal-testing/neuron-tgi-testing")
for model in models:
if args.yes:
delete = True
else:
answer = input(f"Do you want to delete {model.id} [y/N] ?")
delete = answer == "y"
if delete:
api.delete_repo(model.id)
print(f"Deleted {model.id}.")
if __name__ == "__main__":
main()
================================================
FILE: backends/neuron/tests/pytest.ini
================================================
[pytest]
asyncio_mode = auto
================================================
FILE: backends/neuron/tests/requirements.txt
================================================
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
text-generation >= 0.6.0
pytest >= 7.4.0
pytest-asyncio >= 0.21.1
requests < 2.32.0
docker >= 6.1.3
Levenshtein
================================================
FILE: backends/neuron/tests/server/helpers.py
================================================
from text_generation_server.generator import NeuronGenerator
from text_generation_server.pb.generate_pb2 import (
Batch,
NextTokenChooserParameters,
Request,
StoppingCriteriaParameters,
)
def create_request(
id: int,
inputs: str,
truncate: int = 0,
max_new_tokens: int = 20,
do_sample: bool = False,
top_k: int = 50,
top_p: float = 0.9,
temperature: float = 1.0,
seed: int = 42,
repetition_penalty: float = 1.0,
):
parameters = NextTokenChooserParameters(
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
seed=seed,
repetition_penalty=repetition_penalty,
)
stopping_parameters = StoppingCriteriaParameters(max_new_tokens=max_new_tokens)
return Request(
id=id,
inputs=inputs,
truncate=truncate,
parameters=parameters,
stopping_parameters=stopping_parameters,
)
def check_prefill(
input_text,
expected_token_id,
expected_token_text,
do_sample,
batch_size,
model_path,
):
"""Verify that a prefill for a single request generates the expected output."""
generator = NeuronGenerator.from_pretrained(model_path)
assert generator.model.batch_size >= batch_size
requests = []
max_new_tokens = 20
for i in range(batch_size):
requests.append(
create_request(
id=0,
inputs=input_text,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
)
)
# Let's be pessimistic when estimating max_tokens
batch_size * (len(input_text) + max_new_tokens)
max_length = generator.model.max_length
batch = Batch(
id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
)
generations, next_batch = generator.prefill(batch)
assert next_batch.size == batch_size
# Whatever was passed as max_tokens, the server will correct it
# because of static batching
assert next_batch.max_tokens == batch_size * max_length
assert len(generations) == batch_size
for g in generations:
tokens = g.tokens
assert tokens.ids == [expected_token_id]
assert tokens.texts == [expected_token_text]
def check_decode_single(
input_text, max_new_tokens, generated_text, do_sample, model_path
):
"""Verify that a decoding for a single request generates the expected output."""
generator = NeuronGenerator.from_pretrained(model_path)
request = create_request(
id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample
)
max_length = generator.model.max_length
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
generations, next_batch = generator.prefill(batch)
# We already generated one token: call decode max_new_tokens - 1 times
for _ in range(max_new_tokens - 1):
assert next_batch.size == 1
assert next_batch.max_tokens == max_length
assert len(generations) == 1
assert len(generations[0].tokens.ids) == 1
generations, next_batch = generator.decode([next_batch])
assert next_batch is None
assert len(generations) == 1
output = generations[0].generated_text
assert output.generated_tokens == max_new_tokens
assert output.finish_reason == 0
assert output.text == generated_text
def check_decode_multiple(model_path):
"""Verify that two requests added to the batch at different generation steps
generate the same outputs (continuous batching).
"""
generator = NeuronGenerator.from_pretrained(model_path)
assert generator.model.batch_size > 1
input_text = "Once upon a time"
max_new_tokens = 20
# Prefill a single request, remembering the generated token
tokens = {0: [], 1: []}
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens)
max_length = generator.model.max_length
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
generations, next_batch = generator.prefill(batch)
assert next_batch.size == 1
assert len(generations) == 1
g = generations[0]
tokens[g.request_id].append(g.tokens.ids[0])
assert len(tokens[0]) == 1
# Decode a few tokens
gen_tokens = 4
for _ in range(gen_tokens - 1):
generations, next_batch = generator.decode([next_batch])
assert len(generations) == 1
g = generations[0]
tokens[g.request_id].append(g.tokens.ids[0])
assert len(tokens[0]) == gen_tokens
assert next_batch.size == 1
# Add a second request
request = create_request(id=1, inputs=input_text, max_new_tokens=max_new_tokens)
batch = Batch(id=1, requests=[request], size=1, max_tokens=max_length)
generations, next_batch_1 = generator.prefill(batch)
assert next_batch_1.size == 1
# We should have generated only a single token
assert len(generations) == 1
g = generations[0]
tokens[g.request_id].append(g.tokens.ids[0])
assert len(tokens[0]) == gen_tokens
assert len(tokens[1]) == 1
# Decode more tokens until we reach the maximum for the first request
batches = [next_batch, next_batch_1]
for _ in range(max_new_tokens - gen_tokens):
generations, next_batch = generator.decode(batches)
for g in generations:
tokens[g.request_id].append(g.tokens.ids[0])
batches = [next_batch]
# Verify we now only have one pending request
assert next_batch.size == 1
assert len(tokens[0]) == max_new_tokens
assert len(tokens[1]) == max_new_tokens - gen_tokens + 1
# Verify we have the output for the first request
for g in generations:
if g.request_id == 0:
output = g.generated_text
assert output.text != ""
assert output.generated_tokens == max_new_tokens
generated_text = output.text
# Continue decoding until the end of the second request
for _ in range(gen_tokens - 1):
generations, next_batch = generator.decode([next_batch])
assert len(generations) == 1
g = generations[0]
tokens[g.request_id].append(g.tokens.ids[0])
assert next_batch is None
output = generations[0].generated_text
assert output.generated_tokens == max_new_tokens
assert tokens[0] == tokens[1]
assert output.text == generated_text
================================================
FILE: backends/neuron/tests/server/test_cached_model.py
================================================
import os
import pytest
from text_generation_server.generator import NeuronGenerator
from text_generation_server.model import fetch_model, is_cached
@pytest.fixture(scope="module")
def cached_model_id(neuron_model_config) -> str:
"""
Fixture to provide a cached model ID for testing.
This assumes the model is already cached in the local environment.
"""
export_kwargs = neuron_model_config["export_kwargs"]
os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"])
os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"])
os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"]
os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"])
yield neuron_model_config["model_id"]
os.environ.pop("MAX_BATCH_SIZE", None)
os.environ.pop("MAX_TOTAL_TOKENS", None)
os.environ.pop("HF_AUTO_CAST_TYPE", None)
os.environ.pop("HF_NUM_CORES", None)
def test_model_is_cached(cached_model_id):
assert is_cached(cached_model_id), f"Model {cached_model_id} is not cached"
def test_fetch_cached_model(cached_model_id: str):
model_path = fetch_model(cached_model_id)
assert os.path.exists(
model_path
), f"Model {cached_model_id} was not fetched successfully"
assert os.path.isdir(model_path), f"Model {cached_model_id} is not a directory"
def test_generator_from_cached_model(cached_model_id: str):
generator = NeuronGenerator.from_pretrained(model_id=cached_model_id)
assert generator is not None, "Generator could not be created from cached model"
assert generator.model is not None, "Generator model is not initialized"
assert generator.tokenizer is not None, "Generator tokenizer is not initialized"
================================================
FILE: backends/neuron/tests/server/test_continuous_batching.py
================================================
from helpers import create_request
from text_generation_server.generator import NeuronGenerator
from text_generation_server.pb.generate_pb2 import Batch
def test_continuous_batching_two_requests(neuron_model_config):
"""Verify that two requests added to the batch at different generation steps
generate the same outputs (continuous batching).
"""
neuron_model_path = neuron_model_config["neuron_model_path"]
generator = NeuronGenerator.from_pretrained(neuron_model_path)
assert generator.model.neuron_config.batch_size > 1
input_text = "Once upon a time"
max_new_tokens = 20
# Prefill a single request, remembering the generated token
tokens = {0: [], 1: []}
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens)
max_length = generator.model.neuron_config.sequence_length
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
generations, next_batch = generator.prefill(batch)
assert next_batch.size == 1
assert len(generations) == 1
g = generations[0]
tokens[g.request_id].append(g.tokens.ids[0])
assert len(tokens[0]) == 1
# Decode a few tokens
gen_tokens = 4
for _ in range(gen_tokens - 1):
generations, next_batch = generator.decode([next_batch])
assert len(generations) == 1
g = generations[0]
tokens[g.request_id].append(g.tokens.ids[0])
assert len(tokens[0]) == gen_tokens
assert next_batch.size == 1
# Add a second request
request = create_request(id=1, inputs=input_text, max_new_tokens=max_new_tokens)
batch = Batch(id=1, requests=[request], size=1, max_tokens=max_length)
generations, next_batch_1 = generator.prefill(batch)
assert next_batch_1.size == 1
# We should have generated only a single token
assert len(generations) == 1
g = generations[0]
tokens[g.request_id].append(g.tokens.ids[0])
assert len(tokens[0]) == gen_tokens
assert len(tokens[1]) == 1
# Decode more tokens until we reach the maximum for the first request
batches = [next_batch, next_batch_1]
for _ in range(max_new_tokens - gen_tokens):
generations, next_batch = generator.decode(batches)
for g in generations:
tokens[g.request_id].append(g.tokens.ids[0])
batches = [next_batch]
# Verify we now only have one pending request
assert next_batch.size == 1
assert len(tokens[0]) == max_new_tokens
assert len(tokens[1]) == max_new_tokens - gen_tokens + 1
# Verify we have the output for the first request
for g in generations:
if g.request_id == 0:
output = g.generated_text
assert output.text != ""
assert output.generated_tokens == max_new_tokens
generated_text = output.text
# Continue decoding until the end of the second request
for _ in range(gen_tokens - 1):
generations, next_batch = generator.decode([next_batch])
assert len(generations) == 1
g = generations[0]
tokens[g.request_id].append(g.tokens.ids[0])
assert next_batch is None
output = generations[0].generated_text
assert output.generated_tokens == max_new_tokens
assert tokens[0] == tokens[1]
assert output.text == generated_text
================================================
FILE: backends/neuron/tests/server/test_decode.py
================================================
from helpers import create_request
from text_generation_server.generator import NeuronGenerator
from text_generation_server.pb.generate_pb2 import Batch
def test_decode(neuron_model_config):
"""Verify that a decoding for a single request generates the expected output."""
config_name = neuron_model_config["name"]
neuron_model_path = neuron_model_config["neuron_model_path"]
generator = NeuronGenerator.from_pretrained(neuron_model_path)
for do_sample in [True, False]:
mode = "sample" if do_sample else "greedy"
print(f"{config_name}[{mode}]")
generated_text = _test_decode(config_name, generator, do_sample)
if not do_sample:
expected_text = {
"llama": " The world was holding its breath as the world's top scientists and engineers gathered at the secret underground facility",
"qwen2": " I was sitting in my room, staring at the clock, when a knock at the door. I",
"granite": "\n\nThis opening line is from George Orwell's dystopian novel, \"1",
}[config_name]
assert generated_text == expected_text
generator.clear()
def _test_decode(config_name, generator, do_sample):
input_text = (
"It was a bright cold day in April, and the clocks were striking thirteen."
)
max_new_tokens = 20
request = create_request(
id=0,
inputs=input_text,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=0.9,
)
max_length = generator.model.neuron_config.sequence_length
batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
generations, next_batch = generator.prefill(batch)
# We already generated one token: call decode max_new_tokens - 1 times
for _ in range(max_new_tokens - 1):
assert next_batch.size == 1
assert next_batch.max_tokens == max_length
assert len(generations) == 1
assert len(generations[0].tokens.ids) == 1
generations, next_batch = generator.decode([next_batch])
assert next_batch is None
assert len(generations) == 1
output = generations[0].generated_text
assert output.generated_tokens == max_new_tokens
assert output.finish_reason == 0
return output.text
================================================
FILE: backends/neuron/tests/server/test_generator_slot.py
================================================
import pytest
import torch
from text_generation_server.generator import Slot
from text_generation_server.pb.generate_pb2 import Request
from transformers import AutoTokenizer, GenerationConfig
TOKENIZERS = ["NousResearch/Llama-2-7b-hf", "gpt2"]
@pytest.fixture(params=TOKENIZERS)
def tokenizer(request):
t = AutoTokenizer.from_pretrained(request.param)
t.padding_side = "left"
t.pad_token_id = t.eos_token_id
return t
@pytest.mark.parametrize(
"input_text, generated_text",
[
[
"It was a bright cold day in April, and the clocks were striking thirteen.",
" Winston Smith, his chin nuzzled into his breast in an effort to escape the vile wind,"
" slipped quickly through the glass doors of Victory Mansions, though not quickly enough"
" to prevent a swirl of gritty dust from entering along with him.",
],
["This sentence is written in chinese:", "我很感谢你的热情"],
["Some text might contain a lot of emojis like 😃", "😍💪 👉 👀"],
],
ids=["spaces", "chinese-utf8", "emojis"],
)
def test_decode_streaming(tokenizer, input_text, generated_text):
slot = Slot(0, tokenizer)
request = Request(id=0, inputs=input_text)
slot.assign(0, request, GenerationConfig())
assert slot.cached_text == input_text
inputs = tokenizer(
input_text,
padding="max_length",
max_length=len(input_text) + 1,
return_tensors="pt",
)
input_ids = inputs["input_ids"][0]
attention_mask = inputs["attention_mask"][0]
generated_tokens = tokenizer(generated_text, add_special_tokens=False)["input_ids"]
# We need to regenerate the full text as the tokenizer might change it (extra spaces might be added)
all_input_ids = torch.cat([input_ids, torch.tensor(generated_tokens)])
full_text = tokenizer.decode(all_input_ids, skip_special_tokens=True)
regenerated_text = full_text[len(input_text) :]
# Initialize the slot with the inputs
slot.reset(input_ids, attention_mask, selector=None)
assert slot.generated_tokens == 0
# Simulate an iterative generation (i.e. don't call select and use known tokens instead)
decoded_text = ""
for i in range(len(generated_tokens)):
text = slot.append(generated_tokens[i])
assert slot.generated_tokens == i + 1
decoded_text += text
assert decoded_text == regenerated_text
================================================
FILE: backends/neuron/tests/server/test_info.py
================================================
from text_generation_server.generator import NeuronGenerator
def test_info(neuron_model_path):
generator = NeuronGenerator.from_pretrained(neuron_model_path)
info = generator.info
assert info.requires_padding is True
assert info.device_type == "xla"
assert info.window_size == 0
assert info.speculate == 0
================================================
FILE: backends/neuron/tests/server/test_prefill.py
================================================
from helpers import create_request
from text_generation_server.generator import NeuronGenerator
from text_generation_server.pb.generate_pb2 import Batch
def test_prefill(neuron_model_config):
"""Verify that a prefill for a single request generates the expected output."""
config_name = neuron_model_config["name"]
neuron_model_path = neuron_model_config["neuron_model_path"]
generator = NeuronGenerator.from_pretrained(neuron_model_path)
max_batch_size = 4
assert generator.model.neuron_config.batch_size >= max_batch_size
for num_requests in [1, max_batch_size]:
for do_sample in [True, False]:
mode = "sample" if do_sample else "greedy"
print(f"[{mode}]: {num_requests} requests")
_test_prefill(config_name, generator, num_requests, do_sample)
generator.clear()
def _test_prefill(config_name, generator, batch_size, do_sample):
requests = []
max_new_tokens = 20
input_text = (
"It was a bright cold day in April, and the clocks were striking thirteen."
)
for i in range(batch_size):
requests.append(
create_request(
id=i,
inputs=input_text,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
)
)
# Let's be pessimistic when estimating max_tokens
max_length = generator.max_prefill_length()
batch = Batch(
id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
)
generations, next_batch = generator.prefill(batch)
assert next_batch.size == batch_size
# Whatever was passed as max_tokens, the server will correct it
# because of static batching
assert next_batch.max_tokens == batch_size * max_length
assert len(generations) == batch_size
expectations = {
"llama": [578, " The"],
"qwen2": [358, " I"],
"granite": [203, "\n"],
}[config_name]
# Greedy mode should always generate the same output
if not do_sample:
for g in generations:
tokens = g.tokens
assert tokens.ids[0] == expectations[0]
assert tokens.texts[0] == expectations[1]
def test_prefill_truncate(neuron_model_config):
config_name = neuron_model_config["name"]
neuron_model_path = neuron_model_config["neuron_model_path"]
generator = NeuronGenerator.from_pretrained(neuron_model_path)
batch_size = generator.model.neuron_config.batch_size
# We apply truncation to all requests but the first one
truncate = [
None,
] + [i * 3 for i in range(1, batch_size)]
input_text = (
"Two gin-scented tears trickled down the sides of his nose."
" But it was all right, everything was all right, the struggle was finished."
" He had won the victory over himself. He loved Big Brother."
)
requests = []
for i in range(batch_size):
requests.append(create_request(id=i, inputs=input_text, truncate=truncate[i]))
max_length = generator.max_prefill_length()
batch = Batch(
id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length
)
generations, _ = generator.prefill(batch)
# Even if the input text is identical for all requests, the first generated token might
# be different because of the truncation
expectations = {
"llama": [" He", "iens", "\x08", " He"],
"qwen2": [" He", "<|endoftext|>", " ", " The"],
"granite": ["\n", "\n", "\n", "\n"],
}[config_name]
for i, g in enumerate(generations):
tokens = g.tokens
assert (
tokens.texts[0] == expectations[i]
), f"Request {i} expected [{expectations[i]}], got [{tokens.texts[0]}]"
================================================
FILE: backends/neuron/tests/test_entry_point.py
================================================
import os
import pytest
from tempfile import TemporaryDirectory
from optimum.neuron.models.inference.nxd.backend.config import NxDNeuronConfig
from optimum.neuron.utils import map_torch_dtype
from text_generation_server.tgi_env import (
get_neuron_config_for_model,
lookup_compatible_cached_model,
neuron_config_to_env,
)
def test_get_neuron_config_for_model(neuron_model_config):
neuron_model_path = neuron_model_config["neuron_model_path"]
export_kwargs = neuron_model_config["export_kwargs"]
os.environ["MAX_BATCH_SIZE"] = str(export_kwargs["batch_size"])
os.environ["MAX_TOTAL_TOKENS"] = str(export_kwargs["sequence_length"])
os.environ["HF_AUTO_CAST_TYPE"] = export_kwargs["auto_cast_type"]
os.environ["HF_NUM_CORES"] = str(export_kwargs["num_cores"])
neuron_config = get_neuron_config_for_model(neuron_model_path)
assert neuron_config is not None
assert neuron_config.batch_size == export_kwargs["batch_size"]
assert neuron_config.sequence_length == export_kwargs["sequence_length"]
assert neuron_config.tp_degree == export_kwargs["num_cores"]
if isinstance(neuron_config, NxDNeuronConfig):
assert map_torch_dtype(neuron_config.torch_dtype) == map_torch_dtype(
export_kwargs["auto_cast_type"]
)
else:
assert map_torch_dtype(neuron_config.auto_cast_type) == map_torch_dtype(
export_kwargs["auto_cast_type"]
)
@pytest.mark.parametrize("model_id", ["unsloth/Llama-3.2-1B-Instruct"])
def test_lookup_compatible_cached_model(model_id: str):
neuron_config = lookup_compatible_cached_model(model_id, None)
assert neuron_config is not None
def test_neuron_config_to_env(neuron_model_config) -> None:
neuron_model_path = neuron_model_config["neuron_model_path"]
neuron_config = get_neuron_config_for_model(neuron_model_path)
with TemporaryDirectory() as temp_dir:
os.environ["ENV_FILEPATH"] = os.path.join(temp_dir, "env.sh")
neuron_config_to_env(neuron_config)
with open(os.environ["ENV_FILEPATH"], "r") as env_file:
env_content = env_file.read()
assert f"export MAX_BATCH_SIZE={neuron_config.batch_size}" in env_content
assert (
f"export MAX_TOTAL_TOKENS={neuron_config.sequence_length}"
in env_content
)
assert f"export HF_NUM_CORES={neuron_config.tp_degree}" in env_content
if hasattr(neuron_config, "torch_dtype"):
auto_cast_type = str(map_torch_dtype(neuron_config.torch_dtype)).split(
"."
)[-1]
else:
auto_cast_type = neuron_config.auto_cast_type
assert f"export HF_AUTO_CAST_TYPE={auto_cast_type}" in env_content
================================================
FILE: backends/neuron/tgi-entrypoint.sh
================================================
#!/bin/bash
set -e -o pipefail -u
export ENV_FILEPATH=$(mktemp)
trap "rm -f ${ENV_FILEPATH}" EXIT
touch $ENV_FILEPATH
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
${SCRIPT_DIR}/tgi_entry_point.py $@
source $ENV_FILEPATH
exec text-generation-launcher $@
================================================
FILE: backends/neuron/tgi_entry_point.py
================================================
#!/usr/bin/env python
import logging
import os
import sys
from text_generation_server.tgi_env import (
available_cores,
get_env_dict,
get_neuron_config_for_model,
neuron_config_to_env,
neuronxcc_version,
parse_cmdline_and_set_env,
tgi_env_vars,
)
logger = logging.getLogger(__name__)
def main():
"""
This script determines proper default TGI env variables for the neuron precompiled models to
work properly
:return:
"""
args = parse_cmdline_and_set_env()
for env_var in tgi_env_vars:
if not os.getenv(env_var):
break
else:
logger.info(
"All env vars %s already set, skipping, user know what they are doing",
tgi_env_vars,
)
sys.exit(0)
neuron_config = get_neuron_config_for_model(args.model_id, args.revision)
if not neuron_config:
msg = (
"No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}"
).format(get_env_dict(), available_cores, neuronxcc_version)
logger.error(msg)
raise Exception(msg)
neuron_config_to_env(neuron_config)
if __name__ == "__main__":
main()
================================================
FILE: backends/trtllm/CMakeLists.txt
================================================
cmake_minimum_required(VERSION 3.20)
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
cmake_policy(SET CMP0135 NEW)
endif ()
project(tgi-trtllm-backend VERSION 1.0.0)
set(CMAKE_CXX_STANDARD 23)
include(FetchContent)
include(ExternalProject)
include(CheckCXXCompilerFlag)
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
option(TGI_TRTLLM_BACKEND_BUILD_USE_LLD "Enable lld linker instead of ld" OFF)
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located")
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located")
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
find_package(MPI REQUIRED)
#### External dependencies ####
include(cmake/json.cmake)
include(cmake/spdlog.cmake)
include(cmake/trtllm.cmake)
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(TGI_TRTLLM_BACKEND_DEBUG ON)
add_compile_definitions(TGI_TRTLLM_BACKEND_DEBUG=1)
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_TRACE)
endif ()
if (${TGI_TRTLLM_BACKEND_BUILD_USE_LLD})
message(STATUS "Using lld linker")
add_link_options("-fuse-ld=lld")
endif ()
# Let's build TRTLLM as part of CMake
add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..")
# Tell CMake to need try to override the RPATH for executorWorker as it has not information on how to do so
set_target_properties(executorWorker PROPERTIES SKIP_BUILD_RPATH TRUE)
# TGI TRTLLM Backend definition
add_library(tgi_trtllm_backend_impl STATIC csrc/hardware.hpp csrc/backend.hpp csrc/backend.cpp)
include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
target_include_directories(tgi_trtllm_backend_impl PRIVATE
$
# $
)
target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
target_link_libraries(tgi_trtllm_backend_impl PRIVATE CUDA::cudart CUDA::nvml)
target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog)
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper)
# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back
install(TARGETS tgi_trtllm_backend_impl)
#install(TARGETS cutlass_src fb_gemm_src fpA_intB_gemm_src gemm_swiglu_sm90_src kernels_src)
install(TARGETS decoder_attention_0 decoder_attention_1)
install(TARGETS tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention_src executorWorker)
install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} TYPE LIB)
if (NOT ${TGI_TRTLLM_BACKEND_DEBUG})
install(FILES ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB)
endif ()
#### Unit Tests ####
if (${TGI_TRTLLM_BACKEND_BUILD_TESTS} AND CMAKE_BUILD_TYPE MATCHES "Debug")
message(STATUS "Building tests")
option(TGI_TRTLLM_BACKEND_ENABLE_ASAN "Enable AddressSanitizer")
option(TGI_TRTLLM_BACKEND_ENABLE_UBSAN "Enable UndefinedSanitizer")
FetchContent_Declare(
Catch2
URL https://github.com/catchorg/Catch2/archive/refs/tags/v3.7.1.tar.gz
)
FetchContent_MakeAvailable(Catch2)
# This attempt to detect if the compiler can emit warning if it can't apply return value optimization from a function
check_cxx_compiler_flag("-Wnrvo" COMPILER_SUPPORT_WARNING_ON_NVRO)
if (${COMPILER_SUPPORT_WARNING_ON_NVRO})
message(STATUS "Enabling non-NVRO detection")
target_compile_options(tgi_trtllm_backend_impl PRIVATE -Wnrvo)
endif ()
target_compile_options(tgi_trtllm_backend_impl PRIVATE -Wall)
cmake_path(GET TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH PARENT_PATH TRTLLM_NVRTC_WRAPPER_PARENT_LIBRARY_PATH)
message(STATUS "Adding linking path: ${TRTLLM_NVRTC_WRAPPER_PARENT_LIBRARY_PATH}")
add_executable(tgi_trtllm_backend_tests tests/test_hardware.cpp tests/test_backend.cpp)
# target_compile_options(tgi_trtllm_backend_tests PRIVATE -Werror)
target_link_directories(tgi_trtllm_backend_tests PRIVATE "${TRTLLM_NVRTC_WRAPPER_PARENT_LIBRARY_PATH}")
target_include_directories(tgi_trtllm_backend_tests PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
target_include_directories(tgi_trtllm_backend_tests PUBLIC "csrc/")
target_link_libraries(tgi_trtllm_backend_tests PRIVATE ${TRTLLM_LIBS} CUDA::cudart CUDA::nvml)
target_link_libraries(tgi_trtllm_backend_tests PUBLIC Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog tgi_trtllm_backend_impl)
target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper)
if (${TGI_TRTLLM_BACKEND_ENABLE_ASAN})
message(STATUS "Enabled AddressSanitizer")
target_link_options(tgi_trtllm_backend_tests BEFORE PUBLIC -fsanitize=address)
endif ()
if (${TGI_TRTLLM_BACKEND_ENABLE_UBSAN})
message(STATUS "Enabled UndefinedSanitizer")
target_link_options(tgi_trtllm_backend_tests BEFORE PUBLIC -fsanitize=undefined)
endif ()
install(TARGETS tgi_trtllm_backend_tests)
# list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
# include(CTest)
# include(Catch)
# catch_discover_tests(tgi_trtllm_backend_tests)
endif ()
================================================
FILE: backends/trtllm/Cargo.toml
================================================
[package]
name = "text-generation-backends-trtllm"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true
[dependencies]
async-trait = "0.1"
clap = { version = "4.5", features = ["derive"] }
cxx = "1.0"
hashbrown = "0.15"
hf-hub = { workspace = true }
text-generation-router = { path = "../../router" }
tokenizers = { workspace = true }
tokio = { version = "1.43.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.17"
thiserror = "1.0.63"
tracing = "0.1"
pyo3 = { workspace = true }
[build-dependencies]
cmake = "0.1"
cxx-build = { version = "1.0", features = ["parallel"] }
pkg-config = "0.3"
================================================
FILE: backends/trtllm/README.md
================================================
# Text Generation Inference - TensorRT-LLM Backend Implementation
## Description
This folder provides the sources of the TensorRT-LLM backend implementation powered by TensorRT-LLM Executor new API
## Simplified Request Sequence
```mermaid
sequenceDiagram
actor User
participant TextGenerationInference.HttpServer
participant TextGenerationInference.TensorRtLlmBackend
participant TextGenerationInference.TensorRtLlmWorkerThread
participant TensorRtLlm.Executor
participant Nvidia.Gpu
User ->> TextGenerationInference.HttpServer: POST /generate
TextGenerationInference.HttpServer ->> TextGenerationInference.TensorRtLlmBackend: Validate and forward inputs & parameters
TextGenerationInference.TensorRtLlmBackend ->> TextGenerationInference.TensorRtLlmWorkerThread: Allocate a new context and spawn a new thread to handle the request
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Submit the request to the In-Flight Batcher
activate Nvidia.Gpu
TensorRtLlm.Executor ->> Nvidia.Gpu: Add the request to the poll for execution
TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Response with an unique request identifier
rect rgb(10, 92, 54)
loop every 100us
rect rgb(15, 81, 50)
alt Acquire lock to query executor
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Poll request number of new token(s) generated
else There are new generated tokens
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Retrieve newly generated tokens
TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Return decoded token information and potential error (omitted)
rect rgb(11, 110, 79)
alt Generated token is final
TensorRtLlm.Executor ->> Nvidia.Gpu: Remove request from the scheduler and from the GPU
TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream the remaining decoded tokens and flush the connection
else Generated token is not final
TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream token back to the user as they get decoded
end
end
end
end
deactivate Nvidia.Gpu
end
end
```
================================================
FILE: backends/trtllm/build.rs
================================================
use cxx_build::CFG;
use pkg_config;
use std::env;
use std::env::consts::ARCH;
use std::path::{absolute, PathBuf};
use std::sync::LazyLock;
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 1] = ["spdlog"];
const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST");
const CUDA_REQUIRED_VERSION: &str = "12.8";
const MPI_REQUIRED_VERSION: &str = "4.1";
const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX");
const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR");
const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR");
const IS_GHA_BUILD: LazyLock = LazyLock::new(|| {
option_env!("SCCACHE_GHA_ENABLED").map_or(false, |value| match value.to_lowercase().as_str() {
"on" => true,
"true" => true,
"1" => true,
_ => false,
})
});
// Dependencies
const BACKEND_DEPS: &str = "tgi_trtllm_backend_impl";
const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"];
const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [
("dylib", "tensorrt_llm"),
("dylib", "tensorrt_llm_nvrtc_wrapper"),
("dylib", "nvinfer_plugin_tensorrt_llm"),
("dylib", "decoder_attention_0"),
("dylib", "decoder_attention_1"),
];
macro_rules! probe {
($name: expr, $version: expr) => {
if let Err(_) = pkg_config::probe_library($name) {
pkg_config::probe_library(&format!("{}-{}", $name, $version))
.expect(&format!("Failed to locate {}", $name));
}
};
}
fn get_compiler_flag(
switch: bool,
true_case: &'static str,
false_case: &'static str,
) -> &'static str {
match switch {
true => true_case,
false => false_case,
}
}
fn get_library_architecture() -> &'static str {
let os = env::var("CARGO_CFG_TARGET_OS").unwrap();
let arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap();
let env = env::var("CARGO_CFG_TARGET_ENV").unwrap();
match os.as_str() {
"linux" => {
if env != "gnu" {
panic!("unsupported linux ABI {env}, only 'gnu' is supported")
}
match arch.as_str() {
"x86_64" => "x86_64-linux-gnu",
"aarch64" => "aarch64-linux-gnu",
_ => panic!("unsupported linux architecture {arch}"),
}
}
"windows" => {
if env != "msvc" {
panic!("unsupported windows ABI {env}, only 'msvc' is supported")
}
match arch.as_str() {
"x86_64" => "x86_64-windows-msvc",
_ => panic!("unsupported windows architecture {arch}"),
}
}
_ => panic!("unsupported OS {os}"),
}
}
fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) {
// Build the backend implementation through CMake
let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt");
let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("75-real;80-real;86-real;89-real;90-real");
let mut install_path = PathBuf::from(install_path);
if !install_path.is_absolute() {
install_path = absolute(out_dir).expect("cannot happen").join(install_path);
}
let mut config = cmake::Config::new(".");
config
.uses_cxx11()
.generator("Ninja")
.profile(match is_debug {
true => "Debug",
false => "Release",
})
.env("OPT_LEVEL", opt_level)
.define("CMAKE_INSTALL_PREFIX", &install_path)
.define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc")
.define("CMAKE_LIBRARY_ARCHITECTURE", get_library_architecture())
.define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list)
.define(
"TGI_TRTLLM_BACKEND_DEBUG",
get_compiler_flag(is_debug, "ON", "OFF"),
)
.define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path);
if is_debug || *IS_GHA_BUILD {
config.define("TGI_TRTLLM_BACKEND_BUILD_TESTS", "ON");
}
if option_env!("USE_LLD_LINKER").is_some() {
println!("cargo:warning=Using lld linker");
config.define("TGI_TRTLLM_BACKEND_BUILD_USE_LLD", "ON");
}
if (is_debug && option_env!("ENABLE_ASAN").is_some()) || *IS_GHA_BUILD {
println!("cargo:warning=Enabling Address Sanitizer");
config.define("TGI_TRTLLM_BACKEND_ENABLE_ASAN", "ON");
}
if (is_debug && option_env!("ENABLE_UBSAN").is_some()) || *IS_GHA_BUILD {
println!("cargo:warning=Enabling Undefined Sanitizer");
config.define("TGI_TRTLLM_BACKEND_ENABLE_UBSAN", "ON");
}
if let Some(nvcc_host_compiler) = option_env!("CMAKE_CUDA_HOST_COMPILER") {
config.define("CMAKE_CUDA_HOST_COMPILER", nvcc_host_compiler);
}
if let Some(wrapper) = option_env!("RUSTC_WRAPPER") {
println!("cargo:warning=Using caching tool: {wrapper}");
config.define("CMAKE_C_COMPILER_LAUNCHER", wrapper);
config.define("CMAKE_CXX_COMPILER_LAUNCHER", wrapper);
config.define("CMAKE_CUDA_COMPILER_LAUNCHER", wrapper);
}
// Allow to override which Python to use ...
if let Some(python3) = option_env!("Python3_EXECUTABLE") {
config.define("Python3_EXECUTABLE", python3);
}
config.build();
// Additional transitive CMake dependencies
let deps_folder = out_dir.join("build").join("_deps");
for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES {
let dep_name = match is_debug {
true => format!("{}d", dependency),
false => String::from(dependency),
};
let dep_path = deps_folder.join(format!("{}-build", dependency));
println!("cargo:rustc-link-search={}", dep_path.display());
println!("cargo:rustc-link-lib=static={}", dep_name);
}
// Emit linkage information from the artifacts we just built
for path in ["lib", "lib64"] {
let install_lib_path = install_path.join(path);
println!(
r"cargo:warning=Adding link search path: {}",
install_lib_path.display()
);
println!(r"cargo:rustc-link-search={}", install_lib_path.display());
}
(PathBuf::from(install_path), deps_folder)
}
fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) {
CFG.include_prefix = "backends/trtllm";
cxx_build::bridge("src/lib.rs")
.static_flag(true)
.std("c++23")
.include(deps_folder.join("spdlog-src").join("include"))
.include(deps_folder.join("json-src").join("include"))
.include(deps_folder.join("trtllm-src").join("cpp").join("include"))
.include("/usr/local/cuda/include")
.include("/usr/local/tensorrt/include")
.include("csrc/")
.file("csrc/ffi.hpp")
.define(
"TGI_TRTLLM_BACKEND_DEBUG",
get_compiler_flag(is_debug, "ON", "OFF"),
)
.compile("tgi_trtllm_backend");
println!("cargo:rerun-if-changed=CMakeLists.txt");
println!("cargo:rerun-if-changed=cmake/trtllm.cmake");
println!("cargo:rerun-if-changed=cmake/json.cmake");
println!("cargo:rerun-if-changed=cmake/spdlog.cmake");
println!("cargo:rerun-if-changed=csrc/backend.hpp");
println!("cargo:rerun-if-changed=csrc/backend.cpp");
println!("cargo:rerun-if-changed=csrc/hardware.hpp");
println!("cargo:rerun-if-changed=csrc/ffi.hpp");
}
fn main() {
// Misc variables
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let build_profile = env::var("PROFILE").unwrap();
let (is_debug, opt_level) = match build_profile.as_ref() {
"debug" => (true, "0"),
"dev" => (true, "0"),
_ => (false, "3"),
};
// Build the backend
let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir);
// Build the FFI layer calling the backend above
build_ffi_layer(&deps_folder, is_debug);
// Emit linkage search path
probe!("ompi", MPI_REQUIRED_VERSION);
// Probe CUDA & co. with pkg-config
CUDA_TRANSITIVE_DEPS.iter().for_each(|name| {
probe!(name, CUDA_REQUIRED_VERSION);
});
// NCCL is slightly trickier because it might not have a pkgconfig installed
let nccl_library_path_default = format!("/usr/local/{}-linux-gnu", ARCH);
let nccl_library_path = NCCL_ROOT_DIR.unwrap_or(&nccl_library_path_default);
println!(r"cargo:rustc-link-search=native={}", nccl_library_path);
println!("cargo:rustc-link-lib=dylib=nccl");
// TensorRT
let tensort_library_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt/lib");
println!(r"cargo:rustc-link-search=native={}", tensort_library_path);
println!("cargo:rustc-link-lib=dylib=nvinfer");
// TensorRT-LLM
TENSORRT_LLM_TRANSITIVE_DEPS
.iter()
.for_each(|(link_type, name)| {
println!("cargo:rustc-link-lib={}={}", link_type, name);
});
// Backend
println!("cargo:rustc-link-lib=static={}", &BACKEND_DEPS);
}
================================================
FILE: backends/trtllm/cmake/json.cmake
================================================
fetchcontent_declare(
json
# DOWNLOAD_EXTRACT_TIMESTAMP
URL https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.tar.gz
)
fetchcontent_makeavailable(json)
================================================
FILE: backends/trtllm/cmake/spdlog.cmake
================================================
set(SPDLOG_USE_FMT ON)
set(SPDLOG_BUILD_SHARED OFF)
set(SPDLOG_FMT_EXTERNAL OFF)
# Define the level at which SPDLOG_ compilation level is defined
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_TRACE)
else ()
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG)
endif ()
fetchcontent_declare(
spdlog
# DOWNLOAD_EXTRACT_TIMESTAMP
URL https://github.com/gabime/spdlog/archive/refs/tags/v1.15.0.tar.gz
)
fetchcontent_makeavailable(spdlog)
================================================
FILE: backends/trtllm/cmake/trtllm.cmake
================================================
set(TRT_INCLUDE_DIR ${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
set(TRT_LIB_DIR ${TGI_TRTLLM_BACKEND_TRT_LIB_DIR})
set(USE_CXX11_ABI ON)
set(BUILD_PYT OFF)
set(BUILD_PYBIND OFF)
set(BUILD_MICRO_BENCHMARKS OFF)
set(BUILD_BENCHMARKS OFF)
set(BUILD_TESTS OFF)
set(CMAKE_CUDA_ARCHITECTURES ${TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST})
message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}")
set(ENABLE_UCX OFF)
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
set(FAST_BUILD ON)
set(NVTX_DISABLE ON)
set(INDEX_RANGE_CHECK ON)
else ()
set(FAST_BUILD OFF)
set(FAST_MATH ON)
set(NVTX_DISABLE OFF)
set(INDEX_RANGE_CHECK OFF)
endif ()
find_package(Python3 REQUIRED Interpreter)
fetchcontent_declare(
trtllm
GIT_REPOSITORY https://github.com/nvidia/TensorRT-LLM.git
GIT_TAG v0.17.0
GIT_SHALLOW ON
DOWNLOAD_EXTRACT_TIMESTAMP
)
fetchcontent_makeavailable(trtllm)
message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}")
execute_process(COMMAND git lfs install WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/")
execute_process(COMMAND git lfs pull WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/")
# TRTLLM use a JIT based *precompiled* library to generate some specific kernels, we are generating the path to this one here
set(TRTLLM_NVRTC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_nvrtc_wrapper${CMAKE_SHARED_LIBRARY_SUFFIX}" CACHE INTERNAL "nvrtc wrapper library name")
set(TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_NVRTC_LIBRARY_NAME}"
CACHE INTERNAL "nvrtc wrapper library path")
# The same Executor Static library
set(TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_executor_static${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE INTERNAL "executor_static library name")
set(TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/executor/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME}" CACHE INTERNAL "executor_static library path")
================================================
FILE: backends/trtllm/cmake/utils/detect_cuda_arch.cu
================================================
================================================
FILE: backends/trtllm/csrc/backend.cpp
================================================
#include
#include
#include "backend.hpp"
#include "hardware.hpp"
namespace huggingface::tgi::backends::trtllm {
tle::ParallelConfig backend_workspace_t::parallel_config() const {
// Single engine (TP = PP = 1) -> using leader mode (no MPI involved)
const auto world_size = config_["/pretrained_config/mapping/world_size"_json_pointer].get();
auto mode = tle::CommunicationMode::kLEADER;
std::optional orchestratorConfig = std::nullopt;
if (world_size > 1) {
SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
mode = tle::CommunicationMode::kORCHESTRATOR;
orchestratorConfig = std::make_optional(true, executor_worker_path_, nullptr,
true);
} else {
SPDLOG_INFO("Detected single engine deployment, using leader mode");
}
return tle::ParallelConfig(tle::CommunicationType::kMPI, mode, std::nullopt, std::nullopt, orchestratorConfig);
}
tle::ExecutorConfig backend_workspace_t::executor_config() const {
// Retrieve the compute capabilities to enable some options at runtime
const auto compute_capabilities = hardware::cuda::compute_capabilities_t();
// Allocate the config
tle::ExecutorConfig executor_config(/* maxBeamWidth = */ 1);
// Set the parallel config as inferred
executor_config.setParallelConfig(parallel_config());
// Define some configuration variables
executor_config.setKvCacheConfig(tle::KvCacheConfig(true));
executor_config.setEnableChunkedContext(compute_capabilities.is_at_least_ampere());
executor_config.setSchedulerConfig(tle::SchedulerConfig(tle::CapacitySchedulerPolicy::kMAX_UTILIZATION));
return executor_config;
}
backend_t::backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path)
: workspace(engines_folder, executor_worker_path), executor_(executor_factory_initializer(workspace)) {}
size_t backend_t::num_tokens_ready() const noexcept {
return executor_.getNumResponsesReady();
}
std::expected
backend_t::submit(std::span token_ids, const generation_params_t g_params,
const sampling_params_t s_params) noexcept {
SPDLOG_DEBUG("Submit {:d} tokens for scheduling ({}, {})", token_ids.size(), g_params, s_params);
return executor_.enqueueRequest(tle::Request{
{token_ids.begin(), token_ids.end()}, // Making actual copy of the tokens
static_cast(g_params.max_new_tokens),
true,
(tle::SamplingConfig) s_params,
tle::OutputConfig{ /* returnLogProbs= */ true},
std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt,
workspace.generation_config().stop_words
});
}
std::vector backend_t::pull_tokens() noexcept {
SPDLOG_TRACE(FMT_STRING("Pulling out tokens ({:d} available)"), num_tokens_ready());
return executor_.awaitResponses();
}
void backend_t::cancel(request_id_t request_id) noexcept {
SPDLOG_TRACE(FMT_STRING("Cancelling request: {:d}"), request_id);
executor_.cancelRequest(request_id);
}
}
================================================
FILE: backends/trtllm/csrc/backend.hpp
================================================
#ifndef TGI_BACKEND_TRTLLM
#define TGI_BACKEND_TRTLLM
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
namespace huggingface::tgi::backends::trtllm {
namespace tle = tensorrt_llm::executor;
using json = nlohmann::json;
using request_id_t = uint64_t;
using token_id_t = tle::TokenIdType;
/**
* Represent the parameters used for generation
*/
struct generation_params_t {
uint32_t max_new_tokens;
};
/**
* Represent the parameters used to sample tokens from the logit distribution
*/
struct sampling_params_t {
uint32_t top_k;
float_t top_p;
float_t repetition_penalty;
float_t frequency_penalty;
float_t temperature;
uint64_t seed;
constexpr explicit operator tle::SamplingConfig() const {
return tle::SamplingConfig{
1,
top_k,
top_p,
std::nullopt,
std::nullopt,
std::nullopt,
seed,
temperature,
std::nullopt,
std::nullopt,
repetition_penalty,
std::nullopt,
frequency_penalty,
std::nullopt
};
}
};
/**
* Represent possible values from transformers generation `generation_config.json`.
* It usually stores default sampling parameters to use, such as top_p, temperature, etc.
*/
struct generation_config_t {
float_t top_p;
float_t temperature;
std::list> stop_words;
constexpr explicit generation_config_t(const json &config) :
top_p(config.value("top_p", 1.0f)), temperature(config.value("temperature", 1.0f)), stop_words(0) {
if (config.contains("/eos_token_id"_json_pointer) && config["/eos_token_id"_json_pointer].is_array()) {
const auto &eos_token_id = config["/eos_token_id"_json_pointer];
std::for_each(eos_token_id.begin(), eos_token_id.end(), [this](const auto token_id) {
stop_words.emplace_back(1, token_id.template get());
});
SPDLOG_DEBUG("Detected {:d} predefined stop_words from generation_config.json", stop_words.size());
}
}
};
/**
* Helper class representing various items which are stored within the TensorRT-LLM engines folder and
* can be retrieved at runtime
*/
class backend_workspace_t {
private:
constexpr static auto as_json = [](const std::filesystem::path &path) -> json {
std::ifstream config_f(path);
return json::parse(config_f);
};
std::filesystem::path engines_folder_;
std::filesystem::path executor_worker_path_;
json config_;
generation_config_t generation_config_;
public:
backend_workspace_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path) :
engines_folder_(engines_folder),
executor_worker_path_(executor_worker_path),
config_(as_json(engines_folder / "config.json")),
generation_config_(as_json(engines_folder / "generation_config.json")) {};
backend_workspace_t(std::filesystem::path &&engines_folder, std::filesystem::path &&executor_worker_path) :
engines_folder_(engines_folder),
executor_worker_path_(executor_worker_path),
config_(as_json(engines_folder / "config.json")),
generation_config_(as_json(engines_folder / "generation_config.json")) {};
/**
* Path to the folder containing the TensorRT-LLM engines
* @return local filesystem path to the folder
*/
[[nodiscard]] constexpr std::filesystem::path engines_folder() const { return engines_folder_; }
/**
* Hugging Face transformers' generated `generation_config_t` mapping information stored in the
* `generation_config.json` holding default generation parameters.
* @return `generation_config_t`
*/
[[nodiscard]] constexpr const generation_config_t &generation_config() const { return generation_config_; }
/**
* Factory method returning new `tensorrt_llm::executor::ParallelConfig` instance used
* to initialize `tensorrt_llm::executor::Executor` with multi-instance communication information
* @return `tensorrt_llm::executor::ParallelConfig` instance
*/
[[nodiscard]] tle::ParallelConfig parallel_config() const;
/**
* Factory method returning new `tensorrt_llm::executor::ExecutorConfig` instance used
* to initialize `tensorrt_llm::executor::Executor`
* @return `tensorrt_llm::executor::ExecutorConfig` instance
*/
[[nodiscard]] tle::ExecutorConfig executor_config() const;
};
/**
* Error raised by the underlying backend implementation
*/
enum backend_error_t {
EXECUTOR_NOT_READY = 3,
EXECUTOR_SCHEDULING_FAILED = 4,
};
/**
* Actual TensorRT-LLM backend implementation interacting with TensorRT-LLM Executor service to
* - schedule new request
* - pull status of submitted request(s)
* - cancel submitted request(s)
*/
class backend_t {
private:
backend_workspace_t workspace;
tle::Executor executor_;
public:
backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path);
backend_t(std::filesystem::path &&engines_folder, std::filesystem::path &&executor_worker_path)
: backend_t(engines_folder, executor_worker_path) {};
/**
* Submit a new request to the executor
* @param token_ids
* @param generation_params
* @param sampling_params
* @return Either newly submitted request's id or the error why it failed to submit
*/
[[nodiscard("Discarded executor request_id needs to be assigned")]]
std::expected
submit(std::span token_ids, generation_params_t generation_params,
sampling_params_t sampling_params) noexcept;
/**
* Query the number of tokens available across all in-flight generations
* @return
*/
[[nodiscard("Pulling out the number of tokens")]]
size_t num_tokens_ready() const noexcept;
/**
* Pull out newly generated tokens from the executor
* @return
*/
[[nodiscard("")]]
std::vector pull_tokens() noexcept;
/**
* Cancel the specified request on the executor' set
* @param request_id Request's Identifier to remove from the in-flight executor
*/
void cancel(request_id_t) noexcept;
};
/**
* Create a TensorRT-LLM executor from a workspace
*/
const auto executor_factory_initializer = [](const backend_workspace_t &workspace) -> tle::Executor {
return {workspace.engines_folder(), tensorrt_llm::executor::ModelType::kDECODER_ONLY,
workspace.executor_config()};
};
}
/**
* Helper structures to define formatting strategies for various types in the backend
*/
template<>
struct fmt::formatter : formatter {
auto format(huggingface::tgi::backends::trtllm::generation_params_t const &c,
format_context &ctx) const -> format_context::iterator {
return fmt::format_to(ctx.out(), "generation_params_t{{ max_new_tokens={:d} }}", c.max_new_tokens);
}
};
template<>
struct fmt::formatter : formatter {
auto format(huggingface::tgi::backends::trtllm::sampling_params_t const &c,
format_context &ctx) const -> format_context::iterator {
return fmt::format_to(
ctx.out(),
"sampling_params_t{{ top_k={:d}, top_p={:.3f}, repetition_penalty={:.3f}, frequency_penalty={:.3f}, temperature={:.3f}, seed={:d} }}",
c.top_k, c.top_p, c.repetition_penalty, c.frequency_penalty, c.temperature, c.seed
);
}
};
#endif
================================================
FILE: backends/trtllm/csrc/ffi.hpp
================================================
#ifndef TGI_BACKEND_TRTLLM_FFI
#define TGI_BACKEND_TRTLLM_FFI
#include
#include
#include
#include
#include
#include
#include
#include
namespace rust::behavior {
template
static void trycatch(Try &&func, Fail &&fail) noexcept try {
func();
} catch (tensorrt_llm::common::TllmException &e) {
fail(e.what());
}
}
namespace huggingface::tgi::backends::trtllm {
class tensorrt_llm_backend_t;
}
#include "backends/trtllm/src/lib.rs.h"
namespace huggingface::tgi::backends::trtllm {
std::once_flag backend_initialized_flag;
constexpr finish_reason_t as_finish_reason_t(const tle::FinishReason reason) noexcept {
switch (reason) {
case tle::FinishReason::kNOT_FINISHED:
return finish_reason_t::kNOT_FINISHED;
case tle::FinishReason::kSTOP_WORDS:
return finish_reason_t::kSTOP_WORDS;
case tle::FinishReason::kEND_ID:
return finish_reason_t::kEND_ID;
case tle::FinishReason::kLENGTH:
return finish_reason_t::kLENGTH;
default:
std::unreachable();
}
}
static auto as_generation_step = [](const tle::Response &r) {
const auto reqId = r.getRequestId();
if (!r.hasError()) [[likely]] {
const auto result = r.getResult();
const auto logits = result.logProbs.value()[0];
return generation_step_t{
reqId,
static_cast(result.outputTokenIds[0][0]),
logits.back(),
result.isFinal,
as_finish_reason_t(result.finishReasons[0]),
false,
std::string()
};
} else {
return generation_step_t{
reqId,
0,
0.0,
true,
finish_reason_t::kNOT_FINISHED,
true,
std::move(r.getErrorMsg())
};
}
};
class tensorrt_llm_backend_t {
private:
backend_t inner_;
public:
tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path)
: inner_(engine_folder, executor_worker_path) {}
size_t num_tokens_ready() const noexcept { return inner_.num_tokens_ready(); }
request_id_t submit(
rust::Slice tokens,
uint32_t max_new_tokens,
uint32_t top_k,
float_t top_p,
float_t temperature,
float_t repetition_penalty,
float_t frequency_penalty,
uint64_t seed
) {
// This is enabled only if using add_compile_definitions(SPDLOG_ACTIVE_LEVEL=SPDLOG_LEVEL_TRACE)
SPDLOG_TRACE(FMT_STRING("[FFI] Submitting {:d} prompt tokens to the executor"));
// Submit the request to the executor and get back a potential request_id used to track request status
const auto signed_tokens = std::vector(tokens.begin(), tokens.end());
const auto maybe_request_id = inner_.submit(
signed_tokens,
{max_new_tokens},
{top_k, top_p, repetition_penalty, frequency_penalty, temperature, seed}
);
// If we do have a value, let's return the request_id
if (maybe_request_id.has_value()) [[likely]] {
return *maybe_request_id;
} else {
SPDLOG_WARN("[FFI] Failed to submit request to the executor");
return maybe_request_id.error();
}
}
std::unique_ptr> pull_tokens() noexcept {
if (num_tokens_ready() > 0) [[likely]] {
const auto responses = inner_.pull_tokens();
SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size());
// Transform tle::Response to generation_step_t
#ifdef __cpp_lib_ranges_to_container
auto steps = responses | std::views::transform(as_generation_step) | std::ranges::to();
#else
auto steps = std::vector();
steps.reserve(responses.size());
std::transform(responses.begin(), responses.end(), std::back_inserter(steps), as_generation_step);
#endif
return std::make_unique>(steps);
} else {
return std::make_unique>();
}
}
void cancel(request_id_t request_id) noexcept {
SPDLOG_DEBUG("[FFI] cancelling request {:d}", request_id);
inner_.cancel(request_id);
}
};
void initialize_logging() {
#ifndef TGI_TRTLLM_BACKEND_DEBUG
if (const auto TRTLLM_LOG_LEVEL_CSTR = std::getenv("TRTLLM_LOG_LEVEL")) {
std::string log_level(TRTLLM_LOG_LEVEL_CSTR);
std::transform(log_level.begin(), log_level.end(), log_level.begin(), [](unsigned char c) {
return std::tolower(c);
});
if (log_level == "debug")
spdlog::set_level(spdlog::level::debug);
else
spdlog::set_level(spdlog::level::info);
}
#else
spdlog::set_level(spdlog::level::debug);
#endif
}
void initialize_tensorrt_llm_backend() {
SPDLOG_INFO("Initializing TGI - TensoRT-LLM Backend (v{})", tle::version());
// Initialize everyone
initialize_logging();
nvmlInit_v2();
initTrtLlmPlugins();
const auto numGpus = huggingface::tgi::hardware::cuda::get_device_count();
if (numGpus.has_value()) {
SPDLOG_INFO("[FFI] Detected {:d} Nvidia GPU(s)", *numGpus);
} else {
SPDLOG_WARN("[FFI] Failed to detected Nvidia GPU(s) on the system");
// todo: throw
}
}
std::unique_ptr
create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) {
std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend);
return std::make_unique(
std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()),
std::filesystem::path::format::auto_format),
std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()),
std::filesystem::path::format::auto_format)
);
}
}
#endif
================================================
FILE: backends/trtllm/csrc/hardware.hpp
================================================
#ifndef TGI_HARDWARE_CUDA
#define TGI_HARDWARE_CUDA
#include
#include
#include
namespace huggingface::tgi::hardware::cuda {
static constexpr auto VOLTA = std::make_tuple(7u, 0u);
static constexpr auto TURING = std::make_tuple(7u, 5u);
static constexpr auto AMPERE = std::make_tuple(8u, 0u);
static constexpr auto HOPPER = std::make_tuple(9u, 0u);
static constexpr auto ADA_LOVELACE = std::make_tuple(8u, 9u);
/**
* Get the number of GPUs on the local machine
* @return std::nullopt if no device is available, otherwise >= 1
*/
inline std::optional get_device_count() {
uint32_t numGpus = 0;
if (nvmlDeviceGetCount_v2(&numGpus) == NVML_SUCCESS) {
return numGpus;
}
return std::nullopt;
}
/**
* Store information about the version of the CUDA Compute Capabilities detected on the device
*/
struct compute_capabilities_t {
int32_t major;
int32_t minor;
compute_capabilities_t(): compute_capabilities_t(0) {}
explicit compute_capabilities_t(size_t device_idx): major(-1), minor(-1) {
nvmlDevice_t device;
if (nvmlDeviceGetHandleByIndex_v2(device_idx, &device) == NVML_SUCCESS) {
nvmlDeviceGetCudaComputeCapability(device, &major, &minor);
}
};
compute_capabilities_t(int32_t major, int32_t minor): major(major), minor(minor) {}
/**
* Evaluate if the underlying capabilities is at least greater or equals to the provided 2-tuple (major, minor)
* @param sm Architecture version (major, minor)
* @return True if greater or equals to the underlying compute capabilities
*/
[[nodiscard]] constexpr auto is_at_least(std::tuple sm) const -> decltype(auto) { return std::tie(major, minor) >= sm; }
/**
* Check if the capabilities match at least Volta architecture (sm_70)
* @return true if at least Volta (>= sm_70), false otherwise
*/
[[nodiscard]] constexpr bool is_at_least_volta() const { return is_at_least(VOLTA); }
/**
* Check if the capabilities match at least Turing architecture (sm_75)
* @return true if at least Turing (>= sm_75), false otherwise
*/
[[nodiscard]] constexpr bool is_at_least_turing() const { return is_at_least(TURING); }
/**
* Check if the capabilities match at least Ampere architecture (sm_80)
* @return true if at least Ampere (>= sm_80), false otherwise
*/
[[nodiscard]] constexpr bool is_at_least_ampere() const { return is_at_least(AMPERE); }
/**
* Check if the capabilities match at least Ada Lovelace architecture (sm_89)
* @return true if at least Ada Lovelace (>= sm_89), false otherwise
*/
[[nodiscard]] constexpr bool is_at_least_ada_lovelace() const { return is_at_least(ADA_LOVELACE); }
/**
* Check if the capabilities match at least Hopper architecture (sm_90)
* @return true if at least Hopper (>= sm_90), false otherwise
*/
[[nodiscard]] constexpr bool is_at_least_hopper() const { return is_at_least(HOPPER); }
};
}
#endif
================================================
FILE: backends/trtllm/scripts/install_tensorrt.sh
================================================
#!/bin/bash
set -ex
TRT_VER_BASE="10.8.0"
TRT_VER_FULL="${TRT_VER_BASE}.43"
CUDA_VER="12.8"
CUDNN_VER="9.7.0.66-1"
NCCL_VER="2.25.1-1+cuda${CUDA_VER}"
CUBLAS_VER="${CUDA_VER}.3.14-1"
NVRTC_VER="${CUDA_VER}.61-1"
for i in "$@"; do
case $i in
--TRT_VER=?*) TRT_VER="${i#*=}";;
--CUDA_VER=?*) CUDA_VER="${i#*=}";;
--CUDNN_VER=?*) CUDNN_VER="${i#*=}";;
--NCCL_VER=?*) NCCL_VER="${i#*=}";;
--CUBLAS_VER=?*) CUBLAS_VER="${i#*=}";;
*) ;;
esac
shift
done
NVCC_VERSION_OUTPUT=$(nvcc --version)
if [[ $(echo $NVCC_VERSION_OUTPUT | grep -oP "\d+\.\d+" | head -n 1) != ${CUDA_VER} ]]; then
echo "The version of pre-installed CUDA is not equal to ${CUDA_VER}."
exit 1
fi
install_ubuntu_requirements() {
apt-get update && apt-get install -y --no-install-recommends gnupg2 curl ca-certificates
ARCH=$(uname -m)
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi
curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/${ARCH}/cuda-keyring_1.1-1_all.deb
dpkg -i cuda-keyring_1.1-1_all.deb
rm /etc/apt/sources.list.d/cuda-ubuntu2404-x86_64.list
apt-get update
if [[ $(apt list --installed | grep libcudnn9) ]]; then
apt-get remove --purge -y --allow-change-held-packages libcudnn9*
fi
if [[ $(apt list --installed | grep libnccl) ]]; then
apt-get remove --purge -y --allow-change-held-packages libnccl*
fi
if [[ $(apt list --installed | grep libcublas) ]]; then
apt-get remove --purge -y --allow-change-held-packages libcublas*
fi
if [[ $(apt list --installed | grep cuda-nvrtc-dev) ]]; then
apt-get remove --purge -y --allow-change-held-packages cuda-nvrtc-dev*
fi
CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
apt-get install -y --no-install-recommends libcudnn9-cuda-12=${CUDNN_VER} libcudnn9-dev-cuda-12=${CUDNN_VER}
apt-get install -y --no-install-recommends libnccl2=${NCCL_VER} libnccl-dev=${NCCL_VER}
apt-get install -y --no-install-recommends libcublas-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} libcublas-dev-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER}
# NVRTC static library doesn't exist in NGC PyTorch container.
NVRTC_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
apt-get install -y --no-install-recommends cuda-nvrtc-dev-${NVRTC_CUDA_VERSION}=${NVRTC_VER}
apt-get clean
rm -rf /var/lib/apt/lists/*
}
install_centos_requirements() {
CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
yum -y update
yum -y install epel-release
yum remove -y libnccl* && yum -y install libnccl-${NCCL_VER} libnccl-devel-${NCCL_VER}
yum remove -y libcublas* && yum -y install libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER} libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}
yum clean all
}
install_tensorrt() {
#PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
#PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
TRT_CUDA_VERSION="12.8"
if [ -z "$RELEASE_URL_TRT" ];then
ARCH=${TRT_TARGETARCH}
if [ -z "$ARCH" ];then ARCH=$(uname -m);fi
if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi
if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-24.04" && OS="ubuntu-24.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi
RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/${TRT_VER_BASE}/tars/TensorRT-${TRT_VER_FULL}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz
fi
wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
tar -xf /tmp/TensorRT.tar -C /usr/local/
mv /usr/local/TensorRT-${TRT_VER_FULL} /usr/local/tensorrt
# pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl
rm -rf /tmp/TensorRT.tar
}
# Install base packages depending on the base OS
ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"')
case "$ID" in
debian)
install_ubuntu_requirements
install_tensorrt
;;
ubuntu)
install_ubuntu_requirements
install_tensorrt
;;
centos)
install_centos_requirements
install_tensorrt
;;
*)
echo "Unable to determine OS..."
exit 1
;;
esac
================================================
FILE: backends/trtllm/scripts/setup_sccache.py
================================================
from argparse import ArgumentParser
AWS_S3_CACHING_VARIABLES = {
"AWS_ACCESS_KEY_ID": "aws_access_key_id",
"AWS_SECRET_ACCESS_KEY": "aws_secret_access_key",
"AWS_SESSION_TOKEN": "aws_session_token",
"SCCACHE_REGION": "s3_region",
"SCCACHE_BUCKET": "s3_bucket_name",
}
ALL_CACHING_STORAGE_VARIABLES = {"AWS_S3_CACHING_VARIABLES"}
def setup_sccache_locally():
from os import environ
print("Setting up Local Caching Layer")
for target in ALL_CACHING_STORAGE_VARIABLES:
for envvar in globals()[target].keys():
if envvar in environ:
print(f"Deleted {envvar} from environment variables")
del environ[envvar]
def setup_sccache_for_s3():
from os import environ
print("Setting up AWS S3 Caching Layer")
for envvar in AWS_S3_CACHING_VARIABLES.keys():
if envvar not in environ or not environ[envvar] or len(environ[envvar]) == 0:
print(f"Missing definition for environment variable {envvar}")
if __name__ == "__main__":
parser = ArgumentParser("TensorRT-LLM Build Caching Setup")
parser.add_argument(
"--is-gha-build",
type=str,
default="FALSE",
help="Indicate if the build is from Github Actions",
)
# Parse args
args = parser.parse_args()
args.is_gha_build = args.is_gha_build.lower() in {"on", "true", "1"}
if args.is_gha_build:
setup_sccache_for_s3()
else:
setup_sccache_locally()
================================================
FILE: backends/trtllm/src/errors.rs
================================================
use std::path::PathBuf;
use thiserror::Error;
use text_generation_router::server;
#[derive(Debug, Error)]
pub enum TensorRtLlmBackendError {
#[error("Provided engine folder {0} doesn't exist")]
EngineFolderDoesntExists(PathBuf),
#[error("Provided executorWorker binary path {0} doesn't exist")]
ExecutorWorkerNotFound(PathBuf),
#[error("TensorRT-LLM Runtime error: {0}")]
Runtime(String),
#[error("Tokenizer error: {0}")]
Tokenizer(String),
#[error("Argument validation error: {0}")]
ArgumentValidation(String),
#[error("WebServer error: {0}")]
WebServer(#[from] server::WebServerError),
#[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error),
}
================================================
FILE: backends/trtllm/src/lib.rs
================================================
pub use looper::TensorRtLlmBackendV2;
pub mod errors;
mod looper;
mod utils;
#[cxx::bridge(namespace = "huggingface::tgi::backends::trtllm")]
mod ffi {
#[cxx_name = "finish_reason_t"]
#[derive(Debug, Clone, Copy)]
pub enum FinishReason {
/// The request is not finished.
#[cxx_name = "kNOT_FINISHED"]
NotFinished = 0u8,
/// The request finished because the end id was generated.
#[cxx_name = "kEND_ID"]
EndTokenId = 1u8,
/// The request finished because a stop word was generated.
#[cxx_name = "kSTOP_WORDS"]
StopWords = 2u8,
/// The request finished because the maximum number of tokens was reached.
#[cxx_name = "kLENGTH"]
MaxLength = 3u8,
}
/// Struct used as shared type between rust and C++ to represent the result
/// of a single decoding iteration
#[cxx_name = "generation_step_t"]
#[derive(Debug, Clone)]
pub struct GenerationStep {
request_id: u64,
token_id: u32,
log_prob: f32,
is_final: bool,
finish_reason: FinishReason,
has_error: bool,
error_msg: String,
}
unsafe extern "C++" {
include!("backends/trtllm/csrc/ffi.hpp");
/// Represent an instance of the underlying TensorRT-LLM backend
#[cxx_name = "tensorrt_llm_backend_t"]
type TensorRtLlmBackendImpl;
/// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend
///
/// # Arguments
///
/// * `engine_folder`: Path to the folder containing all the TRTLLM engines
/// * `executor_worker`: Path to the TRTLLM executor worker
///
/// returns:
///
/// # Examples
///
/// ```
///
/// ```
fn create_backend_from_engine_folder(
engine_folder: &str,
executor_worker: &str,
) -> Result>;
fn num_tokens_ready(self: &TensorRtLlmBackendImpl) -> usize;
fn submit(
self: Pin<&mut TensorRtLlmBackendImpl>,
tokens: &[u32],
max_new_tokens: u32,
top_k: u32,
top_p: f32,
temperature: f32,
repetition_penalty: f32,
frequency_penalty: f32,
seed: u64,
) -> Result;
fn pull_tokens(
self: Pin<&mut TensorRtLlmBackendImpl>,
) -> Result>>;
fn cancel(self: Pin<&mut TensorRtLlmBackendImpl>, request_id: u64);
}
}
use ffi::FinishReason;
use text_generation_router::FinishReason as InferFinishReason;
impl From for InferFinishReason {
fn from(reason: FinishReason) -> Self {
match reason {
FinishReason::StopWords => InferFinishReason::StopSequence,
FinishReason::MaxLength => InferFinishReason::Length,
FinishReason::EndTokenId => InferFinishReason::EndOfSequenceToken,
_ => panic!("Cannot convert {reason:?} to text_generation_router::FinishReason"),
}
}
}
================================================
FILE: backends/trtllm/src/looper.rs
================================================
use async_trait::async_trait;
use cxx::UniquePtr;
use hashbrown::HashMap;
use std::hint;
use std::ops::Deref;
use std::path::Path;
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::TryAcquireError;
use tokio::task::spawn_blocking;
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, warn};
use text_generation_router::infer::InferError::{GenerationError, ValidationError};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidationError::{
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
};
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
use text_generation_router::Token;
use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{
create_backend_from_engine_folder, FinishReason, GenerationStep, TensorRtLlmBackendImpl,
};
use crate::utils::first_line;
type InferResult = Result;
/// Wrap the requests along with the channel used to stream back to the client the decoded tokens
struct GenerationContext {
request: ValidGenerateRequest,
streamer: UnboundedSender>,
tokens: Vec,
start: Option,
queued: Instant,
}
#[derive(Debug, Copy, Clone)]
struct DecodedToken {
id: u32,
log_prob: f32,
is_final: bool,
finish_reason: FinishReason,
}
impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
type Error = InferError;
fn try_from(step: &'step GenerationStep) -> Result {
if !step.has_error {
Ok(Self {
id: step.token_id,
log_prob: step.log_prob,
is_final: step.is_final,
finish_reason: step.finish_reason,
})
} else {
Err(GenerationError(step.error_msg.clone()))
}
}
}
fn executor_status_looper(
max_inflight_requests: usize,
tokenizer: Tokenizer,
mut backend: UniquePtr,
mut backlog: UnboundedReceiver,
) {
// Track the tuple (request_id, stream) for each request
let mut in_flights =
HashMap::::with_capacity(max_inflight_requests * 2);
'scheduler: loop {
// Is there any request pending to be scheduled?
let awaiting_requests = backlog.len();
for _ in 0..awaiting_requests {
// Retrieve all the requests
if let Some(ctx) = backlog.blocking_recv() {
// Submit all the request to the executor and move the context to the in-flight tracker
let request = &ctx.request;
let generation_params = &request.parameters;
let stopping_params = &request.stopping_parameters;
let input_ids = request.input_ids.as_deref();
// Submit to the TensorRT-LLM executor for scheduling
match backend.pin_mut().submit(
&input_ids.unwrap(), // This is checked beforehand in validate()
stopping_params.max_new_tokens,
generation_params.top_k,
generation_params.top_p,
generation_params.temperature,
generation_params.repetition_penalty,
generation_params.frequency_penalty,
generation_params.seed,
) {
Ok(request_id) => {
// Insert the context linked to the generated request id in the tracker
debug!("[in-flight] Added {}", request_id);
in_flights.insert(request_id, ctx);
}
Err(e) => {
// Return to the caller
let what = e.to_string();
error!(error = what.as_str(), "Failed to schedule request");
let err = Err(InferError::Overloaded(TryAcquireError::NoPermits));
if let Err(_) = ctx.streamer.send(err) {
error!("Failed to send back error to the client");
}
}
};
} else {
break 'scheduler;
}
}
if backend.num_tokens_ready() > 0 {
let mut backend = backend.pin_mut();
match backend.as_mut().pull_tokens() {
Ok(responses) => {
// Iterate through all the decoded token
for step in responses.deref() {
if let Some(ctx) = in_flights.get_mut(&step.request_id) {
// Update the starting timestamp if not set
// This value might not be the actual real starting time of the request
// on the executor side - Need to expose more info from the executor to
// retrieve this value
// TODO : Expose actual real starting time for a request on FFI layer
if ctx.start.is_none() {
ctx.start = Some(Instant::now());
}
// Try to map the generation step to a DecodedToken
let response = match DecodedToken::try_from(step) {
Ok(decoded_token) => {
post_process_decoded_token(&tokenizer, ctx, decoded_token)
}
Err(err) => Err(err),
};
// Attempt to send back the response to the client
if let Err(_) = ctx.streamer.send(response) {
// Client has dropped, remove from tracked requests
debug!(
"Client dropped - removing request {} from tracked requests",
step.request_id
);
backend.as_mut().cancel(step.request_id);
let _ = in_flights.remove(&step.request_id);
}
} else {
warn!("Untracked request {}", step.request_id,);
}
}
}
Err(ref err) => {
error!("Failed to get responses from the executor: {}.", err.what());
break 'scheduler;
}
}
}
// Hint the CPU we are spin-locking
hint::spin_loop();
}
}
fn post_process_decoded_token(
tokenizer: &Tokenizer,
ctx: &mut GenerationContext,
decoded_token: DecodedToken,
) -> InferResult {
match tokenizer.decode(&[decoded_token.id], false) {
Ok(text) => {
let is_special = tokenizer.get_added_vocabulary().is_special_token(&text);
let token = Token {
id: decoded_token.id,
text,
logprob: decoded_token.log_prob,
special: is_special,
};
// Append the token to the tracked generated tokens
ctx.tokens.push(token.id);
// Map the correct response depending on the step is final or not
let out = if !decoded_token.is_final {
InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
}
} else {
let text = tokenizer.decode(&ctx.tokens, true);
let generated_text = GeneratedText {
text: text.unwrap(),
generated_tokens: ctx.tokens.len() as u32,
finish_reason: decoded_token.finish_reason.into(),
seed: None,
};
InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text,
start: ctx.start.unwrap(),
queued: ctx.queued,
}
};
Ok(out)
}
Err(err) => Err(GenerationError(err.to_string())),
}
}
fn ensure_paths_exist, PP: AsRef>(
engine_folder: P,
executor_worker_path: PP,
) -> Result<(String, String), TensorRtLlmBackendError> {
// Retrieve paths as &str for the backend creation
let engine_folder = engine_folder.as_ref();
let executor_worker_path = executor_worker_path.as_ref();
// Ensure the engine folder exists
if !engine_folder.exists() {
let err = TensorRtLlmBackendError::EngineFolderDoesntExists(engine_folder.to_path_buf());
error!("Path validation failed: {}", err,);
return Err(err);
}
// Ensure executor worker binary exists
if !executor_worker_path.exists() {
let err = TensorRtLlmBackendError::ExecutorWorkerNotFound(engine_folder.to_path_buf());
error!("Path validation failed: {}", err,);
return Err(err);
}
let engine_folder = String::from(
engine_folder
.to_str()
.expect("Failed to convert engine_folder to valid UTF-8"),
);
let executor_worker_path = String::from(
executor_worker_path
.to_str()
.expect("Failed to convert executor_worker_path to valid UTF-8"),
);
Ok((engine_folder, executor_worker_path))
}
unsafe impl Send for TensorRtLlmBackendImpl {}
pub struct TensorRtLlmBackendV2(UnboundedSender);
impl TensorRtLlmBackendV2 {
pub fn new + Send, PP: AsRef + Send>(
tokenizer: Tokenizer,
engine_folder: P,
executor_worker_path: PP,
max_inflight_requests: usize,
) -> Result {
let (engine_folder, executor_worker_path) =
ensure_paths_exist(engine_folder, executor_worker_path)?;
// Allocate the IPC layer to communicate with the backend
let (executor_sender, executor_receiver) = unbounded_channel();
// Create the FFI backend
let backend = create_backend_from_engine_folder(&engine_folder, &executor_worker_path)
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
// Executor looper is responsible for scheduling and pulling requests state at regular interval
spawn_blocking(move || {
executor_status_looper(max_inflight_requests, tokenizer, backend, executor_receiver)
});
Ok(TensorRtLlmBackendV2(executor_sender))
}
fn validate(request: &ValidGenerateRequest) -> InferResult<()> {
if request.input_ids.is_none() {
return Err(ValidationError(UnsupportedModality("No token provided")));
}
if request.top_n_tokens > 1 {
return Err(ValidationError(TopNTokensDisabled));
}
// TODO: Is it really needed? How can it be validated before?
if request.parameters.grammar.is_some() {
return Err(ValidationError(Grammar));
}
match request.inputs.len() {
0 => Err(ValidationError(EmptyInput)),
2.. => Err(GenerationError(
"TensorRT-LLM backend don't support multi-chunk".into(),
)),
1 => match request.inputs.first().expect("Single item-chunk") {
Chunk::Text(_) => Ok(()),
Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
},
}
}
}
#[async_trait]
impl Backend for TensorRtLlmBackendV2 {
fn schedule(
&self,
request: ValidGenerateRequest,
) -> Result>, InferError> {
Self::validate(&request)?;
// Open-up the stream to send tokens
let (streamer, receiver) = unbounded_channel::>();
// Send the context to the executor for scheduling
let queued = Instant::now();
match self.0.send(GenerationContext {
request,
streamer,
tokens: Vec::with_capacity(256),
start: None,
queued,
}) {
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
Err(_) => Err(GenerationError(
"Failed to submit request to the backend".into(),
)),
}
}
async fn health(&self, _: bool) -> bool {
true
}
fn name(&self) -> &'static str {
"TensorRT-LLM"
}
}
================================================
FILE: backends/trtllm/src/main.rs
================================================
use std::path::{Path, PathBuf};
use clap::Parser;
use hf_hub::api::tokio::{Api, ApiBuilder};
use hf_hub::{Cache, Repo, RepoType};
use tracing::info;
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
use text_generation_backends_trtllm::TensorRtLlmBackendV2;
use text_generation_router::server::{
get_hub_model_info, legacy_tokenizer_handle, py_resolve_tokenizer,
};
use text_generation_router::usage_stats::UsageStatsLevel;
use text_generation_router::{server, Tokenizer};
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)]
max_input_tokens: usize,
#[clap(default_value = "2048", long, env)]
max_total_tokens: usize,
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(long, env)]
max_batch_total_tokens: Option,
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "9000", long, short, env)]
prometheus_port: u16,
#[clap(long, env, required = true)]
tokenizer_name: String,
#[clap(long, env)]
tokenizer_config_path: Option,
#[clap(long, env)]
revision: Option,
#[clap(long, env)]
model_id: String,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option,
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: String,
#[clap(long, env)]
cors_allow_origin: Option>,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
#[clap(long, env)]
auth_token: Option,
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
executor_worker: PathBuf,
#[clap(default_value = "on", long, env)]
usage_stats: UsageStatsLevel,
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
#[clap(default_value = "1073741824", long, env)]
max_image_fetch_size: usize,
}
async fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option {
// Parse Huggingface hub token
let authorization_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
// Tokenizer instance
let local_path = Path::new(tokenizer_name);
// Shared API builder initialization
let api_builder = || {
let mut builder = ApiBuilder::new()
.with_progress(false)
.with_token(authorization_token);
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
builder = builder.with_cache_dir(cache_dir.into());
}
if let Ok(origin) = std::env::var("HF_HUB_USER_AGENT_ORIGIN") {
builder = builder.with_user_agent("origin", origin.as_str());
}
builder
};
// Decide if we need to use the API based on the revision and local path
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
// Initialize API if needed
#[derive(Clone)]
enum Type {
Api(Api),
Cache(Cache),
None,
}
let api = if use_api {
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
.map_err(|_| ())
.map(|cache_dir| Cache::new(cache_dir.into()))
.unwrap_or_else(|_| Cache::default());
tracing::warn!("Offline mode active using cache defaults");
Type::Cache(cache)
} else {
tracing::info!("Using the Hugging Face API");
match api_builder().build() {
Ok(api) => Type::Api(api),
Err(_) => {
tracing::warn!("Unable to build the Hugging Face API");
Type::None
}
}
}
} else {
Type::None
};
// Load tokenizer and model info
let (
config_filename,
_tokenizer_config_filename,
_preprocessor_config_filename,
_processor_config_filename,
_model_info,
) = match api {
Type::None => (
Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")),
None,
),
Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.unwrap_or_else(|| "main").to_string(),
));
let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
Some(model_info)
} else {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
None
};
(
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
model_info,
)
}
Type::Cache(cache) => {
let repo = cache.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main").to_string(),
));
(
repo.get("config.json"),
repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"),
repo.get("processor_config.json"),
None,
)
}
};
let tokenizer: Tokenizer = {
use pyo3::prelude::*;
pyo3::Python::with_gil(|py| -> PyResult<()> {
py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), false)?;
Ok(())
})
.inspect_err(|err| {
tracing::error!("Failed to import python tokenizer {err}");
})
.or_else(|err| {
let out = legacy_tokenizer_handle(config_filename.as_ref());
out.ok_or(err)
})
.expect("We cannot load a tokenizer");
let filename = "out/tokenizer.json";
if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) {
Tokenizer::Rust(tok)
} else {
Tokenizer::Python {
tokenizer_name: tokenizer_name.to_string(),
revision: revision.map(|revision| revision.to_string()),
trust_remote_code: false,
}
}
};
Some(tokenizer)
}
#[tokio::main]
async fn main() -> Result<(), TensorRtLlmBackendError> {
// Get args
let args = Args::parse();
// Pattern match configuration
let Args {
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
max_batch_prefill_tokens,
max_batch_total_tokens,
hostname,
port,
prometheus_port,
tokenizer_name,
tokenizer_config_path,
revision,
model_id,
validation_workers,
json_output,
otlp_endpoint,
otlp_service_name,
cors_allow_origin,
max_client_batch_size,
auth_token,
executor_worker,
usage_stats,
payload_limit,
max_image_fetch_size,
} = args;
// Launch Tokio runtime
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
// Validate args
if max_input_tokens >= max_total_tokens {
return Err(TensorRtLlmBackendError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if validation_workers == 0 {
return Err(TensorRtLlmBackendError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
));
}
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
}
if !executor_worker.exists() {
return Err(TensorRtLlmBackendError::ArgumentValidation(format!(
"`executor_work` specified path doesn't exists: {}",
executor_worker.display()
)));
}
// Create the backend
match get_tokenizer(&tokenizer_name, revision.as_deref())
.await
.expect("Failed to retrieve tokenizer implementation")
{
Tokenizer::Python { .. } => Err(TensorRtLlmBackendError::Tokenizer(
"Failed to retrieve Rust based tokenizer".to_string(),
)),
Tokenizer::Rust(tokenizer) => {
info!("Successfully retrieved tokenizer {}", &tokenizer_name);
let backend = TensorRtLlmBackendV2::new(
tokenizer,
model_id,
executor_worker,
max_concurrent_requests,
)?;
info!("Successfully created backend");
// Run server
server::run(
backend,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
validation_workers,
auth_token,
tokenizer_name,
tokenizer_config_path,
revision,
false,
hostname,
port,
cors_allow_origin,
false,
None,
None,
true,
max_client_batch_size,
usage_stats,
payload_limit,
max_image_fetch_size,
prometheus_port,
)
.await?;
Ok(())
}
}
}
================================================
FILE: backends/trtllm/src/utils.rs
================================================
///
/// Extract the first line of the provided string reference.
/// If there is no lines in the buffer, it returns a string
/// which content is defined by the content of `fail`
/// # Arguments
///
/// * `s`: The string buffer to extract the first-line from
/// * `fail`: A string content which is returned if no lines are
/// present in `s`
///
/// returns: String
///
/// # Examples
///
/// ```
/// let s = "My name is Morgan.\n I'm working at Hugging Face.";
/// first_line(s, "No line in string");
/// ```
#[inline]
pub(crate) fn first_line(s: &str, fail: &str) -> String {
s.lines().next().unwrap_or(fail).to_string()
}
================================================
FILE: backends/trtllm/tests/test_backend.cpp
================================================
//
// Created by mfuntowicz on 12/3/24.
//
#include
#include
#include
#include "backend.hpp"
using namespace huggingface::tgi::backends::trtllm;
TEST_CASE("parse generation_config.json all set", "[generation_config_t]")
{
const json config_j = {{"temperature", 0.6},
{"top_p", 0.95},
{"eos_token_id", {1, 2, 3}}};
const auto generation_config = generation_config_t(config_j);
REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(0.6, 1e-6));
REQUIRE_THAT(generation_config.top_p, Catch::Matchers::WithinAbs(0.95, 1e-6));
// Stop words
REQUIRE_FALSE(generation_config.stop_words.empty());
REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());
for (auto [lhs, rhs]: std::views::zip(generation_config.stop_words, std::list>{{1},
{2},
{3}})) {
// Currently we do not support multi-tokens stop words
REQUIRE(lhs.size() == 1);
REQUIRE(rhs.size() == 1);
REQUIRE_THAT(lhs, Catch::Matchers::UnorderedEquals(rhs));
}
}
TEST_CASE("parse generation_config.json default", "[generation_config_t]")
{
const json config_j = {{"eos_token_id", {1, 2, 3}}};
const auto generation_config = generation_config_t(config_j);
REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6));
REQUIRE_THAT(generation_config.top_p, Catch::Matchers::WithinAbs(1.0, 1e-6));
REQUIRE_FALSE(generation_config.stop_words.empty());
REQUIRE(generation_config.stop_words.size() == config_j["/eos_token_id"_json_pointer].size());
for (auto [lhs, rhs]: std::views::zip(generation_config.stop_words, std::list>{{1},
{2},
{3}})) {
// Currently we do not support multi-tokens stop words
REQUIRE(lhs.size() == 1);
REQUIRE(rhs.size() == 1);
REQUIRE_THAT(lhs, Catch::Matchers::UnorderedEquals(rhs));
}
}
TEST_CASE("parse generation_config.json empty", "[generation_config_t]")
{
const json config_j = {{"eos_token_id", {}}};
const auto generation_config = generation_config_t(config_j);
REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6));
REQUIRE_THAT(generation_config.top_p, Catch::Matchers::WithinAbs(1.0, 1e-6));
REQUIRE(generation_config.stop_words.empty());
const json config_j2 = {};
const auto generation_config2 = generation_config_t(config_j);
REQUIRE_THAT(generation_config2.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6));
REQUIRE_THAT(generation_config2.top_p, Catch::Matchers::WithinAbs(1.0, 1e-6));
REQUIRE(generation_config2.stop_words.empty());
}
TEST_CASE("parallel_config single", "[backend_workspace_t]")
{
// Generate temporary folder
const auto tmp_p = std::filesystem::temp_directory_path();
const auto config_p = tmp_p / "config.json";
const auto generation_config_p = tmp_p / "generation_config.json";
// Generate content
std::ofstream o_config(config_p);
o_config << R"({"pretrained_config": {"mapping": {"world_size": 2}}})"_json;
o_config.close();
std::ofstream o_generation_config(generation_config_p);
o_generation_config << R"({"eos_token_id": []})"_json;
o_generation_config.close();
const auto workspace = backend_workspace_t(tmp_p.generic_string(), tmp_p.generic_string());
const auto parallel = workspace.parallel_config();
REQUIRE(parallel.getCommunicationMode() == tle::CommunicationMode::kORCHESTRATOR);
REQUIRE(parallel.getCommunicationType() == tle::CommunicationType::kMPI);
std::filesystem::remove(config_p);
std::filesystem::remove(generation_config_p);
}
TEST_CASE("parallel_config multi", "[backend_workspace_t]")
{
// Generate temporary folder
const auto tmp_p = std::filesystem::temp_directory_path();
const auto config_p = tmp_p / "config.json";
const auto generation_config_p = tmp_p / "generation_config.json";
// Generate content
std::ofstream o_config(config_p);
o_config << R"({"pretrained_config": {"mapping": {"world_size": 1}}})"_json;
o_config.close();
std::ofstream o_generation_config(generation_config_p);
o_generation_config << R"({"eos_token_id": []})"_json;
o_generation_config.close();
const auto workspace = backend_workspace_t(tmp_p.generic_string(), tmp_p.generic_string());
const auto parallel = workspace.parallel_config();
REQUIRE(parallel.getCommunicationMode() == tle::CommunicationMode::kLEADER);
REQUIRE(parallel.getCommunicationType() == tle::CommunicationType::kMPI);
std::filesystem::remove(config_p);
std::filesystem::remove(generation_config_p);
}
TEST_CASE("executor_config", "[backend_workspace_t]")
{
}
TEST_CASE("sampling_params_t to tle::SamplingConfig", "[backend_t]")
{
const sampling_params_t params = {40, 0.95, 0.9, 1.0, 0.6, 2014};
const auto config = static_cast(params);
REQUIRE(config.getTopK().has_value());
REQUIRE(config.getTopK().value() == params.top_k);
REQUIRE(config.getSeed().has_value());
REQUIRE(config.getSeed().value() == params.seed);
REQUIRE(config.getTopP().has_value());
REQUIRE_THAT(*config.getTopP(), Catch::Matchers::WithinAbs(params.top_p, 1e-6f));
REQUIRE(config.getRepetitionPenalty().has_value());
REQUIRE_THAT(*config.getRepetitionPenalty(), Catch::Matchers::WithinAbs(params.repetition_penalty, 1e-6f));
REQUIRE(config.getFrequencyPenalty().has_value());
REQUIRE_THAT(*config.getFrequencyPenalty(), Catch::Matchers::WithinAbs(params.frequency_penalty, 1e-6f));
REQUIRE(config.getTemperature().has_value());
REQUIRE_THAT(*config.getTemperature(), Catch::Matchers::WithinAbs(params.temperature, 1e-6f));
}
================================================
FILE: backends/trtllm/tests/test_hardware.cpp
================================================
//
// Created by mfuntowicz on 11/16/24.
//
#include
#include "../csrc/hardware.hpp"
using namespace huggingface::tgi::hardware::cuda;
TEST_CASE("is_at_least_") {
const static auto VOLTA_CAPABILITIES = compute_capabilities_t(7, 0);
REQUIRE(VOLTA_CAPABILITIES.is_at_least_volta());
REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least_turing());
REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least_ampere());
REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least_ada_lovelace());
REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least_hopper());
const static auto TURING_CAPABILITIES = compute_capabilities_t(7, 5);
REQUIRE(TURING_CAPABILITIES.is_at_least_volta());
REQUIRE(TURING_CAPABILITIES.is_at_least_turing());
REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least_ampere());
REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least_ada_lovelace());
REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least_hopper());
const static auto AMPERE_CAPABILITIES = compute_capabilities_t(8, 0);
REQUIRE(AMPERE_CAPABILITIES.is_at_least_volta());
REQUIRE(AMPERE_CAPABILITIES.is_at_least_turing());
REQUIRE(AMPERE_CAPABILITIES.is_at_least_ampere());
REQUIRE_FALSE(AMPERE_CAPABILITIES.is_at_least_ada_lovelace());
REQUIRE_FALSE(AMPERE_CAPABILITIES.is_at_least_hopper());
const static auto ADA_LOVELACE_CAPABILITIES = compute_capabilities_t(8, 9);
REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least_volta());
REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least_turing());
REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least_ampere());
REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least_ada_lovelace());
REQUIRE_FALSE(ADA_LOVELACE_CAPABILITIES.is_at_least_hopper());
const static auto HOPPER_CAPABILITIES = compute_capabilities_t(9, 0);
REQUIRE(HOPPER_CAPABILITIES.is_at_least_volta());
REQUIRE(HOPPER_CAPABILITIES.is_at_least_turing());
REQUIRE(HOPPER_CAPABILITIES.is_at_least_ampere());
REQUIRE(HOPPER_CAPABILITIES.is_at_least_ada_lovelace());
REQUIRE(HOPPER_CAPABILITIES.is_at_least_hopper());
}
TEST_CASE("is_at_least") {
const static auto VOLTA_CAPABILITIES = compute_capabilities_t(7, 0);
REQUIRE(VOLTA_CAPABILITIES.is_at_least(VOLTA));
REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least(TURING));
REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least(AMPERE));
REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least(ADA_LOVELACE));
REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least(HOPPER));
const static auto TURING_CAPABILITIES = compute_capabilities_t(7, 5);
REQUIRE(TURING_CAPABILITIES.is_at_least(VOLTA));
REQUIRE(TURING_CAPABILITIES.is_at_least(TURING));
REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least(AMPERE));
REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least(ADA_LOVELACE));
REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least(HOPPER));
const static auto AMPERE_CAPABILITIES = compute_capabilities_t(8, 0);
REQUIRE(AMPERE_CAPABILITIES.is_at_least(VOLTA));
REQUIRE(AMPERE_CAPABILITIES.is_at_least(TURING));
REQUIRE(AMPERE_CAPABILITIES.is_at_least(AMPERE));
REQUIRE_FALSE(AMPERE_CAPABILITIES.is_at_least(ADA_LOVELACE));
REQUIRE_FALSE(AMPERE_CAPABILITIES.is_at_least(HOPPER));
const static auto ADA_LOVELACE_CAPABILITIES = compute_capabilities_t(8, 9);
REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least(VOLTA));
REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least(TURING));
REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least(AMPERE));
REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least(ADA_LOVELACE));
REQUIRE_FALSE(ADA_LOVELACE_CAPABILITIES.is_at_least(HOPPER));
const static auto HOPPER_CAPABILITIES = compute_capabilities_t (9, 0);
REQUIRE(HOPPER_CAPABILITIES.is_at_least(VOLTA));
REQUIRE(HOPPER_CAPABILITIES.is_at_least(TURING));
REQUIRE(HOPPER_CAPABILITIES.is_at_least(AMPERE));
REQUIRE(HOPPER_CAPABILITIES.is_at_least(ADA_LOVELACE));
REQUIRE(HOPPER_CAPABILITIES.is_at_least(HOPPER));
}
================================================
FILE: backends/v2/Cargo.toml
================================================
[package]
name = "text-generation-router-v2"
description = "Text Generation Webserver"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true
[lib]
path = "src/lib.rs"
[[bin]]
name = "text-generation-router-v2"
path = "src/main.rs"
[dependencies]
async-trait = "0.1.74"
async-stream = "0.3.5"
axum = { version = "0.7", features = ["json"] }
axum-tracing-opentelemetry = "0.16"
text-generation-router = { path = "../../router" }
clap = { version = "4.4.5", features = ["derive", "env"] }
grpc-metadata = { path = "../grpc-metadata" }
futures = "0.3.28"
hf-hub = { workspace = true }
jsonschema = { version = "0.28.0" }
metrics = { workspace = true }
metrics-exporter-prometheus = { workspace = true }
nohash-hasher = "0.2.0"
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.13.0"
rand = "0.8.5"
reqwest = { version = "0.11.20", features = [] }
serde = "1.0.188"
serde_json = "1.0.107"
slotmap = "1.0.7"
thiserror = "1.0.48"
tokenizers = { workspace = true }
tokio = { version = "1.32.0", features = [
"rt",
"rt-multi-thread",
"parking_lot",
"signal",
"sync",
] }
tokio-stream = "0.1.14"
tower-http = { version = "0.5.1", features = ["cors"] }
tracing = "0.1.37"
tracing-opentelemetry = "0.21.0"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
utoipa = { version = "4.2.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
init-tracing-opentelemetry = { version = "0.14.1", features = [
"opentelemetry-otlp",
] }
minijinja = { workspace = true }
minijinja-contrib = { workspace = true }
futures-util = "0.3.30"
regex = "1.10.3"
once_cell = "1.19.0"
image = "0.25.1"
base64 = { workspace = true }
prost = "^0.12"
tonic = "^0.10"
tower = "^0.4"
[build-dependencies]
tonic-build = "0.10.1"
prost-build = "0.12.1"
[features]
default = ["ngrok"]
ngrok = ["text-generation-router/ngrok"]
google = ["text-generation-router/google"]
kserve = ["text-generation-router/kserve"]
================================================
FILE: backends/v2/build.rs
================================================
use std::fs;
fn main() -> Result<(), Box> {
println!("cargo:rerun-if-changed=../../proto/");
fs::create_dir_all("src/client/pb").unwrap_or(());
let mut config = prost_build::Config::new();
config.protoc_arg("--experimental_allow_proto3_optional");
tonic_build::configure()
.build_client(true)
.build_server(false)
.out_dir("src/client/pb")
.include_file("mod.rs")
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
Ok(())
}
================================================
FILE: backends/v2/src/backend.rs
================================================
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
/// Batching and inference logic
use crate::queue::{Entry, Queue};
use async_trait::async_trait;
use nohash_hasher::IntMap;
use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{FinishReason, PrefillToken, Token};
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{info_span, instrument, Instrument, Span};
pub struct BackendV2 {
/// Request queue
queue: Queue,
/// Notify batcher on queue appends
batching_task_notifier: Arc,
/// Client clone, used for health checks to skip the queue
client: ShardedClient,
}
impl BackendV2 {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
client: ShardedClient,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option,
requires_padding: bool,
window_size: Option,
speculate: u32,
) -> Self {
// Infer shared state
let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string());
let block_size = match attention.as_str() {
"flashinfer" => 1,
"flashdecoding" => 256,
"paged" => 16,
_ => unreachable!(),
};
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let batching_task_notifier = Arc::new(Notify::new());
// Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task(
client.clone(),
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
queue.clone(),
batching_task_notifier.clone(),
));
Self {
queue,
batching_task_notifier,
client,
}
}
}
#[async_trait]
impl Backend for BackendV2 {
#[instrument(skip_all)]
fn schedule(
&self,
request: ValidGenerateRequest,
) -> Result>, InferError> {
// MPSC channel to communicate with the background batching task
let (response_tx, response_rx) = mpsc::unbounded_channel();
// Append the request to the queue
self.queue.append(Entry {
request,
response_tx,
span: Span::current(),
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
});
// Notify the background task that we have a new entry in the queue that needs
// to be batched
self.batching_task_notifier.notify_one();
// Return stream
Ok(UnboundedReceiverStream::new(response_rx))
}
async fn health(&self, current_health: bool) -> bool {
if current_health {
// Generation is healthy, we only check that the shards can allocate on device
self.client.device_health().await
} else {
self.client.model_health().await
}
.is_ok()
}
fn start_health(&self) -> bool {
true
}
fn name(&self) -> &'static str {
"tgi-v2"
}
}
/// Batching logic
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
#[allow(clippy::too_many_arguments)]
pub(crate) async fn batching_task(
mut client: ShardedClient,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option,
queue: Queue,
notifier: Arc